urlmon: Update winehq.org IP.
[wine] / dlls / urlmon / tests / protocol.c
1 /*
2  * Copyright 2005 Jacek Caban
3  *
4  * This library is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 2.1 of the License, or (at your option) any later version.
8  *
9  * This library is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * Lesser General Public License for more details.
13  *
14  * You should have received a copy of the GNU Lesser General Public
15  * License along with this library; if not, write to the Free Software
16  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
17  */
18
19 #define COBJMACROS
20 #define CONST_VTABLE
21
22 #include <wine/test.h>
23 #include <stdarg.h>
24
25 #include "windef.h"
26 #include "winbase.h"
27 #include "ole2.h"
28 #include "urlmon.h"
29
30 #include "initguid.h"
31
32 #define DEFINE_EXPECT(func) \
33     static BOOL expect_ ## func = FALSE, called_ ## func = FALSE
34
35 #define SET_EXPECT(func) \
36     expect_ ## func = TRUE
37
38 #define CHECK_EXPECT(func) \
39     do { \
40         ok(expect_ ##func, "unexpected call " #func "\n"); \
41         expect_ ## func = FALSE; \
42         called_ ## func = TRUE; \
43     }while(0)
44
45 #define CHECK_EXPECT2(func) \
46     do { \
47         ok(expect_ ##func, "unexpected call " #func  "\n"); \
48         called_ ## func = TRUE; \
49     }while(0)
50
51 #define CHECK_CALLED(func) \
52     do { \
53         ok(called_ ## func, "expected " #func "\n"); \
54         expect_ ## func = called_ ## func = FALSE; \
55     }while(0)
56
57 DEFINE_EXPECT(GetBindInfo);
58 DEFINE_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
59 DEFINE_EXPECT(ReportProgress_DIRECTBIND);
60 DEFINE_EXPECT(ReportProgress_FINDINGRESOURCE);
61 DEFINE_EXPECT(ReportProgress_CONNECTING);
62 DEFINE_EXPECT(ReportProgress_SENDINGREQUEST);
63 DEFINE_EXPECT(ReportProgress_CACHEFILENAMEAVAILABLE);
64 DEFINE_EXPECT(ReportProgress_VERIFIEDMIMETYPEAVAILABLE);
65 DEFINE_EXPECT(ReportData);
66 DEFINE_EXPECT(ReportResult);
67 DEFINE_EXPECT(GetBindString_ACCEPT_MIMES);
68 DEFINE_EXPECT(GetBindString_USER_AGENT);
69 DEFINE_EXPECT(QueryService_HttpNegotiate);
70 DEFINE_EXPECT(BeginningTransaction);
71 DEFINE_EXPECT(GetRootSecurityId);
72 DEFINE_EXPECT(OnResponse);
73 DEFINE_EXPECT(Switch);
74
75 static const WCHAR wszIndexHtml[] = {'i','n','d','e','x','.','h','t','m','l',0};
76 static const WCHAR index_url[] =
77     {'f','i','l','e',':','i','n','d','e','x','.','h','t','m','l',0};
78
79 static HRESULT expect_hrResult;
80 static LPCWSTR file_name, http_url;
81 static IInternetProtocol *http_protocol = NULL;
82 static BOOL first_data_notif = FALSE;
83 static HWND protocol_hwnd;
84 static int state = 0;
85 static DWORD bindf = 0;
86
87 static enum {
88     FILE_TEST,
89     HTTP_TEST
90 } tested_protocol;
91
92 static HRESULT WINAPI HttpNegotiate_QueryInterface(IHttpNegotiate2 *iface, REFIID riid, void **ppv)
93 {
94     if(IsEqualGUID(&IID_IUnknown, riid)
95             || IsEqualGUID(&IID_IHttpNegotiate, riid)
96             || IsEqualGUID(&IID_IHttpNegotiate2, riid)) {
97         *ppv = iface;
98         return S_OK;
99     }
100
101     ok(0, "unexpected call\n");
102     return E_NOINTERFACE;
103 }
104
105 static ULONG WINAPI HttpNegotiate_AddRef(IHttpNegotiate2 *iface)
106 {
107     return 2;
108 }
109
110 static ULONG WINAPI HttpNegotiate_Release(IHttpNegotiate2 *iface)
111 {
112     return 1;
113 }
114
115 static HRESULT WINAPI HttpNegotiate_BeginningTransaction(IHttpNegotiate2 *iface, LPCWSTR szURL,
116         LPCWSTR szHeaders, DWORD dwReserved, LPWSTR *pszAdditionalHeaders)
117 {
118     CHECK_EXPECT(BeginningTransaction);
119
120     ok(!lstrcmpW(szURL, http_url), "szURL != http_url\n");
121
122     ok(szHeaders != NULL, "szHeaders == NULL\n");
123     if(szHeaders) {
124         static const WCHAR header[] =
125             {'A','c','c','e','p','t','-','E','n','c','o','d','i','n','g',':',
126              ' ','g','z','i','p',',',' ','d','e','f','l','a','t','e',0};
127         ok(!lstrcmpW(header, szHeaders), "Unexpected szHeaders\n");
128     }
129
130     ok(!dwReserved, "dwReserved=%d, expected 0\n", dwReserved);
131     ok(pszAdditionalHeaders != NULL, "pszAdditionalHeaders == NULL\n");
132     if(pszAdditionalHeaders)
133         ok(*pszAdditionalHeaders == NULL, "*pszAdditionalHeaders != NULL\n");
134
135     return S_OK;
136 }
137
138 static HRESULT WINAPI HttpNegotiate_OnResponse(IHttpNegotiate2 *iface, DWORD dwResponseCode,
139         LPCWSTR szResponseHeaders, LPCWSTR szRequestHeaders, LPWSTR *pszAdditionalRequestHeaders)
140 {
141     CHECK_EXPECT(OnResponse);
142
143     ok(dwResponseCode == 200, "dwResponseCode=%d, expected 200\n", dwResponseCode);
144     ok(szResponseHeaders != NULL, "szResponseHeaders == NULL\n");
145     ok(szRequestHeaders == NULL, "szRequestHeaders != NULL\n");
146     ok(pszAdditionalRequestHeaders == NULL, "pszAdditionalHeaders != NULL\n");
147
148     return S_OK;
149 }
150
151 static HRESULT WINAPI HttpNegotiate_GetRootSecurityId(IHttpNegotiate2 *iface,
152         BYTE *pbSecurityId, DWORD *pcbSecurityId, DWORD_PTR dwReserved)
153 {
154     static const BYTE sec_id[] = {'h','t','t','p',':','t','e','s','t',1,0,0,0};
155     
156     CHECK_EXPECT(GetRootSecurityId);
157
158     ok(!dwReserved, "dwReserved=%ld, expected 0\n", dwReserved);
159     ok(pbSecurityId != NULL, "pbSecurityId == NULL\n");
160     ok(pcbSecurityId != NULL, "pcbSecurityId == NULL\n");
161
162     if(pcbSecurityId) {
163         ok(*pcbSecurityId == 512, "*pcbSecurityId=%d, expected 512\n", *pcbSecurityId);
164         *pcbSecurityId = sizeof(sec_id);
165     }
166
167     if(pbSecurityId)
168         memcpy(pbSecurityId, sec_id, sizeof(sec_id));
169
170     return E_FAIL;
171 }
172
173 static IHttpNegotiate2Vtbl HttpNegotiateVtbl = {
174     HttpNegotiate_QueryInterface,
175     HttpNegotiate_AddRef,
176     HttpNegotiate_Release,
177     HttpNegotiate_BeginningTransaction,
178     HttpNegotiate_OnResponse,
179     HttpNegotiate_GetRootSecurityId
180 };
181
182 static IHttpNegotiate2 http_negotiate = { &HttpNegotiateVtbl };
183
184 static HRESULT QueryInterface(REFIID,void**);
185
186 static HRESULT WINAPI ServiceProvider_QueryInterface(IServiceProvider *iface, REFIID riid, void **ppv)
187 {
188     return QueryInterface(riid, ppv);
189 }
190
191 static ULONG WINAPI ServiceProvider_AddRef(IServiceProvider *iface)
192 {
193     return 2;
194 }
195
196 static ULONG WINAPI ServiceProvider_Release(IServiceProvider *iface)
197 {
198     return 1;
199 }
200
201 static HRESULT WINAPI ServiceProvider_QueryService(IServiceProvider *iface, REFGUID guidService,
202         REFIID riid, void **ppv)
203 {
204     if(IsEqualGUID(&IID_IHttpNegotiate, guidService) || IsEqualGUID(&IID_IHttpNegotiate2, riid)) {
205         CHECK_EXPECT2(QueryService_HttpNegotiate);
206         return IHttpNegotiate2_QueryInterface(&http_negotiate, riid, ppv);
207     }
208
209     ok(0, "unexpected call\n");
210     return E_FAIL;
211 }
212
213 static const IServiceProviderVtbl ServiceProviderVtbl = {
214     ServiceProvider_QueryInterface,
215     ServiceProvider_AddRef,
216     ServiceProvider_Release,
217     ServiceProvider_QueryService
218 };
219
220 static IServiceProvider service_provider = { &ServiceProviderVtbl };
221
222 static HRESULT WINAPI ProtocolSink_QueryInterface(IInternetProtocolSink *iface, REFIID riid, void **ppv)
223 {
224     return QueryInterface(riid, ppv);
225 }
226
227 static ULONG WINAPI ProtocolSink_AddRef(IInternetProtocolSink *iface)
228 {
229     return 2;
230 }
231
232 static ULONG WINAPI ProtocolSink_Release(IInternetProtocolSink *iface)
233 {
234     return 1;
235 }
236
237 static HRESULT WINAPI ProtocolSink_Switch(IInternetProtocolSink *iface, PROTOCOLDATA *pProtocolData)
238 {
239     CHECK_EXPECT2(Switch);
240     ok(pProtocolData != NULL, "pProtocolData == NULL\n");
241     SendMessageW(protocol_hwnd, WM_USER, 0, (LPARAM)pProtocolData);
242     return S_OK;
243 }
244
245 static HRESULT WINAPI ProtocolSink_ReportProgress(IInternetProtocolSink *iface, ULONG ulStatusCode,
246         LPCWSTR szStatusText)
247 {
248     static const WCHAR text_html[] = {'t','e','x','t','/','h','t','m','l',0};
249     static const WCHAR host[] =
250         {'w','w','w','.','w','i','n','e','h','q','.','o','r','g',0};
251     static const WCHAR wszWineHQIP[] =
252         {'2','0','9','.','4','6','.','2','5','.','1','3','4',0};
253     /* I'm not sure if it's a good idea to hardcode here the IP address... */
254
255     switch(ulStatusCode) {
256     case BINDSTATUS_MIMETYPEAVAILABLE:
257         CHECK_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
258         ok(szStatusText != NULL, "szStatusText == NULL\n");
259         if(szStatusText)
260             ok(!lstrcmpW(szStatusText, text_html), "szStatusText != text/html\n");
261         break;
262     case BINDSTATUS_DIRECTBIND:
263         CHECK_EXPECT2(ReportProgress_DIRECTBIND);
264         ok(szStatusText == NULL, "szStatusText != NULL\n");
265         break;
266     case BINDSTATUS_CACHEFILENAMEAVAILABLE:
267         CHECK_EXPECT(ReportProgress_CACHEFILENAMEAVAILABLE);
268         ok(szStatusText != NULL, "szStatusText == NULL\n");
269         if(szStatusText)
270             ok(!lstrcmpW(szStatusText, file_name), "szStatusText != file_name\n");
271         break;
272     case BINDSTATUS_FINDINGRESOURCE:
273         CHECK_EXPECT(ReportProgress_FINDINGRESOURCE);
274         ok(szStatusText != NULL, "szStatusText == NULL\n");
275         if(szStatusText)
276             ok(!lstrcmpW(szStatusText, host), "szStatustext != \"www.winehq.org\"\n");
277         break;
278     case BINDSTATUS_CONNECTING:
279         CHECK_EXPECT(ReportProgress_CONNECTING);
280         ok(szStatusText != NULL, "szStatusText == NULL\n");
281         if(szStatusText)
282             ok(!lstrcmpW(szStatusText, wszWineHQIP), "Unexpected szStatusText\n");
283         break;
284     case BINDSTATUS_SENDINGREQUEST:
285         CHECK_EXPECT(ReportProgress_SENDINGREQUEST);
286         if(tested_protocol == FILE_TEST) {
287             ok(szStatusText != NULL, "szStatusText == NULL\n");
288             if(szStatusText)
289                 ok(!*szStatusText, "wrong szStatusText\n");
290         }else {
291             ok(szStatusText == NULL, "szStatusText != NULL\n");
292         }
293         break;
294     case BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE:
295         CHECK_EXPECT(ReportProgress_VERIFIEDMIMETYPEAVAILABLE);
296         ok(szStatusText != NULL, "szStatusText == NULL\n");
297         if(szStatusText)
298             ok(!lstrcmpW(szStatusText, text_html), "szStatusText != text/html\n");
299         break;
300     default:
301         ok(0, "Unexpected call %d\n", ulStatusCode);
302     };
303
304     return S_OK;
305 }
306
307 static HRESULT WINAPI ProtocolSink_ReportData(IInternetProtocolSink *iface, DWORD grfBSCF,
308         ULONG ulProgress, ULONG ulProgressMax)
309 {
310     if(tested_protocol == FILE_TEST) {
311         CHECK_EXPECT2(ReportData);
312
313         ok(ulProgress == ulProgressMax, "ulProgress (%d) != ulProgressMax (%d)\n",
314            ulProgress, ulProgressMax);
315         ok(ulProgressMax == 13, "ulProgressMax=%d, expected 13\n", ulProgressMax);
316         ok(grfBSCF == (BSCF_FIRSTDATANOTIFICATION | BSCF_LASTDATANOTIFICATION),
317                 "grcfBSCF = %08x\n", grfBSCF);
318     }else if(tested_protocol == HTTP_TEST) {
319         if(!(grfBSCF & BSCF_LASTDATANOTIFICATION))
320             CHECK_EXPECT(ReportData);
321
322         ok(ulProgress, "ulProgress == 0\n");
323
324         if(first_data_notif) {
325             ok(grfBSCF == BSCF_FIRSTDATANOTIFICATION, "grcfBSCF = %08x\n", grfBSCF);
326             first_data_notif = FALSE;
327         } else {
328             ok(grfBSCF == BSCF_INTERMEDIATEDATANOTIFICATION
329                || grfBSCF == (BSCF_LASTDATANOTIFICATION|BSCF_INTERMEDIATEDATANOTIFICATION),
330                "grcfBSCF = %08x\n", grfBSCF);
331         }
332     }
333     return S_OK;
334 }
335
336 static HRESULT WINAPI ProtocolSink_ReportResult(IInternetProtocolSink *iface, HRESULT hrResult,
337         DWORD dwError, LPCWSTR szResult)
338 {
339     CHECK_EXPECT(ReportResult);
340
341     ok(hrResult == expect_hrResult, "hrResult = %08x, expected: %08x\n",
342             hrResult, expect_hrResult);
343     if(SUCCEEDED(hrResult))
344         ok(dwError == ERROR_SUCCESS, "dwError = %d, expected ERROR_SUCCESS\n", dwError);
345     else
346         ok(dwError != ERROR_SUCCESS, "dwError == ERROR_SUCCESS\n");
347     ok(!szResult, "szResult != NULL\n");
348
349     return S_OK;
350 }
351
352 static IInternetProtocolSinkVtbl protocol_sink_vtbl = {
353     ProtocolSink_QueryInterface,
354     ProtocolSink_AddRef,
355     ProtocolSink_Release,
356     ProtocolSink_Switch,
357     ProtocolSink_ReportProgress,
358     ProtocolSink_ReportData,
359     ProtocolSink_ReportResult
360 };
361
362 static IInternetProtocolSink protocol_sink = { &protocol_sink_vtbl };
363
364 static HRESULT QueryInterface(REFIID riid, void **ppv)
365 {
366     *ppv = NULL;
367
368     if(IsEqualGUID(&IID_IUnknown, riid) || IsEqualGUID(&IID_IInternetProtocolSink, riid))
369         *ppv = &protocol_sink;
370     if(IsEqualGUID(&IID_IServiceProvider, riid))
371         *ppv = &service_provider;
372
373     if(*ppv)
374         return S_OK;
375
376     return E_NOINTERFACE;
377 }
378
379 static HRESULT WINAPI BindInfo_QueryInterface(IInternetBindInfo *iface, REFIID riid, void **ppv)
380 {
381     if(IsEqualGUID(&IID_IUnknown, riid) || IsEqualGUID(&IID_IInternetBindInfo, riid)) {
382         *ppv = iface;
383         return S_OK;
384     }
385     return E_NOINTERFACE;
386 }
387
388 static ULONG WINAPI BindInfo_AddRef(IInternetBindInfo *iface)
389 {
390     return 2;
391 }
392
393 static ULONG WINAPI BindInfo_Release(IInternetBindInfo *iface)
394 {
395     return 1;
396 }
397
398 static HRESULT WINAPI BindInfo_GetBindInfo(IInternetBindInfo *iface, DWORD *grfBINDF, BINDINFO *pbindinfo)
399 {
400     DWORD cbSize;
401
402     CHECK_EXPECT(GetBindInfo);
403
404     ok(grfBINDF != NULL, "grfBINDF == NULL\n");
405     ok(pbindinfo != NULL, "pbindinfo == NULL\n");
406     ok(pbindinfo->cbSize == sizeof(BINDINFO), "wrong size of pbindinfo: %d\n", pbindinfo->cbSize);
407
408     *grfBINDF = bindf;
409     cbSize = pbindinfo->cbSize;
410     memset(pbindinfo, 0, cbSize);
411     pbindinfo->cbSize = cbSize;
412
413     return S_OK;
414 }
415
416 static HRESULT WINAPI BindInfo_GetBindString(IInternetBindInfo *iface, ULONG ulStringType,
417         LPOLESTR *ppwzStr, ULONG cEl, ULONG *pcElFetched)
418 {
419     static const WCHAR acc_mime[] = {'*','/','*',0};
420     static const WCHAR user_agent[] = {'W','i','n','e',0};
421
422     ok(ppwzStr != NULL, "ppwzStr == NULL\n");
423     ok(pcElFetched != NULL, "pcElFetched == NULL\n");
424
425     switch(ulStringType) {
426     case BINDSTRING_ACCEPT_MIMES:
427         CHECK_EXPECT(GetBindString_ACCEPT_MIMES);
428         ok(cEl == 256, "cEl=%d, expected 256\n", cEl);
429         if(pcElFetched) {
430             ok(*pcElFetched == 256, "*pcElFetched=%d, expected 256\n", *pcElFetched);
431             *pcElFetched = 1;
432         }
433         if(ppwzStr) {
434             *ppwzStr = CoTaskMemAlloc(sizeof(acc_mime));
435             memcpy(*ppwzStr, acc_mime, sizeof(acc_mime));
436         }
437         return S_OK;
438     case BINDSTRING_USER_AGENT:
439         CHECK_EXPECT(GetBindString_USER_AGENT);
440         ok(cEl == 1, "cEl=%d, expected 1\n", cEl);
441         if(pcElFetched) {
442             ok(*pcElFetched == 0, "*pcElFetch=%d, expectd 0\n", *pcElFetched);
443             *pcElFetched = 1;
444         }
445         if(ppwzStr) {
446             *ppwzStr = CoTaskMemAlloc(sizeof(user_agent));
447             memcpy(*ppwzStr, user_agent, sizeof(user_agent));
448         }
449         return S_OK;
450     default:
451         ok(0, "unexpected call\n");
452     }
453
454     return E_NOTIMPL;
455 }
456
457 static IInternetBindInfoVtbl bind_info_vtbl = {
458     BindInfo_QueryInterface,
459     BindInfo_AddRef,
460     BindInfo_Release,
461     BindInfo_GetBindInfo,
462     BindInfo_GetBindString
463 };
464
465 static IInternetBindInfo bind_info = { &bind_info_vtbl };
466
467 static void test_priority(IInternetProtocol *protocol)
468 {
469     IInternetPriority *priority;
470     LONG pr;
471     HRESULT hres;
472
473     hres = IInternetProtocol_QueryInterface(protocol, &IID_IInternetPriority,
474                                             (void**)&priority);
475     ok(hres == S_OK, "QueryInterface(IID_IInternetPriority) failed: %08x\n", hres);
476     if(FAILED(hres))
477         return;
478
479     hres = IInternetPriority_GetPriority(priority, &pr);
480     ok(hres == S_OK, "GetPriority failed: %08x\n", hres);
481     ok(pr == 0, "pr=%d, expected 0\n", pr);
482
483     hres = IInternetPriority_SetPriority(priority, 1);
484     ok(hres == S_OK, "SetPriority failed: %08x\n", hres);
485
486     hres = IInternetPriority_GetPriority(priority, &pr);
487     ok(hres == S_OK, "GetPriority failed: %08x\n", hres);
488     ok(pr == 1, "pr=%d, expected 1\n", pr);
489
490     IInternetPriority_Release(priority);
491 }
492
493 static void file_protocol_start(IInternetProtocol *protocol, LPCWSTR url, BOOL is_first)
494 {
495     HRESULT hres;
496
497     SET_EXPECT(GetBindInfo);
498     if(!(bindf & BINDF_FROMURLMON))
499        SET_EXPECT(ReportProgress_DIRECTBIND);
500     if(is_first) {
501         SET_EXPECT(ReportProgress_SENDINGREQUEST);
502         SET_EXPECT(ReportProgress_CACHEFILENAMEAVAILABLE);
503         if(bindf & BINDF_FROMURLMON)
504             SET_EXPECT(ReportProgress_VERIFIEDMIMETYPEAVAILABLE);
505         else
506             SET_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
507     }
508     SET_EXPECT(ReportData);
509     if(is_first)
510         SET_EXPECT(ReportResult);
511
512     expect_hrResult = S_OK;
513
514     hres = IInternetProtocol_Start(protocol, url, &protocol_sink, &bind_info, 0, 0);
515     ok(hres == S_OK, "Start failed: %08x\n", hres);
516
517     CHECK_CALLED(GetBindInfo);
518     if(!(bindf & BINDF_FROMURLMON))
519        CHECK_CALLED(ReportProgress_DIRECTBIND);
520     if(is_first) {
521         CHECK_CALLED(ReportProgress_SENDINGREQUEST);
522         CHECK_CALLED(ReportProgress_CACHEFILENAMEAVAILABLE);
523         if(bindf & BINDF_FROMURLMON)
524             CHECK_CALLED(ReportProgress_VERIFIEDMIMETYPEAVAILABLE);
525         else
526             CHECK_CALLED(ReportProgress_MIMETYPEAVAILABLE);
527     }
528     CHECK_CALLED(ReportData);
529     if(is_first)
530         CHECK_CALLED(ReportResult);
531 }
532
533 static void test_file_protocol_url(LPCWSTR url)
534 {
535     IInternetProtocolInfo *protocol_info;
536     IUnknown *unk;
537     IClassFactory *factory;
538     HRESULT hres;
539
540     hres = CoGetClassObject(&CLSID_FileProtocol, CLSCTX_INPROC_SERVER, NULL,
541             &IID_IUnknown, (void**)&unk);
542     ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres);
543     if(!SUCCEEDED(hres))
544         return;
545
546     hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocolInfo, (void**)&protocol_info);
547     ok(hres == E_NOINTERFACE,
548             "Could not get IInternetProtocolInfo interface: %08x, expected E_NOINTERFACE\n", hres);
549
550     hres = IUnknown_QueryInterface(unk, &IID_IClassFactory, (void**)&factory);
551     ok(hres == S_OK, "Could not get IClassFactory interface\n");
552     if(SUCCEEDED(hres)) {
553         IInternetProtocol *protocol;
554         BYTE buf[512];
555         ULONG cb;
556         hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&protocol);
557         ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
558
559         if(SUCCEEDED(hres)) {
560             file_protocol_start(protocol, url, TRUE);
561             hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
562             ok(hres == S_OK, "Read failed: %08x\n", hres);
563             ok(cb == 2, "cb=%u expected 2\n", cb);
564             hres = IInternetProtocol_Read(protocol, buf, sizeof(buf), &cb);
565             ok(hres == S_FALSE, "Read failed: %08x\n", hres);
566             hres = IInternetProtocol_Read(protocol, buf, sizeof(buf), &cb);
567             ok(hres == S_FALSE, "Read failed: %08x expected S_FALSE\n", hres);
568             ok(cb == 0, "cb=%u expected 0\n", cb);
569             hres = IInternetProtocol_UnlockRequest(protocol);
570             ok(hres == S_OK, "UnlockRequest failed: %08x\n", hres);
571
572             file_protocol_start(protocol, url, FALSE);
573             hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
574             ok(hres == S_FALSE, "Read failed: %08x\n", hres);
575             hres = IInternetProtocol_LockRequest(protocol, 0);
576             ok(hres == S_OK, "LockRequest failed: %08x\n", hres);
577             hres = IInternetProtocol_UnlockRequest(protocol);
578             ok(hres == S_OK, "UnlockRequest failed: %08x\n", hres);
579
580             IInternetProtocol_Release(protocol);
581         }
582
583         hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&protocol);
584         ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
585
586         if(SUCCEEDED(hres)) {
587             file_protocol_start(protocol, url, TRUE);
588             hres = IInternetProtocol_LockRequest(protocol, 0);
589             ok(hres == S_OK, "LockRequest failed: %08x\n", hres);
590             hres = IInternetProtocol_Terminate(protocol, 0);
591             ok(hres == S_OK, "Terminate failed: %08x\n", hres);
592             hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
593             ok(hres == S_OK, "Read failed: %08x\n\n", hres);
594             hres = IInternetProtocol_UnlockRequest(protocol);
595             ok(hres == S_OK, "UnlockRequest failed: %08x\n", hres);
596             hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
597             ok(hres == S_OK, "Read failed: %08x\n", hres);
598             hres = IInternetProtocol_Terminate(protocol, 0);
599             ok(hres == S_OK, "Terminate failed: %08x\n", hres);
600
601             IInternetProtocol_Release(protocol);
602         }
603
604         hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&protocol);
605         ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
606
607         if(SUCCEEDED(hres)) {
608             file_protocol_start(protocol, url, TRUE);
609             hres = IInternetProtocol_Terminate(protocol, 0);
610             ok(hres == S_OK, "Terminate failed: %08x\n", hres);
611             hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
612             ok(hres == S_OK, "Read failed: %08x\n", hres);
613             ok(cb == 2, "cb=%u expected 2\n", cb);
614
615             IInternetProtocol_Release(protocol);
616         }
617
618         IClassFactory_Release(factory);
619     }
620
621     IUnknown_Release(unk);
622 }
623
624 static void test_file_protocol_fail(void)
625 {
626     IInternetProtocol *protocol;
627     HRESULT hres;
628
629     static const WCHAR index_url2[] =
630         {'f','i','l','e',':','/','/','i','n','d','e','x','.','h','t','m','l',0};
631
632     hres = CoCreateInstance(&CLSID_FileProtocol, NULL, CLSCTX_INPROC_SERVER|CLSCTX_INPROC_HANDLER,
633             &IID_IInternetProtocol, (void**)&protocol);
634     ok(hres == S_OK, "CoCreateInstance failed: %08x\n", hres);
635     if(FAILED(hres))
636         return;
637
638     SET_EXPECT(GetBindInfo);
639     expect_hrResult = MK_E_SYNTAX;
640     hres = IInternetProtocol_Start(protocol, wszIndexHtml, &protocol_sink, &bind_info, 0, 0);
641     ok(hres == MK_E_SYNTAX, "Start failed: %08x, expected MK_E_SYNTAX\n", hres);
642     CHECK_CALLED(GetBindInfo);
643
644     SET_EXPECT(GetBindInfo);
645     if(!(bindf & BINDF_FROMURLMON))
646         SET_EXPECT(ReportProgress_DIRECTBIND);
647     SET_EXPECT(ReportProgress_SENDINGREQUEST);
648     SET_EXPECT(ReportResult);
649     expect_hrResult = INET_E_RESOURCE_NOT_FOUND;
650     hres = IInternetProtocol_Start(protocol, index_url, &protocol_sink, &bind_info, 0, 0);
651     ok(hres == INET_E_RESOURCE_NOT_FOUND,
652             "Start failed: %08x expected INET_E_RESOURCE_NOT_FOUND\n", hres);
653     CHECK_CALLED(GetBindInfo);
654     if(!(bindf & BINDF_FROMURLMON))
655         CHECK_CALLED(ReportProgress_DIRECTBIND);
656     CHECK_CALLED(ReportProgress_SENDINGREQUEST);
657     CHECK_CALLED(ReportResult);
658
659     IInternetProtocol_Release(protocol);
660
661     hres = CoCreateInstance(&CLSID_FileProtocol, NULL, CLSCTX_INPROC_SERVER|CLSCTX_INPROC_HANDLER,
662             &IID_IInternetProtocol, (void**)&protocol);
663     ok(hres == S_OK, "CoCreateInstance failed: %08x\n", hres);
664     if(FAILED(hres))
665         return;
666
667     SET_EXPECT(GetBindInfo);
668     if(!(bindf & BINDF_FROMURLMON))
669         SET_EXPECT(ReportProgress_DIRECTBIND);
670     SET_EXPECT(ReportProgress_SENDINGREQUEST);
671     SET_EXPECT(ReportResult);
672     expect_hrResult = INET_E_RESOURCE_NOT_FOUND;
673
674     hres = IInternetProtocol_Start(protocol, index_url2, &protocol_sink, &bind_info, 0, 0);
675     ok(hres == INET_E_RESOURCE_NOT_FOUND,
676             "Start failed: %08x, expected INET_E_RESOURCE_NOT_FOUND\n", hres);
677     CHECK_CALLED(GetBindInfo);
678     if(!(bindf & BINDF_FROMURLMON))
679         CHECK_CALLED(ReportProgress_DIRECTBIND);
680     CHECK_CALLED(ReportProgress_SENDINGREQUEST);
681     CHECK_CALLED(ReportResult);
682
683     IInternetProtocol_Release(protocol);
684 }
685
686 static void test_file_protocol(void) {
687     WCHAR buf[MAX_PATH];
688     DWORD size;
689     ULONG len;
690     HANDLE file;
691
692     static const WCHAR wszFile[] = {'f','i','l','e',':',0};
693     static const WCHAR wszFile2[] = {'f','i','l','e',':','/','/',0};
694     static const WCHAR wszFile3[] = {'f','i','l','e',':','/','/','/',0};
695     static const char html_doc[] = "<HTML></HTML>";
696
697     tested_protocol = FILE_TEST;
698
699     file = CreateFileW(wszIndexHtml, GENERIC_WRITE, 0, NULL, CREATE_ALWAYS,
700             FILE_ATTRIBUTE_NORMAL, NULL);
701     ok(file != INVALID_HANDLE_VALUE, "CreateFile failed\n");
702     if(file == INVALID_HANDLE_VALUE)
703         return;
704     WriteFile(file, html_doc, sizeof(html_doc)-1, &size, NULL);
705     CloseHandle(file);
706
707     file_name = wszIndexHtml;
708     bindf = 0;
709     test_file_protocol_url(index_url);
710     bindf = BINDF_FROMURLMON;
711     test_file_protocol_url(index_url);
712
713     memcpy(buf, wszFile, sizeof(wszFile));
714     len = sizeof(wszFile)/sizeof(WCHAR)-1;
715     len += GetCurrentDirectoryW(sizeof(buf)/sizeof(WCHAR)-len, buf+len);
716     buf[len++] = '\\';
717     memcpy(buf+len, wszIndexHtml, sizeof(wszIndexHtml));
718
719     file_name = buf + sizeof(wszFile)/sizeof(WCHAR)-1;
720     bindf = 0;
721     test_file_protocol_url(buf);
722     bindf = BINDF_FROMURLMON;
723     test_file_protocol_url(buf);
724
725     memcpy(buf, wszFile2, sizeof(wszFile2));
726     len = sizeof(wszFile2)/sizeof(WCHAR)-1;
727     len += GetCurrentDirectoryW(sizeof(buf)/sizeof(WCHAR)-len, buf+len);
728     buf[len++] = '\\';
729     memcpy(buf+len, wszIndexHtml, sizeof(wszIndexHtml));
730
731     file_name = buf + sizeof(wszFile2)/sizeof(WCHAR)-1;
732     bindf = 0;
733     test_file_protocol_url(buf);
734     bindf = BINDF_FROMURLMON;
735     test_file_protocol_url(buf);
736
737     memcpy(buf, wszFile3, sizeof(wszFile3));
738     len = sizeof(wszFile3)/sizeof(WCHAR)-1;
739     len += GetCurrentDirectoryW(sizeof(buf)/sizeof(WCHAR)-len, buf+len);
740     buf[len++] = '\\';
741     memcpy(buf+len, wszIndexHtml, sizeof(wszIndexHtml));
742
743     file_name = buf + sizeof(wszFile3)/sizeof(WCHAR)-1;
744     bindf = 0;
745     test_file_protocol_url(buf);
746     bindf = BINDF_FROMURLMON;
747     test_file_protocol_url(buf);
748
749     DeleteFileW(wszIndexHtml);
750
751     bindf = 0;
752     test_file_protocol_fail();
753     bindf = BINDF_FROMURLMON;
754     test_file_protocol_fail();
755 }
756
757 static BOOL http_protocol_start(LPCWSTR url, BOOL is_first)
758 {
759     HRESULT hres;
760
761     first_data_notif = TRUE;
762
763     SET_EXPECT(GetBindInfo);
764     SET_EXPECT(GetBindString_USER_AGENT);
765     SET_EXPECT(GetBindString_ACCEPT_MIMES);
766     SET_EXPECT(QueryService_HttpNegotiate);
767     SET_EXPECT(BeginningTransaction);
768     SET_EXPECT(GetRootSecurityId);
769
770     hres = IInternetProtocol_Start(http_protocol, url, &protocol_sink, &bind_info, 0, 0);
771     todo_wine {
772         ok(hres == S_OK, "Start failed: %08x\n", hres);
773     }
774     if(FAILED(hres))
775         return FALSE;
776
777     CHECK_CALLED(GetBindInfo);
778     CHECK_CALLED(GetBindString_USER_AGENT);
779     CHECK_CALLED(GetBindString_ACCEPT_MIMES);
780     CHECK_CALLED(QueryService_HttpNegotiate);
781     CHECK_CALLED(BeginningTransaction);
782     CHECK_CALLED(GetRootSecurityId);
783
784     return TRUE;
785 }
786
787 static void test_http_protocol_url(LPCWSTR url)
788 {
789     IInternetProtocolInfo *protocol_info;
790     IClassFactory *factory;
791     IUnknown *unk;
792     HRESULT hres;
793
794     http_url = url;
795
796     hres = CoGetClassObject(&CLSID_HttpProtocol, CLSCTX_INPROC_SERVER, NULL,
797             &IID_IUnknown, (void**)&unk);
798     ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres);
799     if(!SUCCEEDED(hres))
800         return;
801
802     hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocolInfo, (void**)&protocol_info);
803     ok(hres == E_NOINTERFACE,
804         "Could not get IInternetProtocolInfo interface: %08x, expected E_NOINTERFACE\n",
805         hres);
806
807     hres = IUnknown_QueryInterface(unk, &IID_IClassFactory, (void**)&factory);
808     ok(hres == S_OK, "Could not get IClassFactory interface\n");
809     IUnknown_Release(unk);
810     if(FAILED(hres))
811         return;
812
813     hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol,
814                                         (void**)&http_protocol);
815     ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
816     if(SUCCEEDED(hres)) {
817         BYTE buf[512];
818         DWORD cb;
819         MSG msg;
820
821         bindf = BINDF_ASYNCHRONOUS | BINDF_ASYNCSTORAGE | BINDF_PULLDATA | BINDF_FROMURLMON;
822
823         test_priority(http_protocol);
824
825         SET_EXPECT(ReportProgress_FINDINGRESOURCE);
826         SET_EXPECT(ReportProgress_CONNECTING);
827         SET_EXPECT(ReportProgress_SENDINGREQUEST);
828
829         if(!http_protocol_start(url, TRUE))
830             return;
831
832         hres = IInternetProtocol_Read(http_protocol, buf, 2, &cb);
833         ok(hres == E_PENDING, "Read failed: %08x, expected E_PENDING\n", hres);
834         ok(!cb, "cb=%d, expected 0\n", cb);
835
836         SET_EXPECT(Switch);
837         SET_EXPECT(ReportResult);
838         expect_hrResult = S_OK;
839
840         GetMessageW(&msg, NULL, 0, 0);
841
842         CHECK_CALLED(Switch);
843         CHECK_CALLED(ReportResult);
844
845         IInternetProtocol_Release(http_protocol);
846     }
847
848     IClassFactory_Release(factory);
849 }
850
851 static void test_http_protocol(void)
852 {
853     static const WCHAR winehq_url[] =
854         {'h','t','t','p',':','/','/','w','w','w','.','w','i','n','e','h','q','.',
855             'o','r','g','/','s','i','t','e','/','a','b','o','u','t',0};
856
857     tested_protocol = HTTP_TEST;
858     test_http_protocol_url(winehq_url);
859
860 }
861
862 static LRESULT WINAPI wnd_proc(HWND hwnd, UINT msg, WPARAM wParam, LPARAM lParam)
863 {
864     if(msg == WM_USER) {
865         HRESULT hres;
866         DWORD cb;
867         BYTE buf[3600];
868
869         SET_EXPECT(ReportData);
870         if(!state) {
871             CHECK_CALLED(ReportProgress_FINDINGRESOURCE);
872             CHECK_CALLED(ReportProgress_CONNECTING);
873             CHECK_CALLED(ReportProgress_SENDINGREQUEST);
874
875             SET_EXPECT(OnResponse);
876             SET_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
877         }
878
879         hres = IInternetProtocol_Continue(http_protocol, (PROTOCOLDATA*)lParam);
880         ok(hres == S_OK, "Continue failed: %08x\n", hres);
881
882         CHECK_CALLED(ReportData);
883         if(!state) {
884             CHECK_CALLED(OnResponse);
885             CHECK_CALLED(ReportProgress_MIMETYPEAVAILABLE);
886         }
887
888         do hres = IInternetProtocol_Read(http_protocol, buf, sizeof(buf), &cb);
889         while(cb);
890
891         ok(hres == S_FALSE || hres == E_PENDING, "Read failed: %08x\n", hres);
892
893         if(hres == S_FALSE)
894             PostMessageW(protocol_hwnd, WM_USER+1, 0, 0);
895
896         if(!state) {
897             state = 1;
898
899             hres = IInternetProtocol_LockRequest(http_protocol, 0);
900             ok(hres == S_OK, "LockRequest failed: %08x\n", hres);
901
902             do hres = IInternetProtocol_Read(http_protocol, buf, sizeof(buf), &cb);
903             while(cb);
904             ok(hres == S_FALSE || hres == E_PENDING, "Read failed: %08x\n", hres);
905         }
906     }
907
908     return DefWindowProc(hwnd, msg, wParam, lParam);
909 }
910
911 static HWND create_protocol_window(void)
912 {
913     static const WCHAR wszProtocolWindow[] =
914         {'P','r','o','t','o','c','o','l','W','i','n','d','o','w',0};
915     static WNDCLASSEXW wndclass = {
916         sizeof(WNDCLASSEXW),
917         0,
918         wnd_proc,
919         0, 0, NULL, NULL, NULL, NULL, NULL,
920         wszProtocolWindow,
921         NULL
922     };
923
924     RegisterClassExW(&wndclass);
925     return CreateWindowW(wszProtocolWindow, wszProtocolWindow,
926                          WS_OVERLAPPEDWINDOW, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT,
927                          CW_USEDEFAULT, NULL, NULL, NULL, NULL);
928 }
929
930 START_TEST(protocol)
931 {
932     OleInitialize(NULL);
933
934     protocol_hwnd = create_protocol_window();
935
936     test_file_protocol();
937     test_http_protocol();
938
939     DestroyWindow(protocol_hwnd);
940
941     OleUninitialize();
942 }