Merge git://git.infradead.org/mtd-2.6
[linux-2.6] / net / rxrpc / rxkad.c
1 /* Kerberos-based RxRPC security
2  *
3  * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved.
4  * Written by David Howells (dhowells@redhat.com)
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public License
8  * as published by the Free Software Foundation; either version
9  * 2 of the License, or (at your option) any later version.
10  */
11
12 #include <linux/module.h>
13 #include <linux/net.h>
14 #include <linux/skbuff.h>
15 #include <linux/udp.h>
16 #include <linux/crypto.h>
17 #include <linux/scatterlist.h>
18 #include <linux/ctype.h>
19 #include <net/sock.h>
20 #include <net/af_rxrpc.h>
21 #define rxrpc_debug rxkad_debug
22 #include "ar-internal.h"
23
24 #define RXKAD_VERSION                   2
25 #define MAXKRB5TICKETLEN                1024
26 #define RXKAD_TKT_TYPE_KERBEROS_V5      256
27 #define ANAME_SZ                        40      /* size of authentication name */
28 #define INST_SZ                         40      /* size of principal's instance */
29 #define REALM_SZ                        40      /* size of principal's auth domain */
30 #define SNAME_SZ                        40      /* size of service name */
31
32 unsigned rxrpc_debug;
33 module_param_named(debug, rxrpc_debug, uint, S_IWUSR | S_IRUGO);
34 MODULE_PARM_DESC(debug, "rxkad debugging mask");
35
36 struct rxkad_level1_hdr {
37         __be32  data_size;      /* true data size (excluding padding) */
38 };
39
40 struct rxkad_level2_hdr {
41         __be32  data_size;      /* true data size (excluding padding) */
42         __be32  checksum;       /* decrypted data checksum */
43 };
44
45 MODULE_DESCRIPTION("RxRPC network protocol type-2 security (Kerberos)");
46 MODULE_AUTHOR("Red Hat, Inc.");
47 MODULE_LICENSE("GPL");
48
49 /*
50  * this holds a pinned cipher so that keventd doesn't get called by the cipher
51  * alloc routine, but since we have it to hand, we use it to decrypt RESPONSE
52  * packets
53  */
54 static struct crypto_blkcipher *rxkad_ci;
55 static DEFINE_MUTEX(rxkad_ci_mutex);
56
57 /*
58  * initialise connection security
59  */
60 static int rxkad_init_connection_security(struct rxrpc_connection *conn)
61 {
62         struct rxrpc_key_payload *payload;
63         struct crypto_blkcipher *ci;
64         int ret;
65
66         _enter("{%d},{%x}", conn->debug_id, key_serial(conn->key));
67
68         payload = conn->key->payload.data;
69         conn->security_ix = payload->k.security_index;
70
71         ci = crypto_alloc_blkcipher("pcbc(fcrypt)", 0, CRYPTO_ALG_ASYNC);
72         if (IS_ERR(ci)) {
73                 _debug("no cipher");
74                 ret = PTR_ERR(ci);
75                 goto error;
76         }
77
78         if (crypto_blkcipher_setkey(ci, payload->k.session_key,
79                                     sizeof(payload->k.session_key)) < 0)
80                 BUG();
81
82         switch (conn->security_level) {
83         case RXRPC_SECURITY_PLAIN:
84                 break;
85         case RXRPC_SECURITY_AUTH:
86                 conn->size_align = 8;
87                 conn->security_size = sizeof(struct rxkad_level1_hdr);
88                 conn->header_size += sizeof(struct rxkad_level1_hdr);
89                 break;
90         case RXRPC_SECURITY_ENCRYPT:
91                 conn->size_align = 8;
92                 conn->security_size = sizeof(struct rxkad_level2_hdr);
93                 conn->header_size += sizeof(struct rxkad_level2_hdr);
94                 break;
95         default:
96                 ret = -EKEYREJECTED;
97                 goto error;
98         }
99
100         conn->cipher = ci;
101         ret = 0;
102 error:
103         _leave(" = %d", ret);
104         return ret;
105 }
106
107 /*
108  * prime the encryption state with the invariant parts of a connection's
109  * description
110  */
111 static void rxkad_prime_packet_security(struct rxrpc_connection *conn)
112 {
113         struct rxrpc_key_payload *payload;
114         struct blkcipher_desc desc;
115         struct scatterlist sg[2];
116         struct rxrpc_crypt iv;
117         struct {
118                 __be32 x[4];
119         } tmpbuf __attribute__((aligned(16))); /* must all be in same page */
120
121         _enter("");
122
123         if (!conn->key)
124                 return;
125
126         payload = conn->key->payload.data;
127         memcpy(&iv, payload->k.session_key, sizeof(iv));
128
129         desc.tfm = conn->cipher;
130         desc.info = iv.x;
131         desc.flags = 0;
132
133         tmpbuf.x[0] = conn->epoch;
134         tmpbuf.x[1] = conn->cid;
135         tmpbuf.x[2] = 0;
136         tmpbuf.x[3] = htonl(conn->security_ix);
137
138         sg_init_one(&sg[0], &tmpbuf, sizeof(tmpbuf));
139         sg_init_one(&sg[1], &tmpbuf, sizeof(tmpbuf));
140         crypto_blkcipher_encrypt_iv(&desc, &sg[0], &sg[1], sizeof(tmpbuf));
141
142         memcpy(&conn->csum_iv, &tmpbuf.x[2], sizeof(conn->csum_iv));
143         ASSERTCMP(conn->csum_iv.n[0], ==, tmpbuf.x[2]);
144
145         _leave("");
146 }
147
148 /*
149  * partially encrypt a packet (level 1 security)
150  */
151 static int rxkad_secure_packet_auth(const struct rxrpc_call *call,
152                                     struct sk_buff *skb,
153                                     u32 data_size,
154                                     void *sechdr)
155 {
156         struct rxrpc_skb_priv *sp;
157         struct blkcipher_desc desc;
158         struct rxrpc_crypt iv;
159         struct scatterlist sg[2];
160         struct {
161                 struct rxkad_level1_hdr hdr;
162                 __be32  first;  /* first four bytes of data and padding */
163         } tmpbuf __attribute__((aligned(8))); /* must all be in same page */
164         u16 check;
165
166         sp = rxrpc_skb(skb);
167
168         _enter("");
169
170         check = ntohl(sp->hdr.seq ^ sp->hdr.callNumber);
171         data_size |= (u32) check << 16;
172
173         tmpbuf.hdr.data_size = htonl(data_size);
174         memcpy(&tmpbuf.first, sechdr + 4, sizeof(tmpbuf.first));
175
176         /* start the encryption afresh */
177         memset(&iv, 0, sizeof(iv));
178         desc.tfm = call->conn->cipher;
179         desc.info = iv.x;
180         desc.flags = 0;
181
182         sg_init_one(&sg[0], &tmpbuf, sizeof(tmpbuf));
183         sg_init_one(&sg[1], &tmpbuf, sizeof(tmpbuf));
184         crypto_blkcipher_encrypt_iv(&desc, &sg[0], &sg[1], sizeof(tmpbuf));
185
186         memcpy(sechdr, &tmpbuf, sizeof(tmpbuf));
187
188         _leave(" = 0");
189         return 0;
190 }
191
192 /*
193  * wholly encrypt a packet (level 2 security)
194  */
195 static int rxkad_secure_packet_encrypt(const struct rxrpc_call *call,
196                                         struct sk_buff *skb,
197                                         u32 data_size,
198                                         void *sechdr)
199 {
200         const struct rxrpc_key_payload *payload;
201         struct rxkad_level2_hdr rxkhdr
202                 __attribute__((aligned(8))); /* must be all on one page */
203         struct rxrpc_skb_priv *sp;
204         struct blkcipher_desc desc;
205         struct rxrpc_crypt iv;
206         struct scatterlist sg[16];
207         struct sk_buff *trailer;
208         unsigned len;
209         u16 check;
210         int nsg;
211
212         sp = rxrpc_skb(skb);
213
214         _enter("");
215
216         check = ntohl(sp->hdr.seq ^ sp->hdr.callNumber);
217
218         rxkhdr.data_size = htonl(data_size | (u32) check << 16);
219         rxkhdr.checksum = 0;
220
221         /* encrypt from the session key */
222         payload = call->conn->key->payload.data;
223         memcpy(&iv, payload->k.session_key, sizeof(iv));
224         desc.tfm = call->conn->cipher;
225         desc.info = iv.x;
226         desc.flags = 0;
227
228         sg_init_one(&sg[0], sechdr, sizeof(rxkhdr));
229         sg_init_one(&sg[1], &rxkhdr, sizeof(rxkhdr));
230         crypto_blkcipher_encrypt_iv(&desc, &sg[0], &sg[1], sizeof(rxkhdr));
231
232         /* we want to encrypt the skbuff in-place */
233         nsg = skb_cow_data(skb, 0, &trailer);
234         if (nsg < 0 || nsg > 16)
235                 return -ENOMEM;
236
237         len = data_size + call->conn->size_align - 1;
238         len &= ~(call->conn->size_align - 1);
239
240         sg_init_table(sg, nsg);
241         skb_to_sgvec(skb, sg, 0, len);
242         crypto_blkcipher_encrypt_iv(&desc, sg, sg, len);
243
244         _leave(" = 0");
245         return 0;
246 }
247
248 /*
249  * checksum an RxRPC packet header
250  */
251 static int rxkad_secure_packet(const struct rxrpc_call *call,
252                                 struct sk_buff *skb,
253                                 size_t data_size,
254                                 void *sechdr)
255 {
256         struct rxrpc_skb_priv *sp;
257         struct blkcipher_desc desc;
258         struct rxrpc_crypt iv;
259         struct scatterlist sg[2];
260         struct {
261                 __be32 x[2];
262         } tmpbuf __attribute__((aligned(8))); /* must all be in same page */
263         __be32 x;
264         u32 y;
265         int ret;
266
267         sp = rxrpc_skb(skb);
268
269         _enter("{%d{%x}},{#%u},%zu,",
270                call->debug_id, key_serial(call->conn->key), ntohl(sp->hdr.seq),
271                data_size);
272
273         if (!call->conn->cipher)
274                 return 0;
275
276         ret = key_validate(call->conn->key);
277         if (ret < 0)
278                 return ret;
279
280         /* continue encrypting from where we left off */
281         memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
282         desc.tfm = call->conn->cipher;
283         desc.info = iv.x;
284         desc.flags = 0;
285
286         /* calculate the security checksum */
287         x = htonl(call->channel << (32 - RXRPC_CIDSHIFT));
288         x |= sp->hdr.seq & cpu_to_be32(0x3fffffff);
289         tmpbuf.x[0] = sp->hdr.callNumber;
290         tmpbuf.x[1] = x;
291
292         sg_init_one(&sg[0], &tmpbuf, sizeof(tmpbuf));
293         sg_init_one(&sg[1], &tmpbuf, sizeof(tmpbuf));
294         crypto_blkcipher_encrypt_iv(&desc, &sg[0], &sg[1], sizeof(tmpbuf));
295
296         y = ntohl(tmpbuf.x[1]);
297         y = (y >> 16) & 0xffff;
298         if (y == 0)
299                 y = 1; /* zero checksums are not permitted */
300         sp->hdr.cksum = htons(y);
301
302         switch (call->conn->security_level) {
303         case RXRPC_SECURITY_PLAIN:
304                 ret = 0;
305                 break;
306         case RXRPC_SECURITY_AUTH:
307                 ret = rxkad_secure_packet_auth(call, skb, data_size, sechdr);
308                 break;
309         case RXRPC_SECURITY_ENCRYPT:
310                 ret = rxkad_secure_packet_encrypt(call, skb, data_size,
311                                                   sechdr);
312                 break;
313         default:
314                 ret = -EPERM;
315                 break;
316         }
317
318         _leave(" = %d [set %hx]", ret, y);
319         return ret;
320 }
321
322 /*
323  * decrypt partial encryption on a packet (level 1 security)
324  */
325 static int rxkad_verify_packet_auth(const struct rxrpc_call *call,
326                                     struct sk_buff *skb,
327                                     u32 *_abort_code)
328 {
329         struct rxkad_level1_hdr sechdr;
330         struct rxrpc_skb_priv *sp;
331         struct blkcipher_desc desc;
332         struct rxrpc_crypt iv;
333         struct scatterlist sg[16];
334         struct sk_buff *trailer;
335         u32 data_size, buf;
336         u16 check;
337         int nsg;
338
339         _enter("");
340
341         sp = rxrpc_skb(skb);
342
343         /* we want to decrypt the skbuff in-place */
344         nsg = skb_cow_data(skb, 0, &trailer);
345         if (nsg < 0 || nsg > 16)
346                 goto nomem;
347
348         sg_init_table(sg, nsg);
349         skb_to_sgvec(skb, sg, 0, 8);
350
351         /* start the decryption afresh */
352         memset(&iv, 0, sizeof(iv));
353         desc.tfm = call->conn->cipher;
354         desc.info = iv.x;
355         desc.flags = 0;
356
357         crypto_blkcipher_decrypt_iv(&desc, sg, sg, 8);
358
359         /* remove the decrypted packet length */
360         if (skb_copy_bits(skb, 0, &sechdr, sizeof(sechdr)) < 0)
361                 goto datalen_error;
362         if (!skb_pull(skb, sizeof(sechdr)))
363                 BUG();
364
365         buf = ntohl(sechdr.data_size);
366         data_size = buf & 0xffff;
367
368         check = buf >> 16;
369         check ^= ntohl(sp->hdr.seq ^ sp->hdr.callNumber);
370         check &= 0xffff;
371         if (check != 0) {
372                 *_abort_code = RXKADSEALEDINCON;
373                 goto protocol_error;
374         }
375
376         /* shorten the packet to remove the padding */
377         if (data_size > skb->len)
378                 goto datalen_error;
379         else if (data_size < skb->len)
380                 skb->len = data_size;
381
382         _leave(" = 0 [dlen=%x]", data_size);
383         return 0;
384
385 datalen_error:
386         *_abort_code = RXKADDATALEN;
387 protocol_error:
388         _leave(" = -EPROTO");
389         return -EPROTO;
390
391 nomem:
392         _leave(" = -ENOMEM");
393         return -ENOMEM;
394 }
395
396 /*
397  * wholly decrypt a packet (level 2 security)
398  */
399 static int rxkad_verify_packet_encrypt(const struct rxrpc_call *call,
400                                        struct sk_buff *skb,
401                                        u32 *_abort_code)
402 {
403         const struct rxrpc_key_payload *payload;
404         struct rxkad_level2_hdr sechdr;
405         struct rxrpc_skb_priv *sp;
406         struct blkcipher_desc desc;
407         struct rxrpc_crypt iv;
408         struct scatterlist _sg[4], *sg;
409         struct sk_buff *trailer;
410         u32 data_size, buf;
411         u16 check;
412         int nsg;
413
414         _enter(",{%d}", skb->len);
415
416         sp = rxrpc_skb(skb);
417
418         /* we want to decrypt the skbuff in-place */
419         nsg = skb_cow_data(skb, 0, &trailer);
420         if (nsg < 0)
421                 goto nomem;
422
423         sg = _sg;
424         if (unlikely(nsg > 4)) {
425                 sg = kmalloc(sizeof(*sg) * nsg, GFP_NOIO);
426                 if (!sg)
427                         goto nomem;
428         }
429
430         sg_init_table(sg, nsg);
431         skb_to_sgvec(skb, sg, 0, skb->len);
432
433         /* decrypt from the session key */
434         payload = call->conn->key->payload.data;
435         memcpy(&iv, payload->k.session_key, sizeof(iv));
436         desc.tfm = call->conn->cipher;
437         desc.info = iv.x;
438         desc.flags = 0;
439
440         crypto_blkcipher_decrypt_iv(&desc, sg, sg, skb->len);
441         if (sg != _sg)
442                 kfree(sg);
443
444         /* remove the decrypted packet length */
445         if (skb_copy_bits(skb, 0, &sechdr, sizeof(sechdr)) < 0)
446                 goto datalen_error;
447         if (!skb_pull(skb, sizeof(sechdr)))
448                 BUG();
449
450         buf = ntohl(sechdr.data_size);
451         data_size = buf & 0xffff;
452
453         check = buf >> 16;
454         check ^= ntohl(sp->hdr.seq ^ sp->hdr.callNumber);
455         check &= 0xffff;
456         if (check != 0) {
457                 *_abort_code = RXKADSEALEDINCON;
458                 goto protocol_error;
459         }
460
461         /* shorten the packet to remove the padding */
462         if (data_size > skb->len)
463                 goto datalen_error;
464         else if (data_size < skb->len)
465                 skb->len = data_size;
466
467         _leave(" = 0 [dlen=%x]", data_size);
468         return 0;
469
470 datalen_error:
471         *_abort_code = RXKADDATALEN;
472 protocol_error:
473         _leave(" = -EPROTO");
474         return -EPROTO;
475
476 nomem:
477         _leave(" = -ENOMEM");
478         return -ENOMEM;
479 }
480
481 /*
482  * verify the security on a received packet
483  */
484 static int rxkad_verify_packet(const struct rxrpc_call *call,
485                                struct sk_buff *skb,
486                                u32 *_abort_code)
487 {
488         struct blkcipher_desc desc;
489         struct rxrpc_skb_priv *sp;
490         struct rxrpc_crypt iv;
491         struct scatterlist sg[2];
492         struct {
493                 __be32 x[2];
494         } tmpbuf __attribute__((aligned(8))); /* must all be in same page */
495         __be32 x;
496         __be16 cksum;
497         u32 y;
498         int ret;
499
500         sp = rxrpc_skb(skb);
501
502         _enter("{%d{%x}},{#%u}",
503                call->debug_id, key_serial(call->conn->key),
504                ntohl(sp->hdr.seq));
505
506         if (!call->conn->cipher)
507                 return 0;
508
509         if (sp->hdr.securityIndex != 2) {
510                 *_abort_code = RXKADINCONSISTENCY;
511                 _leave(" = -EPROTO [not rxkad]");
512                 return -EPROTO;
513         }
514
515         /* continue encrypting from where we left off */
516         memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
517         desc.tfm = call->conn->cipher;
518         desc.info = iv.x;
519         desc.flags = 0;
520
521         /* validate the security checksum */
522         x = htonl(call->channel << (32 - RXRPC_CIDSHIFT));
523         x |= sp->hdr.seq & cpu_to_be32(0x3fffffff);
524         tmpbuf.x[0] = call->call_id;
525         tmpbuf.x[1] = x;
526
527         sg_init_one(&sg[0], &tmpbuf, sizeof(tmpbuf));
528         sg_init_one(&sg[1], &tmpbuf, sizeof(tmpbuf));
529         crypto_blkcipher_encrypt_iv(&desc, &sg[0], &sg[1], sizeof(tmpbuf));
530
531         y = ntohl(tmpbuf.x[1]);
532         y = (y >> 16) & 0xffff;
533         if (y == 0)
534                 y = 1; /* zero checksums are not permitted */
535
536         cksum = htons(y);
537         if (sp->hdr.cksum != cksum) {
538                 *_abort_code = RXKADSEALEDINCON;
539                 _leave(" = -EPROTO [csum failed]");
540                 return -EPROTO;
541         }
542
543         switch (call->conn->security_level) {
544         case RXRPC_SECURITY_PLAIN:
545                 ret = 0;
546                 break;
547         case RXRPC_SECURITY_AUTH:
548                 ret = rxkad_verify_packet_auth(call, skb, _abort_code);
549                 break;
550         case RXRPC_SECURITY_ENCRYPT:
551                 ret = rxkad_verify_packet_encrypt(call, skb, _abort_code);
552                 break;
553         default:
554                 ret = -ENOANO;
555                 break;
556         }
557
558         _leave(" = %d", ret);
559         return ret;
560 }
561
562 /*
563  * issue a challenge
564  */
565 static int rxkad_issue_challenge(struct rxrpc_connection *conn)
566 {
567         struct rxkad_challenge challenge;
568         struct rxrpc_header hdr;
569         struct msghdr msg;
570         struct kvec iov[2];
571         size_t len;
572         int ret;
573
574         _enter("{%d,%x}", conn->debug_id, key_serial(conn->key));
575
576         ret = key_validate(conn->key);
577         if (ret < 0)
578                 return ret;
579
580         get_random_bytes(&conn->security_nonce, sizeof(conn->security_nonce));
581
582         challenge.version       = htonl(2);
583         challenge.nonce         = htonl(conn->security_nonce);
584         challenge.min_level     = htonl(0);
585         challenge.__padding     = 0;
586
587         msg.msg_name    = &conn->trans->peer->srx.transport.sin;
588         msg.msg_namelen = sizeof(conn->trans->peer->srx.transport.sin);
589         msg.msg_control = NULL;
590         msg.msg_controllen = 0;
591         msg.msg_flags   = 0;
592
593         hdr.epoch       = conn->epoch;
594         hdr.cid         = conn->cid;
595         hdr.callNumber  = 0;
596         hdr.seq         = 0;
597         hdr.type        = RXRPC_PACKET_TYPE_CHALLENGE;
598         hdr.flags       = conn->out_clientflag;
599         hdr.userStatus  = 0;
600         hdr.securityIndex = conn->security_ix;
601         hdr._rsvd       = 0;
602         hdr.serviceId   = conn->service_id;
603
604         iov[0].iov_base = &hdr;
605         iov[0].iov_len  = sizeof(hdr);
606         iov[1].iov_base = &challenge;
607         iov[1].iov_len  = sizeof(challenge);
608
609         len = iov[0].iov_len + iov[1].iov_len;
610
611         hdr.serial = htonl(atomic_inc_return(&conn->serial));
612         _proto("Tx CHALLENGE %%%u", ntohl(hdr.serial));
613
614         ret = kernel_sendmsg(conn->trans->local->socket, &msg, iov, 2, len);
615         if (ret < 0) {
616                 _debug("sendmsg failed: %d", ret);
617                 return -EAGAIN;
618         }
619
620         _leave(" = 0");
621         return 0;
622 }
623
624 /*
625  * send a Kerberos security response
626  */
627 static int rxkad_send_response(struct rxrpc_connection *conn,
628                                struct rxrpc_header *hdr,
629                                struct rxkad_response *resp,
630                                const struct rxkad_key *s2)
631 {
632         struct msghdr msg;
633         struct kvec iov[3];
634         size_t len;
635         int ret;
636
637         _enter("");
638
639         msg.msg_name    = &conn->trans->peer->srx.transport.sin;
640         msg.msg_namelen = sizeof(conn->trans->peer->srx.transport.sin);
641         msg.msg_control = NULL;
642         msg.msg_controllen = 0;
643         msg.msg_flags   = 0;
644
645         hdr->epoch      = conn->epoch;
646         hdr->seq        = 0;
647         hdr->type       = RXRPC_PACKET_TYPE_RESPONSE;
648         hdr->flags      = conn->out_clientflag;
649         hdr->userStatus = 0;
650         hdr->_rsvd      = 0;
651
652         iov[0].iov_base = hdr;
653         iov[0].iov_len  = sizeof(*hdr);
654         iov[1].iov_base = resp;
655         iov[1].iov_len  = sizeof(*resp);
656         iov[2].iov_base = (void *) s2->ticket;
657         iov[2].iov_len  = s2->ticket_len;
658
659         len = iov[0].iov_len + iov[1].iov_len + iov[2].iov_len;
660
661         hdr->serial = htonl(atomic_inc_return(&conn->serial));
662         _proto("Tx RESPONSE %%%u", ntohl(hdr->serial));
663
664         ret = kernel_sendmsg(conn->trans->local->socket, &msg, iov, 3, len);
665         if (ret < 0) {
666                 _debug("sendmsg failed: %d", ret);
667                 return -EAGAIN;
668         }
669
670         _leave(" = 0");
671         return 0;
672 }
673
674 /*
675  * calculate the response checksum
676  */
677 static void rxkad_calc_response_checksum(struct rxkad_response *response)
678 {
679         u32 csum = 1000003;
680         int loop;
681         u8 *p = (u8 *) response;
682
683         for (loop = sizeof(*response); loop > 0; loop--)
684                 csum = csum * 0x10204081 + *p++;
685
686         response->encrypted.checksum = htonl(csum);
687 }
688
689 /*
690  * load a scatterlist with a potentially split-page buffer
691  */
692 static void rxkad_sg_set_buf2(struct scatterlist sg[2],
693                               void *buf, size_t buflen)
694 {
695         int nsg = 1;
696
697         sg_init_table(sg, 2);
698
699         sg_set_buf(&sg[0], buf, buflen);
700         if (sg[0].offset + buflen > PAGE_SIZE) {
701                 /* the buffer was split over two pages */
702                 sg[0].length = PAGE_SIZE - sg[0].offset;
703                 sg_set_buf(&sg[1], buf + sg[0].length, buflen - sg[0].length);
704                 nsg++;
705         }
706
707         sg_mark_end(&sg[nsg - 1]);
708
709         ASSERTCMP(sg[0].length + sg[1].length, ==, buflen);
710 }
711
712 /*
713  * encrypt the response packet
714  */
715 static void rxkad_encrypt_response(struct rxrpc_connection *conn,
716                                    struct rxkad_response *resp,
717                                    const struct rxkad_key *s2)
718 {
719         struct blkcipher_desc desc;
720         struct rxrpc_crypt iv;
721         struct scatterlist sg[2];
722
723         /* continue encrypting from where we left off */
724         memcpy(&iv, s2->session_key, sizeof(iv));
725         desc.tfm = conn->cipher;
726         desc.info = iv.x;
727         desc.flags = 0;
728
729         rxkad_sg_set_buf2(sg, &resp->encrypted, sizeof(resp->encrypted));
730         crypto_blkcipher_encrypt_iv(&desc, sg, sg, sizeof(resp->encrypted));
731 }
732
733 /*
734  * respond to a challenge packet
735  */
736 static int rxkad_respond_to_challenge(struct rxrpc_connection *conn,
737                                       struct sk_buff *skb,
738                                       u32 *_abort_code)
739 {
740         const struct rxrpc_key_payload *payload;
741         struct rxkad_challenge challenge;
742         struct rxkad_response resp
743                 __attribute__((aligned(8))); /* must be aligned for crypto */
744         struct rxrpc_skb_priv *sp;
745         u32 version, nonce, min_level, abort_code;
746         int ret;
747
748         _enter("{%d,%x}", conn->debug_id, key_serial(conn->key));
749
750         if (!conn->key) {
751                 _leave(" = -EPROTO [no key]");
752                 return -EPROTO;
753         }
754
755         ret = key_validate(conn->key);
756         if (ret < 0) {
757                 *_abort_code = RXKADEXPIRED;
758                 return ret;
759         }
760
761         abort_code = RXKADPACKETSHORT;
762         sp = rxrpc_skb(skb);
763         if (skb_copy_bits(skb, 0, &challenge, sizeof(challenge)) < 0)
764                 goto protocol_error;
765
766         version = ntohl(challenge.version);
767         nonce = ntohl(challenge.nonce);
768         min_level = ntohl(challenge.min_level);
769
770         _proto("Rx CHALLENGE %%%u { v=%u n=%u ml=%u }",
771                ntohl(sp->hdr.serial), version, nonce, min_level);
772
773         abort_code = RXKADINCONSISTENCY;
774         if (version != RXKAD_VERSION)
775                 goto protocol_error;
776
777         abort_code = RXKADLEVELFAIL;
778         if (conn->security_level < min_level)
779                 goto protocol_error;
780
781         payload = conn->key->payload.data;
782
783         /* build the response packet */
784         memset(&resp, 0, sizeof(resp));
785
786         resp.version = RXKAD_VERSION;
787         resp.encrypted.epoch = conn->epoch;
788         resp.encrypted.cid = conn->cid;
789         resp.encrypted.securityIndex = htonl(conn->security_ix);
790         resp.encrypted.call_id[0] =
791                 (conn->channels[0] ? conn->channels[0]->call_id : 0);
792         resp.encrypted.call_id[1] =
793                 (conn->channels[1] ? conn->channels[1]->call_id : 0);
794         resp.encrypted.call_id[2] =
795                 (conn->channels[2] ? conn->channels[2]->call_id : 0);
796         resp.encrypted.call_id[3] =
797                 (conn->channels[3] ? conn->channels[3]->call_id : 0);
798         resp.encrypted.inc_nonce = htonl(nonce + 1);
799         resp.encrypted.level = htonl(conn->security_level);
800         resp.kvno = htonl(payload->k.kvno);
801         resp.ticket_len = htonl(payload->k.ticket_len);
802
803         /* calculate the response checksum and then do the encryption */
804         rxkad_calc_response_checksum(&resp);
805         rxkad_encrypt_response(conn, &resp, &payload->k);
806         return rxkad_send_response(conn, &sp->hdr, &resp, &payload->k);
807
808 protocol_error:
809         *_abort_code = abort_code;
810         _leave(" = -EPROTO [%d]", abort_code);
811         return -EPROTO;
812 }
813
814 /*
815  * decrypt the kerberos IV ticket in the response
816  */
817 static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
818                                 void *ticket, size_t ticket_len,
819                                 struct rxrpc_crypt *_session_key,
820                                 time_t *_expiry,
821                                 u32 *_abort_code)
822 {
823         struct blkcipher_desc desc;
824         struct rxrpc_crypt iv, key;
825         struct scatterlist sg[1];
826         struct in_addr addr;
827         unsigned life;
828         time_t issue, now;
829         bool little_endian;
830         int ret;
831         u8 *p, *q, *name, *end;
832
833         _enter("{%d},{%x}", conn->debug_id, key_serial(conn->server_key));
834
835         *_expiry = 0;
836
837         ret = key_validate(conn->server_key);
838         if (ret < 0) {
839                 switch (ret) {
840                 case -EKEYEXPIRED:
841                         *_abort_code = RXKADEXPIRED;
842                         goto error;
843                 default:
844                         *_abort_code = RXKADNOAUTH;
845                         goto error;
846                 }
847         }
848
849         ASSERT(conn->server_key->payload.data != NULL);
850         ASSERTCMP((unsigned long) ticket & 7UL, ==, 0);
851
852         memcpy(&iv, &conn->server_key->type_data, sizeof(iv));
853
854         desc.tfm = conn->server_key->payload.data;
855         desc.info = iv.x;
856         desc.flags = 0;
857
858         sg_init_one(&sg[0], ticket, ticket_len);
859         crypto_blkcipher_decrypt_iv(&desc, sg, sg, ticket_len);
860
861         p = ticket;
862         end = p + ticket_len;
863
864 #define Z(size)                                         \
865         ({                                              \
866                 u8 *__str = p;                          \
867                 q = memchr(p, 0, end - p);              \
868                 if (!q || q - p > (size))               \
869                         goto bad_ticket;                \
870                 for (; p < q; p++)                      \
871                         if (!isprint(*p))               \
872                                 goto bad_ticket;        \
873                 p++;                                    \
874                 __str;                                  \
875         })
876
877         /* extract the ticket flags */
878         _debug("KIV FLAGS: %x", *p);
879         little_endian = *p & 1;
880         p++;
881
882         /* extract the authentication name */
883         name = Z(ANAME_SZ);
884         _debug("KIV ANAME: %s", name);
885
886         /* extract the principal's instance */
887         name = Z(INST_SZ);
888         _debug("KIV INST : %s", name);
889
890         /* extract the principal's authentication domain */
891         name = Z(REALM_SZ);
892         _debug("KIV REALM: %s", name);
893
894         if (end - p < 4 + 8 + 4 + 2)
895                 goto bad_ticket;
896
897         /* get the IPv4 address of the entity that requested the ticket */
898         memcpy(&addr, p, sizeof(addr));
899         p += 4;
900         _debug("KIV ADDR : "NIPQUAD_FMT, NIPQUAD(addr));
901
902         /* get the session key from the ticket */
903         memcpy(&key, p, sizeof(key));
904         p += 8;
905         _debug("KIV KEY  : %08x %08x", ntohl(key.n[0]), ntohl(key.n[1]));
906         memcpy(_session_key, &key, sizeof(key));
907
908         /* get the ticket's lifetime */
909         life = *p++ * 5 * 60;
910         _debug("KIV LIFE : %u", life);
911
912         /* get the issue time of the ticket */
913         if (little_endian) {
914                 __le32 stamp;
915                 memcpy(&stamp, p, 4);
916                 issue = le32_to_cpu(stamp);
917         } else {
918                 __be32 stamp;
919                 memcpy(&stamp, p, 4);
920                 issue = be32_to_cpu(stamp);
921         }
922         p += 4;
923         now = get_seconds();
924         _debug("KIV ISSUE: %lx [%lx]", issue, now);
925
926         /* check the ticket is in date */
927         if (issue > now) {
928                 *_abort_code = RXKADNOAUTH;
929                 ret = -EKEYREJECTED;
930                 goto error;
931         }
932
933         if (issue < now - life) {
934                 *_abort_code = RXKADEXPIRED;
935                 ret = -EKEYEXPIRED;
936                 goto error;
937         }
938
939         *_expiry = issue + life;
940
941         /* get the service name */
942         name = Z(SNAME_SZ);
943         _debug("KIV SNAME: %s", name);
944
945         /* get the service instance name */
946         name = Z(INST_SZ);
947         _debug("KIV SINST: %s", name);
948
949         ret = 0;
950 error:
951         _leave(" = %d", ret);
952         return ret;
953
954 bad_ticket:
955         *_abort_code = RXKADBADTICKET;
956         ret = -EBADMSG;
957         goto error;
958 }
959
960 /*
961  * decrypt the response packet
962  */
963 static void rxkad_decrypt_response(struct rxrpc_connection *conn,
964                                    struct rxkad_response *resp,
965                                    const struct rxrpc_crypt *session_key)
966 {
967         struct blkcipher_desc desc;
968         struct scatterlist sg[2];
969         struct rxrpc_crypt iv;
970
971         _enter(",,%08x%08x",
972                ntohl(session_key->n[0]), ntohl(session_key->n[1]));
973
974         ASSERT(rxkad_ci != NULL);
975
976         mutex_lock(&rxkad_ci_mutex);
977         if (crypto_blkcipher_setkey(rxkad_ci, session_key->x,
978                                     sizeof(*session_key)) < 0)
979                 BUG();
980
981         memcpy(&iv, session_key, sizeof(iv));
982         desc.tfm = rxkad_ci;
983         desc.info = iv.x;
984         desc.flags = 0;
985
986         rxkad_sg_set_buf2(sg, &resp->encrypted, sizeof(resp->encrypted));
987         crypto_blkcipher_decrypt_iv(&desc, sg, sg, sizeof(resp->encrypted));
988         mutex_unlock(&rxkad_ci_mutex);
989
990         _leave("");
991 }
992
993 /*
994  * verify a response
995  */
996 static int rxkad_verify_response(struct rxrpc_connection *conn,
997                                  struct sk_buff *skb,
998                                  u32 *_abort_code)
999 {
1000         struct rxkad_response response
1001                 __attribute__((aligned(8))); /* must be aligned for crypto */
1002         struct rxrpc_skb_priv *sp;
1003         struct rxrpc_crypt session_key;
1004         time_t expiry;
1005         void *ticket;
1006         u32 abort_code, version, kvno, ticket_len, level;
1007         __be32 csum;
1008         int ret;
1009
1010         _enter("{%d,%x}", conn->debug_id, key_serial(conn->server_key));
1011
1012         abort_code = RXKADPACKETSHORT;
1013         if (skb_copy_bits(skb, 0, &response, sizeof(response)) < 0)
1014                 goto protocol_error;
1015         if (!pskb_pull(skb, sizeof(response)))
1016                 BUG();
1017
1018         version = ntohl(response.version);
1019         ticket_len = ntohl(response.ticket_len);
1020         kvno = ntohl(response.kvno);
1021         sp = rxrpc_skb(skb);
1022         _proto("Rx RESPONSE %%%u { v=%u kv=%u tl=%u }",
1023                ntohl(sp->hdr.serial), version, kvno, ticket_len);
1024
1025         abort_code = RXKADINCONSISTENCY;
1026         if (version != RXKAD_VERSION)
1027                 goto protocol_error;
1028
1029         abort_code = RXKADTICKETLEN;
1030         if (ticket_len < 4 || ticket_len > MAXKRB5TICKETLEN)
1031                 goto protocol_error;
1032
1033         abort_code = RXKADUNKNOWNKEY;
1034         if (kvno >= RXKAD_TKT_TYPE_KERBEROS_V5)
1035                 goto protocol_error;
1036
1037         /* extract the kerberos ticket and decrypt and decode it */
1038         ticket = kmalloc(ticket_len, GFP_NOFS);
1039         if (!ticket)
1040                 return -ENOMEM;
1041
1042         abort_code = RXKADPACKETSHORT;
1043         if (skb_copy_bits(skb, 0, ticket, ticket_len) < 0)
1044                 goto protocol_error_free;
1045
1046         ret = rxkad_decrypt_ticket(conn, ticket, ticket_len, &session_key,
1047                                    &expiry, &abort_code);
1048         if (ret < 0) {
1049                 *_abort_code = abort_code;
1050                 kfree(ticket);
1051                 return ret;
1052         }
1053
1054         /* use the session key from inside the ticket to decrypt the
1055          * response */
1056         rxkad_decrypt_response(conn, &response, &session_key);
1057
1058         abort_code = RXKADSEALEDINCON;
1059         if (response.encrypted.epoch != conn->epoch)
1060                 goto protocol_error_free;
1061         if (response.encrypted.cid != conn->cid)
1062                 goto protocol_error_free;
1063         if (ntohl(response.encrypted.securityIndex) != conn->security_ix)
1064                 goto protocol_error_free;
1065         csum = response.encrypted.checksum;
1066         response.encrypted.checksum = 0;
1067         rxkad_calc_response_checksum(&response);
1068         if (response.encrypted.checksum != csum)
1069                 goto protocol_error_free;
1070
1071         if (ntohl(response.encrypted.call_id[0]) > INT_MAX ||
1072             ntohl(response.encrypted.call_id[1]) > INT_MAX ||
1073             ntohl(response.encrypted.call_id[2]) > INT_MAX ||
1074             ntohl(response.encrypted.call_id[3]) > INT_MAX)
1075                 goto protocol_error_free;
1076
1077         abort_code = RXKADOUTOFSEQUENCE;
1078         if (response.encrypted.inc_nonce != htonl(conn->security_nonce + 1))
1079                 goto protocol_error_free;
1080
1081         abort_code = RXKADLEVELFAIL;
1082         level = ntohl(response.encrypted.level);
1083         if (level > RXRPC_SECURITY_ENCRYPT)
1084                 goto protocol_error_free;
1085         conn->security_level = level;
1086
1087         /* create a key to hold the security data and expiration time - after
1088          * this the connection security can be handled in exactly the same way
1089          * as for a client connection */
1090         ret = rxrpc_get_server_data_key(conn, &session_key, expiry, kvno);
1091         if (ret < 0) {
1092                 kfree(ticket);
1093                 return ret;
1094         }
1095
1096         kfree(ticket);
1097         _leave(" = 0");
1098         return 0;
1099
1100 protocol_error_free:
1101         kfree(ticket);
1102 protocol_error:
1103         *_abort_code = abort_code;
1104         _leave(" = -EPROTO [%d]", abort_code);
1105         return -EPROTO;
1106 }
1107
1108 /*
1109  * clear the connection security
1110  */
1111 static void rxkad_clear(struct rxrpc_connection *conn)
1112 {
1113         _enter("");
1114
1115         if (conn->cipher)
1116                 crypto_free_blkcipher(conn->cipher);
1117 }
1118
1119 /*
1120  * RxRPC Kerberos-based security
1121  */
1122 static struct rxrpc_security rxkad = {
1123         .owner                          = THIS_MODULE,
1124         .name                           = "rxkad",
1125         .security_index                 = RXKAD_VERSION,
1126         .init_connection_security       = rxkad_init_connection_security,
1127         .prime_packet_security          = rxkad_prime_packet_security,
1128         .secure_packet                  = rxkad_secure_packet,
1129         .verify_packet                  = rxkad_verify_packet,
1130         .issue_challenge                = rxkad_issue_challenge,
1131         .respond_to_challenge           = rxkad_respond_to_challenge,
1132         .verify_response                = rxkad_verify_response,
1133         .clear                          = rxkad_clear,
1134 };
1135
1136 static __init int rxkad_init(void)
1137 {
1138         _enter("");
1139
1140         /* pin the cipher we need so that the crypto layer doesn't invoke
1141          * keventd to go get it */
1142         rxkad_ci = crypto_alloc_blkcipher("pcbc(fcrypt)", 0, CRYPTO_ALG_ASYNC);
1143         if (IS_ERR(rxkad_ci))
1144                 return PTR_ERR(rxkad_ci);
1145
1146         return rxrpc_register_security(&rxkad);
1147 }
1148
1149 module_init(rxkad_init);
1150
1151 static __exit void rxkad_exit(void)
1152 {
1153         _enter("");
1154
1155         rxrpc_unregister_security(&rxkad);
1156         crypto_free_blkcipher(rxkad_ci);
1157 }
1158
1159 module_exit(rxkad_exit);