version/tests: Fix a copy-paste mistake.
[wine] / dlls / rsaenh / mpi.c
1 /*
2  * dlls/rsaenh/mpi.c
3  * Multi Precision Integer functions
4  *
5  * Copyright 2004 Michael Jung
6  * Based on public domain code by Tom St Denis (tomstdenis@iahu.ca)
7  *
8  * This library is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * This library is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with this library; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
21  */
22
23 /*
24  * This file contains code from the LibTomCrypt cryptographic 
25  * library written by Tom St Denis (tomstdenis@iahu.ca). LibTomCrypt
26  * is in the public domain. The code in this file is tailored to
27  * special requirements. Take a look at http://libtomcrypt.org for the
28  * original version. 
29  */
30
31 #include <stdarg.h>
32 #include "tomcrypt.h"
33
34 /* Known optimal configurations
35  CPU                    /Compiler     /MUL CUTOFF/SQR CUTOFF
36 -------------------------------------------------------------
37  Intel P4 Northwood     /GCC v3.4.1   /        88/       128/LTM 0.32 ;-)
38 */
39 static const int KARATSUBA_MUL_CUTOFF = 88,  /* Min. number of digits before Karatsuba multiplication is used. */
40                  KARATSUBA_SQR_CUTOFF = 128; /* Min. number of digits before Karatsuba squaring is used. */
41
42 static void bn_reverse(unsigned char *s, int len);
43 static int s_mp_add(mp_int *a, mp_int *b, mp_int *c);
44 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y);
45 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
46 static int s_mp_mul_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
47 static int s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
48 static int s_mp_sqr(const mp_int *a, mp_int *b);
49 static int s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c);
50 static int mp_exptmod_fast(const mp_int *G, const mp_int *X, mp_int *P, mp_int *Y, int mode);
51 static int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c);
52 static int mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c);
53 static int mp_karatsuba_sqr(const mp_int *a, mp_int *b);
54
55 /* computes the modular inverse via binary extended euclidean algorithm, 
56  * that is c = 1/a mod b 
57  *
58  * Based on slow invmod except this is optimized for the case where b is 
59  * odd as per HAC Note 14.64 on pp. 610
60  */
61 static int
62 fast_mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
63 {
64   mp_int  x, y, u, v, B, D;
65   int     res, neg;
66
67   /* 2. [modified] b must be odd   */
68   if (mp_iseven (b) == 1) {
69     return MP_VAL;
70   }
71
72   /* init all our temps */
73   if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D, NULL)) != MP_OKAY) {
74      return res;
75   }
76
77   /* x == modulus, y == value to invert */
78   if ((res = mp_copy (b, &x)) != MP_OKAY) {
79     goto __ERR;
80   }
81
82   /* we need y = |a| */
83   if ((res = mp_abs (a, &y)) != MP_OKAY) {
84     goto __ERR;
85   }
86
87   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
88   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
89     goto __ERR;
90   }
91   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
92     goto __ERR;
93   }
94   mp_set (&D, 1);
95
96 top:
97   /* 4.  while u is even do */
98   while (mp_iseven (&u) == 1) {
99     /* 4.1 u = u/2 */
100     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
101       goto __ERR;
102     }
103     /* 4.2 if B is odd then */
104     if (mp_isodd (&B) == 1) {
105       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
106         goto __ERR;
107       }
108     }
109     /* B = B/2 */
110     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
111       goto __ERR;
112     }
113   }
114
115   /* 5.  while v is even do */
116   while (mp_iseven (&v) == 1) {
117     /* 5.1 v = v/2 */
118     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
119       goto __ERR;
120     }
121     /* 5.2 if D is odd then */
122     if (mp_isodd (&D) == 1) {
123       /* D = (D-x)/2 */
124       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
125         goto __ERR;
126       }
127     }
128     /* D = D/2 */
129     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
130       goto __ERR;
131     }
132   }
133
134   /* 6.  if u >= v then */
135   if (mp_cmp (&u, &v) != MP_LT) {
136     /* u = u - v, B = B - D */
137     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
138       goto __ERR;
139     }
140
141     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
142       goto __ERR;
143     }
144   } else {
145     /* v - v - u, D = D - B */
146     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
147       goto __ERR;
148     }
149
150     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
151       goto __ERR;
152     }
153   }
154
155   /* if not zero goto step 4 */
156   if (mp_iszero (&u) == 0) {
157     goto top;
158   }
159
160   /* now a = C, b = D, gcd == g*v */
161
162   /* if v != 1 then there is no inverse */
163   if (mp_cmp_d (&v, 1) != MP_EQ) {
164     res = MP_VAL;
165     goto __ERR;
166   }
167
168   /* b is now the inverse */
169   neg = a->sign;
170   while (D.sign == MP_NEG) {
171     if ((res = mp_add (&D, b, &D)) != MP_OKAY) {
172       goto __ERR;
173     }
174   }
175   mp_exch (&D, c);
176   c->sign = neg;
177   res = MP_OKAY;
178
179 __ERR:mp_clear_multi (&x, &y, &u, &v, &B, &D, NULL);
180   return res;
181 }
182
183 /* computes xR**-1 == x (mod N) via Montgomery Reduction
184  *
185  * This is an optimized implementation of montgomery_reduce
186  * which uses the comba method to quickly calculate the columns of the
187  * reduction.
188  *
189  * Based on Algorithm 14.32 on pp.601 of HAC.
190 */
191 static int
192 fast_mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
193 {
194   int     ix, res, olduse;
195   mp_word W[MP_WARRAY];
196
197   /* get old used count */
198   olduse = x->used;
199
200   /* grow a as required */
201   if (x->alloc < n->used + 1) {
202     if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
203       return res;
204     }
205   }
206
207   /* first we have to get the digits of the input into
208    * an array of double precision words W[...]
209    */
210   {
211     register mp_word *_W;
212     register mp_digit *tmpx;
213
214     /* alias for the W[] array */
215     _W   = W;
216
217     /* alias for the digits of  x*/
218     tmpx = x->dp;
219
220     /* copy the digits of a into W[0..a->used-1] */
221     for (ix = 0; ix < x->used; ix++) {
222       *_W++ = *tmpx++;
223     }
224
225     /* zero the high words of W[a->used..m->used*2] */
226     for (; ix < n->used * 2 + 1; ix++) {
227       *_W++ = 0;
228     }
229   }
230
231   /* now we proceed to zero successive digits
232    * from the least significant upwards
233    */
234   for (ix = 0; ix < n->used; ix++) {
235     /* mu = ai * m' mod b
236      *
237      * We avoid a double precision multiplication (which isn't required)
238      * by casting the value down to a mp_digit.  Note this requires
239      * that W[ix-1] have  the carry cleared (see after the inner loop)
240      */
241     register mp_digit mu;
242     mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
243
244     /* a = a + mu * m * b**i
245      *
246      * This is computed in place and on the fly.  The multiplication
247      * by b**i is handled by offsetting which columns the results
248      * are added to.
249      *
250      * Note the comba method normally doesn't handle carries in the
251      * inner loop In this case we fix the carry from the previous
252      * column since the Montgomery reduction requires digits of the
253      * result (so far) [see above] to work.  This is
254      * handled by fixing up one carry after the inner loop.  The
255      * carry fixups are done in order so after these loops the
256      * first m->used words of W[] have the carries fixed
257      */
258     {
259       register int iy;
260       register mp_digit *tmpn;
261       register mp_word *_W;
262
263       /* alias for the digits of the modulus */
264       tmpn = n->dp;
265
266       /* Alias for the columns set by an offset of ix */
267       _W = W + ix;
268
269       /* inner loop */
270       for (iy = 0; iy < n->used; iy++) {
271           *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
272       }
273     }
274
275     /* now fix carry for next digit, W[ix+1] */
276     W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
277   }
278
279   /* now we have to propagate the carries and
280    * shift the words downward [all those least
281    * significant digits we zeroed].
282    */
283   {
284     register mp_digit *tmpx;
285     register mp_word *_W, *_W1;
286
287     /* nox fix rest of carries */
288
289     /* alias for current word */
290     _W1 = W + ix;
291
292     /* alias for next word, where the carry goes */
293     _W = W + ++ix;
294
295     for (; ix <= n->used * 2 + 1; ix++) {
296       *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
297     }
298
299     /* copy out, A = A/b**n
300      *
301      * The result is A/b**n but instead of converting from an
302      * array of mp_word to mp_digit than calling mp_rshd
303      * we just copy them in the right order
304      */
305
306     /* alias for destination word */
307     tmpx = x->dp;
308
309     /* alias for shifted double precision result */
310     _W = W + n->used;
311
312     for (ix = 0; ix < n->used + 1; ix++) {
313       *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
314     }
315
316     /* zero oldused digits, if the input a was larger than
317      * m->used+1 we'll have to clear the digits
318      */
319     for (; ix < olduse; ix++) {
320       *tmpx++ = 0;
321     }
322   }
323
324   /* set the max used and clamp */
325   x->used = n->used + 1;
326   mp_clamp (x);
327
328   /* if A >= m then A = A - m */
329   if (mp_cmp_mag (x, n) != MP_LT) {
330     return s_mp_sub (x, n, x);
331   }
332   return MP_OKAY;
333 }
334
335 /* Fast (comba) multiplier
336  *
337  * This is the fast column-array [comba] multiplier.  It is 
338  * designed to compute the columns of the product first 
339  * then handle the carries afterwards.  This has the effect 
340  * of making the nested loops that compute the columns very
341  * simple and schedulable on super-scalar processors.
342  *
343  * This has been modified to produce a variable number of 
344  * digits of output so if say only a half-product is required 
345  * you don't have to compute the upper half (a feature 
346  * required for fast Barrett reduction).
347  *
348  * Based on Algorithm 14.12 on pp.595 of HAC.
349  *
350  */
351 static int
352 fast_s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
353 {
354   int     olduse, res, pa, ix, iz;
355   mp_digit W[MP_WARRAY];
356   register mp_word  _W;
357
358   /* grow the destination as required */
359   if (c->alloc < digs) {
360     if ((res = mp_grow (c, digs)) != MP_OKAY) {
361       return res;
362     }
363   }
364
365   /* number of output digits to produce */
366   pa = MIN(digs, a->used + b->used);
367
368   /* clear the carry */
369   _W = 0;
370   for (ix = 0; ix <= pa; ix++) { 
371       int      tx, ty;
372       int      iy;
373       mp_digit *tmpx, *tmpy;
374
375       /* get offsets into the two bignums */
376       ty = MIN(b->used-1, ix);
377       tx = ix - ty;
378
379       /* setup temp aliases */
380       tmpx = a->dp + tx;
381       tmpy = b->dp + ty;
382
383       /* This is the number of times the loop will iterate, essentially it's
384          while (tx++ < a->used && ty-- >= 0) { ... }
385        */
386       iy = MIN(a->used-tx, ty+1);
387
388       /* execute loop */
389       for (iz = 0; iz < iy; ++iz) {
390          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
391       }
392
393       /* store term */
394       W[ix] = ((mp_digit)_W) & MP_MASK;
395
396       /* make next carry */
397       _W = _W >> ((mp_word)DIGIT_BIT);
398   }
399
400   /* setup dest */
401   olduse  = c->used;
402   c->used = digs;
403
404   {
405     register mp_digit *tmpc;
406     tmpc = c->dp;
407     for (ix = 0; ix < digs; ix++) {
408       /* now extract the previous digit [below the carry] */
409       *tmpc++ = W[ix];
410     }
411
412     /* clear unused digits [that existed in the old copy of c] */
413     for (; ix < olduse; ix++) {
414       *tmpc++ = 0;
415     }
416   }
417   mp_clamp (c);
418   return MP_OKAY;
419 }
420
421 /* this is a modified version of fast_s_mul_digs that only produces
422  * output digits *above* digs.  See the comments for fast_s_mul_digs
423  * to see how it works.
424  *
425  * This is used in the Barrett reduction since for one of the multiplications
426  * only the higher digits were needed.  This essentially halves the work.
427  *
428  * Based on Algorithm 14.12 on pp.595 of HAC.
429  */
430 static int
431 fast_s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
432 {
433   int     olduse, res, pa, ix, iz;
434   mp_digit W[MP_WARRAY];
435   mp_word  _W;
436
437   /* grow the destination as required */
438   pa = a->used + b->used;
439   if (c->alloc < pa) {
440     if ((res = mp_grow (c, pa)) != MP_OKAY) {
441       return res;
442     }
443   }
444
445   /* number of output digits to produce */
446   pa = a->used + b->used;
447   _W = 0;
448   for (ix = digs; ix <= pa; ix++) { 
449       int      tx, ty, iy;
450       mp_digit *tmpx, *tmpy;
451
452       /* get offsets into the two bignums */
453       ty = MIN(b->used-1, ix);
454       tx = ix - ty;
455
456       /* setup temp aliases */
457       tmpx = a->dp + tx;
458       tmpy = b->dp + ty;
459
460       /* This is the number of times the loop will iterate, essentially it's
461          while (tx++ < a->used && ty-- >= 0) { ... }
462        */
463       iy = MIN(a->used-tx, ty+1);
464
465       /* execute loop */
466       for (iz = 0; iz < iy; iz++) {
467          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
468       }
469
470       /* store term */
471       W[ix] = ((mp_digit)_W) & MP_MASK;
472
473       /* make next carry */
474       _W = _W >> ((mp_word)DIGIT_BIT);
475   }
476
477   /* setup dest */
478   olduse  = c->used;
479   c->used = pa;
480
481   {
482     register mp_digit *tmpc;
483
484     tmpc = c->dp + digs;
485     for (ix = digs; ix <= pa; ix++) {
486       /* now extract the previous digit [below the carry] */
487       *tmpc++ = W[ix];
488     }
489
490     /* clear unused digits [that existed in the old copy of c] */
491     for (; ix < olduse; ix++) {
492       *tmpc++ = 0;
493     }
494   }
495   mp_clamp (c);
496   return MP_OKAY;
497 }
498
499 /* fast squaring
500  *
501  * This is the comba method where the columns of the product
502  * are computed first then the carries are computed.  This
503  * has the effect of making a very simple inner loop that
504  * is executed the most
505  *
506  * W2 represents the outer products and W the inner.
507  *
508  * A further optimizations is made because the inner
509  * products are of the form "A * B * 2".  The *2 part does
510  * not need to be computed until the end which is good
511  * because 64-bit shifts are slow!
512  *
513  * Based on Algorithm 14.16 on pp.597 of HAC.
514  *
515  */
516 /* the jist of squaring...
517
518 you do like mult except the offset of the tmpx [one that starts closer to zero]
519 can't equal the offset of tmpy.  So basically you set up iy like before then you min it with
520 (ty-tx) so that it never happens.  You double all those you add in the inner loop
521
522 After that loop you do the squares and add them in.
523
524 Remove W2 and don't memset W
525
526 */
527
528 static int fast_s_mp_sqr (const mp_int * a, mp_int * b)
529 {
530   int       olduse, res, pa, ix, iz;
531   mp_digit   W[MP_WARRAY], *tmpx;
532   mp_word   W1;
533
534   /* grow the destination as required */
535   pa = a->used + a->used;
536   if (b->alloc < pa) {
537     if ((res = mp_grow (b, pa)) != MP_OKAY) {
538       return res;
539     }
540   }
541
542   /* number of output digits to produce */
543   W1 = 0;
544   for (ix = 0; ix <= pa; ix++) { 
545       int      tx, ty, iy;
546       mp_word  _W;
547       mp_digit *tmpy;
548
549       /* clear counter */
550       _W = 0;
551
552       /* get offsets into the two bignums */
553       ty = MIN(a->used-1, ix);
554       tx = ix - ty;
555
556       /* setup temp aliases */
557       tmpx = a->dp + tx;
558       tmpy = a->dp + ty;
559
560       /* This is the number of times the loop will iterate, essentially it's
561          while (tx++ < a->used && ty-- >= 0) { ... }
562        */
563       iy = MIN(a->used-tx, ty+1);
564
565       /* now for squaring tx can never equal ty 
566        * we halve the distance since they approach at a rate of 2x
567        * and we have to round because odd cases need to be executed
568        */
569       iy = MIN(iy, (ty-tx+1)>>1);
570
571       /* execute loop */
572       for (iz = 0; iz < iy; iz++) {
573          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
574       }
575
576       /* double the inner product and add carry */
577       _W = _W + _W + W1;
578
579       /* even columns have the square term in them */
580       if ((ix&1) == 0) {
581          _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
582       }
583
584       /* store it */
585       W[ix] = _W;
586
587       /* make next carry */
588       W1 = _W >> ((mp_word)DIGIT_BIT);
589   }
590
591   /* setup dest */
592   olduse  = b->used;
593   b->used = a->used+a->used;
594
595   {
596     mp_digit *tmpb;
597     tmpb = b->dp;
598     for (ix = 0; ix < pa; ix++) {
599       *tmpb++ = W[ix] & MP_MASK;
600     }
601
602     /* clear unused digits [that existed in the old copy of c] */
603     for (; ix < olduse; ix++) {
604       *tmpb++ = 0;
605     }
606   }
607   mp_clamp (b);
608   return MP_OKAY;
609 }
610
611 /* computes a = 2**b 
612  *
613  * Simple algorithm which zeroes the int, grows it then just sets one bit
614  * as required.
615  */
616 int
617 mp_2expt (mp_int * a, int b)
618 {
619   int     res;
620
621   /* zero a as per default */
622   mp_zero (a);
623
624   /* grow a to accommodate the single bit */
625   if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
626     return res;
627   }
628
629   /* set the used count of where the bit will go */
630   a->used = b / DIGIT_BIT + 1;
631
632   /* put the single bit in its place */
633   a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
634
635   return MP_OKAY;
636 }
637
638 /* b = |a| 
639  *
640  * Simple function copies the input and fixes the sign to positive
641  */
642 int
643 mp_abs (const mp_int * a, mp_int * b)
644 {
645   int     res;
646
647   /* copy a to b */
648   if (a != b) {
649      if ((res = mp_copy (a, b)) != MP_OKAY) {
650        return res;
651      }
652   }
653
654   /* force the sign of b to positive */
655   b->sign = MP_ZPOS;
656
657   return MP_OKAY;
658 }
659
660 /* high level addition (handles signs) */
661 int mp_add (mp_int * a, mp_int * b, mp_int * c)
662 {
663   int     sa, sb, res;
664
665   /* get sign of both inputs */
666   sa = a->sign;
667   sb = b->sign;
668
669   /* handle two cases, not four */
670   if (sa == sb) {
671     /* both positive or both negative */
672     /* add their magnitudes, copy the sign */
673     c->sign = sa;
674     res = s_mp_add (a, b, c);
675   } else {
676     /* one positive, the other negative */
677     /* subtract the one with the greater magnitude from */
678     /* the one of the lesser magnitude.  The result gets */
679     /* the sign of the one with the greater magnitude. */
680     if (mp_cmp_mag (a, b) == MP_LT) {
681       c->sign = sb;
682       res = s_mp_sub (b, a, c);
683     } else {
684       c->sign = sa;
685       res = s_mp_sub (a, b, c);
686     }
687   }
688   return res;
689 }
690
691
692 /* single digit addition */
693 int
694 mp_add_d (mp_int * a, mp_digit b, mp_int * c)
695 {
696   int     res, ix, oldused;
697   mp_digit *tmpa, *tmpc, mu;
698
699   /* grow c as required */
700   if (c->alloc < a->used + 1) {
701      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
702         return res;
703      }
704   }
705
706   /* if a is negative and |a| >= b, call c = |a| - b */
707   if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
708      /* temporarily fix sign of a */
709      a->sign = MP_ZPOS;
710
711      /* c = |a| - b */
712      res = mp_sub_d(a, b, c);
713
714      /* fix sign  */
715      a->sign = c->sign = MP_NEG;
716
717      return res;
718   }
719
720   /* old number of used digits in c */
721   oldused = c->used;
722
723   /* sign always positive */
724   c->sign = MP_ZPOS;
725
726   /* source alias */
727   tmpa    = a->dp;
728
729   /* destination alias */
730   tmpc    = c->dp;
731
732   /* if a is positive */
733   if (a->sign == MP_ZPOS) {
734      /* add digit, after this we're propagating
735       * the carry.
736       */
737      *tmpc   = *tmpa++ + b;
738      mu      = *tmpc >> DIGIT_BIT;
739      *tmpc++ &= MP_MASK;
740
741      /* now handle rest of the digits */
742      for (ix = 1; ix < a->used; ix++) {
743         *tmpc   = *tmpa++ + mu;
744         mu      = *tmpc >> DIGIT_BIT;
745         *tmpc++ &= MP_MASK;
746      }
747      /* set final carry */
748      ix++;
749      *tmpc++  = mu;
750
751      /* setup size */
752      c->used = a->used + 1;
753   } else {
754      /* a was negative and |a| < b */
755      c->used  = 1;
756
757      /* the result is a single digit */
758      if (a->used == 1) {
759         *tmpc++  =  b - a->dp[0];
760      } else {
761         *tmpc++  =  b;
762      }
763
764      /* setup count so the clearing of oldused
765       * can fall through correctly
766       */
767      ix       = 1;
768   }
769
770   /* now zero to oldused */
771   while (ix++ < oldused) {
772      *tmpc++ = 0;
773   }
774   mp_clamp(c);
775
776   return MP_OKAY;
777 }
778
779 /* trim unused digits 
780  *
781  * This is used to ensure that leading zero digits are
782  * trimed and the leading "used" digit will be non-zero
783  * Typically very fast.  Also fixes the sign if there
784  * are no more leading digits
785  */
786 void
787 mp_clamp (mp_int * a)
788 {
789   /* decrease used while the most significant digit is
790    * zero.
791    */
792   while (a->used > 0 && a->dp[a->used - 1] == 0) {
793     --(a->used);
794   }
795
796   /* reset the sign flag if used == 0 */
797   if (a->used == 0) {
798     a->sign = MP_ZPOS;
799   }
800 }
801
802 /* clear one (frees)  */
803 void
804 mp_clear (mp_int * a)
805 {
806   int i;
807
808   /* only do anything if a hasn't been freed previously */
809   if (a->dp != NULL) {
810     /* first zero the digits */
811     for (i = 0; i < a->used; i++) {
812         a->dp[i] = 0;
813     }
814
815     /* free ram */
816     free(a->dp);
817
818     /* reset members to make debugging easier */
819     a->dp    = NULL;
820     a->alloc = a->used = 0;
821     a->sign  = MP_ZPOS;
822   }
823 }
824
825
826 void mp_clear_multi(mp_int *mp, ...) 
827 {
828     mp_int* next_mp = mp;
829     va_list args;
830     va_start(args, mp);
831     while (next_mp != NULL) {
832         mp_clear(next_mp);
833         next_mp = va_arg(args, mp_int*);
834     }
835     va_end(args);
836 }
837
838 /* compare two ints (signed)*/
839 int
840 mp_cmp (const mp_int * a, const mp_int * b)
841 {
842   /* compare based on sign */
843   if (a->sign != b->sign) {
844      if (a->sign == MP_NEG) {
845         return MP_LT;
846      } else {
847         return MP_GT;
848      }
849   }
850   
851   /* compare digits */
852   if (a->sign == MP_NEG) {
853      /* if negative compare opposite direction */
854      return mp_cmp_mag(b, a);
855   } else {
856      return mp_cmp_mag(a, b);
857   }
858 }
859
860 /* compare a digit */
861 int mp_cmp_d(const mp_int * a, mp_digit b)
862 {
863   /* compare based on sign */
864   if (a->sign == MP_NEG) {
865     return MP_LT;
866   }
867
868   /* compare based on magnitude */
869   if (a->used > 1) {
870     return MP_GT;
871   }
872
873   /* compare the only digit of a to b */
874   if (a->dp[0] > b) {
875     return MP_GT;
876   } else if (a->dp[0] < b) {
877     return MP_LT;
878   } else {
879     return MP_EQ;
880   }
881 }
882
883 /* compare maginitude of two ints (unsigned) */
884 int mp_cmp_mag (const mp_int * a, const mp_int * b)
885 {
886   int     n;
887   mp_digit *tmpa, *tmpb;
888
889   /* compare based on # of non-zero digits */
890   if (a->used > b->used) {
891     return MP_GT;
892   }
893   
894   if (a->used < b->used) {
895     return MP_LT;
896   }
897
898   /* alias for a */
899   tmpa = a->dp + (a->used - 1);
900
901   /* alias for b */
902   tmpb = b->dp + (a->used - 1);
903
904   /* compare based on digits  */
905   for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
906     if (*tmpa > *tmpb) {
907       return MP_GT;
908     }
909
910     if (*tmpa < *tmpb) {
911       return MP_LT;
912     }
913   }
914   return MP_EQ;
915 }
916
917 static const int lnz[16] = { 
918    4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
919 };
920
921 /* Counts the number of lsbs which are zero before the first zero bit */
922 int mp_cnt_lsb(const mp_int *a)
923 {
924    int x;
925    mp_digit q, qq;
926
927    /* easy out */
928    if (mp_iszero(a) == 1) {
929       return 0;
930    }
931
932    /* scan lower digits until non-zero */
933    for (x = 0; x < a->used && a->dp[x] == 0; x++);
934    q = a->dp[x];
935    x *= DIGIT_BIT;
936
937    /* now scan this digit until a 1 is found */
938    if ((q & 1) == 0) {
939       do {
940          qq  = q & 15;
941          x  += lnz[qq];
942          q >>= 4;
943       } while (qq == 0);
944    }
945    return x;
946 }
947
948 /* copy, b = a */
949 int
950 mp_copy (const mp_int * a, mp_int * b)
951 {
952   int     res, n;
953
954   /* if dst == src do nothing */
955   if (a == b) {
956     return MP_OKAY;
957   }
958
959   /* grow dest */
960   if (b->alloc < a->used) {
961      if ((res = mp_grow (b, a->used)) != MP_OKAY) {
962         return res;
963      }
964   }
965
966   /* zero b and copy the parameters over */
967   {
968     register mp_digit *tmpa, *tmpb;
969
970     /* pointer aliases */
971
972     /* source */
973     tmpa = a->dp;
974
975     /* destination */
976     tmpb = b->dp;
977
978     /* copy all the digits */
979     for (n = 0; n < a->used; n++) {
980       *tmpb++ = *tmpa++;
981     }
982
983     /* clear high digits */
984     for (; n < b->used; n++) {
985       *tmpb++ = 0;
986     }
987   }
988
989   /* copy used count and sign */
990   b->used = a->used;
991   b->sign = a->sign;
992   return MP_OKAY;
993 }
994
995 /* returns the number of bits in an int */
996 int
997 mp_count_bits (const mp_int * a)
998 {
999   int     r;
1000   mp_digit q;
1001
1002   /* shortcut */
1003   if (a->used == 0) {
1004     return 0;
1005   }
1006
1007   /* get number of digits and add that */
1008   r = (a->used - 1) * DIGIT_BIT;
1009   
1010   /* take the last digit and count the bits in it */
1011   q = a->dp[a->used - 1];
1012   while (q > ((mp_digit) 0)) {
1013     ++r;
1014     q >>= ((mp_digit) 1);
1015   }
1016   return r;
1017 }
1018
1019 /* integer signed division. 
1020  * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
1021  * HAC pp.598 Algorithm 14.20
1022  *
1023  * Note that the description in HAC is horribly 
1024  * incomplete.  For example, it doesn't consider 
1025  * the case where digits are removed from 'x' in 
1026  * the inner loop.  It also doesn't consider the 
1027  * case that y has fewer than three digits, etc..
1028  *
1029  * The overall algorithm is as described as 
1030  * 14.20 from HAC but fixed to treat these cases.
1031 */
1032 int mp_div (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
1033 {
1034   mp_int  q, x, y, t1, t2;
1035   int     res, n, t, i, norm, neg;
1036
1037   /* is divisor zero ? */
1038   if (mp_iszero (b) == 1) {
1039     return MP_VAL;
1040   }
1041
1042   /* if a < b then q=0, r = a */
1043   if (mp_cmp_mag (a, b) == MP_LT) {
1044     if (d != NULL) {
1045       res = mp_copy (a, d);
1046     } else {
1047       res = MP_OKAY;
1048     }
1049     if (c != NULL) {
1050       mp_zero (c);
1051     }
1052     return res;
1053   }
1054
1055   if ((res = mp_init_size (&q, a->used + 2)) != MP_OKAY) {
1056     return res;
1057   }
1058   q.used = a->used + 2;
1059
1060   if ((res = mp_init (&t1)) != MP_OKAY) {
1061     goto __Q;
1062   }
1063
1064   if ((res = mp_init (&t2)) != MP_OKAY) {
1065     goto __T1;
1066   }
1067
1068   if ((res = mp_init_copy (&x, a)) != MP_OKAY) {
1069     goto __T2;
1070   }
1071
1072   if ((res = mp_init_copy (&y, b)) != MP_OKAY) {
1073     goto __X;
1074   }
1075
1076   /* fix the sign */
1077   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1078   x.sign = y.sign = MP_ZPOS;
1079
1080   /* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1081   norm = mp_count_bits(&y) % DIGIT_BIT;
1082   if (norm < DIGIT_BIT-1) {
1083      norm = (DIGIT_BIT-1) - norm;
1084      if ((res = mp_mul_2d (&x, norm, &x)) != MP_OKAY) {
1085        goto __Y;
1086      }
1087      if ((res = mp_mul_2d (&y, norm, &y)) != MP_OKAY) {
1088        goto __Y;
1089      }
1090   } else {
1091      norm = 0;
1092   }
1093
1094   /* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
1095   n = x.used - 1;
1096   t = y.used - 1;
1097
1098   /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1099   if ((res = mp_lshd (&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1100     goto __Y;
1101   }
1102
1103   while (mp_cmp (&x, &y) != MP_LT) {
1104     ++(q.dp[n - t]);
1105     if ((res = mp_sub (&x, &y, &x)) != MP_OKAY) {
1106       goto __Y;
1107     }
1108   }
1109
1110   /* reset y by shifting it back down */
1111   mp_rshd (&y, n - t);
1112
1113   /* step 3. for i from n down to (t + 1) */
1114   for (i = n; i >= (t + 1); i--) {
1115     if (i > x.used) {
1116       continue;
1117     }
1118
1119     /* step 3.1 if xi == yt then set q{i-t-1} to b-1, 
1120      * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1121     if (x.dp[i] == y.dp[t]) {
1122       q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1123     } else {
1124       mp_word tmp;
1125       tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1126       tmp |= ((mp_word) x.dp[i - 1]);
1127       tmp /= ((mp_word) y.dp[t]);
1128       if (tmp > (mp_word) MP_MASK)
1129         tmp = MP_MASK;
1130       q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1131     }
1132
1133     /* while (q{i-t-1} * (yt * b + y{t-1})) > 
1134              xi * b**2 + xi-1 * b + xi-2 
1135      
1136        do q{i-t-1} -= 1; 
1137     */
1138     q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1139     do {
1140       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1141
1142       /* find left hand */
1143       mp_zero (&t1);
1144       t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1145       t1.dp[1] = y.dp[t];
1146       t1.used = 2;
1147       if ((res = mp_mul_d (&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1148         goto __Y;
1149       }
1150
1151       /* find right hand */
1152       t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1153       t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1154       t2.dp[2] = x.dp[i];
1155       t2.used = 3;
1156     } while (mp_cmp_mag(&t1, &t2) == MP_GT);
1157
1158     /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1159     if ((res = mp_mul_d (&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1160       goto __Y;
1161     }
1162
1163     if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1164       goto __Y;
1165     }
1166
1167     if ((res = mp_sub (&x, &t1, &x)) != MP_OKAY) {
1168       goto __Y;
1169     }
1170
1171     /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1172     if (x.sign == MP_NEG) {
1173       if ((res = mp_copy (&y, &t1)) != MP_OKAY) {
1174         goto __Y;
1175       }
1176       if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1177         goto __Y;
1178       }
1179       if ((res = mp_add (&x, &t1, &x)) != MP_OKAY) {
1180         goto __Y;
1181       }
1182
1183       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1184     }
1185   }
1186
1187   /* now q is the quotient and x is the remainder 
1188    * [which we have to normalize] 
1189    */
1190   
1191   /* get sign before writing to c */
1192   x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1193
1194   if (c != NULL) {
1195     mp_clamp (&q);
1196     mp_exch (&q, c);
1197     c->sign = neg;
1198   }
1199
1200   if (d != NULL) {
1201     mp_div_2d (&x, norm, &x, NULL);
1202     mp_exch (&x, d);
1203   }
1204
1205   res = MP_OKAY;
1206
1207 __Y:mp_clear (&y);
1208 __X:mp_clear (&x);
1209 __T2:mp_clear (&t2);
1210 __T1:mp_clear (&t1);
1211 __Q:mp_clear (&q);
1212   return res;
1213 }
1214
1215 /* b = a/2 */
1216 int mp_div_2(const mp_int * a, mp_int * b)
1217 {
1218   int     x, res, oldused;
1219
1220   /* copy */
1221   if (b->alloc < a->used) {
1222     if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1223       return res;
1224     }
1225   }
1226
1227   oldused = b->used;
1228   b->used = a->used;
1229   {
1230     register mp_digit r, rr, *tmpa, *tmpb;
1231
1232     /* source alias */
1233     tmpa = a->dp + b->used - 1;
1234
1235     /* dest alias */
1236     tmpb = b->dp + b->used - 1;
1237
1238     /* carry */
1239     r = 0;
1240     for (x = b->used - 1; x >= 0; x--) {
1241       /* get the carry for the next iteration */
1242       rr = *tmpa & 1;
1243
1244       /* shift the current digit, add in carry and store */
1245       *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
1246
1247       /* forward carry to next iteration */
1248       r = rr;
1249     }
1250
1251     /* zero excess digits */
1252     tmpb = b->dp + b->used;
1253     for (x = b->used; x < oldused; x++) {
1254       *tmpb++ = 0;
1255     }
1256   }
1257   b->sign = a->sign;
1258   mp_clamp (b);
1259   return MP_OKAY;
1260 }
1261
1262 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
1263 int mp_div_2d (const mp_int * a, int b, mp_int * c, mp_int * d)
1264 {
1265   mp_digit D, r, rr;
1266   int     x, res;
1267   mp_int  t;
1268
1269
1270   /* if the shift count is <= 0 then we do no work */
1271   if (b <= 0) {
1272     res = mp_copy (a, c);
1273     if (d != NULL) {
1274       mp_zero (d);
1275     }
1276     return res;
1277   }
1278
1279   if ((res = mp_init (&t)) != MP_OKAY) {
1280     return res;
1281   }
1282
1283   /* get the remainder */
1284   if (d != NULL) {
1285     if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
1286       mp_clear (&t);
1287       return res;
1288     }
1289   }
1290
1291   /* copy */
1292   if ((res = mp_copy (a, c)) != MP_OKAY) {
1293     mp_clear (&t);
1294     return res;
1295   }
1296
1297   /* shift by as many digits in the bit count */
1298   if (b >= DIGIT_BIT) {
1299     mp_rshd (c, b / DIGIT_BIT);
1300   }
1301
1302   /* shift any bit count < DIGIT_BIT */
1303   D = (mp_digit) (b % DIGIT_BIT);
1304   if (D != 0) {
1305     register mp_digit *tmpc, mask, shift;
1306
1307     /* mask */
1308     mask = (((mp_digit)1) << D) - 1;
1309
1310     /* shift for lsb */
1311     shift = DIGIT_BIT - D;
1312
1313     /* alias */
1314     tmpc = c->dp + (c->used - 1);
1315
1316     /* carry */
1317     r = 0;
1318     for (x = c->used - 1; x >= 0; x--) {
1319       /* get the lower  bits of this word in a temp */
1320       rr = *tmpc & mask;
1321
1322       /* shift the current word and mix in the carry bits from the previous word */
1323       *tmpc = (*tmpc >> D) | (r << shift);
1324       --tmpc;
1325
1326       /* set the carry to the carry bits of the current word found above */
1327       r = rr;
1328     }
1329   }
1330   mp_clamp (c);
1331   if (d != NULL) {
1332     mp_exch (&t, d);
1333   }
1334   mp_clear (&t);
1335   return MP_OKAY;
1336 }
1337
1338 static int s_is_power_of_two(mp_digit b, int *p)
1339 {
1340    int x;
1341
1342    for (x = 1; x < DIGIT_BIT; x++) {
1343       if (b == (((mp_digit)1)<<x)) {
1344          *p = x;
1345          return 1;
1346       }
1347    }
1348    return 0;
1349 }
1350
1351 /* single digit division (based on routine from MPI) */
1352 int mp_div_d (const mp_int * a, mp_digit b, mp_int * c, mp_digit * d)
1353 {
1354   mp_int  q;
1355   mp_word w;
1356   mp_digit t;
1357   int     res, ix;
1358
1359   /* cannot divide by zero */
1360   if (b == 0) {
1361      return MP_VAL;
1362   }
1363
1364   /* quick outs */
1365   if (b == 1 || mp_iszero(a) == 1) {
1366      if (d != NULL) {
1367         *d = 0;
1368      }
1369      if (c != NULL) {
1370         return mp_copy(a, c);
1371      }
1372      return MP_OKAY;
1373   }
1374
1375   /* power of two ? */
1376   if (s_is_power_of_two(b, &ix) == 1) {
1377      if (d != NULL) {
1378         *d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
1379      }
1380      if (c != NULL) {
1381         return mp_div_2d(a, ix, c, NULL);
1382      }
1383      return MP_OKAY;
1384   }
1385
1386   /* no easy answer [c'est la vie].  Just division */
1387   if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
1388      return res;
1389   }
1390   
1391   q.used = a->used;
1392   q.sign = a->sign;
1393   w = 0;
1394   for (ix = a->used - 1; ix >= 0; ix--) {
1395      w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
1396      
1397      if (w >= b) {
1398         t = (mp_digit)(w / b);
1399         w -= ((mp_word)t) * ((mp_word)b);
1400       } else {
1401         t = 0;
1402       }
1403       q.dp[ix] = t;
1404   }
1405
1406   if (d != NULL) {
1407      *d = (mp_digit)w;
1408   }
1409   
1410   if (c != NULL) {
1411      mp_clamp(&q);
1412      mp_exch(&q, c);
1413   }
1414   mp_clear(&q);
1415   
1416   return res;
1417 }
1418
1419 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
1420  *
1421  * Based on algorithm from the paper
1422  *
1423  * "Generating Efficient Primes for Discrete Log Cryptosystems"
1424  *                 Chae Hoon Lim, Pil Loong Lee,
1425  *          POSTECH Information Research Laboratories
1426  *
1427  * The modulus must be of a special format [see manual]
1428  *
1429  * Has been modified to use algorithm 7.10 from the LTM book instead
1430  *
1431  * Input x must be in the range 0 <= x <= (n-1)**2
1432  */
1433 int
1434 mp_dr_reduce (mp_int * x, const mp_int * n, mp_digit k)
1435 {
1436   int      err, i, m;
1437   mp_word  r;
1438   mp_digit mu, *tmpx1, *tmpx2;
1439
1440   /* m = digits in modulus */
1441   m = n->used;
1442
1443   /* ensure that "x" has at least 2m digits */
1444   if (x->alloc < m + m) {
1445     if ((err = mp_grow (x, m + m)) != MP_OKAY) {
1446       return err;
1447     }
1448   }
1449
1450 /* top of loop, this is where the code resumes if
1451  * another reduction pass is required.
1452  */
1453 top:
1454   /* aliases for digits */
1455   /* alias for lower half of x */
1456   tmpx1 = x->dp;
1457
1458   /* alias for upper half of x, or x/B**m */
1459   tmpx2 = x->dp + m;
1460
1461   /* set carry to zero */
1462   mu = 0;
1463
1464   /* compute (x mod B**m) + k * [x/B**m] inline and inplace */
1465   for (i = 0; i < m; i++) {
1466       r         = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
1467       *tmpx1++  = (mp_digit)(r & MP_MASK);
1468       mu        = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
1469   }
1470
1471   /* set final carry */
1472   *tmpx1++ = mu;
1473
1474   /* zero words above m */
1475   for (i = m + 1; i < x->used; i++) {
1476       *tmpx1++ = 0;
1477   }
1478
1479   /* clamp, sub and return */
1480   mp_clamp (x);
1481
1482   /* if x >= n then subtract and reduce again
1483    * Each successive "recursion" makes the input smaller and smaller.
1484    */
1485   if (mp_cmp_mag (x, n) != MP_LT) {
1486     s_mp_sub(x, n, x);
1487     goto top;
1488   }
1489   return MP_OKAY;
1490 }
1491
1492 /* determines the setup value */
1493 void mp_dr_setup(const mp_int *a, mp_digit *d)
1494 {
1495    /* the casts are required if DIGIT_BIT is one less than
1496     * the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
1497     */
1498    *d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) - 
1499         ((mp_word)a->dp[0]));
1500 }
1501
1502 /* swap the elements of two integers, for cases where you can't simply swap the 
1503  * mp_int pointers around
1504  */
1505 void
1506 mp_exch (mp_int * a, mp_int * b)
1507 {
1508   mp_int  t;
1509
1510   t  = *a;
1511   *a = *b;
1512   *b = t;
1513 }
1514
1515 /* this is a shell function that calls either the normal or Montgomery
1516  * exptmod functions.  Originally the call to the montgomery code was
1517  * embedded in the normal function but that wasted a lot of stack space
1518  * for nothing (since 99% of the time the Montgomery code would be called)
1519  */
1520 int mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
1521 {
1522   int dr;
1523
1524   /* modulus P must be positive */
1525   if (P->sign == MP_NEG) {
1526      return MP_VAL;
1527   }
1528
1529   /* if exponent X is negative we have to recurse */
1530   if (X->sign == MP_NEG) {
1531      mp_int tmpG, tmpX;
1532      int err;
1533
1534      /* first compute 1/G mod P */
1535      if ((err = mp_init(&tmpG)) != MP_OKAY) {
1536         return err;
1537      }
1538      if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
1539         mp_clear(&tmpG);
1540         return err;
1541      }
1542
1543      /* now get |X| */
1544      if ((err = mp_init(&tmpX)) != MP_OKAY) {
1545         mp_clear(&tmpG);
1546         return err;
1547      }
1548      if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
1549         mp_clear_multi(&tmpG, &tmpX, NULL);
1550         return err;
1551      }
1552
1553      /* and now compute (1/G)**|X| instead of G**X [X < 0] */
1554      err = mp_exptmod(&tmpG, &tmpX, P, Y);
1555      mp_clear_multi(&tmpG, &tmpX, NULL);
1556      return err;
1557   }
1558
1559   dr = 0;
1560
1561   /* if the modulus is odd or dr != 0 use the fast method */
1562   if (mp_isodd (P) == 1 || dr !=  0) {
1563     return mp_exptmod_fast (G, X, P, Y, dr);
1564   } else {
1565     /* otherwise use the generic Barrett reduction technique */
1566     return s_mp_exptmod (G, X, P, Y);
1567   }
1568 }
1569
1570 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
1571  *
1572  * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
1573  * The value of k changes based on the size of the exponent.
1574  *
1575  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
1576  */
1577
1578 int
1579 mp_exptmod_fast (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y, int redmode)
1580 {
1581   mp_int  M[256], res;
1582   mp_digit buf, mp;
1583   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1584
1585   /* use a pointer to the reduction algorithm.  This allows us to use
1586    * one of many reduction algorithms without modding the guts of
1587    * the code with if statements everywhere.
1588    */
1589   int     (*redux)(mp_int*,const mp_int*,mp_digit);
1590
1591   /* find window size */
1592   x = mp_count_bits (X);
1593   if (x <= 7) {
1594     winsize = 2;
1595   } else if (x <= 36) {
1596     winsize = 3;
1597   } else if (x <= 140) {
1598     winsize = 4;
1599   } else if (x <= 450) {
1600     winsize = 5;
1601   } else if (x <= 1303) {
1602     winsize = 6;
1603   } else if (x <= 3529) {
1604     winsize = 7;
1605   } else {
1606     winsize = 8;
1607   }
1608
1609   /* init M array */
1610   /* init first cell */
1611   if ((err = mp_init(&M[1])) != MP_OKAY) {
1612      return err;
1613   }
1614
1615   /* now init the second half of the array */
1616   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
1617     if ((err = mp_init(&M[x])) != MP_OKAY) {
1618       for (y = 1<<(winsize-1); y < x; y++) {
1619         mp_clear (&M[y]);
1620       }
1621       mp_clear(&M[1]);
1622       return err;
1623     }
1624   }
1625
1626   /* determine and setup reduction code */
1627   if (redmode == 0) {
1628      /* now setup montgomery  */
1629      if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
1630         goto __M;
1631      }
1632
1633      /* automatically pick the comba one if available (saves quite a few calls/ifs) */
1634      if (((P->used * 2 + 1) < MP_WARRAY) &&
1635           P->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
1636         redux = fast_mp_montgomery_reduce;
1637      } else {
1638         /* use slower baseline Montgomery method */
1639         redux = mp_montgomery_reduce;
1640      }
1641   } else if (redmode == 1) {
1642      /* setup DR reduction for moduli of the form B**k - b */
1643      mp_dr_setup(P, &mp);
1644      redux = mp_dr_reduce;
1645   } else {
1646      /* setup DR reduction for moduli of the form 2**k - b */
1647      if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
1648         goto __M;
1649      }
1650      redux = mp_reduce_2k;
1651   }
1652
1653   /* setup result */
1654   if ((err = mp_init (&res)) != MP_OKAY) {
1655     goto __M;
1656   }
1657
1658   /* create M table
1659    *
1660
1661    *
1662    * The first half of the table is not computed though accept for M[0] and M[1]
1663    */
1664
1665   if (redmode == 0) {
1666      /* now we need R mod m */
1667      if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
1668        goto __RES;
1669      }
1670
1671      /* now set M[1] to G * R mod m */
1672      if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
1673        goto __RES;
1674      }
1675   } else {
1676      mp_set(&res, 1);
1677      if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
1678         goto __RES;
1679      }
1680   }
1681
1682   /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
1683   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
1684     goto __RES;
1685   }
1686
1687   for (x = 0; x < (winsize - 1); x++) {
1688     if ((err = mp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
1689       goto __RES;
1690     }
1691     if ((err = redux (&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
1692       goto __RES;
1693     }
1694   }
1695
1696   /* create upper table */
1697   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
1698     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
1699       goto __RES;
1700     }
1701     if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
1702       goto __RES;
1703     }
1704   }
1705
1706   /* set initial mode and bit cnt */
1707   mode   = 0;
1708   bitcnt = 1;
1709   buf    = 0;
1710   digidx = X->used - 1;
1711   bitcpy = 0;
1712   bitbuf = 0;
1713
1714   for (;;) {
1715     /* grab next digit as required */
1716     if (--bitcnt == 0) {
1717       /* if digidx == -1 we are out of digits so break */
1718       if (digidx == -1) {
1719         break;
1720       }
1721       /* read next digit and reset bitcnt */
1722       buf    = X->dp[digidx--];
1723       bitcnt = DIGIT_BIT;
1724     }
1725
1726     /* grab the next msb from the exponent */
1727     y     = (buf >> (DIGIT_BIT - 1)) & 1;
1728     buf <<= (mp_digit)1;
1729
1730     /* if the bit is zero and mode == 0 then we ignore it
1731      * These represent the leading zero bits before the first 1 bit
1732      * in the exponent.  Technically this opt is not required but it
1733      * does lower the # of trivial squaring/reductions used
1734      */
1735     if (mode == 0 && y == 0) {
1736       continue;
1737     }
1738
1739     /* if the bit is zero and mode == 1 then we square */
1740     if (mode == 1 && y == 0) {
1741       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
1742         goto __RES;
1743       }
1744       if ((err = redux (&res, P, mp)) != MP_OKAY) {
1745         goto __RES;
1746       }
1747       continue;
1748     }
1749
1750     /* else we add it to the window */
1751     bitbuf |= (y << (winsize - ++bitcpy));
1752     mode    = 2;
1753
1754     if (bitcpy == winsize) {
1755       /* ok window is filled so square as required and multiply  */
1756       /* square first */
1757       for (x = 0; x < winsize; x++) {
1758         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
1759           goto __RES;
1760         }
1761         if ((err = redux (&res, P, mp)) != MP_OKAY) {
1762           goto __RES;
1763         }
1764       }
1765
1766       /* then multiply */
1767       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
1768         goto __RES;
1769       }
1770       if ((err = redux (&res, P, mp)) != MP_OKAY) {
1771         goto __RES;
1772       }
1773
1774       /* empty window and reset */
1775       bitcpy = 0;
1776       bitbuf = 0;
1777       mode   = 1;
1778     }
1779   }
1780
1781   /* if bits remain then square/multiply */
1782   if (mode == 2 && bitcpy > 0) {
1783     /* square then multiply if the bit is set */
1784     for (x = 0; x < bitcpy; x++) {
1785       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
1786         goto __RES;
1787       }
1788       if ((err = redux (&res, P, mp)) != MP_OKAY) {
1789         goto __RES;
1790       }
1791
1792       /* get next bit of the window */
1793       bitbuf <<= 1;
1794       if ((bitbuf & (1 << winsize)) != 0) {
1795         /* then multiply */
1796         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
1797           goto __RES;
1798         }
1799         if ((err = redux (&res, P, mp)) != MP_OKAY) {
1800           goto __RES;
1801         }
1802       }
1803     }
1804   }
1805
1806   if (redmode == 0) {
1807      /* fixup result if Montgomery reduction is used
1808       * recall that any value in a Montgomery system is
1809       * actually multiplied by R mod n.  So we have
1810       * to reduce one more time to cancel out the factor
1811       * of R.
1812       */
1813      if ((err = redux(&res, P, mp)) != MP_OKAY) {
1814        goto __RES;
1815      }
1816   }
1817
1818   /* swap res with Y */
1819   mp_exch (&res, Y);
1820   err = MP_OKAY;
1821 __RES:mp_clear (&res);
1822 __M:
1823   mp_clear(&M[1]);
1824   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
1825     mp_clear (&M[x]);
1826   }
1827   return err;
1828 }
1829
1830 /* Greatest Common Divisor using the binary method */
1831 int mp_gcd (const mp_int * a, const mp_int * b, mp_int * c)
1832 {
1833   mp_int  u, v;
1834   int     k, u_lsb, v_lsb, res;
1835
1836   /* either zero than gcd is the largest */
1837   if (mp_iszero (a) == 1 && mp_iszero (b) == 0) {
1838     return mp_abs (b, c);
1839   }
1840   if (mp_iszero (a) == 0 && mp_iszero (b) == 1) {
1841     return mp_abs (a, c);
1842   }
1843
1844   /* optimized.  At this point if a == 0 then
1845    * b must equal zero too
1846    */
1847   if (mp_iszero (a) == 1) {
1848     mp_zero(c);
1849     return MP_OKAY;
1850   }
1851
1852   /* get copies of a and b we can modify */
1853   if ((res = mp_init_copy (&u, a)) != MP_OKAY) {
1854     return res;
1855   }
1856
1857   if ((res = mp_init_copy (&v, b)) != MP_OKAY) {
1858     goto __U;
1859   }
1860
1861   /* must be positive for the remainder of the algorithm */
1862   u.sign = v.sign = MP_ZPOS;
1863
1864   /* B1.  Find the common power of two for u and v */
1865   u_lsb = mp_cnt_lsb(&u);
1866   v_lsb = mp_cnt_lsb(&v);
1867   k     = MIN(u_lsb, v_lsb);
1868
1869   if (k > 0) {
1870      /* divide the power of two out */
1871      if ((res = mp_div_2d(&u, k, &u, NULL)) != MP_OKAY) {
1872         goto __V;
1873      }
1874
1875      if ((res = mp_div_2d(&v, k, &v, NULL)) != MP_OKAY) {
1876         goto __V;
1877      }
1878   }
1879
1880   /* divide any remaining factors of two out */
1881   if (u_lsb != k) {
1882      if ((res = mp_div_2d(&u, u_lsb - k, &u, NULL)) != MP_OKAY) {
1883         goto __V;
1884      }
1885   }
1886
1887   if (v_lsb != k) {
1888      if ((res = mp_div_2d(&v, v_lsb - k, &v, NULL)) != MP_OKAY) {
1889         goto __V;
1890      }
1891   }
1892
1893   while (mp_iszero(&v) == 0) {
1894      /* make sure v is the largest */
1895      if (mp_cmp_mag(&u, &v) == MP_GT) {
1896         /* swap u and v to make sure v is >= u */
1897         mp_exch(&u, &v);
1898      }
1899      
1900      /* subtract smallest from largest */
1901      if ((res = s_mp_sub(&v, &u, &v)) != MP_OKAY) {
1902         goto __V;
1903      }
1904      
1905      /* Divide out all factors of two */
1906      if ((res = mp_div_2d(&v, mp_cnt_lsb(&v), &v, NULL)) != MP_OKAY) {
1907         goto __V;
1908      } 
1909   } 
1910
1911   /* multiply by 2**k which we divided out at the beginning */
1912   if ((res = mp_mul_2d (&u, k, c)) != MP_OKAY) {
1913      goto __V;
1914   }
1915   c->sign = MP_ZPOS;
1916   res = MP_OKAY;
1917 __V:mp_clear (&u);
1918 __U:mp_clear (&v);
1919   return res;
1920 }
1921
1922 /* get the lower 32-bits of an mp_int */
1923 unsigned long mp_get_int(const mp_int * a)
1924 {
1925   int i;
1926   unsigned long res;
1927
1928   if (a->used == 0) {
1929      return 0;
1930   }
1931
1932   /* get number of digits of the lsb we have to read */
1933   i = MIN(a->used,(int)((sizeof(unsigned long)*CHAR_BIT+DIGIT_BIT-1)/DIGIT_BIT))-1;
1934
1935   /* get most significant digit of result */
1936   res = DIGIT(a,i);
1937    
1938   while (--i >= 0) {
1939     res = (res << DIGIT_BIT) | DIGIT(a,i);
1940   }
1941
1942   /* force result to 32-bits always so it is consistent on non 32-bit platforms */
1943   return res & 0xFFFFFFFFUL;
1944 }
1945
1946 /* grow as required */
1947 int mp_grow (mp_int * a, int size)
1948 {
1949   int     i;
1950   mp_digit *tmp;
1951
1952   /* if the alloc size is smaller alloc more ram */
1953   if (a->alloc < size) {
1954     /* ensure there are always at least MP_PREC digits extra on top */
1955     size += (MP_PREC * 2) - (size % MP_PREC);
1956
1957     /* reallocate the array a->dp
1958      *
1959      * We store the return in a temporary variable
1960      * in case the operation failed we don't want
1961      * to overwrite the dp member of a.
1962      */
1963     tmp = realloc (a->dp, sizeof (mp_digit) * size);
1964     if (tmp == NULL) {
1965       /* reallocation failed but "a" is still valid [can be freed] */
1966       return MP_MEM;
1967     }
1968
1969     /* reallocation succeeded so set a->dp */
1970     a->dp = tmp;
1971
1972     /* zero excess digits */
1973     i        = a->alloc;
1974     a->alloc = size;
1975     for (; i < a->alloc; i++) {
1976       a->dp[i] = 0;
1977     }
1978   }
1979   return MP_OKAY;
1980 }
1981
1982 /* init a new mp_int */
1983 int mp_init (mp_int * a)
1984 {
1985   int i;
1986
1987   /* allocate memory required and clear it */
1988   a->dp = malloc (sizeof (mp_digit) * MP_PREC);
1989   if (a->dp == NULL) {
1990     return MP_MEM;
1991   }
1992
1993   /* set the digits to zero */
1994   for (i = 0; i < MP_PREC; i++) {
1995       a->dp[i] = 0;
1996   }
1997
1998   /* set the used to zero, allocated digits to the default precision
1999    * and sign to positive */
2000   a->used  = 0;
2001   a->alloc = MP_PREC;
2002   a->sign  = MP_ZPOS;
2003
2004   return MP_OKAY;
2005 }
2006
2007 /* creates "a" then copies b into it */
2008 int mp_init_copy (mp_int * a, const mp_int * b)
2009 {
2010   int     res;
2011
2012   if ((res = mp_init (a)) != MP_OKAY) {
2013     return res;
2014   }
2015   return mp_copy (b, a);
2016 }
2017
2018 int mp_init_multi(mp_int *mp, ...) 
2019 {
2020     mp_err res = MP_OKAY;      /* Assume ok until proven otherwise */
2021     int n = 0;                 /* Number of ok inits */
2022     mp_int* cur_arg = mp;
2023     va_list args;
2024
2025     va_start(args, mp);        /* init args to next argument from caller */
2026     while (cur_arg != NULL) {
2027         if (mp_init(cur_arg) != MP_OKAY) {
2028             /* Oops - error! Back-track and mp_clear what we already
2029                succeeded in init-ing, then return error.
2030             */
2031             va_list clean_args;
2032             
2033             /* end the current list */
2034             va_end(args);
2035             
2036             /* now start cleaning up */            
2037             cur_arg = mp;
2038             va_start(clean_args, mp);
2039             while (n--) {
2040                 mp_clear(cur_arg);
2041                 cur_arg = va_arg(clean_args, mp_int*);
2042             }
2043             va_end(clean_args);
2044             res = MP_MEM;
2045             break;
2046         }
2047         n++;
2048         cur_arg = va_arg(args, mp_int*);
2049     }
2050     va_end(args);
2051     return res;                /* Assumed ok, if error flagged above. */
2052 }
2053
2054 /* init an mp_init for a given size */
2055 int mp_init_size (mp_int * a, int size)
2056 {
2057   int x;
2058
2059   /* pad size so there are always extra digits */
2060   size += (MP_PREC * 2) - (size % MP_PREC);    
2061   
2062   /* alloc mem */
2063   a->dp = malloc (sizeof (mp_digit) * size);
2064   if (a->dp == NULL) {
2065     return MP_MEM;
2066   }
2067
2068   /* set the members */
2069   a->used  = 0;
2070   a->alloc = size;
2071   a->sign  = MP_ZPOS;
2072
2073   /* zero the digits */
2074   for (x = 0; x < size; x++) {
2075       a->dp[x] = 0;
2076   }
2077
2078   return MP_OKAY;
2079 }
2080
2081 /* hac 14.61, pp608 */
2082 int mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
2083 {
2084   /* b cannot be negative */
2085   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2086     return MP_VAL;
2087   }
2088
2089   /* if the modulus is odd we can use a faster routine instead */
2090   if (mp_isodd (b) == 1) {
2091     return fast_mp_invmod (a, b, c);
2092   }
2093   
2094   return mp_invmod_slow(a, b, c);
2095 }
2096
2097 /* hac 14.61, pp608 */
2098 int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c)
2099 {
2100   mp_int  x, y, u, v, A, B, C, D;
2101   int     res;
2102
2103   /* b cannot be negative */
2104   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2105     return MP_VAL;
2106   }
2107
2108   /* init temps */
2109   if ((res = mp_init_multi(&x, &y, &u, &v, 
2110                            &A, &B, &C, &D, NULL)) != MP_OKAY) {
2111      return res;
2112   }
2113
2114   /* x = a, y = b */
2115   if ((res = mp_copy (a, &x)) != MP_OKAY) {
2116     goto __ERR;
2117   }
2118   if ((res = mp_copy (b, &y)) != MP_OKAY) {
2119     goto __ERR;
2120   }
2121
2122   /* 2. [modified] if x,y are both even then return an error! */
2123   if (mp_iseven (&x) == 1 && mp_iseven (&y) == 1) {
2124     res = MP_VAL;
2125     goto __ERR;
2126   }
2127
2128   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
2129   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
2130     goto __ERR;
2131   }
2132   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
2133     goto __ERR;
2134   }
2135   mp_set (&A, 1);
2136   mp_set (&D, 1);
2137
2138 top:
2139   /* 4.  while u is even do */
2140   while (mp_iseven (&u) == 1) {
2141     /* 4.1 u = u/2 */
2142     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
2143       goto __ERR;
2144     }
2145     /* 4.2 if A or B is odd then */
2146     if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
2147       /* A = (A+y)/2, B = (B-x)/2 */
2148       if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
2149          goto __ERR;
2150       }
2151       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
2152          goto __ERR;
2153       }
2154     }
2155     /* A = A/2, B = B/2 */
2156     if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
2157       goto __ERR;
2158     }
2159     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
2160       goto __ERR;
2161     }
2162   }
2163
2164   /* 5.  while v is even do */
2165   while (mp_iseven (&v) == 1) {
2166     /* 5.1 v = v/2 */
2167     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
2168       goto __ERR;
2169     }
2170     /* 5.2 if C or D is odd then */
2171     if (mp_isodd (&C) == 1 || mp_isodd (&D) == 1) {
2172       /* C = (C+y)/2, D = (D-x)/2 */
2173       if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
2174          goto __ERR;
2175       }
2176       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
2177          goto __ERR;
2178       }
2179     }
2180     /* C = C/2, D = D/2 */
2181     if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
2182       goto __ERR;
2183     }
2184     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
2185       goto __ERR;
2186     }
2187   }
2188
2189   /* 6.  if u >= v then */
2190   if (mp_cmp (&u, &v) != MP_LT) {
2191     /* u = u - v, A = A - C, B = B - D */
2192     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
2193       goto __ERR;
2194     }
2195
2196     if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
2197       goto __ERR;
2198     }
2199
2200     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
2201       goto __ERR;
2202     }
2203   } else {
2204     /* v - v - u, C = C - A, D = D - B */
2205     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
2206       goto __ERR;
2207     }
2208
2209     if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
2210       goto __ERR;
2211     }
2212
2213     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
2214       goto __ERR;
2215     }
2216   }
2217
2218   /* if not zero goto step 4 */
2219   if (mp_iszero (&u) == 0)
2220     goto top;
2221
2222   /* now a = C, b = D, gcd == g*v */
2223
2224   /* if v != 1 then there is no inverse */
2225   if (mp_cmp_d (&v, 1) != MP_EQ) {
2226     res = MP_VAL;
2227     goto __ERR;
2228   }
2229
2230   /* if its too low */
2231   while (mp_cmp_d(&C, 0) == MP_LT) {
2232       if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
2233          goto __ERR;
2234       }
2235   }
2236   
2237   /* too big */
2238   while (mp_cmp_mag(&C, b) != MP_LT) {
2239       if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
2240          goto __ERR;
2241       }
2242   }
2243   
2244   /* C is now the inverse */
2245   mp_exch (&C, c);
2246   res = MP_OKAY;
2247 __ERR:mp_clear_multi (&x, &y, &u, &v, &A, &B, &C, &D, NULL);
2248   return res;
2249 }
2250
2251 /* c = |a| * |b| using Karatsuba Multiplication using 
2252  * three half size multiplications
2253  *
2254  * Let B represent the radix [e.g. 2**DIGIT_BIT] and 
2255  * let n represent half of the number of digits in 
2256  * the min(a,b)
2257  *
2258  * a = a1 * B**n + a0
2259  * b = b1 * B**n + b0
2260  *
2261  * Then, a * b => 
2262    a1b1 * B**2n + ((a1 - a0)(b1 - b0) + a0b0 + a1b1) * B + a0b0
2263  *
2264  * Note that a1b1 and a0b0 are used twice and only need to be 
2265  * computed once.  So in total three half size (half # of 
2266  * digit) multiplications are performed, a0b0, a1b1 and 
2267  * (a1-b1)(a0-b0)
2268  *
2269  * Note that a multiplication of half the digits requires
2270  * 1/4th the number of single precision multiplications so in 
2271  * total after one call 25% of the single precision multiplications 
2272  * are saved.  Note also that the call to mp_mul can end up back 
2273  * in this function if the a0, a1, b0, or b1 are above the threshold.  
2274  * This is known as divide-and-conquer and leads to the famous 
2275  * O(N**lg(3)) or O(N**1.584) work which is asymptotically lower than
2276  * the standard O(N**2) that the baseline/comba methods use.  
2277  * Generally though the overhead of this method doesn't pay off 
2278  * until a certain size (N ~ 80) is reached.
2279  */
2280 int mp_karatsuba_mul (const mp_int * a, const mp_int * b, mp_int * c)
2281 {
2282   mp_int  x0, x1, y0, y1, t1, x0y0, x1y1;
2283   int     B, err;
2284
2285   /* default the return code to an error */
2286   err = MP_MEM;
2287
2288   /* min # of digits */
2289   B = MIN (a->used, b->used);
2290
2291   /* now divide in two */
2292   B = B >> 1;
2293
2294   /* init copy all the temps */
2295   if (mp_init_size (&x0, B) != MP_OKAY)
2296     goto ERR;
2297   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2298     goto X0;
2299   if (mp_init_size (&y0, B) != MP_OKAY)
2300     goto X1;
2301   if (mp_init_size (&y1, b->used - B) != MP_OKAY)
2302     goto Y0;
2303
2304   /* init temps */
2305   if (mp_init_size (&t1, B * 2) != MP_OKAY)
2306     goto Y1;
2307   if (mp_init_size (&x0y0, B * 2) != MP_OKAY)
2308     goto T1;
2309   if (mp_init_size (&x1y1, B * 2) != MP_OKAY)
2310     goto X0Y0;
2311
2312   /* now shift the digits */
2313   x0.used = y0.used = B;
2314   x1.used = a->used - B;
2315   y1.used = b->used - B;
2316
2317   {
2318     register int x;
2319     register mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
2320
2321     /* we copy the digits directly instead of using higher level functions
2322      * since we also need to shift the digits
2323      */
2324     tmpa = a->dp;
2325     tmpb = b->dp;
2326
2327     tmpx = x0.dp;
2328     tmpy = y0.dp;
2329     for (x = 0; x < B; x++) {
2330       *tmpx++ = *tmpa++;
2331       *tmpy++ = *tmpb++;
2332     }
2333
2334     tmpx = x1.dp;
2335     for (x = B; x < a->used; x++) {
2336       *tmpx++ = *tmpa++;
2337     }
2338
2339     tmpy = y1.dp;
2340     for (x = B; x < b->used; x++) {
2341       *tmpy++ = *tmpb++;
2342     }
2343   }
2344
2345   /* only need to clamp the lower words since by definition the 
2346    * upper words x1/y1 must have a known number of digits
2347    */
2348   mp_clamp (&x0);
2349   mp_clamp (&y0);
2350
2351   /* now calc the products x0y0 and x1y1 */
2352   /* after this x0 is no longer required, free temp [x0==t2]! */
2353   if (mp_mul (&x0, &y0, &x0y0) != MP_OKAY)  
2354     goto X1Y1;          /* x0y0 = x0*y0 */
2355   if (mp_mul (&x1, &y1, &x1y1) != MP_OKAY)
2356     goto X1Y1;          /* x1y1 = x1*y1 */
2357
2358   /* now calc x1-x0 and y1-y0 */
2359   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2360     goto X1Y1;          /* t1 = x1 - x0 */
2361   if (mp_sub (&y1, &y0, &x0) != MP_OKAY)
2362     goto X1Y1;          /* t2 = y1 - y0 */
2363   if (mp_mul (&t1, &x0, &t1) != MP_OKAY)
2364     goto X1Y1;          /* t1 = (x1 - x0) * (y1 - y0) */
2365
2366   /* add x0y0 */
2367   if (mp_add (&x0y0, &x1y1, &x0) != MP_OKAY)
2368     goto X1Y1;          /* t2 = x0y0 + x1y1 */
2369   if (mp_sub (&x0, &t1, &t1) != MP_OKAY)
2370     goto X1Y1;          /* t1 = x0y0 + x1y1 - (x1-x0)*(y1-y0) */
2371
2372   /* shift by B */
2373   if (mp_lshd (&t1, B) != MP_OKAY)
2374     goto X1Y1;          /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
2375   if (mp_lshd (&x1y1, B * 2) != MP_OKAY)
2376     goto X1Y1;          /* x1y1 = x1y1 << 2*B */
2377
2378   if (mp_add (&x0y0, &t1, &t1) != MP_OKAY)
2379     goto X1Y1;          /* t1 = x0y0 + t1 */
2380   if (mp_add (&t1, &x1y1, c) != MP_OKAY)
2381     goto X1Y1;          /* t1 = x0y0 + t1 + x1y1 */
2382
2383   /* Algorithm succeeded set the return code to MP_OKAY */
2384   err = MP_OKAY;
2385
2386 X1Y1:mp_clear (&x1y1);
2387 X0Y0:mp_clear (&x0y0);
2388 T1:mp_clear (&t1);
2389 Y1:mp_clear (&y1);
2390 Y0:mp_clear (&y0);
2391 X1:mp_clear (&x1);
2392 X0:mp_clear (&x0);
2393 ERR:
2394   return err;
2395 }
2396
2397 /* Karatsuba squaring, computes b = a*a using three 
2398  * half size squarings
2399  *
2400  * See comments of karatsuba_mul for details.  It 
2401  * is essentially the same algorithm but merely 
2402  * tuned to perform recursive squarings.
2403  */
2404 int mp_karatsuba_sqr (const mp_int * a, mp_int * b)
2405 {
2406   mp_int  x0, x1, t1, t2, x0x0, x1x1;
2407   int     B, err;
2408
2409   err = MP_MEM;
2410
2411   /* min # of digits */
2412   B = a->used;
2413
2414   /* now divide in two */
2415   B = B >> 1;
2416
2417   /* init copy all the temps */
2418   if (mp_init_size (&x0, B) != MP_OKAY)
2419     goto ERR;
2420   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2421     goto X0;
2422
2423   /* init temps */
2424   if (mp_init_size (&t1, a->used * 2) != MP_OKAY)
2425     goto X1;
2426   if (mp_init_size (&t2, a->used * 2) != MP_OKAY)
2427     goto T1;
2428   if (mp_init_size (&x0x0, B * 2) != MP_OKAY)
2429     goto T2;
2430   if (mp_init_size (&x1x1, (a->used - B) * 2) != MP_OKAY)
2431     goto X0X0;
2432
2433   {
2434     register int x;
2435     register mp_digit *dst, *src;
2436
2437     src = a->dp;
2438
2439     /* now shift the digits */
2440     dst = x0.dp;
2441     for (x = 0; x < B; x++) {
2442       *dst++ = *src++;
2443     }
2444
2445     dst = x1.dp;
2446     for (x = B; x < a->used; x++) {
2447       *dst++ = *src++;
2448     }
2449   }
2450
2451   x0.used = B;
2452   x1.used = a->used - B;
2453
2454   mp_clamp (&x0);
2455
2456   /* now calc the products x0*x0 and x1*x1 */
2457   if (mp_sqr (&x0, &x0x0) != MP_OKAY)
2458     goto X1X1;           /* x0x0 = x0*x0 */
2459   if (mp_sqr (&x1, &x1x1) != MP_OKAY)
2460     goto X1X1;           /* x1x1 = x1*x1 */
2461
2462   /* now calc (x1-x0)**2 */
2463   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2464     goto X1X1;           /* t1 = x1 - x0 */
2465   if (mp_sqr (&t1, &t1) != MP_OKAY)
2466     goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */
2467
2468   /* add x0y0 */
2469   if (s_mp_add (&x0x0, &x1x1, &t2) != MP_OKAY)
2470     goto X1X1;           /* t2 = x0x0 + x1x1 */
2471   if (mp_sub (&t2, &t1, &t1) != MP_OKAY)
2472     goto X1X1;           /* t1 = x0x0 + x1x1 - (x1-x0)*(x1-x0) */
2473
2474   /* shift by B */
2475   if (mp_lshd (&t1, B) != MP_OKAY)
2476     goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
2477   if (mp_lshd (&x1x1, B * 2) != MP_OKAY)
2478     goto X1X1;           /* x1x1 = x1x1 << 2*B */
2479
2480   if (mp_add (&x0x0, &t1, &t1) != MP_OKAY)
2481     goto X1X1;           /* t1 = x0x0 + t1 */
2482   if (mp_add (&t1, &x1x1, b) != MP_OKAY)
2483     goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */
2484
2485   err = MP_OKAY;
2486
2487 X1X1:mp_clear (&x1x1);
2488 X0X0:mp_clear (&x0x0);
2489 T2:mp_clear (&t2);
2490 T1:mp_clear (&t1);
2491 X1:mp_clear (&x1);
2492 X0:mp_clear (&x0);
2493 ERR:
2494   return err;
2495 }
2496
2497 /* computes least common multiple as |a*b|/(a, b) */
2498 int mp_lcm (const mp_int * a, const mp_int * b, mp_int * c)
2499 {
2500   int     res;
2501   mp_int  t1, t2;
2502
2503
2504   if ((res = mp_init_multi (&t1, &t2, NULL)) != MP_OKAY) {
2505     return res;
2506   }
2507
2508   /* t1 = get the GCD of the two inputs */
2509   if ((res = mp_gcd (a, b, &t1)) != MP_OKAY) {
2510     goto __T;
2511   }
2512
2513   /* divide the smallest by the GCD */
2514   if (mp_cmp_mag(a, b) == MP_LT) {
2515      /* store quotient in t2 such that t2 * b is the LCM */
2516      if ((res = mp_div(a, &t1, &t2, NULL)) != MP_OKAY) {
2517         goto __T;
2518      }
2519      res = mp_mul(b, &t2, c);
2520   } else {
2521      /* store quotient in t2 such that t2 * a is the LCM */
2522      if ((res = mp_div(b, &t1, &t2, NULL)) != MP_OKAY) {
2523         goto __T;
2524      }
2525      res = mp_mul(a, &t2, c);
2526   }
2527
2528   /* fix the sign to positive */
2529   c->sign = MP_ZPOS;
2530
2531 __T:
2532   mp_clear_multi (&t1, &t2, NULL);
2533   return res;
2534 }
2535
2536 /* shift left a certain amount of digits */
2537 int mp_lshd (mp_int * a, int b)
2538 {
2539   int     x, res;
2540
2541   /* if its less than zero return */
2542   if (b <= 0) {
2543     return MP_OKAY;
2544   }
2545
2546   /* grow to fit the new digits */
2547   if (a->alloc < a->used + b) {
2548      if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
2549        return res;
2550      }
2551   }
2552
2553   {
2554     register mp_digit *top, *bottom;
2555
2556     /* increment the used by the shift amount then copy upwards */
2557     a->used += b;
2558
2559     /* top */
2560     top = a->dp + a->used - 1;
2561
2562     /* base */
2563     bottom = a->dp + a->used - 1 - b;
2564
2565     /* much like mp_rshd this is implemented using a sliding window
2566      * except the window goes the otherway around.  Copying from
2567      * the bottom to the top.  see bn_mp_rshd.c for more info.
2568      */
2569     for (x = a->used - 1; x >= b; x--) {
2570       *top-- = *bottom--;
2571     }
2572
2573     /* zero the lower digits */
2574     top = a->dp;
2575     for (x = 0; x < b; x++) {
2576       *top++ = 0;
2577     }
2578   }
2579   return MP_OKAY;
2580 }
2581
2582 /* c = a mod b, 0 <= c < b */
2583 int
2584 mp_mod (const mp_int * a, mp_int * b, mp_int * c)
2585 {
2586   mp_int  t;
2587   int     res;
2588
2589   if ((res = mp_init (&t)) != MP_OKAY) {
2590     return res;
2591   }
2592
2593   if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
2594     mp_clear (&t);
2595     return res;
2596   }
2597
2598   if (t.sign != b->sign) {
2599     res = mp_add (b, &t, c);
2600   } else {
2601     res = MP_OKAY;
2602     mp_exch (&t, c);
2603   }
2604
2605   mp_clear (&t);
2606   return res;
2607 }
2608
2609 /* calc a value mod 2**b */
2610 int
2611 mp_mod_2d (const mp_int * a, int b, mp_int * c)
2612 {
2613   int     x, res;
2614
2615   /* if b is <= 0 then zero the int */
2616   if (b <= 0) {
2617     mp_zero (c);
2618     return MP_OKAY;
2619   }
2620
2621   /* if the modulus is larger than the value than return */
2622   if (b > a->used * DIGIT_BIT) {
2623     res = mp_copy (a, c);
2624     return res;
2625   }
2626
2627   /* copy */
2628   if ((res = mp_copy (a, c)) != MP_OKAY) {
2629     return res;
2630   }
2631
2632   /* zero digits above the last digit of the modulus */
2633   for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
2634     c->dp[x] = 0;
2635   }
2636   /* clear the digit that is not completely outside/inside the modulus */
2637   c->dp[b / DIGIT_BIT] &= (1 << ((mp_digit)b % DIGIT_BIT)) - 1;
2638   mp_clamp (c);
2639   return MP_OKAY;
2640 }
2641
2642 int
2643 mp_mod_d (const mp_int * a, mp_digit b, mp_digit * c)
2644 {
2645   return mp_div_d(a, b, NULL, c);
2646 }
2647
2648 /*
2649  * shifts with subtractions when the result is greater than b.
2650  *
2651  * The method is slightly modified to shift B unconditionally up to just under
2652  * the leading bit of b.  This saves a lot of multiple precision shifting.
2653  */
2654 int mp_montgomery_calc_normalization (mp_int * a, const mp_int * b)
2655 {
2656   int     x, bits, res;
2657
2658   /* how many bits of last digit does b use */
2659   bits = mp_count_bits (b) % DIGIT_BIT;
2660
2661
2662   if (b->used > 1) {
2663      if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
2664         return res;
2665      }
2666   } else {
2667      mp_set(a, 1);
2668      bits = 1;
2669   }
2670
2671
2672   /* now compute C = A * B mod b */
2673   for (x = bits - 1; x < DIGIT_BIT; x++) {
2674     if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
2675       return res;
2676     }
2677     if (mp_cmp_mag (a, b) != MP_LT) {
2678       if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
2679         return res;
2680       }
2681     }
2682   }
2683
2684   return MP_OKAY;
2685 }
2686
2687 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
2688 int
2689 mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
2690 {
2691   int     ix, res, digs;
2692   mp_digit mu;
2693
2694   /* can the fast reduction [comba] method be used?
2695    *
2696    * Note that unlike in mul you're safely allowed *less*
2697    * than the available columns [255 per default] since carries
2698    * are fixed up in the inner loop.
2699    */
2700   digs = n->used * 2 + 1;
2701   if ((digs < MP_WARRAY) &&
2702       n->used <
2703       (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2704     return fast_mp_montgomery_reduce (x, n, rho);
2705   }
2706
2707   /* grow the input as required */
2708   if (x->alloc < digs) {
2709     if ((res = mp_grow (x, digs)) != MP_OKAY) {
2710       return res;
2711     }
2712   }
2713   x->used = digs;
2714
2715   for (ix = 0; ix < n->used; ix++) {
2716     /* mu = ai * rho mod b
2717      *
2718      * The value of rho must be precalculated via
2719      * montgomery_setup() such that
2720      * it equals -1/n0 mod b this allows the
2721      * following inner loop to reduce the
2722      * input one digit at a time
2723      */
2724     mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
2725
2726     /* a = a + mu * m * b**i */
2727     {
2728       register int iy;
2729       register mp_digit *tmpn, *tmpx, u;
2730       register mp_word r;
2731
2732       /* alias for digits of the modulus */
2733       tmpn = n->dp;
2734
2735       /* alias for the digits of x [the input] */
2736       tmpx = x->dp + ix;
2737
2738       /* set the carry to zero */
2739       u = 0;
2740
2741       /* Multiply and add in place */
2742       for (iy = 0; iy < n->used; iy++) {
2743         /* compute product and sum */
2744         r       = ((mp_word)mu) * ((mp_word)*tmpn++) +
2745                   ((mp_word) u) + ((mp_word) * tmpx);
2746
2747         /* get carry */
2748         u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2749
2750         /* fix digit */
2751         *tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
2752       }
2753       /* At this point the ix'th digit of x should be zero */
2754
2755
2756       /* propagate carries upwards as required*/
2757       while (u) {
2758         *tmpx   += u;
2759         u        = *tmpx >> DIGIT_BIT;
2760         *tmpx++ &= MP_MASK;
2761       }
2762     }
2763   }
2764
2765   /* at this point the n.used'th least
2766    * significant digits of x are all zero
2767    * which means we can shift x to the
2768    * right by n.used digits and the
2769    * residue is unchanged.
2770    */
2771
2772   /* x = x/b**n.used */
2773   mp_clamp(x);
2774   mp_rshd (x, n->used);
2775
2776   /* if x >= n then x = x - n */
2777   if (mp_cmp_mag (x, n) != MP_LT) {
2778     return s_mp_sub (x, n, x);
2779   }
2780
2781   return MP_OKAY;
2782 }
2783
2784 /* setups the montgomery reduction stuff */
2785 int
2786 mp_montgomery_setup (const mp_int * n, mp_digit * rho)
2787 {
2788   mp_digit x, b;
2789
2790 /* fast inversion mod 2**k
2791  *
2792  * Based on the fact that
2793  *
2794  * XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
2795  *                    =>  2*X*A - X*X*A*A = 1
2796  *                    =>  2*(1) - (1)     = 1
2797  */
2798   b = n->dp[0];
2799
2800   if ((b & 1) == 0) {
2801     return MP_VAL;
2802   }
2803
2804   x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
2805   x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
2806   x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
2807   x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
2808
2809   /* rho = -1/m mod b */
2810   *rho = (((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
2811
2812   return MP_OKAY;
2813 }
2814
2815 /* high level multiplication (handles sign) */
2816 int mp_mul (const mp_int * a, const mp_int * b, mp_int * c)
2817 {
2818   int     res, neg;
2819   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
2820
2821   /* use Karatsuba? */
2822   if (MIN (a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
2823     res = mp_karatsuba_mul (a, b, c);
2824   } else 
2825   {
2826     /* can we use the fast multiplier?
2827      *
2828      * The fast multiplier can be used if the output will 
2829      * have less than MP_WARRAY digits and the number of 
2830      * digits won't affect carry propagation
2831      */
2832     int     digs = a->used + b->used + 1;
2833
2834     if ((digs < MP_WARRAY) &&
2835         MIN(a->used, b->used) <= 
2836         (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2837       res = fast_s_mp_mul_digs (a, b, c, digs);
2838     } else 
2839       res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
2840   }
2841   c->sign = (c->used > 0) ? neg : MP_ZPOS;
2842   return res;
2843 }
2844
2845 /* b = a*2 */
2846 int mp_mul_2(const mp_int * a, mp_int * b)
2847 {
2848   int     x, res, oldused;
2849
2850   /* grow to accommodate result */
2851   if (b->alloc < a->used + 1) {
2852     if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
2853       return res;
2854     }
2855   }
2856
2857   oldused = b->used;
2858   b->used = a->used;
2859
2860   {
2861     register mp_digit r, rr, *tmpa, *tmpb;
2862
2863     /* alias for source */
2864     tmpa = a->dp;
2865     
2866     /* alias for dest */
2867     tmpb = b->dp;
2868
2869     /* carry */
2870     r = 0;
2871     for (x = 0; x < a->used; x++) {
2872     
2873       /* get what will be the *next* carry bit from the 
2874        * MSB of the current digit 
2875        */
2876       rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
2877       
2878       /* now shift up this digit, add in the carry [from the previous] */
2879       *tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
2880       
2881       /* copy the carry that would be from the source 
2882        * digit into the next iteration 
2883        */
2884       r = rr;
2885     }
2886
2887     /* new leading digit? */
2888     if (r != 0) {
2889       /* add a MSB which is always 1 at this point */
2890       *tmpb = 1;
2891       ++(b->used);
2892     }
2893
2894     /* now zero any excess digits on the destination 
2895      * that we didn't write to 
2896      */
2897     tmpb = b->dp + b->used;
2898     for (x = b->used; x < oldused; x++) {
2899       *tmpb++ = 0;
2900     }
2901   }
2902   b->sign = a->sign;
2903   return MP_OKAY;
2904 }
2905
2906 /* shift left by a certain bit count */
2907 int mp_mul_2d (const mp_int * a, int b, mp_int * c)
2908 {
2909   mp_digit d;
2910   int      res;
2911
2912   /* copy */
2913   if (a != c) {
2914      if ((res = mp_copy (a, c)) != MP_OKAY) {
2915        return res;
2916      }
2917   }
2918
2919   if (c->alloc < c->used + b/DIGIT_BIT + 1) {
2920      if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
2921        return res;
2922      }
2923   }
2924
2925   /* shift by as many digits in the bit count */
2926   if (b >= DIGIT_BIT) {
2927     if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
2928       return res;
2929     }
2930   }
2931
2932   /* shift any bit count < DIGIT_BIT */
2933   d = (mp_digit) (b % DIGIT_BIT);
2934   if (d != 0) {
2935     register mp_digit *tmpc, shift, mask, r, rr;
2936     register int x;
2937
2938     /* bitmask for carries */
2939     mask = (((mp_digit)1) << d) - 1;
2940
2941     /* shift for msbs */
2942     shift = DIGIT_BIT - d;
2943
2944     /* alias */
2945     tmpc = c->dp;
2946
2947     /* carry */
2948     r    = 0;
2949     for (x = 0; x < c->used; x++) {
2950       /* get the higher bits of the current word */
2951       rr = (*tmpc >> shift) & mask;
2952
2953       /* shift the current word and OR in the carry */
2954       *tmpc = ((*tmpc << d) | r) & MP_MASK;
2955       ++tmpc;
2956
2957       /* set the carry to the carry bits of the current word */
2958       r = rr;
2959     }
2960     
2961     /* set final carry */
2962     if (r != 0) {
2963        c->dp[(c->used)++] = r;
2964     }
2965   }
2966   mp_clamp (c);
2967   return MP_OKAY;
2968 }
2969
2970 /* multiply by a digit */
2971 int
2972 mp_mul_d (const mp_int * a, mp_digit b, mp_int * c)
2973 {
2974   mp_digit u, *tmpa, *tmpc;
2975   mp_word  r;
2976   int      ix, res, olduse;
2977
2978   /* make sure c is big enough to hold a*b */
2979   if (c->alloc < a->used + 1) {
2980     if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
2981       return res;
2982     }
2983   }
2984
2985   /* get the original destinations used count */
2986   olduse = c->used;
2987
2988   /* set the sign */
2989   c->sign = a->sign;
2990
2991   /* alias for a->dp [source] */
2992   tmpa = a->dp;
2993
2994   /* alias for c->dp [dest] */
2995   tmpc = c->dp;
2996
2997   /* zero carry */
2998   u = 0;
2999
3000   /* compute columns */
3001   for (ix = 0; ix < a->used; ix++) {
3002     /* compute product and carry sum for this term */
3003     r       = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
3004
3005     /* mask off higher bits to get a single digit */
3006     *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
3007
3008     /* send carry into next iteration */
3009     u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
3010   }
3011
3012   /* store final carry [if any] */
3013   *tmpc++ = u;
3014
3015   /* now zero digits above the top */
3016   while (ix++ < olduse) {
3017      *tmpc++ = 0;
3018   }
3019
3020   /* set used count */
3021   c->used = a->used + 1;
3022   mp_clamp(c);
3023
3024   return MP_OKAY;
3025 }
3026
3027 /* d = a * b (mod c) */
3028 int
3029 mp_mulmod (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
3030 {
3031   int     res;
3032   mp_int  t;
3033
3034   if ((res = mp_init (&t)) != MP_OKAY) {
3035     return res;
3036   }
3037
3038   if ((res = mp_mul (a, b, &t)) != MP_OKAY) {
3039     mp_clear (&t);
3040     return res;
3041   }
3042   res = mp_mod (&t, c, d);
3043   mp_clear (&t);
3044   return res;
3045 }
3046
3047 /* table of first PRIME_SIZE primes */
3048 static const mp_digit __prime_tab[] = {
3049   0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
3050   0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
3051   0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
3052   0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F, 0x0083,
3053   0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
3054   0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
3055   0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
3056   0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
3057
3058   0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
3059   0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
3060   0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
3061   0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
3062   0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
3063   0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
3064   0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
3065   0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
3066
3067   0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
3068   0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
3069   0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
3070   0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
3071   0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
3072   0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
3073   0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
3074   0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
3075
3076   0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
3077   0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
3078   0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
3079   0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
3080   0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
3081   0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
3082   0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
3083   0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
3084 };
3085
3086 /* determines if an integers is divisible by one 
3087  * of the first PRIME_SIZE primes or not
3088  *
3089  * sets result to 0 if not, 1 if yes
3090  */
3091 int mp_prime_is_divisible (const mp_int * a, int *result)
3092 {
3093   int     err, ix;
3094   mp_digit res;
3095
3096   /* default to not */
3097   *result = MP_NO;
3098
3099   for (ix = 0; ix < PRIME_SIZE; ix++) {
3100     /* what is a mod __prime_tab[ix] */
3101     if ((err = mp_mod_d (a, __prime_tab[ix], &res)) != MP_OKAY) {
3102       return err;
3103     }
3104
3105     /* is the residue zero? */
3106     if (res == 0) {
3107       *result = MP_YES;
3108       return MP_OKAY;
3109     }
3110   }
3111
3112   return MP_OKAY;
3113 }
3114
3115 /* performs a variable number of rounds of Miller-Rabin
3116  *
3117  * Probability of error after t rounds is no more than
3118
3119  *
3120  * Sets result to 1 if probably prime, 0 otherwise
3121  */
3122 int mp_prime_is_prime (mp_int * a, int t, int *result)
3123 {
3124   mp_int  b;
3125   int     ix, err, res;
3126
3127   /* default to no */
3128   *result = MP_NO;
3129
3130   /* valid value of t? */
3131   if (t <= 0 || t > PRIME_SIZE) {
3132     return MP_VAL;
3133   }
3134
3135   /* is the input equal to one of the primes in the table? */
3136   for (ix = 0; ix < PRIME_SIZE; ix++) {
3137       if (mp_cmp_d(a, __prime_tab[ix]) == MP_EQ) {
3138          *result = 1;
3139          return MP_OKAY;
3140       }
3141   }
3142
3143   /* first perform trial division */
3144   if ((err = mp_prime_is_divisible (a, &res)) != MP_OKAY) {
3145     return err;
3146   }
3147
3148   /* return if it was trivially divisible */
3149   if (res == MP_YES) {
3150     return MP_OKAY;
3151   }
3152
3153   /* now perform the miller-rabin rounds */
3154   if ((err = mp_init (&b)) != MP_OKAY) {
3155     return err;
3156   }
3157
3158   for (ix = 0; ix < t; ix++) {
3159     /* set the prime */
3160     mp_set (&b, __prime_tab[ix]);
3161
3162     if ((err = mp_prime_miller_rabin (a, &b, &res)) != MP_OKAY) {
3163       goto __B;
3164     }
3165
3166     if (res == MP_NO) {
3167       goto __B;
3168     }
3169   }
3170
3171   /* passed the test */
3172   *result = MP_YES;
3173 __B:mp_clear (&b);
3174   return err;
3175 }
3176
3177 /* Miller-Rabin test of "a" to the base of "b" as described in 
3178  * HAC pp. 139 Algorithm 4.24
3179  *
3180  * Sets result to 0 if definitely composite or 1 if probably prime.
3181  * Randomly the chance of error is no more than 1/4 and often 
3182  * very much lower.
3183  */
3184 int mp_prime_miller_rabin (mp_int * a, const mp_int * b, int *result)
3185 {
3186   mp_int  n1, y, r;
3187   int     s, j, err;
3188
3189   /* default */
3190   *result = MP_NO;
3191
3192   /* ensure b > 1 */
3193   if (mp_cmp_d(b, 1) != MP_GT) {
3194      return MP_VAL;
3195   }     
3196
3197   /* get n1 = a - 1 */
3198   if ((err = mp_init_copy (&n1, a)) != MP_OKAY) {
3199     return err;
3200   }
3201   if ((err = mp_sub_d (&n1, 1, &n1)) != MP_OKAY) {
3202     goto __N1;
3203   }
3204
3205   /* set 2**s * r = n1 */
3206   if ((err = mp_init_copy (&r, &n1)) != MP_OKAY) {
3207     goto __N1;
3208   }
3209
3210   /* count the number of least significant bits
3211    * which are zero
3212    */
3213   s = mp_cnt_lsb(&r);
3214
3215   /* now divide n - 1 by 2**s */
3216   if ((err = mp_div_2d (&r, s, &r, NULL)) != MP_OKAY) {
3217     goto __R;
3218   }
3219
3220   /* compute y = b**r mod a */
3221   if ((err = mp_init (&y)) != MP_OKAY) {
3222     goto __R;
3223   }
3224   if ((err = mp_exptmod (b, &r, a, &y)) != MP_OKAY) {
3225     goto __Y;
3226   }
3227
3228   /* if y != 1 and y != n1 do */
3229   if (mp_cmp_d (&y, 1) != MP_EQ && mp_cmp (&y, &n1) != MP_EQ) {
3230     j = 1;
3231     /* while j <= s-1 and y != n1 */
3232     while ((j <= (s - 1)) && mp_cmp (&y, &n1) != MP_EQ) {
3233       if ((err = mp_sqrmod (&y, a, &y)) != MP_OKAY) {
3234          goto __Y;
3235       }
3236
3237       /* if y == 1 then composite */
3238       if (mp_cmp_d (&y, 1) == MP_EQ) {
3239          goto __Y;
3240       }
3241
3242       ++j;
3243     }
3244
3245     /* if y != n1 then composite */
3246     if (mp_cmp (&y, &n1) != MP_EQ) {
3247       goto __Y;
3248     }
3249   }
3250
3251   /* probably prime now */
3252   *result = MP_YES;
3253 __Y:mp_clear (&y);
3254 __R:mp_clear (&r);
3255 __N1:mp_clear (&n1);
3256   return err;
3257 }
3258
3259 static const struct {
3260    int k, t;
3261 } sizes[] = {
3262 {   128,    28 },
3263 {   256,    16 },
3264 {   384,    10 },
3265 {   512,     7 },
3266 {   640,     6 },
3267 {   768,     5 },
3268 {   896,     4 },
3269 {  1024,     4 }
3270 };
3271
3272 /* returns # of RM trials required for a given bit size */
3273 int mp_prime_rabin_miller_trials(int size)
3274 {
3275    int x;
3276
3277    for (x = 0; x < (int)(sizeof(sizes)/(sizeof(sizes[0]))); x++) {
3278        if (sizes[x].k == size) {
3279           return sizes[x].t;
3280        } else if (sizes[x].k > size) {
3281           return (x == 0) ? sizes[0].t : sizes[x - 1].t;
3282        }
3283    }
3284    return sizes[x-1].t + 1;
3285 }
3286
3287 /* makes a truly random prime of a given size (bits),
3288  *
3289  * Flags are as follows:
3290  * 
3291  *   LTM_PRIME_BBS      - make prime congruent to 3 mod 4
3292  *   LTM_PRIME_SAFE     - make sure (p-1)/2 is prime as well (implies LTM_PRIME_BBS)
3293  *   LTM_PRIME_2MSB_OFF - make the 2nd highest bit zero
3294  *   LTM_PRIME_2MSB_ON  - make the 2nd highest bit one
3295  *
3296  * You have to supply a callback which fills in a buffer with random bytes.  "dat" is a parameter you can
3297  * have passed to the callback (e.g. a state or something).  This function doesn't use "dat" itself
3298  * so it can be NULL
3299  *
3300  */
3301
3302 /* This is possibly the mother of all prime generation functions, muahahahahaha! */
3303 int mp_prime_random_ex(mp_int *a, int t, int size, int flags, ltm_prime_callback cb, void *dat)
3304 {
3305    unsigned char *tmp, maskAND, maskOR_msb, maskOR_lsb;
3306    int res, err, bsize, maskOR_msb_offset;
3307
3308    /* sanity check the input */
3309    if (size <= 1 || t <= 0) {
3310       return MP_VAL;
3311    }
3312
3313    /* LTM_PRIME_SAFE implies LTM_PRIME_BBS */
3314    if (flags & LTM_PRIME_SAFE) {
3315       flags |= LTM_PRIME_BBS;
3316    }
3317
3318    /* calc the byte size */
3319    bsize = (size>>3)+((size&7)?1:0);
3320
3321    /* we need a buffer of bsize bytes */
3322    tmp = malloc(bsize);
3323    if (tmp == NULL) {
3324       return MP_MEM;
3325    }
3326
3327    /* calc the maskAND value for the MSbyte*/
3328    maskAND = ((size&7) == 0) ? 0xFF : (0xFF >> (8 - (size & 7))); 
3329
3330    /* calc the maskOR_msb */
3331    maskOR_msb        = 0;
3332    maskOR_msb_offset = ((size & 7) == 1) ? 1 : 0;
3333    if (flags & LTM_PRIME_2MSB_ON) {
3334       maskOR_msb     |= 1 << ((size - 2) & 7);
3335    } else if (flags & LTM_PRIME_2MSB_OFF) {
3336       maskAND        &= ~(1 << ((size - 2) & 7));
3337    }
3338
3339    /* get the maskOR_lsb */
3340    maskOR_lsb         = 0;
3341    if (flags & LTM_PRIME_BBS) {
3342       maskOR_lsb     |= 3;
3343    }
3344
3345    do {
3346       /* read the bytes */
3347       if (cb(tmp, bsize, dat) != bsize) {
3348          err = MP_VAL;
3349          goto error;
3350       }
3351  
3352       /* work over the MSbyte */
3353       tmp[0]    &= maskAND;
3354       tmp[0]    |= 1 << ((size - 1) & 7);
3355
3356       /* mix in the maskORs */
3357       tmp[maskOR_msb_offset]   |= maskOR_msb;
3358       tmp[bsize-1]             |= maskOR_lsb;
3359
3360       /* read it in */
3361       if ((err = mp_read_unsigned_bin(a, tmp, bsize)) != MP_OKAY)     { goto error; }
3362
3363       /* is it prime? */
3364       if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)           { goto error; }
3365       if (res == MP_NO) {  
3366          continue;
3367       }
3368
3369       if (flags & LTM_PRIME_SAFE) {
3370          /* see if (a-1)/2 is prime */
3371          if ((err = mp_sub_d(a, 1, a)) != MP_OKAY)                    { goto error; }
3372          if ((err = mp_div_2(a, a)) != MP_OKAY)                       { goto error; }
3373  
3374          /* is it prime? */
3375          if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)        { goto error; }
3376       }
3377    } while (res == MP_NO);
3378
3379    if (flags & LTM_PRIME_SAFE) {
3380       /* restore a to the original value */
3381       if ((err = mp_mul_2(a, a)) != MP_OKAY)                          { goto error; }
3382       if ((err = mp_add_d(a, 1, a)) != MP_OKAY)                       { goto error; }
3383    }
3384
3385    err = MP_OKAY;
3386 error:
3387    free(tmp);
3388    return err;
3389 }
3390
3391 /* reads an unsigned char array, assumes the msb is stored first [big endian] */
3392 int
3393 mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
3394 {
3395   int     res;
3396
3397   /* make sure there are at least two digits */
3398   if (a->alloc < 2) {
3399      if ((res = mp_grow(a, 2)) != MP_OKAY) {
3400         return res;
3401      }
3402   }
3403
3404   /* zero the int */
3405   mp_zero (a);
3406
3407   /* read the bytes in */
3408   while (c-- > 0) {
3409     if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
3410       return res;
3411     }
3412
3413       a->dp[0] |= *b++;
3414       a->used += 1;
3415   }
3416   mp_clamp (a);
3417   return MP_OKAY;
3418 }
3419
3420 /* reduces x mod m, assumes 0 < x < m**2, mu is 
3421  * precomputed via mp_reduce_setup.
3422  * From HAC pp.604 Algorithm 14.42
3423  */
3424 int
3425 mp_reduce (mp_int * x, const mp_int * m, const mp_int * mu)
3426 {
3427   mp_int  q;
3428   int     res, um = m->used;
3429
3430   /* q = x */
3431   if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
3432     return res;
3433   }
3434
3435   /* q1 = x / b**(k-1)  */
3436   mp_rshd (&q, um - 1);         
3437
3438   /* according to HAC this optimization is ok */
3439   if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
3440     if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
3441       goto CLEANUP;
3442     }
3443   } else {
3444     if ((res = s_mp_mul_high_digs (&q, mu, &q, um - 1)) != MP_OKAY) {
3445       goto CLEANUP;
3446     }
3447   }
3448
3449   /* q3 = q2 / b**(k+1) */
3450   mp_rshd (&q, um + 1);         
3451
3452   /* x = x mod b**(k+1), quick (no division) */
3453   if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
3454     goto CLEANUP;
3455   }
3456
3457   /* q = q * m mod b**(k+1), quick (no division) */
3458   if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
3459     goto CLEANUP;
3460   }
3461
3462   /* x = x - q */
3463   if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
3464     goto CLEANUP;
3465   }
3466
3467   /* If x < 0, add b**(k+1) to it */
3468   if (mp_cmp_d (x, 0) == MP_LT) {
3469     mp_set (&q, 1);
3470     if ((res = mp_lshd (&q, um + 1)) != MP_OKAY)
3471       goto CLEANUP;
3472     if ((res = mp_add (x, &q, x)) != MP_OKAY)
3473       goto CLEANUP;
3474   }
3475
3476   /* Back off if it's too big */
3477   while (mp_cmp (x, m) != MP_LT) {
3478     if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
3479       goto CLEANUP;
3480     }
3481   }
3482   
3483 CLEANUP:
3484   mp_clear (&q);
3485
3486   return res;
3487 }
3488
3489 /* reduces a modulo n where n is of the form 2**p - d */
3490 int
3491 mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d)
3492 {
3493    mp_int q;
3494    int    p, res;
3495    
3496    if ((res = mp_init(&q)) != MP_OKAY) {
3497       return res;
3498    }
3499    
3500    p = mp_count_bits(n);    
3501 top:
3502    /* q = a/2**p, a = a mod 2**p */
3503    if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
3504       goto ERR;
3505    }
3506    
3507    if (d != 1) {
3508       /* q = q * d */
3509       if ((res = mp_mul_d(&q, d, &q)) != MP_OKAY) { 
3510          goto ERR;
3511       }
3512    }
3513    
3514    /* a = a + q */
3515    if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
3516       goto ERR;
3517    }
3518    
3519    if (mp_cmp_mag(a, n) != MP_LT) {
3520       s_mp_sub(a, n, a);
3521       goto top;
3522    }
3523    
3524 ERR:
3525    mp_clear(&q);
3526    return res;
3527 }
3528
3529 /* determines the setup value */
3530 int 
3531 mp_reduce_2k_setup(const mp_int *a, mp_digit *d)
3532 {
3533    int res, p;
3534    mp_int tmp;
3535    
3536    if ((res = mp_init(&tmp)) != MP_OKAY) {
3537       return res;
3538    }
3539    
3540    p = mp_count_bits(a);
3541    if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
3542       mp_clear(&tmp);
3543       return res;
3544    }
3545    
3546    if ((res = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
3547       mp_clear(&tmp);
3548       return res;
3549    }
3550    
3551    *d = tmp.dp[0];
3552    mp_clear(&tmp);
3553    return MP_OKAY;
3554 }
3555
3556 /* pre-calculate the value required for Barrett reduction
3557  * For a given modulus "b" it calulates the value required in "a"
3558  */
3559 int mp_reduce_setup (mp_int * a, const mp_int * b)
3560 {
3561   int     res;
3562
3563   if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
3564     return res;
3565   }
3566   return mp_div (a, b, a, NULL);
3567 }
3568
3569 /* shift right a certain amount of digits */
3570 void mp_rshd (mp_int * a, int b)
3571 {
3572   int     x;
3573
3574   /* if b <= 0 then ignore it */
3575   if (b <= 0) {
3576     return;
3577   }
3578
3579   /* if b > used then simply zero it and return */
3580   if (a->used <= b) {
3581     mp_zero (a);
3582     return;
3583   }
3584
3585   {
3586     register mp_digit *bottom, *top;
3587
3588     /* shift the digits down */
3589
3590     /* bottom */
3591     bottom = a->dp;
3592
3593     /* top [offset into digits] */
3594     top = a->dp + b;
3595
3596     /* this is implemented as a sliding window where 
3597      * the window is b-digits long and digits from 
3598      * the top of the window are copied to the bottom
3599      *
3600      * e.g.
3601
3602      b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
3603                  /\                   |      ---->
3604                   \-------------------/      ---->
3605      */
3606     for (x = 0; x < (a->used - b); x++) {
3607       *bottom++ = *top++;
3608     }
3609
3610     /* zero the top digits */
3611     for (; x < a->used; x++) {
3612       *bottom++ = 0;
3613     }
3614   }
3615   
3616   /* remove excess digits */
3617   a->used -= b;
3618 }
3619
3620 /* set to a digit */
3621 void mp_set (mp_int * a, mp_digit b)
3622 {
3623   mp_zero (a);
3624   a->dp[0] = b & MP_MASK;
3625   a->used  = (a->dp[0] != 0) ? 1 : 0;
3626 }
3627
3628 /* set a 32-bit const */
3629 int mp_set_int (mp_int * a, unsigned long b)
3630 {
3631   int     x, res;
3632
3633   mp_zero (a);
3634   
3635   /* set four bits at a time */
3636   for (x = 0; x < 8; x++) {
3637     /* shift the number up four bits */
3638     if ((res = mp_mul_2d (a, 4, a)) != MP_OKAY) {
3639       return res;
3640     }
3641
3642     /* OR in the top four bits of the source */
3643     a->dp[0] |= (b >> 28) & 15;
3644
3645     /* shift the source up to the next four bits */
3646     b <<= 4;
3647
3648     /* ensure that digits are not clamped off */
3649     a->used += 1;
3650   }
3651   mp_clamp (a);
3652   return MP_OKAY;
3653 }
3654
3655 /* shrink a bignum */
3656 int mp_shrink (mp_int * a)
3657 {
3658   mp_digit *tmp;
3659   if (a->alloc != a->used && a->used > 0) {
3660     if ((tmp = realloc (a->dp, sizeof (mp_digit) * a->used)) == NULL) {
3661       return MP_MEM;
3662     }
3663     a->dp    = tmp;
3664     a->alloc = a->used;
3665   }
3666   return MP_OKAY;
3667 }
3668
3669 /* get the size for an signed equivalent */
3670 int mp_signed_bin_size (const mp_int * a)
3671 {
3672   return 1 + mp_unsigned_bin_size (a);
3673 }
3674
3675 /* computes b = a*a */
3676 int
3677 mp_sqr (const mp_int * a, mp_int * b)
3678 {
3679   int     res;
3680
3681 if (a->used >= KARATSUBA_SQR_CUTOFF) {
3682     res = mp_karatsuba_sqr (a, b);
3683   } else 
3684   {
3685     /* can we use the fast comba multiplier? */
3686     if ((a->used * 2 + 1) < MP_WARRAY && 
3687          a->used < 
3688          (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
3689       res = fast_s_mp_sqr (a, b);
3690     } else
3691       res = s_mp_sqr (a, b);
3692   }
3693   b->sign = MP_ZPOS;
3694   return res;
3695 }
3696
3697 /* c = a * a (mod b) */
3698 int
3699 mp_sqrmod (const mp_int * a, mp_int * b, mp_int * c)
3700 {
3701   int     res;
3702   mp_int  t;
3703
3704   if ((res = mp_init (&t)) != MP_OKAY) {
3705     return res;
3706   }
3707
3708   if ((res = mp_sqr (a, &t)) != MP_OKAY) {
3709     mp_clear (&t);
3710     return res;
3711   }
3712   res = mp_mod (&t, b, c);
3713   mp_clear (&t);
3714   return res;
3715 }
3716
3717 /* high level subtraction (handles signs) */
3718 int
3719 mp_sub (mp_int * a, mp_int * b, mp_int * c)
3720 {
3721   int     sa, sb, res;
3722
3723   sa = a->sign;
3724   sb = b->sign;
3725
3726   if (sa != sb) {
3727     /* subtract a negative from a positive, OR */
3728     /* subtract a positive from a negative. */
3729     /* In either case, ADD their magnitudes, */
3730     /* and use the sign of the first number. */
3731     c->sign = sa;
3732     res = s_mp_add (a, b, c);
3733   } else {
3734     /* subtract a positive from a positive, OR */
3735     /* subtract a negative from a negative. */
3736     /* First, take the difference between their */
3737     /* magnitudes, then... */
3738     if (mp_cmp_mag (a, b) != MP_LT) {
3739       /* Copy the sign from the first */
3740       c->sign = sa;
3741       /* The first has a larger or equal magnitude */
3742       res = s_mp_sub (a, b, c);
3743     } else {
3744       /* The result has the *opposite* sign from */
3745       /* the first number. */
3746       c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
3747       /* The second has a larger magnitude */
3748       res = s_mp_sub (b, a, c);
3749     }
3750   }
3751   return res;
3752 }
3753
3754 /* single digit subtraction */
3755 int
3756 mp_sub_d (mp_int * a, mp_digit b, mp_int * c)
3757 {
3758   mp_digit *tmpa, *tmpc, mu;
3759   int       res, ix, oldused;
3760
3761   /* grow c as required */
3762   if (c->alloc < a->used + 1) {
3763      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
3764         return res;
3765      }
3766   }
3767
3768   /* if a is negative just do an unsigned
3769    * addition [with fudged signs]
3770    */
3771   if (a->sign == MP_NEG) {
3772      a->sign = MP_ZPOS;
3773      res     = mp_add_d(a, b, c);
3774      a->sign = c->sign = MP_NEG;
3775      return res;
3776   }
3777
3778   /* setup regs */
3779   oldused = c->used;
3780   tmpa    = a->dp;
3781   tmpc    = c->dp;
3782
3783   /* if a <= b simply fix the single digit */
3784   if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
3785      if (a->used == 1) {
3786         *tmpc++ = b - *tmpa;
3787      } else {
3788         *tmpc++ = b;
3789      }
3790      ix      = 1;
3791
3792      /* negative/1digit */
3793      c->sign = MP_NEG;
3794      c->used = 1;
3795   } else {
3796      /* positive/size */
3797      c->sign = MP_ZPOS;
3798      c->used = a->used;
3799
3800      /* subtract first digit */
3801      *tmpc    = *tmpa++ - b;
3802      mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3803      *tmpc++ &= MP_MASK;
3804
3805      /* handle rest of the digits */
3806      for (ix = 1; ix < a->used; ix++) {
3807         *tmpc    = *tmpa++ - mu;
3808         mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3809         *tmpc++ &= MP_MASK;
3810      }
3811   }
3812
3813   /* zero excess digits */
3814   while (ix++ < oldused) {
3815      *tmpc++ = 0;
3816   }
3817   mp_clamp(c);
3818   return MP_OKAY;
3819 }
3820
3821 /* store in unsigned [big endian] format */
3822 int
3823 mp_to_unsigned_bin (const mp_int * a, unsigned char *b)
3824 {
3825   int     x, res;
3826   mp_int  t;
3827
3828   if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
3829     return res;
3830   }
3831
3832   x = 0;
3833   while (mp_iszero (&t) == 0) {
3834     b[x++] = (unsigned char) (t.dp[0] & 255);
3835     if ((res = mp_div_2d (&t, 8, &t, NULL)) != MP_OKAY) {
3836       mp_clear (&t);
3837       return res;
3838     }
3839   }
3840   bn_reverse (b, x);
3841   mp_clear (&t);
3842   return MP_OKAY;
3843 }
3844
3845 /* get the size for an unsigned equivalent */
3846 int
3847 mp_unsigned_bin_size (const mp_int * a)
3848 {
3849   int     size = mp_count_bits (a);
3850   return (size / 8 + ((size & 7) != 0 ? 1 : 0));
3851 }
3852
3853 /* set to zero */
3854 void
3855 mp_zero (mp_int * a)
3856 {
3857   a->sign = MP_ZPOS;
3858   a->used = 0;
3859   memset (a->dp, 0, sizeof (mp_digit) * a->alloc);
3860 }
3861
3862 /* reverse an array, used for radix code */
3863 static void
3864 bn_reverse (unsigned char *s, int len)
3865 {
3866   int     ix, iy;
3867   unsigned char t;
3868
3869   ix = 0;
3870   iy = len - 1;
3871   while (ix < iy) {
3872     t     = s[ix];
3873     s[ix] = s[iy];
3874     s[iy] = t;
3875     ++ix;
3876     --iy;
3877   }
3878 }
3879
3880 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
3881 static int
3882 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
3883 {
3884   mp_int *x;
3885   int     olduse, res, min, max;
3886
3887   /* find sizes, we let |a| <= |b| which means we have to sort
3888    * them.  "x" will point to the input with the most digits
3889    */
3890   if (a->used > b->used) {
3891     min = b->used;
3892     max = a->used;
3893     x = a;
3894   } else {
3895     min = a->used;
3896     max = b->used;
3897     x = b;
3898   }
3899
3900   /* init result */
3901   if (c->alloc < max + 1) {
3902     if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
3903       return res;
3904     }
3905   }
3906
3907   /* get old used digit count and set new one */
3908   olduse = c->used;
3909   c->used = max + 1;
3910
3911   {
3912     register mp_digit u, *tmpa, *tmpb, *tmpc;
3913     register int i;
3914
3915     /* alias for digit pointers */
3916
3917     /* first input */
3918     tmpa = a->dp;
3919
3920     /* second input */
3921     tmpb = b->dp;
3922
3923     /* destination */
3924     tmpc = c->dp;
3925
3926     /* zero the carry */
3927     u = 0;
3928     for (i = 0; i < min; i++) {
3929       /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
3930       *tmpc = *tmpa++ + *tmpb++ + u;
3931
3932       /* U = carry bit of T[i] */
3933       u = *tmpc >> ((mp_digit)DIGIT_BIT);
3934
3935       /* take away carry bit from T[i] */
3936       *tmpc++ &= MP_MASK;
3937     }
3938
3939     /* now copy higher words if any, that is in A+B 
3940      * if A or B has more digits add those in 
3941      */
3942     if (min != max) {
3943       for (; i < max; i++) {
3944         /* T[i] = X[i] + U */
3945         *tmpc = x->dp[i] + u;
3946
3947         /* U = carry bit of T[i] */
3948         u = *tmpc >> ((mp_digit)DIGIT_BIT);
3949
3950         /* take away carry bit from T[i] */
3951         *tmpc++ &= MP_MASK;
3952       }
3953     }
3954
3955     /* add carry */
3956     *tmpc++ = u;
3957
3958     /* clear digits above oldused */
3959     for (i = c->used; i < olduse; i++) {
3960       *tmpc++ = 0;
3961     }
3962   }
3963
3964   mp_clamp (c);
3965   return MP_OKAY;
3966 }
3967
3968 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
3969 {
3970   mp_int  M[256], res, mu;
3971   mp_digit buf;
3972   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
3973
3974   /* find window size */
3975   x = mp_count_bits (X);
3976   if (x <= 7) {
3977     winsize = 2;
3978   } else if (x <= 36) {
3979     winsize = 3;
3980   } else if (x <= 140) {
3981     winsize = 4;
3982   } else if (x <= 450) {
3983     winsize = 5;
3984   } else if (x <= 1303) {
3985     winsize = 6;
3986   } else if (x <= 3529) {
3987     winsize = 7;
3988   } else {
3989     winsize = 8;
3990   }
3991
3992   /* init M array */
3993   /* init first cell */
3994   if ((err = mp_init(&M[1])) != MP_OKAY) {
3995      return err; 
3996   }
3997
3998   /* now init the second half of the array */
3999   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4000     if ((err = mp_init(&M[x])) != MP_OKAY) {
4001       for (y = 1<<(winsize-1); y < x; y++) {
4002         mp_clear (&M[y]);
4003       }
4004       mp_clear(&M[1]);
4005       return err;
4006     }
4007   }
4008
4009   /* create mu, used for Barrett reduction */
4010   if ((err = mp_init (&mu)) != MP_OKAY) {
4011     goto __M;
4012   }
4013   if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
4014     goto __MU;
4015   }
4016
4017   /* create M table
4018    *
4019    * The M table contains powers of the base, 
4020    * e.g. M[x] = G**x mod P
4021    *
4022    * The first half of the table is not 
4023    * computed though accept for M[0] and M[1]
4024    */
4025   if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
4026     goto __MU;
4027   }
4028
4029   /* compute the value at M[1<<(winsize-1)] by squaring 
4030    * M[1] (winsize-1) times 
4031    */
4032   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
4033     goto __MU;
4034   }
4035
4036   for (x = 0; x < (winsize - 1); x++) {
4037     if ((err = mp_sqr (&M[1 << (winsize - 1)], 
4038                        &M[1 << (winsize - 1)])) != MP_OKAY) {
4039       goto __MU;
4040     }
4041     if ((err = mp_reduce (&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
4042       goto __MU;
4043     }
4044   }
4045
4046   /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
4047    * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
4048    */
4049   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
4050     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
4051       goto __MU;
4052     }
4053     if ((err = mp_reduce (&M[x], P, &mu)) != MP_OKAY) {
4054       goto __MU;
4055     }
4056   }
4057
4058   /* setup result */
4059   if ((err = mp_init (&res)) != MP_OKAY) {
4060     goto __MU;
4061   }
4062   mp_set (&res, 1);
4063
4064   /* set initial mode and bit cnt */
4065   mode   = 0;
4066   bitcnt = 1;
4067   buf    = 0;
4068   digidx = X->used - 1;
4069   bitcpy = 0;
4070   bitbuf = 0;
4071
4072   for (;;) {
4073     /* grab next digit as required */
4074     if (--bitcnt == 0) {
4075       /* if digidx == -1 we are out of digits */
4076       if (digidx == -1) {
4077         break;
4078       }
4079       /* read next digit and reset the bitcnt */
4080       buf    = X->dp[digidx--];
4081       bitcnt = DIGIT_BIT;
4082     }
4083
4084     /* grab the next msb from the exponent */
4085     y     = (buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
4086     buf <<= (mp_digit)1;
4087
4088     /* if the bit is zero and mode == 0 then we ignore it
4089      * These represent the leading zero bits before the first 1 bit
4090      * in the exponent.  Technically this opt is not required but it
4091      * does lower the # of trivial squaring/reductions used
4092      */
4093     if (mode == 0 && y == 0) {
4094       continue;
4095     }
4096
4097     /* if the bit is zero and mode == 1 then we square */
4098     if (mode == 1 && y == 0) {
4099       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4100         goto __RES;
4101       }
4102       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4103         goto __RES;
4104       }
4105       continue;
4106     }
4107
4108     /* else we add it to the window */
4109     bitbuf |= (y << (winsize - ++bitcpy));
4110     mode    = 2;
4111
4112     if (bitcpy == winsize) {
4113       /* ok window is filled so square as required and multiply  */
4114       /* square first */
4115       for (x = 0; x < winsize; x++) {
4116         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4117           goto __RES;
4118         }
4119         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4120           goto __RES;
4121         }
4122       }
4123
4124       /* then multiply */
4125       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
4126         goto __RES;
4127       }
4128       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4129         goto __RES;
4130       }
4131
4132       /* empty window and reset */
4133       bitcpy = 0;
4134       bitbuf = 0;
4135       mode   = 1;
4136     }
4137   }
4138
4139   /* if bits remain then square/multiply */
4140   if (mode == 2 && bitcpy > 0) {
4141     /* square then multiply if the bit is set */
4142     for (x = 0; x < bitcpy; x++) {
4143       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4144         goto __RES;
4145       }
4146       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4147         goto __RES;
4148       }
4149
4150       bitbuf <<= 1;
4151       if ((bitbuf & (1 << winsize)) != 0) {
4152         /* then multiply */
4153         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
4154           goto __RES;
4155         }
4156         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4157           goto __RES;
4158         }
4159       }
4160     }
4161   }
4162
4163   mp_exch (&res, Y);
4164   err = MP_OKAY;
4165 __RES:mp_clear (&res);
4166 __MU:mp_clear (&mu);
4167 __M:
4168   mp_clear(&M[1]);
4169   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4170     mp_clear (&M[x]);
4171   }
4172   return err;
4173 }
4174
4175 /* multiplies |a| * |b| and only computes up to digs digits of result
4176  * HAC pp. 595, Algorithm 14.12  Modified so you can control how 
4177  * many digits of output are created.
4178  */
4179 static int
4180 s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4181 {
4182   mp_int  t;
4183   int     res, pa, pb, ix, iy;
4184   mp_digit u;
4185   mp_word r;
4186   mp_digit tmpx, *tmpt, *tmpy;
4187
4188   /* can we use the fast multiplier? */
4189   if (((digs) < MP_WARRAY) &&
4190       MIN (a->used, b->used) < 
4191           (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4192     return fast_s_mp_mul_digs (a, b, c, digs);
4193   }
4194
4195   if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
4196     return res;
4197   }
4198   t.used = digs;
4199
4200   /* compute the digits of the product directly */
4201   pa = a->used;
4202   for (ix = 0; ix < pa; ix++) {
4203     /* set the carry to zero */
4204     u = 0;
4205
4206     /* limit ourselves to making digs digits of output */
4207     pb = MIN (b->used, digs - ix);
4208
4209     /* setup some aliases */
4210     /* copy of the digit from a used within the nested loop */
4211     tmpx = a->dp[ix];
4212     
4213     /* an alias for the destination shifted ix places */
4214     tmpt = t.dp + ix;
4215     
4216     /* an alias for the digits of b */
4217     tmpy = b->dp;
4218
4219     /* compute the columns of the output and propagate the carry */
4220     for (iy = 0; iy < pb; iy++) {
4221       /* compute the column as a mp_word */
4222       r       = ((mp_word)*tmpt) +
4223                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4224                 ((mp_word) u);
4225
4226       /* the new column is the lower part of the result */
4227       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4228
4229       /* get the carry word from the result */
4230       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4231     }
4232     /* set carry if it is placed below digs */
4233     if (ix + iy < digs) {
4234       *tmpt = u;
4235     }
4236   }
4237
4238   mp_clamp (&t);
4239   mp_exch (&t, c);
4240
4241   mp_clear (&t);
4242   return MP_OKAY;
4243 }
4244
4245 /* multiplies |a| * |b| and does not compute the lower digs digits
4246  * [meant to get the higher part of the product]
4247  */
4248 static int
4249 s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4250 {
4251   mp_int  t;
4252   int     res, pa, pb, ix, iy;
4253   mp_digit u;
4254   mp_word r;
4255   mp_digit tmpx, *tmpt, *tmpy;
4256
4257   /* can we use the fast multiplier? */
4258   if (((a->used + b->used + 1) < MP_WARRAY)
4259       && MIN (a->used, b->used) < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4260     return fast_s_mp_mul_high_digs (a, b, c, digs);
4261   }
4262
4263   if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
4264     return res;
4265   }
4266   t.used = a->used + b->used + 1;
4267
4268   pa = a->used;
4269   pb = b->used;
4270   for (ix = 0; ix < pa; ix++) {
4271     /* clear the carry */
4272     u = 0;
4273
4274     /* left hand side of A[ix] * B[iy] */
4275     tmpx = a->dp[ix];
4276
4277     /* alias to the address of where the digits will be stored */
4278     tmpt = &(t.dp[digs]);
4279
4280     /* alias for where to read the right hand side from */
4281     tmpy = b->dp + (digs - ix);
4282
4283     for (iy = digs - ix; iy < pb; iy++) {
4284       /* calculate the double precision result */
4285       r       = ((mp_word)*tmpt) +
4286                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4287                 ((mp_word) u);
4288
4289       /* get the lower part */
4290       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4291
4292       /* carry the carry */
4293       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4294     }
4295     *tmpt = u;
4296   }
4297   mp_clamp (&t);
4298   mp_exch (&t, c);
4299   mp_clear (&t);
4300   return MP_OKAY;
4301 }
4302
4303 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
4304 static int
4305 s_mp_sqr (const mp_int * a, mp_int * b)
4306 {
4307   mp_int  t;
4308   int     res, ix, iy, pa;
4309   mp_word r;
4310   mp_digit u, tmpx, *tmpt;
4311
4312   pa = a->used;
4313   if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
4314     return res;
4315   }
4316
4317   /* default used is maximum possible size */
4318   t.used = 2*pa + 1;
4319
4320   for (ix = 0; ix < pa; ix++) {
4321     /* first calculate the digit at 2*ix */
4322     /* calculate double precision result */
4323     r = ((mp_word) t.dp[2*ix]) +
4324         ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
4325
4326     /* store lower part in result */
4327     t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
4328
4329     /* get the carry */
4330     u           = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4331
4332     /* left hand side of A[ix] * A[iy] */
4333     tmpx        = a->dp[ix];
4334
4335     /* alias for where to store the results */
4336     tmpt        = t.dp + (2*ix + 1);
4337     
4338     for (iy = ix + 1; iy < pa; iy++) {
4339       /* first calculate the product */
4340       r       = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
4341
4342       /* now calculate the double precision result, note we use
4343        * addition instead of *2 since it's easier to optimize
4344        */
4345       r       = ((mp_word) *tmpt) + r + r + ((mp_word) u);
4346
4347       /* store lower part */
4348       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4349
4350       /* get carry */
4351       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4352     }
4353     /* propagate upwards */
4354     while (u != ((mp_digit) 0)) {
4355       r       = ((mp_word) *tmpt) + ((mp_word) u);
4356       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4357       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4358     }
4359   }
4360
4361   mp_clamp (&t);
4362   mp_exch (&t, b);
4363   mp_clear (&t);
4364   return MP_OKAY;
4365 }
4366
4367 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
4368 int
4369 s_mp_sub (const mp_int * a, const mp_int * b, mp_int * c)
4370 {
4371   int     olduse, res, min, max;
4372
4373   /* find sizes */
4374   min = b->used;
4375   max = a->used;
4376
4377   /* init result */
4378   if (c->alloc < max) {
4379     if ((res = mp_grow (c, max)) != MP_OKAY) {
4380       return res;
4381     }
4382   }
4383   olduse = c->used;
4384   c->used = max;
4385
4386   {
4387     register mp_digit u, *tmpa, *tmpb, *tmpc;
4388     register int i;
4389
4390     /* alias for digit pointers */
4391     tmpa = a->dp;
4392     tmpb = b->dp;
4393     tmpc = c->dp;
4394
4395     /* set carry to zero */
4396     u = 0;
4397     for (i = 0; i < min; i++) {
4398       /* T[i] = A[i] - B[i] - U */
4399       *tmpc = *tmpa++ - *tmpb++ - u;
4400
4401       /* U = carry bit of T[i]
4402        * Note this saves performing an AND operation since
4403        * if a carry does occur it will propagate all the way to the
4404        * MSB.  As a result a single shift is enough to get the carry
4405        */
4406       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4407
4408       /* Clear carry from T[i] */
4409       *tmpc++ &= MP_MASK;
4410     }
4411
4412     /* now copy higher words if any, e.g. if A has more digits than B  */
4413     for (; i < max; i++) {
4414       /* T[i] = A[i] - U */
4415       *tmpc = *tmpa++ - u;
4416
4417       /* U = carry bit of T[i] */
4418       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4419
4420       /* Clear carry from T[i] */
4421       *tmpc++ &= MP_MASK;
4422     }
4423
4424     /* clear digits above used (since we may not have grown result above) */
4425     for (i = c->used; i < olduse; i++) {
4426       *tmpc++ = 0;
4427     }
4428   }
4429
4430   mp_clamp (c);
4431   return MP_OKAY;
4432 }