Merge git://git.kernel.org/pub/scm/linux/kernel/git/rusty/linux-2.6-for-linus
[linux-2.6] / fs / cifs / asn1.c
1 /*
2  * The ASB.1/BER parsing code is derived from ip_nat_snmp_basic.c which was in
3  * turn derived from the gxsnmp package by Gregory McLean & Jochen Friedrich
4  *
5  * Copyright (c) 2000 RP Internet (www.rpi.net.au).
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 2 of the License, or
10  * (at your option) any later version.
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
18  */
19
20 #include <linux/module.h>
21 #include <linux/types.h>
22 #include <linux/kernel.h>
23 #include <linux/mm.h>
24 #include <linux/slab.h>
25 #include "cifspdu.h"
26 #include "cifsglob.h"
27 #include "cifs_debug.h"
28 #include "cifsproto.h"
29
30 /*****************************************************************************
31  *
32  * Basic ASN.1 decoding routines (gxsnmp author Dirk Wisse)
33  *
34  *****************************************************************************/
35
36 /* Class */
37 #define ASN1_UNI        0       /* Universal */
38 #define ASN1_APL        1       /* Application */
39 #define ASN1_CTX        2       /* Context */
40 #define ASN1_PRV        3       /* Private */
41
42 /* Tag */
43 #define ASN1_EOC        0       /* End Of Contents or N/A */
44 #define ASN1_BOL        1       /* Boolean */
45 #define ASN1_INT        2       /* Integer */
46 #define ASN1_BTS        3       /* Bit String */
47 #define ASN1_OTS        4       /* Octet String */
48 #define ASN1_NUL        5       /* Null */
49 #define ASN1_OJI        6       /* Object Identifier  */
50 #define ASN1_OJD        7       /* Object Description */
51 #define ASN1_EXT        8       /* External */
52 #define ASN1_SEQ        16      /* Sequence */
53 #define ASN1_SET        17      /* Set */
54 #define ASN1_NUMSTR     18      /* Numerical String */
55 #define ASN1_PRNSTR     19      /* Printable String */
56 #define ASN1_TEXSTR     20      /* Teletext String */
57 #define ASN1_VIDSTR     21      /* Video String */
58 #define ASN1_IA5STR     22      /* IA5 String */
59 #define ASN1_UNITIM     23      /* Universal Time */
60 #define ASN1_GENTIM     24      /* General Time */
61 #define ASN1_GRASTR     25      /* Graphical String */
62 #define ASN1_VISSTR     26      /* Visible String */
63 #define ASN1_GENSTR     27      /* General String */
64
65 /* Primitive / Constructed methods*/
66 #define ASN1_PRI        0       /* Primitive */
67 #define ASN1_CON        1       /* Constructed */
68
69 /*
70  * Error codes.
71  */
72 #define ASN1_ERR_NOERROR                0
73 #define ASN1_ERR_DEC_EMPTY              2
74 #define ASN1_ERR_DEC_EOC_MISMATCH       3
75 #define ASN1_ERR_DEC_LENGTH_MISMATCH    4
76 #define ASN1_ERR_DEC_BADVALUE           5
77
78 #define SPNEGO_OID_LEN 7
79 #define NTLMSSP_OID_LEN  10
80 #define KRB5_OID_LEN  7
81 #define MSKRB5_OID_LEN  7
82 static unsigned long SPNEGO_OID[7] = { 1, 3, 6, 1, 5, 5, 2 };
83 static unsigned long NTLMSSP_OID[10] = { 1, 3, 6, 1, 4, 1, 311, 2, 2, 10 };
84 static unsigned long KRB5_OID[7] = { 1, 2, 840, 113554, 1, 2, 2 };
85 static unsigned long MSKRB5_OID[7] = { 1, 2, 840, 48018, 1, 2, 2 };
86
87 /*
88  * ASN.1 context.
89  */
90 struct asn1_ctx {
91         int error;              /* Error condition */
92         unsigned char *pointer; /* Octet just to be decoded */
93         unsigned char *begin;   /* First octet */
94         unsigned char *end;     /* Octet after last octet */
95 };
96
97 /*
98  * Octet string (not null terminated)
99  */
100 struct asn1_octstr {
101         unsigned char *data;
102         unsigned int len;
103 };
104
105 static void
106 asn1_open(struct asn1_ctx *ctx, unsigned char *buf, unsigned int len)
107 {
108         ctx->begin = buf;
109         ctx->end = buf + len;
110         ctx->pointer = buf;
111         ctx->error = ASN1_ERR_NOERROR;
112 }
113
114 static unsigned char
115 asn1_octet_decode(struct asn1_ctx *ctx, unsigned char *ch)
116 {
117         if (ctx->pointer >= ctx->end) {
118                 ctx->error = ASN1_ERR_DEC_EMPTY;
119                 return 0;
120         }
121         *ch = *(ctx->pointer)++;
122         return 1;
123 }
124
125 static unsigned char
126 asn1_tag_decode(struct asn1_ctx *ctx, unsigned int *tag)
127 {
128         unsigned char ch;
129
130         *tag = 0;
131
132         do {
133                 if (!asn1_octet_decode(ctx, &ch))
134                         return 0;
135                 *tag <<= 7;
136                 *tag |= ch & 0x7F;
137         } while ((ch & 0x80) == 0x80);
138         return 1;
139 }
140
141 static unsigned char
142 asn1_id_decode(struct asn1_ctx *ctx,
143                unsigned int *cls, unsigned int *con, unsigned int *tag)
144 {
145         unsigned char ch;
146
147         if (!asn1_octet_decode(ctx, &ch))
148                 return 0;
149
150         *cls = (ch & 0xC0) >> 6;
151         *con = (ch & 0x20) >> 5;
152         *tag = (ch & 0x1F);
153
154         if (*tag == 0x1F) {
155                 if (!asn1_tag_decode(ctx, tag))
156                         return 0;
157         }
158         return 1;
159 }
160
161 static unsigned char
162 asn1_length_decode(struct asn1_ctx *ctx, unsigned int *def, unsigned int *len)
163 {
164         unsigned char ch, cnt;
165
166         if (!asn1_octet_decode(ctx, &ch))
167                 return 0;
168
169         if (ch == 0x80)
170                 *def = 0;
171         else {
172                 *def = 1;
173
174                 if (ch < 0x80)
175                         *len = ch;
176                 else {
177                         cnt = (unsigned char) (ch & 0x7F);
178                         *len = 0;
179
180                         while (cnt > 0) {
181                                 if (!asn1_octet_decode(ctx, &ch))
182                                         return 0;
183                                 *len <<= 8;
184                                 *len |= ch;
185                                 cnt--;
186                         }
187                 }
188         }
189
190         /* don't trust len bigger than ctx buffer */
191         if (*len > ctx->end - ctx->pointer)
192                 return 0;
193
194         return 1;
195 }
196
197 static unsigned char
198 asn1_header_decode(struct asn1_ctx *ctx,
199                    unsigned char **eoc,
200                    unsigned int *cls, unsigned int *con, unsigned int *tag)
201 {
202         unsigned int def = 0;
203         unsigned int len = 0;
204
205         if (!asn1_id_decode(ctx, cls, con, tag))
206                 return 0;
207
208         if (!asn1_length_decode(ctx, &def, &len))
209                 return 0;
210
211         /* primitive shall be definite, indefinite shall be constructed */
212         if (*con == ASN1_PRI && !def)
213                 return 0;
214
215         if (def)
216                 *eoc = ctx->pointer + len;
217         else
218                 *eoc = NULL;
219         return 1;
220 }
221
222 static unsigned char
223 asn1_eoc_decode(struct asn1_ctx *ctx, unsigned char *eoc)
224 {
225         unsigned char ch;
226
227         if (eoc == NULL) {
228                 if (!asn1_octet_decode(ctx, &ch))
229                         return 0;
230
231                 if (ch != 0x00) {
232                         ctx->error = ASN1_ERR_DEC_EOC_MISMATCH;
233                         return 0;
234                 }
235
236                 if (!asn1_octet_decode(ctx, &ch))
237                         return 0;
238
239                 if (ch != 0x00) {
240                         ctx->error = ASN1_ERR_DEC_EOC_MISMATCH;
241                         return 0;
242                 }
243                 return 1;
244         } else {
245                 if (ctx->pointer != eoc) {
246                         ctx->error = ASN1_ERR_DEC_LENGTH_MISMATCH;
247                         return 0;
248                 }
249                 return 1;
250         }
251 }
252
253 /* static unsigned char asn1_null_decode(struct asn1_ctx *ctx,
254                                       unsigned char *eoc)
255 {
256         ctx->pointer = eoc;
257         return 1;
258 }
259
260 static unsigned char asn1_long_decode(struct asn1_ctx *ctx,
261                                       unsigned char *eoc, long *integer)
262 {
263         unsigned char ch;
264         unsigned int len;
265
266         if (!asn1_octet_decode(ctx, &ch))
267                 return 0;
268
269         *integer = (signed char) ch;
270         len = 1;
271
272         while (ctx->pointer < eoc) {
273                 if (++len > sizeof(long)) {
274                         ctx->error = ASN1_ERR_DEC_BADVALUE;
275                         return 0;
276                 }
277
278                 if (!asn1_octet_decode(ctx, &ch))
279                         return 0;
280
281                 *integer <<= 8;
282                 *integer |= ch;
283         }
284         return 1;
285 }
286
287 static unsigned char asn1_uint_decode(struct asn1_ctx *ctx,
288                                       unsigned char *eoc,
289                                       unsigned int *integer)
290 {
291         unsigned char ch;
292         unsigned int len;
293
294         if (!asn1_octet_decode(ctx, &ch))
295                 return 0;
296
297         *integer = ch;
298         if (ch == 0)
299                 len = 0;
300         else
301                 len = 1;
302
303         while (ctx->pointer < eoc) {
304                 if (++len > sizeof(unsigned int)) {
305                         ctx->error = ASN1_ERR_DEC_BADVALUE;
306                         return 0;
307                 }
308
309                 if (!asn1_octet_decode(ctx, &ch))
310                         return 0;
311
312                 *integer <<= 8;
313                 *integer |= ch;
314         }
315         return 1;
316 }
317
318 static unsigned char asn1_ulong_decode(struct asn1_ctx *ctx,
319                                        unsigned char *eoc,
320                                        unsigned long *integer)
321 {
322         unsigned char ch;
323         unsigned int len;
324
325         if (!asn1_octet_decode(ctx, &ch))
326                 return 0;
327
328         *integer = ch;
329         if (ch == 0)
330                 len = 0;
331         else
332                 len = 1;
333
334         while (ctx->pointer < eoc) {
335                 if (++len > sizeof(unsigned long)) {
336                         ctx->error = ASN1_ERR_DEC_BADVALUE;
337                         return 0;
338                 }
339
340                 if (!asn1_octet_decode(ctx, &ch))
341                         return 0;
342
343                 *integer <<= 8;
344                 *integer |= ch;
345         }
346         return 1;
347 }
348
349 static unsigned char
350 asn1_octets_decode(struct asn1_ctx *ctx,
351                    unsigned char *eoc,
352                    unsigned char **octets, unsigned int *len)
353 {
354         unsigned char *ptr;
355
356         *len = 0;
357
358         *octets = kmalloc(eoc - ctx->pointer, GFP_ATOMIC);
359         if (*octets == NULL) {
360                 return 0;
361         }
362
363         ptr = *octets;
364         while (ctx->pointer < eoc) {
365                 if (!asn1_octet_decode(ctx, (unsigned char *) ptr++)) {
366                         kfree(*octets);
367                         *octets = NULL;
368                         return 0;
369                 }
370                 (*len)++;
371         }
372         return 1;
373 } */
374
375 static unsigned char
376 asn1_subid_decode(struct asn1_ctx *ctx, unsigned long *subid)
377 {
378         unsigned char ch;
379
380         *subid = 0;
381
382         do {
383                 if (!asn1_octet_decode(ctx, &ch))
384                         return 0;
385
386                 *subid <<= 7;
387                 *subid |= ch & 0x7F;
388         } while ((ch & 0x80) == 0x80);
389         return 1;
390 }
391
392 static int
393 asn1_oid_decode(struct asn1_ctx *ctx,
394                 unsigned char *eoc, unsigned long **oid, unsigned int *len)
395 {
396         unsigned long subid;
397         unsigned int size;
398         unsigned long *optr;
399
400         size = eoc - ctx->pointer + 1;
401
402         /* first subid actually encodes first two subids */
403         if (size < 2 || size > UINT_MAX/sizeof(unsigned long))
404                 return 0;
405
406         *oid = kmalloc(size * sizeof(unsigned long), GFP_ATOMIC);
407         if (*oid == NULL)
408                 return 0;
409
410         optr = *oid;
411
412         if (!asn1_subid_decode(ctx, &subid)) {
413                 kfree(*oid);
414                 *oid = NULL;
415                 return 0;
416         }
417
418         if (subid < 40) {
419                 optr[0] = 0;
420                 optr[1] = subid;
421         } else if (subid < 80) {
422                 optr[0] = 1;
423                 optr[1] = subid - 40;
424         } else {
425                 optr[0] = 2;
426                 optr[1] = subid - 80;
427         }
428
429         *len = 2;
430         optr += 2;
431
432         while (ctx->pointer < eoc) {
433                 if (++(*len) > size) {
434                         ctx->error = ASN1_ERR_DEC_BADVALUE;
435                         kfree(*oid);
436                         *oid = NULL;
437                         return 0;
438                 }
439
440                 if (!asn1_subid_decode(ctx, optr++)) {
441                         kfree(*oid);
442                         *oid = NULL;
443                         return 0;
444                 }
445         }
446         return 1;
447 }
448
449 static int
450 compare_oid(unsigned long *oid1, unsigned int oid1len,
451             unsigned long *oid2, unsigned int oid2len)
452 {
453         unsigned int i;
454
455         if (oid1len != oid2len)
456                 return 0;
457         else {
458                 for (i = 0; i < oid1len; i++) {
459                         if (oid1[i] != oid2[i])
460                                 return 0;
461                 }
462                 return 1;
463         }
464 }
465
466         /* BB check for endian conversion issues here */
467
468 int
469 decode_negTokenInit(unsigned char *security_blob, int length,
470                     enum securityEnum *secType)
471 {
472         struct asn1_ctx ctx;
473         unsigned char *end;
474         unsigned char *sequence_end;
475         unsigned long *oid = NULL;
476         unsigned int cls, con, tag, oidlen, rc;
477         bool use_ntlmssp = false;
478         bool use_kerberos = false;
479         bool use_mskerberos = false;
480
481         *secType = NTLM; /* BB eventually make Kerberos or NLTMSSP the default*/
482
483         /* cifs_dump_mem(" Received SecBlob ", security_blob, length); */
484
485         asn1_open(&ctx, security_blob, length);
486
487         /* GSSAPI header */
488         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
489                 cFYI(1, ("Error decoding negTokenInit header"));
490                 return 0;
491         } else if ((cls != ASN1_APL) || (con != ASN1_CON)
492                    || (tag != ASN1_EOC)) {
493                 cFYI(1, ("cls = %d con = %d tag = %d", cls, con, tag));
494                 return 0;
495         }
496
497         /* Check for SPNEGO OID -- remember to free obj->oid */
498         rc = asn1_header_decode(&ctx, &end, &cls, &con, &tag);
499         if (rc) {
500                 if ((tag == ASN1_OJI) && (con == ASN1_PRI) &&
501                     (cls == ASN1_UNI)) {
502                         rc = asn1_oid_decode(&ctx, end, &oid, &oidlen);
503                         if (rc) {
504                                 rc = compare_oid(oid, oidlen, SPNEGO_OID,
505                                                  SPNEGO_OID_LEN);
506                                 kfree(oid);
507                         }
508                 } else
509                         rc = 0;
510         }
511
512         /* SPNEGO OID not present or garbled -- bail out */
513         if (!rc) {
514                 cFYI(1, ("Error decoding negTokenInit header"));
515                 return 0;
516         }
517
518         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
519                 cFYI(1, ("Error decoding negTokenInit"));
520                 return 0;
521         } else if ((cls != ASN1_CTX) || (con != ASN1_CON)
522                    || (tag != ASN1_EOC)) {
523                 cFYI(1,
524                      ("cls = %d con = %d tag = %d end = %p (%d) exit 0",
525                       cls, con, tag, end, *end));
526                 return 0;
527         }
528
529         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
530                 cFYI(1, ("Error decoding negTokenInit"));
531                 return 0;
532         } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
533                    || (tag != ASN1_SEQ)) {
534                 cFYI(1,
535                      ("cls = %d con = %d tag = %d end = %p (%d) exit 1",
536                       cls, con, tag, end, *end));
537                 return 0;
538         }
539
540         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
541                 cFYI(1, ("Error decoding 2nd part of negTokenInit"));
542                 return 0;
543         } else if ((cls != ASN1_CTX) || (con != ASN1_CON)
544                    || (tag != ASN1_EOC)) {
545                 cFYI(1,
546                      ("cls = %d con = %d tag = %d end = %p (%d) exit 0",
547                       cls, con, tag, end, *end));
548                 return 0;
549         }
550
551         if (asn1_header_decode
552             (&ctx, &sequence_end, &cls, &con, &tag) == 0) {
553                 cFYI(1, ("Error decoding 2nd part of negTokenInit"));
554                 return 0;
555         } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
556                    || (tag != ASN1_SEQ)) {
557                 cFYI(1,
558                      ("cls = %d con = %d tag = %d end = %p (%d) exit 1",
559                       cls, con, tag, end, *end));
560                 return 0;
561         }
562
563         while (!asn1_eoc_decode(&ctx, sequence_end)) {
564                 rc = asn1_header_decode(&ctx, &end, &cls, &con, &tag);
565                 if (!rc) {
566                         cFYI(1,
567                              ("Error decoding negTokenInit hdr exit2"));
568                         return 0;
569                 }
570                 if ((tag == ASN1_OJI) && (con == ASN1_PRI)) {
571                         if (asn1_oid_decode(&ctx, end, &oid, &oidlen)) {
572
573                                 cFYI(1, ("OID len = %d oid = 0x%lx 0x%lx "
574                                          "0x%lx 0x%lx", oidlen, *oid,
575                                          *(oid + 1), *(oid + 2), *(oid + 3)));
576
577                                 if (compare_oid(oid, oidlen, MSKRB5_OID,
578                                                 MSKRB5_OID_LEN) &&
579                                                 !use_kerberos)
580                                         use_mskerberos = true;
581                                 else if (compare_oid(oid, oidlen, KRB5_OID,
582                                                      KRB5_OID_LEN) &&
583                                                      !use_mskerberos)
584                                         use_kerberos = true;
585                                 else if (compare_oid(oid, oidlen, NTLMSSP_OID,
586                                                      NTLMSSP_OID_LEN))
587                                         use_ntlmssp = true;
588
589                                 kfree(oid);
590                         }
591                 } else {
592                         cFYI(1, ("Should be an oid what is going on?"));
593                 }
594         }
595
596         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
597                 cFYI(1, ("Error decoding last part negTokenInit exit3"));
598                 return 0;
599         } else if ((cls != ASN1_CTX) || (con != ASN1_CON)) {
600                 /* tag = 3 indicating mechListMIC */
601                 cFYI(1, ("Exit 4 cls = %d con = %d tag = %d end = %p (%d)",
602                          cls, con, tag, end, *end));
603                 return 0;
604         }
605         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
606                 cFYI(1, ("Error decoding last part negTokenInit exit5"));
607                 return 0;
608         } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
609                    || (tag != ASN1_SEQ)) {
610                 cFYI(1, ("cls = %d con = %d tag = %d end = %p (%d)",
611                         cls, con, tag, end, *end));
612         }
613
614         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
615                 cFYI(1, ("Error decoding last part negTokenInit exit 7"));
616                 return 0;
617         } else if ((cls != ASN1_CTX) || (con != ASN1_CON)) {
618                 cFYI(1, ("Exit 8 cls = %d con = %d tag = %d end = %p (%d)",
619                          cls, con, tag, end, *end));
620                 return 0;
621         }
622         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
623                 cFYI(1, ("Error decoding last part negTokenInit exit9"));
624                 return 0;
625         } else if ((cls != ASN1_UNI) || (con != ASN1_PRI)
626                    || (tag != ASN1_GENSTR)) {
627                 cFYI(1, ("Exit10 cls = %d con = %d tag = %d end = %p (%d)",
628                          cls, con, tag, end, *end));
629                 return 0;
630         }
631         cFYI(1, ("Need to call asn1_octets_decode() function for %s",
632                  ctx.pointer)); /* is this UTF-8 or ASCII? */
633
634         if (use_kerberos)
635                 *secType = Kerberos;
636         else if (use_mskerberos)
637                 *secType = MSKerberos;
638         else if (use_ntlmssp)
639                 *secType = NTLMSSP;
640
641         return 1;
642 }