ole32/tests: Fix crashes in usrmarshal.
[wine] / dlls / ole32 / git.c
1 /*
2  * Implementation of the StdGlobalInterfaceTable object
3  *
4  * The GlobalInterfaceTable (GIT) object is used to marshal interfaces between
5  * threading apartments (contexts). When you want to pass an interface but not
6  * as a parameter, it wouldn't get marshalled automatically, so you can use this
7  * object to insert the interface into a table, and you get back a cookie.
8  * Then when it's retrieved, it'll be unmarshalled into the right apartment.
9  *
10  * Copyright 2003 Mike Hearn <mike@theoretic.com>
11  *
12  * This library is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public
14  * License as published by the Free Software Foundation; either
15  * version 2.1 of the License, or (at your option) any later version.
16  *
17  * This library is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with this library; if not, write to the Free Software
24  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
25  */
26
27 #include <stdarg.h>
28
29 #define COBJMACROS
30 #define NONAMELESSUNION
31 #define NONAMELESSSTRUCT
32
33 #include "windef.h"
34 #include "winbase.h"
35 #include "winuser.h"
36 #include "objbase.h"
37 #include "ole2.h"
38 #include "winerror.h"
39
40 #include "compobj_private.h" 
41
42 #include "wine/list.h"
43 #include "wine/debug.h"
44
45 WINE_DEFAULT_DEBUG_CHANNEL(ole);
46
47 /****************************************************************************
48  * StdGlobalInterfaceTable definition
49  *
50  * This class implements IGlobalInterfaceTable and is a process-wide singleton
51  * used for marshalling interfaces between threading apartments using cookies.
52  */
53
54 /* Each entry in the linked list of GIT entries */
55 typedef struct StdGITEntry
56 {
57   DWORD cookie;
58   IID iid;         /* IID of the interface */
59   IStream* stream; /* Holds the marshalled interface */
60
61   struct list entry;
62 } StdGITEntry;
63
64 /* Class data */
65 typedef struct StdGlobalInterfaceTableImpl
66 {
67   IGlobalInterfaceTable IGlobalInterfaceTable_iface;
68
69   ULONG ref;
70   struct list list;
71   ULONG nextCookie;
72   
73 } StdGlobalInterfaceTableImpl;
74
75 void* StdGlobalInterfaceTableInstance;
76
77 static CRITICAL_SECTION git_section;
78 static CRITICAL_SECTION_DEBUG critsect_debug =
79 {
80     0, 0, &git_section,
81     { &critsect_debug.ProcessLocksList, &critsect_debug.ProcessLocksList },
82       0, 0, { (DWORD_PTR)(__FILE__ ": global interface table") }
83 };
84 static CRITICAL_SECTION git_section = { &critsect_debug, -1, 0, 0, 0, 0 };
85
86
87 static inline StdGlobalInterfaceTableImpl *impl_from_IGlobalInterfaceTable(IGlobalInterfaceTable *iface)
88 {
89   return CONTAINING_RECORD(iface, StdGlobalInterfaceTableImpl, IGlobalInterfaceTable_iface);
90 }
91
92 /** This destroys it again. It should revoke all the held interfaces first **/
93 static void StdGlobalInterfaceTable_Destroy(void* This)
94 {
95   TRACE("(%p)\n", This);
96   FIXME("Revoke held interfaces here\n");
97   
98   HeapFree(GetProcessHeap(), 0, This);
99   StdGlobalInterfaceTableInstance = NULL;
100 }
101
102 /***
103  * A helper function to traverse the list and find the entry that matches the cookie.
104  * Returns NULL if not found. Must be called inside git_section critical section.
105  */
106 static StdGITEntry* StdGlobalInterfaceTable_FindEntry(StdGlobalInterfaceTableImpl* This,
107                 DWORD cookie)
108 {
109   StdGITEntry* e;
110
111   TRACE("This=%p, cookie=0x%x\n", This, cookie);
112
113   LIST_FOR_EACH_ENTRY(e, &This->list, StdGITEntry, entry) {
114     if (e->cookie == cookie)
115       return e;
116   }
117
118   TRACE("Entry not found\n");
119   return NULL;
120 }
121
122 /***
123  * Here's the boring boilerplate stuff for IUnknown
124  */
125
126 static HRESULT WINAPI
127 StdGlobalInterfaceTable_QueryInterface(IGlobalInterfaceTable* iface,
128                REFIID riid, void** ppvObject)
129 {
130   /* Make sure silly coders can't crash us */
131   if (ppvObject == 0) return E_INVALIDARG;
132
133   *ppvObject = 0; /* assume we don't have the interface */
134
135   /* Do we implement that interface? */
136   if (IsEqualIID(&IID_IUnknown, riid) ||
137       IsEqualIID(&IID_IGlobalInterfaceTable, riid))
138     *ppvObject = iface;
139   else
140     return E_NOINTERFACE;
141
142   /* Now inc the refcount */
143   IGlobalInterfaceTable_AddRef(iface);
144   return S_OK;
145 }
146
147 static ULONG WINAPI
148 StdGlobalInterfaceTable_AddRef(IGlobalInterfaceTable* iface)
149 {
150   StdGlobalInterfaceTableImpl* const This = impl_from_IGlobalInterfaceTable(iface);
151
152   /* InterlockedIncrement(&This->ref); */
153   return This->ref;
154 }
155
156 static ULONG WINAPI
157 StdGlobalInterfaceTable_Release(IGlobalInterfaceTable* iface)
158 {
159   StdGlobalInterfaceTableImpl* const This = impl_from_IGlobalInterfaceTable(iface);
160
161   /* InterlockedDecrement(&This->ref); */
162   if (This->ref == 0) {
163     /* Hey ho, it's time to go, so long again 'till next weeks show! */
164     StdGlobalInterfaceTable_Destroy(This);
165     return 0;
166   }
167
168   return This->ref;
169 }
170
171 /***
172  * Now implement the actual IGlobalInterfaceTable interface
173  */
174
175 static HRESULT WINAPI
176 StdGlobalInterfaceTable_RegisterInterfaceInGlobal(
177                IGlobalInterfaceTable* iface, IUnknown* pUnk,
178                REFIID riid, DWORD* pdwCookie)
179 {
180   StdGlobalInterfaceTableImpl* const This = impl_from_IGlobalInterfaceTable(iface);
181   IStream* stream = NULL;
182   HRESULT hres;
183   StdGITEntry* entry;
184   LARGE_INTEGER zero;
185
186   TRACE("iface=%p, pUnk=%p, riid=%s, pdwCookie=0x%p\n", iface, pUnk, debugstr_guid(riid), pdwCookie);
187
188   if (pUnk == NULL) return E_INVALIDARG;
189   
190   /* marshal the interface */
191   TRACE("About to marshal the interface\n");
192
193   hres = CreateStreamOnHGlobal(0, TRUE, &stream);
194   if (hres != S_OK) return hres;
195   hres = CoMarshalInterface(stream, riid, pUnk, MSHCTX_INPROC, NULL, MSHLFLAGS_TABLESTRONG);
196   if (hres != S_OK)
197   {
198     IStream_Release(stream);
199     return hres;
200   }
201
202   zero.QuadPart = 0;
203   IStream_Seek(stream, zero, STREAM_SEEK_SET, NULL);
204
205   entry = HeapAlloc(GetProcessHeap(), 0, sizeof(StdGITEntry));
206   if (entry == NULL) return E_OUTOFMEMORY;
207
208   EnterCriticalSection(&git_section);
209   
210   entry->iid = *riid;
211   entry->stream = stream;
212   entry->cookie = This->nextCookie;
213   This->nextCookie++; /* inc the cookie count */
214
215   /* insert the new entry at the end of the list */
216   list_add_tail(&This->list, &entry->entry);
217
218   /* and return the cookie */
219   *pdwCookie = entry->cookie;
220   
221   LeaveCriticalSection(&git_section);
222   
223   TRACE("Cookie is 0x%x\n", entry->cookie);
224   return S_OK;
225 }
226
227 static HRESULT WINAPI
228 StdGlobalInterfaceTable_RevokeInterfaceFromGlobal(
229                IGlobalInterfaceTable* iface, DWORD dwCookie)
230 {
231   StdGlobalInterfaceTableImpl* This = impl_from_IGlobalInterfaceTable(iface);
232   StdGITEntry* entry;
233   HRESULT hr;
234
235   TRACE("iface=%p, dwCookie=0x%x\n", iface, dwCookie);
236
237   EnterCriticalSection(&git_section);
238
239   entry = StdGlobalInterfaceTable_FindEntry(This, dwCookie);
240   if (entry == NULL) {
241     TRACE("Entry not found\n");
242     LeaveCriticalSection(&git_section);
243     return E_INVALIDARG; /* not found */
244   }
245
246   list_remove(&entry->entry);
247
248   LeaveCriticalSection(&git_section);
249   
250   /* Free the stream */
251   hr = CoReleaseMarshalData(entry->stream);
252   if (hr != S_OK)
253   {
254     WARN("Failed to release marshal data, hr = 0x%08x\n", hr);
255     return hr;
256   }
257   IStream_Release(entry->stream);
258                     
259   HeapFree(GetProcessHeap(), 0, entry);
260   return S_OK;
261 }
262
263 static HRESULT WINAPI
264 StdGlobalInterfaceTable_GetInterfaceFromGlobal(
265                IGlobalInterfaceTable* iface, DWORD dwCookie,
266                REFIID riid, void **ppv)
267 {
268   StdGlobalInterfaceTableImpl* This = impl_from_IGlobalInterfaceTable(iface);
269   StdGITEntry* entry;
270   HRESULT hres;
271   IStream *stream;
272
273   TRACE("dwCookie=0x%x, riid=%s, ppv=%p\n", dwCookie, debugstr_guid(riid), ppv);
274
275   EnterCriticalSection(&git_section);
276
277   entry = StdGlobalInterfaceTable_FindEntry(This, dwCookie);
278   if (entry == NULL) {
279     WARN("Entry for cookie 0x%x not found\n", dwCookie);
280     LeaveCriticalSection(&git_section);
281     return E_INVALIDARG;
282   }
283
284   TRACE("entry=%p\n", entry);
285
286   hres = IStream_Clone(entry->stream, &stream);
287
288   LeaveCriticalSection(&git_section);
289
290   if (hres != S_OK) {
291     WARN("Failed to clone stream with error 0x%08x\n", hres);
292     return hres;
293   }
294
295   /* unmarshal the interface */
296   hres = CoUnmarshalInterface(stream, riid, ppv);
297   IStream_Release(stream);
298
299   if (hres) {
300     WARN("Failed to unmarshal stream\n");
301     return hres;
302   }
303
304   TRACE("ppv=%p\n", *ppv);
305   return S_OK;
306 }
307
308 /* Classfactory definition - despite what MSDN says, some programs need this */
309
310 static HRESULT WINAPI
311 GITCF_QueryInterface(LPCLASSFACTORY iface,REFIID riid, LPVOID *ppv)
312 {
313   *ppv = NULL;
314   if (IsEqualIID(riid,&IID_IUnknown) ||
315       IsEqualIID(riid,&IID_IGlobalInterfaceTable))
316   {
317     *ppv = iface;
318     return S_OK;
319   }
320   return E_NOINTERFACE;
321 }
322
323 static ULONG WINAPI GITCF_AddRef(LPCLASSFACTORY iface)
324 {
325   return 2;
326 }
327
328 static ULONG WINAPI GITCF_Release(LPCLASSFACTORY iface)
329 {
330   return 1;
331 }
332
333 static HRESULT WINAPI
334 GITCF_CreateInstance(LPCLASSFACTORY iface, LPUNKNOWN pUnk,
335                      REFIID riid, LPVOID *ppv)
336 {
337   if (IsEqualIID(riid,&IID_IGlobalInterfaceTable)) {
338     if (StdGlobalInterfaceTableInstance == NULL) 
339       StdGlobalInterfaceTableInstance = StdGlobalInterfaceTable_Construct();
340     return IGlobalInterfaceTable_QueryInterface( (IGlobalInterfaceTable*) StdGlobalInterfaceTableInstance, riid, ppv);
341   }
342
343   FIXME("(%s), not supported.\n",debugstr_guid(riid));
344   return E_NOINTERFACE;
345 }
346
347 static HRESULT WINAPI GITCF_LockServer(LPCLASSFACTORY iface, BOOL fLock)
348 {
349     FIXME("(%d), stub!\n",fLock);
350     return S_OK;
351 }
352
353 static const IClassFactoryVtbl GITClassFactoryVtbl = {
354     GITCF_QueryInterface,
355     GITCF_AddRef,
356     GITCF_Release,
357     GITCF_CreateInstance,
358     GITCF_LockServer
359 };
360
361 static const IClassFactoryVtbl *PGITClassFactoryVtbl = &GITClassFactoryVtbl;
362
363 HRESULT StdGlobalInterfaceTable_GetFactory(LPVOID *ppv)
364 {
365   *ppv = &PGITClassFactoryVtbl;
366   TRACE("Returning GIT classfactory\n");
367   return S_OK;
368 }
369
370 /* Virtual function table */
371 static const IGlobalInterfaceTableVtbl StdGlobalInterfaceTableImpl_Vtbl =
372 {
373   StdGlobalInterfaceTable_QueryInterface,
374   StdGlobalInterfaceTable_AddRef,
375   StdGlobalInterfaceTable_Release,
376   StdGlobalInterfaceTable_RegisterInterfaceInGlobal,
377   StdGlobalInterfaceTable_RevokeInterfaceFromGlobal,
378   StdGlobalInterfaceTable_GetInterfaceFromGlobal
379 };
380
381 /** This function constructs the GIT. It should only be called once **/
382 void* StdGlobalInterfaceTable_Construct(void)
383 {
384   StdGlobalInterfaceTableImpl* newGIT;
385
386   newGIT = HeapAlloc(GetProcessHeap(), 0, sizeof(StdGlobalInterfaceTableImpl));
387   if (newGIT == 0) return newGIT;
388
389   newGIT->IGlobalInterfaceTable_iface.lpVtbl = &StdGlobalInterfaceTableImpl_Vtbl;
390   newGIT->ref = 1;      /* Initialise the reference count */
391   list_init(&newGIT->list);
392   newGIT->nextCookie = 0xf100; /* that's where windows starts, so that's where we start */
393   TRACE("Created the GIT at %p\n", newGIT);
394
395   return (void*)newGIT;
396 }