Avoid fixed length buffers for conditions.
[wine] / dlls / msi / string.c
1 /*
2  * Implementation of the Microsoft Installer (msi.dll)
3  *
4  * Copyright 2002-2004, Mike McCormack for CodeWeavers
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19  */
20
21 #include <stdarg.h>
22 #include <assert.h>
23
24 #include "windef.h"
25 #include "winbase.h"
26 #include "winerror.h"
27 #include "wine/debug.h"
28 #include "wine/unicode.h"
29 #include "msi.h"
30 #include "msiquery.h"
31 #include "objbase.h"
32 #include "objidl.h"
33 #include "msipriv.h"
34 #include "winnls.h"
35
36 #include "query.h"
37
38 WINE_DEFAULT_DEBUG_CHANNEL(msi);
39
40 typedef struct _msistring
41 {
42     UINT hash;
43     UINT refcount;
44     LPWSTR str;
45 } msistring;
46
47 struct string_table
48 {
49     UINT maxcount;         /* the number of strings */
50     UINT freeslot;
51     UINT codepage;
52     msistring *strings; /* an array of strings (in the tree) */
53 };
54
55 static UINT msistring_makehash( const WCHAR *str )
56 {
57     UINT hash = 0;
58
59     if (str==NULL)
60         return hash;
61
62     while( *str )
63     {
64         hash ^= *str++;
65         hash *= 53;
66         hash = (hash<<5) | (hash>>27);
67     }
68     return hash;
69 }
70
71 string_table *msi_init_stringtable( int entries, UINT codepage )
72 {
73     string_table *st;
74
75     st = HeapAlloc( GetProcessHeap(), 0, sizeof (string_table) );
76     if( !st )
77         return NULL;    
78     st->strings = HeapAlloc( GetProcessHeap(), HEAP_ZERO_MEMORY,
79                               sizeof (msistring) * entries );
80     if( !st )
81     {
82         HeapFree( GetProcessHeap(), 0, st );
83         return NULL;    
84     }
85     if( entries < 1 )
86         entries = 1;
87     st->maxcount = entries;
88     st->freeslot = 1;
89     st->codepage = codepage;
90
91     return st;
92 }
93
94 VOID msi_destroy_stringtable( string_table *st )
95 {
96     UINT i;
97
98     for( i=0; i<st->maxcount; i++ )
99     {
100         if( st->strings[i].refcount )
101             HeapFree( GetProcessHeap(), 0, st->strings[i].str );
102     }
103     HeapFree( GetProcessHeap(), 0, st->strings );
104     HeapFree( GetProcessHeap(), 0, st );
105 }
106
107 static int st_find_free_entry( string_table *st )
108 {
109     UINT i, sz;
110     msistring *p;
111
112     TRACE("%p\n", st);
113
114     if( st->freeslot )
115     {
116         for( i = st->freeslot; i < st->maxcount; i++ )
117             if( !st->strings[i].refcount )
118                 return i;
119     }
120     for( i = 1; i < st->maxcount; i++ )
121         if( !st->strings[i].refcount )
122             return i;
123
124     /* dynamically resize */
125     sz = st->maxcount + 1 + st->maxcount/2;
126     p = HeapReAlloc( GetProcessHeap(), HEAP_ZERO_MEMORY,
127                      st->strings, sz*sizeof(msistring) );
128     if( !p )
129         return -1;
130     st->strings = p;
131     st->freeslot = st->maxcount;
132     st->maxcount = sz;
133     if( st->strings[st->freeslot].refcount )
134         ERR("oops. expected freeslot to be free...\n");
135     return st->freeslot;
136 }
137
138 static void st_mark_entry_used( string_table *st, UINT n )
139 {
140     if( n >= st->maxcount )
141         return;
142     st->freeslot = n + 1;
143 }
144
145 int msi_addstring( string_table *st, UINT n, const CHAR *data, int len, UINT refcount )
146 {
147     int sz;
148
149     if( !data )
150         return 0;
151     if( !data[0] )
152         return 0;
153     if( n > 0 )
154     {
155         if( st->strings[n].refcount )
156             return -1;
157     }
158     else
159     {
160         if( ERROR_SUCCESS == msi_string2idA( st, data, &n ) )
161         {
162             st->strings[n].refcount++;
163             return n;
164         }
165         n = st_find_free_entry( st );
166         if( n < 0 )
167             return -1;
168     }
169
170     if( n < 1 )
171     {
172         ERR("invalid index adding %s (%d)\n", debugstr_a( data ), n );
173         return -1;
174     }
175
176     /* allocate a new string */
177     if( len < 0 )
178         len = strlen(data);
179     sz = MultiByteToWideChar( st->codepage, 0, data, len, NULL, 0 );
180     st->strings[n].str = HeapAlloc( GetProcessHeap(), 0, (sz+1)*sizeof(WCHAR) );
181     if( !st->strings[n].str )
182         return -1;
183     MultiByteToWideChar( st->codepage, 0, data, len, st->strings[n].str, sz );
184     st->strings[n].str[sz] = 0;
185     st->strings[n].refcount = 1;
186     st->strings[n].hash = msistring_makehash( st->strings[n].str );
187
188     st_mark_entry_used( st, n );
189
190     return n;
191 }
192
193 int msi_addstringW( string_table *st, UINT n, const WCHAR *data, int len, UINT refcount )
194 {
195     /* TRACE("[%2d] = %s\n", string_no, debugstr_an(data,len) ); */
196
197     if( !data )
198         return 0;
199     if( !data[0] )
200         return 0;
201     if( n > 0 )
202     {
203         if( st->strings[n].refcount )
204             return -1;
205     }
206     else
207     {
208         if( ERROR_SUCCESS == msi_string2idW( st, data, &n ) )
209         {
210             st->strings[n].refcount++;
211             return n;
212         }
213         n = st_find_free_entry( st );
214         if( n < 0 )
215             return -1;
216     }
217
218     if( n < 1 )
219     {
220         ERR("invalid index adding %s (%d)\n", debugstr_w( data ), n );
221         return -1;
222     }
223
224     /* allocate a new string */
225     if(len<0)
226         len = strlenW(data);
227     TRACE("%s, n = %d len = %d\n", debugstr_w(data), n, len );
228
229     st->strings[n].str = HeapAlloc( GetProcessHeap(), 0, (len+1)*sizeof(WCHAR) );
230     if( !st->strings[n].str )
231         return -1;
232     TRACE("%d\n",__LINE__);
233     memcpy( st->strings[n].str, data, len*sizeof(WCHAR) );
234     st->strings[n].str[len] = 0;
235     st->strings[n].refcount = 1;
236     st->strings[n].hash = msistring_makehash( st->strings[n].str );
237
238     st_mark_entry_used( st, n );
239
240     return n;
241 }
242
243 /* find the string identified by an id - return null if there's none */
244 const WCHAR *msi_string_lookup_id( string_table *st, UINT id )
245 {
246     static const WCHAR zero[] = { 0 };
247     if( id == 0 )
248         return zero;
249
250     if( id >= st->maxcount )
251         return NULL;
252
253     if( id && !st->strings[id].refcount )
254         return NULL;
255
256     return st->strings[id].str;
257 }
258
259 /*
260  *  msi_id2stringW
261  *
262  *  [in] st         - pointer to the string table
263  *  [in] id  - id of the string to retrieve
264  *  [out] buffer    - destination of the string
265  *  [in/out] sz     - number of bytes available in the buffer on input
266  *                    number of bytes used on output
267  *
268  *   The size includes the terminating nul character.  Short buffers
269  *  will be filled, but not nul terminated.
270  */
271 UINT msi_id2stringW( string_table *st, UINT id, LPWSTR buffer, UINT *sz )
272 {
273     UINT len;
274     const WCHAR *str;
275
276     TRACE("Finding string %d of %d\n", id, st->maxcount);
277
278     str = msi_string_lookup_id( st, id );
279     if( !str )
280         return ERROR_FUNCTION_FAILED;
281
282     len = strlenW( str ) + 1;
283
284     if( !buffer )
285     {
286         *sz = len;
287         return ERROR_SUCCESS;
288     }
289
290     if( *sz < len )
291         *sz = len;
292     memcpy( buffer, str, (*sz)*sizeof(WCHAR) ); 
293     *sz = len;
294
295     return ERROR_SUCCESS;
296 }
297
298 /*
299  *  msi_id2stringA
300  *
301  *  [in] st         - pointer to the string table
302  *  [in] id         - id of the string to retrieve
303  *  [out] buffer    - destination of the UTF8 string
304  *  [in/out] sz     - number of bytes available in the buffer on input
305  *                    number of bytes used on output
306  *
307  *   The size includes the terminating nul character.  Short buffers
308  *  will be filled, but not nul terminated.
309  */
310 UINT msi_id2stringA( string_table *st, UINT id, LPSTR buffer, UINT *sz )
311 {
312     UINT len;
313     const WCHAR *str;
314     int n;
315
316     TRACE("Finding string %d of %d\n", id, st->maxcount);
317
318     str = msi_string_lookup_id( st, id );
319     if( !str )
320         return ERROR_FUNCTION_FAILED;
321
322     len = WideCharToMultiByte( st->codepage, 0, str, -1, NULL, 0, NULL, NULL );
323
324     if( !buffer )
325     {
326         *sz = len;
327         return ERROR_SUCCESS;
328     }
329
330     if( len > *sz )
331     {
332         n = strlenW( str ) + 1;
333         while( n && (len > *sz) )
334             len = WideCharToMultiByte( st->codepage, 0, 
335                            str, --n, NULL, 0, NULL, NULL );
336     }
337     else
338         n = -1;
339
340     *sz = WideCharToMultiByte( st->codepage, 0, str, n, buffer, len, NULL, NULL );
341
342     return ERROR_SUCCESS;
343 }
344
345 /*
346  *  msi_string2idW
347  *
348  *  [in] st         - pointer to the string table
349  *  [in] str        - string to find in the string table
350  *  [out] id        - id of the string, if found
351  */
352 UINT msi_string2idW( string_table *st, LPCWSTR str, UINT *id )
353 {
354     UINT hash;
355     UINT i, r = ERROR_INVALID_PARAMETER;
356
357     hash = msistring_makehash( str );
358     for( i=0; i<st->maxcount; i++ )
359     {
360         if ( (str == NULL && st->strings[i].str == NULL) || 
361             ( ( st->strings[i].hash == hash ) &&
362             !strcmpW( st->strings[i].str, str ) ))
363         {
364             r = ERROR_SUCCESS;
365             *id = i;
366             break;
367         }
368     }
369
370     return r;
371 }
372
373 UINT msi_string2idA( string_table *st, LPCSTR buffer, UINT *id )
374 {
375     DWORD sz;
376     UINT r = ERROR_INVALID_PARAMETER;
377     LPWSTR str;
378
379     TRACE("Finding string %s in string table\n", debugstr_a(buffer) );
380
381     if( buffer[0] == 0 )
382     {
383         *id = 0;
384         return ERROR_SUCCESS;
385     }
386
387     sz = MultiByteToWideChar( st->codepage, 0, buffer, -1, NULL, 0 );
388     if( sz <= 0 )
389         return r;
390     str = HeapAlloc( GetProcessHeap(), 0, sz*sizeof(WCHAR) );
391     if( !str )
392         return ERROR_NOT_ENOUGH_MEMORY;
393     MultiByteToWideChar( st->codepage, 0, buffer, -1, str, sz );
394
395     r = msi_string2idW( st, str, id );
396     HeapFree( GetProcessHeap(), 0, str );
397
398     return r;
399 }
400
401 UINT msi_strcmp( string_table *st, UINT lval, UINT rval, UINT *res )
402 {
403     const WCHAR *l_str, *r_str;
404
405     l_str = msi_string_lookup_id( st, lval );
406     if( !l_str )
407         return ERROR_INVALID_PARAMETER;
408     
409     r_str = msi_string_lookup_id( st, rval );
410     if( !r_str )
411         return ERROR_INVALID_PARAMETER;
412
413     /* does this do the right thing for all UTF-8 strings? */
414     *res = strcmpW( l_str, r_str );
415
416     return ERROR_SUCCESS;
417 }
418
419 UINT msi_string_count( string_table *st )
420 {
421     return st->maxcount;
422 }
423
424 UINT msi_id_refcount( string_table *st, UINT i )
425 {
426     if( i >= st->maxcount )
427         return 0;
428     return st->strings[i].refcount;
429 }
430
431 UINT msi_string_totalsize( string_table *st, UINT *total )
432 {
433     UINT size = 0, i, len;
434
435     if( st->strings[0].str || st->strings[0].refcount )
436         ERR("oops. element 0 has a string\n");
437     *total = 0;
438     for( i=1; i<st->maxcount; i++ )
439     {
440         if( st->strings[i].str )
441         {
442             TRACE("[%u] = %s\n", i, debugstr_w(st->strings[i].str));
443             len = WideCharToMultiByte( st->codepage, 0,
444                      st->strings[i].str, -1, NULL, 0, NULL, NULL);
445             if( len )
446                 len--;
447             size += len;
448             *total = (i+1);
449         }
450     }
451     TRACE("%u/%u strings %u bytes codepage %x\n", *total, st->maxcount, size, st->codepage );
452     return size;
453 }
454
455 UINT msi_string_get_codepage( string_table *st )
456 {
457     return st->codepage;
458 }