itss: Fix handling URLs without '/' in object name.
[wine] / dlls / itss / tests / protocol.c
1 /*
2  * Copyright 2006 Jacek Caban for CodeWeavers
3  *
4  * This library is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 2.1 of the License, or (at your option) any later version.
8  *
9  * This library is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * Lesser General Public License for more details.
13  *
14  * You should have received a copy of the GNU Lesser General Public
15  * License along with this library; if not, write to the Free Software
16  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
17  */
18
19 #define COBJMACROS
20
21 #include <wine/test.h>
22 #include <stdarg.h>
23
24 #include "windef.h"
25 #include "winbase.h"
26 #include "ole2.h"
27 #include "urlmon.h"
28 #include "shlwapi.h"
29
30 #include "initguid.h"
31
32 #define DEFINE_EXPECT(func) \
33     static BOOL expect_ ## func = FALSE, called_ ## func = FALSE
34
35 #define SET_EXPECT(func) \
36     expect_ ## func = TRUE
37
38 #define CHECK_EXPECT(func) \
39     do { \
40         ok(expect_ ##func, "unexpected call " #func "\n"); \
41         expect_ ## func = FALSE; \
42         called_ ## func = TRUE; \
43     }while(0)
44
45 #define CHECK_EXPECT2(func) \
46     do { \
47         ok(expect_ ##func, "unexpected call " #func  "\n"); \
48         called_ ## func = TRUE; \
49     }while(0)
50
51 #define CHECK_CALLED(func) \
52     do { \
53         ok(called_ ## func, "expected " #func "\n"); \
54         expect_ ## func = called_ ## func = FALSE; \
55     }while(0)
56
57 DEFINE_GUID(CLSID_ITSProtocol,0x9d148291,0xb9c8,0x11d0,0xa4,0xcc,0x00,0x00,0xf8,0x01,0x49,0xf6);
58
59 DEFINE_EXPECT(GetBindInfo);
60 DEFINE_EXPECT(ReportProgress_BEGINDOWNLOADDATA);
61 DEFINE_EXPECT(ReportProgress_SENDINGREQUEST);
62 DEFINE_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
63 DEFINE_EXPECT(ReportProgress_CACHEFILENAMEAVAIABLE);
64 DEFINE_EXPECT(ReportProgress_DIRECTBIND);
65 DEFINE_EXPECT(ReportData);
66 DEFINE_EXPECT(ReportResult);
67
68 static HRESULT expect_hrResult;
69 static IInternetProtocol *read_protocol = NULL;
70
71 static enum {
72     ITS_PROTOCOL,
73     MK_PROTOCOL
74 } test_protocol;
75
76 static HRESULT WINAPI ProtocolSink_QueryInterface(IInternetProtocolSink *iface, REFIID riid, void **ppv)
77 {
78     if(IsEqualGUID(&IID_IUnknown, riid) || IsEqualGUID(&IID_IInternetProtocolSink, riid)) {
79         *ppv = iface;
80         return S_OK;
81     }
82     return E_NOINTERFACE;
83 }
84
85 static ULONG WINAPI ProtocolSink_AddRef(IInternetProtocolSink *iface)
86 {
87     return 2;
88 }
89
90 static ULONG WINAPI ProtocolSink_Release(IInternetProtocolSink *iface)
91 {
92     return 1;
93 }
94
95 static HRESULT WINAPI ProtocolSink_Switch(IInternetProtocolSink *iface, PROTOCOLDATA *pProtocolData)
96 {
97     ok(0, "unexpected call\n");
98     return E_NOTIMPL;
99 }
100
101 static HRESULT WINAPI ProtocolSink_ReportProgress(IInternetProtocolSink *iface, ULONG ulStatusCode,
102         LPCWSTR szStatusText)
103 {
104     static const WCHAR blank_html[] = {'b','l','a','n','k','.','h','t','m','l',0};
105     static const WCHAR text_html[] = {'t','e','x','t','/','h','t','m','l',0};
106     static const WCHAR cache_file[] =
107         {'t','e','s','t','.','c','h','m',':',':','/','b','l','a','n','k','.','h','t','m','l',0};
108
109     switch(ulStatusCode) {
110     case BINDSTATUS_BEGINDOWNLOADDATA:
111         CHECK_EXPECT(ReportProgress_BEGINDOWNLOADDATA);
112         ok(!szStatusText, "szStatusText != NULL\n");
113         break;
114     case BINDSTATUS_SENDINGREQUEST:
115         CHECK_EXPECT(ReportProgress_SENDINGREQUEST);
116         if(test_protocol == ITS_PROTOCOL)
117             ok(!lstrcmpW(szStatusText, blank_html), "unexpected szStatusText\n");
118         else
119             ok(szStatusText == NULL, "szStatusText != NULL\n");
120         break;
121     case BINDSTATUS_MIMETYPEAVAILABLE:
122         CHECK_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
123         ok(!lstrcmpW(szStatusText, text_html), "unexpected szStatusText\n");
124         break;
125     case BINDSTATUS_CACHEFILENAMEAVAILABLE:
126         CHECK_EXPECT(ReportProgress_CACHEFILENAMEAVAIABLE);
127         ok(!lstrcmpW(szStatusText, cache_file), "unexpected szStatusText\n");
128         break;
129     case BINDSTATUS_DIRECTBIND:
130         CHECK_EXPECT(ReportProgress_DIRECTBIND);
131         ok(!szStatusText, "szStatusText != NULL\n");
132         break;
133     default:
134         ok(0, "unexpected ulStatusCode %d\n", ulStatusCode);
135         break;
136     }
137
138     return S_OK;
139 }
140
141 static HRESULT WINAPI ProtocolSink_ReportData(IInternetProtocolSink *iface, DWORD grfBSCF, ULONG ulProgress,
142         ULONG ulProgressMax)
143 {
144     CHECK_EXPECT(ReportData);
145
146     ok(ulProgress == ulProgressMax, "ulProgress != ulProgressMax\n");
147     if(test_protocol == ITS_PROTOCOL)
148         ok(grfBSCF == (BSCF_FIRSTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE), "grcf = %08x\n", grfBSCF);
149     else
150         ok(grfBSCF == (BSCF_FIRSTDATANOTIFICATION | BSCF_LASTDATANOTIFICATION), "grcf = %08x\n", grfBSCF);
151
152     if(read_protocol) {
153         BYTE buf[100];
154         DWORD cb = 0xdeadbeef;
155         HRESULT hres;
156
157         hres = IInternetProtocol_Read(read_protocol, buf, sizeof(buf), &cb);
158         ok(hres == S_OK, "Read failed: %08x\n", hres);
159         ok(cb == 13, "cb=%u expected 13\n", cb);
160         ok(!memcmp(buf, "<html></html>", 13), "unexpected data\n");
161     }
162
163     return S_OK;
164 }
165
166 static HRESULT WINAPI ProtocolSink_ReportResult(IInternetProtocolSink *iface, HRESULT hrResult,
167         DWORD dwError, LPCWSTR szResult)
168 {
169     CHECK_EXPECT(ReportResult);
170
171     ok(hrResult == expect_hrResult, "expected: %08x got: %08x\n", expect_hrResult, hrResult);
172     ok(dwError == 0, "dwError = %d\n", dwError);
173     ok(!szResult, "szResult != NULL\n");
174
175     return S_OK;
176 }
177
178 static IInternetProtocolSinkVtbl protocol_sink_vtbl = {
179     ProtocolSink_QueryInterface,
180     ProtocolSink_AddRef,
181     ProtocolSink_Release,
182     ProtocolSink_Switch,
183     ProtocolSink_ReportProgress,
184     ProtocolSink_ReportData,
185     ProtocolSink_ReportResult
186 };
187
188 static IInternetProtocolSink protocol_sink = {
189     &protocol_sink_vtbl
190 };
191
192 static HRESULT WINAPI BindInfo_QueryInterface(IInternetBindInfo *iface, REFIID riid, void **ppv)
193 {
194     if(IsEqualGUID(&IID_IUnknown, riid) || IsEqualGUID(&IID_IInternetBindInfo, riid)) {
195         *ppv = iface;
196         return S_OK;
197     }
198     return E_NOINTERFACE;
199 }
200
201 static ULONG WINAPI BindInfo_AddRef(IInternetBindInfo *iface)
202 {
203     return 2;
204 }
205
206 static ULONG WINAPI BindInfo_Release(IInternetBindInfo *iface)
207 {
208     return 1;
209 }
210
211 static HRESULT WINAPI BindInfo_GetBindInfo(IInternetBindInfo *iface, DWORD *grfBINDF, BINDINFO *pbindinfo)
212 {
213     CHECK_EXPECT(GetBindInfo);
214
215     ok(grfBINDF != NULL, "grfBINDF == NULL\n");
216     if(grfBINDF)
217         ok(!*grfBINDF, "*grfBINDF != 0\n");
218     ok(pbindinfo != NULL, "pbindinfo == NULL\n");
219     ok(pbindinfo->cbSize == sizeof(BINDINFO), "wrong size of pbindinfo: %d\n", pbindinfo->cbSize);
220
221     return S_OK;
222 }
223
224 static HRESULT WINAPI BindInfo_GetBindString(IInternetBindInfo *iface, ULONG ulStringType, LPOLESTR *ppwzStr,
225         ULONG cEl, ULONG *pcElFetched)
226 {
227     ok(0, "unexpected call\n");
228     return E_NOTIMPL;
229 }
230
231 static IInternetBindInfoVtbl bind_info_vtbl = {
232     BindInfo_QueryInterface,
233     BindInfo_AddRef,
234     BindInfo_Release,
235     BindInfo_GetBindInfo,
236     BindInfo_GetBindString
237 };
238
239 static IInternetBindInfo bind_info = {
240     &bind_info_vtbl
241 };
242
243 static void test_protocol_fail(IInternetProtocol *protocol, LPCWSTR url, HRESULT expected_hres)
244 {
245     HRESULT hres;
246
247     SET_EXPECT(GetBindInfo);
248     SET_EXPECT(ReportResult);
249
250     expect_hrResult = expected_hres;
251     hres = IInternetProtocol_Start(protocol, url, &protocol_sink, &bind_info, 0, 0);
252     ok(hres == expected_hres, "expected: %08x got: %08x\n", expected_hres, hres);
253
254     CHECK_CALLED(GetBindInfo);
255     CHECK_CALLED(ReportResult);
256 }
257
258 static void protocol_start(IInternetProtocol *protocol, LPCWSTR url)
259 {
260     HRESULT hres;
261
262     SET_EXPECT(GetBindInfo);
263     if(test_protocol == MK_PROTOCOL)
264         SET_EXPECT(ReportProgress_DIRECTBIND);
265     SET_EXPECT(ReportProgress_SENDINGREQUEST);
266     SET_EXPECT(ReportProgress_MIMETYPEAVAILABLE);
267     if(test_protocol == MK_PROTOCOL)
268         SET_EXPECT(ReportProgress_CACHEFILENAMEAVAIABLE);
269     SET_EXPECT(ReportData);
270     if(test_protocol == ITS_PROTOCOL)
271         SET_EXPECT(ReportProgress_BEGINDOWNLOADDATA);
272     SET_EXPECT(ReportResult);
273     expect_hrResult = S_OK;
274
275     hres = IInternetProtocol_Start(protocol, url, &protocol_sink, &bind_info, 0, 0);
276     ok(hres == S_OK, "Start failed: %08x\n", hres);
277
278     CHECK_CALLED(GetBindInfo);
279     if(test_protocol == MK_PROTOCOL)
280         CHECK_CALLED(ReportProgress_DIRECTBIND);
281     CHECK_CALLED(ReportProgress_SENDINGREQUEST);
282     CHECK_CALLED(ReportProgress_MIMETYPEAVAILABLE);
283     if(test_protocol == MK_PROTOCOL)
284         SET_EXPECT(ReportProgress_CACHEFILENAMEAVAIABLE);
285     CHECK_CALLED(ReportData);
286     if(test_protocol == ITS_PROTOCOL)
287         CHECK_CALLED(ReportProgress_BEGINDOWNLOADDATA);
288     CHECK_CALLED(ReportResult);
289 }
290
291 static void test_protocol_url(IClassFactory *factory, LPCWSTR url)
292 {
293     IInternetProtocol *protocol;
294     BYTE buf[512];
295     ULONG cb, ref;
296     HRESULT hres;
297
298     hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&protocol);
299     ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
300     if(FAILED(hres))
301         return;
302
303     protocol_start(protocol, url);
304     hres = IInternetProtocol_Read(protocol, buf, sizeof(buf), &cb);
305     ok(hres == S_OK, "Read failed: %08x\n", hres);
306     ok(cb == 13, "cb=%u expected 13\n", cb);
307     ok(!memcmp(buf, "<html></html>", 13), "unexpected data\n");
308     ref = IInternetProtocol_Release(protocol);
309     ok(!ref, "protocol ref=%d\n", ref);
310
311     hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&protocol);
312     ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
313     if(FAILED(hres))
314         return;
315
316     cb = 0xdeadbeef;
317     hres = IInternetProtocol_Read(protocol, buf, sizeof(buf), &cb);
318     ok(hres == (test_protocol == ITS_PROTOCOL ? INET_E_DATA_NOT_AVAILABLE : E_FAIL),
319        "Read returned %08x\n", hres);
320     ok(cb == 0xdeadbeef, "cb=%u expected 0xdeadbeef\n", cb);
321
322     protocol_start(protocol, url);
323     hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
324     ok(hres == S_OK, "Read failed: %08x\n", hres);
325     ok(cb == 2, "cb=%u expected 2\n", cb);
326     hres = IInternetProtocol_Read(protocol, buf, sizeof(buf), &cb);
327     ok(hres == S_OK, "Read failed: %08x\n", hres);
328     ok(cb == 11, "cb=%u, expected 11\n", cb);
329     hres = IInternetProtocol_Read(protocol, buf, sizeof(buf), &cb);
330     ok(hres == S_FALSE, "Read failed: %08x expected S_FALSE\n", hres);
331     ok(cb == 0, "cb=%u expected 0\n", cb);
332     hres = IInternetProtocol_UnlockRequest(protocol);
333     ok(hres == S_OK, "UnlockRequest failed: %08x\n", hres);
334     ref = IInternetProtocol_Release(protocol);
335     ok(!ref, "protocol ref=%d\n", ref);
336
337     hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&protocol);
338     ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
339     if(FAILED(hres))
340         return;
341
342     protocol_start(protocol, url);
343     hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
344     ok(hres == S_OK, "Read failed: %08x\n", hres);
345     hres = IInternetProtocol_LockRequest(protocol, 0);
346     ok(hres == S_OK, "LockRequest failed: %08x\n", hres);
347     hres = IInternetProtocol_UnlockRequest(protocol);
348     ok(hres == S_OK, "UnlockRequest failed: %08x\n", hres);
349     hres = IInternetProtocol_Read(protocol, buf, sizeof(buf), &cb);
350     ok(hres == S_OK, "Read failed: %08x\n", hres);
351     ok(cb == 11, "cb=%u, expected 11\n", cb);
352     ref = IInternetProtocol_Release(protocol);
353     ok(!ref, "protocol ref=%d\n", ref);
354
355     hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&protocol);
356     ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
357     if(FAILED(hres))
358         return;
359
360     protocol_start(protocol, url);
361     hres = IInternetProtocol_LockRequest(protocol, 0);
362     ok(hres == S_OK, "LockRequest failed: %08x\n", hres);
363     hres = IInternetProtocol_Terminate(protocol, 0);
364     ok(hres == S_OK, "Terminate failed: %08x\n", hres);
365     hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
366     ok(hres == S_OK, "Read failed: %08x\n", hres);
367     ok(cb == 2, "cb=%u, expected 2\n", cb);
368     hres = IInternetProtocol_UnlockRequest(protocol);
369     ok(hres == S_OK, "UnlockRequest failed: %08x\n", hres);
370     hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
371     ok(hres == S_OK, "Read failed: %08x\n", hres);
372     ok(cb == 2, "cb=%u, expected 2\n", cb);
373     hres = IInternetProtocol_Terminate(protocol, 0);
374     ok(hres == S_OK, "Terminate failed: %08x\n", hres);
375     hres = IInternetProtocol_Read(protocol, buf, 2, &cb);
376     ok(hres == S_OK, "Read failed: %08x\n", hres);
377     ok(cb == 2, "cb=%u expected 2\n", cb);
378     ref = IInternetProtocol_Release(protocol);
379     ok(!ref, "protocol ref=%d\n", ref);
380
381     hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&read_protocol);
382     ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
383     if(FAILED(hres))
384         return;
385
386     protocol_start(read_protocol, url);
387     ref = IInternetProtocol_Release(read_protocol);
388     ok(!ref, "protocol ref=%d\n", ref);
389     read_protocol = NULL;
390 }
391
392 static void test_its_protocol(void)
393 {
394     IClassFactory *factory;
395     IUnknown *unk;
396     ULONG ref;
397     HRESULT hres;
398
399     static const WCHAR blank_url1[] = {'i','t','s',':',
400         't','e','s','t','.','c','h','m',':',':','/','b','l','a','n','k','.','h','t','m','l',0};
401     static const WCHAR blank_url2[] = {'m','S','-','i','T','s',':',
402          't','e','s','t','.','c','h','m',':',':','/','b','l','a','n','k','.','h','t','m','l',0};
403     static const WCHAR blank_url3[] = {'m','k',':','@','M','S','I','T','S','t','o','r','e',':',
404          't','e','s','t','.','c','h','m',':',':','/','b','l','a','n','k','.','h','t','m','l',0};
405     static const WCHAR blank_url4[] = {'i','t','s',':',
406         't','e','s','t','.','c','h','m',':',':','b','l','a','n','k','.','h','t','m','l',0};
407     static const WCHAR wrong_url1[] =
408         {'i','t','s',':','t','e','s','t','.','c','h','m',':',':','/','b','l','a','n','.','h','t','m','l',0};
409     static const WCHAR wrong_url2[] =
410         {'i','t','s',':','t','e','s','.','c','h','m',':',':','b','/','l','a','n','k','.','h','t','m','l',0};
411     static const WCHAR wrong_url3[] =
412         {'i','t','s',':','t','e','s','t','.','c','h','m','/','b','l','a','n','k','.','h','t','m','l',0};
413     static const WCHAR wrong_url4[] = {'m','k',':','@','M','S','I','T','S','t','o','r',':',
414          't','e','s','t','.','c','h','m',':',':','/','b','l','a','n','k','.','h','t','m','l',0};
415     static const WCHAR wrong_url5[] = {'f','i','l','e',':',
416         't','e','s','.','c','h','m',':',':','/','b','l','a','n','k','.','h','t','m','l',0};
417
418     test_protocol = ITS_PROTOCOL;
419
420     hres = CoGetClassObject(&CLSID_ITSProtocol, CLSCTX_INPROC_SERVER, NULL, &IID_IUnknown, (void**)&unk);
421     ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres);
422     if(!SUCCEEDED(hres))
423         return;
424
425     hres = IUnknown_QueryInterface(unk, &IID_IClassFactory, (void**)&factory);
426     ok(hres == S_OK, "Could not get IClassFactory interface\n");
427     if(SUCCEEDED(hres)) {
428         IInternetProtocol *protocol;
429
430         hres = IClassFactory_CreateInstance(factory, NULL, &IID_IInternetProtocol, (void**)&protocol);
431         ok(hres == S_OK, "Could not get IInternetProtocol: %08x\n", hres);
432         if(SUCCEEDED(hres)) {
433             test_protocol_fail(protocol, wrong_url1, STG_E_FILENOTFOUND);
434             test_protocol_fail(protocol, wrong_url2, STG_E_FILENOTFOUND);
435             test_protocol_fail(protocol, wrong_url3, STG_E_FILENOTFOUND);
436
437             hres = IInternetProtocol_Start(protocol, wrong_url4, &protocol_sink, &bind_info, 0, 0);
438             ok(hres == INET_E_USE_DEFAULT_PROTOCOLHANDLER,
439                "Start failed: %08x, expected INET_E_USE_DEFAULT_PROTOCOLHANDLER\n", hres);
440
441             hres = IInternetProtocol_Start(protocol, wrong_url5, &protocol_sink, &bind_info, 0, 0);
442             ok(hres == INET_E_USE_DEFAULT_PROTOCOLHANDLER,
443                "Start failed: %08x, expected INET_E_USE_DEFAULT_PROTOCOLHANDLER\n", hres);
444
445             ref = IInternetProtocol_Release(protocol);
446             ok(!ref, "protocol ref=%d\n", ref);
447
448             test_protocol_url(factory, blank_url1);
449             test_protocol_url(factory, blank_url2);
450             test_protocol_url(factory, blank_url3);
451             test_protocol_url(factory, blank_url4);
452         }
453
454         IClassFactory_Release(factory);
455     }
456
457     IUnknown_Release(unk);
458 }
459
460 static void test_mk_protocol(void)
461 {
462     IClassFactory *cf;
463     HRESULT hres;
464
465     static const WCHAR blank_url[] = {'m','k',':','@','M','S','I','T','S','t','o','r','e',':',
466          't','e','s','t','.','c','h','m',':',':','/','b','l','a','n','k','.','h','t','m','l',0};
467
468     test_protocol = MK_PROTOCOL;
469
470     hres = CoGetClassObject(&CLSID_MkProtocol, CLSCTX_INPROC_SERVER, NULL, &IID_IClassFactory,
471                             (void**)&cf);
472     ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres);
473     if(!SUCCEEDED(hres))
474         return;
475
476     test_protocol_url(cf, blank_url);
477
478     IClassFactory_Release(cf);
479 }
480
481 static BOOL create_chm(void)
482 {
483     HANDLE file;
484     HRSRC src;
485     DWORD size;
486
487     file = CreateFileA("test.chm", GENERIC_WRITE, 0, NULL, CREATE_ALWAYS,
488             FILE_ATTRIBUTE_NORMAL, NULL);
489     ok(file != INVALID_HANDLE_VALUE, "Could not create test.chm file\n");
490     if(file == INVALID_HANDLE_VALUE)
491         return FALSE;
492
493     src = FindResourceA(NULL, MAKEINTRESOURCEA(60), MAKEINTRESOURCEA(60));
494
495     WriteFile(file, LoadResource(NULL, src), SizeofResource(NULL, src), &size, NULL);
496     CloseHandle(file);
497
498     return TRUE;
499 }
500
501 static void delete_chm(void)
502 {
503     BOOL ret;
504
505     ret = DeleteFileA("test.chm");
506     ok(ret, "DeleteFileA failed: %d\n", GetLastError());
507 }
508
509 START_TEST(protocol)
510 {
511     OleInitialize(NULL);
512
513     if(!create_chm())
514         return;
515
516     test_its_protocol();
517     test_mk_protocol();
518
519     delete_chm();
520     OleUninitialize();
521 }