vbscript: Added more equality expressions parser/compiler implementation.
[wine] / dlls / msi / join.c
1 /*
2  * Implementation of the Microsoft Installer (msi.dll)
3  *
4  * Copyright 2006 Mike McCormack for CodeWeavers
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  */
20
21 #include <stdarg.h>
22
23 #include "windef.h"
24 #include "winbase.h"
25 #include "winerror.h"
26 #include "msi.h"
27 #include "msiquery.h"
28 #include "objbase.h"
29 #include "objidl.h"
30 #include "msipriv.h"
31 #include "query.h"
32
33 #include "wine/debug.h"
34 #include "wine/unicode.h"
35
36 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
37
38 typedef struct tagJOINTABLE
39 {
40     struct list entry;
41     MSIVIEW *view;
42     UINT columns;
43     UINT rows;
44     UINT next_rows;
45 } JOINTABLE;
46
47 typedef struct tagMSIJOINVIEW
48 {
49     MSIVIEW        view;
50     MSIDATABASE   *db;
51     struct list    tables;
52     UINT           columns;
53     UINT           rows;
54 } MSIJOINVIEW;
55
56 static UINT JOIN_fetch_int( struct tagMSIVIEW *view, UINT row, UINT col, UINT *val )
57 {
58     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
59     JOINTABLE *table;
60     UINT cols = 0;
61     UINT prev_rows = 1;
62
63     if (col == 0 || col > jv->columns)
64          return ERROR_FUNCTION_FAILED;
65
66     if (row >= jv->rows)
67          return ERROR_FUNCTION_FAILED;
68
69     LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
70     {
71         if (col <= cols + table->columns)
72         {
73             row = (row % (jv->rows / table->next_rows)) / prev_rows;
74             col -= cols;
75             break;
76         }
77
78         prev_rows *= table->rows;
79         cols += table->columns;
80     }
81
82     return table->view->ops->fetch_int( table->view, row, col, val );
83 }
84
85 static UINT JOIN_fetch_stream( struct tagMSIVIEW *view, UINT row, UINT col, IStream **stm)
86 {
87     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
88     JOINTABLE *table;
89     UINT cols = 0;
90     UINT prev_rows = 1;
91
92     TRACE("%p %d %d %p\n", jv, row, col, stm );
93
94     if (col == 0 || col > jv->columns)
95          return ERROR_FUNCTION_FAILED;
96
97     if (row >= jv->rows)
98          return ERROR_FUNCTION_FAILED;
99
100     LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
101     {
102         if (col <= cols + table->columns)
103         {
104             row = (row % (jv->rows / table->next_rows)) / prev_rows;
105             col -= cols;
106             break;
107         }
108
109         prev_rows *= table->rows;
110         cols += table->columns;
111     }
112
113     return table->view->ops->fetch_stream( table->view, row, col, stm );
114 }
115
116 static UINT JOIN_get_row( struct tagMSIVIEW *view, UINT row, MSIRECORD **rec )
117 {
118     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
119
120     TRACE("%p %d %p\n", jv, row, rec);
121
122     return msi_view_get_row( jv->db, view, row, rec );
123 }
124
125 static UINT JOIN_execute( struct tagMSIVIEW *view, MSIRECORD *record )
126 {
127     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
128     JOINTABLE *table;
129     UINT r, rows;
130
131     TRACE("%p %p\n", jv, record);
132
133     LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
134     {
135         table->view->ops->execute(table->view, NULL);
136
137         r = table->view->ops->get_dimensions(table->view, &table->rows, NULL);
138         if (r != ERROR_SUCCESS)
139         {
140             ERR("failed to get table dimensions\n");
141             return r;
142         }
143
144         /* each table must have at least one row */
145         if (table->rows == 0)
146         {
147             jv->rows = 0;
148             return ERROR_SUCCESS;
149         }
150
151         if (jv->rows == 0)
152             jv->rows = table->rows;
153         else
154             jv->rows *= table->rows;
155     }
156
157     rows = jv->rows;
158     LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
159     {
160         rows /= table->rows;
161         table->next_rows = rows;
162     }
163
164     return ERROR_SUCCESS;
165 }
166
167 static UINT JOIN_close( struct tagMSIVIEW *view )
168 {
169     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
170     JOINTABLE *table;
171
172     TRACE("%p\n", jv );
173
174     LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
175     {
176         table->view->ops->close(table->view);
177     }
178
179     return ERROR_SUCCESS;
180 }
181
182 static UINT JOIN_get_dimensions( struct tagMSIVIEW *view, UINT *rows, UINT *cols )
183 {
184     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
185
186     TRACE("%p %p %p\n", jv, rows, cols );
187
188     if (cols)
189         *cols = jv->columns;
190
191     if (rows)
192         *rows = jv->rows;
193
194     return ERROR_SUCCESS;
195 }
196
197 static UINT JOIN_get_column_info( struct tagMSIVIEW *view, UINT n, LPCWSTR *name,
198                                   UINT *type, BOOL *temporary, LPCWSTR *table_name )
199 {
200     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
201     JOINTABLE *table;
202     UINT cols = 0;
203
204     TRACE("%p %d %p %p %p %p\n", jv, n, name, type, temporary, table_name );
205
206     if (n == 0 || n > jv->columns)
207         return ERROR_FUNCTION_FAILED;
208
209     LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
210     {
211         if (n <= cols + table->columns)
212             return table->view->ops->get_column_info(table->view, n - cols,
213                                                      name, type, temporary,
214                                                      table_name);
215         cols += table->columns;
216     }
217
218     return ERROR_FUNCTION_FAILED;
219 }
220
221 static UINT join_find_row( MSIJOINVIEW *jv, MSIRECORD *rec, UINT *row )
222 {
223     LPCWSTR str;
224     UINT r, i, id, data;
225
226     str = MSI_RecordGetString( rec, 1 );
227     r = msi_string2idW( jv->db->strings, str, &id );
228     if (r != ERROR_SUCCESS)
229         return r;
230
231     for (i = 0; i < jv->rows; i++)
232     {
233         JOIN_fetch_int( &jv->view, i, 1, &data );
234
235         if (data == id)
236         {
237             *row = i;
238             return ERROR_SUCCESS;
239         }
240     }
241
242     return ERROR_FUNCTION_FAILED;
243 }
244
245 static UINT JOIN_set_row( struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UINT mask )
246 {
247     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
248     JOINTABLE *table;
249     UINT i, reduced_mask = 0, r = ERROR_SUCCESS, offset = 0, col_count;
250     MSIRECORD *reduced;
251
252     TRACE("%p %d %p %u %08x\n", jv, row, rec, rec->count, mask );
253
254     if (mask >= 1 << jv->columns)
255         return ERROR_INVALID_PARAMETER;
256
257     LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
258     {
259         r = table->view->ops->get_dimensions( table->view, NULL, &col_count );
260         if (r != ERROR_SUCCESS)
261             return r;
262
263         reduced = MSI_CreateRecord( col_count );
264         if (!reduced)
265             return ERROR_FUNCTION_FAILED;
266
267         for (i = 0; i < col_count; i++)
268         {
269             r = MSI_RecordCopyField( rec, i + offset + 1, reduced, i + 1 );
270             if (r != ERROR_SUCCESS)
271                 break;
272         }
273
274         offset += col_count;
275         reduced_mask = mask >> (jv->columns - offset) & ((1 << col_count) - 1);
276
277         if (r == ERROR_SUCCESS)
278             r = table->view->ops->set_row( table->view, row, reduced, reduced_mask );
279
280         msiobj_release( &reduced->hdr );
281     }
282
283     return r;
284 }
285
286 static UINT join_modify_update( struct tagMSIVIEW *view, MSIRECORD *rec )
287 {
288     MSIJOINVIEW *jv = (MSIJOINVIEW *)view;
289     UINT r, row;
290
291     r = join_find_row( jv, rec, &row );
292     if (r != ERROR_SUCCESS)
293         return r;
294
295     return JOIN_set_row( view, row, rec, (1 << jv->columns) - 1 );
296 }
297
298 static UINT JOIN_modify( struct tagMSIVIEW *view, MSIMODIFY mode, MSIRECORD *rec, UINT row )
299 {
300     UINT r;
301
302     TRACE("%p %d %p %u\n", view, mode, rec, row);
303
304     switch (mode)
305     {
306     case MSIMODIFY_UPDATE:
307         return join_modify_update( view, rec );
308
309     case MSIMODIFY_ASSIGN:
310     case MSIMODIFY_DELETE:
311     case MSIMODIFY_INSERT:
312     case MSIMODIFY_INSERT_TEMPORARY:
313     case MSIMODIFY_MERGE:
314     case MSIMODIFY_REPLACE:
315     case MSIMODIFY_SEEK:
316     case MSIMODIFY_VALIDATE:
317     case MSIMODIFY_VALIDATE_DELETE:
318     case MSIMODIFY_VALIDATE_FIELD:
319     case MSIMODIFY_VALIDATE_NEW:
320         r = ERROR_FUNCTION_FAILED;
321         break;
322
323     case MSIMODIFY_REFRESH:
324         r = ERROR_CALL_NOT_IMPLEMENTED;
325         break;
326
327     default:
328         WARN("%p %d %p %u - unknown mode\n", view, mode, rec, row );
329         r = ERROR_INVALID_PARAMETER;
330         break;
331     }
332
333     return r;
334 }
335
336 static UINT JOIN_delete( struct tagMSIVIEW *view )
337 {
338     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
339     struct list *item, *cursor;
340
341     TRACE("%p\n", jv );
342
343     LIST_FOR_EACH_SAFE(item, cursor, &jv->tables)
344     {
345         JOINTABLE* table = LIST_ENTRY(item, JOINTABLE, entry);
346
347         list_remove(&table->entry);
348         table->view->ops->delete(table->view);
349         table->view = NULL;
350         msi_free(table);
351     }
352
353     msi_free(jv);
354
355     return ERROR_SUCCESS;
356 }
357
358 static UINT JOIN_find_matching_rows( struct tagMSIVIEW *view, UINT col,
359     UINT val, UINT *row, MSIITERHANDLE *handle )
360 {
361     MSIJOINVIEW *jv = (MSIJOINVIEW*)view;
362     UINT i, row_value;
363
364     TRACE("%p, %d, %u, %p\n", view, col, val, *handle);
365
366     if (col == 0 || col > jv->columns)
367         return ERROR_INVALID_PARAMETER;
368
369     for (i = PtrToUlong(*handle); i < jv->rows; i++)
370     {
371         if (view->ops->fetch_int( view, i, col, &row_value ) != ERROR_SUCCESS)
372             continue;
373
374         if (row_value == val)
375         {
376             *row = i;
377             (*(UINT *)handle) = i + 1;
378             return ERROR_SUCCESS;
379         }
380     }
381
382     return ERROR_NO_MORE_ITEMS;
383 }
384
385 static UINT JOIN_sort(struct tagMSIVIEW *view, column_info *columns)
386 {
387     MSIJOINVIEW *jv = (MSIJOINVIEW *)view;
388     JOINTABLE *table;
389     UINT r;
390
391     TRACE("%p %p\n", view, columns);
392
393     LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry)
394     {
395         r = table->view->ops->sort(table->view, columns);
396         if (r != ERROR_SUCCESS)
397             return r;
398     }
399
400     return ERROR_SUCCESS;
401 }
402
403 static const MSIVIEWOPS join_ops =
404 {
405     JOIN_fetch_int,
406     JOIN_fetch_stream,
407     JOIN_get_row,
408     NULL,
409     NULL,
410     NULL,
411     JOIN_execute,
412     JOIN_close,
413     JOIN_get_dimensions,
414     JOIN_get_column_info,
415     JOIN_modify,
416     JOIN_delete,
417     JOIN_find_matching_rows,
418     NULL,
419     NULL,
420     NULL,
421     NULL,
422     JOIN_sort,
423     NULL,
424 };
425
426 UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view, LPWSTR tables )
427 {
428     MSIJOINVIEW *jv = NULL;
429     UINT r = ERROR_SUCCESS;
430     JOINTABLE *table;
431     LPWSTR ptr;
432
433     TRACE("%p (%s)\n", jv, debugstr_w(tables) );
434
435     jv = msi_alloc_zero( sizeof *jv );
436     if( !jv )
437         return ERROR_FUNCTION_FAILED;
438
439     /* fill the structure */
440     jv->view.ops = &join_ops;
441     jv->db = db;
442     jv->columns = 0;
443     jv->rows = 0;
444
445     list_init(&jv->tables);
446
447     while (*tables)
448     {
449         if ((ptr = strchrW(tables, ' ')))
450             *ptr = '\0';
451
452         table = msi_alloc(sizeof(JOINTABLE));
453         if (!table)
454         {
455             r = ERROR_OUTOFMEMORY;
456             goto end;
457         }
458
459         r = TABLE_CreateView( db, tables, &table->view );
460         if( r != ERROR_SUCCESS )
461         {
462             WARN("can't create table: %s\n", debugstr_w(tables));
463             msi_free(table);
464             r = ERROR_BAD_QUERY_SYNTAX;
465             goto end;
466         }
467
468         r = table->view->ops->get_dimensions( table->view, NULL,
469                                               &table->columns );
470         if( r != ERROR_SUCCESS )
471         {
472             ERR("can't get table dimensions\n");
473             goto end;
474         }
475
476         jv->columns += table->columns;
477
478         list_add_head( &jv->tables, &table->entry );
479
480         if (!ptr)
481             break;
482
483         tables = ptr + 1;
484     }
485
486     *view = &jv->view;
487     return ERROR_SUCCESS;
488
489 end:
490     jv->view.ops->delete( &jv->view );
491
492     return r;
493 }