]> www.ginac.de Git - cln.git/blob - src/modinteger/cl_MI_montgom.h
Replace unused macro with cl_unused.
[cln.git] / src / modinteger / cl_MI_montgom.h
1 // m > 1 odd, Montgomery representation
2
3 namespace cln {
4
5 // We use Montgomery's modular multiplication trick
6 // [Peter L. Montgomery: Modular multiplication without trial division,
7 //  Mathematics of Computation 44 (1985), 519-521.]
8 //
9 // Assume we want to compute modulo M, M odd. V and N will be chosen
10 // so that V*N==1 mod M and that (a,b) --> a*b*V mod M can be more easily
11 // computed than (a,b) --> a*b mod M. Then, we have a ring isomorphism
12 //   (Z/MZ, +, * mod M)  \isomorph  (Z/MZ, +, (a,b) --> a*b*V mod M)
13 //   x mod M             -------->  x*N mod M
14 // It is thus preferrable to use x*N mod M as a "representation" of x mod M,
15 // especially for computations which involve at least several multiplications.
16 //
17 // The precise algorithm to compute a*b*V mod M, given a and b, and the choice
18 // of N and V depend on M and on the hardware. The general idea is this:
19 // Choose N = 2^n, so that division by N is easy. Recall that V == N^-1 mod M.
20 // 1. Given a and b as m-bit numbers (M <= 2^m), compute a*b in full
21 //    precision.
22 // 2. Write a*b = c*N+d, i.e. split it into components c and d.
23 // 3. Now a*b*V = c*N*V+d*V == c+d*V mod M.
24 // 4. Instead of computing d*V mod M
25 //    a. by full multiplication and then division mod M, or
26 //    b. by left shifts: repeated application of
27 //          x := 2*x+(0 or 1); if (x >= M) { x := x-M; }
28 //    we compute
29 //    c. by right shifts (recall that d*V == d*2^-n mod M): repeated application
30 //       of   if (x odd) { x := (x+M)/2; } else { x := x/2; }
31 // Usually one will choose N = 2^m, so that c and d have both m bits.
32 // Several variations are possible: In step 4 one can implement the right
33 // shifts in hardware. Or (for example when N = 2^160 and working on a
34 // 32-bit machine) one can do 32 shift steps at the same time:
35 // Choose M' == M^-1 mod 2^32 and compute n/32 times
36 //       x := (x - ((x mod 2^32) * M' mod 2^32) * M) / 2^32.
37
38 // Here, we choose to use Montgomery representation only if |V| can be chosen
39 // to be very small, and in that case we compute d*V mod M using standard
40 // multiplication and division.
41 // So we choose N = 2^n with 0 < n <= m (the larger n, the better) and hope
42 // that it will yield V with |V| < 2^k. We thus replace a division of
43 // 2m bits by m bits (cost: approx. m^2) by a multiplication of n bits with
44 // k bits (cost: approx. n*k) and a division of max(2m-n,n+k) bits by m bits
45 // (cost: approx. (max(2m-n,n+k)-m)*m). Of course, U*M+V*N=1 implies (roughly)
46 // n+k >= m. It is worth it when
47 //         m^2 > n*k + (n+k-m)*m   and   m^2 > n*k + (m-n)*m
48 // <==>  3*m^2 > (n+m)*(k+m)       and     m > k
49 // <==   3/2*m > k+m                (assume n to be near m)
50 // <==>    m/2 > k .
51 //
52 // How to find N and V:
53 // U*M+V*N=1 means that U = (M mod 2^n)^-1 = U_m mod 2^n, where
54 // U_m := (M mod 2^m)^-1 (2-adic reciprocal). |V| < 2^(m/2) is more or less
55 // equivalent to |V*N| < 2^(n+m/2) <==> |U|*M < 2^(n+m/2) <==> |U| < n-m/2
56 // <==> the most significant m/2 bits of |U| are all equal. So we search
57 // for a bit string of at least m/2+1 equal bits in U_m, which has m bits.
58 // Very easy: take the middle bit of U_m, look how many bits adjacent to it
59 // (at the left and at the right) have the same value. Say these are the
60 // bits n-1,...,n-l. (Choose n and l as large as possible. k = m-l + O(1).)
61 // If l < m/2, forget it. Else fix n and compute V = (1-U*M)/2^n.
62 //
63 // It is now clear that only very few moduli M will allow such a good
64 // choice of N and V, but in these cases the Montgomery multiplication
65 // reduces the multiplication complexity by a large constant factor.
66
67
68 class cl_heap_modint_ring_montgom : public cl_heap_modint_ring {
69         SUBCLASS_cl_heap_modint_ring()
70 public:
71         // Constructor.
72         cl_heap_modint_ring_montgom (const cl_I& M, uintL m, uintL n, const cl_I& V);
73         // Destructor.
74         ~cl_heap_modint_ring_montgom () {}
75         // Additional information.
76         uintL m; // M = 2^m
77         uintL n; // N = 2^n, n <= m
78         cl_I V;
79 };
80
81 static void cl_modint_ring_montgom_destructor (cl_heap* pointer)
82 {
83         (*(cl_heap_modint_ring_montgom*)pointer).~cl_heap_modint_ring_montgom();
84 }
85
86 cl_class cl_class_modint_ring_montgom = {
87         cl_modint_ring_montgom_destructor,
88         cl_class_flags_modint_ring
89 };
90
91 // Assuming 0 <= x < 2^(2m), return  V*x mod M.
92 static inline const cl_I montgom_redc (cl_heap_modint_ring_montgom* R, const cl_I& x)
93 {
94         return mod((x >> R->n) + (R->V * ldb(x,cl_byte(R->n,0))), R->modulus);
95 }
96
97 static const _cl_MI montgom_canonhom (cl_heap_modint_ring* _R, const cl_I& x)
98 {
99         var cl_heap_modint_ring_montgom* R = (cl_heap_modint_ring_montgom*)_R;
100         return _cl_MI(R, mod(x << R->n, R->modulus));
101 }
102
103 static const cl_I montgom_retract (cl_heap_modint_ring* _R, const _cl_MI& x)
104 {
105         var cl_heap_modint_ring_montgom* R = (cl_heap_modint_ring_montgom*)_R;
106         return montgom_redc(R,x.rep);
107 }
108
109 static const _cl_MI montgom_one (cl_heap_modint_ring* _R)
110 {
111         var cl_heap_modint_ring_montgom* R = (cl_heap_modint_ring_montgom*)_R;
112         var cl_I zr = (cl_I)1 << R->n;
113         return _cl_MI(R, R->n == R->m ? zr - R->modulus : zr);
114 }
115
116 static const _cl_MI montgom_mul (cl_heap_modint_ring* _R, const _cl_MI& x, const _cl_MI& y)
117 {
118         var cl_heap_modint_ring_montgom* R = (cl_heap_modint_ring_montgom*)_R;
119         return _cl_MI(R, montgom_redc(R,x.rep * y.rep));
120 }
121
122 static const _cl_MI montgom_square (cl_heap_modint_ring* _R, const _cl_MI& x)
123 {
124         var cl_heap_modint_ring_montgom* R = (cl_heap_modint_ring_montgom*)_R;
125         return _cl_MI(R, montgom_redc(R,square(x.rep)));
126 }
127
128 static const cl_MI_x montgom_recip (cl_heap_modint_ring* _R, const _cl_MI& x)
129 {
130         var cl_heap_modint_ring_montgom* R = (cl_heap_modint_ring_montgom*)_R;
131         var const cl_I& xr = x.rep;
132         var cl_I u, v;
133         var cl_I g = xgcd(xr,R->modulus,&u,&v);
134         // g = gcd(x,M) = x*u+M*v
135         if (eq(g,1))
136                 return cl_MI(R, mod((minusp(u) ? u + R->modulus : u) << (2*R->n), R->modulus));
137         if (zerop(xr))
138                 throw division_by_0_exception();
139         return cl_notify_composite(R,xr);
140 }
141
142 static const cl_MI_x montgom_div (cl_heap_modint_ring* _R, const _cl_MI& x, const _cl_MI& y)
143 {
144         var cl_heap_modint_ring_montgom* R = (cl_heap_modint_ring_montgom*)_R;
145         var const cl_I& yr = y.rep;
146         var cl_I u, v;
147         var cl_I g = xgcd(yr,R->modulus,&u,&v);
148         // g = gcd(y,M) = y*u+M*v
149         if (eq(g,1))
150                 return cl_MI(R, mod((x.rep * (minusp(u) ? u + R->modulus : u)) << R->n, R->modulus));
151         if (zerop(yr))
152                 throw division_by_0_exception();
153         return cl_notify_composite(R,yr);
154 }
155
156 #define montgom_addops std_addops
157 static cl_modint_mulops montgom_mulops = {
158         montgom_one,
159         montgom_canonhom,
160         montgom_mul,
161         montgom_square,
162         std_expt_pos,
163         montgom_recip,
164         montgom_div,
165         std_expt,
166         std_reduce_modulo,
167         montgom_retract
168 };
169
170 // Constructor.
171 inline cl_heap_modint_ring_montgom::cl_heap_modint_ring_montgom (const cl_I& M, uintL _m, uintL _n, const cl_I& _V)
172         : cl_heap_modint_ring (M, &std_setops, &montgom_addops, &montgom_mulops),
173           m (_m), n (_n), V (_V)
174 {
175         type = &cl_class_modint_ring_montgom;
176 }
177
178 static cl_heap_modint_ring* try_make_modint_ring_montgom (const cl_I& M)
179 {
180         if (!oddp(M))
181                 return NULL;
182         var uintC m = integer_length(M);
183         CL_ALLOCA_STACK;
184         var uintC len;
185         var const uintD* M_LSDptr;
186         I_to_NDS_nocopy(M, ,len=,M_LSDptr=,false,);
187         if (lspref(M_LSDptr,len-1)==0) { len--; } // normalize
188         // Compute U as 2-adic inverse of M.
189         var uintD* U_LSDptr;
190         num_stack_alloc(len,,U_LSDptr=);
191         recip2adic(len,M_LSDptr,U_LSDptr);
192         // Look at U's bits.
193         #define U_bit(i) (lspref(U_LSDptr,floor(i,intDsize)) & ((uintD)1 << ((i)%intDsize)))
194         var uintC i_min;
195         var uintC i_max;
196         var uintC i = floor(m,2);
197         var bool negative;
198         if (U_bit(i)) {
199                 for (; --i > 0; )
200                         if (!U_bit(i)) break;
201                 i_min = i+1;
202                 i = floor(m,2);
203                 for (; ++i < m; )
204                         if (!U_bit(i)) break;
205                 i_max = i;
206                 negative = true;
207         } else {
208                 for (; --i > 0; )
209                         if (U_bit(i)) break;
210                 i_min = i+1;
211                 i = floor(m,2);
212                 for (; ++i < m; )
213                         if (U_bit(i)) break;
214                 i_max = i;
215                 negative = false;
216         }
217         #undef U_bit
218         // OK, all the bits i_max-1..i_min of U are equal.
219         if (i_max - i_min <= floor(m,2))
220                 return NULL;
221         var uintC n = i_max;
222         // Turn U (mod 2^n) into a signed integer.
223         if (n % intDsize) {
224                 if (negative)
225                         lspref(U_LSDptr,floor(n,intDsize)) |= (uintD)(-1) << (n % intDsize);
226                 else
227                         lspref(U_LSDptr,floor(n,intDsize)) &= ((uintD)1 << (n % intDsize)) - 1;
228         }
229         var uintC U_len = ceiling(n,intDsize);
230         var cl_I U = DS_to_I(U_LSDptr lspop U_len,U_len);
231         var cl_I V_N = 1 - U*M;
232         if (ldb_test(V_N,cl_byte(n,0)))
233                 throw runtime_exception();
234         var cl_I V = V_N >> n;
235         return new cl_heap_modint_ring_montgom(M,m,n,V);
236 }
237
238 }  // namespace cln