crypt32: Introduce function to encode an array of items as a set.
[wine] / dlls / msi / where.c
1 /*
2  * Implementation of the Microsoft Installer (msi.dll)
3  *
4  * Copyright 2002 Mike McCormack for CodeWeavers
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  */
20
21 #include <stdarg.h>
22
23 #include "windef.h"
24 #include "winbase.h"
25 #include "winerror.h"
26 #include "wine/debug.h"
27 #include "msi.h"
28 #include "msiquery.h"
29 #include "objbase.h"
30 #include "objidl.h"
31 #include "msipriv.h"
32 #include "winnls.h"
33
34 #include "query.h"
35
36 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
37
38
39 /* below is the query interface to a table */
40
41 typedef struct tagMSIWHEREVIEW
42 {
43     MSIVIEW        view;
44     MSIDATABASE   *db;
45     MSIVIEW       *table;
46     UINT           row_count;
47     UINT          *reorder;
48     struct expr   *cond;
49 } MSIWHEREVIEW;
50
51 static UINT WHERE_fetch_int( struct tagMSIVIEW *view, UINT row, UINT col, UINT *val )
52 {
53     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
54
55     TRACE("%p %d %d %p\n", wv, row, col, val );
56
57     if( !wv->table )
58         return ERROR_FUNCTION_FAILED;
59
60     if( row > wv->row_count )
61         return ERROR_NO_MORE_ITEMS;
62
63     row = wv->reorder[ row ];
64
65     return wv->table->ops->fetch_int( wv->table, row, col, val );
66 }
67
68 static UINT WHERE_fetch_stream( struct tagMSIVIEW *view, UINT row, UINT col, IStream **stm )
69 {
70     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
71
72     TRACE("%p %d %d %p\n", wv, row, col, stm );
73
74     if( !wv->table )
75         return ERROR_FUNCTION_FAILED;
76
77     if( row > wv->row_count )
78         return ERROR_NO_MORE_ITEMS;
79
80     row = wv->reorder[ row ];
81
82     return wv->table->ops->fetch_stream( wv->table, row, col, stm );
83 }
84
85 static UINT WHERE_set_row( struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UINT mask )
86 {
87     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
88
89     TRACE("%p %d %p %08x\n", wv, row, rec, mask );
90
91     if( !wv->table )
92          return ERROR_FUNCTION_FAILED;
93
94     if( row > wv->row_count )
95         return ERROR_NO_MORE_ITEMS;
96
97     row = wv->reorder[ row ];
98
99     return wv->table->ops->set_row( wv->table, row, rec, mask );
100 }
101
102 static INT INT_evaluate_binary( INT lval, UINT op, INT rval )
103 {
104     switch( op )
105     {
106     case OP_EQ:
107         return ( lval == rval );
108     case OP_AND:
109         return ( lval && rval );
110     case OP_OR:
111         return ( lval || rval );
112     case OP_GT:
113         return ( lval > rval );
114     case OP_LT:
115         return ( lval < rval );
116     case OP_LE:
117         return ( lval <= rval );
118     case OP_GE:
119         return ( lval >= rval );
120     case OP_NE:
121         return ( lval != rval );
122     default:
123         ERR("Unknown operator %d\n", op );
124     }
125     return 0;
126 }
127
128 static INT INT_evaluate_unary( INT lval, UINT op )
129 {
130     switch( op )
131     {
132     case OP_ISNULL:
133         return ( !lval );
134     case OP_NOTNULL:
135         return ( lval );
136     default:
137         ERR("Unknown operator %d\n", op );
138     }
139     return 0;
140 }
141
142 static const WCHAR *STRING_evaluate( const string_table *st,
143               MSIVIEW *table, UINT row, const struct expr *expr, const MSIRECORD *record )
144 {
145     UINT val = 0, r;
146
147     switch( expr->type )
148     {
149     case EXPR_COL_NUMBER_STRING:
150         r = table->ops->fetch_int( table, row, expr->u.col_number, &val );
151         if( r != ERROR_SUCCESS )
152             return NULL;
153         return msi_string_lookup_id( st, val );
154
155     case EXPR_SVAL:
156         return expr->u.sval;
157
158     case EXPR_WILDCARD:
159         return MSI_RecordGetString( record, 1 );
160
161     default:
162         ERR("Invalid expression type\n");
163         break;
164     }
165     return NULL;
166 }
167
168 static UINT STRCMP_Evaluate( const string_table *st, MSIVIEW *table, UINT row,
169                              const struct expr *cond, INT *val, const MSIRECORD *record )
170 {
171     int sr;
172     const WCHAR *l_str, *r_str;
173
174     l_str = STRING_evaluate( st, table, row, cond->u.expr.left, record );
175     r_str = STRING_evaluate( st, table, row, cond->u.expr.right, record );
176     if( l_str == r_str )
177         sr = 0;
178     else if( l_str && ! r_str )
179         sr = 1;
180     else if( r_str && ! l_str )
181         sr = -1;
182     else
183         sr = lstrcmpW( l_str, r_str );
184
185     *val = ( cond->u.expr.op == OP_EQ && ( sr == 0 ) ) ||
186            ( cond->u.expr.op == OP_LT && ( sr < 0 ) ) ||
187            ( cond->u.expr.op == OP_GT && ( sr > 0 ) );
188
189     return ERROR_SUCCESS;
190 }
191
192 static UINT WHERE_evaluate( MSIDATABASE *db, MSIVIEW *table, UINT row, 
193                              const struct expr *cond, INT *val, MSIRECORD *record )
194 {
195     UINT r, tval;
196     INT lval, rval;
197
198     if( !cond )
199         return ERROR_SUCCESS;
200
201     switch( cond->type )
202     {
203     case EXPR_COL_NUMBER:
204         r = table->ops->fetch_int( table, row, cond->u.col_number, &tval );
205         *val = tval - 0x8000;
206         return ERROR_SUCCESS;
207
208     case EXPR_COL_NUMBER32:
209         r = table->ops->fetch_int( table, row, cond->u.col_number, &tval );
210         *val = tval - 0x80000000;
211         return r;
212
213     case EXPR_UVAL:
214         *val = cond->u.uval;
215         return ERROR_SUCCESS;
216
217     case EXPR_COMPLEX:
218         r = WHERE_evaluate( db, table, row, cond->u.expr.left, &lval, record );
219         if( r != ERROR_SUCCESS )
220             return r;
221         r = WHERE_evaluate( db, table, row, cond->u.expr.right, &rval, record );
222         if( r != ERROR_SUCCESS )
223             return r;
224         *val = INT_evaluate_binary( lval, cond->u.expr.op, rval );
225         return ERROR_SUCCESS;
226
227     case EXPR_UNARY:
228         r = table->ops->fetch_int( table, row, cond->u.expr.left->u.col_number, &tval );
229         if( r != ERROR_SUCCESS )
230             return r;
231         *val = INT_evaluate_unary( tval, cond->u.expr.op );
232         return ERROR_SUCCESS;
233
234     case EXPR_STRCMP:
235         return STRCMP_Evaluate( db->strings, table, row, cond, val, record );
236
237     case EXPR_WILDCARD:
238         *val = MSI_RecordGetInteger( record, 1 );
239         return ERROR_SUCCESS;
240
241     default:
242         ERR("Invalid expression type\n");
243         break;
244     }
245
246     return ERROR_SUCCESS;
247
248 }
249
250 static UINT WHERE_execute( struct tagMSIVIEW *view, MSIRECORD *record )
251 {
252     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
253     UINT count = 0, r, i;
254     INT val;
255     MSIVIEW *table = wv->table;
256
257     TRACE("%p %p\n", wv, record);
258
259     if( !table )
260          return ERROR_FUNCTION_FAILED;
261
262     r = table->ops->execute( table, record );
263     if( r != ERROR_SUCCESS )
264         return r;
265
266     r = table->ops->get_dimensions( table, &count, NULL );
267     if( r != ERROR_SUCCESS )
268         return r;
269
270     msi_free( wv->reorder );
271     wv->reorder = msi_alloc( count*sizeof(UINT) );
272     if( !wv->reorder )
273         return ERROR_FUNCTION_FAILED;
274
275     wv->row_count = 0;
276     if (wv->cond->type == EXPR_STRCMP)
277     {
278         MSIITERHANDLE handle = NULL;
279         UINT row, value, col;
280         struct expr *col_cond = wv->cond->u.expr.left;
281         struct expr *val_cond = wv->cond->u.expr.right;
282
283         /* swap conditionals */
284         if (col_cond->type != EXPR_COL_NUMBER_STRING)
285         {
286             val_cond = wv->cond->u.expr.left;
287             col_cond = wv->cond->u.expr.right;
288         }
289
290         if ((col_cond->type == EXPR_COL_NUMBER_STRING) && (val_cond->type == EXPR_SVAL))
291         {
292             col = col_cond->u.col_number;
293             /* special case for "" - translate it into nil */
294             if (!val_cond->u.sval[0])
295                 value = 0;
296             else
297             {
298                 r = msi_string2idW(wv->db->strings, val_cond->u.sval, &value);
299                 if (r != ERROR_SUCCESS)
300                 {
301                     TRACE("no id for %s, assuming it doesn't exist in the table\n", debugstr_w(wv->cond->u.expr.right->u.sval));
302                     return ERROR_SUCCESS;
303                 }
304             }
305
306             do
307             {
308                 r = table->ops->find_matching_rows(table, col, value, &row, &handle);
309                 if (r == ERROR_SUCCESS)
310                     wv->reorder[ wv->row_count ++ ] = row;
311             } while (r == ERROR_SUCCESS);
312
313             if (r == ERROR_NO_MORE_ITEMS)
314                 return ERROR_SUCCESS;
315             else
316                 return r;
317         }
318         /* else fallback to slow case */
319     }
320
321     for( i=0; i<count; i++ )
322     {
323         val = 0;
324         r = WHERE_evaluate( wv->db, table, i, wv->cond, &val, record );
325         if( r != ERROR_SUCCESS )
326             return r;
327         if( val )
328             wv->reorder[ wv->row_count ++ ] = i;
329     }
330
331     return ERROR_SUCCESS;
332 }
333
334 static UINT WHERE_close( struct tagMSIVIEW *view )
335 {
336     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
337
338     TRACE("%p\n", wv );
339
340     if( !wv->table )
341          return ERROR_FUNCTION_FAILED;
342
343     msi_free( wv->reorder );
344     wv->reorder = NULL;
345
346     return wv->table->ops->close( wv->table );
347 }
348
349 static UINT WHERE_get_dimensions( struct tagMSIVIEW *view, UINT *rows, UINT *cols )
350 {
351     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
352
353     TRACE("%p %p %p\n", wv, rows, cols );
354
355     if( !wv->table )
356          return ERROR_FUNCTION_FAILED;
357
358     if( rows )
359     {
360         if( !wv->reorder )
361             return ERROR_FUNCTION_FAILED;
362         *rows = wv->row_count;
363     }
364
365     return wv->table->ops->get_dimensions( wv->table, NULL, cols );
366 }
367
368 static UINT WHERE_get_column_info( struct tagMSIVIEW *view,
369                 UINT n, LPWSTR *name, UINT *type )
370 {
371     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
372
373     TRACE("%p %d %p %p\n", wv, n, name, type );
374
375     if( !wv->table )
376          return ERROR_FUNCTION_FAILED;
377
378     return wv->table->ops->get_column_info( wv->table, n, name, type );
379 }
380
381 static UINT WHERE_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode,
382                 MSIRECORD *rec )
383 {
384     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
385
386     TRACE("%p %d %p\n", wv, eModifyMode, rec );
387
388     if( !wv->table )
389          return ERROR_FUNCTION_FAILED;
390
391     return wv->table->ops->modify( wv->table, eModifyMode, rec );
392 }
393
394 static UINT WHERE_delete( struct tagMSIVIEW *view )
395 {
396     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
397
398     TRACE("%p\n", wv );
399
400     if( wv->table )
401         wv->table->ops->delete( wv->table );
402     wv->table = 0;
403
404     msi_free( wv->reorder );
405     wv->reorder = NULL;
406     wv->row_count = 0;
407
408     msiobj_release( &wv->db->hdr );
409     msi_free( wv );
410
411     return ERROR_SUCCESS;
412 }
413
414 static UINT WHERE_find_matching_rows( struct tagMSIVIEW *view, UINT col,
415     UINT val, UINT *row, MSIITERHANDLE *handle )
416 {
417     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
418     UINT r;
419
420     TRACE("%p, %d, %u, %p\n", view, col, val, *handle);
421
422     if( !wv->table )
423          return ERROR_FUNCTION_FAILED;
424
425     r = wv->table->ops->find_matching_rows( wv->table, col, val, row, handle );
426
427     if( *row > wv->row_count )
428         return ERROR_NO_MORE_ITEMS;
429
430     *row = wv->reorder[ *row ];
431
432     return r;
433 }
434
435
436 static const MSIVIEWOPS where_ops =
437 {
438     WHERE_fetch_int,
439     WHERE_fetch_stream,
440     WHERE_set_row,
441     NULL,
442     NULL,
443     WHERE_execute,
444     WHERE_close,
445     WHERE_get_dimensions,
446     WHERE_get_column_info,
447     WHERE_modify,
448     WHERE_delete,
449     WHERE_find_matching_rows,
450     NULL,
451     NULL,
452     NULL,
453     NULL,
454 };
455
456 static UINT WHERE_VerifyCondition( MSIDATABASE *db, MSIVIEW *table, struct expr *cond,
457                                    UINT *valid )
458 {
459     UINT r, val = 0;
460
461     switch( cond->type )
462     {
463     case EXPR_COLUMN:
464         r = VIEW_find_column( table, cond->u.column, &val );
465         if( r == ERROR_SUCCESS )
466         {
467             UINT type = 0;
468             r = table->ops->get_column_info( table, val, NULL, &type );
469             if( r == ERROR_SUCCESS )
470             {
471                 if (type&MSITYPE_STRING)
472                     cond->type = EXPR_COL_NUMBER_STRING;
473                 else if ((type&0xff) == 4)
474                     cond->type = EXPR_COL_NUMBER32;
475                 else
476                     cond->type = EXPR_COL_NUMBER;
477                 cond->u.col_number = val;
478                 *valid = 1;
479             }
480             else
481                 *valid = 0;
482         }
483         else
484         {
485             *valid = 0;
486             ERR("Couldn't find column %s\n", debugstr_w( cond->u.column ) );
487         }
488         break;
489     case EXPR_COMPLEX:
490         r = WHERE_VerifyCondition( db, table, cond->u.expr.left, valid );
491         if( r != ERROR_SUCCESS )
492             return r;
493         if( !*valid )
494             return ERROR_SUCCESS;
495         r = WHERE_VerifyCondition( db, table, cond->u.expr.right, valid );
496         if( r != ERROR_SUCCESS )
497             return r;
498
499         /* check the type of the comparison */
500         if( ( cond->u.expr.left->type == EXPR_SVAL ) ||
501             ( cond->u.expr.left->type == EXPR_COL_NUMBER_STRING ) ||
502             ( cond->u.expr.right->type == EXPR_SVAL ) ||
503             ( cond->u.expr.right->type == EXPR_COL_NUMBER_STRING ) )
504         {
505             switch( cond->u.expr.op )
506             {
507             case OP_EQ:
508             case OP_GT:
509             case OP_LT:
510                 break;
511             default:
512                 *valid = FALSE;
513                 return ERROR_INVALID_PARAMETER;
514             }
515
516             /* FIXME: check we're comparing a string to a column */
517
518             cond->type = EXPR_STRCMP;
519         }
520
521         break;
522     case EXPR_UNARY:
523         if ( cond->u.expr.left->type != EXPR_COLUMN )
524         {
525             *valid = FALSE;
526             return ERROR_INVALID_PARAMETER;
527         }
528         r = WHERE_VerifyCondition( db, table, cond->u.expr.left, valid );
529         if( r != ERROR_SUCCESS )
530             return r;
531         break;
532     case EXPR_IVAL:
533         *valid = 1;
534         cond->type = EXPR_UVAL;
535         cond->u.uval = cond->u.ival;
536         break;
537     case EXPR_WILDCARD:
538         *valid = 1;
539         break;
540     case EXPR_SVAL:
541         *valid = 1;
542         break;
543     default:
544         ERR("Invalid expression type\n");
545         *valid = 0;
546         break;
547     }
548
549     return ERROR_SUCCESS;
550 }
551
552 UINT WHERE_CreateView( MSIDATABASE *db, MSIVIEW **view, MSIVIEW *table,
553                        struct expr *cond )
554 {
555     MSIWHEREVIEW *wv = NULL;
556     UINT count = 0, r, valid = 0;
557
558     TRACE("%p\n", table );
559
560     r = table->ops->get_dimensions( table, NULL, &count );
561     if( r != ERROR_SUCCESS )
562     {
563         ERR("can't get table dimensions\n");
564         return r;
565     }
566
567     if( cond )
568     {
569         r = WHERE_VerifyCondition( db, table, cond, &valid );
570         if( r != ERROR_SUCCESS )
571             return r;
572         if( !valid )
573             return ERROR_FUNCTION_FAILED;
574     }
575
576     wv = msi_alloc_zero( sizeof *wv );
577     if( !wv )
578         return ERROR_FUNCTION_FAILED;
579     
580     /* fill the structure */
581     wv->view.ops = &where_ops;
582     msiobj_addref( &db->hdr );
583     wv->db = db;
584     wv->table = table;
585     wv->row_count = 0;
586     wv->reorder = NULL;
587     wv->cond = cond;
588     *view = (MSIVIEW*) wv;
589
590     return ERROR_SUCCESS;
591 }