user32: Remove a unused variable from ES_PASSWORD test.
[wine] / dlls / msi / string.c
1 /*
2  * String Table Functions
3  *
4  * Copyright 2002-2004, Mike McCormack for CodeWeavers
5  * Copyright 2007 Robert Shearman for CodeWeavers
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
20  */
21
22 #define COBJMACROS
23
24 #include <stdarg.h>
25 #include <assert.h>
26
27 #include "windef.h"
28 #include "winbase.h"
29 #include "winerror.h"
30 #include "wine/debug.h"
31 #include "wine/unicode.h"
32 #include "msi.h"
33 #include "msiquery.h"
34 #include "objbase.h"
35 #include "objidl.h"
36 #include "msipriv.h"
37 #include "winnls.h"
38
39 #include "query.h"
40
41 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
42
43 #define HASH_SIZE 0x101
44 #define LONG_STR_BYTES 3
45
46 typedef struct _msistring
47 {
48     int hash_next;
49     UINT persistent_refcount;
50     UINT nonpersistent_refcount;
51     LPWSTR str;
52 } msistring;
53
54 struct string_table
55 {
56     UINT maxcount;         /* the number of strings */
57     UINT freeslot;
58     UINT codepage;
59     int hash[HASH_SIZE];
60     msistring *strings; /* an array of strings (in the tree) */
61 };
62
63 static UINT msistring_makehash( const WCHAR *str )
64 {
65     UINT hash = 0;
66
67     if (str==NULL)
68         return hash;
69
70     while( *str )
71     {
72         hash ^= *str++;
73         hash *= 53;
74         hash = (hash<<5) | (hash>>27);
75     }
76     return hash % HASH_SIZE;
77 }
78
79 static string_table *init_stringtable( int entries, UINT codepage )
80 {
81     string_table *st;
82     int i;
83
84     st = msi_alloc( sizeof (string_table) );
85     if( !st )
86         return NULL;    
87     if( entries < 1 )
88         entries = 1;
89     st->strings = msi_alloc_zero( sizeof (msistring) * entries );
90     if( !st->strings )
91     {
92         msi_free( st );
93         return NULL;    
94     }
95     st->maxcount = entries;
96     st->freeslot = 1;
97     st->codepage = codepage;
98
99     for( i=0; i<HASH_SIZE; i++ )
100         st->hash[i] = -1;
101
102     return st;
103 }
104
105 VOID msi_destroy_stringtable( string_table *st )
106 {
107     UINT i;
108
109     for( i=0; i<st->maxcount; i++ )
110     {
111         if( st->strings[i].persistent_refcount ||
112             st->strings[i].nonpersistent_refcount )
113             msi_free( st->strings[i].str );
114     }
115     msi_free( st->strings );
116     msi_free( st );
117 }
118
119 static int st_find_free_entry( string_table *st )
120 {
121     UINT i, sz;
122     msistring *p;
123
124     TRACE("%p\n", st);
125
126     if( st->freeslot )
127     {
128         for( i = st->freeslot; i < st->maxcount; i++ )
129             if( !st->strings[i].persistent_refcount &&
130                 !st->strings[i].nonpersistent_refcount )
131                 return i;
132     }
133     for( i = 1; i < st->maxcount; i++ )
134         if( !st->strings[i].persistent_refcount &&
135             !st->strings[i].nonpersistent_refcount )
136             return i;
137
138     /* dynamically resize */
139     sz = st->maxcount + 1 + st->maxcount/2;
140     p = msi_realloc_zero( st->strings, sz*sizeof(msistring) );
141     if( !p )
142         return -1;
143     st->strings = p;
144     st->freeslot = st->maxcount;
145     st->maxcount = sz;
146     if( st->strings[st->freeslot].persistent_refcount ||
147         st->strings[st->freeslot].nonpersistent_refcount )
148         ERR("oops. expected freeslot to be free...\n");
149     return st->freeslot;
150 }
151
152 static void set_st_entry( string_table *st, UINT n, LPWSTR str, UINT refcount, enum StringPersistence persistence )
153 {
154     UINT hash = msistring_makehash( str );
155
156     if (persistence == StringPersistent)
157     {
158         st->strings[n].persistent_refcount = refcount;
159         st->strings[n].nonpersistent_refcount = 0;
160     }
161     else
162     {
163         st->strings[n].persistent_refcount = 0;
164         st->strings[n].nonpersistent_refcount = refcount;
165     }
166
167     st->strings[n].str = str;
168
169     st->strings[n].hash_next = st->hash[hash];
170     st->hash[hash] = n;
171
172     if( n < st->maxcount )
173         st->freeslot = n + 1;
174 }
175
176 static int msi_addstring( string_table *st, UINT n, const CHAR *data, int len, UINT refcount, enum StringPersistence persistence )
177 {
178     LPWSTR str;
179     int sz;
180
181     if( !data )
182         return 0;
183     if( !data[0] )
184         return 0;
185     if( n > 0 )
186     {
187         if( st->strings[n].persistent_refcount ||
188             st->strings[n].nonpersistent_refcount )
189             return -1;
190     }
191     else
192     {
193         if( ERROR_SUCCESS == msi_string2idA( st, data, &n ) )
194         {
195             if (persistence == StringPersistent)
196                 st->strings[n].persistent_refcount += refcount;
197             else
198                 st->strings[n].nonpersistent_refcount += refcount;
199             return n;
200         }
201         n = st_find_free_entry( st );
202         if( n < 0 )
203             return -1;
204     }
205
206     if( n < 1 )
207     {
208         ERR("invalid index adding %s (%d)\n", debugstr_a( data ), n );
209         return -1;
210     }
211
212     /* allocate a new string */
213     if( len < 0 )
214         len = strlen(data);
215     sz = MultiByteToWideChar( st->codepage, 0, data, len, NULL, 0 );
216     str = msi_alloc( (sz+1)*sizeof(WCHAR) );
217     if( !str )
218         return -1;
219     MultiByteToWideChar( st->codepage, 0, data, len, str, sz );
220     str[sz] = 0;
221
222     set_st_entry( st, n, str, refcount, persistence );
223
224     return n;
225 }
226
227 int msi_addstringW( string_table *st, UINT n, const WCHAR *data, int len, UINT refcount, enum StringPersistence persistence )
228 {
229     LPWSTR str;
230
231     /* TRACE("[%2d] = %s\n", string_no, debugstr_an(data,len) ); */
232
233     if( !data )
234         return 0;
235     if( !data[0] )
236         return 0;
237     if( n > 0 )
238     {
239         if( st->strings[n].persistent_refcount ||
240             st->strings[n].nonpersistent_refcount )
241             return -1;
242     }
243     else
244     {
245         if( ERROR_SUCCESS == msi_string2idW( st, data, &n ) )
246         {
247             if (persistence == StringPersistent)
248                 st->strings[n].persistent_refcount += refcount;
249             else
250                 st->strings[n].nonpersistent_refcount += refcount;
251             return n;
252         }
253         n = st_find_free_entry( st );
254         if( n < 0 )
255             return -1;
256     }
257
258     if( n < 1 )
259     {
260         ERR("invalid index adding %s (%d)\n", debugstr_w( data ), n );
261         return -1;
262     }
263
264     /* allocate a new string */
265     if(len<0)
266         len = strlenW(data);
267     TRACE("%s, n = %d len = %d\n", debugstr_w(data), n, len );
268
269     str = msi_alloc( (len+1)*sizeof(WCHAR) );
270     if( !str )
271         return -1;
272     TRACE("%d\n",__LINE__);
273     memcpy( str, data, len*sizeof(WCHAR) );
274     str[len] = 0;
275
276     set_st_entry( st, n, str, refcount, persistence );
277
278     return n;
279 }
280
281 /* find the string identified by an id - return null if there's none */
282 const WCHAR *msi_string_lookup_id( const string_table *st, UINT id )
283 {
284     static const WCHAR zero[] = { 0 };
285     if( id == 0 )
286         return zero;
287
288     if( id >= st->maxcount )
289         return NULL;
290
291     if( id && !st->strings[id].persistent_refcount && !st->strings[id].nonpersistent_refcount)
292         return NULL;
293
294     return st->strings[id].str;
295 }
296
297 /*
298  *  msi_id2stringW
299  *
300  *  [in] st         - pointer to the string table
301  *  [in] id  - id of the string to retrieve
302  *  [out] buffer    - destination of the string
303  *  [in/out] sz     - number of bytes available in the buffer on input
304  *                    number of bytes used on output
305  *
306  *   The size includes the terminating nul character.  Short buffers
307  *  will be filled, but not nul terminated.
308  */
309 UINT msi_id2stringW( const string_table *st, UINT id, LPWSTR buffer, UINT *sz )
310 {
311     UINT len;
312     const WCHAR *str;
313
314     TRACE("Finding string %d of %d\n", id, st->maxcount);
315
316     str = msi_string_lookup_id( st, id );
317     if( !str )
318         return ERROR_FUNCTION_FAILED;
319
320     len = strlenW( str ) + 1;
321
322     if( !buffer )
323     {
324         *sz = len;
325         return ERROR_SUCCESS;
326     }
327
328     if( *sz < len )
329         *sz = len;
330     memcpy( buffer, str, (*sz)*sizeof(WCHAR) ); 
331     *sz = len;
332
333     return ERROR_SUCCESS;
334 }
335
336 /*
337  *  msi_id2stringA
338  *
339  *  [in] st         - pointer to the string table
340  *  [in] id         - id of the string to retrieve
341  *  [out] buffer    - destination of the UTF8 string
342  *  [in/out] sz     - number of bytes available in the buffer on input
343  *                    number of bytes used on output
344  *
345  *   The size includes the terminating nul character.  Short buffers
346  *  will be filled, but not nul terminated.
347  */
348 UINT msi_id2stringA( const string_table *st, UINT id, LPSTR buffer, UINT *sz )
349 {
350     UINT len;
351     const WCHAR *str;
352     int n;
353
354     TRACE("Finding string %d of %d\n", id, st->maxcount);
355
356     str = msi_string_lookup_id( st, id );
357     if( !str )
358         return ERROR_FUNCTION_FAILED;
359
360     len = WideCharToMultiByte( st->codepage, 0, str, -1, NULL, 0, NULL, NULL );
361
362     if( !buffer )
363     {
364         *sz = len;
365         return ERROR_SUCCESS;
366     }
367
368     if( len > *sz )
369     {
370         n = strlenW( str ) + 1;
371         while( n && (len > *sz) )
372             len = WideCharToMultiByte( st->codepage, 0, 
373                            str, --n, NULL, 0, NULL, NULL );
374     }
375     else
376         n = -1;
377
378     *sz = WideCharToMultiByte( st->codepage, 0, str, n, buffer, len, NULL, NULL );
379
380     return ERROR_SUCCESS;
381 }
382
383 /*
384  *  msi_string2idW
385  *
386  *  [in] st         - pointer to the string table
387  *  [in] str        - string to find in the string table
388  *  [out] id        - id of the string, if found
389  */
390 UINT msi_string2idW( const string_table *st, LPCWSTR str, UINT *id )
391 {
392     UINT n, hash = msistring_makehash( str );
393     msistring *se = st->strings;
394
395     for (n = st->hash[hash]; n != -1; n = st->strings[n].hash_next )
396     {
397         if ((str == se[n].str) || !lstrcmpW(str, se[n].str))
398         {
399             *id = n;
400             return ERROR_SUCCESS;
401         }
402     }
403
404     return ERROR_INVALID_PARAMETER;
405 }
406
407 UINT msi_string2idA( const string_table *st, LPCSTR buffer, UINT *id )
408 {
409     DWORD sz;
410     UINT r = ERROR_INVALID_PARAMETER;
411     LPWSTR str;
412
413     TRACE("Finding string %s in string table\n", debugstr_a(buffer) );
414
415     if( buffer[0] == 0 )
416     {
417         *id = 0;
418         return ERROR_SUCCESS;
419     }
420
421     sz = MultiByteToWideChar( st->codepage, 0, buffer, -1, NULL, 0 );
422     if( sz <= 0 )
423         return r;
424     str = msi_alloc( sz*sizeof(WCHAR) );
425     if( !str )
426         return ERROR_NOT_ENOUGH_MEMORY;
427     MultiByteToWideChar( st->codepage, 0, buffer, -1, str, sz );
428
429     r = msi_string2idW( st, str, id );
430     msi_free( str );
431
432     return r;
433 }
434
435 UINT msi_strcmp( const string_table *st, UINT lval, UINT rval, UINT *res )
436 {
437     const WCHAR *l_str, *r_str;
438
439     l_str = msi_string_lookup_id( st, lval );
440     if( !l_str )
441         return ERROR_INVALID_PARAMETER;
442     
443     r_str = msi_string_lookup_id( st, rval );
444     if( !r_str )
445         return ERROR_INVALID_PARAMETER;
446
447     /* does this do the right thing for all UTF-8 strings? */
448     *res = strcmpW( l_str, r_str );
449
450     return ERROR_SUCCESS;
451 }
452
453 static void string_totalsize( const string_table *st, UINT *datasize, UINT *poolsize )
454 {
455     UINT i, len, max, holesize;
456
457     if( st->strings[0].str || st->strings[0].persistent_refcount || st->strings[0].nonpersistent_refcount)
458         ERR("oops. element 0 has a string\n");
459
460     *poolsize = 4;
461     *datasize = 0;
462     max = 1;
463     holesize = 0;
464     for( i=1; i<st->maxcount; i++ )
465     {
466         if( !st->strings[i].persistent_refcount )
467             continue;
468         if( st->strings[i].str )
469         {
470             TRACE("[%u] = %s\n", i, debugstr_w(st->strings[i].str));
471             len = WideCharToMultiByte( st->codepage, 0,
472                      st->strings[i].str, -1, NULL, 0, NULL, NULL);
473             if( len )
474                 len--;
475             (*datasize) += len;
476             if (len>0xffff)
477                 (*poolsize) += 4;
478             max = i + 1;
479             (*poolsize) += holesize + 4;
480             holesize = 0;
481         }
482         else
483             holesize += 4;
484     }
485     TRACE("data %u pool %u codepage %x\n", *datasize, *poolsize, st->codepage );
486 }
487
488 static const WCHAR szStringData[] = {
489     '_','S','t','r','i','n','g','D','a','t','a',0 };
490 static const WCHAR szStringPool[] = {
491     '_','S','t','r','i','n','g','P','o','o','l',0 };
492
493 HRESULT msi_init_string_table( IStorage *stg )
494 {
495     USHORT zero[2] = { 0, 0 };
496     UINT ret;
497
498     /* create the StringPool stream... add the zero string to it*/
499     ret = write_stream_data(stg, szStringPool, zero, sizeof zero, TRUE);
500     if (ret != ERROR_SUCCESS)
501         return E_FAIL;
502
503     /* create the StringData stream... make it zero length */
504     ret = write_stream_data(stg, szStringData, NULL, 0, TRUE);
505     if (ret != ERROR_SUCCESS)
506         return E_FAIL;
507
508     return S_OK;
509 }
510
511 string_table *msi_load_string_table( IStorage *stg, UINT *bytes_per_strref )
512 {
513     string_table *st = NULL;
514     CHAR *data = NULL;
515     USHORT *pool = NULL;
516     UINT r, datasize = 0, poolsize = 0, codepage;
517     DWORD i, count, offset, len, n, refs;
518
519     static const USHORT large_str_sig[] = { 0x0000, 0x8000 };
520
521     r = read_stream_data( stg, szStringPool, &pool, &poolsize );
522     if( r != ERROR_SUCCESS)
523         goto end;
524     r = read_stream_data( stg, szStringData, (USHORT**)&data, &datasize );
525     if( r != ERROR_SUCCESS)
526         goto end;
527
528     if ( !memcmp(pool, large_str_sig, sizeof(large_str_sig)) )
529         *bytes_per_strref = LONG_STR_BYTES;
530     else
531         *bytes_per_strref = sizeof(USHORT);
532
533     /* FIXME: don't know where the codepage is in large str tables */
534     count = poolsize/4;
535     if( poolsize > 4 && *bytes_per_strref != LONG_STR_BYTES )
536         codepage = pool[0] | ( pool[1] << 16 );
537     else
538         codepage = CP_ACP;
539     st = init_stringtable( count, codepage );
540
541     offset = 0;
542     n = 1;
543     i = 1;
544     while( i<count )
545     {
546         /* the string reference count is always the second word */
547         refs = pool[i*2+1];
548
549         /* empty entries have two zeros, still have a string id */
550         if (pool[i*2] == 0 && refs == 0)
551         {
552             i++;
553             n++;
554             continue;
555         }
556
557         /*
558          * If a string is over 64k, the previous string entry is made null
559          * and its the high word of the length is inserted in the null string's
560          * reference count field.
561          */
562         if( pool[i*2] == 0)
563         {
564             len = (pool[i*2+3] << 16) + pool[i*2+2];
565             i += 2;
566         }
567         else
568         {
569             len = pool[i*2];
570             i += 1;
571         }
572
573         if ( (offset + len) > datasize )
574         {
575             ERR("string table corrupt?\n");
576             break;
577         }
578
579         r = msi_addstring( st, n, data+offset, len, refs, StringPersistent );
580         if( r != n )
581             ERR("Failed to add string %d\n", n );
582         n++;
583         offset += len;
584     }
585
586     if ( datasize != offset )
587         ERR("string table load failed! (%08x != %08x), please report\n", datasize, offset );
588
589     TRACE("Loaded %d strings\n", count);
590
591 end:
592     msi_free( pool );
593     msi_free( data );
594
595     return st;
596 }
597
598 UINT msi_save_string_table( const string_table *st, IStorage *storage )
599 {
600     UINT i, datasize = 0, poolsize = 0, sz, used, r, codepage, n;
601     UINT ret = ERROR_FUNCTION_FAILED;
602     CHAR *data = NULL;
603     USHORT *pool = NULL;
604
605     TRACE("\n");
606
607     /* construct the new table in memory first */
608     string_totalsize( st, &datasize, &poolsize );
609
610     TRACE("%u %u %u\n", st->maxcount, datasize, poolsize );
611
612     pool = msi_alloc( poolsize );
613     if( ! pool )
614     {
615         WARN("Failed to alloc pool %d bytes\n", poolsize );
616         goto err;
617     }
618     data = msi_alloc( datasize );
619     if( ! data )
620     {
621         WARN("Failed to alloc data %d bytes\n", poolsize );
622         goto err;
623     }
624
625     used = 0;
626     codepage = st->codepage;
627     pool[0]=codepage&0xffff;
628     pool[1]=(codepage>>16);
629     n = 1;
630     for( i=1; i<st->maxcount; i++ )
631     {
632         if( !st->strings[i].persistent_refcount )
633             continue;
634         sz = datasize - used;
635         r = msi_id2stringA( st, i, data+used, &sz );
636         if( r != ERROR_SUCCESS )
637         {
638             ERR("failed to fetch string\n");
639             sz = 0;
640         }
641         if( sz && (sz < (datasize - used ) ) )
642             sz--;
643
644         if (sz)
645             pool[ n*2 + 1 ] = st->strings[i].persistent_refcount;
646         else
647             pool[ n*2 + 1 ] = 0;
648         if (sz < 0x10000)
649         {
650             pool[ n*2 ] = sz;
651             n++;
652         }
653         else
654         {
655             pool[ n*2 ] = 0;
656             pool[ n*2 + 2 ] = sz&0xffff;
657             pool[ n*2 + 3 ] = (sz>>16);
658             n += 2;
659         }
660         used += sz;
661         if( used > datasize  )
662         {
663             ERR("oops overran %d >= %d\n", used, datasize);
664             goto err;
665         }
666     }
667
668     if( used != datasize )
669     {
670         ERR("oops used %d != datasize %d\n", used, datasize);
671         goto err;
672     }
673
674     /* write the streams */
675     r = write_stream_data( storage, szStringData, data, datasize, TRUE );
676     TRACE("Wrote StringData r=%08x\n", r);
677     if( r )
678         goto err;
679     r = write_stream_data( storage, szStringPool, pool, poolsize, TRUE );
680     TRACE("Wrote StringPool r=%08x\n", r);
681     if( r )
682         goto err;
683
684     ret = ERROR_SUCCESS;
685
686 err:
687     msi_free( data );
688     msi_free( pool );
689
690     return ret;
691 }