msxml3: Added IDispatchEx support for IXMLDOMImplementation.
[wine] / dlls / rpcrt4 / rpc_assoc.c
1 /*
2  * Associations
3  *
4  * Copyright 2007 Robert Shearman (for CodeWeavers)
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  *
20  */
21
22 #include <stdarg.h>
23 #include <assert.h>
24
25 #include "rpc.h"
26 #include "rpcndr.h"
27 #include "winternl.h"
28
29 #include "wine/unicode.h"
30 #include "wine/debug.h"
31
32 #include "rpc_binding.h"
33 #include "rpc_assoc.h"
34 #include "rpc_message.h"
35
36 WINE_DEFAULT_DEBUG_CHANNEL(rpc);
37
38 static CRITICAL_SECTION assoc_list_cs;
39 static CRITICAL_SECTION_DEBUG assoc_list_cs_debug =
40 {
41     0, 0, &assoc_list_cs,
42     { &assoc_list_cs_debug.ProcessLocksList, &assoc_list_cs_debug.ProcessLocksList },
43       0, 0, { (DWORD_PTR)(__FILE__ ": assoc_list_cs") }
44 };
45 static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 };
46
47 static struct list client_assoc_list = LIST_INIT(client_assoc_list);
48 static struct list server_assoc_list = LIST_INIT(server_assoc_list);
49
50 static LONG last_assoc_group_id;
51
52 typedef struct _RpcContextHandle
53 {
54     struct list entry;
55     void *user_context;
56     NDR_RUNDOWN rundown_routine;
57     void *ctx_guard;
58     UUID uuid;
59     RTL_RWLOCK rw_lock;
60     unsigned int refs;
61 } RpcContextHandle;
62
63 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle);
64
65 static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr,
66                                  LPCSTR Endpoint, LPCWSTR NetworkOptions,
67                                  RpcAssoc **assoc_out)
68 {
69     RpcAssoc *assoc;
70     assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc));
71     if (!assoc)
72         return RPC_S_OUT_OF_RESOURCES;
73     assoc->refs = 1;
74     list_init(&assoc->free_connection_pool);
75     list_init(&assoc->context_handle_list);
76     InitializeCriticalSection(&assoc->cs);
77     assoc->cs.DebugInfo->Spare[0] = (DWORD_PTR)(__FILE__ ": RpcAssoc.cs");
78     assoc->Protseq = RPCRT4_strdupA(Protseq);
79     assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
80     assoc->Endpoint = RPCRT4_strdupA(Endpoint);
81     assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL;
82     assoc->assoc_group_id = 0;
83     list_init(&assoc->entry);
84     *assoc_out = assoc;
85     return RPC_S_OK;
86 }
87
88 static BOOL compare_networkoptions(LPCWSTR opts1, LPCWSTR opts2)
89 {
90     if ((opts1 == NULL) && (opts2 == NULL))
91         return TRUE;
92     if ((opts1 == NULL) || (opts2 == NULL))
93         return FALSE;
94     return !strcmpW(opts1, opts2);
95 }
96
97 RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
98                                  LPCSTR Endpoint, LPCWSTR NetworkOptions,
99                                  RpcAssoc **assoc_out)
100 {
101     RpcAssoc *assoc;
102     RPC_STATUS status;
103
104     EnterCriticalSection(&assoc_list_cs);
105     LIST_FOR_EACH_ENTRY(assoc, &client_assoc_list, RpcAssoc, entry)
106     {
107         if (!strcmp(Protseq, assoc->Protseq) &&
108             !strcmp(NetworkAddr, assoc->NetworkAddr) &&
109             !strcmp(Endpoint, assoc->Endpoint) &&
110             compare_networkoptions(NetworkOptions, assoc->NetworkOptions))
111         {
112             assoc->refs++;
113             *assoc_out = assoc;
114             LeaveCriticalSection(&assoc_list_cs);
115             TRACE("using existing assoc %p\n", assoc);
116             return RPC_S_OK;
117         }
118     }
119
120     status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
121     if (status != RPC_S_OK)
122     {
123         LeaveCriticalSection(&assoc_list_cs);
124         return status;
125     }
126     list_add_head(&client_assoc_list, &assoc->entry);
127     *assoc_out = assoc;
128
129     LeaveCriticalSection(&assoc_list_cs);
130
131     TRACE("new assoc %p\n", assoc);
132
133     return RPC_S_OK;
134 }
135
136 RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
137                                          LPCSTR Endpoint, LPCWSTR NetworkOptions,
138                                          ULONG assoc_gid,
139                                          RpcAssoc **assoc_out)
140 {
141     RpcAssoc *assoc;
142     RPC_STATUS status;
143
144     EnterCriticalSection(&assoc_list_cs);
145     if (assoc_gid)
146     {
147         LIST_FOR_EACH_ENTRY(assoc, &server_assoc_list, RpcAssoc, entry)
148         {
149             /* FIXME: NetworkAddr shouldn't be NULL */
150             if (assoc->assoc_group_id == assoc_gid &&
151                 !strcmp(Protseq, assoc->Protseq) &&
152                 (!NetworkAddr || !assoc->NetworkAddr || !strcmp(NetworkAddr, assoc->NetworkAddr)) &&
153                 !strcmp(Endpoint, assoc->Endpoint) &&
154                 ((!assoc->NetworkOptions == !NetworkOptions) &&
155                  (!NetworkOptions || !strcmpW(NetworkOptions, assoc->NetworkOptions))))
156             {
157                 assoc->refs++;
158                 *assoc_out = assoc;
159                 LeaveCriticalSection(&assoc_list_cs);
160                 TRACE("using existing assoc %p\n", assoc);
161                 return RPC_S_OK;
162             }
163         }
164         *assoc_out = NULL;
165         LeaveCriticalSection(&assoc_list_cs);
166         return RPC_S_NO_CONTEXT_AVAILABLE;
167     }
168
169     status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
170     if (status != RPC_S_OK)
171     {
172         LeaveCriticalSection(&assoc_list_cs);
173         return status;
174     }
175     assoc->assoc_group_id = InterlockedIncrement(&last_assoc_group_id);
176     list_add_head(&server_assoc_list, &assoc->entry);
177     *assoc_out = assoc;
178
179     LeaveCriticalSection(&assoc_list_cs);
180
181     TRACE("new assoc %p\n", assoc);
182
183     return RPC_S_OK;
184 }
185
186 ULONG RpcAssoc_Release(RpcAssoc *assoc)
187 {
188     ULONG refs;
189
190     EnterCriticalSection(&assoc_list_cs);
191     refs = --assoc->refs;
192     if (!refs)
193         list_remove(&assoc->entry);
194     LeaveCriticalSection(&assoc_list_cs);
195
196     if (!refs)
197     {
198         RpcConnection *Connection, *cursor2;
199         RpcContextHandle *context_handle, *context_handle_cursor;
200
201         TRACE("destroying assoc %p\n", assoc);
202
203         LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
204         {
205             list_remove(&Connection->conn_pool_entry);
206             RPCRT4_DestroyConnection(Connection);
207         }
208
209         LIST_FOR_EACH_ENTRY_SAFE(context_handle, context_handle_cursor, &assoc->context_handle_list, RpcContextHandle, entry)
210             RpcContextHandle_Destroy(context_handle);
211
212         HeapFree(GetProcessHeap(), 0, assoc->NetworkOptions);
213         HeapFree(GetProcessHeap(), 0, assoc->Endpoint);
214         HeapFree(GetProcessHeap(), 0, assoc->NetworkAddr);
215         HeapFree(GetProcessHeap(), 0, assoc->Protseq);
216
217         assoc->cs.DebugInfo->Spare[0] = 0;
218         DeleteCriticalSection(&assoc->cs);
219
220         HeapFree(GetProcessHeap(), 0, assoc);
221     }
222
223     return refs;
224 }
225
226 #define ROUND_UP(value, alignment) (((value) + ((alignment) - 1)) & ~((alignment)-1))
227
228 static RPC_STATUS RpcAssoc_BindConnection(const RpcAssoc *assoc, RpcConnection *conn,
229                                           const RPC_SYNTAX_IDENTIFIER *InterfaceId,
230                                           const RPC_SYNTAX_IDENTIFIER *TransferSyntax)
231 {
232     RpcPktHdr *hdr;
233     RpcPktHdr *response_hdr;
234     RPC_MESSAGE msg;
235     RPC_STATUS status;
236     unsigned char *auth_data = NULL;
237     ULONG auth_length;
238
239     TRACE("sending bind request to server\n");
240
241     hdr = RPCRT4_BuildBindHeader(NDR_LOCAL_DATA_REPRESENTATION,
242                                  RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE,
243                                  assoc->assoc_group_id,
244                                  InterfaceId, TransferSyntax);
245
246     status = RPCRT4_Send(conn, hdr, NULL, 0);
247     RPCRT4_FreeHeader(hdr);
248     if (status != RPC_S_OK)
249         return status;
250
251     status = RPCRT4_ReceiveWithAuth(conn, &response_hdr, &msg, &auth_data, &auth_length);
252     if (status != RPC_S_OK)
253     {
254         ERR("receive failed with error %d\n", status);
255         return status;
256     }
257
258     switch (response_hdr->common.ptype)
259     {
260     case PKT_BIND_ACK:
261     {
262         RpcAddressString *server_address = msg.Buffer;
263         if ((msg.BufferLength >= FIELD_OFFSET(RpcAddressString, string[0])) ||
264             (msg.BufferLength >= ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4)))
265         {
266             unsigned short remaining = msg.BufferLength -
267             ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4);
268             RpcResultList *results = (RpcResultList*)((ULONG_PTR)server_address +
269                 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4));
270             if ((results->num_results == 1) &&
271                 (remaining >= FIELD_OFFSET(RpcResultList, results[results->num_results])))
272             {
273                 switch (results->results[0].result)
274                 {
275                 case RESULT_ACCEPT:
276                     /* respond to authorization request */
277                     if (auth_length > sizeof(RpcAuthVerifier))
278                         status = RPCRT4_ClientConnectionAuth(conn,
279                                                              auth_data + sizeof(RpcAuthVerifier),
280                                                              auth_length);
281                     if (status == RPC_S_OK)
282                     {
283                         conn->assoc_group_id = response_hdr->bind_ack.assoc_gid;
284                         conn->MaxTransmissionSize = response_hdr->bind_ack.max_tsize;
285                         conn->ActiveInterface = *InterfaceId;
286                     }
287                     break;
288                 case RESULT_PROVIDER_REJECTION:
289                     switch (results->results[0].reason)
290                     {
291                     case REASON_ABSTRACT_SYNTAX_NOT_SUPPORTED:
292                         ERR("syntax %s, %d.%d not supported\n",
293                             debugstr_guid(&InterfaceId->SyntaxGUID),
294                             InterfaceId->SyntaxVersion.MajorVersion,
295                             InterfaceId->SyntaxVersion.MinorVersion);
296                         status = RPC_S_UNKNOWN_IF;
297                         break;
298                     case REASON_TRANSFER_SYNTAXES_NOT_SUPPORTED:
299                         ERR("transfer syntax not supported\n");
300                         status = RPC_S_SERVER_UNAVAILABLE;
301                         break;
302                     case REASON_NONE:
303                     default:
304                         status = RPC_S_CALL_FAILED_DNE;
305                     }
306                     break;
307                 case RESULT_USER_REJECTION:
308                 default:
309                     ERR("rejection result %d\n", results->results[0].result);
310                     status = RPC_S_CALL_FAILED_DNE;
311                 }
312             }
313             else
314             {
315                 ERR("incorrect results size\n");
316                 status = RPC_S_CALL_FAILED_DNE;
317             }
318         }
319         else
320         {
321             ERR("bind ack packet too small (%d)\n", msg.BufferLength);
322             status = RPC_S_PROTOCOL_ERROR;
323         }
324         break;
325     }
326     case PKT_BIND_NACK:
327         switch (response_hdr->bind_nack.reject_reason)
328         {
329         case REJECT_LOCAL_LIMIT_EXCEEDED:
330         case REJECT_TEMPORARY_CONGESTION:
331             ERR("server too busy\n");
332             status = RPC_S_SERVER_TOO_BUSY;
333             break;
334         case REJECT_PROTOCOL_VERSION_NOT_SUPPORTED:
335             ERR("protocol version not supported\n");
336             status = RPC_S_PROTOCOL_ERROR;
337             break;
338         case REJECT_UNKNOWN_AUTHN_SERVICE:
339             ERR("unknown authentication service\n");
340             status = RPC_S_UNKNOWN_AUTHN_SERVICE;
341             break;
342         case REJECT_INVALID_CHECKSUM:
343             ERR("invalid checksum\n");
344             status = ERROR_ACCESS_DENIED;
345             break;
346         default:
347             ERR("rejected bind for reason %d\n", response_hdr->bind_nack.reject_reason);
348             status = RPC_S_CALL_FAILED_DNE;
349         }
350         break;
351     default:
352         ERR("wrong packet type received %d\n", response_hdr->common.ptype);
353         status = RPC_S_PROTOCOL_ERROR;
354         break;
355     }
356
357     I_RpcFree(msg.Buffer);
358     RPCRT4_FreeHeader(response_hdr);
359     HeapFree(GetProcessHeap(), 0, auth_data);
360     return status;
361 }
362
363 static RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc,
364                                                  const RPC_SYNTAX_IDENTIFIER *InterfaceId,
365                                                  const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo,
366                                                  const RpcQualityOfService *QOS)
367 {
368     RpcConnection *Connection;
369     EnterCriticalSection(&assoc->cs);
370     /* try to find a compatible connection from the connection pool */
371     LIST_FOR_EACH_ENTRY(Connection, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
372     {
373         if (!memcmp(&Connection->ActiveInterface, InterfaceId,
374                     sizeof(RPC_SYNTAX_IDENTIFIER)) &&
375             RpcAuthInfo_IsEqual(Connection->AuthInfo, AuthInfo) &&
376             RpcQualityOfService_IsEqual(Connection->QOS, QOS))
377         {
378             list_remove(&Connection->conn_pool_entry);
379             LeaveCriticalSection(&assoc->cs);
380             TRACE("got connection from pool %p\n", Connection);
381             return Connection;
382         }
383     }
384
385     LeaveCriticalSection(&assoc->cs);
386     return NULL;
387 }
388
389 RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc,
390                                         const RPC_SYNTAX_IDENTIFIER *InterfaceId,
391                                         const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo,
392                                         RpcQualityOfService *QOS, RpcConnection **Connection)
393 {
394     RpcConnection *NewConnection;
395     RPC_STATUS status;
396
397     *Connection = RpcAssoc_GetIdleConnection(assoc, InterfaceId, TransferSyntax, AuthInfo, QOS);
398     if (*Connection)
399         return RPC_S_OK;
400
401     /* create a new connection */
402     status = RPCRT4_CreateConnection(&NewConnection, FALSE /* is this a server connection? */,
403         assoc->Protseq, assoc->NetworkAddr,
404         assoc->Endpoint, assoc->NetworkOptions,
405         AuthInfo, QOS);
406     if (status != RPC_S_OK)
407         return status;
408
409     NewConnection->assoc = assoc;
410     status = RPCRT4_OpenClientConnection(NewConnection);
411     if (status != RPC_S_OK)
412     {
413         RPCRT4_DestroyConnection(NewConnection);
414         return status;
415     }
416
417     status = RpcAssoc_BindConnection(assoc, NewConnection, InterfaceId, TransferSyntax);
418     if (status != RPC_S_OK)
419     {
420         RPCRT4_DestroyConnection(NewConnection);
421         return status;
422     }
423
424     *Connection = NewConnection;
425
426     return RPC_S_OK;
427 }
428
429 void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection)
430 {
431     assert(!Connection->server);
432     Connection->async_state = NULL;
433     EnterCriticalSection(&assoc->cs);
434     if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id;
435     list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry);
436     LeaveCriticalSection(&assoc->cs);
437 }
438
439 RPC_STATUS RpcServerAssoc_AllocateContextHandle(RpcAssoc *assoc, void *CtxGuard,
440                                                 NDR_SCONTEXT *SContext)
441 {
442     RpcContextHandle *context_handle;
443
444     context_handle = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*context_handle));
445     if (!context_handle)
446         return ERROR_OUTOFMEMORY;
447
448     context_handle->ctx_guard = CtxGuard;
449     RtlInitializeResource(&context_handle->rw_lock);
450     context_handle->refs = 1;
451
452     /* lock here to mirror unmarshall, so we don't need to special-case the
453      * freeing of a non-marshalled context handle */
454     RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
455
456     EnterCriticalSection(&assoc->cs);
457     list_add_tail(&assoc->context_handle_list, &context_handle->entry);
458     LeaveCriticalSection(&assoc->cs);
459
460     *SContext = (NDR_SCONTEXT)context_handle;
461     return RPC_S_OK;
462 }
463
464 BOOL RpcContextHandle_IsGuardCorrect(NDR_SCONTEXT SContext, void *CtxGuard)
465 {
466     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
467     return context_handle->ctx_guard == CtxGuard;
468 }
469
470 RPC_STATUS RpcServerAssoc_FindContextHandle(RpcAssoc *assoc, const UUID *uuid,
471                                             void *CtxGuard, ULONG Flags, NDR_SCONTEXT *SContext)
472 {
473     RpcContextHandle *context_handle;
474
475     EnterCriticalSection(&assoc->cs);
476     LIST_FOR_EACH_ENTRY(context_handle, &assoc->context_handle_list, RpcContextHandle, entry)
477     {
478         if (RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard) &&
479             !memcmp(&context_handle->uuid, uuid, sizeof(*uuid)))
480         {
481             *SContext = (NDR_SCONTEXT)context_handle;
482             if (context_handle->refs++)
483             {
484                 LeaveCriticalSection(&assoc->cs);
485                 TRACE("found %p\n", context_handle);
486                 RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
487                 return RPC_S_OK;
488             }
489         }
490     }
491     LeaveCriticalSection(&assoc->cs);
492
493     ERR("no context handle found for uuid %s, guard %p\n",
494         debugstr_guid(uuid), CtxGuard);
495     return ERROR_INVALID_HANDLE;
496 }
497
498 RPC_STATUS RpcServerAssoc_UpdateContextHandle(RpcAssoc *assoc,
499                                               NDR_SCONTEXT SContext,
500                                               void *CtxGuard,
501                                               NDR_RUNDOWN rundown_routine)
502 {
503     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
504     RPC_STATUS status;
505
506     if (!RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard))
507         return ERROR_INVALID_HANDLE;
508
509     EnterCriticalSection(&assoc->cs);
510     if (UuidIsNil(&context_handle->uuid, &status))
511     {
512         /* add a ref for the data being valid */
513         context_handle->refs++;
514         UuidCreate(&context_handle->uuid);
515         context_handle->rundown_routine = rundown_routine;
516         TRACE("allocated uuid %s for context handle %p\n",
517               debugstr_guid(&context_handle->uuid), context_handle);
518     }
519     LeaveCriticalSection(&assoc->cs);
520
521     return RPC_S_OK;
522 }
523
524 void RpcContextHandle_GetUuid(NDR_SCONTEXT SContext, UUID *uuid)
525 {
526     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
527     *uuid = context_handle->uuid;
528 }
529
530 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle)
531 {
532     TRACE("freeing %p\n", context_handle);
533
534     if (context_handle->user_context && context_handle->rundown_routine)
535     {
536         TRACE("calling rundown routine %p with user context %p\n",
537               context_handle->rundown_routine, context_handle->user_context);
538         context_handle->rundown_routine(context_handle->user_context);
539     }
540
541     RtlDeleteResource(&context_handle->rw_lock);
542
543     HeapFree(GetProcessHeap(), 0, context_handle);
544 }
545
546 unsigned int RpcServerAssoc_ReleaseContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, BOOL release_lock)
547 {
548     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
549     unsigned int refs;
550
551     if (release_lock)
552         RtlReleaseResource(&context_handle->rw_lock);
553
554     EnterCriticalSection(&assoc->cs);
555     refs = --context_handle->refs;
556     if (!refs)
557         list_remove(&context_handle->entry);
558     LeaveCriticalSection(&assoc->cs);
559
560     if (!refs)
561         RpcContextHandle_Destroy(context_handle);
562
563     return refs;
564 }