user32/tests: Fix some window test failures on various Windows platforms.
[wine] / dlls / rsaenh / mpi.c
1 /*
2  * dlls/rsaenh/mpi.c
3  * Multi Precision Integer functions
4  *
5  * Copyright 2004 Michael Jung
6  * Based on public domain code by Tom St Denis (tomstdenis@iahu.ca)
7  *
8  * This library is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * This library is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with this library; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
21  */
22
23 /*
24  * This file contains code from the LibTomCrypt cryptographic 
25  * library written by Tom St Denis (tomstdenis@iahu.ca). LibTomCrypt
26  * is in the public domain. The code in this file is tailored to
27  * special requirements. Take a look at http://libtomcrypt.org for the
28  * original version. 
29  */
30
31 #include <stdarg.h>
32 #include "tomcrypt.h"
33
34 /* Known optimal configurations
35  CPU                    /Compiler     /MUL CUTOFF/SQR CUTOFF
36 -------------------------------------------------------------
37  Intel P4 Northwood     /GCC v3.4.1   /        88/       128/LTM 0.32 ;-)
38 */
39 static const int KARATSUBA_MUL_CUTOFF = 88,  /* Min. number of digits before Karatsuba multiplication is used. */
40                  KARATSUBA_SQR_CUTOFF = 128; /* Min. number of digits before Karatsuba squaring is used. */
41
42 static void bn_reverse(unsigned char *s, int len);
43 static int s_mp_add(mp_int *a, mp_int *b, mp_int *c);
44 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y);
45 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
46 static int s_mp_mul_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
47 static int s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
48 static int s_mp_sqr(const mp_int *a, mp_int *b);
49 static int s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c);
50 static int mp_exptmod_fast(const mp_int *G, const mp_int *X, mp_int *P, mp_int *Y, int mode);
51 static int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c);
52 static int mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c);
53 static int mp_karatsuba_sqr(const mp_int *a, mp_int *b);
54
55 /* grow as required */
56 static int mp_grow (mp_int * a, int size)
57 {
58   int     i;
59   mp_digit *tmp;
60
61   /* if the alloc size is smaller alloc more ram */
62   if (a->alloc < size) {
63     /* ensure there are always at least MP_PREC digits extra on top */
64     size += (MP_PREC * 2) - (size % MP_PREC);
65
66     /* reallocate the array a->dp
67      *
68      * We store the return in a temporary variable
69      * in case the operation failed we don't want
70      * to overwrite the dp member of a.
71      */
72     tmp = realloc (a->dp, sizeof (mp_digit) * size);
73     if (tmp == NULL) {
74       /* reallocation failed but "a" is still valid [can be freed] */
75       return MP_MEM;
76     }
77
78     /* reallocation succeeded so set a->dp */
79     a->dp = tmp;
80
81     /* zero excess digits */
82     i        = a->alloc;
83     a->alloc = size;
84     for (; i < a->alloc; i++) {
85       a->dp[i] = 0;
86     }
87   }
88   return MP_OKAY;
89 }
90
91 /* b = a/2 */
92 static int mp_div_2(const mp_int * a, mp_int * b)
93 {
94   int     x, res, oldused;
95
96   /* copy */
97   if (b->alloc < a->used) {
98     if ((res = mp_grow (b, a->used)) != MP_OKAY) {
99       return res;
100     }
101   }
102
103   oldused = b->used;
104   b->used = a->used;
105   {
106     register mp_digit r, rr, *tmpa, *tmpb;
107
108     /* source alias */
109     tmpa = a->dp + b->used - 1;
110
111     /* dest alias */
112     tmpb = b->dp + b->used - 1;
113
114     /* carry */
115     r = 0;
116     for (x = b->used - 1; x >= 0; x--) {
117       /* get the carry for the next iteration */
118       rr = *tmpa & 1;
119
120       /* shift the current digit, add in carry and store */
121       *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
122
123       /* forward carry to next iteration */
124       r = rr;
125     }
126
127     /* zero excess digits */
128     tmpb = b->dp + b->used;
129     for (x = b->used; x < oldused; x++) {
130       *tmpb++ = 0;
131     }
132   }
133   b->sign = a->sign;
134   mp_clamp (b);
135   return MP_OKAY;
136 }
137
138 /* swap the elements of two integers, for cases where you can't simply swap the
139  * mp_int pointers around
140  */
141 static void
142 mp_exch (mp_int * a, mp_int * b)
143 {
144   mp_int  t;
145
146   t  = *a;
147   *a = *b;
148   *b = t;
149 }
150
151 /* init a new mp_int */
152 static int mp_init (mp_int * a)
153 {
154   int i;
155
156   /* allocate memory required and clear it */
157   a->dp = malloc (sizeof (mp_digit) * MP_PREC);
158   if (a->dp == NULL) {
159     return MP_MEM;
160   }
161
162   /* set the digits to zero */
163   for (i = 0; i < MP_PREC; i++) {
164       a->dp[i] = 0;
165   }
166
167   /* set the used to zero, allocated digits to the default precision
168    * and sign to positive */
169   a->used  = 0;
170   a->alloc = MP_PREC;
171   a->sign  = MP_ZPOS;
172
173   return MP_OKAY;
174 }
175
176 /* init an mp_init for a given size */
177 static int mp_init_size (mp_int * a, int size)
178 {
179   int x;
180
181   /* pad size so there are always extra digits */
182   size += (MP_PREC * 2) - (size % MP_PREC);
183
184   /* alloc mem */
185   a->dp = malloc (sizeof (mp_digit) * size);
186   if (a->dp == NULL) {
187     return MP_MEM;
188   }
189
190   /* set the members */
191   a->used  = 0;
192   a->alloc = size;
193   a->sign  = MP_ZPOS;
194
195   /* zero the digits */
196   for (x = 0; x < size; x++) {
197       a->dp[x] = 0;
198   }
199
200   return MP_OKAY;
201 }
202
203 /* clear one (frees)  */
204 static void
205 mp_clear (mp_int * a)
206 {
207   int i;
208
209   /* only do anything if a hasn't been freed previously */
210   if (a->dp != NULL) {
211     /* first zero the digits */
212     for (i = 0; i < a->used; i++) {
213         a->dp[i] = 0;
214     }
215
216     /* free ram */
217     free(a->dp);
218
219     /* reset members to make debugging easier */
220     a->dp    = NULL;
221     a->alloc = a->used = 0;
222     a->sign  = MP_ZPOS;
223   }
224 }
225
226 /* set to zero */
227 static void
228 mp_zero (mp_int * a)
229 {
230   a->sign = MP_ZPOS;
231   a->used = 0;
232   memset (a->dp, 0, sizeof (mp_digit) * a->alloc);
233 }
234
235 /* b = |a|
236  *
237  * Simple function copies the input and fixes the sign to positive
238  */
239 static int
240 mp_abs (const mp_int * a, mp_int * b)
241 {
242   int     res;
243
244   /* copy a to b */
245   if (a != b) {
246      if ((res = mp_copy (a, b)) != MP_OKAY) {
247        return res;
248      }
249   }
250
251   /* force the sign of b to positive */
252   b->sign = MP_ZPOS;
253
254   return MP_OKAY;
255 }
256
257 /* computes the modular inverse via binary extended euclidean algorithm, 
258  * that is c = 1/a mod b 
259  *
260  * Based on slow invmod except this is optimized for the case where b is 
261  * odd as per HAC Note 14.64 on pp. 610
262  */
263 static int
264 fast_mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
265 {
266   mp_int  x, y, u, v, B, D;
267   int     res, neg;
268
269   /* 2. [modified] b must be odd   */
270   if (mp_iseven (b) == 1) {
271     return MP_VAL;
272   }
273
274   /* init all our temps */
275   if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D, NULL)) != MP_OKAY) {
276      return res;
277   }
278
279   /* x == modulus, y == value to invert */
280   if ((res = mp_copy (b, &x)) != MP_OKAY) {
281     goto __ERR;
282   }
283
284   /* we need y = |a| */
285   if ((res = mp_abs (a, &y)) != MP_OKAY) {
286     goto __ERR;
287   }
288
289   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
290   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
291     goto __ERR;
292   }
293   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
294     goto __ERR;
295   }
296   mp_set (&D, 1);
297
298 top:
299   /* 4.  while u is even do */
300   while (mp_iseven (&u) == 1) {
301     /* 4.1 u = u/2 */
302     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
303       goto __ERR;
304     }
305     /* 4.2 if B is odd then */
306     if (mp_isodd (&B) == 1) {
307       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
308         goto __ERR;
309       }
310     }
311     /* B = B/2 */
312     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
313       goto __ERR;
314     }
315   }
316
317   /* 5.  while v is even do */
318   while (mp_iseven (&v) == 1) {
319     /* 5.1 v = v/2 */
320     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
321       goto __ERR;
322     }
323     /* 5.2 if D is odd then */
324     if (mp_isodd (&D) == 1) {
325       /* D = (D-x)/2 */
326       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
327         goto __ERR;
328       }
329     }
330     /* D = D/2 */
331     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
332       goto __ERR;
333     }
334   }
335
336   /* 6.  if u >= v then */
337   if (mp_cmp (&u, &v) != MP_LT) {
338     /* u = u - v, B = B - D */
339     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
340       goto __ERR;
341     }
342
343     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
344       goto __ERR;
345     }
346   } else {
347     /* v - v - u, D = D - B */
348     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
349       goto __ERR;
350     }
351
352     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
353       goto __ERR;
354     }
355   }
356
357   /* if not zero goto step 4 */
358   if (mp_iszero (&u) == 0) {
359     goto top;
360   }
361
362   /* now a = C, b = D, gcd == g*v */
363
364   /* if v != 1 then there is no inverse */
365   if (mp_cmp_d (&v, 1) != MP_EQ) {
366     res = MP_VAL;
367     goto __ERR;
368   }
369
370   /* b is now the inverse */
371   neg = a->sign;
372   while (D.sign == MP_NEG) {
373     if ((res = mp_add (&D, b, &D)) != MP_OKAY) {
374       goto __ERR;
375     }
376   }
377   mp_exch (&D, c);
378   c->sign = neg;
379   res = MP_OKAY;
380
381 __ERR:mp_clear_multi (&x, &y, &u, &v, &B, &D, NULL);
382   return res;
383 }
384
385 /* computes xR**-1 == x (mod N) via Montgomery Reduction
386  *
387  * This is an optimized implementation of montgomery_reduce
388  * which uses the comba method to quickly calculate the columns of the
389  * reduction.
390  *
391  * Based on Algorithm 14.32 on pp.601 of HAC.
392 */
393 static int
394 fast_mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
395 {
396   int     ix, res, olduse;
397   mp_word W[MP_WARRAY];
398
399   /* get old used count */
400   olduse = x->used;
401
402   /* grow a as required */
403   if (x->alloc < n->used + 1) {
404     if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
405       return res;
406     }
407   }
408
409   /* first we have to get the digits of the input into
410    * an array of double precision words W[...]
411    */
412   {
413     register mp_word *_W;
414     register mp_digit *tmpx;
415
416     /* alias for the W[] array */
417     _W   = W;
418
419     /* alias for the digits of  x*/
420     tmpx = x->dp;
421
422     /* copy the digits of a into W[0..a->used-1] */
423     for (ix = 0; ix < x->used; ix++) {
424       *_W++ = *tmpx++;
425     }
426
427     /* zero the high words of W[a->used..m->used*2] */
428     for (; ix < n->used * 2 + 1; ix++) {
429       *_W++ = 0;
430     }
431   }
432
433   /* now we proceed to zero successive digits
434    * from the least significant upwards
435    */
436   for (ix = 0; ix < n->used; ix++) {
437     /* mu = ai * m' mod b
438      *
439      * We avoid a double precision multiplication (which isn't required)
440      * by casting the value down to a mp_digit.  Note this requires
441      * that W[ix-1] have  the carry cleared (see after the inner loop)
442      */
443     register mp_digit mu;
444     mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
445
446     /* a = a + mu * m * b**i
447      *
448      * This is computed in place and on the fly.  The multiplication
449      * by b**i is handled by offsetting which columns the results
450      * are added to.
451      *
452      * Note the comba method normally doesn't handle carries in the
453      * inner loop In this case we fix the carry from the previous
454      * column since the Montgomery reduction requires digits of the
455      * result (so far) [see above] to work.  This is
456      * handled by fixing up one carry after the inner loop.  The
457      * carry fixups are done in order so after these loops the
458      * first m->used words of W[] have the carries fixed
459      */
460     {
461       register int iy;
462       register mp_digit *tmpn;
463       register mp_word *_W;
464
465       /* alias for the digits of the modulus */
466       tmpn = n->dp;
467
468       /* Alias for the columns set by an offset of ix */
469       _W = W + ix;
470
471       /* inner loop */
472       for (iy = 0; iy < n->used; iy++) {
473           *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
474       }
475     }
476
477     /* now fix carry for next digit, W[ix+1] */
478     W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
479   }
480
481   /* now we have to propagate the carries and
482    * shift the words downward [all those least
483    * significant digits we zeroed].
484    */
485   {
486     register mp_digit *tmpx;
487     register mp_word *_W, *_W1;
488
489     /* nox fix rest of carries */
490
491     /* alias for current word */
492     _W1 = W + ix;
493
494     /* alias for next word, where the carry goes */
495     _W = W + ++ix;
496
497     for (; ix <= n->used * 2 + 1; ix++) {
498       *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
499     }
500
501     /* copy out, A = A/b**n
502      *
503      * The result is A/b**n but instead of converting from an
504      * array of mp_word to mp_digit than calling mp_rshd
505      * we just copy them in the right order
506      */
507
508     /* alias for destination word */
509     tmpx = x->dp;
510
511     /* alias for shifted double precision result */
512     _W = W + n->used;
513
514     for (ix = 0; ix < n->used + 1; ix++) {
515       *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
516     }
517
518     /* zero oldused digits, if the input a was larger than
519      * m->used+1 we'll have to clear the digits
520      */
521     for (; ix < olduse; ix++) {
522       *tmpx++ = 0;
523     }
524   }
525
526   /* set the max used and clamp */
527   x->used = n->used + 1;
528   mp_clamp (x);
529
530   /* if A >= m then A = A - m */
531   if (mp_cmp_mag (x, n) != MP_LT) {
532     return s_mp_sub (x, n, x);
533   }
534   return MP_OKAY;
535 }
536
537 /* Fast (comba) multiplier
538  *
539  * This is the fast column-array [comba] multiplier.  It is 
540  * designed to compute the columns of the product first 
541  * then handle the carries afterwards.  This has the effect 
542  * of making the nested loops that compute the columns very
543  * simple and schedulable on super-scalar processors.
544  *
545  * This has been modified to produce a variable number of 
546  * digits of output so if say only a half-product is required 
547  * you don't have to compute the upper half (a feature 
548  * required for fast Barrett reduction).
549  *
550  * Based on Algorithm 14.12 on pp.595 of HAC.
551  *
552  */
553 static int
554 fast_s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
555 {
556   int     olduse, res, pa, ix, iz;
557   mp_digit W[MP_WARRAY];
558   register mp_word  _W;
559
560   /* grow the destination as required */
561   if (c->alloc < digs) {
562     if ((res = mp_grow (c, digs)) != MP_OKAY) {
563       return res;
564     }
565   }
566
567   /* number of output digits to produce */
568   pa = MIN(digs, a->used + b->used);
569
570   /* clear the carry */
571   _W = 0;
572   for (ix = 0; ix <= pa; ix++) { 
573       int      tx, ty;
574       int      iy;
575       mp_digit *tmpx, *tmpy;
576
577       /* get offsets into the two bignums */
578       ty = MIN(b->used-1, ix);
579       tx = ix - ty;
580
581       /* setup temp aliases */
582       tmpx = a->dp + tx;
583       tmpy = b->dp + ty;
584
585       /* This is the number of times the loop will iterate, essentially it's
586          while (tx++ < a->used && ty-- >= 0) { ... }
587        */
588       iy = MIN(a->used-tx, ty+1);
589
590       /* execute loop */
591       for (iz = 0; iz < iy; ++iz) {
592          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
593       }
594
595       /* store term */
596       W[ix] = ((mp_digit)_W) & MP_MASK;
597
598       /* make next carry */
599       _W = _W >> ((mp_word)DIGIT_BIT);
600   }
601
602   /* setup dest */
603   olduse  = c->used;
604   c->used = digs;
605
606   {
607     register mp_digit *tmpc;
608     tmpc = c->dp;
609     for (ix = 0; ix < digs; ix++) {
610       /* now extract the previous digit [below the carry] */
611       *tmpc++ = W[ix];
612     }
613
614     /* clear unused digits [that existed in the old copy of c] */
615     for (; ix < olduse; ix++) {
616       *tmpc++ = 0;
617     }
618   }
619   mp_clamp (c);
620   return MP_OKAY;
621 }
622
623 /* this is a modified version of fast_s_mul_digs that only produces
624  * output digits *above* digs.  See the comments for fast_s_mul_digs
625  * to see how it works.
626  *
627  * This is used in the Barrett reduction since for one of the multiplications
628  * only the higher digits were needed.  This essentially halves the work.
629  *
630  * Based on Algorithm 14.12 on pp.595 of HAC.
631  */
632 static int
633 fast_s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
634 {
635   int     olduse, res, pa, ix, iz;
636   mp_digit W[MP_WARRAY];
637   mp_word  _W;
638
639   /* grow the destination as required */
640   pa = a->used + b->used;
641   if (c->alloc < pa) {
642     if ((res = mp_grow (c, pa)) != MP_OKAY) {
643       return res;
644     }
645   }
646
647   /* number of output digits to produce */
648   pa = a->used + b->used;
649   _W = 0;
650   for (ix = digs; ix <= pa; ix++) { 
651       int      tx, ty, iy;
652       mp_digit *tmpx, *tmpy;
653
654       /* get offsets into the two bignums */
655       ty = MIN(b->used-1, ix);
656       tx = ix - ty;
657
658       /* setup temp aliases */
659       tmpx = a->dp + tx;
660       tmpy = b->dp + ty;
661
662       /* This is the number of times the loop will iterate, essentially it's
663          while (tx++ < a->used && ty-- >= 0) { ... }
664        */
665       iy = MIN(a->used-tx, ty+1);
666
667       /* execute loop */
668       for (iz = 0; iz < iy; iz++) {
669          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
670       }
671
672       /* store term */
673       W[ix] = ((mp_digit)_W) & MP_MASK;
674
675       /* make next carry */
676       _W = _W >> ((mp_word)DIGIT_BIT);
677   }
678
679   /* setup dest */
680   olduse  = c->used;
681   c->used = pa;
682
683   {
684     register mp_digit *tmpc;
685
686     tmpc = c->dp + digs;
687     for (ix = digs; ix <= pa; ix++) {
688       /* now extract the previous digit [below the carry] */
689       *tmpc++ = W[ix];
690     }
691
692     /* clear unused digits [that existed in the old copy of c] */
693     for (; ix < olduse; ix++) {
694       *tmpc++ = 0;
695     }
696   }
697   mp_clamp (c);
698   return MP_OKAY;
699 }
700
701 /* fast squaring
702  *
703  * This is the comba method where the columns of the product
704  * are computed first then the carries are computed.  This
705  * has the effect of making a very simple inner loop that
706  * is executed the most
707  *
708  * W2 represents the outer products and W the inner.
709  *
710  * A further optimizations is made because the inner
711  * products are of the form "A * B * 2".  The *2 part does
712  * not need to be computed until the end which is good
713  * because 64-bit shifts are slow!
714  *
715  * Based on Algorithm 14.16 on pp.597 of HAC.
716  *
717  */
718 /* the jist of squaring...
719
720 you do like mult except the offset of the tmpx [one that starts closer to zero]
721 can't equal the offset of tmpy.  So basically you set up iy like before then you min it with
722 (ty-tx) so that it never happens.  You double all those you add in the inner loop
723
724 After that loop you do the squares and add them in.
725
726 Remove W2 and don't memset W
727
728 */
729
730 static int fast_s_mp_sqr (const mp_int * a, mp_int * b)
731 {
732   int       olduse, res, pa, ix, iz;
733   mp_digit   W[MP_WARRAY], *tmpx;
734   mp_word   W1;
735
736   /* grow the destination as required */
737   pa = a->used + a->used;
738   if (b->alloc < pa) {
739     if ((res = mp_grow (b, pa)) != MP_OKAY) {
740       return res;
741     }
742   }
743
744   /* number of output digits to produce */
745   W1 = 0;
746   for (ix = 0; ix <= pa; ix++) { 
747       int      tx, ty, iy;
748       mp_word  _W;
749       mp_digit *tmpy;
750
751       /* clear counter */
752       _W = 0;
753
754       /* get offsets into the two bignums */
755       ty = MIN(a->used-1, ix);
756       tx = ix - ty;
757
758       /* setup temp aliases */
759       tmpx = a->dp + tx;
760       tmpy = a->dp + ty;
761
762       /* This is the number of times the loop will iterate, essentially it's
763          while (tx++ < a->used && ty-- >= 0) { ... }
764        */
765       iy = MIN(a->used-tx, ty+1);
766
767       /* now for squaring tx can never equal ty 
768        * we halve the distance since they approach at a rate of 2x
769        * and we have to round because odd cases need to be executed
770        */
771       iy = MIN(iy, (ty-tx+1)>>1);
772
773       /* execute loop */
774       for (iz = 0; iz < iy; iz++) {
775          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
776       }
777
778       /* double the inner product and add carry */
779       _W = _W + _W + W1;
780
781       /* even columns have the square term in them */
782       if ((ix&1) == 0) {
783          _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
784       }
785
786       /* store it */
787       W[ix] = _W;
788
789       /* make next carry */
790       W1 = _W >> ((mp_word)DIGIT_BIT);
791   }
792
793   /* setup dest */
794   olduse  = b->used;
795   b->used = a->used+a->used;
796
797   {
798     mp_digit *tmpb;
799     tmpb = b->dp;
800     for (ix = 0; ix < pa; ix++) {
801       *tmpb++ = W[ix] & MP_MASK;
802     }
803
804     /* clear unused digits [that existed in the old copy of c] */
805     for (; ix < olduse; ix++) {
806       *tmpb++ = 0;
807     }
808   }
809   mp_clamp (b);
810   return MP_OKAY;
811 }
812
813 /* computes a = 2**b 
814  *
815  * Simple algorithm which zeroes the int, grows it then just sets one bit
816  * as required.
817  */
818 static int
819 mp_2expt (mp_int * a, int b)
820 {
821   int     res;
822
823   /* zero a as per default */
824   mp_zero (a);
825
826   /* grow a to accommodate the single bit */
827   if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
828     return res;
829   }
830
831   /* set the used count of where the bit will go */
832   a->used = b / DIGIT_BIT + 1;
833
834   /* put the single bit in its place */
835   a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
836
837   return MP_OKAY;
838 }
839
840 /* high level addition (handles signs) */
841 int mp_add (mp_int * a, mp_int * b, mp_int * c)
842 {
843   int     sa, sb, res;
844
845   /* get sign of both inputs */
846   sa = a->sign;
847   sb = b->sign;
848
849   /* handle two cases, not four */
850   if (sa == sb) {
851     /* both positive or both negative */
852     /* add their magnitudes, copy the sign */
853     c->sign = sa;
854     res = s_mp_add (a, b, c);
855   } else {
856     /* one positive, the other negative */
857     /* subtract the one with the greater magnitude from */
858     /* the one of the lesser magnitude.  The result gets */
859     /* the sign of the one with the greater magnitude. */
860     if (mp_cmp_mag (a, b) == MP_LT) {
861       c->sign = sb;
862       res = s_mp_sub (b, a, c);
863     } else {
864       c->sign = sa;
865       res = s_mp_sub (a, b, c);
866     }
867   }
868   return res;
869 }
870
871
872 /* single digit addition */
873 static int
874 mp_add_d (mp_int * a, mp_digit b, mp_int * c)
875 {
876   int     res, ix, oldused;
877   mp_digit *tmpa, *tmpc, mu;
878
879   /* grow c as required */
880   if (c->alloc < a->used + 1) {
881      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
882         return res;
883      }
884   }
885
886   /* if a is negative and |a| >= b, call c = |a| - b */
887   if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
888      /* temporarily fix sign of a */
889      a->sign = MP_ZPOS;
890
891      /* c = |a| - b */
892      res = mp_sub_d(a, b, c);
893
894      /* fix sign  */
895      a->sign = c->sign = MP_NEG;
896
897      return res;
898   }
899
900   /* old number of used digits in c */
901   oldused = c->used;
902
903   /* sign always positive */
904   c->sign = MP_ZPOS;
905
906   /* source alias */
907   tmpa    = a->dp;
908
909   /* destination alias */
910   tmpc    = c->dp;
911
912   /* if a is positive */
913   if (a->sign == MP_ZPOS) {
914      /* add digit, after this we're propagating
915       * the carry.
916       */
917      *tmpc   = *tmpa++ + b;
918      mu      = *tmpc >> DIGIT_BIT;
919      *tmpc++ &= MP_MASK;
920
921      /* now handle rest of the digits */
922      for (ix = 1; ix < a->used; ix++) {
923         *tmpc   = *tmpa++ + mu;
924         mu      = *tmpc >> DIGIT_BIT;
925         *tmpc++ &= MP_MASK;
926      }
927      /* set final carry */
928      ix++;
929      *tmpc++  = mu;
930
931      /* setup size */
932      c->used = a->used + 1;
933   } else {
934      /* a was negative and |a| < b */
935      c->used  = 1;
936
937      /* the result is a single digit */
938      if (a->used == 1) {
939         *tmpc++  =  b - a->dp[0];
940      } else {
941         *tmpc++  =  b;
942      }
943
944      /* setup count so the clearing of oldused
945       * can fall through correctly
946       */
947      ix       = 1;
948   }
949
950   /* now zero to oldused */
951   while (ix++ < oldused) {
952      *tmpc++ = 0;
953   }
954   mp_clamp(c);
955
956   return MP_OKAY;
957 }
958
959 /* trim unused digits 
960  *
961  * This is used to ensure that leading zero digits are
962  * trimed and the leading "used" digit will be non-zero
963  * Typically very fast.  Also fixes the sign if there
964  * are no more leading digits
965  */
966 void
967 mp_clamp (mp_int * a)
968 {
969   /* decrease used while the most significant digit is
970    * zero.
971    */
972   while (a->used > 0 && a->dp[a->used - 1] == 0) {
973     --(a->used);
974   }
975
976   /* reset the sign flag if used == 0 */
977   if (a->used == 0) {
978     a->sign = MP_ZPOS;
979   }
980 }
981
982 void mp_clear_multi(mp_int *mp, ...) 
983 {
984     mp_int* next_mp = mp;
985     va_list args;
986     va_start(args, mp);
987     while (next_mp != NULL) {
988         mp_clear(next_mp);
989         next_mp = va_arg(args, mp_int*);
990     }
991     va_end(args);
992 }
993
994 /* compare two ints (signed)*/
995 int
996 mp_cmp (const mp_int * a, const mp_int * b)
997 {
998   /* compare based on sign */
999   if (a->sign != b->sign) {
1000      if (a->sign == MP_NEG) {
1001         return MP_LT;
1002      } else {
1003         return MP_GT;
1004      }
1005   }
1006   
1007   /* compare digits */
1008   if (a->sign == MP_NEG) {
1009      /* if negative compare opposite direction */
1010      return mp_cmp_mag(b, a);
1011   } else {
1012      return mp_cmp_mag(a, b);
1013   }
1014 }
1015
1016 /* compare a digit */
1017 int mp_cmp_d(const mp_int * a, mp_digit b)
1018 {
1019   /* compare based on sign */
1020   if (a->sign == MP_NEG) {
1021     return MP_LT;
1022   }
1023
1024   /* compare based on magnitude */
1025   if (a->used > 1) {
1026     return MP_GT;
1027   }
1028
1029   /* compare the only digit of a to b */
1030   if (a->dp[0] > b) {
1031     return MP_GT;
1032   } else if (a->dp[0] < b) {
1033     return MP_LT;
1034   } else {
1035     return MP_EQ;
1036   }
1037 }
1038
1039 /* compare maginitude of two ints (unsigned) */
1040 int mp_cmp_mag (const mp_int * a, const mp_int * b)
1041 {
1042   int     n;
1043   mp_digit *tmpa, *tmpb;
1044
1045   /* compare based on # of non-zero digits */
1046   if (a->used > b->used) {
1047     return MP_GT;
1048   }
1049   
1050   if (a->used < b->used) {
1051     return MP_LT;
1052   }
1053
1054   /* alias for a */
1055   tmpa = a->dp + (a->used - 1);
1056
1057   /* alias for b */
1058   tmpb = b->dp + (a->used - 1);
1059
1060   /* compare based on digits  */
1061   for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
1062     if (*tmpa > *tmpb) {
1063       return MP_GT;
1064     }
1065
1066     if (*tmpa < *tmpb) {
1067       return MP_LT;
1068     }
1069   }
1070   return MP_EQ;
1071 }
1072
1073 static const int lnz[16] = { 
1074    4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
1075 };
1076
1077 /* Counts the number of lsbs which are zero before the first zero bit */
1078 int mp_cnt_lsb(const mp_int *a)
1079 {
1080    int x;
1081    mp_digit q, qq;
1082
1083    /* easy out */
1084    if (mp_iszero(a) == 1) {
1085       return 0;
1086    }
1087
1088    /* scan lower digits until non-zero */
1089    for (x = 0; x < a->used && a->dp[x] == 0; x++);
1090    q = a->dp[x];
1091    x *= DIGIT_BIT;
1092
1093    /* now scan this digit until a 1 is found */
1094    if ((q & 1) == 0) {
1095       do {
1096          qq  = q & 15;
1097          x  += lnz[qq];
1098          q >>= 4;
1099       } while (qq == 0);
1100    }
1101    return x;
1102 }
1103
1104 /* copy, b = a */
1105 int
1106 mp_copy (const mp_int * a, mp_int * b)
1107 {
1108   int     res, n;
1109
1110   /* if dst == src do nothing */
1111   if (a == b) {
1112     return MP_OKAY;
1113   }
1114
1115   /* grow dest */
1116   if (b->alloc < a->used) {
1117      if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1118         return res;
1119      }
1120   }
1121
1122   /* zero b and copy the parameters over */
1123   {
1124     register mp_digit *tmpa, *tmpb;
1125
1126     /* pointer aliases */
1127
1128     /* source */
1129     tmpa = a->dp;
1130
1131     /* destination */
1132     tmpb = b->dp;
1133
1134     /* copy all the digits */
1135     for (n = 0; n < a->used; n++) {
1136       *tmpb++ = *tmpa++;
1137     }
1138
1139     /* clear high digits */
1140     for (; n < b->used; n++) {
1141       *tmpb++ = 0;
1142     }
1143   }
1144
1145   /* copy used count and sign */
1146   b->used = a->used;
1147   b->sign = a->sign;
1148   return MP_OKAY;
1149 }
1150
1151 /* returns the number of bits in an int */
1152 int
1153 mp_count_bits (const mp_int * a)
1154 {
1155   int     r;
1156   mp_digit q;
1157
1158   /* shortcut */
1159   if (a->used == 0) {
1160     return 0;
1161   }
1162
1163   /* get number of digits and add that */
1164   r = (a->used - 1) * DIGIT_BIT;
1165   
1166   /* take the last digit and count the bits in it */
1167   q = a->dp[a->used - 1];
1168   while (q > 0) {
1169     ++r;
1170     q >>= ((mp_digit) 1);
1171   }
1172   return r;
1173 }
1174
1175 /* calc a value mod 2**b */
1176 static int
1177 mp_mod_2d (const mp_int * a, int b, mp_int * c)
1178 {
1179   int     x, res;
1180
1181   /* if b is <= 0 then zero the int */
1182   if (b <= 0) {
1183     mp_zero (c);
1184     return MP_OKAY;
1185   }
1186
1187   /* if the modulus is larger than the value than return */
1188   if (b > a->used * DIGIT_BIT) {
1189     res = mp_copy (a, c);
1190     return res;
1191   }
1192
1193   /* copy */
1194   if ((res = mp_copy (a, c)) != MP_OKAY) {
1195     return res;
1196   }
1197
1198   /* zero digits above the last digit of the modulus */
1199   for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
1200     c->dp[x] = 0;
1201   }
1202   /* clear the digit that is not completely outside/inside the modulus */
1203   c->dp[b / DIGIT_BIT] &= (1 << ((mp_digit)b % DIGIT_BIT)) - 1;
1204   mp_clamp (c);
1205   return MP_OKAY;
1206 }
1207
1208 /* shift right a certain amount of digits */
1209 static void mp_rshd (mp_int * a, int b)
1210 {
1211   int     x;
1212
1213   /* if b <= 0 then ignore it */
1214   if (b <= 0) {
1215     return;
1216   }
1217
1218   /* if b > used then simply zero it and return */
1219   if (a->used <= b) {
1220     mp_zero (a);
1221     return;
1222   }
1223
1224   {
1225     register mp_digit *bottom, *top;
1226
1227     /* shift the digits down */
1228
1229     /* bottom */
1230     bottom = a->dp;
1231
1232     /* top [offset into digits] */
1233     top = a->dp + b;
1234
1235     /* this is implemented as a sliding window where
1236      * the window is b-digits long and digits from
1237      * the top of the window are copied to the bottom
1238      *
1239      * e.g.
1240
1241      b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
1242                  /\                   |      ---->
1243                   \-------------------/      ---->
1244      */
1245     for (x = 0; x < (a->used - b); x++) {
1246       *bottom++ = *top++;
1247     }
1248
1249     /* zero the top digits */
1250     for (; x < a->used; x++) {
1251       *bottom++ = 0;
1252     }
1253   }
1254
1255   /* remove excess digits */
1256   a->used -= b;
1257 }
1258
1259 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
1260 static int mp_div_2d (const mp_int * a, int b, mp_int * c, mp_int * d)
1261 {
1262   mp_digit D, r, rr;
1263   int     x, res;
1264   mp_int  t;
1265
1266
1267   /* if the shift count is <= 0 then we do no work */
1268   if (b <= 0) {
1269     res = mp_copy (a, c);
1270     if (d != NULL) {
1271       mp_zero (d);
1272     }
1273     return res;
1274   }
1275
1276   if ((res = mp_init (&t)) != MP_OKAY) {
1277     return res;
1278   }
1279
1280   /* get the remainder */
1281   if (d != NULL) {
1282     if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
1283       mp_clear (&t);
1284       return res;
1285     }
1286   }
1287
1288   /* copy */
1289   if ((res = mp_copy (a, c)) != MP_OKAY) {
1290     mp_clear (&t);
1291     return res;
1292   }
1293
1294   /* shift by as many digits in the bit count */
1295   if (b >= DIGIT_BIT) {
1296     mp_rshd (c, b / DIGIT_BIT);
1297   }
1298
1299   /* shift any bit count < DIGIT_BIT */
1300   D = (mp_digit) (b % DIGIT_BIT);
1301   if (D != 0) {
1302     register mp_digit *tmpc, mask, shift;
1303
1304     /* mask */
1305     mask = (((mp_digit)1) << D) - 1;
1306
1307     /* shift for lsb */
1308     shift = DIGIT_BIT - D;
1309
1310     /* alias */
1311     tmpc = c->dp + (c->used - 1);
1312
1313     /* carry */
1314     r = 0;
1315     for (x = c->used - 1; x >= 0; x--) {
1316       /* get the lower  bits of this word in a temp */
1317       rr = *tmpc & mask;
1318
1319       /* shift the current word and mix in the carry bits from the previous word */
1320       *tmpc = (*tmpc >> D) | (r << shift);
1321       --tmpc;
1322
1323       /* set the carry to the carry bits of the current word found above */
1324       r = rr;
1325     }
1326   }
1327   mp_clamp (c);
1328   if (d != NULL) {
1329     mp_exch (&t, d);
1330   }
1331   mp_clear (&t);
1332   return MP_OKAY;
1333 }
1334
1335 /* shift left a certain amount of digits */
1336 static int mp_lshd (mp_int * a, int b)
1337 {
1338   int     x, res;
1339
1340   /* if its less than zero return */
1341   if (b <= 0) {
1342     return MP_OKAY;
1343   }
1344
1345   /* grow to fit the new digits */
1346   if (a->alloc < a->used + b) {
1347      if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
1348        return res;
1349      }
1350   }
1351
1352   {
1353     register mp_digit *top, *bottom;
1354
1355     /* increment the used by the shift amount then copy upwards */
1356     a->used += b;
1357
1358     /* top */
1359     top = a->dp + a->used - 1;
1360
1361     /* base */
1362     bottom = a->dp + a->used - 1 - b;
1363
1364     /* much like mp_rshd this is implemented using a sliding window
1365      * except the window goes the otherway around.  Copying from
1366      * the bottom to the top.  see bn_mp_rshd.c for more info.
1367      */
1368     for (x = a->used - 1; x >= b; x--) {
1369       *top-- = *bottom--;
1370     }
1371
1372     /* zero the lower digits */
1373     top = a->dp;
1374     for (x = 0; x < b; x++) {
1375       *top++ = 0;
1376     }
1377   }
1378   return MP_OKAY;
1379 }
1380
1381 /* shift left by a certain bit count */
1382 static int mp_mul_2d (const mp_int * a, int b, mp_int * c)
1383 {
1384   mp_digit d;
1385   int      res;
1386
1387   /* copy */
1388   if (a != c) {
1389      if ((res = mp_copy (a, c)) != MP_OKAY) {
1390        return res;
1391      }
1392   }
1393
1394   if (c->alloc < c->used + b/DIGIT_BIT + 1) {
1395      if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
1396        return res;
1397      }
1398   }
1399
1400   /* shift by as many digits in the bit count */
1401   if (b >= DIGIT_BIT) {
1402     if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
1403       return res;
1404     }
1405   }
1406
1407   /* shift any bit count < DIGIT_BIT */
1408   d = (mp_digit) (b % DIGIT_BIT);
1409   if (d != 0) {
1410     register mp_digit *tmpc, shift, mask, r, rr;
1411     register int x;
1412
1413     /* bitmask for carries */
1414     mask = (((mp_digit)1) << d) - 1;
1415
1416     /* shift for msbs */
1417     shift = DIGIT_BIT - d;
1418
1419     /* alias */
1420     tmpc = c->dp;
1421
1422     /* carry */
1423     r    = 0;
1424     for (x = 0; x < c->used; x++) {
1425       /* get the higher bits of the current word */
1426       rr = (*tmpc >> shift) & mask;
1427
1428       /* shift the current word and OR in the carry */
1429       *tmpc = ((*tmpc << d) | r) & MP_MASK;
1430       ++tmpc;
1431
1432       /* set the carry to the carry bits of the current word */
1433       r = rr;
1434     }
1435
1436     /* set final carry */
1437     if (r != 0) {
1438        c->dp[(c->used)++] = r;
1439     }
1440   }
1441   mp_clamp (c);
1442   return MP_OKAY;
1443 }
1444
1445 /* multiply by a digit */
1446 static int
1447 mp_mul_d (const mp_int * a, mp_digit b, mp_int * c)
1448 {
1449   mp_digit u, *tmpa, *tmpc;
1450   mp_word  r;
1451   int      ix, res, olduse;
1452
1453   /* make sure c is big enough to hold a*b */
1454   if (c->alloc < a->used + 1) {
1455     if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
1456       return res;
1457     }
1458   }
1459
1460   /* get the original destinations used count */
1461   olduse = c->used;
1462
1463   /* set the sign */
1464   c->sign = a->sign;
1465
1466   /* alias for a->dp [source] */
1467   tmpa = a->dp;
1468
1469   /* alias for c->dp [dest] */
1470   tmpc = c->dp;
1471
1472   /* zero carry */
1473   u = 0;
1474
1475   /* compute columns */
1476   for (ix = 0; ix < a->used; ix++) {
1477     /* compute product and carry sum for this term */
1478     r       = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
1479
1480     /* mask off higher bits to get a single digit */
1481     *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
1482
1483     /* send carry into next iteration */
1484     u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
1485   }
1486
1487   /* store final carry [if any] */
1488   *tmpc++ = u;
1489
1490   /* now zero digits above the top */
1491   while (ix++ < olduse) {
1492      *tmpc++ = 0;
1493   }
1494
1495   /* set used count */
1496   c->used = a->used + 1;
1497   mp_clamp(c);
1498
1499   return MP_OKAY;
1500 }
1501
1502 /* integer signed division. 
1503  * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
1504  * HAC pp.598 Algorithm 14.20
1505  *
1506  * Note that the description in HAC is horribly 
1507  * incomplete.  For example, it doesn't consider 
1508  * the case where digits are removed from 'x' in 
1509  * the inner loop.  It also doesn't consider the 
1510  * case that y has fewer than three digits, etc..
1511  *
1512  * The overall algorithm is as described as 
1513  * 14.20 from HAC but fixed to treat these cases.
1514 */
1515 static int mp_div (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
1516 {
1517   mp_int  q, x, y, t1, t2;
1518   int     res, n, t, i, norm, neg;
1519
1520   /* is divisor zero ? */
1521   if (mp_iszero (b) == 1) {
1522     return MP_VAL;
1523   }
1524
1525   /* if a < b then q=0, r = a */
1526   if (mp_cmp_mag (a, b) == MP_LT) {
1527     if (d != NULL) {
1528       res = mp_copy (a, d);
1529     } else {
1530       res = MP_OKAY;
1531     }
1532     if (c != NULL) {
1533       mp_zero (c);
1534     }
1535     return res;
1536   }
1537
1538   if ((res = mp_init_size (&q, a->used + 2)) != MP_OKAY) {
1539     return res;
1540   }
1541   q.used = a->used + 2;
1542
1543   if ((res = mp_init (&t1)) != MP_OKAY) {
1544     goto __Q;
1545   }
1546
1547   if ((res = mp_init (&t2)) != MP_OKAY) {
1548     goto __T1;
1549   }
1550
1551   if ((res = mp_init_copy (&x, a)) != MP_OKAY) {
1552     goto __T2;
1553   }
1554
1555   if ((res = mp_init_copy (&y, b)) != MP_OKAY) {
1556     goto __X;
1557   }
1558
1559   /* fix the sign */
1560   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1561   x.sign = y.sign = MP_ZPOS;
1562
1563   /* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1564   norm = mp_count_bits(&y) % DIGIT_BIT;
1565   if (norm < DIGIT_BIT-1) {
1566      norm = (DIGIT_BIT-1) - norm;
1567      if ((res = mp_mul_2d (&x, norm, &x)) != MP_OKAY) {
1568        goto __Y;
1569      }
1570      if ((res = mp_mul_2d (&y, norm, &y)) != MP_OKAY) {
1571        goto __Y;
1572      }
1573   } else {
1574      norm = 0;
1575   }
1576
1577   /* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
1578   n = x.used - 1;
1579   t = y.used - 1;
1580
1581   /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1582   if ((res = mp_lshd (&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1583     goto __Y;
1584   }
1585
1586   while (mp_cmp (&x, &y) != MP_LT) {
1587     ++(q.dp[n - t]);
1588     if ((res = mp_sub (&x, &y, &x)) != MP_OKAY) {
1589       goto __Y;
1590     }
1591   }
1592
1593   /* reset y by shifting it back down */
1594   mp_rshd (&y, n - t);
1595
1596   /* step 3. for i from n down to (t + 1) */
1597   for (i = n; i >= (t + 1); i--) {
1598     if (i > x.used) {
1599       continue;
1600     }
1601
1602     /* step 3.1 if xi == yt then set q{i-t-1} to b-1, 
1603      * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1604     if (x.dp[i] == y.dp[t]) {
1605       q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1606     } else {
1607       mp_word tmp;
1608       tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1609       tmp |= ((mp_word) x.dp[i - 1]);
1610       tmp /= ((mp_word) y.dp[t]);
1611       if (tmp > (mp_word) MP_MASK)
1612         tmp = MP_MASK;
1613       q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1614     }
1615
1616     /* while (q{i-t-1} * (yt * b + y{t-1})) > 
1617              xi * b**2 + xi-1 * b + xi-2 
1618      
1619        do q{i-t-1} -= 1; 
1620     */
1621     q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1622     do {
1623       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1624
1625       /* find left hand */
1626       mp_zero (&t1);
1627       t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1628       t1.dp[1] = y.dp[t];
1629       t1.used = 2;
1630       if ((res = mp_mul_d (&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1631         goto __Y;
1632       }
1633
1634       /* find right hand */
1635       t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1636       t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1637       t2.dp[2] = x.dp[i];
1638       t2.used = 3;
1639     } while (mp_cmp_mag(&t1, &t2) == MP_GT);
1640
1641     /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1642     if ((res = mp_mul_d (&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1643       goto __Y;
1644     }
1645
1646     if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1647       goto __Y;
1648     }
1649
1650     if ((res = mp_sub (&x, &t1, &x)) != MP_OKAY) {
1651       goto __Y;
1652     }
1653
1654     /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1655     if (x.sign == MP_NEG) {
1656       if ((res = mp_copy (&y, &t1)) != MP_OKAY) {
1657         goto __Y;
1658       }
1659       if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1660         goto __Y;
1661       }
1662       if ((res = mp_add (&x, &t1, &x)) != MP_OKAY) {
1663         goto __Y;
1664       }
1665
1666       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1667     }
1668   }
1669
1670   /* now q is the quotient and x is the remainder 
1671    * [which we have to normalize] 
1672    */
1673   
1674   /* get sign before writing to c */
1675   x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1676
1677   if (c != NULL) {
1678     mp_clamp (&q);
1679     mp_exch (&q, c);
1680     c->sign = neg;
1681   }
1682
1683   if (d != NULL) {
1684     mp_div_2d (&x, norm, &x, NULL);
1685     mp_exch (&x, d);
1686   }
1687
1688   res = MP_OKAY;
1689
1690 __Y:mp_clear (&y);
1691 __X:mp_clear (&x);
1692 __T2:mp_clear (&t2);
1693 __T1:mp_clear (&t1);
1694 __Q:mp_clear (&q);
1695   return res;
1696 }
1697
1698 static int s_is_power_of_two(mp_digit b, int *p)
1699 {
1700    int x;
1701
1702    for (x = 1; x < DIGIT_BIT; x++) {
1703       if (b == (((mp_digit)1)<<x)) {
1704          *p = x;
1705          return 1;
1706       }
1707    }
1708    return 0;
1709 }
1710
1711 /* single digit division (based on routine from MPI) */
1712 static int mp_div_d (const mp_int * a, mp_digit b, mp_int * c, mp_digit * d)
1713 {
1714   mp_int  q;
1715   mp_word w;
1716   mp_digit t;
1717   int     res, ix;
1718
1719   /* cannot divide by zero */
1720   if (b == 0) {
1721      return MP_VAL;
1722   }
1723
1724   /* quick outs */
1725   if (b == 1 || mp_iszero(a) == 1) {
1726      if (d != NULL) {
1727         *d = 0;
1728      }
1729      if (c != NULL) {
1730         return mp_copy(a, c);
1731      }
1732      return MP_OKAY;
1733   }
1734
1735   /* power of two ? */
1736   if (s_is_power_of_two(b, &ix) == 1) {
1737      if (d != NULL) {
1738         *d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
1739      }
1740      if (c != NULL) {
1741         return mp_div_2d(a, ix, c, NULL);
1742      }
1743      return MP_OKAY;
1744   }
1745
1746   /* no easy answer [c'est la vie].  Just division */
1747   if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
1748      return res;
1749   }
1750   
1751   q.used = a->used;
1752   q.sign = a->sign;
1753   w = 0;
1754   for (ix = a->used - 1; ix >= 0; ix--) {
1755      w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
1756      
1757      if (w >= b) {
1758         t = (mp_digit)(w / b);
1759         w -= ((mp_word)t) * ((mp_word)b);
1760       } else {
1761         t = 0;
1762       }
1763       q.dp[ix] = t;
1764   }
1765
1766   if (d != NULL) {
1767      *d = (mp_digit)w;
1768   }
1769   
1770   if (c != NULL) {
1771      mp_clamp(&q);
1772      mp_exch(&q, c);
1773   }
1774   mp_clear(&q);
1775   
1776   return res;
1777 }
1778
1779 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
1780  *
1781  * Based on algorithm from the paper
1782  *
1783  * "Generating Efficient Primes for Discrete Log Cryptosystems"
1784  *                 Chae Hoon Lim, Pil Loong Lee,
1785  *          POSTECH Information Research Laboratories
1786  *
1787  * The modulus must be of a special format [see manual]
1788  *
1789  * Has been modified to use algorithm 7.10 from the LTM book instead
1790  *
1791  * Input x must be in the range 0 <= x <= (n-1)**2
1792  */
1793 static int
1794 mp_dr_reduce (mp_int * x, const mp_int * n, mp_digit k)
1795 {
1796   int      err, i, m;
1797   mp_word  r;
1798   mp_digit mu, *tmpx1, *tmpx2;
1799
1800   /* m = digits in modulus */
1801   m = n->used;
1802
1803   /* ensure that "x" has at least 2m digits */
1804   if (x->alloc < m + m) {
1805     if ((err = mp_grow (x, m + m)) != MP_OKAY) {
1806       return err;
1807     }
1808   }
1809
1810 /* top of loop, this is where the code resumes if
1811  * another reduction pass is required.
1812  */
1813 top:
1814   /* aliases for digits */
1815   /* alias for lower half of x */
1816   tmpx1 = x->dp;
1817
1818   /* alias for upper half of x, or x/B**m */
1819   tmpx2 = x->dp + m;
1820
1821   /* set carry to zero */
1822   mu = 0;
1823
1824   /* compute (x mod B**m) + k * [x/B**m] inline and inplace */
1825   for (i = 0; i < m; i++) {
1826       r         = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
1827       *tmpx1++  = (mp_digit)(r & MP_MASK);
1828       mu        = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
1829   }
1830
1831   /* set final carry */
1832   *tmpx1++ = mu;
1833
1834   /* zero words above m */
1835   for (i = m + 1; i < x->used; i++) {
1836       *tmpx1++ = 0;
1837   }
1838
1839   /* clamp, sub and return */
1840   mp_clamp (x);
1841
1842   /* if x >= n then subtract and reduce again
1843    * Each successive "recursion" makes the input smaller and smaller.
1844    */
1845   if (mp_cmp_mag (x, n) != MP_LT) {
1846     s_mp_sub(x, n, x);
1847     goto top;
1848   }
1849   return MP_OKAY;
1850 }
1851
1852 /* sets the value of "d" required for mp_dr_reduce */
1853 static void mp_dr_setup(const mp_int *a, mp_digit *d)
1854 {
1855    /* the casts are required if DIGIT_BIT is one less than
1856     * the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
1857     */
1858    *d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) - 
1859         ((mp_word)a->dp[0]));
1860 }
1861
1862 /* this is a shell function that calls either the normal or Montgomery
1863  * exptmod functions.  Originally the call to the montgomery code was
1864  * embedded in the normal function but that wasted a lot of stack space
1865  * for nothing (since 99% of the time the Montgomery code would be called)
1866  */
1867 int mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
1868 {
1869   int dr;
1870
1871   /* modulus P must be positive */
1872   if (P->sign == MP_NEG) {
1873      return MP_VAL;
1874   }
1875
1876   /* if exponent X is negative we have to recurse */
1877   if (X->sign == MP_NEG) {
1878      mp_int tmpG, tmpX;
1879      int err;
1880
1881      /* first compute 1/G mod P */
1882      if ((err = mp_init(&tmpG)) != MP_OKAY) {
1883         return err;
1884      }
1885      if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
1886         mp_clear(&tmpG);
1887         return err;
1888      }
1889
1890      /* now get |X| */
1891      if ((err = mp_init(&tmpX)) != MP_OKAY) {
1892         mp_clear(&tmpG);
1893         return err;
1894      }
1895      if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
1896         mp_clear_multi(&tmpG, &tmpX, NULL);
1897         return err;
1898      }
1899
1900      /* and now compute (1/G)**|X| instead of G**X [X < 0] */
1901      err = mp_exptmod(&tmpG, &tmpX, P, Y);
1902      mp_clear_multi(&tmpG, &tmpX, NULL);
1903      return err;
1904   }
1905
1906   dr = 0;
1907
1908   /* if the modulus is odd or dr != 0 use the fast method */
1909   if (mp_isodd (P) == 1 || dr !=  0) {
1910     return mp_exptmod_fast (G, X, P, Y, dr);
1911   } else {
1912     /* otherwise use the generic Barrett reduction technique */
1913     return s_mp_exptmod (G, X, P, Y);
1914   }
1915 }
1916
1917 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
1918  *
1919  * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
1920  * The value of k changes based on the size of the exponent.
1921  *
1922  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
1923  */
1924
1925 int
1926 mp_exptmod_fast (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y, int redmode)
1927 {
1928   mp_int  M[256], res;
1929   mp_digit buf, mp;
1930   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1931
1932   /* use a pointer to the reduction algorithm.  This allows us to use
1933    * one of many reduction algorithms without modding the guts of
1934    * the code with if statements everywhere.
1935    */
1936   int     (*redux)(mp_int*,const mp_int*,mp_digit);
1937
1938   /* find window size */
1939   x = mp_count_bits (X);
1940   if (x <= 7) {
1941     winsize = 2;
1942   } else if (x <= 36) {
1943     winsize = 3;
1944   } else if (x <= 140) {
1945     winsize = 4;
1946   } else if (x <= 450) {
1947     winsize = 5;
1948   } else if (x <= 1303) {
1949     winsize = 6;
1950   } else if (x <= 3529) {
1951     winsize = 7;
1952   } else {
1953     winsize = 8;
1954   }
1955
1956   /* init M array */
1957   /* init first cell */
1958   if ((err = mp_init(&M[1])) != MP_OKAY) {
1959      return err;
1960   }
1961
1962   /* now init the second half of the array */
1963   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
1964     if ((err = mp_init(&M[x])) != MP_OKAY) {
1965       for (y = 1<<(winsize-1); y < x; y++) {
1966         mp_clear (&M[y]);
1967       }
1968       mp_clear(&M[1]);
1969       return err;
1970     }
1971   }
1972
1973   /* determine and setup reduction code */
1974   if (redmode == 0) {
1975      /* now setup montgomery  */
1976      if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
1977         goto __M;
1978      }
1979
1980      /* automatically pick the comba one if available (saves quite a few calls/ifs) */
1981      if (((P->used * 2 + 1) < MP_WARRAY) &&
1982           P->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
1983         redux = fast_mp_montgomery_reduce;
1984      } else {
1985         /* use slower baseline Montgomery method */
1986         redux = mp_montgomery_reduce;
1987      }
1988   } else if (redmode == 1) {
1989      /* setup DR reduction for moduli of the form B**k - b */
1990      mp_dr_setup(P, &mp);
1991      redux = mp_dr_reduce;
1992   } else {
1993      /* setup DR reduction for moduli of the form 2**k - b */
1994      if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
1995         goto __M;
1996      }
1997      redux = mp_reduce_2k;
1998   }
1999
2000   /* setup result */
2001   if ((err = mp_init (&res)) != MP_OKAY) {
2002     goto __M;
2003   }
2004
2005   /* create M table
2006    *
2007
2008    *
2009    * The first half of the table is not computed though accept for M[0] and M[1]
2010    */
2011
2012   if (redmode == 0) {
2013      /* now we need R mod m */
2014      if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
2015        goto __RES;
2016      }
2017
2018      /* now set M[1] to G * R mod m */
2019      if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
2020        goto __RES;
2021      }
2022   } else {
2023      mp_set(&res, 1);
2024      if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
2025         goto __RES;
2026      }
2027   }
2028
2029   /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
2030   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
2031     goto __RES;
2032   }
2033
2034   for (x = 0; x < (winsize - 1); x++) {
2035     if ((err = mp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
2036       goto __RES;
2037     }
2038     if ((err = redux (&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
2039       goto __RES;
2040     }
2041   }
2042
2043   /* create upper table */
2044   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
2045     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
2046       goto __RES;
2047     }
2048     if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
2049       goto __RES;
2050     }
2051   }
2052
2053   /* set initial mode and bit cnt */
2054   mode   = 0;
2055   bitcnt = 1;
2056   buf    = 0;
2057   digidx = X->used - 1;
2058   bitcpy = 0;
2059   bitbuf = 0;
2060
2061   for (;;) {
2062     /* grab next digit as required */
2063     if (--bitcnt == 0) {
2064       /* if digidx == -1 we are out of digits so break */
2065       if (digidx == -1) {
2066         break;
2067       }
2068       /* read next digit and reset bitcnt */
2069       buf    = X->dp[digidx--];
2070       bitcnt = DIGIT_BIT;
2071     }
2072
2073     /* grab the next msb from the exponent */
2074     y     = (buf >> (DIGIT_BIT - 1)) & 1;
2075     buf <<= (mp_digit)1;
2076
2077     /* if the bit is zero and mode == 0 then we ignore it
2078      * These represent the leading zero bits before the first 1 bit
2079      * in the exponent.  Technically this opt is not required but it
2080      * does lower the # of trivial squaring/reductions used
2081      */
2082     if (mode == 0 && y == 0) {
2083       continue;
2084     }
2085
2086     /* if the bit is zero and mode == 1 then we square */
2087     if (mode == 1 && y == 0) {
2088       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2089         goto __RES;
2090       }
2091       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2092         goto __RES;
2093       }
2094       continue;
2095     }
2096
2097     /* else we add it to the window */
2098     bitbuf |= (y << (winsize - ++bitcpy));
2099     mode    = 2;
2100
2101     if (bitcpy == winsize) {
2102       /* ok window is filled so square as required and multiply  */
2103       /* square first */
2104       for (x = 0; x < winsize; x++) {
2105         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2106           goto __RES;
2107         }
2108         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2109           goto __RES;
2110         }
2111       }
2112
2113       /* then multiply */
2114       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
2115         goto __RES;
2116       }
2117       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2118         goto __RES;
2119       }
2120
2121       /* empty window and reset */
2122       bitcpy = 0;
2123       bitbuf = 0;
2124       mode   = 1;
2125     }
2126   }
2127
2128   /* if bits remain then square/multiply */
2129   if (mode == 2 && bitcpy > 0) {
2130     /* square then multiply if the bit is set */
2131     for (x = 0; x < bitcpy; x++) {
2132       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2133         goto __RES;
2134       }
2135       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2136         goto __RES;
2137       }
2138
2139       /* get next bit of the window */
2140       bitbuf <<= 1;
2141       if ((bitbuf & (1 << winsize)) != 0) {
2142         /* then multiply */
2143         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
2144           goto __RES;
2145         }
2146         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2147           goto __RES;
2148         }
2149       }
2150     }
2151   }
2152
2153   if (redmode == 0) {
2154      /* fixup result if Montgomery reduction is used
2155       * recall that any value in a Montgomery system is
2156       * actually multiplied by R mod n.  So we have
2157       * to reduce one more time to cancel out the factor
2158       * of R.
2159       */
2160      if ((err = redux(&res, P, mp)) != MP_OKAY) {
2161        goto __RES;
2162      }
2163   }
2164
2165   /* swap res with Y */
2166   mp_exch (&res, Y);
2167   err = MP_OKAY;
2168 __RES:mp_clear (&res);
2169 __M:
2170   mp_clear(&M[1]);
2171   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2172     mp_clear (&M[x]);
2173   }
2174   return err;
2175 }
2176
2177 /* Greatest Common Divisor using the binary method */
2178 int mp_gcd (const mp_int * a, const mp_int * b, mp_int * c)
2179 {
2180   mp_int  u, v;
2181   int     k, u_lsb, v_lsb, res;
2182
2183   /* either zero than gcd is the largest */
2184   if (mp_iszero (a) == 1 && mp_iszero (b) == 0) {
2185     return mp_abs (b, c);
2186   }
2187   if (mp_iszero (a) == 0 && mp_iszero (b) == 1) {
2188     return mp_abs (a, c);
2189   }
2190
2191   /* optimized.  At this point if a == 0 then
2192    * b must equal zero too
2193    */
2194   if (mp_iszero (a) == 1) {
2195     mp_zero(c);
2196     return MP_OKAY;
2197   }
2198
2199   /* get copies of a and b we can modify */
2200   if ((res = mp_init_copy (&u, a)) != MP_OKAY) {
2201     return res;
2202   }
2203
2204   if ((res = mp_init_copy (&v, b)) != MP_OKAY) {
2205     goto __U;
2206   }
2207
2208   /* must be positive for the remainder of the algorithm */
2209   u.sign = v.sign = MP_ZPOS;
2210
2211   /* B1.  Find the common power of two for u and v */
2212   u_lsb = mp_cnt_lsb(&u);
2213   v_lsb = mp_cnt_lsb(&v);
2214   k     = MIN(u_lsb, v_lsb);
2215
2216   if (k > 0) {
2217      /* divide the power of two out */
2218      if ((res = mp_div_2d(&u, k, &u, NULL)) != MP_OKAY) {
2219         goto __V;
2220      }
2221
2222      if ((res = mp_div_2d(&v, k, &v, NULL)) != MP_OKAY) {
2223         goto __V;
2224      }
2225   }
2226
2227   /* divide any remaining factors of two out */
2228   if (u_lsb != k) {
2229      if ((res = mp_div_2d(&u, u_lsb - k, &u, NULL)) != MP_OKAY) {
2230         goto __V;
2231      }
2232   }
2233
2234   if (v_lsb != k) {
2235      if ((res = mp_div_2d(&v, v_lsb - k, &v, NULL)) != MP_OKAY) {
2236         goto __V;
2237      }
2238   }
2239
2240   while (mp_iszero(&v) == 0) {
2241      /* make sure v is the largest */
2242      if (mp_cmp_mag(&u, &v) == MP_GT) {
2243         /* swap u and v to make sure v is >= u */
2244         mp_exch(&u, &v);
2245      }
2246      
2247      /* subtract smallest from largest */
2248      if ((res = s_mp_sub(&v, &u, &v)) != MP_OKAY) {
2249         goto __V;
2250      }
2251      
2252      /* Divide out all factors of two */
2253      if ((res = mp_div_2d(&v, mp_cnt_lsb(&v), &v, NULL)) != MP_OKAY) {
2254         goto __V;
2255      } 
2256   } 
2257
2258   /* multiply by 2**k which we divided out at the beginning */
2259   if ((res = mp_mul_2d (&u, k, c)) != MP_OKAY) {
2260      goto __V;
2261   }
2262   c->sign = MP_ZPOS;
2263   res = MP_OKAY;
2264 __V:mp_clear (&u);
2265 __U:mp_clear (&v);
2266   return res;
2267 }
2268
2269 /* get the lower 32-bits of an mp_int */
2270 unsigned long mp_get_int(const mp_int * a)
2271 {
2272   int i;
2273   unsigned long res;
2274
2275   if (a->used == 0) {
2276      return 0;
2277   }
2278
2279   /* get number of digits of the lsb we have to read */
2280   i = MIN(a->used,(int)((sizeof(unsigned long)*CHAR_BIT+DIGIT_BIT-1)/DIGIT_BIT))-1;
2281
2282   /* get most significant digit of result */
2283   res = DIGIT(a,i);
2284    
2285   while (--i >= 0) {
2286     res = (res << DIGIT_BIT) | DIGIT(a,i);
2287   }
2288
2289   /* force result to 32-bits always so it is consistent on non 32-bit platforms */
2290   return res & 0xFFFFFFFFUL;
2291 }
2292
2293 /* creates "a" then copies b into it */
2294 int mp_init_copy (mp_int * a, const mp_int * b)
2295 {
2296   int     res;
2297
2298   if ((res = mp_init (a)) != MP_OKAY) {
2299     return res;
2300   }
2301   return mp_copy (b, a);
2302 }
2303
2304 int mp_init_multi(mp_int *mp, ...) 
2305 {
2306     mp_err res = MP_OKAY;      /* Assume ok until proven otherwise */
2307     int n = 0;                 /* Number of ok inits */
2308     mp_int* cur_arg = mp;
2309     va_list args;
2310
2311     va_start(args, mp);        /* init args to next argument from caller */
2312     while (cur_arg != NULL) {
2313         if (mp_init(cur_arg) != MP_OKAY) {
2314             /* Oops - error! Back-track and mp_clear what we already
2315                succeeded in init-ing, then return error.
2316             */
2317             va_list clean_args;
2318             
2319             /* end the current list */
2320             va_end(args);
2321             
2322             /* now start cleaning up */            
2323             cur_arg = mp;
2324             va_start(clean_args, mp);
2325             while (n--) {
2326                 mp_clear(cur_arg);
2327                 cur_arg = va_arg(clean_args, mp_int*);
2328             }
2329             va_end(clean_args);
2330             res = MP_MEM;
2331             break;
2332         }
2333         n++;
2334         cur_arg = va_arg(args, mp_int*);
2335     }
2336     va_end(args);
2337     return res;                /* Assumed ok, if error flagged above. */
2338 }
2339
2340 /* hac 14.61, pp608 */
2341 int mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
2342 {
2343   /* b cannot be negative */
2344   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2345     return MP_VAL;
2346   }
2347
2348   /* if the modulus is odd we can use a faster routine instead */
2349   if (mp_isodd (b) == 1) {
2350     return fast_mp_invmod (a, b, c);
2351   }
2352   
2353   return mp_invmod_slow(a, b, c);
2354 }
2355
2356 /* hac 14.61, pp608 */
2357 int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c)
2358 {
2359   mp_int  x, y, u, v, A, B, C, D;
2360   int     res;
2361
2362   /* b cannot be negative */
2363   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2364     return MP_VAL;
2365   }
2366
2367   /* init temps */
2368   if ((res = mp_init_multi(&x, &y, &u, &v, 
2369                            &A, &B, &C, &D, NULL)) != MP_OKAY) {
2370      return res;
2371   }
2372
2373   /* x = a, y = b */
2374   if ((res = mp_copy (a, &x)) != MP_OKAY) {
2375     goto __ERR;
2376   }
2377   if ((res = mp_copy (b, &y)) != MP_OKAY) {
2378     goto __ERR;
2379   }
2380
2381   /* 2. [modified] if x,y are both even then return an error! */
2382   if (mp_iseven (&x) == 1 && mp_iseven (&y) == 1) {
2383     res = MP_VAL;
2384     goto __ERR;
2385   }
2386
2387   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
2388   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
2389     goto __ERR;
2390   }
2391   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
2392     goto __ERR;
2393   }
2394   mp_set (&A, 1);
2395   mp_set (&D, 1);
2396
2397 top:
2398   /* 4.  while u is even do */
2399   while (mp_iseven (&u) == 1) {
2400     /* 4.1 u = u/2 */
2401     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
2402       goto __ERR;
2403     }
2404     /* 4.2 if A or B is odd then */
2405     if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
2406       /* A = (A+y)/2, B = (B-x)/2 */
2407       if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
2408          goto __ERR;
2409       }
2410       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
2411          goto __ERR;
2412       }
2413     }
2414     /* A = A/2, B = B/2 */
2415     if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
2416       goto __ERR;
2417     }
2418     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
2419       goto __ERR;
2420     }
2421   }
2422
2423   /* 5.  while v is even do */
2424   while (mp_iseven (&v) == 1) {
2425     /* 5.1 v = v/2 */
2426     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
2427       goto __ERR;
2428     }
2429     /* 5.2 if C or D is odd then */
2430     if (mp_isodd (&C) == 1 || mp_isodd (&D) == 1) {
2431       /* C = (C+y)/2, D = (D-x)/2 */
2432       if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
2433          goto __ERR;
2434       }
2435       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
2436          goto __ERR;
2437       }
2438     }
2439     /* C = C/2, D = D/2 */
2440     if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
2441       goto __ERR;
2442     }
2443     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
2444       goto __ERR;
2445     }
2446   }
2447
2448   /* 6.  if u >= v then */
2449   if (mp_cmp (&u, &v) != MP_LT) {
2450     /* u = u - v, A = A - C, B = B - D */
2451     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
2452       goto __ERR;
2453     }
2454
2455     if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
2456       goto __ERR;
2457     }
2458
2459     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
2460       goto __ERR;
2461     }
2462   } else {
2463     /* v - v - u, C = C - A, D = D - B */
2464     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
2465       goto __ERR;
2466     }
2467
2468     if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
2469       goto __ERR;
2470     }
2471
2472     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
2473       goto __ERR;
2474     }
2475   }
2476
2477   /* if not zero goto step 4 */
2478   if (mp_iszero (&u) == 0)
2479     goto top;
2480
2481   /* now a = C, b = D, gcd == g*v */
2482
2483   /* if v != 1 then there is no inverse */
2484   if (mp_cmp_d (&v, 1) != MP_EQ) {
2485     res = MP_VAL;
2486     goto __ERR;
2487   }
2488
2489   /* if its too low */
2490   while (mp_cmp_d(&C, 0) == MP_LT) {
2491       if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
2492          goto __ERR;
2493       }
2494   }
2495   
2496   /* too big */
2497   while (mp_cmp_mag(&C, b) != MP_LT) {
2498       if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
2499          goto __ERR;
2500       }
2501   }
2502   
2503   /* C is now the inverse */
2504   mp_exch (&C, c);
2505   res = MP_OKAY;
2506 __ERR:mp_clear_multi (&x, &y, &u, &v, &A, &B, &C, &D, NULL);
2507   return res;
2508 }
2509
2510 /* c = |a| * |b| using Karatsuba Multiplication using 
2511  * three half size multiplications
2512  *
2513  * Let B represent the radix [e.g. 2**DIGIT_BIT] and 
2514  * let n represent half of the number of digits in 
2515  * the min(a,b)
2516  *
2517  * a = a1 * B**n + a0
2518  * b = b1 * B**n + b0
2519  *
2520  * Then, a * b => 
2521    a1b1 * B**2n + ((a1 - a0)(b1 - b0) + a0b0 + a1b1) * B + a0b0
2522  *
2523  * Note that a1b1 and a0b0 are used twice and only need to be 
2524  * computed once.  So in total three half size (half # of 
2525  * digit) multiplications are performed, a0b0, a1b1 and 
2526  * (a1-b1)(a0-b0)
2527  *
2528  * Note that a multiplication of half the digits requires
2529  * 1/4th the number of single precision multiplications so in 
2530  * total after one call 25% of the single precision multiplications 
2531  * are saved.  Note also that the call to mp_mul can end up back 
2532  * in this function if the a0, a1, b0, or b1 are above the threshold.  
2533  * This is known as divide-and-conquer and leads to the famous 
2534  * O(N**lg(3)) or O(N**1.584) work which is asymptotically lower than
2535  * the standard O(N**2) that the baseline/comba methods use.  
2536  * Generally though the overhead of this method doesn't pay off 
2537  * until a certain size (N ~ 80) is reached.
2538  */
2539 int mp_karatsuba_mul (const mp_int * a, const mp_int * b, mp_int * c)
2540 {
2541   mp_int  x0, x1, y0, y1, t1, x0y0, x1y1;
2542   int     B, err;
2543
2544   /* default the return code to an error */
2545   err = MP_MEM;
2546
2547   /* min # of digits */
2548   B = MIN (a->used, b->used);
2549
2550   /* now divide in two */
2551   B = B >> 1;
2552
2553   /* init copy all the temps */
2554   if (mp_init_size (&x0, B) != MP_OKAY)
2555     goto ERR;
2556   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2557     goto X0;
2558   if (mp_init_size (&y0, B) != MP_OKAY)
2559     goto X1;
2560   if (mp_init_size (&y1, b->used - B) != MP_OKAY)
2561     goto Y0;
2562
2563   /* init temps */
2564   if (mp_init_size (&t1, B * 2) != MP_OKAY)
2565     goto Y1;
2566   if (mp_init_size (&x0y0, B * 2) != MP_OKAY)
2567     goto T1;
2568   if (mp_init_size (&x1y1, B * 2) != MP_OKAY)
2569     goto X0Y0;
2570
2571   /* now shift the digits */
2572   x0.used = y0.used = B;
2573   x1.used = a->used - B;
2574   y1.used = b->used - B;
2575
2576   {
2577     register int x;
2578     register mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
2579
2580     /* we copy the digits directly instead of using higher level functions
2581      * since we also need to shift the digits
2582      */
2583     tmpa = a->dp;
2584     tmpb = b->dp;
2585
2586     tmpx = x0.dp;
2587     tmpy = y0.dp;
2588     for (x = 0; x < B; x++) {
2589       *tmpx++ = *tmpa++;
2590       *tmpy++ = *tmpb++;
2591     }
2592
2593     tmpx = x1.dp;
2594     for (x = B; x < a->used; x++) {
2595       *tmpx++ = *tmpa++;
2596     }
2597
2598     tmpy = y1.dp;
2599     for (x = B; x < b->used; x++) {
2600       *tmpy++ = *tmpb++;
2601     }
2602   }
2603
2604   /* only need to clamp the lower words since by definition the 
2605    * upper words x1/y1 must have a known number of digits
2606    */
2607   mp_clamp (&x0);
2608   mp_clamp (&y0);
2609
2610   /* now calc the products x0y0 and x1y1 */
2611   /* after this x0 is no longer required, free temp [x0==t2]! */
2612   if (mp_mul (&x0, &y0, &x0y0) != MP_OKAY)  
2613     goto X1Y1;          /* x0y0 = x0*y0 */
2614   if (mp_mul (&x1, &y1, &x1y1) != MP_OKAY)
2615     goto X1Y1;          /* x1y1 = x1*y1 */
2616
2617   /* now calc x1-x0 and y1-y0 */
2618   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2619     goto X1Y1;          /* t1 = x1 - x0 */
2620   if (mp_sub (&y1, &y0, &x0) != MP_OKAY)
2621     goto X1Y1;          /* t2 = y1 - y0 */
2622   if (mp_mul (&t1, &x0, &t1) != MP_OKAY)
2623     goto X1Y1;          /* t1 = (x1 - x0) * (y1 - y0) */
2624
2625   /* add x0y0 */
2626   if (mp_add (&x0y0, &x1y1, &x0) != MP_OKAY)
2627     goto X1Y1;          /* t2 = x0y0 + x1y1 */
2628   if (mp_sub (&x0, &t1, &t1) != MP_OKAY)
2629     goto X1Y1;          /* t1 = x0y0 + x1y1 - (x1-x0)*(y1-y0) */
2630
2631   /* shift by B */
2632   if (mp_lshd (&t1, B) != MP_OKAY)
2633     goto X1Y1;          /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
2634   if (mp_lshd (&x1y1, B * 2) != MP_OKAY)
2635     goto X1Y1;          /* x1y1 = x1y1 << 2*B */
2636
2637   if (mp_add (&x0y0, &t1, &t1) != MP_OKAY)
2638     goto X1Y1;          /* t1 = x0y0 + t1 */
2639   if (mp_add (&t1, &x1y1, c) != MP_OKAY)
2640     goto X1Y1;          /* t1 = x0y0 + t1 + x1y1 */
2641
2642   /* Algorithm succeeded set the return code to MP_OKAY */
2643   err = MP_OKAY;
2644
2645 X1Y1:mp_clear (&x1y1);
2646 X0Y0:mp_clear (&x0y0);
2647 T1:mp_clear (&t1);
2648 Y1:mp_clear (&y1);
2649 Y0:mp_clear (&y0);
2650 X1:mp_clear (&x1);
2651 X0:mp_clear (&x0);
2652 ERR:
2653   return err;
2654 }
2655
2656 /* Karatsuba squaring, computes b = a*a using three 
2657  * half size squarings
2658  *
2659  * See comments of karatsuba_mul for details.  It 
2660  * is essentially the same algorithm but merely 
2661  * tuned to perform recursive squarings.
2662  */
2663 int mp_karatsuba_sqr (const mp_int * a, mp_int * b)
2664 {
2665   mp_int  x0, x1, t1, t2, x0x0, x1x1;
2666   int     B, err;
2667
2668   err = MP_MEM;
2669
2670   /* min # of digits */
2671   B = a->used;
2672
2673   /* now divide in two */
2674   B = B >> 1;
2675
2676   /* init copy all the temps */
2677   if (mp_init_size (&x0, B) != MP_OKAY)
2678     goto ERR;
2679   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2680     goto X0;
2681
2682   /* init temps */
2683   if (mp_init_size (&t1, a->used * 2) != MP_OKAY)
2684     goto X1;
2685   if (mp_init_size (&t2, a->used * 2) != MP_OKAY)
2686     goto T1;
2687   if (mp_init_size (&x0x0, B * 2) != MP_OKAY)
2688     goto T2;
2689   if (mp_init_size (&x1x1, (a->used - B) * 2) != MP_OKAY)
2690     goto X0X0;
2691
2692   {
2693     register int x;
2694     register mp_digit *dst, *src;
2695
2696     src = a->dp;
2697
2698     /* now shift the digits */
2699     dst = x0.dp;
2700     for (x = 0; x < B; x++) {
2701       *dst++ = *src++;
2702     }
2703
2704     dst = x1.dp;
2705     for (x = B; x < a->used; x++) {
2706       *dst++ = *src++;
2707     }
2708   }
2709
2710   x0.used = B;
2711   x1.used = a->used - B;
2712
2713   mp_clamp (&x0);
2714
2715   /* now calc the products x0*x0 and x1*x1 */
2716   if (mp_sqr (&x0, &x0x0) != MP_OKAY)
2717     goto X1X1;           /* x0x0 = x0*x0 */
2718   if (mp_sqr (&x1, &x1x1) != MP_OKAY)
2719     goto X1X1;           /* x1x1 = x1*x1 */
2720
2721   /* now calc (x1-x0)**2 */
2722   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2723     goto X1X1;           /* t1 = x1 - x0 */
2724   if (mp_sqr (&t1, &t1) != MP_OKAY)
2725     goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */
2726
2727   /* add x0y0 */
2728   if (s_mp_add (&x0x0, &x1x1, &t2) != MP_OKAY)
2729     goto X1X1;           /* t2 = x0x0 + x1x1 */
2730   if (mp_sub (&t2, &t1, &t1) != MP_OKAY)
2731     goto X1X1;           /* t1 = x0x0 + x1x1 - (x1-x0)*(x1-x0) */
2732
2733   /* shift by B */
2734   if (mp_lshd (&t1, B) != MP_OKAY)
2735     goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
2736   if (mp_lshd (&x1x1, B * 2) != MP_OKAY)
2737     goto X1X1;           /* x1x1 = x1x1 << 2*B */
2738
2739   if (mp_add (&x0x0, &t1, &t1) != MP_OKAY)
2740     goto X1X1;           /* t1 = x0x0 + t1 */
2741   if (mp_add (&t1, &x1x1, b) != MP_OKAY)
2742     goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */
2743
2744   err = MP_OKAY;
2745
2746 X1X1:mp_clear (&x1x1);
2747 X0X0:mp_clear (&x0x0);
2748 T2:mp_clear (&t2);
2749 T1:mp_clear (&t1);
2750 X1:mp_clear (&x1);
2751 X0:mp_clear (&x0);
2752 ERR:
2753   return err;
2754 }
2755
2756 /* computes least common multiple as |a*b|/(a, b) */
2757 int mp_lcm (const mp_int * a, const mp_int * b, mp_int * c)
2758 {
2759   int     res;
2760   mp_int  t1, t2;
2761
2762
2763   if ((res = mp_init_multi (&t1, &t2, NULL)) != MP_OKAY) {
2764     return res;
2765   }
2766
2767   /* t1 = get the GCD of the two inputs */
2768   if ((res = mp_gcd (a, b, &t1)) != MP_OKAY) {
2769     goto __T;
2770   }
2771
2772   /* divide the smallest by the GCD */
2773   if (mp_cmp_mag(a, b) == MP_LT) {
2774      /* store quotient in t2 such that t2 * b is the LCM */
2775      if ((res = mp_div(a, &t1, &t2, NULL)) != MP_OKAY) {
2776         goto __T;
2777      }
2778      res = mp_mul(b, &t2, c);
2779   } else {
2780      /* store quotient in t2 such that t2 * a is the LCM */
2781      if ((res = mp_div(b, &t1, &t2, NULL)) != MP_OKAY) {
2782         goto __T;
2783      }
2784      res = mp_mul(a, &t2, c);
2785   }
2786
2787   /* fix the sign to positive */
2788   c->sign = MP_ZPOS;
2789
2790 __T:
2791   mp_clear_multi (&t1, &t2, NULL);
2792   return res;
2793 }
2794
2795 /* c = a mod b, 0 <= c < b */
2796 int
2797 mp_mod (const mp_int * a, mp_int * b, mp_int * c)
2798 {
2799   mp_int  t;
2800   int     res;
2801
2802   if ((res = mp_init (&t)) != MP_OKAY) {
2803     return res;
2804   }
2805
2806   if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
2807     mp_clear (&t);
2808     return res;
2809   }
2810
2811   if (t.sign != b->sign) {
2812     res = mp_add (b, &t, c);
2813   } else {
2814     res = MP_OKAY;
2815     mp_exch (&t, c);
2816   }
2817
2818   mp_clear (&t);
2819   return res;
2820 }
2821
2822 static int
2823 mp_mod_d (const mp_int * a, mp_digit b, mp_digit * c)
2824 {
2825   return mp_div_d(a, b, NULL, c);
2826 }
2827
2828 /* b = a*2 */
2829 static int mp_mul_2(const mp_int * a, mp_int * b)
2830 {
2831   int     x, res, oldused;
2832
2833   /* grow to accommodate result */
2834   if (b->alloc < a->used + 1) {
2835     if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
2836       return res;
2837     }
2838   }
2839
2840   oldused = b->used;
2841   b->used = a->used;
2842
2843   {
2844     register mp_digit r, rr, *tmpa, *tmpb;
2845
2846     /* alias for source */
2847     tmpa = a->dp;
2848
2849     /* alias for dest */
2850     tmpb = b->dp;
2851
2852     /* carry */
2853     r = 0;
2854     for (x = 0; x < a->used; x++) {
2855
2856       /* get what will be the *next* carry bit from the
2857        * MSB of the current digit
2858        */
2859       rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
2860
2861       /* now shift up this digit, add in the carry [from the previous] */
2862       *tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
2863
2864       /* copy the carry that would be from the source
2865        * digit into the next iteration
2866        */
2867       r = rr;
2868     }
2869
2870     /* new leading digit? */
2871     if (r != 0) {
2872       /* add a MSB which is always 1 at this point */
2873       *tmpb = 1;
2874       ++(b->used);
2875     }
2876
2877     /* now zero any excess digits on the destination
2878      * that we didn't write to
2879      */
2880     tmpb = b->dp + b->used;
2881     for (x = b->used; x < oldused; x++) {
2882       *tmpb++ = 0;
2883     }
2884   }
2885   b->sign = a->sign;
2886   return MP_OKAY;
2887 }
2888
2889 /*
2890  * shifts with subtractions when the result is greater than b.
2891  *
2892  * The method is slightly modified to shift B unconditionally up to just under
2893  * the leading bit of b.  This saves a lot of multiple precision shifting.
2894  */
2895 int mp_montgomery_calc_normalization (mp_int * a, const mp_int * b)
2896 {
2897   int     x, bits, res;
2898
2899   /* how many bits of last digit does b use */
2900   bits = mp_count_bits (b) % DIGIT_BIT;
2901
2902
2903   if (b->used > 1) {
2904      if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
2905         return res;
2906      }
2907   } else {
2908      mp_set(a, 1);
2909      bits = 1;
2910   }
2911
2912
2913   /* now compute C = A * B mod b */
2914   for (x = bits - 1; x < DIGIT_BIT; x++) {
2915     if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
2916       return res;
2917     }
2918     if (mp_cmp_mag (a, b) != MP_LT) {
2919       if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
2920         return res;
2921       }
2922     }
2923   }
2924
2925   return MP_OKAY;
2926 }
2927
2928 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
2929 int
2930 mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
2931 {
2932   int     ix, res, digs;
2933   mp_digit mu;
2934
2935   /* can the fast reduction [comba] method be used?
2936    *
2937    * Note that unlike in mul you're safely allowed *less*
2938    * than the available columns [255 per default] since carries
2939    * are fixed up in the inner loop.
2940    */
2941   digs = n->used * 2 + 1;
2942   if ((digs < MP_WARRAY) &&
2943       n->used <
2944       (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2945     return fast_mp_montgomery_reduce (x, n, rho);
2946   }
2947
2948   /* grow the input as required */
2949   if (x->alloc < digs) {
2950     if ((res = mp_grow (x, digs)) != MP_OKAY) {
2951       return res;
2952     }
2953   }
2954   x->used = digs;
2955
2956   for (ix = 0; ix < n->used; ix++) {
2957     /* mu = ai * rho mod b
2958      *
2959      * The value of rho must be precalculated via
2960      * montgomery_setup() such that
2961      * it equals -1/n0 mod b this allows the
2962      * following inner loop to reduce the
2963      * input one digit at a time
2964      */
2965     mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
2966
2967     /* a = a + mu * m * b**i */
2968     {
2969       register int iy;
2970       register mp_digit *tmpn, *tmpx, u;
2971       register mp_word r;
2972
2973       /* alias for digits of the modulus */
2974       tmpn = n->dp;
2975
2976       /* alias for the digits of x [the input] */
2977       tmpx = x->dp + ix;
2978
2979       /* set the carry to zero */
2980       u = 0;
2981
2982       /* Multiply and add in place */
2983       for (iy = 0; iy < n->used; iy++) {
2984         /* compute product and sum */
2985         r       = ((mp_word)mu) * ((mp_word)*tmpn++) +
2986                   ((mp_word) u) + ((mp_word) * tmpx);
2987
2988         /* get carry */
2989         u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2990
2991         /* fix digit */
2992         *tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
2993       }
2994       /* At this point the ix'th digit of x should be zero */
2995
2996
2997       /* propagate carries upwards as required*/
2998       while (u) {
2999         *tmpx   += u;
3000         u        = *tmpx >> DIGIT_BIT;
3001         *tmpx++ &= MP_MASK;
3002       }
3003     }
3004   }
3005
3006   /* at this point the n.used'th least
3007    * significant digits of x are all zero
3008    * which means we can shift x to the
3009    * right by n.used digits and the
3010    * residue is unchanged.
3011    */
3012
3013   /* x = x/b**n.used */
3014   mp_clamp(x);
3015   mp_rshd (x, n->used);
3016
3017   /* if x >= n then x = x - n */
3018   if (mp_cmp_mag (x, n) != MP_LT) {
3019     return s_mp_sub (x, n, x);
3020   }
3021
3022   return MP_OKAY;
3023 }
3024
3025 /* setups the montgomery reduction stuff */
3026 int
3027 mp_montgomery_setup (const mp_int * n, mp_digit * rho)
3028 {
3029   mp_digit x, b;
3030
3031 /* fast inversion mod 2**k
3032  *
3033  * Based on the fact that
3034  *
3035  * XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
3036  *                    =>  2*X*A - X*X*A*A = 1
3037  *                    =>  2*(1) - (1)     = 1
3038  */
3039   b = n->dp[0];
3040
3041   if ((b & 1) == 0) {
3042     return MP_VAL;
3043   }
3044
3045   x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
3046   x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
3047   x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
3048   x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
3049
3050   /* rho = -1/m mod b */
3051   *rho = (((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
3052
3053   return MP_OKAY;
3054 }
3055
3056 /* high level multiplication (handles sign) */
3057 int mp_mul (const mp_int * a, const mp_int * b, mp_int * c)
3058 {
3059   int     res, neg;
3060   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
3061
3062   /* use Karatsuba? */
3063   if (MIN (a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
3064     res = mp_karatsuba_mul (a, b, c);
3065   } else 
3066   {
3067     /* can we use the fast multiplier?
3068      *
3069      * The fast multiplier can be used if the output will 
3070      * have less than MP_WARRAY digits and the number of 
3071      * digits won't affect carry propagation
3072      */
3073     int     digs = a->used + b->used + 1;
3074
3075     if ((digs < MP_WARRAY) &&
3076         MIN(a->used, b->used) <= 
3077         (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
3078       res = fast_s_mp_mul_digs (a, b, c, digs);
3079     } else 
3080       res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
3081   }
3082   c->sign = (c->used > 0) ? neg : MP_ZPOS;
3083   return res;
3084 }
3085
3086 /* d = a * b (mod c) */
3087 int
3088 mp_mulmod (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
3089 {
3090   int     res;
3091   mp_int  t;
3092
3093   if ((res = mp_init (&t)) != MP_OKAY) {
3094     return res;
3095   }
3096
3097   if ((res = mp_mul (a, b, &t)) != MP_OKAY) {
3098     mp_clear (&t);
3099     return res;
3100   }
3101   res = mp_mod (&t, c, d);
3102   mp_clear (&t);
3103   return res;
3104 }
3105
3106 /* table of first PRIME_SIZE primes */
3107 static const mp_digit __prime_tab[] = {
3108   0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
3109   0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
3110   0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
3111   0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F, 0x0083,
3112   0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
3113   0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
3114   0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
3115   0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
3116
3117   0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
3118   0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
3119   0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
3120   0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
3121   0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
3122   0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
3123   0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
3124   0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
3125
3126   0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
3127   0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
3128   0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
3129   0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
3130   0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
3131   0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
3132   0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
3133   0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
3134
3135   0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
3136   0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
3137   0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
3138   0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
3139   0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
3140   0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
3141   0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
3142   0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
3143 };
3144
3145 /* determines if an integers is divisible by one 
3146  * of the first PRIME_SIZE primes or not
3147  *
3148  * sets result to 0 if not, 1 if yes
3149  */
3150 static int mp_prime_is_divisible (const mp_int * a, int *result)
3151 {
3152   int     err, ix;
3153   mp_digit res;
3154
3155   /* default to not */
3156   *result = MP_NO;
3157
3158   for (ix = 0; ix < PRIME_SIZE; ix++) {
3159     /* what is a mod __prime_tab[ix] */
3160     if ((err = mp_mod_d (a, __prime_tab[ix], &res)) != MP_OKAY) {
3161       return err;
3162     }
3163
3164     /* is the residue zero? */
3165     if (res == 0) {
3166       *result = MP_YES;
3167       return MP_OKAY;
3168     }
3169   }
3170
3171   return MP_OKAY;
3172 }
3173
3174 /* Miller-Rabin test of "a" to the base of "b" as described in 
3175  * HAC pp. 139 Algorithm 4.24
3176  *
3177  * Sets result to 0 if definitely composite or 1 if probably prime.
3178  * Randomly the chance of error is no more than 1/4 and often 
3179  * very much lower.
3180  */
3181 static int mp_prime_miller_rabin (mp_int * a, const mp_int * b, int *result)
3182 {
3183   mp_int  n1, y, r;
3184   int     s, j, err;
3185
3186   /* default */
3187   *result = MP_NO;
3188
3189   /* ensure b > 1 */
3190   if (mp_cmp_d(b, 1) != MP_GT) {
3191      return MP_VAL;
3192   }     
3193
3194   /* get n1 = a - 1 */
3195   if ((err = mp_init_copy (&n1, a)) != MP_OKAY) {
3196     return err;
3197   }
3198   if ((err = mp_sub_d (&n1, 1, &n1)) != MP_OKAY) {
3199     goto __N1;
3200   }
3201
3202   /* set 2**s * r = n1 */
3203   if ((err = mp_init_copy (&r, &n1)) != MP_OKAY) {
3204     goto __N1;
3205   }
3206
3207   /* count the number of least significant bits
3208    * which are zero
3209    */
3210   s = mp_cnt_lsb(&r);
3211
3212   /* now divide n - 1 by 2**s */
3213   if ((err = mp_div_2d (&r, s, &r, NULL)) != MP_OKAY) {
3214     goto __R;
3215   }
3216
3217   /* compute y = b**r mod a */
3218   if ((err = mp_init (&y)) != MP_OKAY) {
3219     goto __R;
3220   }
3221   if ((err = mp_exptmod (b, &r, a, &y)) != MP_OKAY) {
3222     goto __Y;
3223   }
3224
3225   /* if y != 1 and y != n1 do */
3226   if (mp_cmp_d (&y, 1) != MP_EQ && mp_cmp (&y, &n1) != MP_EQ) {
3227     j = 1;
3228     /* while j <= s-1 and y != n1 */
3229     while ((j <= (s - 1)) && mp_cmp (&y, &n1) != MP_EQ) {
3230       if ((err = mp_sqrmod (&y, a, &y)) != MP_OKAY) {
3231          goto __Y;
3232       }
3233
3234       /* if y == 1 then composite */
3235       if (mp_cmp_d (&y, 1) == MP_EQ) {
3236          goto __Y;
3237       }
3238
3239       ++j;
3240     }
3241
3242     /* if y != n1 then composite */
3243     if (mp_cmp (&y, &n1) != MP_EQ) {
3244       goto __Y;
3245     }
3246   }
3247
3248   /* probably prime now */
3249   *result = MP_YES;
3250 __Y:mp_clear (&y);
3251 __R:mp_clear (&r);
3252 __N1:mp_clear (&n1);
3253   return err;
3254 }
3255
3256 /* performs a variable number of rounds of Miller-Rabin
3257  *
3258  * Probability of error after t rounds is no more than
3259
3260  *
3261  * Sets result to 1 if probably prime, 0 otherwise
3262  */
3263 static int mp_prime_is_prime (mp_int * a, int t, int *result)
3264 {
3265   mp_int  b;
3266   int     ix, err, res;
3267
3268   /* default to no */
3269   *result = MP_NO;
3270
3271   /* valid value of t? */
3272   if (t <= 0 || t > PRIME_SIZE) {
3273     return MP_VAL;
3274   }
3275
3276   /* is the input equal to one of the primes in the table? */
3277   for (ix = 0; ix < PRIME_SIZE; ix++) {
3278       if (mp_cmp_d(a, __prime_tab[ix]) == MP_EQ) {
3279          *result = 1;
3280          return MP_OKAY;
3281       }
3282   }
3283
3284   /* first perform trial division */
3285   if ((err = mp_prime_is_divisible (a, &res)) != MP_OKAY) {
3286     return err;
3287   }
3288
3289   /* return if it was trivially divisible */
3290   if (res == MP_YES) {
3291     return MP_OKAY;
3292   }
3293
3294   /* now perform the miller-rabin rounds */
3295   if ((err = mp_init (&b)) != MP_OKAY) {
3296     return err;
3297   }
3298
3299   for (ix = 0; ix < t; ix++) {
3300     /* set the prime */
3301     mp_set (&b, __prime_tab[ix]);
3302
3303     if ((err = mp_prime_miller_rabin (a, &b, &res)) != MP_OKAY) {
3304       goto __B;
3305     }
3306
3307     if (res == MP_NO) {
3308       goto __B;
3309     }
3310   }
3311
3312   /* passed the test */
3313   *result = MP_YES;
3314 __B:mp_clear (&b);
3315   return err;
3316 }
3317
3318 static const struct {
3319    int k, t;
3320 } sizes[] = {
3321 {   128,    28 },
3322 {   256,    16 },
3323 {   384,    10 },
3324 {   512,     7 },
3325 {   640,     6 },
3326 {   768,     5 },
3327 {   896,     4 },
3328 {  1024,     4 }
3329 };
3330
3331 /* returns # of RM trials required for a given bit size */
3332 int mp_prime_rabin_miller_trials(int size)
3333 {
3334    int x;
3335
3336    for (x = 0; x < (int)(sizeof(sizes)/(sizeof(sizes[0]))); x++) {
3337        if (sizes[x].k == size) {
3338           return sizes[x].t;
3339        } else if (sizes[x].k > size) {
3340           return (x == 0) ? sizes[0].t : sizes[x - 1].t;
3341        }
3342    }
3343    return sizes[x-1].t + 1;
3344 }
3345
3346 /* makes a truly random prime of a given size (bits),
3347  *
3348  * Flags are as follows:
3349  * 
3350  *   LTM_PRIME_BBS      - make prime congruent to 3 mod 4
3351  *   LTM_PRIME_SAFE     - make sure (p-1)/2 is prime as well (implies LTM_PRIME_BBS)
3352  *   LTM_PRIME_2MSB_OFF - make the 2nd highest bit zero
3353  *   LTM_PRIME_2MSB_ON  - make the 2nd highest bit one
3354  *
3355  * You have to supply a callback which fills in a buffer with random bytes.  "dat" is a parameter you can
3356  * have passed to the callback (e.g. a state or something).  This function doesn't use "dat" itself
3357  * so it can be NULL
3358  *
3359  */
3360
3361 /* This is possibly the mother of all prime generation functions, muahahahahaha! */
3362 int mp_prime_random_ex(mp_int *a, int t, int size, int flags, ltm_prime_callback cb, void *dat)
3363 {
3364    unsigned char *tmp, maskAND, maskOR_msb, maskOR_lsb;
3365    int res, err, bsize, maskOR_msb_offset;
3366
3367    /* sanity check the input */
3368    if (size <= 1 || t <= 0) {
3369       return MP_VAL;
3370    }
3371
3372    /* LTM_PRIME_SAFE implies LTM_PRIME_BBS */
3373    if (flags & LTM_PRIME_SAFE) {
3374       flags |= LTM_PRIME_BBS;
3375    }
3376
3377    /* calc the byte size */
3378    bsize = (size>>3)+((size&7)?1:0);
3379
3380    /* we need a buffer of bsize bytes */
3381    tmp = malloc(bsize);
3382    if (tmp == NULL) {
3383       return MP_MEM;
3384    }
3385
3386    /* calc the maskAND value for the MSbyte*/
3387    maskAND = ((size&7) == 0) ? 0xFF : (0xFF >> (8 - (size & 7))); 
3388
3389    /* calc the maskOR_msb */
3390    maskOR_msb        = 0;
3391    maskOR_msb_offset = ((size & 7) == 1) ? 1 : 0;
3392    if (flags & LTM_PRIME_2MSB_ON) {
3393       maskOR_msb     |= 1 << ((size - 2) & 7);
3394    } else if (flags & LTM_PRIME_2MSB_OFF) {
3395       maskAND        &= ~(1 << ((size - 2) & 7));
3396    }
3397
3398    /* get the maskOR_lsb */
3399    maskOR_lsb         = 0;
3400    if (flags & LTM_PRIME_BBS) {
3401       maskOR_lsb     |= 3;
3402    }
3403
3404    do {
3405       /* read the bytes */
3406       if (cb(tmp, bsize, dat) != bsize) {
3407          err = MP_VAL;
3408          goto error;
3409       }
3410  
3411       /* work over the MSbyte */
3412       tmp[0]    &= maskAND;
3413       tmp[0]    |= 1 << ((size - 1) & 7);
3414
3415       /* mix in the maskORs */
3416       tmp[maskOR_msb_offset]   |= maskOR_msb;
3417       tmp[bsize-1]             |= maskOR_lsb;
3418
3419       /* read it in */
3420       if ((err = mp_read_unsigned_bin(a, tmp, bsize)) != MP_OKAY)     { goto error; }
3421
3422       /* is it prime? */
3423       if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)           { goto error; }
3424       if (res == MP_NO) {  
3425          continue;
3426       }
3427
3428       if (flags & LTM_PRIME_SAFE) {
3429          /* see if (a-1)/2 is prime */
3430          if ((err = mp_sub_d(a, 1, a)) != MP_OKAY)                    { goto error; }
3431          if ((err = mp_div_2(a, a)) != MP_OKAY)                       { goto error; }
3432  
3433          /* is it prime? */
3434          if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)        { goto error; }
3435       }
3436    } while (res == MP_NO);
3437
3438    if (flags & LTM_PRIME_SAFE) {
3439       /* restore a to the original value */
3440       if ((err = mp_mul_2(a, a)) != MP_OKAY)                          { goto error; }
3441       if ((err = mp_add_d(a, 1, a)) != MP_OKAY)                       { goto error; }
3442    }
3443
3444    err = MP_OKAY;
3445 error:
3446    free(tmp);
3447    return err;
3448 }
3449
3450 /* reads an unsigned char array, assumes the msb is stored first [big endian] */
3451 int
3452 mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
3453 {
3454   int     res;
3455
3456   /* make sure there are at least two digits */
3457   if (a->alloc < 2) {
3458      if ((res = mp_grow(a, 2)) != MP_OKAY) {
3459         return res;
3460      }
3461   }
3462
3463   /* zero the int */
3464   mp_zero (a);
3465
3466   /* read the bytes in */
3467   while (c-- > 0) {
3468     if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
3469       return res;
3470     }
3471
3472       a->dp[0] |= *b++;
3473       a->used += 1;
3474   }
3475   mp_clamp (a);
3476   return MP_OKAY;
3477 }
3478
3479 /* reduces x mod m, assumes 0 < x < m**2, mu is 
3480  * precomputed via mp_reduce_setup.
3481  * From HAC pp.604 Algorithm 14.42
3482  */
3483 int
3484 mp_reduce (mp_int * x, const mp_int * m, const mp_int * mu)
3485 {
3486   mp_int  q;
3487   int     res, um = m->used;
3488
3489   /* q = x */
3490   if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
3491     return res;
3492   }
3493
3494   /* q1 = x / b**(k-1)  */
3495   mp_rshd (&q, um - 1);         
3496
3497   /* according to HAC this optimization is ok */
3498   if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
3499     if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
3500       goto CLEANUP;
3501     }
3502   } else {
3503     if ((res = s_mp_mul_high_digs (&q, mu, &q, um - 1)) != MP_OKAY) {
3504       goto CLEANUP;
3505     }
3506   }
3507
3508   /* q3 = q2 / b**(k+1) */
3509   mp_rshd (&q, um + 1);         
3510
3511   /* x = x mod b**(k+1), quick (no division) */
3512   if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
3513     goto CLEANUP;
3514   }
3515
3516   /* q = q * m mod b**(k+1), quick (no division) */
3517   if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
3518     goto CLEANUP;
3519   }
3520
3521   /* x = x - q */
3522   if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
3523     goto CLEANUP;
3524   }
3525
3526   /* If x < 0, add b**(k+1) to it */
3527   if (mp_cmp_d (x, 0) == MP_LT) {
3528     mp_set (&q, 1);
3529     if ((res = mp_lshd (&q, um + 1)) != MP_OKAY)
3530       goto CLEANUP;
3531     if ((res = mp_add (x, &q, x)) != MP_OKAY)
3532       goto CLEANUP;
3533   }
3534
3535   /* Back off if it's too big */
3536   while (mp_cmp (x, m) != MP_LT) {
3537     if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
3538       goto CLEANUP;
3539     }
3540   }
3541   
3542 CLEANUP:
3543   mp_clear (&q);
3544
3545   return res;
3546 }
3547
3548 /* reduces a modulo n where n is of the form 2**p - d */
3549 int
3550 mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d)
3551 {
3552    mp_int q;
3553    int    p, res;
3554    
3555    if ((res = mp_init(&q)) != MP_OKAY) {
3556       return res;
3557    }
3558    
3559    p = mp_count_bits(n);    
3560 top:
3561    /* q = a/2**p, a = a mod 2**p */
3562    if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
3563       goto ERR;
3564    }
3565    
3566    if (d != 1) {
3567       /* q = q * d */
3568       if ((res = mp_mul_d(&q, d, &q)) != MP_OKAY) { 
3569          goto ERR;
3570       }
3571    }
3572    
3573    /* a = a + q */
3574    if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
3575       goto ERR;
3576    }
3577    
3578    if (mp_cmp_mag(a, n) != MP_LT) {
3579       s_mp_sub(a, n, a);
3580       goto top;
3581    }
3582    
3583 ERR:
3584    mp_clear(&q);
3585    return res;
3586 }
3587
3588 /* determines the setup value */
3589 int 
3590 mp_reduce_2k_setup(const mp_int *a, mp_digit *d)
3591 {
3592    int res, p;
3593    mp_int tmp;
3594    
3595    if ((res = mp_init(&tmp)) != MP_OKAY) {
3596       return res;
3597    }
3598    
3599    p = mp_count_bits(a);
3600    if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
3601       mp_clear(&tmp);
3602       return res;
3603    }
3604    
3605    if ((res = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
3606       mp_clear(&tmp);
3607       return res;
3608    }
3609    
3610    *d = tmp.dp[0];
3611    mp_clear(&tmp);
3612    return MP_OKAY;
3613 }
3614
3615 /* pre-calculate the value required for Barrett reduction
3616  * For a given modulus "b" it calulates the value required in "a"
3617  */
3618 int mp_reduce_setup (mp_int * a, const mp_int * b)
3619 {
3620   int     res;
3621
3622   if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
3623     return res;
3624   }
3625   return mp_div (a, b, a, NULL);
3626 }
3627
3628 /* set to a digit */
3629 void mp_set (mp_int * a, mp_digit b)
3630 {
3631   mp_zero (a);
3632   a->dp[0] = b & MP_MASK;
3633   a->used  = (a->dp[0] != 0) ? 1 : 0;
3634 }
3635
3636 /* set a 32-bit const */
3637 int mp_set_int (mp_int * a, unsigned long b)
3638 {
3639   int     x, res;
3640
3641   mp_zero (a);
3642   
3643   /* set four bits at a time */
3644   for (x = 0; x < 8; x++) {
3645     /* shift the number up four bits */
3646     if ((res = mp_mul_2d (a, 4, a)) != MP_OKAY) {
3647       return res;
3648     }
3649
3650     /* OR in the top four bits of the source */
3651     a->dp[0] |= (b >> 28) & 15;
3652
3653     /* shift the source up to the next four bits */
3654     b <<= 4;
3655
3656     /* ensure that digits are not clamped off */
3657     a->used += 1;
3658   }
3659   mp_clamp (a);
3660   return MP_OKAY;
3661 }
3662
3663 /* shrink a bignum */
3664 int mp_shrink (mp_int * a)
3665 {
3666   mp_digit *tmp;
3667   if (a->alloc != a->used && a->used > 0) {
3668     if ((tmp = realloc (a->dp, sizeof (mp_digit) * a->used)) == NULL) {
3669       return MP_MEM;
3670     }
3671     a->dp    = tmp;
3672     a->alloc = a->used;
3673   }
3674   return MP_OKAY;
3675 }
3676
3677 /* get the size for an signed equivalent */
3678 int mp_signed_bin_size (const mp_int * a)
3679 {
3680   return 1 + mp_unsigned_bin_size (a);
3681 }
3682
3683 /* computes b = a*a */
3684 int
3685 mp_sqr (const mp_int * a, mp_int * b)
3686 {
3687   int     res;
3688
3689 if (a->used >= KARATSUBA_SQR_CUTOFF) {
3690     res = mp_karatsuba_sqr (a, b);
3691   } else 
3692   {
3693     /* can we use the fast comba multiplier? */
3694     if ((a->used * 2 + 1) < MP_WARRAY && 
3695          a->used < 
3696          (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
3697       res = fast_s_mp_sqr (a, b);
3698     } else
3699       res = s_mp_sqr (a, b);
3700   }
3701   b->sign = MP_ZPOS;
3702   return res;
3703 }
3704
3705 /* c = a * a (mod b) */
3706 int
3707 mp_sqrmod (const mp_int * a, mp_int * b, mp_int * c)
3708 {
3709   int     res;
3710   mp_int  t;
3711
3712   if ((res = mp_init (&t)) != MP_OKAY) {
3713     return res;
3714   }
3715
3716   if ((res = mp_sqr (a, &t)) != MP_OKAY) {
3717     mp_clear (&t);
3718     return res;
3719   }
3720   res = mp_mod (&t, b, c);
3721   mp_clear (&t);
3722   return res;
3723 }
3724
3725 /* high level subtraction (handles signs) */
3726 int
3727 mp_sub (mp_int * a, mp_int * b, mp_int * c)
3728 {
3729   int     sa, sb, res;
3730
3731   sa = a->sign;
3732   sb = b->sign;
3733
3734   if (sa != sb) {
3735     /* subtract a negative from a positive, OR */
3736     /* subtract a positive from a negative. */
3737     /* In either case, ADD their magnitudes, */
3738     /* and use the sign of the first number. */
3739     c->sign = sa;
3740     res = s_mp_add (a, b, c);
3741   } else {
3742     /* subtract a positive from a positive, OR */
3743     /* subtract a negative from a negative. */
3744     /* First, take the difference between their */
3745     /* magnitudes, then... */
3746     if (mp_cmp_mag (a, b) != MP_LT) {
3747       /* Copy the sign from the first */
3748       c->sign = sa;
3749       /* The first has a larger or equal magnitude */
3750       res = s_mp_sub (a, b, c);
3751     } else {
3752       /* The result has the *opposite* sign from */
3753       /* the first number. */
3754       c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
3755       /* The second has a larger magnitude */
3756       res = s_mp_sub (b, a, c);
3757     }
3758   }
3759   return res;
3760 }
3761
3762 /* single digit subtraction */
3763 int
3764 mp_sub_d (mp_int * a, mp_digit b, mp_int * c)
3765 {
3766   mp_digit *tmpa, *tmpc, mu;
3767   int       res, ix, oldused;
3768
3769   /* grow c as required */
3770   if (c->alloc < a->used + 1) {
3771      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
3772         return res;
3773      }
3774   }
3775
3776   /* if a is negative just do an unsigned
3777    * addition [with fudged signs]
3778    */
3779   if (a->sign == MP_NEG) {
3780      a->sign = MP_ZPOS;
3781      res     = mp_add_d(a, b, c);
3782      a->sign = c->sign = MP_NEG;
3783      return res;
3784   }
3785
3786   /* setup regs */
3787   oldused = c->used;
3788   tmpa    = a->dp;
3789   tmpc    = c->dp;
3790
3791   /* if a <= b simply fix the single digit */
3792   if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
3793      if (a->used == 1) {
3794         *tmpc++ = b - *tmpa;
3795      } else {
3796         *tmpc++ = b;
3797      }
3798      ix      = 1;
3799
3800      /* negative/1digit */
3801      c->sign = MP_NEG;
3802      c->used = 1;
3803   } else {
3804      /* positive/size */
3805      c->sign = MP_ZPOS;
3806      c->used = a->used;
3807
3808      /* subtract first digit */
3809      *tmpc    = *tmpa++ - b;
3810      mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3811      *tmpc++ &= MP_MASK;
3812
3813      /* handle rest of the digits */
3814      for (ix = 1; ix < a->used; ix++) {
3815         *tmpc    = *tmpa++ - mu;
3816         mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3817         *tmpc++ &= MP_MASK;
3818      }
3819   }
3820
3821   /* zero excess digits */
3822   while (ix++ < oldused) {
3823      *tmpc++ = 0;
3824   }
3825   mp_clamp(c);
3826   return MP_OKAY;
3827 }
3828
3829 /* store in unsigned [big endian] format */
3830 int
3831 mp_to_unsigned_bin (const mp_int * a, unsigned char *b)
3832 {
3833   int     x, res;
3834   mp_int  t;
3835
3836   if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
3837     return res;
3838   }
3839
3840   x = 0;
3841   while (mp_iszero (&t) == 0) {
3842     b[x++] = (unsigned char) (t.dp[0] & 255);
3843     if ((res = mp_div_2d (&t, 8, &t, NULL)) != MP_OKAY) {
3844       mp_clear (&t);
3845       return res;
3846     }
3847   }
3848   bn_reverse (b, x);
3849   mp_clear (&t);
3850   return MP_OKAY;
3851 }
3852
3853 /* get the size for an unsigned equivalent */
3854 int
3855 mp_unsigned_bin_size (const mp_int * a)
3856 {
3857   int     size = mp_count_bits (a);
3858   return (size / 8 + ((size & 7) != 0 ? 1 : 0));
3859 }
3860
3861 /* reverse an array, used for radix code */
3862 static void
3863 bn_reverse (unsigned char *s, int len)
3864 {
3865   int     ix, iy;
3866   unsigned char t;
3867
3868   ix = 0;
3869   iy = len - 1;
3870   while (ix < iy) {
3871     t     = s[ix];
3872     s[ix] = s[iy];
3873     s[iy] = t;
3874     ++ix;
3875     --iy;
3876   }
3877 }
3878
3879 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
3880 static int
3881 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
3882 {
3883   mp_int *x;
3884   int     olduse, res, min, max;
3885
3886   /* find sizes, we let |a| <= |b| which means we have to sort
3887    * them.  "x" will point to the input with the most digits
3888    */
3889   if (a->used > b->used) {
3890     min = b->used;
3891     max = a->used;
3892     x = a;
3893   } else {
3894     min = a->used;
3895     max = b->used;
3896     x = b;
3897   }
3898
3899   /* init result */
3900   if (c->alloc < max + 1) {
3901     if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
3902       return res;
3903     }
3904   }
3905
3906   /* get old used digit count and set new one */
3907   olduse = c->used;
3908   c->used = max + 1;
3909
3910   {
3911     register mp_digit u, *tmpa, *tmpb, *tmpc;
3912     register int i;
3913
3914     /* alias for digit pointers */
3915
3916     /* first input */
3917     tmpa = a->dp;
3918
3919     /* second input */
3920     tmpb = b->dp;
3921
3922     /* destination */
3923     tmpc = c->dp;
3924
3925     /* zero the carry */
3926     u = 0;
3927     for (i = 0; i < min; i++) {
3928       /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
3929       *tmpc = *tmpa++ + *tmpb++ + u;
3930
3931       /* U = carry bit of T[i] */
3932       u = *tmpc >> ((mp_digit)DIGIT_BIT);
3933
3934       /* take away carry bit from T[i] */
3935       *tmpc++ &= MP_MASK;
3936     }
3937
3938     /* now copy higher words if any, that is in A+B 
3939      * if A or B has more digits add those in 
3940      */
3941     if (min != max) {
3942       for (; i < max; i++) {
3943         /* T[i] = X[i] + U */
3944         *tmpc = x->dp[i] + u;
3945
3946         /* U = carry bit of T[i] */
3947         u = *tmpc >> ((mp_digit)DIGIT_BIT);
3948
3949         /* take away carry bit from T[i] */
3950         *tmpc++ &= MP_MASK;
3951       }
3952     }
3953
3954     /* add carry */
3955     *tmpc++ = u;
3956
3957     /* clear digits above oldused */
3958     for (i = c->used; i < olduse; i++) {
3959       *tmpc++ = 0;
3960     }
3961   }
3962
3963   mp_clamp (c);
3964   return MP_OKAY;
3965 }
3966
3967 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
3968 {
3969   mp_int  M[256], res, mu;
3970   mp_digit buf;
3971   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
3972
3973   /* find window size */
3974   x = mp_count_bits (X);
3975   if (x <= 7) {
3976     winsize = 2;
3977   } else if (x <= 36) {
3978     winsize = 3;
3979   } else if (x <= 140) {
3980     winsize = 4;
3981   } else if (x <= 450) {
3982     winsize = 5;
3983   } else if (x <= 1303) {
3984     winsize = 6;
3985   } else if (x <= 3529) {
3986     winsize = 7;
3987   } else {
3988     winsize = 8;
3989   }
3990
3991   /* init M array */
3992   /* init first cell */
3993   if ((err = mp_init(&M[1])) != MP_OKAY) {
3994      return err; 
3995   }
3996
3997   /* now init the second half of the array */
3998   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
3999     if ((err = mp_init(&M[x])) != MP_OKAY) {
4000       for (y = 1<<(winsize-1); y < x; y++) {
4001         mp_clear (&M[y]);
4002       }
4003       mp_clear(&M[1]);
4004       return err;
4005     }
4006   }
4007
4008   /* create mu, used for Barrett reduction */
4009   if ((err = mp_init (&mu)) != MP_OKAY) {
4010     goto __M;
4011   }
4012   if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
4013     goto __MU;
4014   }
4015
4016   /* create M table
4017    *
4018    * The M table contains powers of the base, 
4019    * e.g. M[x] = G**x mod P
4020    *
4021    * The first half of the table is not 
4022    * computed though accept for M[0] and M[1]
4023    */
4024   if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
4025     goto __MU;
4026   }
4027
4028   /* compute the value at M[1<<(winsize-1)] by squaring 
4029    * M[1] (winsize-1) times 
4030    */
4031   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
4032     goto __MU;
4033   }
4034
4035   for (x = 0; x < (winsize - 1); x++) {
4036     if ((err = mp_sqr (&M[1 << (winsize - 1)], 
4037                        &M[1 << (winsize - 1)])) != MP_OKAY) {
4038       goto __MU;
4039     }
4040     if ((err = mp_reduce (&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
4041       goto __MU;
4042     }
4043   }
4044
4045   /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
4046    * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
4047    */
4048   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
4049     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
4050       goto __MU;
4051     }
4052     if ((err = mp_reduce (&M[x], P, &mu)) != MP_OKAY) {
4053       goto __MU;
4054     }
4055   }
4056
4057   /* setup result */
4058   if ((err = mp_init (&res)) != MP_OKAY) {
4059     goto __MU;
4060   }
4061   mp_set (&res, 1);
4062
4063   /* set initial mode and bit cnt */
4064   mode   = 0;
4065   bitcnt = 1;
4066   buf    = 0;
4067   digidx = X->used - 1;
4068   bitcpy = 0;
4069   bitbuf = 0;
4070
4071   for (;;) {
4072     /* grab next digit as required */
4073     if (--bitcnt == 0) {
4074       /* if digidx == -1 we are out of digits */
4075       if (digidx == -1) {
4076         break;
4077       }
4078       /* read next digit and reset the bitcnt */
4079       buf    = X->dp[digidx--];
4080       bitcnt = DIGIT_BIT;
4081     }
4082
4083     /* grab the next msb from the exponent */
4084     y     = (buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
4085     buf <<= (mp_digit)1;
4086
4087     /* if the bit is zero and mode == 0 then we ignore it
4088      * These represent the leading zero bits before the first 1 bit
4089      * in the exponent.  Technically this opt is not required but it
4090      * does lower the # of trivial squaring/reductions used
4091      */
4092     if (mode == 0 && y == 0) {
4093       continue;
4094     }
4095
4096     /* if the bit is zero and mode == 1 then we square */
4097     if (mode == 1 && y == 0) {
4098       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4099         goto __RES;
4100       }
4101       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4102         goto __RES;
4103       }
4104       continue;
4105     }
4106
4107     /* else we add it to the window */
4108     bitbuf |= (y << (winsize - ++bitcpy));
4109     mode    = 2;
4110
4111     if (bitcpy == winsize) {
4112       /* ok window is filled so square as required and multiply  */
4113       /* square first */
4114       for (x = 0; x < winsize; x++) {
4115         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4116           goto __RES;
4117         }
4118         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4119           goto __RES;
4120         }
4121       }
4122
4123       /* then multiply */
4124       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
4125         goto __RES;
4126       }
4127       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4128         goto __RES;
4129       }
4130
4131       /* empty window and reset */
4132       bitcpy = 0;
4133       bitbuf = 0;
4134       mode   = 1;
4135     }
4136   }
4137
4138   /* if bits remain then square/multiply */
4139   if (mode == 2 && bitcpy > 0) {
4140     /* square then multiply if the bit is set */
4141     for (x = 0; x < bitcpy; x++) {
4142       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4143         goto __RES;
4144       }
4145       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4146         goto __RES;
4147       }
4148
4149       bitbuf <<= 1;
4150       if ((bitbuf & (1 << winsize)) != 0) {
4151         /* then multiply */
4152         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
4153           goto __RES;
4154         }
4155         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4156           goto __RES;
4157         }
4158       }
4159     }
4160   }
4161
4162   mp_exch (&res, Y);
4163   err = MP_OKAY;
4164 __RES:mp_clear (&res);
4165 __MU:mp_clear (&mu);
4166 __M:
4167   mp_clear(&M[1]);
4168   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4169     mp_clear (&M[x]);
4170   }
4171   return err;
4172 }
4173
4174 /* multiplies |a| * |b| and only computes up to digs digits of result
4175  * HAC pp. 595, Algorithm 14.12  Modified so you can control how 
4176  * many digits of output are created.
4177  */
4178 static int
4179 s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4180 {
4181   mp_int  t;
4182   int     res, pa, pb, ix, iy;
4183   mp_digit u;
4184   mp_word r;
4185   mp_digit tmpx, *tmpt, *tmpy;
4186
4187   /* can we use the fast multiplier? */
4188   if (((digs) < MP_WARRAY) &&
4189       MIN (a->used, b->used) < 
4190           (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4191     return fast_s_mp_mul_digs (a, b, c, digs);
4192   }
4193
4194   if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
4195     return res;
4196   }
4197   t.used = digs;
4198
4199   /* compute the digits of the product directly */
4200   pa = a->used;
4201   for (ix = 0; ix < pa; ix++) {
4202     /* set the carry to zero */
4203     u = 0;
4204
4205     /* limit ourselves to making digs digits of output */
4206     pb = MIN (b->used, digs - ix);
4207
4208     /* setup some aliases */
4209     /* copy of the digit from a used within the nested loop */
4210     tmpx = a->dp[ix];
4211     
4212     /* an alias for the destination shifted ix places */
4213     tmpt = t.dp + ix;
4214     
4215     /* an alias for the digits of b */
4216     tmpy = b->dp;
4217
4218     /* compute the columns of the output and propagate the carry */
4219     for (iy = 0; iy < pb; iy++) {
4220       /* compute the column as a mp_word */
4221       r       = ((mp_word)*tmpt) +
4222                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4223                 ((mp_word) u);
4224
4225       /* the new column is the lower part of the result */
4226       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4227
4228       /* get the carry word from the result */
4229       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4230     }
4231     /* set carry if it is placed below digs */
4232     if (ix + iy < digs) {
4233       *tmpt = u;
4234     }
4235   }
4236
4237   mp_clamp (&t);
4238   mp_exch (&t, c);
4239
4240   mp_clear (&t);
4241   return MP_OKAY;
4242 }
4243
4244 /* multiplies |a| * |b| and does not compute the lower digs digits
4245  * [meant to get the higher part of the product]
4246  */
4247 static int
4248 s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4249 {
4250   mp_int  t;
4251   int     res, pa, pb, ix, iy;
4252   mp_digit u;
4253   mp_word r;
4254   mp_digit tmpx, *tmpt, *tmpy;
4255
4256   /* can we use the fast multiplier? */
4257   if (((a->used + b->used + 1) < MP_WARRAY)
4258       && MIN (a->used, b->used) < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4259     return fast_s_mp_mul_high_digs (a, b, c, digs);
4260   }
4261
4262   if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
4263     return res;
4264   }
4265   t.used = a->used + b->used + 1;
4266
4267   pa = a->used;
4268   pb = b->used;
4269   for (ix = 0; ix < pa; ix++) {
4270     /* clear the carry */
4271     u = 0;
4272
4273     /* left hand side of A[ix] * B[iy] */
4274     tmpx = a->dp[ix];
4275
4276     /* alias to the address of where the digits will be stored */
4277     tmpt = &(t.dp[digs]);
4278
4279     /* alias for where to read the right hand side from */
4280     tmpy = b->dp + (digs - ix);
4281
4282     for (iy = digs - ix; iy < pb; iy++) {
4283       /* calculate the double precision result */
4284       r       = ((mp_word)*tmpt) +
4285                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4286                 ((mp_word) u);
4287
4288       /* get the lower part */
4289       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4290
4291       /* carry the carry */
4292       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4293     }
4294     *tmpt = u;
4295   }
4296   mp_clamp (&t);
4297   mp_exch (&t, c);
4298   mp_clear (&t);
4299   return MP_OKAY;
4300 }
4301
4302 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
4303 static int
4304 s_mp_sqr (const mp_int * a, mp_int * b)
4305 {
4306   mp_int  t;
4307   int     res, ix, iy, pa;
4308   mp_word r;
4309   mp_digit u, tmpx, *tmpt;
4310
4311   pa = a->used;
4312   if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
4313     return res;
4314   }
4315
4316   /* default used is maximum possible size */
4317   t.used = 2*pa + 1;
4318
4319   for (ix = 0; ix < pa; ix++) {
4320     /* first calculate the digit at 2*ix */
4321     /* calculate double precision result */
4322     r = ((mp_word) t.dp[2*ix]) +
4323         ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
4324
4325     /* store lower part in result */
4326     t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
4327
4328     /* get the carry */
4329     u           = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4330
4331     /* left hand side of A[ix] * A[iy] */
4332     tmpx        = a->dp[ix];
4333
4334     /* alias for where to store the results */
4335     tmpt        = t.dp + (2*ix + 1);
4336     
4337     for (iy = ix + 1; iy < pa; iy++) {
4338       /* first calculate the product */
4339       r       = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
4340
4341       /* now calculate the double precision result, note we use
4342        * addition instead of *2 since it's easier to optimize
4343        */
4344       r       = ((mp_word) *tmpt) + r + r + ((mp_word) u);
4345
4346       /* store lower part */
4347       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4348
4349       /* get carry */
4350       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4351     }
4352     /* propagate upwards */
4353     while (u != 0) {
4354       r       = ((mp_word) *tmpt) + ((mp_word) u);
4355       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4356       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4357     }
4358   }
4359
4360   mp_clamp (&t);
4361   mp_exch (&t, b);
4362   mp_clear (&t);
4363   return MP_OKAY;
4364 }
4365
4366 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
4367 int
4368 s_mp_sub (const mp_int * a, const mp_int * b, mp_int * c)
4369 {
4370   int     olduse, res, min, max;
4371
4372   /* find sizes */
4373   min = b->used;
4374   max = a->used;
4375
4376   /* init result */
4377   if (c->alloc < max) {
4378     if ((res = mp_grow (c, max)) != MP_OKAY) {
4379       return res;
4380     }
4381   }
4382   olduse = c->used;
4383   c->used = max;
4384
4385   {
4386     register mp_digit u, *tmpa, *tmpb, *tmpc;
4387     register int i;
4388
4389     /* alias for digit pointers */
4390     tmpa = a->dp;
4391     tmpb = b->dp;
4392     tmpc = c->dp;
4393
4394     /* set carry to zero */
4395     u = 0;
4396     for (i = 0; i < min; i++) {
4397       /* T[i] = A[i] - B[i] - U */
4398       *tmpc = *tmpa++ - *tmpb++ - u;
4399
4400       /* U = carry bit of T[i]
4401        * Note this saves performing an AND operation since
4402        * if a carry does occur it will propagate all the way to the
4403        * MSB.  As a result a single shift is enough to get the carry
4404        */
4405       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4406
4407       /* Clear carry from T[i] */
4408       *tmpc++ &= MP_MASK;
4409     }
4410
4411     /* now copy higher words if any, e.g. if A has more digits than B  */
4412     for (; i < max; i++) {
4413       /* T[i] = A[i] - U */
4414       *tmpc = *tmpa++ - u;
4415
4416       /* U = carry bit of T[i] */
4417       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4418
4419       /* Clear carry from T[i] */
4420       *tmpc++ &= MP_MASK;
4421     }
4422
4423     /* clear digits above used (since we may not have grown result above) */
4424     for (i = c->used; i < olduse; i++) {
4425       *tmpc++ = 0;
4426     }
4427   }
4428
4429   mp_clamp (c);
4430   return MP_OKAY;
4431 }