oleaut32/tests: Replace realloc() with HeapReAlloc().
[wine] / dlls / urlmon / protocol.c
1 /*
2  * Copyright 2007 Misha Koshelev
3  * Copyright 2009 Jacek Caban for CodeWeavers
4  *
5  * This library is free software; you can redistribute it and/or
6  * modify it under the terms of the GNU Lesser General Public
7  * License as published by the Free Software Foundation; either
8  * version 2.1 of the License, or (at your option) any later version.
9  *
10  * This library is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * Lesser General Public License for more details.
14  *
15  * You should have received a copy of the GNU Lesser General Public
16  * License along with this library; if not, write to the Free Software
17  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
18  */
19
20 #include "urlmon_main.h"
21
22 #include "wine/debug.h"
23
24 WINE_DEFAULT_DEBUG_CHANNEL(urlmon);
25
26 /* Flags are needed for, among other things, return HRESULTs from the Read function
27  * to conform to native. For example, Read returns:
28  *
29  * 1. E_PENDING if called before the request has completed,
30  *        (flags = 0)
31  * 2. S_FALSE after all data has been read and S_OK has been reported,
32  *        (flags = FLAG_REQUEST_COMPLETE | FLAG_ALL_DATA_READ | FLAG_RESULT_REPORTED)
33  * 3. INET_E_DATA_NOT_AVAILABLE if InternetQueryDataAvailable fails. The first time
34  *    this occurs, INET_E_DATA_NOT_AVAILABLE will also be reported to the sink,
35  *        (flags = FLAG_REQUEST_COMPLETE)
36  *    but upon subsequent calls to Read no reporting will take place, yet
37  *    InternetQueryDataAvailable will still be called, and, on failure,
38  *    INET_E_DATA_NOT_AVAILABLE will still be returned.
39  *        (flags = FLAG_REQUEST_COMPLETE | FLAG_RESULT_REPORTED)
40  *
41  * FLAG_FIRST_DATA_REPORTED and FLAG_LAST_DATA_REPORTED are needed for proper
42  * ReportData reporting. For example, if OnResponse returns S_OK, Continue will
43  * report BSCF_FIRSTDATANOTIFICATION, and when all data has been read Read will
44  * report BSCF_INTERMEDIATEDATANOTIFICATION|BSCF_LASTDATANOTIFICATION. However,
45  * if OnResponse does not return S_OK, Continue will not report data, and Read
46  * will report BSCF_FIRSTDATANOTIFICATION|BSCF_LASTDATANOTIFICATION when all
47  * data has been read.
48  */
49 #define FLAG_REQUEST_COMPLETE         0x0001
50 #define FLAG_FIRST_CONTINUE_COMPLETE  0x0002
51 #define FLAG_FIRST_DATA_REPORTED      0x0004
52 #define FLAG_ALL_DATA_READ            0x0008
53 #define FLAG_LAST_DATA_REPORTED       0x0010
54 #define FLAG_RESULT_REPORTED          0x0020
55
56 static inline HRESULT report_progress(Protocol *protocol, ULONG status_code, LPCWSTR status_text)
57 {
58     return IInternetProtocolSink_ReportProgress(protocol->protocol_sink, status_code, status_text);
59 }
60
61 static inline HRESULT report_result(Protocol *protocol, HRESULT hres)
62 {
63     if (!(protocol->flags & FLAG_RESULT_REPORTED) && protocol->protocol_sink) {
64         protocol->flags |= FLAG_RESULT_REPORTED;
65         IInternetProtocolSink_ReportResult(protocol->protocol_sink, hres, 0, NULL);
66     }
67
68     return hres;
69 }
70
71 static void report_data(Protocol *protocol)
72 {
73     DWORD bscf;
74
75     if((protocol->flags & FLAG_LAST_DATA_REPORTED) || !protocol->protocol_sink)
76         return;
77
78     if(protocol->flags & FLAG_FIRST_DATA_REPORTED) {
79         bscf = BSCF_INTERMEDIATEDATANOTIFICATION;
80     }else {
81         protocol->flags |= FLAG_FIRST_DATA_REPORTED;
82         bscf = BSCF_FIRSTDATANOTIFICATION;
83     }
84
85     if(protocol->flags & FLAG_ALL_DATA_READ && !(protocol->flags & FLAG_LAST_DATA_REPORTED)) {
86         protocol->flags |= FLAG_LAST_DATA_REPORTED;
87         bscf |= BSCF_LASTDATANOTIFICATION;
88     }
89
90     IInternetProtocolSink_ReportData(protocol->protocol_sink, bscf,
91             protocol->current_position+protocol->available_bytes,
92             protocol->content_length);
93 }
94
95 static void all_data_read(Protocol *protocol)
96 {
97     protocol->flags |= FLAG_ALL_DATA_READ;
98
99     report_data(protocol);
100     report_result(protocol, S_OK);
101 }
102
103 static void request_complete(Protocol *protocol, INTERNET_ASYNC_RESULT *ar)
104 {
105     PROTOCOLDATA data;
106
107     if(!ar->dwResult) {
108         WARN("request failed: %d\n", ar->dwError);
109         return;
110     }
111
112     protocol->flags |= FLAG_REQUEST_COMPLETE;
113
114     if(!protocol->request) {
115         TRACE("setting request handle %p\n", (HINTERNET)ar->dwResult);
116         protocol->request = (HINTERNET)ar->dwResult;
117     }
118
119     /* PROTOCOLDATA same as native */
120     memset(&data, 0, sizeof(data));
121     data.dwState = 0xf1000000;
122     if(protocol->flags & FLAG_FIRST_CONTINUE_COMPLETE)
123         data.pData = (LPVOID)BINDSTATUS_ENDDOWNLOADCOMPONENTS;
124     else
125         data.pData = (LPVOID)BINDSTATUS_DOWNLOADINGDATA;
126
127     if (protocol->bindf & BINDF_FROMURLMON)
128         IInternetProtocolSink_Switch(protocol->protocol_sink, &data);
129     else
130         protocol_continue(protocol, &data);
131 }
132
133 static void WINAPI internet_status_callback(HINTERNET internet, DWORD_PTR context,
134         DWORD internet_status, LPVOID status_info, DWORD status_info_len)
135 {
136     Protocol *protocol = (Protocol*)context;
137
138     switch(internet_status) {
139     case INTERNET_STATUS_RESOLVING_NAME:
140         TRACE("%p INTERNET_STATUS_RESOLVING_NAME\n", protocol);
141         report_progress(protocol, BINDSTATUS_FINDINGRESOURCE, (LPWSTR)status_info);
142         break;
143
144     case INTERNET_STATUS_CONNECTING_TO_SERVER:
145         TRACE("%p INTERNET_STATUS_CONNECTING_TO_SERVER\n", protocol);
146         report_progress(protocol, BINDSTATUS_CONNECTING, (LPWSTR)status_info);
147         break;
148
149     case INTERNET_STATUS_SENDING_REQUEST:
150         TRACE("%p INTERNET_STATUS_SENDING_REQUEST\n", protocol);
151         report_progress(protocol, BINDSTATUS_SENDINGREQUEST, (LPWSTR)status_info);
152         break;
153
154     case INTERNET_STATUS_REQUEST_COMPLETE:
155         request_complete(protocol, status_info);
156         break;
157
158     case INTERNET_STATUS_HANDLE_CREATED:
159         TRACE("%p INTERNET_STATUS_HANDLE_CREATED\n", protocol);
160         IInternetProtocol_AddRef(protocol->protocol);
161         break;
162
163     case INTERNET_STATUS_HANDLE_CLOSING:
164         TRACE("%p INTERNET_STATUS_HANDLE_CLOSING\n", protocol);
165
166         if(*(HINTERNET *)status_info == protocol->request) {
167             protocol->request = NULL;
168             if(protocol->protocol_sink) {
169                 IInternetProtocolSink_Release(protocol->protocol_sink);
170                 protocol->protocol_sink = NULL;
171             }
172
173             if(protocol->bind_info.cbSize) {
174                 ReleaseBindInfo(&protocol->bind_info);
175                 memset(&protocol->bind_info, 0, sizeof(protocol->bind_info));
176             }
177         }else if(*(HINTERNET *)status_info == protocol->connection) {
178             protocol->connection = NULL;
179         }
180
181         IInternetProtocol_Release(protocol->protocol);
182         break;
183
184     default:
185         WARN("Unhandled Internet status callback %d\n", internet_status);
186     }
187 }
188
189 HRESULT protocol_start(Protocol *protocol, IInternetProtocol *prot, LPCWSTR url,
190         IInternetProtocolSink *protocol_sink, IInternetBindInfo *bind_info)
191 {
192     LPOLESTR user_agent = NULL;
193     DWORD request_flags;
194     ULONG size = 0;
195     HRESULT hres;
196
197     protocol->protocol = prot;
198
199     IInternetProtocolSink_AddRef(protocol_sink);
200     protocol->protocol_sink = protocol_sink;
201
202     memset(&protocol->bind_info, 0, sizeof(protocol->bind_info));
203     protocol->bind_info.cbSize = sizeof(BINDINFO);
204     hres = IInternetBindInfo_GetBindInfo(bind_info, &protocol->bindf, &protocol->bind_info);
205     if(hres != S_OK) {
206         WARN("GetBindInfo failed: %08x\n", hres);
207         return report_result(protocol, hres);
208     }
209
210     if(!(protocol->bindf & BINDF_FROMURLMON))
211         report_progress(protocol, BINDSTATUS_DIRECTBIND, NULL);
212
213     hres = IInternetBindInfo_GetBindString(bind_info, BINDSTRING_USER_AGENT, &user_agent, 1, &size);
214     if (hres != S_OK || !size) {
215         DWORD len;
216         CHAR null_char = 0;
217         LPSTR user_agenta = NULL;
218
219         len = 0;
220         if ((hres = ObtainUserAgentString(0, &null_char, &len)) != E_OUTOFMEMORY) {
221             WARN("ObtainUserAgentString failed: %08x\n", hres);
222         }else if (!(user_agenta = heap_alloc(len*sizeof(CHAR)))) {
223             WARN("Out of memory\n");
224         }else if ((hres = ObtainUserAgentString(0, user_agenta, &len)) != S_OK) {
225             WARN("ObtainUserAgentString failed: %08x\n", hres);
226         }else {
227             if(!(user_agent = CoTaskMemAlloc((len)*sizeof(WCHAR))))
228                 WARN("Out of memory\n");
229             else
230                 MultiByteToWideChar(CP_ACP, 0, user_agenta, -1, user_agent, len);
231         }
232         heap_free(user_agenta);
233     }
234
235     protocol->internet = InternetOpenW(user_agent, 0, NULL, NULL, INTERNET_FLAG_ASYNC);
236     CoTaskMemFree(user_agent);
237     if(!protocol->internet) {
238         WARN("InternetOpen failed: %d\n", GetLastError());
239         return report_result(protocol, INET_E_NO_SESSION);
240     }
241
242     /* Native does not check for success of next call, so we won't either */
243     InternetSetStatusCallbackW(protocol->internet, internet_status_callback);
244
245     request_flags = INTERNET_FLAG_KEEP_CONNECTION;
246     if(protocol->bindf & BINDF_NOWRITECACHE)
247         request_flags |= INTERNET_FLAG_NO_CACHE_WRITE;
248     if(protocol->bindf & BINDF_NEEDFILE)
249         request_flags |= INTERNET_FLAG_NEED_FILE;
250
251     hres = protocol->vtbl->open_request(protocol, url, request_flags, bind_info);
252     if(FAILED(hres)) {
253         protocol_close_connection(protocol);
254         return report_result(protocol, hres);
255     }
256
257     return S_OK;
258 }
259
260 HRESULT protocol_continue(Protocol *protocol, PROTOCOLDATA *data)
261 {
262     HRESULT hres;
263
264     if (!data) {
265         WARN("Expected pProtocolData to be non-NULL\n");
266         return S_OK;
267     }
268
269     if(!protocol->request) {
270         WARN("Expected request to be non-NULL\n");
271         return S_OK;
272     }
273
274     if(!protocol->protocol_sink) {
275         WARN("Expected IInternetProtocolSink pointer to be non-NULL\n");
276         return S_OK;
277     }
278
279     if(data->pData == (LPVOID)BINDSTATUS_DOWNLOADINGDATA) {
280         hres = protocol->vtbl->start_downloading(protocol);
281         if(FAILED(hres)) {
282             protocol_close_connection(protocol);
283             report_result(protocol, hres);
284             return S_OK;
285         }
286
287         if(protocol->bindf & BINDF_NEEDFILE) {
288             WCHAR cache_file[MAX_PATH];
289             DWORD buflen = sizeof(cache_file);
290
291             if(InternetQueryOptionW(protocol->request, INTERNET_OPTION_DATAFILE_NAME,
292                     cache_file, &buflen)) {
293                 report_progress(protocol, BINDSTATUS_CACHEFILENAMEAVAILABLE, cache_file);
294             }else {
295                 FIXME("Could not get cache file\n");
296             }
297         }
298
299         protocol->flags |= FLAG_FIRST_CONTINUE_COMPLETE;
300     }
301
302     if(data->pData >= (LPVOID)BINDSTATUS_DOWNLOADINGDATA) {
303         BOOL res;
304
305         /* InternetQueryDataAvailable may immediately fork and perform its asynchronous
306          * read, so clear the flag _before_ calling so it does not incorrectly get cleared
307          * after the status callback is called */
308         protocol->flags &= ~FLAG_REQUEST_COMPLETE;
309         res = InternetQueryDataAvailable(protocol->request, &protocol->available_bytes, 0, 0);
310         if(res) {
311             protocol->flags |= FLAG_REQUEST_COMPLETE;
312             report_data(protocol);
313         }else if(GetLastError() != ERROR_IO_PENDING) {
314             protocol->flags |= FLAG_REQUEST_COMPLETE;
315             WARN("InternetQueryDataAvailable failed: %d\n", GetLastError());
316             report_result(protocol, INET_E_DATA_NOT_AVAILABLE);
317         }
318     }
319
320     return S_OK;
321 }
322
323 HRESULT protocol_read(Protocol *protocol, void *buf, ULONG size, ULONG *read_ret)
324 {
325     ULONG read = 0;
326     BOOL res;
327     HRESULT hres = S_FALSE;
328
329     if(!(protocol->flags & FLAG_REQUEST_COMPLETE)) {
330         *read_ret = 0;
331         return E_PENDING;
332     }
333
334     if(protocol->flags & FLAG_ALL_DATA_READ) {
335         *read_ret = 0;
336         return S_FALSE;
337     }
338
339     while(read < size) {
340         if(protocol->available_bytes) {
341             ULONG len;
342
343             res = InternetReadFile(protocol->request, ((BYTE *)buf)+read,
344                     protocol->available_bytes > size-read ? size-read : protocol->available_bytes, &len);
345             if(!res) {
346                 WARN("InternetReadFile failed: %d\n", GetLastError());
347                 hres = INET_E_DOWNLOAD_FAILURE;
348                 report_result(protocol, hres);
349                 break;
350             }
351
352             if(!len) {
353                 all_data_read(protocol);
354                 break;
355             }
356
357             read += len;
358             protocol->current_position += len;
359             protocol->available_bytes -= len;
360         }else {
361             /* InternetQueryDataAvailable may immediately fork and perform its asynchronous
362              * read, so clear the flag _before_ calling so it does not incorrectly get cleared
363              * after the status callback is called */
364             protocol->flags &= ~FLAG_REQUEST_COMPLETE;
365             res = InternetQueryDataAvailable(protocol->request, &protocol->available_bytes, 0, 0);
366             if(!res) {
367                 if (GetLastError() == ERROR_IO_PENDING) {
368                     hres = E_PENDING;
369                 }else {
370                     WARN("InternetQueryDataAvailable failed: %d\n", GetLastError());
371                     hres = INET_E_DATA_NOT_AVAILABLE;
372                     report_result(protocol, hres);
373                 }
374                 break;
375             }
376
377             if(!protocol->available_bytes) {
378                 all_data_read(protocol);
379                 break;
380             }
381         }
382     }
383
384     *read_ret = read;
385
386     if (hres != E_PENDING)
387         protocol->flags |= FLAG_REQUEST_COMPLETE;
388     if(FAILED(hres))
389         return hres;
390
391     return read ? S_OK : S_FALSE;
392 }
393
394 HRESULT protocol_lock_request(Protocol *protocol)
395 {
396     if (!InternetLockRequestFile(protocol->request, &protocol->lock))
397         WARN("InternetLockRequest failed: %d\n", GetLastError());
398
399     return S_OK;
400 }
401
402 HRESULT protocol_unlock_request(Protocol *protocol)
403 {
404     if(!protocol->lock)
405         return S_OK;
406
407     if(!InternetUnlockRequestFile(protocol->lock))
408         WARN("InternetUnlockRequest failed: %d\n", GetLastError());
409     protocol->lock = 0;
410
411     return S_OK;
412 }
413
414 void protocol_close_connection(Protocol *protocol)
415 {
416     protocol->vtbl->close_connection(protocol);
417
418     if(protocol->request)
419         InternetCloseHandle(protocol->request);
420
421     if(protocol->connection)
422         InternetCloseHandle(protocol->connection);
423
424     if(protocol->internet) {
425         InternetCloseHandle(protocol->internet);
426         protocol->internet = 0;
427     }
428
429     protocol->flags = 0;
430 }