ddraw: Get rid of ddcomimpl.h.
[wine] / dlls / rsaenh / mpi.c
1 /*
2  * dlls/rsaenh/mpi.c
3  * Multi Precision Integer functions
4  *
5  * Copyright 2004 Michael Jung
6  * Based on public domain code by Tom St Denis (tomstdenis@iahu.ca)
7  *
8  * This library is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * This library is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with this library; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
21  */
22
23 /*
24  * This file contains code from the LibTomCrypt cryptographic 
25  * library written by Tom St Denis (tomstdenis@iahu.ca). LibTomCrypt
26  * is in the public domain. The code in this file is tailored to
27  * special requirements. Take a look at http://libtomcrypt.org for the
28  * original version. 
29  */
30
31 #include <stdarg.h>
32 #include "tomcrypt.h"
33
34 /* Known optimal configurations
35  CPU                    /Compiler     /MUL CUTOFF/SQR CUTOFF
36 -------------------------------------------------------------
37  Intel P4 Northwood     /GCC v3.4.1   /        88/       128/LTM 0.32 ;-)
38 */
39 static const int KARATSUBA_MUL_CUTOFF = 88,  /* Min. number of digits before Karatsuba multiplication is used. */
40                  KARATSUBA_SQR_CUTOFF = 128; /* Min. number of digits before Karatsuba squaring is used. */
41
42 static void bn_reverse(unsigned char *s, int len);
43 static int s_mp_add(mp_int *a, mp_int *b, mp_int *c);
44 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y);
45 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
46 static int s_mp_mul_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
47 static int s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
48 static int s_mp_sqr(const mp_int *a, mp_int *b);
49 static int s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c);
50 static int mp_exptmod_fast(const mp_int *G, const mp_int *X, mp_int *P, mp_int *Y, int mode);
51 static int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c);
52 static int mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c);
53 static int mp_karatsuba_sqr(const mp_int *a, mp_int *b);
54
55 /* grow as required */
56 static int mp_grow (mp_int * a, int size)
57 {
58   int     i;
59   mp_digit *tmp;
60
61   /* if the alloc size is smaller alloc more ram */
62   if (a->alloc < size) {
63     /* ensure there are always at least MP_PREC digits extra on top */
64     size += (MP_PREC * 2) - (size % MP_PREC);
65
66     /* reallocate the array a->dp
67      *
68      * We store the return in a temporary variable
69      * in case the operation failed we don't want
70      * to overwrite the dp member of a.
71      */
72     tmp = realloc (a->dp, sizeof (mp_digit) * size);
73     if (tmp == NULL) {
74       /* reallocation failed but "a" is still valid [can be freed] */
75       return MP_MEM;
76     }
77
78     /* reallocation succeeded so set a->dp */
79     a->dp = tmp;
80
81     /* zero excess digits */
82     i        = a->alloc;
83     a->alloc = size;
84     for (; i < a->alloc; i++) {
85       a->dp[i] = 0;
86     }
87   }
88   return MP_OKAY;
89 }
90
91 /* b = a/2 */
92 static int mp_div_2(const mp_int * a, mp_int * b)
93 {
94   int     x, res, oldused;
95
96   /* copy */
97   if (b->alloc < a->used) {
98     if ((res = mp_grow (b, a->used)) != MP_OKAY) {
99       return res;
100     }
101   }
102
103   oldused = b->used;
104   b->used = a->used;
105   {
106     register mp_digit r, rr, *tmpa, *tmpb;
107
108     /* source alias */
109     tmpa = a->dp + b->used - 1;
110
111     /* dest alias */
112     tmpb = b->dp + b->used - 1;
113
114     /* carry */
115     r = 0;
116     for (x = b->used - 1; x >= 0; x--) {
117       /* get the carry for the next iteration */
118       rr = *tmpa & 1;
119
120       /* shift the current digit, add in carry and store */
121       *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
122
123       /* forward carry to next iteration */
124       r = rr;
125     }
126
127     /* zero excess digits */
128     tmpb = b->dp + b->used;
129     for (x = b->used; x < oldused; x++) {
130       *tmpb++ = 0;
131     }
132   }
133   b->sign = a->sign;
134   mp_clamp (b);
135   return MP_OKAY;
136 }
137
138 /* swap the elements of two integers, for cases where you can't simply swap the
139  * mp_int pointers around
140  */
141 static void
142 mp_exch (mp_int * a, mp_int * b)
143 {
144   mp_int  t;
145
146   t  = *a;
147   *a = *b;
148   *b = t;
149 }
150
151 /* init a new mp_int */
152 static int mp_init (mp_int * a)
153 {
154   int i;
155
156   /* allocate memory required and clear it */
157   a->dp = malloc (sizeof (mp_digit) * MP_PREC);
158   if (a->dp == NULL) {
159     return MP_MEM;
160   }
161
162   /* set the digits to zero */
163   for (i = 0; i < MP_PREC; i++) {
164       a->dp[i] = 0;
165   }
166
167   /* set the used to zero, allocated digits to the default precision
168    * and sign to positive */
169   a->used  = 0;
170   a->alloc = MP_PREC;
171   a->sign  = MP_ZPOS;
172
173   return MP_OKAY;
174 }
175
176 /* init an mp_init for a given size */
177 static int mp_init_size (mp_int * a, int size)
178 {
179   int x;
180
181   /* pad size so there are always extra digits */
182   size += (MP_PREC * 2) - (size % MP_PREC);
183
184   /* alloc mem */
185   a->dp = malloc (sizeof (mp_digit) * size);
186   if (a->dp == NULL) {
187     return MP_MEM;
188   }
189
190   /* set the members */
191   a->used  = 0;
192   a->alloc = size;
193   a->sign  = MP_ZPOS;
194
195   /* zero the digits */
196   for (x = 0; x < size; x++) {
197       a->dp[x] = 0;
198   }
199
200   return MP_OKAY;
201 }
202
203 /* clear one (frees)  */
204 static void
205 mp_clear (mp_int * a)
206 {
207   int i;
208
209   /* only do anything if a hasn't been freed previously */
210   if (a->dp != NULL) {
211     /* first zero the digits */
212     for (i = 0; i < a->used; i++) {
213         a->dp[i] = 0;
214     }
215
216     /* free ram */
217     free(a->dp);
218
219     /* reset members to make debugging easier */
220     a->dp    = NULL;
221     a->alloc = a->used = 0;
222     a->sign  = MP_ZPOS;
223   }
224 }
225
226 /* set to zero */
227 static void
228 mp_zero (mp_int * a)
229 {
230   a->sign = MP_ZPOS;
231   a->used = 0;
232   memset (a->dp, 0, sizeof (mp_digit) * a->alloc);
233 }
234
235 /* computes the modular inverse via binary extended euclidean algorithm, 
236  * that is c = 1/a mod b 
237  *
238  * Based on slow invmod except this is optimized for the case where b is 
239  * odd as per HAC Note 14.64 on pp. 610
240  */
241 static int
242 fast_mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
243 {
244   mp_int  x, y, u, v, B, D;
245   int     res, neg;
246
247   /* 2. [modified] b must be odd   */
248   if (mp_iseven (b) == 1) {
249     return MP_VAL;
250   }
251
252   /* init all our temps */
253   if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D, NULL)) != MP_OKAY) {
254      return res;
255   }
256
257   /* x == modulus, y == value to invert */
258   if ((res = mp_copy (b, &x)) != MP_OKAY) {
259     goto __ERR;
260   }
261
262   /* we need y = |a| */
263   if ((res = mp_abs (a, &y)) != MP_OKAY) {
264     goto __ERR;
265   }
266
267   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
268   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
269     goto __ERR;
270   }
271   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
272     goto __ERR;
273   }
274   mp_set (&D, 1);
275
276 top:
277   /* 4.  while u is even do */
278   while (mp_iseven (&u) == 1) {
279     /* 4.1 u = u/2 */
280     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
281       goto __ERR;
282     }
283     /* 4.2 if B is odd then */
284     if (mp_isodd (&B) == 1) {
285       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
286         goto __ERR;
287       }
288     }
289     /* B = B/2 */
290     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
291       goto __ERR;
292     }
293   }
294
295   /* 5.  while v is even do */
296   while (mp_iseven (&v) == 1) {
297     /* 5.1 v = v/2 */
298     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
299       goto __ERR;
300     }
301     /* 5.2 if D is odd then */
302     if (mp_isodd (&D) == 1) {
303       /* D = (D-x)/2 */
304       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
305         goto __ERR;
306       }
307     }
308     /* D = D/2 */
309     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
310       goto __ERR;
311     }
312   }
313
314   /* 6.  if u >= v then */
315   if (mp_cmp (&u, &v) != MP_LT) {
316     /* u = u - v, B = B - D */
317     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
318       goto __ERR;
319     }
320
321     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
322       goto __ERR;
323     }
324   } else {
325     /* v - v - u, D = D - B */
326     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
327       goto __ERR;
328     }
329
330     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
331       goto __ERR;
332     }
333   }
334
335   /* if not zero goto step 4 */
336   if (mp_iszero (&u) == 0) {
337     goto top;
338   }
339
340   /* now a = C, b = D, gcd == g*v */
341
342   /* if v != 1 then there is no inverse */
343   if (mp_cmp_d (&v, 1) != MP_EQ) {
344     res = MP_VAL;
345     goto __ERR;
346   }
347
348   /* b is now the inverse */
349   neg = a->sign;
350   while (D.sign == MP_NEG) {
351     if ((res = mp_add (&D, b, &D)) != MP_OKAY) {
352       goto __ERR;
353     }
354   }
355   mp_exch (&D, c);
356   c->sign = neg;
357   res = MP_OKAY;
358
359 __ERR:mp_clear_multi (&x, &y, &u, &v, &B, &D, NULL);
360   return res;
361 }
362
363 /* computes xR**-1 == x (mod N) via Montgomery Reduction
364  *
365  * This is an optimized implementation of montgomery_reduce
366  * which uses the comba method to quickly calculate the columns of the
367  * reduction.
368  *
369  * Based on Algorithm 14.32 on pp.601 of HAC.
370 */
371 static int
372 fast_mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
373 {
374   int     ix, res, olduse;
375   mp_word W[MP_WARRAY];
376
377   /* get old used count */
378   olduse = x->used;
379
380   /* grow a as required */
381   if (x->alloc < n->used + 1) {
382     if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
383       return res;
384     }
385   }
386
387   /* first we have to get the digits of the input into
388    * an array of double precision words W[...]
389    */
390   {
391     register mp_word *_W;
392     register mp_digit *tmpx;
393
394     /* alias for the W[] array */
395     _W   = W;
396
397     /* alias for the digits of  x*/
398     tmpx = x->dp;
399
400     /* copy the digits of a into W[0..a->used-1] */
401     for (ix = 0; ix < x->used; ix++) {
402       *_W++ = *tmpx++;
403     }
404
405     /* zero the high words of W[a->used..m->used*2] */
406     for (; ix < n->used * 2 + 1; ix++) {
407       *_W++ = 0;
408     }
409   }
410
411   /* now we proceed to zero successive digits
412    * from the least significant upwards
413    */
414   for (ix = 0; ix < n->used; ix++) {
415     /* mu = ai * m' mod b
416      *
417      * We avoid a double precision multiplication (which isn't required)
418      * by casting the value down to a mp_digit.  Note this requires
419      * that W[ix-1] have  the carry cleared (see after the inner loop)
420      */
421     register mp_digit mu;
422     mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
423
424     /* a = a + mu * m * b**i
425      *
426      * This is computed in place and on the fly.  The multiplication
427      * by b**i is handled by offsetting which columns the results
428      * are added to.
429      *
430      * Note the comba method normally doesn't handle carries in the
431      * inner loop In this case we fix the carry from the previous
432      * column since the Montgomery reduction requires digits of the
433      * result (so far) [see above] to work.  This is
434      * handled by fixing up one carry after the inner loop.  The
435      * carry fixups are done in order so after these loops the
436      * first m->used words of W[] have the carries fixed
437      */
438     {
439       register int iy;
440       register mp_digit *tmpn;
441       register mp_word *_W;
442
443       /* alias for the digits of the modulus */
444       tmpn = n->dp;
445
446       /* Alias for the columns set by an offset of ix */
447       _W = W + ix;
448
449       /* inner loop */
450       for (iy = 0; iy < n->used; iy++) {
451           *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
452       }
453     }
454
455     /* now fix carry for next digit, W[ix+1] */
456     W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
457   }
458
459   /* now we have to propagate the carries and
460    * shift the words downward [all those least
461    * significant digits we zeroed].
462    */
463   {
464     register mp_digit *tmpx;
465     register mp_word *_W, *_W1;
466
467     /* nox fix rest of carries */
468
469     /* alias for current word */
470     _W1 = W + ix;
471
472     /* alias for next word, where the carry goes */
473     _W = W + ++ix;
474
475     for (; ix <= n->used * 2 + 1; ix++) {
476       *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
477     }
478
479     /* copy out, A = A/b**n
480      *
481      * The result is A/b**n but instead of converting from an
482      * array of mp_word to mp_digit than calling mp_rshd
483      * we just copy them in the right order
484      */
485
486     /* alias for destination word */
487     tmpx = x->dp;
488
489     /* alias for shifted double precision result */
490     _W = W + n->used;
491
492     for (ix = 0; ix < n->used + 1; ix++) {
493       *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
494     }
495
496     /* zero oldused digits, if the input a was larger than
497      * m->used+1 we'll have to clear the digits
498      */
499     for (; ix < olduse; ix++) {
500       *tmpx++ = 0;
501     }
502   }
503
504   /* set the max used and clamp */
505   x->used = n->used + 1;
506   mp_clamp (x);
507
508   /* if A >= m then A = A - m */
509   if (mp_cmp_mag (x, n) != MP_LT) {
510     return s_mp_sub (x, n, x);
511   }
512   return MP_OKAY;
513 }
514
515 /* Fast (comba) multiplier
516  *
517  * This is the fast column-array [comba] multiplier.  It is 
518  * designed to compute the columns of the product first 
519  * then handle the carries afterwards.  This has the effect 
520  * of making the nested loops that compute the columns very
521  * simple and schedulable on super-scalar processors.
522  *
523  * This has been modified to produce a variable number of 
524  * digits of output so if say only a half-product is required 
525  * you don't have to compute the upper half (a feature 
526  * required for fast Barrett reduction).
527  *
528  * Based on Algorithm 14.12 on pp.595 of HAC.
529  *
530  */
531 static int
532 fast_s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
533 {
534   int     olduse, res, pa, ix, iz;
535   mp_digit W[MP_WARRAY];
536   register mp_word  _W;
537
538   /* grow the destination as required */
539   if (c->alloc < digs) {
540     if ((res = mp_grow (c, digs)) != MP_OKAY) {
541       return res;
542     }
543   }
544
545   /* number of output digits to produce */
546   pa = MIN(digs, a->used + b->used);
547
548   /* clear the carry */
549   _W = 0;
550   for (ix = 0; ix <= pa; ix++) { 
551       int      tx, ty;
552       int      iy;
553       mp_digit *tmpx, *tmpy;
554
555       /* get offsets into the two bignums */
556       ty = MIN(b->used-1, ix);
557       tx = ix - ty;
558
559       /* setup temp aliases */
560       tmpx = a->dp + tx;
561       tmpy = b->dp + ty;
562
563       /* This is the number of times the loop will iterate, essentially it's
564          while (tx++ < a->used && ty-- >= 0) { ... }
565        */
566       iy = MIN(a->used-tx, ty+1);
567
568       /* execute loop */
569       for (iz = 0; iz < iy; ++iz) {
570          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
571       }
572
573       /* store term */
574       W[ix] = ((mp_digit)_W) & MP_MASK;
575
576       /* make next carry */
577       _W = _W >> ((mp_word)DIGIT_BIT);
578   }
579
580   /* setup dest */
581   olduse  = c->used;
582   c->used = digs;
583
584   {
585     register mp_digit *tmpc;
586     tmpc = c->dp;
587     for (ix = 0; ix < digs; ix++) {
588       /* now extract the previous digit [below the carry] */
589       *tmpc++ = W[ix];
590     }
591
592     /* clear unused digits [that existed in the old copy of c] */
593     for (; ix < olduse; ix++) {
594       *tmpc++ = 0;
595     }
596   }
597   mp_clamp (c);
598   return MP_OKAY;
599 }
600
601 /* this is a modified version of fast_s_mul_digs that only produces
602  * output digits *above* digs.  See the comments for fast_s_mul_digs
603  * to see how it works.
604  *
605  * This is used in the Barrett reduction since for one of the multiplications
606  * only the higher digits were needed.  This essentially halves the work.
607  *
608  * Based on Algorithm 14.12 on pp.595 of HAC.
609  */
610 static int
611 fast_s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
612 {
613   int     olduse, res, pa, ix, iz;
614   mp_digit W[MP_WARRAY];
615   mp_word  _W;
616
617   /* grow the destination as required */
618   pa = a->used + b->used;
619   if (c->alloc < pa) {
620     if ((res = mp_grow (c, pa)) != MP_OKAY) {
621       return res;
622     }
623   }
624
625   /* number of output digits to produce */
626   pa = a->used + b->used;
627   _W = 0;
628   for (ix = digs; ix <= pa; ix++) { 
629       int      tx, ty, iy;
630       mp_digit *tmpx, *tmpy;
631
632       /* get offsets into the two bignums */
633       ty = MIN(b->used-1, ix);
634       tx = ix - ty;
635
636       /* setup temp aliases */
637       tmpx = a->dp + tx;
638       tmpy = b->dp + ty;
639
640       /* This is the number of times the loop will iterate, essentially it's
641          while (tx++ < a->used && ty-- >= 0) { ... }
642        */
643       iy = MIN(a->used-tx, ty+1);
644
645       /* execute loop */
646       for (iz = 0; iz < iy; iz++) {
647          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
648       }
649
650       /* store term */
651       W[ix] = ((mp_digit)_W) & MP_MASK;
652
653       /* make next carry */
654       _W = _W >> ((mp_word)DIGIT_BIT);
655   }
656
657   /* setup dest */
658   olduse  = c->used;
659   c->used = pa;
660
661   {
662     register mp_digit *tmpc;
663
664     tmpc = c->dp + digs;
665     for (ix = digs; ix <= pa; ix++) {
666       /* now extract the previous digit [below the carry] */
667       *tmpc++ = W[ix];
668     }
669
670     /* clear unused digits [that existed in the old copy of c] */
671     for (; ix < olduse; ix++) {
672       *tmpc++ = 0;
673     }
674   }
675   mp_clamp (c);
676   return MP_OKAY;
677 }
678
679 /* fast squaring
680  *
681  * This is the comba method where the columns of the product
682  * are computed first then the carries are computed.  This
683  * has the effect of making a very simple inner loop that
684  * is executed the most
685  *
686  * W2 represents the outer products and W the inner.
687  *
688  * A further optimizations is made because the inner
689  * products are of the form "A * B * 2".  The *2 part does
690  * not need to be computed until the end which is good
691  * because 64-bit shifts are slow!
692  *
693  * Based on Algorithm 14.16 on pp.597 of HAC.
694  *
695  */
696 /* the jist of squaring...
697
698 you do like mult except the offset of the tmpx [one that starts closer to zero]
699 can't equal the offset of tmpy.  So basically you set up iy like before then you min it with
700 (ty-tx) so that it never happens.  You double all those you add in the inner loop
701
702 After that loop you do the squares and add them in.
703
704 Remove W2 and don't memset W
705
706 */
707
708 static int fast_s_mp_sqr (const mp_int * a, mp_int * b)
709 {
710   int       olduse, res, pa, ix, iz;
711   mp_digit   W[MP_WARRAY], *tmpx;
712   mp_word   W1;
713
714   /* grow the destination as required */
715   pa = a->used + a->used;
716   if (b->alloc < pa) {
717     if ((res = mp_grow (b, pa)) != MP_OKAY) {
718       return res;
719     }
720   }
721
722   /* number of output digits to produce */
723   W1 = 0;
724   for (ix = 0; ix <= pa; ix++) { 
725       int      tx, ty, iy;
726       mp_word  _W;
727       mp_digit *tmpy;
728
729       /* clear counter */
730       _W = 0;
731
732       /* get offsets into the two bignums */
733       ty = MIN(a->used-1, ix);
734       tx = ix - ty;
735
736       /* setup temp aliases */
737       tmpx = a->dp + tx;
738       tmpy = a->dp + ty;
739
740       /* This is the number of times the loop will iterate, essentially it's
741          while (tx++ < a->used && ty-- >= 0) { ... }
742        */
743       iy = MIN(a->used-tx, ty+1);
744
745       /* now for squaring tx can never equal ty 
746        * we halve the distance since they approach at a rate of 2x
747        * and we have to round because odd cases need to be executed
748        */
749       iy = MIN(iy, (ty-tx+1)>>1);
750
751       /* execute loop */
752       for (iz = 0; iz < iy; iz++) {
753          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
754       }
755
756       /* double the inner product and add carry */
757       _W = _W + _W + W1;
758
759       /* even columns have the square term in them */
760       if ((ix&1) == 0) {
761          _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
762       }
763
764       /* store it */
765       W[ix] = _W;
766
767       /* make next carry */
768       W1 = _W >> ((mp_word)DIGIT_BIT);
769   }
770
771   /* setup dest */
772   olduse  = b->used;
773   b->used = a->used+a->used;
774
775   {
776     mp_digit *tmpb;
777     tmpb = b->dp;
778     for (ix = 0; ix < pa; ix++) {
779       *tmpb++ = W[ix] & MP_MASK;
780     }
781
782     /* clear unused digits [that existed in the old copy of c] */
783     for (; ix < olduse; ix++) {
784       *tmpb++ = 0;
785     }
786   }
787   mp_clamp (b);
788   return MP_OKAY;
789 }
790
791 /* computes a = 2**b 
792  *
793  * Simple algorithm which zeroes the int, grows it then just sets one bit
794  * as required.
795  */
796 int
797 mp_2expt (mp_int * a, int b)
798 {
799   int     res;
800
801   /* zero a as per default */
802   mp_zero (a);
803
804   /* grow a to accommodate the single bit */
805   if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
806     return res;
807   }
808
809   /* set the used count of where the bit will go */
810   a->used = b / DIGIT_BIT + 1;
811
812   /* put the single bit in its place */
813   a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
814
815   return MP_OKAY;
816 }
817
818 /* b = |a| 
819  *
820  * Simple function copies the input and fixes the sign to positive
821  */
822 int
823 mp_abs (const mp_int * a, mp_int * b)
824 {
825   int     res;
826
827   /* copy a to b */
828   if (a != b) {
829      if ((res = mp_copy (a, b)) != MP_OKAY) {
830        return res;
831      }
832   }
833
834   /* force the sign of b to positive */
835   b->sign = MP_ZPOS;
836
837   return MP_OKAY;
838 }
839
840 /* high level addition (handles signs) */
841 int mp_add (mp_int * a, mp_int * b, mp_int * c)
842 {
843   int     sa, sb, res;
844
845   /* get sign of both inputs */
846   sa = a->sign;
847   sb = b->sign;
848
849   /* handle two cases, not four */
850   if (sa == sb) {
851     /* both positive or both negative */
852     /* add their magnitudes, copy the sign */
853     c->sign = sa;
854     res = s_mp_add (a, b, c);
855   } else {
856     /* one positive, the other negative */
857     /* subtract the one with the greater magnitude from */
858     /* the one of the lesser magnitude.  The result gets */
859     /* the sign of the one with the greater magnitude. */
860     if (mp_cmp_mag (a, b) == MP_LT) {
861       c->sign = sb;
862       res = s_mp_sub (b, a, c);
863     } else {
864       c->sign = sa;
865       res = s_mp_sub (a, b, c);
866     }
867   }
868   return res;
869 }
870
871
872 /* single digit addition */
873 int
874 mp_add_d (mp_int * a, mp_digit b, mp_int * c)
875 {
876   int     res, ix, oldused;
877   mp_digit *tmpa, *tmpc, mu;
878
879   /* grow c as required */
880   if (c->alloc < a->used + 1) {
881      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
882         return res;
883      }
884   }
885
886   /* if a is negative and |a| >= b, call c = |a| - b */
887   if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
888      /* temporarily fix sign of a */
889      a->sign = MP_ZPOS;
890
891      /* c = |a| - b */
892      res = mp_sub_d(a, b, c);
893
894      /* fix sign  */
895      a->sign = c->sign = MP_NEG;
896
897      return res;
898   }
899
900   /* old number of used digits in c */
901   oldused = c->used;
902
903   /* sign always positive */
904   c->sign = MP_ZPOS;
905
906   /* source alias */
907   tmpa    = a->dp;
908
909   /* destination alias */
910   tmpc    = c->dp;
911
912   /* if a is positive */
913   if (a->sign == MP_ZPOS) {
914      /* add digit, after this we're propagating
915       * the carry.
916       */
917      *tmpc   = *tmpa++ + b;
918      mu      = *tmpc >> DIGIT_BIT;
919      *tmpc++ &= MP_MASK;
920
921      /* now handle rest of the digits */
922      for (ix = 1; ix < a->used; ix++) {
923         *tmpc   = *tmpa++ + mu;
924         mu      = *tmpc >> DIGIT_BIT;
925         *tmpc++ &= MP_MASK;
926      }
927      /* set final carry */
928      ix++;
929      *tmpc++  = mu;
930
931      /* setup size */
932      c->used = a->used + 1;
933   } else {
934      /* a was negative and |a| < b */
935      c->used  = 1;
936
937      /* the result is a single digit */
938      if (a->used == 1) {
939         *tmpc++  =  b - a->dp[0];
940      } else {
941         *tmpc++  =  b;
942      }
943
944      /* setup count so the clearing of oldused
945       * can fall through correctly
946       */
947      ix       = 1;
948   }
949
950   /* now zero to oldused */
951   while (ix++ < oldused) {
952      *tmpc++ = 0;
953   }
954   mp_clamp(c);
955
956   return MP_OKAY;
957 }
958
959 /* trim unused digits 
960  *
961  * This is used to ensure that leading zero digits are
962  * trimed and the leading "used" digit will be non-zero
963  * Typically very fast.  Also fixes the sign if there
964  * are no more leading digits
965  */
966 void
967 mp_clamp (mp_int * a)
968 {
969   /* decrease used while the most significant digit is
970    * zero.
971    */
972   while (a->used > 0 && a->dp[a->used - 1] == 0) {
973     --(a->used);
974   }
975
976   /* reset the sign flag if used == 0 */
977   if (a->used == 0) {
978     a->sign = MP_ZPOS;
979   }
980 }
981
982 void mp_clear_multi(mp_int *mp, ...) 
983 {
984     mp_int* next_mp = mp;
985     va_list args;
986     va_start(args, mp);
987     while (next_mp != NULL) {
988         mp_clear(next_mp);
989         next_mp = va_arg(args, mp_int*);
990     }
991     va_end(args);
992 }
993
994 /* compare two ints (signed)*/
995 int
996 mp_cmp (const mp_int * a, const mp_int * b)
997 {
998   /* compare based on sign */
999   if (a->sign != b->sign) {
1000      if (a->sign == MP_NEG) {
1001         return MP_LT;
1002      } else {
1003         return MP_GT;
1004      }
1005   }
1006   
1007   /* compare digits */
1008   if (a->sign == MP_NEG) {
1009      /* if negative compare opposite direction */
1010      return mp_cmp_mag(b, a);
1011   } else {
1012      return mp_cmp_mag(a, b);
1013   }
1014 }
1015
1016 /* compare a digit */
1017 int mp_cmp_d(const mp_int * a, mp_digit b)
1018 {
1019   /* compare based on sign */
1020   if (a->sign == MP_NEG) {
1021     return MP_LT;
1022   }
1023
1024   /* compare based on magnitude */
1025   if (a->used > 1) {
1026     return MP_GT;
1027   }
1028
1029   /* compare the only digit of a to b */
1030   if (a->dp[0] > b) {
1031     return MP_GT;
1032   } else if (a->dp[0] < b) {
1033     return MP_LT;
1034   } else {
1035     return MP_EQ;
1036   }
1037 }
1038
1039 /* compare maginitude of two ints (unsigned) */
1040 int mp_cmp_mag (const mp_int * a, const mp_int * b)
1041 {
1042   int     n;
1043   mp_digit *tmpa, *tmpb;
1044
1045   /* compare based on # of non-zero digits */
1046   if (a->used > b->used) {
1047     return MP_GT;
1048   }
1049   
1050   if (a->used < b->used) {
1051     return MP_LT;
1052   }
1053
1054   /* alias for a */
1055   tmpa = a->dp + (a->used - 1);
1056
1057   /* alias for b */
1058   tmpb = b->dp + (a->used - 1);
1059
1060   /* compare based on digits  */
1061   for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
1062     if (*tmpa > *tmpb) {
1063       return MP_GT;
1064     }
1065
1066     if (*tmpa < *tmpb) {
1067       return MP_LT;
1068     }
1069   }
1070   return MP_EQ;
1071 }
1072
1073 static const int lnz[16] = { 
1074    4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
1075 };
1076
1077 /* Counts the number of lsbs which are zero before the first zero bit */
1078 int mp_cnt_lsb(const mp_int *a)
1079 {
1080    int x;
1081    mp_digit q, qq;
1082
1083    /* easy out */
1084    if (mp_iszero(a) == 1) {
1085       return 0;
1086    }
1087
1088    /* scan lower digits until non-zero */
1089    for (x = 0; x < a->used && a->dp[x] == 0; x++);
1090    q = a->dp[x];
1091    x *= DIGIT_BIT;
1092
1093    /* now scan this digit until a 1 is found */
1094    if ((q & 1) == 0) {
1095       do {
1096          qq  = q & 15;
1097          x  += lnz[qq];
1098          q >>= 4;
1099       } while (qq == 0);
1100    }
1101    return x;
1102 }
1103
1104 /* copy, b = a */
1105 int
1106 mp_copy (const mp_int * a, mp_int * b)
1107 {
1108   int     res, n;
1109
1110   /* if dst == src do nothing */
1111   if (a == b) {
1112     return MP_OKAY;
1113   }
1114
1115   /* grow dest */
1116   if (b->alloc < a->used) {
1117      if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1118         return res;
1119      }
1120   }
1121
1122   /* zero b and copy the parameters over */
1123   {
1124     register mp_digit *tmpa, *tmpb;
1125
1126     /* pointer aliases */
1127
1128     /* source */
1129     tmpa = a->dp;
1130
1131     /* destination */
1132     tmpb = b->dp;
1133
1134     /* copy all the digits */
1135     for (n = 0; n < a->used; n++) {
1136       *tmpb++ = *tmpa++;
1137     }
1138
1139     /* clear high digits */
1140     for (; n < b->used; n++) {
1141       *tmpb++ = 0;
1142     }
1143   }
1144
1145   /* copy used count and sign */
1146   b->used = a->used;
1147   b->sign = a->sign;
1148   return MP_OKAY;
1149 }
1150
1151 /* returns the number of bits in an int */
1152 int
1153 mp_count_bits (const mp_int * a)
1154 {
1155   int     r;
1156   mp_digit q;
1157
1158   /* shortcut */
1159   if (a->used == 0) {
1160     return 0;
1161   }
1162
1163   /* get number of digits and add that */
1164   r = (a->used - 1) * DIGIT_BIT;
1165   
1166   /* take the last digit and count the bits in it */
1167   q = a->dp[a->used - 1];
1168   while (q > 0) {
1169     ++r;
1170     q >>= ((mp_digit) 1);
1171   }
1172   return r;
1173 }
1174
1175 /* calc a value mod 2**b */
1176 static int
1177 mp_mod_2d (const mp_int * a, int b, mp_int * c)
1178 {
1179   int     x, res;
1180
1181   /* if b is <= 0 then zero the int */
1182   if (b <= 0) {
1183     mp_zero (c);
1184     return MP_OKAY;
1185   }
1186
1187   /* if the modulus is larger than the value than return */
1188   if (b > a->used * DIGIT_BIT) {
1189     res = mp_copy (a, c);
1190     return res;
1191   }
1192
1193   /* copy */
1194   if ((res = mp_copy (a, c)) != MP_OKAY) {
1195     return res;
1196   }
1197
1198   /* zero digits above the last digit of the modulus */
1199   for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
1200     c->dp[x] = 0;
1201   }
1202   /* clear the digit that is not completely outside/inside the modulus */
1203   c->dp[b / DIGIT_BIT] &= (1 << ((mp_digit)b % DIGIT_BIT)) - 1;
1204   mp_clamp (c);
1205   return MP_OKAY;
1206 }
1207
1208 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
1209 static int mp_div_2d (const mp_int * a, int b, mp_int * c, mp_int * d)
1210 {
1211   mp_digit D, r, rr;
1212   int     x, res;
1213   mp_int  t;
1214
1215
1216   /* if the shift count is <= 0 then we do no work */
1217   if (b <= 0) {
1218     res = mp_copy (a, c);
1219     if (d != NULL) {
1220       mp_zero (d);
1221     }
1222     return res;
1223   }
1224
1225   if ((res = mp_init (&t)) != MP_OKAY) {
1226     return res;
1227   }
1228
1229   /* get the remainder */
1230   if (d != NULL) {
1231     if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
1232       mp_clear (&t);
1233       return res;
1234     }
1235   }
1236
1237   /* copy */
1238   if ((res = mp_copy (a, c)) != MP_OKAY) {
1239     mp_clear (&t);
1240     return res;
1241   }
1242
1243   /* shift by as many digits in the bit count */
1244   if (b >= DIGIT_BIT) {
1245     mp_rshd (c, b / DIGIT_BIT);
1246   }
1247
1248   /* shift any bit count < DIGIT_BIT */
1249   D = (mp_digit) (b % DIGIT_BIT);
1250   if (D != 0) {
1251     register mp_digit *tmpc, mask, shift;
1252
1253     /* mask */
1254     mask = (((mp_digit)1) << D) - 1;
1255
1256     /* shift for lsb */
1257     shift = DIGIT_BIT - D;
1258
1259     /* alias */
1260     tmpc = c->dp + (c->used - 1);
1261
1262     /* carry */
1263     r = 0;
1264     for (x = c->used - 1; x >= 0; x--) {
1265       /* get the lower  bits of this word in a temp */
1266       rr = *tmpc & mask;
1267
1268       /* shift the current word and mix in the carry bits from the previous word */
1269       *tmpc = (*tmpc >> D) | (r << shift);
1270       --tmpc;
1271
1272       /* set the carry to the carry bits of the current word found above */
1273       r = rr;
1274     }
1275   }
1276   mp_clamp (c);
1277   if (d != NULL) {
1278     mp_exch (&t, d);
1279   }
1280   mp_clear (&t);
1281   return MP_OKAY;
1282 }
1283
1284 /* shift left a certain amount of digits */
1285 static int mp_lshd (mp_int * a, int b)
1286 {
1287   int     x, res;
1288
1289   /* if its less than zero return */
1290   if (b <= 0) {
1291     return MP_OKAY;
1292   }
1293
1294   /* grow to fit the new digits */
1295   if (a->alloc < a->used + b) {
1296      if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
1297        return res;
1298      }
1299   }
1300
1301   {
1302     register mp_digit *top, *bottom;
1303
1304     /* increment the used by the shift amount then copy upwards */
1305     a->used += b;
1306
1307     /* top */
1308     top = a->dp + a->used - 1;
1309
1310     /* base */
1311     bottom = a->dp + a->used - 1 - b;
1312
1313     /* much like mp_rshd this is implemented using a sliding window
1314      * except the window goes the otherway around.  Copying from
1315      * the bottom to the top.  see bn_mp_rshd.c for more info.
1316      */
1317     for (x = a->used - 1; x >= b; x--) {
1318       *top-- = *bottom--;
1319     }
1320
1321     /* zero the lower digits */
1322     top = a->dp;
1323     for (x = 0; x < b; x++) {
1324       *top++ = 0;
1325     }
1326   }
1327   return MP_OKAY;
1328 }
1329
1330 /* shift left by a certain bit count */
1331 static int mp_mul_2d (const mp_int * a, int b, mp_int * c)
1332 {
1333   mp_digit d;
1334   int      res;
1335
1336   /* copy */
1337   if (a != c) {
1338      if ((res = mp_copy (a, c)) != MP_OKAY) {
1339        return res;
1340      }
1341   }
1342
1343   if (c->alloc < c->used + b/DIGIT_BIT + 1) {
1344      if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
1345        return res;
1346      }
1347   }
1348
1349   /* shift by as many digits in the bit count */
1350   if (b >= DIGIT_BIT) {
1351     if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
1352       return res;
1353     }
1354   }
1355
1356   /* shift any bit count < DIGIT_BIT */
1357   d = (mp_digit) (b % DIGIT_BIT);
1358   if (d != 0) {
1359     register mp_digit *tmpc, shift, mask, r, rr;
1360     register int x;
1361
1362     /* bitmask for carries */
1363     mask = (((mp_digit)1) << d) - 1;
1364
1365     /* shift for msbs */
1366     shift = DIGIT_BIT - d;
1367
1368     /* alias */
1369     tmpc = c->dp;
1370
1371     /* carry */
1372     r    = 0;
1373     for (x = 0; x < c->used; x++) {
1374       /* get the higher bits of the current word */
1375       rr = (*tmpc >> shift) & mask;
1376
1377       /* shift the current word and OR in the carry */
1378       *tmpc = ((*tmpc << d) | r) & MP_MASK;
1379       ++tmpc;
1380
1381       /* set the carry to the carry bits of the current word */
1382       r = rr;
1383     }
1384
1385     /* set final carry */
1386     if (r != 0) {
1387        c->dp[(c->used)++] = r;
1388     }
1389   }
1390   mp_clamp (c);
1391   return MP_OKAY;
1392 }
1393
1394 /* multiply by a digit */
1395 static int
1396 mp_mul_d (const mp_int * a, mp_digit b, mp_int * c)
1397 {
1398   mp_digit u, *tmpa, *tmpc;
1399   mp_word  r;
1400   int      ix, res, olduse;
1401
1402   /* make sure c is big enough to hold a*b */
1403   if (c->alloc < a->used + 1) {
1404     if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
1405       return res;
1406     }
1407   }
1408
1409   /* get the original destinations used count */
1410   olduse = c->used;
1411
1412   /* set the sign */
1413   c->sign = a->sign;
1414
1415   /* alias for a->dp [source] */
1416   tmpa = a->dp;
1417
1418   /* alias for c->dp [dest] */
1419   tmpc = c->dp;
1420
1421   /* zero carry */
1422   u = 0;
1423
1424   /* compute columns */
1425   for (ix = 0; ix < a->used; ix++) {
1426     /* compute product and carry sum for this term */
1427     r       = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
1428
1429     /* mask off higher bits to get a single digit */
1430     *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
1431
1432     /* send carry into next iteration */
1433     u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
1434   }
1435
1436   /* store final carry [if any] */
1437   *tmpc++ = u;
1438
1439   /* now zero digits above the top */
1440   while (ix++ < olduse) {
1441      *tmpc++ = 0;
1442   }
1443
1444   /* set used count */
1445   c->used = a->used + 1;
1446   mp_clamp(c);
1447
1448   return MP_OKAY;
1449 }
1450
1451 /* integer signed division. 
1452  * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
1453  * HAC pp.598 Algorithm 14.20
1454  *
1455  * Note that the description in HAC is horribly 
1456  * incomplete.  For example, it doesn't consider 
1457  * the case where digits are removed from 'x' in 
1458  * the inner loop.  It also doesn't consider the 
1459  * case that y has fewer than three digits, etc..
1460  *
1461  * The overall algorithm is as described as 
1462  * 14.20 from HAC but fixed to treat these cases.
1463 */
1464 static int mp_div (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
1465 {
1466   mp_int  q, x, y, t1, t2;
1467   int     res, n, t, i, norm, neg;
1468
1469   /* is divisor zero ? */
1470   if (mp_iszero (b) == 1) {
1471     return MP_VAL;
1472   }
1473
1474   /* if a < b then q=0, r = a */
1475   if (mp_cmp_mag (a, b) == MP_LT) {
1476     if (d != NULL) {
1477       res = mp_copy (a, d);
1478     } else {
1479       res = MP_OKAY;
1480     }
1481     if (c != NULL) {
1482       mp_zero (c);
1483     }
1484     return res;
1485   }
1486
1487   if ((res = mp_init_size (&q, a->used + 2)) != MP_OKAY) {
1488     return res;
1489   }
1490   q.used = a->used + 2;
1491
1492   if ((res = mp_init (&t1)) != MP_OKAY) {
1493     goto __Q;
1494   }
1495
1496   if ((res = mp_init (&t2)) != MP_OKAY) {
1497     goto __T1;
1498   }
1499
1500   if ((res = mp_init_copy (&x, a)) != MP_OKAY) {
1501     goto __T2;
1502   }
1503
1504   if ((res = mp_init_copy (&y, b)) != MP_OKAY) {
1505     goto __X;
1506   }
1507
1508   /* fix the sign */
1509   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1510   x.sign = y.sign = MP_ZPOS;
1511
1512   /* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1513   norm = mp_count_bits(&y) % DIGIT_BIT;
1514   if (norm < DIGIT_BIT-1) {
1515      norm = (DIGIT_BIT-1) - norm;
1516      if ((res = mp_mul_2d (&x, norm, &x)) != MP_OKAY) {
1517        goto __Y;
1518      }
1519      if ((res = mp_mul_2d (&y, norm, &y)) != MP_OKAY) {
1520        goto __Y;
1521      }
1522   } else {
1523      norm = 0;
1524   }
1525
1526   /* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
1527   n = x.used - 1;
1528   t = y.used - 1;
1529
1530   /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1531   if ((res = mp_lshd (&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1532     goto __Y;
1533   }
1534
1535   while (mp_cmp (&x, &y) != MP_LT) {
1536     ++(q.dp[n - t]);
1537     if ((res = mp_sub (&x, &y, &x)) != MP_OKAY) {
1538       goto __Y;
1539     }
1540   }
1541
1542   /* reset y by shifting it back down */
1543   mp_rshd (&y, n - t);
1544
1545   /* step 3. for i from n down to (t + 1) */
1546   for (i = n; i >= (t + 1); i--) {
1547     if (i > x.used) {
1548       continue;
1549     }
1550
1551     /* step 3.1 if xi == yt then set q{i-t-1} to b-1, 
1552      * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1553     if (x.dp[i] == y.dp[t]) {
1554       q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1555     } else {
1556       mp_word tmp;
1557       tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1558       tmp |= ((mp_word) x.dp[i - 1]);
1559       tmp /= ((mp_word) y.dp[t]);
1560       if (tmp > (mp_word) MP_MASK)
1561         tmp = MP_MASK;
1562       q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1563     }
1564
1565     /* while (q{i-t-1} * (yt * b + y{t-1})) > 
1566              xi * b**2 + xi-1 * b + xi-2 
1567      
1568        do q{i-t-1} -= 1; 
1569     */
1570     q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1571     do {
1572       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1573
1574       /* find left hand */
1575       mp_zero (&t1);
1576       t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1577       t1.dp[1] = y.dp[t];
1578       t1.used = 2;
1579       if ((res = mp_mul_d (&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1580         goto __Y;
1581       }
1582
1583       /* find right hand */
1584       t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1585       t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1586       t2.dp[2] = x.dp[i];
1587       t2.used = 3;
1588     } while (mp_cmp_mag(&t1, &t2) == MP_GT);
1589
1590     /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1591     if ((res = mp_mul_d (&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1592       goto __Y;
1593     }
1594
1595     if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1596       goto __Y;
1597     }
1598
1599     if ((res = mp_sub (&x, &t1, &x)) != MP_OKAY) {
1600       goto __Y;
1601     }
1602
1603     /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1604     if (x.sign == MP_NEG) {
1605       if ((res = mp_copy (&y, &t1)) != MP_OKAY) {
1606         goto __Y;
1607       }
1608       if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1609         goto __Y;
1610       }
1611       if ((res = mp_add (&x, &t1, &x)) != MP_OKAY) {
1612         goto __Y;
1613       }
1614
1615       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1616     }
1617   }
1618
1619   /* now q is the quotient and x is the remainder 
1620    * [which we have to normalize] 
1621    */
1622   
1623   /* get sign before writing to c */
1624   x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1625
1626   if (c != NULL) {
1627     mp_clamp (&q);
1628     mp_exch (&q, c);
1629     c->sign = neg;
1630   }
1631
1632   if (d != NULL) {
1633     mp_div_2d (&x, norm, &x, NULL);
1634     mp_exch (&x, d);
1635   }
1636
1637   res = MP_OKAY;
1638
1639 __Y:mp_clear (&y);
1640 __X:mp_clear (&x);
1641 __T2:mp_clear (&t2);
1642 __T1:mp_clear (&t1);
1643 __Q:mp_clear (&q);
1644   return res;
1645 }
1646
1647 static int s_is_power_of_two(mp_digit b, int *p)
1648 {
1649    int x;
1650
1651    for (x = 1; x < DIGIT_BIT; x++) {
1652       if (b == (((mp_digit)1)<<x)) {
1653          *p = x;
1654          return 1;
1655       }
1656    }
1657    return 0;
1658 }
1659
1660 /* single digit division (based on routine from MPI) */
1661 static int mp_div_d (const mp_int * a, mp_digit b, mp_int * c, mp_digit * d)
1662 {
1663   mp_int  q;
1664   mp_word w;
1665   mp_digit t;
1666   int     res, ix;
1667
1668   /* cannot divide by zero */
1669   if (b == 0) {
1670      return MP_VAL;
1671   }
1672
1673   /* quick outs */
1674   if (b == 1 || mp_iszero(a) == 1) {
1675      if (d != NULL) {
1676         *d = 0;
1677      }
1678      if (c != NULL) {
1679         return mp_copy(a, c);
1680      }
1681      return MP_OKAY;
1682   }
1683
1684   /* power of two ? */
1685   if (s_is_power_of_two(b, &ix) == 1) {
1686      if (d != NULL) {
1687         *d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
1688      }
1689      if (c != NULL) {
1690         return mp_div_2d(a, ix, c, NULL);
1691      }
1692      return MP_OKAY;
1693   }
1694
1695   /* no easy answer [c'est la vie].  Just division */
1696   if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
1697      return res;
1698   }
1699   
1700   q.used = a->used;
1701   q.sign = a->sign;
1702   w = 0;
1703   for (ix = a->used - 1; ix >= 0; ix--) {
1704      w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
1705      
1706      if (w >= b) {
1707         t = (mp_digit)(w / b);
1708         w -= ((mp_word)t) * ((mp_word)b);
1709       } else {
1710         t = 0;
1711       }
1712       q.dp[ix] = t;
1713   }
1714
1715   if (d != NULL) {
1716      *d = (mp_digit)w;
1717   }
1718   
1719   if (c != NULL) {
1720      mp_clamp(&q);
1721      mp_exch(&q, c);
1722   }
1723   mp_clear(&q);
1724   
1725   return res;
1726 }
1727
1728 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
1729  *
1730  * Based on algorithm from the paper
1731  *
1732  * "Generating Efficient Primes for Discrete Log Cryptosystems"
1733  *                 Chae Hoon Lim, Pil Loong Lee,
1734  *          POSTECH Information Research Laboratories
1735  *
1736  * The modulus must be of a special format [see manual]
1737  *
1738  * Has been modified to use algorithm 7.10 from the LTM book instead
1739  *
1740  * Input x must be in the range 0 <= x <= (n-1)**2
1741  */
1742 static int
1743 mp_dr_reduce (mp_int * x, const mp_int * n, mp_digit k)
1744 {
1745   int      err, i, m;
1746   mp_word  r;
1747   mp_digit mu, *tmpx1, *tmpx2;
1748
1749   /* m = digits in modulus */
1750   m = n->used;
1751
1752   /* ensure that "x" has at least 2m digits */
1753   if (x->alloc < m + m) {
1754     if ((err = mp_grow (x, m + m)) != MP_OKAY) {
1755       return err;
1756     }
1757   }
1758
1759 /* top of loop, this is where the code resumes if
1760  * another reduction pass is required.
1761  */
1762 top:
1763   /* aliases for digits */
1764   /* alias for lower half of x */
1765   tmpx1 = x->dp;
1766
1767   /* alias for upper half of x, or x/B**m */
1768   tmpx2 = x->dp + m;
1769
1770   /* set carry to zero */
1771   mu = 0;
1772
1773   /* compute (x mod B**m) + k * [x/B**m] inline and inplace */
1774   for (i = 0; i < m; i++) {
1775       r         = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
1776       *tmpx1++  = (mp_digit)(r & MP_MASK);
1777       mu        = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
1778   }
1779
1780   /* set final carry */
1781   *tmpx1++ = mu;
1782
1783   /* zero words above m */
1784   for (i = m + 1; i < x->used; i++) {
1785       *tmpx1++ = 0;
1786   }
1787
1788   /* clamp, sub and return */
1789   mp_clamp (x);
1790
1791   /* if x >= n then subtract and reduce again
1792    * Each successive "recursion" makes the input smaller and smaller.
1793    */
1794   if (mp_cmp_mag (x, n) != MP_LT) {
1795     s_mp_sub(x, n, x);
1796     goto top;
1797   }
1798   return MP_OKAY;
1799 }
1800
1801 /* sets the value of "d" required for mp_dr_reduce */
1802 static void mp_dr_setup(const mp_int *a, mp_digit *d)
1803 {
1804    /* the casts are required if DIGIT_BIT is one less than
1805     * the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
1806     */
1807    *d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) - 
1808         ((mp_word)a->dp[0]));
1809 }
1810
1811 /* this is a shell function that calls either the normal or Montgomery
1812  * exptmod functions.  Originally the call to the montgomery code was
1813  * embedded in the normal function but that wasted a lot of stack space
1814  * for nothing (since 99% of the time the Montgomery code would be called)
1815  */
1816 int mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
1817 {
1818   int dr;
1819
1820   /* modulus P must be positive */
1821   if (P->sign == MP_NEG) {
1822      return MP_VAL;
1823   }
1824
1825   /* if exponent X is negative we have to recurse */
1826   if (X->sign == MP_NEG) {
1827      mp_int tmpG, tmpX;
1828      int err;
1829
1830      /* first compute 1/G mod P */
1831      if ((err = mp_init(&tmpG)) != MP_OKAY) {
1832         return err;
1833      }
1834      if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
1835         mp_clear(&tmpG);
1836         return err;
1837      }
1838
1839      /* now get |X| */
1840      if ((err = mp_init(&tmpX)) != MP_OKAY) {
1841         mp_clear(&tmpG);
1842         return err;
1843      }
1844      if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
1845         mp_clear_multi(&tmpG, &tmpX, NULL);
1846         return err;
1847      }
1848
1849      /* and now compute (1/G)**|X| instead of G**X [X < 0] */
1850      err = mp_exptmod(&tmpG, &tmpX, P, Y);
1851      mp_clear_multi(&tmpG, &tmpX, NULL);
1852      return err;
1853   }
1854
1855   dr = 0;
1856
1857   /* if the modulus is odd or dr != 0 use the fast method */
1858   if (mp_isodd (P) == 1 || dr !=  0) {
1859     return mp_exptmod_fast (G, X, P, Y, dr);
1860   } else {
1861     /* otherwise use the generic Barrett reduction technique */
1862     return s_mp_exptmod (G, X, P, Y);
1863   }
1864 }
1865
1866 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
1867  *
1868  * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
1869  * The value of k changes based on the size of the exponent.
1870  *
1871  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
1872  */
1873
1874 int
1875 mp_exptmod_fast (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y, int redmode)
1876 {
1877   mp_int  M[256], res;
1878   mp_digit buf, mp;
1879   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1880
1881   /* use a pointer to the reduction algorithm.  This allows us to use
1882    * one of many reduction algorithms without modding the guts of
1883    * the code with if statements everywhere.
1884    */
1885   int     (*redux)(mp_int*,const mp_int*,mp_digit);
1886
1887   /* find window size */
1888   x = mp_count_bits (X);
1889   if (x <= 7) {
1890     winsize = 2;
1891   } else if (x <= 36) {
1892     winsize = 3;
1893   } else if (x <= 140) {
1894     winsize = 4;
1895   } else if (x <= 450) {
1896     winsize = 5;
1897   } else if (x <= 1303) {
1898     winsize = 6;
1899   } else if (x <= 3529) {
1900     winsize = 7;
1901   } else {
1902     winsize = 8;
1903   }
1904
1905   /* init M array */
1906   /* init first cell */
1907   if ((err = mp_init(&M[1])) != MP_OKAY) {
1908      return err;
1909   }
1910
1911   /* now init the second half of the array */
1912   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
1913     if ((err = mp_init(&M[x])) != MP_OKAY) {
1914       for (y = 1<<(winsize-1); y < x; y++) {
1915         mp_clear (&M[y]);
1916       }
1917       mp_clear(&M[1]);
1918       return err;
1919     }
1920   }
1921
1922   /* determine and setup reduction code */
1923   if (redmode == 0) {
1924      /* now setup montgomery  */
1925      if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
1926         goto __M;
1927      }
1928
1929      /* automatically pick the comba one if available (saves quite a few calls/ifs) */
1930      if (((P->used * 2 + 1) < MP_WARRAY) &&
1931           P->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
1932         redux = fast_mp_montgomery_reduce;
1933      } else {
1934         /* use slower baseline Montgomery method */
1935         redux = mp_montgomery_reduce;
1936      }
1937   } else if (redmode == 1) {
1938      /* setup DR reduction for moduli of the form B**k - b */
1939      mp_dr_setup(P, &mp);
1940      redux = mp_dr_reduce;
1941   } else {
1942      /* setup DR reduction for moduli of the form 2**k - b */
1943      if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
1944         goto __M;
1945      }
1946      redux = mp_reduce_2k;
1947   }
1948
1949   /* setup result */
1950   if ((err = mp_init (&res)) != MP_OKAY) {
1951     goto __M;
1952   }
1953
1954   /* create M table
1955    *
1956
1957    *
1958    * The first half of the table is not computed though accept for M[0] and M[1]
1959    */
1960
1961   if (redmode == 0) {
1962      /* now we need R mod m */
1963      if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
1964        goto __RES;
1965      }
1966
1967      /* now set M[1] to G * R mod m */
1968      if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
1969        goto __RES;
1970      }
1971   } else {
1972      mp_set(&res, 1);
1973      if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
1974         goto __RES;
1975      }
1976   }
1977
1978   /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
1979   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
1980     goto __RES;
1981   }
1982
1983   for (x = 0; x < (winsize - 1); x++) {
1984     if ((err = mp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
1985       goto __RES;
1986     }
1987     if ((err = redux (&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
1988       goto __RES;
1989     }
1990   }
1991
1992   /* create upper table */
1993   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
1994     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
1995       goto __RES;
1996     }
1997     if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
1998       goto __RES;
1999     }
2000   }
2001
2002   /* set initial mode and bit cnt */
2003   mode   = 0;
2004   bitcnt = 1;
2005   buf    = 0;
2006   digidx = X->used - 1;
2007   bitcpy = 0;
2008   bitbuf = 0;
2009
2010   for (;;) {
2011     /* grab next digit as required */
2012     if (--bitcnt == 0) {
2013       /* if digidx == -1 we are out of digits so break */
2014       if (digidx == -1) {
2015         break;
2016       }
2017       /* read next digit and reset bitcnt */
2018       buf    = X->dp[digidx--];
2019       bitcnt = DIGIT_BIT;
2020     }
2021
2022     /* grab the next msb from the exponent */
2023     y     = (buf >> (DIGIT_BIT - 1)) & 1;
2024     buf <<= (mp_digit)1;
2025
2026     /* if the bit is zero and mode == 0 then we ignore it
2027      * These represent the leading zero bits before the first 1 bit
2028      * in the exponent.  Technically this opt is not required but it
2029      * does lower the # of trivial squaring/reductions used
2030      */
2031     if (mode == 0 && y == 0) {
2032       continue;
2033     }
2034
2035     /* if the bit is zero and mode == 1 then we square */
2036     if (mode == 1 && y == 0) {
2037       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2038         goto __RES;
2039       }
2040       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2041         goto __RES;
2042       }
2043       continue;
2044     }
2045
2046     /* else we add it to the window */
2047     bitbuf |= (y << (winsize - ++bitcpy));
2048     mode    = 2;
2049
2050     if (bitcpy == winsize) {
2051       /* ok window is filled so square as required and multiply  */
2052       /* square first */
2053       for (x = 0; x < winsize; x++) {
2054         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2055           goto __RES;
2056         }
2057         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2058           goto __RES;
2059         }
2060       }
2061
2062       /* then multiply */
2063       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
2064         goto __RES;
2065       }
2066       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2067         goto __RES;
2068       }
2069
2070       /* empty window and reset */
2071       bitcpy = 0;
2072       bitbuf = 0;
2073       mode   = 1;
2074     }
2075   }
2076
2077   /* if bits remain then square/multiply */
2078   if (mode == 2 && bitcpy > 0) {
2079     /* square then multiply if the bit is set */
2080     for (x = 0; x < bitcpy; x++) {
2081       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2082         goto __RES;
2083       }
2084       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2085         goto __RES;
2086       }
2087
2088       /* get next bit of the window */
2089       bitbuf <<= 1;
2090       if ((bitbuf & (1 << winsize)) != 0) {
2091         /* then multiply */
2092         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
2093           goto __RES;
2094         }
2095         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2096           goto __RES;
2097         }
2098       }
2099     }
2100   }
2101
2102   if (redmode == 0) {
2103      /* fixup result if Montgomery reduction is used
2104       * recall that any value in a Montgomery system is
2105       * actually multiplied by R mod n.  So we have
2106       * to reduce one more time to cancel out the factor
2107       * of R.
2108       */
2109      if ((err = redux(&res, P, mp)) != MP_OKAY) {
2110        goto __RES;
2111      }
2112   }
2113
2114   /* swap res with Y */
2115   mp_exch (&res, Y);
2116   err = MP_OKAY;
2117 __RES:mp_clear (&res);
2118 __M:
2119   mp_clear(&M[1]);
2120   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2121     mp_clear (&M[x]);
2122   }
2123   return err;
2124 }
2125
2126 /* Greatest Common Divisor using the binary method */
2127 int mp_gcd (const mp_int * a, const mp_int * b, mp_int * c)
2128 {
2129   mp_int  u, v;
2130   int     k, u_lsb, v_lsb, res;
2131
2132   /* either zero than gcd is the largest */
2133   if (mp_iszero (a) == 1 && mp_iszero (b) == 0) {
2134     return mp_abs (b, c);
2135   }
2136   if (mp_iszero (a) == 0 && mp_iszero (b) == 1) {
2137     return mp_abs (a, c);
2138   }
2139
2140   /* optimized.  At this point if a == 0 then
2141    * b must equal zero too
2142    */
2143   if (mp_iszero (a) == 1) {
2144     mp_zero(c);
2145     return MP_OKAY;
2146   }
2147
2148   /* get copies of a and b we can modify */
2149   if ((res = mp_init_copy (&u, a)) != MP_OKAY) {
2150     return res;
2151   }
2152
2153   if ((res = mp_init_copy (&v, b)) != MP_OKAY) {
2154     goto __U;
2155   }
2156
2157   /* must be positive for the remainder of the algorithm */
2158   u.sign = v.sign = MP_ZPOS;
2159
2160   /* B1.  Find the common power of two for u and v */
2161   u_lsb = mp_cnt_lsb(&u);
2162   v_lsb = mp_cnt_lsb(&v);
2163   k     = MIN(u_lsb, v_lsb);
2164
2165   if (k > 0) {
2166      /* divide the power of two out */
2167      if ((res = mp_div_2d(&u, k, &u, NULL)) != MP_OKAY) {
2168         goto __V;
2169      }
2170
2171      if ((res = mp_div_2d(&v, k, &v, NULL)) != MP_OKAY) {
2172         goto __V;
2173      }
2174   }
2175
2176   /* divide any remaining factors of two out */
2177   if (u_lsb != k) {
2178      if ((res = mp_div_2d(&u, u_lsb - k, &u, NULL)) != MP_OKAY) {
2179         goto __V;
2180      }
2181   }
2182
2183   if (v_lsb != k) {
2184      if ((res = mp_div_2d(&v, v_lsb - k, &v, NULL)) != MP_OKAY) {
2185         goto __V;
2186      }
2187   }
2188
2189   while (mp_iszero(&v) == 0) {
2190      /* make sure v is the largest */
2191      if (mp_cmp_mag(&u, &v) == MP_GT) {
2192         /* swap u and v to make sure v is >= u */
2193         mp_exch(&u, &v);
2194      }
2195      
2196      /* subtract smallest from largest */
2197      if ((res = s_mp_sub(&v, &u, &v)) != MP_OKAY) {
2198         goto __V;
2199      }
2200      
2201      /* Divide out all factors of two */
2202      if ((res = mp_div_2d(&v, mp_cnt_lsb(&v), &v, NULL)) != MP_OKAY) {
2203         goto __V;
2204      } 
2205   } 
2206
2207   /* multiply by 2**k which we divided out at the beginning */
2208   if ((res = mp_mul_2d (&u, k, c)) != MP_OKAY) {
2209      goto __V;
2210   }
2211   c->sign = MP_ZPOS;
2212   res = MP_OKAY;
2213 __V:mp_clear (&u);
2214 __U:mp_clear (&v);
2215   return res;
2216 }
2217
2218 /* get the lower 32-bits of an mp_int */
2219 unsigned long mp_get_int(const mp_int * a)
2220 {
2221   int i;
2222   unsigned long res;
2223
2224   if (a->used == 0) {
2225      return 0;
2226   }
2227
2228   /* get number of digits of the lsb we have to read */
2229   i = MIN(a->used,(int)((sizeof(unsigned long)*CHAR_BIT+DIGIT_BIT-1)/DIGIT_BIT))-1;
2230
2231   /* get most significant digit of result */
2232   res = DIGIT(a,i);
2233    
2234   while (--i >= 0) {
2235     res = (res << DIGIT_BIT) | DIGIT(a,i);
2236   }
2237
2238   /* force result to 32-bits always so it is consistent on non 32-bit platforms */
2239   return res & 0xFFFFFFFFUL;
2240 }
2241
2242 /* creates "a" then copies b into it */
2243 int mp_init_copy (mp_int * a, const mp_int * b)
2244 {
2245   int     res;
2246
2247   if ((res = mp_init (a)) != MP_OKAY) {
2248     return res;
2249   }
2250   return mp_copy (b, a);
2251 }
2252
2253 int mp_init_multi(mp_int *mp, ...) 
2254 {
2255     mp_err res = MP_OKAY;      /* Assume ok until proven otherwise */
2256     int n = 0;                 /* Number of ok inits */
2257     mp_int* cur_arg = mp;
2258     va_list args;
2259
2260     va_start(args, mp);        /* init args to next argument from caller */
2261     while (cur_arg != NULL) {
2262         if (mp_init(cur_arg) != MP_OKAY) {
2263             /* Oops - error! Back-track and mp_clear what we already
2264                succeeded in init-ing, then return error.
2265             */
2266             va_list clean_args;
2267             
2268             /* end the current list */
2269             va_end(args);
2270             
2271             /* now start cleaning up */            
2272             cur_arg = mp;
2273             va_start(clean_args, mp);
2274             while (n--) {
2275                 mp_clear(cur_arg);
2276                 cur_arg = va_arg(clean_args, mp_int*);
2277             }
2278             va_end(clean_args);
2279             res = MP_MEM;
2280             break;
2281         }
2282         n++;
2283         cur_arg = va_arg(args, mp_int*);
2284     }
2285     va_end(args);
2286     return res;                /* Assumed ok, if error flagged above. */
2287 }
2288
2289 /* hac 14.61, pp608 */
2290 int mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
2291 {
2292   /* b cannot be negative */
2293   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2294     return MP_VAL;
2295   }
2296
2297   /* if the modulus is odd we can use a faster routine instead */
2298   if (mp_isodd (b) == 1) {
2299     return fast_mp_invmod (a, b, c);
2300   }
2301   
2302   return mp_invmod_slow(a, b, c);
2303 }
2304
2305 /* hac 14.61, pp608 */
2306 int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c)
2307 {
2308   mp_int  x, y, u, v, A, B, C, D;
2309   int     res;
2310
2311   /* b cannot be negative */
2312   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2313     return MP_VAL;
2314   }
2315
2316   /* init temps */
2317   if ((res = mp_init_multi(&x, &y, &u, &v, 
2318                            &A, &B, &C, &D, NULL)) != MP_OKAY) {
2319      return res;
2320   }
2321
2322   /* x = a, y = b */
2323   if ((res = mp_copy (a, &x)) != MP_OKAY) {
2324     goto __ERR;
2325   }
2326   if ((res = mp_copy (b, &y)) != MP_OKAY) {
2327     goto __ERR;
2328   }
2329
2330   /* 2. [modified] if x,y are both even then return an error! */
2331   if (mp_iseven (&x) == 1 && mp_iseven (&y) == 1) {
2332     res = MP_VAL;
2333     goto __ERR;
2334   }
2335
2336   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
2337   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
2338     goto __ERR;
2339   }
2340   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
2341     goto __ERR;
2342   }
2343   mp_set (&A, 1);
2344   mp_set (&D, 1);
2345
2346 top:
2347   /* 4.  while u is even do */
2348   while (mp_iseven (&u) == 1) {
2349     /* 4.1 u = u/2 */
2350     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
2351       goto __ERR;
2352     }
2353     /* 4.2 if A or B is odd then */
2354     if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
2355       /* A = (A+y)/2, B = (B-x)/2 */
2356       if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
2357          goto __ERR;
2358       }
2359       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
2360          goto __ERR;
2361       }
2362     }
2363     /* A = A/2, B = B/2 */
2364     if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
2365       goto __ERR;
2366     }
2367     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
2368       goto __ERR;
2369     }
2370   }
2371
2372   /* 5.  while v is even do */
2373   while (mp_iseven (&v) == 1) {
2374     /* 5.1 v = v/2 */
2375     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
2376       goto __ERR;
2377     }
2378     /* 5.2 if C or D is odd then */
2379     if (mp_isodd (&C) == 1 || mp_isodd (&D) == 1) {
2380       /* C = (C+y)/2, D = (D-x)/2 */
2381       if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
2382          goto __ERR;
2383       }
2384       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
2385          goto __ERR;
2386       }
2387     }
2388     /* C = C/2, D = D/2 */
2389     if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
2390       goto __ERR;
2391     }
2392     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
2393       goto __ERR;
2394     }
2395   }
2396
2397   /* 6.  if u >= v then */
2398   if (mp_cmp (&u, &v) != MP_LT) {
2399     /* u = u - v, A = A - C, B = B - D */
2400     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
2401       goto __ERR;
2402     }
2403
2404     if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
2405       goto __ERR;
2406     }
2407
2408     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
2409       goto __ERR;
2410     }
2411   } else {
2412     /* v - v - u, C = C - A, D = D - B */
2413     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
2414       goto __ERR;
2415     }
2416
2417     if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
2418       goto __ERR;
2419     }
2420
2421     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
2422       goto __ERR;
2423     }
2424   }
2425
2426   /* if not zero goto step 4 */
2427   if (mp_iszero (&u) == 0)
2428     goto top;
2429
2430   /* now a = C, b = D, gcd == g*v */
2431
2432   /* if v != 1 then there is no inverse */
2433   if (mp_cmp_d (&v, 1) != MP_EQ) {
2434     res = MP_VAL;
2435     goto __ERR;
2436   }
2437
2438   /* if its too low */
2439   while (mp_cmp_d(&C, 0) == MP_LT) {
2440       if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
2441          goto __ERR;
2442       }
2443   }
2444   
2445   /* too big */
2446   while (mp_cmp_mag(&C, b) != MP_LT) {
2447       if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
2448          goto __ERR;
2449       }
2450   }
2451   
2452   /* C is now the inverse */
2453   mp_exch (&C, c);
2454   res = MP_OKAY;
2455 __ERR:mp_clear_multi (&x, &y, &u, &v, &A, &B, &C, &D, NULL);
2456   return res;
2457 }
2458
2459 /* c = |a| * |b| using Karatsuba Multiplication using 
2460  * three half size multiplications
2461  *
2462  * Let B represent the radix [e.g. 2**DIGIT_BIT] and 
2463  * let n represent half of the number of digits in 
2464  * the min(a,b)
2465  *
2466  * a = a1 * B**n + a0
2467  * b = b1 * B**n + b0
2468  *
2469  * Then, a * b => 
2470    a1b1 * B**2n + ((a1 - a0)(b1 - b0) + a0b0 + a1b1) * B + a0b0
2471  *
2472  * Note that a1b1 and a0b0 are used twice and only need to be 
2473  * computed once.  So in total three half size (half # of 
2474  * digit) multiplications are performed, a0b0, a1b1 and 
2475  * (a1-b1)(a0-b0)
2476  *
2477  * Note that a multiplication of half the digits requires
2478  * 1/4th the number of single precision multiplications so in 
2479  * total after one call 25% of the single precision multiplications 
2480  * are saved.  Note also that the call to mp_mul can end up back 
2481  * in this function if the a0, a1, b0, or b1 are above the threshold.  
2482  * This is known as divide-and-conquer and leads to the famous 
2483  * O(N**lg(3)) or O(N**1.584) work which is asymptotically lower than
2484  * the standard O(N**2) that the baseline/comba methods use.  
2485  * Generally though the overhead of this method doesn't pay off 
2486  * until a certain size (N ~ 80) is reached.
2487  */
2488 int mp_karatsuba_mul (const mp_int * a, const mp_int * b, mp_int * c)
2489 {
2490   mp_int  x0, x1, y0, y1, t1, x0y0, x1y1;
2491   int     B, err;
2492
2493   /* default the return code to an error */
2494   err = MP_MEM;
2495
2496   /* min # of digits */
2497   B = MIN (a->used, b->used);
2498
2499   /* now divide in two */
2500   B = B >> 1;
2501
2502   /* init copy all the temps */
2503   if (mp_init_size (&x0, B) != MP_OKAY)
2504     goto ERR;
2505   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2506     goto X0;
2507   if (mp_init_size (&y0, B) != MP_OKAY)
2508     goto X1;
2509   if (mp_init_size (&y1, b->used - B) != MP_OKAY)
2510     goto Y0;
2511
2512   /* init temps */
2513   if (mp_init_size (&t1, B * 2) != MP_OKAY)
2514     goto Y1;
2515   if (mp_init_size (&x0y0, B * 2) != MP_OKAY)
2516     goto T1;
2517   if (mp_init_size (&x1y1, B * 2) != MP_OKAY)
2518     goto X0Y0;
2519
2520   /* now shift the digits */
2521   x0.used = y0.used = B;
2522   x1.used = a->used - B;
2523   y1.used = b->used - B;
2524
2525   {
2526     register int x;
2527     register mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
2528
2529     /* we copy the digits directly instead of using higher level functions
2530      * since we also need to shift the digits
2531      */
2532     tmpa = a->dp;
2533     tmpb = b->dp;
2534
2535     tmpx = x0.dp;
2536     tmpy = y0.dp;
2537     for (x = 0; x < B; x++) {
2538       *tmpx++ = *tmpa++;
2539       *tmpy++ = *tmpb++;
2540     }
2541
2542     tmpx = x1.dp;
2543     for (x = B; x < a->used; x++) {
2544       *tmpx++ = *tmpa++;
2545     }
2546
2547     tmpy = y1.dp;
2548     for (x = B; x < b->used; x++) {
2549       *tmpy++ = *tmpb++;
2550     }
2551   }
2552
2553   /* only need to clamp the lower words since by definition the 
2554    * upper words x1/y1 must have a known number of digits
2555    */
2556   mp_clamp (&x0);
2557   mp_clamp (&y0);
2558
2559   /* now calc the products x0y0 and x1y1 */
2560   /* after this x0 is no longer required, free temp [x0==t2]! */
2561   if (mp_mul (&x0, &y0, &x0y0) != MP_OKAY)  
2562     goto X1Y1;          /* x0y0 = x0*y0 */
2563   if (mp_mul (&x1, &y1, &x1y1) != MP_OKAY)
2564     goto X1Y1;          /* x1y1 = x1*y1 */
2565
2566   /* now calc x1-x0 and y1-y0 */
2567   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2568     goto X1Y1;          /* t1 = x1 - x0 */
2569   if (mp_sub (&y1, &y0, &x0) != MP_OKAY)
2570     goto X1Y1;          /* t2 = y1 - y0 */
2571   if (mp_mul (&t1, &x0, &t1) != MP_OKAY)
2572     goto X1Y1;          /* t1 = (x1 - x0) * (y1 - y0) */
2573
2574   /* add x0y0 */
2575   if (mp_add (&x0y0, &x1y1, &x0) != MP_OKAY)
2576     goto X1Y1;          /* t2 = x0y0 + x1y1 */
2577   if (mp_sub (&x0, &t1, &t1) != MP_OKAY)
2578     goto X1Y1;          /* t1 = x0y0 + x1y1 - (x1-x0)*(y1-y0) */
2579
2580   /* shift by B */
2581   if (mp_lshd (&t1, B) != MP_OKAY)
2582     goto X1Y1;          /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
2583   if (mp_lshd (&x1y1, B * 2) != MP_OKAY)
2584     goto X1Y1;          /* x1y1 = x1y1 << 2*B */
2585
2586   if (mp_add (&x0y0, &t1, &t1) != MP_OKAY)
2587     goto X1Y1;          /* t1 = x0y0 + t1 */
2588   if (mp_add (&t1, &x1y1, c) != MP_OKAY)
2589     goto X1Y1;          /* t1 = x0y0 + t1 + x1y1 */
2590
2591   /* Algorithm succeeded set the return code to MP_OKAY */
2592   err = MP_OKAY;
2593
2594 X1Y1:mp_clear (&x1y1);
2595 X0Y0:mp_clear (&x0y0);
2596 T1:mp_clear (&t1);
2597 Y1:mp_clear (&y1);
2598 Y0:mp_clear (&y0);
2599 X1:mp_clear (&x1);
2600 X0:mp_clear (&x0);
2601 ERR:
2602   return err;
2603 }
2604
2605 /* Karatsuba squaring, computes b = a*a using three 
2606  * half size squarings
2607  *
2608  * See comments of karatsuba_mul for details.  It 
2609  * is essentially the same algorithm but merely 
2610  * tuned to perform recursive squarings.
2611  */
2612 int mp_karatsuba_sqr (const mp_int * a, mp_int * b)
2613 {
2614   mp_int  x0, x1, t1, t2, x0x0, x1x1;
2615   int     B, err;
2616
2617   err = MP_MEM;
2618
2619   /* min # of digits */
2620   B = a->used;
2621
2622   /* now divide in two */
2623   B = B >> 1;
2624
2625   /* init copy all the temps */
2626   if (mp_init_size (&x0, B) != MP_OKAY)
2627     goto ERR;
2628   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2629     goto X0;
2630
2631   /* init temps */
2632   if (mp_init_size (&t1, a->used * 2) != MP_OKAY)
2633     goto X1;
2634   if (mp_init_size (&t2, a->used * 2) != MP_OKAY)
2635     goto T1;
2636   if (mp_init_size (&x0x0, B * 2) != MP_OKAY)
2637     goto T2;
2638   if (mp_init_size (&x1x1, (a->used - B) * 2) != MP_OKAY)
2639     goto X0X0;
2640
2641   {
2642     register int x;
2643     register mp_digit *dst, *src;
2644
2645     src = a->dp;
2646
2647     /* now shift the digits */
2648     dst = x0.dp;
2649     for (x = 0; x < B; x++) {
2650       *dst++ = *src++;
2651     }
2652
2653     dst = x1.dp;
2654     for (x = B; x < a->used; x++) {
2655       *dst++ = *src++;
2656     }
2657   }
2658
2659   x0.used = B;
2660   x1.used = a->used - B;
2661
2662   mp_clamp (&x0);
2663
2664   /* now calc the products x0*x0 and x1*x1 */
2665   if (mp_sqr (&x0, &x0x0) != MP_OKAY)
2666     goto X1X1;           /* x0x0 = x0*x0 */
2667   if (mp_sqr (&x1, &x1x1) != MP_OKAY)
2668     goto X1X1;           /* x1x1 = x1*x1 */
2669
2670   /* now calc (x1-x0)**2 */
2671   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2672     goto X1X1;           /* t1 = x1 - x0 */
2673   if (mp_sqr (&t1, &t1) != MP_OKAY)
2674     goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */
2675
2676   /* add x0y0 */
2677   if (s_mp_add (&x0x0, &x1x1, &t2) != MP_OKAY)
2678     goto X1X1;           /* t2 = x0x0 + x1x1 */
2679   if (mp_sub (&t2, &t1, &t1) != MP_OKAY)
2680     goto X1X1;           /* t1 = x0x0 + x1x1 - (x1-x0)*(x1-x0) */
2681
2682   /* shift by B */
2683   if (mp_lshd (&t1, B) != MP_OKAY)
2684     goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
2685   if (mp_lshd (&x1x1, B * 2) != MP_OKAY)
2686     goto X1X1;           /* x1x1 = x1x1 << 2*B */
2687
2688   if (mp_add (&x0x0, &t1, &t1) != MP_OKAY)
2689     goto X1X1;           /* t1 = x0x0 + t1 */
2690   if (mp_add (&t1, &x1x1, b) != MP_OKAY)
2691     goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */
2692
2693   err = MP_OKAY;
2694
2695 X1X1:mp_clear (&x1x1);
2696 X0X0:mp_clear (&x0x0);
2697 T2:mp_clear (&t2);
2698 T1:mp_clear (&t1);
2699 X1:mp_clear (&x1);
2700 X0:mp_clear (&x0);
2701 ERR:
2702   return err;
2703 }
2704
2705 /* computes least common multiple as |a*b|/(a, b) */
2706 int mp_lcm (const mp_int * a, const mp_int * b, mp_int * c)
2707 {
2708   int     res;
2709   mp_int  t1, t2;
2710
2711
2712   if ((res = mp_init_multi (&t1, &t2, NULL)) != MP_OKAY) {
2713     return res;
2714   }
2715
2716   /* t1 = get the GCD of the two inputs */
2717   if ((res = mp_gcd (a, b, &t1)) != MP_OKAY) {
2718     goto __T;
2719   }
2720
2721   /* divide the smallest by the GCD */
2722   if (mp_cmp_mag(a, b) == MP_LT) {
2723      /* store quotient in t2 such that t2 * b is the LCM */
2724      if ((res = mp_div(a, &t1, &t2, NULL)) != MP_OKAY) {
2725         goto __T;
2726      }
2727      res = mp_mul(b, &t2, c);
2728   } else {
2729      /* store quotient in t2 such that t2 * a is the LCM */
2730      if ((res = mp_div(b, &t1, &t2, NULL)) != MP_OKAY) {
2731         goto __T;
2732      }
2733      res = mp_mul(a, &t2, c);
2734   }
2735
2736   /* fix the sign to positive */
2737   c->sign = MP_ZPOS;
2738
2739 __T:
2740   mp_clear_multi (&t1, &t2, NULL);
2741   return res;
2742 }
2743
2744 /* c = a mod b, 0 <= c < b */
2745 int
2746 mp_mod (const mp_int * a, mp_int * b, mp_int * c)
2747 {
2748   mp_int  t;
2749   int     res;
2750
2751   if ((res = mp_init (&t)) != MP_OKAY) {
2752     return res;
2753   }
2754
2755   if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
2756     mp_clear (&t);
2757     return res;
2758   }
2759
2760   if (t.sign != b->sign) {
2761     res = mp_add (b, &t, c);
2762   } else {
2763     res = MP_OKAY;
2764     mp_exch (&t, c);
2765   }
2766
2767   mp_clear (&t);
2768   return res;
2769 }
2770
2771 static int
2772 mp_mod_d (const mp_int * a, mp_digit b, mp_digit * c)
2773 {
2774   return mp_div_d(a, b, NULL, c);
2775 }
2776
2777 /* b = a*2 */
2778 static int mp_mul_2(const mp_int * a, mp_int * b)
2779 {
2780   int     x, res, oldused;
2781
2782   /* grow to accommodate result */
2783   if (b->alloc < a->used + 1) {
2784     if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
2785       return res;
2786     }
2787   }
2788
2789   oldused = b->used;
2790   b->used = a->used;
2791
2792   {
2793     register mp_digit r, rr, *tmpa, *tmpb;
2794
2795     /* alias for source */
2796     tmpa = a->dp;
2797
2798     /* alias for dest */
2799     tmpb = b->dp;
2800
2801     /* carry */
2802     r = 0;
2803     for (x = 0; x < a->used; x++) {
2804
2805       /* get what will be the *next* carry bit from the
2806        * MSB of the current digit
2807        */
2808       rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
2809
2810       /* now shift up this digit, add in the carry [from the previous] */
2811       *tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
2812
2813       /* copy the carry that would be from the source
2814        * digit into the next iteration
2815        */
2816       r = rr;
2817     }
2818
2819     /* new leading digit? */
2820     if (r != 0) {
2821       /* add a MSB which is always 1 at this point */
2822       *tmpb = 1;
2823       ++(b->used);
2824     }
2825
2826     /* now zero any excess digits on the destination
2827      * that we didn't write to
2828      */
2829     tmpb = b->dp + b->used;
2830     for (x = b->used; x < oldused; x++) {
2831       *tmpb++ = 0;
2832     }
2833   }
2834   b->sign = a->sign;
2835   return MP_OKAY;
2836 }
2837
2838 /*
2839  * shifts with subtractions when the result is greater than b.
2840  *
2841  * The method is slightly modified to shift B unconditionally up to just under
2842  * the leading bit of b.  This saves a lot of multiple precision shifting.
2843  */
2844 int mp_montgomery_calc_normalization (mp_int * a, const mp_int * b)
2845 {
2846   int     x, bits, res;
2847
2848   /* how many bits of last digit does b use */
2849   bits = mp_count_bits (b) % DIGIT_BIT;
2850
2851
2852   if (b->used > 1) {
2853      if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
2854         return res;
2855      }
2856   } else {
2857      mp_set(a, 1);
2858      bits = 1;
2859   }
2860
2861
2862   /* now compute C = A * B mod b */
2863   for (x = bits - 1; x < DIGIT_BIT; x++) {
2864     if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
2865       return res;
2866     }
2867     if (mp_cmp_mag (a, b) != MP_LT) {
2868       if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
2869         return res;
2870       }
2871     }
2872   }
2873
2874   return MP_OKAY;
2875 }
2876
2877 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
2878 int
2879 mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
2880 {
2881   int     ix, res, digs;
2882   mp_digit mu;
2883
2884   /* can the fast reduction [comba] method be used?
2885    *
2886    * Note that unlike in mul you're safely allowed *less*
2887    * than the available columns [255 per default] since carries
2888    * are fixed up in the inner loop.
2889    */
2890   digs = n->used * 2 + 1;
2891   if ((digs < MP_WARRAY) &&
2892       n->used <
2893       (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2894     return fast_mp_montgomery_reduce (x, n, rho);
2895   }
2896
2897   /* grow the input as required */
2898   if (x->alloc < digs) {
2899     if ((res = mp_grow (x, digs)) != MP_OKAY) {
2900       return res;
2901     }
2902   }
2903   x->used = digs;
2904
2905   for (ix = 0; ix < n->used; ix++) {
2906     /* mu = ai * rho mod b
2907      *
2908      * The value of rho must be precalculated via
2909      * montgomery_setup() such that
2910      * it equals -1/n0 mod b this allows the
2911      * following inner loop to reduce the
2912      * input one digit at a time
2913      */
2914     mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
2915
2916     /* a = a + mu * m * b**i */
2917     {
2918       register int iy;
2919       register mp_digit *tmpn, *tmpx, u;
2920       register mp_word r;
2921
2922       /* alias for digits of the modulus */
2923       tmpn = n->dp;
2924
2925       /* alias for the digits of x [the input] */
2926       tmpx = x->dp + ix;
2927
2928       /* set the carry to zero */
2929       u = 0;
2930
2931       /* Multiply and add in place */
2932       for (iy = 0; iy < n->used; iy++) {
2933         /* compute product and sum */
2934         r       = ((mp_word)mu) * ((mp_word)*tmpn++) +
2935                   ((mp_word) u) + ((mp_word) * tmpx);
2936
2937         /* get carry */
2938         u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2939
2940         /* fix digit */
2941         *tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
2942       }
2943       /* At this point the ix'th digit of x should be zero */
2944
2945
2946       /* propagate carries upwards as required*/
2947       while (u) {
2948         *tmpx   += u;
2949         u        = *tmpx >> DIGIT_BIT;
2950         *tmpx++ &= MP_MASK;
2951       }
2952     }
2953   }
2954
2955   /* at this point the n.used'th least
2956    * significant digits of x are all zero
2957    * which means we can shift x to the
2958    * right by n.used digits and the
2959    * residue is unchanged.
2960    */
2961
2962   /* x = x/b**n.used */
2963   mp_clamp(x);
2964   mp_rshd (x, n->used);
2965
2966   /* if x >= n then x = x - n */
2967   if (mp_cmp_mag (x, n) != MP_LT) {
2968     return s_mp_sub (x, n, x);
2969   }
2970
2971   return MP_OKAY;
2972 }
2973
2974 /* setups the montgomery reduction stuff */
2975 int
2976 mp_montgomery_setup (const mp_int * n, mp_digit * rho)
2977 {
2978   mp_digit x, b;
2979
2980 /* fast inversion mod 2**k
2981  *
2982  * Based on the fact that
2983  *
2984  * XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
2985  *                    =>  2*X*A - X*X*A*A = 1
2986  *                    =>  2*(1) - (1)     = 1
2987  */
2988   b = n->dp[0];
2989
2990   if ((b & 1) == 0) {
2991     return MP_VAL;
2992   }
2993
2994   x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
2995   x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
2996   x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
2997   x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
2998
2999   /* rho = -1/m mod b */
3000   *rho = (((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
3001
3002   return MP_OKAY;
3003 }
3004
3005 /* high level multiplication (handles sign) */
3006 int mp_mul (const mp_int * a, const mp_int * b, mp_int * c)
3007 {
3008   int     res, neg;
3009   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
3010
3011   /* use Karatsuba? */
3012   if (MIN (a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
3013     res = mp_karatsuba_mul (a, b, c);
3014   } else 
3015   {
3016     /* can we use the fast multiplier?
3017      *
3018      * The fast multiplier can be used if the output will 
3019      * have less than MP_WARRAY digits and the number of 
3020      * digits won't affect carry propagation
3021      */
3022     int     digs = a->used + b->used + 1;
3023
3024     if ((digs < MP_WARRAY) &&
3025         MIN(a->used, b->used) <= 
3026         (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
3027       res = fast_s_mp_mul_digs (a, b, c, digs);
3028     } else 
3029       res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
3030   }
3031   c->sign = (c->used > 0) ? neg : MP_ZPOS;
3032   return res;
3033 }
3034
3035 /* d = a * b (mod c) */
3036 int
3037 mp_mulmod (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
3038 {
3039   int     res;
3040   mp_int  t;
3041
3042   if ((res = mp_init (&t)) != MP_OKAY) {
3043     return res;
3044   }
3045
3046   if ((res = mp_mul (a, b, &t)) != MP_OKAY) {
3047     mp_clear (&t);
3048     return res;
3049   }
3050   res = mp_mod (&t, c, d);
3051   mp_clear (&t);
3052   return res;
3053 }
3054
3055 /* table of first PRIME_SIZE primes */
3056 static const mp_digit __prime_tab[] = {
3057   0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
3058   0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
3059   0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
3060   0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F, 0x0083,
3061   0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
3062   0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
3063   0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
3064   0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
3065
3066   0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
3067   0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
3068   0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
3069   0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
3070   0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
3071   0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
3072   0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
3073   0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
3074
3075   0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
3076   0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
3077   0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
3078   0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
3079   0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
3080   0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
3081   0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
3082   0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
3083
3084   0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
3085   0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
3086   0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
3087   0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
3088   0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
3089   0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
3090   0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
3091   0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
3092 };
3093
3094 /* determines if an integers is divisible by one 
3095  * of the first PRIME_SIZE primes or not
3096  *
3097  * sets result to 0 if not, 1 if yes
3098  */
3099 int mp_prime_is_divisible (const mp_int * a, int *result)
3100 {
3101   int     err, ix;
3102   mp_digit res;
3103
3104   /* default to not */
3105   *result = MP_NO;
3106
3107   for (ix = 0; ix < PRIME_SIZE; ix++) {
3108     /* what is a mod __prime_tab[ix] */
3109     if ((err = mp_mod_d (a, __prime_tab[ix], &res)) != MP_OKAY) {
3110       return err;
3111     }
3112
3113     /* is the residue zero? */
3114     if (res == 0) {
3115       *result = MP_YES;
3116       return MP_OKAY;
3117     }
3118   }
3119
3120   return MP_OKAY;
3121 }
3122
3123 /* performs a variable number of rounds of Miller-Rabin
3124  *
3125  * Probability of error after t rounds is no more than
3126
3127  *
3128  * Sets result to 1 if probably prime, 0 otherwise
3129  */
3130 int mp_prime_is_prime (mp_int * a, int t, int *result)
3131 {
3132   mp_int  b;
3133   int     ix, err, res;
3134
3135   /* default to no */
3136   *result = MP_NO;
3137
3138   /* valid value of t? */
3139   if (t <= 0 || t > PRIME_SIZE) {
3140     return MP_VAL;
3141   }
3142
3143   /* is the input equal to one of the primes in the table? */
3144   for (ix = 0; ix < PRIME_SIZE; ix++) {
3145       if (mp_cmp_d(a, __prime_tab[ix]) == MP_EQ) {
3146          *result = 1;
3147          return MP_OKAY;
3148       }
3149   }
3150
3151   /* first perform trial division */
3152   if ((err = mp_prime_is_divisible (a, &res)) != MP_OKAY) {
3153     return err;
3154   }
3155
3156   /* return if it was trivially divisible */
3157   if (res == MP_YES) {
3158     return MP_OKAY;
3159   }
3160
3161   /* now perform the miller-rabin rounds */
3162   if ((err = mp_init (&b)) != MP_OKAY) {
3163     return err;
3164   }
3165
3166   for (ix = 0; ix < t; ix++) {
3167     /* set the prime */
3168     mp_set (&b, __prime_tab[ix]);
3169
3170     if ((err = mp_prime_miller_rabin (a, &b, &res)) != MP_OKAY) {
3171       goto __B;
3172     }
3173
3174     if (res == MP_NO) {
3175       goto __B;
3176     }
3177   }
3178
3179   /* passed the test */
3180   *result = MP_YES;
3181 __B:mp_clear (&b);
3182   return err;
3183 }
3184
3185 /* Miller-Rabin test of "a" to the base of "b" as described in 
3186  * HAC pp. 139 Algorithm 4.24
3187  *
3188  * Sets result to 0 if definitely composite or 1 if probably prime.
3189  * Randomly the chance of error is no more than 1/4 and often 
3190  * very much lower.
3191  */
3192 int mp_prime_miller_rabin (mp_int * a, const mp_int * b, int *result)
3193 {
3194   mp_int  n1, y, r;
3195   int     s, j, err;
3196
3197   /* default */
3198   *result = MP_NO;
3199
3200   /* ensure b > 1 */
3201   if (mp_cmp_d(b, 1) != MP_GT) {
3202      return MP_VAL;
3203   }     
3204
3205   /* get n1 = a - 1 */
3206   if ((err = mp_init_copy (&n1, a)) != MP_OKAY) {
3207     return err;
3208   }
3209   if ((err = mp_sub_d (&n1, 1, &n1)) != MP_OKAY) {
3210     goto __N1;
3211   }
3212
3213   /* set 2**s * r = n1 */
3214   if ((err = mp_init_copy (&r, &n1)) != MP_OKAY) {
3215     goto __N1;
3216   }
3217
3218   /* count the number of least significant bits
3219    * which are zero
3220    */
3221   s = mp_cnt_lsb(&r);
3222
3223   /* now divide n - 1 by 2**s */
3224   if ((err = mp_div_2d (&r, s, &r, NULL)) != MP_OKAY) {
3225     goto __R;
3226   }
3227
3228   /* compute y = b**r mod a */
3229   if ((err = mp_init (&y)) != MP_OKAY) {
3230     goto __R;
3231   }
3232   if ((err = mp_exptmod (b, &r, a, &y)) != MP_OKAY) {
3233     goto __Y;
3234   }
3235
3236   /* if y != 1 and y != n1 do */
3237   if (mp_cmp_d (&y, 1) != MP_EQ && mp_cmp (&y, &n1) != MP_EQ) {
3238     j = 1;
3239     /* while j <= s-1 and y != n1 */
3240     while ((j <= (s - 1)) && mp_cmp (&y, &n1) != MP_EQ) {
3241       if ((err = mp_sqrmod (&y, a, &y)) != MP_OKAY) {
3242          goto __Y;
3243       }
3244
3245       /* if y == 1 then composite */
3246       if (mp_cmp_d (&y, 1) == MP_EQ) {
3247          goto __Y;
3248       }
3249
3250       ++j;
3251     }
3252
3253     /* if y != n1 then composite */
3254     if (mp_cmp (&y, &n1) != MP_EQ) {
3255       goto __Y;
3256     }
3257   }
3258
3259   /* probably prime now */
3260   *result = MP_YES;
3261 __Y:mp_clear (&y);
3262 __R:mp_clear (&r);
3263 __N1:mp_clear (&n1);
3264   return err;
3265 }
3266
3267 static const struct {
3268    int k, t;
3269 } sizes[] = {
3270 {   128,    28 },
3271 {   256,    16 },
3272 {   384,    10 },
3273 {   512,     7 },
3274 {   640,     6 },
3275 {   768,     5 },
3276 {   896,     4 },
3277 {  1024,     4 }
3278 };
3279
3280 /* returns # of RM trials required for a given bit size */
3281 int mp_prime_rabin_miller_trials(int size)
3282 {
3283    int x;
3284
3285    for (x = 0; x < (int)(sizeof(sizes)/(sizeof(sizes[0]))); x++) {
3286        if (sizes[x].k == size) {
3287           return sizes[x].t;
3288        } else if (sizes[x].k > size) {
3289           return (x == 0) ? sizes[0].t : sizes[x - 1].t;
3290        }
3291    }
3292    return sizes[x-1].t + 1;
3293 }
3294
3295 /* makes a truly random prime of a given size (bits),
3296  *
3297  * Flags are as follows:
3298  * 
3299  *   LTM_PRIME_BBS      - make prime congruent to 3 mod 4
3300  *   LTM_PRIME_SAFE     - make sure (p-1)/2 is prime as well (implies LTM_PRIME_BBS)
3301  *   LTM_PRIME_2MSB_OFF - make the 2nd highest bit zero
3302  *   LTM_PRIME_2MSB_ON  - make the 2nd highest bit one
3303  *
3304  * You have to supply a callback which fills in a buffer with random bytes.  "dat" is a parameter you can
3305  * have passed to the callback (e.g. a state or something).  This function doesn't use "dat" itself
3306  * so it can be NULL
3307  *
3308  */
3309
3310 /* This is possibly the mother of all prime generation functions, muahahahahaha! */
3311 int mp_prime_random_ex(mp_int *a, int t, int size, int flags, ltm_prime_callback cb, void *dat)
3312 {
3313    unsigned char *tmp, maskAND, maskOR_msb, maskOR_lsb;
3314    int res, err, bsize, maskOR_msb_offset;
3315
3316    /* sanity check the input */
3317    if (size <= 1 || t <= 0) {
3318       return MP_VAL;
3319    }
3320
3321    /* LTM_PRIME_SAFE implies LTM_PRIME_BBS */
3322    if (flags & LTM_PRIME_SAFE) {
3323       flags |= LTM_PRIME_BBS;
3324    }
3325
3326    /* calc the byte size */
3327    bsize = (size>>3)+((size&7)?1:0);
3328
3329    /* we need a buffer of bsize bytes */
3330    tmp = malloc(bsize);
3331    if (tmp == NULL) {
3332       return MP_MEM;
3333    }
3334
3335    /* calc the maskAND value for the MSbyte*/
3336    maskAND = ((size&7) == 0) ? 0xFF : (0xFF >> (8 - (size & 7))); 
3337
3338    /* calc the maskOR_msb */
3339    maskOR_msb        = 0;
3340    maskOR_msb_offset = ((size & 7) == 1) ? 1 : 0;
3341    if (flags & LTM_PRIME_2MSB_ON) {
3342       maskOR_msb     |= 1 << ((size - 2) & 7);
3343    } else if (flags & LTM_PRIME_2MSB_OFF) {
3344       maskAND        &= ~(1 << ((size - 2) & 7));
3345    }
3346
3347    /* get the maskOR_lsb */
3348    maskOR_lsb         = 0;
3349    if (flags & LTM_PRIME_BBS) {
3350       maskOR_lsb     |= 3;
3351    }
3352
3353    do {
3354       /* read the bytes */
3355       if (cb(tmp, bsize, dat) != bsize) {
3356          err = MP_VAL;
3357          goto error;
3358       }
3359  
3360       /* work over the MSbyte */
3361       tmp[0]    &= maskAND;
3362       tmp[0]    |= 1 << ((size - 1) & 7);
3363
3364       /* mix in the maskORs */
3365       tmp[maskOR_msb_offset]   |= maskOR_msb;
3366       tmp[bsize-1]             |= maskOR_lsb;
3367
3368       /* read it in */
3369       if ((err = mp_read_unsigned_bin(a, tmp, bsize)) != MP_OKAY)     { goto error; }
3370
3371       /* is it prime? */
3372       if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)           { goto error; }
3373       if (res == MP_NO) {  
3374          continue;
3375       }
3376
3377       if (flags & LTM_PRIME_SAFE) {
3378          /* see if (a-1)/2 is prime */
3379          if ((err = mp_sub_d(a, 1, a)) != MP_OKAY)                    { goto error; }
3380          if ((err = mp_div_2(a, a)) != MP_OKAY)                       { goto error; }
3381  
3382          /* is it prime? */
3383          if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)        { goto error; }
3384       }
3385    } while (res == MP_NO);
3386
3387    if (flags & LTM_PRIME_SAFE) {
3388       /* restore a to the original value */
3389       if ((err = mp_mul_2(a, a)) != MP_OKAY)                          { goto error; }
3390       if ((err = mp_add_d(a, 1, a)) != MP_OKAY)                       { goto error; }
3391    }
3392
3393    err = MP_OKAY;
3394 error:
3395    free(tmp);
3396    return err;
3397 }
3398
3399 /* reads an unsigned char array, assumes the msb is stored first [big endian] */
3400 int
3401 mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
3402 {
3403   int     res;
3404
3405   /* make sure there are at least two digits */
3406   if (a->alloc < 2) {
3407      if ((res = mp_grow(a, 2)) != MP_OKAY) {
3408         return res;
3409      }
3410   }
3411
3412   /* zero the int */
3413   mp_zero (a);
3414
3415   /* read the bytes in */
3416   while (c-- > 0) {
3417     if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
3418       return res;
3419     }
3420
3421       a->dp[0] |= *b++;
3422       a->used += 1;
3423   }
3424   mp_clamp (a);
3425   return MP_OKAY;
3426 }
3427
3428 /* reduces x mod m, assumes 0 < x < m**2, mu is 
3429  * precomputed via mp_reduce_setup.
3430  * From HAC pp.604 Algorithm 14.42
3431  */
3432 int
3433 mp_reduce (mp_int * x, const mp_int * m, const mp_int * mu)
3434 {
3435   mp_int  q;
3436   int     res, um = m->used;
3437
3438   /* q = x */
3439   if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
3440     return res;
3441   }
3442
3443   /* q1 = x / b**(k-1)  */
3444   mp_rshd (&q, um - 1);         
3445
3446   /* according to HAC this optimization is ok */
3447   if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
3448     if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
3449       goto CLEANUP;
3450     }
3451   } else {
3452     if ((res = s_mp_mul_high_digs (&q, mu, &q, um - 1)) != MP_OKAY) {
3453       goto CLEANUP;
3454     }
3455   }
3456
3457   /* q3 = q2 / b**(k+1) */
3458   mp_rshd (&q, um + 1);         
3459
3460   /* x = x mod b**(k+1), quick (no division) */
3461   if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
3462     goto CLEANUP;
3463   }
3464
3465   /* q = q * m mod b**(k+1), quick (no division) */
3466   if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
3467     goto CLEANUP;
3468   }
3469
3470   /* x = x - q */
3471   if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
3472     goto CLEANUP;
3473   }
3474
3475   /* If x < 0, add b**(k+1) to it */
3476   if (mp_cmp_d (x, 0) == MP_LT) {
3477     mp_set (&q, 1);
3478     if ((res = mp_lshd (&q, um + 1)) != MP_OKAY)
3479       goto CLEANUP;
3480     if ((res = mp_add (x, &q, x)) != MP_OKAY)
3481       goto CLEANUP;
3482   }
3483
3484   /* Back off if it's too big */
3485   while (mp_cmp (x, m) != MP_LT) {
3486     if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
3487       goto CLEANUP;
3488     }
3489   }
3490   
3491 CLEANUP:
3492   mp_clear (&q);
3493
3494   return res;
3495 }
3496
3497 /* reduces a modulo n where n is of the form 2**p - d */
3498 int
3499 mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d)
3500 {
3501    mp_int q;
3502    int    p, res;
3503    
3504    if ((res = mp_init(&q)) != MP_OKAY) {
3505       return res;
3506    }
3507    
3508    p = mp_count_bits(n);    
3509 top:
3510    /* q = a/2**p, a = a mod 2**p */
3511    if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
3512       goto ERR;
3513    }
3514    
3515    if (d != 1) {
3516       /* q = q * d */
3517       if ((res = mp_mul_d(&q, d, &q)) != MP_OKAY) { 
3518          goto ERR;
3519       }
3520    }
3521    
3522    /* a = a + q */
3523    if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
3524       goto ERR;
3525    }
3526    
3527    if (mp_cmp_mag(a, n) != MP_LT) {
3528       s_mp_sub(a, n, a);
3529       goto top;
3530    }
3531    
3532 ERR:
3533    mp_clear(&q);
3534    return res;
3535 }
3536
3537 /* determines the setup value */
3538 int 
3539 mp_reduce_2k_setup(const mp_int *a, mp_digit *d)
3540 {
3541    int res, p;
3542    mp_int tmp;
3543    
3544    if ((res = mp_init(&tmp)) != MP_OKAY) {
3545       return res;
3546    }
3547    
3548    p = mp_count_bits(a);
3549    if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
3550       mp_clear(&tmp);
3551       return res;
3552    }
3553    
3554    if ((res = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
3555       mp_clear(&tmp);
3556       return res;
3557    }
3558    
3559    *d = tmp.dp[0];
3560    mp_clear(&tmp);
3561    return MP_OKAY;
3562 }
3563
3564 /* pre-calculate the value required for Barrett reduction
3565  * For a given modulus "b" it calulates the value required in "a"
3566  */
3567 int mp_reduce_setup (mp_int * a, const mp_int * b)
3568 {
3569   int     res;
3570
3571   if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
3572     return res;
3573   }
3574   return mp_div (a, b, a, NULL);
3575 }
3576
3577 /* shift right a certain amount of digits */
3578 void mp_rshd (mp_int * a, int b)
3579 {
3580   int     x;
3581
3582   /* if b <= 0 then ignore it */
3583   if (b <= 0) {
3584     return;
3585   }
3586
3587   /* if b > used then simply zero it and return */
3588   if (a->used <= b) {
3589     mp_zero (a);
3590     return;
3591   }
3592
3593   {
3594     register mp_digit *bottom, *top;
3595
3596     /* shift the digits down */
3597
3598     /* bottom */
3599     bottom = a->dp;
3600
3601     /* top [offset into digits] */
3602     top = a->dp + b;
3603
3604     /* this is implemented as a sliding window where 
3605      * the window is b-digits long and digits from 
3606      * the top of the window are copied to the bottom
3607      *
3608      * e.g.
3609
3610      b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
3611                  /\                   |      ---->
3612                   \-------------------/      ---->
3613      */
3614     for (x = 0; x < (a->used - b); x++) {
3615       *bottom++ = *top++;
3616     }
3617
3618     /* zero the top digits */
3619     for (; x < a->used; x++) {
3620       *bottom++ = 0;
3621     }
3622   }
3623   
3624   /* remove excess digits */
3625   a->used -= b;
3626 }
3627
3628 /* set to a digit */
3629 void mp_set (mp_int * a, mp_digit b)
3630 {
3631   mp_zero (a);
3632   a->dp[0] = b & MP_MASK;
3633   a->used  = (a->dp[0] != 0) ? 1 : 0;
3634 }
3635
3636 /* set a 32-bit const */
3637 int mp_set_int (mp_int * a, unsigned long b)
3638 {
3639   int     x, res;
3640
3641   mp_zero (a);
3642   
3643   /* set four bits at a time */
3644   for (x = 0; x < 8; x++) {
3645     /* shift the number up four bits */
3646     if ((res = mp_mul_2d (a, 4, a)) != MP_OKAY) {
3647       return res;
3648     }
3649
3650     /* OR in the top four bits of the source */
3651     a->dp[0] |= (b >> 28) & 15;
3652
3653     /* shift the source up to the next four bits */
3654     b <<= 4;
3655
3656     /* ensure that digits are not clamped off */
3657     a->used += 1;
3658   }
3659   mp_clamp (a);
3660   return MP_OKAY;
3661 }
3662
3663 /* shrink a bignum */
3664 int mp_shrink (mp_int * a)
3665 {
3666   mp_digit *tmp;
3667   if (a->alloc != a->used && a->used > 0) {
3668     if ((tmp = realloc (a->dp, sizeof (mp_digit) * a->used)) == NULL) {
3669       return MP_MEM;
3670     }
3671     a->dp    = tmp;
3672     a->alloc = a->used;
3673   }
3674   return MP_OKAY;
3675 }
3676
3677 /* get the size for an signed equivalent */
3678 int mp_signed_bin_size (const mp_int * a)
3679 {
3680   return 1 + mp_unsigned_bin_size (a);
3681 }
3682
3683 /* computes b = a*a */
3684 int
3685 mp_sqr (const mp_int * a, mp_int * b)
3686 {
3687   int     res;
3688
3689 if (a->used >= KARATSUBA_SQR_CUTOFF) {
3690     res = mp_karatsuba_sqr (a, b);
3691   } else 
3692   {
3693     /* can we use the fast comba multiplier? */
3694     if ((a->used * 2 + 1) < MP_WARRAY && 
3695          a->used < 
3696          (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
3697       res = fast_s_mp_sqr (a, b);
3698     } else
3699       res = s_mp_sqr (a, b);
3700   }
3701   b->sign = MP_ZPOS;
3702   return res;
3703 }
3704
3705 /* c = a * a (mod b) */
3706 int
3707 mp_sqrmod (const mp_int * a, mp_int * b, mp_int * c)
3708 {
3709   int     res;
3710   mp_int  t;
3711
3712   if ((res = mp_init (&t)) != MP_OKAY) {
3713     return res;
3714   }
3715
3716   if ((res = mp_sqr (a, &t)) != MP_OKAY) {
3717     mp_clear (&t);
3718     return res;
3719   }
3720   res = mp_mod (&t, b, c);
3721   mp_clear (&t);
3722   return res;
3723 }
3724
3725 /* high level subtraction (handles signs) */
3726 int
3727 mp_sub (mp_int * a, mp_int * b, mp_int * c)
3728 {
3729   int     sa, sb, res;
3730
3731   sa = a->sign;
3732   sb = b->sign;
3733
3734   if (sa != sb) {
3735     /* subtract a negative from a positive, OR */
3736     /* subtract a positive from a negative. */
3737     /* In either case, ADD their magnitudes, */
3738     /* and use the sign of the first number. */
3739     c->sign = sa;
3740     res = s_mp_add (a, b, c);
3741   } else {
3742     /* subtract a positive from a positive, OR */
3743     /* subtract a negative from a negative. */
3744     /* First, take the difference between their */
3745     /* magnitudes, then... */
3746     if (mp_cmp_mag (a, b) != MP_LT) {
3747       /* Copy the sign from the first */
3748       c->sign = sa;
3749       /* The first has a larger or equal magnitude */
3750       res = s_mp_sub (a, b, c);
3751     } else {
3752       /* The result has the *opposite* sign from */
3753       /* the first number. */
3754       c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
3755       /* The second has a larger magnitude */
3756       res = s_mp_sub (b, a, c);
3757     }
3758   }
3759   return res;
3760 }
3761
3762 /* single digit subtraction */
3763 int
3764 mp_sub_d (mp_int * a, mp_digit b, mp_int * c)
3765 {
3766   mp_digit *tmpa, *tmpc, mu;
3767   int       res, ix, oldused;
3768
3769   /* grow c as required */
3770   if (c->alloc < a->used + 1) {
3771      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
3772         return res;
3773      }
3774   }
3775
3776   /* if a is negative just do an unsigned
3777    * addition [with fudged signs]
3778    */
3779   if (a->sign == MP_NEG) {
3780      a->sign = MP_ZPOS;
3781      res     = mp_add_d(a, b, c);
3782      a->sign = c->sign = MP_NEG;
3783      return res;
3784   }
3785
3786   /* setup regs */
3787   oldused = c->used;
3788   tmpa    = a->dp;
3789   tmpc    = c->dp;
3790
3791   /* if a <= b simply fix the single digit */
3792   if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
3793      if (a->used == 1) {
3794         *tmpc++ = b - *tmpa;
3795      } else {
3796         *tmpc++ = b;
3797      }
3798      ix      = 1;
3799
3800      /* negative/1digit */
3801      c->sign = MP_NEG;
3802      c->used = 1;
3803   } else {
3804      /* positive/size */
3805      c->sign = MP_ZPOS;
3806      c->used = a->used;
3807
3808      /* subtract first digit */
3809      *tmpc    = *tmpa++ - b;
3810      mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3811      *tmpc++ &= MP_MASK;
3812
3813      /* handle rest of the digits */
3814      for (ix = 1; ix < a->used; ix++) {
3815         *tmpc    = *tmpa++ - mu;
3816         mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3817         *tmpc++ &= MP_MASK;
3818      }
3819   }
3820
3821   /* zero excess digits */
3822   while (ix++ < oldused) {
3823      *tmpc++ = 0;
3824   }
3825   mp_clamp(c);
3826   return MP_OKAY;
3827 }
3828
3829 /* store in unsigned [big endian] format */
3830 int
3831 mp_to_unsigned_bin (const mp_int * a, unsigned char *b)
3832 {
3833   int     x, res;
3834   mp_int  t;
3835
3836   if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
3837     return res;
3838   }
3839
3840   x = 0;
3841   while (mp_iszero (&t) == 0) {
3842     b[x++] = (unsigned char) (t.dp[0] & 255);
3843     if ((res = mp_div_2d (&t, 8, &t, NULL)) != MP_OKAY) {
3844       mp_clear (&t);
3845       return res;
3846     }
3847   }
3848   bn_reverse (b, x);
3849   mp_clear (&t);
3850   return MP_OKAY;
3851 }
3852
3853 /* get the size for an unsigned equivalent */
3854 int
3855 mp_unsigned_bin_size (const mp_int * a)
3856 {
3857   int     size = mp_count_bits (a);
3858   return (size / 8 + ((size & 7) != 0 ? 1 : 0));
3859 }
3860
3861 /* reverse an array, used for radix code */
3862 static void
3863 bn_reverse (unsigned char *s, int len)
3864 {
3865   int     ix, iy;
3866   unsigned char t;
3867
3868   ix = 0;
3869   iy = len - 1;
3870   while (ix < iy) {
3871     t     = s[ix];
3872     s[ix] = s[iy];
3873     s[iy] = t;
3874     ++ix;
3875     --iy;
3876   }
3877 }
3878
3879 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
3880 static int
3881 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
3882 {
3883   mp_int *x;
3884   int     olduse, res, min, max;
3885
3886   /* find sizes, we let |a| <= |b| which means we have to sort
3887    * them.  "x" will point to the input with the most digits
3888    */
3889   if (a->used > b->used) {
3890     min = b->used;
3891     max = a->used;
3892     x = a;
3893   } else {
3894     min = a->used;
3895     max = b->used;
3896     x = b;
3897   }
3898
3899   /* init result */
3900   if (c->alloc < max + 1) {
3901     if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
3902       return res;
3903     }
3904   }
3905
3906   /* get old used digit count and set new one */
3907   olduse = c->used;
3908   c->used = max + 1;
3909
3910   {
3911     register mp_digit u, *tmpa, *tmpb, *tmpc;
3912     register int i;
3913
3914     /* alias for digit pointers */
3915
3916     /* first input */
3917     tmpa = a->dp;
3918
3919     /* second input */
3920     tmpb = b->dp;
3921
3922     /* destination */
3923     tmpc = c->dp;
3924
3925     /* zero the carry */
3926     u = 0;
3927     for (i = 0; i < min; i++) {
3928       /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
3929       *tmpc = *tmpa++ + *tmpb++ + u;
3930
3931       /* U = carry bit of T[i] */
3932       u = *tmpc >> ((mp_digit)DIGIT_BIT);
3933
3934       /* take away carry bit from T[i] */
3935       *tmpc++ &= MP_MASK;
3936     }
3937
3938     /* now copy higher words if any, that is in A+B 
3939      * if A or B has more digits add those in 
3940      */
3941     if (min != max) {
3942       for (; i < max; i++) {
3943         /* T[i] = X[i] + U */
3944         *tmpc = x->dp[i] + u;
3945
3946         /* U = carry bit of T[i] */
3947         u = *tmpc >> ((mp_digit)DIGIT_BIT);
3948
3949         /* take away carry bit from T[i] */
3950         *tmpc++ &= MP_MASK;
3951       }
3952     }
3953
3954     /* add carry */
3955     *tmpc++ = u;
3956
3957     /* clear digits above oldused */
3958     for (i = c->used; i < olduse; i++) {
3959       *tmpc++ = 0;
3960     }
3961   }
3962
3963   mp_clamp (c);
3964   return MP_OKAY;
3965 }
3966
3967 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
3968 {
3969   mp_int  M[256], res, mu;
3970   mp_digit buf;
3971   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
3972
3973   /* find window size */
3974   x = mp_count_bits (X);
3975   if (x <= 7) {
3976     winsize = 2;
3977   } else if (x <= 36) {
3978     winsize = 3;
3979   } else if (x <= 140) {
3980     winsize = 4;
3981   } else if (x <= 450) {
3982     winsize = 5;
3983   } else if (x <= 1303) {
3984     winsize = 6;
3985   } else if (x <= 3529) {
3986     winsize = 7;
3987   } else {
3988     winsize = 8;
3989   }
3990
3991   /* init M array */
3992   /* init first cell */
3993   if ((err = mp_init(&M[1])) != MP_OKAY) {
3994      return err; 
3995   }
3996
3997   /* now init the second half of the array */
3998   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
3999     if ((err = mp_init(&M[x])) != MP_OKAY) {
4000       for (y = 1<<(winsize-1); y < x; y++) {
4001         mp_clear (&M[y]);
4002       }
4003       mp_clear(&M[1]);
4004       return err;
4005     }
4006   }
4007
4008   /* create mu, used for Barrett reduction */
4009   if ((err = mp_init (&mu)) != MP_OKAY) {
4010     goto __M;
4011   }
4012   if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
4013     goto __MU;
4014   }
4015
4016   /* create M table
4017    *
4018    * The M table contains powers of the base, 
4019    * e.g. M[x] = G**x mod P
4020    *
4021    * The first half of the table is not 
4022    * computed though accept for M[0] and M[1]
4023    */
4024   if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
4025     goto __MU;
4026   }
4027
4028   /* compute the value at M[1<<(winsize-1)] by squaring 
4029    * M[1] (winsize-1) times 
4030    */
4031   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
4032     goto __MU;
4033   }
4034
4035   for (x = 0; x < (winsize - 1); x++) {
4036     if ((err = mp_sqr (&M[1 << (winsize - 1)], 
4037                        &M[1 << (winsize - 1)])) != MP_OKAY) {
4038       goto __MU;
4039     }
4040     if ((err = mp_reduce (&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
4041       goto __MU;
4042     }
4043   }
4044
4045   /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
4046    * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
4047    */
4048   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
4049     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
4050       goto __MU;
4051     }
4052     if ((err = mp_reduce (&M[x], P, &mu)) != MP_OKAY) {
4053       goto __MU;
4054     }
4055   }
4056
4057   /* setup result */
4058   if ((err = mp_init (&res)) != MP_OKAY) {
4059     goto __MU;
4060   }
4061   mp_set (&res, 1);
4062
4063   /* set initial mode and bit cnt */
4064   mode   = 0;
4065   bitcnt = 1;
4066   buf    = 0;
4067   digidx = X->used - 1;
4068   bitcpy = 0;
4069   bitbuf = 0;
4070
4071   for (;;) {
4072     /* grab next digit as required */
4073     if (--bitcnt == 0) {
4074       /* if digidx == -1 we are out of digits */
4075       if (digidx == -1) {
4076         break;
4077       }
4078       /* read next digit and reset the bitcnt */
4079       buf    = X->dp[digidx--];
4080       bitcnt = DIGIT_BIT;
4081     }
4082
4083     /* grab the next msb from the exponent */
4084     y     = (buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
4085     buf <<= (mp_digit)1;
4086
4087     /* if the bit is zero and mode == 0 then we ignore it
4088      * These represent the leading zero bits before the first 1 bit
4089      * in the exponent.  Technically this opt is not required but it
4090      * does lower the # of trivial squaring/reductions used
4091      */
4092     if (mode == 0 && y == 0) {
4093       continue;
4094     }
4095
4096     /* if the bit is zero and mode == 1 then we square */
4097     if (mode == 1 && y == 0) {
4098       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4099         goto __RES;
4100       }
4101       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4102         goto __RES;
4103       }
4104       continue;
4105     }
4106
4107     /* else we add it to the window */
4108     bitbuf |= (y << (winsize - ++bitcpy));
4109     mode    = 2;
4110
4111     if (bitcpy == winsize) {
4112       /* ok window is filled so square as required and multiply  */
4113       /* square first */
4114       for (x = 0; x < winsize; x++) {
4115         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4116           goto __RES;
4117         }
4118         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4119           goto __RES;
4120         }
4121       }
4122
4123       /* then multiply */
4124       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
4125         goto __RES;
4126       }
4127       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4128         goto __RES;
4129       }
4130
4131       /* empty window and reset */
4132       bitcpy = 0;
4133       bitbuf = 0;
4134       mode   = 1;
4135     }
4136   }
4137
4138   /* if bits remain then square/multiply */
4139   if (mode == 2 && bitcpy > 0) {
4140     /* square then multiply if the bit is set */
4141     for (x = 0; x < bitcpy; x++) {
4142       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4143         goto __RES;
4144       }
4145       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4146         goto __RES;
4147       }
4148
4149       bitbuf <<= 1;
4150       if ((bitbuf & (1 << winsize)) != 0) {
4151         /* then multiply */
4152         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
4153           goto __RES;
4154         }
4155         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4156           goto __RES;
4157         }
4158       }
4159     }
4160   }
4161
4162   mp_exch (&res, Y);
4163   err = MP_OKAY;
4164 __RES:mp_clear (&res);
4165 __MU:mp_clear (&mu);
4166 __M:
4167   mp_clear(&M[1]);
4168   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4169     mp_clear (&M[x]);
4170   }
4171   return err;
4172 }
4173
4174 /* multiplies |a| * |b| and only computes up to digs digits of result
4175  * HAC pp. 595, Algorithm 14.12  Modified so you can control how 
4176  * many digits of output are created.
4177  */
4178 static int
4179 s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4180 {
4181   mp_int  t;
4182   int     res, pa, pb, ix, iy;
4183   mp_digit u;
4184   mp_word r;
4185   mp_digit tmpx, *tmpt, *tmpy;
4186
4187   /* can we use the fast multiplier? */
4188   if (((digs) < MP_WARRAY) &&
4189       MIN (a->used, b->used) < 
4190           (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4191     return fast_s_mp_mul_digs (a, b, c, digs);
4192   }
4193
4194   if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
4195     return res;
4196   }
4197   t.used = digs;
4198
4199   /* compute the digits of the product directly */
4200   pa = a->used;
4201   for (ix = 0; ix < pa; ix++) {
4202     /* set the carry to zero */
4203     u = 0;
4204
4205     /* limit ourselves to making digs digits of output */
4206     pb = MIN (b->used, digs - ix);
4207
4208     /* setup some aliases */
4209     /* copy of the digit from a used within the nested loop */
4210     tmpx = a->dp[ix];
4211     
4212     /* an alias for the destination shifted ix places */
4213     tmpt = t.dp + ix;
4214     
4215     /* an alias for the digits of b */
4216     tmpy = b->dp;
4217
4218     /* compute the columns of the output and propagate the carry */
4219     for (iy = 0; iy < pb; iy++) {
4220       /* compute the column as a mp_word */
4221       r       = ((mp_word)*tmpt) +
4222                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4223                 ((mp_word) u);
4224
4225       /* the new column is the lower part of the result */
4226       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4227
4228       /* get the carry word from the result */
4229       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4230     }
4231     /* set carry if it is placed below digs */
4232     if (ix + iy < digs) {
4233       *tmpt = u;
4234     }
4235   }
4236
4237   mp_clamp (&t);
4238   mp_exch (&t, c);
4239
4240   mp_clear (&t);
4241   return MP_OKAY;
4242 }
4243
4244 /* multiplies |a| * |b| and does not compute the lower digs digits
4245  * [meant to get the higher part of the product]
4246  */
4247 static int
4248 s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4249 {
4250   mp_int  t;
4251   int     res, pa, pb, ix, iy;
4252   mp_digit u;
4253   mp_word r;
4254   mp_digit tmpx, *tmpt, *tmpy;
4255
4256   /* can we use the fast multiplier? */
4257   if (((a->used + b->used + 1) < MP_WARRAY)
4258       && MIN (a->used, b->used) < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4259     return fast_s_mp_mul_high_digs (a, b, c, digs);
4260   }
4261
4262   if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
4263     return res;
4264   }
4265   t.used = a->used + b->used + 1;
4266
4267   pa = a->used;
4268   pb = b->used;
4269   for (ix = 0; ix < pa; ix++) {
4270     /* clear the carry */
4271     u = 0;
4272
4273     /* left hand side of A[ix] * B[iy] */
4274     tmpx = a->dp[ix];
4275
4276     /* alias to the address of where the digits will be stored */
4277     tmpt = &(t.dp[digs]);
4278
4279     /* alias for where to read the right hand side from */
4280     tmpy = b->dp + (digs - ix);
4281
4282     for (iy = digs - ix; iy < pb; iy++) {
4283       /* calculate the double precision result */
4284       r       = ((mp_word)*tmpt) +
4285                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4286                 ((mp_word) u);
4287
4288       /* get the lower part */
4289       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4290
4291       /* carry the carry */
4292       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4293     }
4294     *tmpt = u;
4295   }
4296   mp_clamp (&t);
4297   mp_exch (&t, c);
4298   mp_clear (&t);
4299   return MP_OKAY;
4300 }
4301
4302 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
4303 static int
4304 s_mp_sqr (const mp_int * a, mp_int * b)
4305 {
4306   mp_int  t;
4307   int     res, ix, iy, pa;
4308   mp_word r;
4309   mp_digit u, tmpx, *tmpt;
4310
4311   pa = a->used;
4312   if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
4313     return res;
4314   }
4315
4316   /* default used is maximum possible size */
4317   t.used = 2*pa + 1;
4318
4319   for (ix = 0; ix < pa; ix++) {
4320     /* first calculate the digit at 2*ix */
4321     /* calculate double precision result */
4322     r = ((mp_word) t.dp[2*ix]) +
4323         ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
4324
4325     /* store lower part in result */
4326     t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
4327
4328     /* get the carry */
4329     u           = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4330
4331     /* left hand side of A[ix] * A[iy] */
4332     tmpx        = a->dp[ix];
4333
4334     /* alias for where to store the results */
4335     tmpt        = t.dp + (2*ix + 1);
4336     
4337     for (iy = ix + 1; iy < pa; iy++) {
4338       /* first calculate the product */
4339       r       = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
4340
4341       /* now calculate the double precision result, note we use
4342        * addition instead of *2 since it's easier to optimize
4343        */
4344       r       = ((mp_word) *tmpt) + r + r + ((mp_word) u);
4345
4346       /* store lower part */
4347       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4348
4349       /* get carry */
4350       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4351     }
4352     /* propagate upwards */
4353     while (u != 0) {
4354       r       = ((mp_word) *tmpt) + ((mp_word) u);
4355       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4356       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4357     }
4358   }
4359
4360   mp_clamp (&t);
4361   mp_exch (&t, b);
4362   mp_clear (&t);
4363   return MP_OKAY;
4364 }
4365
4366 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
4367 int
4368 s_mp_sub (const mp_int * a, const mp_int * b, mp_int * c)
4369 {
4370   int     olduse, res, min, max;
4371
4372   /* find sizes */
4373   min = b->used;
4374   max = a->used;
4375
4376   /* init result */
4377   if (c->alloc < max) {
4378     if ((res = mp_grow (c, max)) != MP_OKAY) {
4379       return res;
4380     }
4381   }
4382   olduse = c->used;
4383   c->used = max;
4384
4385   {
4386     register mp_digit u, *tmpa, *tmpb, *tmpc;
4387     register int i;
4388
4389     /* alias for digit pointers */
4390     tmpa = a->dp;
4391     tmpb = b->dp;
4392     tmpc = c->dp;
4393
4394     /* set carry to zero */
4395     u = 0;
4396     for (i = 0; i < min; i++) {
4397       /* T[i] = A[i] - B[i] - U */
4398       *tmpc = *tmpa++ - *tmpb++ - u;
4399
4400       /* U = carry bit of T[i]
4401        * Note this saves performing an AND operation since
4402        * if a carry does occur it will propagate all the way to the
4403        * MSB.  As a result a single shift is enough to get the carry
4404        */
4405       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4406
4407       /* Clear carry from T[i] */
4408       *tmpc++ &= MP_MASK;
4409     }
4410
4411     /* now copy higher words if any, e.g. if A has more digits than B  */
4412     for (; i < max; i++) {
4413       /* T[i] = A[i] - U */
4414       *tmpc = *tmpa++ - u;
4415
4416       /* U = carry bit of T[i] */
4417       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4418
4419       /* Clear carry from T[i] */
4420       *tmpc++ &= MP_MASK;
4421     }
4422
4423     /* clear digits above used (since we may not have grown result above) */
4424     for (i = c->used; i < olduse; i++) {
4425       *tmpc++ = 0;
4426     }
4427   }
4428
4429   mp_clamp (c);
4430   return MP_OKAY;
4431 }