urlmon: Added mk protocol implementation.
[wine] / dlls / urlmon / tests / protocol.c
1 /*
2  * Copyright 2005 Jacek Caban
3  *
4  * This library is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 2.1 of the License, or (at your option) any later version.
8  *
9  * This library is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * Lesser General Public License for more details.
13  *
14  * You should have received a copy of the GNU Lesser General Public
15  * License along with this library; if not, write to the Free Software
16  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
17  */
18
19 #define COBJMACROS
20 #define CONST_VTABLE
21
22 #include <wine/test.h>
23 #include <stdarg.h>
24
25 #include "windef.h"
26 #include "winbase.h"
27 #include "ole2.h"
28 #include "urlmon.h"
29
30 #include "initguid.h"
31
32 #define DEFINE_EXPECT(func) \
33     static BOOL expect_ ## func = FALSE, called_ ## func = FALSE
34
35 #define SET_EXPECT(func) \
36     expect_ ## func = TRUE
37
38 #define CHECK_EXPECT(func) \
39     do { \
40         ok(expect_ ##func, "unexpected call " #func "\n"); \
41         expect_ ## func = FALSE; \
42         called_ ## func = TRUE; \
43     }while(0)
44
45 #define CHECK_EXPECT2(func) \
46     do { \
47         ok(expect_ ##func, "unexpected call " #func  "\n"); \
48         called_ ## func = TRUE; \
49     }while(0)
50
51 #define CHECK_CALLED(func) \
52     do { \
53         ok(called_ ## func, "expected " #func "\n"); \
54         expect_ ## func = called_ ## func = FALSE; \
55     }while(0)
56
57 DEFINE_EXPECT(GetBindInfo);
58 DEFINE_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
59 DEFINE_EXPECT(ReportProgress_DIRECTBIND);
60 DEFINE_EXPECT(ReportProgress_FINDINGRESOURCE);
61 DEFINE_EXPECT(ReportProgress_CONNECTING);
62 DEFINE_EXPECT(ReportProgress_SENDINGREQUEST);
63 DEFINE_EXPECT(ReportProgress_CACHEFILENAMEAVAILABLE);
64 DEFINE_EXPECT(ReportProgress_VERIFIEDMIMETYPEAVAILABLE);
65 DEFINE_EXPECT(ReportData);
66 DEFINE_EXPECT(ReportResult);
67 DEFINE_EXPECT(GetBindString_ACCEPT_MIMES);
68 DEFINE_EXPECT(GetBindString_USER_AGENT);
69 DEFINE_EXPECT(QueryService_HttpNegotiate);
70 DEFINE_EXPECT(BeginningTransaction);
71 DEFINE_EXPECT(GetRootSecurityId);
72 DEFINE_EXPECT(OnResponse);
73 DEFINE_EXPECT(Switch);
74
75 static const WCHAR wszIndexHtml[] = {'i','n','d','e','x','.','h','t','m','l',0};
76 static const WCHAR index_url[] =
77     {'f','i','l','e',':','i','n','d','e','x','.','h','t','m','l',0};
78
79 static HRESULT expect_hrResult;
80 static LPCWSTR file_name, http_url;
81 static IInternetProtocol *http_protocol = NULL;
82 static BOOL first_data_notif = FALSE;
83 static HWND protocol_hwnd;
84 static int state = 0;
85 static DWORD bindf = 0;
86
87 static enum {
88     FILE_TEST,
89     HTTP_TEST,
90     MK_TEST
91 } tested_protocol;
92
93 static HRESULT WINAPI HttpNegotiate_QueryInterface(IHttpNegotiate2 *iface, REFIID riid, void **ppv)
94 {
95     if(IsEqualGUID(&IID_IUnknown, riid)
96             || IsEqualGUID(&IID_IHttpNegotiate, riid)
97             || IsEqualGUID(&IID_IHttpNegotiate2, riid)) {
98         *ppv = iface;
99         return S_OK;
100     }
101
102     ok(0, "unexpected call\n");
103     return E_NOINTERFACE;
104 }
105
106 static ULONG WINAPI HttpNegotiate_AddRef(IHttpNegotiate2 *iface)
107 {
108     return 2;
109 }
110
111 static ULONG WINAPI HttpNegotiate_Release(IHttpNegotiate2 *iface)
112 {
113     return 1;
114 }
115
116 static HRESULT WINAPI HttpNegotiate_BeginningTransaction(IHttpNegotiate2 *iface, LPCWSTR szURL,
117         LPCWSTR szHeaders, DWORD dwReserved, LPWSTR *pszAdditionalHeaders)
118 {
119     CHECK_EXPECT(BeginningTransaction);
120
121     ok(!lstrcmpW(szURL, http_url), "szURL != http_url\n");
122     ok(!dwReserved, "dwReserved=%d, expected 0\n", dwReserved);
123     ok(pszAdditionalHeaders != NULL, "pszAdditionalHeaders == NULL\n");
124     if(pszAdditionalHeaders)
125         ok(*pszAdditionalHeaders == NULL, "*pszAdditionalHeaders != NULL\n");
126
127     return S_OK;
128 }
129
130 static HRESULT WINAPI HttpNegotiate_OnResponse(IHttpNegotiate2 *iface, DWORD dwResponseCode,
131         LPCWSTR szResponseHeaders, LPCWSTR szRequestHeaders, LPWSTR *pszAdditionalRequestHeaders)
132 {
133     CHECK_EXPECT(OnResponse);
134
135     ok(dwResponseCode == 200, "dwResponseCode=%d, expected 200\n", dwResponseCode);
136     ok(szResponseHeaders != NULL, "szResponseHeaders == NULL\n");
137     ok(szRequestHeaders == NULL, "szRequestHeaders != NULL\n");
138     ok(pszAdditionalRequestHeaders == NULL, "pszAdditionalHeaders != NULL\n");
139
140     return S_OK;
141 }
142
143 static HRESULT WINAPI HttpNegotiate_GetRootSecurityId(IHttpNegotiate2 *iface,
144         BYTE *pbSecurityId, DWORD *pcbSecurityId, DWORD_PTR dwReserved)
145 {
146     static const BYTE sec_id[] = {'h','t','t','p',':','t','e','s','t',1,0,0,0};
147     
148     CHECK_EXPECT(GetRootSecurityId);
149
150     ok(!dwReserved, "dwReserved=%ld, expected 0\n", dwReserved);
151     ok(pbSecurityId != NULL, "pbSecurityId == NULL\n");
152     ok(pcbSecurityId != NULL, "pcbSecurityId == NULL\n");
153
154     if(pcbSecurityId) {
155         ok(*pcbSecurityId == 512, "*pcbSecurityId=%d, expected 512\n", *pcbSecurityId);
156         *pcbSecurityId = sizeof(sec_id);
157     }
158
159     if(pbSecurityId)
160         memcpy(pbSecurityId, sec_id, sizeof(sec_id));
161
162     return E_FAIL;
163 }
164
165 static IHttpNegotiate2Vtbl HttpNegotiateVtbl = {
166     HttpNegotiate_QueryInterface,
167     HttpNegotiate_AddRef,
168     HttpNegotiate_Release,
169     HttpNegotiate_BeginningTransaction,
170     HttpNegotiate_OnResponse,
171     HttpNegotiate_GetRootSecurityId
172 };
173
174 static IHttpNegotiate2 http_negotiate = { &HttpNegotiateVtbl };
175
176 static HRESULT QueryInterface(REFIID,void**);
177
178 static HRESULT WINAPI ServiceProvider_QueryInterface(IServiceProvider *iface, REFIID riid, void **ppv)
179 {
180     return QueryInterface(riid, ppv);
181 }
182
183 static ULONG WINAPI ServiceProvider_AddRef(IServiceProvider *iface)
184 {
185     return 2;
186 }
187
188 static ULONG WINAPI ServiceProvider_Release(IServiceProvider *iface)
189 {
190     return 1;
191 }
192
193 static HRESULT WINAPI ServiceProvider_QueryService(IServiceProvider *iface, REFGUID guidService,
194         REFIID riid, void **ppv)
195 {
196     if(IsEqualGUID(&IID_IHttpNegotiate, guidService) || IsEqualGUID(&IID_IHttpNegotiate2, riid)) {
197         CHECK_EXPECT2(QueryService_HttpNegotiate);
198         return IHttpNegotiate2_QueryInterface(&http_negotiate, riid, ppv);
199     }
200
201     ok(0, "unexpected call\n");
202     return E_FAIL;
203 }
204
205 static const IServiceProviderVtbl ServiceProviderVtbl = {
206     ServiceProvider_QueryInterface,
207     ServiceProvider_AddRef,
208     ServiceProvider_Release,
209     ServiceProvider_QueryService
210 };
211
212 static IServiceProvider service_provider = { &ServiceProviderVtbl };
213
214 static HRESULT WINAPI ProtocolSink_QueryInterface(IInternetProtocolSink *iface, REFIID riid, void **ppv)
215 {
216     return QueryInterface(riid, ppv);
217 }
218
219 static ULONG WINAPI ProtocolSink_AddRef(IInternetProtocolSink *iface)
220 {
221     return 2;
222 }
223
224 static ULONG WINAPI ProtocolSink_Release(IInternetProtocolSink *iface)
225 {
226     return 1;
227 }
228
229 static HRESULT WINAPI ProtocolSink_Switch(IInternetProtocolSink *iface, PROTOCOLDATA *pProtocolData)
230 {
231     CHECK_EXPECT2(Switch);
232     ok(pProtocolData != NULL, "pProtocolData == NULL\n");
233     SendMessageW(protocol_hwnd, WM_USER, 0, (LPARAM)pProtocolData);
234     return S_OK;
235 }
236
237 static HRESULT WINAPI ProtocolSink_ReportProgress(IInternetProtocolSink *iface, ULONG ulStatusCode,
238         LPCWSTR szStatusText)
239 {
240     static const WCHAR text_html[] = {'t','e','x','t','/','h','t','m','l',0};
241     static const WCHAR host[] =
242         {'w','w','w','.','w','i','n','e','h','q','.','o','r','g',0};
243     static const WCHAR wszWineHQIP[] =
244         {'2','0','9','.','4','6','.','2','5','.','1','3','4',0};
245     /* I'm not sure if it's a good idea to hardcode here the IP address... */
246
247     switch(ulStatusCode) {
248     case BINDSTATUS_MIMETYPEAVAILABLE:
249         CHECK_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
250         ok(szStatusText != NULL, "szStatusText == NULL\n");
251         if(szStatusText)
252             ok(!lstrcmpW(szStatusText, text_html), "szStatusText != text/html\n");
253         break;
254     case BINDSTATUS_DIRECTBIND:
255         CHECK_EXPECT2(ReportProgress_DIRECTBIND);
256         ok(szStatusText == NULL, "szStatusText != NULL\n");
257         break;
258     case BINDSTATUS_CACHEFILENAMEAVAILABLE:
259         CHECK_EXPECT(ReportProgress_CACHEFILENAMEAVAILABLE);
260         ok(szStatusText != NULL, "szStatusText == NULL\n");
261         if(szStatusText)
262             ok(!lstrcmpW(szStatusText, file_name), "szStatusText != file_name\n");
263         break;
264     case BINDSTATUS_FINDINGRESOURCE:
265         CHECK_EXPECT(ReportProgress_FINDINGRESOURCE);
266         ok(szStatusText != NULL, "szStatusText == NULL\n");
267         if(szStatusText)
268             ok(!lstrcmpW(szStatusText, host), "szStatustext != \"www.winehq.org\"\n");
269         break;
270     case BINDSTATUS_CONNECTING:
271         CHECK_EXPECT(ReportProgress_CONNECTING);
272         ok(szStatusText != NULL, "szStatusText == NULL\n");
273         if(szStatusText)
274             ok(!lstrcmpW(szStatusText, wszWineHQIP), "Unexpected szStatusText\n");
275         break;
276     case BINDSTATUS_SENDINGREQUEST:
277         CHECK_EXPECT(ReportProgress_SENDINGREQUEST);
278         if(tested_protocol == FILE_TEST) {
279             ok(szStatusText != NULL, "szStatusText == NULL\n");
280             if(szStatusText)
281                 ok(!*szStatusText, "wrong szStatusText\n");
282         }else {
283             ok(szStatusText == NULL, "szStatusText != NULL\n");
284         }
285         break;
286     case BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE:
287         CHECK_EXPECT(ReportProgress_VERIFIEDMIMETYPEAVAILABLE);
288         ok(szStatusText != NULL, "szStatusText == NULL\n");
289         if(szStatusText)
290             ok(!lstrcmpW(szStatusText, text_html), "szStatusText != text/html\n");
291         break;
292     default:
293         ok(0, "Unexpected call %d\n", ulStatusCode);
294     };
295
296     return S_OK;
297 }
298
299 static HRESULT WINAPI ProtocolSink_ReportData(IInternetProtocolSink *iface, DWORD grfBSCF,
300         ULONG ulProgress, ULONG ulProgressMax)
301 {
302     if(tested_protocol == FILE_TEST) {
303         CHECK_EXPECT2(ReportData);
304
305         ok(ulProgress == ulProgressMax, "ulProgress (%d) != ulProgressMax (%d)\n",
306            ulProgress, ulProgressMax);
307         ok(ulProgressMax == 13, "ulProgressMax=%d, expected 13\n", ulProgressMax);
308         ok(grfBSCF == (BSCF_FIRSTDATANOTIFICATION | BSCF_LASTDATANOTIFICATION),
309                 "grcfBSCF = %08x\n", grfBSCF);
310     }else if(tested_protocol == HTTP_TEST) {
311         if(!(grfBSCF & BSCF_LASTDATANOTIFICATION))
312             CHECK_EXPECT(ReportData);
313
314         ok(ulProgress, "ulProgress == 0\n");
315
316         if(first_data_notif) {
317             ok(grfBSCF == BSCF_FIRSTDATANOTIFICATION, "grcfBSCF = %08x\n", grfBSCF);
318             first_data_notif = FALSE;
319         } else {
320             ok(grfBSCF == BSCF_INTERMEDIATEDATANOTIFICATION
321                || grfBSCF == (BSCF_LASTDATANOTIFICATION|BSCF_INTERMEDIATEDATANOTIFICATION),
322                "grcfBSCF = %08x\n", grfBSCF);
323         }
324     }
325     return S_OK;
326 }
327
328 static HRESULT WINAPI ProtocolSink_ReportResult(IInternetProtocolSink *iface, HRESULT hrResult,
329         DWORD dwError, LPCWSTR szResult)
330 {
331     CHECK_EXPECT(ReportResult);
332
333     ok(hrResult == expect_hrResult, "hrResult = %08x, expected: %08x\n",
334             hrResult, expect_hrResult);
335     if(SUCCEEDED(hrResult))
336         ok(dwError == ERROR_SUCCESS, "dwError = %d, expected ERROR_SUCCESS\n", dwError);
337     else
338         ok(dwError != ERROR_SUCCESS, "dwError == ERROR_SUCCESS\n");
339     ok(!szResult, "szResult != NULL\n");
340
341     return S_OK;
342 }
343
344 static IInternetProtocolSinkVtbl protocol_sink_vtbl = {
345     ProtocolSink_QueryInterface,
346     ProtocolSink_AddRef,
347     ProtocolSink_Release,
348     ProtocolSink_Switch,
349     ProtocolSink_ReportProgress,
350     ProtocolSink_ReportData,
351     ProtocolSink_ReportResult
352 };
353
354 static IInternetProtocolSink protocol_sink = { &protocol_sink_vtbl };
355
356 static HRESULT QueryInterface(REFIID riid, void **ppv)
357 {
358     *ppv = NULL;
359
360     if(IsEqualGUID(&IID_IUnknown, riid) || IsEqualGUID(&IID_IInternetProtocolSink, riid))
361         *ppv = &protocol_sink;
362     if(IsEqualGUID(&IID_IServiceProvider, riid))
363         *ppv = &service_provider;
364
365     if(*ppv)
366         return S_OK;
367
368     return E_NOINTERFACE;
369 }
370
371 static HRESULT WINAPI BindInfo_QueryInterface(IInternetBindInfo *iface, REFIID riid, void **ppv)
372 {
373     if(IsEqualGUID(&IID_IUnknown, riid) || IsEqualGUID(&IID_IInternetBindInfo, riid)) {
374         *ppv = iface;
375         return S_OK;
376     }
377     return E_NOINTERFACE;
378 }
379
380 static ULONG WINAPI BindInfo_AddRef(IInternetBindInfo *iface)
381 {
382     return 2;
383 }
384
385 static ULONG WINAPI BindInfo_Release(IInternetBindInfo *iface)
386 {
387     return 1;
388 }
389
390 static HRESULT WINAPI BindInfo_GetBindInfo(IInternetBindInfo *iface, DWORD *grfBINDF, BINDINFO *pbindinfo)
391 {
392     DWORD cbSize;
393
394     CHECK_EXPECT(GetBindInfo);
395
396     ok(grfBINDF != NULL, "grfBINDF == NULL\n");
397     ok(pbindinfo != NULL, "pbindinfo == NULL\n");
398     ok(pbindinfo->cbSize == sizeof(BINDINFO), "wrong size of pbindinfo: %d\n", pbindinfo->cbSize);
399
400     *grfBINDF = bindf;
401     cbSize = pbindinfo->cbSize;
402     memset(pbindinfo, 0, cbSize);
403     pbindinfo->cbSize = cbSize;
404
405     return S_OK;
406 }
407
408 static HRESULT WINAPI BindInfo_GetBindString(IInternetBindInfo *iface, ULONG ulStringType,
409         LPOLESTR *ppwzStr, ULONG cEl, ULONG *pcElFetched)
410 {
411     static const WCHAR acc_mime[] = {'*','/','*',0};
412     static const WCHAR user_agent[] = {'W','i','n','e',0};
413
414     ok(ppwzStr != NULL, "ppwzStr == NULL\n");
415     ok(pcElFetched != NULL, "pcElFetched == NULL\n");
416
417     switch(ulStringType) {
418     case BINDSTRING_ACCEPT_MIMES:
419         CHECK_EXPECT(GetBindString_ACCEPT_MIMES);
420         ok(cEl == 256, "cEl=%d, expected 256\n", cEl);
421         if(pcElFetched) {
422             ok(*pcElFetched == 256, "*pcElFetched=%d, expected 256\n", *pcElFetched);
423             *pcElFetched = 1;
424         }
425         if(ppwzStr) {
426             *ppwzStr = CoTaskMemAlloc(sizeof(acc_mime));
427             memcpy(*ppwzStr, acc_mime, sizeof(acc_mime));
428         }
429         return S_OK;
430     case BINDSTRING_USER_AGENT:
431         CHECK_EXPECT(GetBindString_USER_AGENT);
432         ok(cEl == 1, "cEl=%d, expected 1\n", cEl);
433         if(pcElFetched) {
434             ok(*pcElFetched == 0, "*pcElFetch=%d, expectd 0\n", *pcElFetched);
435             *pcElFetched = 1;
436         }
437         if(ppwzStr) {
438             *ppwzStr = CoTaskMemAlloc(sizeof(user_agent));
439             memcpy(*ppwzStr, user_agent, sizeof(user_agent));
440         }
441         return S_OK;
442     default:
443         ok(0, "unexpected call\n");
444     }
445
446     return E_NOTIMPL;
447 }
448
449 static IInternetBindInfoVtbl bind_info_vtbl = {
450     BindInfo_QueryInterface,
451     BindInfo_AddRef,
452     BindInfo_Release,
453     BindInfo_GetBindInfo,
454     BindInfo_GetBindString
455 };
456
457 static IInternetBindInfo bind_info = { &bind_info_vtbl };
458
459 static void test_priority(IInternetProtocol *protocol)
460 {
461     IInternetPriority *priority;
462     LONG pr;
463     HRESULT hres;
464
465     hres = IInternetProtocol_QueryInterface(protocol, &IID_IInternetPriority,
466                                             (void**)&priority);
467     ok(hres == S_OK, "QueryInterface(IID_IInternetPriority) failed: %08x\n", hres);
468     if(FAILED(hres))
469         return;
470
471     hres = IInternetPriority_GetPriority(priority, &pr);
472     ok(hres == S_OK, "GetPriority failed: %08x\n", hres);
473     ok(pr == 0, "pr=%d, expected 0\n", pr);
474
475     hres = IInternetPriority_SetPriority(priority, 1);
476     ok(hres == S_OK, "SetPriority failed: %08x\n", hres);
477
478     hres = IInternetPriority_GetPriority(priority, &pr);
479     ok(hres == S_OK, "GetPriority failed: %08x\n", hres);
480     ok(pr == 1, "pr=%d, expected 1\n", pr);
481
482     IInternetPriority_Release(priority);
483 }
484
485 static void file_protocol_start(IInternetProtocol *protocol, LPCWSTR url, BOOL is_first)
486 {
487     HRESULT hres;
488
489     SET_EXPECT(GetBindInfo);
490     if(!(bindf & BINDF_FROMURLMON))
491        SET_EXPECT(ReportProgress_DIRECTBIND);
492     if(is_first) {
493         SET_EXPECT(ReportProgress_SENDINGREQUEST);
494         SET_EXPECT(ReportProgress_CACHEFILENAMEAVAILABLE);
495         if(bindf & BINDF_FROMURLMON)
496             SET_EXPECT(ReportProgress_VERIFIEDMIMETYPEAVAILABLE);
497         else
498             SET_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
499     }
500     SET_EXPECT(ReportData);
501     if(is_first)
502         SET_EXPECT(ReportResult);
503
504     expect_hrResult = S_OK;
505
506     hres = IInternetProtocol_Start(protocol, url, &protocol_sink, &bind_info, 0, 0);
507     ok(hres == S_OK, "Start failed: %08x\n", hres);
508
509     CHECK_CALLED(GetBindInfo);
510     if(!(bindf & BINDF_FROMURLMON))
511        CHECK_CALLED(ReportProgress_DIRECTBIND);
512     if(is_first) {
513         CHECK_CALLED(ReportProgress_SENDINGREQUEST);
514         CHECK_CALLED(ReportProgress_CACHEFILENAMEAVAILABLE);
515         if(bindf & BINDF_FROMURLMON)
516             CHECK_CALLED(ReportProgress_VERIFIEDMIMETYPEAVAILABLE);
517         else
518             CHECK_CALLED(ReportProgress_MIMETYPEAVAILABLE);
519     }
520     CHECK_CALLED(ReportData);
521     if(is_first)
522         CHECK_CALLED(ReportResult);
523 }
524
525 static void test_file_protocol_url(LPCWSTR url)
526 {
527     IInternetProtocolInfo *protocol_info;
528     IUnknown *unk;
529     IClassFactory *factory;
530     HRESULT hres;
531
532     hres = CoGetClassObject(&CLSID_FileProtocol, CLSCTX_INPROC_SERVER, NULL,
533             &IID_IUnknown, (void**)&unk);
534     ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres);
535     if(!SUCCEEDED(hres))
536         return;
537
538     hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocolInfo, (void**)&protocol_info);
539     ok(hres == E_NOINTERFACE,
540             "Could not get IInternetProtocolInfo interface: %08x, expected E_NOINTERFACE\n", hres);
541
542     hres = IUnknown_QueryInterface(unk, &IID_IClassFactory, (void**)&factory);
543     ok(hres == S_OK, "Could not get IClassFactory interface\n");
544     if(SUCCEEDED(hres)) {
545         IInternetProtocol *protocol;
546         BYTE buf[512];
547         ULONG cb;
548         hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&protocol);
549         ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
550
551         if(SUCCEEDED(hres)) {
552             file_protocol_start(protocol, url, TRUE);
553             hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
554             ok(hres == S_OK, "Read failed: %08x\n", hres);
555             ok(cb == 2, "cb=%u expected 2\n", cb);
556             hres = IInternetProtocol_Read(protocol, buf, sizeof(buf), &cb);
557             ok(hres == S_FALSE, "Read failed: %08x\n", hres);
558             hres = IInternetProtocol_Read(protocol, buf, sizeof(buf), &cb);
559             ok(hres == S_FALSE, "Read failed: %08x expected S_FALSE\n", hres);
560             ok(cb == 0, "cb=%u expected 0\n", cb);
561             hres = IInternetProtocol_UnlockRequest(protocol);
562             ok(hres == S_OK, "UnlockRequest failed: %08x\n", hres);
563
564             file_protocol_start(protocol, url, FALSE);
565             hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
566             ok(hres == S_FALSE, "Read failed: %08x\n", hres);
567             hres = IInternetProtocol_LockRequest(protocol, 0);
568             ok(hres == S_OK, "LockRequest failed: %08x\n", hres);
569             hres = IInternetProtocol_UnlockRequest(protocol);
570             ok(hres == S_OK, "UnlockRequest failed: %08x\n", hres);
571
572             IInternetProtocol_Release(protocol);
573         }
574
575         hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&protocol);
576         ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
577
578         if(SUCCEEDED(hres)) {
579             file_protocol_start(protocol, url, TRUE);
580             hres = IInternetProtocol_LockRequest(protocol, 0);
581             ok(hres == S_OK, "LockRequest failed: %08x\n", hres);
582             hres = IInternetProtocol_Terminate(protocol, 0);
583             ok(hres == S_OK, "Terminate failed: %08x\n", hres);
584             hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
585             ok(hres == S_OK, "Read failed: %08x\n\n", hres);
586             hres = IInternetProtocol_UnlockRequest(protocol);
587             ok(hres == S_OK, "UnlockRequest failed: %08x\n", hres);
588             hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
589             ok(hres == S_OK, "Read failed: %08x\n", hres);
590             hres = IInternetProtocol_Terminate(protocol, 0);
591             ok(hres == S_OK, "Terminate failed: %08x\n", hres);
592
593             IInternetProtocol_Release(protocol);
594         }
595
596         hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&protocol);
597         ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
598
599         if(SUCCEEDED(hres)) {
600             file_protocol_start(protocol, url, TRUE);
601             hres = IInternetProtocol_Terminate(protocol, 0);
602             ok(hres == S_OK, "Terminate failed: %08x\n", hres);
603             hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
604             ok(hres == S_OK, "Read failed: %08x\n", hres);
605             ok(cb == 2, "cb=%u expected 2\n", cb);
606
607             IInternetProtocol_Release(protocol);
608         }
609
610         IClassFactory_Release(factory);
611     }
612
613     IUnknown_Release(unk);
614 }
615
616 static void test_file_protocol_fail(void)
617 {
618     IInternetProtocol *protocol;
619     HRESULT hres;
620
621     static const WCHAR index_url2[] =
622         {'f','i','l','e',':','/','/','i','n','d','e','x','.','h','t','m','l',0};
623
624     hres = CoCreateInstance(&CLSID_FileProtocol, NULL, CLSCTX_INPROC_SERVER|CLSCTX_INPROC_HANDLER,
625             &IID_IInternetProtocol, (void**)&protocol);
626     ok(hres == S_OK, "CoCreateInstance failed: %08x\n", hres);
627     if(FAILED(hres))
628         return;
629
630     SET_EXPECT(GetBindInfo);
631     expect_hrResult = MK_E_SYNTAX;
632     hres = IInternetProtocol_Start(protocol, wszIndexHtml, &protocol_sink, &bind_info, 0, 0);
633     ok(hres == MK_E_SYNTAX, "Start failed: %08x, expected MK_E_SYNTAX\n", hres);
634     CHECK_CALLED(GetBindInfo);
635
636     SET_EXPECT(GetBindInfo);
637     if(!(bindf & BINDF_FROMURLMON))
638         SET_EXPECT(ReportProgress_DIRECTBIND);
639     SET_EXPECT(ReportProgress_SENDINGREQUEST);
640     SET_EXPECT(ReportResult);
641     expect_hrResult = INET_E_RESOURCE_NOT_FOUND;
642     hres = IInternetProtocol_Start(protocol, index_url, &protocol_sink, &bind_info, 0, 0);
643     ok(hres == INET_E_RESOURCE_NOT_FOUND,
644             "Start failed: %08x expected INET_E_RESOURCE_NOT_FOUND\n", hres);
645     CHECK_CALLED(GetBindInfo);
646     if(!(bindf & BINDF_FROMURLMON))
647         CHECK_CALLED(ReportProgress_DIRECTBIND);
648     CHECK_CALLED(ReportProgress_SENDINGREQUEST);
649     CHECK_CALLED(ReportResult);
650
651     IInternetProtocol_Release(protocol);
652
653     hres = CoCreateInstance(&CLSID_FileProtocol, NULL, CLSCTX_INPROC_SERVER|CLSCTX_INPROC_HANDLER,
654             &IID_IInternetProtocol, (void**)&protocol);
655     ok(hres == S_OK, "CoCreateInstance failed: %08x\n", hres);
656     if(FAILED(hres))
657         return;
658
659     SET_EXPECT(GetBindInfo);
660     if(!(bindf & BINDF_FROMURLMON))
661         SET_EXPECT(ReportProgress_DIRECTBIND);
662     SET_EXPECT(ReportProgress_SENDINGREQUEST);
663     SET_EXPECT(ReportResult);
664     expect_hrResult = INET_E_RESOURCE_NOT_FOUND;
665
666     hres = IInternetProtocol_Start(protocol, index_url2, &protocol_sink, &bind_info, 0, 0);
667     ok(hres == INET_E_RESOURCE_NOT_FOUND,
668             "Start failed: %08x, expected INET_E_RESOURCE_NOT_FOUND\n", hres);
669     CHECK_CALLED(GetBindInfo);
670     if(!(bindf & BINDF_FROMURLMON))
671         CHECK_CALLED(ReportProgress_DIRECTBIND);
672     CHECK_CALLED(ReportProgress_SENDINGREQUEST);
673     CHECK_CALLED(ReportResult);
674
675     IInternetProtocol_Release(protocol);
676 }
677
678 static void test_file_protocol(void) {
679     WCHAR buf[MAX_PATH];
680     DWORD size;
681     ULONG len;
682     HANDLE file;
683
684     static const WCHAR wszFile[] = {'f','i','l','e',':',0};
685     static const WCHAR wszFile2[] = {'f','i','l','e',':','/','/',0};
686     static const WCHAR wszFile3[] = {'f','i','l','e',':','/','/','/',0};
687     static const char html_doc[] = "<HTML></HTML>";
688
689     tested_protocol = FILE_TEST;
690
691     file = CreateFileW(wszIndexHtml, GENERIC_WRITE, 0, NULL, CREATE_ALWAYS,
692             FILE_ATTRIBUTE_NORMAL, NULL);
693     ok(file != INVALID_HANDLE_VALUE, "CreateFile failed\n");
694     if(file == INVALID_HANDLE_VALUE)
695         return;
696     WriteFile(file, html_doc, sizeof(html_doc)-1, &size, NULL);
697     CloseHandle(file);
698
699     file_name = wszIndexHtml;
700     bindf = 0;
701     test_file_protocol_url(index_url);
702     bindf = BINDF_FROMURLMON;
703     test_file_protocol_url(index_url);
704
705     memcpy(buf, wszFile, sizeof(wszFile));
706     len = sizeof(wszFile)/sizeof(WCHAR)-1;
707     len += GetCurrentDirectoryW(sizeof(buf)/sizeof(WCHAR)-len, buf+len);
708     buf[len++] = '\\';
709     memcpy(buf+len, wszIndexHtml, sizeof(wszIndexHtml));
710
711     file_name = buf + sizeof(wszFile)/sizeof(WCHAR)-1;
712     bindf = 0;
713     test_file_protocol_url(buf);
714     bindf = BINDF_FROMURLMON;
715     test_file_protocol_url(buf);
716
717     memcpy(buf, wszFile2, sizeof(wszFile2));
718     len = sizeof(wszFile2)/sizeof(WCHAR)-1;
719     len += GetCurrentDirectoryW(sizeof(buf)/sizeof(WCHAR)-len, buf+len);
720     buf[len++] = '\\';
721     memcpy(buf+len, wszIndexHtml, sizeof(wszIndexHtml));
722
723     file_name = buf + sizeof(wszFile2)/sizeof(WCHAR)-1;
724     bindf = 0;
725     test_file_protocol_url(buf);
726     bindf = BINDF_FROMURLMON;
727     test_file_protocol_url(buf);
728
729     memcpy(buf, wszFile3, sizeof(wszFile3));
730     len = sizeof(wszFile3)/sizeof(WCHAR)-1;
731     len += GetCurrentDirectoryW(sizeof(buf)/sizeof(WCHAR)-len, buf+len);
732     buf[len++] = '\\';
733     memcpy(buf+len, wszIndexHtml, sizeof(wszIndexHtml));
734
735     file_name = buf + sizeof(wszFile3)/sizeof(WCHAR)-1;
736     bindf = 0;
737     test_file_protocol_url(buf);
738     bindf = BINDF_FROMURLMON;
739     test_file_protocol_url(buf);
740
741     DeleteFileW(wszIndexHtml);
742
743     bindf = 0;
744     test_file_protocol_fail();
745     bindf = BINDF_FROMURLMON;
746     test_file_protocol_fail();
747 }
748
749 static BOOL http_protocol_start(LPCWSTR url, BOOL is_first)
750 {
751     HRESULT hres;
752
753     first_data_notif = TRUE;
754
755     SET_EXPECT(GetBindInfo);
756     SET_EXPECT(GetBindString_USER_AGENT);
757     SET_EXPECT(GetBindString_ACCEPT_MIMES);
758     SET_EXPECT(QueryService_HttpNegotiate);
759     SET_EXPECT(BeginningTransaction);
760     SET_EXPECT(GetRootSecurityId);
761
762     hres = IInternetProtocol_Start(http_protocol, url, &protocol_sink, &bind_info, 0, 0);
763     todo_wine {
764         ok(hres == S_OK, "Start failed: %08x\n", hres);
765     }
766     if(FAILED(hres))
767         return FALSE;
768
769     CHECK_CALLED(GetBindInfo);
770     CHECK_CALLED(GetBindString_USER_AGENT);
771     CHECK_CALLED(GetBindString_ACCEPT_MIMES);
772     CHECK_CALLED(QueryService_HttpNegotiate);
773     CHECK_CALLED(BeginningTransaction);
774     CHECK_CALLED(GetRootSecurityId);
775
776     return TRUE;
777 }
778
779 static void test_http_protocol_url(LPCWSTR url)
780 {
781     IInternetProtocolInfo *protocol_info;
782     IClassFactory *factory;
783     IUnknown *unk;
784     HRESULT hres;
785
786     http_url = url;
787
788     hres = CoGetClassObject(&CLSID_HttpProtocol, CLSCTX_INPROC_SERVER, NULL,
789             &IID_IUnknown, (void**)&unk);
790     ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres);
791     if(!SUCCEEDED(hres))
792         return;
793
794     hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocolInfo, (void**)&protocol_info);
795     ok(hres == E_NOINTERFACE,
796         "Could not get IInternetProtocolInfo interface: %08x, expected E_NOINTERFACE\n",
797         hres);
798
799     hres = IUnknown_QueryInterface(unk, &IID_IClassFactory, (void**)&factory);
800     ok(hres == S_OK, "Could not get IClassFactory interface\n");
801     IUnknown_Release(unk);
802     if(FAILED(hres))
803         return;
804
805     hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol,
806                                         (void**)&http_protocol);
807     ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
808     if(SUCCEEDED(hres)) {
809         BYTE buf[512];
810         DWORD cb;
811         MSG msg;
812
813         bindf = BINDF_ASYNCHRONOUS | BINDF_ASYNCSTORAGE | BINDF_PULLDATA | BINDF_FROMURLMON;
814
815         test_priority(http_protocol);
816
817         SET_EXPECT(ReportProgress_FINDINGRESOURCE);
818         SET_EXPECT(ReportProgress_CONNECTING);
819         SET_EXPECT(ReportProgress_SENDINGREQUEST);
820
821         if(!http_protocol_start(url, TRUE))
822             return;
823
824         hres = IInternetProtocol_Read(http_protocol, buf, 2, &cb);
825         ok(hres == E_PENDING, "Read failed: %08x, expected E_PENDING\n", hres);
826         ok(!cb, "cb=%d, expected 0\n", cb);
827
828         SET_EXPECT(Switch);
829         SET_EXPECT(ReportResult);
830         expect_hrResult = S_OK;
831
832         GetMessageW(&msg, NULL, 0, 0);
833
834         CHECK_CALLED(Switch);
835         CHECK_CALLED(ReportResult);
836
837         IInternetProtocol_Release(http_protocol);
838     }
839
840     IClassFactory_Release(factory);
841 }
842
843 static void test_http_protocol(void)
844 {
845     static const WCHAR winehq_url[] =
846         {'h','t','t','p',':','/','/','w','w','w','.','w','i','n','e','h','q','.',
847             'o','r','g','/','s','i','t','e','/','a','b','o','u','t',0};
848
849     tested_protocol = HTTP_TEST;
850     test_http_protocol_url(winehq_url);
851
852 }
853
854 static LRESULT WINAPI wnd_proc(HWND hwnd, UINT msg, WPARAM wParam, LPARAM lParam)
855 {
856     if(msg == WM_USER) {
857         HRESULT hres;
858         DWORD cb;
859         BYTE buf[3600];
860
861         SET_EXPECT(ReportData);
862         if(!state) {
863             CHECK_CALLED(ReportProgress_FINDINGRESOURCE);
864             CHECK_CALLED(ReportProgress_CONNECTING);
865             CHECK_CALLED(ReportProgress_SENDINGREQUEST);
866
867             SET_EXPECT(OnResponse);
868             SET_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
869         }
870
871         hres = IInternetProtocol_Continue(http_protocol, (PROTOCOLDATA*)lParam);
872         ok(hres == S_OK, "Continue failed: %08x\n", hres);
873
874         CHECK_CALLED(ReportData);
875         if(!state) {
876             CHECK_CALLED(OnResponse);
877             CHECK_CALLED(ReportProgress_MIMETYPEAVAILABLE);
878         }
879
880         do hres = IInternetProtocol_Read(http_protocol, buf, sizeof(buf), &cb);
881         while(cb);
882
883         ok(hres == S_FALSE || hres == E_PENDING, "Read failed: %08x\n", hres);
884
885         if(hres == S_FALSE)
886             PostMessageW(protocol_hwnd, WM_USER+1, 0, 0);
887
888         if(!state) {
889             state = 1;
890
891             hres = IInternetProtocol_LockRequest(http_protocol, 0);
892             ok(hres == S_OK, "LockRequest failed: %08x\n", hres);
893
894             do hres = IInternetProtocol_Read(http_protocol, buf, sizeof(buf), &cb);
895             while(cb);
896             ok(hres == S_FALSE || hres == E_PENDING, "Read failed: %08x\n", hres);
897         }
898     }
899
900     return DefWindowProc(hwnd, msg, wParam, lParam);
901 }
902
903 static HWND create_protocol_window(void)
904 {
905     static const WCHAR wszProtocolWindow[] =
906         {'P','r','o','t','o','c','o','l','W','i','n','d','o','w',0};
907     static WNDCLASSEXW wndclass = {
908         sizeof(WNDCLASSEXW),
909         0,
910         wnd_proc,
911         0, 0, NULL, NULL, NULL, NULL, NULL,
912         wszProtocolWindow,
913         NULL
914     };
915
916     RegisterClassExW(&wndclass);
917     return CreateWindowW(wszProtocolWindow, wszProtocolWindow,
918                          WS_OVERLAPPEDWINDOW, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT,
919                          CW_USEDEFAULT, NULL, NULL, NULL, NULL);
920 }
921
922 static void test_mk_protocol(void)
923 {
924     IInternetProtocolInfo *protocol_info;
925     IInternetProtocol *protocol;
926     IClassFactory *factory;
927     IUnknown *unk;
928     HRESULT hres;
929
930     static const WCHAR wrong_url1[] = {'t','e','s','t',':','@','M','S','I','T','S','t','o','r','e',
931                                        ':',':','/','t','e','s','t','.','h','t','m','l',0};
932     static const WCHAR wrong_url2[] = {'m','k',':','/','t','e','s','t','.','h','t','m','l',0};
933
934     tested_protocol = MK_TEST;
935
936     hres = CoGetClassObject(&CLSID_MkProtocol, CLSCTX_INPROC_SERVER, NULL,
937             &IID_IUnknown, (void**)&unk);
938     ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres);
939
940     hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocolInfo, (void**)&protocol_info);
941     ok(hres == E_NOINTERFACE,
942         "Could not get IInternetProtocolInfo interface: %08x, expected E_NOINTERFACE\n",
943         hres);
944
945     hres = IUnknown_QueryInterface(unk, &IID_IClassFactory, (void**)&factory);
946     ok(hres == S_OK, "Could not get IClassFactory interface\n");
947     IUnknown_Release(unk);
948     if(FAILED(hres))
949         return;
950
951     hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol,
952                                         (void**)&protocol);
953     IClassFactory_Release(factory);
954     ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
955
956     SET_EXPECT(GetBindInfo);
957     hres = IInternetProtocol_Start(protocol, wrong_url1, &protocol_sink, &bind_info, 0, 0);
958     ok(hres == MK_E_SYNTAX, "Start failed: %08x, expected MK_E_SYNTAX\n", hres);
959     CHECK_CALLED(GetBindInfo);
960
961     SET_EXPECT(GetBindInfo);
962     SET_EXPECT(ReportProgress_DIRECTBIND);
963     SET_EXPECT(ReportProgress_SENDINGREQUEST);
964     SET_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
965     SET_EXPECT(ReportResult);
966     expect_hrResult = INET_E_RESOURCE_NOT_FOUND;
967
968     hres = IInternetProtocol_Start(protocol, wrong_url2, &protocol_sink, &bind_info, 0, 0);
969     ok(hres == INET_E_RESOURCE_NOT_FOUND, "Start failed: %08x, expected INET_E_RESOURCE_NOT_FOUND\n", hres);
970
971     CHECK_CALLED(GetBindInfo);
972     CHECK_CALLED(ReportProgress_DIRECTBIND);
973     CHECK_CALLED(ReportProgress_SENDINGREQUEST);
974     CHECK_CALLED(ReportProgress_MIMETYPEAVAILABLE);
975     CHECK_CALLED(ReportResult);
976
977     IInternetProtocol_Release(protocol);
978 }
979
980 START_TEST(protocol)
981 {
982     OleInitialize(NULL);
983
984     protocol_hwnd = create_protocol_window();
985
986     test_file_protocol();
987     test_http_protocol();
988     test_mk_protocol();
989
990     DestroyWindow(protocol_hwnd);
991
992     OleUninitialize();
993 }