Merge branch 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/jack/linux...
[linux-2.6] / fs / cifs / asn1.c
1 /*
2  * The ASB.1/BER parsing code is derived from ip_nat_snmp_basic.c which was in
3  * turn derived from the gxsnmp package by Gregory McLean & Jochen Friedrich
4  *
5  * Copyright (c) 2000 RP Internet (www.rpi.net.au).
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 2 of the License, or
10  * (at your option) any later version.
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
18  */
19
20 #include <linux/module.h>
21 #include <linux/types.h>
22 #include <linux/kernel.h>
23 #include <linux/mm.h>
24 #include <linux/slab.h>
25 #include "cifspdu.h"
26 #include "cifsglob.h"
27 #include "cifs_debug.h"
28 #include "cifsproto.h"
29
30 /*****************************************************************************
31  *
32  * Basic ASN.1 decoding routines (gxsnmp author Dirk Wisse)
33  *
34  *****************************************************************************/
35
36 /* Class */
37 #define ASN1_UNI        0       /* Universal */
38 #define ASN1_APL        1       /* Application */
39 #define ASN1_CTX        2       /* Context */
40 #define ASN1_PRV        3       /* Private */
41
42 /* Tag */
43 #define ASN1_EOC        0       /* End Of Contents or N/A */
44 #define ASN1_BOL        1       /* Boolean */
45 #define ASN1_INT        2       /* Integer */
46 #define ASN1_BTS        3       /* Bit String */
47 #define ASN1_OTS        4       /* Octet String */
48 #define ASN1_NUL        5       /* Null */
49 #define ASN1_OJI        6       /* Object Identifier  */
50 #define ASN1_OJD        7       /* Object Description */
51 #define ASN1_EXT        8       /* External */
52 #define ASN1_ENUM       10      /* Enumerated */
53 #define ASN1_SEQ        16      /* Sequence */
54 #define ASN1_SET        17      /* Set */
55 #define ASN1_NUMSTR     18      /* Numerical String */
56 #define ASN1_PRNSTR     19      /* Printable String */
57 #define ASN1_TEXSTR     20      /* Teletext String */
58 #define ASN1_VIDSTR     21      /* Video String */
59 #define ASN1_IA5STR     22      /* IA5 String */
60 #define ASN1_UNITIM     23      /* Universal Time */
61 #define ASN1_GENTIM     24      /* General Time */
62 #define ASN1_GRASTR     25      /* Graphical String */
63 #define ASN1_VISSTR     26      /* Visible String */
64 #define ASN1_GENSTR     27      /* General String */
65
66 /* Primitive / Constructed methods*/
67 #define ASN1_PRI        0       /* Primitive */
68 #define ASN1_CON        1       /* Constructed */
69
70 /*
71  * Error codes.
72  */
73 #define ASN1_ERR_NOERROR                0
74 #define ASN1_ERR_DEC_EMPTY              2
75 #define ASN1_ERR_DEC_EOC_MISMATCH       3
76 #define ASN1_ERR_DEC_LENGTH_MISMATCH    4
77 #define ASN1_ERR_DEC_BADVALUE           5
78
79 #define SPNEGO_OID_LEN 7
80 #define NTLMSSP_OID_LEN  10
81 #define KRB5_OID_LEN  7
82 #define KRB5U2U_OID_LEN  8
83 #define MSKRB5_OID_LEN  7
84 static unsigned long SPNEGO_OID[7] = { 1, 3, 6, 1, 5, 5, 2 };
85 static unsigned long NTLMSSP_OID[10] = { 1, 3, 6, 1, 4, 1, 311, 2, 2, 10 };
86 static unsigned long KRB5_OID[7] = { 1, 2, 840, 113554, 1, 2, 2 };
87 static unsigned long KRB5U2U_OID[8] = { 1, 2, 840, 113554, 1, 2, 2, 3 };
88 static unsigned long MSKRB5_OID[7] = { 1, 2, 840, 48018, 1, 2, 2 };
89
90 /*
91  * ASN.1 context.
92  */
93 struct asn1_ctx {
94         int error;              /* Error condition */
95         unsigned char *pointer; /* Octet just to be decoded */
96         unsigned char *begin;   /* First octet */
97         unsigned char *end;     /* Octet after last octet */
98 };
99
100 /*
101  * Octet string (not null terminated)
102  */
103 struct asn1_octstr {
104         unsigned char *data;
105         unsigned int len;
106 };
107
108 static void
109 asn1_open(struct asn1_ctx *ctx, unsigned char *buf, unsigned int len)
110 {
111         ctx->begin = buf;
112         ctx->end = buf + len;
113         ctx->pointer = buf;
114         ctx->error = ASN1_ERR_NOERROR;
115 }
116
117 static unsigned char
118 asn1_octet_decode(struct asn1_ctx *ctx, unsigned char *ch)
119 {
120         if (ctx->pointer >= ctx->end) {
121                 ctx->error = ASN1_ERR_DEC_EMPTY;
122                 return 0;
123         }
124         *ch = *(ctx->pointer)++;
125         return 1;
126 }
127
128 #if 0 /* will be needed later by spnego decoding/encoding of ntlmssp */
129 static unsigned char
130 asn1_enum_decode(struct asn1_ctx *ctx, __le32 *val)
131 {
132         unsigned char ch;
133
134         if (ctx->pointer >= ctx->end) {
135                 ctx->error = ASN1_ERR_DEC_EMPTY;
136                 return 0;
137         }
138
139         ch = *(ctx->pointer)++; /* ch has 0xa, ptr points to lenght octet */
140         if ((ch) == ASN1_ENUM)  /* if ch value is ENUM, 0xa */
141                 *val = *(++(ctx->pointer)); /* value has enum value */
142         else
143                 return 0;
144
145         ctx->pointer++;
146         return 1;
147 }
148 #endif
149
150 static unsigned char
151 asn1_tag_decode(struct asn1_ctx *ctx, unsigned int *tag)
152 {
153         unsigned char ch;
154
155         *tag = 0;
156
157         do {
158                 if (!asn1_octet_decode(ctx, &ch))
159                         return 0;
160                 *tag <<= 7;
161                 *tag |= ch & 0x7F;
162         } while ((ch & 0x80) == 0x80);
163         return 1;
164 }
165
166 static unsigned char
167 asn1_id_decode(struct asn1_ctx *ctx,
168                unsigned int *cls, unsigned int *con, unsigned int *tag)
169 {
170         unsigned char ch;
171
172         if (!asn1_octet_decode(ctx, &ch))
173                 return 0;
174
175         *cls = (ch & 0xC0) >> 6;
176         *con = (ch & 0x20) >> 5;
177         *tag = (ch & 0x1F);
178
179         if (*tag == 0x1F) {
180                 if (!asn1_tag_decode(ctx, tag))
181                         return 0;
182         }
183         return 1;
184 }
185
186 static unsigned char
187 asn1_length_decode(struct asn1_ctx *ctx, unsigned int *def, unsigned int *len)
188 {
189         unsigned char ch, cnt;
190
191         if (!asn1_octet_decode(ctx, &ch))
192                 return 0;
193
194         if (ch == 0x80)
195                 *def = 0;
196         else {
197                 *def = 1;
198
199                 if (ch < 0x80)
200                         *len = ch;
201                 else {
202                         cnt = (unsigned char) (ch & 0x7F);
203                         *len = 0;
204
205                         while (cnt > 0) {
206                                 if (!asn1_octet_decode(ctx, &ch))
207                                         return 0;
208                                 *len <<= 8;
209                                 *len |= ch;
210                                 cnt--;
211                         }
212                 }
213         }
214
215         /* don't trust len bigger than ctx buffer */
216         if (*len > ctx->end - ctx->pointer)
217                 return 0;
218
219         return 1;
220 }
221
222 static unsigned char
223 asn1_header_decode(struct asn1_ctx *ctx,
224                    unsigned char **eoc,
225                    unsigned int *cls, unsigned int *con, unsigned int *tag)
226 {
227         unsigned int def = 0;
228         unsigned int len = 0;
229
230         if (!asn1_id_decode(ctx, cls, con, tag))
231                 return 0;
232
233         if (!asn1_length_decode(ctx, &def, &len))
234                 return 0;
235
236         /* primitive shall be definite, indefinite shall be constructed */
237         if (*con == ASN1_PRI && !def)
238                 return 0;
239
240         if (def)
241                 *eoc = ctx->pointer + len;
242         else
243                 *eoc = NULL;
244         return 1;
245 }
246
247 static unsigned char
248 asn1_eoc_decode(struct asn1_ctx *ctx, unsigned char *eoc)
249 {
250         unsigned char ch;
251
252         if (eoc == NULL) {
253                 if (!asn1_octet_decode(ctx, &ch))
254                         return 0;
255
256                 if (ch != 0x00) {
257                         ctx->error = ASN1_ERR_DEC_EOC_MISMATCH;
258                         return 0;
259                 }
260
261                 if (!asn1_octet_decode(ctx, &ch))
262                         return 0;
263
264                 if (ch != 0x00) {
265                         ctx->error = ASN1_ERR_DEC_EOC_MISMATCH;
266                         return 0;
267                 }
268                 return 1;
269         } else {
270                 if (ctx->pointer != eoc) {
271                         ctx->error = ASN1_ERR_DEC_LENGTH_MISMATCH;
272                         return 0;
273                 }
274                 return 1;
275         }
276 }
277
278 /* static unsigned char asn1_null_decode(struct asn1_ctx *ctx,
279                                       unsigned char *eoc)
280 {
281         ctx->pointer = eoc;
282         return 1;
283 }
284
285 static unsigned char asn1_long_decode(struct asn1_ctx *ctx,
286                                       unsigned char *eoc, long *integer)
287 {
288         unsigned char ch;
289         unsigned int len;
290
291         if (!asn1_octet_decode(ctx, &ch))
292                 return 0;
293
294         *integer = (signed char) ch;
295         len = 1;
296
297         while (ctx->pointer < eoc) {
298                 if (++len > sizeof(long)) {
299                         ctx->error = ASN1_ERR_DEC_BADVALUE;
300                         return 0;
301                 }
302
303                 if (!asn1_octet_decode(ctx, &ch))
304                         return 0;
305
306                 *integer <<= 8;
307                 *integer |= ch;
308         }
309         return 1;
310 }
311
312 static unsigned char asn1_uint_decode(struct asn1_ctx *ctx,
313                                       unsigned char *eoc,
314                                       unsigned int *integer)
315 {
316         unsigned char ch;
317         unsigned int len;
318
319         if (!asn1_octet_decode(ctx, &ch))
320                 return 0;
321
322         *integer = ch;
323         if (ch == 0)
324                 len = 0;
325         else
326                 len = 1;
327
328         while (ctx->pointer < eoc) {
329                 if (++len > sizeof(unsigned int)) {
330                         ctx->error = ASN1_ERR_DEC_BADVALUE;
331                         return 0;
332                 }
333
334                 if (!asn1_octet_decode(ctx, &ch))
335                         return 0;
336
337                 *integer <<= 8;
338                 *integer |= ch;
339         }
340         return 1;
341 }
342
343 static unsigned char asn1_ulong_decode(struct asn1_ctx *ctx,
344                                        unsigned char *eoc,
345                                        unsigned long *integer)
346 {
347         unsigned char ch;
348         unsigned int len;
349
350         if (!asn1_octet_decode(ctx, &ch))
351                 return 0;
352
353         *integer = ch;
354         if (ch == 0)
355                 len = 0;
356         else
357                 len = 1;
358
359         while (ctx->pointer < eoc) {
360                 if (++len > sizeof(unsigned long)) {
361                         ctx->error = ASN1_ERR_DEC_BADVALUE;
362                         return 0;
363                 }
364
365                 if (!asn1_octet_decode(ctx, &ch))
366                         return 0;
367
368                 *integer <<= 8;
369                 *integer |= ch;
370         }
371         return 1;
372 }
373
374 static unsigned char
375 asn1_octets_decode(struct asn1_ctx *ctx,
376                    unsigned char *eoc,
377                    unsigned char **octets, unsigned int *len)
378 {
379         unsigned char *ptr;
380
381         *len = 0;
382
383         *octets = kmalloc(eoc - ctx->pointer, GFP_ATOMIC);
384         if (*octets == NULL) {
385                 return 0;
386         }
387
388         ptr = *octets;
389         while (ctx->pointer < eoc) {
390                 if (!asn1_octet_decode(ctx, (unsigned char *) ptr++)) {
391                         kfree(*octets);
392                         *octets = NULL;
393                         return 0;
394                 }
395                 (*len)++;
396         }
397         return 1;
398 } */
399
400 static unsigned char
401 asn1_subid_decode(struct asn1_ctx *ctx, unsigned long *subid)
402 {
403         unsigned char ch;
404
405         *subid = 0;
406
407         do {
408                 if (!asn1_octet_decode(ctx, &ch))
409                         return 0;
410
411                 *subid <<= 7;
412                 *subid |= ch & 0x7F;
413         } while ((ch & 0x80) == 0x80);
414         return 1;
415 }
416
417 static int
418 asn1_oid_decode(struct asn1_ctx *ctx,
419                 unsigned char *eoc, unsigned long **oid, unsigned int *len)
420 {
421         unsigned long subid;
422         unsigned int size;
423         unsigned long *optr;
424
425         size = eoc - ctx->pointer + 1;
426
427         /* first subid actually encodes first two subids */
428         if (size < 2 || size > UINT_MAX/sizeof(unsigned long))
429                 return 0;
430
431         *oid = kmalloc(size * sizeof(unsigned long), GFP_ATOMIC);
432         if (*oid == NULL)
433                 return 0;
434
435         optr = *oid;
436
437         if (!asn1_subid_decode(ctx, &subid)) {
438                 kfree(*oid);
439                 *oid = NULL;
440                 return 0;
441         }
442
443         if (subid < 40) {
444                 optr[0] = 0;
445                 optr[1] = subid;
446         } else if (subid < 80) {
447                 optr[0] = 1;
448                 optr[1] = subid - 40;
449         } else {
450                 optr[0] = 2;
451                 optr[1] = subid - 80;
452         }
453
454         *len = 2;
455         optr += 2;
456
457         while (ctx->pointer < eoc) {
458                 if (++(*len) > size) {
459                         ctx->error = ASN1_ERR_DEC_BADVALUE;
460                         kfree(*oid);
461                         *oid = NULL;
462                         return 0;
463                 }
464
465                 if (!asn1_subid_decode(ctx, optr++)) {
466                         kfree(*oid);
467                         *oid = NULL;
468                         return 0;
469                 }
470         }
471         return 1;
472 }
473
474 static int
475 compare_oid(unsigned long *oid1, unsigned int oid1len,
476             unsigned long *oid2, unsigned int oid2len)
477 {
478         unsigned int i;
479
480         if (oid1len != oid2len)
481                 return 0;
482         else {
483                 for (i = 0; i < oid1len; i++) {
484                         if (oid1[i] != oid2[i])
485                                 return 0;
486                 }
487                 return 1;
488         }
489 }
490
491         /* BB check for endian conversion issues here */
492
493 int
494 decode_negTokenInit(unsigned char *security_blob, int length,
495                     enum securityEnum *secType)
496 {
497         struct asn1_ctx ctx;
498         unsigned char *end;
499         unsigned char *sequence_end;
500         unsigned long *oid = NULL;
501         unsigned int cls, con, tag, oidlen, rc;
502         bool use_ntlmssp = false;
503         bool use_kerberos = false;
504         bool use_kerberosu2u = false;
505         bool use_mskerberos = false;
506
507         /* cifs_dump_mem(" Received SecBlob ", security_blob, length); */
508
509         asn1_open(&ctx, security_blob, length);
510
511         /* GSSAPI header */
512         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
513                 cFYI(1, ("Error decoding negTokenInit header"));
514                 return 0;
515         } else if ((cls != ASN1_APL) || (con != ASN1_CON)
516                    || (tag != ASN1_EOC)) {
517                 cFYI(1, ("cls = %d con = %d tag = %d", cls, con, tag));
518                 return 0;
519         }
520
521         /* Check for SPNEGO OID -- remember to free obj->oid */
522         rc = asn1_header_decode(&ctx, &end, &cls, &con, &tag);
523         if (rc) {
524                 if ((tag == ASN1_OJI) && (con == ASN1_PRI) &&
525                     (cls == ASN1_UNI)) {
526                         rc = asn1_oid_decode(&ctx, end, &oid, &oidlen);
527                         if (rc) {
528                                 rc = compare_oid(oid, oidlen, SPNEGO_OID,
529                                                  SPNEGO_OID_LEN);
530                                 kfree(oid);
531                         }
532                 } else
533                         rc = 0;
534         }
535
536         /* SPNEGO OID not present or garbled -- bail out */
537         if (!rc) {
538                 cFYI(1, ("Error decoding negTokenInit header"));
539                 return 0;
540         }
541
542         /* SPNEGO */
543         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
544                 cFYI(1, ("Error decoding negTokenInit"));
545                 return 0;
546         } else if ((cls != ASN1_CTX) || (con != ASN1_CON)
547                    || (tag != ASN1_EOC)) {
548                 cFYI(1,
549                      ("cls = %d con = %d tag = %d end = %p (%d) exit 0",
550                       cls, con, tag, end, *end));
551                 return 0;
552         }
553
554         /* negTokenInit */
555         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
556                 cFYI(1, ("Error decoding negTokenInit"));
557                 return 0;
558         } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
559                    || (tag != ASN1_SEQ)) {
560                 cFYI(1,
561                      ("cls = %d con = %d tag = %d end = %p (%d) exit 1",
562                       cls, con, tag, end, *end));
563                 return 0;
564         }
565
566         /* sequence */
567         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
568                 cFYI(1, ("Error decoding 2nd part of negTokenInit"));
569                 return 0;
570         } else if ((cls != ASN1_CTX) || (con != ASN1_CON)
571                    || (tag != ASN1_EOC)) {
572                 cFYI(1,
573                      ("cls = %d con = %d tag = %d end = %p (%d) exit 0",
574                       cls, con, tag, end, *end));
575                 return 0;
576         }
577
578         /* sequence of */
579         if (asn1_header_decode
580             (&ctx, &sequence_end, &cls, &con, &tag) == 0) {
581                 cFYI(1, ("Error decoding 2nd part of negTokenInit"));
582                 return 0;
583         } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
584                    || (tag != ASN1_SEQ)) {
585                 cFYI(1,
586                      ("cls = %d con = %d tag = %d end = %p (%d) exit 1",
587                       cls, con, tag, end, *end));
588                 return 0;
589         }
590
591         /* list of security mechanisms */
592         while (!asn1_eoc_decode(&ctx, sequence_end)) {
593                 rc = asn1_header_decode(&ctx, &end, &cls, &con, &tag);
594                 if (!rc) {
595                         cFYI(1,
596                              ("Error decoding negTokenInit hdr exit2"));
597                         return 0;
598                 }
599                 if ((tag == ASN1_OJI) && (con == ASN1_PRI)) {
600                         if (asn1_oid_decode(&ctx, end, &oid, &oidlen)) {
601
602                                 cFYI(1, ("OID len = %d oid = 0x%lx 0x%lx "
603                                          "0x%lx 0x%lx", oidlen, *oid,
604                                          *(oid + 1), *(oid + 2), *(oid + 3)));
605
606                                 if (compare_oid(oid, oidlen, MSKRB5_OID,
607                                                 MSKRB5_OID_LEN) &&
608                                                 !use_mskerberos)
609                                         use_mskerberos = true;
610                                 else if (compare_oid(oid, oidlen, KRB5U2U_OID,
611                                                      KRB5U2U_OID_LEN) &&
612                                                      !use_kerberosu2u)
613                                         use_kerberosu2u = true;
614                                 else if (compare_oid(oid, oidlen, KRB5_OID,
615                                                      KRB5_OID_LEN) &&
616                                                      !use_kerberos)
617                                         use_kerberos = true;
618                                 else if (compare_oid(oid, oidlen, NTLMSSP_OID,
619                                                      NTLMSSP_OID_LEN))
620                                         use_ntlmssp = true;
621
622                                 kfree(oid);
623                         }
624                 } else {
625                         cFYI(1, ("Should be an oid what is going on?"));
626                 }
627         }
628
629         /* mechlistMIC */
630         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
631                 /* Check if we have reached the end of the blob, but with
632                    no mechListMic (e.g. NTLMSSP instead of KRB5) */
633                 if (ctx.error == ASN1_ERR_DEC_EMPTY)
634                         goto decode_negtoken_exit;
635                 cFYI(1, ("Error decoding last part negTokenInit exit3"));
636                 return 0;
637         } else if ((cls != ASN1_CTX) || (con != ASN1_CON)) {
638                 /* tag = 3 indicating mechListMIC */
639                 cFYI(1, ("Exit 4 cls = %d con = %d tag = %d end = %p (%d)",
640                          cls, con, tag, end, *end));
641                 return 0;
642         }
643
644         /* sequence */
645         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
646                 cFYI(1, ("Error decoding last part negTokenInit exit5"));
647                 return 0;
648         } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
649                    || (tag != ASN1_SEQ)) {
650                 cFYI(1, ("cls = %d con = %d tag = %d end = %p (%d)",
651                         cls, con, tag, end, *end));
652         }
653
654         /* sequence of */
655         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
656                 cFYI(1, ("Error decoding last part negTokenInit exit 7"));
657                 return 0;
658         } else if ((cls != ASN1_CTX) || (con != ASN1_CON)) {
659                 cFYI(1, ("Exit 8 cls = %d con = %d tag = %d end = %p (%d)",
660                          cls, con, tag, end, *end));
661                 return 0;
662         }
663
664         /* general string */
665         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
666                 cFYI(1, ("Error decoding last part negTokenInit exit9"));
667                 return 0;
668         } else if ((cls != ASN1_UNI) || (con != ASN1_PRI)
669                    || (tag != ASN1_GENSTR)) {
670                 cFYI(1, ("Exit10 cls = %d con = %d tag = %d end = %p (%d)",
671                          cls, con, tag, end, *end));
672                 return 0;
673         }
674         cFYI(1, ("Need to call asn1_octets_decode() function for %s",
675                  ctx.pointer)); /* is this UTF-8 or ASCII? */
676 decode_negtoken_exit:
677         if (use_kerberos)
678                 *secType = Kerberos;
679         else if (use_mskerberos)
680                 *secType = MSKerberos;
681         else if (use_ntlmssp)
682                 *secType = RawNTLMSSP;
683
684         return 1;
685 }