Do not check for non NULL pointer before HeapFree'ing it. It's
[wine] / dlls / rpcrt4 / rpc_message.c
1 /*
2  * RPC messages
3  *
4  * Copyright 2001-2002 Ove Kåven, TransGaming Technologies
5  * Copyright 2004 Filip Navara
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library; if not, write to the Free Software
19  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
20  *
21  * TODO:
22  *  - figure out whether we *really* got this right
23  *  - check for errors and throw exceptions
24  */
25
26 #include <stdarg.h>
27 #include <stdio.h>
28 #include <string.h>
29
30 #include "windef.h"
31 #include "winbase.h"
32 #include "winerror.h"
33 #include "winreg.h"
34
35 #include "rpc.h"
36 #include "rpcndr.h"
37 #include "rpcdcep.h"
38
39 #include "wine/debug.h"
40
41 #include "rpc_binding.h"
42 #include "rpc_misc.h"
43 #include "rpc_defs.h"
44
45 WINE_DEFAULT_DEBUG_CHANNEL(ole);
46
47 DWORD RPCRT4_GetHeaderSize(RpcPktHdr *Header)
48 {
49   static const DWORD header_sizes[] = {
50     sizeof(Header->request), 0, sizeof(Header->response),
51     sizeof(Header->fault), 0, 0, 0, 0, 0, 0, 0, sizeof(Header->bind),
52     sizeof(Header->bind_ack), sizeof(Header->bind_nack),
53     0, 0, 0, 0, 0
54   };
55   ULONG ret = 0;
56   
57   if (Header->common.ptype < sizeof(header_sizes) / sizeof(header_sizes[0])) {
58     ret = header_sizes[Header->common.ptype];
59     if (ret == 0)
60       FIXME("unhandled packet type\n");
61     if (Header->common.flags & RPC_FLG_OBJECT_UUID)
62       ret += sizeof(UUID);
63   } else {
64     TRACE("invalid packet type\n");
65   }
66
67   return ret;
68 }
69
70 VOID RPCRT4_BuildCommonHeader(RpcPktHdr *Header, unsigned char PacketType,
71                               unsigned long DataRepresentation)
72 {
73   Header->common.rpc_ver = RPC_VER_MAJOR;
74   Header->common.rpc_ver_minor = RPC_VER_MINOR;
75   Header->common.ptype = PacketType;
76   Header->common.drep[0] = LOBYTE(LOWORD(DataRepresentation));
77   Header->common.drep[1] = HIBYTE(LOWORD(DataRepresentation));
78   Header->common.drep[2] = LOBYTE(HIWORD(DataRepresentation));
79   Header->common.drep[3] = HIBYTE(HIWORD(DataRepresentation));
80   Header->common.auth_len = 0;
81   Header->common.call_id = 1;
82   Header->common.flags = 0;
83   /* Flags and fragment length are computed in RPCRT4_Send. */
84 }                              
85
86 RpcPktHdr *RPCRT4_BuildRequestHeader(unsigned long DataRepresentation,
87                                      unsigned long BufferLength,
88                                      unsigned short ProcNum,
89                                      UUID *ObjectUuid)
90 {
91   RpcPktHdr *header;
92   BOOL has_object;
93   RPC_STATUS status;
94
95   has_object = (ObjectUuid != NULL && !UuidIsNil(ObjectUuid, &status));
96   header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY,
97                      sizeof(header->request) + (has_object ? sizeof(UUID) : 0));
98   if (header == NULL) {
99     return NULL;
100   }
101
102   RPCRT4_BuildCommonHeader(header, PKT_REQUEST, DataRepresentation);
103   header->common.frag_len = sizeof(header->request);
104   header->request.alloc_hint = BufferLength;
105   header->request.context_id = 0;
106   header->request.opnum = ProcNum;
107   if (has_object) {
108     header->common.flags |= RPC_FLG_OBJECT_UUID;
109     header->common.frag_len += sizeof(UUID);
110     memcpy(&header->request + 1, ObjectUuid, sizeof(UUID));
111   }
112
113   return header;
114 }
115
116 RpcPktHdr *RPCRT4_BuildResponseHeader(unsigned long DataRepresentation,
117                                       unsigned long BufferLength)
118 {
119   RpcPktHdr *header;
120
121   header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(header->response));
122   if (header == NULL) {
123     return NULL;
124   }
125
126   RPCRT4_BuildCommonHeader(header, PKT_RESPONSE, DataRepresentation);
127   header->common.frag_len = sizeof(header->response);
128   header->response.alloc_hint = BufferLength;
129
130   return header;
131 }
132
133 RpcPktHdr *RPCRT4_BuildFaultHeader(unsigned long DataRepresentation,
134                                    RPC_STATUS Status)
135 {
136   RpcPktHdr *header;
137
138   header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(header->fault));
139   if (header == NULL) {
140     return NULL;
141   }
142
143   RPCRT4_BuildCommonHeader(header, PKT_FAULT, DataRepresentation);
144   header->common.frag_len = sizeof(header->fault);
145   header->fault.status = Status;
146
147   return header;
148 }
149
150 RpcPktHdr *RPCRT4_BuildBindHeader(unsigned long DataRepresentation,
151                                   unsigned short MaxTransmissionSize,
152                                   unsigned short MaxReceiveSize,
153                                   RPC_SYNTAX_IDENTIFIER *AbstractId,
154                                   RPC_SYNTAX_IDENTIFIER *TransferId)
155 {
156   RpcPktHdr *header;
157
158   header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(header->bind));
159   if (header == NULL) {
160     return NULL;
161   }
162
163   RPCRT4_BuildCommonHeader(header, PKT_BIND, DataRepresentation);
164   header->common.frag_len = sizeof(header->bind);
165   header->bind.max_tsize = MaxTransmissionSize;
166   header->bind.max_rsize = MaxReceiveSize;
167   header->bind.num_elements = 1;
168   header->bind.num_syntaxes = 1;
169   memcpy(&header->bind.abstract, AbstractId, sizeof(RPC_SYNTAX_IDENTIFIER));
170   memcpy(&header->bind.transfer, TransferId, sizeof(RPC_SYNTAX_IDENTIFIER));
171
172   return header;
173 }
174
175 RpcPktHdr *RPCRT4_BuildBindNackHeader(unsigned long DataRepresentation,
176                                       unsigned char RpcVersion,
177                                       unsigned char RpcVersionMinor)
178 {
179   RpcPktHdr *header;
180
181   header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(header->bind_nack));
182   if (header == NULL) {
183     return NULL;
184   }
185
186   RPCRT4_BuildCommonHeader(header, PKT_BIND_NACK, DataRepresentation);
187   header->common.frag_len = sizeof(header->bind_nack);
188   header->bind_nack.protocols_count = 1;
189   header->bind_nack.protocols[0].rpc_ver = RpcVersion;
190   header->bind_nack.protocols[0].rpc_ver_minor = RpcVersionMinor;
191
192   return header;
193 }
194
195 RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation,
196                                      unsigned short MaxTransmissionSize,
197                                      unsigned short MaxReceiveSize,
198                                      LPSTR ServerAddress,
199                                      unsigned long Result,
200                                      unsigned long Reason,
201                                      RPC_SYNTAX_IDENTIFIER *TransferId)
202 {
203   RpcPktHdr *header;
204   unsigned long header_size;
205   RpcAddressString *server_address;
206   RpcResults *results;
207   RPC_SYNTAX_IDENTIFIER *transfer_id;
208
209   header_size = sizeof(header->bind_ack) + sizeof(RpcResults) +
210                 sizeof(RPC_SYNTAX_IDENTIFIER) + sizeof(RpcAddressString) +
211                 strlen(ServerAddress);
212
213   header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, header_size);
214   if (header == NULL) {
215     return NULL;
216   }
217
218   RPCRT4_BuildCommonHeader(header, PKT_BIND_ACK, DataRepresentation);
219   header->common.frag_len = header_size;
220   header->bind_ack.max_tsize = MaxTransmissionSize;
221   header->bind_ack.max_rsize = MaxReceiveSize;
222   server_address = (RpcAddressString*)(&header->bind_ack + 1);
223   server_address->length = strlen(ServerAddress) + 1;
224   strcpy(server_address->string, ServerAddress);
225   results = (RpcResults*)((ULONG_PTR)server_address + sizeof(RpcAddressString) + server_address->length - 1);
226   results->num_results = 1;
227   results->results[0].result = Result;
228   results->results[0].reason = Reason;
229   transfer_id = (RPC_SYNTAX_IDENTIFIER*)(results + 1);
230   memcpy(transfer_id, TransferId, sizeof(RPC_SYNTAX_IDENTIFIER));
231
232   return header;
233 }
234
235 VOID RPCRT4_FreeHeader(RpcPktHdr *Header)
236 {
237   HeapFree(GetProcessHeap(), 0, Header);
238 }
239
240 /***********************************************************************
241  *           RPCRT4_Send (internal)
242  * 
243  * Transmit a packet over connection in acceptable fragments.
244  */
245 RPC_STATUS RPCRT4_Send(RpcConnection *Connection, RpcPktHdr *Header,
246                        void *Buffer, unsigned int BufferLength)
247 {
248   PUCHAR buffer_pos;
249   DWORD hdr_size, count;
250
251   buffer_pos = Buffer;
252   /* The packet building functions save the packet header size, so we can use it. */
253   hdr_size = Header->common.frag_len;
254   Header->common.flags |= RPC_FLG_FIRST;
255   Header->common.flags &= ~RPC_FLG_LAST;
256   while (!(Header->common.flags & RPC_FLG_LAST)) {    
257     /* decide if we need to split the packet into fragments */
258     if ((BufferLength + hdr_size) <= Connection->MaxTransmissionSize) {
259       Header->common.flags |= RPC_FLG_LAST;
260       Header->common.frag_len = BufferLength + hdr_size;
261     } else {
262       Header->common.frag_len = Connection->MaxTransmissionSize;
263       buffer_pos += Header->common.frag_len - hdr_size;
264       BufferLength -= Header->common.frag_len - hdr_size;
265     }
266
267     /* transmit packet header */
268     if (!WriteFile(Connection->conn, Header, hdr_size, &count, NULL)) {
269       WARN("WriteFile failed with error %ld\n", GetLastError());
270       return GetLastError();
271     }
272
273     /* fragment consisted of header only and is the last one */
274     if (hdr_size == Header->common.frag_len &&
275         Header->common.flags & RPC_FLG_LAST) {
276       return RPC_S_OK;
277     }
278
279     /* send the fragment data */
280     if (!WriteFile(Connection->conn, buffer_pos, Header->common.frag_len - hdr_size, &count, NULL)) {
281       WARN("WriteFile failed with error %ld\n", GetLastError());
282       return GetLastError();
283     }
284
285     Header->common.flags &= ~RPC_FLG_FIRST;
286   }
287
288   return RPC_S_OK;
289 }
290
291 /***********************************************************************
292  *           RPCRT4_Receive (internal)
293  * 
294  * Receive a packet from connection and merge the fragments.
295  */
296 RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
297                           PRPC_MESSAGE pMsg)
298 {
299   RPC_STATUS status;
300   DWORD dwRead, hdr_length;
301   unsigned short first_flag;
302   unsigned long data_length;
303   unsigned long buffer_length;
304   unsigned char *buffer_ptr;
305   RpcPktCommonHdr common_hdr;
306
307   *Header = NULL;
308
309   TRACE("(%p, %p, %p)\n", Connection, Header, pMsg);
310
311   /* read packet common header */
312   if (!ReadFile(Connection->conn, &common_hdr, sizeof(common_hdr), &dwRead, NULL)) {
313     if (GetLastError() != ERROR_MORE_DATA) {
314       WARN("ReadFile failed with error %ld\n", GetLastError());
315       status = RPC_S_PROTOCOL_ERROR;
316       goto fail;
317     }
318   }
319   if (dwRead != sizeof(common_hdr)) {
320     status = RPC_S_PROTOCOL_ERROR;
321     goto fail;
322   }
323
324   /* verify if the header really makes sense */
325   if (common_hdr.rpc_ver != RPC_VER_MAJOR ||
326       common_hdr.rpc_ver_minor != RPC_VER_MINOR) {
327     WARN("unhandled packet version\n");
328     status = RPC_S_PROTOCOL_ERROR;
329     goto fail;
330   }
331
332   hdr_length = RPCRT4_GetHeaderSize((RpcPktHdr*)&common_hdr);
333   if (hdr_length == 0) {
334     status = RPC_S_PROTOCOL_ERROR;
335     goto fail;
336   }
337
338   *Header = HeapAlloc(GetProcessHeap(), 0, hdr_length);
339   memcpy(*Header, &common_hdr, sizeof(common_hdr));
340
341   /* read the rest of packet header */
342   if (!ReadFile(Connection->conn, &(*Header)->common + 1,
343                 hdr_length - sizeof(common_hdr), &dwRead, NULL)) {
344     if (GetLastError() != ERROR_MORE_DATA) {
345       WARN("ReadFile failed with error %ld\n", GetLastError());
346       status = RPC_S_PROTOCOL_ERROR;
347       goto fail;
348     }
349   }
350   if (dwRead != hdr_length - sizeof(common_hdr)) {
351     status = RPC_S_PROTOCOL_ERROR;
352     goto fail;
353   }
354
355   /* read packet body */
356   switch (common_hdr.ptype) {
357   case PKT_RESPONSE:
358     pMsg->BufferLength = (*Header)->response.alloc_hint;
359     break;
360   case PKT_REQUEST:
361     pMsg->BufferLength = (*Header)->request.alloc_hint;
362     break;
363   default:
364     pMsg->BufferLength = common_hdr.frag_len - hdr_length;
365   }
366   status = I_RpcGetBuffer(pMsg);
367   if (status != RPC_S_OK) goto fail;
368
369   first_flag = RPC_FLG_FIRST;
370   buffer_length = 0;
371   buffer_ptr = pMsg->Buffer;
372   while (buffer_length < pMsg->BufferLength)
373   {
374     data_length = (*Header)->common.frag_len - hdr_length;
375     if (((*Header)->common.flags & RPC_FLG_FIRST) != first_flag ||
376         data_length + buffer_length > pMsg->BufferLength) {
377       TRACE("invalid packet flags or buffer length\n");
378       status = RPC_S_PROTOCOL_ERROR;
379       goto fail;
380     }
381
382     if (data_length == 0) dwRead = 0; else
383     if (!ReadFile(Connection->conn, buffer_ptr, data_length, &dwRead, NULL)) {
384       if (GetLastError() != ERROR_MORE_DATA) {
385         WARN("ReadFile failed with error %ld\n", GetLastError());
386         status = RPC_S_PROTOCOL_ERROR;
387         goto fail;
388       }
389     }
390     if (dwRead != data_length) {
391       status = RPC_S_PROTOCOL_ERROR;
392       goto fail;
393     }
394
395     if (buffer_length == pMsg->BufferLength &&
396         ((*Header)->common.flags & RPC_FLG_LAST) == 0) {
397       status = RPC_S_PROTOCOL_ERROR;
398       goto fail;
399     }
400
401     buffer_length += data_length;
402     if (buffer_length < pMsg->BufferLength) {
403       TRACE("next header\n");
404
405       /* read the header of next packet */
406       if (!ReadFile(Connection->conn, *Header, hdr_length, &dwRead, NULL)) {
407         if (GetLastError() != ERROR_MORE_DATA) {
408           WARN("ReadFile failed with error %ld\n", GetLastError());
409           status = GetLastError();
410           goto fail;
411         }
412       }
413       if (dwRead != hdr_length) {
414         WARN("invalid packet header size (%ld)\n", dwRead);
415         status = RPC_S_PROTOCOL_ERROR;
416         goto fail;
417       }
418
419       buffer_ptr += data_length;
420       first_flag = 0;
421     }
422   }
423
424   /* success */
425   status = RPC_S_OK;
426
427 fail:
428   if (status != RPC_S_OK && *Header) {
429     RPCRT4_FreeHeader(*Header);
430     *Header = NULL;
431   }
432   return status;
433 }
434
435 /***********************************************************************
436  *           I_RpcGetBuffer [RPCRT4.@]
437  */
438 RPC_STATUS WINAPI I_RpcGetBuffer(PRPC_MESSAGE pMsg)
439 {
440   RpcBinding* bind = (RpcBinding*)pMsg->Handle;
441
442   TRACE("(%p): BufferLength=%d\n", pMsg, pMsg->BufferLength);
443   /* FIXME: pfnAllocate? */
444   if (bind->server) {
445     /* it turns out that the original buffer data must still be available
446      * while the RPC server is marshalling a reply, so we should not deallocate
447      * it, we'll leave deallocating the original buffer to the RPC server */
448     pMsg->Buffer = HeapAlloc(GetProcessHeap(), 0, pMsg->BufferLength);
449   } else {
450     HeapFree(GetProcessHeap(), 0, pMsg->Buffer);
451     pMsg->Buffer = HeapAlloc(GetProcessHeap(), 0, pMsg->BufferLength);
452   }
453   TRACE("Buffer=%p\n", pMsg->Buffer);
454   /* FIXME: which errors to return? */
455   return pMsg->Buffer ? S_OK : E_OUTOFMEMORY;
456 }
457
458 /***********************************************************************
459  *           I_RpcFreeBuffer [RPCRT4.@]
460  */
461 RPC_STATUS WINAPI I_RpcFreeBuffer(PRPC_MESSAGE pMsg)
462 {
463   TRACE("(%p) Buffer=%p\n", pMsg, pMsg->Buffer);
464   /* FIXME: pfnFree? */
465   HeapFree(GetProcessHeap(), 0, pMsg->Buffer);
466   pMsg->Buffer = NULL;
467   return S_OK;
468 }
469
470 /***********************************************************************
471  *           I_RpcSend [RPCRT4.@]
472  */
473 RPC_STATUS WINAPI I_RpcSend(PRPC_MESSAGE pMsg)
474 {
475   RpcBinding* bind = (RpcBinding*)pMsg->Handle;
476   RpcConnection* conn;
477   RPC_CLIENT_INTERFACE* cif = NULL;
478   RPC_SERVER_INTERFACE* sif = NULL;
479   RPC_STATUS status;
480   RpcPktHdr *hdr;
481
482   TRACE("(%p)\n", pMsg);
483   if (!bind) return RPC_S_INVALID_BINDING;
484
485   if (bind->server) {
486     sif = pMsg->RpcInterfaceInformation;
487     if (!sif) return RPC_S_INTERFACE_NOT_FOUND; /* ? */
488     status = RPCRT4_OpenBinding(bind, &conn, &sif->TransferSyntax,
489                                 &sif->InterfaceId);
490   } else {
491     cif = pMsg->RpcInterfaceInformation;
492     if (!cif) return RPC_S_INTERFACE_NOT_FOUND; /* ? */
493     status = RPCRT4_OpenBinding(bind, &conn, &cif->TransferSyntax,
494                                 &cif->InterfaceId);
495   }
496
497   if (status != RPC_S_OK) return status;
498
499   if (bind->server) {
500     if (pMsg->RpcFlags & WINE_RPCFLAG_EXCEPTION) {
501       hdr = RPCRT4_BuildFaultHeader(pMsg->DataRepresentation,
502                                     RPC_S_CALL_FAILED);
503     } else {
504       hdr = RPCRT4_BuildResponseHeader(pMsg->DataRepresentation,
505                                        pMsg->BufferLength);
506     }
507   } else {
508     hdr = RPCRT4_BuildRequestHeader(pMsg->DataRepresentation,
509                                     pMsg->BufferLength, pMsg->ProcNum,
510                                     &bind->ObjectUuid);
511   }
512
513   status = RPCRT4_Send(conn, hdr, pMsg->Buffer, pMsg->BufferLength);
514
515   RPCRT4_FreeHeader(hdr);
516
517   /* success */
518   if (!bind->server) {
519     /* save the connection, so the response can be read from it */
520     pMsg->ReservedForRuntime = conn;
521     return RPC_S_OK;
522   }
523   RPCRT4_CloseBinding(bind, conn);
524   status = RPC_S_OK;
525
526   return status;
527 }
528
529 /***********************************************************************
530  *           I_RpcReceive [RPCRT4.@]
531  */
532 RPC_STATUS WINAPI I_RpcReceive(PRPC_MESSAGE pMsg)
533 {
534   RpcBinding* bind = (RpcBinding*)pMsg->Handle;
535   RpcConnection* conn;
536   RPC_CLIENT_INTERFACE* cif = NULL;
537   RPC_SERVER_INTERFACE* sif = NULL;
538   RPC_STATUS status;
539   RpcPktHdr *hdr = NULL;
540
541   TRACE("(%p)\n", pMsg);
542   if (!bind) return RPC_S_INVALID_BINDING;
543
544   if (pMsg->ReservedForRuntime) {
545     conn = pMsg->ReservedForRuntime;
546     pMsg->ReservedForRuntime = NULL;
547   } else {
548     if (bind->server) {
549       sif = pMsg->RpcInterfaceInformation;
550       if (!sif) return RPC_S_INTERFACE_NOT_FOUND; /* ? */
551       status = RPCRT4_OpenBinding(bind, &conn, &sif->TransferSyntax,
552                                   &sif->InterfaceId);
553     } else {
554       cif = pMsg->RpcInterfaceInformation;
555       if (!cif) return RPC_S_INTERFACE_NOT_FOUND; /* ? */
556       status = RPCRT4_OpenBinding(bind, &conn, &cif->TransferSyntax,
557                                   &cif->InterfaceId);
558     }
559     if (status != RPC_S_OK) return status;
560   }
561
562   status = RPCRT4_Receive(conn, &hdr, pMsg);
563   if (status != RPC_S_OK) {
564     WARN("receive failed with error %lx\n", status);
565     goto fail;
566   }
567
568   status = RPC_S_PROTOCOL_ERROR;
569
570   switch (hdr->common.ptype) {
571   case PKT_RESPONSE:
572     if (bind->server) goto fail;
573     break;
574   case PKT_REQUEST:
575     if (!bind->server) goto fail;
576     break;
577   case PKT_FAULT:
578     pMsg->RpcFlags |= WINE_RPCFLAG_EXCEPTION;
579     ERR ("we got fault packet with status %lx\n", hdr->fault.status);
580     status = RPC_S_CALL_FAILED; /* ? */
581     goto fail;
582   default:
583     goto fail;
584   }
585
586   /* success */
587   status = RPC_S_OK;
588
589 fail:
590   if (hdr) {
591     RPCRT4_FreeHeader(hdr);
592   }
593   RPCRT4_CloseBinding(bind, conn);
594   return status;
595 }
596
597 /***********************************************************************
598  *           I_RpcSendReceive [RPCRT4.@]
599  */
600 RPC_STATUS WINAPI I_RpcSendReceive(PRPC_MESSAGE pMsg)
601 {
602   RPC_STATUS status;
603
604   TRACE("(%p)\n", pMsg);
605   status = I_RpcSend(pMsg);
606   if (status == RPC_S_OK)
607     status = I_RpcReceive(pMsg);
608   return status;
609 }