]> www.ginac.de Git - cln.git/blob - src/base/digitseq/cl_DS_mul_fftp3m.h
14a277c6a5ea56cb74943c4b5ceaa236ca1fc86a
[cln.git] / src / base / digitseq / cl_DS_mul_fftp3m.h
1 // Fast integer multiplication using FFT in a modular ring.
2 // Bruno Haible 5.5.1996, 30.6.1996, 20.8.1996
3
4 // FFT in the complex domain has the drawback that it needs careful round-off
5 // error analysis. So here we choose another field of characteristic 0: Q_p.
6 // Since Q_p contains exactly the (p-1)th roots of unity, we choose
7 // p == 1 mod N and have the Nth roots of unity (N = 2^n) in Q_p and
8 // even in Z_p. Actually, we compute in Z/(p^m Z).
9
10 // All operations the FFT algorithm needs is addition, subtraction,
11 // multiplication, multiplication by the Nth root of unity and division
12 // by N. Hence we can use the domain Z/(p^m Z) even if p is not a prime!
13
14 // We want to compute the convolution of N 32-bit words. The resulting
15 // words are < (2^32)^2 * N. To avoid computing with numbers greater than
16 // 32 bits, we compute in Z/pZ for three different primes p in parallel,
17 // i.e. we compute in the ring (Z / p1 Z) x (Z / p2 Z) x (Z / p3 Z). We choose
18 // p1 = 3*2^30+1, p2 = 15*2^27+1, p3 = 7*2^26+1.
19 // Because of p1*p2*p3 >= 2^91 >= (2^32)^2 * N, the chinese remainder theorem
20 // will faithfully combine 3 32-bit words to a word < (2^32)^2 * N.
21
22 // Furthermore we use Montgomery's modular multiplication trick
23 // [Peter L. Montgomery: Modular multiplication without trial division,
24 //  Mathematics of Computation 44 (1985), 519-521.]
25 //
26 // Assume we want to compute modulo M, M odd. V and N will be chosen
27 // so that V*N==1 mod M and that (a,b) --> a*b*V mod M can be more easily
28 // computed than (a,b) --> a*b mod M. Then, we have a ring isomorphism
29 //   (Z/MZ, +, * mod M)  \isomorph  (Z/MZ, +, (a,b) --> a*b*V mod M)
30 //   x mod M             -------->  x*N mod M
31 // It is thus preferrable to use x*N mod M as a "representation" of x mod M,
32 // especially for computations which involve at least several multiplications.
33 //
34 // The precise algorithm to compute a*b*V mod M, given a and b, and the choice
35 // of N and V depend on M and on the hardware. The general idea is this:
36 // Choose N = 2^n, so that division by N is easy. Recall that V == N^-1 mod M.
37 // 1. Given a and b as m-bit numbers (M <= 2^m), compute a*b in full
38 //    precision.
39 // 2. Write a*b = c*N+d, i.e. split it into components c and d.
40 // 3. Now a*b*V = c*N*V+d*V == c+d*V mod M.
41 // 4. Instead of computing d*V mod M
42 //    a. by full multiplication and then division mod M, or
43 //    b. by left shifts: repeated application of
44 //          x := 2*x+(0 or 1); if (x >= M) { x := x-M; }
45 //    we compute
46 //    c. by right shifts (recall that d*V == d*2^-n mod M): repeated application
47 //       of   if (x odd) { x := (x+M)/2; } else { x := x/2; }
48 // Usually one will choose N = 2^m, so that c and d have both m bits.
49 // Several variations are possible: In step 4 one can implement the right
50 // shifts in hardware. Or (for example when N = 2^160 and working on a
51 // 32-bit machine) one can do 32 shift steps at the same time:
52 // Choose M' == M^-1 mod 2^32 and compute n/32 times
53 //       x := (x - ((x mod 2^32) * M' mod 2^32) * M) / 2^32.
54 //
55 // Here, we deal with moduli M = p_i = j*2^k+1. These form of primes comes
56 // in because we need 2^n-th roots of unity mod M. But is also comes handy
57 // for Montgomery multiplication: Instead of choosing N = 2^32 (which makes
58 // up for very easy splitting in step 1) and V = j^2*2^(2*k-32), we better
59 // choose N = 2^k and V = -j. The algorithm now goes like this (recall that
60 // M is an m-bit number and j is an (m-k)-bit number):
61 // 1. Compute a*b in full precision, as a 2*m <= 64 bit number.
62 // 2. Split a*b = c*N+d, with c an (2m-k)-bit number and d an k-bit number.
63 // 3. a*b*V == c+d*V mod M.
64 // 4. Compute c mod M by splitting off the leading (m-k+1) bits of c and
65 //    using table lookup; the remainder (c mod 2^(m-1)) is already reduced
66 //    mod M.
67 //    Compute d*|V| the standard way; |V| has only few bits. d*|V| is
68 //    already reduced mod M, because d*|V| < j*2^k < M.
69
70 // In order to get best performance, we carefully choose the primes so that
71 // a. the table of size 2^(m-k+1) doesn't get too large,
72 // b. multiplication by V is easy.
73 // Here is a list of the interesting primes < 2^32:
74 //
75 //                       U*M+V*N = 1
76 //   prime       bits    N=2^n  V    U    2m<=n+32 ?
77 //     M           m       n
78 //
79 //   3*2^30+1     32      30    -3   1        n       (*)
80 //
81 //  13*2^28+1     32      28   -13   1        n
82 //
83 //  15*2^27+1     31      27   -15   1        n       (*)
84 //  17*2^27+1     32      27   -17   1        n
85 //  29*2^27+1     32      27   -29   1        n
86 //
87 //   7*2^26+1     29      26    -7   1        y       (*)
88 //  27*2^26+1     31      26   -27   1        n
89 //  37*2^26+1     32      26   -37   1        n
90 //  43*2^26+1     32      26   -43   1        n
91 //
92 //   5*2^25+1     28      25    -5   1        y
93 //  33*2^25+1     31      25   -33   1        n
94 //  51*2^25+1     31      25   -51   1        n
95 //  63*2^25+1     31      25   -63   1        n
96 //  81*2^25+1     32      25   -81   1        n
97 // 125*2^25+1     32      25  -125   1        n
98 //
99 //  45*2^24+1     30      24   -45   1        n
100 //  73*2^24+1     31      24   -73   1        n
101 // 127*2^24+1     31      24  -127   1        n
102 // 151*2^24+1     32      24  -151   1        n
103 // 157*2^24+1     32      24  -157   1        n
104 // 171*2^24+1     32      24  -171   1        n
105 // 193*2^24+1     32      24  -193   1        n
106 // 235*2^24+1     32      24  -235   1        n
107 // 243*2^24+1     32      24  -243   1        n
108 //
109 //  45*2^23+1     29      23   -45   1        n
110 // ...
111 //
112 // The inequality 2m<=n+32 would mean that c fits in a 32-bit word, but that's
113 // actually irrelevant because we can fetch the most significant bits of c
114 // before actually computing c.
115 // We choose the primes marked with an asterisk.
116
117
118 #if !(intDsize==32)
119 #error "fft mod p implemented only for intDsize==32"
120 #endif
121
122 // Avoid clash with fftp3
123 #define p1 fftp3m_p1
124 #define p2 fftp3m_p2
125 #define p3 fftp3m_p3
126 #define n1 fftp3m_n1
127 #define n2 fftp3m_n2
128 #define n3 fftp3m_n3
129
130 static const uint32 p1 = 1+(3<<30); // = 3221225473
131 static const uint32 p2 = 1+(15<<27); // = 2013265921
132 static const uint32 p3 = 1+(7<<26); // = 469762049
133 static const uint32 n1 = 30; // Montgomery: represent x mod p1 as x*2^n1 mod p1
134 static const uint32 n2 = 27; // Montgomery: represent x mod p2 as x*2^n2 mod p2
135 static const uint32 n3 = 26; // Montgomery: represent x mod p3 as x*2^n3 mod p3
136
137 typedef struct {
138         uint32 w1; // remainder mod p1
139         uint32 w2; // remainder mod p2
140         uint32 w3; // remainder mod p3
141 } fftp3m_word;
142
143 static const fftp3m_word fftp3m_roots_of_1 [26+1] =
144   // roots_of_1[n] is a (2^n)th root of unity in our ring.
145   // (Also roots_of_1[n-1] = roots_of_1[n]^2, but we don't need this.)
146   {
147     #if 0 // in standard representation
148     {          1,          1,          1 },
149     { 3221225472, 2013265920,  469762048 },
150     { 1013946479,  284861408,   19610091 },
151     { 1031213943,  211723194,   26623616 },
152     {  694614138,   78945800,  111570435 },
153     {  347220834,  772607190,  135956445 },
154     {  680684264,  288289890,  181505383 },
155     { 1109768284,  112574482,  145518049 },
156     {  602134989,  928726468,  109721424 },
157     { 1080308101,  875419223,    2847903 },
158     {  381653707,  510575142,  110273149 },
159     {  902453688,  193023072,   65701394 },
160     { 1559299664,  313561437,  181642641 },
161     {  254499731,  121307056,   82315502 },
162     { 1376063215,   20899142,  142137197 },
163     { 1284040478,  956809618,  207661045 },
164     {  336664489,  317295870,  194405005 },
165     {  894491787,  785393806,    2821902 },
166     {  795860341,  738526384,  230963948 },
167     {   23880336,  956561758,   59211404 },
168     {  790585193,  352904935,   95374542 },
169     {  877386874,  836313293,  153165757 },
170     { 1510644826,  971592443,   74027009 },
171     {  353060343,  692611595,   24417505 },
172     {  716717815,  791167605,   26032760 },
173     { 1020271667,  751686895,  150976424 },
174     {  139914905,  477826617,   71902965 }
175     #else // in Montgomery representation
176     { 1073741824,  134217728,   67108864 },
177     { 2147483649, 1879048193,  402653185 },
178     { 1809501489, 1054751064,  265634015 },
179     { 2877487492, 1193844673,  331740947 },
180     { 2989687427,  665825587,  252496823 },
181     { 3105485195, 1961758775,  114795379 },
182     { 1920588894, 1994046595,  175397252 },
183     {  703819063,  932019131,  314756028 },
184     { 3020513810, 1682915367,   51434375 },
185     {  713639124, 1015380543,  133810885 },
186     {  946523922, 1576574394,  454008742 },
187     { 2920407577, 1597744532,  191940679 },
188     { 1627717094, 1589708641,  309595372 },
189     { 2062650405,  126130591,  189567235 },
190     {  615054086,  267042180,  382347871 },
191     { 1719470156, 1681043157,  238769593 },
192     {  961520328, 1992112863,  240663313 },
193     { 2923061544,   81858141,  402250056 },
194     {  808455044,  487635820,  302549471 },
195     { 3213265361, 1681059681,  461303277 },
196     { 1883955251, 1318650285,  254810522 },
197     {  781279533, 1017987605,  179445770 },
198     {  570193549, 1008968995,  459186762 },
199     { 3103538692,  624914534,  466273834 },
200     {  834835886, 1960521414,  331825355 },
201     { 1807393093, 1292064821,  246867396 },
202     { 2100845347, 1578757629,   56837012 }
203     #endif
204   };
205
206 // Define this for (cheap) consistency checks.
207 //#define DEBUG_FFTP3M
208
209 // Define this for extensive consistency checks.
210 //#define DEBUG_FFTP3M_OPERATIONS
211
212 // Define the algorithm of the backward FFT:
213 // Either FORWARD (a normal FFT followed by a permutation)
214 // or     RECIPROOT (an FFT with reciprocal root of unity)
215 // or     CLEVER (an FFT with reciprocal root of unity but clever computation
216 //                of the reciprocals).
217 // Drawback of FORWARD: the permutation pass.
218 // Drawback of RECIPROOT: need all the powers of the root, not only half of them.
219 #define FORWARD   42
220 #define RECIPROOT 43
221 #define CLEVER    44
222 #define FFTP3M_BACKWARD CLEVER
223
224 #ifdef DEBUG_FFTP3M_OPERATIONS
225 #define check_fftp3m_word(x)  if ((x.w1 >= p1) || (x.w2 >= p2) || (x.w3 >= p3)) cl_abort()
226 #else
227 #define check_fftp3m_word(x)
228 #endif
229
230 // r := 0 mod p
231 static inline void zerop3m (fftp3m_word& r)
232 {
233         r.w1 = 0;
234         r.w2 = 0;
235         r.w3 = 0;
236 }
237
238 // r := x mod p
239 static inline void setp3m (uint32 x, fftp3m_word& r)
240 {
241         var uint32 hi;
242         var uint32 lo;
243         hi = x >> (32-n1); lo = x << n1; divu_6432_3232(hi,lo,p1, ,r.w1=);
244         hi = x >> (32-n2); lo = x << n2; divu_6432_3232(hi,lo,p2, ,r.w2=);
245         hi = x >> (32-n3); lo = x << n3; divu_6432_3232(hi,lo,p3, ,r.w3=);
246 }
247
248 // Chinese remainder theorem:
249 // (Z / p1 Z) x (Z / p2 Z) x (Z / p3 Z) == Z / p1*p2*p3 Z = Z / P Z.
250 // Return r as an integer >= 0, < p1*p2*p3, as 3-digit-sequence res.
251 // This routine also does the "de-Montgomerizing".
252 static void combinep3m (const fftp3m_word& r, uintD* resLSDptr)
253 {
254         check_fftp3m_word(r);
255         // Compute e1 * v1 * r.w1 + e2 * v2 * r.w2 + e3 * v3 * r.w3 where
256         // vi == 2^-ni mod pi, and the idempotents ei are found as:
257         // xgcd(pi,p/pi) = 1 = ui*pi + vi*P/pi, ei = 1 - ui*pi.
258         // e1 = 1709008312966733882383995583
259         // e2 = 2781580629833601225216537109
260         // e3 = 1602397205945693664242711343
261         // e1*v1 = 965961209845827124691257285
262         // e2*v2 = 927193593718183024654651603
263         // e3*v3 = 969191855872201893987508667
264         // We will have 0 <= e1*v1 * r.w1 + e2*v2 * r.w2 + e3*v3 * r.w3 <
265         // < e1*v1 * p1 + e2*v2 * p2 + e3*v3 * p3 < 3 * 2^32 * p1*p2*p3 < 2^128.
266         // The sum of the products fits in 4 digits, we divide by p1*p2*p3
267         // as a 3-digit sequence, thus getting the remainder.
268         #if 0
269         #if CL_DS_BIG_ENDIAN_P
270         var const uintD p123 [3] = { 0x09D80000, 0x7C200001, 0x54000001 };
271         var const uintD e1v1 [3] = { 0x031F063E, 0x1CD1F37E, 0x20E0C7C5 };
272         var const uintD e2v2 [3] = { 0x02FEF4E1, 0x6E62C875, 0x788590D3 };
273         var const uintD e3v3 [3] = { 0x0321B25B, 0xC8DB371B, 0xF0E861BB };
274         #else
275         var const uintD p123 [3] = { 0x54000001, 0x7C200001, 0x09D80000 };
276         var const uintD e1v1 [3] = { 0x20E0C7C5, 0x1CD1F37E, 0x031F063E };
277         var const uintD e2v2 [3] = { 0x788590D3, 0x6E62C875, 0x02FEF4E1 };
278         var const uintD e3v3 [3] = { 0xF0E861BB, 0xC8DB371B, 0x0321B25B };
279         #endif
280         #else
281         // The final division step requires a shift left by 4 bits in order
282         // to normalize p1*p2*p3. We combine this shift left with the
283         // multiplications. Note that since e1v1 + e2v2 + e3v3 < p1*p2*p3,
284         // there is no risk of overflow.
285         #if CL_DS_BIG_ENDIAN_P
286         var const uintD p123 [3] = { 0x9D800007, 0xC2000015, 0x40000010 };
287         var const uintD e1v1 [3] = { 0x31F063E1, 0xCD1F37E2, 0x0E0C7C50 };
288         var const uintD e2v2 [3] = { 0x2FEF4E16, 0xE62C8757, 0x88590D30 };
289         var const uintD e3v3 [3] = { 0x321B25BC, 0x8DB371BF, 0x0E861BB0 };
290         #else
291         var const uintD p123 [3] = { 0x40000010, 0xC2000015, 0x9D800007 };
292         var const uintD e1v1 [3] = { 0x0E0C7C50, 0xCD1F37E2, 0x31F063E1 };
293         var const uintD e2v2 [3] = { 0x88590D30, 0xE62C8757, 0x2FEF4E16 };
294         var const uintD e3v3 [3] = { 0x0E861BB0, 0x8DB371BF, 0x321B25BC };
295         #endif
296         #endif
297         var uintD sum [4];
298         var uintD* const sumLSDptr = arrayLSDptr(sum,4);
299         mulu_loop_lsp(r.w1,arrayLSDptr(e1v1,3), sumLSDptr,3);
300         lspref(sumLSDptr,3) += muluadd_loop_lsp(r.w2,arrayLSDptr(e2v2,3), sumLSDptr,3);
301         lspref(sumLSDptr,3) += muluadd_loop_lsp(r.w3,arrayLSDptr(e3v3,3), sumLSDptr,3);
302         #if 0
303         {CL_ALLOCA_STACK;
304          var DS q;
305          var DS r;
306          UDS_divide(arrayMSDptr(sum,4),4,arrayLSDptr(sum,4),
307                     arrayMSDptr(p123,3),3,arrayLSDptr(p123,3),
308                     &q,&r
309                    );
310          ASSERT(q.len <= 1)
311          ASSERT(r.len <= 3)
312          copy_loop_lsp(r.LSDptr,arrayLSDptr(sum,4),r.len);
313          DS_clear_loop(arrayMSDptr(sum,4) mspop 1,3-r.len,arrayLSDptr(sum,4) lspop r.len);
314         }
315         #else
316         // Division wie UDS_divide mit a_len=4, b_len=3.
317         {
318                 var uintD q_stern;
319                 var uintD c1;
320                 #if HAVE_DD
321                   divuD(highlowDD(lspref(sumLSDptr,3),lspref(sumLSDptr,2)),lspref(arrayLSDptr(p123,3),2), q_stern=,c1=);
322                   { var uintDD c2 = highlowDD(c1,lspref(sumLSDptr,1));
323                     var uintDD c3 = muluD(lspref(arrayLSDptr(p123,3),1),q_stern);
324                     if (c3 > c2)
325                       { q_stern = q_stern-1;
326                         if (c3-c2 > highlowDD(lspref(arrayLSDptr(p123,3),2),lspref(arrayLSDptr(p123,3),1)))
327                           { q_stern = q_stern-1; }
328                   }   }
329                 #else
330                   divuD(lspref(sumLSDptr,3),lspref(sumLSDptr,2),lspref(arrayLSDptr(p123,3),2), q_stern=,c1=);
331                   { var uintD c2lo = lspref(sumLSDptr,1);
332                     var uintD c3hi;
333                     var uintD c3lo;
334                     muluD(lspref(arrayLSDptr(p123,3),1),q_stern, c3hi=,c3lo=);
335                     if ((c3hi > c1) || ((c3hi == c1) && (c3lo > c2lo)))
336                       { q_stern = q_stern-1;
337                         c3hi -= c1; if (c3lo < c2lo) { c3hi--; }; c3lo -= c2lo;
338                         if ((c3hi > lspref(arrayLSDptr(p123,3),2)) || ((c3hi == lspref(arrayLSDptr(p123,3),2)) && (c3lo > lspref(arrayLSDptr(p123,3),1))))
339                           { q_stern = q_stern-1; }
340                    }   }
341                 #endif
342                 if (!(q_stern==0))
343                   { var uintD carry = mulusub_loop_lsp(q_stern,arrayLSDptr(p123,3),sumLSDptr,3);
344                     if (carry > lspref(sumLSDptr,3))
345                       { q_stern = q_stern-1;
346                         addto_loop_lsp(arrayLSDptr(p123,3),sumLSDptr,3);
347                   }   }
348         }
349         #endif
350         #ifdef DEBUG_FFTP3M_OPERATIONS
351         if (compare_loop_msp(sumLSDptr lspop 3,arrayMSDptr(p123,3),3) >= 0)
352                 cl_abort();
353         #endif
354         // Renormalize the division's remainder: shift right by 4 bits.
355         shiftrightcopy_loop_msp(sumLSDptr lspop 3,resLSDptr lspop 3,3,4,0);
356 }
357
358 // r := (a + b) mod p
359 static inline void addp3m (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& r)
360 {
361         var uint32 x;
362
363         check_fftp3m_word(a); check_fftp3m_word(b);
364         // Add single 32-bit words mod pi.
365         if (((x = (a.w1 + b.w1)) < b.w1) || (x >= p1))
366                 x -= p1;
367         r.w1 = x;
368         if ((x = (a.w2 + b.w2)) >= p2) // x doesn't overflow since p2 <= 2^31
369                 x -= p2;
370         r.w2 = x;
371         if ((x = (a.w3 + b.w3)) >= p3) // x doesn't overflow since p3 <= 2^31
372                 x -= p3;
373         r.w3 = x;
374         check_fftp3m_word(r);
375 }
376
377 // r := (a - b) mod p
378 static inline void subp3m (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& r)
379 {
380         check_fftp3m_word(a); check_fftp3m_word(b);
381         // Subtract single 32-bit words mod pi.
382         r.w1 = (a.w1 < b.w1 ? a.w1-b.w1+p1 : a.w1-b.w1);
383         r.w2 = (a.w2 < b.w2 ? a.w2-b.w2+p2 : a.w2-b.w2);
384         r.w3 = (a.w3 < b.w3 ? a.w3-b.w3+p3 : a.w3-b.w3);
385         check_fftp3m_word(r);
386 }
387
388 // r := (a * b) mod p
389 static void mulp3m (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& res)
390 {
391         check_fftp3m_word(a); check_fftp3m_word(b);
392         // Multiplication à la Montgomery:
393         #define mul_mod_p(aw,bw,result_zuweisung,p,m,n,j,js,table)  \
394         {       /* table[i] == i*2^(m-1) mod p for 0 <= i < 2^(m-n+1) */\
395                 var uint32 hi;                                          \
396                 var uint32 lo;                                          \
397                 mulu32(aw,bw, hi=,lo=);                                 \
398                 /* hi has 2m-32 bits */                                 \
399                 var const int l = (m-1)-(32-n);                         \
400                 var uint32 r = table[hi>>l];                            \
401                 hi = ((hi << (32-l)) >> (n-l)) | (lo >> n);             \
402                 /* hi = c mod 2^(m-1), has m-1 bits */                  \
403                 lo = lo & (bit(n)-1);                                   \
404                 /* lo = d, has n bits */                                \
405                 lo = (lo << js) - lo;                                   \
406                 /* lo = d*|V|, has m bits */                            \
407                 /* Finally compute (r + hi - lo) mod p. */              \
408                 if (m < 32) {                                           \
409                         r += hi;                                        \
410                         if (r >= p)                                     \
411                                 { r = r - p; }                          \
412                 } else {                                                \
413                         if (((r += hi) < hi) || (r >= p))               \
414                                 { r = r - p; }                          \
415                 }                                                       \
416                 r = (r < lo ? r-lo+p : r-lo);                           \
417                 /* ifdef DEBUG_FFTP3M_OPERATIONS *                      \
418                 var uint32 tmp;                                         \
419                 mulu32(aw,bw, hi=,lo=);                                 \
420                 divu_6432_3232(hi,lo,p, ,tmp=);                         \
421                 mulu32(tmp,j, hi=, lo=);                                \
422                 divu_6432_3232(hi,lo,p, ,tmp=);                         \
423                 if (tmp != 0) { tmp = p-tmp; }                          \
424                 if (tmp != r)                                           \
425                         cl_abort();                                     \
426                  * endif DEBUG_FFTP3M_OPERATIONS */                     \
427                 result_zuweisung r;                                     \
428         }
429         // p1 = 3*2^30+1, n1 = 30, j1 = 3 = 2^2-1
430         static uint32 table1 [8] =
431           {          0, 2147483648, 1073741823, 3221225471,
432             2147483646, 1073741821, 3221225469, 2147483644
433           };
434         mul_mod_p(a.w1,b.w1,res.w1=,p1,32,30,3,2,table1);
435         // p2 = 15*2^27+1, n2 = 27, j2 = 15 = 2^4-1
436         static uint32 table2 [32] =
437           {          0, 1073741824,  134217727, 1207959551,
438              268435454, 1342177278,  402653181, 1476395005,
439              536870908, 1610612732,  671088635, 1744830459,
440              805306362, 1879048186,  939524089, 2013265913,
441             1073741816,  134217719, 1207959543,  268435446,
442             1342177270,  402653173, 1476394997,  536870900,
443             1610612724,  671088627, 1744830451,  805306354,
444             1879048178,  939524081, 2013265905, 1073741808
445           };
446         mul_mod_p(a.w2,b.w2,res.w2=,p2,31,27,15,4,table2);
447         // p3 = 7*2^26+1, n3 = 26, j3 = 7 = 2^3-1
448         static uint32 table3 [16] =
449           {          0,  268435456,   67108863,  335544319,
450              134217726,  402653182,  201326589,  469762045,
451              268435452,   67108859,  335544315,  134217722,
452              402653178,  201326585,  469762041,  268435448
453           };
454         mul_mod_p(a.w3,b.w3,res.w3=,p3,29,26,7,3,table3);
455         #undef mul_mod_p
456         check_fftp3m_word(res);
457 }
458 #ifdef DEBUG_FFTP3M_OPERATIONS
459 static void mulp3m_doublecheck (const fftp3m_word& a, const fftp3m_word& b, fftp3m_word& r)
460 {
461         fftp3m_word zero, ma, mb, or;
462         zerop3m(zero);
463         subp3m(zero,a, ma);
464         subp3m(zero,b, mb);
465         mulp3m(ma,mb, or);
466         mulp3m(a,b, r);
467         if (!((r.w1 == or.w1) && (r.w2 == or.w2) && (r.w3 == or.w3)))
468                 cl_abort();
469 }
470 #define mulp3m mulp3m_doublecheck
471 #endif /* DEBUG_FFTP3M_OPERATIONS */
472
473 // b := (a / 2) mod p
474 static inline void shiftp3m (const fftp3m_word& a, fftp3m_word& b)
475 {
476         check_fftp3m_word(a);
477         b.w1 = (a.w1 & 1 ? (a.w1 >> 1) + (p1 >> 1) + 1 : (a.w1 >> 1));
478         b.w2 = (a.w2 & 1 ? (a.w2 >> 1) + (p2 >> 1) + 1 : (a.w2 >> 1));
479         b.w3 = (a.w3 & 1 ? (a.w3 >> 1) + (p3 >> 1) + 1 : (a.w3 >> 1));
480         check_fftp3m_word(b);
481 }
482
483 #ifndef _BIT_REVERSE
484 #define _BIT_REVERSE
485 // Reverse an n-bit number x. n>0.
486 static uintL bit_reverse (uintL n, uintL x)
487 {
488         var uintL y = 0;
489         do {
490                 y <<= 1;
491                 y |= (x & 1);
492                 x >>= 1;
493         } while (!(--n == 0));
494         return y;
495 }
496 #endif
497
498 // Compute an convolution mod p using FFT: z[0..N-1] := x[0..N-1] * y[0..N-1].
499 static void fftp3m_convolution (const uintL n, const uintL N, // N = 2^n
500                                 fftp3m_word * x, // N words
501                                 fftp3m_word * y, // N words
502                                 fftp3m_word * z  // N words result
503                                )
504 {
505         CL_ALLOCA_STACK;
506         #if (FFTP3M_BACKWARD == RECIPROOT) || defined(DEBUG_FFTP3M)
507         var fftp3m_word* const w = cl_alloc_array(fftp3m_word,N);
508         #else
509         var fftp3m_word* const w = cl_alloc_array(fftp3m_word,(N>>1)+1);
510         #endif
511         var uintL i;
512         // Initialize w[i] to w^i, w a primitive N-th root of unity.
513         w[0] = fftp3m_roots_of_1[0];
514         w[1] = fftp3m_roots_of_1[n];
515         #if (FFTP3M_BACKWARD == RECIPROOT) || defined(DEBUG_FFTP3M)
516         for (i = 2; i < N; i++)
517                 mulp3m(w[i-1],fftp3m_roots_of_1[n], w[i]);
518         #else // need only half of the roots
519         for (i = 2; i < N>>1; i++)
520                 mulp3m(w[i-1],fftp3m_roots_of_1[n], w[i]);
521         #endif
522         #ifdef DEBUG_FFTP3M
523         // Check that w is really a primitive N-th root of unity.
524         {
525                 var fftp3m_word w_N;
526                 mulp3m(w[N-1],fftp3m_roots_of_1[n], w_N);
527                 if (!(   w_N.w1 == (uint32)1<<n1
528                       && w_N.w2 == (uint32)1<<n2
529                       && w_N.w3 == (uint32)1<<n3))
530                         cl_abort();
531                 w_N = w[N>>1];
532                 if (!(   w_N.w1 == p1-((uint32)1<<n1)
533                       && w_N.w2 == p2-((uint32)1<<n2)
534                       && w_N.w3 == p3-((uint32)1<<n3)))
535                         cl_abort();
536         }
537         #endif
538         var bool squaring = (x == y);
539         // Do an FFT of length N on x.
540         {
541                 var sintL l;
542                 /* l = n-1 */ {
543                         var const uintL tmax = N>>1; // tmax = 2^(n-1)
544                         for (var uintL t = 0; t < tmax; t++) {
545                                 var uintL i1 = t;
546                                 var uintL i2 = i1 + tmax;
547                                 // Butterfly: replace (x(i1),x(i2)) by
548                                 // (x(i1) + x(i2), x(i1) - x(i2)).
549                                 var fftp3m_word tmp;
550                                 tmp = x[i2];
551                                 subp3m(x[i1],tmp, x[i2]);
552                                 addp3m(x[i1],tmp, x[i1]);
553                         }
554                 }
555                 for (l = n-2; l>=0; l--) {
556                         var const uintL smax = (uintL)1 << (n-1-l);
557                         var const uintL tmax = (uintL)1 << l;
558                         for (var uintL s = 0; s < smax; s++) {
559                                 var uintL exp = bit_reverse(n-1-l,s) << l;
560                                 for (var uintL t = 0; t < tmax; t++) {
561                                         var uintL i1 = (s << (l+1)) + t;
562                                         var uintL i2 = i1 + tmax;
563                                         // Butterfly: replace (x(i1),x(i2)) by
564                                         // (x(i1) + w^exp*x(i2), x(i1) - w^exp*x(i2)).
565                                         var fftp3m_word tmp;
566                                         mulp3m(x[i2],w[exp], tmp);
567                                         subp3m(x[i1],tmp, x[i2]);
568                                         addp3m(x[i1],tmp, x[i1]);
569                                 }
570                         }
571                 }
572         }
573         // Do an FFT of length N on y.
574         if (!squaring) {
575                 var sintL l;
576                 /* l = n-1 */ {
577                         var uintL const tmax = N>>1; // tmax = 2^(n-1)
578                         for (var uintL t = 0; t < tmax; t++) {
579                                 var uintL i1 = t;
580                                 var uintL i2 = i1 + tmax;
581                                 // Butterfly: replace (y(i1),y(i2)) by
582                                 // (y(i1) + y(i2), y(i1) - y(i2)).
583                                 var fftp3m_word tmp;
584                                 tmp = y[i2];
585                                 subp3m(y[i1],tmp, y[i2]);
586                                 addp3m(y[i1],tmp, y[i1]);
587                         }
588                 }
589                 for (l = n-2; l>=0; l--) {
590                         var const uintL smax = (uintL)1 << (n-1-l);
591                         var const uintL tmax = (uintL)1 << l;
592                         for (var uintL s = 0; s < smax; s++) {
593                                 var uintL exp = bit_reverse(n-1-l,s) << l;
594                                 for (var uintL t = 0; t < tmax; t++) {
595                                         var uintL i1 = (s << (l+1)) + t;
596                                         var uintL i2 = i1 + tmax;
597                                         // Butterfly: replace (y(i1),y(i2)) by
598                                         // (y(i1) + w^exp*y(i2), y(i1) - w^exp*y(i2)).
599                                         var fftp3m_word tmp;
600                                         mulp3m(y[i2],w[exp], tmp);
601                                         subp3m(y[i1],tmp, y[i2]);
602                                         addp3m(y[i1],tmp, y[i1]);
603                                 }
604                         }
605                 }
606         }
607         // Multiply the transformed vectors into z.
608         for (i = 0; i < N; i++)
609                 mulp3m(x[i],y[i], z[i]);
610         // Undo an FFT of length N on z.
611         {
612                 var uintL l;
613                 for (l = 0; l < n-1; l++) {
614                         var const uintL smax = (uintL)1 << (n-1-l);
615                         var const uintL tmax = (uintL)1 << l;
616                         #if FFTP3M_BACKWARD != CLEVER
617                         for (var uintL s = 0; s < smax; s++) {
618                                 var uintL exp = bit_reverse(n-1-l,s) << l;
619                                 #if FFTP3M_BACKWARD == RECIPROOT
620                                 if (exp > 0)
621                                         exp = N - exp; // negate exp (use w^-1 instead of w)
622                                 #endif
623                                 for (var uintL t = 0; t < tmax; t++) {
624                                         var uintL i1 = (s << (l+1)) + t;
625                                         var uintL i2 = i1 + tmax;
626                                         // Inverse Butterfly: replace (z(i1),z(i2)) by
627                                         // ((z(i1)+z(i2))/2, (z(i1)-z(i2))/(2*w^exp)).
628                                         var fftp3m_word sum;
629                                         var fftp3m_word diff;
630                                         addp3m(z[i1],z[i2], sum);
631                                         subp3m(z[i1],z[i2], diff);
632                                         shiftp3m(sum, z[i1]);
633                                         mulp3m(diff,w[exp], diff); shiftp3m(diff, z[i2]);
634                                 }
635                         }
636                         #else // FFTP3M_BACKWARD == CLEVER: clever handling of negative exponents
637                         /* s = 0, exp = 0 */ {
638                                 for (var uintL t = 0; t < tmax; t++) {
639                                         var uintL i1 = t;
640                                         var uintL i2 = i1 + tmax;
641                                         // Inverse Butterfly: replace (z(i1),z(i2)) by
642                                         // ((z(i1)+z(i2))/2, (z(i1)-z(i2))/(2*w^exp)),
643                                         // with exp <-- 0.
644                                         var fftp3m_word sum;
645                                         var fftp3m_word diff;
646                                         addp3m(z[i1],z[i2], sum);
647                                         subp3m(z[i1],z[i2], diff);
648                                         shiftp3m(sum, z[i1]);
649                                         shiftp3m(diff, z[i2]);
650                                 }
651                         }
652                         for (var uintL s = 1; s < smax; s++) {
653                                 var uintL exp = bit_reverse(n-1-l,s) << l;
654                                 exp = (N>>1) - exp; // negate exp (use w^-1 instead of w)
655                                 for (var uintL t = 0; t < tmax; t++) {
656                                         var uintL i1 = (s << (l+1)) + t;
657                                         var uintL i2 = i1 + tmax;
658                                         // Inverse Butterfly: replace (z(i1),z(i2)) by
659                                         // ((z(i1)+z(i2))/2, (z(i1)-z(i2))/(2*w^exp)),
660                                         // with exp <-- (N/2 - exp).
661                                         var fftp3m_word sum;
662                                         var fftp3m_word diff;
663                                         addp3m(z[i1],z[i2], sum);
664                                         subp3m(z[i2],z[i1], diff); // note that w^(N/2) = -1
665                                         shiftp3m(sum, z[i1]);
666                                         mulp3m(diff,w[exp], diff); shiftp3m(diff, z[i2]);
667                                 }
668                         }
669                         #endif
670                 }
671                 /* l = n-1 */ {
672                         var const uintL tmax = N>>1; // tmax = 2^(n-1)
673                         for (var uintL t = 0; t < tmax; t++) {
674                                 var uintL i1 = t;
675                                 var uintL i2 = i1 + tmax;
676                                 // Inverse Butterfly: replace (z(i1),z(i2)) by
677                                 // ((z(i1)+z(i2))/2, (z(i1)-z(i2))/2).
678                                 var fftp3m_word sum;
679                                 var fftp3m_word diff;
680                                 addp3m(z[i1],z[i2], sum);
681                                 subp3m(z[i1],z[i2], diff);
682                                 shiftp3m(sum, z[i1]);
683                                 shiftp3m(diff, z[i2]);
684                         }
685                 }
686         }
687         #if FFTP3M_BACKWARD == FORWARD
688         // Swap z[i] and z[N-i] for 0 < i < N/2.
689         for (i = (N>>1)-1; i > 0; i--) {
690                 var fftp3m_word tmp = z[i];
691                 z[i] = z[N-i];
692                 z[N-i] = tmp;
693         }
694         #endif
695 }
696
697 static void mulu_fft_modp3m (const uintD* sourceptr1, uintC len1,
698                              const uintD* sourceptr2, uintC len2,
699                              uintD* destptr)
700 // Es ist 2 <= len1 <= len2.
701 {
702         // Methode:
703         // source1 ist ein Stück der Länge N1, source2 ein oder mehrere Stücke
704         // der Länge N2, mit N1+N2 <= N, wobei N Zweierpotenz ist.
705         // sum(i=0..N-1, x_i b^i) * sum(i=0..N-1, y_i b^i) wird errechnet,
706         // indem man die beiden Polynome
707         // sum(i=0..N-1, x_i T^i), sum(i=0..N-1, y_i T^i)
708         // multipliziert, und zwar durch Fourier-Transformation (s.o.).
709         var uint32 n;
710         integerlength32(len1-1, n=); // 2^(n-1) < len1 <= 2^n
711         var uintL len = (uintL)1 << n; // kleinste Zweierpotenz >= len1
712         // Wählt man N = len, so hat man ceiling(len2/(len-len1+1)) * FFT(len).
713         // Wählt man N = 2*len, so hat man ceiling(len2/(2*len-len1+1)) * FFT(2*len).
714         // Wir wählen das billigere von beiden:
715         // Bei ceiling(len2/(len-len1+1)) <= 2 * ceiling(len2/(2*len-len1+1))
716         // nimmt man N = len, bei ....... > ........ dagegen N = 2*len.
717         // (Wahl von N = 4*len oder mehr bringt nur in Extremfällen etwas.)
718         if (len2 > 2 * (len-len1+1) * (len2 <= (2*len-len1+1) ? 1 : ceiling(len2,(2*len-len1+1)))) {
719                 n = n+1;
720                 len = len << 1;
721         }
722         var const uintL N = len; // N = 2^n
723         CL_ALLOCA_STACK;
724         var fftp3m_word* const x = cl_alloc_array(fftp3m_word,N);
725         var fftp3m_word* const y = cl_alloc_array(fftp3m_word,N);
726         #ifdef DEBUG_FFTP3M
727         var fftp3m_word* const z = cl_alloc_array(fftp3m_word,N);
728         #else
729         var fftp3m_word* const z = x; // put z in place of x - saves memory
730         #endif
731         var uintD* const tmpprod = cl_alloc_array(uintD,len1+1);
732         var uintP i;
733         var uintL destlen = len1+len2;
734         clear_loop_lsp(destptr,destlen);
735         do {
736                 var uintL len2p; // length of a piece of source2
737                 len2p = N - len1 + 1;
738                 if (len2p > len2)
739                         len2p = len2;
740                 // len2p = min(N-len1+1,len2).
741                 if (len2p == 1) {
742                         // cheap case
743                         var uintD* tmpptr = arrayLSDptr(tmpprod,len1+1);
744                         mulu_loop_lsp(lspref(sourceptr2,0),sourceptr1,tmpptr,len1);
745                         if (addto_loop_lsp(tmpptr,destptr,len1+1))
746                                 if (inc_loop_lsp(destptr lspop (len1+1),destlen-(len1+1)))
747                                         cl_abort();
748                 } else {
749                         var uintL destlenp = len1 + len2p - 1;
750                         // destlenp = min(N,destlen-1).
751                         var bool squaring = ((sourceptr1 == sourceptr2) && (len1 == len2p));
752                         // Fill factor x.
753                         {
754                                 for (i = 0; i < len1; i++)
755                                         setp3m(lspref(sourceptr1,i), x[i]);
756                                 for (i = len1; i < N; i++)
757                                         zerop3m(x[i]);
758                         }
759                         // Fill factor y.
760                         if (!squaring) {
761                                 for (i = 0; i < len2p; i++)
762                                         setp3m(lspref(sourceptr2,i), y[i]);
763                                 for (i = len2p; i < N; i++)
764                                         zerop3m(y[i]);
765                         }
766                         // Multiply.
767                         if (!squaring)
768                                 fftp3m_convolution(n,N, &x[0], &y[0], &z[0]);
769                         else
770                                 fftp3m_convolution(n,N, &x[0], &x[0], &z[0]);
771                         // Add result to destptr[-destlen..-1]:
772                         {
773                                 var uintD* ptr = destptr;
774                                 // ac2|ac1|ac0 are an accumulator.
775                                 var uint32 ac0 = 0;
776                                 var uint32 ac1 = 0;
777                                 var uint32 ac2 = 0;
778                                 var uint32 tmp;
779                                 for (i = 0; i < destlenp; i++) {
780                                         // Convert z[i] to a 3-digit number.
781                                         var uintD z_i[3];
782                                         combinep3m(z[i],arrayLSDptr(z_i,3));
783                                         #ifdef DEBUG_FFTP3M
784                                         if (!(arrayLSref(z_i,3,2) < N))
785                                                 cl_abort();
786                                         #endif
787                                         // Add z[i] to the accumulator.
788                                         tmp = arrayLSref(z_i,3,0);
789                                         if ((ac0 += tmp) < tmp) {
790                                                 if (++ac1 == 0)
791                                                         ++ac2;
792                                         }
793                                         tmp = arrayLSref(z_i,3,1);
794                                         if ((ac1 += tmp) < tmp)
795                                                 ++ac2;
796                                         tmp = arrayLSref(z_i,3,2);
797                                         ac2 += tmp;
798                                         // Add the accumulator's least significant word to destptr:
799                                         tmp = lspref(ptr,0);
800                                         if ((ac0 += tmp) < tmp) {
801                                                 if (++ac1 == 0)
802                                                         ++ac2;
803                                         }
804                                         lspref(ptr,0) = ac0;
805                                         lsshrink(ptr);
806                                         ac0 = ac1;
807                                         ac1 = ac2;
808                                         ac2 = 0;
809                                 }
810                                 // ac2 = 0.
811                                 if (ac1 > 0) {
812                                         if (!((i += 2) <= destlen))
813                                                 cl_abort();
814                                         tmp = lspref(ptr,0);
815                                         if ((ac0 += tmp) < tmp)
816                                                 ++ac1;
817                                         lspref(ptr,0) = ac0;
818                                         lsshrink(ptr);
819                                         tmp = lspref(ptr,0);
820                                         ac1 += tmp;
821                                         lspref(ptr,0) = ac1;
822                                         lsshrink(ptr);
823                                         if (ac1 < tmp)
824                                                 if (inc_loop_lsp(ptr,destlen-i))
825                                                         cl_abort();
826                                 } else if (ac0 > 0) {
827                                         if (!((i += 1) <= destlen))
828                                                 cl_abort();
829                                         tmp = lspref(ptr,0);
830                                         ac0 += tmp;
831                                         lspref(ptr,0) = ac0;
832                                         lsshrink(ptr);
833                                         if (ac0 < tmp)
834                                                 if (inc_loop_lsp(ptr,destlen-i))
835                                                         cl_abort();
836                                 }
837                         }
838                         #ifdef DEBUG_FFTP3M
839                         // If destlenp < N, check that the remaining z[i] are 0.
840                         for (i = destlenp; i < N; i++)
841                                 if (z[i].w1 > 0 || z[i].w2 > 0 || z[i].w3 > 0)
842                                         cl_abort();
843                         #endif
844                 }
845                 // Decrement len2.
846                 destptr = destptr lspop len2p;
847                 destlen -= len2p;
848                 sourceptr2 = sourceptr2 lspop len2p;
849                 len2 -= len2p;
850         } while (len2 > 0);
851 }
852
853 #undef n3
854 #undef n2
855 #undef n1
856 #undef p3
857 #undef p2
858 #undef p1