Make the SQL insert query work.
[wine] / dlls / msi / where.c
1 /*
2  * Implementation of the Microsoft Installer (msi.dll)
3  *
4  * Copyright 2002 Mike McCormack for CodeWeavers
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19  */
20
21 #include <stdarg.h>
22
23 #include "windef.h"
24 #include "winbase.h"
25 #include "winerror.h"
26 #include "wine/debug.h"
27 #include "wine/unicode.h"
28 #include "msi.h"
29 #include "msiquery.h"
30 #include "objbase.h"
31 #include "objidl.h"
32 #include "msipriv.h"
33 #include "winnls.h"
34
35 #include "query.h"
36
37 WINE_DEFAULT_DEBUG_CHANNEL(msi);
38
39
40 /* below is the query interface to a table */
41
42 typedef struct tagMSIWHEREVIEW
43 {
44     MSIVIEW        view;
45     MSIDATABASE   *db;
46     MSIVIEW       *table;
47     UINT           row_count;
48     UINT          *reorder;
49     struct expr   *cond;
50 } MSIWHEREVIEW;
51
52 static UINT WHERE_fetch_int( struct tagMSIVIEW *view, UINT row, UINT col, UINT *val )
53 {
54     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
55
56     TRACE("%p %d %d %p\n", wv, row, col, val );
57
58     if( !wv->table )
59         return ERROR_FUNCTION_FAILED;
60
61     if( row > wv->row_count )
62         return ERROR_NO_MORE_ITEMS;
63
64     row = wv->reorder[ row ];
65
66     return wv->table->ops->fetch_int( wv->table, row, col, val );
67 }
68
69 static UINT INT_evaluate( UINT lval, UINT op, UINT rval )
70 {
71     switch( op )
72     {
73     case OP_EQ:
74         return ( lval == rval );
75     case OP_AND:
76         return ( lval && rval );
77     case OP_OR:
78         return ( lval || rval );
79     case OP_GT:
80         return ( lval > rval );
81     case OP_LT:
82         return ( lval < rval );
83     case OP_LE:
84         return ( lval <= rval );
85     case OP_GE:
86         return ( lval >= rval );
87     case OP_NE:
88         return ( lval != rval );
89     case OP_ISNULL:
90         return ( !lval );
91     case OP_NOTNULL:
92         return ( lval );
93     default:
94         ERR("Unknown operator %d\n", op );
95     }
96     return 0;
97 }
98
99 static const WCHAR *STRING_evaluate( string_table *st,
100               MSIVIEW *table, UINT row, struct expr *expr, MSIHANDLE record )
101 {
102     UINT val = 0, r;
103
104     switch( expr->type )
105     {
106     case EXPR_COL_NUMBER:
107         r = table->ops->fetch_int( table, row, expr->u.col_number, &val );
108         if( r != ERROR_SUCCESS )
109             return NULL;
110         return msi_string_lookup_id( st, val );
111
112     case EXPR_SVAL:
113         return expr->u.sval;
114
115     case EXPR_WILDCARD:
116         return MSI_RecordGetString( record, 1 );
117
118     default:
119         ERR("Invalid expression type\n");
120         break;
121     }
122     return NULL;
123 }
124
125 static UINT STRCMP_Evaluate( string_table *st, MSIVIEW *table, UINT row, 
126                              struct expr *cond, UINT *val, MSIHANDLE record )
127 {
128     int sr;
129     const WCHAR *l_str, *r_str;
130
131     l_str = STRING_evaluate( st, table, row, cond->u.expr.left, record );
132     r_str = STRING_evaluate( st, table, row, cond->u.expr.right, record );
133     if( l_str == r_str )
134         sr = 0;
135     else if( l_str && ! r_str )
136         sr = 1;
137     else if( r_str && ! l_str )
138         sr = -1;
139     else
140         sr = strcmpW( l_str, r_str );
141
142     *val = ( cond->u.expr.op == OP_EQ && ( sr == 0 ) ) ||
143            ( cond->u.expr.op == OP_LT && ( sr < 0 ) ) ||
144            ( cond->u.expr.op == OP_GT && ( sr > 0 ) );
145
146     return ERROR_SUCCESS;
147 }
148
149 static UINT WHERE_evaluate( MSIDATABASE *db, MSIVIEW *table, UINT row, 
150                              struct expr *cond, UINT *val, MSIHANDLE record )
151 {
152     UINT r, lval, rval;
153
154     if( !cond )
155         return ERROR_SUCCESS;
156
157     switch( cond->type )
158     {
159     case EXPR_COL_NUMBER:
160         return table->ops->fetch_int( table, row, cond->u.col_number, val );
161
162     case EXPR_UVAL:
163         *val = cond->u.uval;
164         return ERROR_SUCCESS;
165
166     case EXPR_COMPLEX:
167         r = WHERE_evaluate( db, table, row, cond->u.expr.left, &lval, record );
168         if( r != ERROR_SUCCESS )
169             return r;
170         r = WHERE_evaluate( db, table, row, cond->u.expr.right, &rval, record );
171         if( r != ERROR_SUCCESS )
172             return r;
173         *val = INT_evaluate( lval, cond->u.expr.op, rval );
174         return ERROR_SUCCESS;
175
176     case EXPR_STRCMP:
177         return STRCMP_Evaluate( db->strings, table, row, cond, val, record );
178
179     case EXPR_WILDCARD:
180         *val = MsiRecordGetInteger( record, 1 );
181         return ERROR_SUCCESS;
182
183     default:
184         ERR("Invalid expression type\n");
185         break;
186     } 
187
188     return ERROR_SUCCESS;
189
190 }
191
192 static UINT WHERE_execute( struct tagMSIVIEW *view, MSIHANDLE record )
193 {
194     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
195     UINT count = 0, r, val, i;
196     MSIVIEW *table = wv->table;
197
198     TRACE("%p %ld\n", wv, record);
199
200     if( !table )
201          return ERROR_FUNCTION_FAILED;
202
203     r = table->ops->execute( table, record );
204     if( r != ERROR_SUCCESS )
205         return r;
206
207     r = table->ops->get_dimensions( table, &count, NULL );
208     if( r != ERROR_SUCCESS )
209         return r;
210
211     wv->reorder = HeapAlloc( GetProcessHeap(), 0, count*sizeof(UINT) );
212     if( !wv->reorder )
213         return ERROR_FUNCTION_FAILED;
214
215     for( i=0; i<count; i++ )
216     {
217         val = 0;
218         r = WHERE_evaluate( wv->db, table, i, wv->cond, &val, record );
219         if( r != ERROR_SUCCESS )
220             return r;
221         if( val )
222             wv->reorder[ wv->row_count ++ ] = i;
223     }
224
225     return ERROR_SUCCESS;
226 }
227
228 static UINT WHERE_close( struct tagMSIVIEW *view )
229 {
230     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
231
232     TRACE("%p\n", wv );
233
234     if( !wv->table )
235          return ERROR_FUNCTION_FAILED;
236
237     if( wv->reorder )
238         HeapFree( GetProcessHeap(), 0, wv->reorder );
239     wv->reorder = NULL;
240
241     return wv->table->ops->close( wv->table );
242 }
243
244 static UINT WHERE_get_dimensions( struct tagMSIVIEW *view, UINT *rows, UINT *cols )
245 {
246     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
247
248     TRACE("%p %p %p\n", wv, rows, cols );
249
250     if( !wv->table )
251          return ERROR_FUNCTION_FAILED;
252
253     if( rows )
254     {
255         if( !wv->reorder )
256             return ERROR_FUNCTION_FAILED;
257         *rows = wv->row_count;
258     }
259
260     return wv->table->ops->get_dimensions( wv->table, NULL, cols );
261 }
262
263 static UINT WHERE_get_column_info( struct tagMSIVIEW *view,
264                 UINT n, LPWSTR *name, UINT *type )
265 {
266     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
267
268     TRACE("%p %d %p %p\n", wv, n, name, type );
269
270     if( !wv->table )
271          return ERROR_FUNCTION_FAILED;
272
273     return wv->table->ops->get_column_info( wv->table, n, name, type );
274 }
275
276 static UINT WHERE_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode, MSIHANDLE hrec)
277 {
278     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
279
280     TRACE("%p %d %ld\n", wv, eModifyMode, hrec );
281
282     if( !wv->table )
283          return ERROR_FUNCTION_FAILED;
284
285     return wv->table->ops->modify( wv->table, eModifyMode, hrec );
286 }
287
288 static UINT WHERE_delete( struct tagMSIVIEW *view )
289 {
290     MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
291
292     TRACE("%p\n", wv );
293
294     if( wv->table )
295         wv->table->ops->delete( wv->table );
296
297     if( wv->reorder )
298         HeapFree( GetProcessHeap(), 0, wv->reorder );
299     wv->reorder = NULL;
300     wv->row_count = 0;
301
302     if( wv->cond )
303         delete_expr( wv->cond );
304
305     HeapFree( GetProcessHeap(), 0, wv );
306
307     return ERROR_SUCCESS;
308 }
309
310
311 MSIVIEWOPS where_ops =
312 {
313     WHERE_fetch_int,
314     NULL,
315     NULL,
316     WHERE_execute,
317     WHERE_close,
318     WHERE_get_dimensions,
319     WHERE_get_column_info,
320     WHERE_modify,
321     WHERE_delete
322 };
323
324 UINT WHERE_CreateView( MSIDATABASE *db, MSIVIEW **view, MSIVIEW *table )
325 {
326     MSIWHEREVIEW *wv = NULL;
327     UINT count = 0, r;
328
329     TRACE("%p\n", wv );
330
331     r = table->ops->get_dimensions( table, NULL, &count );
332     if( r != ERROR_SUCCESS )
333     {
334         ERR("can't get table dimensions\n");
335         return r;
336     }
337
338     wv = HeapAlloc( GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof *wv );
339     if( !wv )
340         return ERROR_FUNCTION_FAILED;
341     
342     /* fill the structure */
343     wv->view.ops = &where_ops;
344     wv->db = db;
345     wv->table = table;
346     wv->row_count = 0;
347     wv->reorder = NULL;
348     wv->cond = NULL;
349     *view = (MSIVIEW*) wv;
350
351     return ERROR_SUCCESS;
352 }
353
354 static UINT WHERE_VerifyCondition( MSIDATABASE *db, MSIVIEW *table, struct expr *cond,
355                                    UINT *valid )
356 {
357     UINT r, val = 0;
358
359     switch( cond->type )
360     {
361     case EXPR_COLUMN:
362         r = VIEW_find_column( table, cond->u.column, &val );
363         if( r == ERROR_SUCCESS )
364         {
365             *valid = 1;
366             cond->type = EXPR_COL_NUMBER;
367             cond->u.col_number = val;
368         }
369         else
370         {
371             *valid = 0;
372             ERR("Couldn't find column %s\n", debugstr_w( cond->u.column ) );
373         }
374         break;
375     case EXPR_COMPLEX:
376         r = WHERE_VerifyCondition( db, table, cond->u.expr.left, valid );
377         if( r != ERROR_SUCCESS )
378             return r;
379         if( !*valid )
380             return ERROR_SUCCESS;
381         r = WHERE_VerifyCondition( db, table, cond->u.expr.right, valid );
382         if( r != ERROR_SUCCESS )
383             return r;
384
385         /* check the type of the comparison */
386         if( ( cond->u.expr.left->type == EXPR_SVAL ) ||
387             ( cond->u.expr.right->type == EXPR_SVAL ) )
388         {
389             switch( cond->u.expr.op )
390             {
391             case OP_EQ:
392             case OP_GT:
393             case OP_LT:
394                 break;
395             default:
396                 *valid = FALSE;
397                 return ERROR_INVALID_PARAMETER;
398             }
399
400             /* FIXME: check we're comparing a string to a column */
401
402             cond->type = EXPR_STRCMP;
403         }
404
405         break;
406     case EXPR_IVAL:
407         *valid = 1;
408         cond->type = EXPR_UVAL;
409         cond->u.uval = cond->u.ival + (1<<15);
410         break;
411     case EXPR_WILDCARD:
412         *valid = 1;
413         break;
414     case EXPR_SVAL:
415         *valid = 1;
416         break;
417     default:
418         ERR("Invalid expression type\n");
419         *valid = 0;
420         break;
421     } 
422
423     return ERROR_SUCCESS;
424 }
425
426 UINT WHERE_AddCondition( MSIVIEW *view, struct expr *cond )
427 {
428     MSIWHEREVIEW *wv = (MSIWHEREVIEW *) view;
429     UINT r, valid = 0;
430
431     if( wv->view.ops != &where_ops )
432         return ERROR_FUNCTION_FAILED;
433     if( !wv->table )
434         return ERROR_INVALID_PARAMETER;
435     
436     if( !cond )
437         return ERROR_SUCCESS;
438
439     TRACE("Adding condition\n");
440
441     r = WHERE_VerifyCondition( wv->db, wv->table, cond, &valid );
442     if( r != ERROR_SUCCESS )
443         ERR("condition evaluation failed\n");
444
445     TRACE("condition is %s\n", valid ? "valid" : "invalid" );
446     if( !valid )
447     {
448         delete_expr( cond );
449         return ERROR_FUNCTION_FAILED;
450     }
451
452     wv->cond = cond;
453
454     return ERROR_SUCCESS;
455 }