msi: Make the WriteEnvironmentStrings handling of [~] a bit better.
[wine] / dlls / msi / where.c
1 /*
2  * Implementation of the Microsoft Installer (msi.dll)
3  *
4  * Copyright 2002 Mike McCormack for CodeWeavers
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  */
20
21 #include <stdarg.h>
22
23 #include "windef.h"
24 #include "winbase.h"
25 #include "winerror.h"
26 #include "wine/debug.h"
27 #include "msi.h"
28 #include "msiquery.h"
29 #include "objbase.h"
30 #include "objidl.h"
31 #include "msipriv.h"
32 #include "winnls.h"
33
34 #include "query.h"
35
36 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
37
38
39 /* below is the query interface to a table */
40
41 typedef struct tagMSIWHEREVIEW
42 {
43     MSIVIEW        view;
44     MSIDATABASE   *db;
45     MSIVIEW       *table;
46     UINT           row_count;
47     UINT          *reorder;
48     struct expr   *cond;
49     UINT           rec_index;
50 } MSIWHEREVIEW;
51
52 static UINT WHERE_fetch_int( struct tagMSIVIEW *view, UINT row, UINT col, UINT *val )
53 {
54     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
55
56     TRACE("%p %d %d %p\n", wv, row, col, val );
57
58     if( !wv->table )
59         return ERROR_FUNCTION_FAILED;
60
61     if( row > wv->row_count )
62         return ERROR_NO_MORE_ITEMS;
63
64     row = wv->reorder[ row ];
65
66     return wv->table->ops->fetch_int( wv->table, row, col, val );
67 }
68
69 static UINT WHERE_fetch_stream( struct tagMSIVIEW *view, UINT row, UINT col, IStream **stm )
70 {
71     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
72
73     TRACE("%p %d %d %p\n", wv, row, col, stm );
74
75     if( !wv->table )
76         return ERROR_FUNCTION_FAILED;
77
78     if( row > wv->row_count )
79         return ERROR_NO_MORE_ITEMS;
80
81     row = wv->reorder[ row ];
82
83     return wv->table->ops->fetch_stream( wv->table, row, col, stm );
84 }
85
86 static UINT WHERE_get_row( struct tagMSIVIEW *view, UINT row, MSIRECORD **rec )
87 {
88     MSIWHEREVIEW *wv = (MSIWHEREVIEW *)view;
89
90     TRACE("%p %d %p\n", wv, row, rec );
91
92     if (!wv->table)
93         return ERROR_FUNCTION_FAILED;
94
95     if (row > wv->row_count)
96         return ERROR_NO_MORE_ITEMS;
97
98     row = wv->reorder[row];
99
100     return wv->table->ops->get_row(view, row, rec);
101 }
102
103 static UINT WHERE_set_row( struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UINT mask )
104 {
105     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
106
107     TRACE("%p %d %p %08x\n", wv, row, rec, mask );
108
109     if( !wv->table )
110          return ERROR_FUNCTION_FAILED;
111
112     if( row > wv->row_count )
113         return ERROR_NO_MORE_ITEMS;
114
115     row = wv->reorder[ row ];
116
117     return wv->table->ops->set_row( wv->table, row, rec, mask );
118 }
119
120 static INT INT_evaluate_binary( INT lval, UINT op, INT rval )
121 {
122     switch( op )
123     {
124     case OP_EQ:
125         return ( lval == rval );
126     case OP_AND:
127         return ( lval && rval );
128     case OP_OR:
129         return ( lval || rval );
130     case OP_GT:
131         return ( lval > rval );
132     case OP_LT:
133         return ( lval < rval );
134     case OP_LE:
135         return ( lval <= rval );
136     case OP_GE:
137         return ( lval >= rval );
138     case OP_NE:
139         return ( lval != rval );
140     default:
141         ERR("Unknown operator %d\n", op );
142     }
143     return 0;
144 }
145
146 static INT INT_evaluate_unary( INT lval, UINT op )
147 {
148     switch( op )
149     {
150     case OP_ISNULL:
151         return ( !lval );
152     case OP_NOTNULL:
153         return ( lval );
154     default:
155         ERR("Unknown operator %d\n", op );
156     }
157     return 0;
158 }
159
160 static const WCHAR *STRING_evaluate( MSIWHEREVIEW *wv, UINT row,
161                                      const struct expr *expr,
162                                      const MSIRECORD *record )
163 {
164     UINT val = 0, r;
165
166     switch( expr->type )
167     {
168     case EXPR_COL_NUMBER_STRING:
169         r = wv->table->ops->fetch_int( wv->table, row, expr->u.col_number, &val );
170         if( r != ERROR_SUCCESS )
171             return NULL;
172         return msi_string_lookup_id( wv->db->strings, val );
173
174     case EXPR_SVAL:
175         return expr->u.sval;
176
177     case EXPR_WILDCARD:
178         return MSI_RecordGetString( record, ++wv->rec_index );
179
180     default:
181         ERR("Invalid expression type\n");
182         break;
183     }
184     return NULL;
185 }
186
187 static UINT STRCMP_Evaluate( MSIWHEREVIEW *wv, UINT row, const struct expr *cond,
188                              INT *val, const MSIRECORD *record )
189 {
190     int sr;
191     const WCHAR *l_str, *r_str;
192
193     l_str = STRING_evaluate( wv, row, cond->u.expr.left, record );
194     r_str = STRING_evaluate( wv, row, cond->u.expr.right, record );
195     if( l_str == r_str ||
196         ((!l_str || !*l_str) && (!r_str || !*r_str)) )
197         sr = 0;
198     else if( l_str && ! r_str )
199         sr = 1;
200     else if( r_str && ! l_str )
201         sr = -1;
202     else
203         sr = lstrcmpW( l_str, r_str );
204
205     *val = ( cond->u.expr.op == OP_EQ && ( sr == 0 ) ) ||
206            ( cond->u.expr.op == OP_LT && ( sr < 0 ) ) ||
207            ( cond->u.expr.op == OP_GT && ( sr > 0 ) );
208
209     return ERROR_SUCCESS;
210 }
211
212 static UINT WHERE_evaluate( MSIWHEREVIEW *wv, UINT row,
213                             struct expr *cond, INT *val, MSIRECORD *record )
214 {
215     UINT r, tval;
216     INT lval, rval;
217
218     if( !cond )
219         return ERROR_SUCCESS;
220
221     switch( cond->type )
222     {
223     case EXPR_COL_NUMBER:
224         r = wv->table->ops->fetch_int( wv->table, row, cond->u.col_number, &tval );
225         *val = tval - 0x8000;
226         return ERROR_SUCCESS;
227
228     case EXPR_COL_NUMBER32:
229         r = wv->table->ops->fetch_int( wv->table, row, cond->u.col_number, &tval );
230         *val = tval - 0x80000000;
231         return r;
232
233     case EXPR_UVAL:
234         *val = cond->u.uval;
235         return ERROR_SUCCESS;
236
237     case EXPR_COMPLEX:
238         r = WHERE_evaluate( wv, row, cond->u.expr.left, &lval, record );
239         if( r != ERROR_SUCCESS )
240             return r;
241         r = WHERE_evaluate( wv, row, cond->u.expr.right, &rval, record );
242         if( r != ERROR_SUCCESS )
243             return r;
244         *val = INT_evaluate_binary( lval, cond->u.expr.op, rval );
245         return ERROR_SUCCESS;
246
247     case EXPR_UNARY:
248         r = wv->table->ops->fetch_int( wv->table, row, cond->u.expr.left->u.col_number, &tval );
249         if( r != ERROR_SUCCESS )
250             return r;
251         *val = INT_evaluate_unary( tval, cond->u.expr.op );
252         return ERROR_SUCCESS;
253
254     case EXPR_STRCMP:
255         return STRCMP_Evaluate( wv, row, cond, val, record );
256
257     case EXPR_WILDCARD:
258         *val = MSI_RecordGetInteger( record, ++wv->rec_index );
259         return ERROR_SUCCESS;
260
261     default:
262         ERR("Invalid expression type\n");
263         break;
264     }
265
266     return ERROR_SUCCESS;
267
268 }
269
270 static UINT WHERE_execute( struct tagMSIVIEW *view, MSIRECORD *record )
271 {
272     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
273     UINT count = 0, r, i;
274     INT val;
275     MSIVIEW *table = wv->table;
276
277     TRACE("%p %p\n", wv, record);
278
279     if( !table )
280          return ERROR_FUNCTION_FAILED;
281
282     r = table->ops->execute( table, record );
283     if( r != ERROR_SUCCESS )
284         return r;
285
286     r = table->ops->get_dimensions( table, &count, NULL );
287     if( r != ERROR_SUCCESS )
288         return r;
289
290     msi_free( wv->reorder );
291     wv->reorder = msi_alloc( count*sizeof(UINT) );
292     if( !wv->reorder )
293         return ERROR_FUNCTION_FAILED;
294
295     wv->row_count = 0;
296     if (wv->cond->type == EXPR_STRCMP)
297     {
298         MSIITERHANDLE handle = NULL;
299         UINT row, value, col;
300         struct expr *col_cond = wv->cond->u.expr.left;
301         struct expr *val_cond = wv->cond->u.expr.right;
302
303         /* swap conditionals */
304         if (col_cond->type != EXPR_COL_NUMBER_STRING)
305         {
306             val_cond = wv->cond->u.expr.left;
307             col_cond = wv->cond->u.expr.right;
308         }
309
310         if ((col_cond->type == EXPR_COL_NUMBER_STRING) && (val_cond->type == EXPR_SVAL))
311         {
312             col = col_cond->u.col_number;
313             /* special case for "" - translate it into nil */
314             if (!val_cond->u.sval[0])
315                 value = 0;
316             else
317             {
318                 r = msi_string2idW(wv->db->strings, val_cond->u.sval, &value);
319                 if (r != ERROR_SUCCESS)
320                 {
321                     TRACE("no id for %s, assuming it doesn't exist in the table\n", debugstr_w(wv->cond->u.expr.right->u.sval));
322                     return ERROR_SUCCESS;
323                 }
324             }
325
326             do
327             {
328                 r = table->ops->find_matching_rows(table, col, value, &row, &handle);
329                 if (r == ERROR_SUCCESS)
330                     wv->reorder[ wv->row_count ++ ] = row;
331             } while (r == ERROR_SUCCESS);
332
333             if (r == ERROR_NO_MORE_ITEMS)
334                 return ERROR_SUCCESS;
335             else
336                 return r;
337         }
338         /* else fallback to slow case */
339     }
340
341     for( i=0; i<count; i++ )
342     {
343         val = 0;
344         wv->rec_index = 0;
345         r = WHERE_evaluate( wv, i, wv->cond, &val, record );
346         if( r != ERROR_SUCCESS )
347             return r;
348         if( val )
349             wv->reorder[ wv->row_count ++ ] = i;
350     }
351
352     return ERROR_SUCCESS;
353 }
354
355 static UINT WHERE_close( struct tagMSIVIEW *view )
356 {
357     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
358
359     TRACE("%p\n", wv );
360
361     if( !wv->table )
362          return ERROR_FUNCTION_FAILED;
363
364     msi_free( wv->reorder );
365     wv->reorder = NULL;
366
367     return wv->table->ops->close( wv->table );
368 }
369
370 static UINT WHERE_get_dimensions( struct tagMSIVIEW *view, UINT *rows, UINT *cols )
371 {
372     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
373
374     TRACE("%p %p %p\n", wv, rows, cols );
375
376     if( !wv->table )
377          return ERROR_FUNCTION_FAILED;
378
379     if( rows )
380     {
381         if( !wv->reorder )
382             return ERROR_FUNCTION_FAILED;
383         *rows = wv->row_count;
384     }
385
386     return wv->table->ops->get_dimensions( wv->table, NULL, cols );
387 }
388
389 static UINT WHERE_get_column_info( struct tagMSIVIEW *view,
390                 UINT n, LPWSTR *name, UINT *type )
391 {
392     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
393
394     TRACE("%p %d %p %p\n", wv, n, name, type );
395
396     if( !wv->table )
397          return ERROR_FUNCTION_FAILED;
398
399     return wv->table->ops->get_column_info( wv->table, n, name, type );
400 }
401
402 static UINT WHERE_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode,
403                           MSIRECORD *rec, UINT row )
404 {
405     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
406
407     TRACE("%p %d %p\n", wv, eModifyMode, rec );
408
409     if( !wv->table )
410          return ERROR_FUNCTION_FAILED;
411
412     return wv->table->ops->modify( wv->table, eModifyMode, rec, row );
413 }
414
415 static UINT WHERE_delete( struct tagMSIVIEW *view )
416 {
417     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
418
419     TRACE("%p\n", wv );
420
421     if( wv->table )
422         wv->table->ops->delete( wv->table );
423     wv->table = 0;
424
425     msi_free( wv->reorder );
426     wv->reorder = NULL;
427     wv->row_count = 0;
428
429     msiobj_release( &wv->db->hdr );
430     msi_free( wv );
431
432     return ERROR_SUCCESS;
433 }
434
435 static UINT WHERE_find_matching_rows( struct tagMSIVIEW *view, UINT col,
436     UINT val, UINT *row, MSIITERHANDLE *handle )
437 {
438     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
439     UINT r;
440
441     TRACE("%p, %d, %u, %p\n", view, col, val, *handle);
442
443     if( !wv->table )
444          return ERROR_FUNCTION_FAILED;
445
446     r = wv->table->ops->find_matching_rows( wv->table, col, val, row, handle );
447
448     if( *row > wv->row_count )
449         return ERROR_NO_MORE_ITEMS;
450
451     *row = wv->reorder[ *row ];
452
453     return r;
454 }
455
456
457 static const MSIVIEWOPS where_ops =
458 {
459     WHERE_fetch_int,
460     WHERE_fetch_stream,
461     WHERE_get_row,
462     WHERE_set_row,
463     NULL,
464     NULL,
465     WHERE_execute,
466     WHERE_close,
467     WHERE_get_dimensions,
468     WHERE_get_column_info,
469     WHERE_modify,
470     WHERE_delete,
471     WHERE_find_matching_rows,
472     NULL,
473     NULL,
474     NULL,
475     NULL,
476 };
477
478 static UINT WHERE_VerifyCondition( MSIDATABASE *db, MSIVIEW *table, struct expr *cond,
479                                    UINT *valid )
480 {
481     UINT r, val = 0;
482
483     switch( cond->type )
484     {
485     case EXPR_COLUMN:
486         r = VIEW_find_column( table, cond->u.column, &val );
487         if( r == ERROR_SUCCESS )
488         {
489             UINT type = 0;
490             r = table->ops->get_column_info( table, val, NULL, &type );
491             if( r == ERROR_SUCCESS )
492             {
493                 if (type&MSITYPE_STRING)
494                     cond->type = EXPR_COL_NUMBER_STRING;
495                 else if ((type&0xff) == 4)
496                     cond->type = EXPR_COL_NUMBER32;
497                 else
498                     cond->type = EXPR_COL_NUMBER;
499                 cond->u.col_number = val;
500                 *valid = 1;
501             }
502             else
503                 *valid = 0;
504         }
505         else
506         {
507             *valid = 0;
508             ERR("Couldn't find column %s\n", debugstr_w( cond->u.column ) );
509         }
510         break;
511     case EXPR_COMPLEX:
512         r = WHERE_VerifyCondition( db, table, cond->u.expr.left, valid );
513         if( r != ERROR_SUCCESS )
514             return r;
515         if( !*valid )
516             return ERROR_SUCCESS;
517         r = WHERE_VerifyCondition( db, table, cond->u.expr.right, valid );
518         if( r != ERROR_SUCCESS )
519             return r;
520
521         /* check the type of the comparison */
522         if( ( cond->u.expr.left->type == EXPR_SVAL ) ||
523             ( cond->u.expr.left->type == EXPR_COL_NUMBER_STRING ) ||
524             ( cond->u.expr.right->type == EXPR_SVAL ) ||
525             ( cond->u.expr.right->type == EXPR_COL_NUMBER_STRING ) )
526         {
527             switch( cond->u.expr.op )
528             {
529             case OP_EQ:
530             case OP_GT:
531             case OP_LT:
532                 break;
533             default:
534                 *valid = FALSE;
535                 return ERROR_INVALID_PARAMETER;
536             }
537
538             /* FIXME: check we're comparing a string to a column */
539
540             cond->type = EXPR_STRCMP;
541         }
542
543         break;
544     case EXPR_UNARY:
545         if ( cond->u.expr.left->type != EXPR_COLUMN )
546         {
547             *valid = FALSE;
548             return ERROR_INVALID_PARAMETER;
549         }
550         r = WHERE_VerifyCondition( db, table, cond->u.expr.left, valid );
551         if( r != ERROR_SUCCESS )
552             return r;
553         break;
554     case EXPR_IVAL:
555         *valid = 1;
556         cond->type = EXPR_UVAL;
557         cond->u.uval = cond->u.ival;
558         break;
559     case EXPR_WILDCARD:
560         *valid = 1;
561         break;
562     case EXPR_SVAL:
563         *valid = 1;
564         break;
565     default:
566         ERR("Invalid expression type\n");
567         *valid = 0;
568         break;
569     }
570
571     return ERROR_SUCCESS;
572 }
573
574 UINT WHERE_CreateView( MSIDATABASE *db, MSIVIEW **view, MSIVIEW *table,
575                        struct expr *cond )
576 {
577     MSIWHEREVIEW *wv = NULL;
578     UINT count = 0, r, valid = 0;
579
580     TRACE("%p\n", table );
581
582     r = table->ops->get_dimensions( table, NULL, &count );
583     if( r != ERROR_SUCCESS )
584     {
585         ERR("can't get table dimensions\n");
586         return r;
587     }
588
589     if( cond )
590     {
591         r = WHERE_VerifyCondition( db, table, cond, &valid );
592         if( r != ERROR_SUCCESS )
593             return r;
594         if( !valid )
595             return ERROR_FUNCTION_FAILED;
596     }
597
598     wv = msi_alloc_zero( sizeof *wv );
599     if( !wv )
600         return ERROR_FUNCTION_FAILED;
601     
602     /* fill the structure */
603     wv->view.ops = &where_ops;
604     msiobj_addref( &db->hdr );
605     wv->db = db;
606     wv->table = table;
607     wv->row_count = 0;
608     wv->reorder = NULL;
609     wv->cond = cond;
610     wv->rec_index = 0;
611     *view = (MSIVIEW*) wv;
612
613     return ERROR_SUCCESS;
614 }