msi: Remove redundant null checks before MSI_EvaluateCondition.
[wine] / dlls / msi / join.c
1 /*
2  * Implementation of the Microsoft Installer (msi.dll)
3  *
4  * Copyright 2006 Mike McCormack for CodeWeavers
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  */
20
21 #include <stdarg.h>
22
23 #include "windef.h"
24 #include "winbase.h"
25 #include "winerror.h"
26 #include "wine/debug.h"
27 #include "msi.h"
28 #include "msiquery.h"
29 #include "objbase.h"
30 #include "objidl.h"
31 #include "msipriv.h"
32 #include "query.h"
33
34 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
35
36 typedef struct tagMSIJOINVIEW
37 {
38     MSIVIEW        view;
39     MSIDATABASE   *db;
40     MSIVIEW       *left, *right;
41     UINT           left_count, right_count;
42     UINT           left_key, right_key;
43     UINT          *pairs;
44     UINT           pair_count;
45 } MSIJOINVIEW;
46
47 static UINT JOIN_fetch_int( struct tagMSIVIEW *view, UINT row, UINT col, UINT *val )
48 {
49     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
50     MSIVIEW *table;
51
52     TRACE("%p %d %d %p\n", jv, row, col, val );
53
54     if( !jv->left || !jv->right )
55          return ERROR_FUNCTION_FAILED;
56
57     if( (col==0) || (col>(jv->left_count + jv->right_count)) )
58          return ERROR_FUNCTION_FAILED;
59
60     if( row >= jv->pair_count )
61          return ERROR_FUNCTION_FAILED;
62
63     if( col <= jv->left_count )
64     {
65         table = jv->left;
66         row = jv->pairs[ row*2 ];
67     }
68     else
69     {
70         table = jv->right;
71         row = jv->pairs[ row*2 + 1 ];
72         col -= jv->left_count;
73     }
74
75     return table->ops->fetch_int( table, row, col, val );
76 }
77
78 static UINT JOIN_fetch_stream( struct tagMSIVIEW *view, UINT row, UINT col, IStream **stm)
79 {
80     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
81     MSIVIEW *table;
82
83     TRACE("%p %d %d %p\n", jv, row, col, stm );
84
85     if( !jv->left || !jv->right )
86          return ERROR_FUNCTION_FAILED;
87
88     if( (col==0) || (col>(jv->left_count + jv->right_count)) )
89          return ERROR_FUNCTION_FAILED;
90
91     if( row <= jv->left_count )
92     {
93         table = jv->left;
94         row = jv->pairs[ row*2 ];
95     }
96     else
97     {
98         table = jv->right;
99         row = jv->pairs[ row*2 + 1 ];
100         col -= jv->left_count;
101     }
102
103     return table->ops->fetch_stream( table, row, col, stm );
104 }
105
106 static int join_key_compare(const void *l, const void *r)
107 {
108     const UINT *left = l, *right = r;
109     if (left[1] < right[1])
110         return -1;
111     if (left[1] == right[1])
112         return 0;
113     return 1;
114 }
115
116 static UINT join_load_key_column( MSIJOINVIEW *jv, MSIVIEW *table, UINT column,
117                                   UINT **pdata, UINT *pcount )
118 {
119     UINT r, i, count = 0, *data = NULL;
120
121     r = table->ops->get_dimensions( table, &count, NULL );
122     if( r != ERROR_SUCCESS )
123         return r;
124
125     if (!count)
126         goto end;
127
128     data = msi_alloc( count * 2 * sizeof (UINT) );
129     if (!data)
130         return ERROR_SUCCESS;
131
132     for (i=0; i<count; i++)
133     {
134         data[i*2] = i;
135         r = table->ops->fetch_int( table, i, column, &data[i*2+1] );
136         if (r != ERROR_SUCCESS)
137             ERR("fetch data (%u,%u) failed\n", i, column);
138     }
139
140     qsort( data, count, 2 * sizeof (UINT), join_key_compare );
141
142 end:
143     *pdata = data;
144     *pcount = count;
145
146     return ERROR_SUCCESS;
147 }
148
149 static UINT join_match( UINT *ldata, UINT lcount,
150                         UINT *rdata, UINT rcount,
151                         UINT **ppairs, UINT *ppair_count )
152 {
153     UINT *pairs;
154     UINT n, i, j;
155
156     TRACE("left %u right %u\n", rcount, lcount);
157
158     /* there can be at most max(lcount, rcount) matches */
159     if (lcount > rcount)
160         n = lcount;
161     else
162         n = rcount;
163
164     pairs = msi_alloc( n * 2 * sizeof(UINT) );
165     if (!pairs)
166         return ERROR_OUTOFMEMORY;
167
168     for (n=0, i=0, j=0; i<lcount && j<rcount; )
169     {
170         /* values match... store the row numbers */
171         if (ldata[i*2+1] == rdata[j*2+1])
172         {
173             pairs[n*2] = ldata[i*2];
174             pairs[n*2+1] = rdata[j*2];
175             n++;
176
177             if ( ldata[i*2+3] < rdata[j*2+3])
178                 i++;
179             else
180                 j++;
181         }
182
183         /* values differ... move along */
184         else if (ldata[i*2+1] < rdata[j*2+1])
185             i++;
186         else
187             j++;
188     }
189
190     *ppairs = pairs;
191     *ppair_count = n;
192
193     return ERROR_SUCCESS;
194 }
195
196 static UINT JOIN_execute( struct tagMSIVIEW *view, MSIRECORD *record )
197 {
198     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
199     UINT r, *ldata = NULL, *rdata = NULL, lcount = 0, rcount = 0;
200
201     TRACE("%p %p\n", jv, record);
202
203     if( !jv->left || !jv->right )
204          return ERROR_FUNCTION_FAILED;
205
206     r = jv->left->ops->execute( jv->left, NULL );
207     if (r != ERROR_SUCCESS)
208         return r;
209
210     r = jv->right->ops->execute( jv->right, NULL );
211     if (r != ERROR_SUCCESS)
212         return r;
213
214     r = join_load_key_column( jv, jv->left, jv->left_key, &ldata, &lcount );
215     if (r != ERROR_SUCCESS)
216         return r;
217
218     r = join_load_key_column( jv, jv->right, jv->right_key, &rdata, &rcount );
219     if (r != ERROR_SUCCESS)
220         goto end;
221
222     r = join_match( ldata, lcount, rdata, rcount, &jv->pairs, &jv->pair_count );
223
224 end:
225     msi_free( ldata );
226     msi_free( rdata );
227
228     return r;
229 }
230
231 static UINT JOIN_close( struct tagMSIVIEW *view )
232 {
233     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
234
235     TRACE("%p\n", jv );
236
237     if( !jv->left || !jv->right )
238         return ERROR_FUNCTION_FAILED;
239
240     jv->left->ops->close( jv->left );
241     jv->right->ops->close( jv->right );
242
243     return ERROR_SUCCESS;
244 }
245
246 static UINT JOIN_get_dimensions( struct tagMSIVIEW *view, UINT *rows, UINT *cols )
247 {
248     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
249
250     TRACE("%p %p %p\n", jv, rows, cols );
251
252     if( cols )
253         *cols = jv->left_count + jv->right_count;
254
255     if( rows )
256     {
257         if( !jv->left || !jv->right )
258             return ERROR_FUNCTION_FAILED;
259
260         *rows = jv->pair_count;
261     }
262
263     return ERROR_SUCCESS;
264 }
265
266 static UINT JOIN_get_column_info( struct tagMSIVIEW *view,
267                 UINT n, LPWSTR *name, UINT *type )
268 {
269     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
270
271     TRACE("%p %d %p %p\n", jv, n, name, type );
272
273     if( !jv->left || !jv->right )
274         return ERROR_FUNCTION_FAILED;
275
276     if( (n==0) || (n>(jv->left_count + jv->right_count)) )
277         return ERROR_FUNCTION_FAILED;
278
279     if( n <= jv->left_count )
280         return jv->left->ops->get_column_info( jv->left, n, name, type );
281
282     n = n - jv->left_count;
283
284     return jv->right->ops->get_column_info( jv->right, n, name, type );
285 }
286
287 static UINT JOIN_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode,
288                 MSIRECORD *rec )
289 {
290     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
291
292     TRACE("%p %d %p\n", jv, eModifyMode, rec );
293
294     return ERROR_FUNCTION_FAILED;
295 }
296
297 static UINT JOIN_delete( struct tagMSIVIEW *view )
298 {
299     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
300
301     TRACE("%p\n", jv );
302
303     if( jv->left )
304         jv->left->ops->delete( jv->left );
305     jv->left = NULL;
306
307     if( jv->right )
308         jv->right->ops->delete( jv->right );
309     jv->right = NULL;
310
311     msi_free( jv->pairs );
312     jv->pairs = NULL;
313
314     msi_free( jv );
315
316     return ERROR_SUCCESS;
317 }
318
319 static UINT JOIN_find_matching_rows( struct tagMSIVIEW *view, UINT col,
320     UINT val, UINT *row, MSIITERHANDLE *handle )
321 {
322     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
323
324     FIXME("%p, %d, %u, %p\n", jv, col, val, *handle);
325
326     return ERROR_FUNCTION_FAILED;
327 }
328
329 static const MSIVIEWOPS join_ops =
330 {
331     JOIN_fetch_int,
332     JOIN_fetch_stream,
333     NULL,
334     NULL,
335     JOIN_execute,
336     JOIN_close,
337     JOIN_get_dimensions,
338     JOIN_get_column_info,
339     JOIN_modify,
340     JOIN_delete,
341     JOIN_find_matching_rows
342 };
343
344 /*
345  * join_check_condition
346  *
347  * This is probably overly strict about what kind of condition we need
348  *  for a join query.
349  */
350 static UINT join_check_condition(MSIJOINVIEW *jv, struct expr *cond)
351 {
352     UINT r;
353
354     /* assume that we have  `KeyColumn` = `SubkeyColumn` */
355     if ( cond->type != EXPR_COMPLEX )
356         return ERROR_FUNCTION_FAILED;
357
358     if ( cond->u.expr.op != OP_EQ )
359         return ERROR_FUNCTION_FAILED;
360
361     if ( cond->u.expr.left->type != EXPR_COLUMN )
362         return ERROR_FUNCTION_FAILED;
363
364     if ( cond->u.expr.right->type != EXPR_COLUMN )
365         return ERROR_FUNCTION_FAILED;
366
367     /* make sure both columns exist */
368     r = VIEW_find_column( jv->left, cond->u.expr.left->u.column, &jv->left_key );
369     if (r != ERROR_SUCCESS)
370         return ERROR_FUNCTION_FAILED;
371
372     r = VIEW_find_column( jv->right, cond->u.expr.right->u.column, &jv->right_key );
373     if (r != ERROR_SUCCESS)
374         return ERROR_FUNCTION_FAILED;
375
376     TRACE("left %s (%u) right %s (%u)\n",
377         debugstr_w(cond->u.expr.left->u.column), jv->left_key,
378         debugstr_w(cond->u.expr.right->u.column), jv->right_key);
379
380     return ERROR_SUCCESS;
381 }
382
383 UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view,
384                       LPCWSTR left, LPCWSTR right,
385                       struct expr *cond )
386 {
387     MSIJOINVIEW *jv = NULL;
388     UINT r = ERROR_SUCCESS;
389
390     TRACE("%p (%s,%s)\n", jv, debugstr_w(left), debugstr_w(right) );
391
392     jv = msi_alloc_zero( sizeof *jv );
393     if( !jv )
394         return ERROR_FUNCTION_FAILED;
395
396     /* fill the structure */
397     jv->view.ops = &join_ops;
398     jv->db = db;
399
400     /* create the tables to join */
401     r = TABLE_CreateView( db, left, &jv->left );
402     if( r != ERROR_SUCCESS )
403     {
404         ERR("can't create left table\n");
405         goto end;
406     }
407
408     r = TABLE_CreateView( db, right, &jv->right );
409     if( r != ERROR_SUCCESS )
410     {
411         ERR("can't create right table\n");
412         goto end;
413     }
414
415     /* get the number of columns in each table */
416     r = jv->left->ops->get_dimensions( jv->left, NULL, &jv->left_count );
417     if( r != ERROR_SUCCESS )
418     {
419         ERR("can't get left table dimensions\n");
420         goto end;
421     }
422
423     r = jv->right->ops->get_dimensions( jv->right, NULL, &jv->right_count );
424     if( r != ERROR_SUCCESS )
425     {
426         ERR("can't get right table dimensions\n");
427         goto end;
428     }
429
430     r = join_check_condition( jv, cond );
431     if( r != ERROR_SUCCESS )
432     {
433         ERR("can't get join condition\n");
434         goto end;
435     }
436
437     *view = &jv->view;
438     return ERROR_SUCCESS;
439
440 end:
441     jv->view.ops->delete( &jv->view );
442
443     return r;
444 }