include: Change RPC_STATUS from long to LONG for Win64 compatibility.
[wine] / dlls / rpcrt4 / rpc_assoc.c
1 /*
2  * Associations
3  *
4  * Copyright 2007 Robert Shearman (for CodeWeavers)
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  *
20  */
21
22 #include <stdarg.h>
23 #include <assert.h>
24
25 #include "rpc.h"
26 #include "rpcndr.h"
27 #include "winternl.h"
28
29 #include "wine/unicode.h"
30 #include "wine/debug.h"
31
32 #include "rpc_binding.h"
33 #include "rpc_assoc.h"
34 #include "rpc_message.h"
35
36 WINE_DEFAULT_DEBUG_CHANNEL(rpc);
37
38 static CRITICAL_SECTION assoc_list_cs;
39 static CRITICAL_SECTION_DEBUG assoc_list_cs_debug =
40 {
41     0, 0, &assoc_list_cs,
42     { &assoc_list_cs_debug.ProcessLocksList, &assoc_list_cs_debug.ProcessLocksList },
43       0, 0, { (DWORD_PTR)(__FILE__ ": assoc_list_cs") }
44 };
45 static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 };
46
47 static struct list client_assoc_list = LIST_INIT(client_assoc_list);
48 static struct list server_assoc_list = LIST_INIT(server_assoc_list);
49
50 static LONG last_assoc_group_id;
51
52 typedef struct _RpcContextHandle
53 {
54     struct list entry;
55     void *user_context;
56     NDR_RUNDOWN rundown_routine;
57     void *ctx_guard;
58     UUID uuid;
59     RTL_RWLOCK rw_lock;
60     unsigned int refs;
61 } RpcContextHandle;
62
63 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle);
64
65 static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr,
66                                  LPCSTR Endpoint, LPCWSTR NetworkOptions,
67                                  RpcAssoc **assoc_out)
68 {
69     RpcAssoc *assoc;
70     assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc));
71     if (!assoc)
72         return RPC_S_OUT_OF_RESOURCES;
73     assoc->refs = 1;
74     list_init(&assoc->free_connection_pool);
75     list_init(&assoc->context_handle_list);
76     InitializeCriticalSection(&assoc->cs);
77     assoc->Protseq = RPCRT4_strdupA(Protseq);
78     assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
79     assoc->Endpoint = RPCRT4_strdupA(Endpoint);
80     assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL;
81     assoc->assoc_group_id = 0;
82     list_init(&assoc->entry);
83     *assoc_out = assoc;
84     return RPC_S_OK;
85 }
86
87 RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
88                                  LPCSTR Endpoint, LPCWSTR NetworkOptions,
89                                  RpcAssoc **assoc_out)
90 {
91     RpcAssoc *assoc;
92     RPC_STATUS status;
93
94     EnterCriticalSection(&assoc_list_cs);
95     LIST_FOR_EACH_ENTRY(assoc, &client_assoc_list, RpcAssoc, entry)
96     {
97         if (!strcmp(Protseq, assoc->Protseq) &&
98             !strcmp(NetworkAddr, assoc->NetworkAddr) &&
99             !strcmp(Endpoint, assoc->Endpoint) &&
100             ((!assoc->NetworkOptions && !NetworkOptions) || !strcmpW(NetworkOptions, assoc->NetworkOptions)))
101         {
102             assoc->refs++;
103             *assoc_out = assoc;
104             LeaveCriticalSection(&assoc_list_cs);
105             TRACE("using existing assoc %p\n", assoc);
106             return RPC_S_OK;
107         }
108     }
109
110     status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
111     if (status != RPC_S_OK)
112     {
113         LeaveCriticalSection(&assoc_list_cs);
114         return status;
115     }
116     list_add_head(&client_assoc_list, &assoc->entry);
117     *assoc_out = assoc;
118
119     LeaveCriticalSection(&assoc_list_cs);
120
121     TRACE("new assoc %p\n", assoc);
122
123     return RPC_S_OK;
124 }
125
126 RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
127                                          LPCSTR Endpoint, LPCWSTR NetworkOptions,
128                                          unsigned long assoc_gid,
129                                          RpcAssoc **assoc_out)
130 {
131     RpcAssoc *assoc;
132     RPC_STATUS status;
133
134     EnterCriticalSection(&assoc_list_cs);
135     if (assoc_gid)
136     {
137         LIST_FOR_EACH_ENTRY(assoc, &server_assoc_list, RpcAssoc, entry)
138         {
139             /* FIXME: NetworkAddr shouldn't be NULL */
140             if (assoc->assoc_group_id == assoc_gid &&
141                 !strcmp(Protseq, assoc->Protseq) &&
142                 (!NetworkAddr || !assoc->NetworkAddr || !strcmp(NetworkAddr, assoc->NetworkAddr)) &&
143                 !strcmp(Endpoint, assoc->Endpoint) &&
144                 ((!assoc->NetworkOptions == !NetworkOptions) &&
145                  (!NetworkOptions || !strcmpW(NetworkOptions, assoc->NetworkOptions))))
146             {
147                 assoc->refs++;
148                 *assoc_out = assoc;
149                 LeaveCriticalSection(&assoc_list_cs);
150                 TRACE("using existing assoc %p\n", assoc);
151                 return RPC_S_OK;
152             }
153         }
154         *assoc_out = NULL;
155         LeaveCriticalSection(&assoc_list_cs);
156         return RPC_S_NO_CONTEXT_AVAILABLE;
157     }
158
159     status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
160     if (status != RPC_S_OK)
161     {
162         LeaveCriticalSection(&assoc_list_cs);
163         return status;
164     }
165     assoc->assoc_group_id = InterlockedIncrement(&last_assoc_group_id);
166     list_add_head(&server_assoc_list, &assoc->entry);
167     *assoc_out = assoc;
168
169     LeaveCriticalSection(&assoc_list_cs);
170
171     TRACE("new assoc %p\n", assoc);
172
173     return RPC_S_OK;
174 }
175
176 ULONG RpcAssoc_Release(RpcAssoc *assoc)
177 {
178     ULONG refs;
179
180     EnterCriticalSection(&assoc_list_cs);
181     refs = --assoc->refs;
182     if (!refs)
183         list_remove(&assoc->entry);
184     LeaveCriticalSection(&assoc_list_cs);
185
186     if (!refs)
187     {
188         RpcConnection *Connection, *cursor2;
189         RpcContextHandle *context_handle, *context_handle_cursor;
190
191         TRACE("destroying assoc %p\n", assoc);
192
193         LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
194         {
195             list_remove(&Connection->conn_pool_entry);
196             RPCRT4_DestroyConnection(Connection);
197         }
198
199         LIST_FOR_EACH_ENTRY_SAFE(context_handle, context_handle_cursor, &assoc->context_handle_list, RpcContextHandle, entry)
200             RpcContextHandle_Destroy(context_handle);
201
202         HeapFree(GetProcessHeap(), 0, assoc->NetworkOptions);
203         HeapFree(GetProcessHeap(), 0, assoc->Endpoint);
204         HeapFree(GetProcessHeap(), 0, assoc->NetworkAddr);
205         HeapFree(GetProcessHeap(), 0, assoc->Protseq);
206
207         DeleteCriticalSection(&assoc->cs);
208
209         HeapFree(GetProcessHeap(), 0, assoc);
210     }
211
212     return refs;
213 }
214
215 #define ROUND_UP(value, alignment) (((value) + ((alignment) - 1)) & ~((alignment)-1))
216
217 static RPC_STATUS RpcAssoc_BindConnection(const RpcAssoc *assoc, RpcConnection *conn,
218                                           const RPC_SYNTAX_IDENTIFIER *InterfaceId,
219                                           const RPC_SYNTAX_IDENTIFIER *TransferSyntax)
220 {
221     RpcPktHdr *hdr;
222     RpcPktHdr *response_hdr;
223     RPC_MESSAGE msg;
224     RPC_STATUS status;
225     unsigned char *auth_data = NULL;
226     unsigned long auth_length;
227
228     TRACE("sending bind request to server\n");
229
230     hdr = RPCRT4_BuildBindHeader(NDR_LOCAL_DATA_REPRESENTATION,
231                                  RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE,
232                                  assoc->assoc_group_id,
233                                  InterfaceId, TransferSyntax);
234
235     status = RPCRT4_Send(conn, hdr, NULL, 0);
236     RPCRT4_FreeHeader(hdr);
237     if (status != RPC_S_OK)
238         return status;
239
240     status = RPCRT4_ReceiveWithAuth(conn, &response_hdr, &msg, &auth_data, &auth_length);
241     if (status != RPC_S_OK)
242     {
243         ERR("receive failed with error %d\n", status);
244         return status;
245     }
246
247     switch (response_hdr->common.ptype)
248     {
249     case PKT_BIND_ACK:
250     {
251         RpcAddressString *server_address = msg.Buffer;
252         if ((msg.BufferLength >= FIELD_OFFSET(RpcAddressString, string[0])) ||
253             (msg.BufferLength >= ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4)))
254         {
255             unsigned short remaining = msg.BufferLength -
256             ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4);
257             RpcResults *results = (RpcResults*)((ULONG_PTR)server_address +
258                                                 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4));
259             if ((results->num_results == 1) && (remaining >= sizeof(*results)))
260             {
261                 switch (results->results[0].result)
262                 {
263                 case RESULT_ACCEPT:
264                     /* respond to authorization request */
265                     if (auth_length > sizeof(RpcAuthVerifier))
266                         status = RPCRT4_AuthorizeConnection(conn,
267                                                             auth_data + sizeof(RpcAuthVerifier),
268                                                             auth_length);
269                     if (status == RPC_S_OK)
270                     {
271                         conn->assoc_group_id = response_hdr->bind_ack.assoc_gid;
272                         conn->MaxTransmissionSize = response_hdr->bind_ack.max_tsize;
273                         conn->ActiveInterface = *InterfaceId;
274                     }
275                     break;
276                 case RESULT_PROVIDER_REJECTION:
277                     switch (results->results[0].reason)
278                     {
279                     case REASON_ABSTRACT_SYNTAX_NOT_SUPPORTED:
280                         ERR("syntax %s, %d.%d not supported\n",
281                             debugstr_guid(&InterfaceId->SyntaxGUID),
282                             InterfaceId->SyntaxVersion.MajorVersion,
283                             InterfaceId->SyntaxVersion.MinorVersion);
284                         status = RPC_S_UNKNOWN_IF;
285                         break;
286                     case REASON_TRANSFER_SYNTAXES_NOT_SUPPORTED:
287                         ERR("transfer syntax not supported\n");
288                         status = RPC_S_SERVER_UNAVAILABLE;
289                         break;
290                     case REASON_NONE:
291                     default:
292                         status = RPC_S_CALL_FAILED_DNE;
293                     }
294                     break;
295                 case RESULT_USER_REJECTION:
296                 default:
297                     ERR("rejection result %d\n", results->results[0].result);
298                     status = RPC_S_CALL_FAILED_DNE;
299                 }
300             }
301             else
302             {
303                 ERR("incorrect results size\n");
304                 status = RPC_S_CALL_FAILED_DNE;
305             }
306         }
307         else
308         {
309             ERR("bind ack packet too small (%d)\n", msg.BufferLength);
310             status = RPC_S_PROTOCOL_ERROR;
311         }
312         break;
313     }
314     case PKT_BIND_NACK:
315         switch (response_hdr->bind_nack.reject_reason)
316         {
317         case REJECT_LOCAL_LIMIT_EXCEEDED:
318         case REJECT_TEMPORARY_CONGESTION:
319             ERR("server too busy\n");
320             status = RPC_S_SERVER_TOO_BUSY;
321             break;
322         case REJECT_PROTOCOL_VERSION_NOT_SUPPORTED:
323             ERR("protocol version not supported\n");
324             status = RPC_S_PROTOCOL_ERROR;
325             break;
326         case REJECT_UNKNOWN_AUTHN_SERVICE:
327             ERR("unknown authentication service\n");
328             status = RPC_S_UNKNOWN_AUTHN_SERVICE;
329             break;
330         case REJECT_INVALID_CHECKSUM:
331             ERR("invalid checksum\n");
332             status = ERROR_ACCESS_DENIED;
333             break;
334         default:
335             ERR("rejected bind for reason %d\n", response_hdr->bind_nack.reject_reason);
336             status = RPC_S_CALL_FAILED_DNE;
337         }
338         break;
339     default:
340         ERR("wrong packet type received %d\n", response_hdr->common.ptype);
341         status = RPC_S_PROTOCOL_ERROR;
342         break;
343     }
344
345     I_RpcFree(msg.Buffer);
346     RPCRT4_FreeHeader(response_hdr);
347     HeapFree(GetProcessHeap(), 0, auth_data);
348     return status;
349 }
350
351 static RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc,
352                                                  const RPC_SYNTAX_IDENTIFIER *InterfaceId,
353                                                  const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo,
354                                                  const RpcQualityOfService *QOS)
355 {
356     RpcConnection *Connection;
357     EnterCriticalSection(&assoc->cs);
358     /* try to find a compatible connection from the connection pool */
359     LIST_FOR_EACH_ENTRY(Connection, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
360     {
361         if (!memcmp(&Connection->ActiveInterface, InterfaceId,
362                     sizeof(RPC_SYNTAX_IDENTIFIER)) &&
363             RpcAuthInfo_IsEqual(Connection->AuthInfo, AuthInfo) &&
364             RpcQualityOfService_IsEqual(Connection->QOS, QOS))
365         {
366             list_remove(&Connection->conn_pool_entry);
367             LeaveCriticalSection(&assoc->cs);
368             TRACE("got connection from pool %p\n", Connection);
369             return Connection;
370         }
371     }
372
373     LeaveCriticalSection(&assoc->cs);
374     return NULL;
375 }
376
377 RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc,
378                                         const RPC_SYNTAX_IDENTIFIER *InterfaceId,
379                                         const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo,
380                                         RpcQualityOfService *QOS, RpcConnection **Connection)
381 {
382     RpcConnection *NewConnection;
383     RPC_STATUS status;
384
385     *Connection = RpcAssoc_GetIdleConnection(assoc, InterfaceId, TransferSyntax, AuthInfo, QOS);
386     if (*Connection)
387         return RPC_S_OK;
388
389     /* create a new connection */
390     status = RPCRT4_CreateConnection(&NewConnection, FALSE /* is this a server connection? */,
391         assoc->Protseq, assoc->NetworkAddr,
392         assoc->Endpoint, assoc->NetworkOptions,
393         AuthInfo, QOS);
394     if (status != RPC_S_OK)
395         return status;
396
397     status = RPCRT4_OpenClientConnection(NewConnection);
398     if (status != RPC_S_OK)
399     {
400         RPCRT4_DestroyConnection(NewConnection);
401         return status;
402     }
403
404     status = RpcAssoc_BindConnection(assoc, NewConnection, InterfaceId, TransferSyntax);
405     if (status != RPC_S_OK)
406     {
407         RPCRT4_DestroyConnection(NewConnection);
408         return status;
409     }
410
411     *Connection = NewConnection;
412
413     return RPC_S_OK;
414 }
415
416 void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection)
417 {
418     assert(!Connection->server);
419     EnterCriticalSection(&assoc->cs);
420     if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id;
421     list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry);
422     LeaveCriticalSection(&assoc->cs);
423 }
424
425 RPC_STATUS RpcServerAssoc_AllocateContextHandle(RpcAssoc *assoc, void *CtxGuard,
426                                                 NDR_SCONTEXT *SContext)
427 {
428     RpcContextHandle *context_handle;
429
430     context_handle = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*context_handle));
431     if (!context_handle)
432         return ERROR_OUTOFMEMORY;
433
434     context_handle->ctx_guard = CtxGuard;
435     RtlInitializeResource(&context_handle->rw_lock);
436     context_handle->refs = 1;
437
438     /* lock here to mirror unmarshall, so we don't need to special-case the
439      * freeing of a non-marshalled context handle */
440     RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
441
442     EnterCriticalSection(&assoc->cs);
443     list_add_tail(&assoc->context_handle_list, &context_handle->entry);
444     LeaveCriticalSection(&assoc->cs);
445
446     *SContext = (NDR_SCONTEXT)context_handle;
447     return RPC_S_OK;
448 }
449
450 BOOL RpcContextHandle_IsGuardCorrect(NDR_SCONTEXT SContext, void *CtxGuard)
451 {
452     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
453     return context_handle->ctx_guard == CtxGuard;
454 }
455
456 RPC_STATUS RpcServerAssoc_FindContextHandle(RpcAssoc *assoc, const UUID *uuid,
457                                             void *CtxGuard, ULONG Flags, NDR_SCONTEXT *SContext)
458 {
459     RpcContextHandle *context_handle;
460
461     EnterCriticalSection(&assoc->cs);
462     LIST_FOR_EACH_ENTRY(context_handle, &assoc->context_handle_list, RpcContextHandle, entry)
463     {
464         if (RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard) &&
465             !memcmp(&context_handle->uuid, uuid, sizeof(*uuid)))
466         {
467             *SContext = (NDR_SCONTEXT)context_handle;
468             if (context_handle->refs++)
469             {
470                 LeaveCriticalSection(&assoc->cs);
471                 TRACE("found %p\n", context_handle);
472                 RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
473                 return RPC_S_OK;
474             }
475         }
476     }
477     LeaveCriticalSection(&assoc->cs);
478
479     ERR("no context handle found for uuid %s, guard %p\n",
480         debugstr_guid(uuid), CtxGuard);
481     return ERROR_INVALID_HANDLE;
482 }
483
484 RPC_STATUS RpcServerAssoc_UpdateContextHandle(RpcAssoc *assoc,
485                                               NDR_SCONTEXT SContext,
486                                               void *CtxGuard,
487                                               NDR_RUNDOWN rundown_routine)
488 {
489     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
490     RPC_STATUS status;
491
492     if (!RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard))
493         return ERROR_INVALID_HANDLE;
494
495     EnterCriticalSection(&assoc->cs);
496     if (UuidIsNil(&context_handle->uuid, &status))
497     {
498         /* add a ref for the data being valid */
499         context_handle->refs++;
500         UuidCreate(&context_handle->uuid);
501         context_handle->rundown_routine = rundown_routine;
502         TRACE("allocated uuid %s for context handle %p\n",
503               debugstr_guid(&context_handle->uuid), context_handle);
504     }
505     LeaveCriticalSection(&assoc->cs);
506
507     return RPC_S_OK;
508 }
509
510 void RpcContextHandle_GetUuid(NDR_SCONTEXT SContext, UUID *uuid)
511 {
512     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
513     *uuid = context_handle->uuid;
514 }
515
516 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle)
517 {
518     TRACE("freeing %p\n", context_handle);
519
520     if (context_handle->user_context && context_handle->rundown_routine)
521     {
522         TRACE("calling rundown routine %p with user context %p\n",
523               context_handle->rundown_routine, context_handle->user_context);
524         context_handle->rundown_routine(context_handle->user_context);
525     }
526
527     RtlDeleteResource(&context_handle->rw_lock);
528
529     HeapFree(GetProcessHeap(), 0, context_handle);
530 }
531
532 unsigned int RpcServerAssoc_ReleaseContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, BOOL release_lock)
533 {
534     RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
535     unsigned int refs;
536
537     if (release_lock)
538         RtlReleaseResource(&context_handle->rw_lock);
539
540     EnterCriticalSection(&assoc->cs);
541     refs = --context_handle->refs;
542     if (!refs)
543         list_remove(&context_handle->entry);
544     LeaveCriticalSection(&assoc->cs);
545
546     if (!refs)
547         RpcContextHandle_Destroy(context_handle);
548
549     return refs;
550 }