[CRYPTO] cbc: Require block size to be a power of 2
[linux-2.6] / crypto / pcbc.c
1 /*
2  * PCBC: Propagating Cipher Block Chaining mode
3  *
4  * Copyright (C) 2006 Red Hat, Inc. All Rights Reserved.
5  * Written by David Howells (dhowells@redhat.com)
6  *
7  * Derived from cbc.c
8  * - Copyright (c) 2006 Herbert Xu <herbert@gondor.apana.org.au>
9  *
10  * This program is free software; you can redistribute it and/or modify it
11  * under the terms of the GNU General Public License as published by the Free
12  * Software Foundation; either version 2 of the License, or (at your option)
13  * any later version.
14  *
15  */
16
17 #include <crypto/algapi.h>
18 #include <linux/err.h>
19 #include <linux/init.h>
20 #include <linux/kernel.h>
21 #include <linux/module.h>
22 #include <linux/scatterlist.h>
23 #include <linux/slab.h>
24
25 struct crypto_pcbc_ctx {
26         struct crypto_cipher *child;
27         void (*xor)(u8 *dst, const u8 *src, unsigned int bs);
28 };
29
30 static int crypto_pcbc_setkey(struct crypto_tfm *parent, const u8 *key,
31                               unsigned int keylen)
32 {
33         struct crypto_pcbc_ctx *ctx = crypto_tfm_ctx(parent);
34         struct crypto_cipher *child = ctx->child;
35         int err;
36
37         crypto_cipher_clear_flags(child, CRYPTO_TFM_REQ_MASK);
38         crypto_cipher_set_flags(child, crypto_tfm_get_flags(parent) &
39                                 CRYPTO_TFM_REQ_MASK);
40         err = crypto_cipher_setkey(child, key, keylen);
41         crypto_tfm_set_flags(parent, crypto_cipher_get_flags(child) &
42                              CRYPTO_TFM_RES_MASK);
43         return err;
44 }
45
46 static int crypto_pcbc_encrypt_segment(struct blkcipher_desc *desc,
47                                        struct blkcipher_walk *walk,
48                                        struct crypto_cipher *tfm,
49                                        void (*xor)(u8 *, const u8 *,
50                                                    unsigned int))
51 {
52         void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
53                 crypto_cipher_alg(tfm)->cia_encrypt;
54         int bsize = crypto_cipher_blocksize(tfm);
55         unsigned int nbytes = walk->nbytes;
56         u8 *src = walk->src.virt.addr;
57         u8 *dst = walk->dst.virt.addr;
58         u8 *iv = walk->iv;
59
60         do {
61                 xor(iv, src, bsize);
62                 fn(crypto_cipher_tfm(tfm), dst, iv);
63                 memcpy(iv, dst, bsize);
64                 xor(iv, src, bsize);
65
66                 src += bsize;
67                 dst += bsize;
68         } while ((nbytes -= bsize) >= bsize);
69
70         return nbytes;
71 }
72
73 static int crypto_pcbc_encrypt_inplace(struct blkcipher_desc *desc,
74                                        struct blkcipher_walk *walk,
75                                        struct crypto_cipher *tfm,
76                                        void (*xor)(u8 *, const u8 *,
77                                                    unsigned int))
78 {
79         void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
80                 crypto_cipher_alg(tfm)->cia_encrypt;
81         int bsize = crypto_cipher_blocksize(tfm);
82         unsigned int nbytes = walk->nbytes;
83         u8 *src = walk->src.virt.addr;
84         u8 *iv = walk->iv;
85         u8 tmpbuf[bsize];
86
87         do {
88                 memcpy(tmpbuf, src, bsize);
89                 xor(iv, tmpbuf, bsize);
90                 fn(crypto_cipher_tfm(tfm), src, iv);
91                 memcpy(iv, src, bsize);
92                 xor(iv, tmpbuf, bsize);
93
94                 src += bsize;
95         } while ((nbytes -= bsize) >= bsize);
96
97         memcpy(walk->iv, iv, bsize);
98
99         return nbytes;
100 }
101
102 static int crypto_pcbc_encrypt(struct blkcipher_desc *desc,
103                                struct scatterlist *dst, struct scatterlist *src,
104                                unsigned int nbytes)
105 {
106         struct blkcipher_walk walk;
107         struct crypto_blkcipher *tfm = desc->tfm;
108         struct crypto_pcbc_ctx *ctx = crypto_blkcipher_ctx(tfm);
109         struct crypto_cipher *child = ctx->child;
110         void (*xor)(u8 *, const u8 *, unsigned int bs) = ctx->xor;
111         int err;
112
113         blkcipher_walk_init(&walk, dst, src, nbytes);
114         err = blkcipher_walk_virt(desc, &walk);
115
116         while ((nbytes = walk.nbytes)) {
117                 if (walk.src.virt.addr == walk.dst.virt.addr)
118                         nbytes = crypto_pcbc_encrypt_inplace(desc, &walk, child,
119                                                              xor);
120                 else
121                         nbytes = crypto_pcbc_encrypt_segment(desc, &walk, child,
122                                                              xor);
123                 err = blkcipher_walk_done(desc, &walk, nbytes);
124         }
125
126         return err;
127 }
128
129 static int crypto_pcbc_decrypt_segment(struct blkcipher_desc *desc,
130                                        struct blkcipher_walk *walk,
131                                        struct crypto_cipher *tfm,
132                                        void (*xor)(u8 *, const u8 *,
133                                                    unsigned int))
134 {
135         void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
136                 crypto_cipher_alg(tfm)->cia_decrypt;
137         int bsize = crypto_cipher_blocksize(tfm);
138         unsigned int nbytes = walk->nbytes;
139         u8 *src = walk->src.virt.addr;
140         u8 *dst = walk->dst.virt.addr;
141         u8 *iv = walk->iv;
142
143         do {
144                 fn(crypto_cipher_tfm(tfm), dst, src);
145                 xor(dst, iv, bsize);
146                 memcpy(iv, src, bsize);
147                 xor(iv, dst, bsize);
148
149                 src += bsize;
150                 dst += bsize;
151         } while ((nbytes -= bsize) >= bsize);
152
153         memcpy(walk->iv, iv, bsize);
154
155         return nbytes;
156 }
157
158 static int crypto_pcbc_decrypt_inplace(struct blkcipher_desc *desc,
159                                        struct blkcipher_walk *walk,
160                                        struct crypto_cipher *tfm,
161                                        void (*xor)(u8 *, const u8 *,
162                                                    unsigned int))
163 {
164         void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
165                 crypto_cipher_alg(tfm)->cia_decrypt;
166         int bsize = crypto_cipher_blocksize(tfm);
167         unsigned int nbytes = walk->nbytes;
168         u8 *src = walk->src.virt.addr;
169         u8 *iv = walk->iv;
170         u8 tmpbuf[bsize];
171
172         do {
173                 memcpy(tmpbuf, src, bsize);
174                 fn(crypto_cipher_tfm(tfm), src, src);
175                 xor(src, iv, bsize);
176                 memcpy(iv, tmpbuf, bsize);
177                 xor(iv, src, bsize);
178
179                 src += bsize;
180         } while ((nbytes -= bsize) >= bsize);
181
182         memcpy(walk->iv, iv, bsize);
183
184         return nbytes;
185 }
186
187 static int crypto_pcbc_decrypt(struct blkcipher_desc *desc,
188                                struct scatterlist *dst, struct scatterlist *src,
189                                unsigned int nbytes)
190 {
191         struct blkcipher_walk walk;
192         struct crypto_blkcipher *tfm = desc->tfm;
193         struct crypto_pcbc_ctx *ctx = crypto_blkcipher_ctx(tfm);
194         struct crypto_cipher *child = ctx->child;
195         void (*xor)(u8 *, const u8 *, unsigned int bs) = ctx->xor;
196         int err;
197
198         blkcipher_walk_init(&walk, dst, src, nbytes);
199         err = blkcipher_walk_virt(desc, &walk);
200
201         while ((nbytes = walk.nbytes)) {
202                 if (walk.src.virt.addr == walk.dst.virt.addr)
203                         nbytes = crypto_pcbc_decrypt_inplace(desc, &walk, child,
204                                                              xor);
205                 else
206                         nbytes = crypto_pcbc_decrypt_segment(desc, &walk, child,
207                                                              xor);
208                 err = blkcipher_walk_done(desc, &walk, nbytes);
209         }
210
211         return err;
212 }
213
214 static void xor_byte(u8 *a, const u8 *b, unsigned int bs)
215 {
216         do {
217                 *a++ ^= *b++;
218         } while (--bs);
219 }
220
221 static void xor_quad(u8 *dst, const u8 *src, unsigned int bs)
222 {
223         u32 *a = (u32 *)dst;
224         u32 *b = (u32 *)src;
225
226         do {
227                 *a++ ^= *b++;
228         } while ((bs -= 4));
229 }
230
231 static void xor_64(u8 *a, const u8 *b, unsigned int bs)
232 {
233         ((u32 *)a)[0] ^= ((u32 *)b)[0];
234         ((u32 *)a)[1] ^= ((u32 *)b)[1];
235 }
236
237 static void xor_128(u8 *a, const u8 *b, unsigned int bs)
238 {
239         ((u32 *)a)[0] ^= ((u32 *)b)[0];
240         ((u32 *)a)[1] ^= ((u32 *)b)[1];
241         ((u32 *)a)[2] ^= ((u32 *)b)[2];
242         ((u32 *)a)[3] ^= ((u32 *)b)[3];
243 }
244
245 static int crypto_pcbc_init_tfm(struct crypto_tfm *tfm)
246 {
247         struct crypto_instance *inst = (void *)tfm->__crt_alg;
248         struct crypto_spawn *spawn = crypto_instance_ctx(inst);
249         struct crypto_pcbc_ctx *ctx = crypto_tfm_ctx(tfm);
250         struct crypto_cipher *cipher;
251
252         switch (crypto_tfm_alg_blocksize(tfm)) {
253         case 8:
254                 ctx->xor = xor_64;
255                 break;
256
257         case 16:
258                 ctx->xor = xor_128;
259                 break;
260
261         default:
262                 if (crypto_tfm_alg_blocksize(tfm) % 4)
263                         ctx->xor = xor_byte;
264                 else
265                         ctx->xor = xor_quad;
266         }
267
268         cipher = crypto_spawn_cipher(spawn);
269         if (IS_ERR(cipher))
270                 return PTR_ERR(cipher);
271
272         ctx->child = cipher;
273         return 0;
274 }
275
276 static void crypto_pcbc_exit_tfm(struct crypto_tfm *tfm)
277 {
278         struct crypto_pcbc_ctx *ctx = crypto_tfm_ctx(tfm);
279         crypto_free_cipher(ctx->child);
280 }
281
282 static struct crypto_instance *crypto_pcbc_alloc(struct rtattr **tb)
283 {
284         struct crypto_instance *inst;
285         struct crypto_alg *alg;
286         int err;
287
288         err = crypto_check_attr_type(tb, CRYPTO_ALG_TYPE_BLKCIPHER);
289         if (err)
290                 return ERR_PTR(err);
291
292         alg = crypto_get_attr_alg(tb, CRYPTO_ALG_TYPE_CIPHER,
293                                   CRYPTO_ALG_TYPE_MASK);
294         if (IS_ERR(alg))
295                 return ERR_PTR(PTR_ERR(alg));
296
297         inst = crypto_alloc_instance("pcbc", alg);
298         if (IS_ERR(inst))
299                 goto out_put_alg;
300
301         inst->alg.cra_flags = CRYPTO_ALG_TYPE_BLKCIPHER;
302         inst->alg.cra_priority = alg->cra_priority;
303         inst->alg.cra_blocksize = alg->cra_blocksize;
304         inst->alg.cra_alignmask = alg->cra_alignmask;
305         inst->alg.cra_type = &crypto_blkcipher_type;
306
307         if (!(alg->cra_blocksize % 4))
308                 inst->alg.cra_alignmask |= 3;
309         inst->alg.cra_blkcipher.ivsize = alg->cra_blocksize;
310         inst->alg.cra_blkcipher.min_keysize = alg->cra_cipher.cia_min_keysize;
311         inst->alg.cra_blkcipher.max_keysize = alg->cra_cipher.cia_max_keysize;
312
313         inst->alg.cra_ctxsize = sizeof(struct crypto_pcbc_ctx);
314
315         inst->alg.cra_init = crypto_pcbc_init_tfm;
316         inst->alg.cra_exit = crypto_pcbc_exit_tfm;
317
318         inst->alg.cra_blkcipher.setkey = crypto_pcbc_setkey;
319         inst->alg.cra_blkcipher.encrypt = crypto_pcbc_encrypt;
320         inst->alg.cra_blkcipher.decrypt = crypto_pcbc_decrypt;
321
322 out_put_alg:
323         crypto_mod_put(alg);
324         return inst;
325 }
326
327 static void crypto_pcbc_free(struct crypto_instance *inst)
328 {
329         crypto_drop_spawn(crypto_instance_ctx(inst));
330         kfree(inst);
331 }
332
333 static struct crypto_template crypto_pcbc_tmpl = {
334         .name = "pcbc",
335         .alloc = crypto_pcbc_alloc,
336         .free = crypto_pcbc_free,
337         .module = THIS_MODULE,
338 };
339
340 static int __init crypto_pcbc_module_init(void)
341 {
342         return crypto_register_template(&crypto_pcbc_tmpl);
343 }
344
345 static void __exit crypto_pcbc_module_exit(void)
346 {
347         crypto_unregister_template(&crypto_pcbc_tmpl);
348 }
349
350 module_init(crypto_pcbc_module_init);
351 module_exit(crypto_pcbc_module_exit);
352
353 MODULE_LICENSE("GPL");
354 MODULE_DESCRIPTION("PCBC block cipher algorithm");