rsaenh: Declare some functions static.
[wine] / dlls / rsaenh / mpi.c
1 /*
2  * dlls/rsaenh/mpi.c
3  * Multi Precision Integer functions
4  *
5  * Copyright 2004 Michael Jung
6  * Based on public domain code by Tom St Denis (tomstdenis@iahu.ca)
7  *
8  * This library is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * This library is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with this library; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
21  */
22
23 /*
24  * This file contains code from the LibTomCrypt cryptographic 
25  * library written by Tom St Denis (tomstdenis@iahu.ca). LibTomCrypt
26  * is in the public domain. The code in this file is tailored to
27  * special requirements. Take a look at http://libtomcrypt.org for the
28  * original version. 
29  */
30
31 #include <stdarg.h>
32 #include "tomcrypt.h"
33
34 /* Known optimal configurations
35  CPU                    /Compiler     /MUL CUTOFF/SQR CUTOFF
36 -------------------------------------------------------------
37  Intel P4 Northwood     /GCC v3.4.1   /        88/       128/LTM 0.32 ;-)
38 */
39 static const int KARATSUBA_MUL_CUTOFF = 88,  /* Min. number of digits before Karatsuba multiplication is used. */
40                  KARATSUBA_SQR_CUTOFF = 128; /* Min. number of digits before Karatsuba squaring is used. */
41
42 static void bn_reverse(unsigned char *s, int len);
43 static int s_mp_add(mp_int *a, mp_int *b, mp_int *c);
44 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y);
45 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
46 static int s_mp_mul_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
47 static int s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
48 static int s_mp_sqr(const mp_int *a, mp_int *b);
49 static int s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c);
50 static int mp_exptmod_fast(const mp_int *G, const mp_int *X, mp_int *P, mp_int *Y, int mode);
51 static int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c);
52 static int mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c);
53 static int mp_karatsuba_sqr(const mp_int *a, mp_int *b);
54
55 /* b = a/2 */
56 static int mp_div_2(const mp_int * a, mp_int * b)
57 {
58   int     x, res, oldused;
59
60   /* copy */
61   if (b->alloc < a->used) {
62     if ((res = mp_grow (b, a->used)) != MP_OKAY) {
63       return res;
64     }
65   }
66
67   oldused = b->used;
68   b->used = a->used;
69   {
70     register mp_digit r, rr, *tmpa, *tmpb;
71
72     /* source alias */
73     tmpa = a->dp + b->used - 1;
74
75     /* dest alias */
76     tmpb = b->dp + b->used - 1;
77
78     /* carry */
79     r = 0;
80     for (x = b->used - 1; x >= 0; x--) {
81       /* get the carry for the next iteration */
82       rr = *tmpa & 1;
83
84       /* shift the current digit, add in carry and store */
85       *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
86
87       /* forward carry to next iteration */
88       r = rr;
89     }
90
91     /* zero excess digits */
92     tmpb = b->dp + b->used;
93     for (x = b->used; x < oldused; x++) {
94       *tmpb++ = 0;
95     }
96   }
97   b->sign = a->sign;
98   mp_clamp (b);
99   return MP_OKAY;
100 }
101
102 /* computes the modular inverse via binary extended euclidean algorithm, 
103  * that is c = 1/a mod b 
104  *
105  * Based on slow invmod except this is optimized for the case where b is 
106  * odd as per HAC Note 14.64 on pp. 610
107  */
108 static int
109 fast_mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
110 {
111   mp_int  x, y, u, v, B, D;
112   int     res, neg;
113
114   /* 2. [modified] b must be odd   */
115   if (mp_iseven (b) == 1) {
116     return MP_VAL;
117   }
118
119   /* init all our temps */
120   if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D, NULL)) != MP_OKAY) {
121      return res;
122   }
123
124   /* x == modulus, y == value to invert */
125   if ((res = mp_copy (b, &x)) != MP_OKAY) {
126     goto __ERR;
127   }
128
129   /* we need y = |a| */
130   if ((res = mp_abs (a, &y)) != MP_OKAY) {
131     goto __ERR;
132   }
133
134   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
135   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
136     goto __ERR;
137   }
138   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
139     goto __ERR;
140   }
141   mp_set (&D, 1);
142
143 top:
144   /* 4.  while u is even do */
145   while (mp_iseven (&u) == 1) {
146     /* 4.1 u = u/2 */
147     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
148       goto __ERR;
149     }
150     /* 4.2 if B is odd then */
151     if (mp_isodd (&B) == 1) {
152       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
153         goto __ERR;
154       }
155     }
156     /* B = B/2 */
157     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
158       goto __ERR;
159     }
160   }
161
162   /* 5.  while v is even do */
163   while (mp_iseven (&v) == 1) {
164     /* 5.1 v = v/2 */
165     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
166       goto __ERR;
167     }
168     /* 5.2 if D is odd then */
169     if (mp_isodd (&D) == 1) {
170       /* D = (D-x)/2 */
171       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
172         goto __ERR;
173       }
174     }
175     /* D = D/2 */
176     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
177       goto __ERR;
178     }
179   }
180
181   /* 6.  if u >= v then */
182   if (mp_cmp (&u, &v) != MP_LT) {
183     /* u = u - v, B = B - D */
184     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
185       goto __ERR;
186     }
187
188     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
189       goto __ERR;
190     }
191   } else {
192     /* v - v - u, D = D - B */
193     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
194       goto __ERR;
195     }
196
197     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
198       goto __ERR;
199     }
200   }
201
202   /* if not zero goto step 4 */
203   if (mp_iszero (&u) == 0) {
204     goto top;
205   }
206
207   /* now a = C, b = D, gcd == g*v */
208
209   /* if v != 1 then there is no inverse */
210   if (mp_cmp_d (&v, 1) != MP_EQ) {
211     res = MP_VAL;
212     goto __ERR;
213   }
214
215   /* b is now the inverse */
216   neg = a->sign;
217   while (D.sign == MP_NEG) {
218     if ((res = mp_add (&D, b, &D)) != MP_OKAY) {
219       goto __ERR;
220     }
221   }
222   mp_exch (&D, c);
223   c->sign = neg;
224   res = MP_OKAY;
225
226 __ERR:mp_clear_multi (&x, &y, &u, &v, &B, &D, NULL);
227   return res;
228 }
229
230 /* computes xR**-1 == x (mod N) via Montgomery Reduction
231  *
232  * This is an optimized implementation of montgomery_reduce
233  * which uses the comba method to quickly calculate the columns of the
234  * reduction.
235  *
236  * Based on Algorithm 14.32 on pp.601 of HAC.
237 */
238 static int
239 fast_mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
240 {
241   int     ix, res, olduse;
242   mp_word W[MP_WARRAY];
243
244   /* get old used count */
245   olduse = x->used;
246
247   /* grow a as required */
248   if (x->alloc < n->used + 1) {
249     if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
250       return res;
251     }
252   }
253
254   /* first we have to get the digits of the input into
255    * an array of double precision words W[...]
256    */
257   {
258     register mp_word *_W;
259     register mp_digit *tmpx;
260
261     /* alias for the W[] array */
262     _W   = W;
263
264     /* alias for the digits of  x*/
265     tmpx = x->dp;
266
267     /* copy the digits of a into W[0..a->used-1] */
268     for (ix = 0; ix < x->used; ix++) {
269       *_W++ = *tmpx++;
270     }
271
272     /* zero the high words of W[a->used..m->used*2] */
273     for (; ix < n->used * 2 + 1; ix++) {
274       *_W++ = 0;
275     }
276   }
277
278   /* now we proceed to zero successive digits
279    * from the least significant upwards
280    */
281   for (ix = 0; ix < n->used; ix++) {
282     /* mu = ai * m' mod b
283      *
284      * We avoid a double precision multiplication (which isn't required)
285      * by casting the value down to a mp_digit.  Note this requires
286      * that W[ix-1] have  the carry cleared (see after the inner loop)
287      */
288     register mp_digit mu;
289     mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
290
291     /* a = a + mu * m * b**i
292      *
293      * This is computed in place and on the fly.  The multiplication
294      * by b**i is handled by offsetting which columns the results
295      * are added to.
296      *
297      * Note the comba method normally doesn't handle carries in the
298      * inner loop In this case we fix the carry from the previous
299      * column since the Montgomery reduction requires digits of the
300      * result (so far) [see above] to work.  This is
301      * handled by fixing up one carry after the inner loop.  The
302      * carry fixups are done in order so after these loops the
303      * first m->used words of W[] have the carries fixed
304      */
305     {
306       register int iy;
307       register mp_digit *tmpn;
308       register mp_word *_W;
309
310       /* alias for the digits of the modulus */
311       tmpn = n->dp;
312
313       /* Alias for the columns set by an offset of ix */
314       _W = W + ix;
315
316       /* inner loop */
317       for (iy = 0; iy < n->used; iy++) {
318           *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
319       }
320     }
321
322     /* now fix carry for next digit, W[ix+1] */
323     W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
324   }
325
326   /* now we have to propagate the carries and
327    * shift the words downward [all those least
328    * significant digits we zeroed].
329    */
330   {
331     register mp_digit *tmpx;
332     register mp_word *_W, *_W1;
333
334     /* nox fix rest of carries */
335
336     /* alias for current word */
337     _W1 = W + ix;
338
339     /* alias for next word, where the carry goes */
340     _W = W + ++ix;
341
342     for (; ix <= n->used * 2 + 1; ix++) {
343       *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
344     }
345
346     /* copy out, A = A/b**n
347      *
348      * The result is A/b**n but instead of converting from an
349      * array of mp_word to mp_digit than calling mp_rshd
350      * we just copy them in the right order
351      */
352
353     /* alias for destination word */
354     tmpx = x->dp;
355
356     /* alias for shifted double precision result */
357     _W = W + n->used;
358
359     for (ix = 0; ix < n->used + 1; ix++) {
360       *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
361     }
362
363     /* zero oldused digits, if the input a was larger than
364      * m->used+1 we'll have to clear the digits
365      */
366     for (; ix < olduse; ix++) {
367       *tmpx++ = 0;
368     }
369   }
370
371   /* set the max used and clamp */
372   x->used = n->used + 1;
373   mp_clamp (x);
374
375   /* if A >= m then A = A - m */
376   if (mp_cmp_mag (x, n) != MP_LT) {
377     return s_mp_sub (x, n, x);
378   }
379   return MP_OKAY;
380 }
381
382 /* Fast (comba) multiplier
383  *
384  * This is the fast column-array [comba] multiplier.  It is 
385  * designed to compute the columns of the product first 
386  * then handle the carries afterwards.  This has the effect 
387  * of making the nested loops that compute the columns very
388  * simple and schedulable on super-scalar processors.
389  *
390  * This has been modified to produce a variable number of 
391  * digits of output so if say only a half-product is required 
392  * you don't have to compute the upper half (a feature 
393  * required for fast Barrett reduction).
394  *
395  * Based on Algorithm 14.12 on pp.595 of HAC.
396  *
397  */
398 static int
399 fast_s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
400 {
401   int     olduse, res, pa, ix, iz;
402   mp_digit W[MP_WARRAY];
403   register mp_word  _W;
404
405   /* grow the destination as required */
406   if (c->alloc < digs) {
407     if ((res = mp_grow (c, digs)) != MP_OKAY) {
408       return res;
409     }
410   }
411
412   /* number of output digits to produce */
413   pa = MIN(digs, a->used + b->used);
414
415   /* clear the carry */
416   _W = 0;
417   for (ix = 0; ix <= pa; ix++) { 
418       int      tx, ty;
419       int      iy;
420       mp_digit *tmpx, *tmpy;
421
422       /* get offsets into the two bignums */
423       ty = MIN(b->used-1, ix);
424       tx = ix - ty;
425
426       /* setup temp aliases */
427       tmpx = a->dp + tx;
428       tmpy = b->dp + ty;
429
430       /* This is the number of times the loop will iterate, essentially it's
431          while (tx++ < a->used && ty-- >= 0) { ... }
432        */
433       iy = MIN(a->used-tx, ty+1);
434
435       /* execute loop */
436       for (iz = 0; iz < iy; ++iz) {
437          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
438       }
439
440       /* store term */
441       W[ix] = ((mp_digit)_W) & MP_MASK;
442
443       /* make next carry */
444       _W = _W >> ((mp_word)DIGIT_BIT);
445   }
446
447   /* setup dest */
448   olduse  = c->used;
449   c->used = digs;
450
451   {
452     register mp_digit *tmpc;
453     tmpc = c->dp;
454     for (ix = 0; ix < digs; ix++) {
455       /* now extract the previous digit [below the carry] */
456       *tmpc++ = W[ix];
457     }
458
459     /* clear unused digits [that existed in the old copy of c] */
460     for (; ix < olduse; ix++) {
461       *tmpc++ = 0;
462     }
463   }
464   mp_clamp (c);
465   return MP_OKAY;
466 }
467
468 /* this is a modified version of fast_s_mul_digs that only produces
469  * output digits *above* digs.  See the comments for fast_s_mul_digs
470  * to see how it works.
471  *
472  * This is used in the Barrett reduction since for one of the multiplications
473  * only the higher digits were needed.  This essentially halves the work.
474  *
475  * Based on Algorithm 14.12 on pp.595 of HAC.
476  */
477 static int
478 fast_s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
479 {
480   int     olduse, res, pa, ix, iz;
481   mp_digit W[MP_WARRAY];
482   mp_word  _W;
483
484   /* grow the destination as required */
485   pa = a->used + b->used;
486   if (c->alloc < pa) {
487     if ((res = mp_grow (c, pa)) != MP_OKAY) {
488       return res;
489     }
490   }
491
492   /* number of output digits to produce */
493   pa = a->used + b->used;
494   _W = 0;
495   for (ix = digs; ix <= pa; ix++) { 
496       int      tx, ty, iy;
497       mp_digit *tmpx, *tmpy;
498
499       /* get offsets into the two bignums */
500       ty = MIN(b->used-1, ix);
501       tx = ix - ty;
502
503       /* setup temp aliases */
504       tmpx = a->dp + tx;
505       tmpy = b->dp + ty;
506
507       /* This is the number of times the loop will iterate, essentially it's
508          while (tx++ < a->used && ty-- >= 0) { ... }
509        */
510       iy = MIN(a->used-tx, ty+1);
511
512       /* execute loop */
513       for (iz = 0; iz < iy; iz++) {
514          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
515       }
516
517       /* store term */
518       W[ix] = ((mp_digit)_W) & MP_MASK;
519
520       /* make next carry */
521       _W = _W >> ((mp_word)DIGIT_BIT);
522   }
523
524   /* setup dest */
525   olduse  = c->used;
526   c->used = pa;
527
528   {
529     register mp_digit *tmpc;
530
531     tmpc = c->dp + digs;
532     for (ix = digs; ix <= pa; ix++) {
533       /* now extract the previous digit [below the carry] */
534       *tmpc++ = W[ix];
535     }
536
537     /* clear unused digits [that existed in the old copy of c] */
538     for (; ix < olduse; ix++) {
539       *tmpc++ = 0;
540     }
541   }
542   mp_clamp (c);
543   return MP_OKAY;
544 }
545
546 /* fast squaring
547  *
548  * This is the comba method where the columns of the product
549  * are computed first then the carries are computed.  This
550  * has the effect of making a very simple inner loop that
551  * is executed the most
552  *
553  * W2 represents the outer products and W the inner.
554  *
555  * A further optimizations is made because the inner
556  * products are of the form "A * B * 2".  The *2 part does
557  * not need to be computed until the end which is good
558  * because 64-bit shifts are slow!
559  *
560  * Based on Algorithm 14.16 on pp.597 of HAC.
561  *
562  */
563 /* the jist of squaring...
564
565 you do like mult except the offset of the tmpx [one that starts closer to zero]
566 can't equal the offset of tmpy.  So basically you set up iy like before then you min it with
567 (ty-tx) so that it never happens.  You double all those you add in the inner loop
568
569 After that loop you do the squares and add them in.
570
571 Remove W2 and don't memset W
572
573 */
574
575 static int fast_s_mp_sqr (const mp_int * a, mp_int * b)
576 {
577   int       olduse, res, pa, ix, iz;
578   mp_digit   W[MP_WARRAY], *tmpx;
579   mp_word   W1;
580
581   /* grow the destination as required */
582   pa = a->used + a->used;
583   if (b->alloc < pa) {
584     if ((res = mp_grow (b, pa)) != MP_OKAY) {
585       return res;
586     }
587   }
588
589   /* number of output digits to produce */
590   W1 = 0;
591   for (ix = 0; ix <= pa; ix++) { 
592       int      tx, ty, iy;
593       mp_word  _W;
594       mp_digit *tmpy;
595
596       /* clear counter */
597       _W = 0;
598
599       /* get offsets into the two bignums */
600       ty = MIN(a->used-1, ix);
601       tx = ix - ty;
602
603       /* setup temp aliases */
604       tmpx = a->dp + tx;
605       tmpy = a->dp + ty;
606
607       /* This is the number of times the loop will iterate, essentially it's
608          while (tx++ < a->used && ty-- >= 0) { ... }
609        */
610       iy = MIN(a->used-tx, ty+1);
611
612       /* now for squaring tx can never equal ty 
613        * we halve the distance since they approach at a rate of 2x
614        * and we have to round because odd cases need to be executed
615        */
616       iy = MIN(iy, (ty-tx+1)>>1);
617
618       /* execute loop */
619       for (iz = 0; iz < iy; iz++) {
620          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
621       }
622
623       /* double the inner product and add carry */
624       _W = _W + _W + W1;
625
626       /* even columns have the square term in them */
627       if ((ix&1) == 0) {
628          _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
629       }
630
631       /* store it */
632       W[ix] = _W;
633
634       /* make next carry */
635       W1 = _W >> ((mp_word)DIGIT_BIT);
636   }
637
638   /* setup dest */
639   olduse  = b->used;
640   b->used = a->used+a->used;
641
642   {
643     mp_digit *tmpb;
644     tmpb = b->dp;
645     for (ix = 0; ix < pa; ix++) {
646       *tmpb++ = W[ix] & MP_MASK;
647     }
648
649     /* clear unused digits [that existed in the old copy of c] */
650     for (; ix < olduse; ix++) {
651       *tmpb++ = 0;
652     }
653   }
654   mp_clamp (b);
655   return MP_OKAY;
656 }
657
658 /* computes a = 2**b 
659  *
660  * Simple algorithm which zeroes the int, grows it then just sets one bit
661  * as required.
662  */
663 int
664 mp_2expt (mp_int * a, int b)
665 {
666   int     res;
667
668   /* zero a as per default */
669   mp_zero (a);
670
671   /* grow a to accommodate the single bit */
672   if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
673     return res;
674   }
675
676   /* set the used count of where the bit will go */
677   a->used = b / DIGIT_BIT + 1;
678
679   /* put the single bit in its place */
680   a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
681
682   return MP_OKAY;
683 }
684
685 /* b = |a| 
686  *
687  * Simple function copies the input and fixes the sign to positive
688  */
689 int
690 mp_abs (const mp_int * a, mp_int * b)
691 {
692   int     res;
693
694   /* copy a to b */
695   if (a != b) {
696      if ((res = mp_copy (a, b)) != MP_OKAY) {
697        return res;
698      }
699   }
700
701   /* force the sign of b to positive */
702   b->sign = MP_ZPOS;
703
704   return MP_OKAY;
705 }
706
707 /* high level addition (handles signs) */
708 int mp_add (mp_int * a, mp_int * b, mp_int * c)
709 {
710   int     sa, sb, res;
711
712   /* get sign of both inputs */
713   sa = a->sign;
714   sb = b->sign;
715
716   /* handle two cases, not four */
717   if (sa == sb) {
718     /* both positive or both negative */
719     /* add their magnitudes, copy the sign */
720     c->sign = sa;
721     res = s_mp_add (a, b, c);
722   } else {
723     /* one positive, the other negative */
724     /* subtract the one with the greater magnitude from */
725     /* the one of the lesser magnitude.  The result gets */
726     /* the sign of the one with the greater magnitude. */
727     if (mp_cmp_mag (a, b) == MP_LT) {
728       c->sign = sb;
729       res = s_mp_sub (b, a, c);
730     } else {
731       c->sign = sa;
732       res = s_mp_sub (a, b, c);
733     }
734   }
735   return res;
736 }
737
738
739 /* single digit addition */
740 int
741 mp_add_d (mp_int * a, mp_digit b, mp_int * c)
742 {
743   int     res, ix, oldused;
744   mp_digit *tmpa, *tmpc, mu;
745
746   /* grow c as required */
747   if (c->alloc < a->used + 1) {
748      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
749         return res;
750      }
751   }
752
753   /* if a is negative and |a| >= b, call c = |a| - b */
754   if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
755      /* temporarily fix sign of a */
756      a->sign = MP_ZPOS;
757
758      /* c = |a| - b */
759      res = mp_sub_d(a, b, c);
760
761      /* fix sign  */
762      a->sign = c->sign = MP_NEG;
763
764      return res;
765   }
766
767   /* old number of used digits in c */
768   oldused = c->used;
769
770   /* sign always positive */
771   c->sign = MP_ZPOS;
772
773   /* source alias */
774   tmpa    = a->dp;
775
776   /* destination alias */
777   tmpc    = c->dp;
778
779   /* if a is positive */
780   if (a->sign == MP_ZPOS) {
781      /* add digit, after this we're propagating
782       * the carry.
783       */
784      *tmpc   = *tmpa++ + b;
785      mu      = *tmpc >> DIGIT_BIT;
786      *tmpc++ &= MP_MASK;
787
788      /* now handle rest of the digits */
789      for (ix = 1; ix < a->used; ix++) {
790         *tmpc   = *tmpa++ + mu;
791         mu      = *tmpc >> DIGIT_BIT;
792         *tmpc++ &= MP_MASK;
793      }
794      /* set final carry */
795      ix++;
796      *tmpc++  = mu;
797
798      /* setup size */
799      c->used = a->used + 1;
800   } else {
801      /* a was negative and |a| < b */
802      c->used  = 1;
803
804      /* the result is a single digit */
805      if (a->used == 1) {
806         *tmpc++  =  b - a->dp[0];
807      } else {
808         *tmpc++  =  b;
809      }
810
811      /* setup count so the clearing of oldused
812       * can fall through correctly
813       */
814      ix       = 1;
815   }
816
817   /* now zero to oldused */
818   while (ix++ < oldused) {
819      *tmpc++ = 0;
820   }
821   mp_clamp(c);
822
823   return MP_OKAY;
824 }
825
826 /* trim unused digits 
827  *
828  * This is used to ensure that leading zero digits are
829  * trimed and the leading "used" digit will be non-zero
830  * Typically very fast.  Also fixes the sign if there
831  * are no more leading digits
832  */
833 void
834 mp_clamp (mp_int * a)
835 {
836   /* decrease used while the most significant digit is
837    * zero.
838    */
839   while (a->used > 0 && a->dp[a->used - 1] == 0) {
840     --(a->used);
841   }
842
843   /* reset the sign flag if used == 0 */
844   if (a->used == 0) {
845     a->sign = MP_ZPOS;
846   }
847 }
848
849 /* clear one (frees)  */
850 void
851 mp_clear (mp_int * a)
852 {
853   int i;
854
855   /* only do anything if a hasn't been freed previously */
856   if (a->dp != NULL) {
857     /* first zero the digits */
858     for (i = 0; i < a->used; i++) {
859         a->dp[i] = 0;
860     }
861
862     /* free ram */
863     free(a->dp);
864
865     /* reset members to make debugging easier */
866     a->dp    = NULL;
867     a->alloc = a->used = 0;
868     a->sign  = MP_ZPOS;
869   }
870 }
871
872
873 void mp_clear_multi(mp_int *mp, ...) 
874 {
875     mp_int* next_mp = mp;
876     va_list args;
877     va_start(args, mp);
878     while (next_mp != NULL) {
879         mp_clear(next_mp);
880         next_mp = va_arg(args, mp_int*);
881     }
882     va_end(args);
883 }
884
885 /* compare two ints (signed)*/
886 int
887 mp_cmp (const mp_int * a, const mp_int * b)
888 {
889   /* compare based on sign */
890   if (a->sign != b->sign) {
891      if (a->sign == MP_NEG) {
892         return MP_LT;
893      } else {
894         return MP_GT;
895      }
896   }
897   
898   /* compare digits */
899   if (a->sign == MP_NEG) {
900      /* if negative compare opposite direction */
901      return mp_cmp_mag(b, a);
902   } else {
903      return mp_cmp_mag(a, b);
904   }
905 }
906
907 /* compare a digit */
908 int mp_cmp_d(const mp_int * a, mp_digit b)
909 {
910   /* compare based on sign */
911   if (a->sign == MP_NEG) {
912     return MP_LT;
913   }
914
915   /* compare based on magnitude */
916   if (a->used > 1) {
917     return MP_GT;
918   }
919
920   /* compare the only digit of a to b */
921   if (a->dp[0] > b) {
922     return MP_GT;
923   } else if (a->dp[0] < b) {
924     return MP_LT;
925   } else {
926     return MP_EQ;
927   }
928 }
929
930 /* compare maginitude of two ints (unsigned) */
931 int mp_cmp_mag (const mp_int * a, const mp_int * b)
932 {
933   int     n;
934   mp_digit *tmpa, *tmpb;
935
936   /* compare based on # of non-zero digits */
937   if (a->used > b->used) {
938     return MP_GT;
939   }
940   
941   if (a->used < b->used) {
942     return MP_LT;
943   }
944
945   /* alias for a */
946   tmpa = a->dp + (a->used - 1);
947
948   /* alias for b */
949   tmpb = b->dp + (a->used - 1);
950
951   /* compare based on digits  */
952   for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
953     if (*tmpa > *tmpb) {
954       return MP_GT;
955     }
956
957     if (*tmpa < *tmpb) {
958       return MP_LT;
959     }
960   }
961   return MP_EQ;
962 }
963
964 static const int lnz[16] = { 
965    4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
966 };
967
968 /* Counts the number of lsbs which are zero before the first zero bit */
969 int mp_cnt_lsb(const mp_int *a)
970 {
971    int x;
972    mp_digit q, qq;
973
974    /* easy out */
975    if (mp_iszero(a) == 1) {
976       return 0;
977    }
978
979    /* scan lower digits until non-zero */
980    for (x = 0; x < a->used && a->dp[x] == 0; x++);
981    q = a->dp[x];
982    x *= DIGIT_BIT;
983
984    /* now scan this digit until a 1 is found */
985    if ((q & 1) == 0) {
986       do {
987          qq  = q & 15;
988          x  += lnz[qq];
989          q >>= 4;
990       } while (qq == 0);
991    }
992    return x;
993 }
994
995 /* copy, b = a */
996 int
997 mp_copy (const mp_int * a, mp_int * b)
998 {
999   int     res, n;
1000
1001   /* if dst == src do nothing */
1002   if (a == b) {
1003     return MP_OKAY;
1004   }
1005
1006   /* grow dest */
1007   if (b->alloc < a->used) {
1008      if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1009         return res;
1010      }
1011   }
1012
1013   /* zero b and copy the parameters over */
1014   {
1015     register mp_digit *tmpa, *tmpb;
1016
1017     /* pointer aliases */
1018
1019     /* source */
1020     tmpa = a->dp;
1021
1022     /* destination */
1023     tmpb = b->dp;
1024
1025     /* copy all the digits */
1026     for (n = 0; n < a->used; n++) {
1027       *tmpb++ = *tmpa++;
1028     }
1029
1030     /* clear high digits */
1031     for (; n < b->used; n++) {
1032       *tmpb++ = 0;
1033     }
1034   }
1035
1036   /* copy used count and sign */
1037   b->used = a->used;
1038   b->sign = a->sign;
1039   return MP_OKAY;
1040 }
1041
1042 /* returns the number of bits in an int */
1043 int
1044 mp_count_bits (const mp_int * a)
1045 {
1046   int     r;
1047   mp_digit q;
1048
1049   /* shortcut */
1050   if (a->used == 0) {
1051     return 0;
1052   }
1053
1054   /* get number of digits and add that */
1055   r = (a->used - 1) * DIGIT_BIT;
1056   
1057   /* take the last digit and count the bits in it */
1058   q = a->dp[a->used - 1];
1059   while (q > 0) {
1060     ++r;
1061     q >>= ((mp_digit) 1);
1062   }
1063   return r;
1064 }
1065
1066 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
1067 static int mp_div_2d (const mp_int * a, int b, mp_int * c, mp_int * d)
1068 {
1069   mp_digit D, r, rr;
1070   int     x, res;
1071   mp_int  t;
1072
1073
1074   /* if the shift count is <= 0 then we do no work */
1075   if (b <= 0) {
1076     res = mp_copy (a, c);
1077     if (d != NULL) {
1078       mp_zero (d);
1079     }
1080     return res;
1081   }
1082
1083   if ((res = mp_init (&t)) != MP_OKAY) {
1084     return res;
1085   }
1086
1087   /* get the remainder */
1088   if (d != NULL) {
1089     if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
1090       mp_clear (&t);
1091       return res;
1092     }
1093   }
1094
1095   /* copy */
1096   if ((res = mp_copy (a, c)) != MP_OKAY) {
1097     mp_clear (&t);
1098     return res;
1099   }
1100
1101   /* shift by as many digits in the bit count */
1102   if (b >= DIGIT_BIT) {
1103     mp_rshd (c, b / DIGIT_BIT);
1104   }
1105
1106   /* shift any bit count < DIGIT_BIT */
1107   D = (mp_digit) (b % DIGIT_BIT);
1108   if (D != 0) {
1109     register mp_digit *tmpc, mask, shift;
1110
1111     /* mask */
1112     mask = (((mp_digit)1) << D) - 1;
1113
1114     /* shift for lsb */
1115     shift = DIGIT_BIT - D;
1116
1117     /* alias */
1118     tmpc = c->dp + (c->used - 1);
1119
1120     /* carry */
1121     r = 0;
1122     for (x = c->used - 1; x >= 0; x--) {
1123       /* get the lower  bits of this word in a temp */
1124       rr = *tmpc & mask;
1125
1126       /* shift the current word and mix in the carry bits from the previous word */
1127       *tmpc = (*tmpc >> D) | (r << shift);
1128       --tmpc;
1129
1130       /* set the carry to the carry bits of the current word found above */
1131       r = rr;
1132     }
1133   }
1134   mp_clamp (c);
1135   if (d != NULL) {
1136     mp_exch (&t, d);
1137   }
1138   mp_clear (&t);
1139   return MP_OKAY;
1140 }
1141
1142 /* integer signed division. 
1143  * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
1144  * HAC pp.598 Algorithm 14.20
1145  *
1146  * Note that the description in HAC is horribly 
1147  * incomplete.  For example, it doesn't consider 
1148  * the case where digits are removed from 'x' in 
1149  * the inner loop.  It also doesn't consider the 
1150  * case that y has fewer than three digits, etc..
1151  *
1152  * The overall algorithm is as described as 
1153  * 14.20 from HAC but fixed to treat these cases.
1154 */
1155 static int mp_div (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
1156 {
1157   mp_int  q, x, y, t1, t2;
1158   int     res, n, t, i, norm, neg;
1159
1160   /* is divisor zero ? */
1161   if (mp_iszero (b) == 1) {
1162     return MP_VAL;
1163   }
1164
1165   /* if a < b then q=0, r = a */
1166   if (mp_cmp_mag (a, b) == MP_LT) {
1167     if (d != NULL) {
1168       res = mp_copy (a, d);
1169     } else {
1170       res = MP_OKAY;
1171     }
1172     if (c != NULL) {
1173       mp_zero (c);
1174     }
1175     return res;
1176   }
1177
1178   if ((res = mp_init_size (&q, a->used + 2)) != MP_OKAY) {
1179     return res;
1180   }
1181   q.used = a->used + 2;
1182
1183   if ((res = mp_init (&t1)) != MP_OKAY) {
1184     goto __Q;
1185   }
1186
1187   if ((res = mp_init (&t2)) != MP_OKAY) {
1188     goto __T1;
1189   }
1190
1191   if ((res = mp_init_copy (&x, a)) != MP_OKAY) {
1192     goto __T2;
1193   }
1194
1195   if ((res = mp_init_copy (&y, b)) != MP_OKAY) {
1196     goto __X;
1197   }
1198
1199   /* fix the sign */
1200   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1201   x.sign = y.sign = MP_ZPOS;
1202
1203   /* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1204   norm = mp_count_bits(&y) % DIGIT_BIT;
1205   if (norm < DIGIT_BIT-1) {
1206      norm = (DIGIT_BIT-1) - norm;
1207      if ((res = mp_mul_2d (&x, norm, &x)) != MP_OKAY) {
1208        goto __Y;
1209      }
1210      if ((res = mp_mul_2d (&y, norm, &y)) != MP_OKAY) {
1211        goto __Y;
1212      }
1213   } else {
1214      norm = 0;
1215   }
1216
1217   /* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
1218   n = x.used - 1;
1219   t = y.used - 1;
1220
1221   /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1222   if ((res = mp_lshd (&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1223     goto __Y;
1224   }
1225
1226   while (mp_cmp (&x, &y) != MP_LT) {
1227     ++(q.dp[n - t]);
1228     if ((res = mp_sub (&x, &y, &x)) != MP_OKAY) {
1229       goto __Y;
1230     }
1231   }
1232
1233   /* reset y by shifting it back down */
1234   mp_rshd (&y, n - t);
1235
1236   /* step 3. for i from n down to (t + 1) */
1237   for (i = n; i >= (t + 1); i--) {
1238     if (i > x.used) {
1239       continue;
1240     }
1241
1242     /* step 3.1 if xi == yt then set q{i-t-1} to b-1, 
1243      * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1244     if (x.dp[i] == y.dp[t]) {
1245       q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1246     } else {
1247       mp_word tmp;
1248       tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1249       tmp |= ((mp_word) x.dp[i - 1]);
1250       tmp /= ((mp_word) y.dp[t]);
1251       if (tmp > (mp_word) MP_MASK)
1252         tmp = MP_MASK;
1253       q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1254     }
1255
1256     /* while (q{i-t-1} * (yt * b + y{t-1})) > 
1257              xi * b**2 + xi-1 * b + xi-2 
1258      
1259        do q{i-t-1} -= 1; 
1260     */
1261     q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1262     do {
1263       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1264
1265       /* find left hand */
1266       mp_zero (&t1);
1267       t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1268       t1.dp[1] = y.dp[t];
1269       t1.used = 2;
1270       if ((res = mp_mul_d (&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1271         goto __Y;
1272       }
1273
1274       /* find right hand */
1275       t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1276       t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1277       t2.dp[2] = x.dp[i];
1278       t2.used = 3;
1279     } while (mp_cmp_mag(&t1, &t2) == MP_GT);
1280
1281     /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1282     if ((res = mp_mul_d (&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1283       goto __Y;
1284     }
1285
1286     if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1287       goto __Y;
1288     }
1289
1290     if ((res = mp_sub (&x, &t1, &x)) != MP_OKAY) {
1291       goto __Y;
1292     }
1293
1294     /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1295     if (x.sign == MP_NEG) {
1296       if ((res = mp_copy (&y, &t1)) != MP_OKAY) {
1297         goto __Y;
1298       }
1299       if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1300         goto __Y;
1301       }
1302       if ((res = mp_add (&x, &t1, &x)) != MP_OKAY) {
1303         goto __Y;
1304       }
1305
1306       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1307     }
1308   }
1309
1310   /* now q is the quotient and x is the remainder 
1311    * [which we have to normalize] 
1312    */
1313   
1314   /* get sign before writing to c */
1315   x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1316
1317   if (c != NULL) {
1318     mp_clamp (&q);
1319     mp_exch (&q, c);
1320     c->sign = neg;
1321   }
1322
1323   if (d != NULL) {
1324     mp_div_2d (&x, norm, &x, NULL);
1325     mp_exch (&x, d);
1326   }
1327
1328   res = MP_OKAY;
1329
1330 __Y:mp_clear (&y);
1331 __X:mp_clear (&x);
1332 __T2:mp_clear (&t2);
1333 __T1:mp_clear (&t1);
1334 __Q:mp_clear (&q);
1335   return res;
1336 }
1337
1338 static int s_is_power_of_two(mp_digit b, int *p)
1339 {
1340    int x;
1341
1342    for (x = 1; x < DIGIT_BIT; x++) {
1343       if (b == (((mp_digit)1)<<x)) {
1344          *p = x;
1345          return 1;
1346       }
1347    }
1348    return 0;
1349 }
1350
1351 /* single digit division (based on routine from MPI) */
1352 static int mp_div_d (const mp_int * a, mp_digit b, mp_int * c, mp_digit * d)
1353 {
1354   mp_int  q;
1355   mp_word w;
1356   mp_digit t;
1357   int     res, ix;
1358
1359   /* cannot divide by zero */
1360   if (b == 0) {
1361      return MP_VAL;
1362   }
1363
1364   /* quick outs */
1365   if (b == 1 || mp_iszero(a) == 1) {
1366      if (d != NULL) {
1367         *d = 0;
1368      }
1369      if (c != NULL) {
1370         return mp_copy(a, c);
1371      }
1372      return MP_OKAY;
1373   }
1374
1375   /* power of two ? */
1376   if (s_is_power_of_two(b, &ix) == 1) {
1377      if (d != NULL) {
1378         *d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
1379      }
1380      if (c != NULL) {
1381         return mp_div_2d(a, ix, c, NULL);
1382      }
1383      return MP_OKAY;
1384   }
1385
1386   /* no easy answer [c'est la vie].  Just division */
1387   if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
1388      return res;
1389   }
1390   
1391   q.used = a->used;
1392   q.sign = a->sign;
1393   w = 0;
1394   for (ix = a->used - 1; ix >= 0; ix--) {
1395      w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
1396      
1397      if (w >= b) {
1398         t = (mp_digit)(w / b);
1399         w -= ((mp_word)t) * ((mp_word)b);
1400       } else {
1401         t = 0;
1402       }
1403       q.dp[ix] = t;
1404   }
1405
1406   if (d != NULL) {
1407      *d = (mp_digit)w;
1408   }
1409   
1410   if (c != NULL) {
1411      mp_clamp(&q);
1412      mp_exch(&q, c);
1413   }
1414   mp_clear(&q);
1415   
1416   return res;
1417 }
1418
1419 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
1420  *
1421  * Based on algorithm from the paper
1422  *
1423  * "Generating Efficient Primes for Discrete Log Cryptosystems"
1424  *                 Chae Hoon Lim, Pil Loong Lee,
1425  *          POSTECH Information Research Laboratories
1426  *
1427  * The modulus must be of a special format [see manual]
1428  *
1429  * Has been modified to use algorithm 7.10 from the LTM book instead
1430  *
1431  * Input x must be in the range 0 <= x <= (n-1)**2
1432  */
1433 static int
1434 mp_dr_reduce (mp_int * x, const mp_int * n, mp_digit k)
1435 {
1436   int      err, i, m;
1437   mp_word  r;
1438   mp_digit mu, *tmpx1, *tmpx2;
1439
1440   /* m = digits in modulus */
1441   m = n->used;
1442
1443   /* ensure that "x" has at least 2m digits */
1444   if (x->alloc < m + m) {
1445     if ((err = mp_grow (x, m + m)) != MP_OKAY) {
1446       return err;
1447     }
1448   }
1449
1450 /* top of loop, this is where the code resumes if
1451  * another reduction pass is required.
1452  */
1453 top:
1454   /* aliases for digits */
1455   /* alias for lower half of x */
1456   tmpx1 = x->dp;
1457
1458   /* alias for upper half of x, or x/B**m */
1459   tmpx2 = x->dp + m;
1460
1461   /* set carry to zero */
1462   mu = 0;
1463
1464   /* compute (x mod B**m) + k * [x/B**m] inline and inplace */
1465   for (i = 0; i < m; i++) {
1466       r         = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
1467       *tmpx1++  = (mp_digit)(r & MP_MASK);
1468       mu        = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
1469   }
1470
1471   /* set final carry */
1472   *tmpx1++ = mu;
1473
1474   /* zero words above m */
1475   for (i = m + 1; i < x->used; i++) {
1476       *tmpx1++ = 0;
1477   }
1478
1479   /* clamp, sub and return */
1480   mp_clamp (x);
1481
1482   /* if x >= n then subtract and reduce again
1483    * Each successive "recursion" makes the input smaller and smaller.
1484    */
1485   if (mp_cmp_mag (x, n) != MP_LT) {
1486     s_mp_sub(x, n, x);
1487     goto top;
1488   }
1489   return MP_OKAY;
1490 }
1491
1492 /* sets the value of "d" required for mp_dr_reduce */
1493 static void mp_dr_setup(const mp_int *a, mp_digit *d)
1494 {
1495    /* the casts are required if DIGIT_BIT is one less than
1496     * the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
1497     */
1498    *d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) - 
1499         ((mp_word)a->dp[0]));
1500 }
1501
1502 /* swap the elements of two integers, for cases where you can't simply swap the 
1503  * mp_int pointers around
1504  */
1505 void
1506 mp_exch (mp_int * a, mp_int * b)
1507 {
1508   mp_int  t;
1509
1510   t  = *a;
1511   *a = *b;
1512   *b = t;
1513 }
1514
1515 /* this is a shell function that calls either the normal or Montgomery
1516  * exptmod functions.  Originally the call to the montgomery code was
1517  * embedded in the normal function but that wasted a lot of stack space
1518  * for nothing (since 99% of the time the Montgomery code would be called)
1519  */
1520 int mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
1521 {
1522   int dr;
1523
1524   /* modulus P must be positive */
1525   if (P->sign == MP_NEG) {
1526      return MP_VAL;
1527   }
1528
1529   /* if exponent X is negative we have to recurse */
1530   if (X->sign == MP_NEG) {
1531      mp_int tmpG, tmpX;
1532      int err;
1533
1534      /* first compute 1/G mod P */
1535      if ((err = mp_init(&tmpG)) != MP_OKAY) {
1536         return err;
1537      }
1538      if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
1539         mp_clear(&tmpG);
1540         return err;
1541      }
1542
1543      /* now get |X| */
1544      if ((err = mp_init(&tmpX)) != MP_OKAY) {
1545         mp_clear(&tmpG);
1546         return err;
1547      }
1548      if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
1549         mp_clear_multi(&tmpG, &tmpX, NULL);
1550         return err;
1551      }
1552
1553      /* and now compute (1/G)**|X| instead of G**X [X < 0] */
1554      err = mp_exptmod(&tmpG, &tmpX, P, Y);
1555      mp_clear_multi(&tmpG, &tmpX, NULL);
1556      return err;
1557   }
1558
1559   dr = 0;
1560
1561   /* if the modulus is odd or dr != 0 use the fast method */
1562   if (mp_isodd (P) == 1 || dr !=  0) {
1563     return mp_exptmod_fast (G, X, P, Y, dr);
1564   } else {
1565     /* otherwise use the generic Barrett reduction technique */
1566     return s_mp_exptmod (G, X, P, Y);
1567   }
1568 }
1569
1570 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
1571  *
1572  * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
1573  * The value of k changes based on the size of the exponent.
1574  *
1575  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
1576  */
1577
1578 int
1579 mp_exptmod_fast (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y, int redmode)
1580 {
1581   mp_int  M[256], res;
1582   mp_digit buf, mp;
1583   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1584
1585   /* use a pointer to the reduction algorithm.  This allows us to use
1586    * one of many reduction algorithms without modding the guts of
1587    * the code with if statements everywhere.
1588    */
1589   int     (*redux)(mp_int*,const mp_int*,mp_digit);
1590
1591   /* find window size */
1592   x = mp_count_bits (X);
1593   if (x <= 7) {
1594     winsize = 2;
1595   } else if (x <= 36) {
1596     winsize = 3;
1597   } else if (x <= 140) {
1598     winsize = 4;
1599   } else if (x <= 450) {
1600     winsize = 5;
1601   } else if (x <= 1303) {
1602     winsize = 6;
1603   } else if (x <= 3529) {
1604     winsize = 7;
1605   } else {
1606     winsize = 8;
1607   }
1608
1609   /* init M array */
1610   /* init first cell */
1611   if ((err = mp_init(&M[1])) != MP_OKAY) {
1612      return err;
1613   }
1614
1615   /* now init the second half of the array */
1616   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
1617     if ((err = mp_init(&M[x])) != MP_OKAY) {
1618       for (y = 1<<(winsize-1); y < x; y++) {
1619         mp_clear (&M[y]);
1620       }
1621       mp_clear(&M[1]);
1622       return err;
1623     }
1624   }
1625
1626   /* determine and setup reduction code */
1627   if (redmode == 0) {
1628      /* now setup montgomery  */
1629      if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
1630         goto __M;
1631      }
1632
1633      /* automatically pick the comba one if available (saves quite a few calls/ifs) */
1634      if (((P->used * 2 + 1) < MP_WARRAY) &&
1635           P->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
1636         redux = fast_mp_montgomery_reduce;
1637      } else {
1638         /* use slower baseline Montgomery method */
1639         redux = mp_montgomery_reduce;
1640      }
1641   } else if (redmode == 1) {
1642      /* setup DR reduction for moduli of the form B**k - b */
1643      mp_dr_setup(P, &mp);
1644      redux = mp_dr_reduce;
1645   } else {
1646      /* setup DR reduction for moduli of the form 2**k - b */
1647      if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
1648         goto __M;
1649      }
1650      redux = mp_reduce_2k;
1651   }
1652
1653   /* setup result */
1654   if ((err = mp_init (&res)) != MP_OKAY) {
1655     goto __M;
1656   }
1657
1658   /* create M table
1659    *
1660
1661    *
1662    * The first half of the table is not computed though accept for M[0] and M[1]
1663    */
1664
1665   if (redmode == 0) {
1666      /* now we need R mod m */
1667      if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
1668        goto __RES;
1669      }
1670
1671      /* now set M[1] to G * R mod m */
1672      if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
1673        goto __RES;
1674      }
1675   } else {
1676      mp_set(&res, 1);
1677      if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
1678         goto __RES;
1679      }
1680   }
1681
1682   /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
1683   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
1684     goto __RES;
1685   }
1686
1687   for (x = 0; x < (winsize - 1); x++) {
1688     if ((err = mp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
1689       goto __RES;
1690     }
1691     if ((err = redux (&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
1692       goto __RES;
1693     }
1694   }
1695
1696   /* create upper table */
1697   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
1698     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
1699       goto __RES;
1700     }
1701     if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
1702       goto __RES;
1703     }
1704   }
1705
1706   /* set initial mode and bit cnt */
1707   mode   = 0;
1708   bitcnt = 1;
1709   buf    = 0;
1710   digidx = X->used - 1;
1711   bitcpy = 0;
1712   bitbuf = 0;
1713
1714   for (;;) {
1715     /* grab next digit as required */
1716     if (--bitcnt == 0) {
1717       /* if digidx == -1 we are out of digits so break */
1718       if (digidx == -1) {
1719         break;
1720       }
1721       /* read next digit and reset bitcnt */
1722       buf    = X->dp[digidx--];
1723       bitcnt = DIGIT_BIT;
1724     }
1725
1726     /* grab the next msb from the exponent */
1727     y     = (buf >> (DIGIT_BIT - 1)) & 1;
1728     buf <<= (mp_digit)1;
1729
1730     /* if the bit is zero and mode == 0 then we ignore it
1731      * These represent the leading zero bits before the first 1 bit
1732      * in the exponent.  Technically this opt is not required but it
1733      * does lower the # of trivial squaring/reductions used
1734      */
1735     if (mode == 0 && y == 0) {
1736       continue;
1737     }
1738
1739     /* if the bit is zero and mode == 1 then we square */
1740     if (mode == 1 && y == 0) {
1741       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
1742         goto __RES;
1743       }
1744       if ((err = redux (&res, P, mp)) != MP_OKAY) {
1745         goto __RES;
1746       }
1747       continue;
1748     }
1749
1750     /* else we add it to the window */
1751     bitbuf |= (y << (winsize - ++bitcpy));
1752     mode    = 2;
1753
1754     if (bitcpy == winsize) {
1755       /* ok window is filled so square as required and multiply  */
1756       /* square first */
1757       for (x = 0; x < winsize; x++) {
1758         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
1759           goto __RES;
1760         }
1761         if ((err = redux (&res, P, mp)) != MP_OKAY) {
1762           goto __RES;
1763         }
1764       }
1765
1766       /* then multiply */
1767       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
1768         goto __RES;
1769       }
1770       if ((err = redux (&res, P, mp)) != MP_OKAY) {
1771         goto __RES;
1772       }
1773
1774       /* empty window and reset */
1775       bitcpy = 0;
1776       bitbuf = 0;
1777       mode   = 1;
1778     }
1779   }
1780
1781   /* if bits remain then square/multiply */
1782   if (mode == 2 && bitcpy > 0) {
1783     /* square then multiply if the bit is set */
1784     for (x = 0; x < bitcpy; x++) {
1785       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
1786         goto __RES;
1787       }
1788       if ((err = redux (&res, P, mp)) != MP_OKAY) {
1789         goto __RES;
1790       }
1791
1792       /* get next bit of the window */
1793       bitbuf <<= 1;
1794       if ((bitbuf & (1 << winsize)) != 0) {
1795         /* then multiply */
1796         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
1797           goto __RES;
1798         }
1799         if ((err = redux (&res, P, mp)) != MP_OKAY) {
1800           goto __RES;
1801         }
1802       }
1803     }
1804   }
1805
1806   if (redmode == 0) {
1807      /* fixup result if Montgomery reduction is used
1808       * recall that any value in a Montgomery system is
1809       * actually multiplied by R mod n.  So we have
1810       * to reduce one more time to cancel out the factor
1811       * of R.
1812       */
1813      if ((err = redux(&res, P, mp)) != MP_OKAY) {
1814        goto __RES;
1815      }
1816   }
1817
1818   /* swap res with Y */
1819   mp_exch (&res, Y);
1820   err = MP_OKAY;
1821 __RES:mp_clear (&res);
1822 __M:
1823   mp_clear(&M[1]);
1824   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
1825     mp_clear (&M[x]);
1826   }
1827   return err;
1828 }
1829
1830 /* Greatest Common Divisor using the binary method */
1831 int mp_gcd (const mp_int * a, const mp_int * b, mp_int * c)
1832 {
1833   mp_int  u, v;
1834   int     k, u_lsb, v_lsb, res;
1835
1836   /* either zero than gcd is the largest */
1837   if (mp_iszero (a) == 1 && mp_iszero (b) == 0) {
1838     return mp_abs (b, c);
1839   }
1840   if (mp_iszero (a) == 0 && mp_iszero (b) == 1) {
1841     return mp_abs (a, c);
1842   }
1843
1844   /* optimized.  At this point if a == 0 then
1845    * b must equal zero too
1846    */
1847   if (mp_iszero (a) == 1) {
1848     mp_zero(c);
1849     return MP_OKAY;
1850   }
1851
1852   /* get copies of a and b we can modify */
1853   if ((res = mp_init_copy (&u, a)) != MP_OKAY) {
1854     return res;
1855   }
1856
1857   if ((res = mp_init_copy (&v, b)) != MP_OKAY) {
1858     goto __U;
1859   }
1860
1861   /* must be positive for the remainder of the algorithm */
1862   u.sign = v.sign = MP_ZPOS;
1863
1864   /* B1.  Find the common power of two for u and v */
1865   u_lsb = mp_cnt_lsb(&u);
1866   v_lsb = mp_cnt_lsb(&v);
1867   k     = MIN(u_lsb, v_lsb);
1868
1869   if (k > 0) {
1870      /* divide the power of two out */
1871      if ((res = mp_div_2d(&u, k, &u, NULL)) != MP_OKAY) {
1872         goto __V;
1873      }
1874
1875      if ((res = mp_div_2d(&v, k, &v, NULL)) != MP_OKAY) {
1876         goto __V;
1877      }
1878   }
1879
1880   /* divide any remaining factors of two out */
1881   if (u_lsb != k) {
1882      if ((res = mp_div_2d(&u, u_lsb - k, &u, NULL)) != MP_OKAY) {
1883         goto __V;
1884      }
1885   }
1886
1887   if (v_lsb != k) {
1888      if ((res = mp_div_2d(&v, v_lsb - k, &v, NULL)) != MP_OKAY) {
1889         goto __V;
1890      }
1891   }
1892
1893   while (mp_iszero(&v) == 0) {
1894      /* make sure v is the largest */
1895      if (mp_cmp_mag(&u, &v) == MP_GT) {
1896         /* swap u and v to make sure v is >= u */
1897         mp_exch(&u, &v);
1898      }
1899      
1900      /* subtract smallest from largest */
1901      if ((res = s_mp_sub(&v, &u, &v)) != MP_OKAY) {
1902         goto __V;
1903      }
1904      
1905      /* Divide out all factors of two */
1906      if ((res = mp_div_2d(&v, mp_cnt_lsb(&v), &v, NULL)) != MP_OKAY) {
1907         goto __V;
1908      } 
1909   } 
1910
1911   /* multiply by 2**k which we divided out at the beginning */
1912   if ((res = mp_mul_2d (&u, k, c)) != MP_OKAY) {
1913      goto __V;
1914   }
1915   c->sign = MP_ZPOS;
1916   res = MP_OKAY;
1917 __V:mp_clear (&u);
1918 __U:mp_clear (&v);
1919   return res;
1920 }
1921
1922 /* get the lower 32-bits of an mp_int */
1923 unsigned long mp_get_int(const mp_int * a)
1924 {
1925   int i;
1926   unsigned long res;
1927
1928   if (a->used == 0) {
1929      return 0;
1930   }
1931
1932   /* get number of digits of the lsb we have to read */
1933   i = MIN(a->used,(int)((sizeof(unsigned long)*CHAR_BIT+DIGIT_BIT-1)/DIGIT_BIT))-1;
1934
1935   /* get most significant digit of result */
1936   res = DIGIT(a,i);
1937    
1938   while (--i >= 0) {
1939     res = (res << DIGIT_BIT) | DIGIT(a,i);
1940   }
1941
1942   /* force result to 32-bits always so it is consistent on non 32-bit platforms */
1943   return res & 0xFFFFFFFFUL;
1944 }
1945
1946 /* grow as required */
1947 int mp_grow (mp_int * a, int size)
1948 {
1949   int     i;
1950   mp_digit *tmp;
1951
1952   /* if the alloc size is smaller alloc more ram */
1953   if (a->alloc < size) {
1954     /* ensure there are always at least MP_PREC digits extra on top */
1955     size += (MP_PREC * 2) - (size % MP_PREC);
1956
1957     /* reallocate the array a->dp
1958      *
1959      * We store the return in a temporary variable
1960      * in case the operation failed we don't want
1961      * to overwrite the dp member of a.
1962      */
1963     tmp = realloc (a->dp, sizeof (mp_digit) * size);
1964     if (tmp == NULL) {
1965       /* reallocation failed but "a" is still valid [can be freed] */
1966       return MP_MEM;
1967     }
1968
1969     /* reallocation succeeded so set a->dp */
1970     a->dp = tmp;
1971
1972     /* zero excess digits */
1973     i        = a->alloc;
1974     a->alloc = size;
1975     for (; i < a->alloc; i++) {
1976       a->dp[i] = 0;
1977     }
1978   }
1979   return MP_OKAY;
1980 }
1981
1982 /* init a new mp_int */
1983 int mp_init (mp_int * a)
1984 {
1985   int i;
1986
1987   /* allocate memory required and clear it */
1988   a->dp = malloc (sizeof (mp_digit) * MP_PREC);
1989   if (a->dp == NULL) {
1990     return MP_MEM;
1991   }
1992
1993   /* set the digits to zero */
1994   for (i = 0; i < MP_PREC; i++) {
1995       a->dp[i] = 0;
1996   }
1997
1998   /* set the used to zero, allocated digits to the default precision
1999    * and sign to positive */
2000   a->used  = 0;
2001   a->alloc = MP_PREC;
2002   a->sign  = MP_ZPOS;
2003
2004   return MP_OKAY;
2005 }
2006
2007 /* creates "a" then copies b into it */
2008 int mp_init_copy (mp_int * a, const mp_int * b)
2009 {
2010   int     res;
2011
2012   if ((res = mp_init (a)) != MP_OKAY) {
2013     return res;
2014   }
2015   return mp_copy (b, a);
2016 }
2017
2018 int mp_init_multi(mp_int *mp, ...) 
2019 {
2020     mp_err res = MP_OKAY;      /* Assume ok until proven otherwise */
2021     int n = 0;                 /* Number of ok inits */
2022     mp_int* cur_arg = mp;
2023     va_list args;
2024
2025     va_start(args, mp);        /* init args to next argument from caller */
2026     while (cur_arg != NULL) {
2027         if (mp_init(cur_arg) != MP_OKAY) {
2028             /* Oops - error! Back-track and mp_clear what we already
2029                succeeded in init-ing, then return error.
2030             */
2031             va_list clean_args;
2032             
2033             /* end the current list */
2034             va_end(args);
2035             
2036             /* now start cleaning up */            
2037             cur_arg = mp;
2038             va_start(clean_args, mp);
2039             while (n--) {
2040                 mp_clear(cur_arg);
2041                 cur_arg = va_arg(clean_args, mp_int*);
2042             }
2043             va_end(clean_args);
2044             res = MP_MEM;
2045             break;
2046         }
2047         n++;
2048         cur_arg = va_arg(args, mp_int*);
2049     }
2050     va_end(args);
2051     return res;                /* Assumed ok, if error flagged above. */
2052 }
2053
2054 /* init an mp_init for a given size */
2055 int mp_init_size (mp_int * a, int size)
2056 {
2057   int x;
2058
2059   /* pad size so there are always extra digits */
2060   size += (MP_PREC * 2) - (size % MP_PREC);    
2061   
2062   /* alloc mem */
2063   a->dp = malloc (sizeof (mp_digit) * size);
2064   if (a->dp == NULL) {
2065     return MP_MEM;
2066   }
2067
2068   /* set the members */
2069   a->used  = 0;
2070   a->alloc = size;
2071   a->sign  = MP_ZPOS;
2072
2073   /* zero the digits */
2074   for (x = 0; x < size; x++) {
2075       a->dp[x] = 0;
2076   }
2077
2078   return MP_OKAY;
2079 }
2080
2081 /* hac 14.61, pp608 */
2082 int mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
2083 {
2084   /* b cannot be negative */
2085   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2086     return MP_VAL;
2087   }
2088
2089   /* if the modulus is odd we can use a faster routine instead */
2090   if (mp_isodd (b) == 1) {
2091     return fast_mp_invmod (a, b, c);
2092   }
2093   
2094   return mp_invmod_slow(a, b, c);
2095 }
2096
2097 /* hac 14.61, pp608 */
2098 int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c)
2099 {
2100   mp_int  x, y, u, v, A, B, C, D;
2101   int     res;
2102
2103   /* b cannot be negative */
2104   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2105     return MP_VAL;
2106   }
2107
2108   /* init temps */
2109   if ((res = mp_init_multi(&x, &y, &u, &v, 
2110                            &A, &B, &C, &D, NULL)) != MP_OKAY) {
2111      return res;
2112   }
2113
2114   /* x = a, y = b */
2115   if ((res = mp_copy (a, &x)) != MP_OKAY) {
2116     goto __ERR;
2117   }
2118   if ((res = mp_copy (b, &y)) != MP_OKAY) {
2119     goto __ERR;
2120   }
2121
2122   /* 2. [modified] if x,y are both even then return an error! */
2123   if (mp_iseven (&x) == 1 && mp_iseven (&y) == 1) {
2124     res = MP_VAL;
2125     goto __ERR;
2126   }
2127
2128   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
2129   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
2130     goto __ERR;
2131   }
2132   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
2133     goto __ERR;
2134   }
2135   mp_set (&A, 1);
2136   mp_set (&D, 1);
2137
2138 top:
2139   /* 4.  while u is even do */
2140   while (mp_iseven (&u) == 1) {
2141     /* 4.1 u = u/2 */
2142     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
2143       goto __ERR;
2144     }
2145     /* 4.2 if A or B is odd then */
2146     if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
2147       /* A = (A+y)/2, B = (B-x)/2 */
2148       if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
2149          goto __ERR;
2150       }
2151       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
2152          goto __ERR;
2153       }
2154     }
2155     /* A = A/2, B = B/2 */
2156     if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
2157       goto __ERR;
2158     }
2159     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
2160       goto __ERR;
2161     }
2162   }
2163
2164   /* 5.  while v is even do */
2165   while (mp_iseven (&v) == 1) {
2166     /* 5.1 v = v/2 */
2167     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
2168       goto __ERR;
2169     }
2170     /* 5.2 if C or D is odd then */
2171     if (mp_isodd (&C) == 1 || mp_isodd (&D) == 1) {
2172       /* C = (C+y)/2, D = (D-x)/2 */
2173       if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
2174          goto __ERR;
2175       }
2176       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
2177          goto __ERR;
2178       }
2179     }
2180     /* C = C/2, D = D/2 */
2181     if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
2182       goto __ERR;
2183     }
2184     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
2185       goto __ERR;
2186     }
2187   }
2188
2189   /* 6.  if u >= v then */
2190   if (mp_cmp (&u, &v) != MP_LT) {
2191     /* u = u - v, A = A - C, B = B - D */
2192     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
2193       goto __ERR;
2194     }
2195
2196     if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
2197       goto __ERR;
2198     }
2199
2200     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
2201       goto __ERR;
2202     }
2203   } else {
2204     /* v - v - u, C = C - A, D = D - B */
2205     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
2206       goto __ERR;
2207     }
2208
2209     if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
2210       goto __ERR;
2211     }
2212
2213     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
2214       goto __ERR;
2215     }
2216   }
2217
2218   /* if not zero goto step 4 */
2219   if (mp_iszero (&u) == 0)
2220     goto top;
2221
2222   /* now a = C, b = D, gcd == g*v */
2223
2224   /* if v != 1 then there is no inverse */
2225   if (mp_cmp_d (&v, 1) != MP_EQ) {
2226     res = MP_VAL;
2227     goto __ERR;
2228   }
2229
2230   /* if its too low */
2231   while (mp_cmp_d(&C, 0) == MP_LT) {
2232       if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
2233          goto __ERR;
2234       }
2235   }
2236   
2237   /* too big */
2238   while (mp_cmp_mag(&C, b) != MP_LT) {
2239       if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
2240          goto __ERR;
2241       }
2242   }
2243   
2244   /* C is now the inverse */
2245   mp_exch (&C, c);
2246   res = MP_OKAY;
2247 __ERR:mp_clear_multi (&x, &y, &u, &v, &A, &B, &C, &D, NULL);
2248   return res;
2249 }
2250
2251 /* c = |a| * |b| using Karatsuba Multiplication using 
2252  * three half size multiplications
2253  *
2254  * Let B represent the radix [e.g. 2**DIGIT_BIT] and 
2255  * let n represent half of the number of digits in 
2256  * the min(a,b)
2257  *
2258  * a = a1 * B**n + a0
2259  * b = b1 * B**n + b0
2260  *
2261  * Then, a * b => 
2262    a1b1 * B**2n + ((a1 - a0)(b1 - b0) + a0b0 + a1b1) * B + a0b0
2263  *
2264  * Note that a1b1 and a0b0 are used twice and only need to be 
2265  * computed once.  So in total three half size (half # of 
2266  * digit) multiplications are performed, a0b0, a1b1 and 
2267  * (a1-b1)(a0-b0)
2268  *
2269  * Note that a multiplication of half the digits requires
2270  * 1/4th the number of single precision multiplications so in 
2271  * total after one call 25% of the single precision multiplications 
2272  * are saved.  Note also that the call to mp_mul can end up back 
2273  * in this function if the a0, a1, b0, or b1 are above the threshold.  
2274  * This is known as divide-and-conquer and leads to the famous 
2275  * O(N**lg(3)) or O(N**1.584) work which is asymptotically lower than
2276  * the standard O(N**2) that the baseline/comba methods use.  
2277  * Generally though the overhead of this method doesn't pay off 
2278  * until a certain size (N ~ 80) is reached.
2279  */
2280 int mp_karatsuba_mul (const mp_int * a, const mp_int * b, mp_int * c)
2281 {
2282   mp_int  x0, x1, y0, y1, t1, x0y0, x1y1;
2283   int     B, err;
2284
2285   /* default the return code to an error */
2286   err = MP_MEM;
2287
2288   /* min # of digits */
2289   B = MIN (a->used, b->used);
2290
2291   /* now divide in two */
2292   B = B >> 1;
2293
2294   /* init copy all the temps */
2295   if (mp_init_size (&x0, B) != MP_OKAY)
2296     goto ERR;
2297   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2298     goto X0;
2299   if (mp_init_size (&y0, B) != MP_OKAY)
2300     goto X1;
2301   if (mp_init_size (&y1, b->used - B) != MP_OKAY)
2302     goto Y0;
2303
2304   /* init temps */
2305   if (mp_init_size (&t1, B * 2) != MP_OKAY)
2306     goto Y1;
2307   if (mp_init_size (&x0y0, B * 2) != MP_OKAY)
2308     goto T1;
2309   if (mp_init_size (&x1y1, B * 2) != MP_OKAY)
2310     goto X0Y0;
2311
2312   /* now shift the digits */
2313   x0.used = y0.used = B;
2314   x1.used = a->used - B;
2315   y1.used = b->used - B;
2316
2317   {
2318     register int x;
2319     register mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
2320
2321     /* we copy the digits directly instead of using higher level functions
2322      * since we also need to shift the digits
2323      */
2324     tmpa = a->dp;
2325     tmpb = b->dp;
2326
2327     tmpx = x0.dp;
2328     tmpy = y0.dp;
2329     for (x = 0; x < B; x++) {
2330       *tmpx++ = *tmpa++;
2331       *tmpy++ = *tmpb++;
2332     }
2333
2334     tmpx = x1.dp;
2335     for (x = B; x < a->used; x++) {
2336       *tmpx++ = *tmpa++;
2337     }
2338
2339     tmpy = y1.dp;
2340     for (x = B; x < b->used; x++) {
2341       *tmpy++ = *tmpb++;
2342     }
2343   }
2344
2345   /* only need to clamp the lower words since by definition the 
2346    * upper words x1/y1 must have a known number of digits
2347    */
2348   mp_clamp (&x0);
2349   mp_clamp (&y0);
2350
2351   /* now calc the products x0y0 and x1y1 */
2352   /* after this x0 is no longer required, free temp [x0==t2]! */
2353   if (mp_mul (&x0, &y0, &x0y0) != MP_OKAY)  
2354     goto X1Y1;          /* x0y0 = x0*y0 */
2355   if (mp_mul (&x1, &y1, &x1y1) != MP_OKAY)
2356     goto X1Y1;          /* x1y1 = x1*y1 */
2357
2358   /* now calc x1-x0 and y1-y0 */
2359   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2360     goto X1Y1;          /* t1 = x1 - x0 */
2361   if (mp_sub (&y1, &y0, &x0) != MP_OKAY)
2362     goto X1Y1;          /* t2 = y1 - y0 */
2363   if (mp_mul (&t1, &x0, &t1) != MP_OKAY)
2364     goto X1Y1;          /* t1 = (x1 - x0) * (y1 - y0) */
2365
2366   /* add x0y0 */
2367   if (mp_add (&x0y0, &x1y1, &x0) != MP_OKAY)
2368     goto X1Y1;          /* t2 = x0y0 + x1y1 */
2369   if (mp_sub (&x0, &t1, &t1) != MP_OKAY)
2370     goto X1Y1;          /* t1 = x0y0 + x1y1 - (x1-x0)*(y1-y0) */
2371
2372   /* shift by B */
2373   if (mp_lshd (&t1, B) != MP_OKAY)
2374     goto X1Y1;          /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
2375   if (mp_lshd (&x1y1, B * 2) != MP_OKAY)
2376     goto X1Y1;          /* x1y1 = x1y1 << 2*B */
2377
2378   if (mp_add (&x0y0, &t1, &t1) != MP_OKAY)
2379     goto X1Y1;          /* t1 = x0y0 + t1 */
2380   if (mp_add (&t1, &x1y1, c) != MP_OKAY)
2381     goto X1Y1;          /* t1 = x0y0 + t1 + x1y1 */
2382
2383   /* Algorithm succeeded set the return code to MP_OKAY */
2384   err = MP_OKAY;
2385
2386 X1Y1:mp_clear (&x1y1);
2387 X0Y0:mp_clear (&x0y0);
2388 T1:mp_clear (&t1);
2389 Y1:mp_clear (&y1);
2390 Y0:mp_clear (&y0);
2391 X1:mp_clear (&x1);
2392 X0:mp_clear (&x0);
2393 ERR:
2394   return err;
2395 }
2396
2397 /* Karatsuba squaring, computes b = a*a using three 
2398  * half size squarings
2399  *
2400  * See comments of karatsuba_mul for details.  It 
2401  * is essentially the same algorithm but merely 
2402  * tuned to perform recursive squarings.
2403  */
2404 int mp_karatsuba_sqr (const mp_int * a, mp_int * b)
2405 {
2406   mp_int  x0, x1, t1, t2, x0x0, x1x1;
2407   int     B, err;
2408
2409   err = MP_MEM;
2410
2411   /* min # of digits */
2412   B = a->used;
2413
2414   /* now divide in two */
2415   B = B >> 1;
2416
2417   /* init copy all the temps */
2418   if (mp_init_size (&x0, B) != MP_OKAY)
2419     goto ERR;
2420   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2421     goto X0;
2422
2423   /* init temps */
2424   if (mp_init_size (&t1, a->used * 2) != MP_OKAY)
2425     goto X1;
2426   if (mp_init_size (&t2, a->used * 2) != MP_OKAY)
2427     goto T1;
2428   if (mp_init_size (&x0x0, B * 2) != MP_OKAY)
2429     goto T2;
2430   if (mp_init_size (&x1x1, (a->used - B) * 2) != MP_OKAY)
2431     goto X0X0;
2432
2433   {
2434     register int x;
2435     register mp_digit *dst, *src;
2436
2437     src = a->dp;
2438
2439     /* now shift the digits */
2440     dst = x0.dp;
2441     for (x = 0; x < B; x++) {
2442       *dst++ = *src++;
2443     }
2444
2445     dst = x1.dp;
2446     for (x = B; x < a->used; x++) {
2447       *dst++ = *src++;
2448     }
2449   }
2450
2451   x0.used = B;
2452   x1.used = a->used - B;
2453
2454   mp_clamp (&x0);
2455
2456   /* now calc the products x0*x0 and x1*x1 */
2457   if (mp_sqr (&x0, &x0x0) != MP_OKAY)
2458     goto X1X1;           /* x0x0 = x0*x0 */
2459   if (mp_sqr (&x1, &x1x1) != MP_OKAY)
2460     goto X1X1;           /* x1x1 = x1*x1 */
2461
2462   /* now calc (x1-x0)**2 */
2463   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2464     goto X1X1;           /* t1 = x1 - x0 */
2465   if (mp_sqr (&t1, &t1) != MP_OKAY)
2466     goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */
2467
2468   /* add x0y0 */
2469   if (s_mp_add (&x0x0, &x1x1, &t2) != MP_OKAY)
2470     goto X1X1;           /* t2 = x0x0 + x1x1 */
2471   if (mp_sub (&t2, &t1, &t1) != MP_OKAY)
2472     goto X1X1;           /* t1 = x0x0 + x1x1 - (x1-x0)*(x1-x0) */
2473
2474   /* shift by B */
2475   if (mp_lshd (&t1, B) != MP_OKAY)
2476     goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
2477   if (mp_lshd (&x1x1, B * 2) != MP_OKAY)
2478     goto X1X1;           /* x1x1 = x1x1 << 2*B */
2479
2480   if (mp_add (&x0x0, &t1, &t1) != MP_OKAY)
2481     goto X1X1;           /* t1 = x0x0 + t1 */
2482   if (mp_add (&t1, &x1x1, b) != MP_OKAY)
2483     goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */
2484
2485   err = MP_OKAY;
2486
2487 X1X1:mp_clear (&x1x1);
2488 X0X0:mp_clear (&x0x0);
2489 T2:mp_clear (&t2);
2490 T1:mp_clear (&t1);
2491 X1:mp_clear (&x1);
2492 X0:mp_clear (&x0);
2493 ERR:
2494   return err;
2495 }
2496
2497 /* computes least common multiple as |a*b|/(a, b) */
2498 int mp_lcm (const mp_int * a, const mp_int * b, mp_int * c)
2499 {
2500   int     res;
2501   mp_int  t1, t2;
2502
2503
2504   if ((res = mp_init_multi (&t1, &t2, NULL)) != MP_OKAY) {
2505     return res;
2506   }
2507
2508   /* t1 = get the GCD of the two inputs */
2509   if ((res = mp_gcd (a, b, &t1)) != MP_OKAY) {
2510     goto __T;
2511   }
2512
2513   /* divide the smallest by the GCD */
2514   if (mp_cmp_mag(a, b) == MP_LT) {
2515      /* store quotient in t2 such that t2 * b is the LCM */
2516      if ((res = mp_div(a, &t1, &t2, NULL)) != MP_OKAY) {
2517         goto __T;
2518      }
2519      res = mp_mul(b, &t2, c);
2520   } else {
2521      /* store quotient in t2 such that t2 * a is the LCM */
2522      if ((res = mp_div(b, &t1, &t2, NULL)) != MP_OKAY) {
2523         goto __T;
2524      }
2525      res = mp_mul(a, &t2, c);
2526   }
2527
2528   /* fix the sign to positive */
2529   c->sign = MP_ZPOS;
2530
2531 __T:
2532   mp_clear_multi (&t1, &t2, NULL);
2533   return res;
2534 }
2535
2536 /* shift left a certain amount of digits */
2537 int mp_lshd (mp_int * a, int b)
2538 {
2539   int     x, res;
2540
2541   /* if its less than zero return */
2542   if (b <= 0) {
2543     return MP_OKAY;
2544   }
2545
2546   /* grow to fit the new digits */
2547   if (a->alloc < a->used + b) {
2548      if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
2549        return res;
2550      }
2551   }
2552
2553   {
2554     register mp_digit *top, *bottom;
2555
2556     /* increment the used by the shift amount then copy upwards */
2557     a->used += b;
2558
2559     /* top */
2560     top = a->dp + a->used - 1;
2561
2562     /* base */
2563     bottom = a->dp + a->used - 1 - b;
2564
2565     /* much like mp_rshd this is implemented using a sliding window
2566      * except the window goes the otherway around.  Copying from
2567      * the bottom to the top.  see bn_mp_rshd.c for more info.
2568      */
2569     for (x = a->used - 1; x >= b; x--) {
2570       *top-- = *bottom--;
2571     }
2572
2573     /* zero the lower digits */
2574     top = a->dp;
2575     for (x = 0; x < b; x++) {
2576       *top++ = 0;
2577     }
2578   }
2579   return MP_OKAY;
2580 }
2581
2582 /* c = a mod b, 0 <= c < b */
2583 int
2584 mp_mod (const mp_int * a, mp_int * b, mp_int * c)
2585 {
2586   mp_int  t;
2587   int     res;
2588
2589   if ((res = mp_init (&t)) != MP_OKAY) {
2590     return res;
2591   }
2592
2593   if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
2594     mp_clear (&t);
2595     return res;
2596   }
2597
2598   if (t.sign != b->sign) {
2599     res = mp_add (b, &t, c);
2600   } else {
2601     res = MP_OKAY;
2602     mp_exch (&t, c);
2603   }
2604
2605   mp_clear (&t);
2606   return res;
2607 }
2608
2609 /* calc a value mod 2**b */
2610 int
2611 mp_mod_2d (const mp_int * a, int b, mp_int * c)
2612 {
2613   int     x, res;
2614
2615   /* if b is <= 0 then zero the int */
2616   if (b <= 0) {
2617     mp_zero (c);
2618     return MP_OKAY;
2619   }
2620
2621   /* if the modulus is larger than the value than return */
2622   if (b > a->used * DIGIT_BIT) {
2623     res = mp_copy (a, c);
2624     return res;
2625   }
2626
2627   /* copy */
2628   if ((res = mp_copy (a, c)) != MP_OKAY) {
2629     return res;
2630   }
2631
2632   /* zero digits above the last digit of the modulus */
2633   for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
2634     c->dp[x] = 0;
2635   }
2636   /* clear the digit that is not completely outside/inside the modulus */
2637   c->dp[b / DIGIT_BIT] &= (1 << ((mp_digit)b % DIGIT_BIT)) - 1;
2638   mp_clamp (c);
2639   return MP_OKAY;
2640 }
2641
2642 int
2643 mp_mod_d (const mp_int * a, mp_digit b, mp_digit * c)
2644 {
2645   return mp_div_d(a, b, NULL, c);
2646 }
2647
2648 /*
2649  * shifts with subtractions when the result is greater than b.
2650  *
2651  * The method is slightly modified to shift B unconditionally up to just under
2652  * the leading bit of b.  This saves a lot of multiple precision shifting.
2653  */
2654 int mp_montgomery_calc_normalization (mp_int * a, const mp_int * b)
2655 {
2656   int     x, bits, res;
2657
2658   /* how many bits of last digit does b use */
2659   bits = mp_count_bits (b) % DIGIT_BIT;
2660
2661
2662   if (b->used > 1) {
2663      if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
2664         return res;
2665      }
2666   } else {
2667      mp_set(a, 1);
2668      bits = 1;
2669   }
2670
2671
2672   /* now compute C = A * B mod b */
2673   for (x = bits - 1; x < DIGIT_BIT; x++) {
2674     if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
2675       return res;
2676     }
2677     if (mp_cmp_mag (a, b) != MP_LT) {
2678       if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
2679         return res;
2680       }
2681     }
2682   }
2683
2684   return MP_OKAY;
2685 }
2686
2687 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
2688 int
2689 mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
2690 {
2691   int     ix, res, digs;
2692   mp_digit mu;
2693
2694   /* can the fast reduction [comba] method be used?
2695    *
2696    * Note that unlike in mul you're safely allowed *less*
2697    * than the available columns [255 per default] since carries
2698    * are fixed up in the inner loop.
2699    */
2700   digs = n->used * 2 + 1;
2701   if ((digs < MP_WARRAY) &&
2702       n->used <
2703       (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2704     return fast_mp_montgomery_reduce (x, n, rho);
2705   }
2706
2707   /* grow the input as required */
2708   if (x->alloc < digs) {
2709     if ((res = mp_grow (x, digs)) != MP_OKAY) {
2710       return res;
2711     }
2712   }
2713   x->used = digs;
2714
2715   for (ix = 0; ix < n->used; ix++) {
2716     /* mu = ai * rho mod b
2717      *
2718      * The value of rho must be precalculated via
2719      * montgomery_setup() such that
2720      * it equals -1/n0 mod b this allows the
2721      * following inner loop to reduce the
2722      * input one digit at a time
2723      */
2724     mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
2725
2726     /* a = a + mu * m * b**i */
2727     {
2728       register int iy;
2729       register mp_digit *tmpn, *tmpx, u;
2730       register mp_word r;
2731
2732       /* alias for digits of the modulus */
2733       tmpn = n->dp;
2734
2735       /* alias for the digits of x [the input] */
2736       tmpx = x->dp + ix;
2737
2738       /* set the carry to zero */
2739       u = 0;
2740
2741       /* Multiply and add in place */
2742       for (iy = 0; iy < n->used; iy++) {
2743         /* compute product and sum */
2744         r       = ((mp_word)mu) * ((mp_word)*tmpn++) +
2745                   ((mp_word) u) + ((mp_word) * tmpx);
2746
2747         /* get carry */
2748         u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2749
2750         /* fix digit */
2751         *tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
2752       }
2753       /* At this point the ix'th digit of x should be zero */
2754
2755
2756       /* propagate carries upwards as required*/
2757       while (u) {
2758         *tmpx   += u;
2759         u        = *tmpx >> DIGIT_BIT;
2760         *tmpx++ &= MP_MASK;
2761       }
2762     }
2763   }
2764
2765   /* at this point the n.used'th least
2766    * significant digits of x are all zero
2767    * which means we can shift x to the
2768    * right by n.used digits and the
2769    * residue is unchanged.
2770    */
2771
2772   /* x = x/b**n.used */
2773   mp_clamp(x);
2774   mp_rshd (x, n->used);
2775
2776   /* if x >= n then x = x - n */
2777   if (mp_cmp_mag (x, n) != MP_LT) {
2778     return s_mp_sub (x, n, x);
2779   }
2780
2781   return MP_OKAY;
2782 }
2783
2784 /* setups the montgomery reduction stuff */
2785 int
2786 mp_montgomery_setup (const mp_int * n, mp_digit * rho)
2787 {
2788   mp_digit x, b;
2789
2790 /* fast inversion mod 2**k
2791  *
2792  * Based on the fact that
2793  *
2794  * XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
2795  *                    =>  2*X*A - X*X*A*A = 1
2796  *                    =>  2*(1) - (1)     = 1
2797  */
2798   b = n->dp[0];
2799
2800   if ((b & 1) == 0) {
2801     return MP_VAL;
2802   }
2803
2804   x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
2805   x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
2806   x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
2807   x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
2808
2809   /* rho = -1/m mod b */
2810   *rho = (((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
2811
2812   return MP_OKAY;
2813 }
2814
2815 /* high level multiplication (handles sign) */
2816 int mp_mul (const mp_int * a, const mp_int * b, mp_int * c)
2817 {
2818   int     res, neg;
2819   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
2820
2821   /* use Karatsuba? */
2822   if (MIN (a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
2823     res = mp_karatsuba_mul (a, b, c);
2824   } else 
2825   {
2826     /* can we use the fast multiplier?
2827      *
2828      * The fast multiplier can be used if the output will 
2829      * have less than MP_WARRAY digits and the number of 
2830      * digits won't affect carry propagation
2831      */
2832     int     digs = a->used + b->used + 1;
2833
2834     if ((digs < MP_WARRAY) &&
2835         MIN(a->used, b->used) <= 
2836         (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2837       res = fast_s_mp_mul_digs (a, b, c, digs);
2838     } else 
2839       res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
2840   }
2841   c->sign = (c->used > 0) ? neg : MP_ZPOS;
2842   return res;
2843 }
2844
2845 /* b = a*2 */
2846 int mp_mul_2(const mp_int * a, mp_int * b)
2847 {
2848   int     x, res, oldused;
2849
2850   /* grow to accommodate result */
2851   if (b->alloc < a->used + 1) {
2852     if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
2853       return res;
2854     }
2855   }
2856
2857   oldused = b->used;
2858   b->used = a->used;
2859
2860   {
2861     register mp_digit r, rr, *tmpa, *tmpb;
2862
2863     /* alias for source */
2864     tmpa = a->dp;
2865     
2866     /* alias for dest */
2867     tmpb = b->dp;
2868
2869     /* carry */
2870     r = 0;
2871     for (x = 0; x < a->used; x++) {
2872     
2873       /* get what will be the *next* carry bit from the 
2874        * MSB of the current digit 
2875        */
2876       rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
2877       
2878       /* now shift up this digit, add in the carry [from the previous] */
2879       *tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
2880       
2881       /* copy the carry that would be from the source 
2882        * digit into the next iteration 
2883        */
2884       r = rr;
2885     }
2886
2887     /* new leading digit? */
2888     if (r != 0) {
2889       /* add a MSB which is always 1 at this point */
2890       *tmpb = 1;
2891       ++(b->used);
2892     }
2893
2894     /* now zero any excess digits on the destination 
2895      * that we didn't write to 
2896      */
2897     tmpb = b->dp + b->used;
2898     for (x = b->used; x < oldused; x++) {
2899       *tmpb++ = 0;
2900     }
2901   }
2902   b->sign = a->sign;
2903   return MP_OKAY;
2904 }
2905
2906 /* shift left by a certain bit count */
2907 int mp_mul_2d (const mp_int * a, int b, mp_int * c)
2908 {
2909   mp_digit d;
2910   int      res;
2911
2912   /* copy */
2913   if (a != c) {
2914      if ((res = mp_copy (a, c)) != MP_OKAY) {
2915        return res;
2916      }
2917   }
2918
2919   if (c->alloc < c->used + b/DIGIT_BIT + 1) {
2920      if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
2921        return res;
2922      }
2923   }
2924
2925   /* shift by as many digits in the bit count */
2926   if (b >= DIGIT_BIT) {
2927     if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
2928       return res;
2929     }
2930   }
2931
2932   /* shift any bit count < DIGIT_BIT */
2933   d = (mp_digit) (b % DIGIT_BIT);
2934   if (d != 0) {
2935     register mp_digit *tmpc, shift, mask, r, rr;
2936     register int x;
2937
2938     /* bitmask for carries */
2939     mask = (((mp_digit)1) << d) - 1;
2940
2941     /* shift for msbs */
2942     shift = DIGIT_BIT - d;
2943
2944     /* alias */
2945     tmpc = c->dp;
2946
2947     /* carry */
2948     r    = 0;
2949     for (x = 0; x < c->used; x++) {
2950       /* get the higher bits of the current word */
2951       rr = (*tmpc >> shift) & mask;
2952
2953       /* shift the current word and OR in the carry */
2954       *tmpc = ((*tmpc << d) | r) & MP_MASK;
2955       ++tmpc;
2956
2957       /* set the carry to the carry bits of the current word */
2958       r = rr;
2959     }
2960     
2961     /* set final carry */
2962     if (r != 0) {
2963        c->dp[(c->used)++] = r;
2964     }
2965   }
2966   mp_clamp (c);
2967   return MP_OKAY;
2968 }
2969
2970 /* multiply by a digit */
2971 int
2972 mp_mul_d (const mp_int * a, mp_digit b, mp_int * c)
2973 {
2974   mp_digit u, *tmpa, *tmpc;
2975   mp_word  r;
2976   int      ix, res, olduse;
2977
2978   /* make sure c is big enough to hold a*b */
2979   if (c->alloc < a->used + 1) {
2980     if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
2981       return res;
2982     }
2983   }
2984
2985   /* get the original destinations used count */
2986   olduse = c->used;
2987
2988   /* set the sign */
2989   c->sign = a->sign;
2990
2991   /* alias for a->dp [source] */
2992   tmpa = a->dp;
2993
2994   /* alias for c->dp [dest] */
2995   tmpc = c->dp;
2996
2997   /* zero carry */
2998   u = 0;
2999
3000   /* compute columns */
3001   for (ix = 0; ix < a->used; ix++) {
3002     /* compute product and carry sum for this term */
3003     r       = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
3004
3005     /* mask off higher bits to get a single digit */
3006     *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
3007
3008     /* send carry into next iteration */
3009     u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
3010   }
3011
3012   /* store final carry [if any] */
3013   *tmpc++ = u;
3014
3015   /* now zero digits above the top */
3016   while (ix++ < olduse) {
3017      *tmpc++ = 0;
3018   }
3019
3020   /* set used count */
3021   c->used = a->used + 1;
3022   mp_clamp(c);
3023
3024   return MP_OKAY;
3025 }
3026
3027 /* d = a * b (mod c) */
3028 int
3029 mp_mulmod (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
3030 {
3031   int     res;
3032   mp_int  t;
3033
3034   if ((res = mp_init (&t)) != MP_OKAY) {
3035     return res;
3036   }
3037
3038   if ((res = mp_mul (a, b, &t)) != MP_OKAY) {
3039     mp_clear (&t);
3040     return res;
3041   }
3042   res = mp_mod (&t, c, d);
3043   mp_clear (&t);
3044   return res;
3045 }
3046
3047 /* table of first PRIME_SIZE primes */
3048 static const mp_digit __prime_tab[] = {
3049   0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
3050   0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
3051   0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
3052   0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F, 0x0083,
3053   0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
3054   0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
3055   0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
3056   0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
3057
3058   0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
3059   0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
3060   0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
3061   0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
3062   0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
3063   0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
3064   0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
3065   0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
3066
3067   0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
3068   0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
3069   0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
3070   0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
3071   0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
3072   0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
3073   0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
3074   0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
3075
3076   0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
3077   0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
3078   0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
3079   0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
3080   0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
3081   0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
3082   0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
3083   0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
3084 };
3085
3086 /* determines if an integers is divisible by one 
3087  * of the first PRIME_SIZE primes or not
3088  *
3089  * sets result to 0 if not, 1 if yes
3090  */
3091 int mp_prime_is_divisible (const mp_int * a, int *result)
3092 {
3093   int     err, ix;
3094   mp_digit res;
3095
3096   /* default to not */
3097   *result = MP_NO;
3098
3099   for (ix = 0; ix < PRIME_SIZE; ix++) {
3100     /* what is a mod __prime_tab[ix] */
3101     if ((err = mp_mod_d (a, __prime_tab[ix], &res)) != MP_OKAY) {
3102       return err;
3103     }
3104
3105     /* is the residue zero? */
3106     if (res == 0) {
3107       *result = MP_YES;
3108       return MP_OKAY;
3109     }
3110   }
3111
3112   return MP_OKAY;
3113 }
3114
3115 /* performs a variable number of rounds of Miller-Rabin
3116  *
3117  * Probability of error after t rounds is no more than
3118
3119  *
3120  * Sets result to 1 if probably prime, 0 otherwise
3121  */
3122 int mp_prime_is_prime (mp_int * a, int t, int *result)
3123 {
3124   mp_int  b;
3125   int     ix, err, res;
3126
3127   /* default to no */
3128   *result = MP_NO;
3129
3130   /* valid value of t? */
3131   if (t <= 0 || t > PRIME_SIZE) {
3132     return MP_VAL;
3133   }
3134
3135   /* is the input equal to one of the primes in the table? */
3136   for (ix = 0; ix < PRIME_SIZE; ix++) {
3137       if (mp_cmp_d(a, __prime_tab[ix]) == MP_EQ) {
3138          *result = 1;
3139          return MP_OKAY;
3140       }
3141   }
3142
3143   /* first perform trial division */
3144   if ((err = mp_prime_is_divisible (a, &res)) != MP_OKAY) {
3145     return err;
3146   }
3147
3148   /* return if it was trivially divisible */
3149   if (res == MP_YES) {
3150     return MP_OKAY;
3151   }
3152
3153   /* now perform the miller-rabin rounds */
3154   if ((err = mp_init (&b)) != MP_OKAY) {
3155     return err;
3156   }
3157
3158   for (ix = 0; ix < t; ix++) {
3159     /* set the prime */
3160     mp_set (&b, __prime_tab[ix]);
3161
3162     if ((err = mp_prime_miller_rabin (a, &b, &res)) != MP_OKAY) {
3163       goto __B;
3164     }
3165
3166     if (res == MP_NO) {
3167       goto __B;
3168     }
3169   }
3170
3171   /* passed the test */
3172   *result = MP_YES;
3173 __B:mp_clear (&b);
3174   return err;
3175 }
3176
3177 /* Miller-Rabin test of "a" to the base of "b" as described in 
3178  * HAC pp. 139 Algorithm 4.24
3179  *
3180  * Sets result to 0 if definitely composite or 1 if probably prime.
3181  * Randomly the chance of error is no more than 1/4 and often 
3182  * very much lower.
3183  */
3184 int mp_prime_miller_rabin (mp_int * a, const mp_int * b, int *result)
3185 {
3186   mp_int  n1, y, r;
3187   int     s, j, err;
3188
3189   /* default */
3190   *result = MP_NO;
3191
3192   /* ensure b > 1 */
3193   if (mp_cmp_d(b, 1) != MP_GT) {
3194      return MP_VAL;
3195   }     
3196
3197   /* get n1 = a - 1 */
3198   if ((err = mp_init_copy (&n1, a)) != MP_OKAY) {
3199     return err;
3200   }
3201   if ((err = mp_sub_d (&n1, 1, &n1)) != MP_OKAY) {
3202     goto __N1;
3203   }
3204
3205   /* set 2**s * r = n1 */
3206   if ((err = mp_init_copy (&r, &n1)) != MP_OKAY) {
3207     goto __N1;
3208   }
3209
3210   /* count the number of least significant bits
3211    * which are zero
3212    */
3213   s = mp_cnt_lsb(&r);
3214
3215   /* now divide n - 1 by 2**s */
3216   if ((err = mp_div_2d (&r, s, &r, NULL)) != MP_OKAY) {
3217     goto __R;
3218   }
3219
3220   /* compute y = b**r mod a */
3221   if ((err = mp_init (&y)) != MP_OKAY) {
3222     goto __R;
3223   }
3224   if ((err = mp_exptmod (b, &r, a, &y)) != MP_OKAY) {
3225     goto __Y;
3226   }
3227
3228   /* if y != 1 and y != n1 do */
3229   if (mp_cmp_d (&y, 1) != MP_EQ && mp_cmp (&y, &n1) != MP_EQ) {
3230     j = 1;
3231     /* while j <= s-1 and y != n1 */
3232     while ((j <= (s - 1)) && mp_cmp (&y, &n1) != MP_EQ) {
3233       if ((err = mp_sqrmod (&y, a, &y)) != MP_OKAY) {
3234          goto __Y;
3235       }
3236
3237       /* if y == 1 then composite */
3238       if (mp_cmp_d (&y, 1) == MP_EQ) {
3239          goto __Y;
3240       }
3241
3242       ++j;
3243     }
3244
3245     /* if y != n1 then composite */
3246     if (mp_cmp (&y, &n1) != MP_EQ) {
3247       goto __Y;
3248     }
3249   }
3250
3251   /* probably prime now */
3252   *result = MP_YES;
3253 __Y:mp_clear (&y);
3254 __R:mp_clear (&r);
3255 __N1:mp_clear (&n1);
3256   return err;
3257 }
3258
3259 static const struct {
3260    int k, t;
3261 } sizes[] = {
3262 {   128,    28 },
3263 {   256,    16 },
3264 {   384,    10 },
3265 {   512,     7 },
3266 {   640,     6 },
3267 {   768,     5 },
3268 {   896,     4 },
3269 {  1024,     4 }
3270 };
3271
3272 /* returns # of RM trials required for a given bit size */
3273 int mp_prime_rabin_miller_trials(int size)
3274 {
3275    int x;
3276
3277    for (x = 0; x < (int)(sizeof(sizes)/(sizeof(sizes[0]))); x++) {
3278        if (sizes[x].k == size) {
3279           return sizes[x].t;
3280        } else if (sizes[x].k > size) {
3281           return (x == 0) ? sizes[0].t : sizes[x - 1].t;
3282        }
3283    }
3284    return sizes[x-1].t + 1;
3285 }
3286
3287 /* makes a truly random prime of a given size (bits),
3288  *
3289  * Flags are as follows:
3290  * 
3291  *   LTM_PRIME_BBS      - make prime congruent to 3 mod 4
3292  *   LTM_PRIME_SAFE     - make sure (p-1)/2 is prime as well (implies LTM_PRIME_BBS)
3293  *   LTM_PRIME_2MSB_OFF - make the 2nd highest bit zero
3294  *   LTM_PRIME_2MSB_ON  - make the 2nd highest bit one
3295  *
3296  * You have to supply a callback which fills in a buffer with random bytes.  "dat" is a parameter you can
3297  * have passed to the callback (e.g. a state or something).  This function doesn't use "dat" itself
3298  * so it can be NULL
3299  *
3300  */
3301
3302 /* This is possibly the mother of all prime generation functions, muahahahahaha! */
3303 int mp_prime_random_ex(mp_int *a, int t, int size, int flags, ltm_prime_callback cb, void *dat)
3304 {
3305    unsigned char *tmp, maskAND, maskOR_msb, maskOR_lsb;
3306    int res, err, bsize, maskOR_msb_offset;
3307
3308    /* sanity check the input */
3309    if (size <= 1 || t <= 0) {
3310       return MP_VAL;
3311    }
3312
3313    /* LTM_PRIME_SAFE implies LTM_PRIME_BBS */
3314    if (flags & LTM_PRIME_SAFE) {
3315       flags |= LTM_PRIME_BBS;
3316    }
3317
3318    /* calc the byte size */
3319    bsize = (size>>3)+((size&7)?1:0);
3320
3321    /* we need a buffer of bsize bytes */
3322    tmp = malloc(bsize);
3323    if (tmp == NULL) {
3324       return MP_MEM;
3325    }
3326
3327    /* calc the maskAND value for the MSbyte*/
3328    maskAND = ((size&7) == 0) ? 0xFF : (0xFF >> (8 - (size & 7))); 
3329
3330    /* calc the maskOR_msb */
3331    maskOR_msb        = 0;
3332    maskOR_msb_offset = ((size & 7) == 1) ? 1 : 0;
3333    if (flags & LTM_PRIME_2MSB_ON) {
3334       maskOR_msb     |= 1 << ((size - 2) & 7);
3335    } else if (flags & LTM_PRIME_2MSB_OFF) {
3336       maskAND        &= ~(1 << ((size - 2) & 7));
3337    }
3338
3339    /* get the maskOR_lsb */
3340    maskOR_lsb         = 0;
3341    if (flags & LTM_PRIME_BBS) {
3342       maskOR_lsb     |= 3;
3343    }
3344
3345    do {
3346       /* read the bytes */
3347       if (cb(tmp, bsize, dat) != bsize) {
3348          err = MP_VAL;
3349          goto error;
3350       }
3351  
3352       /* work over the MSbyte */
3353       tmp[0]    &= maskAND;
3354       tmp[0]    |= 1 << ((size - 1) & 7);
3355
3356       /* mix in the maskORs */
3357       tmp[maskOR_msb_offset]   |= maskOR_msb;
3358       tmp[bsize-1]             |= maskOR_lsb;
3359
3360       /* read it in */
3361       if ((err = mp_read_unsigned_bin(a, tmp, bsize)) != MP_OKAY)     { goto error; }
3362
3363       /* is it prime? */
3364       if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)           { goto error; }
3365       if (res == MP_NO) {  
3366          continue;
3367       }
3368
3369       if (flags & LTM_PRIME_SAFE) {
3370          /* see if (a-1)/2 is prime */
3371          if ((err = mp_sub_d(a, 1, a)) != MP_OKAY)                    { goto error; }
3372          if ((err = mp_div_2(a, a)) != MP_OKAY)                       { goto error; }
3373  
3374          /* is it prime? */
3375          if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)        { goto error; }
3376       }
3377    } while (res == MP_NO);
3378
3379    if (flags & LTM_PRIME_SAFE) {
3380       /* restore a to the original value */
3381       if ((err = mp_mul_2(a, a)) != MP_OKAY)                          { goto error; }
3382       if ((err = mp_add_d(a, 1, a)) != MP_OKAY)                       { goto error; }
3383    }
3384
3385    err = MP_OKAY;
3386 error:
3387    free(tmp);
3388    return err;
3389 }
3390
3391 /* reads an unsigned char array, assumes the msb is stored first [big endian] */
3392 int
3393 mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
3394 {
3395   int     res;
3396
3397   /* make sure there are at least two digits */
3398   if (a->alloc < 2) {
3399      if ((res = mp_grow(a, 2)) != MP_OKAY) {
3400         return res;
3401      }
3402   }
3403
3404   /* zero the int */
3405   mp_zero (a);
3406
3407   /* read the bytes in */
3408   while (c-- > 0) {
3409     if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
3410       return res;
3411     }
3412
3413       a->dp[0] |= *b++;
3414       a->used += 1;
3415   }
3416   mp_clamp (a);
3417   return MP_OKAY;
3418 }
3419
3420 /* reduces x mod m, assumes 0 < x < m**2, mu is 
3421  * precomputed via mp_reduce_setup.
3422  * From HAC pp.604 Algorithm 14.42
3423  */
3424 int
3425 mp_reduce (mp_int * x, const mp_int * m, const mp_int * mu)
3426 {
3427   mp_int  q;
3428   int     res, um = m->used;
3429
3430   /* q = x */
3431   if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
3432     return res;
3433   }
3434
3435   /* q1 = x / b**(k-1)  */
3436   mp_rshd (&q, um - 1);         
3437
3438   /* according to HAC this optimization is ok */
3439   if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
3440     if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
3441       goto CLEANUP;
3442     }
3443   } else {
3444     if ((res = s_mp_mul_high_digs (&q, mu, &q, um - 1)) != MP_OKAY) {
3445       goto CLEANUP;
3446     }
3447   }
3448
3449   /* q3 = q2 / b**(k+1) */
3450   mp_rshd (&q, um + 1);         
3451
3452   /* x = x mod b**(k+1), quick (no division) */
3453   if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
3454     goto CLEANUP;
3455   }
3456
3457   /* q = q * m mod b**(k+1), quick (no division) */
3458   if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
3459     goto CLEANUP;
3460   }
3461
3462   /* x = x - q */
3463   if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
3464     goto CLEANUP;
3465   }
3466
3467   /* If x < 0, add b**(k+1) to it */
3468   if (mp_cmp_d (x, 0) == MP_LT) {
3469     mp_set (&q, 1);
3470     if ((res = mp_lshd (&q, um + 1)) != MP_OKAY)
3471       goto CLEANUP;
3472     if ((res = mp_add (x, &q, x)) != MP_OKAY)
3473       goto CLEANUP;
3474   }
3475
3476   /* Back off if it's too big */
3477   while (mp_cmp (x, m) != MP_LT) {
3478     if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
3479       goto CLEANUP;
3480     }
3481   }
3482   
3483 CLEANUP:
3484   mp_clear (&q);
3485
3486   return res;
3487 }
3488
3489 /* reduces a modulo n where n is of the form 2**p - d */
3490 int
3491 mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d)
3492 {
3493    mp_int q;
3494    int    p, res;
3495    
3496    if ((res = mp_init(&q)) != MP_OKAY) {
3497       return res;
3498    }
3499    
3500    p = mp_count_bits(n);    
3501 top:
3502    /* q = a/2**p, a = a mod 2**p */
3503    if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
3504       goto ERR;
3505    }
3506    
3507    if (d != 1) {
3508       /* q = q * d */
3509       if ((res = mp_mul_d(&q, d, &q)) != MP_OKAY) { 
3510          goto ERR;
3511       }
3512    }
3513    
3514    /* a = a + q */
3515    if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
3516       goto ERR;
3517    }
3518    
3519    if (mp_cmp_mag(a, n) != MP_LT) {
3520       s_mp_sub(a, n, a);
3521       goto top;
3522    }
3523    
3524 ERR:
3525    mp_clear(&q);
3526    return res;
3527 }
3528
3529 /* determines the setup value */
3530 int 
3531 mp_reduce_2k_setup(const mp_int *a, mp_digit *d)
3532 {
3533    int res, p;
3534    mp_int tmp;
3535    
3536    if ((res = mp_init(&tmp)) != MP_OKAY) {
3537       return res;
3538    }
3539    
3540    p = mp_count_bits(a);
3541    if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
3542       mp_clear(&tmp);
3543       return res;
3544    }
3545    
3546    if ((res = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
3547       mp_clear(&tmp);
3548       return res;
3549    }
3550    
3551    *d = tmp.dp[0];
3552    mp_clear(&tmp);
3553    return MP_OKAY;
3554 }
3555
3556 /* pre-calculate the value required for Barrett reduction
3557  * For a given modulus "b" it calulates the value required in "a"
3558  */
3559 int mp_reduce_setup (mp_int * a, const mp_int * b)
3560 {
3561   int     res;
3562
3563   if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
3564     return res;
3565   }
3566   return mp_div (a, b, a, NULL);
3567 }
3568
3569 /* shift right a certain amount of digits */
3570 void mp_rshd (mp_int * a, int b)
3571 {
3572   int     x;
3573
3574   /* if b <= 0 then ignore it */
3575   if (b <= 0) {
3576     return;
3577   }
3578
3579   /* if b > used then simply zero it and return */
3580   if (a->used <= b) {
3581     mp_zero (a);
3582     return;
3583   }
3584
3585   {
3586     register mp_digit *bottom, *top;
3587
3588     /* shift the digits down */
3589
3590     /* bottom */
3591     bottom = a->dp;
3592
3593     /* top [offset into digits] */
3594     top = a->dp + b;
3595
3596     /* this is implemented as a sliding window where 
3597      * the window is b-digits long and digits from 
3598      * the top of the window are copied to the bottom
3599      *
3600      * e.g.
3601
3602      b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
3603                  /\                   |      ---->
3604                   \-------------------/      ---->
3605      */
3606     for (x = 0; x < (a->used - b); x++) {
3607       *bottom++ = *top++;
3608     }
3609
3610     /* zero the top digits */
3611     for (; x < a->used; x++) {
3612       *bottom++ = 0;
3613     }
3614   }
3615   
3616   /* remove excess digits */
3617   a->used -= b;
3618 }
3619
3620 /* set to a digit */
3621 void mp_set (mp_int * a, mp_digit b)
3622 {
3623   mp_zero (a);
3624   a->dp[0] = b & MP_MASK;
3625   a->used  = (a->dp[0] != 0) ? 1 : 0;
3626 }
3627
3628 /* set a 32-bit const */
3629 int mp_set_int (mp_int * a, unsigned long b)
3630 {
3631   int     x, res;
3632
3633   mp_zero (a);
3634   
3635   /* set four bits at a time */
3636   for (x = 0; x < 8; x++) {
3637     /* shift the number up four bits */
3638     if ((res = mp_mul_2d (a, 4, a)) != MP_OKAY) {
3639       return res;
3640     }
3641
3642     /* OR in the top four bits of the source */
3643     a->dp[0] |= (b >> 28) & 15;
3644
3645     /* shift the source up to the next four bits */
3646     b <<= 4;
3647
3648     /* ensure that digits are not clamped off */
3649     a->used += 1;
3650   }
3651   mp_clamp (a);
3652   return MP_OKAY;
3653 }
3654
3655 /* shrink a bignum */
3656 int mp_shrink (mp_int * a)
3657 {
3658   mp_digit *tmp;
3659   if (a->alloc != a->used && a->used > 0) {
3660     if ((tmp = realloc (a->dp, sizeof (mp_digit) * a->used)) == NULL) {
3661       return MP_MEM;
3662     }
3663     a->dp    = tmp;
3664     a->alloc = a->used;
3665   }
3666   return MP_OKAY;
3667 }
3668
3669 /* get the size for an signed equivalent */
3670 int mp_signed_bin_size (const mp_int * a)
3671 {
3672   return 1 + mp_unsigned_bin_size (a);
3673 }
3674
3675 /* computes b = a*a */
3676 int
3677 mp_sqr (const mp_int * a, mp_int * b)
3678 {
3679   int     res;
3680
3681 if (a->used >= KARATSUBA_SQR_CUTOFF) {
3682     res = mp_karatsuba_sqr (a, b);
3683   } else 
3684   {
3685     /* can we use the fast comba multiplier? */
3686     if ((a->used * 2 + 1) < MP_WARRAY && 
3687          a->used < 
3688          (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
3689       res = fast_s_mp_sqr (a, b);
3690     } else
3691       res = s_mp_sqr (a, b);
3692   }
3693   b->sign = MP_ZPOS;
3694   return res;
3695 }
3696
3697 /* c = a * a (mod b) */
3698 int
3699 mp_sqrmod (const mp_int * a, mp_int * b, mp_int * c)
3700 {
3701   int     res;
3702   mp_int  t;
3703
3704   if ((res = mp_init (&t)) != MP_OKAY) {
3705     return res;
3706   }
3707
3708   if ((res = mp_sqr (a, &t)) != MP_OKAY) {
3709     mp_clear (&t);
3710     return res;
3711   }
3712   res = mp_mod (&t, b, c);
3713   mp_clear (&t);
3714   return res;
3715 }
3716
3717 /* high level subtraction (handles signs) */
3718 int
3719 mp_sub (mp_int * a, mp_int * b, mp_int * c)
3720 {
3721   int     sa, sb, res;
3722
3723   sa = a->sign;
3724   sb = b->sign;
3725
3726   if (sa != sb) {
3727     /* subtract a negative from a positive, OR */
3728     /* subtract a positive from a negative. */
3729     /* In either case, ADD their magnitudes, */
3730     /* and use the sign of the first number. */
3731     c->sign = sa;
3732     res = s_mp_add (a, b, c);
3733   } else {
3734     /* subtract a positive from a positive, OR */
3735     /* subtract a negative from a negative. */
3736     /* First, take the difference between their */
3737     /* magnitudes, then... */
3738     if (mp_cmp_mag (a, b) != MP_LT) {
3739       /* Copy the sign from the first */
3740       c->sign = sa;
3741       /* The first has a larger or equal magnitude */
3742       res = s_mp_sub (a, b, c);
3743     } else {
3744       /* The result has the *opposite* sign from */
3745       /* the first number. */
3746       c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
3747       /* The second has a larger magnitude */
3748       res = s_mp_sub (b, a, c);
3749     }
3750   }
3751   return res;
3752 }
3753
3754 /* single digit subtraction */
3755 int
3756 mp_sub_d (mp_int * a, mp_digit b, mp_int * c)
3757 {
3758   mp_digit *tmpa, *tmpc, mu;
3759   int       res, ix, oldused;
3760
3761   /* grow c as required */
3762   if (c->alloc < a->used + 1) {
3763      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
3764         return res;
3765      }
3766   }
3767
3768   /* if a is negative just do an unsigned
3769    * addition [with fudged signs]
3770    */
3771   if (a->sign == MP_NEG) {
3772      a->sign = MP_ZPOS;
3773      res     = mp_add_d(a, b, c);
3774      a->sign = c->sign = MP_NEG;
3775      return res;
3776   }
3777
3778   /* setup regs */
3779   oldused = c->used;
3780   tmpa    = a->dp;
3781   tmpc    = c->dp;
3782
3783   /* if a <= b simply fix the single digit */
3784   if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
3785      if (a->used == 1) {
3786         *tmpc++ = b - *tmpa;
3787      } else {
3788         *tmpc++ = b;
3789      }
3790      ix      = 1;
3791
3792      /* negative/1digit */
3793      c->sign = MP_NEG;
3794      c->used = 1;
3795   } else {
3796      /* positive/size */
3797      c->sign = MP_ZPOS;
3798      c->used = a->used;
3799
3800      /* subtract first digit */
3801      *tmpc    = *tmpa++ - b;
3802      mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3803      *tmpc++ &= MP_MASK;
3804
3805      /* handle rest of the digits */
3806      for (ix = 1; ix < a->used; ix++) {
3807         *tmpc    = *tmpa++ - mu;
3808         mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3809         *tmpc++ &= MP_MASK;
3810      }
3811   }
3812
3813   /* zero excess digits */
3814   while (ix++ < oldused) {
3815      *tmpc++ = 0;
3816   }
3817   mp_clamp(c);
3818   return MP_OKAY;
3819 }
3820
3821 /* store in unsigned [big endian] format */
3822 int
3823 mp_to_unsigned_bin (const mp_int * a, unsigned char *b)
3824 {
3825   int     x, res;
3826   mp_int  t;
3827
3828   if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
3829     return res;
3830   }
3831
3832   x = 0;
3833   while (mp_iszero (&t) == 0) {
3834     b[x++] = (unsigned char) (t.dp[0] & 255);
3835     if ((res = mp_div_2d (&t, 8, &t, NULL)) != MP_OKAY) {
3836       mp_clear (&t);
3837       return res;
3838     }
3839   }
3840   bn_reverse (b, x);
3841   mp_clear (&t);
3842   return MP_OKAY;
3843 }
3844
3845 /* get the size for an unsigned equivalent */
3846 int
3847 mp_unsigned_bin_size (const mp_int * a)
3848 {
3849   int     size = mp_count_bits (a);
3850   return (size / 8 + ((size & 7) != 0 ? 1 : 0));
3851 }
3852
3853 /* set to zero */
3854 void
3855 mp_zero (mp_int * a)
3856 {
3857   a->sign = MP_ZPOS;
3858   a->used = 0;
3859   memset (a->dp, 0, sizeof (mp_digit) * a->alloc);
3860 }
3861
3862 /* reverse an array, used for radix code */
3863 static void
3864 bn_reverse (unsigned char *s, int len)
3865 {
3866   int     ix, iy;
3867   unsigned char t;
3868
3869   ix = 0;
3870   iy = len - 1;
3871   while (ix < iy) {
3872     t     = s[ix];
3873     s[ix] = s[iy];
3874     s[iy] = t;
3875     ++ix;
3876     --iy;
3877   }
3878 }
3879
3880 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
3881 static int
3882 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
3883 {
3884   mp_int *x;
3885   int     olduse, res, min, max;
3886
3887   /* find sizes, we let |a| <= |b| which means we have to sort
3888    * them.  "x" will point to the input with the most digits
3889    */
3890   if (a->used > b->used) {
3891     min = b->used;
3892     max = a->used;
3893     x = a;
3894   } else {
3895     min = a->used;
3896     max = b->used;
3897     x = b;
3898   }
3899
3900   /* init result */
3901   if (c->alloc < max + 1) {
3902     if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
3903       return res;
3904     }
3905   }
3906
3907   /* get old used digit count and set new one */
3908   olduse = c->used;
3909   c->used = max + 1;
3910
3911   {
3912     register mp_digit u, *tmpa, *tmpb, *tmpc;
3913     register int i;
3914
3915     /* alias for digit pointers */
3916
3917     /* first input */
3918     tmpa = a->dp;
3919
3920     /* second input */
3921     tmpb = b->dp;
3922
3923     /* destination */
3924     tmpc = c->dp;
3925
3926     /* zero the carry */
3927     u = 0;
3928     for (i = 0; i < min; i++) {
3929       /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
3930       *tmpc = *tmpa++ + *tmpb++ + u;
3931
3932       /* U = carry bit of T[i] */
3933       u = *tmpc >> ((mp_digit)DIGIT_BIT);
3934
3935       /* take away carry bit from T[i] */
3936       *tmpc++ &= MP_MASK;
3937     }
3938
3939     /* now copy higher words if any, that is in A+B 
3940      * if A or B has more digits add those in 
3941      */
3942     if (min != max) {
3943       for (; i < max; i++) {
3944         /* T[i] = X[i] + U */
3945         *tmpc = x->dp[i] + u;
3946
3947         /* U = carry bit of T[i] */
3948         u = *tmpc >> ((mp_digit)DIGIT_BIT);
3949
3950         /* take away carry bit from T[i] */
3951         *tmpc++ &= MP_MASK;
3952       }
3953     }
3954
3955     /* add carry */
3956     *tmpc++ = u;
3957
3958     /* clear digits above oldused */
3959     for (i = c->used; i < olduse; i++) {
3960       *tmpc++ = 0;
3961     }
3962   }
3963
3964   mp_clamp (c);
3965   return MP_OKAY;
3966 }
3967
3968 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
3969 {
3970   mp_int  M[256], res, mu;
3971   mp_digit buf;
3972   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
3973
3974   /* find window size */
3975   x = mp_count_bits (X);
3976   if (x <= 7) {
3977     winsize = 2;
3978   } else if (x <= 36) {
3979     winsize = 3;
3980   } else if (x <= 140) {
3981     winsize = 4;
3982   } else if (x <= 450) {
3983     winsize = 5;
3984   } else if (x <= 1303) {
3985     winsize = 6;
3986   } else if (x <= 3529) {
3987     winsize = 7;
3988   } else {
3989     winsize = 8;
3990   }
3991
3992   /* init M array */
3993   /* init first cell */
3994   if ((err = mp_init(&M[1])) != MP_OKAY) {
3995      return err; 
3996   }
3997
3998   /* now init the second half of the array */
3999   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4000     if ((err = mp_init(&M[x])) != MP_OKAY) {
4001       for (y = 1<<(winsize-1); y < x; y++) {
4002         mp_clear (&M[y]);
4003       }
4004       mp_clear(&M[1]);
4005       return err;
4006     }
4007   }
4008
4009   /* create mu, used for Barrett reduction */
4010   if ((err = mp_init (&mu)) != MP_OKAY) {
4011     goto __M;
4012   }
4013   if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
4014     goto __MU;
4015   }
4016
4017   /* create M table
4018    *
4019    * The M table contains powers of the base, 
4020    * e.g. M[x] = G**x mod P
4021    *
4022    * The first half of the table is not 
4023    * computed though accept for M[0] and M[1]
4024    */
4025   if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
4026     goto __MU;
4027   }
4028
4029   /* compute the value at M[1<<(winsize-1)] by squaring 
4030    * M[1] (winsize-1) times 
4031    */
4032   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
4033     goto __MU;
4034   }
4035
4036   for (x = 0; x < (winsize - 1); x++) {
4037     if ((err = mp_sqr (&M[1 << (winsize - 1)], 
4038                        &M[1 << (winsize - 1)])) != MP_OKAY) {
4039       goto __MU;
4040     }
4041     if ((err = mp_reduce (&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
4042       goto __MU;
4043     }
4044   }
4045
4046   /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
4047    * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
4048    */
4049   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
4050     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
4051       goto __MU;
4052     }
4053     if ((err = mp_reduce (&M[x], P, &mu)) != MP_OKAY) {
4054       goto __MU;
4055     }
4056   }
4057
4058   /* setup result */
4059   if ((err = mp_init (&res)) != MP_OKAY) {
4060     goto __MU;
4061   }
4062   mp_set (&res, 1);
4063
4064   /* set initial mode and bit cnt */
4065   mode   = 0;
4066   bitcnt = 1;
4067   buf    = 0;
4068   digidx = X->used - 1;
4069   bitcpy = 0;
4070   bitbuf = 0;
4071
4072   for (;;) {
4073     /* grab next digit as required */
4074     if (--bitcnt == 0) {
4075       /* if digidx == -1 we are out of digits */
4076       if (digidx == -1) {
4077         break;
4078       }
4079       /* read next digit and reset the bitcnt */
4080       buf    = X->dp[digidx--];
4081       bitcnt = DIGIT_BIT;
4082     }
4083
4084     /* grab the next msb from the exponent */
4085     y     = (buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
4086     buf <<= (mp_digit)1;
4087
4088     /* if the bit is zero and mode == 0 then we ignore it
4089      * These represent the leading zero bits before the first 1 bit
4090      * in the exponent.  Technically this opt is not required but it
4091      * does lower the # of trivial squaring/reductions used
4092      */
4093     if (mode == 0 && y == 0) {
4094       continue;
4095     }
4096
4097     /* if the bit is zero and mode == 1 then we square */
4098     if (mode == 1 && y == 0) {
4099       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4100         goto __RES;
4101       }
4102       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4103         goto __RES;
4104       }
4105       continue;
4106     }
4107
4108     /* else we add it to the window */
4109     bitbuf |= (y << (winsize - ++bitcpy));
4110     mode    = 2;
4111
4112     if (bitcpy == winsize) {
4113       /* ok window is filled so square as required and multiply  */
4114       /* square first */
4115       for (x = 0; x < winsize; x++) {
4116         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4117           goto __RES;
4118         }
4119         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4120           goto __RES;
4121         }
4122       }
4123
4124       /* then multiply */
4125       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
4126         goto __RES;
4127       }
4128       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4129         goto __RES;
4130       }
4131
4132       /* empty window and reset */
4133       bitcpy = 0;
4134       bitbuf = 0;
4135       mode   = 1;
4136     }
4137   }
4138
4139   /* if bits remain then square/multiply */
4140   if (mode == 2 && bitcpy > 0) {
4141     /* square then multiply if the bit is set */
4142     for (x = 0; x < bitcpy; x++) {
4143       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4144         goto __RES;
4145       }
4146       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4147         goto __RES;
4148       }
4149
4150       bitbuf <<= 1;
4151       if ((bitbuf & (1 << winsize)) != 0) {
4152         /* then multiply */
4153         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
4154           goto __RES;
4155         }
4156         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4157           goto __RES;
4158         }
4159       }
4160     }
4161   }
4162
4163   mp_exch (&res, Y);
4164   err = MP_OKAY;
4165 __RES:mp_clear (&res);
4166 __MU:mp_clear (&mu);
4167 __M:
4168   mp_clear(&M[1]);
4169   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4170     mp_clear (&M[x]);
4171   }
4172   return err;
4173 }
4174
4175 /* multiplies |a| * |b| and only computes up to digs digits of result
4176  * HAC pp. 595, Algorithm 14.12  Modified so you can control how 
4177  * many digits of output are created.
4178  */
4179 static int
4180 s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4181 {
4182   mp_int  t;
4183   int     res, pa, pb, ix, iy;
4184   mp_digit u;
4185   mp_word r;
4186   mp_digit tmpx, *tmpt, *tmpy;
4187
4188   /* can we use the fast multiplier? */
4189   if (((digs) < MP_WARRAY) &&
4190       MIN (a->used, b->used) < 
4191           (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4192     return fast_s_mp_mul_digs (a, b, c, digs);
4193   }
4194
4195   if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
4196     return res;
4197   }
4198   t.used = digs;
4199
4200   /* compute the digits of the product directly */
4201   pa = a->used;
4202   for (ix = 0; ix < pa; ix++) {
4203     /* set the carry to zero */
4204     u = 0;
4205
4206     /* limit ourselves to making digs digits of output */
4207     pb = MIN (b->used, digs - ix);
4208
4209     /* setup some aliases */
4210     /* copy of the digit from a used within the nested loop */
4211     tmpx = a->dp[ix];
4212     
4213     /* an alias for the destination shifted ix places */
4214     tmpt = t.dp + ix;
4215     
4216     /* an alias for the digits of b */
4217     tmpy = b->dp;
4218
4219     /* compute the columns of the output and propagate the carry */
4220     for (iy = 0; iy < pb; iy++) {
4221       /* compute the column as a mp_word */
4222       r       = ((mp_word)*tmpt) +
4223                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4224                 ((mp_word) u);
4225
4226       /* the new column is the lower part of the result */
4227       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4228
4229       /* get the carry word from the result */
4230       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4231     }
4232     /* set carry if it is placed below digs */
4233     if (ix + iy < digs) {
4234       *tmpt = u;
4235     }
4236   }
4237
4238   mp_clamp (&t);
4239   mp_exch (&t, c);
4240
4241   mp_clear (&t);
4242   return MP_OKAY;
4243 }
4244
4245 /* multiplies |a| * |b| and does not compute the lower digs digits
4246  * [meant to get the higher part of the product]
4247  */
4248 static int
4249 s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4250 {
4251   mp_int  t;
4252   int     res, pa, pb, ix, iy;
4253   mp_digit u;
4254   mp_word r;
4255   mp_digit tmpx, *tmpt, *tmpy;
4256
4257   /* can we use the fast multiplier? */
4258   if (((a->used + b->used + 1) < MP_WARRAY)
4259       && MIN (a->used, b->used) < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4260     return fast_s_mp_mul_high_digs (a, b, c, digs);
4261   }
4262
4263   if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
4264     return res;
4265   }
4266   t.used = a->used + b->used + 1;
4267
4268   pa = a->used;
4269   pb = b->used;
4270   for (ix = 0; ix < pa; ix++) {
4271     /* clear the carry */
4272     u = 0;
4273
4274     /* left hand side of A[ix] * B[iy] */
4275     tmpx = a->dp[ix];
4276
4277     /* alias to the address of where the digits will be stored */
4278     tmpt = &(t.dp[digs]);
4279
4280     /* alias for where to read the right hand side from */
4281     tmpy = b->dp + (digs - ix);
4282
4283     for (iy = digs - ix; iy < pb; iy++) {
4284       /* calculate the double precision result */
4285       r       = ((mp_word)*tmpt) +
4286                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4287                 ((mp_word) u);
4288
4289       /* get the lower part */
4290       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4291
4292       /* carry the carry */
4293       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4294     }
4295     *tmpt = u;
4296   }
4297   mp_clamp (&t);
4298   mp_exch (&t, c);
4299   mp_clear (&t);
4300   return MP_OKAY;
4301 }
4302
4303 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
4304 static int
4305 s_mp_sqr (const mp_int * a, mp_int * b)
4306 {
4307   mp_int  t;
4308   int     res, ix, iy, pa;
4309   mp_word r;
4310   mp_digit u, tmpx, *tmpt;
4311
4312   pa = a->used;
4313   if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
4314     return res;
4315   }
4316
4317   /* default used is maximum possible size */
4318   t.used = 2*pa + 1;
4319
4320   for (ix = 0; ix < pa; ix++) {
4321     /* first calculate the digit at 2*ix */
4322     /* calculate double precision result */
4323     r = ((mp_word) t.dp[2*ix]) +
4324         ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
4325
4326     /* store lower part in result */
4327     t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
4328
4329     /* get the carry */
4330     u           = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4331
4332     /* left hand side of A[ix] * A[iy] */
4333     tmpx        = a->dp[ix];
4334
4335     /* alias for where to store the results */
4336     tmpt        = t.dp + (2*ix + 1);
4337     
4338     for (iy = ix + 1; iy < pa; iy++) {
4339       /* first calculate the product */
4340       r       = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
4341
4342       /* now calculate the double precision result, note we use
4343        * addition instead of *2 since it's easier to optimize
4344        */
4345       r       = ((mp_word) *tmpt) + r + r + ((mp_word) u);
4346
4347       /* store lower part */
4348       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4349
4350       /* get carry */
4351       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4352     }
4353     /* propagate upwards */
4354     while (u != 0) {
4355       r       = ((mp_word) *tmpt) + ((mp_word) u);
4356       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4357       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4358     }
4359   }
4360
4361   mp_clamp (&t);
4362   mp_exch (&t, b);
4363   mp_clear (&t);
4364   return MP_OKAY;
4365 }
4366
4367 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
4368 int
4369 s_mp_sub (const mp_int * a, const mp_int * b, mp_int * c)
4370 {
4371   int     olduse, res, min, max;
4372
4373   /* find sizes */
4374   min = b->used;
4375   max = a->used;
4376
4377   /* init result */
4378   if (c->alloc < max) {
4379     if ((res = mp_grow (c, max)) != MP_OKAY) {
4380       return res;
4381     }
4382   }
4383   olduse = c->used;
4384   c->used = max;
4385
4386   {
4387     register mp_digit u, *tmpa, *tmpb, *tmpc;
4388     register int i;
4389
4390     /* alias for digit pointers */
4391     tmpa = a->dp;
4392     tmpb = b->dp;
4393     tmpc = c->dp;
4394
4395     /* set carry to zero */
4396     u = 0;
4397     for (i = 0; i < min; i++) {
4398       /* T[i] = A[i] - B[i] - U */
4399       *tmpc = *tmpa++ - *tmpb++ - u;
4400
4401       /* U = carry bit of T[i]
4402        * Note this saves performing an AND operation since
4403        * if a carry does occur it will propagate all the way to the
4404        * MSB.  As a result a single shift is enough to get the carry
4405        */
4406       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4407
4408       /* Clear carry from T[i] */
4409       *tmpc++ &= MP_MASK;
4410     }
4411
4412     /* now copy higher words if any, e.g. if A has more digits than B  */
4413     for (; i < max; i++) {
4414       /* T[i] = A[i] - U */
4415       *tmpc = *tmpa++ - u;
4416
4417       /* U = carry bit of T[i] */
4418       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4419
4420       /* Clear carry from T[i] */
4421       *tmpc++ &= MP_MASK;
4422     }
4423
4424     /* clear digits above used (since we may not have grown result above) */
4425     for (i = c->used; i < olduse; i++) {
4426       *tmpc++ = 0;
4427     }
4428   }
4429
4430   mp_clamp (c);
4431   return MP_OKAY;
4432 }