Merge branch 'master'
[linux-2.6] / fs / cifs / asn1.c
1 /* 
2  * The ASB.1/BER parsing code is derived from ip_nat_snmp_basic.c which was in
3  * turn derived from the gxsnmp package by Gregory McLean & Jochen Friedrich
4  *      
5  * Copyright (c) 2000 RP Internet (www.rpi.net.au).
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 2 of the License, or
10  * (at your option) any later version.
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
18  */
19
20 #include <linux/config.h>
21 #include <linux/module.h>
22 #include <linux/types.h>
23 #include <linux/kernel.h>
24 #include <linux/mm.h>
25 #include <linux/slab.h>
26 #include "cifspdu.h"
27 #include "cifsglob.h"
28 #include "cifs_debug.h"
29 #include "cifsproto.h"
30
31 /*****************************************************************************
32  *
33  * Basic ASN.1 decoding routines (gxsnmp author Dirk Wisse)
34  *
35  *****************************************************************************/
36
37 /* Class */
38 #define ASN1_UNI        0       /* Universal */
39 #define ASN1_APL        1       /* Application */
40 #define ASN1_CTX        2       /* Context */
41 #define ASN1_PRV        3       /* Private */
42
43 /* Tag */
44 #define ASN1_EOC        0       /* End Of Contents or N/A */
45 #define ASN1_BOL        1       /* Boolean */
46 #define ASN1_INT        2       /* Integer */
47 #define ASN1_BTS        3       /* Bit String */
48 #define ASN1_OTS        4       /* Octet String */
49 #define ASN1_NUL        5       /* Null */
50 #define ASN1_OJI        6       /* Object Identifier  */
51 #define ASN1_OJD        7       /* Object Description */
52 #define ASN1_EXT        8       /* External */
53 #define ASN1_SEQ        16      /* Sequence */
54 #define ASN1_SET        17      /* Set */
55 #define ASN1_NUMSTR     18      /* Numerical String */
56 #define ASN1_PRNSTR     19      /* Printable String */
57 #define ASN1_TEXSTR     20      /* Teletext String */
58 #define ASN1_VIDSTR     21      /* Video String */
59 #define ASN1_IA5STR     22      /* IA5 String */
60 #define ASN1_UNITIM     23      /* Universal Time */
61 #define ASN1_GENTIM     24      /* General Time */
62 #define ASN1_GRASTR     25      /* Graphical String */
63 #define ASN1_VISSTR     26      /* Visible String */
64 #define ASN1_GENSTR     27      /* General String */
65
66 /* Primitive / Constructed methods*/
67 #define ASN1_PRI        0       /* Primitive */
68 #define ASN1_CON        1       /* Constructed */
69
70 /*
71  * Error codes.
72  */
73 #define ASN1_ERR_NOERROR                0
74 #define ASN1_ERR_DEC_EMPTY              2
75 #define ASN1_ERR_DEC_EOC_MISMATCH       3
76 #define ASN1_ERR_DEC_LENGTH_MISMATCH    4
77 #define ASN1_ERR_DEC_BADVALUE           5
78
79 #define SPNEGO_OID_LEN 7
80 #define NTLMSSP_OID_LEN  10
81 static unsigned long SPNEGO_OID[7] = { 1, 3, 6, 1, 5, 5, 2 };
82 static unsigned long NTLMSSP_OID[10] = { 1, 3, 6, 1, 4, 1, 311, 2, 2, 10 };
83
84 /* 
85  * ASN.1 context.
86  */
87 struct asn1_ctx {
88         int error;              /* Error condition */
89         unsigned char *pointer; /* Octet just to be decoded */
90         unsigned char *begin;   /* First octet */
91         unsigned char *end;     /* Octet after last octet */
92 };
93
94 /*
95  * Octet string (not null terminated)
96  */
97 struct asn1_octstr {
98         unsigned char *data;
99         unsigned int len;
100 };
101
102 static void
103 asn1_open(struct asn1_ctx *ctx, unsigned char *buf, unsigned int len)
104 {
105         ctx->begin = buf;
106         ctx->end = buf + len;
107         ctx->pointer = buf;
108         ctx->error = ASN1_ERR_NOERROR;
109 }
110
111 static unsigned char
112 asn1_octet_decode(struct asn1_ctx *ctx, unsigned char *ch)
113 {
114         if (ctx->pointer >= ctx->end) {
115                 ctx->error = ASN1_ERR_DEC_EMPTY;
116                 return 0;
117         }
118         *ch = *(ctx->pointer)++;
119         return 1;
120 }
121
122 static unsigned char
123 asn1_tag_decode(struct asn1_ctx *ctx, unsigned int *tag)
124 {
125         unsigned char ch;
126
127         *tag = 0;
128
129         do {
130                 if (!asn1_octet_decode(ctx, &ch))
131                         return 0;
132                 *tag <<= 7;
133                 *tag |= ch & 0x7F;
134         } while ((ch & 0x80) == 0x80);
135         return 1;
136 }
137
138 static unsigned char
139 asn1_id_decode(struct asn1_ctx *ctx,
140                unsigned int *cls, unsigned int *con, unsigned int *tag)
141 {
142         unsigned char ch;
143
144         if (!asn1_octet_decode(ctx, &ch))
145                 return 0;
146
147         *cls = (ch & 0xC0) >> 6;
148         *con = (ch & 0x20) >> 5;
149         *tag = (ch & 0x1F);
150
151         if (*tag == 0x1F) {
152                 if (!asn1_tag_decode(ctx, tag))
153                         return 0;
154         }
155         return 1;
156 }
157
158 static unsigned char
159 asn1_length_decode(struct asn1_ctx *ctx, unsigned int *def, unsigned int *len)
160 {
161         unsigned char ch, cnt;
162
163         if (!asn1_octet_decode(ctx, &ch))
164                 return 0;
165
166         if (ch == 0x80)
167                 *def = 0;
168         else {
169                 *def = 1;
170
171                 if (ch < 0x80)
172                         *len = ch;
173                 else {
174                         cnt = (unsigned char) (ch & 0x7F);
175                         *len = 0;
176
177                         while (cnt > 0) {
178                                 if (!asn1_octet_decode(ctx, &ch))
179                                         return 0;
180                                 *len <<= 8;
181                                 *len |= ch;
182                                 cnt--;
183                         }
184                 }
185         }
186         return 1;
187 }
188
189 static unsigned char
190 asn1_header_decode(struct asn1_ctx *ctx,
191                    unsigned char **eoc,
192                    unsigned int *cls, unsigned int *con, unsigned int *tag)
193 {
194         unsigned int def = 0; 
195         unsigned int len = 0;
196
197         if (!asn1_id_decode(ctx, cls, con, tag))
198                 return 0;
199
200         if (!asn1_length_decode(ctx, &def, &len))
201                 return 0;
202
203         if (def)
204                 *eoc = ctx->pointer + len;
205         else
206                 *eoc = NULL;
207         return 1;
208 }
209
210 static unsigned char
211 asn1_eoc_decode(struct asn1_ctx *ctx, unsigned char *eoc)
212 {
213         unsigned char ch;
214
215         if (eoc == NULL) {
216                 if (!asn1_octet_decode(ctx, &ch))
217                         return 0;
218
219                 if (ch != 0x00) {
220                         ctx->error = ASN1_ERR_DEC_EOC_MISMATCH;
221                         return 0;
222                 }
223
224                 if (!asn1_octet_decode(ctx, &ch))
225                         return 0;
226
227                 if (ch != 0x00) {
228                         ctx->error = ASN1_ERR_DEC_EOC_MISMATCH;
229                         return 0;
230                 }
231                 return 1;
232         } else {
233                 if (ctx->pointer != eoc) {
234                         ctx->error = ASN1_ERR_DEC_LENGTH_MISMATCH;
235                         return 0;
236                 }
237                 return 1;
238         }
239 }
240
241 /* static unsigned char asn1_null_decode(struct asn1_ctx *ctx,
242                                       unsigned char *eoc)
243 {
244         ctx->pointer = eoc;
245         return 1;
246 }
247
248 static unsigned char asn1_long_decode(struct asn1_ctx *ctx,
249                                       unsigned char *eoc, long *integer)
250 {
251         unsigned char ch;
252         unsigned int len;
253
254         if (!asn1_octet_decode(ctx, &ch))
255                 return 0;
256
257         *integer = (signed char) ch;
258         len = 1;
259
260         while (ctx->pointer < eoc) {
261                 if (++len > sizeof(long)) {
262                         ctx->error = ASN1_ERR_DEC_BADVALUE;
263                         return 0;
264                 }
265
266                 if (!asn1_octet_decode(ctx, &ch))
267                         return 0;
268
269                 *integer <<= 8;
270                 *integer |= ch;
271         }
272         return 1;
273 }
274
275 static unsigned char asn1_uint_decode(struct asn1_ctx *ctx,
276                                       unsigned char *eoc,
277                                       unsigned int *integer)
278 {
279         unsigned char ch;
280         unsigned int len;
281
282         if (!asn1_octet_decode(ctx, &ch))
283                 return 0;
284
285         *integer = ch;
286         if (ch == 0)
287                 len = 0;
288         else
289                 len = 1;
290
291         while (ctx->pointer < eoc) {
292                 if (++len > sizeof(unsigned int)) {
293                         ctx->error = ASN1_ERR_DEC_BADVALUE;
294                         return 0;
295                 }
296
297                 if (!asn1_octet_decode(ctx, &ch))
298                         return 0;
299
300                 *integer <<= 8;
301                 *integer |= ch;
302         }
303         return 1;
304 }
305
306 static unsigned char asn1_ulong_decode(struct asn1_ctx *ctx,
307                                        unsigned char *eoc,
308                                        unsigned long *integer)
309 {
310         unsigned char ch;
311         unsigned int len;
312
313         if (!asn1_octet_decode(ctx, &ch))
314                 return 0;
315
316         *integer = ch;
317         if (ch == 0)
318                 len = 0;
319         else
320                 len = 1;
321
322         while (ctx->pointer < eoc) {
323                 if (++len > sizeof(unsigned long)) {
324                         ctx->error = ASN1_ERR_DEC_BADVALUE;
325                         return 0;
326                 }
327
328                 if (!asn1_octet_decode(ctx, &ch))
329                         return 0;
330
331                 *integer <<= 8;
332                 *integer |= ch;
333         }
334         return 1;
335
336
337 static unsigned char
338 asn1_octets_decode(struct asn1_ctx *ctx,
339                    unsigned char *eoc,
340                    unsigned char **octets, unsigned int *len)
341 {
342         unsigned char *ptr;
343
344         *len = 0;
345
346         *octets = kmalloc(eoc - ctx->pointer, GFP_ATOMIC);
347         if (*octets == NULL) {
348                 return 0;
349         }
350
351         ptr = *octets;
352         while (ctx->pointer < eoc) {
353                 if (!asn1_octet_decode(ctx, (unsigned char *) ptr++)) {
354                         kfree(*octets);
355                         *octets = NULL;
356                         return 0;
357                 }
358                 (*len)++;
359         }
360         return 1;
361 } */
362
363 static unsigned char
364 asn1_subid_decode(struct asn1_ctx *ctx, unsigned long *subid)
365 {
366         unsigned char ch;
367
368         *subid = 0;
369
370         do {
371                 if (!asn1_octet_decode(ctx, &ch))
372                         return 0;
373
374                 *subid <<= 7;
375                 *subid |= ch & 0x7F;
376         } while ((ch & 0x80) == 0x80);
377         return 1;
378 }
379
380 static int 
381 asn1_oid_decode(struct asn1_ctx *ctx,
382                 unsigned char *eoc, unsigned long **oid, unsigned int *len)
383 {
384         unsigned long subid;
385         unsigned int size;
386         unsigned long *optr;
387
388         size = eoc - ctx->pointer + 1;
389         *oid = kmalloc(size * sizeof (unsigned long), GFP_ATOMIC);
390         if (*oid == NULL) {
391                 return 0;
392         }
393
394         optr = *oid;
395
396         if (!asn1_subid_decode(ctx, &subid)) {
397                 kfree(*oid);
398                 *oid = NULL;
399                 return 0;
400         }
401
402         if (subid < 40) {
403                 optr[0] = 0;
404                 optr[1] = subid;
405         } else if (subid < 80) {
406                 optr[0] = 1;
407                 optr[1] = subid - 40;
408         } else {
409                 optr[0] = 2;
410                 optr[1] = subid - 80;
411         }
412
413         *len = 2;
414         optr += 2;
415
416         while (ctx->pointer < eoc) {
417                 if (++(*len) > size) {
418                         ctx->error = ASN1_ERR_DEC_BADVALUE;
419                         kfree(*oid);
420                         *oid = NULL;
421                         return 0;
422                 }
423
424                 if (!asn1_subid_decode(ctx, optr++)) {
425                         kfree(*oid);
426                         *oid = NULL;
427                         return 0;
428                 }
429         }
430         return 1;
431 }
432
433 static int
434 compare_oid(unsigned long *oid1, unsigned int oid1len,
435             unsigned long *oid2, unsigned int oid2len)
436 {
437         unsigned int i;
438
439         if (oid1len != oid2len)
440                 return 0;
441         else {
442                 for (i = 0; i < oid1len; i++) {
443                         if (oid1[i] != oid2[i])
444                                 return 0;
445                 }
446                 return 1;
447         }
448 }
449
450         /* BB check for endian conversion issues here */
451
452 int
453 decode_negTokenInit(unsigned char *security_blob, int length,
454                     enum securityEnum *secType)
455 {
456         struct asn1_ctx ctx;
457         unsigned char *end;
458         unsigned char *sequence_end;
459         unsigned long *oid = NULL;
460         unsigned int cls, con, tag, oidlen, rc;
461         int use_ntlmssp = FALSE;
462
463         *secType = NTLM; /* BB eventually make Kerberos or NLTMSSP the default */
464
465         /* cifs_dump_mem(" Received SecBlob ", security_blob, length); */
466
467         asn1_open(&ctx, security_blob, length);
468
469         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
470                 cFYI(1, ("Error decoding negTokenInit header "));
471                 return 0;
472         } else if ((cls != ASN1_APL) || (con != ASN1_CON)
473                    || (tag != ASN1_EOC)) {
474                 cFYI(1, ("cls = %d con = %d tag = %d", cls, con, tag));
475                 return 0;
476         } else {
477                 /*      remember to free obj->oid */
478                 rc = asn1_header_decode(&ctx, &end, &cls, &con, &tag);
479                 if (rc) {
480                         if ((tag == ASN1_OJI) && (cls == ASN1_PRI)) {
481                                 rc = asn1_oid_decode(&ctx, end, &oid, &oidlen);
482                                 if (rc) {
483                                         rc = compare_oid(oid, oidlen,
484                                                          SPNEGO_OID,
485                                                          SPNEGO_OID_LEN);
486                                         kfree(oid);
487                                 }
488                         } else
489                                 rc = 0;
490                 }
491
492                 if (!rc) {
493                         cFYI(1, ("Error decoding negTokenInit header"));
494                         return 0;
495                 }
496
497                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
498                         cFYI(1, ("Error decoding negTokenInit "));
499                         return 0;
500                 } else if ((cls != ASN1_CTX) || (con != ASN1_CON)
501                            || (tag != ASN1_EOC)) {
502                         cFYI(1,("cls = %d con = %d tag = %d end = %p (%d) exit 0",
503                               cls, con, tag, end, *end));
504                         return 0;
505                 }
506
507                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
508                         cFYI(1, ("Error decoding negTokenInit "));
509                         return 0;
510                 } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
511                            || (tag != ASN1_SEQ)) {
512                         cFYI(1,("cls = %d con = %d tag = %d end = %p (%d) exit 1",
513                               cls, con, tag, end, *end));
514                         return 0;
515                 }
516
517                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
518                         cFYI(1, ("Error decoding 2nd part of negTokenInit "));
519                         return 0;
520                 } else if ((cls != ASN1_CTX) || (con != ASN1_CON)
521                            || (tag != ASN1_EOC)) {
522                         cFYI(1,
523                              ("cls = %d con = %d tag = %d end = %p (%d) exit 0",
524                               cls, con, tag, end, *end));
525                         return 0;
526                 }
527
528                 if (asn1_header_decode
529                     (&ctx, &sequence_end, &cls, &con, &tag) == 0) {
530                         cFYI(1, ("Error decoding 2nd part of negTokenInit "));
531                         return 0;
532                 } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
533                            || (tag != ASN1_SEQ)) {
534                         cFYI(1,
535                              ("cls = %d con = %d tag = %d end = %p (%d) exit 1",
536                               cls, con, tag, end, *end));
537                         return 0;
538                 }
539
540                 while (!asn1_eoc_decode(&ctx, sequence_end)) {
541                         rc = asn1_header_decode(&ctx, &end, &cls, &con, &tag);
542                         if (!rc) {
543                                 cFYI(1,
544                                      ("Error 1 decoding negTokenInit header exit 2"));
545                                 return 0;
546                         }
547                         if ((tag == ASN1_OJI) && (con == ASN1_PRI)) {
548                                 rc = asn1_oid_decode(&ctx, end, &oid, &oidlen);
549                                 if(rc) {                
550                                         cFYI(1,
551                                           ("OID len = %d oid = 0x%lx 0x%lx 0x%lx 0x%lx",
552                                            oidlen, *oid, *(oid + 1), *(oid + 2),
553                                            *(oid + 3)));
554                                         rc = compare_oid(oid, oidlen, NTLMSSP_OID,
555                                                  NTLMSSP_OID_LEN);
556                                         if(oid)
557                                                 kfree(oid);
558                                         if (rc)
559                                                 use_ntlmssp = TRUE;
560                                 }
561                         } else {
562                                 cFYI(1,("This should be an oid what is going on? "));
563                         }
564                 }
565
566                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
567                         cFYI(1,
568                              ("Error decoding last part of negTokenInit exit 3"));
569                         return 0;
570                 } else if ((cls != ASN1_CTX) || (con != ASN1_CON)) {    /* tag = 3 indicating mechListMIC */
571                         cFYI(1,
572                              ("Exit 4 cls = %d con = %d tag = %d end = %p (%d)",
573                               cls, con, tag, end, *end));
574                         return 0;
575                 }
576                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
577                         cFYI(1,
578                              ("Error decoding last part of negTokenInit exit 5"));
579                         return 0;
580                 } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
581                            || (tag != ASN1_SEQ)) {
582                         cFYI(1,
583                              ("Exit 6 cls = %d con = %d tag = %d end = %p (%d)",
584                               cls, con, tag, end, *end));
585                 }
586
587                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
588                         cFYI(1,
589                              ("Error decoding last part of negTokenInit exit 7"));
590                         return 0;
591                 } else if ((cls != ASN1_CTX) || (con != ASN1_CON)) {
592                         cFYI(1,
593                              ("Exit 8 cls = %d con = %d tag = %d end = %p (%d)",
594                               cls, con, tag, end, *end));
595                         return 0;
596                 }
597                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
598                         cFYI(1,
599                              ("Error decoding last part of negTokenInit exit 9"));
600                         return 0;
601                 } else if ((cls != ASN1_UNI) || (con != ASN1_PRI)
602                            || (tag != ASN1_GENSTR)) {
603                         cFYI(1,
604                              ("Exit 10 cls = %d con = %d tag = %d end = %p (%d)",
605                               cls, con, tag, end, *end));
606                         return 0;
607                 }
608                 cFYI(1, ("Need to call asn1_octets_decode() function for this %s", ctx.pointer));       /* is this UTF-8 or ASCII? */
609         }
610
611         /* if (use_kerberos) 
612            *secType = Kerberos 
613            else */
614         if (use_ntlmssp) {
615                 *secType = NTLMSSP;
616         }
617
618         return 1;
619 }