msxml3: Embed user/password in uri used to create a moniker.
[wine] / dlls / urlmon / protocol.c
1 /*
2  * Copyright 2007 Misha Koshelev
3  * Copyright 2009 Jacek Caban for CodeWeavers
4  *
5  * This library is free software; you can redistribute it and/or
6  * modify it under the terms of the GNU Lesser General Public
7  * License as published by the Free Software Foundation; either
8  * version 2.1 of the License, or (at your option) any later version.
9  *
10  * This library is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * Lesser General Public License for more details.
14  *
15  * You should have received a copy of the GNU Lesser General Public
16  * License along with this library; if not, write to the Free Software
17  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
18  */
19
20 #include "urlmon_main.h"
21
22 #include "wine/debug.h"
23
24 WINE_DEFAULT_DEBUG_CHANNEL(urlmon);
25
26 static inline HRESULT report_progress(Protocol *protocol, ULONG status_code, LPCWSTR status_text)
27 {
28     return IInternetProtocolSink_ReportProgress(protocol->protocol_sink, status_code, status_text);
29 }
30
31 static inline HRESULT report_result(Protocol *protocol, HRESULT hres)
32 {
33     if (!(protocol->flags & FLAG_RESULT_REPORTED) && protocol->protocol_sink) {
34         protocol->flags |= FLAG_RESULT_REPORTED;
35         IInternetProtocolSink_ReportResult(protocol->protocol_sink, hres, 0, NULL);
36     }
37
38     return hres;
39 }
40
41 static void report_data(Protocol *protocol)
42 {
43     DWORD bscf;
44
45     if((protocol->flags & FLAG_LAST_DATA_REPORTED) || !protocol->protocol_sink)
46         return;
47
48     if(protocol->flags & FLAG_FIRST_DATA_REPORTED) {
49         bscf = BSCF_INTERMEDIATEDATANOTIFICATION;
50     }else {
51         protocol->flags |= FLAG_FIRST_DATA_REPORTED;
52         bscf = BSCF_FIRSTDATANOTIFICATION;
53     }
54
55     if(protocol->flags & FLAG_ALL_DATA_READ && !(protocol->flags & FLAG_LAST_DATA_REPORTED)) {
56         protocol->flags |= FLAG_LAST_DATA_REPORTED;
57         bscf |= BSCF_LASTDATANOTIFICATION;
58     }
59
60     IInternetProtocolSink_ReportData(protocol->protocol_sink, bscf,
61             protocol->current_position+protocol->available_bytes,
62             protocol->content_length);
63 }
64
65 static void all_data_read(Protocol *protocol)
66 {
67     protocol->flags |= FLAG_ALL_DATA_READ;
68
69     report_data(protocol);
70     report_result(protocol, S_OK);
71 }
72
73 static HRESULT start_downloading(Protocol *protocol)
74 {
75     HRESULT hres;
76
77     hres = protocol->vtbl->start_downloading(protocol);
78     if(FAILED(hres)) {
79         protocol_close_connection(protocol);
80         report_result(protocol, hres);
81         return hres;
82     }
83
84     if(protocol->bindf & BINDF_NEEDFILE) {
85         WCHAR cache_file[MAX_PATH];
86         DWORD buflen = sizeof(cache_file);
87
88         if(InternetQueryOptionW(protocol->request, INTERNET_OPTION_DATAFILE_NAME, cache_file, &buflen)) {
89             report_progress(protocol, BINDSTATUS_CACHEFILENAMEAVAILABLE, cache_file);
90         }else {
91             FIXME("Could not get cache file\n");
92         }
93     }
94
95     protocol->flags |= FLAG_FIRST_CONTINUE_COMPLETE;
96     return S_OK;
97 }
98
99 HRESULT protocol_syncbinding(Protocol *protocol)
100 {
101     BOOL res;
102     HRESULT hres;
103
104     protocol->flags |= FLAG_SYNC_READ;
105
106     hres = start_downloading(protocol);
107     if(FAILED(hres))
108         return hres;
109
110     res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
111     if(res)
112         protocol->available_bytes = protocol->query_available;
113     else
114         WARN("InternetQueryDataAvailable failed: %u\n", GetLastError());
115
116     protocol->flags |= FLAG_FIRST_DATA_REPORTED|FLAG_LAST_DATA_REPORTED;
117     IInternetProtocolSink_ReportData(protocol->protocol_sink, BSCF_LASTDATANOTIFICATION|BSCF_DATAFULLYAVAILABLE,
118             protocol->available_bytes, protocol->content_length);
119     return S_OK;
120 }
121
122 static void request_complete(Protocol *protocol, INTERNET_ASYNC_RESULT *ar)
123 {
124     PROTOCOLDATA data;
125
126     TRACE("(%p)->(%p)\n", protocol, ar);
127
128     /* PROTOCOLDATA same as native */
129     memset(&data, 0, sizeof(data));
130     data.dwState = 0xf1000000;
131
132     if(ar->dwResult) {
133         protocol->flags |= FLAG_REQUEST_COMPLETE;
134
135         if(!protocol->request) {
136             TRACE("setting request handle %p\n", (HINTERNET)ar->dwResult);
137             protocol->request = (HINTERNET)ar->dwResult;
138         }
139
140         if(protocol->flags & FLAG_FIRST_CONTINUE_COMPLETE)
141             data.pData = UlongToPtr(BINDSTATUS_ENDDOWNLOADCOMPONENTS);
142         else
143             data.pData = UlongToPtr(BINDSTATUS_DOWNLOADINGDATA);
144
145     }else {
146         protocol->flags |= FLAG_ERROR;
147         data.pData = UlongToPtr(ar->dwError);
148     }
149
150     if (protocol->bindf & BINDF_FROMURLMON)
151         IInternetProtocolSink_Switch(protocol->protocol_sink, &data);
152     else
153         protocol_continue(protocol, &data);
154 }
155
156 static void WINAPI internet_status_callback(HINTERNET internet, DWORD_PTR context,
157         DWORD internet_status, LPVOID status_info, DWORD status_info_len)
158 {
159     Protocol *protocol = (Protocol*)context;
160
161     switch(internet_status) {
162     case INTERNET_STATUS_RESOLVING_NAME:
163         TRACE("%p INTERNET_STATUS_RESOLVING_NAME\n", protocol);
164         report_progress(protocol, BINDSTATUS_FINDINGRESOURCE, (LPWSTR)status_info);
165         break;
166
167     case INTERNET_STATUS_CONNECTING_TO_SERVER: {
168         WCHAR *info;
169
170         TRACE("%p INTERNET_STATUS_CONNECTING_TO_SERVER %s\n", protocol, (const char*)status_info);
171
172         info = heap_strdupAtoW(status_info);
173         if(!info)
174             return;
175
176         report_progress(protocol, BINDSTATUS_CONNECTING, info);
177         heap_free(info);
178         break;
179     }
180
181     case INTERNET_STATUS_SENDING_REQUEST:
182         TRACE("%p INTERNET_STATUS_SENDING_REQUEST\n", protocol);
183         report_progress(protocol, BINDSTATUS_SENDINGREQUEST, (LPWSTR)status_info);
184         break;
185
186     case INTERNET_STATUS_REDIRECT:
187         TRACE("%p INTERNET_STATUS_REDIRECT\n", protocol);
188         report_progress(protocol, BINDSTATUS_REDIRECTING, (LPWSTR)status_info);
189         break;
190
191     case INTERNET_STATUS_REQUEST_COMPLETE:
192         request_complete(protocol, status_info);
193         break;
194
195     case INTERNET_STATUS_HANDLE_CREATED:
196         TRACE("%p INTERNET_STATUS_HANDLE_CREATED\n", protocol);
197         IInternetProtocol_AddRef(protocol->protocol);
198         break;
199
200     case INTERNET_STATUS_HANDLE_CLOSING:
201         TRACE("%p INTERNET_STATUS_HANDLE_CLOSING\n", protocol);
202
203         if(*(HINTERNET *)status_info == protocol->request) {
204             protocol->request = NULL;
205             if(protocol->protocol_sink) {
206                 IInternetProtocolSink_Release(protocol->protocol_sink);
207                 protocol->protocol_sink = NULL;
208             }
209
210             if(protocol->bind_info.cbSize) {
211                 ReleaseBindInfo(&protocol->bind_info);
212                 memset(&protocol->bind_info, 0, sizeof(protocol->bind_info));
213             }
214         }else if(*(HINTERNET *)status_info == protocol->connection) {
215             protocol->connection = NULL;
216         }
217
218         IInternetProtocol_Release(protocol->protocol);
219         break;
220
221     default:
222         WARN("Unhandled Internet status callback %d\n", internet_status);
223     }
224 }
225
226 static HRESULT write_post_stream(Protocol *protocol)
227 {
228     BYTE buf[0x20000];
229     DWORD written;
230     ULONG size;
231     BOOL res;
232     HRESULT hres;
233
234     protocol->flags &= ~FLAG_REQUEST_COMPLETE;
235
236     while(1) {
237         size = 0;
238         hres = IStream_Read(protocol->post_stream, buf, sizeof(buf), &size);
239         if(FAILED(hres) || !size)
240             break;
241         res = InternetWriteFile(protocol->request, buf, size, &written);
242         if(!res) {
243             FIXME("InternetWriteFile failed: %u\n", GetLastError());
244             hres = E_FAIL;
245             break;
246         }
247     }
248
249     if(SUCCEEDED(hres)) {
250         IStream_Release(protocol->post_stream);
251         protocol->post_stream = NULL;
252
253         hres = protocol->vtbl->end_request(protocol);
254     }
255
256     if(FAILED(hres))
257         return report_result(protocol, hres);
258
259     return S_OK;
260 }
261
262 static HINTERNET create_internet_session(IInternetBindInfo *bind_info)
263 {
264     LPWSTR global_user_agent = NULL;
265     LPOLESTR user_agent = NULL;
266     ULONG size = 0;
267     HINTERNET ret;
268     HRESULT hres;
269
270     hres = IInternetBindInfo_GetBindString(bind_info, BINDSTRING_USER_AGENT, &user_agent, 1, &size);
271     if(hres != S_OK || !size)
272         global_user_agent = get_useragent();
273
274     ret = InternetOpenW(user_agent ? user_agent : global_user_agent, 0, NULL, NULL, INTERNET_FLAG_ASYNC);
275     heap_free(global_user_agent);
276     CoTaskMemFree(user_agent);
277     if(!ret) {
278         WARN("InternetOpen failed: %d\n", GetLastError());
279         return NULL;
280     }
281
282     InternetSetStatusCallbackW(ret, internet_status_callback);
283     return ret;
284 }
285
286 static HINTERNET internet_session;
287
288 HINTERNET get_internet_session(IInternetBindInfo *bind_info)
289 {
290     HINTERNET new_session;
291
292     if(internet_session)
293         return internet_session;
294
295     if(!bind_info)
296         return NULL;
297
298     new_session = create_internet_session(bind_info);
299     if(new_session && InterlockedCompareExchangePointer((void**)&internet_session, new_session, NULL))
300         InternetCloseHandle(new_session);
301
302     return internet_session;
303 }
304
305 HRESULT protocol_start(Protocol *protocol, IInternetProtocol *prot, IUri *uri,
306         IInternetProtocolSink *protocol_sink, IInternetBindInfo *bind_info)
307 {
308     DWORD request_flags;
309     HRESULT hres;
310
311     protocol->protocol = prot;
312
313     IInternetProtocolSink_AddRef(protocol_sink);
314     protocol->protocol_sink = protocol_sink;
315
316     memset(&protocol->bind_info, 0, sizeof(protocol->bind_info));
317     protocol->bind_info.cbSize = sizeof(BINDINFO);
318     hres = IInternetBindInfo_GetBindInfo(bind_info, &protocol->bindf, &protocol->bind_info);
319     if(hres != S_OK) {
320         WARN("GetBindInfo failed: %08x\n", hres);
321         return report_result(protocol, hres);
322     }
323
324     if(!(protocol->bindf & BINDF_FROMURLMON))
325         report_progress(protocol, BINDSTATUS_DIRECTBIND, NULL);
326
327     if(!get_internet_session(bind_info))
328         return report_result(protocol, INET_E_NO_SESSION);
329
330     request_flags = INTERNET_FLAG_KEEP_CONNECTION;
331     if(protocol->bindf & BINDF_NOWRITECACHE)
332         request_flags |= INTERNET_FLAG_NO_CACHE_WRITE;
333     if(protocol->bindf & BINDF_NEEDFILE)
334         request_flags |= INTERNET_FLAG_NEED_FILE;
335
336     hres = protocol->vtbl->open_request(protocol, uri, request_flags, internet_session, bind_info);
337     if(FAILED(hres)) {
338         protocol_close_connection(protocol);
339         return report_result(protocol, hres);
340     }
341
342     return S_OK;
343 }
344
345 HRESULT protocol_continue(Protocol *protocol, PROTOCOLDATA *data)
346 {
347     BOOL is_start;
348     HRESULT hres;
349
350     is_start = !data || data->pData == UlongToPtr(BINDSTATUS_DOWNLOADINGDATA);
351
352     if(!protocol->request) {
353         WARN("Expected request to be non-NULL\n");
354         return S_OK;
355     }
356
357     if(!protocol->protocol_sink) {
358         WARN("Expected IInternetProtocolSink pointer to be non-NULL\n");
359         return S_OK;
360     }
361
362     if(protocol->flags & FLAG_ERROR) {
363         protocol->flags &= ~FLAG_ERROR;
364         protocol->vtbl->on_error(protocol, PtrToUlong(data->pData));
365         return S_OK;
366     }
367
368     if(protocol->post_stream)
369         return write_post_stream(protocol);
370
371     if(is_start) {
372         hres = start_downloading(protocol);
373         if(FAILED(hres))
374             return S_OK;
375     }
376
377     if(!data || data->pData >= UlongToPtr(BINDSTATUS_DOWNLOADINGDATA)) {
378         if(!protocol->available_bytes) {
379             if(protocol->query_available) {
380                 protocol->available_bytes = protocol->query_available;
381             }else {
382                 BOOL res;
383
384                 /* InternetQueryDataAvailable may immediately fork and perform its asynchronous
385                  * read, so clear the flag _before_ calling so it does not incorrectly get cleared
386                  * after the status callback is called */
387                 protocol->flags &= ~FLAG_REQUEST_COMPLETE;
388                 res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
389                 if(res) {
390                     TRACE("available %u bytes\n", protocol->query_available);
391                     if(!protocol->query_available) {
392                         if(is_start) {
393                             TRACE("empty file\n");
394                             all_data_read(protocol);
395                         }else {
396                             WARN("unexpected end of file?\n");
397                             report_result(protocol, INET_E_DOWNLOAD_FAILURE);
398                         }
399                         return S_OK;
400                     }
401                     protocol->available_bytes = protocol->query_available;
402                 }else if(GetLastError() != ERROR_IO_PENDING) {
403                     protocol->flags |= FLAG_REQUEST_COMPLETE;
404                     WARN("InternetQueryDataAvailable failed: %d\n", GetLastError());
405                     report_result(protocol, INET_E_DATA_NOT_AVAILABLE);
406                     return S_OK;
407                 }
408             }
409
410             protocol->flags |= FLAG_REQUEST_COMPLETE;
411         }
412
413         report_data(protocol);
414     }
415
416     return S_OK;
417 }
418
419 HRESULT protocol_read(Protocol *protocol, void *buf, ULONG size, ULONG *read_ret)
420 {
421     ULONG read = 0;
422     BOOL res;
423     HRESULT hres = S_FALSE;
424
425     if(protocol->flags & FLAG_ALL_DATA_READ) {
426         *read_ret = 0;
427         return S_FALSE;
428     }
429
430     if(!(protocol->flags & FLAG_SYNC_READ) && (!(protocol->flags & FLAG_REQUEST_COMPLETE) || !protocol->available_bytes)) {
431         *read_ret = 0;
432         return E_PENDING;
433     }
434
435     while(read < size && protocol->available_bytes) {
436         ULONG len;
437
438         res = InternetReadFile(protocol->request, ((BYTE *)buf)+read,
439                 protocol->available_bytes > size-read ? size-read : protocol->available_bytes, &len);
440         if(!res) {
441             WARN("InternetReadFile failed: %d\n", GetLastError());
442             hres = INET_E_DOWNLOAD_FAILURE;
443             report_result(protocol, hres);
444             break;
445         }
446
447         if(!len) {
448             all_data_read(protocol);
449             break;
450         }
451
452         read += len;
453         protocol->current_position += len;
454         protocol->available_bytes -= len;
455
456         TRACE("current_position %d, available_bytes %d\n", protocol->current_position, protocol->available_bytes);
457
458         if(!protocol->available_bytes) {
459             /* InternetQueryDataAvailable may immediately fork and perform its asynchronous
460              * read, so clear the flag _before_ calling so it does not incorrectly get cleared
461              * after the status callback is called */
462             protocol->flags &= ~FLAG_REQUEST_COMPLETE;
463             res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
464             if(!res) {
465                 if (GetLastError() == ERROR_IO_PENDING) {
466                     hres = E_PENDING;
467                 }else {
468                     WARN("InternetQueryDataAvailable failed: %d\n", GetLastError());
469                     hres = INET_E_DATA_NOT_AVAILABLE;
470                     report_result(protocol, hres);
471                 }
472                 break;
473             }
474
475             if(!protocol->query_available) {
476                 all_data_read(protocol);
477                 break;
478             }
479
480             protocol->available_bytes = protocol->query_available;
481         }
482     }
483
484     *read_ret = read;
485
486     if (hres != E_PENDING)
487         protocol->flags |= FLAG_REQUEST_COMPLETE;
488     if(FAILED(hres))
489         return hres;
490
491     return read ? S_OK : S_FALSE;
492 }
493
494 HRESULT protocol_lock_request(Protocol *protocol)
495 {
496     if (!InternetLockRequestFile(protocol->request, &protocol->lock))
497         WARN("InternetLockRequest failed: %d\n", GetLastError());
498
499     return S_OK;
500 }
501
502 HRESULT protocol_unlock_request(Protocol *protocol)
503 {
504     if(!protocol->lock)
505         return S_OK;
506
507     if(!InternetUnlockRequestFile(protocol->lock))
508         WARN("InternetUnlockRequest failed: %d\n", GetLastError());
509     protocol->lock = 0;
510
511     return S_OK;
512 }
513
514 HRESULT protocol_abort(Protocol *protocol, HRESULT reason)
515 {
516     if(!protocol->protocol_sink)
517         return S_OK;
518
519     if(protocol->flags & FLAG_RESULT_REPORTED)
520         return INET_E_RESULT_DISPATCHED;
521
522     report_result(protocol, reason);
523     return S_OK;
524 }
525
526 void protocol_close_connection(Protocol *protocol)
527 {
528     protocol->vtbl->close_connection(protocol);
529
530     if(protocol->request)
531         InternetCloseHandle(protocol->request);
532
533     if(protocol->connection)
534         InternetCloseHandle(protocol->connection);
535
536     if(protocol->post_stream) {
537         IStream_Release(protocol->post_stream);
538         protocol->post_stream = NULL;
539     }
540
541     protocol->flags = 0;
542 }