msctf: Add a file version.
[wine] / dlls / rsaenh / mpi.c
1 /*
2  * dlls/rsaenh/mpi.c
3  * Multi Precision Integer functions
4  *
5  * Copyright 2004 Michael Jung
6  * Based on public domain code by Tom St Denis (tomstdenis@iahu.ca)
7  *
8  * This library is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * This library is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with this library; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
21  */
22
23 /*
24  * This file contains code from the LibTomCrypt cryptographic 
25  * library written by Tom St Denis (tomstdenis@iahu.ca). LibTomCrypt
26  * is in the public domain. The code in this file is tailored to
27  * special requirements. Take a look at http://libtomcrypt.org for the
28  * original version. 
29  */
30
31 #include <stdarg.h>
32
33 #include "windef.h"
34 #include "winbase.h"
35 #include "tomcrypt.h"
36
37 /* Known optimal configurations
38  CPU                    /Compiler     /MUL CUTOFF/SQR CUTOFF
39 -------------------------------------------------------------
40  Intel P4 Northwood     /GCC v3.4.1   /        88/       128/LTM 0.32 ;-)
41 */
42 static const int KARATSUBA_MUL_CUTOFF = 88,  /* Min. number of digits before Karatsuba multiplication is used. */
43                  KARATSUBA_SQR_CUTOFF = 128; /* Min. number of digits before Karatsuba squaring is used. */
44
45 static void bn_reverse(unsigned char *s, int len);
46 static int s_mp_add(mp_int *a, mp_int *b, mp_int *c);
47 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y);
48 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
49 static int s_mp_mul_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
50 static int s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
51 static int s_mp_sqr(const mp_int *a, mp_int *b);
52 static int s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c);
53 static int mp_exptmod_fast(const mp_int *G, const mp_int *X, mp_int *P, mp_int *Y, int mode);
54 static int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c);
55 static int mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c);
56 static int mp_karatsuba_sqr(const mp_int *a, mp_int *b);
57
58 /* grow as required */
59 static int mp_grow (mp_int * a, int size)
60 {
61   int     i;
62   mp_digit *tmp;
63
64   /* if the alloc size is smaller alloc more ram */
65   if (a->alloc < size) {
66     /* ensure there are always at least MP_PREC digits extra on top */
67     size += (MP_PREC * 2) - (size % MP_PREC);
68
69     /* reallocate the array a->dp
70      *
71      * We store the return in a temporary variable
72      * in case the operation failed we don't want
73      * to overwrite the dp member of a.
74      */
75     tmp = HeapReAlloc(GetProcessHeap(), 0, a->dp, sizeof (mp_digit) * size);
76     if (tmp == NULL) {
77       /* reallocation failed but "a" is still valid [can be freed] */
78       return MP_MEM;
79     }
80
81     /* reallocation succeeded so set a->dp */
82     a->dp = tmp;
83
84     /* zero excess digits */
85     i        = a->alloc;
86     a->alloc = size;
87     for (; i < a->alloc; i++) {
88       a->dp[i] = 0;
89     }
90   }
91   return MP_OKAY;
92 }
93
94 /* b = a/2 */
95 static int mp_div_2(const mp_int * a, mp_int * b)
96 {
97   int     x, res, oldused;
98
99   /* copy */
100   if (b->alloc < a->used) {
101     if ((res = mp_grow (b, a->used)) != MP_OKAY) {
102       return res;
103     }
104   }
105
106   oldused = b->used;
107   b->used = a->used;
108   {
109     register mp_digit r, rr, *tmpa, *tmpb;
110
111     /* source alias */
112     tmpa = a->dp + b->used - 1;
113
114     /* dest alias */
115     tmpb = b->dp + b->used - 1;
116
117     /* carry */
118     r = 0;
119     for (x = b->used - 1; x >= 0; x--) {
120       /* get the carry for the next iteration */
121       rr = *tmpa & 1;
122
123       /* shift the current digit, add in carry and store */
124       *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
125
126       /* forward carry to next iteration */
127       r = rr;
128     }
129
130     /* zero excess digits */
131     tmpb = b->dp + b->used;
132     for (x = b->used; x < oldused; x++) {
133       *tmpb++ = 0;
134     }
135   }
136   b->sign = a->sign;
137   mp_clamp (b);
138   return MP_OKAY;
139 }
140
141 /* swap the elements of two integers, for cases where you can't simply swap the
142  * mp_int pointers around
143  */
144 static void
145 mp_exch (mp_int * a, mp_int * b)
146 {
147   mp_int  t;
148
149   t  = *a;
150   *a = *b;
151   *b = t;
152 }
153
154 /* init a new mp_int */
155 static int mp_init (mp_int * a)
156 {
157   int i;
158
159   /* allocate memory required and clear it */
160   a->dp = HeapAlloc(GetProcessHeap(), 0, sizeof (mp_digit) * MP_PREC);
161   if (a->dp == NULL) {
162     return MP_MEM;
163   }
164
165   /* set the digits to zero */
166   for (i = 0; i < MP_PREC; i++) {
167       a->dp[i] = 0;
168   }
169
170   /* set the used to zero, allocated digits to the default precision
171    * and sign to positive */
172   a->used  = 0;
173   a->alloc = MP_PREC;
174   a->sign  = MP_ZPOS;
175
176   return MP_OKAY;
177 }
178
179 /* init an mp_init for a given size */
180 static int mp_init_size (mp_int * a, int size)
181 {
182   int x;
183
184   /* pad size so there are always extra digits */
185   size += (MP_PREC * 2) - (size % MP_PREC);
186
187   /* alloc mem */
188   a->dp = HeapAlloc(GetProcessHeap(), 0, sizeof (mp_digit) * size);
189   if (a->dp == NULL) {
190     return MP_MEM;
191   }
192
193   /* set the members */
194   a->used  = 0;
195   a->alloc = size;
196   a->sign  = MP_ZPOS;
197
198   /* zero the digits */
199   for (x = 0; x < size; x++) {
200       a->dp[x] = 0;
201   }
202
203   return MP_OKAY;
204 }
205
206 /* clear one (frees)  */
207 static void
208 mp_clear (mp_int * a)
209 {
210   int i;
211
212   /* only do anything if a hasn't been freed previously */
213   if (a->dp != NULL) {
214     /* first zero the digits */
215     for (i = 0; i < a->used; i++) {
216         a->dp[i] = 0;
217     }
218
219     /* free ram */
220     HeapFree(GetProcessHeap(), 0, a->dp);
221
222     /* reset members to make debugging easier */
223     a->dp    = NULL;
224     a->alloc = a->used = 0;
225     a->sign  = MP_ZPOS;
226   }
227 }
228
229 /* set to zero */
230 static void
231 mp_zero (mp_int * a)
232 {
233   a->sign = MP_ZPOS;
234   a->used = 0;
235   memset (a->dp, 0, sizeof (mp_digit) * a->alloc);
236 }
237
238 /* b = |a|
239  *
240  * Simple function copies the input and fixes the sign to positive
241  */
242 static int
243 mp_abs (const mp_int * a, mp_int * b)
244 {
245   int     res;
246
247   /* copy a to b */
248   if (a != b) {
249      if ((res = mp_copy (a, b)) != MP_OKAY) {
250        return res;
251      }
252   }
253
254   /* force the sign of b to positive */
255   b->sign = MP_ZPOS;
256
257   return MP_OKAY;
258 }
259
260 /* computes the modular inverse via binary extended euclidean algorithm, 
261  * that is c = 1/a mod b 
262  *
263  * Based on slow invmod except this is optimized for the case where b is 
264  * odd as per HAC Note 14.64 on pp. 610
265  */
266 static int
267 fast_mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
268 {
269   mp_int  x, y, u, v, B, D;
270   int     res, neg;
271
272   /* 2. [modified] b must be odd   */
273   if (mp_iseven (b) == 1) {
274     return MP_VAL;
275   }
276
277   /* init all our temps */
278   if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D, NULL)) != MP_OKAY) {
279      return res;
280   }
281
282   /* x == modulus, y == value to invert */
283   if ((res = mp_copy (b, &x)) != MP_OKAY) {
284     goto __ERR;
285   }
286
287   /* we need y = |a| */
288   if ((res = mp_abs (a, &y)) != MP_OKAY) {
289     goto __ERR;
290   }
291
292   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
293   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
294     goto __ERR;
295   }
296   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
297     goto __ERR;
298   }
299   mp_set (&D, 1);
300
301 top:
302   /* 4.  while u is even do */
303   while (mp_iseven (&u) == 1) {
304     /* 4.1 u = u/2 */
305     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
306       goto __ERR;
307     }
308     /* 4.2 if B is odd then */
309     if (mp_isodd (&B) == 1) {
310       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
311         goto __ERR;
312       }
313     }
314     /* B = B/2 */
315     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
316       goto __ERR;
317     }
318   }
319
320   /* 5.  while v is even do */
321   while (mp_iseven (&v) == 1) {
322     /* 5.1 v = v/2 */
323     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
324       goto __ERR;
325     }
326     /* 5.2 if D is odd then */
327     if (mp_isodd (&D) == 1) {
328       /* D = (D-x)/2 */
329       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
330         goto __ERR;
331       }
332     }
333     /* D = D/2 */
334     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
335       goto __ERR;
336     }
337   }
338
339   /* 6.  if u >= v then */
340   if (mp_cmp (&u, &v) != MP_LT) {
341     /* u = u - v, B = B - D */
342     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
343       goto __ERR;
344     }
345
346     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
347       goto __ERR;
348     }
349   } else {
350     /* v - v - u, D = D - B */
351     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
352       goto __ERR;
353     }
354
355     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
356       goto __ERR;
357     }
358   }
359
360   /* if not zero goto step 4 */
361   if (mp_iszero (&u) == 0) {
362     goto top;
363   }
364
365   /* now a = C, b = D, gcd == g*v */
366
367   /* if v != 1 then there is no inverse */
368   if (mp_cmp_d (&v, 1) != MP_EQ) {
369     res = MP_VAL;
370     goto __ERR;
371   }
372
373   /* b is now the inverse */
374   neg = a->sign;
375   while (D.sign == MP_NEG) {
376     if ((res = mp_add (&D, b, &D)) != MP_OKAY) {
377       goto __ERR;
378     }
379   }
380   mp_exch (&D, c);
381   c->sign = neg;
382   res = MP_OKAY;
383
384 __ERR:mp_clear_multi (&x, &y, &u, &v, &B, &D, NULL);
385   return res;
386 }
387
388 /* computes xR**-1 == x (mod N) via Montgomery Reduction
389  *
390  * This is an optimized implementation of montgomery_reduce
391  * which uses the comba method to quickly calculate the columns of the
392  * reduction.
393  *
394  * Based on Algorithm 14.32 on pp.601 of HAC.
395 */
396 static int
397 fast_mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
398 {
399   int     ix, res, olduse;
400   mp_word W[MP_WARRAY];
401
402   /* get old used count */
403   olduse = x->used;
404
405   /* grow a as required */
406   if (x->alloc < n->used + 1) {
407     if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
408       return res;
409     }
410   }
411
412   /* first we have to get the digits of the input into
413    * an array of double precision words W[...]
414    */
415   {
416     register mp_word *_W;
417     register mp_digit *tmpx;
418
419     /* alias for the W[] array */
420     _W   = W;
421
422     /* alias for the digits of  x*/
423     tmpx = x->dp;
424
425     /* copy the digits of a into W[0..a->used-1] */
426     for (ix = 0; ix < x->used; ix++) {
427       *_W++ = *tmpx++;
428     }
429
430     /* zero the high words of W[a->used..m->used*2] */
431     for (; ix < n->used * 2 + 1; ix++) {
432       *_W++ = 0;
433     }
434   }
435
436   /* now we proceed to zero successive digits
437    * from the least significant upwards
438    */
439   for (ix = 0; ix < n->used; ix++) {
440     /* mu = ai * m' mod b
441      *
442      * We avoid a double precision multiplication (which isn't required)
443      * by casting the value down to a mp_digit.  Note this requires
444      * that W[ix-1] have  the carry cleared (see after the inner loop)
445      */
446     register mp_digit mu;
447     mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
448
449     /* a = a + mu * m * b**i
450      *
451      * This is computed in place and on the fly.  The multiplication
452      * by b**i is handled by offsetting which columns the results
453      * are added to.
454      *
455      * Note the comba method normally doesn't handle carries in the
456      * inner loop In this case we fix the carry from the previous
457      * column since the Montgomery reduction requires digits of the
458      * result (so far) [see above] to work.  This is
459      * handled by fixing up one carry after the inner loop.  The
460      * carry fixups are done in order so after these loops the
461      * first m->used words of W[] have the carries fixed
462      */
463     {
464       register int iy;
465       register mp_digit *tmpn;
466       register mp_word *_W;
467
468       /* alias for the digits of the modulus */
469       tmpn = n->dp;
470
471       /* Alias for the columns set by an offset of ix */
472       _W = W + ix;
473
474       /* inner loop */
475       for (iy = 0; iy < n->used; iy++) {
476           *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
477       }
478     }
479
480     /* now fix carry for next digit, W[ix+1] */
481     W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
482   }
483
484   /* now we have to propagate the carries and
485    * shift the words downward [all those least
486    * significant digits we zeroed].
487    */
488   {
489     register mp_digit *tmpx;
490     register mp_word *_W, *_W1;
491
492     /* nox fix rest of carries */
493
494     /* alias for current word */
495     _W1 = W + ix;
496
497     /* alias for next word, where the carry goes */
498     _W = W + ++ix;
499
500     for (; ix <= n->used * 2 + 1; ix++) {
501       *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
502     }
503
504     /* copy out, A = A/b**n
505      *
506      * The result is A/b**n but instead of converting from an
507      * array of mp_word to mp_digit than calling mp_rshd
508      * we just copy them in the right order
509      */
510
511     /* alias for destination word */
512     tmpx = x->dp;
513
514     /* alias for shifted double precision result */
515     _W = W + n->used;
516
517     for (ix = 0; ix < n->used + 1; ix++) {
518       *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
519     }
520
521     /* zero oldused digits, if the input a was larger than
522      * m->used+1 we'll have to clear the digits
523      */
524     for (; ix < olduse; ix++) {
525       *tmpx++ = 0;
526     }
527   }
528
529   /* set the max used and clamp */
530   x->used = n->used + 1;
531   mp_clamp (x);
532
533   /* if A >= m then A = A - m */
534   if (mp_cmp_mag (x, n) != MP_LT) {
535     return s_mp_sub (x, n, x);
536   }
537   return MP_OKAY;
538 }
539
540 /* Fast (comba) multiplier
541  *
542  * This is the fast column-array [comba] multiplier.  It is 
543  * designed to compute the columns of the product first 
544  * then handle the carries afterwards.  This has the effect 
545  * of making the nested loops that compute the columns very
546  * simple and schedulable on super-scalar processors.
547  *
548  * This has been modified to produce a variable number of 
549  * digits of output so if say only a half-product is required 
550  * you don't have to compute the upper half (a feature 
551  * required for fast Barrett reduction).
552  *
553  * Based on Algorithm 14.12 on pp.595 of HAC.
554  *
555  */
556 static int
557 fast_s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
558 {
559   int     olduse, res, pa, ix, iz;
560   mp_digit W[MP_WARRAY];
561   register mp_word  _W;
562
563   /* grow the destination as required */
564   if (c->alloc < digs) {
565     if ((res = mp_grow (c, digs)) != MP_OKAY) {
566       return res;
567     }
568   }
569
570   /* number of output digits to produce */
571   pa = MIN(digs, a->used + b->used);
572
573   /* clear the carry */
574   _W = 0;
575   for (ix = 0; ix <= pa; ix++) { 
576       int      tx, ty;
577       int      iy;
578       mp_digit *tmpx, *tmpy;
579
580       /* get offsets into the two bignums */
581       ty = MIN(b->used-1, ix);
582       tx = ix - ty;
583
584       /* setup temp aliases */
585       tmpx = a->dp + tx;
586       tmpy = b->dp + ty;
587
588       /* This is the number of times the loop will iterate, essentially it's
589          while (tx++ < a->used && ty-- >= 0) { ... }
590        */
591       iy = MIN(a->used-tx, ty+1);
592
593       /* execute loop */
594       for (iz = 0; iz < iy; ++iz) {
595          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
596       }
597
598       /* store term */
599       W[ix] = ((mp_digit)_W) & MP_MASK;
600
601       /* make next carry */
602       _W = _W >> ((mp_word)DIGIT_BIT);
603   }
604
605   /* setup dest */
606   olduse  = c->used;
607   c->used = digs;
608
609   {
610     register mp_digit *tmpc;
611     tmpc = c->dp;
612     for (ix = 0; ix < digs; ix++) {
613       /* now extract the previous digit [below the carry] */
614       *tmpc++ = W[ix];
615     }
616
617     /* clear unused digits [that existed in the old copy of c] */
618     for (; ix < olduse; ix++) {
619       *tmpc++ = 0;
620     }
621   }
622   mp_clamp (c);
623   return MP_OKAY;
624 }
625
626 /* this is a modified version of fast_s_mul_digs that only produces
627  * output digits *above* digs.  See the comments for fast_s_mul_digs
628  * to see how it works.
629  *
630  * This is used in the Barrett reduction since for one of the multiplications
631  * only the higher digits were needed.  This essentially halves the work.
632  *
633  * Based on Algorithm 14.12 on pp.595 of HAC.
634  */
635 static int
636 fast_s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
637 {
638   int     olduse, res, pa, ix, iz;
639   mp_digit W[MP_WARRAY];
640   mp_word  _W;
641
642   /* grow the destination as required */
643   pa = a->used + b->used;
644   if (c->alloc < pa) {
645     if ((res = mp_grow (c, pa)) != MP_OKAY) {
646       return res;
647     }
648   }
649
650   /* number of output digits to produce */
651   pa = a->used + b->used;
652   _W = 0;
653   for (ix = digs; ix <= pa; ix++) { 
654       int      tx, ty, iy;
655       mp_digit *tmpx, *tmpy;
656
657       /* get offsets into the two bignums */
658       ty = MIN(b->used-1, ix);
659       tx = ix - ty;
660
661       /* setup temp aliases */
662       tmpx = a->dp + tx;
663       tmpy = b->dp + ty;
664
665       /* This is the number of times the loop will iterate, essentially it's
666          while (tx++ < a->used && ty-- >= 0) { ... }
667        */
668       iy = MIN(a->used-tx, ty+1);
669
670       /* execute loop */
671       for (iz = 0; iz < iy; iz++) {
672          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
673       }
674
675       /* store term */
676       W[ix] = ((mp_digit)_W) & MP_MASK;
677
678       /* make next carry */
679       _W = _W >> ((mp_word)DIGIT_BIT);
680   }
681
682   /* setup dest */
683   olduse  = c->used;
684   c->used = pa;
685
686   {
687     register mp_digit *tmpc;
688
689     tmpc = c->dp + digs;
690     for (ix = digs; ix <= pa; ix++) {
691       /* now extract the previous digit [below the carry] */
692       *tmpc++ = W[ix];
693     }
694
695     /* clear unused digits [that existed in the old copy of c] */
696     for (; ix < olduse; ix++) {
697       *tmpc++ = 0;
698     }
699   }
700   mp_clamp (c);
701   return MP_OKAY;
702 }
703
704 /* fast squaring
705  *
706  * This is the comba method where the columns of the product
707  * are computed first then the carries are computed.  This
708  * has the effect of making a very simple inner loop that
709  * is executed the most
710  *
711  * W2 represents the outer products and W the inner.
712  *
713  * A further optimizations is made because the inner
714  * products are of the form "A * B * 2".  The *2 part does
715  * not need to be computed until the end which is good
716  * because 64-bit shifts are slow!
717  *
718  * Based on Algorithm 14.16 on pp.597 of HAC.
719  *
720  */
721 /* the jist of squaring...
722
723 you do like mult except the offset of the tmpx [one that starts closer to zero]
724 can't equal the offset of tmpy.  So basically you set up iy like before then you min it with
725 (ty-tx) so that it never happens.  You double all those you add in the inner loop
726
727 After that loop you do the squares and add them in.
728
729 Remove W2 and don't memset W
730
731 */
732
733 static int fast_s_mp_sqr (const mp_int * a, mp_int * b)
734 {
735   int       olduse, res, pa, ix, iz;
736   mp_digit   W[MP_WARRAY], *tmpx;
737   mp_word   W1;
738
739   /* grow the destination as required */
740   pa = a->used + a->used;
741   if (b->alloc < pa) {
742     if ((res = mp_grow (b, pa)) != MP_OKAY) {
743       return res;
744     }
745   }
746
747   /* number of output digits to produce */
748   W1 = 0;
749   for (ix = 0; ix <= pa; ix++) { 
750       int      tx, ty, iy;
751       mp_word  _W;
752       mp_digit *tmpy;
753
754       /* clear counter */
755       _W = 0;
756
757       /* get offsets into the two bignums */
758       ty = MIN(a->used-1, ix);
759       tx = ix - ty;
760
761       /* setup temp aliases */
762       tmpx = a->dp + tx;
763       tmpy = a->dp + ty;
764
765       /* This is the number of times the loop will iterate, essentially it's
766          while (tx++ < a->used && ty-- >= 0) { ... }
767        */
768       iy = MIN(a->used-tx, ty+1);
769
770       /* now for squaring tx can never equal ty 
771        * we halve the distance since they approach at a rate of 2x
772        * and we have to round because odd cases need to be executed
773        */
774       iy = MIN(iy, (ty-tx+1)>>1);
775
776       /* execute loop */
777       for (iz = 0; iz < iy; iz++) {
778          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
779       }
780
781       /* double the inner product and add carry */
782       _W = _W + _W + W1;
783
784       /* even columns have the square term in them */
785       if ((ix&1) == 0) {
786          _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
787       }
788
789       /* store it */
790       W[ix] = _W;
791
792       /* make next carry */
793       W1 = _W >> ((mp_word)DIGIT_BIT);
794   }
795
796   /* setup dest */
797   olduse  = b->used;
798   b->used = a->used+a->used;
799
800   {
801     mp_digit *tmpb;
802     tmpb = b->dp;
803     for (ix = 0; ix < pa; ix++) {
804       *tmpb++ = W[ix] & MP_MASK;
805     }
806
807     /* clear unused digits [that existed in the old copy of c] */
808     for (; ix < olduse; ix++) {
809       *tmpb++ = 0;
810     }
811   }
812   mp_clamp (b);
813   return MP_OKAY;
814 }
815
816 /* computes a = 2**b 
817  *
818  * Simple algorithm which zeroes the int, grows it then just sets one bit
819  * as required.
820  */
821 static int
822 mp_2expt (mp_int * a, int b)
823 {
824   int     res;
825
826   /* zero a as per default */
827   mp_zero (a);
828
829   /* grow a to accommodate the single bit */
830   if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
831     return res;
832   }
833
834   /* set the used count of where the bit will go */
835   a->used = b / DIGIT_BIT + 1;
836
837   /* put the single bit in its place */
838   a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
839
840   return MP_OKAY;
841 }
842
843 /* high level addition (handles signs) */
844 int mp_add (mp_int * a, mp_int * b, mp_int * c)
845 {
846   int     sa, sb, res;
847
848   /* get sign of both inputs */
849   sa = a->sign;
850   sb = b->sign;
851
852   /* handle two cases, not four */
853   if (sa == sb) {
854     /* both positive or both negative */
855     /* add their magnitudes, copy the sign */
856     c->sign = sa;
857     res = s_mp_add (a, b, c);
858   } else {
859     /* one positive, the other negative */
860     /* subtract the one with the greater magnitude from */
861     /* the one of the lesser magnitude.  The result gets */
862     /* the sign of the one with the greater magnitude. */
863     if (mp_cmp_mag (a, b) == MP_LT) {
864       c->sign = sb;
865       res = s_mp_sub (b, a, c);
866     } else {
867       c->sign = sa;
868       res = s_mp_sub (a, b, c);
869     }
870   }
871   return res;
872 }
873
874
875 /* single digit addition */
876 static int
877 mp_add_d (mp_int * a, mp_digit b, mp_int * c)
878 {
879   int     res, ix, oldused;
880   mp_digit *tmpa, *tmpc, mu;
881
882   /* grow c as required */
883   if (c->alloc < a->used + 1) {
884      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
885         return res;
886      }
887   }
888
889   /* if a is negative and |a| >= b, call c = |a| - b */
890   if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
891      /* temporarily fix sign of a */
892      a->sign = MP_ZPOS;
893
894      /* c = |a| - b */
895      res = mp_sub_d(a, b, c);
896
897      /* fix sign  */
898      a->sign = c->sign = MP_NEG;
899
900      return res;
901   }
902
903   /* old number of used digits in c */
904   oldused = c->used;
905
906   /* sign always positive */
907   c->sign = MP_ZPOS;
908
909   /* source alias */
910   tmpa    = a->dp;
911
912   /* destination alias */
913   tmpc    = c->dp;
914
915   /* if a is positive */
916   if (a->sign == MP_ZPOS) {
917      /* add digit, after this we're propagating
918       * the carry.
919       */
920      *tmpc   = *tmpa++ + b;
921      mu      = *tmpc >> DIGIT_BIT;
922      *tmpc++ &= MP_MASK;
923
924      /* now handle rest of the digits */
925      for (ix = 1; ix < a->used; ix++) {
926         *tmpc   = *tmpa++ + mu;
927         mu      = *tmpc >> DIGIT_BIT;
928         *tmpc++ &= MP_MASK;
929      }
930      /* set final carry */
931      ix++;
932      *tmpc++  = mu;
933
934      /* setup size */
935      c->used = a->used + 1;
936   } else {
937      /* a was negative and |a| < b */
938      c->used  = 1;
939
940      /* the result is a single digit */
941      if (a->used == 1) {
942         *tmpc++  =  b - a->dp[0];
943      } else {
944         *tmpc++  =  b;
945      }
946
947      /* setup count so the clearing of oldused
948       * can fall through correctly
949       */
950      ix       = 1;
951   }
952
953   /* now zero to oldused */
954   while (ix++ < oldused) {
955      *tmpc++ = 0;
956   }
957   mp_clamp(c);
958
959   return MP_OKAY;
960 }
961
962 /* trim unused digits 
963  *
964  * This is used to ensure that leading zero digits are
965  * trimed and the leading "used" digit will be non-zero
966  * Typically very fast.  Also fixes the sign if there
967  * are no more leading digits
968  */
969 void
970 mp_clamp (mp_int * a)
971 {
972   /* decrease used while the most significant digit is
973    * zero.
974    */
975   while (a->used > 0 && a->dp[a->used - 1] == 0) {
976     --(a->used);
977   }
978
979   /* reset the sign flag if used == 0 */
980   if (a->used == 0) {
981     a->sign = MP_ZPOS;
982   }
983 }
984
985 void mp_clear_multi(mp_int *mp, ...) 
986 {
987     mp_int* next_mp = mp;
988     va_list args;
989     va_start(args, mp);
990     while (next_mp != NULL) {
991         mp_clear(next_mp);
992         next_mp = va_arg(args, mp_int*);
993     }
994     va_end(args);
995 }
996
997 /* compare two ints (signed)*/
998 int
999 mp_cmp (const mp_int * a, const mp_int * b)
1000 {
1001   /* compare based on sign */
1002   if (a->sign != b->sign) {
1003      if (a->sign == MP_NEG) {
1004         return MP_LT;
1005      } else {
1006         return MP_GT;
1007      }
1008   }
1009   
1010   /* compare digits */
1011   if (a->sign == MP_NEG) {
1012      /* if negative compare opposite direction */
1013      return mp_cmp_mag(b, a);
1014   } else {
1015      return mp_cmp_mag(a, b);
1016   }
1017 }
1018
1019 /* compare a digit */
1020 int mp_cmp_d(const mp_int * a, mp_digit b)
1021 {
1022   /* compare based on sign */
1023   if (a->sign == MP_NEG) {
1024     return MP_LT;
1025   }
1026
1027   /* compare based on magnitude */
1028   if (a->used > 1) {
1029     return MP_GT;
1030   }
1031
1032   /* compare the only digit of a to b */
1033   if (a->dp[0] > b) {
1034     return MP_GT;
1035   } else if (a->dp[0] < b) {
1036     return MP_LT;
1037   } else {
1038     return MP_EQ;
1039   }
1040 }
1041
1042 /* compare maginitude of two ints (unsigned) */
1043 int mp_cmp_mag (const mp_int * a, const mp_int * b)
1044 {
1045   int     n;
1046   mp_digit *tmpa, *tmpb;
1047
1048   /* compare based on # of non-zero digits */
1049   if (a->used > b->used) {
1050     return MP_GT;
1051   }
1052   
1053   if (a->used < b->used) {
1054     return MP_LT;
1055   }
1056
1057   /* alias for a */
1058   tmpa = a->dp + (a->used - 1);
1059
1060   /* alias for b */
1061   tmpb = b->dp + (a->used - 1);
1062
1063   /* compare based on digits  */
1064   for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
1065     if (*tmpa > *tmpb) {
1066       return MP_GT;
1067     }
1068
1069     if (*tmpa < *tmpb) {
1070       return MP_LT;
1071     }
1072   }
1073   return MP_EQ;
1074 }
1075
1076 static const int lnz[16] = { 
1077    4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
1078 };
1079
1080 /* Counts the number of lsbs which are zero before the first zero bit */
1081 int mp_cnt_lsb(const mp_int *a)
1082 {
1083    int x;
1084    mp_digit q, qq;
1085
1086    /* easy out */
1087    if (mp_iszero(a) == 1) {
1088       return 0;
1089    }
1090
1091    /* scan lower digits until non-zero */
1092    for (x = 0; x < a->used && a->dp[x] == 0; x++);
1093    q = a->dp[x];
1094    x *= DIGIT_BIT;
1095
1096    /* now scan this digit until a 1 is found */
1097    if ((q & 1) == 0) {
1098       do {
1099          qq  = q & 15;
1100          x  += lnz[qq];
1101          q >>= 4;
1102       } while (qq == 0);
1103    }
1104    return x;
1105 }
1106
1107 /* copy, b = a */
1108 int
1109 mp_copy (const mp_int * a, mp_int * b)
1110 {
1111   int     res, n;
1112
1113   /* if dst == src do nothing */
1114   if (a == b) {
1115     return MP_OKAY;
1116   }
1117
1118   /* grow dest */
1119   if (b->alloc < a->used) {
1120      if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1121         return res;
1122      }
1123   }
1124
1125   /* zero b and copy the parameters over */
1126   {
1127     register mp_digit *tmpa, *tmpb;
1128
1129     /* pointer aliases */
1130
1131     /* source */
1132     tmpa = a->dp;
1133
1134     /* destination */
1135     tmpb = b->dp;
1136
1137     /* copy all the digits */
1138     for (n = 0; n < a->used; n++) {
1139       *tmpb++ = *tmpa++;
1140     }
1141
1142     /* clear high digits */
1143     for (; n < b->used; n++) {
1144       *tmpb++ = 0;
1145     }
1146   }
1147
1148   /* copy used count and sign */
1149   b->used = a->used;
1150   b->sign = a->sign;
1151   return MP_OKAY;
1152 }
1153
1154 /* returns the number of bits in an int */
1155 int
1156 mp_count_bits (const mp_int * a)
1157 {
1158   int     r;
1159   mp_digit q;
1160
1161   /* shortcut */
1162   if (a->used == 0) {
1163     return 0;
1164   }
1165
1166   /* get number of digits and add that */
1167   r = (a->used - 1) * DIGIT_BIT;
1168   
1169   /* take the last digit and count the bits in it */
1170   q = a->dp[a->used - 1];
1171   while (q > 0) {
1172     ++r;
1173     q >>= ((mp_digit) 1);
1174   }
1175   return r;
1176 }
1177
1178 /* calc a value mod 2**b */
1179 static int
1180 mp_mod_2d (const mp_int * a, int b, mp_int * c)
1181 {
1182   int     x, res;
1183
1184   /* if b is <= 0 then zero the int */
1185   if (b <= 0) {
1186     mp_zero (c);
1187     return MP_OKAY;
1188   }
1189
1190   /* if the modulus is larger than the value than return */
1191   if (b > a->used * DIGIT_BIT) {
1192     res = mp_copy (a, c);
1193     return res;
1194   }
1195
1196   /* copy */
1197   if ((res = mp_copy (a, c)) != MP_OKAY) {
1198     return res;
1199   }
1200
1201   /* zero digits above the last digit of the modulus */
1202   for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
1203     c->dp[x] = 0;
1204   }
1205   /* clear the digit that is not completely outside/inside the modulus */
1206   c->dp[b / DIGIT_BIT] &= (1 << ((mp_digit)b % DIGIT_BIT)) - 1;
1207   mp_clamp (c);
1208   return MP_OKAY;
1209 }
1210
1211 /* shift right a certain amount of digits */
1212 static void mp_rshd (mp_int * a, int b)
1213 {
1214   int     x;
1215
1216   /* if b <= 0 then ignore it */
1217   if (b <= 0) {
1218     return;
1219   }
1220
1221   /* if b > used then simply zero it and return */
1222   if (a->used <= b) {
1223     mp_zero (a);
1224     return;
1225   }
1226
1227   {
1228     register mp_digit *bottom, *top;
1229
1230     /* shift the digits down */
1231
1232     /* bottom */
1233     bottom = a->dp;
1234
1235     /* top [offset into digits] */
1236     top = a->dp + b;
1237
1238     /* this is implemented as a sliding window where
1239      * the window is b-digits long and digits from
1240      * the top of the window are copied to the bottom
1241      *
1242      * e.g.
1243
1244      b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
1245                  /\                   |      ---->
1246                   \-------------------/      ---->
1247      */
1248     for (x = 0; x < (a->used - b); x++) {
1249       *bottom++ = *top++;
1250     }
1251
1252     /* zero the top digits */
1253     for (; x < a->used; x++) {
1254       *bottom++ = 0;
1255     }
1256   }
1257
1258   /* remove excess digits */
1259   a->used -= b;
1260 }
1261
1262 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
1263 static int mp_div_2d (const mp_int * a, int b, mp_int * c, mp_int * d)
1264 {
1265   mp_digit D, r, rr;
1266   int     x, res;
1267   mp_int  t;
1268
1269
1270   /* if the shift count is <= 0 then we do no work */
1271   if (b <= 0) {
1272     res = mp_copy (a, c);
1273     if (d != NULL) {
1274       mp_zero (d);
1275     }
1276     return res;
1277   }
1278
1279   if ((res = mp_init (&t)) != MP_OKAY) {
1280     return res;
1281   }
1282
1283   /* get the remainder */
1284   if (d != NULL) {
1285     if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
1286       mp_clear (&t);
1287       return res;
1288     }
1289   }
1290
1291   /* copy */
1292   if ((res = mp_copy (a, c)) != MP_OKAY) {
1293     mp_clear (&t);
1294     return res;
1295   }
1296
1297   /* shift by as many digits in the bit count */
1298   if (b >= DIGIT_BIT) {
1299     mp_rshd (c, b / DIGIT_BIT);
1300   }
1301
1302   /* shift any bit count < DIGIT_BIT */
1303   D = (mp_digit) (b % DIGIT_BIT);
1304   if (D != 0) {
1305     register mp_digit *tmpc, mask, shift;
1306
1307     /* mask */
1308     mask = (((mp_digit)1) << D) - 1;
1309
1310     /* shift for lsb */
1311     shift = DIGIT_BIT - D;
1312
1313     /* alias */
1314     tmpc = c->dp + (c->used - 1);
1315
1316     /* carry */
1317     r = 0;
1318     for (x = c->used - 1; x >= 0; x--) {
1319       /* get the lower  bits of this word in a temp */
1320       rr = *tmpc & mask;
1321
1322       /* shift the current word and mix in the carry bits from the previous word */
1323       *tmpc = (*tmpc >> D) | (r << shift);
1324       --tmpc;
1325
1326       /* set the carry to the carry bits of the current word found above */
1327       r = rr;
1328     }
1329   }
1330   mp_clamp (c);
1331   if (d != NULL) {
1332     mp_exch (&t, d);
1333   }
1334   mp_clear (&t);
1335   return MP_OKAY;
1336 }
1337
1338 /* shift left a certain amount of digits */
1339 static int mp_lshd (mp_int * a, int b)
1340 {
1341   int     x, res;
1342
1343   /* if its less than zero return */
1344   if (b <= 0) {
1345     return MP_OKAY;
1346   }
1347
1348   /* grow to fit the new digits */
1349   if (a->alloc < a->used + b) {
1350      if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
1351        return res;
1352      }
1353   }
1354
1355   {
1356     register mp_digit *top, *bottom;
1357
1358     /* increment the used by the shift amount then copy upwards */
1359     a->used += b;
1360
1361     /* top */
1362     top = a->dp + a->used - 1;
1363
1364     /* base */
1365     bottom = a->dp + a->used - 1 - b;
1366
1367     /* much like mp_rshd this is implemented using a sliding window
1368      * except the window goes the otherway around.  Copying from
1369      * the bottom to the top.  see bn_mp_rshd.c for more info.
1370      */
1371     for (x = a->used - 1; x >= b; x--) {
1372       *top-- = *bottom--;
1373     }
1374
1375     /* zero the lower digits */
1376     top = a->dp;
1377     for (x = 0; x < b; x++) {
1378       *top++ = 0;
1379     }
1380   }
1381   return MP_OKAY;
1382 }
1383
1384 /* shift left by a certain bit count */
1385 static int mp_mul_2d (const mp_int * a, int b, mp_int * c)
1386 {
1387   mp_digit d;
1388   int      res;
1389
1390   /* copy */
1391   if (a != c) {
1392      if ((res = mp_copy (a, c)) != MP_OKAY) {
1393        return res;
1394      }
1395   }
1396
1397   if (c->alloc < c->used + b/DIGIT_BIT + 1) {
1398      if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
1399        return res;
1400      }
1401   }
1402
1403   /* shift by as many digits in the bit count */
1404   if (b >= DIGIT_BIT) {
1405     if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
1406       return res;
1407     }
1408   }
1409
1410   /* shift any bit count < DIGIT_BIT */
1411   d = (mp_digit) (b % DIGIT_BIT);
1412   if (d != 0) {
1413     register mp_digit *tmpc, shift, mask, r, rr;
1414     register int x;
1415
1416     /* bitmask for carries */
1417     mask = (((mp_digit)1) << d) - 1;
1418
1419     /* shift for msbs */
1420     shift = DIGIT_BIT - d;
1421
1422     /* alias */
1423     tmpc = c->dp;
1424
1425     /* carry */
1426     r    = 0;
1427     for (x = 0; x < c->used; x++) {
1428       /* get the higher bits of the current word */
1429       rr = (*tmpc >> shift) & mask;
1430
1431       /* shift the current word and OR in the carry */
1432       *tmpc = ((*tmpc << d) | r) & MP_MASK;
1433       ++tmpc;
1434
1435       /* set the carry to the carry bits of the current word */
1436       r = rr;
1437     }
1438
1439     /* set final carry */
1440     if (r != 0) {
1441        c->dp[(c->used)++] = r;
1442     }
1443   }
1444   mp_clamp (c);
1445   return MP_OKAY;
1446 }
1447
1448 /* multiply by a digit */
1449 static int
1450 mp_mul_d (const mp_int * a, mp_digit b, mp_int * c)
1451 {
1452   mp_digit u, *tmpa, *tmpc;
1453   mp_word  r;
1454   int      ix, res, olduse;
1455
1456   /* make sure c is big enough to hold a*b */
1457   if (c->alloc < a->used + 1) {
1458     if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
1459       return res;
1460     }
1461   }
1462
1463   /* get the original destinations used count */
1464   olduse = c->used;
1465
1466   /* set the sign */
1467   c->sign = a->sign;
1468
1469   /* alias for a->dp [source] */
1470   tmpa = a->dp;
1471
1472   /* alias for c->dp [dest] */
1473   tmpc = c->dp;
1474
1475   /* zero carry */
1476   u = 0;
1477
1478   /* compute columns */
1479   for (ix = 0; ix < a->used; ix++) {
1480     /* compute product and carry sum for this term */
1481     r       = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
1482
1483     /* mask off higher bits to get a single digit */
1484     *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
1485
1486     /* send carry into next iteration */
1487     u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
1488   }
1489
1490   /* store final carry [if any] */
1491   *tmpc++ = u;
1492
1493   /* now zero digits above the top */
1494   while (ix++ < olduse) {
1495      *tmpc++ = 0;
1496   }
1497
1498   /* set used count */
1499   c->used = a->used + 1;
1500   mp_clamp(c);
1501
1502   return MP_OKAY;
1503 }
1504
1505 /* integer signed division. 
1506  * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
1507  * HAC pp.598 Algorithm 14.20
1508  *
1509  * Note that the description in HAC is horribly 
1510  * incomplete.  For example, it doesn't consider 
1511  * the case where digits are removed from 'x' in 
1512  * the inner loop.  It also doesn't consider the 
1513  * case that y has fewer than three digits, etc..
1514  *
1515  * The overall algorithm is as described as 
1516  * 14.20 from HAC but fixed to treat these cases.
1517 */
1518 static int mp_div (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
1519 {
1520   mp_int  q, x, y, t1, t2;
1521   int     res, n, t, i, norm, neg;
1522
1523   /* is divisor zero ? */
1524   if (mp_iszero (b) == 1) {
1525     return MP_VAL;
1526   }
1527
1528   /* if a < b then q=0, r = a */
1529   if (mp_cmp_mag (a, b) == MP_LT) {
1530     if (d != NULL) {
1531       res = mp_copy (a, d);
1532     } else {
1533       res = MP_OKAY;
1534     }
1535     if (c != NULL) {
1536       mp_zero (c);
1537     }
1538     return res;
1539   }
1540
1541   if ((res = mp_init_size (&q, a->used + 2)) != MP_OKAY) {
1542     return res;
1543   }
1544   q.used = a->used + 2;
1545
1546   if ((res = mp_init (&t1)) != MP_OKAY) {
1547     goto __Q;
1548   }
1549
1550   if ((res = mp_init (&t2)) != MP_OKAY) {
1551     goto __T1;
1552   }
1553
1554   if ((res = mp_init_copy (&x, a)) != MP_OKAY) {
1555     goto __T2;
1556   }
1557
1558   if ((res = mp_init_copy (&y, b)) != MP_OKAY) {
1559     goto __X;
1560   }
1561
1562   /* fix the sign */
1563   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1564   x.sign = y.sign = MP_ZPOS;
1565
1566   /* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1567   norm = mp_count_bits(&y) % DIGIT_BIT;
1568   if (norm < DIGIT_BIT-1) {
1569      norm = (DIGIT_BIT-1) - norm;
1570      if ((res = mp_mul_2d (&x, norm, &x)) != MP_OKAY) {
1571        goto __Y;
1572      }
1573      if ((res = mp_mul_2d (&y, norm, &y)) != MP_OKAY) {
1574        goto __Y;
1575      }
1576   } else {
1577      norm = 0;
1578   }
1579
1580   /* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
1581   n = x.used - 1;
1582   t = y.used - 1;
1583
1584   /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1585   if ((res = mp_lshd (&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1586     goto __Y;
1587   }
1588
1589   while (mp_cmp (&x, &y) != MP_LT) {
1590     ++(q.dp[n - t]);
1591     if ((res = mp_sub (&x, &y, &x)) != MP_OKAY) {
1592       goto __Y;
1593     }
1594   }
1595
1596   /* reset y by shifting it back down */
1597   mp_rshd (&y, n - t);
1598
1599   /* step 3. for i from n down to (t + 1) */
1600   for (i = n; i >= (t + 1); i--) {
1601     if (i > x.used) {
1602       continue;
1603     }
1604
1605     /* step 3.1 if xi == yt then set q{i-t-1} to b-1, 
1606      * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1607     if (x.dp[i] == y.dp[t]) {
1608       q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1609     } else {
1610       mp_word tmp;
1611       tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1612       tmp |= ((mp_word) x.dp[i - 1]);
1613       tmp /= ((mp_word) y.dp[t]);
1614       if (tmp > (mp_word) MP_MASK)
1615         tmp = MP_MASK;
1616       q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1617     }
1618
1619     /* while (q{i-t-1} * (yt * b + y{t-1})) > 
1620              xi * b**2 + xi-1 * b + xi-2 
1621      
1622        do q{i-t-1} -= 1; 
1623     */
1624     q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1625     do {
1626       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1627
1628       /* find left hand */
1629       mp_zero (&t1);
1630       t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1631       t1.dp[1] = y.dp[t];
1632       t1.used = 2;
1633       if ((res = mp_mul_d (&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1634         goto __Y;
1635       }
1636
1637       /* find right hand */
1638       t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1639       t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1640       t2.dp[2] = x.dp[i];
1641       t2.used = 3;
1642     } while (mp_cmp_mag(&t1, &t2) == MP_GT);
1643
1644     /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1645     if ((res = mp_mul_d (&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1646       goto __Y;
1647     }
1648
1649     if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1650       goto __Y;
1651     }
1652
1653     if ((res = mp_sub (&x, &t1, &x)) != MP_OKAY) {
1654       goto __Y;
1655     }
1656
1657     /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1658     if (x.sign == MP_NEG) {
1659       if ((res = mp_copy (&y, &t1)) != MP_OKAY) {
1660         goto __Y;
1661       }
1662       if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1663         goto __Y;
1664       }
1665       if ((res = mp_add (&x, &t1, &x)) != MP_OKAY) {
1666         goto __Y;
1667       }
1668
1669       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1670     }
1671   }
1672
1673   /* now q is the quotient and x is the remainder 
1674    * [which we have to normalize] 
1675    */
1676   
1677   /* get sign before writing to c */
1678   x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1679
1680   if (c != NULL) {
1681     mp_clamp (&q);
1682     mp_exch (&q, c);
1683     c->sign = neg;
1684   }
1685
1686   if (d != NULL) {
1687     mp_div_2d (&x, norm, &x, NULL);
1688     mp_exch (&x, d);
1689   }
1690
1691   res = MP_OKAY;
1692
1693 __Y:mp_clear (&y);
1694 __X:mp_clear (&x);
1695 __T2:mp_clear (&t2);
1696 __T1:mp_clear (&t1);
1697 __Q:mp_clear (&q);
1698   return res;
1699 }
1700
1701 static int s_is_power_of_two(mp_digit b, int *p)
1702 {
1703    int x;
1704
1705    for (x = 1; x < DIGIT_BIT; x++) {
1706       if (b == (((mp_digit)1)<<x)) {
1707          *p = x;
1708          return 1;
1709       }
1710    }
1711    return 0;
1712 }
1713
1714 /* single digit division (based on routine from MPI) */
1715 static int mp_div_d (const mp_int * a, mp_digit b, mp_int * c, mp_digit * d)
1716 {
1717   mp_int  q;
1718   mp_word w;
1719   mp_digit t;
1720   int     res, ix;
1721
1722   /* cannot divide by zero */
1723   if (b == 0) {
1724      return MP_VAL;
1725   }
1726
1727   /* quick outs */
1728   if (b == 1 || mp_iszero(a) == 1) {
1729      if (d != NULL) {
1730         *d = 0;
1731      }
1732      if (c != NULL) {
1733         return mp_copy(a, c);
1734      }
1735      return MP_OKAY;
1736   }
1737
1738   /* power of two ? */
1739   if (s_is_power_of_two(b, &ix) == 1) {
1740      if (d != NULL) {
1741         *d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
1742      }
1743      if (c != NULL) {
1744         return mp_div_2d(a, ix, c, NULL);
1745      }
1746      return MP_OKAY;
1747   }
1748
1749   /* no easy answer [c'est la vie].  Just division */
1750   if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
1751      return res;
1752   }
1753   
1754   q.used = a->used;
1755   q.sign = a->sign;
1756   w = 0;
1757   for (ix = a->used - 1; ix >= 0; ix--) {
1758      w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
1759      
1760      if (w >= b) {
1761         t = (mp_digit)(w / b);
1762         w -= ((mp_word)t) * ((mp_word)b);
1763       } else {
1764         t = 0;
1765       }
1766       q.dp[ix] = t;
1767   }
1768
1769   if (d != NULL) {
1770      *d = (mp_digit)w;
1771   }
1772   
1773   if (c != NULL) {
1774      mp_clamp(&q);
1775      mp_exch(&q, c);
1776   }
1777   mp_clear(&q);
1778   
1779   return res;
1780 }
1781
1782 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
1783  *
1784  * Based on algorithm from the paper
1785  *
1786  * "Generating Efficient Primes for Discrete Log Cryptosystems"
1787  *                 Chae Hoon Lim, Pil Loong Lee,
1788  *          POSTECH Information Research Laboratories
1789  *
1790  * The modulus must be of a special format [see manual]
1791  *
1792  * Has been modified to use algorithm 7.10 from the LTM book instead
1793  *
1794  * Input x must be in the range 0 <= x <= (n-1)**2
1795  */
1796 static int
1797 mp_dr_reduce (mp_int * x, const mp_int * n, mp_digit k)
1798 {
1799   int      err, i, m;
1800   mp_word  r;
1801   mp_digit mu, *tmpx1, *tmpx2;
1802
1803   /* m = digits in modulus */
1804   m = n->used;
1805
1806   /* ensure that "x" has at least 2m digits */
1807   if (x->alloc < m + m) {
1808     if ((err = mp_grow (x, m + m)) != MP_OKAY) {
1809       return err;
1810     }
1811   }
1812
1813 /* top of loop, this is where the code resumes if
1814  * another reduction pass is required.
1815  */
1816 top:
1817   /* aliases for digits */
1818   /* alias for lower half of x */
1819   tmpx1 = x->dp;
1820
1821   /* alias for upper half of x, or x/B**m */
1822   tmpx2 = x->dp + m;
1823
1824   /* set carry to zero */
1825   mu = 0;
1826
1827   /* compute (x mod B**m) + k * [x/B**m] inline and inplace */
1828   for (i = 0; i < m; i++) {
1829       r         = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
1830       *tmpx1++  = (mp_digit)(r & MP_MASK);
1831       mu        = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
1832   }
1833
1834   /* set final carry */
1835   *tmpx1++ = mu;
1836
1837   /* zero words above m */
1838   for (i = m + 1; i < x->used; i++) {
1839       *tmpx1++ = 0;
1840   }
1841
1842   /* clamp, sub and return */
1843   mp_clamp (x);
1844
1845   /* if x >= n then subtract and reduce again
1846    * Each successive "recursion" makes the input smaller and smaller.
1847    */
1848   if (mp_cmp_mag (x, n) != MP_LT) {
1849     s_mp_sub(x, n, x);
1850     goto top;
1851   }
1852   return MP_OKAY;
1853 }
1854
1855 /* sets the value of "d" required for mp_dr_reduce */
1856 static void mp_dr_setup(const mp_int *a, mp_digit *d)
1857 {
1858    /* the casts are required if DIGIT_BIT is one less than
1859     * the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
1860     */
1861    *d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) - 
1862         ((mp_word)a->dp[0]));
1863 }
1864
1865 /* this is a shell function that calls either the normal or Montgomery
1866  * exptmod functions.  Originally the call to the montgomery code was
1867  * embedded in the normal function but that wasted a lot of stack space
1868  * for nothing (since 99% of the time the Montgomery code would be called)
1869  */
1870 int mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
1871 {
1872   int dr;
1873
1874   /* modulus P must be positive */
1875   if (P->sign == MP_NEG) {
1876      return MP_VAL;
1877   }
1878
1879   /* if exponent X is negative we have to recurse */
1880   if (X->sign == MP_NEG) {
1881      mp_int tmpG, tmpX;
1882      int err;
1883
1884      /* first compute 1/G mod P */
1885      if ((err = mp_init(&tmpG)) != MP_OKAY) {
1886         return err;
1887      }
1888      if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
1889         mp_clear(&tmpG);
1890         return err;
1891      }
1892
1893      /* now get |X| */
1894      if ((err = mp_init(&tmpX)) != MP_OKAY) {
1895         mp_clear(&tmpG);
1896         return err;
1897      }
1898      if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
1899         mp_clear_multi(&tmpG, &tmpX, NULL);
1900         return err;
1901      }
1902
1903      /* and now compute (1/G)**|X| instead of G**X [X < 0] */
1904      err = mp_exptmod(&tmpG, &tmpX, P, Y);
1905      mp_clear_multi(&tmpG, &tmpX, NULL);
1906      return err;
1907   }
1908
1909   dr = 0;
1910
1911   /* if the modulus is odd or dr != 0 use the fast method */
1912   if (mp_isodd (P) == 1 || dr !=  0) {
1913     return mp_exptmod_fast (G, X, P, Y, dr);
1914   } else {
1915     /* otherwise use the generic Barrett reduction technique */
1916     return s_mp_exptmod (G, X, P, Y);
1917   }
1918 }
1919
1920 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
1921  *
1922  * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
1923  * The value of k changes based on the size of the exponent.
1924  *
1925  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
1926  */
1927
1928 int
1929 mp_exptmod_fast (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y, int redmode)
1930 {
1931   mp_int  M[256], res;
1932   mp_digit buf, mp;
1933   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1934
1935   /* use a pointer to the reduction algorithm.  This allows us to use
1936    * one of many reduction algorithms without modding the guts of
1937    * the code with if statements everywhere.
1938    */
1939   int     (*redux)(mp_int*,const mp_int*,mp_digit);
1940
1941   /* find window size */
1942   x = mp_count_bits (X);
1943   if (x <= 7) {
1944     winsize = 2;
1945   } else if (x <= 36) {
1946     winsize = 3;
1947   } else if (x <= 140) {
1948     winsize = 4;
1949   } else if (x <= 450) {
1950     winsize = 5;
1951   } else if (x <= 1303) {
1952     winsize = 6;
1953   } else if (x <= 3529) {
1954     winsize = 7;
1955   } else {
1956     winsize = 8;
1957   }
1958
1959   /* init M array */
1960   /* init first cell */
1961   if ((err = mp_init(&M[1])) != MP_OKAY) {
1962      return err;
1963   }
1964
1965   /* now init the second half of the array */
1966   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
1967     if ((err = mp_init(&M[x])) != MP_OKAY) {
1968       for (y = 1<<(winsize-1); y < x; y++) {
1969         mp_clear (&M[y]);
1970       }
1971       mp_clear(&M[1]);
1972       return err;
1973     }
1974   }
1975
1976   /* determine and setup reduction code */
1977   if (redmode == 0) {
1978      /* now setup montgomery  */
1979      if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
1980         goto __M;
1981      }
1982
1983      /* automatically pick the comba one if available (saves quite a few calls/ifs) */
1984      if (((P->used * 2 + 1) < MP_WARRAY) &&
1985           P->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
1986         redux = fast_mp_montgomery_reduce;
1987      } else {
1988         /* use slower baseline Montgomery method */
1989         redux = mp_montgomery_reduce;
1990      }
1991   } else if (redmode == 1) {
1992      /* setup DR reduction for moduli of the form B**k - b */
1993      mp_dr_setup(P, &mp);
1994      redux = mp_dr_reduce;
1995   } else {
1996      /* setup DR reduction for moduli of the form 2**k - b */
1997      if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
1998         goto __M;
1999      }
2000      redux = mp_reduce_2k;
2001   }
2002
2003   /* setup result */
2004   if ((err = mp_init (&res)) != MP_OKAY) {
2005     goto __M;
2006   }
2007
2008   /* create M table
2009    *
2010
2011    *
2012    * The first half of the table is not computed though accept for M[0] and M[1]
2013    */
2014
2015   if (redmode == 0) {
2016      /* now we need R mod m */
2017      if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
2018        goto __RES;
2019      }
2020
2021      /* now set M[1] to G * R mod m */
2022      if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
2023        goto __RES;
2024      }
2025   } else {
2026      mp_set(&res, 1);
2027      if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
2028         goto __RES;
2029      }
2030   }
2031
2032   /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
2033   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
2034     goto __RES;
2035   }
2036
2037   for (x = 0; x < (winsize - 1); x++) {
2038     if ((err = mp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
2039       goto __RES;
2040     }
2041     if ((err = redux (&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
2042       goto __RES;
2043     }
2044   }
2045
2046   /* create upper table */
2047   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
2048     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
2049       goto __RES;
2050     }
2051     if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
2052       goto __RES;
2053     }
2054   }
2055
2056   /* set initial mode and bit cnt */
2057   mode   = 0;
2058   bitcnt = 1;
2059   buf    = 0;
2060   digidx = X->used - 1;
2061   bitcpy = 0;
2062   bitbuf = 0;
2063
2064   for (;;) {
2065     /* grab next digit as required */
2066     if (--bitcnt == 0) {
2067       /* if digidx == -1 we are out of digits so break */
2068       if (digidx == -1) {
2069         break;
2070       }
2071       /* read next digit and reset bitcnt */
2072       buf    = X->dp[digidx--];
2073       bitcnt = DIGIT_BIT;
2074     }
2075
2076     /* grab the next msb from the exponent */
2077     y     = (buf >> (DIGIT_BIT - 1)) & 1;
2078     buf <<= (mp_digit)1;
2079
2080     /* if the bit is zero and mode == 0 then we ignore it
2081      * These represent the leading zero bits before the first 1 bit
2082      * in the exponent.  Technically this opt is not required but it
2083      * does lower the # of trivial squaring/reductions used
2084      */
2085     if (mode == 0 && y == 0) {
2086       continue;
2087     }
2088
2089     /* if the bit is zero and mode == 1 then we square */
2090     if (mode == 1 && y == 0) {
2091       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2092         goto __RES;
2093       }
2094       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2095         goto __RES;
2096       }
2097       continue;
2098     }
2099
2100     /* else we add it to the window */
2101     bitbuf |= (y << (winsize - ++bitcpy));
2102     mode    = 2;
2103
2104     if (bitcpy == winsize) {
2105       /* ok window is filled so square as required and multiply  */
2106       /* square first */
2107       for (x = 0; x < winsize; x++) {
2108         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2109           goto __RES;
2110         }
2111         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2112           goto __RES;
2113         }
2114       }
2115
2116       /* then multiply */
2117       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
2118         goto __RES;
2119       }
2120       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2121         goto __RES;
2122       }
2123
2124       /* empty window and reset */
2125       bitcpy = 0;
2126       bitbuf = 0;
2127       mode   = 1;
2128     }
2129   }
2130
2131   /* if bits remain then square/multiply */
2132   if (mode == 2 && bitcpy > 0) {
2133     /* square then multiply if the bit is set */
2134     for (x = 0; x < bitcpy; x++) {
2135       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2136         goto __RES;
2137       }
2138       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2139         goto __RES;
2140       }
2141
2142       /* get next bit of the window */
2143       bitbuf <<= 1;
2144       if ((bitbuf & (1 << winsize)) != 0) {
2145         /* then multiply */
2146         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
2147           goto __RES;
2148         }
2149         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2150           goto __RES;
2151         }
2152       }
2153     }
2154   }
2155
2156   if (redmode == 0) {
2157      /* fixup result if Montgomery reduction is used
2158       * recall that any value in a Montgomery system is
2159       * actually multiplied by R mod n.  So we have
2160       * to reduce one more time to cancel out the factor
2161       * of R.
2162       */
2163      if ((err = redux(&res, P, mp)) != MP_OKAY) {
2164        goto __RES;
2165      }
2166   }
2167
2168   /* swap res with Y */
2169   mp_exch (&res, Y);
2170   err = MP_OKAY;
2171 __RES:mp_clear (&res);
2172 __M:
2173   mp_clear(&M[1]);
2174   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2175     mp_clear (&M[x]);
2176   }
2177   return err;
2178 }
2179
2180 /* Greatest Common Divisor using the binary method */
2181 int mp_gcd (const mp_int * a, const mp_int * b, mp_int * c)
2182 {
2183   mp_int  u, v;
2184   int     k, u_lsb, v_lsb, res;
2185
2186   /* either zero than gcd is the largest */
2187   if (mp_iszero (a) == 1 && mp_iszero (b) == 0) {
2188     return mp_abs (b, c);
2189   }
2190   if (mp_iszero (a) == 0 && mp_iszero (b) == 1) {
2191     return mp_abs (a, c);
2192   }
2193
2194   /* optimized.  At this point if a == 0 then
2195    * b must equal zero too
2196    */
2197   if (mp_iszero (a) == 1) {
2198     mp_zero(c);
2199     return MP_OKAY;
2200   }
2201
2202   /* get copies of a and b we can modify */
2203   if ((res = mp_init_copy (&u, a)) != MP_OKAY) {
2204     return res;
2205   }
2206
2207   if ((res = mp_init_copy (&v, b)) != MP_OKAY) {
2208     goto __U;
2209   }
2210
2211   /* must be positive for the remainder of the algorithm */
2212   u.sign = v.sign = MP_ZPOS;
2213
2214   /* B1.  Find the common power of two for u and v */
2215   u_lsb = mp_cnt_lsb(&u);
2216   v_lsb = mp_cnt_lsb(&v);
2217   k     = MIN(u_lsb, v_lsb);
2218
2219   if (k > 0) {
2220      /* divide the power of two out */
2221      if ((res = mp_div_2d(&u, k, &u, NULL)) != MP_OKAY) {
2222         goto __V;
2223      }
2224
2225      if ((res = mp_div_2d(&v, k, &v, NULL)) != MP_OKAY) {
2226         goto __V;
2227      }
2228   }
2229
2230   /* divide any remaining factors of two out */
2231   if (u_lsb != k) {
2232      if ((res = mp_div_2d(&u, u_lsb - k, &u, NULL)) != MP_OKAY) {
2233         goto __V;
2234      }
2235   }
2236
2237   if (v_lsb != k) {
2238      if ((res = mp_div_2d(&v, v_lsb - k, &v, NULL)) != MP_OKAY) {
2239         goto __V;
2240      }
2241   }
2242
2243   while (mp_iszero(&v) == 0) {
2244      /* make sure v is the largest */
2245      if (mp_cmp_mag(&u, &v) == MP_GT) {
2246         /* swap u and v to make sure v is >= u */
2247         mp_exch(&u, &v);
2248      }
2249      
2250      /* subtract smallest from largest */
2251      if ((res = s_mp_sub(&v, &u, &v)) != MP_OKAY) {
2252         goto __V;
2253      }
2254      
2255      /* Divide out all factors of two */
2256      if ((res = mp_div_2d(&v, mp_cnt_lsb(&v), &v, NULL)) != MP_OKAY) {
2257         goto __V;
2258      } 
2259   } 
2260
2261   /* multiply by 2**k which we divided out at the beginning */
2262   if ((res = mp_mul_2d (&u, k, c)) != MP_OKAY) {
2263      goto __V;
2264   }
2265   c->sign = MP_ZPOS;
2266   res = MP_OKAY;
2267 __V:mp_clear (&u);
2268 __U:mp_clear (&v);
2269   return res;
2270 }
2271
2272 /* get the lower 32-bits of an mp_int */
2273 unsigned long mp_get_int(const mp_int * a)
2274 {
2275   int i;
2276   unsigned long res;
2277
2278   if (a->used == 0) {
2279      return 0;
2280   }
2281
2282   /* get number of digits of the lsb we have to read */
2283   i = MIN(a->used,(int)((sizeof(unsigned long)*CHAR_BIT+DIGIT_BIT-1)/DIGIT_BIT))-1;
2284
2285   /* get most significant digit of result */
2286   res = DIGIT(a,i);
2287    
2288   while (--i >= 0) {
2289     res = (res << DIGIT_BIT) | DIGIT(a,i);
2290   }
2291
2292   /* force result to 32-bits always so it is consistent on non 32-bit platforms */
2293   return res & 0xFFFFFFFFUL;
2294 }
2295
2296 /* creates "a" then copies b into it */
2297 int mp_init_copy (mp_int * a, const mp_int * b)
2298 {
2299   int     res;
2300
2301   if ((res = mp_init (a)) != MP_OKAY) {
2302     return res;
2303   }
2304   return mp_copy (b, a);
2305 }
2306
2307 int mp_init_multi(mp_int *mp, ...) 
2308 {
2309     mp_err res = MP_OKAY;      /* Assume ok until proven otherwise */
2310     int n = 0;                 /* Number of ok inits */
2311     mp_int* cur_arg = mp;
2312     va_list args;
2313
2314     va_start(args, mp);        /* init args to next argument from caller */
2315     while (cur_arg != NULL) {
2316         if (mp_init(cur_arg) != MP_OKAY) {
2317             /* Oops - error! Back-track and mp_clear what we already
2318                succeeded in init-ing, then return error.
2319             */
2320             va_list clean_args;
2321             
2322             /* end the current list */
2323             va_end(args);
2324             
2325             /* now start cleaning up */            
2326             cur_arg = mp;
2327             va_start(clean_args, mp);
2328             while (n--) {
2329                 mp_clear(cur_arg);
2330                 cur_arg = va_arg(clean_args, mp_int*);
2331             }
2332             va_end(clean_args);
2333             res = MP_MEM;
2334             break;
2335         }
2336         n++;
2337         cur_arg = va_arg(args, mp_int*);
2338     }
2339     va_end(args);
2340     return res;                /* Assumed ok, if error flagged above. */
2341 }
2342
2343 /* hac 14.61, pp608 */
2344 int mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
2345 {
2346   /* b cannot be negative */
2347   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2348     return MP_VAL;
2349   }
2350
2351   /* if the modulus is odd we can use a faster routine instead */
2352   if (mp_isodd (b) == 1) {
2353     return fast_mp_invmod (a, b, c);
2354   }
2355   
2356   return mp_invmod_slow(a, b, c);
2357 }
2358
2359 /* hac 14.61, pp608 */
2360 int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c)
2361 {
2362   mp_int  x, y, u, v, A, B, C, D;
2363   int     res;
2364
2365   /* b cannot be negative */
2366   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2367     return MP_VAL;
2368   }
2369
2370   /* init temps */
2371   if ((res = mp_init_multi(&x, &y, &u, &v, 
2372                            &A, &B, &C, &D, NULL)) != MP_OKAY) {
2373      return res;
2374   }
2375
2376   /* x = a, y = b */
2377   if ((res = mp_copy (a, &x)) != MP_OKAY) {
2378     goto __ERR;
2379   }
2380   if ((res = mp_copy (b, &y)) != MP_OKAY) {
2381     goto __ERR;
2382   }
2383
2384   /* 2. [modified] if x,y are both even then return an error! */
2385   if (mp_iseven (&x) == 1 && mp_iseven (&y) == 1) {
2386     res = MP_VAL;
2387     goto __ERR;
2388   }
2389
2390   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
2391   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
2392     goto __ERR;
2393   }
2394   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
2395     goto __ERR;
2396   }
2397   mp_set (&A, 1);
2398   mp_set (&D, 1);
2399
2400 top:
2401   /* 4.  while u is even do */
2402   while (mp_iseven (&u) == 1) {
2403     /* 4.1 u = u/2 */
2404     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
2405       goto __ERR;
2406     }
2407     /* 4.2 if A or B is odd then */
2408     if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
2409       /* A = (A+y)/2, B = (B-x)/2 */
2410       if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
2411          goto __ERR;
2412       }
2413       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
2414          goto __ERR;
2415       }
2416     }
2417     /* A = A/2, B = B/2 */
2418     if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
2419       goto __ERR;
2420     }
2421     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
2422       goto __ERR;
2423     }
2424   }
2425
2426   /* 5.  while v is even do */
2427   while (mp_iseven (&v) == 1) {
2428     /* 5.1 v = v/2 */
2429     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
2430       goto __ERR;
2431     }
2432     /* 5.2 if C or D is odd then */
2433     if (mp_isodd (&C) == 1 || mp_isodd (&D) == 1) {
2434       /* C = (C+y)/2, D = (D-x)/2 */
2435       if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
2436          goto __ERR;
2437       }
2438       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
2439          goto __ERR;
2440       }
2441     }
2442     /* C = C/2, D = D/2 */
2443     if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
2444       goto __ERR;
2445     }
2446     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
2447       goto __ERR;
2448     }
2449   }
2450
2451   /* 6.  if u >= v then */
2452   if (mp_cmp (&u, &v) != MP_LT) {
2453     /* u = u - v, A = A - C, B = B - D */
2454     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
2455       goto __ERR;
2456     }
2457
2458     if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
2459       goto __ERR;
2460     }
2461
2462     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
2463       goto __ERR;
2464     }
2465   } else {
2466     /* v - v - u, C = C - A, D = D - B */
2467     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
2468       goto __ERR;
2469     }
2470
2471     if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
2472       goto __ERR;
2473     }
2474
2475     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
2476       goto __ERR;
2477     }
2478   }
2479
2480   /* if not zero goto step 4 */
2481   if (mp_iszero (&u) == 0)
2482     goto top;
2483
2484   /* now a = C, b = D, gcd == g*v */
2485
2486   /* if v != 1 then there is no inverse */
2487   if (mp_cmp_d (&v, 1) != MP_EQ) {
2488     res = MP_VAL;
2489     goto __ERR;
2490   }
2491
2492   /* if its too low */
2493   while (mp_cmp_d(&C, 0) == MP_LT) {
2494       if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
2495          goto __ERR;
2496       }
2497   }
2498   
2499   /* too big */
2500   while (mp_cmp_mag(&C, b) != MP_LT) {
2501       if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
2502          goto __ERR;
2503       }
2504   }
2505   
2506   /* C is now the inverse */
2507   mp_exch (&C, c);
2508   res = MP_OKAY;
2509 __ERR:mp_clear_multi (&x, &y, &u, &v, &A, &B, &C, &D, NULL);
2510   return res;
2511 }
2512
2513 /* c = |a| * |b| using Karatsuba Multiplication using 
2514  * three half size multiplications
2515  *
2516  * Let B represent the radix [e.g. 2**DIGIT_BIT] and 
2517  * let n represent half of the number of digits in 
2518  * the min(a,b)
2519  *
2520  * a = a1 * B**n + a0
2521  * b = b1 * B**n + b0
2522  *
2523  * Then, a * b => 
2524    a1b1 * B**2n + ((a1 - a0)(b1 - b0) + a0b0 + a1b1) * B + a0b0
2525  *
2526  * Note that a1b1 and a0b0 are used twice and only need to be 
2527  * computed once.  So in total three half size (half # of 
2528  * digit) multiplications are performed, a0b0, a1b1 and 
2529  * (a1-b1)(a0-b0)
2530  *
2531  * Note that a multiplication of half the digits requires
2532  * 1/4th the number of single precision multiplications so in 
2533  * total after one call 25% of the single precision multiplications 
2534  * are saved.  Note also that the call to mp_mul can end up back 
2535  * in this function if the a0, a1, b0, or b1 are above the threshold.  
2536  * This is known as divide-and-conquer and leads to the famous 
2537  * O(N**lg(3)) or O(N**1.584) work which is asymptotically lower than
2538  * the standard O(N**2) that the baseline/comba methods use.  
2539  * Generally though the overhead of this method doesn't pay off 
2540  * until a certain size (N ~ 80) is reached.
2541  */
2542 int mp_karatsuba_mul (const mp_int * a, const mp_int * b, mp_int * c)
2543 {
2544   mp_int  x0, x1, y0, y1, t1, x0y0, x1y1;
2545   int     B, err;
2546
2547   /* default the return code to an error */
2548   err = MP_MEM;
2549
2550   /* min # of digits */
2551   B = MIN (a->used, b->used);
2552
2553   /* now divide in two */
2554   B = B >> 1;
2555
2556   /* init copy all the temps */
2557   if (mp_init_size (&x0, B) != MP_OKAY)
2558     goto ERR;
2559   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2560     goto X0;
2561   if (mp_init_size (&y0, B) != MP_OKAY)
2562     goto X1;
2563   if (mp_init_size (&y1, b->used - B) != MP_OKAY)
2564     goto Y0;
2565
2566   /* init temps */
2567   if (mp_init_size (&t1, B * 2) != MP_OKAY)
2568     goto Y1;
2569   if (mp_init_size (&x0y0, B * 2) != MP_OKAY)
2570     goto T1;
2571   if (mp_init_size (&x1y1, B * 2) != MP_OKAY)
2572     goto X0Y0;
2573
2574   /* now shift the digits */
2575   x0.used = y0.used = B;
2576   x1.used = a->used - B;
2577   y1.used = b->used - B;
2578
2579   {
2580     register int x;
2581     register mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
2582
2583     /* we copy the digits directly instead of using higher level functions
2584      * since we also need to shift the digits
2585      */
2586     tmpa = a->dp;
2587     tmpb = b->dp;
2588
2589     tmpx = x0.dp;
2590     tmpy = y0.dp;
2591     for (x = 0; x < B; x++) {
2592       *tmpx++ = *tmpa++;
2593       *tmpy++ = *tmpb++;
2594     }
2595
2596     tmpx = x1.dp;
2597     for (x = B; x < a->used; x++) {
2598       *tmpx++ = *tmpa++;
2599     }
2600
2601     tmpy = y1.dp;
2602     for (x = B; x < b->used; x++) {
2603       *tmpy++ = *tmpb++;
2604     }
2605   }
2606
2607   /* only need to clamp the lower words since by definition the 
2608    * upper words x1/y1 must have a known number of digits
2609    */
2610   mp_clamp (&x0);
2611   mp_clamp (&y0);
2612
2613   /* now calc the products x0y0 and x1y1 */
2614   /* after this x0 is no longer required, free temp [x0==t2]! */
2615   if (mp_mul (&x0, &y0, &x0y0) != MP_OKAY)  
2616     goto X1Y1;          /* x0y0 = x0*y0 */
2617   if (mp_mul (&x1, &y1, &x1y1) != MP_OKAY)
2618     goto X1Y1;          /* x1y1 = x1*y1 */
2619
2620   /* now calc x1-x0 and y1-y0 */
2621   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2622     goto X1Y1;          /* t1 = x1 - x0 */
2623   if (mp_sub (&y1, &y0, &x0) != MP_OKAY)
2624     goto X1Y1;          /* t2 = y1 - y0 */
2625   if (mp_mul (&t1, &x0, &t1) != MP_OKAY)
2626     goto X1Y1;          /* t1 = (x1 - x0) * (y1 - y0) */
2627
2628   /* add x0y0 */
2629   if (mp_add (&x0y0, &x1y1, &x0) != MP_OKAY)
2630     goto X1Y1;          /* t2 = x0y0 + x1y1 */
2631   if (mp_sub (&x0, &t1, &t1) != MP_OKAY)
2632     goto X1Y1;          /* t1 = x0y0 + x1y1 - (x1-x0)*(y1-y0) */
2633
2634   /* shift by B */
2635   if (mp_lshd (&t1, B) != MP_OKAY)
2636     goto X1Y1;          /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
2637   if (mp_lshd (&x1y1, B * 2) != MP_OKAY)
2638     goto X1Y1;          /* x1y1 = x1y1 << 2*B */
2639
2640   if (mp_add (&x0y0, &t1, &t1) != MP_OKAY)
2641     goto X1Y1;          /* t1 = x0y0 + t1 */
2642   if (mp_add (&t1, &x1y1, c) != MP_OKAY)
2643     goto X1Y1;          /* t1 = x0y0 + t1 + x1y1 */
2644
2645   /* Algorithm succeeded set the return code to MP_OKAY */
2646   err = MP_OKAY;
2647
2648 X1Y1:mp_clear (&x1y1);
2649 X0Y0:mp_clear (&x0y0);
2650 T1:mp_clear (&t1);
2651 Y1:mp_clear (&y1);
2652 Y0:mp_clear (&y0);
2653 X1:mp_clear (&x1);
2654 X0:mp_clear (&x0);
2655 ERR:
2656   return err;
2657 }
2658
2659 /* Karatsuba squaring, computes b = a*a using three 
2660  * half size squarings
2661  *
2662  * See comments of karatsuba_mul for details.  It 
2663  * is essentially the same algorithm but merely 
2664  * tuned to perform recursive squarings.
2665  */
2666 int mp_karatsuba_sqr (const mp_int * a, mp_int * b)
2667 {
2668   mp_int  x0, x1, t1, t2, x0x0, x1x1;
2669   int     B, err;
2670
2671   err = MP_MEM;
2672
2673   /* min # of digits */
2674   B = a->used;
2675
2676   /* now divide in two */
2677   B = B >> 1;
2678
2679   /* init copy all the temps */
2680   if (mp_init_size (&x0, B) != MP_OKAY)
2681     goto ERR;
2682   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2683     goto X0;
2684
2685   /* init temps */
2686   if (mp_init_size (&t1, a->used * 2) != MP_OKAY)
2687     goto X1;
2688   if (mp_init_size (&t2, a->used * 2) != MP_OKAY)
2689     goto T1;
2690   if (mp_init_size (&x0x0, B * 2) != MP_OKAY)
2691     goto T2;
2692   if (mp_init_size (&x1x1, (a->used - B) * 2) != MP_OKAY)
2693     goto X0X0;
2694
2695   {
2696     register int x;
2697     register mp_digit *dst, *src;
2698
2699     src = a->dp;
2700
2701     /* now shift the digits */
2702     dst = x0.dp;
2703     for (x = 0; x < B; x++) {
2704       *dst++ = *src++;
2705     }
2706
2707     dst = x1.dp;
2708     for (x = B; x < a->used; x++) {
2709       *dst++ = *src++;
2710     }
2711   }
2712
2713   x0.used = B;
2714   x1.used = a->used - B;
2715
2716   mp_clamp (&x0);
2717
2718   /* now calc the products x0*x0 and x1*x1 */
2719   if (mp_sqr (&x0, &x0x0) != MP_OKAY)
2720     goto X1X1;           /* x0x0 = x0*x0 */
2721   if (mp_sqr (&x1, &x1x1) != MP_OKAY)
2722     goto X1X1;           /* x1x1 = x1*x1 */
2723
2724   /* now calc (x1-x0)**2 */
2725   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2726     goto X1X1;           /* t1 = x1 - x0 */
2727   if (mp_sqr (&t1, &t1) != MP_OKAY)
2728     goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */
2729
2730   /* add x0y0 */
2731   if (s_mp_add (&x0x0, &x1x1, &t2) != MP_OKAY)
2732     goto X1X1;           /* t2 = x0x0 + x1x1 */
2733   if (mp_sub (&t2, &t1, &t1) != MP_OKAY)
2734     goto X1X1;           /* t1 = x0x0 + x1x1 - (x1-x0)*(x1-x0) */
2735
2736   /* shift by B */
2737   if (mp_lshd (&t1, B) != MP_OKAY)
2738     goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
2739   if (mp_lshd (&x1x1, B * 2) != MP_OKAY)
2740     goto X1X1;           /* x1x1 = x1x1 << 2*B */
2741
2742   if (mp_add (&x0x0, &t1, &t1) != MP_OKAY)
2743     goto X1X1;           /* t1 = x0x0 + t1 */
2744   if (mp_add (&t1, &x1x1, b) != MP_OKAY)
2745     goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */
2746
2747   err = MP_OKAY;
2748
2749 X1X1:mp_clear (&x1x1);
2750 X0X0:mp_clear (&x0x0);
2751 T2:mp_clear (&t2);
2752 T1:mp_clear (&t1);
2753 X1:mp_clear (&x1);
2754 X0:mp_clear (&x0);
2755 ERR:
2756   return err;
2757 }
2758
2759 /* computes least common multiple as |a*b|/(a, b) */
2760 int mp_lcm (const mp_int * a, const mp_int * b, mp_int * c)
2761 {
2762   int     res;
2763   mp_int  t1, t2;
2764
2765
2766   if ((res = mp_init_multi (&t1, &t2, NULL)) != MP_OKAY) {
2767     return res;
2768   }
2769
2770   /* t1 = get the GCD of the two inputs */
2771   if ((res = mp_gcd (a, b, &t1)) != MP_OKAY) {
2772     goto __T;
2773   }
2774
2775   /* divide the smallest by the GCD */
2776   if (mp_cmp_mag(a, b) == MP_LT) {
2777      /* store quotient in t2 such that t2 * b is the LCM */
2778      if ((res = mp_div(a, &t1, &t2, NULL)) != MP_OKAY) {
2779         goto __T;
2780      }
2781      res = mp_mul(b, &t2, c);
2782   } else {
2783      /* store quotient in t2 such that t2 * a is the LCM */
2784      if ((res = mp_div(b, &t1, &t2, NULL)) != MP_OKAY) {
2785         goto __T;
2786      }
2787      res = mp_mul(a, &t2, c);
2788   }
2789
2790   /* fix the sign to positive */
2791   c->sign = MP_ZPOS;
2792
2793 __T:
2794   mp_clear_multi (&t1, &t2, NULL);
2795   return res;
2796 }
2797
2798 /* c = a mod b, 0 <= c < b */
2799 int
2800 mp_mod (const mp_int * a, mp_int * b, mp_int * c)
2801 {
2802   mp_int  t;
2803   int     res;
2804
2805   if ((res = mp_init (&t)) != MP_OKAY) {
2806     return res;
2807   }
2808
2809   if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
2810     mp_clear (&t);
2811     return res;
2812   }
2813
2814   if (t.sign != b->sign) {
2815     res = mp_add (b, &t, c);
2816   } else {
2817     res = MP_OKAY;
2818     mp_exch (&t, c);
2819   }
2820
2821   mp_clear (&t);
2822   return res;
2823 }
2824
2825 static int
2826 mp_mod_d (const mp_int * a, mp_digit b, mp_digit * c)
2827 {
2828   return mp_div_d(a, b, NULL, c);
2829 }
2830
2831 /* b = a*2 */
2832 static int mp_mul_2(const mp_int * a, mp_int * b)
2833 {
2834   int     x, res, oldused;
2835
2836   /* grow to accommodate result */
2837   if (b->alloc < a->used + 1) {
2838     if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
2839       return res;
2840     }
2841   }
2842
2843   oldused = b->used;
2844   b->used = a->used;
2845
2846   {
2847     register mp_digit r, rr, *tmpa, *tmpb;
2848
2849     /* alias for source */
2850     tmpa = a->dp;
2851
2852     /* alias for dest */
2853     tmpb = b->dp;
2854
2855     /* carry */
2856     r = 0;
2857     for (x = 0; x < a->used; x++) {
2858
2859       /* get what will be the *next* carry bit from the
2860        * MSB of the current digit
2861        */
2862       rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
2863
2864       /* now shift up this digit, add in the carry [from the previous] */
2865       *tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
2866
2867       /* copy the carry that would be from the source
2868        * digit into the next iteration
2869        */
2870       r = rr;
2871     }
2872
2873     /* new leading digit? */
2874     if (r != 0) {
2875       /* add a MSB which is always 1 at this point */
2876       *tmpb = 1;
2877       ++(b->used);
2878     }
2879
2880     /* now zero any excess digits on the destination
2881      * that we didn't write to
2882      */
2883     tmpb = b->dp + b->used;
2884     for (x = b->used; x < oldused; x++) {
2885       *tmpb++ = 0;
2886     }
2887   }
2888   b->sign = a->sign;
2889   return MP_OKAY;
2890 }
2891
2892 /*
2893  * shifts with subtractions when the result is greater than b.
2894  *
2895  * The method is slightly modified to shift B unconditionally up to just under
2896  * the leading bit of b.  This saves a lot of multiple precision shifting.
2897  */
2898 int mp_montgomery_calc_normalization (mp_int * a, const mp_int * b)
2899 {
2900   int     x, bits, res;
2901
2902   /* how many bits of last digit does b use */
2903   bits = mp_count_bits (b) % DIGIT_BIT;
2904
2905
2906   if (b->used > 1) {
2907      if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
2908         return res;
2909      }
2910   } else {
2911      mp_set(a, 1);
2912      bits = 1;
2913   }
2914
2915
2916   /* now compute C = A * B mod b */
2917   for (x = bits - 1; x < DIGIT_BIT; x++) {
2918     if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
2919       return res;
2920     }
2921     if (mp_cmp_mag (a, b) != MP_LT) {
2922       if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
2923         return res;
2924       }
2925     }
2926   }
2927
2928   return MP_OKAY;
2929 }
2930
2931 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
2932 int
2933 mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
2934 {
2935   int     ix, res, digs;
2936   mp_digit mu;
2937
2938   /* can the fast reduction [comba] method be used?
2939    *
2940    * Note that unlike in mul you're safely allowed *less*
2941    * than the available columns [255 per default] since carries
2942    * are fixed up in the inner loop.
2943    */
2944   digs = n->used * 2 + 1;
2945   if ((digs < MP_WARRAY) &&
2946       n->used <
2947       (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2948     return fast_mp_montgomery_reduce (x, n, rho);
2949   }
2950
2951   /* grow the input as required */
2952   if (x->alloc < digs) {
2953     if ((res = mp_grow (x, digs)) != MP_OKAY) {
2954       return res;
2955     }
2956   }
2957   x->used = digs;
2958
2959   for (ix = 0; ix < n->used; ix++) {
2960     /* mu = ai * rho mod b
2961      *
2962      * The value of rho must be precalculated via
2963      * montgomery_setup() such that
2964      * it equals -1/n0 mod b this allows the
2965      * following inner loop to reduce the
2966      * input one digit at a time
2967      */
2968     mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
2969
2970     /* a = a + mu * m * b**i */
2971     {
2972       register int iy;
2973       register mp_digit *tmpn, *tmpx, u;
2974       register mp_word r;
2975
2976       /* alias for digits of the modulus */
2977       tmpn = n->dp;
2978
2979       /* alias for the digits of x [the input] */
2980       tmpx = x->dp + ix;
2981
2982       /* set the carry to zero */
2983       u = 0;
2984
2985       /* Multiply and add in place */
2986       for (iy = 0; iy < n->used; iy++) {
2987         /* compute product and sum */
2988         r       = ((mp_word)mu) * ((mp_word)*tmpn++) +
2989                   ((mp_word) u) + ((mp_word) * tmpx);
2990
2991         /* get carry */
2992         u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2993
2994         /* fix digit */
2995         *tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
2996       }
2997       /* At this point the ix'th digit of x should be zero */
2998
2999
3000       /* propagate carries upwards as required*/
3001       while (u) {
3002         *tmpx   += u;
3003         u        = *tmpx >> DIGIT_BIT;
3004         *tmpx++ &= MP_MASK;
3005       }
3006     }
3007   }
3008
3009   /* at this point the n.used'th least
3010    * significant digits of x are all zero
3011    * which means we can shift x to the
3012    * right by n.used digits and the
3013    * residue is unchanged.
3014    */
3015
3016   /* x = x/b**n.used */
3017   mp_clamp(x);
3018   mp_rshd (x, n->used);
3019
3020   /* if x >= n then x = x - n */
3021   if (mp_cmp_mag (x, n) != MP_LT) {
3022     return s_mp_sub (x, n, x);
3023   }
3024
3025   return MP_OKAY;
3026 }
3027
3028 /* setups the montgomery reduction stuff */
3029 int
3030 mp_montgomery_setup (const mp_int * n, mp_digit * rho)
3031 {
3032   mp_digit x, b;
3033
3034 /* fast inversion mod 2**k
3035  *
3036  * Based on the fact that
3037  *
3038  * XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
3039  *                    =>  2*X*A - X*X*A*A = 1
3040  *                    =>  2*(1) - (1)     = 1
3041  */
3042   b = n->dp[0];
3043
3044   if ((b & 1) == 0) {
3045     return MP_VAL;
3046   }
3047
3048   x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
3049   x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
3050   x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
3051   x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
3052
3053   /* rho = -1/m mod b */
3054   *rho = (((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
3055
3056   return MP_OKAY;
3057 }
3058
3059 /* high level multiplication (handles sign) */
3060 int mp_mul (const mp_int * a, const mp_int * b, mp_int * c)
3061 {
3062   int     res, neg;
3063   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
3064
3065   /* use Karatsuba? */
3066   if (MIN (a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
3067     res = mp_karatsuba_mul (a, b, c);
3068   } else 
3069   {
3070     /* can we use the fast multiplier?
3071      *
3072      * The fast multiplier can be used if the output will 
3073      * have less than MP_WARRAY digits and the number of 
3074      * digits won't affect carry propagation
3075      */
3076     int     digs = a->used + b->used + 1;
3077
3078     if ((digs < MP_WARRAY) &&
3079         MIN(a->used, b->used) <= 
3080         (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
3081       res = fast_s_mp_mul_digs (a, b, c, digs);
3082     } else 
3083       res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
3084   }
3085   c->sign = (c->used > 0) ? neg : MP_ZPOS;
3086   return res;
3087 }
3088
3089 /* d = a * b (mod c) */
3090 int
3091 mp_mulmod (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
3092 {
3093   int     res;
3094   mp_int  t;
3095
3096   if ((res = mp_init (&t)) != MP_OKAY) {
3097     return res;
3098   }
3099
3100   if ((res = mp_mul (a, b, &t)) != MP_OKAY) {
3101     mp_clear (&t);
3102     return res;
3103   }
3104   res = mp_mod (&t, c, d);
3105   mp_clear (&t);
3106   return res;
3107 }
3108
3109 /* table of first PRIME_SIZE primes */
3110 static const mp_digit __prime_tab[] = {
3111   0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
3112   0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
3113   0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
3114   0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F, 0x0083,
3115   0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
3116   0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
3117   0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
3118   0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
3119
3120   0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
3121   0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
3122   0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
3123   0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
3124   0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
3125   0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
3126   0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
3127   0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
3128
3129   0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
3130   0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
3131   0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
3132   0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
3133   0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
3134   0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
3135   0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
3136   0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
3137
3138   0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
3139   0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
3140   0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
3141   0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
3142   0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
3143   0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
3144   0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
3145   0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
3146 };
3147
3148 /* determines if an integers is divisible by one 
3149  * of the first PRIME_SIZE primes or not
3150  *
3151  * sets result to 0 if not, 1 if yes
3152  */
3153 static int mp_prime_is_divisible (const mp_int * a, int *result)
3154 {
3155   int     err, ix;
3156   mp_digit res;
3157
3158   /* default to not */
3159   *result = MP_NO;
3160
3161   for (ix = 0; ix < PRIME_SIZE; ix++) {
3162     /* what is a mod __prime_tab[ix] */
3163     if ((err = mp_mod_d (a, __prime_tab[ix], &res)) != MP_OKAY) {
3164       return err;
3165     }
3166
3167     /* is the residue zero? */
3168     if (res == 0) {
3169       *result = MP_YES;
3170       return MP_OKAY;
3171     }
3172   }
3173
3174   return MP_OKAY;
3175 }
3176
3177 /* Miller-Rabin test of "a" to the base of "b" as described in 
3178  * HAC pp. 139 Algorithm 4.24
3179  *
3180  * Sets result to 0 if definitely composite or 1 if probably prime.
3181  * Randomly the chance of error is no more than 1/4 and often 
3182  * very much lower.
3183  */
3184 static int mp_prime_miller_rabin (mp_int * a, const mp_int * b, int *result)
3185 {
3186   mp_int  n1, y, r;
3187   int     s, j, err;
3188
3189   /* default */
3190   *result = MP_NO;
3191
3192   /* ensure b > 1 */
3193   if (mp_cmp_d(b, 1) != MP_GT) {
3194      return MP_VAL;
3195   }     
3196
3197   /* get n1 = a - 1 */
3198   if ((err = mp_init_copy (&n1, a)) != MP_OKAY) {
3199     return err;
3200   }
3201   if ((err = mp_sub_d (&n1, 1, &n1)) != MP_OKAY) {
3202     goto __N1;
3203   }
3204
3205   /* set 2**s * r = n1 */
3206   if ((err = mp_init_copy (&r, &n1)) != MP_OKAY) {
3207     goto __N1;
3208   }
3209
3210   /* count the number of least significant bits
3211    * which are zero
3212    */
3213   s = mp_cnt_lsb(&r);
3214
3215   /* now divide n - 1 by 2**s */
3216   if ((err = mp_div_2d (&r, s, &r, NULL)) != MP_OKAY) {
3217     goto __R;
3218   }
3219
3220   /* compute y = b**r mod a */
3221   if ((err = mp_init (&y)) != MP_OKAY) {
3222     goto __R;
3223   }
3224   if ((err = mp_exptmod (b, &r, a, &y)) != MP_OKAY) {
3225     goto __Y;
3226   }
3227
3228   /* if y != 1 and y != n1 do */
3229   if (mp_cmp_d (&y, 1) != MP_EQ && mp_cmp (&y, &n1) != MP_EQ) {
3230     j = 1;
3231     /* while j <= s-1 and y != n1 */
3232     while ((j <= (s - 1)) && mp_cmp (&y, &n1) != MP_EQ) {
3233       if ((err = mp_sqrmod (&y, a, &y)) != MP_OKAY) {
3234          goto __Y;
3235       }
3236
3237       /* if y == 1 then composite */
3238       if (mp_cmp_d (&y, 1) == MP_EQ) {
3239          goto __Y;
3240       }
3241
3242       ++j;
3243     }
3244
3245     /* if y != n1 then composite */
3246     if (mp_cmp (&y, &n1) != MP_EQ) {
3247       goto __Y;
3248     }
3249   }
3250
3251   /* probably prime now */
3252   *result = MP_YES;
3253 __Y:mp_clear (&y);
3254 __R:mp_clear (&r);
3255 __N1:mp_clear (&n1);
3256   return err;
3257 }
3258
3259 /* performs a variable number of rounds of Miller-Rabin
3260  *
3261  * Probability of error after t rounds is no more than
3262
3263  *
3264  * Sets result to 1 if probably prime, 0 otherwise
3265  */
3266 static int mp_prime_is_prime (mp_int * a, int t, int *result)
3267 {
3268   mp_int  b;
3269   int     ix, err, res;
3270
3271   /* default to no */
3272   *result = MP_NO;
3273
3274   /* valid value of t? */
3275   if (t <= 0 || t > PRIME_SIZE) {
3276     return MP_VAL;
3277   }
3278
3279   /* is the input equal to one of the primes in the table? */
3280   for (ix = 0; ix < PRIME_SIZE; ix++) {
3281       if (mp_cmp_d(a, __prime_tab[ix]) == MP_EQ) {
3282          *result = 1;
3283          return MP_OKAY;
3284       }
3285   }
3286
3287   /* first perform trial division */
3288   if ((err = mp_prime_is_divisible (a, &res)) != MP_OKAY) {
3289     return err;
3290   }
3291
3292   /* return if it was trivially divisible */
3293   if (res == MP_YES) {
3294     return MP_OKAY;
3295   }
3296
3297   /* now perform the miller-rabin rounds */
3298   if ((err = mp_init (&b)) != MP_OKAY) {
3299     return err;
3300   }
3301
3302   for (ix = 0; ix < t; ix++) {
3303     /* set the prime */
3304     mp_set (&b, __prime_tab[ix]);
3305
3306     if ((err = mp_prime_miller_rabin (a, &b, &res)) != MP_OKAY) {
3307       goto __B;
3308     }
3309
3310     if (res == MP_NO) {
3311       goto __B;
3312     }
3313   }
3314
3315   /* passed the test */
3316   *result = MP_YES;
3317 __B:mp_clear (&b);
3318   return err;
3319 }
3320
3321 static const struct {
3322    int k, t;
3323 } sizes[] = {
3324 {   128,    28 },
3325 {   256,    16 },
3326 {   384,    10 },
3327 {   512,     7 },
3328 {   640,     6 },
3329 {   768,     5 },
3330 {   896,     4 },
3331 {  1024,     4 }
3332 };
3333
3334 /* returns # of RM trials required for a given bit size */
3335 int mp_prime_rabin_miller_trials(int size)
3336 {
3337    int x;
3338
3339    for (x = 0; x < (int)(sizeof(sizes)/(sizeof(sizes[0]))); x++) {
3340        if (sizes[x].k == size) {
3341           return sizes[x].t;
3342        } else if (sizes[x].k > size) {
3343           return (x == 0) ? sizes[0].t : sizes[x - 1].t;
3344        }
3345    }
3346    return sizes[x-1].t + 1;
3347 }
3348
3349 /* makes a truly random prime of a given size (bits),
3350  *
3351  * Flags are as follows:
3352  * 
3353  *   LTM_PRIME_BBS      - make prime congruent to 3 mod 4
3354  *   LTM_PRIME_SAFE     - make sure (p-1)/2 is prime as well (implies LTM_PRIME_BBS)
3355  *   LTM_PRIME_2MSB_OFF - make the 2nd highest bit zero
3356  *   LTM_PRIME_2MSB_ON  - make the 2nd highest bit one
3357  *
3358  * You have to supply a callback which fills in a buffer with random bytes.  "dat" is a parameter you can
3359  * have passed to the callback (e.g. a state or something).  This function doesn't use "dat" itself
3360  * so it can be NULL
3361  *
3362  */
3363
3364 /* This is possibly the mother of all prime generation functions, muahahahahaha! */
3365 int mp_prime_random_ex(mp_int *a, int t, int size, int flags, ltm_prime_callback cb, void *dat)
3366 {
3367    unsigned char *tmp, maskAND, maskOR_msb, maskOR_lsb;
3368    int res, err, bsize, maskOR_msb_offset;
3369
3370    /* sanity check the input */
3371    if (size <= 1 || t <= 0) {
3372       return MP_VAL;
3373    }
3374
3375    /* LTM_PRIME_SAFE implies LTM_PRIME_BBS */
3376    if (flags & LTM_PRIME_SAFE) {
3377       flags |= LTM_PRIME_BBS;
3378    }
3379
3380    /* calc the byte size */
3381    bsize = (size>>3)+((size&7)?1:0);
3382
3383    /* we need a buffer of bsize bytes */
3384    tmp = HeapAlloc(GetProcessHeap(), 0, bsize);
3385    if (tmp == NULL) {
3386       return MP_MEM;
3387    }
3388
3389    /* calc the maskAND value for the MSbyte*/
3390    maskAND = ((size&7) == 0) ? 0xFF : (0xFF >> (8 - (size & 7))); 
3391
3392    /* calc the maskOR_msb */
3393    maskOR_msb        = 0;
3394    maskOR_msb_offset = ((size & 7) == 1) ? 1 : 0;
3395    if (flags & LTM_PRIME_2MSB_ON) {
3396       maskOR_msb     |= 1 << ((size - 2) & 7);
3397    } else if (flags & LTM_PRIME_2MSB_OFF) {
3398       maskAND        &= ~(1 << ((size - 2) & 7));
3399    }
3400
3401    /* get the maskOR_lsb */
3402    maskOR_lsb         = 0;
3403    if (flags & LTM_PRIME_BBS) {
3404       maskOR_lsb     |= 3;
3405    }
3406
3407    do {
3408       /* read the bytes */
3409       if (cb(tmp, bsize, dat) != bsize) {
3410          err = MP_VAL;
3411          goto error;
3412       }
3413  
3414       /* work over the MSbyte */
3415       tmp[0]    &= maskAND;
3416       tmp[0]    |= 1 << ((size - 1) & 7);
3417
3418       /* mix in the maskORs */
3419       tmp[maskOR_msb_offset]   |= maskOR_msb;
3420       tmp[bsize-1]             |= maskOR_lsb;
3421
3422       /* read it in */
3423       if ((err = mp_read_unsigned_bin(a, tmp, bsize)) != MP_OKAY)     { goto error; }
3424
3425       /* is it prime? */
3426       if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)           { goto error; }
3427       if (res == MP_NO) {  
3428          continue;
3429       }
3430
3431       if (flags & LTM_PRIME_SAFE) {
3432          /* see if (a-1)/2 is prime */
3433          if ((err = mp_sub_d(a, 1, a)) != MP_OKAY)                    { goto error; }
3434          if ((err = mp_div_2(a, a)) != MP_OKAY)                       { goto error; }
3435  
3436          /* is it prime? */
3437          if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)        { goto error; }
3438       }
3439    } while (res == MP_NO);
3440
3441    if (flags & LTM_PRIME_SAFE) {
3442       /* restore a to the original value */
3443       if ((err = mp_mul_2(a, a)) != MP_OKAY)                          { goto error; }
3444       if ((err = mp_add_d(a, 1, a)) != MP_OKAY)                       { goto error; }
3445    }
3446
3447    err = MP_OKAY;
3448 error:
3449    HeapFree(GetProcessHeap(), 0, tmp);
3450    return err;
3451 }
3452
3453 /* reads an unsigned char array, assumes the msb is stored first [big endian] */
3454 int
3455 mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
3456 {
3457   int     res;
3458
3459   /* make sure there are at least two digits */
3460   if (a->alloc < 2) {
3461      if ((res = mp_grow(a, 2)) != MP_OKAY) {
3462         return res;
3463      }
3464   }
3465
3466   /* zero the int */
3467   mp_zero (a);
3468
3469   /* read the bytes in */
3470   while (c-- > 0) {
3471     if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
3472       return res;
3473     }
3474
3475       a->dp[0] |= *b++;
3476       a->used += 1;
3477   }
3478   mp_clamp (a);
3479   return MP_OKAY;
3480 }
3481
3482 /* reduces x mod m, assumes 0 < x < m**2, mu is 
3483  * precomputed via mp_reduce_setup.
3484  * From HAC pp.604 Algorithm 14.42
3485  */
3486 int
3487 mp_reduce (mp_int * x, const mp_int * m, const mp_int * mu)
3488 {
3489   mp_int  q;
3490   int     res, um = m->used;
3491
3492   /* q = x */
3493   if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
3494     return res;
3495   }
3496
3497   /* q1 = x / b**(k-1)  */
3498   mp_rshd (&q, um - 1);         
3499
3500   /* according to HAC this optimization is ok */
3501   if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
3502     if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
3503       goto CLEANUP;
3504     }
3505   } else {
3506     if ((res = s_mp_mul_high_digs (&q, mu, &q, um - 1)) != MP_OKAY) {
3507       goto CLEANUP;
3508     }
3509   }
3510
3511   /* q3 = q2 / b**(k+1) */
3512   mp_rshd (&q, um + 1);         
3513
3514   /* x = x mod b**(k+1), quick (no division) */
3515   if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
3516     goto CLEANUP;
3517   }
3518
3519   /* q = q * m mod b**(k+1), quick (no division) */
3520   if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
3521     goto CLEANUP;
3522   }
3523
3524   /* x = x - q */
3525   if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
3526     goto CLEANUP;
3527   }
3528
3529   /* If x < 0, add b**(k+1) to it */
3530   if (mp_cmp_d (x, 0) == MP_LT) {
3531     mp_set (&q, 1);
3532     if ((res = mp_lshd (&q, um + 1)) != MP_OKAY)
3533       goto CLEANUP;
3534     if ((res = mp_add (x, &q, x)) != MP_OKAY)
3535       goto CLEANUP;
3536   }
3537
3538   /* Back off if it's too big */
3539   while (mp_cmp (x, m) != MP_LT) {
3540     if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
3541       goto CLEANUP;
3542     }
3543   }
3544   
3545 CLEANUP:
3546   mp_clear (&q);
3547
3548   return res;
3549 }
3550
3551 /* reduces a modulo n where n is of the form 2**p - d */
3552 int
3553 mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d)
3554 {
3555    mp_int q;
3556    int    p, res;
3557    
3558    if ((res = mp_init(&q)) != MP_OKAY) {
3559       return res;
3560    }
3561    
3562    p = mp_count_bits(n);    
3563 top:
3564    /* q = a/2**p, a = a mod 2**p */
3565    if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
3566       goto ERR;
3567    }
3568    
3569    if (d != 1) {
3570       /* q = q * d */
3571       if ((res = mp_mul_d(&q, d, &q)) != MP_OKAY) { 
3572          goto ERR;
3573       }
3574    }
3575    
3576    /* a = a + q */
3577    if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
3578       goto ERR;
3579    }
3580    
3581    if (mp_cmp_mag(a, n) != MP_LT) {
3582       s_mp_sub(a, n, a);
3583       goto top;
3584    }
3585    
3586 ERR:
3587    mp_clear(&q);
3588    return res;
3589 }
3590
3591 /* determines the setup value */
3592 int 
3593 mp_reduce_2k_setup(const mp_int *a, mp_digit *d)
3594 {
3595    int res, p;
3596    mp_int tmp;
3597    
3598    if ((res = mp_init(&tmp)) != MP_OKAY) {
3599       return res;
3600    }
3601    
3602    p = mp_count_bits(a);
3603    if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
3604       mp_clear(&tmp);
3605       return res;
3606    }
3607    
3608    if ((res = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
3609       mp_clear(&tmp);
3610       return res;
3611    }
3612    
3613    *d = tmp.dp[0];
3614    mp_clear(&tmp);
3615    return MP_OKAY;
3616 }
3617
3618 /* pre-calculate the value required for Barrett reduction
3619  * For a given modulus "b" it calulates the value required in "a"
3620  */
3621 int mp_reduce_setup (mp_int * a, const mp_int * b)
3622 {
3623   int     res;
3624
3625   if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
3626     return res;
3627   }
3628   return mp_div (a, b, a, NULL);
3629 }
3630
3631 /* set to a digit */
3632 void mp_set (mp_int * a, mp_digit b)
3633 {
3634   mp_zero (a);
3635   a->dp[0] = b & MP_MASK;
3636   a->used  = (a->dp[0] != 0) ? 1 : 0;
3637 }
3638
3639 /* set a 32-bit const */
3640 int mp_set_int (mp_int * a, unsigned long b)
3641 {
3642   int     x, res;
3643
3644   mp_zero (a);
3645   
3646   /* set four bits at a time */
3647   for (x = 0; x < 8; x++) {
3648     /* shift the number up four bits */
3649     if ((res = mp_mul_2d (a, 4, a)) != MP_OKAY) {
3650       return res;
3651     }
3652
3653     /* OR in the top four bits of the source */
3654     a->dp[0] |= (b >> 28) & 15;
3655
3656     /* shift the source up to the next four bits */
3657     b <<= 4;
3658
3659     /* ensure that digits are not clamped off */
3660     a->used += 1;
3661   }
3662   mp_clamp (a);
3663   return MP_OKAY;
3664 }
3665
3666 /* shrink a bignum */
3667 int mp_shrink (mp_int * a)
3668 {
3669   mp_digit *tmp;
3670   if (a->alloc != a->used && a->used > 0) {
3671     if ((tmp = HeapReAlloc(GetProcessHeap(), 0, a->dp, sizeof (mp_digit) * a->used)) == NULL) {
3672       return MP_MEM;
3673     }
3674     a->dp    = tmp;
3675     a->alloc = a->used;
3676   }
3677   return MP_OKAY;
3678 }
3679
3680 /* get the size for an signed equivalent */
3681 int mp_signed_bin_size (const mp_int * a)
3682 {
3683   return 1 + mp_unsigned_bin_size (a);
3684 }
3685
3686 /* computes b = a*a */
3687 int
3688 mp_sqr (const mp_int * a, mp_int * b)
3689 {
3690   int     res;
3691
3692 if (a->used >= KARATSUBA_SQR_CUTOFF) {
3693     res = mp_karatsuba_sqr (a, b);
3694   } else 
3695   {
3696     /* can we use the fast comba multiplier? */
3697     if ((a->used * 2 + 1) < MP_WARRAY && 
3698          a->used < 
3699          (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
3700       res = fast_s_mp_sqr (a, b);
3701     } else
3702       res = s_mp_sqr (a, b);
3703   }
3704   b->sign = MP_ZPOS;
3705   return res;
3706 }
3707
3708 /* c = a * a (mod b) */
3709 int
3710 mp_sqrmod (const mp_int * a, mp_int * b, mp_int * c)
3711 {
3712   int     res;
3713   mp_int  t;
3714
3715   if ((res = mp_init (&t)) != MP_OKAY) {
3716     return res;
3717   }
3718
3719   if ((res = mp_sqr (a, &t)) != MP_OKAY) {
3720     mp_clear (&t);
3721     return res;
3722   }
3723   res = mp_mod (&t, b, c);
3724   mp_clear (&t);
3725   return res;
3726 }
3727
3728 /* high level subtraction (handles signs) */
3729 int
3730 mp_sub (mp_int * a, mp_int * b, mp_int * c)
3731 {
3732   int     sa, sb, res;
3733
3734   sa = a->sign;
3735   sb = b->sign;
3736
3737   if (sa != sb) {
3738     /* subtract a negative from a positive, OR */
3739     /* subtract a positive from a negative. */
3740     /* In either case, ADD their magnitudes, */
3741     /* and use the sign of the first number. */
3742     c->sign = sa;
3743     res = s_mp_add (a, b, c);
3744   } else {
3745     /* subtract a positive from a positive, OR */
3746     /* subtract a negative from a negative. */
3747     /* First, take the difference between their */
3748     /* magnitudes, then... */
3749     if (mp_cmp_mag (a, b) != MP_LT) {
3750       /* Copy the sign from the first */
3751       c->sign = sa;
3752       /* The first has a larger or equal magnitude */
3753       res = s_mp_sub (a, b, c);
3754     } else {
3755       /* The result has the *opposite* sign from */
3756       /* the first number. */
3757       c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
3758       /* The second has a larger magnitude */
3759       res = s_mp_sub (b, a, c);
3760     }
3761   }
3762   return res;
3763 }
3764
3765 /* single digit subtraction */
3766 int
3767 mp_sub_d (mp_int * a, mp_digit b, mp_int * c)
3768 {
3769   mp_digit *tmpa, *tmpc, mu;
3770   int       res, ix, oldused;
3771
3772   /* grow c as required */
3773   if (c->alloc < a->used + 1) {
3774      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
3775         return res;
3776      }
3777   }
3778
3779   /* if a is negative just do an unsigned
3780    * addition [with fudged signs]
3781    */
3782   if (a->sign == MP_NEG) {
3783      a->sign = MP_ZPOS;
3784      res     = mp_add_d(a, b, c);
3785      a->sign = c->sign = MP_NEG;
3786      return res;
3787   }
3788
3789   /* setup regs */
3790   oldused = c->used;
3791   tmpa    = a->dp;
3792   tmpc    = c->dp;
3793
3794   /* if a <= b simply fix the single digit */
3795   if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
3796      if (a->used == 1) {
3797         *tmpc++ = b - *tmpa;
3798      } else {
3799         *tmpc++ = b;
3800      }
3801      ix      = 1;
3802
3803      /* negative/1digit */
3804      c->sign = MP_NEG;
3805      c->used = 1;
3806   } else {
3807      /* positive/size */
3808      c->sign = MP_ZPOS;
3809      c->used = a->used;
3810
3811      /* subtract first digit */
3812      *tmpc    = *tmpa++ - b;
3813      mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3814      *tmpc++ &= MP_MASK;
3815
3816      /* handle rest of the digits */
3817      for (ix = 1; ix < a->used; ix++) {
3818         *tmpc    = *tmpa++ - mu;
3819         mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3820         *tmpc++ &= MP_MASK;
3821      }
3822   }
3823
3824   /* zero excess digits */
3825   while (ix++ < oldused) {
3826      *tmpc++ = 0;
3827   }
3828   mp_clamp(c);
3829   return MP_OKAY;
3830 }
3831
3832 /* store in unsigned [big endian] format */
3833 int
3834 mp_to_unsigned_bin (const mp_int * a, unsigned char *b)
3835 {
3836   int     x, res;
3837   mp_int  t;
3838
3839   if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
3840     return res;
3841   }
3842
3843   x = 0;
3844   while (mp_iszero (&t) == 0) {
3845     b[x++] = (unsigned char) (t.dp[0] & 255);
3846     if ((res = mp_div_2d (&t, 8, &t, NULL)) != MP_OKAY) {
3847       mp_clear (&t);
3848       return res;
3849     }
3850   }
3851   bn_reverse (b, x);
3852   mp_clear (&t);
3853   return MP_OKAY;
3854 }
3855
3856 /* get the size for an unsigned equivalent */
3857 int
3858 mp_unsigned_bin_size (const mp_int * a)
3859 {
3860   int     size = mp_count_bits (a);
3861   return (size / 8 + ((size & 7) != 0 ? 1 : 0));
3862 }
3863
3864 /* reverse an array, used for radix code */
3865 static void
3866 bn_reverse (unsigned char *s, int len)
3867 {
3868   int     ix, iy;
3869   unsigned char t;
3870
3871   ix = 0;
3872   iy = len - 1;
3873   while (ix < iy) {
3874     t     = s[ix];
3875     s[ix] = s[iy];
3876     s[iy] = t;
3877     ++ix;
3878     --iy;
3879   }
3880 }
3881
3882 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
3883 static int
3884 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
3885 {
3886   mp_int *x;
3887   int     olduse, res, min, max;
3888
3889   /* find sizes, we let |a| <= |b| which means we have to sort
3890    * them.  "x" will point to the input with the most digits
3891    */
3892   if (a->used > b->used) {
3893     min = b->used;
3894     max = a->used;
3895     x = a;
3896   } else {
3897     min = a->used;
3898     max = b->used;
3899     x = b;
3900   }
3901
3902   /* init result */
3903   if (c->alloc < max + 1) {
3904     if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
3905       return res;
3906     }
3907   }
3908
3909   /* get old used digit count and set new one */
3910   olduse = c->used;
3911   c->used = max + 1;
3912
3913   {
3914     register mp_digit u, *tmpa, *tmpb, *tmpc;
3915     register int i;
3916
3917     /* alias for digit pointers */
3918
3919     /* first input */
3920     tmpa = a->dp;
3921
3922     /* second input */
3923     tmpb = b->dp;
3924
3925     /* destination */
3926     tmpc = c->dp;
3927
3928     /* zero the carry */
3929     u = 0;
3930     for (i = 0; i < min; i++) {
3931       /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
3932       *tmpc = *tmpa++ + *tmpb++ + u;
3933
3934       /* U = carry bit of T[i] */
3935       u = *tmpc >> ((mp_digit)DIGIT_BIT);
3936
3937       /* take away carry bit from T[i] */
3938       *tmpc++ &= MP_MASK;
3939     }
3940
3941     /* now copy higher words if any, that is in A+B 
3942      * if A or B has more digits add those in 
3943      */
3944     if (min != max) {
3945       for (; i < max; i++) {
3946         /* T[i] = X[i] + U */
3947         *tmpc = x->dp[i] + u;
3948
3949         /* U = carry bit of T[i] */
3950         u = *tmpc >> ((mp_digit)DIGIT_BIT);
3951
3952         /* take away carry bit from T[i] */
3953         *tmpc++ &= MP_MASK;
3954       }
3955     }
3956
3957     /* add carry */
3958     *tmpc++ = u;
3959
3960     /* clear digits above oldused */
3961     for (i = c->used; i < olduse; i++) {
3962       *tmpc++ = 0;
3963     }
3964   }
3965
3966   mp_clamp (c);
3967   return MP_OKAY;
3968 }
3969
3970 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
3971 {
3972   mp_int  M[256], res, mu;
3973   mp_digit buf;
3974   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
3975
3976   /* find window size */
3977   x = mp_count_bits (X);
3978   if (x <= 7) {
3979     winsize = 2;
3980   } else if (x <= 36) {
3981     winsize = 3;
3982   } else if (x <= 140) {
3983     winsize = 4;
3984   } else if (x <= 450) {
3985     winsize = 5;
3986   } else if (x <= 1303) {
3987     winsize = 6;
3988   } else if (x <= 3529) {
3989     winsize = 7;
3990   } else {
3991     winsize = 8;
3992   }
3993
3994   /* init M array */
3995   /* init first cell */
3996   if ((err = mp_init(&M[1])) != MP_OKAY) {
3997      return err; 
3998   }
3999
4000   /* now init the second half of the array */
4001   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4002     if ((err = mp_init(&M[x])) != MP_OKAY) {
4003       for (y = 1<<(winsize-1); y < x; y++) {
4004         mp_clear (&M[y]);
4005       }
4006       mp_clear(&M[1]);
4007       return err;
4008     }
4009   }
4010
4011   /* create mu, used for Barrett reduction */
4012   if ((err = mp_init (&mu)) != MP_OKAY) {
4013     goto __M;
4014   }
4015   if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
4016     goto __MU;
4017   }
4018
4019   /* create M table
4020    *
4021    * The M table contains powers of the base, 
4022    * e.g. M[x] = G**x mod P
4023    *
4024    * The first half of the table is not 
4025    * computed though accept for M[0] and M[1]
4026    */
4027   if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
4028     goto __MU;
4029   }
4030
4031   /* compute the value at M[1<<(winsize-1)] by squaring 
4032    * M[1] (winsize-1) times 
4033    */
4034   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
4035     goto __MU;
4036   }
4037
4038   for (x = 0; x < (winsize - 1); x++) {
4039     if ((err = mp_sqr (&M[1 << (winsize - 1)], 
4040                        &M[1 << (winsize - 1)])) != MP_OKAY) {
4041       goto __MU;
4042     }
4043     if ((err = mp_reduce (&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
4044       goto __MU;
4045     }
4046   }
4047
4048   /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
4049    * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
4050    */
4051   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
4052     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
4053       goto __MU;
4054     }
4055     if ((err = mp_reduce (&M[x], P, &mu)) != MP_OKAY) {
4056       goto __MU;
4057     }
4058   }
4059
4060   /* setup result */
4061   if ((err = mp_init (&res)) != MP_OKAY) {
4062     goto __MU;
4063   }
4064   mp_set (&res, 1);
4065
4066   /* set initial mode and bit cnt */
4067   mode   = 0;
4068   bitcnt = 1;
4069   buf    = 0;
4070   digidx = X->used - 1;
4071   bitcpy = 0;
4072   bitbuf = 0;
4073
4074   for (;;) {
4075     /* grab next digit as required */
4076     if (--bitcnt == 0) {
4077       /* if digidx == -1 we are out of digits */
4078       if (digidx == -1) {
4079         break;
4080       }
4081       /* read next digit and reset the bitcnt */
4082       buf    = X->dp[digidx--];
4083       bitcnt = DIGIT_BIT;
4084     }
4085
4086     /* grab the next msb from the exponent */
4087     y     = (buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
4088     buf <<= (mp_digit)1;
4089
4090     /* if the bit is zero and mode == 0 then we ignore it
4091      * These represent the leading zero bits before the first 1 bit
4092      * in the exponent.  Technically this opt is not required but it
4093      * does lower the # of trivial squaring/reductions used
4094      */
4095     if (mode == 0 && y == 0) {
4096       continue;
4097     }
4098
4099     /* if the bit is zero and mode == 1 then we square */
4100     if (mode == 1 && y == 0) {
4101       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4102         goto __RES;
4103       }
4104       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4105         goto __RES;
4106       }
4107       continue;
4108     }
4109
4110     /* else we add it to the window */
4111     bitbuf |= (y << (winsize - ++bitcpy));
4112     mode    = 2;
4113
4114     if (bitcpy == winsize) {
4115       /* ok window is filled so square as required and multiply  */
4116       /* square first */
4117       for (x = 0; x < winsize; x++) {
4118         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4119           goto __RES;
4120         }
4121         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4122           goto __RES;
4123         }
4124       }
4125
4126       /* then multiply */
4127       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
4128         goto __RES;
4129       }
4130       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4131         goto __RES;
4132       }
4133
4134       /* empty window and reset */
4135       bitcpy = 0;
4136       bitbuf = 0;
4137       mode   = 1;
4138     }
4139   }
4140
4141   /* if bits remain then square/multiply */
4142   if (mode == 2 && bitcpy > 0) {
4143     /* square then multiply if the bit is set */
4144     for (x = 0; x < bitcpy; x++) {
4145       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4146         goto __RES;
4147       }
4148       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4149         goto __RES;
4150       }
4151
4152       bitbuf <<= 1;
4153       if ((bitbuf & (1 << winsize)) != 0) {
4154         /* then multiply */
4155         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
4156           goto __RES;
4157         }
4158         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4159           goto __RES;
4160         }
4161       }
4162     }
4163   }
4164
4165   mp_exch (&res, Y);
4166   err = MP_OKAY;
4167 __RES:mp_clear (&res);
4168 __MU:mp_clear (&mu);
4169 __M:
4170   mp_clear(&M[1]);
4171   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4172     mp_clear (&M[x]);
4173   }
4174   return err;
4175 }
4176
4177 /* multiplies |a| * |b| and only computes up to digs digits of result
4178  * HAC pp. 595, Algorithm 14.12  Modified so you can control how 
4179  * many digits of output are created.
4180  */
4181 static int
4182 s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4183 {
4184   mp_int  t;
4185   int     res, pa, pb, ix, iy;
4186   mp_digit u;
4187   mp_word r;
4188   mp_digit tmpx, *tmpt, *tmpy;
4189
4190   /* can we use the fast multiplier? */
4191   if (((digs) < MP_WARRAY) &&
4192       MIN (a->used, b->used) < 
4193           (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4194     return fast_s_mp_mul_digs (a, b, c, digs);
4195   }
4196
4197   if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
4198     return res;
4199   }
4200   t.used = digs;
4201
4202   /* compute the digits of the product directly */
4203   pa = a->used;
4204   for (ix = 0; ix < pa; ix++) {
4205     /* set the carry to zero */
4206     u = 0;
4207
4208     /* limit ourselves to making digs digits of output */
4209     pb = MIN (b->used, digs - ix);
4210
4211     /* setup some aliases */
4212     /* copy of the digit from a used within the nested loop */
4213     tmpx = a->dp[ix];
4214     
4215     /* an alias for the destination shifted ix places */
4216     tmpt = t.dp + ix;
4217     
4218     /* an alias for the digits of b */
4219     tmpy = b->dp;
4220
4221     /* compute the columns of the output and propagate the carry */
4222     for (iy = 0; iy < pb; iy++) {
4223       /* compute the column as a mp_word */
4224       r       = ((mp_word)*tmpt) +
4225                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4226                 ((mp_word) u);
4227
4228       /* the new column is the lower part of the result */
4229       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4230
4231       /* get the carry word from the result */
4232       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4233     }
4234     /* set carry if it is placed below digs */
4235     if (ix + iy < digs) {
4236       *tmpt = u;
4237     }
4238   }
4239
4240   mp_clamp (&t);
4241   mp_exch (&t, c);
4242
4243   mp_clear (&t);
4244   return MP_OKAY;
4245 }
4246
4247 /* multiplies |a| * |b| and does not compute the lower digs digits
4248  * [meant to get the higher part of the product]
4249  */
4250 static int
4251 s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4252 {
4253   mp_int  t;
4254   int     res, pa, pb, ix, iy;
4255   mp_digit u;
4256   mp_word r;
4257   mp_digit tmpx, *tmpt, *tmpy;
4258
4259   /* can we use the fast multiplier? */
4260   if (((a->used + b->used + 1) < MP_WARRAY)
4261       && MIN (a->used, b->used) < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4262     return fast_s_mp_mul_high_digs (a, b, c, digs);
4263   }
4264
4265   if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
4266     return res;
4267   }
4268   t.used = a->used + b->used + 1;
4269
4270   pa = a->used;
4271   pb = b->used;
4272   for (ix = 0; ix < pa; ix++) {
4273     /* clear the carry */
4274     u = 0;
4275
4276     /* left hand side of A[ix] * B[iy] */
4277     tmpx = a->dp[ix];
4278
4279     /* alias to the address of where the digits will be stored */
4280     tmpt = &(t.dp[digs]);
4281
4282     /* alias for where to read the right hand side from */
4283     tmpy = b->dp + (digs - ix);
4284
4285     for (iy = digs - ix; iy < pb; iy++) {
4286       /* calculate the double precision result */
4287       r       = ((mp_word)*tmpt) +
4288                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4289                 ((mp_word) u);
4290
4291       /* get the lower part */
4292       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4293
4294       /* carry the carry */
4295       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4296     }
4297     *tmpt = u;
4298   }
4299   mp_clamp (&t);
4300   mp_exch (&t, c);
4301   mp_clear (&t);
4302   return MP_OKAY;
4303 }
4304
4305 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
4306 static int
4307 s_mp_sqr (const mp_int * a, mp_int * b)
4308 {
4309   mp_int  t;
4310   int     res, ix, iy, pa;
4311   mp_word r;
4312   mp_digit u, tmpx, *tmpt;
4313
4314   pa = a->used;
4315   if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
4316     return res;
4317   }
4318
4319   /* default used is maximum possible size */
4320   t.used = 2*pa + 1;
4321
4322   for (ix = 0; ix < pa; ix++) {
4323     /* first calculate the digit at 2*ix */
4324     /* calculate double precision result */
4325     r = ((mp_word) t.dp[2*ix]) +
4326         ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
4327
4328     /* store lower part in result */
4329     t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
4330
4331     /* get the carry */
4332     u           = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4333
4334     /* left hand side of A[ix] * A[iy] */
4335     tmpx        = a->dp[ix];
4336
4337     /* alias for where to store the results */
4338     tmpt        = t.dp + (2*ix + 1);
4339     
4340     for (iy = ix + 1; iy < pa; iy++) {
4341       /* first calculate the product */
4342       r       = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
4343
4344       /* now calculate the double precision result, note we use
4345        * addition instead of *2 since it's easier to optimize
4346        */
4347       r       = ((mp_word) *tmpt) + r + r + ((mp_word) u);
4348
4349       /* store lower part */
4350       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4351
4352       /* get carry */
4353       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4354     }
4355     /* propagate upwards */
4356     while (u != 0) {
4357       r       = ((mp_word) *tmpt) + ((mp_word) u);
4358       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4359       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4360     }
4361   }
4362
4363   mp_clamp (&t);
4364   mp_exch (&t, b);
4365   mp_clear (&t);
4366   return MP_OKAY;
4367 }
4368
4369 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
4370 int
4371 s_mp_sub (const mp_int * a, const mp_int * b, mp_int * c)
4372 {
4373   int     olduse, res, min, max;
4374
4375   /* find sizes */
4376   min = b->used;
4377   max = a->used;
4378
4379   /* init result */
4380   if (c->alloc < max) {
4381     if ((res = mp_grow (c, max)) != MP_OKAY) {
4382       return res;
4383     }
4384   }
4385   olduse = c->used;
4386   c->used = max;
4387
4388   {
4389     register mp_digit u, *tmpa, *tmpb, *tmpc;
4390     register int i;
4391
4392     /* alias for digit pointers */
4393     tmpa = a->dp;
4394     tmpb = b->dp;
4395     tmpc = c->dp;
4396
4397     /* set carry to zero */
4398     u = 0;
4399     for (i = 0; i < min; i++) {
4400       /* T[i] = A[i] - B[i] - U */
4401       *tmpc = *tmpa++ - *tmpb++ - u;
4402
4403       /* U = carry bit of T[i]
4404        * Note this saves performing an AND operation since
4405        * if a carry does occur it will propagate all the way to the
4406        * MSB.  As a result a single shift is enough to get the carry
4407        */
4408       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4409
4410       /* Clear carry from T[i] */
4411       *tmpc++ &= MP_MASK;
4412     }
4413
4414     /* now copy higher words if any, e.g. if A has more digits than B  */
4415     for (; i < max; i++) {
4416       /* T[i] = A[i] - U */
4417       *tmpc = *tmpa++ - u;
4418
4419       /* U = carry bit of T[i] */
4420       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4421
4422       /* Clear carry from T[i] */
4423       *tmpc++ &= MP_MASK;
4424     }
4425
4426     /* clear digits above used (since we may not have grown result above) */
4427     for (i = c->used; i < olduse; i++) {
4428       *tmpc++ = 0;
4429     }
4430   }
4431
4432   mp_clamp (c);
4433   return MP_OKAY;
4434 }