Prevent crashes in I_RpcSend on Windows XP.
[wine] / dlls / rpcrt4 / rpc_message.c
1 /*
2  * RPC messages
3  *
4  * Copyright 2001-2002 Ove Kåven, TransGaming Technologies
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19  *
20  * TODO:
21  *  - figure out whether we *really* got this right
22  *  - check for errors and throw exceptions
23  *  - decide if OVERLAPPED_WORKS
24  */
25
26 #include <stdarg.h>
27 #include <stdio.h>
28 #include <string.h>
29
30 #include "windef.h"
31 #include "winbase.h"
32 #include "winerror.h"
33 #include "winreg.h"
34
35 #include "rpc.h"
36 #include "rpcdcep.h"
37
38 #include "wine/debug.h"
39
40 #include "rpc_binding.h"
41 #include "rpc_misc.h"
42 #include "rpc_defs.h"
43
44 WINE_DEFAULT_DEBUG_CHANNEL(ole);
45
46 /***********************************************************************
47  *           I_RpcGetBuffer [RPCRT4.@]
48  */
49 RPC_STATUS WINAPI I_RpcGetBuffer(PRPC_MESSAGE pMsg)
50 {
51   RpcBinding* bind = (RpcBinding*)pMsg->Handle;
52
53   TRACE("(%p): BufferLength=%d\n", pMsg, pMsg->BufferLength);
54   /* FIXME: pfnAllocate? */
55   if (bind->server) {
56     /* it turns out that the original buffer data must still be available
57      * while the RPC server is marshalling a reply, so we should not deallocate
58      * it, we'll leave deallocating the original buffer to the RPC server */
59     pMsg->Buffer = HeapAlloc(GetProcessHeap(), 0, pMsg->BufferLength);
60   } else {
61     if (pMsg->Buffer)
62         HeapFree(GetProcessHeap(), 0, pMsg->Buffer);
63     pMsg->Buffer = HeapAlloc(GetProcessHeap(), 0, pMsg->BufferLength);
64   }
65   TRACE("Buffer=%p\n", pMsg->Buffer);
66   /* FIXME: which errors to return? */
67   return pMsg->Buffer ? S_OK : E_OUTOFMEMORY;
68 }
69
70 /***********************************************************************
71  *           I_RpcFreeBuffer [RPCRT4.@]
72  */
73 RPC_STATUS WINAPI I_RpcFreeBuffer(PRPC_MESSAGE pMsg)
74 {
75   TRACE("(%p) Buffer=%p\n", pMsg, pMsg->Buffer);
76   /* FIXME: pfnFree? */
77   HeapFree(GetProcessHeap(), 0, pMsg->Buffer);
78   pMsg->Buffer = NULL;
79   return S_OK;
80 }
81
82 /***********************************************************************
83  *           I_RpcSend [RPCRT4.@]
84  */
85 RPC_STATUS WINAPI I_RpcSend(PRPC_MESSAGE pMsg)
86 {
87   RpcBinding* bind = (RpcBinding*)pMsg->Handle;
88   RpcConnection* conn;
89   RPC_CLIENT_INTERFACE* cif = NULL;
90   RPC_SERVER_INTERFACE* sif = NULL;
91   UUID* obj;
92   UUID* act;
93   RPC_STATUS status;
94   RpcPktHdr hdr;
95   DWORD count;
96
97   TRACE("(%p)\n", pMsg);
98   if (!bind) return RPC_S_INVALID_BINDING;
99
100   status = RPCRT4_OpenBinding(bind, &conn);
101   if (status != RPC_S_OK) return status;
102
103   obj = &bind->ObjectUuid;
104   act = &bind->ActiveUuid;
105
106   if (bind->server) {
107     sif = pMsg->RpcInterfaceInformation;
108     if (!sif) return RPC_S_INTERFACE_NOT_FOUND; /* ? */
109   } else {
110     cif = pMsg->RpcInterfaceInformation;
111     if (!cif) return RPC_S_INTERFACE_NOT_FOUND; /* ? */
112   }
113
114   /* initialize packet header */
115   memset(&hdr, 0, sizeof(hdr));
116   hdr.rpc_ver = 4;
117   hdr.ptype = bind->server
118               ? ((pMsg->RpcFlags & WINE_RPCFLAG_EXCEPTION) ? PKT_FAULT : PKT_RESPONSE)
119               : PKT_REQUEST;
120   hdr.object = *obj; /* FIXME: IIRC iff no object, the header structure excludes this elt */
121   hdr.if_id = (bind->server) ? sif->InterfaceId.SyntaxGUID : cif->InterfaceId.SyntaxGUID;
122   hdr.if_vers = 
123     (bind->server) ?
124     MAKELONG(sif->InterfaceId.SyntaxVersion.MinorVersion, sif->InterfaceId.SyntaxVersion.MajorVersion) :
125     MAKELONG(cif->InterfaceId.SyntaxVersion.MinorVersion, cif->InterfaceId.SyntaxVersion.MajorVersion);
126   hdr.act_id = *act;
127   hdr.opnum = pMsg->ProcNum;
128   /* only the low-order 3 octets of the DataRepresentation go in the header */
129   hdr.drep[0] = LOBYTE(LOWORD(pMsg->DataRepresentation));
130   hdr.drep[1] = HIBYTE(LOWORD(pMsg->DataRepresentation));
131   hdr.drep[2] = LOBYTE(HIWORD(pMsg->DataRepresentation));
132   hdr.len = pMsg->BufferLength;
133
134   /* transmit packet */
135   if (!WriteFile(conn->conn, &hdr, sizeof(hdr), &count, NULL)) {
136     WARN("WriteFile failed with error %ld\n", GetLastError());
137     status = RPC_S_PROTOCOL_ERROR;
138     goto fail;
139   }
140   
141   if (!pMsg->BufferLength)
142   {
143     status = RPC_S_OK;
144     goto fail;
145   }
146  
147   if (!WriteFile(conn->conn, pMsg->Buffer, pMsg->BufferLength, &count, NULL)) {
148     WARN("WriteFile failed with error %ld\n", GetLastError());
149     status = RPC_S_PROTOCOL_ERROR;
150     goto fail;
151   }
152
153   /* success */
154   if (!bind->server) {
155     /* save the connection, so the response can be read from it */
156     pMsg->ReservedForRuntime = conn;
157     return RPC_S_OK;
158   }
159   RPCRT4_CloseBinding(bind, conn);
160   status = RPC_S_OK;
161 fail:
162
163   return status;
164 }
165
166 /***********************************************************************
167  *           I_RpcReceive [RPCRT4.@]
168  */
169 RPC_STATUS WINAPI I_RpcReceive(PRPC_MESSAGE pMsg)
170 {
171   RpcBinding* bind = (RpcBinding*)pMsg->Handle;
172   RpcConnection* conn;
173   UUID* act;
174   RPC_STATUS status;
175   RpcPktHdr hdr;
176   DWORD dwRead;
177
178   TRACE("(%p)\n", pMsg);
179   if (!bind) return RPC_S_INVALID_BINDING;
180
181   if (pMsg->ReservedForRuntime) {
182     conn = pMsg->ReservedForRuntime;
183     pMsg->ReservedForRuntime = NULL;
184   } else {
185     status = RPCRT4_OpenBinding(bind, &conn);
186     if (status != RPC_S_OK) return status;
187   }
188
189   act = &bind->ActiveUuid;
190
191   for (;;) {
192     /* read packet header */
193 #ifdef OVERLAPPED_WORKS
194     if (!ReadFile(conn->conn, &hdr, sizeof(hdr), &dwRead, &conn->ovl)) {
195       DWORD err = GetLastError();
196       if (err != ERROR_IO_PENDING) {
197         WARN("ReadFile failed with error %ld\n", err);
198         status = RPC_S_PROTOCOL_ERROR;
199         goto fail;
200       }
201       if (!GetOverlappedResult(conn->conn, &conn->ovl, &dwRead, TRUE)) {
202         WARN("ReadFile failed with error %ld\n", GetLastError());
203         status = RPC_S_PROTOCOL_ERROR;
204         goto fail;
205       }
206     }
207 #else
208     if (!ReadFile(conn->conn, &hdr, sizeof(hdr), &dwRead, NULL)) {
209       WARN("ReadFile failed with error %ld\n", GetLastError());
210       status = RPC_S_PROTOCOL_ERROR;
211       goto fail;
212     }
213 #endif
214     if (dwRead != sizeof(hdr)) {
215       status = RPC_S_PROTOCOL_ERROR;
216       goto fail;
217     }
218
219     /* read packet body */
220     pMsg->BufferLength = hdr.len;
221     status = I_RpcGetBuffer(pMsg);
222     if (status != RPC_S_OK) goto fail;
223     if (!pMsg->BufferLength) dwRead = 0; else
224 #ifdef OVERLAPPED_WORKS
225     if (!ReadFile(conn->conn, pMsg->Buffer, hdr.len, &dwRead, &conn->ovl)) {
226       if (GetLastError() != ERROR_IO_PENDING) {
227         WARN("ReadFile failed with error %ld\n", GetLastError());
228         status = RPC_S_PROTOCOL_ERROR;
229         goto fail;
230       }
231       if (!GetOverlappedResult(conn->conn, &conn->ovl, &dwRead, TRUE)) {
232         WARN("ReadFile failed with error %ld\n", GetLastError());
233         status = RPC_S_PROTOCOL_ERROR;
234         goto fail;
235       }
236     }
237 #else
238     if (!ReadFile(conn->conn, pMsg->Buffer, hdr.len, &dwRead, NULL)) {
239       WARN("ReadFile failed with error %ld\n", GetLastError());
240       status = RPC_S_PROTOCOL_ERROR;
241       goto fail;
242     }
243 #endif
244     if (dwRead != hdr.len) {
245       status = RPC_S_PROTOCOL_ERROR;
246       goto fail;
247     }
248
249     status = RPC_S_PROTOCOL_ERROR;
250
251     switch (hdr.ptype) {
252     case PKT_RESPONSE:
253       if (bind->server) goto fail;
254       break;
255     case PKT_REQUEST:
256       if (!bind->server) goto fail;
257       break;
258     case PKT_FAULT:
259       pMsg->RpcFlags |= WINE_RPCFLAG_EXCEPTION;
260       status = RPC_S_CALL_FAILED; /* ? */
261       goto fail;
262     default:
263       goto fail;
264     }
265
266     /* success */
267     status = RPC_S_OK;
268
269     /* FIXME: check destination, etc? */
270     break;
271   }
272 fail:
273   RPCRT4_CloseBinding(bind, conn);
274   return status;
275 }
276
277 /***********************************************************************
278  *           I_RpcSendReceive [RPCRT4.@]
279  */
280 RPC_STATUS WINAPI I_RpcSendReceive(PRPC_MESSAGE pMsg)
281 {
282   RPC_STATUS status;
283
284   TRACE("(%p)\n", pMsg);
285   status = I_RpcSend(pMsg);
286   if (status == RPC_S_OK)
287     status = I_RpcReceive(pMsg);
288   return status;
289 }