msi: Represent table data as bytes instead of shorts.
[wine] / dlls / msi / string.c
1 /*
2  * String Table Functions
3  *
4  * Copyright 2002-2004, Mike McCormack for CodeWeavers
5  * Copyright 2007 Robert Shearman for CodeWeavers
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
20  */
21
22 #define COBJMACROS
23
24 #include <stdarg.h>
25 #include <assert.h>
26
27 #include "windef.h"
28 #include "winbase.h"
29 #include "winerror.h"
30 #include "wine/debug.h"
31 #include "wine/unicode.h"
32 #include "msi.h"
33 #include "msiquery.h"
34 #include "objbase.h"
35 #include "objidl.h"
36 #include "msipriv.h"
37 #include "winnls.h"
38
39 #include "query.h"
40
41 WINE_DEFAULT_DEBUG_CHANNEL(msidb);
42
43 #define HASH_SIZE 0x101
44
45 typedef struct _msistring
46 {
47     int hash_next;
48     UINT persistent_refcount;
49     UINT nonpersistent_refcount;
50     LPWSTR str;
51 } msistring;
52
53 struct string_table
54 {
55     UINT maxcount;         /* the number of strings */
56     UINT freeslot;
57     UINT codepage;
58     int hash[HASH_SIZE];
59     msistring *strings; /* an array of strings (in the tree) */
60 };
61
62 static UINT msistring_makehash( const WCHAR *str )
63 {
64     UINT hash = 0;
65
66     if (str==NULL)
67         return hash;
68
69     while( *str )
70     {
71         hash ^= *str++;
72         hash *= 53;
73         hash = (hash<<5) | (hash>>27);
74     }
75     return hash % HASH_SIZE;
76 }
77
78 static string_table *init_stringtable( int entries, UINT codepage )
79 {
80     string_table *st;
81     int i;
82
83     st = msi_alloc( sizeof (string_table) );
84     if( !st )
85         return NULL;    
86     if( entries < 1 )
87         entries = 1;
88     st->strings = msi_alloc_zero( sizeof (msistring) * entries );
89     if( !st->strings )
90     {
91         msi_free( st );
92         return NULL;    
93     }
94     st->maxcount = entries;
95     st->freeslot = 1;
96     st->codepage = codepage;
97
98     for( i=0; i<HASH_SIZE; i++ )
99         st->hash[i] = -1;
100
101     return st;
102 }
103
104 VOID msi_destroy_stringtable( string_table *st )
105 {
106     UINT i;
107
108     for( i=0; i<st->maxcount; i++ )
109     {
110         if( st->strings[i].persistent_refcount ||
111             st->strings[i].nonpersistent_refcount )
112             msi_free( st->strings[i].str );
113     }
114     msi_free( st->strings );
115     msi_free( st );
116 }
117
118 static int st_find_free_entry( string_table *st )
119 {
120     UINT i, sz;
121     msistring *p;
122
123     TRACE("%p\n", st);
124
125     if( st->freeslot )
126     {
127         for( i = st->freeslot; i < st->maxcount; i++ )
128             if( !st->strings[i].persistent_refcount &&
129                 !st->strings[i].nonpersistent_refcount )
130                 return i;
131     }
132     for( i = 1; i < st->maxcount; i++ )
133         if( !st->strings[i].persistent_refcount &&
134             !st->strings[i].nonpersistent_refcount )
135             return i;
136
137     /* dynamically resize */
138     sz = st->maxcount + 1 + st->maxcount/2;
139     p = msi_realloc_zero( st->strings, sz*sizeof(msistring) );
140     if( !p )
141         return -1;
142     st->strings = p;
143     st->freeslot = st->maxcount;
144     st->maxcount = sz;
145     if( st->strings[st->freeslot].persistent_refcount ||
146         st->strings[st->freeslot].nonpersistent_refcount )
147         ERR("oops. expected freeslot to be free...\n");
148     return st->freeslot;
149 }
150
151 static void set_st_entry( string_table *st, UINT n, LPWSTR str, UINT refcount, enum StringPersistence persistence )
152 {
153     UINT hash = msistring_makehash( str );
154
155     if (persistence == StringPersistent)
156     {
157         st->strings[n].persistent_refcount = refcount;
158         st->strings[n].nonpersistent_refcount = 0;
159     }
160     else
161     {
162         st->strings[n].persistent_refcount = 0;
163         st->strings[n].nonpersistent_refcount = refcount;
164     }
165
166     st->strings[n].str = str;
167
168     st->strings[n].hash_next = st->hash[hash];
169     st->hash[hash] = n;
170
171     if( n < st->maxcount )
172         st->freeslot = n + 1;
173 }
174
175 static int msi_addstring( string_table *st, UINT n, const CHAR *data, int len, UINT refcount, enum StringPersistence persistence )
176 {
177     LPWSTR str;
178     int sz;
179
180     if( !data )
181         return 0;
182     if( !data[0] )
183         return 0;
184     if( n > 0 )
185     {
186         if( st->strings[n].persistent_refcount ||
187             st->strings[n].nonpersistent_refcount )
188             return -1;
189     }
190     else
191     {
192         if( ERROR_SUCCESS == msi_string2idA( st, data, &n ) )
193         {
194             if (persistence == StringPersistent)
195                 st->strings[n].persistent_refcount += refcount;
196             else
197                 st->strings[n].nonpersistent_refcount += refcount;
198             return n;
199         }
200         n = st_find_free_entry( st );
201         if( n < 0 )
202             return -1;
203     }
204
205     if( n < 1 )
206     {
207         ERR("invalid index adding %s (%d)\n", debugstr_a( data ), n );
208         return -1;
209     }
210
211     /* allocate a new string */
212     if( len < 0 )
213         len = strlen(data);
214     sz = MultiByteToWideChar( st->codepage, 0, data, len, NULL, 0 );
215     str = msi_alloc( (sz+1)*sizeof(WCHAR) );
216     if( !str )
217         return -1;
218     MultiByteToWideChar( st->codepage, 0, data, len, str, sz );
219     str[sz] = 0;
220
221     set_st_entry( st, n, str, refcount, persistence );
222
223     return n;
224 }
225
226 int msi_addstringW( string_table *st, UINT n, const WCHAR *data, int len, UINT refcount, enum StringPersistence persistence )
227 {
228     LPWSTR str;
229
230     /* TRACE("[%2d] = %s\n", string_no, debugstr_an(data,len) ); */
231
232     if( !data )
233         return 0;
234     if( !data[0] )
235         return 0;
236     if( n > 0 )
237     {
238         if( st->strings[n].persistent_refcount ||
239             st->strings[n].nonpersistent_refcount )
240             return -1;
241     }
242     else
243     {
244         if( ERROR_SUCCESS == msi_string2idW( st, data, &n ) )
245         {
246             if (persistence == StringPersistent)
247                 st->strings[n].persistent_refcount += refcount;
248             else
249                 st->strings[n].nonpersistent_refcount += refcount;
250             return n;
251         }
252         n = st_find_free_entry( st );
253         if( n < 0 )
254             return -1;
255     }
256
257     if( n < 1 )
258     {
259         ERR("invalid index adding %s (%d)\n", debugstr_w( data ), n );
260         return -1;
261     }
262
263     /* allocate a new string */
264     if(len<0)
265         len = strlenW(data);
266     TRACE("%s, n = %d len = %d\n", debugstr_w(data), n, len );
267
268     str = msi_alloc( (len+1)*sizeof(WCHAR) );
269     if( !str )
270         return -1;
271     TRACE("%d\n",__LINE__);
272     memcpy( str, data, len*sizeof(WCHAR) );
273     str[len] = 0;
274
275     set_st_entry( st, n, str, refcount, persistence );
276
277     return n;
278 }
279
280 /* find the string identified by an id - return null if there's none */
281 const WCHAR *msi_string_lookup_id( string_table *st, UINT id )
282 {
283     static const WCHAR zero[] = { 0 };
284     if( id == 0 )
285         return zero;
286
287     if( id >= st->maxcount )
288         return NULL;
289
290     if( id && !st->strings[id].persistent_refcount && !st->strings[id].nonpersistent_refcount)
291         return NULL;
292
293     return st->strings[id].str;
294 }
295
296 /*
297  *  msi_id2stringW
298  *
299  *  [in] st         - pointer to the string table
300  *  [in] id  - id of the string to retrieve
301  *  [out] buffer    - destination of the string
302  *  [in/out] sz     - number of bytes available in the buffer on input
303  *                    number of bytes used on output
304  *
305  *   The size includes the terminating nul character.  Short buffers
306  *  will be filled, but not nul terminated.
307  */
308 UINT msi_id2stringW( string_table *st, UINT id, LPWSTR buffer, UINT *sz )
309 {
310     UINT len;
311     const WCHAR *str;
312
313     TRACE("Finding string %d of %d\n", id, st->maxcount);
314
315     str = msi_string_lookup_id( st, id );
316     if( !str )
317         return ERROR_FUNCTION_FAILED;
318
319     len = strlenW( str ) + 1;
320
321     if( !buffer )
322     {
323         *sz = len;
324         return ERROR_SUCCESS;
325     }
326
327     if( *sz < len )
328         *sz = len;
329     memcpy( buffer, str, (*sz)*sizeof(WCHAR) ); 
330     *sz = len;
331
332     return ERROR_SUCCESS;
333 }
334
335 /*
336  *  msi_id2stringA
337  *
338  *  [in] st         - pointer to the string table
339  *  [in] id         - id of the string to retrieve
340  *  [out] buffer    - destination of the UTF8 string
341  *  [in/out] sz     - number of bytes available in the buffer on input
342  *                    number of bytes used on output
343  *
344  *   The size includes the terminating nul character.  Short buffers
345  *  will be filled, but not nul terminated.
346  */
347 UINT msi_id2stringA( string_table *st, UINT id, LPSTR buffer, UINT *sz )
348 {
349     UINT len;
350     const WCHAR *str;
351     int n;
352
353     TRACE("Finding string %d of %d\n", id, st->maxcount);
354
355     str = msi_string_lookup_id( st, id );
356     if( !str )
357         return ERROR_FUNCTION_FAILED;
358
359     len = WideCharToMultiByte( st->codepage, 0, str, -1, NULL, 0, NULL, NULL );
360
361     if( !buffer )
362     {
363         *sz = len;
364         return ERROR_SUCCESS;
365     }
366
367     if( len > *sz )
368     {
369         n = strlenW( str ) + 1;
370         while( n && (len > *sz) )
371             len = WideCharToMultiByte( st->codepage, 0, 
372                            str, --n, NULL, 0, NULL, NULL );
373     }
374     else
375         n = -1;
376
377     *sz = WideCharToMultiByte( st->codepage, 0, str, n, buffer, len, NULL, NULL );
378
379     return ERROR_SUCCESS;
380 }
381
382 /*
383  *  msi_string2idW
384  *
385  *  [in] st         - pointer to the string table
386  *  [in] str        - string to find in the string table
387  *  [out] id        - id of the string, if found
388  */
389 UINT msi_string2idW( string_table *st, LPCWSTR str, UINT *id )
390 {
391     UINT n, hash = msistring_makehash( str );
392     msistring *se = st->strings;
393
394     for (n = st->hash[hash]; n != -1; n = st->strings[n].hash_next )
395     {
396         if ((str == se[n].str) || !lstrcmpW(str, se[n].str))
397         {
398             *id = n;
399             return ERROR_SUCCESS;
400         }
401     }
402
403     return ERROR_INVALID_PARAMETER;
404 }
405
406 UINT msi_string2idA( string_table *st, LPCSTR buffer, UINT *id )
407 {
408     DWORD sz;
409     UINT r = ERROR_INVALID_PARAMETER;
410     LPWSTR str;
411
412     TRACE("Finding string %s in string table\n", debugstr_a(buffer) );
413
414     if( buffer[0] == 0 )
415     {
416         *id = 0;
417         return ERROR_SUCCESS;
418     }
419
420     sz = MultiByteToWideChar( st->codepage, 0, buffer, -1, NULL, 0 );
421     if( sz <= 0 )
422         return r;
423     str = msi_alloc( sz*sizeof(WCHAR) );
424     if( !str )
425         return ERROR_NOT_ENOUGH_MEMORY;
426     MultiByteToWideChar( st->codepage, 0, buffer, -1, str, sz );
427
428     r = msi_string2idW( st, str, id );
429     msi_free( str );
430
431     return r;
432 }
433
434 UINT msi_strcmp( string_table *st, UINT lval, UINT rval, UINT *res )
435 {
436     const WCHAR *l_str, *r_str;
437
438     l_str = msi_string_lookup_id( st, lval );
439     if( !l_str )
440         return ERROR_INVALID_PARAMETER;
441     
442     r_str = msi_string_lookup_id( st, rval );
443     if( !r_str )
444         return ERROR_INVALID_PARAMETER;
445
446     /* does this do the right thing for all UTF-8 strings? */
447     *res = strcmpW( l_str, r_str );
448
449     return ERROR_SUCCESS;
450 }
451
452 static void string_totalsize( string_table *st, UINT *datasize, UINT *poolsize )
453 {
454     UINT i, len, max, holesize;
455
456     if( st->strings[0].str || st->strings[0].persistent_refcount || st->strings[0].nonpersistent_refcount)
457         ERR("oops. element 0 has a string\n");
458
459     *poolsize = 4;
460     *datasize = 0;
461     max = 1;
462     holesize = 0;
463     for( i=1; i<st->maxcount; i++ )
464     {
465         if( !st->strings[i].persistent_refcount )
466             continue;
467         if( st->strings[i].str )
468         {
469             TRACE("[%u] = %s\n", i, debugstr_w(st->strings[i].str));
470             len = WideCharToMultiByte( st->codepage, 0,
471                      st->strings[i].str, -1, NULL, 0, NULL, NULL);
472             if( len )
473                 len--;
474             (*datasize) += len;
475             if (len>0xffff)
476                 (*poolsize) += 4;
477             max = i + 1;
478             (*poolsize) += holesize + 4;
479             holesize = 0;
480         }
481         else
482             holesize += 4;
483     }
484     TRACE("data %u pool %u codepage %x\n", *datasize, *poolsize, st->codepage );
485 }
486
487 static const WCHAR szStringData[] = {
488     '_','S','t','r','i','n','g','D','a','t','a',0 };
489 static const WCHAR szStringPool[] = {
490     '_','S','t','r','i','n','g','P','o','o','l',0 };
491
492 HRESULT msi_init_string_table( IStorage *stg )
493 {
494     USHORT zero[2] = { 0, 0 };
495     UINT ret;
496
497     /* create the StringPool stream... add the zero string to it*/
498     ret = write_stream_data(stg, szStringPool, zero, sizeof zero, TRUE);
499     if (ret != ERROR_SUCCESS)
500         return E_FAIL;
501
502     /* create the StringData stream... make it zero length */
503     ret = write_stream_data(stg, szStringData, NULL, 0, TRUE);
504     if (ret != ERROR_SUCCESS)
505         return E_FAIL;
506
507     return S_OK;
508 }
509
510 string_table *msi_load_string_table( IStorage *stg )
511 {
512     string_table *st = NULL;
513     CHAR *data = NULL;
514     USHORT *pool = NULL;
515     UINT r, datasize = 0, poolsize = 0, codepage;
516     DWORD i, count, offset, len, n, refs;
517
518     r = read_stream_data( stg, szStringPool, &pool, &poolsize );
519     if( r != ERROR_SUCCESS)
520         goto end;
521     r = read_stream_data( stg, szStringData, (USHORT**)&data, &datasize );
522     if( r != ERROR_SUCCESS)
523         goto end;
524
525     count = poolsize/4;
526     if( poolsize > 4 )
527         codepage = pool[0] | ( pool[1] << 16 );
528     else
529         codepage = CP_ACP;
530     st = init_stringtable( count, codepage );
531
532     offset = 0;
533     n = 1;
534     i = 1;
535     while( i<count )
536     {
537         /* the string reference count is always the second word */
538         refs = pool[i*2+1];
539
540         /* empty entries have two zeros, still have a string id */
541         if (pool[i*2] == 0 && refs == 0)
542         {
543             i++;
544             n++;
545             continue;
546         }
547
548         /*
549          * If a string is over 64k, the previous string entry is made null
550          * and its the high word of the length is inserted in the null string's
551          * reference count field.
552          */
553         if( pool[i*2] == 0)
554         {
555             len = (pool[i*2+3] << 16) + pool[i*2+2];
556             i += 2;
557         }
558         else
559         {
560             len = pool[i*2];
561             i += 1;
562         }
563
564         if ( (offset + len) > datasize )
565         {
566             ERR("string table corrupt?\n");
567             break;
568         }
569
570         r = msi_addstring( st, n, data+offset, len, refs, StringPersistent );
571         if( r != n )
572             ERR("Failed to add string %d\n", n );
573         n++;
574         offset += len;
575     }
576
577     if ( datasize != offset )
578         ERR("string table load failed! (%08x != %08x), please report\n", datasize, offset );
579
580     TRACE("Loaded %d strings\n", count);
581
582 end:
583     msi_free( pool );
584     msi_free( data );
585
586     return st;
587 }
588
589 UINT msi_save_string_table( string_table *st, IStorage *storage )
590 {
591     UINT i, datasize = 0, poolsize = 0, sz, used, r, codepage, n;
592     UINT ret = ERROR_FUNCTION_FAILED;
593     CHAR *data = NULL;
594     USHORT *pool = NULL;
595
596     TRACE("\n");
597
598     /* construct the new table in memory first */
599     string_totalsize( st, &datasize, &poolsize );
600
601     TRACE("%u %u %u\n", st->maxcount, datasize, poolsize );
602
603     pool = msi_alloc( poolsize );
604     if( ! pool )
605     {
606         WARN("Failed to alloc pool %d bytes\n", poolsize );
607         goto err;
608     }
609     data = msi_alloc( datasize );
610     if( ! data )
611     {
612         WARN("Failed to alloc data %d bytes\n", poolsize );
613         goto err;
614     }
615
616     used = 0;
617     codepage = st->codepage;
618     pool[0]=codepage&0xffff;
619     pool[1]=(codepage>>16);
620     n = 1;
621     for( i=1; i<st->maxcount; i++ )
622     {
623         if( !st->strings[i].persistent_refcount )
624             continue;
625         sz = datasize - used;
626         r = msi_id2stringA( st, i, data+used, &sz );
627         if( r != ERROR_SUCCESS )
628         {
629             ERR("failed to fetch string\n");
630             sz = 0;
631         }
632         if( sz && (sz < (datasize - used ) ) )
633             sz--;
634
635         if (sz)
636             pool[ n*2 + 1 ] = st->strings[i].persistent_refcount;
637         else
638             pool[ n*2 + 1 ] = 0;
639         if (sz < 0x10000)
640         {
641             pool[ n*2 ] = sz;
642             n++;
643         }
644         else
645         {
646             pool[ n*2 ] = 0;
647             pool[ n*2 + 2 ] = sz&0xffff;
648             pool[ n*2 + 3 ] = (sz>>16);
649             n += 2;
650         }
651         used += sz;
652         if( used > datasize  )
653         {
654             ERR("oops overran %d >= %d\n", used, datasize);
655             goto err;
656         }
657     }
658
659     if( used != datasize )
660     {
661         ERR("oops used %d != datasize %d\n", used, datasize);
662         goto err;
663     }
664
665     /* write the streams */
666     r = write_stream_data( storage, szStringData, data, datasize, TRUE );
667     TRACE("Wrote StringData r=%08x\n", r);
668     if( r )
669         goto err;
670     r = write_stream_data( storage, szStringPool, pool, poolsize, TRUE );
671     TRACE("Wrote StringPool r=%08x\n", r);
672     if( r )
673         goto err;
674
675     ret = ERROR_SUCCESS;
676
677 err:
678     msi_free( data );
679     msi_free( pool );
680
681     return ret;
682 }