dplayx: Introduce impl_from_IDirectPlayLobby3A().
[wine] / dlls / wininet / netconnection.c
1 /*
2  * Wininet - networking layer. Uses unix sockets.
3  *
4  * Copyright 2002 TransGaming Technologies Inc.
5  * Copyright 2013 Jacek Caban for CodeWeavers
6  *
7  * David Hammerton
8  *
9  * This library is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 2.1 of the License, or (at your option) any later version.
13  *
14  * This library is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with this library; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
22  */
23
24 #include "config.h"
25 #include "wine/port.h"
26
27 #define NONAMELESSUNION
28
29 #if defined(__MINGW32__) || defined (_MSC_VER)
30 #include <ws2tcpip.h>
31 #endif
32
33 #include <sys/types.h>
34 #ifdef HAVE_POLL_H
35 #include <poll.h>
36 #endif
37 #ifdef HAVE_SYS_POLL_H
38 # include <sys/poll.h>
39 #endif
40 #ifdef HAVE_SYS_TIME_H
41 # include <sys/time.h>
42 #endif
43 #ifdef HAVE_SYS_SOCKET_H
44 # include <sys/socket.h>
45 #endif
46 #ifdef HAVE_SYS_FILIO_H
47 # include <sys/filio.h>
48 #endif
49 #ifdef HAVE_UNISTD_H
50 # include <unistd.h>
51 #endif
52 #ifdef HAVE_SYS_IOCTL_H
53 # include <sys/ioctl.h>
54 #endif
55 #include <time.h>
56 #ifdef HAVE_NETDB_H
57 # include <netdb.h>
58 #endif
59 #ifdef HAVE_NETINET_IN_H
60 # include <netinet/in.h>
61 #endif
62 #ifdef HAVE_NETINET_TCP_H
63 # include <netinet/tcp.h>
64 #endif
65
66 #include <stdarg.h>
67 #include <stdlib.h>
68 #include <string.h>
69 #include <stdio.h>
70 #include <errno.h>
71 #include <assert.h>
72
73 #include "wine/library.h"
74 #include "windef.h"
75 #include "winbase.h"
76 #include "wininet.h"
77 #include "winerror.h"
78
79 #include "wine/debug.h"
80 #include "internet.h"
81
82 /* To avoid conflicts with the Unix socket headers. we only need it for
83  * the error codes anyway. */
84 #define USE_WS_PREFIX
85 #include "winsock2.h"
86
87 #define RESPONSE_TIMEOUT        30            /* FROM internet.c */
88
89
90 WINE_DEFAULT_DEBUG_CHANNEL(wininet);
91
92 /* FIXME!!!!!!
93  *    This should use winsock - To use winsock the functions will have to change a bit
94  *        as they are designed for unix sockets.
95  */
96
97 static DWORD netconn_verify_cert(netconn_t *conn, PCCERT_CONTEXT cert, HCERTSTORE store)
98 {
99     BOOL ret;
100     CERT_CHAIN_PARA chainPara = { sizeof(chainPara), { 0 } };
101     PCCERT_CHAIN_CONTEXT chain;
102     char oid_server_auth[] = szOID_PKIX_KP_SERVER_AUTH;
103     char *server_auth[] = { oid_server_auth };
104     DWORD err = ERROR_SUCCESS, errors;
105
106     static const DWORD supportedErrors =
107         CERT_TRUST_IS_NOT_TIME_VALID |
108         CERT_TRUST_IS_UNTRUSTED_ROOT |
109         CERT_TRUST_IS_PARTIAL_CHAIN |
110         CERT_TRUST_IS_NOT_VALID_FOR_USAGE;
111
112     TRACE("verifying %s\n", debugstr_w(conn->server->name));
113
114     chainPara.RequestedUsage.Usage.cUsageIdentifier = 1;
115     chainPara.RequestedUsage.Usage.rgpszUsageIdentifier = server_auth;
116     if (!(ret = CertGetCertificateChain(NULL, cert, NULL, store, &chainPara, 0, NULL, &chain))) {
117         TRACE("failed\n");
118         return GetLastError();
119     }
120
121     errors = chain->TrustStatus.dwErrorStatus;
122
123     do {
124         /* This seems strange, but that's what tests show */
125         if(errors & CERT_TRUST_IS_PARTIAL_CHAIN) {
126             WARN("ERROR_INTERNET_SEC_CERT_REV_FAILED\n");
127             err = ERROR_INTERNET_SEC_CERT_REV_FAILED;
128             if(conn->mask_errors)
129                 conn->security_flags |= _SECURITY_FLAG_CERT_REV_FAILED;
130             if(!(conn->security_flags & SECURITY_FLAG_IGNORE_REVOCATION))
131                 break;
132         }
133
134         if (chain->TrustStatus.dwErrorStatus & ~supportedErrors) {
135             WARN("error status %x\n", chain->TrustStatus.dwErrorStatus & ~supportedErrors);
136             err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_INVALID_CERT;
137             errors &= supportedErrors;
138             if(!conn->mask_errors)
139                 break;
140             WARN("unknown error flags\n");
141         }
142
143         if(errors & CERT_TRUST_IS_NOT_TIME_VALID) {
144             WARN("CERT_TRUST_IS_NOT_TIME_VALID\n");
145             if(!(conn->security_flags & SECURITY_FLAG_IGNORE_CERT_DATE_INVALID)) {
146                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_CERT_DATE_INVALID;
147                 if(!conn->mask_errors)
148                     break;
149                 conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_DATE;
150             }
151             errors &= ~CERT_TRUST_IS_NOT_TIME_VALID;
152         }
153
154         if(errors & CERT_TRUST_IS_UNTRUSTED_ROOT) {
155             WARN("CERT_TRUST_IS_UNTRUSTED_ROOT\n");
156             if(!(conn->security_flags & SECURITY_FLAG_IGNORE_UNKNOWN_CA)) {
157                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_INVALID_CA;
158                 if(!conn->mask_errors)
159                     break;
160                 conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_CA;
161             }
162             errors &= ~CERT_TRUST_IS_UNTRUSTED_ROOT;
163         }
164
165         if(errors & CERT_TRUST_IS_PARTIAL_CHAIN) {
166             WARN("CERT_TRUST_IS_PARTIAL_CHAIN\n");
167             if(!(conn->security_flags & SECURITY_FLAG_IGNORE_UNKNOWN_CA)) {
168                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_INVALID_CA;
169                 if(!conn->mask_errors)
170                     break;
171                 conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_CA;
172             }
173             errors &= ~CERT_TRUST_IS_PARTIAL_CHAIN;
174         }
175
176         if(errors & CERT_TRUST_IS_NOT_VALID_FOR_USAGE) {
177             WARN("CERT_TRUST_IS_NOT_VALID_FOR_USAGE\n");
178             if(!(conn->security_flags & SECURITY_FLAG_IGNORE_WRONG_USAGE)) {
179                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_INVALID_CERT;
180                 if(!conn->mask_errors)
181                     break;
182                 WARN("CERT_TRUST_IS_NOT_VALID_FOR_USAGE, unknown error flags\n");
183             }
184             errors &= ~CERT_TRUST_IS_NOT_VALID_FOR_USAGE;
185         }
186
187         if(err == ERROR_INTERNET_SEC_CERT_REV_FAILED) {
188             assert(conn->security_flags & SECURITY_FLAG_IGNORE_REVOCATION);
189             err = ERROR_SUCCESS;
190         }
191     }while(0);
192
193     if(!err || conn->mask_errors) {
194         CERT_CHAIN_POLICY_PARA policyPara;
195         SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslExtraPolicyPara;
196         CERT_CHAIN_POLICY_STATUS policyStatus;
197         CERT_CHAIN_CONTEXT chainCopy;
198
199         /* Clear chain->TrustStatus.dwErrorStatus so
200          * CertVerifyCertificateChainPolicy will verify additional checks
201          * rather than stopping with an existing, ignored error.
202          */
203         memcpy(&chainCopy, chain, sizeof(chainCopy));
204         chainCopy.TrustStatus.dwErrorStatus = 0;
205         sslExtraPolicyPara.u.cbSize = sizeof(sslExtraPolicyPara);
206         sslExtraPolicyPara.dwAuthType = AUTHTYPE_SERVER;
207         sslExtraPolicyPara.pwszServerName = conn->server->name;
208         sslExtraPolicyPara.fdwChecks = conn->security_flags;
209         policyPara.cbSize = sizeof(policyPara);
210         policyPara.dwFlags = 0;
211         policyPara.pvExtraPolicyPara = &sslExtraPolicyPara;
212         ret = CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL,
213                 &chainCopy, &policyPara, &policyStatus);
214         /* Any error in the policy status indicates that the
215          * policy couldn't be verified.
216          */
217         if(ret) {
218             if(policyStatus.dwError == CERT_E_CN_NO_MATCH) {
219                 WARN("CERT_E_CN_NO_MATCH\n");
220                 if(conn->mask_errors)
221                     conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_CN;
222                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_CERT_CN_INVALID;
223             }else if(policyStatus.dwError) {
224                 WARN("policyStatus.dwError %x\n", policyStatus.dwError);
225                 if(conn->mask_errors)
226                     WARN("unknown error flags for policy status %x\n", policyStatus.dwError);
227                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_INVALID_CERT;
228             }
229         }else {
230             err = GetLastError();
231         }
232     }
233
234     if(err) {
235         WARN("failed %u\n", err);
236         CertFreeCertificateChain(chain);
237         if(conn->server->cert_chain) {
238             CertFreeCertificateChain(conn->server->cert_chain);
239             conn->server->cert_chain = NULL;
240         }
241         if(conn->mask_errors)
242             conn->server->security_flags |= conn->security_flags & _SECURITY_ERROR_FLAGS_MASK;
243         return err;
244     }
245
246     /* FIXME: Reuse cached chain */
247     if(conn->server->cert_chain)
248         CertFreeCertificateChain(chain);
249     else
250         conn->server->cert_chain = chain;
251     return ERROR_SUCCESS;
252 }
253
254 static SecHandle cred_handle, compat_cred_handle;
255 static BOOL cred_handle_initialized, have_compat_cred_handle;
256
257 static CRITICAL_SECTION init_sechandle_cs;
258 static CRITICAL_SECTION_DEBUG init_sechandle_cs_debug = {
259     0, 0, &init_sechandle_cs,
260     { &init_sechandle_cs_debug.ProcessLocksList,
261       &init_sechandle_cs_debug.ProcessLocksList },
262     0, 0, { (DWORD_PTR)(__FILE__ ": init_sechandle_cs") }
263 };
264 static CRITICAL_SECTION init_sechandle_cs = { &init_sechandle_cs_debug, -1, 0, 0, 0, 0 };
265
266 static BOOL ensure_cred_handle(void)
267 {
268     SECURITY_STATUS res = SEC_E_OK;
269
270     EnterCriticalSection(&init_sechandle_cs);
271
272     if(!cred_handle_initialized) {
273         SCHANNEL_CRED cred = {SCHANNEL_CRED_VERSION};
274         SecPkgCred_SupportedProtocols prots;
275
276         res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
277                 NULL, NULL, &cred_handle, NULL);
278         if(res == SEC_E_OK) {
279             res = QueryCredentialsAttributesA(&cred_handle, SECPKG_ATTR_SUPPORTED_PROTOCOLS, &prots);
280             if(res != SEC_E_OK || (prots.grbitProtocol & SP_PROT_TLS1_1PLUS_CLIENT)) {
281                 cred.grbitEnabledProtocols = prots.grbitProtocol & ~SP_PROT_TLS1_1PLUS_CLIENT;
282                 res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
283                        NULL, NULL, &compat_cred_handle, NULL);
284                 have_compat_cred_handle = res == SEC_E_OK;
285             }
286         }
287
288         cred_handle_initialized = res == SEC_E_OK;
289     }
290
291     LeaveCriticalSection(&init_sechandle_cs);
292
293     if(res != SEC_E_OK) {
294         WARN("Failed: %08x\n", res);
295         return FALSE;
296     }
297
298     return TRUE;
299 }
300
301 static DWORD create_netconn_socket(server_t *server, netconn_t *netconn, DWORD timeout)
302 {
303     int result;
304     ULONG flag;
305
306     assert(server->addr_len);
307     result = netconn->socket = socket(server->addr.ss_family, SOCK_STREAM, 0);
308     if(result != -1) {
309         flag = 1;
310         ioctlsocket(netconn->socket, FIONBIO, &flag);
311         result = connect(netconn->socket, (struct sockaddr*)&server->addr, server->addr_len);
312         if(result == -1)
313         {
314             if (sock_get_error(errno) == WSAEINPROGRESS) {
315                 struct pollfd pfd;
316                 int res;
317
318                 pfd.fd = netconn->socket;
319                 pfd.events = POLLOUT;
320                 res = poll(&pfd, 1, timeout);
321                 if (!res)
322                 {
323                     closesocket(netconn->socket);
324                     return ERROR_INTERNET_CANNOT_CONNECT;
325                 }
326                 else if (res > 0)
327                 {
328                     int err;
329                     socklen_t len = sizeof(err);
330                     if (!getsockopt(netconn->socket, SOL_SOCKET, SO_ERROR, (void *)&err, &len) && !err)
331                         result = 0;
332                 }
333             }
334         }
335         if(result == -1)
336             closesocket(netconn->socket);
337         else {
338             flag = 0;
339             ioctlsocket(netconn->socket, FIONBIO, &flag);
340         }
341     }
342     if(result == -1)
343         return ERROR_INTERNET_CANNOT_CONNECT;
344
345 #ifdef TCP_NODELAY
346     flag = 1;
347     result = setsockopt(netconn->socket, IPPROTO_TCP, TCP_NODELAY, (void*)&flag, sizeof(flag));
348     if(result < 0)
349         WARN("setsockopt(TCP_NODELAY) failed\n");
350 #endif
351
352     return ERROR_SUCCESS;
353 }
354
355 DWORD create_netconn(BOOL useSSL, server_t *server, DWORD security_flags, BOOL mask_errors, DWORD timeout, netconn_t **ret)
356 {
357     netconn_t *netconn;
358     int result;
359
360     netconn = heap_alloc_zero(sizeof(*netconn));
361     if(!netconn)
362         return ERROR_OUTOFMEMORY;
363
364     netconn->socket = -1;
365     netconn->security_flags = security_flags | server->security_flags;
366     netconn->mask_errors = mask_errors;
367     list_init(&netconn->pool_entry);
368
369     result = create_netconn_socket(server, netconn, timeout);
370     if (result != ERROR_SUCCESS) {
371         heap_free(netconn);
372         return result;
373     }
374
375     server_addref(server);
376     netconn->server = server;
377     *ret = netconn;
378     return result;
379 }
380
381 void free_netconn(netconn_t *netconn)
382 {
383     server_release(netconn->server);
384
385     if (netconn->secure) {
386         heap_free(netconn->peek_msg_mem);
387         netconn->peek_msg_mem = NULL;
388         netconn->peek_msg = NULL;
389         netconn->peek_len = 0;
390         heap_free(netconn->ssl_buf);
391         netconn->ssl_buf = NULL;
392         heap_free(netconn->extra_buf);
393         netconn->extra_buf = NULL;
394         netconn->extra_len = 0;
395         DeleteSecurityContext(&netconn->ssl_ctx);
396     }
397
398     closesocket(netconn->socket);
399     heap_free(netconn);
400 }
401
402 void NETCON_unload(void)
403 {
404     if(cred_handle_initialized)
405         FreeCredentialsHandle(&cred_handle);
406     if(have_compat_cred_handle)
407         FreeCredentialsHandle(&compat_cred_handle);
408     DeleteCriticalSection(&init_sechandle_cs);
409 }
410
411 /* translate a unix error code into a winsock one */
412 int sock_get_error( int err )
413 {
414 #if !defined(__MINGW32__) && !defined (_MSC_VER)
415     switch (err)
416     {
417         case EINTR:             return WSAEINTR;
418         case EBADF:             return WSAEBADF;
419         case EPERM:
420         case EACCES:            return WSAEACCES;
421         case EFAULT:            return WSAEFAULT;
422         case EINVAL:            return WSAEINVAL;
423         case EMFILE:            return WSAEMFILE;
424         case EWOULDBLOCK:       return WSAEWOULDBLOCK;
425         case EINPROGRESS:       return WSAEINPROGRESS;
426         case EALREADY:          return WSAEALREADY;
427         case ENOTSOCK:          return WSAENOTSOCK;
428         case EDESTADDRREQ:      return WSAEDESTADDRREQ;
429         case EMSGSIZE:          return WSAEMSGSIZE;
430         case EPROTOTYPE:        return WSAEPROTOTYPE;
431         case ENOPROTOOPT:       return WSAENOPROTOOPT;
432         case EPROTONOSUPPORT:   return WSAEPROTONOSUPPORT;
433         case ESOCKTNOSUPPORT:   return WSAESOCKTNOSUPPORT;
434         case EOPNOTSUPP:        return WSAEOPNOTSUPP;
435         case EPFNOSUPPORT:      return WSAEPFNOSUPPORT;
436         case EAFNOSUPPORT:      return WSAEAFNOSUPPORT;
437         case EADDRINUSE:        return WSAEADDRINUSE;
438         case EADDRNOTAVAIL:     return WSAEADDRNOTAVAIL;
439         case ENETDOWN:          return WSAENETDOWN;
440         case ENETUNREACH:       return WSAENETUNREACH;
441         case ENETRESET:         return WSAENETRESET;
442         case ECONNABORTED:      return WSAECONNABORTED;
443         case EPIPE:
444         case ECONNRESET:        return WSAECONNRESET;
445         case ENOBUFS:           return WSAENOBUFS;
446         case EISCONN:           return WSAEISCONN;
447         case ENOTCONN:          return WSAENOTCONN;
448         case ESHUTDOWN:         return WSAESHUTDOWN;
449         case ETOOMANYREFS:      return WSAETOOMANYREFS;
450         case ETIMEDOUT:         return WSAETIMEDOUT;
451         case ECONNREFUSED:      return WSAECONNREFUSED;
452         case ELOOP:             return WSAELOOP;
453         case ENAMETOOLONG:      return WSAENAMETOOLONG;
454         case EHOSTDOWN:         return WSAEHOSTDOWN;
455         case EHOSTUNREACH:      return WSAEHOSTUNREACH;
456         case ENOTEMPTY:         return WSAENOTEMPTY;
457 #ifdef EPROCLIM
458         case EPROCLIM:          return WSAEPROCLIM;
459 #endif
460 #ifdef EUSERS
461         case EUSERS:            return WSAEUSERS;
462 #endif
463 #ifdef EDQUOT
464         case EDQUOT:            return WSAEDQUOT;
465 #endif
466 #ifdef ESTALE
467         case ESTALE:            return WSAESTALE;
468 #endif
469 #ifdef EREMOTE
470         case EREMOTE:           return WSAEREMOTE;
471 #endif
472     default: errno=err; perror("sock_set_error"); return WSAEFAULT;
473     }
474 #endif
475     return err;
476 }
477
478 static DWORD netcon_secure_connect_setup(netconn_t *connection, BOOL compat_mode)
479 {
480     SecBuffer out_buf = {0, SECBUFFER_TOKEN, NULL}, in_bufs[2] = {{0, SECBUFFER_TOKEN}, {0, SECBUFFER_EMPTY}};
481     SecBufferDesc out_desc = {SECBUFFER_VERSION, 1, &out_buf}, in_desc = {SECBUFFER_VERSION, 2, in_bufs};
482     SecHandle *cred = &cred_handle;
483     BYTE *read_buf;
484     SIZE_T read_buf_size = 2048;
485     ULONG attrs = 0;
486     CtxtHandle ctx;
487     SSIZE_T size;
488     int bits;
489     const CERT_CONTEXT *cert;
490     SECURITY_STATUS status;
491     DWORD res = ERROR_SUCCESS;
492
493     const DWORD isc_req_flags = ISC_REQ_ALLOCATE_MEMORY|ISC_REQ_USE_SESSION_KEY|ISC_REQ_CONFIDENTIALITY
494         |ISC_REQ_SEQUENCE_DETECT|ISC_REQ_REPLAY_DETECT|ISC_REQ_MANUAL_CRED_VALIDATION;
495
496     if(!ensure_cred_handle())
497         return FALSE;
498
499     if(compat_mode) {
500         if(!have_compat_cred_handle)
501             return ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
502         cred = &compat_cred_handle;
503     }
504
505     read_buf = heap_alloc(read_buf_size);
506     if(!read_buf)
507         return ERROR_OUTOFMEMORY;
508
509     status = InitializeSecurityContextW(cred, NULL, connection->server->name, isc_req_flags, 0, 0, NULL, 0,
510             &ctx, &out_desc, &attrs, NULL);
511
512     assert(status != SEC_E_OK);
513
514     while(status == SEC_I_CONTINUE_NEEDED || status == SEC_E_INCOMPLETE_MESSAGE) {
515         if(out_buf.cbBuffer) {
516             assert(status == SEC_I_CONTINUE_NEEDED);
517
518             TRACE("sending %u bytes\n", out_buf.cbBuffer);
519
520             size = send(connection->socket, out_buf.pvBuffer, out_buf.cbBuffer, 0);
521             if(size != out_buf.cbBuffer) {
522                 ERR("send failed\n");
523                 status = ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
524                 break;
525             }
526
527             FreeContextBuffer(out_buf.pvBuffer);
528             out_buf.pvBuffer = NULL;
529             out_buf.cbBuffer = 0;
530         }
531
532         if(status == SEC_I_CONTINUE_NEEDED) {
533             assert(in_bufs[1].cbBuffer < read_buf_size);
534
535             memmove(read_buf, (BYTE*)in_bufs[0].pvBuffer+in_bufs[0].cbBuffer-in_bufs[1].cbBuffer, in_bufs[1].cbBuffer);
536             in_bufs[0].cbBuffer = in_bufs[1].cbBuffer;
537
538             in_bufs[1].BufferType = SECBUFFER_EMPTY;
539             in_bufs[1].cbBuffer = 0;
540             in_bufs[1].pvBuffer = NULL;
541         }
542
543         assert(in_bufs[0].BufferType == SECBUFFER_TOKEN);
544         assert(in_bufs[1].BufferType == SECBUFFER_EMPTY);
545
546         if(in_bufs[0].cbBuffer + 1024 > read_buf_size) {
547             BYTE *new_read_buf;
548
549             new_read_buf = heap_realloc(read_buf, read_buf_size + 1024);
550             if(!new_read_buf) {
551                 status = E_OUTOFMEMORY;
552                 break;
553             }
554
555             in_bufs[0].pvBuffer = read_buf = new_read_buf;
556             read_buf_size += 1024;
557         }
558
559         size = recv(connection->socket, read_buf+in_bufs[0].cbBuffer, read_buf_size-in_bufs[0].cbBuffer, 0);
560         if(size < 1) {
561             WARN("recv error\n");
562             res = ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
563             break;
564         }
565
566         TRACE("recv %lu bytes\n", size);
567
568         in_bufs[0].cbBuffer += size;
569         in_bufs[0].pvBuffer = read_buf;
570         status = InitializeSecurityContextW(cred, &ctx, connection->server->name,  isc_req_flags, 0, 0, &in_desc,
571                 0, NULL, &out_desc, &attrs, NULL);
572         TRACE("InitializeSecurityContext ret %08x\n", status);
573
574         if(status == SEC_E_OK) {
575             if(in_bufs[1].BufferType == SECBUFFER_EXTRA)
576                 FIXME("SECBUFFER_EXTRA not supported\n");
577
578             status = QueryContextAttributesW(&ctx, SECPKG_ATTR_STREAM_SIZES, &connection->ssl_sizes);
579             if(status != SEC_E_OK) {
580                 WARN("Could not get sizes\n");
581                 break;
582             }
583
584             status = QueryContextAttributesW(&ctx, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (void*)&cert);
585             if(status == SEC_E_OK) {
586                 res = netconn_verify_cert(connection, cert, cert->hCertStore);
587                 CertFreeCertificateContext(cert);
588                 if(res != ERROR_SUCCESS) {
589                     WARN("cert verify failed: %u\n", res);
590                     break;
591                 }
592             }else {
593                 WARN("Could not get cert\n");
594                 break;
595             }
596
597             connection->ssl_buf = heap_alloc(connection->ssl_sizes.cbHeader + connection->ssl_sizes.cbMaximumMessage
598                     + connection->ssl_sizes.cbTrailer);
599             if(!connection->ssl_buf) {
600                 res = GetLastError();
601                 break;
602             }
603         }
604     }
605
606
607     if(status != SEC_E_OK || res != ERROR_SUCCESS) {
608         WARN("Failed to initialize security context failed: %08x\n", status);
609         heap_free(connection->ssl_buf);
610         connection->ssl_buf = NULL;
611         DeleteSecurityContext(&ctx);
612         return res ? res : ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
613     }
614
615
616     TRACE("established SSL connection\n");
617     connection->ssl_ctx = ctx;
618
619     connection->secure = TRUE;
620     connection->security_flags |= SECURITY_FLAG_SECURE;
621
622     bits = NETCON_GetCipherStrength(connection);
623     if (bits >= 128)
624         connection->security_flags |= SECURITY_FLAG_STRENGTH_STRONG;
625     else if (bits >= 56)
626         connection->security_flags |= SECURITY_FLAG_STRENGTH_MEDIUM;
627     else
628         connection->security_flags |= SECURITY_FLAG_STRENGTH_WEAK;
629
630     if(connection->mask_errors)
631         connection->server->security_flags = connection->security_flags;
632     return ERROR_SUCCESS;
633 }
634
635 /******************************************************************************
636  * NETCON_secure_connect
637  * Initiates a secure connection over an existing plaintext connection.
638  */
639 DWORD NETCON_secure_connect(netconn_t *connection, server_t *server)
640 {
641     DWORD res;
642
643     /* can't connect if we are already connected */
644     if(connection->secure) {
645         ERR("already connected\n");
646         return ERROR_INTERNET_CANNOT_CONNECT;
647     }
648
649     if(server != connection->server) {
650         server_release(connection->server);
651         server_addref(server);
652         connection->server = server;
653     }
654
655     /* connect with given TLS options */
656     res = netcon_secure_connect_setup(connection, FALSE);
657     if (res == ERROR_SUCCESS)
658         return res;
659
660     /* FIXME: when got version alert and FIN from server */
661     /* fallback to connect without TLSv1.1/TLSv1.2        */
662     if (res == ERROR_INTERNET_SECURITY_CHANNEL_ERROR && have_compat_cred_handle)
663     {
664         closesocket(connection->socket);
665         res = create_netconn_socket(connection->server, connection, 500);
666         if (res != ERROR_SUCCESS)
667             return res;
668         res = netcon_secure_connect_setup(connection, TRUE);
669     }
670     return res;
671 }
672
673 static BOOL send_ssl_chunk(netconn_t *conn, const void *msg, size_t size)
674 {
675     SecBuffer bufs[4] = {
676         {conn->ssl_sizes.cbHeader, SECBUFFER_STREAM_HEADER, conn->ssl_buf},
677         {size,  SECBUFFER_DATA, conn->ssl_buf+conn->ssl_sizes.cbHeader},
678         {conn->ssl_sizes.cbTrailer, SECBUFFER_STREAM_TRAILER, conn->ssl_buf+conn->ssl_sizes.cbHeader+size},
679         {0, SECBUFFER_EMPTY, NULL}
680     };
681     SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
682     SECURITY_STATUS res;
683
684     memcpy(bufs[1].pvBuffer, msg, size);
685     res = EncryptMessage(&conn->ssl_ctx, 0, &buf_desc, 0);
686     if(res != SEC_E_OK) {
687         WARN("EncryptMessage failed\n");
688         return FALSE;
689     }
690
691     if(send(conn->socket, conn->ssl_buf, bufs[0].cbBuffer+bufs[1].cbBuffer+bufs[2].cbBuffer, 0) < 1) {
692         WARN("send failed\n");
693         return FALSE;
694     }
695
696     return TRUE;
697 }
698
699 /******************************************************************************
700  * NETCON_send
701  * Basically calls 'send()' unless we should use SSL
702  * number of chars send is put in *sent
703  */
704 DWORD NETCON_send(netconn_t *connection, const void *msg, size_t len, int flags,
705                 int *sent /* out */)
706 {
707     if(!connection->secure)
708     {
709         *sent = send(connection->socket, msg, len, flags);
710         if (*sent == -1)
711             return sock_get_error(errno);
712         return ERROR_SUCCESS;
713     }
714     else
715     {
716         const BYTE *ptr = msg;
717         size_t chunk_size;
718
719         *sent = 0;
720
721         while(len) {
722             chunk_size = min(len, connection->ssl_sizes.cbMaximumMessage);
723             if(!send_ssl_chunk(connection, ptr, chunk_size))
724                 return ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
725
726             *sent += chunk_size;
727             ptr += chunk_size;
728             len -= chunk_size;
729         }
730
731         return ERROR_SUCCESS;
732     }
733 }
734
735 static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *ret_size, BOOL *eof)
736 {
737     const SIZE_T ssl_buf_size = conn->ssl_sizes.cbHeader+conn->ssl_sizes.cbMaximumMessage+conn->ssl_sizes.cbTrailer;
738     SecBuffer bufs[4];
739     SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
740     SSIZE_T size, buf_len;
741     int i;
742     SECURITY_STATUS res;
743
744     assert(conn->extra_len < ssl_buf_size);
745
746     if(conn->extra_len) {
747         memcpy(conn->ssl_buf, conn->extra_buf, conn->extra_len);
748         buf_len = conn->extra_len;
749         conn->extra_len = 0;
750         heap_free(conn->extra_buf);
751         conn->extra_buf = NULL;
752     }else {
753         buf_len = recv(conn->socket, conn->ssl_buf+conn->extra_len, ssl_buf_size-conn->extra_len, 0);
754         if(buf_len < 0) {
755             WARN("recv failed\n");
756             return FALSE;
757         }
758
759         if(!buf_len) {
760             *eof = TRUE;
761             return TRUE;
762         }
763     }
764
765     *ret_size = 0;
766     *eof = FALSE;
767
768     do {
769         memset(bufs, 0, sizeof(bufs));
770         bufs[0].BufferType = SECBUFFER_DATA;
771         bufs[0].cbBuffer = buf_len;
772         bufs[0].pvBuffer = conn->ssl_buf;
773
774         res = DecryptMessage(&conn->ssl_ctx, &buf_desc, 0, NULL);
775         switch(res) {
776         case SEC_E_OK:
777             break;
778         case SEC_I_CONTEXT_EXPIRED:
779             TRACE("context expired\n");
780             *eof = TRUE;
781             return TRUE;
782         case SEC_E_INCOMPLETE_MESSAGE:
783             assert(buf_len < ssl_buf_size);
784
785             size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
786             if(size < 1)
787                 return FALSE;
788
789             buf_len += size;
790             continue;
791         default:
792             WARN("failed: %08x\n", res);
793             return FALSE;
794         }
795     } while(res != SEC_E_OK);
796
797     for(i=0; i < sizeof(bufs)/sizeof(*bufs); i++) {
798         if(bufs[i].BufferType == SECBUFFER_DATA) {
799             size = min(buf_size, bufs[i].cbBuffer);
800             memcpy(buf, bufs[i].pvBuffer, size);
801             if(size < bufs[i].cbBuffer) {
802                 assert(!conn->peek_len);
803                 conn->peek_msg_mem = conn->peek_msg = heap_alloc(bufs[i].cbBuffer - size);
804                 if(!conn->peek_msg)
805                     return FALSE;
806                 conn->peek_len = bufs[i].cbBuffer-size;
807                 memcpy(conn->peek_msg, (char*)bufs[i].pvBuffer+size, conn->peek_len);
808             }
809
810             *ret_size = size;
811         }
812     }
813
814     for(i=0; i < sizeof(bufs)/sizeof(*bufs); i++) {
815         if(bufs[i].BufferType == SECBUFFER_EXTRA) {
816             conn->extra_buf = heap_alloc(bufs[i].cbBuffer);
817             if(!conn->extra_buf)
818                 return FALSE;
819
820             conn->extra_len = bufs[i].cbBuffer;
821             memcpy(conn->extra_buf, bufs[i].pvBuffer, conn->extra_len);
822         }
823     }
824
825     return TRUE;
826 }
827
828 /******************************************************************************
829  * NETCON_recv
830  * Basically calls 'recv()' unless we should use SSL
831  * number of chars received is put in *recvd
832  */
833 DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, int flags, int *recvd)
834 {
835     *recvd = 0;
836     if (!len)
837         return ERROR_SUCCESS;
838
839     if (!connection->secure)
840     {
841         *recvd = recv(connection->socket, buf, len, flags);
842         return *recvd == -1 ? sock_get_error(errno) :  ERROR_SUCCESS;
843     }
844     else
845     {
846         SIZE_T size = 0, cread;
847         BOOL res, eof;
848
849         if(connection->peek_msg) {
850             size = min(len, connection->peek_len);
851             memcpy(buf, connection->peek_msg, size);
852             connection->peek_len -= size;
853             connection->peek_msg += size;
854
855             if(!connection->peek_len) {
856                 heap_free(connection->peek_msg_mem);
857                 connection->peek_msg_mem = connection->peek_msg = NULL;
858             }
859             /* check if we have enough data from the peek buffer */
860             if(!(flags & MSG_WAITALL) || size == len) {
861                 *recvd = size;
862                 return ERROR_SUCCESS;
863             }
864         }
865
866         do {
867             res = read_ssl_chunk(connection, (BYTE*)buf+size, len-size, &cread, &eof);
868             if(!res) {
869                 WARN("read_ssl_chunk failed\n");
870                 if(!size)
871                     return ERROR_INTERNET_CONNECTION_ABORTED;
872                 break;
873             }
874
875             if(eof) {
876                 TRACE("EOF\n");
877                 break;
878             }
879
880             size += cread;
881         }while(!size || ((flags & MSG_WAITALL) && size < len));
882
883         TRACE("received %ld bytes\n", size);
884         *recvd = size;
885         return ERROR_SUCCESS;
886     }
887 }
888
889 /******************************************************************************
890  * NETCON_query_data_available
891  * Returns the number of bytes of peeked data plus the number of bytes of
892  * queued, but unread data.
893  */
894 BOOL NETCON_query_data_available(netconn_t *connection, DWORD *available)
895 {
896     *available = 0;
897
898     if(!connection->secure)
899     {
900 #ifdef FIONREAD
901         ULONG unread;
902         int retval = ioctlsocket(connection->socket, FIONREAD, &unread);
903         if (!retval)
904         {
905             TRACE("%d bytes of queued, but unread data\n", unread);
906             *available += unread;
907         }
908 #endif
909     }
910     else
911     {
912         *available = connection->peek_len;
913     }
914     return TRUE;
915 }
916
917 BOOL NETCON_is_alive(netconn_t *netconn)
918 {
919 #ifdef MSG_DONTWAIT
920     ssize_t len;
921     BYTE b;
922
923     len = recv(netconn->socket, &b, 1, MSG_PEEK|MSG_DONTWAIT);
924     return len == 1 || (len == -1 && errno == EWOULDBLOCK);
925 #elif defined(__MINGW32__) || defined(_MSC_VER)
926     ULONG mode;
927     int len;
928     char b;
929
930     mode = 1;
931     if(!ioctlsocket(netconn->socket, FIONBIO, &mode))
932         return FALSE;
933
934     len = recv(netconn->socket, &b, 1, MSG_PEEK);
935
936     mode = 0;
937     if(!ioctlsocket(netconn->socket, FIONBIO, &mode))
938         return FALSE;
939
940     return len == 1 || (len == -1 && errno == WSAEWOULDBLOCK);
941 #else
942     FIXME("not supported on this platform\n");
943     return TRUE;
944 #endif
945 }
946
947 LPCVOID NETCON_GetCert(netconn_t *connection)
948 {
949     const CERT_CONTEXT *ret;
950     SECURITY_STATUS res;
951
952     if (!connection->secure)
953         return NULL;
954
955     res = QueryContextAttributesW(&connection->ssl_ctx, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (void*)&ret);
956     return res == SEC_E_OK ? ret : NULL;
957 }
958
959 int NETCON_GetCipherStrength(netconn_t *connection)
960 {
961     SecPkgContext_ConnectionInfo conn_info;
962     SECURITY_STATUS res;
963
964     if (!connection->secure)
965         return 0;
966
967     res = QueryContextAttributesW(&connection->ssl_ctx, SECPKG_ATTR_CONNECTION_INFO, (void*)&conn_info);
968     if(res != SEC_E_OK)
969         WARN("QueryContextAttributesW failed: %08x\n", res);
970     return res == SEC_E_OK ? conn_info.dwCipherStrength : 0;
971 }
972
973 DWORD NETCON_set_timeout(netconn_t *connection, BOOL send, DWORD value)
974 {
975     int result;
976     struct timeval tv;
977
978     /* value is in milliseconds, convert to struct timeval */
979     if (value == INFINITE)
980     {
981         tv.tv_sec = 0;
982         tv.tv_usec = 0;
983     }
984     else
985     {
986         tv.tv_sec = value / 1000;
987         tv.tv_usec = (value % 1000) * 1000;
988     }
989     result = setsockopt(connection->socket, SOL_SOCKET,
990                         send ? SO_SNDTIMEO : SO_RCVTIMEO, (void*)&tv,
991                         sizeof(tv));
992     if (result == -1)
993     {
994         WARN("setsockopt failed (%s)\n", strerror(errno));
995         return sock_get_error(errno);
996     }
997     return ERROR_SUCCESS;
998 }