msi: Represent table data as bytes instead of shorts.
[wine] / dlls / msi / where.c
1 /*
2  * Implementation of the Microsoft Installer (msi.dll)
3  *
4  * Copyright 2002 Mike McCormack for CodeWeavers
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  */
20
21 #include <stdarg.h>
22
23 #include "windef.h"
24 #include "winbase.h"
25 #include "winerror.h"
26 #include "wine/debug.h"
27 #include "msi.h"
28 #include "msiquery.h"
29 #include "objbase.h"
30 #include "objidl.h"
31 #include "msipriv.h"
32 #include "winnls.h"
33
34 #include "query.h"
35
36 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
37
38
39 /* below is the query interface to a table */
40
41 typedef struct tagMSIWHEREVIEW
42 {
43     MSIVIEW        view;
44     MSIDATABASE   *db;
45     MSIVIEW       *table;
46     UINT           row_count;
47     UINT          *reorder;
48     struct expr   *cond;
49 } MSIWHEREVIEW;
50
51 static UINT WHERE_fetch_int( struct tagMSIVIEW *view, UINT row, UINT col, UINT *val )
52 {
53     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
54
55     TRACE("%p %d %d %p\n", wv, row, col, val );
56
57     if( !wv->table )
58         return ERROR_FUNCTION_FAILED;
59
60     if( row > wv->row_count )
61         return ERROR_NO_MORE_ITEMS;
62
63     row = wv->reorder[ row ];
64
65     return wv->table->ops->fetch_int( wv->table, row, col, val );
66 }
67
68 static UINT WHERE_fetch_stream( struct tagMSIVIEW *view, UINT row, UINT col, IStream **stm )
69 {
70     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
71
72     TRACE("%p %d %d %p\n", wv, row, col, stm );
73
74     if( !wv->table )
75         return ERROR_FUNCTION_FAILED;
76
77     if( row > wv->row_count )
78         return ERROR_NO_MORE_ITEMS;
79
80     row = wv->reorder[ row ];
81
82     return wv->table->ops->fetch_stream( wv->table, row, col, stm );
83 }
84
85 static UINT WHERE_set_row( struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UINT mask )
86 {
87     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
88
89     TRACE("%p %d %p %08x\n", wv, row, rec, mask );
90
91     if( !wv->table )
92          return ERROR_FUNCTION_FAILED;
93
94     if( row > wv->row_count )
95         return ERROR_NO_MORE_ITEMS;
96
97     row = wv->reorder[ row ];
98
99     return wv->table->ops->set_row( wv->table, row, rec, mask );
100 }
101
102 static INT INT_evaluate_binary( INT lval, UINT op, INT rval )
103 {
104     switch( op )
105     {
106     case OP_EQ:
107         return ( lval == rval );
108     case OP_AND:
109         return ( lval && rval );
110     case OP_OR:
111         return ( lval || rval );
112     case OP_GT:
113         return ( lval > rval );
114     case OP_LT:
115         return ( lval < rval );
116     case OP_LE:
117         return ( lval <= rval );
118     case OP_GE:
119         return ( lval >= rval );
120     case OP_NE:
121         return ( lval != rval );
122     default:
123         ERR("Unknown operator %d\n", op );
124     }
125     return 0;
126 }
127
128 static INT INT_evaluate_unary( INT lval, UINT op )
129 {
130     switch( op )
131     {
132     case OP_ISNULL:
133         return ( !lval );
134     case OP_NOTNULL:
135         return ( lval );
136     default:
137         ERR("Unknown operator %d\n", op );
138     }
139     return 0;
140 }
141
142 static const WCHAR *STRING_evaluate( string_table *st,
143               MSIVIEW *table, UINT row, struct expr *expr, MSIRECORD *record )
144 {
145     UINT val = 0, r;
146
147     switch( expr->type )
148     {
149     case EXPR_COL_NUMBER_STRING:
150         r = table->ops->fetch_int( table, row, expr->u.col_number, &val );
151         if( r != ERROR_SUCCESS )
152             return NULL;
153         return msi_string_lookup_id( st, val );
154
155     case EXPR_SVAL:
156         return expr->u.sval;
157
158     case EXPR_WILDCARD:
159         return MSI_RecordGetString( record, 1 );
160
161     default:
162         ERR("Invalid expression type\n");
163         break;
164     }
165     return NULL;
166 }
167
168 static UINT STRCMP_Evaluate( string_table *st, MSIVIEW *table, UINT row, 
169                              struct expr *cond, INT *val, MSIRECORD *record )
170 {
171     int sr;
172     const WCHAR *l_str, *r_str;
173
174     l_str = STRING_evaluate( st, table, row, cond->u.expr.left, record );
175     r_str = STRING_evaluate( st, table, row, cond->u.expr.right, record );
176     if( l_str == r_str )
177         sr = 0;
178     else if( l_str && ! r_str )
179         sr = 1;
180     else if( r_str && ! l_str )
181         sr = -1;
182     else
183         sr = lstrcmpW( l_str, r_str );
184
185     *val = ( cond->u.expr.op == OP_EQ && ( sr == 0 ) ) ||
186            ( cond->u.expr.op == OP_LT && ( sr < 0 ) ) ||
187            ( cond->u.expr.op == OP_GT && ( sr > 0 ) );
188
189     return ERROR_SUCCESS;
190 }
191
192 static UINT WHERE_evaluate( MSIDATABASE *db, MSIVIEW *table, UINT row, 
193                              struct expr *cond, INT *val, MSIRECORD *record )
194 {
195     UINT r, tval;
196     INT lval, rval;
197
198     if( !cond )
199         return ERROR_SUCCESS;
200
201     switch( cond->type )
202     {
203     case EXPR_COL_NUMBER:
204         r = table->ops->fetch_int( table, row, cond->u.col_number, &tval );
205         *val = tval - 0x8000;
206         return ERROR_SUCCESS;
207
208     case EXPR_COL_NUMBER32:
209         r = table->ops->fetch_int( table, row, cond->u.col_number, &tval );
210         *val = tval - 0x80000000;
211         return r;
212
213     case EXPR_UVAL:
214         *val = cond->u.uval;
215         return ERROR_SUCCESS;
216
217     case EXPR_COMPLEX:
218         r = WHERE_evaluate( db, table, row, cond->u.expr.left, &lval, record );
219         if( r != ERROR_SUCCESS )
220             return r;
221         r = WHERE_evaluate( db, table, row, cond->u.expr.right, &rval, record );
222         if( r != ERROR_SUCCESS )
223             return r;
224         *val = INT_evaluate_binary( lval, cond->u.expr.op, rval );
225         return ERROR_SUCCESS;
226
227     case EXPR_UNARY:
228         r = table->ops->fetch_int( table, row, cond->u.expr.left->u.col_number, &tval );
229         if( r != ERROR_SUCCESS )
230             return r;
231         *val = INT_evaluate_unary( tval, cond->u.expr.op );
232         return ERROR_SUCCESS;
233
234     case EXPR_STRCMP:
235         return STRCMP_Evaluate( db->strings, table, row, cond, val, record );
236
237     case EXPR_WILDCARD:
238         *val = MSI_RecordGetInteger( record, 1 );
239         return ERROR_SUCCESS;
240
241     default:
242         ERR("Invalid expression type\n");
243         break;
244     }
245
246     return ERROR_SUCCESS;
247
248 }
249
250 static UINT WHERE_execute( struct tagMSIVIEW *view, MSIRECORD *record )
251 {
252     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
253     UINT count = 0, r, i;
254     INT val;
255     MSIVIEW *table = wv->table;
256
257     TRACE("%p %p\n", wv, record);
258
259     if( !table )
260          return ERROR_FUNCTION_FAILED;
261
262     r = table->ops->execute( table, record );
263     if( r != ERROR_SUCCESS )
264         return r;
265
266     r = table->ops->get_dimensions( table, &count, NULL );
267     if( r != ERROR_SUCCESS )
268         return r;
269
270     msi_free( wv->reorder );
271     wv->reorder = msi_alloc( count*sizeof(UINT) );
272     if( !wv->reorder )
273         return ERROR_FUNCTION_FAILED;
274
275     wv->row_count = 0;
276     if (wv->cond->type == EXPR_STRCMP)
277     {
278         MSIITERHANDLE handle = NULL;
279         UINT row, value, col;
280         struct expr *col_cond = wv->cond->u.expr.left;
281         struct expr *val_cond = wv->cond->u.expr.right;
282
283         /* swap conditionals */
284         if (col_cond->type != EXPR_COL_NUMBER_STRING)
285         {
286             val_cond = wv->cond->u.expr.left;
287             col_cond = wv->cond->u.expr.right;
288         }
289
290         if ((col_cond->type == EXPR_COL_NUMBER_STRING) && (val_cond->type == EXPR_SVAL))
291         {
292             col = col_cond->u.col_number;
293             /* special case for "" - translate it into nil */
294             if (!val_cond->u.sval[0])
295                 value = 0;
296             else
297             {
298                 r = msi_string2idW(wv->db->strings, val_cond->u.sval, &value);
299                 if (r != ERROR_SUCCESS)
300                 {
301                     TRACE("no id for %s, assuming it doesn't exist in the table\n", debugstr_w(wv->cond->u.expr.right->u.sval));
302                     return ERROR_SUCCESS;
303                 }
304             }
305
306             do
307             {
308                 r = table->ops->find_matching_rows(table, col, value, &row, &handle);
309                 if (r == ERROR_SUCCESS)
310                     wv->reorder[ wv->row_count ++ ] = row;
311             } while (r == ERROR_SUCCESS);
312
313             if (r == ERROR_NO_MORE_ITEMS)
314                 return ERROR_SUCCESS;
315             else
316                 return r;
317         }
318         /* else fallback to slow case */
319     }
320
321     for( i=0; i<count; i++ )
322     {
323         val = 0;
324         r = WHERE_evaluate( wv->db, table, i, wv->cond, &val, record );
325         if( r != ERROR_SUCCESS )
326             return r;
327         if( val )
328             wv->reorder[ wv->row_count ++ ] = i;
329     }
330
331     return ERROR_SUCCESS;
332 }
333
334 static UINT WHERE_close( struct tagMSIVIEW *view )
335 {
336     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
337
338     TRACE("%p\n", wv );
339
340     if( !wv->table )
341          return ERROR_FUNCTION_FAILED;
342
343     msi_free( wv->reorder );
344     wv->reorder = NULL;
345
346     return wv->table->ops->close( wv->table );
347 }
348
349 static UINT WHERE_get_dimensions( struct tagMSIVIEW *view, UINT *rows, UINT *cols )
350 {
351     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
352
353     TRACE("%p %p %p\n", wv, rows, cols );
354
355     if( !wv->table )
356          return ERROR_FUNCTION_FAILED;
357
358     if( rows )
359     {
360         if( !wv->reorder )
361             return ERROR_FUNCTION_FAILED;
362         *rows = wv->row_count;
363     }
364
365     return wv->table->ops->get_dimensions( wv->table, NULL, cols );
366 }
367
368 static UINT WHERE_get_column_info( struct tagMSIVIEW *view,
369                 UINT n, LPWSTR *name, UINT *type )
370 {
371     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
372
373     TRACE("%p %d %p %p\n", wv, n, name, type );
374
375     if( !wv->table )
376          return ERROR_FUNCTION_FAILED;
377
378     return wv->table->ops->get_column_info( wv->table, n, name, type );
379 }
380
381 static UINT WHERE_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode,
382                 MSIRECORD *rec )
383 {
384     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
385
386     TRACE("%p %d %p\n", wv, eModifyMode, rec );
387
388     if( !wv->table )
389          return ERROR_FUNCTION_FAILED;
390
391     return wv->table->ops->modify( wv->table, eModifyMode, rec );
392 }
393
394 static UINT WHERE_delete( struct tagMSIVIEW *view )
395 {
396     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
397
398     TRACE("%p\n", wv );
399
400     if( wv->table )
401         wv->table->ops->delete( wv->table );
402     wv->table = 0;
403
404     msi_free( wv->reorder );
405     wv->reorder = NULL;
406     wv->row_count = 0;
407
408     msiobj_release( &wv->db->hdr );
409     msi_free( wv );
410
411     return ERROR_SUCCESS;
412 }
413
414 static UINT WHERE_find_matching_rows( struct tagMSIVIEW *view, UINT col,
415     UINT val, UINT *row, MSIITERHANDLE *handle )
416 {
417     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
418     UINT r;
419
420     TRACE("%p, %d, %u, %p\n", view, col, val, *handle);
421
422     if( !wv->table )
423          return ERROR_FUNCTION_FAILED;
424
425     r = wv->table->ops->find_matching_rows( wv->table, col, val, row, handle );
426
427     if( *row > wv->row_count )
428         return ERROR_NO_MORE_ITEMS;
429
430     *row = wv->reorder[ *row ];
431
432     return r;
433 }
434
435
436 static const MSIVIEWOPS where_ops =
437 {
438     WHERE_fetch_int,
439     WHERE_fetch_stream,
440     WHERE_set_row,
441     NULL,
442     WHERE_execute,
443     WHERE_close,
444     WHERE_get_dimensions,
445     WHERE_get_column_info,
446     WHERE_modify,
447     WHERE_delete,
448     WHERE_find_matching_rows
449 };
450
451 static UINT WHERE_VerifyCondition( MSIDATABASE *db, MSIVIEW *table, struct expr *cond,
452                                    UINT *valid )
453 {
454     UINT r, val = 0;
455
456     switch( cond->type )
457     {
458     case EXPR_COLUMN:
459         r = VIEW_find_column( table, cond->u.column, &val );
460         if( r == ERROR_SUCCESS )
461         {
462             UINT type = 0;
463             r = table->ops->get_column_info( table, val, NULL, &type );
464             if( r == ERROR_SUCCESS )
465             {
466                 if (type&MSITYPE_STRING)
467                     cond->type = EXPR_COL_NUMBER_STRING;
468                 else if ((type&0xff) == 4)
469                     cond->type = EXPR_COL_NUMBER32;
470                 else
471                     cond->type = EXPR_COL_NUMBER;
472                 cond->u.col_number = val;
473                 *valid = 1;
474             }
475             else
476                 *valid = 0;
477         }
478         else
479         {
480             *valid = 0;
481             ERR("Couldn't find column %s\n", debugstr_w( cond->u.column ) );
482         }
483         break;
484     case EXPR_COMPLEX:
485         r = WHERE_VerifyCondition( db, table, cond->u.expr.left, valid );
486         if( r != ERROR_SUCCESS )
487             return r;
488         if( !*valid )
489             return ERROR_SUCCESS;
490         r = WHERE_VerifyCondition( db, table, cond->u.expr.right, valid );
491         if( r != ERROR_SUCCESS )
492             return r;
493
494         /* check the type of the comparison */
495         if( ( cond->u.expr.left->type == EXPR_SVAL ) ||
496             ( cond->u.expr.left->type == EXPR_COL_NUMBER_STRING ) ||
497             ( cond->u.expr.right->type == EXPR_SVAL ) ||
498             ( cond->u.expr.right->type == EXPR_COL_NUMBER_STRING ) )
499         {
500             switch( cond->u.expr.op )
501             {
502             case OP_EQ:
503             case OP_GT:
504             case OP_LT:
505                 break;
506             default:
507                 *valid = FALSE;
508                 return ERROR_INVALID_PARAMETER;
509             }
510
511             /* FIXME: check we're comparing a string to a column */
512
513             cond->type = EXPR_STRCMP;
514         }
515
516         break;
517     case EXPR_UNARY:
518         if ( cond->u.expr.left->type != EXPR_COLUMN )
519         {
520             *valid = FALSE;
521             return ERROR_INVALID_PARAMETER;
522         }
523         r = WHERE_VerifyCondition( db, table, cond->u.expr.left, valid );
524         if( r != ERROR_SUCCESS )
525             return r;
526         break;
527     case EXPR_IVAL:
528         *valid = 1;
529         cond->type = EXPR_UVAL;
530         cond->u.uval = cond->u.ival;
531         break;
532     case EXPR_WILDCARD:
533         *valid = 1;
534         break;
535     case EXPR_SVAL:
536         *valid = 1;
537         break;
538     default:
539         ERR("Invalid expression type\n");
540         *valid = 0;
541         break;
542     }
543
544     return ERROR_SUCCESS;
545 }
546
547 UINT WHERE_CreateView( MSIDATABASE *db, MSIVIEW **view, MSIVIEW *table,
548                        struct expr *cond )
549 {
550     MSIWHEREVIEW *wv = NULL;
551     UINT count = 0, r, valid = 0;
552
553     TRACE("%p\n", table );
554
555     r = table->ops->get_dimensions( table, NULL, &count );
556     if( r != ERROR_SUCCESS )
557     {
558         ERR("can't get table dimensions\n");
559         return r;
560     }
561
562     if( cond )
563     {
564         r = WHERE_VerifyCondition( db, table, cond, &valid );
565         if( r != ERROR_SUCCESS )
566             return r;
567         if( !valid )
568             return ERROR_FUNCTION_FAILED;
569     }
570
571     wv = msi_alloc_zero( sizeof *wv );
572     if( !wv )
573         return ERROR_FUNCTION_FAILED;
574     
575     /* fill the structure */
576     wv->view.ops = &where_ops;
577     msiobj_addref( &db->hdr );
578     wv->db = db;
579     wv->table = table;
580     wv->row_count = 0;
581     wv->reorder = NULL;
582     wv->cond = cond;
583     *view = (MSIVIEW*) wv;
584
585     return ERROR_SUCCESS;
586 }