crypt32: Implement querying computed hash of a decoded hash message.
[wine] / dlls / crypt32 / msg.c
1 /*
2  * Copyright 2007 Juan Lang
3  *
4  * This library is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 2.1 of the License, or (at your option) any later version.
8  *
9  * This library is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * Lesser General Public License for more details.
13  *
14  * You should have received a copy of the GNU Lesser General Public
15  * License along with this library; if not, write to the Free Software
16  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
17  */
18 #include <stdarg.h>
19 #include "windef.h"
20 #include "winbase.h"
21 #include "wincrypt.h"
22 #include "snmp.h"
23
24 #include "wine/debug.h"
25 #include "wine/exception.h"
26 #include "crypt32_private.h"
27
28 WINE_DEFAULT_DEBUG_CHANNEL(crypt);
29
30 /* Called when a message's ref count reaches zero.  Free any message-specific
31  * data here.
32  */
33 typedef void (*CryptMsgCloseFunc)(HCRYPTMSG msg);
34
35 typedef BOOL (*CryptMsgGetParamFunc)(HCRYPTMSG hCryptMsg, DWORD dwParamType,
36  DWORD dwIndex, void *pvData, DWORD *pcbData);
37
38 typedef BOOL (*CryptMsgUpdateFunc)(HCRYPTMSG hCryptMsg, const BYTE *pbData,
39  DWORD cbData, BOOL fFinal);
40
41 typedef enum _CryptMsgState {
42     MsgStateInit,
43     MsgStateUpdated,
44     MsgStateFinalized
45 } CryptMsgState;
46
47 typedef struct _CryptMsgBase
48 {
49     LONG                 ref;
50     DWORD                open_flags;
51     BOOL                 streamed;
52     CMSG_STREAM_INFO     stream_info;
53     CryptMsgState        state;
54     CryptMsgCloseFunc    close;
55     CryptMsgUpdateFunc   update;
56     CryptMsgGetParamFunc get_param;
57 } CryptMsgBase;
58
59 static inline void CryptMsgBase_Init(CryptMsgBase *msg, DWORD dwFlags,
60  PCMSG_STREAM_INFO pStreamInfo, CryptMsgCloseFunc close,
61  CryptMsgGetParamFunc get_param, CryptMsgUpdateFunc update)
62 {
63     msg->ref = 1;
64     msg->open_flags = dwFlags;
65     if (pStreamInfo)
66     {
67         msg->streamed = TRUE;
68         memcpy(&msg->stream_info, pStreamInfo, sizeof(msg->stream_info));
69     }
70     else
71     {
72         msg->streamed = FALSE;
73         memset(&msg->stream_info, 0, sizeof(msg->stream_info));
74     }
75     msg->close = close;
76     msg->get_param = get_param;
77     msg->update = update;
78     msg->state = MsgStateInit;
79 }
80
81 typedef struct _CDataEncodeMsg
82 {
83     CryptMsgBase base;
84     DWORD        bare_content_len;
85     LPBYTE       bare_content;
86 } CDataEncodeMsg;
87
88 static const BYTE empty_data_content[] = { 0x04,0x00 };
89
90 static void CDataEncodeMsg_Close(HCRYPTMSG hCryptMsg)
91 {
92     CDataEncodeMsg *msg = (CDataEncodeMsg *)hCryptMsg;
93
94     if (msg->bare_content != empty_data_content)
95         LocalFree(msg->bare_content);
96 }
97
98 static WINAPI BOOL CRYPT_EncodeContentLength(DWORD dwCertEncodingType,
99  LPCSTR lpszStructType, const void *pvStructInfo, DWORD dwFlags,
100  PCRYPT_ENCODE_PARA pEncodePara, BYTE *pbEncoded, DWORD *pcbEncoded)
101 {
102     const CDataEncodeMsg *msg = (const CDataEncodeMsg *)pvStructInfo;
103     DWORD lenBytes;
104     BOOL ret = TRUE;
105
106     /* Trick:  report bytes needed based on total message length, even though
107      * the message isn't available yet.  The caller will use the length
108      * reported here to encode its length.
109      */
110     CRYPT_EncodeLen(msg->base.stream_info.cbContent, NULL, &lenBytes);
111     if (!pbEncoded)
112         *pcbEncoded = 1 + lenBytes + msg->base.stream_info.cbContent;
113     else
114     {
115         if ((ret = CRYPT_EncodeEnsureSpace(dwFlags, pEncodePara, pbEncoded,
116          pcbEncoded, 1 + lenBytes)))
117         {
118             if (dwFlags & CRYPT_ENCODE_ALLOC_FLAG)
119                 pbEncoded = *(BYTE **)pbEncoded;
120             *pbEncoded++ = ASN_OCTETSTRING;
121             CRYPT_EncodeLen(msg->base.stream_info.cbContent, pbEncoded,
122              &lenBytes);
123         }
124     }
125     return ret;
126 }
127
128 static BOOL CRYPT_EncodeDataContentInfoHeader(CDataEncodeMsg *msg,
129  CRYPT_DATA_BLOB *header)
130 {
131     BOOL ret;
132
133     if (msg->base.streamed && msg->base.stream_info.cbContent == 0xffffffff)
134     {
135         FIXME("unimplemented for indefinite-length encoding\n");
136         header->cbData = 0;
137         header->pbData = NULL;
138         ret = TRUE;
139     }
140     else
141     {
142         struct AsnConstructedItem constructed = { 0, msg,
143          CRYPT_EncodeContentLength };
144         struct AsnEncodeSequenceItem items[2] = {
145          { szOID_RSA_data, CRYPT_AsnEncodeOid, 0 },
146          { &constructed,   CRYPT_AsnEncodeConstructed, 0 },
147         };
148
149         ret = CRYPT_AsnEncodeSequence(X509_ASN_ENCODING, items,
150          sizeof(items) / sizeof(items[0]), CRYPT_ENCODE_ALLOC_FLAG, NULL,
151          (LPBYTE)&header->pbData, &header->cbData);
152         if (ret)
153         {
154             /* Trick:  subtract the content length from the reported length,
155              * as the actual content hasn't come yet.
156              */
157             header->cbData -= msg->base.stream_info.cbContent;
158         }
159     }
160     return ret;
161 }
162
163 static BOOL CDataEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
164  DWORD cbData, BOOL fFinal)
165 {
166     CDataEncodeMsg *msg = (CDataEncodeMsg *)hCryptMsg;
167     BOOL ret = FALSE;
168
169     if (msg->base.streamed)
170     {
171         __TRY
172         {
173             if (msg->base.state != MsgStateUpdated)
174             {
175                 CRYPT_DATA_BLOB header;
176
177                 ret = CRYPT_EncodeDataContentInfoHeader(msg, &header);
178                 if (ret)
179                 {
180                     ret = msg->base.stream_info.pfnStreamOutput(
181                      msg->base.stream_info.pvArg, header.pbData, header.cbData,
182                      FALSE);
183                     LocalFree(header.pbData);
184                 }
185             }
186             if (!fFinal)
187                 ret = msg->base.stream_info.pfnStreamOutput(
188                  msg->base.stream_info.pvArg, (BYTE *)pbData, cbData,
189                  FALSE);
190             else
191             {
192                 if (msg->base.stream_info.cbContent == 0xffffffff)
193                 {
194                     BYTE indefinite_trailer[6] = { 0 };
195
196                     ret = msg->base.stream_info.pfnStreamOutput(
197                      msg->base.stream_info.pvArg, (BYTE *)pbData, cbData,
198                      FALSE);
199                     if (ret)
200                         ret = msg->base.stream_info.pfnStreamOutput(
201                          msg->base.stream_info.pvArg, indefinite_trailer,
202                          sizeof(indefinite_trailer), TRUE);
203                 }
204                 else
205                     ret = msg->base.stream_info.pfnStreamOutput(
206                      msg->base.stream_info.pvArg, (BYTE *)pbData, cbData, TRUE);
207             }
208         }
209         __EXCEPT_PAGE_FAULT
210         {
211             SetLastError(STATUS_ACCESS_VIOLATION);
212         }
213         __ENDTRY;
214     }
215     else
216     {
217         if (!fFinal)
218         {
219             if (msg->base.open_flags & CMSG_DETACHED_FLAG)
220                 SetLastError(E_INVALIDARG);
221             else
222                 SetLastError(CRYPT_E_MSG_ERROR);
223         }
224         else
225         {
226             if (!cbData)
227                 SetLastError(E_INVALIDARG);
228             else
229             {
230                 CRYPT_DATA_BLOB blob = { cbData, (LPBYTE)pbData };
231
232                 /* non-streamed data messages don't allow non-final updates,
233                  * don't bother checking whether data already exist, they can't.
234                  */
235                 ret = CryptEncodeObjectEx(X509_ASN_ENCODING, X509_OCTET_STRING,
236                  &blob, CRYPT_ENCODE_ALLOC_FLAG, NULL, &msg->bare_content,
237                  &msg->bare_content_len);
238             }
239         }
240     }
241     return ret;
242 }
243
244 static BOOL CRYPT_CopyParam(void *pvData, DWORD *pcbData, const BYTE *src,
245  DWORD len)
246 {
247     BOOL ret = TRUE;
248
249     if (!pvData)
250         *pcbData = len;
251     else if (*pcbData < len)
252     {
253         *pcbData = len;
254         SetLastError(ERROR_MORE_DATA);
255         ret = FALSE;
256     }
257     else
258     {
259         *pcbData = len;
260         memcpy(pvData, src, len);
261     }
262     return ret;
263 }
264
265 static BOOL CDataEncodeMsg_GetParam(HCRYPTMSG hCryptMsg, DWORD dwParamType,
266  DWORD dwIndex, void *pvData, DWORD *pcbData)
267 {
268     CDataEncodeMsg *msg = (CDataEncodeMsg *)hCryptMsg;
269     BOOL ret = FALSE;
270
271     switch (dwParamType)
272     {
273     case CMSG_CONTENT_PARAM:
274         if (msg->base.streamed)
275             SetLastError(E_INVALIDARG);
276         else
277         {
278             CRYPT_CONTENT_INFO info;
279             char rsa_data[] = "1.2.840.113549.1.7.1";
280
281             info.pszObjId = rsa_data;
282             info.Content.cbData = msg->bare_content_len;
283             info.Content.pbData = msg->bare_content;
284             ret = CryptEncodeObject(X509_ASN_ENCODING, PKCS_CONTENT_INFO, &info,
285              pvData, pcbData);
286         }
287         break;
288     case CMSG_BARE_CONTENT_PARAM:
289         if (msg->base.streamed)
290             SetLastError(E_INVALIDARG);
291         else
292             ret = CRYPT_CopyParam(pvData, pcbData, msg->bare_content,
293              msg->bare_content_len);
294         break;
295     default:
296         SetLastError(CRYPT_E_INVALID_MSG_TYPE);
297     }
298     return ret;
299 }
300
301 static HCRYPTMSG CDataEncodeMsg_Open(DWORD dwFlags, const void *pvMsgEncodeInfo,
302  LPSTR pszInnerContentObjID, PCMSG_STREAM_INFO pStreamInfo)
303 {
304     CDataEncodeMsg *msg;
305
306     if (pvMsgEncodeInfo)
307     {
308         SetLastError(E_INVALIDARG);
309         return NULL;
310     }
311     msg = CryptMemAlloc(sizeof(CDataEncodeMsg));
312     if (msg)
313     {
314         CryptMsgBase_Init((CryptMsgBase *)msg, dwFlags, pStreamInfo,
315          CDataEncodeMsg_Close, CDataEncodeMsg_GetParam, CDataEncodeMsg_Update);
316         msg->bare_content_len = sizeof(empty_data_content);
317         msg->bare_content = (LPBYTE)empty_data_content;
318     }
319     return (HCRYPTMSG)msg;
320 }
321
322 typedef struct _CHashEncodeMsg
323 {
324     CryptMsgBase    base;
325     HCRYPTPROV      prov;
326     HCRYPTHASH      hash;
327     CRYPT_DATA_BLOB data;
328 } CHashEncodeMsg;
329
330 static void CHashEncodeMsg_Close(HCRYPTMSG hCryptMsg)
331 {
332     CHashEncodeMsg *msg = (CHashEncodeMsg *)hCryptMsg;
333
334     CryptMemFree(msg->data.pbData);
335     CryptDestroyHash(msg->hash);
336     if (msg->base.open_flags & CMSG_CRYPT_RELEASE_CONTEXT_FLAG)
337         CryptReleaseContext(msg->prov, 0);
338 }
339
340 static BOOL CRYPT_EncodePKCSDigestedData(CHashEncodeMsg *msg, void *pvData,
341  DWORD *pcbData)
342 {
343     BOOL ret;
344     ALG_ID algID;
345     DWORD size = sizeof(algID);
346
347     ret = CryptGetHashParam(msg->hash, HP_ALGID, (BYTE *)&algID, &size, 0);
348     if (ret)
349     {
350         CRYPT_DIGESTED_DATA digestedData = { 0 };
351         char oid_rsa_data[] = szOID_RSA_data;
352
353         digestedData.version = CMSG_HASHED_DATA_PKCS_1_5_VERSION;
354         digestedData.DigestAlgorithm.pszObjId = (LPSTR)CertAlgIdToOID(algID);
355         /* FIXME: what about digestedData.DigestAlgorithm.Parameters? */
356         /* Quirk:  OID is only encoded messages if an update has happened */
357         if (msg->base.state != MsgStateInit)
358             digestedData.ContentInfo.pszObjId = oid_rsa_data;
359         if (!(msg->base.open_flags & CMSG_DETACHED_FLAG) && msg->data.cbData)
360         {
361             ret = CRYPT_AsnEncodeOctets(0, NULL, &msg->data,
362              CRYPT_ENCODE_ALLOC_FLAG, NULL,
363              (LPBYTE)&digestedData.ContentInfo.Content.pbData,
364              &digestedData.ContentInfo.Content.cbData);
365         }
366         if (msg->base.state == MsgStateFinalized)
367         {
368             size = sizeof(DWORD);
369             ret = CryptGetHashParam(msg->hash, HP_HASHSIZE,
370              (LPBYTE)&digestedData.hash.cbData, &size, 0);
371             if (ret)
372             {
373                 digestedData.hash.pbData = CryptMemAlloc(
374                  digestedData.hash.cbData);
375                 ret = CryptGetHashParam(msg->hash, HP_HASHVAL,
376                  digestedData.hash.pbData, &digestedData.hash.cbData, 0);
377             }
378         }
379         if (ret)
380             ret = CRYPT_AsnEncodePKCSDigestedData(&digestedData, pvData,
381              pcbData);
382         CryptMemFree(digestedData.hash.pbData);
383         LocalFree(digestedData.ContentInfo.Content.pbData);
384     }
385     return ret;
386 }
387
388 static BOOL CHashEncodeMsg_GetParam(HCRYPTMSG hCryptMsg, DWORD dwParamType,
389  DWORD dwIndex, void *pvData, DWORD *pcbData)
390 {
391     CHashEncodeMsg *msg = (CHashEncodeMsg *)hCryptMsg;
392     BOOL ret = FALSE;
393
394     TRACE("(%p, %d, %d, %p, %p)\n", hCryptMsg, dwParamType, dwIndex,
395      pvData, pcbData);
396
397     switch (dwParamType)
398     {
399     case CMSG_BARE_CONTENT_PARAM:
400         if (msg->base.streamed)
401             SetLastError(E_INVALIDARG);
402         else
403             ret = CRYPT_EncodePKCSDigestedData(msg, pvData, pcbData);
404         break;
405     case CMSG_CONTENT_PARAM:
406     {
407         CRYPT_CONTENT_INFO info;
408
409         ret = CryptMsgGetParam(hCryptMsg, CMSG_BARE_CONTENT_PARAM, 0, NULL,
410          &info.Content.cbData);
411         if (ret)
412         {
413             info.Content.pbData = CryptMemAlloc(info.Content.cbData);
414             if (info.Content.pbData)
415             {
416                 ret = CryptMsgGetParam(hCryptMsg, CMSG_BARE_CONTENT_PARAM, 0,
417                  info.Content.pbData, &info.Content.cbData);
418                 if (ret)
419                 {
420                     char oid_rsa_hashed[] = szOID_RSA_hashedData;
421
422                     info.pszObjId = oid_rsa_hashed;
423                     ret = CryptEncodeObjectEx(X509_ASN_ENCODING,
424                      PKCS_CONTENT_INFO, &info, 0, NULL, pvData, pcbData);
425                 }
426                 CryptMemFree(info.Content.pbData);
427             }
428             else
429                 ret = FALSE;
430         }
431         break;
432     }
433     case CMSG_COMPUTED_HASH_PARAM:
434         ret = CryptGetHashParam(msg->hash, HP_HASHVAL, (BYTE *)pvData, pcbData,
435          0);
436         break;
437     case CMSG_VERSION_PARAM:
438         if (msg->base.state != MsgStateFinalized)
439             SetLastError(CRYPT_E_MSG_ERROR);
440         else
441         {
442             DWORD version = CMSG_HASHED_DATA_PKCS_1_5_VERSION;
443
444             /* Since the data are always encoded as octets, the version is
445              * always 0 (see rfc3852, section 7)
446              */
447             ret = CRYPT_CopyParam(pvData, pcbData, (const BYTE *)&version,
448              sizeof(version));
449         }
450         break;
451     default:
452         ret = FALSE;
453     }
454     return ret;
455 }
456
457 static BOOL CHashEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
458  DWORD cbData, BOOL fFinal)
459 {
460     CHashEncodeMsg *msg = (CHashEncodeMsg *)hCryptMsg;
461     BOOL ret = FALSE;
462
463     TRACE("(%p, %p, %d, %d)\n", hCryptMsg, pbData, cbData, fFinal);
464
465     if (msg->base.streamed || (msg->base.open_flags & CMSG_DETACHED_FLAG))
466     {
467         /* Doesn't do much, as stream output is never called, and you
468          * can't get the content.
469          */
470         ret = CryptHashData(msg->hash, pbData, cbData, 0);
471     }
472     else
473     {
474         if (!fFinal)
475             SetLastError(CRYPT_E_MSG_ERROR);
476         else
477         {
478             ret = CryptHashData(msg->hash, pbData, cbData, 0);
479             if (ret)
480             {
481                 msg->data.pbData = CryptMemAlloc(cbData);
482                 if (msg->data.pbData)
483                 {
484                     memcpy(msg->data.pbData + msg->data.cbData, pbData, cbData);
485                     msg->data.cbData += cbData;
486                 }
487                 else
488                     ret = FALSE;
489             }
490         }
491     }
492     return ret;
493 }
494
495 static HCRYPTMSG CHashEncodeMsg_Open(DWORD dwFlags, const void *pvMsgEncodeInfo,
496  LPSTR pszInnerContentObjID, PCMSG_STREAM_INFO pStreamInfo)
497 {
498     CHashEncodeMsg *msg;
499     const CMSG_HASHED_ENCODE_INFO *info =
500      (const CMSG_HASHED_ENCODE_INFO *)pvMsgEncodeInfo;
501     HCRYPTPROV prov;
502     ALG_ID algID;
503
504     if (info->cbSize != sizeof(CMSG_HASHED_ENCODE_INFO))
505     {
506         SetLastError(E_INVALIDARG);
507         return NULL;
508     }
509     if (!(algID = CertOIDToAlgId(info->HashAlgorithm.pszObjId)))
510     {
511         SetLastError(CRYPT_E_UNKNOWN_ALGO);
512         return NULL;
513     }
514     if (info->hCryptProv)
515         prov = info->hCryptProv;
516     else
517     {
518         prov = CRYPT_GetDefaultProvider();
519         dwFlags &= ~CMSG_CRYPT_RELEASE_CONTEXT_FLAG;
520     }
521     msg = CryptMemAlloc(sizeof(CHashEncodeMsg));
522     if (msg)
523     {
524         CryptMsgBase_Init((CryptMsgBase *)msg, dwFlags, pStreamInfo,
525          CHashEncodeMsg_Close, CHashEncodeMsg_GetParam, CHashEncodeMsg_Update);
526         msg->prov = prov;
527         msg->data.cbData = 0;
528         msg->data.pbData = NULL;
529         if (!CryptCreateHash(prov, algID, 0, 0, &msg->hash))
530         {
531             CryptMsgClose(msg);
532             msg = NULL;
533         }
534     }
535     return (HCRYPTMSG)msg;
536 }
537
538 static inline const char *MSG_TYPE_STR(DWORD type)
539 {
540     switch (type)
541     {
542 #define _x(x) case (x): return #x
543         _x(CMSG_DATA);
544         _x(CMSG_SIGNED);
545         _x(CMSG_ENVELOPED);
546         _x(CMSG_SIGNED_AND_ENVELOPED);
547         _x(CMSG_HASHED);
548         _x(CMSG_ENCRYPTED);
549 #undef _x
550         default:
551             return wine_dbg_sprintf("unknown (%d)", type);
552     }
553 }
554
555 HCRYPTMSG WINAPI CryptMsgOpenToEncode(DWORD dwMsgEncodingType, DWORD dwFlags,
556  DWORD dwMsgType, const void *pvMsgEncodeInfo, LPSTR pszInnerContentObjID,
557  PCMSG_STREAM_INFO pStreamInfo)
558 {
559     HCRYPTMSG msg = NULL;
560
561     TRACE("(%08x, %08x, %08x, %p, %s, %p)\n", dwMsgEncodingType, dwFlags,
562      dwMsgType, pvMsgEncodeInfo, debugstr_a(pszInnerContentObjID), pStreamInfo);
563
564     if (GET_CMSG_ENCODING_TYPE(dwMsgEncodingType) != PKCS_7_ASN_ENCODING)
565     {
566         SetLastError(E_INVALIDARG);
567         return NULL;
568     }
569     switch (dwMsgType)
570     {
571     case CMSG_DATA:
572         msg = CDataEncodeMsg_Open(dwFlags, pvMsgEncodeInfo,
573          pszInnerContentObjID, pStreamInfo);
574         break;
575     case CMSG_HASHED:
576         msg = CHashEncodeMsg_Open(dwFlags, pvMsgEncodeInfo,
577          pszInnerContentObjID, pStreamInfo);
578         break;
579     case CMSG_SIGNED:
580     case CMSG_ENVELOPED:
581         FIXME("unimplemented for type %s\n", MSG_TYPE_STR(dwMsgType));
582         break;
583     case CMSG_SIGNED_AND_ENVELOPED:
584     case CMSG_ENCRYPTED:
585         /* defined but invalid, fall through */
586     default:
587         SetLastError(CRYPT_E_INVALID_MSG_TYPE);
588     }
589     return msg;
590 }
591
592 typedef struct _CDecodeMsg
593 {
594     CryptMsgBase           base;
595     DWORD                  type;
596     HCRYPTPROV             crypt_prov;
597     HCRYPTHASH             hash;
598     CRYPT_DATA_BLOB        msg_data;
599     PCONTEXT_PROPERTY_LIST properties;
600 } CDecodeMsg;
601
602 static void CDecodeMsg_Close(HCRYPTMSG hCryptMsg)
603 {
604     CDecodeMsg *msg = (CDecodeMsg *)hCryptMsg;
605
606     if (msg->base.open_flags & CMSG_CRYPT_RELEASE_CONTEXT_FLAG)
607         CryptReleaseContext(msg->crypt_prov, 0);
608     CryptDestroyHash(msg->hash);
609     CryptMemFree(msg->msg_data.pbData);
610     ContextPropertyList_Free(msg->properties);
611 }
612
613 static BOOL CDecodeMsg_CopyData(CDecodeMsg *msg, const BYTE *pbData,
614  DWORD cbData)
615 {
616     BOOL ret = TRUE;
617
618     if (cbData)
619     {
620         if (msg->msg_data.cbData)
621             msg->msg_data.pbData = CryptMemRealloc(msg->msg_data.pbData,
622              msg->msg_data.cbData + cbData);
623         else
624             msg->msg_data.pbData = CryptMemAlloc(cbData);
625         if (msg->msg_data.pbData)
626         {
627             memcpy(msg->msg_data.pbData + msg->msg_data.cbData, pbData, cbData);
628             msg->msg_data.cbData += cbData;
629         }
630         else
631             ret = FALSE;
632     }
633     return ret;
634 }
635
636 static BOOL CDecodeMsg_DecodeDataContent(CDecodeMsg *msg, CRYPT_DER_BLOB *blob)
637 {
638     BOOL ret;
639     CRYPT_DATA_BLOB *data;
640     DWORD size;
641
642     ret = CryptDecodeObjectEx(X509_ASN_ENCODING, X509_OCTET_STRING,
643      blob->pbData, blob->cbData, CRYPT_DECODE_ALLOC_FLAG, NULL, (LPBYTE)&data,
644      &size);
645     if (ret)
646     {
647         ret = ContextPropertyList_SetProperty(msg->properties,
648          CMSG_CONTENT_PARAM, data->pbData, data->cbData);
649         LocalFree(data);
650     }
651     return ret;
652 }
653
654 static void CDecodeMsg_SaveAlgorithmID(CDecodeMsg *msg, DWORD param,
655  const CRYPT_ALGORITHM_IDENTIFIER *id)
656 {
657     static const BYTE nullParams[] = { ASN_NULL, 0 };
658     CRYPT_ALGORITHM_IDENTIFIER *copy;
659     DWORD len = sizeof(CRYPT_ALGORITHM_IDENTIFIER);
660
661     /* Linearize algorithm id */
662     len += strlen(id->pszObjId) + 1;
663     len += id->Parameters.cbData;
664     copy = CryptMemAlloc(len);
665     if (copy)
666     {
667         copy->pszObjId =
668          (LPSTR)((BYTE *)copy + sizeof(CRYPT_ALGORITHM_IDENTIFIER));
669         strcpy(copy->pszObjId, id->pszObjId);
670         copy->Parameters.pbData = (BYTE *)copy->pszObjId + strlen(id->pszObjId)
671          + 1;
672         /* Trick:  omit NULL parameters */
673         if (id->Parameters.cbData == sizeof(nullParams) &&
674          !memcmp(id->Parameters.pbData, nullParams, sizeof(nullParams)))
675         {
676             copy->Parameters.cbData = 0;
677             len -= sizeof(nullParams);
678         }
679         else
680             copy->Parameters.cbData = id->Parameters.cbData;
681         if (copy->Parameters.cbData)
682             memcpy(copy->Parameters.pbData, id->Parameters.pbData,
683              id->Parameters.cbData);
684         ContextPropertyList_SetProperty(msg->properties, param, (BYTE *)copy,
685          len);
686         CryptMemFree(copy);
687     }
688 }
689
690 static inline void CRYPT_FixUpAlgorithmID(CRYPT_ALGORITHM_IDENTIFIER *id)
691 {
692     id->pszObjId = (LPSTR)((BYTE *)id + sizeof(CRYPT_ALGORITHM_IDENTIFIER));
693     id->Parameters.pbData = (BYTE *)id->pszObjId + strlen(id->pszObjId) + 1;
694 }
695
696 /* Decodes the content in blob as the type given, and updates the value
697  * (type, parameters, etc.) of msg based on what blob contains.
698  * It doesn't just use msg's type, to allow a recursive call from an implicitly
699  * typed message once the outer content info has been decoded.
700  */
701 static BOOL CDecodeMsg_DecodeContent(CDecodeMsg *msg, CRYPT_DER_BLOB *blob,
702  DWORD type)
703 {
704     BOOL ret;
705     DWORD size;
706
707     switch (type)
708     {
709     case CMSG_DATA:
710         if ((ret = CDecodeMsg_DecodeDataContent(msg, blob)))
711             msg->type = CMSG_DATA;
712         break;
713     case CMSG_HASHED:
714     {
715         CRYPT_DIGESTED_DATA *digestedData;
716
717         ret = CRYPT_AsnDecodePKCSDigestedData(blob->pbData, blob->cbData,
718          CRYPT_DECODE_ALLOC_FLAG, NULL, (CRYPT_DIGESTED_DATA *)&digestedData,
719          &size);
720         if (ret)
721         {
722             msg->type = CMSG_HASHED;
723             ContextPropertyList_SetProperty(msg->properties,
724              CMSG_VERSION_PARAM, (const BYTE *)&digestedData->version,
725              sizeof(digestedData->version));
726             CDecodeMsg_SaveAlgorithmID(msg, CMSG_HASH_ALGORITHM_PARAM,
727              &digestedData->DigestAlgorithm);
728             ContextPropertyList_SetProperty(msg->properties,
729              CMSG_INNER_CONTENT_TYPE_PARAM,
730              (const BYTE *)digestedData->ContentInfo.pszObjId,
731              digestedData->ContentInfo.pszObjId ?
732              strlen(digestedData->ContentInfo.pszObjId) + 1 : 0);
733             if (digestedData->ContentInfo.Content.cbData)
734                 CDecodeMsg_DecodeDataContent(msg,
735                  &digestedData->ContentInfo.Content);
736             else
737                 ContextPropertyList_SetProperty(msg->properties,
738                  CMSG_CONTENT_PARAM, NULL, 0);
739             ContextPropertyList_SetProperty(msg->properties,
740              CMSG_HASH_DATA_PARAM, digestedData->hash.pbData,
741              digestedData->hash.cbData);
742             LocalFree(digestedData);
743         }
744         break;
745     }
746     case CMSG_ENVELOPED:
747     case CMSG_SIGNED:
748         FIXME("unimplemented for type %s\n", MSG_TYPE_STR(type));
749         ret = TRUE;
750         break;
751     default:
752     {
753         CRYPT_CONTENT_INFO *info;
754
755         ret = CryptDecodeObjectEx(X509_ASN_ENCODING, PKCS_CONTENT_INFO,
756          msg->msg_data.pbData, msg->msg_data.cbData, CRYPT_DECODE_ALLOC_FLAG,
757          NULL, (LPBYTE)&info, &size);
758         if (ret)
759         {
760             if (!strcmp(info->pszObjId, szOID_RSA_data))
761                 ret = CDecodeMsg_DecodeContent(msg, &info->Content, CMSG_DATA);
762             else if (!strcmp(info->pszObjId, szOID_RSA_digestedData))
763                 ret = CDecodeMsg_DecodeContent(msg, &info->Content,
764                  CMSG_HASHED);
765             else if (!strcmp(info->pszObjId, szOID_RSA_envelopedData))
766                 ret = CDecodeMsg_DecodeContent(msg, &info->Content,
767                  CMSG_ENVELOPED);
768             else if (!strcmp(info->pszObjId, szOID_RSA_signedData))
769                 ret = CDecodeMsg_DecodeContent(msg, &info->Content,
770                  CMSG_SIGNED);
771             else
772             {
773                 SetLastError(CRYPT_E_INVALID_MSG_TYPE);
774                 ret = FALSE;
775             }
776             LocalFree(info);
777         }
778     }
779     }
780     return ret;
781 }
782
783 static BOOL CDecodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
784  DWORD cbData, BOOL fFinal)
785 {
786     CDecodeMsg *msg = (CDecodeMsg *)hCryptMsg;
787     BOOL ret = FALSE;
788
789     TRACE("(%p, %p, %d, %d)\n", hCryptMsg, pbData, cbData, fFinal);
790
791     if (msg->base.streamed)
792     {
793         ret = CDecodeMsg_CopyData(msg, pbData, cbData);
794         FIXME("(%p, %p, %d, %d): streamed update stub\n", hCryptMsg, pbData,
795          cbData, fFinal);
796     }
797     else
798     {
799         if (!fFinal)
800             SetLastError(CRYPT_E_MSG_ERROR);
801         else
802         {
803             ret = CDecodeMsg_CopyData(msg, pbData, cbData);
804             if (ret)
805                 ret = CDecodeMsg_DecodeContent(msg, &msg->msg_data, msg->type);
806
807         }
808     }
809     return ret;
810 }
811
812 static BOOL CDecodeMsg_GetParam(HCRYPTMSG hCryptMsg, DWORD dwParamType,
813  DWORD dwIndex, void *pvData, DWORD *pcbData)
814 {
815     CDecodeMsg *msg = (CDecodeMsg *)hCryptMsg;
816     BOOL ret = FALSE;
817
818     switch (dwParamType)
819     {
820     case CMSG_TYPE_PARAM:
821         ret = CRYPT_CopyParam(pvData, pcbData, (const BYTE *)&msg->type,
822          sizeof(msg->type));
823         break;
824     case CMSG_HASH_ALGORITHM_PARAM:
825     {
826         CRYPT_DATA_BLOB blob;
827
828         ret = ContextPropertyList_FindProperty(msg->properties, dwParamType,
829          &blob);
830         if (ret)
831         {
832             ret = CRYPT_CopyParam(pvData, pcbData, blob.pbData, blob.cbData);
833             if (ret && pvData)
834                 CRYPT_FixUpAlgorithmID((CRYPT_ALGORITHM_IDENTIFIER *)pvData);
835         }
836         else
837             SetLastError(CRYPT_E_INVALID_MSG_TYPE);
838         break;
839     }
840     case CMSG_COMPUTED_HASH_PARAM:
841         if (!msg->hash)
842         {
843             CRYPT_ALGORITHM_IDENTIFIER *hashAlgoID = NULL;
844             DWORD size = 0;
845             ALG_ID algID = 0;
846
847             CryptMsgGetParam(msg, CMSG_HASH_ALGORITHM_PARAM, 0, NULL, &size);
848             hashAlgoID = CryptMemAlloc(size);
849             ret = CryptMsgGetParam(msg, CMSG_HASH_ALGORITHM_PARAM, 0,
850              hashAlgoID, &size);
851             if (ret)
852                 algID = CertOIDToAlgId(hashAlgoID->pszObjId);
853             ret = CryptCreateHash(msg->crypt_prov, algID, 0, 0, &msg->hash);
854             if (ret)
855             {
856                 CRYPT_DATA_BLOB content;
857
858                 ret = ContextPropertyList_FindProperty(msg->properties,
859                  CMSG_CONTENT_PARAM, &content);
860                 if (ret)
861                     ret = CryptHashData(msg->hash, content.pbData,
862                      content.cbData, 0);
863             }
864             CryptMemFree(hashAlgoID);
865         }
866         else
867             ret = TRUE;
868         if (ret)
869             ret = CryptGetHashParam(msg->hash, HP_HASHVAL, pvData, pcbData, 0);
870         break;
871     default:
872     {
873         CRYPT_DATA_BLOB blob;
874
875         ret = ContextPropertyList_FindProperty(msg->properties, dwParamType,
876          &blob);
877         if (ret)
878             ret = CRYPT_CopyParam(pvData, pcbData, blob.pbData, blob.cbData);
879         else
880             SetLastError(CRYPT_E_INVALID_MSG_TYPE);
881     }
882     }
883     return ret;
884 }
885
886 HCRYPTMSG WINAPI CryptMsgOpenToDecode(DWORD dwMsgEncodingType, DWORD dwFlags,
887  DWORD dwMsgType, HCRYPTPROV hCryptProv, PCERT_INFO pRecipientInfo,
888  PCMSG_STREAM_INFO pStreamInfo)
889 {
890     CDecodeMsg *msg;
891
892     TRACE("(%08x, %08x, %08x, %08lx, %p, %p)\n", dwMsgEncodingType,
893      dwFlags, dwMsgType, hCryptProv, pRecipientInfo, pStreamInfo);
894
895     if (GET_CMSG_ENCODING_TYPE(dwMsgEncodingType) != PKCS_7_ASN_ENCODING)
896     {
897         SetLastError(E_INVALIDARG);
898         return NULL;
899     }
900     msg = CryptMemAlloc(sizeof(CDecodeMsg));
901     if (msg)
902     {
903         CryptMsgBase_Init((CryptMsgBase *)msg, dwFlags, pStreamInfo,
904          CDecodeMsg_Close, CDecodeMsg_GetParam, CDecodeMsg_Update);
905         msg->type = dwMsgType;
906         if (hCryptProv)
907             msg->crypt_prov = hCryptProv;
908         else
909         {
910             msg->crypt_prov = CRYPT_GetDefaultProvider();
911             msg->base.open_flags &= ~CMSG_CRYPT_RELEASE_CONTEXT_FLAG;
912         }
913         msg->hash = 0;
914         msg->msg_data.cbData = 0;
915         msg->msg_data.pbData = NULL;
916         msg->properties = ContextPropertyList_Create();
917     }
918     return msg;
919 }
920
921 HCRYPTMSG WINAPI CryptMsgDuplicate(HCRYPTMSG hCryptMsg)
922 {
923     TRACE("(%p)\n", hCryptMsg);
924
925     if (hCryptMsg)
926     {
927         CryptMsgBase *msg = (CryptMsgBase *)hCryptMsg;
928
929         InterlockedIncrement(&msg->ref);
930     }
931     return hCryptMsg;
932 }
933
934 BOOL WINAPI CryptMsgClose(HCRYPTMSG hCryptMsg)
935 {
936     TRACE("(%p)\n", hCryptMsg);
937
938     if (hCryptMsg)
939     {
940         CryptMsgBase *msg = (CryptMsgBase *)hCryptMsg;
941
942         if (InterlockedDecrement(&msg->ref) == 0)
943         {
944             TRACE("freeing %p\n", msg);
945             if (msg->close)
946                 msg->close(msg);
947             CryptMemFree(msg);
948         }
949     }
950     return TRUE;
951 }
952
953 BOOL WINAPI CryptMsgUpdate(HCRYPTMSG hCryptMsg, const BYTE *pbData,
954  DWORD cbData, BOOL fFinal)
955 {
956     CryptMsgBase *msg = (CryptMsgBase *)hCryptMsg;
957     BOOL ret = FALSE;
958
959     TRACE("(%p, %p, %d, %d)\n", hCryptMsg, pbData, cbData, fFinal);
960
961     if (msg->state == MsgStateFinalized)
962         SetLastError(CRYPT_E_MSG_ERROR);
963     else
964     {
965         ret = msg->update(hCryptMsg, pbData, cbData, fFinal);
966         msg->state = MsgStateUpdated;
967         if (fFinal)
968             msg->state = MsgStateFinalized;
969     }
970     return ret;
971 }
972
973 BOOL WINAPI CryptMsgGetParam(HCRYPTMSG hCryptMsg, DWORD dwParamType,
974  DWORD dwIndex, void *pvData, DWORD *pcbData)
975 {
976     CryptMsgBase *msg = (CryptMsgBase *)hCryptMsg;
977
978     TRACE("(%p, %d, %d, %p, %p)\n", hCryptMsg, dwParamType, dwIndex,
979      pvData, pcbData);
980     return msg->get_param(hCryptMsg, dwParamType, dwIndex, pvData, pcbData);
981 }