Fix signed/unsigned comparison warnings.
[wine] / dlls / advapi32 / tests / crypt.c
1 /*
2  * Unit tests for crypt functions
3  *
4  * Copyright (c) 2004 Michael Jung
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19  */
20
21 #include <stdio.h>
22
23 #include "wine/test.h"
24 #include "windef.h"
25 #include "winbase.h"
26 #include "wincrypt.h"
27 #include "winerror.h"
28 #include "winreg.h"
29
30 static const char szRsaBaseProv[] = MS_DEF_PROV_A;
31 static const char szNonExistentProv[] = "Wine Non Existent Cryptographic Provider v11.2";
32 static const char szKeySet[] = "wine_test_keyset";
33 static const char szBadKeySet[] = "wine_test_bad_keyset";
34 #define NON_DEF_PROV_TYPE 999
35
36 static void init_environment(void)
37 {
38         HCRYPTPROV hProv;
39         
40         /* Ensure that container "wine_test_keyset" does exist */
41         if (!CryptAcquireContext(&hProv, szKeySet, szRsaBaseProv, PROV_RSA_FULL, 0))
42         {
43                 CryptAcquireContext(&hProv, szKeySet, szRsaBaseProv, PROV_RSA_FULL, CRYPT_NEWKEYSET);
44         }
45         CryptReleaseContext(hProv, 0);
46
47         /* Ensure that container "wine_test_keyset" does exist in default PROV_RSA_FULL type provider */
48         if (!CryptAcquireContext(&hProv, szKeySet, NULL, PROV_RSA_FULL, 0))
49         {
50                 CryptAcquireContext(&hProv, szKeySet, NULL, PROV_RSA_FULL, CRYPT_NEWKEYSET);
51         }
52         CryptReleaseContext(hProv, 0);
53
54         /* Ensure that container "wine_test_bad_keyset" does not exist. */
55         if (CryptAcquireContext(&hProv, szBadKeySet, szRsaBaseProv, PROV_RSA_FULL, 0))
56         {
57                 CryptReleaseContext(hProv, 0);
58                 CryptAcquireContext(&hProv, szBadKeySet, szRsaBaseProv, PROV_RSA_FULL, CRYPT_DELETEKEYSET);
59         }
60 }
61
62 static void clean_up_environment(void)
63 {
64         HCRYPTPROV hProv;
65
66         /* Remove container "wine_test_keyset" */
67         if (CryptAcquireContext(&hProv, szKeySet, szRsaBaseProv, PROV_RSA_FULL, 0))
68         {
69                 CryptReleaseContext(hProv, 0);
70                 CryptAcquireContext(&hProv, szKeySet, szRsaBaseProv, PROV_RSA_FULL, CRYPT_DELETEKEYSET);
71         }
72
73         /* Remove container "wine_test_keyset" from default PROV_RSA_FULL type provider */
74         if (CryptAcquireContext(&hProv, szKeySet, NULL, PROV_RSA_FULL, 0))
75         {
76                 CryptReleaseContext(hProv, 0);
77                 CryptAcquireContext(&hProv, szKeySet, NULL, PROV_RSA_FULL, CRYPT_DELETEKEYSET);
78         }
79 }
80
81 static void test_acquire_context(void)
82 {
83         BOOL result;
84         HCRYPTPROV hProv;
85
86         /* Provoke all kinds of error conditions (which are easy to provoke). 
87          * The order of the error tests seems to match Windows XP's rsaenh.dll CSP,
88          * but since this is likely to change between CSP versions, we don't check
89          * this. Please don't change the order of tests. */
90         result = CryptAcquireContext(&hProv, NULL, NULL, 0, 0);
91         ok(!result && GetLastError()==NTE_BAD_PROV_TYPE, "%ld\n", GetLastError());
92         
93         result = CryptAcquireContext(&hProv, NULL, NULL, 1000, 0);
94         ok(!result && GetLastError()==NTE_BAD_PROV_TYPE, "%ld\n", GetLastError());
95
96         result = CryptAcquireContext(&hProv, NULL, NULL, NON_DEF_PROV_TYPE, 0);
97         ok(!result && GetLastError()==NTE_PROV_TYPE_NOT_DEF, "%ld\n", GetLastError());
98         
99         result = CryptAcquireContext(&hProv, szKeySet, szNonExistentProv, PROV_RSA_FULL, 0);
100         ok(!result && GetLastError()==NTE_KEYSET_NOT_DEF, "%ld\n", GetLastError());
101
102         result = CryptAcquireContext(&hProv, szKeySet, szRsaBaseProv, NON_DEF_PROV_TYPE, 0);
103         ok(!result && GetLastError()==NTE_PROV_TYPE_NO_MATCH, "%ld\n", GetLastError());
104         
105         result = CryptAcquireContext(NULL, szKeySet, szRsaBaseProv, PROV_RSA_FULL, 0);
106         ok(!result && GetLastError()==ERROR_INVALID_PARAMETER, "%ld\n", GetLastError());
107         
108         /* Last not least, try to really acquire a context. */
109         result = CryptAcquireContext(&hProv, szKeySet, szRsaBaseProv, PROV_RSA_FULL, 0);
110         ok(result, "%ld\n", GetLastError());
111
112         if (GetLastError() == ERROR_SUCCESS) 
113                 CryptReleaseContext(hProv, 0);
114
115         /* Try again, witch an empty ("\0") szProvider parameter */
116         result = CryptAcquireContext(&hProv, szKeySet, "", PROV_RSA_FULL, 0);
117         ok(result, "%ld\n", GetLastError());
118
119         if (GetLastError() == ERROR_SUCCESS)
120                 CryptReleaseContext(hProv, 0);
121 }
122
123 static BOOL FindProvRegVals(DWORD dwIndex, DWORD *pdwProvType, LPSTR *pszProvName, 
124                             DWORD *pcbProvName, DWORD *pdwProvCount)
125 {
126         HKEY hKey;
127         HKEY subkey;
128         DWORD size = sizeof(DWORD);
129         
130         if (RegOpenKey(HKEY_LOCAL_MACHINE, "Software\\Microsoft\\Cryptography\\Defaults\\Provider", &hKey))
131                 return FALSE;
132         
133         RegQueryInfoKey(hKey, NULL, NULL, NULL, pdwProvCount, pcbProvName, 
134                                  NULL, NULL, NULL, NULL, NULL, NULL);
135         (*pcbProvName)++;
136         
137         if (!(*pszProvName = ((LPSTR)LocalAlloc(LMEM_ZEROINIT, *pcbProvName))))
138                 return FALSE;
139         
140         RegEnumKeyEx(hKey, dwIndex, *pszProvName, pcbProvName, NULL, NULL, NULL, NULL);
141         (*pcbProvName)++;
142
143         RegOpenKey(hKey, *pszProvName, &subkey);
144         RegQueryValueEx(subkey, "Type", NULL, NULL, (BYTE*)pdwProvType, &size);
145         
146         RegCloseKey(subkey);
147         RegCloseKey(hKey);
148         
149         return TRUE;
150 }
151
152 static void test_enum_providers(void)
153 {
154         /* expected results */
155         CHAR *pszProvName = NULL;
156         DWORD cbName;
157         DWORD dwType;
158         DWORD provCount;
159         DWORD dwIndex = 0;
160         
161         /* actual results */
162         CHAR *provider = NULL;
163         DWORD providerLen;
164         DWORD type;
165         DWORD count;
166         BOOL result;
167         DWORD notNull = 5;
168         DWORD notZeroFlags = 5;
169         
170         if (!FindProvRegVals(dwIndex, &dwType, &pszProvName, &cbName, &provCount))
171                 return;
172         
173         /* check pdwReserved flag for NULL */
174         result = CryptEnumProviders(dwIndex, &notNull, 0, &type, NULL, &providerLen);
175         ok(!result && GetLastError()==ERROR_INVALID_PARAMETER, "%ld\n", GetLastError());
176         
177         /* check dwFlags == 0 */
178         result = CryptEnumProviders(dwIndex, NULL, notZeroFlags, &type, NULL, &providerLen);
179         ok(!result && GetLastError()==NTE_BAD_FLAGS, "%ld\n", GetLastError());
180         
181         /* alloc provider to half the size required
182          * cbName holds the size required */
183         providerLen = cbName / 2;
184         if (!(provider = ((LPSTR)LocalAlloc(LMEM_ZEROINIT, providerLen))))
185                 return;
186
187         result = CryptEnumProviders(dwIndex, NULL, 0, &type, provider, &providerLen);
188         ok(!result && GetLastError()==ERROR_MORE_DATA, "expected %i, got %ld\n",
189                 ERROR_MORE_DATA, GetLastError());
190
191         LocalFree(provider);
192
193         /* loop through the providers to get the number of providers 
194          * after loop ends, count should be provCount + 1 so subtract 1
195          * to get actual number of providers */
196         count = 0;
197         while(CryptEnumProviders(count++, NULL, 0, &type, NULL, &providerLen))
198                 ;
199         count--;
200         ok(count==provCount, "expected %i, got %i\n", (int)provCount, (int)count);
201         
202         /* loop past the actual number of providers to get the error
203          * ERROR_NO_MORE_ITEMS */
204         for (count = 0; count < provCount + 1; count++)
205                 result = CryptEnumProviders(count, NULL, 0, &type, NULL, &providerLen);
206         ok(!result && GetLastError()==ERROR_NO_MORE_ITEMS, "expected %i, got %ld\n", 
207                         ERROR_NO_MORE_ITEMS, GetLastError());
208         
209         /* check expected versus actual values returned */
210         result = CryptEnumProviders(dwIndex, NULL, 0, &type, NULL, &providerLen);
211         ok(result && providerLen==cbName, "expected %i, got %i\n", (int)cbName, (int)providerLen);
212         if (!(provider = ((LPSTR)LocalAlloc(LMEM_ZEROINIT, providerLen))))
213                 return;
214                 
215         result = CryptEnumProviders(dwIndex, NULL, 0, &type, provider, &providerLen);
216         ok(result && type==dwType, "expected %ld, got %ld\n", 
217                 dwType, type);
218         ok(result && !strcmp(pszProvName, provider), "expected %s, got %s\n", pszProvName, provider);
219         ok(result && cbName==providerLen, "expected %ld, got %ld\n", 
220                 cbName, providerLen);
221 }
222
223 static BOOL FindProvTypesRegVals(DWORD dwIndex, DWORD *pdwProvType, LPSTR *pszTypeName, 
224                                  DWORD *pcbTypeName, DWORD *pdwTypeCount)
225 {
226         HKEY hKey;
227         HKEY hSubKey;
228         PSTR ch;
229         
230         if (RegOpenKey(HKEY_LOCAL_MACHINE, "Software\\Microsoft\\Cryptography\\Defaults\\Provider Types", &hKey))
231                 return FALSE;
232         
233         RegQueryInfoKey(hKey, NULL, NULL, NULL, pdwTypeCount, pcbTypeName, NULL,
234                         NULL, NULL, NULL, NULL, NULL);
235         (*pcbTypeName)++;
236         
237         if (!(*pszTypeName = ((LPSTR)LocalAlloc(LMEM_ZEROINIT, *pcbTypeName))))
238                 return FALSE;
239         
240         RegEnumKeyEx(hKey, dwIndex, *pszTypeName, pcbTypeName, NULL, NULL, NULL, NULL);
241         (*pcbTypeName)++;
242         ch = *pszTypeName + strlen(*pszTypeName);
243         /* Convert "Type 000" to 0, etc/ */
244         *pdwProvType = *(--ch) - '0';
245         *pdwProvType += (*(--ch) - '0') * 10;
246         *pdwProvType += (*(--ch) - '0') * 100;
247         
248         RegOpenKey(hKey, *pszTypeName, &hSubKey);
249         LocalFree(*pszTypeName);
250         
251         RegQueryValueEx(hSubKey, "TypeName", NULL, NULL, NULL, pcbTypeName);
252         if (!(*pszTypeName = ((LPSTR)LocalAlloc(LMEM_ZEROINIT, *pcbTypeName))))
253                 return FALSE;
254         
255         RegQueryValueEx(hSubKey, "TypeName", NULL, NULL, *pszTypeName, pcbTypeName);
256         
257         RegCloseKey(hSubKey);
258         RegCloseKey(hKey);
259         
260         return TRUE;
261 }
262
263 static void test_enum_provider_types()
264 {
265         /* expected values */
266         DWORD dwProvType;
267         LPSTR pszTypeName = NULL;
268         DWORD cbTypeName;
269         DWORD dwTypeCount;
270         
271         /* actual values */
272         DWORD index = 0;
273         DWORD provType;
274         LPSTR typeName = NULL;
275         DWORD typeNameSize;
276         DWORD typeCount;
277         DWORD result;
278         DWORD notNull = 5;
279         DWORD notZeroFlags = 5;
280         
281         if (!FindProvTypesRegVals(index, &dwProvType, &pszTypeName, &cbTypeName, &dwTypeCount))
282                 return;
283         
284         /* check pdwReserved for NULL */
285         result = CryptEnumProviderTypes(index, &notNull, 0, &provType, typeName, &typeNameSize);
286         ok(!result && GetLastError()==ERROR_INVALID_PARAMETER, "expected %i, got %ld\n", 
287                 ERROR_INVALID_PARAMETER, GetLastError());
288         
289         /* check dwFlags == zero */
290         result = CryptEnumProviderTypes(index, NULL, notZeroFlags, &provType, typeName, &typeNameSize);
291         ok(!result && GetLastError()==NTE_BAD_FLAGS, "expected %i, got %ld\n",
292                 ERROR_INVALID_PARAMETER, GetLastError());
293         
294         /* alloc provider type to half the size required
295          * cbTypeName holds the size required */
296         typeNameSize = cbTypeName / 2;
297         if (!(typeName = ((LPSTR)LocalAlloc(LMEM_ZEROINIT, typeNameSize))))
298                 return;
299
300         result = CryptEnumProviderTypes(index, NULL, 0, &provType, typeName, &typeNameSize);
301         ok(!result && GetLastError()==ERROR_MORE_DATA, "expected %i, got %ld\n",
302                 ERROR_MORE_DATA, GetLastError());
303         
304         LocalFree(typeName);
305         
306         /* loop through the provider types to get the number of provider types 
307          * after loop ends, count should be dwTypeCount + 1 so subtract 1
308          * to get actual number of provider types */
309         typeCount = 0;
310         while(CryptEnumProviderTypes(typeCount++, NULL, 0, &provType, NULL, &typeNameSize))
311                 ;
312         typeCount--;
313         ok(typeCount==dwTypeCount, "expected %ld, got %ld\n", dwTypeCount, typeCount);
314         
315         /* loop past the actual number of provider types to get the error
316          * ERROR_NO_MORE_ITEMS */
317         for (typeCount = 0; typeCount < dwTypeCount + 1; typeCount++)
318                 result = CryptEnumProviderTypes(typeCount, NULL, 0, &provType, NULL, &typeNameSize);
319         ok(!result && GetLastError()==ERROR_NO_MORE_ITEMS, "expected %i, got %ld\n", 
320                         ERROR_NO_MORE_ITEMS, GetLastError());
321         
322
323         /* check expected versus actual values returned */
324         result = CryptEnumProviderTypes(index, NULL, 0, &provType, NULL, &typeNameSize);
325         ok(result && typeNameSize==cbTypeName, "expected %ld, got %ld\n", cbTypeName, typeNameSize);
326         if (!(typeName = ((LPSTR)LocalAlloc(LMEM_ZEROINIT, typeNameSize))))
327                 return;
328                 
329         result = CryptEnumProviderTypes(index, NULL, 0, &provType, typeName, &typeNameSize);
330         ok(result && provType==dwProvType, "expected %ld, got %ld\n", dwProvType, provType);
331         ok(result && !strcmp(pszTypeName, typeName), "expected %s, got %s\n", pszTypeName, typeName);
332         ok(result && typeNameSize==cbTypeName, "expected %ld, got %ld\n", cbTypeName, typeNameSize);
333 }
334
335 static BOOL FindDfltProvRegVals(DWORD dwProvType, DWORD dwFlags, LPSTR *pszProvName, DWORD *pcbProvName)
336 {
337         HKEY hKey;
338         PSTR keyname;
339         PSTR ptr;
340         DWORD user = dwFlags & CRYPT_USER_DEFAULT;
341         
342         LPSTR MACHINESTR = "Software\\Microsoft\\Cryptography\\Defaults\\Provider Types\\Type XXX";
343         LPSTR USERSTR = "Software\\Microsoft\\Cryptography\\Provider Type XXX";
344         
345         keyname = LocalAlloc(LMEM_ZEROINIT, (user ? strlen(USERSTR) : strlen(MACHINESTR)) + 1);
346         if (keyname)
347         {
348                 user ? strcpy(keyname, USERSTR) : strcpy(keyname, MACHINESTR);
349                 ptr = keyname + strlen(keyname);
350                 *(--ptr) = (dwProvType % 10) + '0';
351                 *(--ptr) = ((dwProvType / 10) % 10) + '0';
352                 *(--ptr) = (dwProvType / 100) + '0';
353         } else
354                 return FALSE;
355         
356         if (RegOpenKey((dwFlags & CRYPT_USER_DEFAULT) ?  HKEY_CURRENT_USER : HKEY_LOCAL_MACHINE ,keyname, &hKey))
357         {
358                 LocalFree(keyname);
359                 return FALSE;
360         }
361         LocalFree(keyname);
362         
363         if (RegQueryValueEx(hKey, "Name", NULL, NULL, *pszProvName, pcbProvName))
364         {
365                 if (GetLastError() != ERROR_MORE_DATA)
366                         SetLastError(NTE_PROV_TYPE_ENTRY_BAD);
367                 return FALSE;
368         }
369         
370         if (!(*pszProvName = LocalAlloc(LMEM_ZEROINIT, *pcbProvName)))
371                 return FALSE;
372         
373         if (RegQueryValueEx(hKey, "Name", NULL, NULL, *pszProvName, pcbProvName))
374         {
375                 if (GetLastError() != ERROR_MORE_DATA)
376                         SetLastError(NTE_PROV_TYPE_ENTRY_BAD);
377                 return FALSE;
378         }
379         
380         RegCloseKey(hKey);
381         
382         return TRUE;
383 }
384
385 static void test_get_default_provider()
386 {
387         /* expected results */
388         DWORD dwProvType = PROV_RSA_FULL;
389         DWORD dwFlags = CRYPT_MACHINE_DEFAULT;
390         LPSTR pszProvName = NULL;
391         DWORD cbProvName;
392         
393         /* actual results */
394         DWORD provType = PROV_RSA_FULL;
395         DWORD flags = CRYPT_MACHINE_DEFAULT;
396         LPSTR provName = NULL;
397         DWORD provNameSize;
398         DWORD result;
399         DWORD notNull = 5;
400         
401         FindDfltProvRegVals(dwProvType, dwFlags, &pszProvName, &cbProvName);
402         
403         /* check pdwReserved for NULL */
404         result = CryptGetDefaultProvider(provType, &notNull, flags, provName, &provNameSize);
405         ok(!result && GetLastError()==ERROR_INVALID_PARAMETER, "expected %i, got %ld\n",
406                 ERROR_INVALID_PARAMETER, GetLastError());
407         
408         /* check for invalid flag */
409         flags = 0xdeadbeef;
410         result = CryptGetDefaultProvider(provType, NULL, flags, provName, &provNameSize);
411         ok(!result && GetLastError()==NTE_BAD_FLAGS, "expected %ld, got %ld\n",
412                 NTE_BAD_FLAGS, GetLastError());
413         flags = CRYPT_MACHINE_DEFAULT;
414         
415         /* check for invalid prov type */
416         provType = 0xdeadbeef;
417         result = CryptGetDefaultProvider(provType, NULL, flags, provName, &provNameSize);
418         ok(!result && GetLastError()==NTE_BAD_PROV_TYPE, "expected %ld, got %ld\n",
419                 NTE_BAD_PROV_TYPE, GetLastError());
420         provType = PROV_RSA_FULL;
421         
422         SetLastError(0);
423         
424         /* alloc provName to half the size required
425          * cbProvName holds the size required */
426         provNameSize = cbProvName / 2;
427         if (!(provName = LocalAlloc(LMEM_ZEROINIT, provNameSize)))
428                 return;
429         
430         result = CryptGetDefaultProvider(provType, NULL, flags, provName, &provNameSize);
431         ok(!result && GetLastError()==ERROR_MORE_DATA, "expected %i, got %ld\n",
432                 ERROR_MORE_DATA, GetLastError());
433                 
434         LocalFree(provName);
435         
436         /* check expected versus actual values returned */
437         result = CryptGetDefaultProvider(provType, NULL, flags, NULL, &provNameSize);
438         ok(result && provNameSize==cbProvName, "expected %ld, got %ld\n", cbProvName, provNameSize);
439         provNameSize = cbProvName;
440         
441         if (!(provName = LocalAlloc(LMEM_ZEROINIT, provNameSize)))
442                 return;
443         
444         result = CryptGetDefaultProvider(provType, NULL, flags, provName, &provNameSize);
445         ok(result && !strcmp(pszProvName, provName), "expected %s, got %s\n", pszProvName, provName);
446         ok(result && provNameSize==cbProvName, "expected %ld, got %ld\n", cbProvName, provNameSize);
447 }
448
449 static void test_set_provider_ex()
450 {
451         DWORD result;
452         DWORD notNull = 5;
453         
454         /* results */
455         LPSTR pszProvName = NULL;
456         DWORD cbProvName;
457         
458         /* check pdwReserved for NULL */
459         result = CryptSetProviderEx(MS_DEF_PROV, PROV_RSA_FULL, &notNull, CRYPT_MACHINE_DEFAULT);
460         ok(!result && GetLastError()==ERROR_INVALID_PARAMETER, "expected %i, got %ld\n",
461                 ERROR_INVALID_PARAMETER, GetLastError());
462
463         /* remove the default provider and then set it to MS_DEF_PROV/PROV_RSA_FULL */
464         result = CryptSetProviderEx(MS_DEF_PROV, PROV_RSA_FULL, NULL, CRYPT_MACHINE_DEFAULT | CRYPT_DELETE_DEFAULT);
465         ok(result, "%ld\n", GetLastError());
466
467         result = CryptSetProviderEx(MS_DEF_PROV, PROV_RSA_FULL, NULL, CRYPT_MACHINE_DEFAULT);
468         ok(result, "%ld\n", GetLastError());
469         
470         /* call CryptGetDefaultProvider to see if they match */
471         result = CryptGetDefaultProvider(PROV_RSA_FULL, NULL, CRYPT_MACHINE_DEFAULT, NULL, &cbProvName);
472         if (!(pszProvName = LocalAlloc(LMEM_ZEROINIT, cbProvName)))
473                 return;
474
475         result = CryptGetDefaultProvider(PROV_RSA_FULL, NULL, CRYPT_MACHINE_DEFAULT, pszProvName, &cbProvName);
476         ok(result && !strcmp(MS_DEF_PROV, pszProvName), "expected %s, got %s\n", MS_DEF_PROV, pszProvName);
477         ok(result && cbProvName==(strlen(MS_DEF_PROV) + 1), "expected %i, got %ld\n", (strlen(MS_DEF_PROV) + 1), cbProvName);
478 }
479
480 START_TEST(crypt)
481 {
482         init_environment();
483         test_acquire_context();
484         clean_up_environment();
485         
486         test_enum_providers();
487         test_enum_provider_types();
488         test_get_default_provider();
489         test_set_provider_ex();
490         test_set_provider_ex();
491 }