user32: Reimplement UserYield using PeekMessageW.
[wine] / dlls / rpcrt4 / rpc_assoc.c
1 /*
2  * Associations
3  *
4  * Copyright 2007 Robert Shearman (for CodeWeavers)
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  *
20  */
21
22 #include <stdarg.h>
23 #include <assert.h>
24
25 #include "rpc.h"
26 #include "rpcndr.h"
27 #include "winternl.h"
28
29 #include "wine/unicode.h"
30 #include "wine/debug.h"
31
32 #include "rpc_binding.h"
33 #include "rpc_assoc.h"
34 #include "rpc_message.h"
35
36 WINE_DEFAULT_DEBUG_CHANNEL(rpc);
37
38 static CRITICAL_SECTION assoc_list_cs;
39 static CRITICAL_SECTION_DEBUG assoc_list_cs_debug =
40 {
41     0, 0, &assoc_list_cs,
42     { &assoc_list_cs_debug.ProcessLocksList, &assoc_list_cs_debug.ProcessLocksList },
43       0, 0, { (DWORD_PTR)(__FILE__ ": assoc_list_cs") }
44 };
45 static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 };
46
47 static struct list client_assoc_list = LIST_INIT(client_assoc_list);
48 static struct list server_assoc_list = LIST_INIT(server_assoc_list);
49
50 static LONG last_assoc_group_id;
51
52 typedef struct _RpcContextHandle
53 {
54     struct list entry;
55     void *user_context;
56     NDR_RUNDOWN rundown_routine;
57     void *ctx_guard;
58     UUID uuid;
59     RTL_RWLOCK rw_lock;
60     unsigned int refs;
61 } RpcContextHandle;
62
63 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle);
64
65 static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr,
66                                  LPCSTR Endpoint, LPCWSTR NetworkOptions,
67                                  RpcAssoc **assoc_out)
68 {
69     RpcAssoc *assoc;
70     assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc));
71     if (!assoc)
72         return RPC_S_OUT_OF_RESOURCES;
73     assoc->refs = 1;
74     list_init(&assoc->free_connection_pool);
75     list_init(&assoc->context_handle_list);
76     InitializeCriticalSection(&assoc->cs);
77     assoc->Protseq = RPCRT4_strdupA(Protseq);
78     assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
79     assoc->Endpoint = RPCRT4_strdupA(Endpoint);
80     assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL;
81     assoc->assoc_group_id = 0;
82     list_init(&assoc->entry);
83     *assoc_out = assoc;
84     return RPC_S_OK;
85 }
86
87 static BOOL compare_networkoptions(LPCWSTR opts1, LPCWSTR opts2)
88 {
89     if ((opts1 == NULL) && (opts2 == NULL))
90         return TRUE;
91     if ((opts1 == NULL) || (opts2 == NULL))
92         return FALSE;
93     return !strcmpW(opts1, opts2);
94 }
95
96 RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
97                                  LPCSTR Endpoint, LPCWSTR NetworkOptions,
98                                  RpcAssoc **assoc_out)
99 {
100     RpcAssoc *assoc;
101     RPC_STATUS status;
102
103     EnterCriticalSection(&assoc_list_cs);
104     LIST_FOR_EACH_ENTRY(assoc, &client_assoc_list, RpcAssoc, entry)
105     {
106         if (!strcmp(Protseq, assoc->Protseq) &&
107             !strcmp(NetworkAddr, assoc->NetworkAddr) &&
108             !strcmp(Endpoint, assoc->Endpoint) &&
109             compare_networkoptions(NetworkOptions, assoc->NetworkOptions))
110         {
111             assoc->refs++;
112             *assoc_out = assoc;
113             LeaveCriticalSection(&assoc_list_cs);
114             TRACE("using existing assoc %p\n", assoc);
115             return RPC_S_OK;
116         }
117     }
118
119     status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
120     if (status != RPC_S_OK)
121     {
122         LeaveCriticalSection(&assoc_list_cs);
123         return status;
124     }
125     list_add_head(&client_assoc_list, &assoc->entry);
126     *assoc_out = assoc;
127
128     LeaveCriticalSection(&assoc_list_cs);
129
130     TRACE("new assoc %p\n", assoc);
131
132     return RPC_S_OK;
133 }
134
135 RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
136                                          LPCSTR Endpoint, LPCWSTR NetworkOptions,
137                                          ULONG assoc_gid,
138                                          RpcAssoc **assoc_out)
139 {
140     RpcAssoc *assoc;
141     RPC_STATUS status;
142
143     EnterCriticalSection(&assoc_list_cs);
144     if (assoc_gid)
145     {
146         LIST_FOR_EACH_ENTRY(assoc, &server_assoc_list, RpcAssoc, entry)
147         {
148             /* FIXME: NetworkAddr shouldn't be NULL */
149             if (assoc->assoc_group_id == assoc_gid &&
150                 !strcmp(Protseq, assoc->Protseq) &&
151                 (!NetworkAddr || !assoc->NetworkAddr || !strcmp(NetworkAddr, assoc->NetworkAddr)) &&
152                 !strcmp(Endpoint, assoc->Endpoint) &&
153                 ((!assoc->NetworkOptions == !NetworkOptions) &&
154                  (!NetworkOptions || !strcmpW(NetworkOptions, assoc->NetworkOptions))))
155             {
156                 assoc->refs++;
157                 *assoc_out = assoc;
158                 LeaveCriticalSection(&assoc_list_cs);
159                 TRACE("using existing assoc %p\n", assoc);
160                 return RPC_S_OK;
161             }
162         }
163         *assoc_out = NULL;
164         LeaveCriticalSection(&assoc_list_cs);
165         return RPC_S_NO_CONTEXT_AVAILABLE;
166     }
167
168     status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
169     if (status != RPC_S_OK)
170     {
171         LeaveCriticalSection(&assoc_list_cs);
172         return status;
173     }
174     assoc->assoc_group_id = InterlockedIncrement(&last_assoc_group_id);
175     list_add_head(&server_assoc_list, &assoc->entry);
176     *assoc_out = assoc;
177
178     LeaveCriticalSection(&assoc_list_cs);
179
180     TRACE("new assoc %p\n", assoc);
181
182     return RPC_S_OK;
183 }
184
185 ULONG RpcAssoc_Release(RpcAssoc *assoc)
186 {
187     ULONG refs;
188
189     EnterCriticalSection(&assoc_list_cs);
190     refs = --assoc->refs;
191     if (!refs)
192         list_remove(&assoc->entry);
193     LeaveCriticalSection(&assoc_list_cs);
194
195     if (!refs)
196     {
197         RpcConnection *Connection, *cursor2;
198         RpcContextHandle *context_handle, *context_handle_cursor;
199
200         TRACE("destroying assoc %p\n", assoc);
201
202         LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
203         {
204             list_remove(&Connection->conn_pool_entry);
205             RPCRT4_DestroyConnection(Connection);
206         }
207
208         LIST_FOR_EACH_ENTRY_SAFE(context_handle, context_handle_cursor, &assoc->context_handle_list, RpcContextHandle, entry)
209             RpcContextHandle_Destroy(context_handle);
210
211         HeapFree(GetProcessHeap(), 0, assoc->NetworkOptions);
212         HeapFree(GetProcessHeap(), 0, assoc->Endpoint);
213         HeapFree(GetProcessHeap(), 0, assoc->NetworkAddr);
214         HeapFree(GetProcessHeap(), 0, assoc->Protseq);
215
216         DeleteCriticalSection(&assoc->cs);
217
218         HeapFree(GetProcessHeap(), 0, assoc);
219     }
220
221     return refs;
222 }
223
224 #define ROUND_UP(value, alignment) (((value) + ((alignment) - 1)) & ~((alignment)-1))
225
226 static RPC_STATUS RpcAssoc_BindConnection(const RpcAssoc *assoc, RpcConnection *conn,
227                                           const RPC_SYNTAX_IDENTIFIER *InterfaceId,
228                                           const RPC_SYNTAX_IDENTIFIER *TransferSyntax)
229 {
230     RpcPktHdr *hdr;
231     RpcPktHdr *response_hdr;
232     RPC_MESSAGE msg;
233     RPC_STATUS status;
234     unsigned char *auth_data = NULL;
235     ULONG auth_length;
236
237     TRACE("sending bind request to server\n");
238
239     hdr = RPCRT4_BuildBindHeader(NDR_LOCAL_DATA_REPRESENTATION,
240                                  RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE,
241                                  assoc->assoc_group_id,
242                                  InterfaceId, TransferSyntax);
243
244     status = RPCRT4_Send(conn, hdr, NULL, 0);
245     RPCRT4_FreeHeader(hdr);
246     if (status != RPC_S_OK)
247         return status;
248
249     status = RPCRT4_ReceiveWithAuth(conn, &response_hdr, &msg, &auth_data, &auth_length);
250     if (status != RPC_S_OK)
251     {
252         ERR("receive failed with error %d\n", status);
253         return status;
254     }
255
256     switch (response_hdr->common.ptype)
257     {
258     case PKT_BIND_ACK:
259     {
260         RpcAddressString *server_address = msg.Buffer;
261         if ((msg.BufferLength >= FIELD_OFFSET(RpcAddressString, string[0])) ||
262             (msg.BufferLength >= ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4)))
263         {
264             unsigned short remaining = msg.BufferLength -
265             ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4);
266             RpcResultList *results = (RpcResultList*)((ULONG_PTR)server_address +
267                 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4));
268             if ((results->num_results == 1) &&
269                 (remaining >= FIELD_OFFSET(RpcResultList, results[results->num_results])))
270             {
271                 switch (results->results[0].result)
272                 {
273                 case RESULT_ACCEPT:
274                     /* respond to authorization request */
275                     if (auth_length > sizeof(RpcAuthVerifier))
276                         status = RPCRT4_AuthorizeConnection(conn,
277                                                             auth_data + sizeof(RpcAuthVerifier),
278                                                             auth_length);
279                     if (status == RPC_S_OK)
280                     {
281                         conn->assoc_group_id = response_hdr->bind_ack.assoc_gid;
282                         conn->MaxTransmissionSize = response_hdr->bind_ack.max_tsize;
283                         conn->ActiveInterface = *InterfaceId;
284                     }
285                     break;
286                 case RESULT_PROVIDER_REJECTION:
287                     switch (results->results[0].reason)
288                     {
289                     case REASON_ABSTRACT_SYNTAX_NOT_SUPPORTED:
290                         ERR("syntax %s, %d.%d not supported\n",
291                             debugstr_guid(&InterfaceId->SyntaxGUID),
292                             InterfaceId->SyntaxVersion.MajorVersion,
293                             InterfaceId->SyntaxVersion.MinorVersion);
294                         status = RPC_S_UNKNOWN_IF;
295                         break;
296                     case REASON_TRANSFER_SYNTAXES_NOT_SUPPORTED:
297                         ERR("transfer syntax not supported\n");
298                         status = RPC_S_SERVER_UNAVAILABLE;
299                         break;
300                     case REASON_NONE:
301                     default:
302                         status = RPC_S_CALL_FAILED_DNE;
303                     }
304                     break;
305                 case RESULT_USER_REJECTION:
306                 default:
307                     ERR("rejection result %d\n", results->results[0].result);
308                     status = RPC_S_CALL_FAILED_DNE;
309                 }
310             }
311             else
312             {
313                 ERR("incorrect results size\n");
314                 status = RPC_S_CALL_FAILED_DNE;
315             }
316         }
317         else
318         {
319             ERR("bind ack packet too small (%d)\n", msg.BufferLength);
320             status = RPC_S_PROTOCOL_ERROR;
321         }
322         break;
323     }
324     case PKT_BIND_NACK:
325         switch (response_hdr->bind_nack.reject_reason)
326         {
327         case REJECT_LOCAL_LIMIT_EXCEEDED:
328         case REJECT_TEMPORARY_CONGESTION:
329             ERR("server too busy\n");
330             status = RPC_S_SERVER_TOO_BUSY;
331             break;
332         case REJECT_PROTOCOL_VERSION_NOT_SUPPORTED:
333             ERR("protocol version not supported\n");
334             status = RPC_S_PROTOCOL_ERROR;
335             break;
336         case REJECT_UNKNOWN_AUTHN_SERVICE:
337             ERR("unknown authentication service\n");
338             status = RPC_S_UNKNOWN_AUTHN_SERVICE;
339             break;
340         case REJECT_INVALID_CHECKSUM:
341             ERR("invalid checksum\n");
342             status = ERROR_ACCESS_DENIED;
343             break;
344         default:
345             ERR("rejected bind for reason %d\n", response_hdr->bind_nack.reject_reason);
346             status = RPC_S_CALL_FAILED_DNE;
347         }
348         break;
349     default:
350         ERR("wrong packet type received %d\n", response_hdr->common.ptype);
351         status = RPC_S_PROTOCOL_ERROR;
352         break;
353     }
354
355     I_RpcFree(msg.Buffer);
356     RPCRT4_FreeHeader(response_hdr);
357     HeapFree(GetProcessHeap(), 0, auth_data);
358     return status;
359 }
360
361 static RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc,
362                                                  const RPC_SYNTAX_IDENTIFIER *InterfaceId,
363                                                  const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo,
364                                                  const RpcQualityOfService *QOS)
365 {
366     RpcConnection *Connection;
367     EnterCriticalSection(&assoc->cs);
368     /* try to find a compatible connection from the connection pool */
369     LIST_FOR_EACH_ENTRY(Connection, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
370     {
371         if (!memcmp(&Connection->ActiveInterface, InterfaceId,
372                     sizeof(RPC_SYNTAX_IDENTIFIER)) &&
373             RpcAuthInfo_IsEqual(Connection->AuthInfo, AuthInfo) &&
374             RpcQualityOfService_IsEqual(Connection->QOS, QOS))
375         {
376             list_remove(&Connection->conn_pool_entry);
377             LeaveCriticalSection(&assoc->cs);
378             TRACE("got connection from pool %p\n", Connection);
379             return Connection;
380         }
381     }
382
383     LeaveCriticalSection(&assoc->cs);
384     return NULL;
385 }
386
387 RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc,
388                                         const RPC_SYNTAX_IDENTIFIER *InterfaceId,
389                                         const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo,
390                                         RpcQualityOfService *QOS, RpcConnection **Connection)
391 {
392     RpcConnection *NewConnection;
393     RPC_STATUS status;
394
395     *Connection = RpcAssoc_GetIdleConnection(assoc, InterfaceId, TransferSyntax, AuthInfo, QOS);
396     if (*Connection)
397         return RPC_S_OK;
398
399     /* create a new connection */
400     status = RPCRT4_CreateConnection(&NewConnection, FALSE /* is this a server connection? */,
401         assoc->Protseq, assoc->NetworkAddr,
402         assoc->Endpoint, assoc->NetworkOptions,
403         AuthInfo, QOS);
404     if (status != RPC_S_OK)
405         return status;
406
407     NewConnection->assoc = assoc;
408     status = RPCRT4_OpenClientConnection(NewConnection);
409     if (status != RPC_S_OK)
410     {
411         RPCRT4_DestroyConnection(NewConnection);
412         return status;
413     }
414
415     status = RpcAssoc_BindConnection(assoc, NewConnection, InterfaceId, TransferSyntax);
416     if (status != RPC_S_OK)
417     {
418         RPCRT4_DestroyConnection(NewConnection);
419         return status;
420     }
421
422     *Connection = NewConnection;
423
424     return RPC_S_OK;
425 }
426
427 void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection)
428 {
429     assert(!Connection->server);
430     Connection->async_state = NULL;
431     EnterCriticalSection(&assoc->cs);
432     if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id;
433     list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry);
434     LeaveCriticalSection(&assoc->cs);
435 }
436
437 RPC_STATUS RpcServerAssoc_AllocateContextHandle(RpcAssoc *assoc, void *CtxGuard,
438                                                 NDR_SCONTEXT *SContext)
439 {
440     RpcContextHandle *context_handle;
441
442     context_handle = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*context_handle));
443     if (!context_handle)
444         return ERROR_OUTOFMEMORY;
445
446     context_handle->ctx_guard = CtxGuard;
447     RtlInitializeResource(&context_handle->rw_lock);
448     context_handle->refs = 1;
449
450     /* lock here to mirror unmarshall, so we don't need to special-case the
451      * freeing of a non-marshalled context handle */
452     RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
453
454     EnterCriticalSection(&assoc->cs);
455     list_add_tail(&assoc->context_handle_list, &context_handle->entry);
456     LeaveCriticalSection(&assoc->cs);
457
458     *SContext = (NDR_SCONTEXT)context_handle;
459     return RPC_S_OK;
460 }
461
462 BOOL RpcContextHandle_IsGuardCorrect(NDR_SCONTEXT SContext, void *CtxGuard)
463 {
464     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
465     return context_handle->ctx_guard == CtxGuard;
466 }
467
468 RPC_STATUS RpcServerAssoc_FindContextHandle(RpcAssoc *assoc, const UUID *uuid,
469                                             void *CtxGuard, ULONG Flags, NDR_SCONTEXT *SContext)
470 {
471     RpcContextHandle *context_handle;
472
473     EnterCriticalSection(&assoc->cs);
474     LIST_FOR_EACH_ENTRY(context_handle, &assoc->context_handle_list, RpcContextHandle, entry)
475     {
476         if (RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard) &&
477             !memcmp(&context_handle->uuid, uuid, sizeof(*uuid)))
478         {
479             *SContext = (NDR_SCONTEXT)context_handle;
480             if (context_handle->refs++)
481             {
482                 LeaveCriticalSection(&assoc->cs);
483                 TRACE("found %p\n", context_handle);
484                 RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
485                 return RPC_S_OK;
486             }
487         }
488     }
489     LeaveCriticalSection(&assoc->cs);
490
491     ERR("no context handle found for uuid %s, guard %p\n",
492         debugstr_guid(uuid), CtxGuard);
493     return ERROR_INVALID_HANDLE;
494 }
495
496 RPC_STATUS RpcServerAssoc_UpdateContextHandle(RpcAssoc *assoc,
497                                               NDR_SCONTEXT SContext,
498                                               void *CtxGuard,
499                                               NDR_RUNDOWN rundown_routine)
500 {
501     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
502     RPC_STATUS status;
503
504     if (!RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard))
505         return ERROR_INVALID_HANDLE;
506
507     EnterCriticalSection(&assoc->cs);
508     if (UuidIsNil(&context_handle->uuid, &status))
509     {
510         /* add a ref for the data being valid */
511         context_handle->refs++;
512         UuidCreate(&context_handle->uuid);
513         context_handle->rundown_routine = rundown_routine;
514         TRACE("allocated uuid %s for context handle %p\n",
515               debugstr_guid(&context_handle->uuid), context_handle);
516     }
517     LeaveCriticalSection(&assoc->cs);
518
519     return RPC_S_OK;
520 }
521
522 void RpcContextHandle_GetUuid(NDR_SCONTEXT SContext, UUID *uuid)
523 {
524     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
525     *uuid = context_handle->uuid;
526 }
527
528 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle)
529 {
530     TRACE("freeing %p\n", context_handle);
531
532     if (context_handle->user_context && context_handle->rundown_routine)
533     {
534         TRACE("calling rundown routine %p with user context %p\n",
535               context_handle->rundown_routine, context_handle->user_context);
536         context_handle->rundown_routine(context_handle->user_context);
537     }
538
539     RtlDeleteResource(&context_handle->rw_lock);
540
541     HeapFree(GetProcessHeap(), 0, context_handle);
542 }
543
544 unsigned int RpcServerAssoc_ReleaseContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, BOOL release_lock)
545 {
546     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
547     unsigned int refs;
548
549     if (release_lock)
550         RtlReleaseResource(&context_handle->rw_lock);
551
552     EnterCriticalSection(&assoc->cs);
553     refs = --context_handle->refs;
554     if (!refs)
555         list_remove(&context_handle->entry);
556     LeaveCriticalSection(&assoc->cs);
557
558     if (!refs)
559         RpcContextHandle_Destroy(context_handle);
560
561     return refs;
562 }