urlmon: Return correct error in get_protocol_handler for unknown protocol types.
[wine] / dlls / urlmon / session.c
1 /*
2  * Copyright 2005-2006 Jacek Caban for CodeWeavers
3  *
4  * This library is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 2.1 of the License, or (at your option) any later version.
8  *
9  * This library is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * Lesser General Public License for more details.
13  *
14  * You should have received a copy of the GNU Lesser General Public
15  * License along with this library; if not, write to the Free Software
16  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
17  */
18
19 #include "urlmon_main.h"
20 #include "winreg.h"
21
22 #include "wine/debug.h"
23
24 WINE_DEFAULT_DEBUG_CHANNEL(urlmon);
25
26 typedef struct name_space {
27     LPWSTR protocol;
28     IClassFactory *cf;
29     CLSID clsid;
30     BOOL urlmon;
31
32     struct name_space *next;
33 } name_space;
34
35 typedef struct mime_filter {
36     IClassFactory *cf;
37     CLSID clsid;
38     LPWSTR mime;
39
40     struct mime_filter *next;
41 } mime_filter;
42
43 static name_space *name_space_list = NULL;
44 static mime_filter *mime_filter_list = NULL;
45
46 static CRITICAL_SECTION session_cs;
47 static CRITICAL_SECTION_DEBUG session_cs_dbg =
48 {
49     0, 0, &session_cs,
50     { &session_cs_dbg.ProcessLocksList, &session_cs_dbg.ProcessLocksList },
51       0, 0, { (DWORD_PTR)(__FILE__ ": session") }
52 };
53 static CRITICAL_SECTION session_cs = { &session_cs_dbg, -1, 0, 0, 0, 0 };
54
55 static const WCHAR internet_settings_keyW[] =
56     {'S','O','F','T','W','A','R','E',
57      '\\','M','i','c','r','o','s','o','f','t',
58      '\\','W','i','n','d','o','w','s',
59      '\\','C','u','r','r','e','n','t','V','e','r','s','i','o','n',
60      '\\','I','n','t','e','r','n','e','t',' ','S','e','t','t','i','n','g','s',0};
61
62 static name_space *find_name_space(LPCWSTR protocol)
63 {
64     name_space *iter;
65
66     for(iter = name_space_list; iter; iter = iter->next) {
67         if(!strcmpW(iter->protocol, protocol))
68             return iter;
69     }
70
71     return NULL;
72 }
73
74 static HRESULT get_protocol_cf(LPCWSTR schema, DWORD schema_len, CLSID *pclsid, IClassFactory **ret)
75 {
76     WCHAR str_clsid[64];
77     HKEY hkey = NULL;
78     DWORD res, type, size;
79     CLSID clsid;
80     LPWSTR wszKey;
81     HRESULT hres;
82
83     static const WCHAR wszProtocolsKey[] =
84         {'P','R','O','T','O','C','O','L','S','\\','H','a','n','d','l','e','r','\\'};
85     static const WCHAR wszCLSID[] = {'C','L','S','I','D',0};
86
87     wszKey = heap_alloc(sizeof(wszProtocolsKey)+(schema_len+1)*sizeof(WCHAR));
88     memcpy(wszKey, wszProtocolsKey, sizeof(wszProtocolsKey));
89     memcpy(wszKey + sizeof(wszProtocolsKey)/sizeof(WCHAR), schema, (schema_len+1)*sizeof(WCHAR));
90
91     res = RegOpenKeyW(HKEY_CLASSES_ROOT, wszKey, &hkey);
92     heap_free(wszKey);
93     if(res != ERROR_SUCCESS) {
94         TRACE("Could not open protocol handler key\n");
95         return MK_E_SYNTAX;
96     }
97     
98     size = sizeof(str_clsid);
99     res = RegQueryValueExW(hkey, wszCLSID, NULL, &type, (LPBYTE)str_clsid, &size);
100     RegCloseKey(hkey);
101     if(res != ERROR_SUCCESS || type != REG_SZ) {
102         WARN("Could not get protocol CLSID res=%d\n", res);
103         return MK_E_SYNTAX;
104     }
105
106     hres = CLSIDFromString(str_clsid, &clsid);
107     if(FAILED(hres)) {
108         WARN("CLSIDFromString failed: %08x\n", hres);
109         return hres;
110     }
111
112     if(pclsid)
113         *pclsid = clsid;
114
115     if(!ret)
116         return S_OK;
117
118     hres = CoGetClassObject(&clsid, CLSCTX_INPROC_SERVER, NULL, &IID_IClassFactory, (void**)ret);
119     return SUCCEEDED(hres) ? S_OK : MK_E_SYNTAX;
120 }
121
122 static HRESULT register_namespace(IClassFactory *cf, REFIID clsid, LPCWSTR protocol, BOOL urlmon_protocol)
123 {
124     name_space *new_name_space;
125
126     new_name_space = heap_alloc(sizeof(name_space));
127
128     if(!urlmon_protocol)
129         IClassFactory_AddRef(cf);
130     new_name_space->cf = cf;
131     new_name_space->clsid = *clsid;
132     new_name_space->urlmon = urlmon_protocol;
133     new_name_space->protocol = heap_strdupW(protocol);
134
135     EnterCriticalSection(&session_cs);
136
137     new_name_space->next = name_space_list;
138     name_space_list = new_name_space;
139
140     LeaveCriticalSection(&session_cs);
141
142     return S_OK;
143 }
144
145 static HRESULT unregister_namespace(IClassFactory *cf, LPCWSTR protocol)
146 {
147     name_space *iter, *last = NULL;
148
149     EnterCriticalSection(&session_cs);
150
151     for(iter = name_space_list; iter; iter = iter->next) {
152         if(iter->cf == cf && !strcmpW(iter->protocol, protocol))
153             break;
154         last = iter;
155     }
156
157     if(iter) {
158         if(last)
159             last->next = iter->next;
160         else
161             name_space_list = iter->next;
162     }
163
164     LeaveCriticalSection(&session_cs);
165
166     if(iter) {
167         if(!iter->urlmon)
168             IClassFactory_Release(iter->cf);
169         heap_free(iter->protocol);
170         heap_free(iter);
171     }
172
173     return S_OK;
174 }
175
176
177 void register_urlmon_namespace(IClassFactory *cf, REFIID clsid, LPCWSTR protocol, BOOL do_register)
178 {
179     if(do_register)
180         register_namespace(cf, clsid, protocol, TRUE);
181     else
182         unregister_namespace(cf, protocol);
183 }
184
185 BOOL is_registered_protocol(LPCWSTR url)
186 {
187     DWORD schema_len;
188     WCHAR schema[64];
189     HRESULT hres;
190
191     hres = CoInternetParseUrl(url, PARSE_SCHEMA, 0, schema, sizeof(schema)/sizeof(schema[0]),
192             &schema_len, 0);
193     if(FAILED(hres))
194         return FALSE;
195
196     return get_protocol_cf(schema, schema_len, NULL, NULL) == S_OK;
197 }
198
199 IInternetProtocolInfo *get_protocol_info(LPCWSTR url)
200 {
201     IInternetProtocolInfo *ret = NULL;
202     IClassFactory *cf;
203     name_space *ns;
204     WCHAR schema[64];
205     DWORD schema_len;
206     HRESULT hres;
207
208     hres = CoInternetParseUrl(url, PARSE_SCHEMA, 0, schema, sizeof(schema)/sizeof(schema[0]),
209             &schema_len, 0);
210     if(FAILED(hres) || !schema_len)
211         return NULL;
212
213     EnterCriticalSection(&session_cs);
214
215     ns = find_name_space(schema);
216     if(ns && !ns->urlmon) {
217         hres = IClassFactory_QueryInterface(ns->cf, &IID_IInternetProtocolInfo, (void**)&ret);
218         if(FAILED(hres))
219             hres = IClassFactory_CreateInstance(ns->cf, NULL, &IID_IInternetProtocolInfo, (void**)&ret);
220     }
221
222     LeaveCriticalSection(&session_cs);
223
224     if(ns && SUCCEEDED(hres))
225         return ret;
226
227     hres = get_protocol_cf(schema, schema_len, NULL, &cf);
228     if(FAILED(hres))
229         return NULL;
230
231     hres = IClassFactory_QueryInterface(cf, &IID_IInternetProtocolInfo, (void**)&ret);
232     if(FAILED(hres))
233         IClassFactory_CreateInstance(cf, NULL, &IID_IInternetProtocolInfo, (void**)&ret);
234     IClassFactory_Release(cf);
235
236     return ret;
237 }
238
239 HRESULT get_protocol_handler(LPCWSTR url, CLSID *clsid, BOOL *urlmon_protocol, IClassFactory **ret)
240 {
241     name_space *ns;
242     WCHAR schema[64];
243     DWORD schema_len;
244     HRESULT hres;
245
246     *ret = NULL;
247
248     hres = CoInternetParseUrl(url, PARSE_SCHEMA, 0, schema, sizeof(schema)/sizeof(schema[0]),
249             &schema_len, 0);
250     if(FAILED(hres) || !schema_len)
251         return schema_len ? hres : MK_E_SYNTAX;
252
253     EnterCriticalSection(&session_cs);
254
255     ns = find_name_space(schema);
256     if(ns) {
257         *ret = ns->cf;
258         IClassFactory_AddRef(*ret);
259         if(clsid)
260             *clsid = ns->clsid;
261         if(urlmon_protocol)
262             *urlmon_protocol = ns->urlmon;
263     }
264
265     LeaveCriticalSection(&session_cs);
266
267     if(*ret)
268         return S_OK;
269
270     if(urlmon_protocol)
271         *urlmon_protocol = FALSE;
272     return get_protocol_cf(schema, schema_len, clsid, ret);
273 }
274
275 IInternetProtocol *get_mime_filter(LPCWSTR mime)
276 {
277     IClassFactory *cf = NULL;
278     IInternetProtocol *ret;
279     mime_filter *iter;
280     HRESULT hres;
281
282     EnterCriticalSection(&session_cs);
283
284     for(iter = mime_filter_list; iter; iter = iter->next) {
285         if(!strcmpW(iter->mime, mime)) {
286             cf = iter->cf;
287             break;
288         }
289     }
290
291     LeaveCriticalSection(&session_cs);
292
293     if(!cf)
294         return NULL;
295
296     hres = IClassFactory_CreateInstance(cf, NULL, &IID_IInternetProtocol, (void**)&ret);
297     if(FAILED(hres)) {
298         WARN("CreateInstance failed: %08x\n", hres);
299         return NULL;
300     }
301
302     return ret;
303 }
304
305 static HRESULT WINAPI InternetSession_QueryInterface(IInternetSession *iface,
306         REFIID riid, void **ppv)
307 {
308     TRACE("(%s %p)\n", debugstr_guid(riid), ppv);
309
310     if(IsEqualGUID(&IID_IUnknown, riid) || IsEqualGUID(&IID_IInternetSession, riid)) {
311         *ppv = iface;
312         IInternetSession_AddRef(iface);
313         return S_OK;
314     }
315
316     *ppv = NULL;
317     return E_NOINTERFACE;
318 }
319
320 static ULONG WINAPI InternetSession_AddRef(IInternetSession *iface)
321 {
322     TRACE("()\n");
323     URLMON_LockModule();
324     return 2;
325 }
326
327 static ULONG WINAPI InternetSession_Release(IInternetSession *iface)
328 {
329     TRACE("()\n");
330     URLMON_UnlockModule();
331     return 1;
332 }
333
334 static HRESULT WINAPI InternetSession_RegisterNameSpace(IInternetSession *iface,
335         IClassFactory *pCF, REFCLSID rclsid, LPCWSTR pwzProtocol, ULONG cPatterns,
336         const LPCWSTR *ppwzPatterns, DWORD dwReserved)
337 {
338     TRACE("(%p %s %s %d %p %d)\n", pCF, debugstr_guid(rclsid), debugstr_w(pwzProtocol),
339           cPatterns, ppwzPatterns, dwReserved);
340
341     if(cPatterns || ppwzPatterns)
342         FIXME("patterns not supported\n");
343     if(dwReserved)
344         WARN("dwReserved = %d\n", dwReserved);
345
346     if(!pCF || !pwzProtocol)
347         return E_INVALIDARG;
348
349     return register_namespace(pCF, rclsid, pwzProtocol, FALSE);
350 }
351
352 static HRESULT WINAPI InternetSession_UnregisterNameSpace(IInternetSession *iface,
353         IClassFactory *pCF, LPCWSTR pszProtocol)
354 {
355     TRACE("(%p %s)\n", pCF, debugstr_w(pszProtocol));
356
357     if(!pCF || !pszProtocol)
358         return E_INVALIDARG;
359
360     return unregister_namespace(pCF, pszProtocol);
361 }
362
363 static HRESULT WINAPI InternetSession_RegisterMimeFilter(IInternetSession *iface,
364         IClassFactory *pCF, REFCLSID rclsid, LPCWSTR pwzType)
365 {
366     mime_filter *filter;
367
368     TRACE("(%p %s %s)\n", pCF, debugstr_guid(rclsid), debugstr_w(pwzType));
369
370     filter = heap_alloc(sizeof(mime_filter));
371
372     IClassFactory_AddRef(pCF);
373     filter->cf = pCF;
374     filter->clsid = *rclsid;
375     filter->mime = heap_strdupW(pwzType);
376
377     EnterCriticalSection(&session_cs);
378
379     filter->next = mime_filter_list;
380     mime_filter_list = filter;
381
382     LeaveCriticalSection(&session_cs);
383
384     return S_OK;
385 }
386
387 static HRESULT WINAPI InternetSession_UnregisterMimeFilter(IInternetSession *iface,
388         IClassFactory *pCF, LPCWSTR pwzType)
389 {
390     mime_filter *iter, *prev = NULL;
391
392     TRACE("(%p %s)\n", pCF, debugstr_w(pwzType));
393
394     EnterCriticalSection(&session_cs);
395
396     for(iter = mime_filter_list; iter; iter = iter->next) {
397         if(iter->cf == pCF && !strcmpW(iter->mime, pwzType))
398             break;
399         prev = iter;
400     }
401
402     if(iter) {
403         if(prev)
404             prev->next = iter->next;
405         else
406             mime_filter_list = iter->next;
407     }
408
409     LeaveCriticalSection(&session_cs);
410
411     if(iter) {
412         IClassFactory_Release(iter->cf);
413         heap_free(iter->mime);
414         heap_free(iter);
415     }
416
417     return S_OK;
418 }
419
420 static HRESULT WINAPI InternetSession_CreateBinding(IInternetSession *iface,
421         LPBC pBC, LPCWSTR szUrl, IUnknown *pUnkOuter, IUnknown **ppUnk,
422         IInternetProtocol **ppOInetProt, DWORD dwOption)
423 {
424     TRACE("(%p %s %p %p %p %08x)\n", pBC, debugstr_w(szUrl), pUnkOuter, ppUnk,
425             ppOInetProt, dwOption);
426
427     if(pBC || pUnkOuter || ppUnk || dwOption)
428         FIXME("Unsupported arguments\n");
429
430     return create_binding_protocol(szUrl, FALSE, ppOInetProt);
431 }
432
433 static HRESULT WINAPI InternetSession_SetSessionOption(IInternetSession *iface,
434         DWORD dwOption, LPVOID pBuffer, DWORD dwBufferLength, DWORD dwReserved)
435 {
436     FIXME("(%08x %p %d %d)\n", dwOption, pBuffer, dwBufferLength, dwReserved);
437     return E_NOTIMPL;
438 }
439
440 static const IInternetSessionVtbl InternetSessionVtbl = {
441     InternetSession_QueryInterface,
442     InternetSession_AddRef,
443     InternetSession_Release,
444     InternetSession_RegisterNameSpace,
445     InternetSession_UnregisterNameSpace,
446     InternetSession_RegisterMimeFilter,
447     InternetSession_UnregisterMimeFilter,
448     InternetSession_CreateBinding,
449     InternetSession_SetSessionOption
450 };
451
452 static IInternetSession InternetSession = { &InternetSessionVtbl };
453
454 /***********************************************************************
455  *           CoInternetGetSession (URLMON.@)
456  *
457  * Create a new internet session and return an IInternetSession interface
458  * representing it.
459  *
460  * PARAMS
461  *    dwSessionMode      [I] Mode for the internet session
462  *    ppIInternetSession [O] Destination for creates IInternetSession object
463  *    dwReserved         [I] Reserved, must be 0.
464  *
465  * RETURNS
466  *    Success: S_OK. ppIInternetSession contains the IInternetSession interface.
467  *    Failure: E_INVALIDARG, if any argument is invalid, or
468  *             E_OUTOFMEMORY if memory allocation fails.
469  */
470 HRESULT WINAPI CoInternetGetSession(DWORD dwSessionMode, IInternetSession **ppIInternetSession,
471         DWORD dwReserved)
472 {
473     TRACE("(%d %p %d)\n", dwSessionMode, ppIInternetSession, dwReserved);
474
475     if(dwSessionMode)
476         ERR("dwSessionMode=%d\n", dwSessionMode);
477     if(dwReserved)
478         ERR("dwReserved=%d\n", dwReserved);
479
480     IInternetSession_AddRef(&InternetSession);
481     *ppIInternetSession = &InternetSession;
482     return S_OK;
483 }
484
485 /**************************************************************************
486  *                 UrlMkGetSessionOption (URLMON.@)
487  */
488 static BOOL get_url_encoding(HKEY root, DWORD *encoding)
489 {
490     DWORD size = sizeof(DWORD), res, type;
491     HKEY hkey;
492
493     static const WCHAR wszUrlEncoding[] = {'U','r','l','E','n','c','o','d','i','n','g',0};
494
495     res = RegOpenKeyW(root, internet_settings_keyW, &hkey);
496     if(res != ERROR_SUCCESS)
497         return FALSE;
498
499     res = RegQueryValueExW(hkey, wszUrlEncoding, NULL, &type, (LPBYTE)encoding, &size);
500     RegCloseKey(hkey);
501
502     return res == ERROR_SUCCESS;
503 }
504
505 static LPWSTR user_agent;
506
507 static void ensure_useragent(void)
508 {
509     DWORD size = sizeof(DWORD), res, type;
510     HKEY hkey;
511
512     static const WCHAR user_agentW[] = {'U','s','e','r',' ','A','g','e','n','t',0};
513
514     if(user_agent)
515         return;
516
517     res = RegOpenKeyW(HKEY_CURRENT_USER, internet_settings_keyW, &hkey);
518     if(res != ERROR_SUCCESS)
519         return;
520
521     res = RegQueryValueExW(hkey, user_agentW, NULL, &type, NULL, &size);
522     if(res == ERROR_SUCCESS && type == REG_SZ) {
523         user_agent = heap_alloc(size);
524         res = RegQueryValueExW(hkey, user_agentW, NULL, &type, (LPBYTE)user_agent, &size);
525         if(res != ERROR_SUCCESS) {
526             heap_free(user_agent);
527             user_agent = NULL;
528         }
529     }else {
530         WARN("Could not find User Agent value: %u\n", res);
531     }
532
533     RegCloseKey(hkey);
534 }
535
536 LPWSTR get_useragent(void)
537 {
538     LPWSTR ret;
539
540     ensure_useragent();
541
542     EnterCriticalSection(&session_cs);
543     ret = heap_strdupW(user_agent);
544     LeaveCriticalSection(&session_cs);
545
546     return ret;
547 }
548
549 HRESULT WINAPI UrlMkGetSessionOption(DWORD dwOption, LPVOID pBuffer, DWORD dwBufferLength,
550                                      DWORD* pdwBufferLength, DWORD dwReserved)
551 {
552     TRACE("(%x, %p, %d, %p)\n", dwOption, pBuffer, dwBufferLength, pdwBufferLength);
553
554     if(dwReserved)
555         WARN("dwReserved = %d\n", dwReserved);
556
557     switch(dwOption) {
558     case URLMON_OPTION_USERAGENT: {
559         HRESULT hres = E_OUTOFMEMORY;
560         DWORD size;
561
562         if(!pdwBufferLength)
563             return E_INVALIDARG;
564
565         EnterCriticalSection(&session_cs);
566
567         ensure_useragent();
568         if(user_agent) {
569             size = WideCharToMultiByte(CP_ACP, 0, user_agent, -1, NULL, 0, NULL, NULL);
570             *pdwBufferLength = size;
571             if(size <= dwBufferLength) {
572                 if(pBuffer)
573                     WideCharToMultiByte(CP_ACP, 0, user_agent, -1, pBuffer, size, NULL, NULL);
574                 else
575                     hres = E_INVALIDARG;
576             }
577         }
578
579         LeaveCriticalSection(&session_cs);
580
581         /* Tests prove that we have to return E_OUTOFMEMORY on success. */
582         return hres;
583     }
584     case URLMON_OPTION_URL_ENCODING: {
585         DWORD encoding = 0;
586
587         if(!pBuffer || dwBufferLength < sizeof(DWORD) || !pdwBufferLength)
588             return E_INVALIDARG;
589
590         if(!get_url_encoding(HKEY_CURRENT_USER, &encoding))
591             get_url_encoding(HKEY_LOCAL_MACHINE, &encoding);
592
593         *pdwBufferLength = sizeof(DWORD);
594         *(DWORD*)pBuffer = encoding ? URL_ENCODING_DISABLE_UTF8 : URL_ENCODING_ENABLE_UTF8;
595         return S_OK;
596     }
597     default:
598         FIXME("unsupported option %x\n", dwOption);
599     }
600
601     return E_INVALIDARG;
602 }
603
604 /**************************************************************************
605  *                 UrlMkSetSessionOption (URLMON.@)
606  */
607 HRESULT WINAPI UrlMkSetSessionOption(DWORD dwOption, LPVOID pBuffer, DWORD dwBufferLength,
608         DWORD Reserved)
609 {
610     TRACE("(%x %p %x)\n", dwOption, pBuffer, dwBufferLength);
611
612     switch(dwOption) {
613     case URLMON_OPTION_USERAGENT: {
614         LPWSTR new_user_agent;
615         char *buf = pBuffer;
616         DWORD len, size;
617
618         if(!pBuffer || !dwBufferLength)
619             return E_INVALIDARG;
620
621         for(len=0; len<dwBufferLength && buf[len]; len++);
622
623         TRACE("Setting user agent %s\n", debugstr_an(buf, len));
624
625         size = MultiByteToWideChar(CP_ACP, 0, buf, len, NULL, 0);
626         new_user_agent = heap_alloc((size+1)*sizeof(WCHAR));
627         if(!new_user_agent)
628             return E_OUTOFMEMORY;
629         MultiByteToWideChar(CP_ACP, 0, buf, len, new_user_agent, size);
630         new_user_agent[size] = 0;
631
632         EnterCriticalSection(&session_cs);
633
634         heap_free(user_agent);
635         user_agent = new_user_agent;
636
637         LeaveCriticalSection(&session_cs);
638         break;
639     }
640     default:
641         FIXME("Unknown option %x\n", dwOption);
642         return E_INVALIDARG;
643     }
644
645     return S_OK;
646 }
647
648 /**************************************************************************
649  *                 ObtainUserAgentString (URLMON.@)
650  */
651 HRESULT WINAPI ObtainUserAgentString(DWORD dwOption, LPSTR pcszUAOut, DWORD *cbSize)
652 {
653     DWORD size;
654     HRESULT hres = E_FAIL;
655
656     TRACE("(%d %p %p)\n", dwOption, pcszUAOut, cbSize);
657
658     if(!pcszUAOut || !cbSize)
659         return E_INVALIDARG;
660
661     EnterCriticalSection(&session_cs);
662
663     ensure_useragent();
664     if(user_agent) {
665         size = WideCharToMultiByte(CP_ACP, 0, user_agent, -1, NULL, 0, NULL, NULL);
666
667         if(size <= *cbSize) {
668             WideCharToMultiByte(CP_ACP, 0, user_agent, -1, pcszUAOut, *cbSize, NULL, NULL);
669             hres = S_OK;
670         }else {
671             hres = E_OUTOFMEMORY;
672         }
673
674         *cbSize = size;
675     }
676
677     LeaveCriticalSection(&session_cs);
678     return hres;
679 }
680
681 void free_session(void)
682 {
683     heap_free(user_agent);
684 }