ole32: Remove some assertions in the stuctured storage code by
[wine] / dlls / msi / join.c
1 /*
2  * Implementation of the Microsoft Installer (msi.dll)
3  *
4  * Copyright 2006 Mike McCormack for CodeWeavers
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  */
20
21 #include <stdarg.h>
22
23 #include "windef.h"
24 #include "winbase.h"
25 #include "winerror.h"
26 #include "wine/debug.h"
27 #include "msi.h"
28 #include "msiquery.h"
29 #include "objbase.h"
30 #include "objidl.h"
31 #include "msipriv.h"
32 #include "query.h"
33
34 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
35
36 typedef struct tagMSIJOINVIEW
37 {
38     MSIVIEW        view;
39     MSIDATABASE   *db;
40     MSIVIEW       *left, *right;
41     UINT           left_count, right_count;
42     UINT           left_key, right_key;
43     UINT          *pairs;
44     UINT           pair_count;
45 } MSIJOINVIEW;
46
47 static UINT JOIN_fetch_int( struct tagMSIVIEW *view, UINT row, UINT col, UINT *val )
48 {
49     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
50     MSIVIEW *table;
51
52     TRACE("%p %d %d %p\n", jv, row, col, val );
53
54     if( !jv->left || !jv->right )
55          return ERROR_FUNCTION_FAILED;
56
57     if( (col==0) || (col>(jv->left_count + jv->right_count)) )
58          return ERROR_FUNCTION_FAILED;
59
60     if( row >= jv->pair_count )
61          return ERROR_FUNCTION_FAILED;
62
63     if( col <= jv->left_count )
64     {
65         table = jv->left;
66         row = jv->pairs[ row*2 ];
67     }
68     else
69     {
70         table = jv->right;
71         row = jv->pairs[ row*2 + 1 ];
72         col -= jv->left_count;
73     }
74
75     return table->ops->fetch_int( table, row, col, val );
76 }
77
78 static UINT JOIN_fetch_stream( struct tagMSIVIEW *view, UINT row, UINT col, IStream **stm)
79 {
80     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
81     MSIVIEW *table;
82
83     TRACE("%p %d %d %p\n", jv, row, col, stm );
84
85     if( !jv->left || !jv->right )
86          return ERROR_FUNCTION_FAILED;
87
88     if( (col==0) || (col>(jv->left_count + jv->right_count)) )
89          return ERROR_FUNCTION_FAILED;
90
91     if( row <= jv->left_count )
92     {
93         table = jv->left;
94         row = jv->pairs[ row*2 ];
95     }
96     else
97     {
98         table = jv->right;
99         row = jv->pairs[ row*2 + 1 ];
100         col -= jv->left_count;
101     }
102
103     return table->ops->fetch_stream( table, row, col, stm );
104 }
105
106 static UINT JOIN_set_int( struct tagMSIVIEW *view, UINT row, UINT col, UINT val )
107 {
108     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
109
110     TRACE("%p %d %d %04x\n", jv, row, col, val );
111
112     return ERROR_FUNCTION_FAILED;
113 }
114
115 static UINT JOIN_insert_row( struct tagMSIVIEW *view, MSIRECORD *record )
116 {
117     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
118
119     TRACE("%p %p\n", jv, record );
120
121     return ERROR_FUNCTION_FAILED;
122 }
123
124 static int join_key_compare(const void *l, const void *r)
125 {
126     const UINT *left = l, *right = r;
127     if (left[1] < right[1])
128         return -1;
129     if (left[1] == right[1])
130         return 0;
131     return 1;
132 }
133
134 static UINT join_load_key_column( MSIJOINVIEW *jv, MSIVIEW *table, UINT column,
135                                   UINT **pdata, UINT *pcount )
136 {
137     UINT r, i, count = 0, *data = NULL;
138
139     r = table->ops->get_dimensions( table, &count, NULL );
140     if( r != ERROR_SUCCESS )
141         return r;
142
143     if (!count)
144         goto end;
145
146     data = msi_alloc( count * 2 * sizeof (UINT) );
147     if (!data)
148         return ERROR_SUCCESS;
149
150     for (i=0; i<count; i++)
151     {
152         data[i*2] = i;
153         r = table->ops->fetch_int( table, i, column, &data[i*2+1] );
154         if (r != ERROR_SUCCESS)
155             ERR("fetch data (%u,%u) failed\n", i, column);
156     }
157
158     qsort( data, count, 2 * sizeof (UINT), join_key_compare );
159
160 end:
161     *pdata = data;
162     *pcount = count;
163
164     return ERROR_SUCCESS;
165 }
166
167 static UINT join_match( UINT *ldata, UINT lcount,
168                         UINT *rdata, UINT rcount,
169                         UINT **ppairs, UINT *ppair_count )
170 {
171     UINT *pairs;
172     UINT n, i, j;
173
174     TRACE("left %u right %u\n", rcount, lcount);
175
176     /* there can be at most max(lcount, rcount) matches */
177     if (lcount > rcount)
178         n = lcount;
179     else
180         n = rcount;
181
182     pairs = msi_alloc( n * 2 * sizeof(UINT) );
183     if (!pairs)
184         return ERROR_OUTOFMEMORY;
185
186     for (n=0, i=0, j=0; i<lcount && j<rcount; )
187     {
188         /* values match... store the row numbers */
189         if (ldata[i*2+1] == rdata[j*2+1])
190         {
191             pairs[n*2] = ldata[i*2];
192             pairs[n*2+1] = rdata[j*2];
193             i++;  /* FIXME: assumes primary key on the right */
194             n++;
195             continue;
196         }
197
198         /* values differ... move along */
199         if (ldata[i*2+1] < rdata[j*2+1])
200             i++;
201         else
202             j++;
203     }
204
205     *ppairs = pairs;
206     *ppair_count = n;
207
208     return ERROR_SUCCESS;
209 }
210
211 static UINT JOIN_execute( struct tagMSIVIEW *view, MSIRECORD *record )
212 {
213     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
214     UINT r, *ldata = NULL, *rdata = NULL, lcount = 0, rcount = 0;
215
216     TRACE("%p %p\n", jv, record);
217
218     if( !jv->left || !jv->right )
219          return ERROR_FUNCTION_FAILED;
220
221     r = jv->left->ops->execute( jv->left, NULL );
222     if (r != ERROR_SUCCESS)
223         return r;
224
225     r = jv->right->ops->execute( jv->right, NULL );
226     if (r != ERROR_SUCCESS)
227         return r;
228
229     r = join_load_key_column( jv, jv->left, jv->left_key, &ldata, &lcount );
230     if (r != ERROR_SUCCESS)
231         return r;
232
233     r = join_load_key_column( jv, jv->right, jv->right_key, &rdata, &rcount );
234     if (r != ERROR_SUCCESS)
235         goto end;
236
237     r = join_match( ldata, lcount, rdata, rcount, &jv->pairs, &jv->pair_count );
238
239 end:
240     msi_free( ldata );
241     msi_free( rdata );
242
243     return r;
244 }
245
246 static UINT JOIN_close( struct tagMSIVIEW *view )
247 {
248     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
249
250     TRACE("%p\n", jv );
251
252     if( !jv->left || !jv->right )
253         return ERROR_FUNCTION_FAILED;
254
255     jv->left->ops->close( jv->left );
256     jv->right->ops->close( jv->right );
257
258     return ERROR_SUCCESS;
259 }
260
261 static UINT JOIN_get_dimensions( struct tagMSIVIEW *view, UINT *rows, UINT *cols )
262 {
263     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
264
265     TRACE("%p %p %p\n", jv, rows, cols );
266
267     if( cols )
268         *cols = jv->left_count + jv->right_count;
269
270     if( rows )
271     {
272         if( !jv->left || !jv->right )
273             return ERROR_FUNCTION_FAILED;
274
275         *rows = jv->pair_count;
276     }
277
278     return ERROR_SUCCESS;
279 }
280
281 static UINT JOIN_get_column_info( struct tagMSIVIEW *view,
282                 UINT n, LPWSTR *name, UINT *type )
283 {
284     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
285
286     TRACE("%p %d %p %p\n", jv, n, name, type );
287
288     if( !jv->left || !jv->right )
289         return ERROR_FUNCTION_FAILED;
290
291     if( (n==0) || (n>(jv->left_count + jv->right_count)) )
292         return ERROR_FUNCTION_FAILED;
293
294     if( n <= jv->left_count )
295         return jv->left->ops->get_column_info( jv->left, n, name, type );
296
297     n = n - jv->left_count;
298
299     return jv->right->ops->get_column_info( jv->right, n, name, type );
300 }
301
302 static UINT JOIN_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode,
303                 MSIRECORD *rec )
304 {
305     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
306
307     TRACE("%p %d %p\n", jv, eModifyMode, rec );
308
309     return ERROR_FUNCTION_FAILED;
310 }
311
312 static UINT JOIN_delete( struct tagMSIVIEW *view )
313 {
314     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
315
316     TRACE("%p\n", jv );
317
318     if( jv->left )
319         jv->left->ops->delete( jv->left );
320     jv->left = NULL;
321
322     if( jv->right )
323         jv->right->ops->delete( jv->right );
324     jv->right = NULL;
325
326     msi_free( jv->pairs );
327     jv->pairs = NULL;
328
329     msi_free( jv );
330
331     return ERROR_SUCCESS;
332 }
333
334 static UINT JOIN_find_matching_rows( struct tagMSIVIEW *view, UINT col,
335     UINT val, UINT *row, MSIITERHANDLE *handle )
336 {
337     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
338
339     FIXME("%p, %d, %u, %p\n", jv, col, val, *handle);
340
341     return ERROR_FUNCTION_FAILED;
342 }
343
344 static const MSIVIEWOPS join_ops =
345 {
346     JOIN_fetch_int,
347     JOIN_fetch_stream,
348     JOIN_set_int,
349     JOIN_insert_row,
350     JOIN_execute,
351     JOIN_close,
352     JOIN_get_dimensions,
353     JOIN_get_column_info,
354     JOIN_modify,
355     JOIN_delete,
356     JOIN_find_matching_rows
357 };
358
359 /*
360  * join_check_condition
361  *
362  * This is probably overly strict about what kind of condition we need
363  *  for a join query.
364  */
365 static UINT join_check_condition(MSIJOINVIEW *jv, struct expr *cond)
366 {
367     UINT r;
368
369     /* assume that we have  `KeyColumn` = `SubkeyColumn` */
370     if ( cond->type != EXPR_COMPLEX )
371         return ERROR_FUNCTION_FAILED;
372
373     if ( cond->u.expr.op != OP_EQ )
374         return ERROR_FUNCTION_FAILED;
375
376     if ( cond->u.expr.left->type != EXPR_COLUMN )
377         return ERROR_FUNCTION_FAILED;
378
379     if ( cond->u.expr.right->type != EXPR_COLUMN )
380         return ERROR_FUNCTION_FAILED;
381
382     /* make sure both columns exist */
383     r = VIEW_find_column( jv->left, cond->u.expr.left->u.column, &jv->left_key );
384     if (r != ERROR_SUCCESS)
385         return ERROR_FUNCTION_FAILED;
386
387     r = VIEW_find_column( jv->right, cond->u.expr.right->u.column, &jv->right_key );
388     if (r != ERROR_SUCCESS)
389         return ERROR_FUNCTION_FAILED;
390
391     TRACE("left %s (%u) right %s (%u)\n",
392         debugstr_w(cond->u.expr.left->u.column), jv->left_key,
393         debugstr_w(cond->u.expr.right->u.column), jv->right_key);
394
395     return ERROR_SUCCESS;
396 }
397
398 UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view,
399                       LPCWSTR left, LPCWSTR right,
400                       struct expr *cond )
401 {
402     MSIJOINVIEW *jv = NULL;
403     UINT r = ERROR_SUCCESS;
404
405     TRACE("%p (%s,%s)\n", jv, debugstr_w(left), debugstr_w(right) );
406
407     jv = msi_alloc_zero( sizeof *jv );
408     if( !jv )
409         return ERROR_FUNCTION_FAILED;
410
411     /* fill the structure */
412     jv->view.ops = &join_ops;
413     jv->db = db;
414
415     /* create the tables to join */
416     r = TABLE_CreateView( db, left, &jv->left );
417     if( r != ERROR_SUCCESS )
418     {
419         ERR("can't create left table\n");
420         goto end;
421     }
422
423     r = TABLE_CreateView( db, right, &jv->right );
424     if( r != ERROR_SUCCESS )
425     {
426         ERR("can't create right table\n");
427         goto end;
428     }
429
430     /* get the number of columns in each table */
431     r = jv->left->ops->get_dimensions( jv->left, NULL, &jv->left_count );
432     if( r != ERROR_SUCCESS )
433     {
434         ERR("can't get left table dimensions\n");
435         goto end;
436     }
437
438     r = jv->right->ops->get_dimensions( jv->right, NULL, &jv->right_count );
439     if( r != ERROR_SUCCESS )
440     {
441         ERR("can't get right table dimensions\n");
442         goto end;
443     }
444
445     r = join_check_condition( jv, cond );
446     if( r != ERROR_SUCCESS )
447     {
448         ERR("can't get join condition\n");
449         goto end;
450     }
451
452     *view = &jv->view;
453     return ERROR_SUCCESS;
454
455 end:
456     jv->view.ops->delete( &jv->view );
457
458     return r;
459 }