winhttp/tests: Fix a test failure on some W2K/XP systems.
[wine] / dlls / rsaenh / mpi.c
1 /*
2  * dlls/rsaenh/mpi.c
3  * Multi Precision Integer functions
4  *
5  * Copyright 2004 Michael Jung
6  * Based on public domain code by Tom St Denis (tomstdenis@iahu.ca)
7  *
8  * This library is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * This library is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with this library; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
21  */
22
23 /*
24  * This file contains code from the LibTomCrypt cryptographic 
25  * library written by Tom St Denis (tomstdenis@iahu.ca). LibTomCrypt
26  * is in the public domain. The code in this file is tailored to
27  * special requirements. Take a look at http://libtomcrypt.org for the
28  * original version. 
29  */
30
31 #include <stdarg.h>
32
33 #include "windef.h"
34 #include "winbase.h"
35 #include "tomcrypt.h"
36
37 /* Known optimal configurations
38  CPU                    /Compiler     /MUL CUTOFF/SQR CUTOFF
39 -------------------------------------------------------------
40  Intel P4 Northwood     /GCC v3.4.1   /        88/       128/LTM 0.32 ;-)
41 */
42 static const int KARATSUBA_MUL_CUTOFF = 88,  /* Min. number of digits before Karatsuba multiplication is used. */
43                  KARATSUBA_SQR_CUTOFF = 128; /* Min. number of digits before Karatsuba squaring is used. */
44
45
46 /* trim unused digits */
47 static void mp_clamp(mp_int *a);
48
49 /* compare |a| to |b| */
50 static int mp_cmp_mag(const mp_int *a, const mp_int *b);
51
52 /* Counts the number of lsbs which are zero before the first zero bit */
53 static int mp_cnt_lsb(const mp_int *a);
54
55 /* computes a = B**n mod b without division or multiplication useful for
56  * normalizing numbers in a Montgomery system.
57  */
58 static int mp_montgomery_calc_normalization(mp_int *a, const mp_int *b);
59
60 /* computes x/R == x (mod N) via Montgomery Reduction */
61 static int mp_montgomery_reduce(mp_int *a, const mp_int *m, mp_digit mp);
62
63 /* setups the montgomery reduction */
64 static int mp_montgomery_setup(const mp_int *a, mp_digit *mp);
65
66 /* Barrett Reduction, computes a (mod b) with a precomputed value c
67  *
68  * Assumes that 0 < a <= b*b, note if 0 > a > -(b*b) then you can merely
69  * compute the reduction as -1 * mp_reduce(mp_abs(a)) [pseudo code].
70  */
71 static int mp_reduce(mp_int *a, const mp_int *b, const mp_int *c);
72
73 /* reduces a modulo b where b is of the form 2**p - k [0 <= a] */
74 static int mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d);
75
76 /* determines k value for 2k reduction */
77 static int mp_reduce_2k_setup(const mp_int *a, mp_digit *d);
78
79 /* used to setup the Barrett reduction for a given modulus b */
80 static int mp_reduce_setup(mp_int *a, const mp_int *b);
81
82 /* set to a digit */
83 static void mp_set(mp_int *a, mp_digit b);
84
85 /* b = a*a  */
86 static int mp_sqr(const mp_int *a, mp_int *b);
87
88 /* c = a * a (mod b) */
89 static int mp_sqrmod(const mp_int *a, mp_int *b, mp_int *c);
90
91
92 static void bn_reverse(unsigned char *s, int len);
93 static int s_mp_add(mp_int *a, mp_int *b, mp_int *c);
94 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y);
95 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
96 static int s_mp_mul_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
97 static int s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
98 static int s_mp_sqr(const mp_int *a, mp_int *b);
99 static int s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c);
100 static int mp_exptmod_fast(const mp_int *G, const mp_int *X, mp_int *P, mp_int *Y, int mode);
101 static int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c);
102 static int mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c);
103 static int mp_karatsuba_sqr(const mp_int *a, mp_int *b);
104
105 /* grow as required */
106 static int mp_grow (mp_int * a, int size)
107 {
108   int     i;
109   mp_digit *tmp;
110
111   /* if the alloc size is smaller alloc more ram */
112   if (a->alloc < size) {
113     /* ensure there are always at least MP_PREC digits extra on top */
114     size += (MP_PREC * 2) - (size % MP_PREC);
115
116     /* reallocate the array a->dp
117      *
118      * We store the return in a temporary variable
119      * in case the operation failed we don't want
120      * to overwrite the dp member of a.
121      */
122     tmp = HeapReAlloc(GetProcessHeap(), 0, a->dp, sizeof (mp_digit) * size);
123     if (tmp == NULL) {
124       /* reallocation failed but "a" is still valid [can be freed] */
125       return MP_MEM;
126     }
127
128     /* reallocation succeeded so set a->dp */
129     a->dp = tmp;
130
131     /* zero excess digits */
132     i        = a->alloc;
133     a->alloc = size;
134     for (; i < a->alloc; i++) {
135       a->dp[i] = 0;
136     }
137   }
138   return MP_OKAY;
139 }
140
141 /* b = a/2 */
142 static int mp_div_2(const mp_int * a, mp_int * b)
143 {
144   int     x, res, oldused;
145
146   /* copy */
147   if (b->alloc < a->used) {
148     if ((res = mp_grow (b, a->used)) != MP_OKAY) {
149       return res;
150     }
151   }
152
153   oldused = b->used;
154   b->used = a->used;
155   {
156     register mp_digit r, rr, *tmpa, *tmpb;
157
158     /* source alias */
159     tmpa = a->dp + b->used - 1;
160
161     /* dest alias */
162     tmpb = b->dp + b->used - 1;
163
164     /* carry */
165     r = 0;
166     for (x = b->used - 1; x >= 0; x--) {
167       /* get the carry for the next iteration */
168       rr = *tmpa & 1;
169
170       /* shift the current digit, add in carry and store */
171       *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
172
173       /* forward carry to next iteration */
174       r = rr;
175     }
176
177     /* zero excess digits */
178     tmpb = b->dp + b->used;
179     for (x = b->used; x < oldused; x++) {
180       *tmpb++ = 0;
181     }
182   }
183   b->sign = a->sign;
184   mp_clamp (b);
185   return MP_OKAY;
186 }
187
188 /* swap the elements of two integers, for cases where you can't simply swap the
189  * mp_int pointers around
190  */
191 static void
192 mp_exch (mp_int * a, mp_int * b)
193 {
194   mp_int  t;
195
196   t  = *a;
197   *a = *b;
198   *b = t;
199 }
200
201 /* init a new mp_int */
202 static int mp_init (mp_int * a)
203 {
204   int i;
205
206   /* allocate memory required and clear it */
207   a->dp = HeapAlloc(GetProcessHeap(), 0, sizeof (mp_digit) * MP_PREC);
208   if (a->dp == NULL) {
209     return MP_MEM;
210   }
211
212   /* set the digits to zero */
213   for (i = 0; i < MP_PREC; i++) {
214       a->dp[i] = 0;
215   }
216
217   /* set the used to zero, allocated digits to the default precision
218    * and sign to positive */
219   a->used  = 0;
220   a->alloc = MP_PREC;
221   a->sign  = MP_ZPOS;
222
223   return MP_OKAY;
224 }
225
226 /* init an mp_init for a given size */
227 static int mp_init_size (mp_int * a, int size)
228 {
229   int x;
230
231   /* pad size so there are always extra digits */
232   size += (MP_PREC * 2) - (size % MP_PREC);
233
234   /* alloc mem */
235   a->dp = HeapAlloc(GetProcessHeap(), 0, sizeof (mp_digit) * size);
236   if (a->dp == NULL) {
237     return MP_MEM;
238   }
239
240   /* set the members */
241   a->used  = 0;
242   a->alloc = size;
243   a->sign  = MP_ZPOS;
244
245   /* zero the digits */
246   for (x = 0; x < size; x++) {
247       a->dp[x] = 0;
248   }
249
250   return MP_OKAY;
251 }
252
253 /* clear one (frees)  */
254 static void
255 mp_clear (mp_int * a)
256 {
257   int i;
258
259   /* only do anything if a hasn't been freed previously */
260   if (a->dp != NULL) {
261     /* first zero the digits */
262     for (i = 0; i < a->used; i++) {
263         a->dp[i] = 0;
264     }
265
266     /* free ram */
267     HeapFree(GetProcessHeap(), 0, a->dp);
268
269     /* reset members to make debugging easier */
270     a->dp    = NULL;
271     a->alloc = a->used = 0;
272     a->sign  = MP_ZPOS;
273   }
274 }
275
276 /* set to zero */
277 static void
278 mp_zero (mp_int * a)
279 {
280   a->sign = MP_ZPOS;
281   a->used = 0;
282   memset (a->dp, 0, sizeof (mp_digit) * a->alloc);
283 }
284
285 /* b = |a|
286  *
287  * Simple function copies the input and fixes the sign to positive
288  */
289 static int
290 mp_abs (const mp_int * a, mp_int * b)
291 {
292   int     res;
293
294   /* copy a to b */
295   if (a != b) {
296      if ((res = mp_copy (a, b)) != MP_OKAY) {
297        return res;
298      }
299   }
300
301   /* force the sign of b to positive */
302   b->sign = MP_ZPOS;
303
304   return MP_OKAY;
305 }
306
307 /* computes the modular inverse via binary extended euclidean algorithm, 
308  * that is c = 1/a mod b 
309  *
310  * Based on slow invmod except this is optimized for the case where b is 
311  * odd as per HAC Note 14.64 on pp. 610
312  */
313 static int
314 fast_mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
315 {
316   mp_int  x, y, u, v, B, D;
317   int     res, neg;
318
319   /* 2. [modified] b must be odd   */
320   if (mp_iseven (b) == 1) {
321     return MP_VAL;
322   }
323
324   /* init all our temps */
325   if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D, NULL)) != MP_OKAY) {
326      return res;
327   }
328
329   /* x == modulus, y == value to invert */
330   if ((res = mp_copy (b, &x)) != MP_OKAY) {
331     goto __ERR;
332   }
333
334   /* we need y = |a| */
335   if ((res = mp_abs (a, &y)) != MP_OKAY) {
336     goto __ERR;
337   }
338
339   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
340   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
341     goto __ERR;
342   }
343   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
344     goto __ERR;
345   }
346   mp_set (&D, 1);
347
348 top:
349   /* 4.  while u is even do */
350   while (mp_iseven (&u) == 1) {
351     /* 4.1 u = u/2 */
352     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
353       goto __ERR;
354     }
355     /* 4.2 if B is odd then */
356     if (mp_isodd (&B) == 1) {
357       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
358         goto __ERR;
359       }
360     }
361     /* B = B/2 */
362     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
363       goto __ERR;
364     }
365   }
366
367   /* 5.  while v is even do */
368   while (mp_iseven (&v) == 1) {
369     /* 5.1 v = v/2 */
370     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
371       goto __ERR;
372     }
373     /* 5.2 if D is odd then */
374     if (mp_isodd (&D) == 1) {
375       /* D = (D-x)/2 */
376       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
377         goto __ERR;
378       }
379     }
380     /* D = D/2 */
381     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
382       goto __ERR;
383     }
384   }
385
386   /* 6.  if u >= v then */
387   if (mp_cmp (&u, &v) != MP_LT) {
388     /* u = u - v, B = B - D */
389     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
390       goto __ERR;
391     }
392
393     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
394       goto __ERR;
395     }
396   } else {
397     /* v - v - u, D = D - B */
398     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
399       goto __ERR;
400     }
401
402     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
403       goto __ERR;
404     }
405   }
406
407   /* if not zero goto step 4 */
408   if (mp_iszero (&u) == 0) {
409     goto top;
410   }
411
412   /* now a = C, b = D, gcd == g*v */
413
414   /* if v != 1 then there is no inverse */
415   if (mp_cmp_d (&v, 1) != MP_EQ) {
416     res = MP_VAL;
417     goto __ERR;
418   }
419
420   /* b is now the inverse */
421   neg = a->sign;
422   while (D.sign == MP_NEG) {
423     if ((res = mp_add (&D, b, &D)) != MP_OKAY) {
424       goto __ERR;
425     }
426   }
427   mp_exch (&D, c);
428   c->sign = neg;
429   res = MP_OKAY;
430
431 __ERR:mp_clear_multi (&x, &y, &u, &v, &B, &D, NULL);
432   return res;
433 }
434
435 /* computes xR**-1 == x (mod N) via Montgomery Reduction
436  *
437  * This is an optimized implementation of montgomery_reduce
438  * which uses the comba method to quickly calculate the columns of the
439  * reduction.
440  *
441  * Based on Algorithm 14.32 on pp.601 of HAC.
442 */
443 static int
444 fast_mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
445 {
446   int     ix, res, olduse;
447   mp_word W[MP_WARRAY];
448
449   /* get old used count */
450   olduse = x->used;
451
452   /* grow a as required */
453   if (x->alloc < n->used + 1) {
454     if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
455       return res;
456     }
457   }
458
459   /* first we have to get the digits of the input into
460    * an array of double precision words W[...]
461    */
462   {
463     register mp_word *_W;
464     register mp_digit *tmpx;
465
466     /* alias for the W[] array */
467     _W   = W;
468
469     /* alias for the digits of  x*/
470     tmpx = x->dp;
471
472     /* copy the digits of a into W[0..a->used-1] */
473     for (ix = 0; ix < x->used; ix++) {
474       *_W++ = *tmpx++;
475     }
476
477     /* zero the high words of W[a->used..m->used*2] */
478     for (; ix < n->used * 2 + 1; ix++) {
479       *_W++ = 0;
480     }
481   }
482
483   /* now we proceed to zero successive digits
484    * from the least significant upwards
485    */
486   for (ix = 0; ix < n->used; ix++) {
487     /* mu = ai * m' mod b
488      *
489      * We avoid a double precision multiplication (which isn't required)
490      * by casting the value down to a mp_digit.  Note this requires
491      * that W[ix-1] have  the carry cleared (see after the inner loop)
492      */
493     register mp_digit mu;
494     mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
495
496     /* a = a + mu * m * b**i
497      *
498      * This is computed in place and on the fly.  The multiplication
499      * by b**i is handled by offsetting which columns the results
500      * are added to.
501      *
502      * Note the comba method normally doesn't handle carries in the
503      * inner loop In this case we fix the carry from the previous
504      * column since the Montgomery reduction requires digits of the
505      * result (so far) [see above] to work.  This is
506      * handled by fixing up one carry after the inner loop.  The
507      * carry fixups are done in order so after these loops the
508      * first m->used words of W[] have the carries fixed
509      */
510     {
511       register int iy;
512       register mp_digit *tmpn;
513       register mp_word *_W;
514
515       /* alias for the digits of the modulus */
516       tmpn = n->dp;
517
518       /* Alias for the columns set by an offset of ix */
519       _W = W + ix;
520
521       /* inner loop */
522       for (iy = 0; iy < n->used; iy++) {
523           *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
524       }
525     }
526
527     /* now fix carry for next digit, W[ix+1] */
528     W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
529   }
530
531   /* now we have to propagate the carries and
532    * shift the words downward [all those least
533    * significant digits we zeroed].
534    */
535   {
536     register mp_digit *tmpx;
537     register mp_word *_W, *_W1;
538
539     /* nox fix rest of carries */
540
541     /* alias for current word */
542     _W1 = W + ix;
543
544     /* alias for next word, where the carry goes */
545     _W = W + ++ix;
546
547     for (; ix <= n->used * 2 + 1; ix++) {
548       *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
549     }
550
551     /* copy out, A = A/b**n
552      *
553      * The result is A/b**n but instead of converting from an
554      * array of mp_word to mp_digit than calling mp_rshd
555      * we just copy them in the right order
556      */
557
558     /* alias for destination word */
559     tmpx = x->dp;
560
561     /* alias for shifted double precision result */
562     _W = W + n->used;
563
564     for (ix = 0; ix < n->used + 1; ix++) {
565       *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
566     }
567
568     /* zero oldused digits, if the input a was larger than
569      * m->used+1 we'll have to clear the digits
570      */
571     for (; ix < olduse; ix++) {
572       *tmpx++ = 0;
573     }
574   }
575
576   /* set the max used and clamp */
577   x->used = n->used + 1;
578   mp_clamp (x);
579
580   /* if A >= m then A = A - m */
581   if (mp_cmp_mag (x, n) != MP_LT) {
582     return s_mp_sub (x, n, x);
583   }
584   return MP_OKAY;
585 }
586
587 /* Fast (comba) multiplier
588  *
589  * This is the fast column-array [comba] multiplier.  It is 
590  * designed to compute the columns of the product first 
591  * then handle the carries afterwards.  This has the effect 
592  * of making the nested loops that compute the columns very
593  * simple and schedulable on super-scalar processors.
594  *
595  * This has been modified to produce a variable number of 
596  * digits of output so if say only a half-product is required 
597  * you don't have to compute the upper half (a feature 
598  * required for fast Barrett reduction).
599  *
600  * Based on Algorithm 14.12 on pp.595 of HAC.
601  *
602  */
603 static int
604 fast_s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
605 {
606   int     olduse, res, pa, ix, iz;
607   mp_digit W[MP_WARRAY];
608   register mp_word  _W;
609
610   /* grow the destination as required */
611   if (c->alloc < digs) {
612     if ((res = mp_grow (c, digs)) != MP_OKAY) {
613       return res;
614     }
615   }
616
617   /* number of output digits to produce */
618   pa = MIN(digs, a->used + b->used);
619
620   /* clear the carry */
621   _W = 0;
622   for (ix = 0; ix <= pa; ix++) { 
623       int      tx, ty;
624       int      iy;
625       mp_digit *tmpx, *tmpy;
626
627       /* get offsets into the two bignums */
628       ty = MIN(b->used-1, ix);
629       tx = ix - ty;
630
631       /* setup temp aliases */
632       tmpx = a->dp + tx;
633       tmpy = b->dp + ty;
634
635       /* This is the number of times the loop will iterate, essentially it's
636          while (tx++ < a->used && ty-- >= 0) { ... }
637        */
638       iy = MIN(a->used-tx, ty+1);
639
640       /* execute loop */
641       for (iz = 0; iz < iy; ++iz) {
642          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
643       }
644
645       /* store term */
646       W[ix] = ((mp_digit)_W) & MP_MASK;
647
648       /* make next carry */
649       _W = _W >> ((mp_word)DIGIT_BIT);
650   }
651
652   /* setup dest */
653   olduse  = c->used;
654   c->used = digs;
655
656   {
657     register mp_digit *tmpc;
658     tmpc = c->dp;
659     for (ix = 0; ix < digs; ix++) {
660       /* now extract the previous digit [below the carry] */
661       *tmpc++ = W[ix];
662     }
663
664     /* clear unused digits [that existed in the old copy of c] */
665     for (; ix < olduse; ix++) {
666       *tmpc++ = 0;
667     }
668   }
669   mp_clamp (c);
670   return MP_OKAY;
671 }
672
673 /* this is a modified version of fast_s_mul_digs that only produces
674  * output digits *above* digs.  See the comments for fast_s_mul_digs
675  * to see how it works.
676  *
677  * This is used in the Barrett reduction since for one of the multiplications
678  * only the higher digits were needed.  This essentially halves the work.
679  *
680  * Based on Algorithm 14.12 on pp.595 of HAC.
681  */
682 static int
683 fast_s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
684 {
685   int     olduse, res, pa, ix, iz;
686   mp_digit W[MP_WARRAY];
687   mp_word  _W;
688
689   /* grow the destination as required */
690   pa = a->used + b->used;
691   if (c->alloc < pa) {
692     if ((res = mp_grow (c, pa)) != MP_OKAY) {
693       return res;
694     }
695   }
696
697   /* number of output digits to produce */
698   pa = a->used + b->used;
699   _W = 0;
700   for (ix = digs; ix <= pa; ix++) { 
701       int      tx, ty, iy;
702       mp_digit *tmpx, *tmpy;
703
704       /* get offsets into the two bignums */
705       ty = MIN(b->used-1, ix);
706       tx = ix - ty;
707
708       /* setup temp aliases */
709       tmpx = a->dp + tx;
710       tmpy = b->dp + ty;
711
712       /* This is the number of times the loop will iterate, essentially it's
713          while (tx++ < a->used && ty-- >= 0) { ... }
714        */
715       iy = MIN(a->used-tx, ty+1);
716
717       /* execute loop */
718       for (iz = 0; iz < iy; iz++) {
719          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
720       }
721
722       /* store term */
723       W[ix] = ((mp_digit)_W) & MP_MASK;
724
725       /* make next carry */
726       _W = _W >> ((mp_word)DIGIT_BIT);
727   }
728
729   /* setup dest */
730   olduse  = c->used;
731   c->used = pa;
732
733   {
734     register mp_digit *tmpc;
735
736     tmpc = c->dp + digs;
737     for (ix = digs; ix <= pa; ix++) {
738       /* now extract the previous digit [below the carry] */
739       *tmpc++ = W[ix];
740     }
741
742     /* clear unused digits [that existed in the old copy of c] */
743     for (; ix < olduse; ix++) {
744       *tmpc++ = 0;
745     }
746   }
747   mp_clamp (c);
748   return MP_OKAY;
749 }
750
751 /* fast squaring
752  *
753  * This is the comba method where the columns of the product
754  * are computed first then the carries are computed.  This
755  * has the effect of making a very simple inner loop that
756  * is executed the most
757  *
758  * W2 represents the outer products and W the inner.
759  *
760  * A further optimizations is made because the inner
761  * products are of the form "A * B * 2".  The *2 part does
762  * not need to be computed until the end which is good
763  * because 64-bit shifts are slow!
764  *
765  * Based on Algorithm 14.16 on pp.597 of HAC.
766  *
767  */
768 /* the jist of squaring...
769
770 you do like mult except the offset of the tmpx [one that starts closer to zero]
771 can't equal the offset of tmpy.  So basically you set up iy like before then you min it with
772 (ty-tx) so that it never happens.  You double all those you add in the inner loop
773
774 After that loop you do the squares and add them in.
775
776 Remove W2 and don't memset W
777
778 */
779
780 static int fast_s_mp_sqr (const mp_int * a, mp_int * b)
781 {
782   int       olduse, res, pa, ix, iz;
783   mp_digit   W[MP_WARRAY], *tmpx;
784   mp_word   W1;
785
786   /* grow the destination as required */
787   pa = a->used + a->used;
788   if (b->alloc < pa) {
789     if ((res = mp_grow (b, pa)) != MP_OKAY) {
790       return res;
791     }
792   }
793
794   /* number of output digits to produce */
795   W1 = 0;
796   for (ix = 0; ix <= pa; ix++) { 
797       int      tx, ty, iy;
798       mp_word  _W;
799       mp_digit *tmpy;
800
801       /* clear counter */
802       _W = 0;
803
804       /* get offsets into the two bignums */
805       ty = MIN(a->used-1, ix);
806       tx = ix - ty;
807
808       /* setup temp aliases */
809       tmpx = a->dp + tx;
810       tmpy = a->dp + ty;
811
812       /* This is the number of times the loop will iterate, essentially it's
813          while (tx++ < a->used && ty-- >= 0) { ... }
814        */
815       iy = MIN(a->used-tx, ty+1);
816
817       /* now for squaring tx can never equal ty 
818        * we halve the distance since they approach at a rate of 2x
819        * and we have to round because odd cases need to be executed
820        */
821       iy = MIN(iy, (ty-tx+1)>>1);
822
823       /* execute loop */
824       for (iz = 0; iz < iy; iz++) {
825          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
826       }
827
828       /* double the inner product and add carry */
829       _W = _W + _W + W1;
830
831       /* even columns have the square term in them */
832       if ((ix&1) == 0) {
833          _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
834       }
835
836       /* store it */
837       W[ix] = _W;
838
839       /* make next carry */
840       W1 = _W >> ((mp_word)DIGIT_BIT);
841   }
842
843   /* setup dest */
844   olduse  = b->used;
845   b->used = a->used+a->used;
846
847   {
848     mp_digit *tmpb;
849     tmpb = b->dp;
850     for (ix = 0; ix < pa; ix++) {
851       *tmpb++ = W[ix] & MP_MASK;
852     }
853
854     /* clear unused digits [that existed in the old copy of c] */
855     for (; ix < olduse; ix++) {
856       *tmpb++ = 0;
857     }
858   }
859   mp_clamp (b);
860   return MP_OKAY;
861 }
862
863 /* computes a = 2**b 
864  *
865  * Simple algorithm which zeroes the int, grows it then just sets one bit
866  * as required.
867  */
868 static int
869 mp_2expt (mp_int * a, int b)
870 {
871   int     res;
872
873   /* zero a as per default */
874   mp_zero (a);
875
876   /* grow a to accommodate the single bit */
877   if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
878     return res;
879   }
880
881   /* set the used count of where the bit will go */
882   a->used = b / DIGIT_BIT + 1;
883
884   /* put the single bit in its place */
885   a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
886
887   return MP_OKAY;
888 }
889
890 /* high level addition (handles signs) */
891 int mp_add (mp_int * a, mp_int * b, mp_int * c)
892 {
893   int     sa, sb, res;
894
895   /* get sign of both inputs */
896   sa = a->sign;
897   sb = b->sign;
898
899   /* handle two cases, not four */
900   if (sa == sb) {
901     /* both positive or both negative */
902     /* add their magnitudes, copy the sign */
903     c->sign = sa;
904     res = s_mp_add (a, b, c);
905   } else {
906     /* one positive, the other negative */
907     /* subtract the one with the greater magnitude from */
908     /* the one of the lesser magnitude.  The result gets */
909     /* the sign of the one with the greater magnitude. */
910     if (mp_cmp_mag (a, b) == MP_LT) {
911       c->sign = sb;
912       res = s_mp_sub (b, a, c);
913     } else {
914       c->sign = sa;
915       res = s_mp_sub (a, b, c);
916     }
917   }
918   return res;
919 }
920
921
922 /* single digit addition */
923 static int
924 mp_add_d (mp_int * a, mp_digit b, mp_int * c)
925 {
926   int     res, ix, oldused;
927   mp_digit *tmpa, *tmpc, mu;
928
929   /* grow c as required */
930   if (c->alloc < a->used + 1) {
931      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
932         return res;
933      }
934   }
935
936   /* if a is negative and |a| >= b, call c = |a| - b */
937   if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
938      /* temporarily fix sign of a */
939      a->sign = MP_ZPOS;
940
941      /* c = |a| - b */
942      res = mp_sub_d(a, b, c);
943
944      /* fix sign  */
945      a->sign = c->sign = MP_NEG;
946
947      return res;
948   }
949
950   /* old number of used digits in c */
951   oldused = c->used;
952
953   /* sign always positive */
954   c->sign = MP_ZPOS;
955
956   /* source alias */
957   tmpa    = a->dp;
958
959   /* destination alias */
960   tmpc    = c->dp;
961
962   /* if a is positive */
963   if (a->sign == MP_ZPOS) {
964      /* add digit, after this we're propagating
965       * the carry.
966       */
967      *tmpc   = *tmpa++ + b;
968      mu      = *tmpc >> DIGIT_BIT;
969      *tmpc++ &= MP_MASK;
970
971      /* now handle rest of the digits */
972      for (ix = 1; ix < a->used; ix++) {
973         *tmpc   = *tmpa++ + mu;
974         mu      = *tmpc >> DIGIT_BIT;
975         *tmpc++ &= MP_MASK;
976      }
977      /* set final carry */
978      ix++;
979      *tmpc++  = mu;
980
981      /* setup size */
982      c->used = a->used + 1;
983   } else {
984      /* a was negative and |a| < b */
985      c->used  = 1;
986
987      /* the result is a single digit */
988      if (a->used == 1) {
989         *tmpc++  =  b - a->dp[0];
990      } else {
991         *tmpc++  =  b;
992      }
993
994      /* setup count so the clearing of oldused
995       * can fall through correctly
996       */
997      ix       = 1;
998   }
999
1000   /* now zero to oldused */
1001   while (ix++ < oldused) {
1002      *tmpc++ = 0;
1003   }
1004   mp_clamp(c);
1005
1006   return MP_OKAY;
1007 }
1008
1009 /* trim unused digits 
1010  *
1011  * This is used to ensure that leading zero digits are
1012  * trimed and the leading "used" digit will be non-zero
1013  * Typically very fast.  Also fixes the sign if there
1014  * are no more leading digits
1015  */
1016 void
1017 mp_clamp (mp_int * a)
1018 {
1019   /* decrease used while the most significant digit is
1020    * zero.
1021    */
1022   while (a->used > 0 && a->dp[a->used - 1] == 0) {
1023     --(a->used);
1024   }
1025
1026   /* reset the sign flag if used == 0 */
1027   if (a->used == 0) {
1028     a->sign = MP_ZPOS;
1029   }
1030 }
1031
1032 void mp_clear_multi(mp_int *mp, ...) 
1033 {
1034     mp_int* next_mp = mp;
1035     va_list args;
1036     va_start(args, mp);
1037     while (next_mp != NULL) {
1038         mp_clear(next_mp);
1039         next_mp = va_arg(args, mp_int*);
1040     }
1041     va_end(args);
1042 }
1043
1044 /* compare two ints (signed)*/
1045 int
1046 mp_cmp (const mp_int * a, const mp_int * b)
1047 {
1048   /* compare based on sign */
1049   if (a->sign != b->sign) {
1050      if (a->sign == MP_NEG) {
1051         return MP_LT;
1052      } else {
1053         return MP_GT;
1054      }
1055   }
1056   
1057   /* compare digits */
1058   if (a->sign == MP_NEG) {
1059      /* if negative compare opposite direction */
1060      return mp_cmp_mag(b, a);
1061   } else {
1062      return mp_cmp_mag(a, b);
1063   }
1064 }
1065
1066 /* compare a digit */
1067 int mp_cmp_d(const mp_int * a, mp_digit b)
1068 {
1069   /* compare based on sign */
1070   if (a->sign == MP_NEG) {
1071     return MP_LT;
1072   }
1073
1074   /* compare based on magnitude */
1075   if (a->used > 1) {
1076     return MP_GT;
1077   }
1078
1079   /* compare the only digit of a to b */
1080   if (a->dp[0] > b) {
1081     return MP_GT;
1082   } else if (a->dp[0] < b) {
1083     return MP_LT;
1084   } else {
1085     return MP_EQ;
1086   }
1087 }
1088
1089 /* compare maginitude of two ints (unsigned) */
1090 int mp_cmp_mag (const mp_int * a, const mp_int * b)
1091 {
1092   int     n;
1093   mp_digit *tmpa, *tmpb;
1094
1095   /* compare based on # of non-zero digits */
1096   if (a->used > b->used) {
1097     return MP_GT;
1098   }
1099   
1100   if (a->used < b->used) {
1101     return MP_LT;
1102   }
1103
1104   /* alias for a */
1105   tmpa = a->dp + (a->used - 1);
1106
1107   /* alias for b */
1108   tmpb = b->dp + (a->used - 1);
1109
1110   /* compare based on digits  */
1111   for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
1112     if (*tmpa > *tmpb) {
1113       return MP_GT;
1114     }
1115
1116     if (*tmpa < *tmpb) {
1117       return MP_LT;
1118     }
1119   }
1120   return MP_EQ;
1121 }
1122
1123 static const int lnz[16] = { 
1124    4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
1125 };
1126
1127 /* Counts the number of lsbs which are zero before the first zero bit */
1128 int mp_cnt_lsb(const mp_int *a)
1129 {
1130    int x;
1131    mp_digit q, qq;
1132
1133    /* easy out */
1134    if (mp_iszero(a) == 1) {
1135       return 0;
1136    }
1137
1138    /* scan lower digits until non-zero */
1139    for (x = 0; x < a->used && a->dp[x] == 0; x++);
1140    q = a->dp[x];
1141    x *= DIGIT_BIT;
1142
1143    /* now scan this digit until a 1 is found */
1144    if ((q & 1) == 0) {
1145       do {
1146          qq  = q & 15;
1147          x  += lnz[qq];
1148          q >>= 4;
1149       } while (qq == 0);
1150    }
1151    return x;
1152 }
1153
1154 /* copy, b = a */
1155 int
1156 mp_copy (const mp_int * a, mp_int * b)
1157 {
1158   int     res, n;
1159
1160   /* if dst == src do nothing */
1161   if (a == b) {
1162     return MP_OKAY;
1163   }
1164
1165   /* grow dest */
1166   if (b->alloc < a->used) {
1167      if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1168         return res;
1169      }
1170   }
1171
1172   /* zero b and copy the parameters over */
1173   {
1174     register mp_digit *tmpa, *tmpb;
1175
1176     /* pointer aliases */
1177
1178     /* source */
1179     tmpa = a->dp;
1180
1181     /* destination */
1182     tmpb = b->dp;
1183
1184     /* copy all the digits */
1185     for (n = 0; n < a->used; n++) {
1186       *tmpb++ = *tmpa++;
1187     }
1188
1189     /* clear high digits */
1190     for (; n < b->used; n++) {
1191       *tmpb++ = 0;
1192     }
1193   }
1194
1195   /* copy used count and sign */
1196   b->used = a->used;
1197   b->sign = a->sign;
1198   return MP_OKAY;
1199 }
1200
1201 /* returns the number of bits in an int */
1202 int
1203 mp_count_bits (const mp_int * a)
1204 {
1205   int     r;
1206   mp_digit q;
1207
1208   /* shortcut */
1209   if (a->used == 0) {
1210     return 0;
1211   }
1212
1213   /* get number of digits and add that */
1214   r = (a->used - 1) * DIGIT_BIT;
1215   
1216   /* take the last digit and count the bits in it */
1217   q = a->dp[a->used - 1];
1218   while (q > 0) {
1219     ++r;
1220     q >>= ((mp_digit) 1);
1221   }
1222   return r;
1223 }
1224
1225 /* calc a value mod 2**b */
1226 static int
1227 mp_mod_2d (const mp_int * a, int b, mp_int * c)
1228 {
1229   int     x, res;
1230
1231   /* if b is <= 0 then zero the int */
1232   if (b <= 0) {
1233     mp_zero (c);
1234     return MP_OKAY;
1235   }
1236
1237   /* if the modulus is larger than the value than return */
1238   if (b > a->used * DIGIT_BIT) {
1239     res = mp_copy (a, c);
1240     return res;
1241   }
1242
1243   /* copy */
1244   if ((res = mp_copy (a, c)) != MP_OKAY) {
1245     return res;
1246   }
1247
1248   /* zero digits above the last digit of the modulus */
1249   for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
1250     c->dp[x] = 0;
1251   }
1252   /* clear the digit that is not completely outside/inside the modulus */
1253   c->dp[b / DIGIT_BIT] &= (1 << ((mp_digit)b % DIGIT_BIT)) - 1;
1254   mp_clamp (c);
1255   return MP_OKAY;
1256 }
1257
1258 /* shift right a certain amount of digits */
1259 static void mp_rshd (mp_int * a, int b)
1260 {
1261   int     x;
1262
1263   /* if b <= 0 then ignore it */
1264   if (b <= 0) {
1265     return;
1266   }
1267
1268   /* if b > used then simply zero it and return */
1269   if (a->used <= b) {
1270     mp_zero (a);
1271     return;
1272   }
1273
1274   {
1275     register mp_digit *bottom, *top;
1276
1277     /* shift the digits down */
1278
1279     /* bottom */
1280     bottom = a->dp;
1281
1282     /* top [offset into digits] */
1283     top = a->dp + b;
1284
1285     /* this is implemented as a sliding window where
1286      * the window is b-digits long and digits from
1287      * the top of the window are copied to the bottom
1288      *
1289      * e.g.
1290
1291      b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
1292                  /\                   |      ---->
1293                   \-------------------/      ---->
1294      */
1295     for (x = 0; x < (a->used - b); x++) {
1296       *bottom++ = *top++;
1297     }
1298
1299     /* zero the top digits */
1300     for (; x < a->used; x++) {
1301       *bottom++ = 0;
1302     }
1303   }
1304
1305   /* remove excess digits */
1306   a->used -= b;
1307 }
1308
1309 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
1310 static int mp_div_2d (const mp_int * a, int b, mp_int * c, mp_int * d)
1311 {
1312   mp_digit D, r, rr;
1313   int     x, res;
1314   mp_int  t;
1315
1316
1317   /* if the shift count is <= 0 then we do no work */
1318   if (b <= 0) {
1319     res = mp_copy (a, c);
1320     if (d != NULL) {
1321       mp_zero (d);
1322     }
1323     return res;
1324   }
1325
1326   if ((res = mp_init (&t)) != MP_OKAY) {
1327     return res;
1328   }
1329
1330   /* get the remainder */
1331   if (d != NULL) {
1332     if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
1333       mp_clear (&t);
1334       return res;
1335     }
1336   }
1337
1338   /* copy */
1339   if ((res = mp_copy (a, c)) != MP_OKAY) {
1340     mp_clear (&t);
1341     return res;
1342   }
1343
1344   /* shift by as many digits in the bit count */
1345   if (b >= DIGIT_BIT) {
1346     mp_rshd (c, b / DIGIT_BIT);
1347   }
1348
1349   /* shift any bit count < DIGIT_BIT */
1350   D = (mp_digit) (b % DIGIT_BIT);
1351   if (D != 0) {
1352     register mp_digit *tmpc, mask, shift;
1353
1354     /* mask */
1355     mask = (((mp_digit)1) << D) - 1;
1356
1357     /* shift for lsb */
1358     shift = DIGIT_BIT - D;
1359
1360     /* alias */
1361     tmpc = c->dp + (c->used - 1);
1362
1363     /* carry */
1364     r = 0;
1365     for (x = c->used - 1; x >= 0; x--) {
1366       /* get the lower  bits of this word in a temp */
1367       rr = *tmpc & mask;
1368
1369       /* shift the current word and mix in the carry bits from the previous word */
1370       *tmpc = (*tmpc >> D) | (r << shift);
1371       --tmpc;
1372
1373       /* set the carry to the carry bits of the current word found above */
1374       r = rr;
1375     }
1376   }
1377   mp_clamp (c);
1378   if (d != NULL) {
1379     mp_exch (&t, d);
1380   }
1381   mp_clear (&t);
1382   return MP_OKAY;
1383 }
1384
1385 /* shift left a certain amount of digits */
1386 static int mp_lshd (mp_int * a, int b)
1387 {
1388   int     x, res;
1389
1390   /* if its less than zero return */
1391   if (b <= 0) {
1392     return MP_OKAY;
1393   }
1394
1395   /* grow to fit the new digits */
1396   if (a->alloc < a->used + b) {
1397      if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
1398        return res;
1399      }
1400   }
1401
1402   {
1403     register mp_digit *top, *bottom;
1404
1405     /* increment the used by the shift amount then copy upwards */
1406     a->used += b;
1407
1408     /* top */
1409     top = a->dp + a->used - 1;
1410
1411     /* base */
1412     bottom = a->dp + a->used - 1 - b;
1413
1414     /* much like mp_rshd this is implemented using a sliding window
1415      * except the window goes the otherway around.  Copying from
1416      * the bottom to the top.  see bn_mp_rshd.c for more info.
1417      */
1418     for (x = a->used - 1; x >= b; x--) {
1419       *top-- = *bottom--;
1420     }
1421
1422     /* zero the lower digits */
1423     top = a->dp;
1424     for (x = 0; x < b; x++) {
1425       *top++ = 0;
1426     }
1427   }
1428   return MP_OKAY;
1429 }
1430
1431 /* shift left by a certain bit count */
1432 static int mp_mul_2d (const mp_int * a, int b, mp_int * c)
1433 {
1434   mp_digit d;
1435   int      res;
1436
1437   /* copy */
1438   if (a != c) {
1439      if ((res = mp_copy (a, c)) != MP_OKAY) {
1440        return res;
1441      }
1442   }
1443
1444   if (c->alloc < c->used + b/DIGIT_BIT + 1) {
1445      if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
1446        return res;
1447      }
1448   }
1449
1450   /* shift by as many digits in the bit count */
1451   if (b >= DIGIT_BIT) {
1452     if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
1453       return res;
1454     }
1455   }
1456
1457   /* shift any bit count < DIGIT_BIT */
1458   d = (mp_digit) (b % DIGIT_BIT);
1459   if (d != 0) {
1460     register mp_digit *tmpc, shift, mask, r, rr;
1461     register int x;
1462
1463     /* bitmask for carries */
1464     mask = (((mp_digit)1) << d) - 1;
1465
1466     /* shift for msbs */
1467     shift = DIGIT_BIT - d;
1468
1469     /* alias */
1470     tmpc = c->dp;
1471
1472     /* carry */
1473     r    = 0;
1474     for (x = 0; x < c->used; x++) {
1475       /* get the higher bits of the current word */
1476       rr = (*tmpc >> shift) & mask;
1477
1478       /* shift the current word and OR in the carry */
1479       *tmpc = ((*tmpc << d) | r) & MP_MASK;
1480       ++tmpc;
1481
1482       /* set the carry to the carry bits of the current word */
1483       r = rr;
1484     }
1485
1486     /* set final carry */
1487     if (r != 0) {
1488        c->dp[(c->used)++] = r;
1489     }
1490   }
1491   mp_clamp (c);
1492   return MP_OKAY;
1493 }
1494
1495 /* multiply by a digit */
1496 static int
1497 mp_mul_d (const mp_int * a, mp_digit b, mp_int * c)
1498 {
1499   mp_digit u, *tmpa, *tmpc;
1500   mp_word  r;
1501   int      ix, res, olduse;
1502
1503   /* make sure c is big enough to hold a*b */
1504   if (c->alloc < a->used + 1) {
1505     if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
1506       return res;
1507     }
1508   }
1509
1510   /* get the original destinations used count */
1511   olduse = c->used;
1512
1513   /* set the sign */
1514   c->sign = a->sign;
1515
1516   /* alias for a->dp [source] */
1517   tmpa = a->dp;
1518
1519   /* alias for c->dp [dest] */
1520   tmpc = c->dp;
1521
1522   /* zero carry */
1523   u = 0;
1524
1525   /* compute columns */
1526   for (ix = 0; ix < a->used; ix++) {
1527     /* compute product and carry sum for this term */
1528     r       = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
1529
1530     /* mask off higher bits to get a single digit */
1531     *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
1532
1533     /* send carry into next iteration */
1534     u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
1535   }
1536
1537   /* store final carry [if any] */
1538   *tmpc++ = u;
1539
1540   /* now zero digits above the top */
1541   while (ix++ < olduse) {
1542      *tmpc++ = 0;
1543   }
1544
1545   /* set used count */
1546   c->used = a->used + 1;
1547   mp_clamp(c);
1548
1549   return MP_OKAY;
1550 }
1551
1552 /* integer signed division. 
1553  * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
1554  * HAC pp.598 Algorithm 14.20
1555  *
1556  * Note that the description in HAC is horribly 
1557  * incomplete.  For example, it doesn't consider 
1558  * the case where digits are removed from 'x' in 
1559  * the inner loop.  It also doesn't consider the 
1560  * case that y has fewer than three digits, etc..
1561  *
1562  * The overall algorithm is as described as 
1563  * 14.20 from HAC but fixed to treat these cases.
1564 */
1565 static int mp_div (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
1566 {
1567   mp_int  q, x, y, t1, t2;
1568   int     res, n, t, i, norm, neg;
1569
1570   /* is divisor zero ? */
1571   if (mp_iszero (b) == 1) {
1572     return MP_VAL;
1573   }
1574
1575   /* if a < b then q=0, r = a */
1576   if (mp_cmp_mag (a, b) == MP_LT) {
1577     if (d != NULL) {
1578       res = mp_copy (a, d);
1579     } else {
1580       res = MP_OKAY;
1581     }
1582     if (c != NULL) {
1583       mp_zero (c);
1584     }
1585     return res;
1586   }
1587
1588   if ((res = mp_init_size (&q, a->used + 2)) != MP_OKAY) {
1589     return res;
1590   }
1591   q.used = a->used + 2;
1592
1593   if ((res = mp_init (&t1)) != MP_OKAY) {
1594     goto __Q;
1595   }
1596
1597   if ((res = mp_init (&t2)) != MP_OKAY) {
1598     goto __T1;
1599   }
1600
1601   if ((res = mp_init_copy (&x, a)) != MP_OKAY) {
1602     goto __T2;
1603   }
1604
1605   if ((res = mp_init_copy (&y, b)) != MP_OKAY) {
1606     goto __X;
1607   }
1608
1609   /* fix the sign */
1610   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1611   x.sign = y.sign = MP_ZPOS;
1612
1613   /* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1614   norm = mp_count_bits(&y) % DIGIT_BIT;
1615   if (norm < DIGIT_BIT-1) {
1616      norm = (DIGIT_BIT-1) - norm;
1617      if ((res = mp_mul_2d (&x, norm, &x)) != MP_OKAY) {
1618        goto __Y;
1619      }
1620      if ((res = mp_mul_2d (&y, norm, &y)) != MP_OKAY) {
1621        goto __Y;
1622      }
1623   } else {
1624      norm = 0;
1625   }
1626
1627   /* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
1628   n = x.used - 1;
1629   t = y.used - 1;
1630
1631   /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1632   if ((res = mp_lshd (&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1633     goto __Y;
1634   }
1635
1636   while (mp_cmp (&x, &y) != MP_LT) {
1637     ++(q.dp[n - t]);
1638     if ((res = mp_sub (&x, &y, &x)) != MP_OKAY) {
1639       goto __Y;
1640     }
1641   }
1642
1643   /* reset y by shifting it back down */
1644   mp_rshd (&y, n - t);
1645
1646   /* step 3. for i from n down to (t + 1) */
1647   for (i = n; i >= (t + 1); i--) {
1648     if (i > x.used) {
1649       continue;
1650     }
1651
1652     /* step 3.1 if xi == yt then set q{i-t-1} to b-1, 
1653      * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1654     if (x.dp[i] == y.dp[t]) {
1655       q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1656     } else {
1657       mp_word tmp;
1658       tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1659       tmp |= ((mp_word) x.dp[i - 1]);
1660       tmp /= ((mp_word) y.dp[t]);
1661       if (tmp > (mp_word) MP_MASK)
1662         tmp = MP_MASK;
1663       q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1664     }
1665
1666     /* while (q{i-t-1} * (yt * b + y{t-1})) > 
1667              xi * b**2 + xi-1 * b + xi-2 
1668      
1669        do q{i-t-1} -= 1; 
1670     */
1671     q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1672     do {
1673       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1674
1675       /* find left hand */
1676       mp_zero (&t1);
1677       t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1678       t1.dp[1] = y.dp[t];
1679       t1.used = 2;
1680       if ((res = mp_mul_d (&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1681         goto __Y;
1682       }
1683
1684       /* find right hand */
1685       t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1686       t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1687       t2.dp[2] = x.dp[i];
1688       t2.used = 3;
1689     } while (mp_cmp_mag(&t1, &t2) == MP_GT);
1690
1691     /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1692     if ((res = mp_mul_d (&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1693       goto __Y;
1694     }
1695
1696     if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1697       goto __Y;
1698     }
1699
1700     if ((res = mp_sub (&x, &t1, &x)) != MP_OKAY) {
1701       goto __Y;
1702     }
1703
1704     /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1705     if (x.sign == MP_NEG) {
1706       if ((res = mp_copy (&y, &t1)) != MP_OKAY) {
1707         goto __Y;
1708       }
1709       if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1710         goto __Y;
1711       }
1712       if ((res = mp_add (&x, &t1, &x)) != MP_OKAY) {
1713         goto __Y;
1714       }
1715
1716       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1717     }
1718   }
1719
1720   /* now q is the quotient and x is the remainder 
1721    * [which we have to normalize] 
1722    */
1723   
1724   /* get sign before writing to c */
1725   x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1726
1727   if (c != NULL) {
1728     mp_clamp (&q);
1729     mp_exch (&q, c);
1730     c->sign = neg;
1731   }
1732
1733   if (d != NULL) {
1734     mp_div_2d (&x, norm, &x, NULL);
1735     mp_exch (&x, d);
1736   }
1737
1738   res = MP_OKAY;
1739
1740 __Y:mp_clear (&y);
1741 __X:mp_clear (&x);
1742 __T2:mp_clear (&t2);
1743 __T1:mp_clear (&t1);
1744 __Q:mp_clear (&q);
1745   return res;
1746 }
1747
1748 static int s_is_power_of_two(mp_digit b, int *p)
1749 {
1750    int x;
1751
1752    for (x = 1; x < DIGIT_BIT; x++) {
1753       if (b == (((mp_digit)1)<<x)) {
1754          *p = x;
1755          return 1;
1756       }
1757    }
1758    return 0;
1759 }
1760
1761 /* single digit division (based on routine from MPI) */
1762 static int mp_div_d (const mp_int * a, mp_digit b, mp_int * c, mp_digit * d)
1763 {
1764   mp_int  q;
1765   mp_word w;
1766   mp_digit t;
1767   int     res, ix;
1768
1769   /* cannot divide by zero */
1770   if (b == 0) {
1771      return MP_VAL;
1772   }
1773
1774   /* quick outs */
1775   if (b == 1 || mp_iszero(a) == 1) {
1776      if (d != NULL) {
1777         *d = 0;
1778      }
1779      if (c != NULL) {
1780         return mp_copy(a, c);
1781      }
1782      return MP_OKAY;
1783   }
1784
1785   /* power of two ? */
1786   if (s_is_power_of_two(b, &ix) == 1) {
1787      if (d != NULL) {
1788         *d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
1789      }
1790      if (c != NULL) {
1791         return mp_div_2d(a, ix, c, NULL);
1792      }
1793      return MP_OKAY;
1794   }
1795
1796   /* no easy answer [c'est la vie].  Just division */
1797   if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
1798      return res;
1799   }
1800   
1801   q.used = a->used;
1802   q.sign = a->sign;
1803   w = 0;
1804   for (ix = a->used - 1; ix >= 0; ix--) {
1805      w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
1806      
1807      if (w >= b) {
1808         t = (mp_digit)(w / b);
1809         w -= ((mp_word)t) * ((mp_word)b);
1810       } else {
1811         t = 0;
1812       }
1813       q.dp[ix] = t;
1814   }
1815
1816   if (d != NULL) {
1817      *d = (mp_digit)w;
1818   }
1819   
1820   if (c != NULL) {
1821      mp_clamp(&q);
1822      mp_exch(&q, c);
1823   }
1824   mp_clear(&q);
1825   
1826   return res;
1827 }
1828
1829 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
1830  *
1831  * Based on algorithm from the paper
1832  *
1833  * "Generating Efficient Primes for Discrete Log Cryptosystems"
1834  *                 Chae Hoon Lim, Pil Loong Lee,
1835  *          POSTECH Information Research Laboratories
1836  *
1837  * The modulus must be of a special format [see manual]
1838  *
1839  * Has been modified to use algorithm 7.10 from the LTM book instead
1840  *
1841  * Input x must be in the range 0 <= x <= (n-1)**2
1842  */
1843 static int
1844 mp_dr_reduce (mp_int * x, const mp_int * n, mp_digit k)
1845 {
1846   int      err, i, m;
1847   mp_word  r;
1848   mp_digit mu, *tmpx1, *tmpx2;
1849
1850   /* m = digits in modulus */
1851   m = n->used;
1852
1853   /* ensure that "x" has at least 2m digits */
1854   if (x->alloc < m + m) {
1855     if ((err = mp_grow (x, m + m)) != MP_OKAY) {
1856       return err;
1857     }
1858   }
1859
1860 /* top of loop, this is where the code resumes if
1861  * another reduction pass is required.
1862  */
1863 top:
1864   /* aliases for digits */
1865   /* alias for lower half of x */
1866   tmpx1 = x->dp;
1867
1868   /* alias for upper half of x, or x/B**m */
1869   tmpx2 = x->dp + m;
1870
1871   /* set carry to zero */
1872   mu = 0;
1873
1874   /* compute (x mod B**m) + k * [x/B**m] inline and inplace */
1875   for (i = 0; i < m; i++) {
1876       r         = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
1877       *tmpx1++  = (mp_digit)(r & MP_MASK);
1878       mu        = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
1879   }
1880
1881   /* set final carry */
1882   *tmpx1++ = mu;
1883
1884   /* zero words above m */
1885   for (i = m + 1; i < x->used; i++) {
1886       *tmpx1++ = 0;
1887   }
1888
1889   /* clamp, sub and return */
1890   mp_clamp (x);
1891
1892   /* if x >= n then subtract and reduce again
1893    * Each successive "recursion" makes the input smaller and smaller.
1894    */
1895   if (mp_cmp_mag (x, n) != MP_LT) {
1896     s_mp_sub(x, n, x);
1897     goto top;
1898   }
1899   return MP_OKAY;
1900 }
1901
1902 /* sets the value of "d" required for mp_dr_reduce */
1903 static void mp_dr_setup(const mp_int *a, mp_digit *d)
1904 {
1905    /* the casts are required if DIGIT_BIT is one less than
1906     * the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
1907     */
1908    *d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) - 
1909         ((mp_word)a->dp[0]));
1910 }
1911
1912 /* this is a shell function that calls either the normal or Montgomery
1913  * exptmod functions.  Originally the call to the montgomery code was
1914  * embedded in the normal function but that wasted a lot of stack space
1915  * for nothing (since 99% of the time the Montgomery code would be called)
1916  */
1917 int mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
1918 {
1919   int dr;
1920
1921   /* modulus P must be positive */
1922   if (P->sign == MP_NEG) {
1923      return MP_VAL;
1924   }
1925
1926   /* if exponent X is negative we have to recurse */
1927   if (X->sign == MP_NEG) {
1928      mp_int tmpG, tmpX;
1929      int err;
1930
1931      /* first compute 1/G mod P */
1932      if ((err = mp_init(&tmpG)) != MP_OKAY) {
1933         return err;
1934      }
1935      if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
1936         mp_clear(&tmpG);
1937         return err;
1938      }
1939
1940      /* now get |X| */
1941      if ((err = mp_init(&tmpX)) != MP_OKAY) {
1942         mp_clear(&tmpG);
1943         return err;
1944      }
1945      if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
1946         mp_clear_multi(&tmpG, &tmpX, NULL);
1947         return err;
1948      }
1949
1950      /* and now compute (1/G)**|X| instead of G**X [X < 0] */
1951      err = mp_exptmod(&tmpG, &tmpX, P, Y);
1952      mp_clear_multi(&tmpG, &tmpX, NULL);
1953      return err;
1954   }
1955
1956   dr = 0;
1957
1958   /* if the modulus is odd or dr != 0 use the fast method */
1959   if (mp_isodd (P) == 1 || dr !=  0) {
1960     return mp_exptmod_fast (G, X, P, Y, dr);
1961   } else {
1962     /* otherwise use the generic Barrett reduction technique */
1963     return s_mp_exptmod (G, X, P, Y);
1964   }
1965 }
1966
1967 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
1968  *
1969  * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
1970  * The value of k changes based on the size of the exponent.
1971  *
1972  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
1973  */
1974
1975 int
1976 mp_exptmod_fast (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y, int redmode)
1977 {
1978   mp_int  M[256], res;
1979   mp_digit buf, mp;
1980   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1981
1982   /* use a pointer to the reduction algorithm.  This allows us to use
1983    * one of many reduction algorithms without modding the guts of
1984    * the code with if statements everywhere.
1985    */
1986   int     (*redux)(mp_int*,const mp_int*,mp_digit);
1987
1988   /* find window size */
1989   x = mp_count_bits (X);
1990   if (x <= 7) {
1991     winsize = 2;
1992   } else if (x <= 36) {
1993     winsize = 3;
1994   } else if (x <= 140) {
1995     winsize = 4;
1996   } else if (x <= 450) {
1997     winsize = 5;
1998   } else if (x <= 1303) {
1999     winsize = 6;
2000   } else if (x <= 3529) {
2001     winsize = 7;
2002   } else {
2003     winsize = 8;
2004   }
2005
2006   /* init M array */
2007   /* init first cell */
2008   if ((err = mp_init(&M[1])) != MP_OKAY) {
2009      return err;
2010   }
2011
2012   /* now init the second half of the array */
2013   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2014     if ((err = mp_init(&M[x])) != MP_OKAY) {
2015       for (y = 1<<(winsize-1); y < x; y++) {
2016         mp_clear (&M[y]);
2017       }
2018       mp_clear(&M[1]);
2019       return err;
2020     }
2021   }
2022
2023   /* determine and setup reduction code */
2024   if (redmode == 0) {
2025      /* now setup montgomery  */
2026      if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
2027         goto __M;
2028      }
2029
2030      /* automatically pick the comba one if available (saves quite a few calls/ifs) */
2031      if (((P->used * 2 + 1) < MP_WARRAY) &&
2032           P->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2033         redux = fast_mp_montgomery_reduce;
2034      } else {
2035         /* use slower baseline Montgomery method */
2036         redux = mp_montgomery_reduce;
2037      }
2038   } else if (redmode == 1) {
2039      /* setup DR reduction for moduli of the form B**k - b */
2040      mp_dr_setup(P, &mp);
2041      redux = mp_dr_reduce;
2042   } else {
2043      /* setup DR reduction for moduli of the form 2**k - b */
2044      if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
2045         goto __M;
2046      }
2047      redux = mp_reduce_2k;
2048   }
2049
2050   /* setup result */
2051   if ((err = mp_init (&res)) != MP_OKAY) {
2052     goto __M;
2053   }
2054
2055   /* create M table
2056    *
2057
2058    *
2059    * The first half of the table is not computed though accept for M[0] and M[1]
2060    */
2061
2062   if (redmode == 0) {
2063      /* now we need R mod m */
2064      if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
2065        goto __RES;
2066      }
2067
2068      /* now set M[1] to G * R mod m */
2069      if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
2070        goto __RES;
2071      }
2072   } else {
2073      mp_set(&res, 1);
2074      if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
2075         goto __RES;
2076      }
2077   }
2078
2079   /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
2080   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
2081     goto __RES;
2082   }
2083
2084   for (x = 0; x < (winsize - 1); x++) {
2085     if ((err = mp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
2086       goto __RES;
2087     }
2088     if ((err = redux (&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
2089       goto __RES;
2090     }
2091   }
2092
2093   /* create upper table */
2094   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
2095     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
2096       goto __RES;
2097     }
2098     if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
2099       goto __RES;
2100     }
2101   }
2102
2103   /* set initial mode and bit cnt */
2104   mode   = 0;
2105   bitcnt = 1;
2106   buf    = 0;
2107   digidx = X->used - 1;
2108   bitcpy = 0;
2109   bitbuf = 0;
2110
2111   for (;;) {
2112     /* grab next digit as required */
2113     if (--bitcnt == 0) {
2114       /* if digidx == -1 we are out of digits so break */
2115       if (digidx == -1) {
2116         break;
2117       }
2118       /* read next digit and reset bitcnt */
2119       buf    = X->dp[digidx--];
2120       bitcnt = DIGIT_BIT;
2121     }
2122
2123     /* grab the next msb from the exponent */
2124     y     = (buf >> (DIGIT_BIT - 1)) & 1;
2125     buf <<= (mp_digit)1;
2126
2127     /* if the bit is zero and mode == 0 then we ignore it
2128      * These represent the leading zero bits before the first 1 bit
2129      * in the exponent.  Technically this opt is not required but it
2130      * does lower the # of trivial squaring/reductions used
2131      */
2132     if (mode == 0 && y == 0) {
2133       continue;
2134     }
2135
2136     /* if the bit is zero and mode == 1 then we square */
2137     if (mode == 1 && y == 0) {
2138       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2139         goto __RES;
2140       }
2141       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2142         goto __RES;
2143       }
2144       continue;
2145     }
2146
2147     /* else we add it to the window */
2148     bitbuf |= (y << (winsize - ++bitcpy));
2149     mode    = 2;
2150
2151     if (bitcpy == winsize) {
2152       /* ok window is filled so square as required and multiply  */
2153       /* square first */
2154       for (x = 0; x < winsize; x++) {
2155         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2156           goto __RES;
2157         }
2158         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2159           goto __RES;
2160         }
2161       }
2162
2163       /* then multiply */
2164       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
2165         goto __RES;
2166       }
2167       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2168         goto __RES;
2169       }
2170
2171       /* empty window and reset */
2172       bitcpy = 0;
2173       bitbuf = 0;
2174       mode   = 1;
2175     }
2176   }
2177
2178   /* if bits remain then square/multiply */
2179   if (mode == 2 && bitcpy > 0) {
2180     /* square then multiply if the bit is set */
2181     for (x = 0; x < bitcpy; x++) {
2182       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2183         goto __RES;
2184       }
2185       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2186         goto __RES;
2187       }
2188
2189       /* get next bit of the window */
2190       bitbuf <<= 1;
2191       if ((bitbuf & (1 << winsize)) != 0) {
2192         /* then multiply */
2193         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
2194           goto __RES;
2195         }
2196         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2197           goto __RES;
2198         }
2199       }
2200     }
2201   }
2202
2203   if (redmode == 0) {
2204      /* fixup result if Montgomery reduction is used
2205       * recall that any value in a Montgomery system is
2206       * actually multiplied by R mod n.  So we have
2207       * to reduce one more time to cancel out the factor
2208       * of R.
2209       */
2210      if ((err = redux(&res, P, mp)) != MP_OKAY) {
2211        goto __RES;
2212      }
2213   }
2214
2215   /* swap res with Y */
2216   mp_exch (&res, Y);
2217   err = MP_OKAY;
2218 __RES:mp_clear (&res);
2219 __M:
2220   mp_clear(&M[1]);
2221   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2222     mp_clear (&M[x]);
2223   }
2224   return err;
2225 }
2226
2227 /* Greatest Common Divisor using the binary method */
2228 int mp_gcd (const mp_int * a, const mp_int * b, mp_int * c)
2229 {
2230   mp_int  u, v;
2231   int     k, u_lsb, v_lsb, res;
2232
2233   /* either zero than gcd is the largest */
2234   if (mp_iszero (a) == 1 && mp_iszero (b) == 0) {
2235     return mp_abs (b, c);
2236   }
2237   if (mp_iszero (a) == 0 && mp_iszero (b) == 1) {
2238     return mp_abs (a, c);
2239   }
2240
2241   /* optimized.  At this point if a == 0 then
2242    * b must equal zero too
2243    */
2244   if (mp_iszero (a) == 1) {
2245     mp_zero(c);
2246     return MP_OKAY;
2247   }
2248
2249   /* get copies of a and b we can modify */
2250   if ((res = mp_init_copy (&u, a)) != MP_OKAY) {
2251     return res;
2252   }
2253
2254   if ((res = mp_init_copy (&v, b)) != MP_OKAY) {
2255     goto __U;
2256   }
2257
2258   /* must be positive for the remainder of the algorithm */
2259   u.sign = v.sign = MP_ZPOS;
2260
2261   /* B1.  Find the common power of two for u and v */
2262   u_lsb = mp_cnt_lsb(&u);
2263   v_lsb = mp_cnt_lsb(&v);
2264   k     = MIN(u_lsb, v_lsb);
2265
2266   if (k > 0) {
2267      /* divide the power of two out */
2268      if ((res = mp_div_2d(&u, k, &u, NULL)) != MP_OKAY) {
2269         goto __V;
2270      }
2271
2272      if ((res = mp_div_2d(&v, k, &v, NULL)) != MP_OKAY) {
2273         goto __V;
2274      }
2275   }
2276
2277   /* divide any remaining factors of two out */
2278   if (u_lsb != k) {
2279      if ((res = mp_div_2d(&u, u_lsb - k, &u, NULL)) != MP_OKAY) {
2280         goto __V;
2281      }
2282   }
2283
2284   if (v_lsb != k) {
2285      if ((res = mp_div_2d(&v, v_lsb - k, &v, NULL)) != MP_OKAY) {
2286         goto __V;
2287      }
2288   }
2289
2290   while (mp_iszero(&v) == 0) {
2291      /* make sure v is the largest */
2292      if (mp_cmp_mag(&u, &v) == MP_GT) {
2293         /* swap u and v to make sure v is >= u */
2294         mp_exch(&u, &v);
2295      }
2296      
2297      /* subtract smallest from largest */
2298      if ((res = s_mp_sub(&v, &u, &v)) != MP_OKAY) {
2299         goto __V;
2300      }
2301      
2302      /* Divide out all factors of two */
2303      if ((res = mp_div_2d(&v, mp_cnt_lsb(&v), &v, NULL)) != MP_OKAY) {
2304         goto __V;
2305      } 
2306   } 
2307
2308   /* multiply by 2**k which we divided out at the beginning */
2309   if ((res = mp_mul_2d (&u, k, c)) != MP_OKAY) {
2310      goto __V;
2311   }
2312   c->sign = MP_ZPOS;
2313   res = MP_OKAY;
2314 __V:mp_clear (&u);
2315 __U:mp_clear (&v);
2316   return res;
2317 }
2318
2319 /* get the lower 32-bits of an mp_int */
2320 unsigned long mp_get_int(const mp_int * a)
2321 {
2322   int i;
2323   unsigned long res;
2324
2325   if (a->used == 0) {
2326      return 0;
2327   }
2328
2329   /* get number of digits of the lsb we have to read */
2330   i = MIN(a->used,(int)((sizeof(unsigned long)*CHAR_BIT+DIGIT_BIT-1)/DIGIT_BIT))-1;
2331
2332   /* get most significant digit of result */
2333   res = DIGIT(a,i);
2334    
2335   while (--i >= 0) {
2336     res = (res << DIGIT_BIT) | DIGIT(a,i);
2337   }
2338
2339   /* force result to 32-bits always so it is consistent on non 32-bit platforms */
2340   return res & 0xFFFFFFFFUL;
2341 }
2342
2343 /* creates "a" then copies b into it */
2344 int mp_init_copy (mp_int * a, const mp_int * b)
2345 {
2346   int     res;
2347
2348   if ((res = mp_init (a)) != MP_OKAY) {
2349     return res;
2350   }
2351   return mp_copy (b, a);
2352 }
2353
2354 int mp_init_multi(mp_int *mp, ...) 
2355 {
2356     mp_err res = MP_OKAY;      /* Assume ok until proven otherwise */
2357     int n = 0;                 /* Number of ok inits */
2358     mp_int* cur_arg = mp;
2359     va_list args;
2360
2361     va_start(args, mp);        /* init args to next argument from caller */
2362     while (cur_arg != NULL) {
2363         if (mp_init(cur_arg) != MP_OKAY) {
2364             /* Oops - error! Back-track and mp_clear what we already
2365                succeeded in init-ing, then return error.
2366             */
2367             va_list clean_args;
2368             
2369             /* end the current list */
2370             va_end(args);
2371             
2372             /* now start cleaning up */            
2373             cur_arg = mp;
2374             va_start(clean_args, mp);
2375             while (n--) {
2376                 mp_clear(cur_arg);
2377                 cur_arg = va_arg(clean_args, mp_int*);
2378             }
2379             va_end(clean_args);
2380             res = MP_MEM;
2381             break;
2382         }
2383         n++;
2384         cur_arg = va_arg(args, mp_int*);
2385     }
2386     va_end(args);
2387     return res;                /* Assumed ok, if error flagged above. */
2388 }
2389
2390 /* hac 14.61, pp608 */
2391 int mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
2392 {
2393   /* b cannot be negative */
2394   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2395     return MP_VAL;
2396   }
2397
2398   /* if the modulus is odd we can use a faster routine instead */
2399   if (mp_isodd (b) == 1) {
2400     return fast_mp_invmod (a, b, c);
2401   }
2402   
2403   return mp_invmod_slow(a, b, c);
2404 }
2405
2406 /* hac 14.61, pp608 */
2407 int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c)
2408 {
2409   mp_int  x, y, u, v, A, B, C, D;
2410   int     res;
2411
2412   /* b cannot be negative */
2413   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2414     return MP_VAL;
2415   }
2416
2417   /* init temps */
2418   if ((res = mp_init_multi(&x, &y, &u, &v, 
2419                            &A, &B, &C, &D, NULL)) != MP_OKAY) {
2420      return res;
2421   }
2422
2423   /* x = a, y = b */
2424   if ((res = mp_copy (a, &x)) != MP_OKAY) {
2425     goto __ERR;
2426   }
2427   if ((res = mp_copy (b, &y)) != MP_OKAY) {
2428     goto __ERR;
2429   }
2430
2431   /* 2. [modified] if x,y are both even then return an error! */
2432   if (mp_iseven (&x) == 1 && mp_iseven (&y) == 1) {
2433     res = MP_VAL;
2434     goto __ERR;
2435   }
2436
2437   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
2438   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
2439     goto __ERR;
2440   }
2441   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
2442     goto __ERR;
2443   }
2444   mp_set (&A, 1);
2445   mp_set (&D, 1);
2446
2447 top:
2448   /* 4.  while u is even do */
2449   while (mp_iseven (&u) == 1) {
2450     /* 4.1 u = u/2 */
2451     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
2452       goto __ERR;
2453     }
2454     /* 4.2 if A or B is odd then */
2455     if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
2456       /* A = (A+y)/2, B = (B-x)/2 */
2457       if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
2458          goto __ERR;
2459       }
2460       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
2461          goto __ERR;
2462       }
2463     }
2464     /* A = A/2, B = B/2 */
2465     if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
2466       goto __ERR;
2467     }
2468     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
2469       goto __ERR;
2470     }
2471   }
2472
2473   /* 5.  while v is even do */
2474   while (mp_iseven (&v) == 1) {
2475     /* 5.1 v = v/2 */
2476     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
2477       goto __ERR;
2478     }
2479     /* 5.2 if C or D is odd then */
2480     if (mp_isodd (&C) == 1 || mp_isodd (&D) == 1) {
2481       /* C = (C+y)/2, D = (D-x)/2 */
2482       if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
2483          goto __ERR;
2484       }
2485       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
2486          goto __ERR;
2487       }
2488     }
2489     /* C = C/2, D = D/2 */
2490     if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
2491       goto __ERR;
2492     }
2493     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
2494       goto __ERR;
2495     }
2496   }
2497
2498   /* 6.  if u >= v then */
2499   if (mp_cmp (&u, &v) != MP_LT) {
2500     /* u = u - v, A = A - C, B = B - D */
2501     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
2502       goto __ERR;
2503     }
2504
2505     if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
2506       goto __ERR;
2507     }
2508
2509     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
2510       goto __ERR;
2511     }
2512   } else {
2513     /* v - v - u, C = C - A, D = D - B */
2514     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
2515       goto __ERR;
2516     }
2517
2518     if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
2519       goto __ERR;
2520     }
2521
2522     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
2523       goto __ERR;
2524     }
2525   }
2526
2527   /* if not zero goto step 4 */
2528   if (mp_iszero (&u) == 0)
2529     goto top;
2530
2531   /* now a = C, b = D, gcd == g*v */
2532
2533   /* if v != 1 then there is no inverse */
2534   if (mp_cmp_d (&v, 1) != MP_EQ) {
2535     res = MP_VAL;
2536     goto __ERR;
2537   }
2538
2539   /* if its too low */
2540   while (mp_cmp_d(&C, 0) == MP_LT) {
2541       if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
2542          goto __ERR;
2543       }
2544   }
2545   
2546   /* too big */
2547   while (mp_cmp_mag(&C, b) != MP_LT) {
2548       if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
2549          goto __ERR;
2550       }
2551   }
2552   
2553   /* C is now the inverse */
2554   mp_exch (&C, c);
2555   res = MP_OKAY;
2556 __ERR:mp_clear_multi (&x, &y, &u, &v, &A, &B, &C, &D, NULL);
2557   return res;
2558 }
2559
2560 /* c = |a| * |b| using Karatsuba Multiplication using 
2561  * three half size multiplications
2562  *
2563  * Let B represent the radix [e.g. 2**DIGIT_BIT] and 
2564  * let n represent half of the number of digits in 
2565  * the min(a,b)
2566  *
2567  * a = a1 * B**n + a0
2568  * b = b1 * B**n + b0
2569  *
2570  * Then, a * b => 
2571    a1b1 * B**2n + ((a1 - a0)(b1 - b0) + a0b0 + a1b1) * B + a0b0
2572  *
2573  * Note that a1b1 and a0b0 are used twice and only need to be 
2574  * computed once.  So in total three half size (half # of 
2575  * digit) multiplications are performed, a0b0, a1b1 and 
2576  * (a1-b1)(a0-b0)
2577  *
2578  * Note that a multiplication of half the digits requires
2579  * 1/4th the number of single precision multiplications so in 
2580  * total after one call 25% of the single precision multiplications 
2581  * are saved.  Note also that the call to mp_mul can end up back 
2582  * in this function if the a0, a1, b0, or b1 are above the threshold.  
2583  * This is known as divide-and-conquer and leads to the famous 
2584  * O(N**lg(3)) or O(N**1.584) work which is asymptotically lower than
2585  * the standard O(N**2) that the baseline/comba methods use.  
2586  * Generally though the overhead of this method doesn't pay off 
2587  * until a certain size (N ~ 80) is reached.
2588  */
2589 int mp_karatsuba_mul (const mp_int * a, const mp_int * b, mp_int * c)
2590 {
2591   mp_int  x0, x1, y0, y1, t1, x0y0, x1y1;
2592   int     B, err;
2593
2594   /* default the return code to an error */
2595   err = MP_MEM;
2596
2597   /* min # of digits */
2598   B = MIN (a->used, b->used);
2599
2600   /* now divide in two */
2601   B = B >> 1;
2602
2603   /* init copy all the temps */
2604   if (mp_init_size (&x0, B) != MP_OKAY)
2605     goto ERR;
2606   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2607     goto X0;
2608   if (mp_init_size (&y0, B) != MP_OKAY)
2609     goto X1;
2610   if (mp_init_size (&y1, b->used - B) != MP_OKAY)
2611     goto Y0;
2612
2613   /* init temps */
2614   if (mp_init_size (&t1, B * 2) != MP_OKAY)
2615     goto Y1;
2616   if (mp_init_size (&x0y0, B * 2) != MP_OKAY)
2617     goto T1;
2618   if (mp_init_size (&x1y1, B * 2) != MP_OKAY)
2619     goto X0Y0;
2620
2621   /* now shift the digits */
2622   x0.used = y0.used = B;
2623   x1.used = a->used - B;
2624   y1.used = b->used - B;
2625
2626   {
2627     register int x;
2628     register mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
2629
2630     /* we copy the digits directly instead of using higher level functions
2631      * since we also need to shift the digits
2632      */
2633     tmpa = a->dp;
2634     tmpb = b->dp;
2635
2636     tmpx = x0.dp;
2637     tmpy = y0.dp;
2638     for (x = 0; x < B; x++) {
2639       *tmpx++ = *tmpa++;
2640       *tmpy++ = *tmpb++;
2641     }
2642
2643     tmpx = x1.dp;
2644     for (x = B; x < a->used; x++) {
2645       *tmpx++ = *tmpa++;
2646     }
2647
2648     tmpy = y1.dp;
2649     for (x = B; x < b->used; x++) {
2650       *tmpy++ = *tmpb++;
2651     }
2652   }
2653
2654   /* only need to clamp the lower words since by definition the 
2655    * upper words x1/y1 must have a known number of digits
2656    */
2657   mp_clamp (&x0);
2658   mp_clamp (&y0);
2659
2660   /* now calc the products x0y0 and x1y1 */
2661   /* after this x0 is no longer required, free temp [x0==t2]! */
2662   if (mp_mul (&x0, &y0, &x0y0) != MP_OKAY)  
2663     goto X1Y1;          /* x0y0 = x0*y0 */
2664   if (mp_mul (&x1, &y1, &x1y1) != MP_OKAY)
2665     goto X1Y1;          /* x1y1 = x1*y1 */
2666
2667   /* now calc x1-x0 and y1-y0 */
2668   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2669     goto X1Y1;          /* t1 = x1 - x0 */
2670   if (mp_sub (&y1, &y0, &x0) != MP_OKAY)
2671     goto X1Y1;          /* t2 = y1 - y0 */
2672   if (mp_mul (&t1, &x0, &t1) != MP_OKAY)
2673     goto X1Y1;          /* t1 = (x1 - x0) * (y1 - y0) */
2674
2675   /* add x0y0 */
2676   if (mp_add (&x0y0, &x1y1, &x0) != MP_OKAY)
2677     goto X1Y1;          /* t2 = x0y0 + x1y1 */
2678   if (mp_sub (&x0, &t1, &t1) != MP_OKAY)
2679     goto X1Y1;          /* t1 = x0y0 + x1y1 - (x1-x0)*(y1-y0) */
2680
2681   /* shift by B */
2682   if (mp_lshd (&t1, B) != MP_OKAY)
2683     goto X1Y1;          /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
2684   if (mp_lshd (&x1y1, B * 2) != MP_OKAY)
2685     goto X1Y1;          /* x1y1 = x1y1 << 2*B */
2686
2687   if (mp_add (&x0y0, &t1, &t1) != MP_OKAY)
2688     goto X1Y1;          /* t1 = x0y0 + t1 */
2689   if (mp_add (&t1, &x1y1, c) != MP_OKAY)
2690     goto X1Y1;          /* t1 = x0y0 + t1 + x1y1 */
2691
2692   /* Algorithm succeeded set the return code to MP_OKAY */
2693   err = MP_OKAY;
2694
2695 X1Y1:mp_clear (&x1y1);
2696 X0Y0:mp_clear (&x0y0);
2697 T1:mp_clear (&t1);
2698 Y1:mp_clear (&y1);
2699 Y0:mp_clear (&y0);
2700 X1:mp_clear (&x1);
2701 X0:mp_clear (&x0);
2702 ERR:
2703   return err;
2704 }
2705
2706 /* Karatsuba squaring, computes b = a*a using three 
2707  * half size squarings
2708  *
2709  * See comments of karatsuba_mul for details.  It 
2710  * is essentially the same algorithm but merely 
2711  * tuned to perform recursive squarings.
2712  */
2713 int mp_karatsuba_sqr (const mp_int * a, mp_int * b)
2714 {
2715   mp_int  x0, x1, t1, t2, x0x0, x1x1;
2716   int     B, err;
2717
2718   err = MP_MEM;
2719
2720   /* min # of digits */
2721   B = a->used;
2722
2723   /* now divide in two */
2724   B = B >> 1;
2725
2726   /* init copy all the temps */
2727   if (mp_init_size (&x0, B) != MP_OKAY)
2728     goto ERR;
2729   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2730     goto X0;
2731
2732   /* init temps */
2733   if (mp_init_size (&t1, a->used * 2) != MP_OKAY)
2734     goto X1;
2735   if (mp_init_size (&t2, a->used * 2) != MP_OKAY)
2736     goto T1;
2737   if (mp_init_size (&x0x0, B * 2) != MP_OKAY)
2738     goto T2;
2739   if (mp_init_size (&x1x1, (a->used - B) * 2) != MP_OKAY)
2740     goto X0X0;
2741
2742   {
2743     register int x;
2744     register mp_digit *dst, *src;
2745
2746     src = a->dp;
2747
2748     /* now shift the digits */
2749     dst = x0.dp;
2750     for (x = 0; x < B; x++) {
2751       *dst++ = *src++;
2752     }
2753
2754     dst = x1.dp;
2755     for (x = B; x < a->used; x++) {
2756       *dst++ = *src++;
2757     }
2758   }
2759
2760   x0.used = B;
2761   x1.used = a->used - B;
2762
2763   mp_clamp (&x0);
2764
2765   /* now calc the products x0*x0 and x1*x1 */
2766   if (mp_sqr (&x0, &x0x0) != MP_OKAY)
2767     goto X1X1;           /* x0x0 = x0*x0 */
2768   if (mp_sqr (&x1, &x1x1) != MP_OKAY)
2769     goto X1X1;           /* x1x1 = x1*x1 */
2770
2771   /* now calc (x1-x0)**2 */
2772   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2773     goto X1X1;           /* t1 = x1 - x0 */
2774   if (mp_sqr (&t1, &t1) != MP_OKAY)
2775     goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */
2776
2777   /* add x0y0 */
2778   if (s_mp_add (&x0x0, &x1x1, &t2) != MP_OKAY)
2779     goto X1X1;           /* t2 = x0x0 + x1x1 */
2780   if (mp_sub (&t2, &t1, &t1) != MP_OKAY)
2781     goto X1X1;           /* t1 = x0x0 + x1x1 - (x1-x0)*(x1-x0) */
2782
2783   /* shift by B */
2784   if (mp_lshd (&t1, B) != MP_OKAY)
2785     goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
2786   if (mp_lshd (&x1x1, B * 2) != MP_OKAY)
2787     goto X1X1;           /* x1x1 = x1x1 << 2*B */
2788
2789   if (mp_add (&x0x0, &t1, &t1) != MP_OKAY)
2790     goto X1X1;           /* t1 = x0x0 + t1 */
2791   if (mp_add (&t1, &x1x1, b) != MP_OKAY)
2792     goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */
2793
2794   err = MP_OKAY;
2795
2796 X1X1:mp_clear (&x1x1);
2797 X0X0:mp_clear (&x0x0);
2798 T2:mp_clear (&t2);
2799 T1:mp_clear (&t1);
2800 X1:mp_clear (&x1);
2801 X0:mp_clear (&x0);
2802 ERR:
2803   return err;
2804 }
2805
2806 /* computes least common multiple as |a*b|/(a, b) */
2807 int mp_lcm (const mp_int * a, const mp_int * b, mp_int * c)
2808 {
2809   int     res;
2810   mp_int  t1, t2;
2811
2812
2813   if ((res = mp_init_multi (&t1, &t2, NULL)) != MP_OKAY) {
2814     return res;
2815   }
2816
2817   /* t1 = get the GCD of the two inputs */
2818   if ((res = mp_gcd (a, b, &t1)) != MP_OKAY) {
2819     goto __T;
2820   }
2821
2822   /* divide the smallest by the GCD */
2823   if (mp_cmp_mag(a, b) == MP_LT) {
2824      /* store quotient in t2 such that t2 * b is the LCM */
2825      if ((res = mp_div(a, &t1, &t2, NULL)) != MP_OKAY) {
2826         goto __T;
2827      }
2828      res = mp_mul(b, &t2, c);
2829   } else {
2830      /* store quotient in t2 such that t2 * a is the LCM */
2831      if ((res = mp_div(b, &t1, &t2, NULL)) != MP_OKAY) {
2832         goto __T;
2833      }
2834      res = mp_mul(a, &t2, c);
2835   }
2836
2837   /* fix the sign to positive */
2838   c->sign = MP_ZPOS;
2839
2840 __T:
2841   mp_clear_multi (&t1, &t2, NULL);
2842   return res;
2843 }
2844
2845 /* c = a mod b, 0 <= c < b */
2846 int
2847 mp_mod (const mp_int * a, mp_int * b, mp_int * c)
2848 {
2849   mp_int  t;
2850   int     res;
2851
2852   if ((res = mp_init (&t)) != MP_OKAY) {
2853     return res;
2854   }
2855
2856   if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
2857     mp_clear (&t);
2858     return res;
2859   }
2860
2861   if (t.sign != b->sign) {
2862     res = mp_add (b, &t, c);
2863   } else {
2864     res = MP_OKAY;
2865     mp_exch (&t, c);
2866   }
2867
2868   mp_clear (&t);
2869   return res;
2870 }
2871
2872 static int
2873 mp_mod_d (const mp_int * a, mp_digit b, mp_digit * c)
2874 {
2875   return mp_div_d(a, b, NULL, c);
2876 }
2877
2878 /* b = a*2 */
2879 static int mp_mul_2(const mp_int * a, mp_int * b)
2880 {
2881   int     x, res, oldused;
2882
2883   /* grow to accommodate result */
2884   if (b->alloc < a->used + 1) {
2885     if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
2886       return res;
2887     }
2888   }
2889
2890   oldused = b->used;
2891   b->used = a->used;
2892
2893   {
2894     register mp_digit r, rr, *tmpa, *tmpb;
2895
2896     /* alias for source */
2897     tmpa = a->dp;
2898
2899     /* alias for dest */
2900     tmpb = b->dp;
2901
2902     /* carry */
2903     r = 0;
2904     for (x = 0; x < a->used; x++) {
2905
2906       /* get what will be the *next* carry bit from the
2907        * MSB of the current digit
2908        */
2909       rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
2910
2911       /* now shift up this digit, add in the carry [from the previous] */
2912       *tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
2913
2914       /* copy the carry that would be from the source
2915        * digit into the next iteration
2916        */
2917       r = rr;
2918     }
2919
2920     /* new leading digit? */
2921     if (r != 0) {
2922       /* add a MSB which is always 1 at this point */
2923       *tmpb = 1;
2924       ++(b->used);
2925     }
2926
2927     /* now zero any excess digits on the destination
2928      * that we didn't write to
2929      */
2930     tmpb = b->dp + b->used;
2931     for (x = b->used; x < oldused; x++) {
2932       *tmpb++ = 0;
2933     }
2934   }
2935   b->sign = a->sign;
2936   return MP_OKAY;
2937 }
2938
2939 /*
2940  * shifts with subtractions when the result is greater than b.
2941  *
2942  * The method is slightly modified to shift B unconditionally up to just under
2943  * the leading bit of b.  This saves a lot of multiple precision shifting.
2944  */
2945 int mp_montgomery_calc_normalization (mp_int * a, const mp_int * b)
2946 {
2947   int     x, bits, res;
2948
2949   /* how many bits of last digit does b use */
2950   bits = mp_count_bits (b) % DIGIT_BIT;
2951
2952
2953   if (b->used > 1) {
2954      if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
2955         return res;
2956      }
2957   } else {
2958      mp_set(a, 1);
2959      bits = 1;
2960   }
2961
2962
2963   /* now compute C = A * B mod b */
2964   for (x = bits - 1; x < DIGIT_BIT; x++) {
2965     if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
2966       return res;
2967     }
2968     if (mp_cmp_mag (a, b) != MP_LT) {
2969       if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
2970         return res;
2971       }
2972     }
2973   }
2974
2975   return MP_OKAY;
2976 }
2977
2978 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
2979 int
2980 mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
2981 {
2982   int     ix, res, digs;
2983   mp_digit mu;
2984
2985   /* can the fast reduction [comba] method be used?
2986    *
2987    * Note that unlike in mul you're safely allowed *less*
2988    * than the available columns [255 per default] since carries
2989    * are fixed up in the inner loop.
2990    */
2991   digs = n->used * 2 + 1;
2992   if ((digs < MP_WARRAY) &&
2993       n->used <
2994       (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2995     return fast_mp_montgomery_reduce (x, n, rho);
2996   }
2997
2998   /* grow the input as required */
2999   if (x->alloc < digs) {
3000     if ((res = mp_grow (x, digs)) != MP_OKAY) {
3001       return res;
3002     }
3003   }
3004   x->used = digs;
3005
3006   for (ix = 0; ix < n->used; ix++) {
3007     /* mu = ai * rho mod b
3008      *
3009      * The value of rho must be precalculated via
3010      * montgomery_setup() such that
3011      * it equals -1/n0 mod b this allows the
3012      * following inner loop to reduce the
3013      * input one digit at a time
3014      */
3015     mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
3016
3017     /* a = a + mu * m * b**i */
3018     {
3019       register int iy;
3020       register mp_digit *tmpn, *tmpx, u;
3021       register mp_word r;
3022
3023       /* alias for digits of the modulus */
3024       tmpn = n->dp;
3025
3026       /* alias for the digits of x [the input] */
3027       tmpx = x->dp + ix;
3028
3029       /* set the carry to zero */
3030       u = 0;
3031
3032       /* Multiply and add in place */
3033       for (iy = 0; iy < n->used; iy++) {
3034         /* compute product and sum */
3035         r       = ((mp_word)mu) * ((mp_word)*tmpn++) +
3036                   ((mp_word) u) + ((mp_word) * tmpx);
3037
3038         /* get carry */
3039         u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3040
3041         /* fix digit */
3042         *tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
3043       }
3044       /* At this point the ix'th digit of x should be zero */
3045
3046
3047       /* propagate carries upwards as required*/
3048       while (u) {
3049         *tmpx   += u;
3050         u        = *tmpx >> DIGIT_BIT;
3051         *tmpx++ &= MP_MASK;
3052       }
3053     }
3054   }
3055
3056   /* at this point the n.used'th least
3057    * significant digits of x are all zero
3058    * which means we can shift x to the
3059    * right by n.used digits and the
3060    * residue is unchanged.
3061    */
3062
3063   /* x = x/b**n.used */
3064   mp_clamp(x);
3065   mp_rshd (x, n->used);
3066
3067   /* if x >= n then x = x - n */
3068   if (mp_cmp_mag (x, n) != MP_LT) {
3069     return s_mp_sub (x, n, x);
3070   }
3071
3072   return MP_OKAY;
3073 }
3074
3075 /* setups the montgomery reduction stuff */
3076 int
3077 mp_montgomery_setup (const mp_int * n, mp_digit * rho)
3078 {
3079   mp_digit x, b;
3080
3081 /* fast inversion mod 2**k
3082  *
3083  * Based on the fact that
3084  *
3085  * XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
3086  *                    =>  2*X*A - X*X*A*A = 1
3087  *                    =>  2*(1) - (1)     = 1
3088  */
3089   b = n->dp[0];
3090
3091   if ((b & 1) == 0) {
3092     return MP_VAL;
3093   }
3094
3095   x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
3096   x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
3097   x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
3098   x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
3099
3100   /* rho = -1/m mod b */
3101   *rho = (((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
3102
3103   return MP_OKAY;
3104 }
3105
3106 /* high level multiplication (handles sign) */
3107 int mp_mul (const mp_int * a, const mp_int * b, mp_int * c)
3108 {
3109   int     res, neg;
3110   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
3111
3112   /* use Karatsuba? */
3113   if (MIN (a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
3114     res = mp_karatsuba_mul (a, b, c);
3115   } else 
3116   {
3117     /* can we use the fast multiplier?
3118      *
3119      * The fast multiplier can be used if the output will 
3120      * have less than MP_WARRAY digits and the number of 
3121      * digits won't affect carry propagation
3122      */
3123     int     digs = a->used + b->used + 1;
3124
3125     if ((digs < MP_WARRAY) &&
3126         MIN(a->used, b->used) <= 
3127         (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
3128       res = fast_s_mp_mul_digs (a, b, c, digs);
3129     } else 
3130       res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
3131   }
3132   c->sign = (c->used > 0) ? neg : MP_ZPOS;
3133   return res;
3134 }
3135
3136 /* d = a * b (mod c) */
3137 int
3138 mp_mulmod (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
3139 {
3140   int     res;
3141   mp_int  t;
3142
3143   if ((res = mp_init (&t)) != MP_OKAY) {
3144     return res;
3145   }
3146
3147   if ((res = mp_mul (a, b, &t)) != MP_OKAY) {
3148     mp_clear (&t);
3149     return res;
3150   }
3151   res = mp_mod (&t, c, d);
3152   mp_clear (&t);
3153   return res;
3154 }
3155
3156 /* table of first PRIME_SIZE primes */
3157 static const mp_digit __prime_tab[] = {
3158   0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
3159   0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
3160   0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
3161   0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F, 0x0083,
3162   0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
3163   0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
3164   0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
3165   0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
3166
3167   0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
3168   0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
3169   0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
3170   0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
3171   0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
3172   0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
3173   0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
3174   0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
3175
3176   0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
3177   0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
3178   0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
3179   0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
3180   0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
3181   0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
3182   0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
3183   0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
3184
3185   0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
3186   0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
3187   0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
3188   0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
3189   0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
3190   0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
3191   0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
3192   0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
3193 };
3194
3195 /* determines if an integers is divisible by one 
3196  * of the first PRIME_SIZE primes or not
3197  *
3198  * sets result to 0 if not, 1 if yes
3199  */
3200 static int mp_prime_is_divisible (const mp_int * a, int *result)
3201 {
3202   int     err, ix;
3203   mp_digit res;
3204
3205   /* default to not */
3206   *result = MP_NO;
3207
3208   for (ix = 0; ix < PRIME_SIZE; ix++) {
3209     /* what is a mod __prime_tab[ix] */
3210     if ((err = mp_mod_d (a, __prime_tab[ix], &res)) != MP_OKAY) {
3211       return err;
3212     }
3213
3214     /* is the residue zero? */
3215     if (res == 0) {
3216       *result = MP_YES;
3217       return MP_OKAY;
3218     }
3219   }
3220
3221   return MP_OKAY;
3222 }
3223
3224 /* Miller-Rabin test of "a" to the base of "b" as described in 
3225  * HAC pp. 139 Algorithm 4.24
3226  *
3227  * Sets result to 0 if definitely composite or 1 if probably prime.
3228  * Randomly the chance of error is no more than 1/4 and often 
3229  * very much lower.
3230  */
3231 static int mp_prime_miller_rabin (mp_int * a, const mp_int * b, int *result)
3232 {
3233   mp_int  n1, y, r;
3234   int     s, j, err;
3235
3236   /* default */
3237   *result = MP_NO;
3238
3239   /* ensure b > 1 */
3240   if (mp_cmp_d(b, 1) != MP_GT) {
3241      return MP_VAL;
3242   }     
3243
3244   /* get n1 = a - 1 */
3245   if ((err = mp_init_copy (&n1, a)) != MP_OKAY) {
3246     return err;
3247   }
3248   if ((err = mp_sub_d (&n1, 1, &n1)) != MP_OKAY) {
3249     goto __N1;
3250   }
3251
3252   /* set 2**s * r = n1 */
3253   if ((err = mp_init_copy (&r, &n1)) != MP_OKAY) {
3254     goto __N1;
3255   }
3256
3257   /* count the number of least significant bits
3258    * which are zero
3259    */
3260   s = mp_cnt_lsb(&r);
3261
3262   /* now divide n - 1 by 2**s */
3263   if ((err = mp_div_2d (&r, s, &r, NULL)) != MP_OKAY) {
3264     goto __R;
3265   }
3266
3267   /* compute y = b**r mod a */
3268   if ((err = mp_init (&y)) != MP_OKAY) {
3269     goto __R;
3270   }
3271   if ((err = mp_exptmod (b, &r, a, &y)) != MP_OKAY) {
3272     goto __Y;
3273   }
3274
3275   /* if y != 1 and y != n1 do */
3276   if (mp_cmp_d (&y, 1) != MP_EQ && mp_cmp (&y, &n1) != MP_EQ) {
3277     j = 1;
3278     /* while j <= s-1 and y != n1 */
3279     while ((j <= (s - 1)) && mp_cmp (&y, &n1) != MP_EQ) {
3280       if ((err = mp_sqrmod (&y, a, &y)) != MP_OKAY) {
3281          goto __Y;
3282       }
3283
3284       /* if y == 1 then composite */
3285       if (mp_cmp_d (&y, 1) == MP_EQ) {
3286          goto __Y;
3287       }
3288
3289       ++j;
3290     }
3291
3292     /* if y != n1 then composite */
3293     if (mp_cmp (&y, &n1) != MP_EQ) {
3294       goto __Y;
3295     }
3296   }
3297
3298   /* probably prime now */
3299   *result = MP_YES;
3300 __Y:mp_clear (&y);
3301 __R:mp_clear (&r);
3302 __N1:mp_clear (&n1);
3303   return err;
3304 }
3305
3306 /* performs a variable number of rounds of Miller-Rabin
3307  *
3308  * Probability of error after t rounds is no more than
3309
3310  *
3311  * Sets result to 1 if probably prime, 0 otherwise
3312  */
3313 static int mp_prime_is_prime (mp_int * a, int t, int *result)
3314 {
3315   mp_int  b;
3316   int     ix, err, res;
3317
3318   /* default to no */
3319   *result = MP_NO;
3320
3321   /* valid value of t? */
3322   if (t <= 0 || t > PRIME_SIZE) {
3323     return MP_VAL;
3324   }
3325
3326   /* is the input equal to one of the primes in the table? */
3327   for (ix = 0; ix < PRIME_SIZE; ix++) {
3328       if (mp_cmp_d(a, __prime_tab[ix]) == MP_EQ) {
3329          *result = 1;
3330          return MP_OKAY;
3331       }
3332   }
3333
3334   /* first perform trial division */
3335   if ((err = mp_prime_is_divisible (a, &res)) != MP_OKAY) {
3336     return err;
3337   }
3338
3339   /* return if it was trivially divisible */
3340   if (res == MP_YES) {
3341     return MP_OKAY;
3342   }
3343
3344   /* now perform the miller-rabin rounds */
3345   if ((err = mp_init (&b)) != MP_OKAY) {
3346     return err;
3347   }
3348
3349   for (ix = 0; ix < t; ix++) {
3350     /* set the prime */
3351     mp_set (&b, __prime_tab[ix]);
3352
3353     if ((err = mp_prime_miller_rabin (a, &b, &res)) != MP_OKAY) {
3354       goto __B;
3355     }
3356
3357     if (res == MP_NO) {
3358       goto __B;
3359     }
3360   }
3361
3362   /* passed the test */
3363   *result = MP_YES;
3364 __B:mp_clear (&b);
3365   return err;
3366 }
3367
3368 static const struct {
3369    int k, t;
3370 } sizes[] = {
3371 {   128,    28 },
3372 {   256,    16 },
3373 {   384,    10 },
3374 {   512,     7 },
3375 {   640,     6 },
3376 {   768,     5 },
3377 {   896,     4 },
3378 {  1024,     4 }
3379 };
3380
3381 /* returns # of RM trials required for a given bit size */
3382 int mp_prime_rabin_miller_trials(int size)
3383 {
3384    int x;
3385
3386    for (x = 0; x < (int)(sizeof(sizes)/(sizeof(sizes[0]))); x++) {
3387        if (sizes[x].k == size) {
3388           return sizes[x].t;
3389        } else if (sizes[x].k > size) {
3390           return (x == 0) ? sizes[0].t : sizes[x - 1].t;
3391        }
3392    }
3393    return sizes[x-1].t + 1;
3394 }
3395
3396 /* makes a truly random prime of a given size (bits),
3397  *
3398  * Flags are as follows:
3399  * 
3400  *   LTM_PRIME_BBS      - make prime congruent to 3 mod 4
3401  *   LTM_PRIME_SAFE     - make sure (p-1)/2 is prime as well (implies LTM_PRIME_BBS)
3402  *   LTM_PRIME_2MSB_OFF - make the 2nd highest bit zero
3403  *   LTM_PRIME_2MSB_ON  - make the 2nd highest bit one
3404  *
3405  * You have to supply a callback which fills in a buffer with random bytes.  "dat" is a parameter you can
3406  * have passed to the callback (e.g. a state or something).  This function doesn't use "dat" itself
3407  * so it can be NULL
3408  *
3409  */
3410
3411 /* This is possibly the mother of all prime generation functions, muahahahahaha! */
3412 int mp_prime_random_ex(mp_int *a, int t, int size, int flags, ltm_prime_callback cb, void *dat)
3413 {
3414    unsigned char *tmp, maskAND, maskOR_msb, maskOR_lsb;
3415    int res, err, bsize, maskOR_msb_offset;
3416
3417    /* sanity check the input */
3418    if (size <= 1 || t <= 0) {
3419       return MP_VAL;
3420    }
3421
3422    /* LTM_PRIME_SAFE implies LTM_PRIME_BBS */
3423    if (flags & LTM_PRIME_SAFE) {
3424       flags |= LTM_PRIME_BBS;
3425    }
3426
3427    /* calc the byte size */
3428    bsize = (size>>3)+((size&7)?1:0);
3429
3430    /* we need a buffer of bsize bytes */
3431    tmp = HeapAlloc(GetProcessHeap(), 0, bsize);
3432    if (tmp == NULL) {
3433       return MP_MEM;
3434    }
3435
3436    /* calc the maskAND value for the MSbyte*/
3437    maskAND = ((size&7) == 0) ? 0xFF : (0xFF >> (8 - (size & 7))); 
3438
3439    /* calc the maskOR_msb */
3440    maskOR_msb        = 0;
3441    maskOR_msb_offset = ((size & 7) == 1) ? 1 : 0;
3442    if (flags & LTM_PRIME_2MSB_ON) {
3443       maskOR_msb     |= 1 << ((size - 2) & 7);
3444    } else if (flags & LTM_PRIME_2MSB_OFF) {
3445       maskAND        &= ~(1 << ((size - 2) & 7));
3446    }
3447
3448    /* get the maskOR_lsb */
3449    maskOR_lsb         = 0;
3450    if (flags & LTM_PRIME_BBS) {
3451       maskOR_lsb     |= 3;
3452    }
3453
3454    do {
3455       /* read the bytes */
3456       if (cb(tmp, bsize, dat) != bsize) {
3457          err = MP_VAL;
3458          goto error;
3459       }
3460  
3461       /* work over the MSbyte */
3462       tmp[0]    &= maskAND;
3463       tmp[0]    |= 1 << ((size - 1) & 7);
3464
3465       /* mix in the maskORs */
3466       tmp[maskOR_msb_offset]   |= maskOR_msb;
3467       tmp[bsize-1]             |= maskOR_lsb;
3468
3469       /* read it in */
3470       if ((err = mp_read_unsigned_bin(a, tmp, bsize)) != MP_OKAY)     { goto error; }
3471
3472       /* is it prime? */
3473       if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)           { goto error; }
3474       if (res == MP_NO) {  
3475          continue;
3476       }
3477
3478       if (flags & LTM_PRIME_SAFE) {
3479          /* see if (a-1)/2 is prime */
3480          if ((err = mp_sub_d(a, 1, a)) != MP_OKAY)                    { goto error; }
3481          if ((err = mp_div_2(a, a)) != MP_OKAY)                       { goto error; }
3482  
3483          /* is it prime? */
3484          if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)        { goto error; }
3485       }
3486    } while (res == MP_NO);
3487
3488    if (flags & LTM_PRIME_SAFE) {
3489       /* restore a to the original value */
3490       if ((err = mp_mul_2(a, a)) != MP_OKAY)                          { goto error; }
3491       if ((err = mp_add_d(a, 1, a)) != MP_OKAY)                       { goto error; }
3492    }
3493
3494    err = MP_OKAY;
3495 error:
3496    HeapFree(GetProcessHeap(), 0, tmp);
3497    return err;
3498 }
3499
3500 /* reads an unsigned char array, assumes the msb is stored first [big endian] */
3501 int
3502 mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
3503 {
3504   int     res;
3505
3506   /* make sure there are at least two digits */
3507   if (a->alloc < 2) {
3508      if ((res = mp_grow(a, 2)) != MP_OKAY) {
3509         return res;
3510      }
3511   }
3512
3513   /* zero the int */
3514   mp_zero (a);
3515
3516   /* read the bytes in */
3517   while (c-- > 0) {
3518     if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
3519       return res;
3520     }
3521
3522       a->dp[0] |= *b++;
3523       a->used += 1;
3524   }
3525   mp_clamp (a);
3526   return MP_OKAY;
3527 }
3528
3529 /* reduces x mod m, assumes 0 < x < m**2, mu is 
3530  * precomputed via mp_reduce_setup.
3531  * From HAC pp.604 Algorithm 14.42
3532  */
3533 int
3534 mp_reduce (mp_int * x, const mp_int * m, const mp_int * mu)
3535 {
3536   mp_int  q;
3537   int     res, um = m->used;
3538
3539   /* q = x */
3540   if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
3541     return res;
3542   }
3543
3544   /* q1 = x / b**(k-1)  */
3545   mp_rshd (&q, um - 1);         
3546
3547   /* according to HAC this optimization is ok */
3548   if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
3549     if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
3550       goto CLEANUP;
3551     }
3552   } else {
3553     if ((res = s_mp_mul_high_digs (&q, mu, &q, um - 1)) != MP_OKAY) {
3554       goto CLEANUP;
3555     }
3556   }
3557
3558   /* q3 = q2 / b**(k+1) */
3559   mp_rshd (&q, um + 1);         
3560
3561   /* x = x mod b**(k+1), quick (no division) */
3562   if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
3563     goto CLEANUP;
3564   }
3565
3566   /* q = q * m mod b**(k+1), quick (no division) */
3567   if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
3568     goto CLEANUP;
3569   }
3570
3571   /* x = x - q */
3572   if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
3573     goto CLEANUP;
3574   }
3575
3576   /* If x < 0, add b**(k+1) to it */
3577   if (mp_cmp_d (x, 0) == MP_LT) {
3578     mp_set (&q, 1);
3579     if ((res = mp_lshd (&q, um + 1)) != MP_OKAY)
3580       goto CLEANUP;
3581     if ((res = mp_add (x, &q, x)) != MP_OKAY)
3582       goto CLEANUP;
3583   }
3584
3585   /* Back off if it's too big */
3586   while (mp_cmp (x, m) != MP_LT) {
3587     if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
3588       goto CLEANUP;
3589     }
3590   }
3591   
3592 CLEANUP:
3593   mp_clear (&q);
3594
3595   return res;
3596 }
3597
3598 /* reduces a modulo n where n is of the form 2**p - d */
3599 int
3600 mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d)
3601 {
3602    mp_int q;
3603    int    p, res;
3604    
3605    if ((res = mp_init(&q)) != MP_OKAY) {
3606       return res;
3607    }
3608    
3609    p = mp_count_bits(n);    
3610 top:
3611    /* q = a/2**p, a = a mod 2**p */
3612    if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
3613       goto ERR;
3614    }
3615    
3616    if (d != 1) {
3617       /* q = q * d */
3618       if ((res = mp_mul_d(&q, d, &q)) != MP_OKAY) { 
3619          goto ERR;
3620       }
3621    }
3622    
3623    /* a = a + q */
3624    if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
3625       goto ERR;
3626    }
3627    
3628    if (mp_cmp_mag(a, n) != MP_LT) {
3629       s_mp_sub(a, n, a);
3630       goto top;
3631    }
3632    
3633 ERR:
3634    mp_clear(&q);
3635    return res;
3636 }
3637
3638 /* determines the setup value */
3639 static int
3640 mp_reduce_2k_setup(const mp_int *a, mp_digit *d)
3641 {
3642    int res, p;
3643    mp_int tmp;
3644    
3645    if ((res = mp_init(&tmp)) != MP_OKAY) {
3646       return res;
3647    }
3648    
3649    p = mp_count_bits(a);
3650    if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
3651       mp_clear(&tmp);
3652       return res;
3653    }
3654    
3655    if ((res = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
3656       mp_clear(&tmp);
3657       return res;
3658    }
3659    
3660    *d = tmp.dp[0];
3661    mp_clear(&tmp);
3662    return MP_OKAY;
3663 }
3664
3665 /* pre-calculate the value required for Barrett reduction
3666  * For a given modulus "b" it calulates the value required in "a"
3667  */
3668 int mp_reduce_setup (mp_int * a, const mp_int * b)
3669 {
3670   int     res;
3671
3672   if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
3673     return res;
3674   }
3675   return mp_div (a, b, a, NULL);
3676 }
3677
3678 /* set to a digit */
3679 void mp_set (mp_int * a, mp_digit b)
3680 {
3681   mp_zero (a);
3682   a->dp[0] = b & MP_MASK;
3683   a->used  = (a->dp[0] != 0) ? 1 : 0;
3684 }
3685
3686 /* set a 32-bit const */
3687 int mp_set_int (mp_int * a, unsigned long b)
3688 {
3689   int     x, res;
3690
3691   mp_zero (a);
3692   
3693   /* set four bits at a time */
3694   for (x = 0; x < 8; x++) {
3695     /* shift the number up four bits */
3696     if ((res = mp_mul_2d (a, 4, a)) != MP_OKAY) {
3697       return res;
3698     }
3699
3700     /* OR in the top four bits of the source */
3701     a->dp[0] |= (b >> 28) & 15;
3702
3703     /* shift the source up to the next four bits */
3704     b <<= 4;
3705
3706     /* ensure that digits are not clamped off */
3707     a->used += 1;
3708   }
3709   mp_clamp (a);
3710   return MP_OKAY;
3711 }
3712
3713 /* shrink a bignum */
3714 int mp_shrink (mp_int * a)
3715 {
3716   mp_digit *tmp;
3717   if (a->alloc != a->used && a->used > 0) {
3718     if ((tmp = HeapReAlloc(GetProcessHeap(), 0, a->dp, sizeof (mp_digit) * a->used)) == NULL) {
3719       return MP_MEM;
3720     }
3721     a->dp    = tmp;
3722     a->alloc = a->used;
3723   }
3724   return MP_OKAY;
3725 }
3726
3727 /* computes b = a*a */
3728 int
3729 mp_sqr (const mp_int * a, mp_int * b)
3730 {
3731   int     res;
3732
3733 if (a->used >= KARATSUBA_SQR_CUTOFF) {
3734     res = mp_karatsuba_sqr (a, b);
3735   } else 
3736   {
3737     /* can we use the fast comba multiplier? */
3738     if ((a->used * 2 + 1) < MP_WARRAY && 
3739          a->used < 
3740          (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
3741       res = fast_s_mp_sqr (a, b);
3742     } else
3743       res = s_mp_sqr (a, b);
3744   }
3745   b->sign = MP_ZPOS;
3746   return res;
3747 }
3748
3749 /* c = a * a (mod b) */
3750 int
3751 mp_sqrmod (const mp_int * a, mp_int * b, mp_int * c)
3752 {
3753   int     res;
3754   mp_int  t;
3755
3756   if ((res = mp_init (&t)) != MP_OKAY) {
3757     return res;
3758   }
3759
3760   if ((res = mp_sqr (a, &t)) != MP_OKAY) {
3761     mp_clear (&t);
3762     return res;
3763   }
3764   res = mp_mod (&t, b, c);
3765   mp_clear (&t);
3766   return res;
3767 }
3768
3769 /* high level subtraction (handles signs) */
3770 int
3771 mp_sub (mp_int * a, mp_int * b, mp_int * c)
3772 {
3773   int     sa, sb, res;
3774
3775   sa = a->sign;
3776   sb = b->sign;
3777
3778   if (sa != sb) {
3779     /* subtract a negative from a positive, OR */
3780     /* subtract a positive from a negative. */
3781     /* In either case, ADD their magnitudes, */
3782     /* and use the sign of the first number. */
3783     c->sign = sa;
3784     res = s_mp_add (a, b, c);
3785   } else {
3786     /* subtract a positive from a positive, OR */
3787     /* subtract a negative from a negative. */
3788     /* First, take the difference between their */
3789     /* magnitudes, then... */
3790     if (mp_cmp_mag (a, b) != MP_LT) {
3791       /* Copy the sign from the first */
3792       c->sign = sa;
3793       /* The first has a larger or equal magnitude */
3794       res = s_mp_sub (a, b, c);
3795     } else {
3796       /* The result has the *opposite* sign from */
3797       /* the first number. */
3798       c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
3799       /* The second has a larger magnitude */
3800       res = s_mp_sub (b, a, c);
3801     }
3802   }
3803   return res;
3804 }
3805
3806 /* single digit subtraction */
3807 int
3808 mp_sub_d (mp_int * a, mp_digit b, mp_int * c)
3809 {
3810   mp_digit *tmpa, *tmpc, mu;
3811   int       res, ix, oldused;
3812
3813   /* grow c as required */
3814   if (c->alloc < a->used + 1) {
3815      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
3816         return res;
3817      }
3818   }
3819
3820   /* if a is negative just do an unsigned
3821    * addition [with fudged signs]
3822    */
3823   if (a->sign == MP_NEG) {
3824      a->sign = MP_ZPOS;
3825      res     = mp_add_d(a, b, c);
3826      a->sign = c->sign = MP_NEG;
3827      return res;
3828   }
3829
3830   /* setup regs */
3831   oldused = c->used;
3832   tmpa    = a->dp;
3833   tmpc    = c->dp;
3834
3835   /* if a <= b simply fix the single digit */
3836   if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
3837      if (a->used == 1) {
3838         *tmpc++ = b - *tmpa;
3839      } else {
3840         *tmpc++ = b;
3841      }
3842      ix      = 1;
3843
3844      /* negative/1digit */
3845      c->sign = MP_NEG;
3846      c->used = 1;
3847   } else {
3848      /* positive/size */
3849      c->sign = MP_ZPOS;
3850      c->used = a->used;
3851
3852      /* subtract first digit */
3853      *tmpc    = *tmpa++ - b;
3854      mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3855      *tmpc++ &= MP_MASK;
3856
3857      /* handle rest of the digits */
3858      for (ix = 1; ix < a->used; ix++) {
3859         *tmpc    = *tmpa++ - mu;
3860         mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3861         *tmpc++ &= MP_MASK;
3862      }
3863   }
3864
3865   /* zero excess digits */
3866   while (ix++ < oldused) {
3867      *tmpc++ = 0;
3868   }
3869   mp_clamp(c);
3870   return MP_OKAY;
3871 }
3872
3873 /* store in unsigned [big endian] format */
3874 int
3875 mp_to_unsigned_bin (const mp_int * a, unsigned char *b)
3876 {
3877   int     x, res;
3878   mp_int  t;
3879
3880   if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
3881     return res;
3882   }
3883
3884   x = 0;
3885   while (mp_iszero (&t) == 0) {
3886     b[x++] = (unsigned char) (t.dp[0] & 255);
3887     if ((res = mp_div_2d (&t, 8, &t, NULL)) != MP_OKAY) {
3888       mp_clear (&t);
3889       return res;
3890     }
3891   }
3892   bn_reverse (b, x);
3893   mp_clear (&t);
3894   return MP_OKAY;
3895 }
3896
3897 /* get the size for an unsigned equivalent */
3898 int
3899 mp_unsigned_bin_size (const mp_int * a)
3900 {
3901   int     size = mp_count_bits (a);
3902   return (size / 8 + ((size & 7) != 0 ? 1 : 0));
3903 }
3904
3905 /* reverse an array, used for radix code */
3906 static void
3907 bn_reverse (unsigned char *s, int len)
3908 {
3909   int     ix, iy;
3910   unsigned char t;
3911
3912   ix = 0;
3913   iy = len - 1;
3914   while (ix < iy) {
3915     t     = s[ix];
3916     s[ix] = s[iy];
3917     s[iy] = t;
3918     ++ix;
3919     --iy;
3920   }
3921 }
3922
3923 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
3924 static int
3925 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
3926 {
3927   mp_int *x;
3928   int     olduse, res, min, max;
3929
3930   /* find sizes, we let |a| <= |b| which means we have to sort
3931    * them.  "x" will point to the input with the most digits
3932    */
3933   if (a->used > b->used) {
3934     min = b->used;
3935     max = a->used;
3936     x = a;
3937   } else {
3938     min = a->used;
3939     max = b->used;
3940     x = b;
3941   }
3942
3943   /* init result */
3944   if (c->alloc < max + 1) {
3945     if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
3946       return res;
3947     }
3948   }
3949
3950   /* get old used digit count and set new one */
3951   olduse = c->used;
3952   c->used = max + 1;
3953
3954   {
3955     register mp_digit u, *tmpa, *tmpb, *tmpc;
3956     register int i;
3957
3958     /* alias for digit pointers */
3959
3960     /* first input */
3961     tmpa = a->dp;
3962
3963     /* second input */
3964     tmpb = b->dp;
3965
3966     /* destination */
3967     tmpc = c->dp;
3968
3969     /* zero the carry */
3970     u = 0;
3971     for (i = 0; i < min; i++) {
3972       /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
3973       *tmpc = *tmpa++ + *tmpb++ + u;
3974
3975       /* U = carry bit of T[i] */
3976       u = *tmpc >> ((mp_digit)DIGIT_BIT);
3977
3978       /* take away carry bit from T[i] */
3979       *tmpc++ &= MP_MASK;
3980     }
3981
3982     /* now copy higher words if any, that is in A+B 
3983      * if A or B has more digits add those in 
3984      */
3985     if (min != max) {
3986       for (; i < max; i++) {
3987         /* T[i] = X[i] + U */
3988         *tmpc = x->dp[i] + u;
3989
3990         /* U = carry bit of T[i] */
3991         u = *tmpc >> ((mp_digit)DIGIT_BIT);
3992
3993         /* take away carry bit from T[i] */
3994         *tmpc++ &= MP_MASK;
3995       }
3996     }
3997
3998     /* add carry */
3999     *tmpc++ = u;
4000
4001     /* clear digits above oldused */
4002     for (i = c->used; i < olduse; i++) {
4003       *tmpc++ = 0;
4004     }
4005   }
4006
4007   mp_clamp (c);
4008   return MP_OKAY;
4009 }
4010
4011 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
4012 {
4013   mp_int  M[256], res, mu;
4014   mp_digit buf;
4015   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
4016
4017   /* find window size */
4018   x = mp_count_bits (X);
4019   if (x <= 7) {
4020     winsize = 2;
4021   } else if (x <= 36) {
4022     winsize = 3;
4023   } else if (x <= 140) {
4024     winsize = 4;
4025   } else if (x <= 450) {
4026     winsize = 5;
4027   } else if (x <= 1303) {
4028     winsize = 6;
4029   } else if (x <= 3529) {
4030     winsize = 7;
4031   } else {
4032     winsize = 8;
4033   }
4034
4035   /* init M array */
4036   /* init first cell */
4037   if ((err = mp_init(&M[1])) != MP_OKAY) {
4038      return err; 
4039   }
4040
4041   /* now init the second half of the array */
4042   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4043     if ((err = mp_init(&M[x])) != MP_OKAY) {
4044       for (y = 1<<(winsize-1); y < x; y++) {
4045         mp_clear (&M[y]);
4046       }
4047       mp_clear(&M[1]);
4048       return err;
4049     }
4050   }
4051
4052   /* create mu, used for Barrett reduction */
4053   if ((err = mp_init (&mu)) != MP_OKAY) {
4054     goto __M;
4055   }
4056   if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
4057     goto __MU;
4058   }
4059
4060   /* create M table
4061    *
4062    * The M table contains powers of the base, 
4063    * e.g. M[x] = G**x mod P
4064    *
4065    * The first half of the table is not 
4066    * computed though accept for M[0] and M[1]
4067    */
4068   if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
4069     goto __MU;
4070   }
4071
4072   /* compute the value at M[1<<(winsize-1)] by squaring 
4073    * M[1] (winsize-1) times 
4074    */
4075   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
4076     goto __MU;
4077   }
4078
4079   for (x = 0; x < (winsize - 1); x++) {
4080     if ((err = mp_sqr (&M[1 << (winsize - 1)], 
4081                        &M[1 << (winsize - 1)])) != MP_OKAY) {
4082       goto __MU;
4083     }
4084     if ((err = mp_reduce (&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
4085       goto __MU;
4086     }
4087   }
4088
4089   /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
4090    * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
4091    */
4092   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
4093     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
4094       goto __MU;
4095     }
4096     if ((err = mp_reduce (&M[x], P, &mu)) != MP_OKAY) {
4097       goto __MU;
4098     }
4099   }
4100
4101   /* setup result */
4102   if ((err = mp_init (&res)) != MP_OKAY) {
4103     goto __MU;
4104   }
4105   mp_set (&res, 1);
4106
4107   /* set initial mode and bit cnt */
4108   mode   = 0;
4109   bitcnt = 1;
4110   buf    = 0;
4111   digidx = X->used - 1;
4112   bitcpy = 0;
4113   bitbuf = 0;
4114
4115   for (;;) {
4116     /* grab next digit as required */
4117     if (--bitcnt == 0) {
4118       /* if digidx == -1 we are out of digits */
4119       if (digidx == -1) {
4120         break;
4121       }
4122       /* read next digit and reset the bitcnt */
4123       buf    = X->dp[digidx--];
4124       bitcnt = DIGIT_BIT;
4125     }
4126
4127     /* grab the next msb from the exponent */
4128     y     = (buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
4129     buf <<= (mp_digit)1;
4130
4131     /* if the bit is zero and mode == 0 then we ignore it
4132      * These represent the leading zero bits before the first 1 bit
4133      * in the exponent.  Technically this opt is not required but it
4134      * does lower the # of trivial squaring/reductions used
4135      */
4136     if (mode == 0 && y == 0) {
4137       continue;
4138     }
4139
4140     /* if the bit is zero and mode == 1 then we square */
4141     if (mode == 1 && y == 0) {
4142       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4143         goto __RES;
4144       }
4145       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4146         goto __RES;
4147       }
4148       continue;
4149     }
4150
4151     /* else we add it to the window */
4152     bitbuf |= (y << (winsize - ++bitcpy));
4153     mode    = 2;
4154
4155     if (bitcpy == winsize) {
4156       /* ok window is filled so square as required and multiply  */
4157       /* square first */
4158       for (x = 0; x < winsize; x++) {
4159         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4160           goto __RES;
4161         }
4162         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4163           goto __RES;
4164         }
4165       }
4166
4167       /* then multiply */
4168       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
4169         goto __RES;
4170       }
4171       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4172         goto __RES;
4173       }
4174
4175       /* empty window and reset */
4176       bitcpy = 0;
4177       bitbuf = 0;
4178       mode   = 1;
4179     }
4180   }
4181
4182   /* if bits remain then square/multiply */
4183   if (mode == 2 && bitcpy > 0) {
4184     /* square then multiply if the bit is set */
4185     for (x = 0; x < bitcpy; x++) {
4186       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4187         goto __RES;
4188       }
4189       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4190         goto __RES;
4191       }
4192
4193       bitbuf <<= 1;
4194       if ((bitbuf & (1 << winsize)) != 0) {
4195         /* then multiply */
4196         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
4197           goto __RES;
4198         }
4199         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4200           goto __RES;
4201         }
4202       }
4203     }
4204   }
4205
4206   mp_exch (&res, Y);
4207   err = MP_OKAY;
4208 __RES:mp_clear (&res);
4209 __MU:mp_clear (&mu);
4210 __M:
4211   mp_clear(&M[1]);
4212   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4213     mp_clear (&M[x]);
4214   }
4215   return err;
4216 }
4217
4218 /* multiplies |a| * |b| and only computes up to digs digits of result
4219  * HAC pp. 595, Algorithm 14.12  Modified so you can control how 
4220  * many digits of output are created.
4221  */
4222 static int
4223 s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4224 {
4225   mp_int  t;
4226   int     res, pa, pb, ix, iy;
4227   mp_digit u;
4228   mp_word r;
4229   mp_digit tmpx, *tmpt, *tmpy;
4230
4231   /* can we use the fast multiplier? */
4232   if (((digs) < MP_WARRAY) &&
4233       MIN (a->used, b->used) < 
4234           (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4235     return fast_s_mp_mul_digs (a, b, c, digs);
4236   }
4237
4238   if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
4239     return res;
4240   }
4241   t.used = digs;
4242
4243   /* compute the digits of the product directly */
4244   pa = a->used;
4245   for (ix = 0; ix < pa; ix++) {
4246     /* set the carry to zero */
4247     u = 0;
4248
4249     /* limit ourselves to making digs digits of output */
4250     pb = MIN (b->used, digs - ix);
4251
4252     /* setup some aliases */
4253     /* copy of the digit from a used within the nested loop */
4254     tmpx = a->dp[ix];
4255     
4256     /* an alias for the destination shifted ix places */
4257     tmpt = t.dp + ix;
4258     
4259     /* an alias for the digits of b */
4260     tmpy = b->dp;
4261
4262     /* compute the columns of the output and propagate the carry */
4263     for (iy = 0; iy < pb; iy++) {
4264       /* compute the column as a mp_word */
4265       r       = ((mp_word)*tmpt) +
4266                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4267                 ((mp_word) u);
4268
4269       /* the new column is the lower part of the result */
4270       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4271
4272       /* get the carry word from the result */
4273       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4274     }
4275     /* set carry if it is placed below digs */
4276     if (ix + iy < digs) {
4277       *tmpt = u;
4278     }
4279   }
4280
4281   mp_clamp (&t);
4282   mp_exch (&t, c);
4283
4284   mp_clear (&t);
4285   return MP_OKAY;
4286 }
4287
4288 /* multiplies |a| * |b| and does not compute the lower digs digits
4289  * [meant to get the higher part of the product]
4290  */
4291 static int
4292 s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4293 {
4294   mp_int  t;
4295   int     res, pa, pb, ix, iy;
4296   mp_digit u;
4297   mp_word r;
4298   mp_digit tmpx, *tmpt, *tmpy;
4299
4300   /* can we use the fast multiplier? */
4301   if (((a->used + b->used + 1) < MP_WARRAY)
4302       && MIN (a->used, b->used) < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4303     return fast_s_mp_mul_high_digs (a, b, c, digs);
4304   }
4305
4306   if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
4307     return res;
4308   }
4309   t.used = a->used + b->used + 1;
4310
4311   pa = a->used;
4312   pb = b->used;
4313   for (ix = 0; ix < pa; ix++) {
4314     /* clear the carry */
4315     u = 0;
4316
4317     /* left hand side of A[ix] * B[iy] */
4318     tmpx = a->dp[ix];
4319
4320     /* alias to the address of where the digits will be stored */
4321     tmpt = &(t.dp[digs]);
4322
4323     /* alias for where to read the right hand side from */
4324     tmpy = b->dp + (digs - ix);
4325
4326     for (iy = digs - ix; iy < pb; iy++) {
4327       /* calculate the double precision result */
4328       r       = ((mp_word)*tmpt) +
4329                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4330                 ((mp_word) u);
4331
4332       /* get the lower part */
4333       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4334
4335       /* carry the carry */
4336       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4337     }
4338     *tmpt = u;
4339   }
4340   mp_clamp (&t);
4341   mp_exch (&t, c);
4342   mp_clear (&t);
4343   return MP_OKAY;
4344 }
4345
4346 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
4347 static int
4348 s_mp_sqr (const mp_int * a, mp_int * b)
4349 {
4350   mp_int  t;
4351   int     res, ix, iy, pa;
4352   mp_word r;
4353   mp_digit u, tmpx, *tmpt;
4354
4355   pa = a->used;
4356   if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
4357     return res;
4358   }
4359
4360   /* default used is maximum possible size */
4361   t.used = 2*pa + 1;
4362
4363   for (ix = 0; ix < pa; ix++) {
4364     /* first calculate the digit at 2*ix */
4365     /* calculate double precision result */
4366     r = ((mp_word) t.dp[2*ix]) +
4367         ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
4368
4369     /* store lower part in result */
4370     t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
4371
4372     /* get the carry */
4373     u           = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4374
4375     /* left hand side of A[ix] * A[iy] */
4376     tmpx        = a->dp[ix];
4377
4378     /* alias for where to store the results */
4379     tmpt        = t.dp + (2*ix + 1);
4380     
4381     for (iy = ix + 1; iy < pa; iy++) {
4382       /* first calculate the product */
4383       r       = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
4384
4385       /* now calculate the double precision result, note we use
4386        * addition instead of *2 since it's easier to optimize
4387        */
4388       r       = ((mp_word) *tmpt) + r + r + ((mp_word) u);
4389
4390       /* store lower part */
4391       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4392
4393       /* get carry */
4394       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4395     }
4396     /* propagate upwards */
4397     while (u != 0) {
4398       r       = ((mp_word) *tmpt) + ((mp_word) u);
4399       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4400       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4401     }
4402   }
4403
4404   mp_clamp (&t);
4405   mp_exch (&t, b);
4406   mp_clear (&t);
4407   return MP_OKAY;
4408 }
4409
4410 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
4411 int
4412 s_mp_sub (const mp_int * a, const mp_int * b, mp_int * c)
4413 {
4414   int     olduse, res, min, max;
4415
4416   /* find sizes */
4417   min = b->used;
4418   max = a->used;
4419
4420   /* init result */
4421   if (c->alloc < max) {
4422     if ((res = mp_grow (c, max)) != MP_OKAY) {
4423       return res;
4424     }
4425   }
4426   olduse = c->used;
4427   c->used = max;
4428
4429   {
4430     register mp_digit u, *tmpa, *tmpb, *tmpc;
4431     register int i;
4432
4433     /* alias for digit pointers */
4434     tmpa = a->dp;
4435     tmpb = b->dp;
4436     tmpc = c->dp;
4437
4438     /* set carry to zero */
4439     u = 0;
4440     for (i = 0; i < min; i++) {
4441       /* T[i] = A[i] - B[i] - U */
4442       *tmpc = *tmpa++ - *tmpb++ - u;
4443
4444       /* U = carry bit of T[i]
4445        * Note this saves performing an AND operation since
4446        * if a carry does occur it will propagate all the way to the
4447        * MSB.  As a result a single shift is enough to get the carry
4448        */
4449       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4450
4451       /* Clear carry from T[i] */
4452       *tmpc++ &= MP_MASK;
4453     }
4454
4455     /* now copy higher words if any, e.g. if A has more digits than B  */
4456     for (; i < max; i++) {
4457       /* T[i] = A[i] - U */
4458       *tmpc = *tmpa++ - u;
4459
4460       /* U = carry bit of T[i] */
4461       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4462
4463       /* Clear carry from T[i] */
4464       *tmpc++ &= MP_MASK;
4465     }
4466
4467     /* clear digits above used (since we may not have grown result above) */
4468     for (i = c->used; i < olduse; i++) {
4469       *tmpc++ = 0;
4470     }
4471   }
4472
4473   mp_clamp (c);
4474   return MP_OKAY;
4475 }