ecf537b010968166d75f5f0488816bbb5caa6251
[ginac.git] / ginac / factor.cpp
1 /** @file factor.cpp
2  *
3  *  Polynomial factorization code (implementation).
4  *
5  *  Algorithms used can be found in
6  *    [W1]  An Improved Multivariate Polynomial Factoring Algorithm,
7  *          P.S.Wang, Mathematics of Computation, Vol. 32, No. 144 (1978) 1215--1231.
8  *    [GCL] Algorithms for Computer Algebra,
9  *          K.O.Geddes, S.R.Czapor, G.Labahn, Springer Verlag, 1992.
10  */
11
12 /*
13  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
14  *
15  *  This program is free software; you can redistribute it and/or modify
16  *  it under the terms of the GNU General Public License as published by
17  *  the Free Software Foundation; either version 2 of the License, or
18  *  (at your option) any later version.
19  *
20  *  This program is distributed in the hope that it will be useful,
21  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
22  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
23  *  GNU General Public License for more details.
24  *
25  *  You should have received a copy of the GNU General Public License
26  *  along with this program; if not, write to the Free Software
27  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
28  */
29
30 //#define DEBUGFACTOR
31
32 #include "factor.h"
33
34 #include "ex.h"
35 #include "numeric.h"
36 #include "operators.h"
37 #include "inifcns.h"
38 #include "symbol.h"
39 #include "relational.h"
40 #include "power.h"
41 #include "mul.h"
42 #include "normal.h"
43 #include "add.h"
44
45 #include <algorithm>
46 #include <cmath>
47 #include <limits>
48 #include <list>
49 #include <vector>
50 #ifdef DEBUGFACTOR
51 #include <ostream>
52 #endif
53 using namespace std;
54
55 #include <cln/cln.h>
56 using namespace cln;
57
58 namespace GiNaC {
59
60 #ifdef DEBUGFACTOR
61 #define DCOUT(str) cout << #str << endl
62 #define DCOUTVAR(var) cout << #var << ": " << var << endl
63 #define DCOUT2(str,var) cout << #str << ": " << var << endl
64 #else
65 #define DCOUT(str)
66 #define DCOUTVAR(var)
67 #define DCOUT2(str,var)
68 #endif
69
70 // anonymous namespace to hide all utility functions
71 namespace {
72
73 typedef vector<cl_MI> mvec;
74 #ifdef DEBUGFACTOR
75 ostream& operator<<(ostream& o, const vector<int>& v)
76 {
77         vector<int>::const_iterator i = v.begin(), end = v.end();
78         while ( i != end ) {
79                 o << *i++ << " ";
80         }
81         return o;
82 }
83 ostream& operator<<(ostream& o, const vector<cl_I>& v)
84 {
85         vector<cl_I>::const_iterator i = v.begin(), end = v.end();
86         while ( i != end ) {
87                 o << *i << "[" << i-v.begin() << "]" << " ";
88                 ++i;
89         }
90         return o;
91 }
92 ostream& operator<<(ostream& o, const vector<cl_MI>& v)
93 {
94         vector<cl_MI>::const_iterator i = v.begin(), end = v.end();
95         while ( i != end ) {
96                 o << *i << "[" << i-v.begin() << "]" << " ";
97                 ++i;
98         }
99         return o;
100 }
101 ostream& operator<<(ostream& o, const vector< vector<cl_MI> >& v)
102 {
103         vector< vector<cl_MI> >::const_iterator i = v.begin(), end = v.end();
104         while ( i != end ) {
105                 o << i-v.begin() << ": " << *i << endl;
106                 ++i;
107         }
108         return o;
109 }
110 #endif
111
112 ////////////////////////////////////////////////////////////////////////////////
113 // modular univariate polynomial code
114
115 typedef std::vector<cln::cl_MI> umodpoly;
116 typedef std::vector<cln::cl_I> upoly;
117 typedef vector<umodpoly> upvec;
118
119 // COPY FROM UPOLY.HPP
120
121 // CHANGED size_t -> int !!!
122 template<typename T> static int degree(const T& p)
123 {
124         return p.size() - 1;
125 }
126
127 template<typename T> static typename T::value_type lcoeff(const T& p)
128 {
129         return p[p.size() - 1];
130 }
131
132 static bool normalize_in_field(umodpoly& a)
133 {
134         if (a.size() == 0)
135                 return true;
136         if ( lcoeff(a) == a[0].ring()->one() ) {
137                 return true;
138         }
139
140         const cln::cl_MI lc_1 = recip(lcoeff(a));
141         for (std::size_t k = a.size(); k-- != 0; )
142                 a[k] = a[k]*lc_1;
143         return false;
144 }
145
146 template<typename T> static void
147 canonicalize(T& p, const typename T::size_type hint = std::numeric_limits<typename T::size_type>::max())
148 {
149         if (p.empty())
150                 return;
151
152         std::size_t i = p.size() - 1;
153         // Be fast if the polynomial is already canonicalized
154         if (!zerop(p[i]))
155                 return;
156
157         if (hint < p.size())
158                 i = hint;
159
160         bool is_zero = false;
161         do {
162                 if (!zerop(p[i])) {
163                         ++i;
164                         break;
165                 }
166                 if (i == 0) {
167                         is_zero = true;
168                         break;
169                 }
170                 --i;
171         } while (true);
172
173         if (is_zero) {
174                 p.clear();
175                 return;
176         }
177
178         p.erase(p.begin() + i, p.end());
179 }
180
181 // END COPY FROM UPOLY.HPP
182
183 static void expt_pos(umodpoly& a, unsigned int q)
184 {
185         if ( a.empty() ) return;
186         cl_MI zero = a[0].ring()->zero(); 
187         int deg = degree(a);
188         a.resize(degree(a)*q+1, zero);
189         for ( int i=deg; i>0; --i ) {
190                 a[i*q] = a[i];
191                 a[i] = zero;
192         }
193 }
194
195 template<typename T>
196 static T operator+(const T& a, const T& b)
197 {
198         int sa = a.size();
199         int sb = b.size();
200         if ( sa >= sb ) {
201                 T r(sa);
202                 int i = 0;
203                 for ( ; i<sb; ++i ) {
204                         r[i] = a[i] + b[i];
205                 }
206                 for ( ; i<sa; ++i ) {
207                         r[i] = a[i];
208                 }
209                 canonicalize(r);
210                 return r;
211         }
212         else {
213                 T r(sb);
214                 int i = 0;
215                 for ( ; i<sa; ++i ) {
216                         r[i] = a[i] + b[i];
217                 }
218                 for ( ; i<sb; ++i ) {
219                         r[i] = b[i];
220                 }
221                 canonicalize(r);
222                 return r;
223         }
224 }
225
226 template<typename T>
227 static T operator-(const T& a, const T& b)
228 {
229         int sa = a.size();
230         int sb = b.size();
231         if ( sa >= sb ) {
232                 T r(sa);
233                 int i = 0;
234                 for ( ; i<sb; ++i ) {
235                         r[i] = a[i] - b[i];
236                 }
237                 for ( ; i<sa; ++i ) {
238                         r[i] = a[i];
239                 }
240                 canonicalize(r);
241                 return r;
242         }
243         else {
244                 T r(sb);
245                 int i = 0;
246                 for ( ; i<sa; ++i ) {
247                         r[i] = a[i] - b[i];
248                 }
249                 for ( ; i<sb; ++i ) {
250                         r[i] = -b[i];
251                 }
252                 canonicalize(r);
253                 return r;
254         }
255 }
256
257 static upoly operator*(const upoly& a, const upoly& b)
258 {
259         upoly c;
260         if ( a.empty() || b.empty() ) return c;
261
262         int n = degree(a) + degree(b);
263         c.resize(n+1, 0);
264         for ( int i=0 ; i<=n; ++i ) {
265                 for ( int j=0 ; j<=i; ++j ) {
266                         if ( j > degree(a) || (i-j) > degree(b) ) continue;
267                         c[i] = c[i] + a[j] * b[i-j];
268                 }
269         }
270         canonicalize(c);
271         return c;
272 }
273
274 static umodpoly operator*(const umodpoly& a, const umodpoly& b)
275 {
276         umodpoly c;
277         if ( a.empty() || b.empty() ) return c;
278
279         int n = degree(a) + degree(b);
280         c.resize(n+1, a[0].ring()->zero());
281         for ( int i=0 ; i<=n; ++i ) {
282                 for ( int j=0 ; j<=i; ++j ) {
283                         if ( j > degree(a) || (i-j) > degree(b) ) continue;
284                         c[i] = c[i] + a[j] * b[i-j];
285                 }
286         }
287         canonicalize(c);
288         return c;
289 }
290
291 static upoly operator*(const upoly& a, const cl_I& x)
292 {
293         if ( zerop(x) ) {
294                 upoly r;
295                 return r;
296         }
297         upoly r(a.size());
298         for ( size_t i=0; i<a.size(); ++i ) {
299                 r[i] = a[i] * x;
300         }
301         return r;
302 }
303
304 static upoly operator/(const upoly& a, const cl_I& x)
305 {
306         if ( zerop(x) ) {
307                 upoly r;
308                 return r;
309         }
310         upoly r(a.size());
311         for ( size_t i=0; i<a.size(); ++i ) {
312                 r[i] = exquo(a[i],x);
313         }
314         return r;
315 }
316
317 static umodpoly operator*(const umodpoly& a, const cl_MI& x)
318 {
319         umodpoly r(a.size());
320         for ( size_t i=0; i<a.size(); ++i ) {
321                 r[i] = a[i] * x;
322         }
323         canonicalize(r);
324         return r;
325 }
326
327 static void upoly_from_ex(upoly& up, const ex& e, const ex& x)
328 {
329         // assert: e is in Z[x]
330         int deg = e.degree(x);
331         up.resize(deg+1);
332         int ldeg = e.ldegree(x);
333         for ( ; deg>=ldeg; --deg ) {
334                 up[deg] = the<cl_I>(ex_to<numeric>(e.coeff(x, deg)).to_cl_N());
335         }
336         for ( ; deg>=0; --deg ) {
337                 up[deg] = 0;
338         }
339         canonicalize(up);
340 }
341
342 static void umodpoly_from_upoly(umodpoly& ump, const upoly& e, const cl_modint_ring& R)
343 {
344         int deg = degree(e);
345         ump.resize(deg+1);
346         for ( ; deg>=0; --deg ) {
347                 ump[deg] = R->canonhom(e[deg]);
348         }
349         canonicalize(ump);
350 }
351
352 static void umodpoly_from_ex(umodpoly& ump, const ex& e, const ex& x, const cl_modint_ring& R)
353 {
354         // assert: e is in Z[x]
355         int deg = e.degree(x);
356         ump.resize(deg+1);
357         int ldeg = e.ldegree(x);
358         for ( ; deg>=ldeg; --deg ) {
359                 cl_I coeff = the<cl_I>(ex_to<numeric>(e.coeff(x, deg)).to_cl_N());
360                 ump[deg] = R->canonhom(coeff);
361         }
362         for ( ; deg>=0; --deg ) {
363                 ump[deg] = R->zero();
364         }
365         canonicalize(ump);
366 }
367
368 static void umodpoly_from_ex(umodpoly& ump, const ex& e, const ex& x, const cl_I& modulus)
369 {
370         umodpoly_from_ex(ump, e, x, find_modint_ring(modulus));
371 }
372
373 static ex upoly_to_ex(const upoly& a, const ex& x)
374 {
375         if ( a.empty() ) return 0;
376         ex e;
377         for ( int i=degree(a); i>=0; --i ) {
378                 e += numeric(a[i]) * pow(x, i);
379         }
380         return e;
381 }
382
383 static ex umodpoly_to_ex(const umodpoly& a, const ex& x)
384 {
385         if ( a.empty() ) return 0;
386         cl_modint_ring R = a[0].ring();
387         cl_I mod = R->modulus;
388         cl_I halfmod = (mod-1) >> 1;
389         ex e;
390         for ( int i=degree(a); i>=0; --i ) {
391                 cl_I n = R->retract(a[i]);
392                 if ( n > halfmod ) {
393                         e += numeric(n-mod) * pow(x, i);
394                 } else {
395                         e += numeric(n) * pow(x, i);
396                 }
397         }
398         return e;
399 }
400
401 static upoly umodpoly_to_upoly(const umodpoly& a)
402 {
403         upoly e(a.size());
404         if ( a.empty() ) return e;
405         cl_modint_ring R = a[0].ring();
406         cl_I mod = R->modulus;
407         cl_I halfmod = (mod-1) >> 1;
408         for ( int i=degree(a); i>=0; --i ) {
409                 cl_I n = R->retract(a[i]);
410                 if ( n > halfmod ) {
411                         e[i] = n-mod;
412                 } else {
413                         e[i] = n;
414                 }
415         }
416         return e;
417 }
418
419 static umodpoly umodpoly_to_umodpoly(const umodpoly& a, const cl_modint_ring& R, unsigned int m)
420 {
421         umodpoly e;
422         if ( a.empty() ) return e;
423         cl_modint_ring oldR = a[0].ring();
424         size_t sa = a.size();
425         e.resize(sa+m, R->zero());
426         for ( size_t i=0; i<sa; ++i ) {
427                 e[i+m] = R->canonhom(oldR->retract(a[i]));
428         }
429         canonicalize(e);
430         return e;
431 }
432
433 /** Divides all coefficients of the polynomial a by the integer x.
434  *  All coefficients are supposed to be divisible by x. If they are not, the
435  *  the<cl_I> cast will raise an exception.
436  *
437  *  @param[in,out] a  polynomial of which the coefficients will be reduced by x
438  *  @param[in]     x  integer that divides the coefficients
439  */
440 static void reduce_coeff(umodpoly& a, const cl_I& x)
441 {
442         if ( a.empty() ) return;
443
444         cl_modint_ring R = a[0].ring();
445         umodpoly::iterator i = a.begin(), end = a.end();
446         for ( ; i!=end; ++i ) {
447                 // cln cannot perform this division in the modular field
448                 cl_I c = R->retract(*i);
449                 *i = cl_MI(R, the<cl_I>(c / x));
450         }
451 }
452
453 /** Calculates remainder of a/b.
454  *  Assertion: a and b not empty.
455  *
456  *  @param[in]  a  polynomial dividend
457  *  @param[in]  b  polynomial divisor
458  *  @param[out] r  polynomial remainder
459  */
460 static void rem(const umodpoly& a, const umodpoly& b, umodpoly& r)
461 {
462         int k, n;
463         n = degree(b);
464         k = degree(a) - n;
465         r = a;
466         if ( k < 0 ) return;
467
468         do {
469                 cl_MI qk = div(r[n+k], b[n]);
470                 if ( !zerop(qk) ) {
471                         for ( int i=0; i<n; ++i ) {
472                                 unsigned int j = n + k - 1 - i;
473                                 r[j] = r[j] - qk * b[j-k];
474                         }
475                 }
476         } while ( k-- );
477
478         fill(r.begin()+n, r.end(), a[0].ring()->zero());
479         canonicalize(r);
480 }
481
482 /** Calculates quotient of a/b.
483  *  Assertion: a and b not empty.
484  *
485  *  @param[in]  a  polynomial dividend
486  *  @param[in]  b  polynomial divisor
487  *  @param[out] q  polynomial quotient
488  */
489 static void div(const umodpoly& a, const umodpoly& b, umodpoly& q)
490 {
491         int k, n;
492         n = degree(b);
493         k = degree(a) - n;
494         q.clear();
495         if ( k < 0 ) return;
496
497         umodpoly r = a;
498         q.resize(k+1, a[0].ring()->zero());
499         do {
500                 cl_MI qk = div(r[n+k], b[n]);
501                 if ( !zerop(qk) ) {
502                         q[k] = qk;
503                         for ( int i=0; i<n; ++i ) {
504                                 unsigned int j = n + k - 1 - i;
505                                 r[j] = r[j] - qk * b[j-k];
506                         }
507                 }
508         } while ( k-- );
509
510         canonicalize(q);
511 }
512
513 /** Calculates quotient and remainder of a/b.
514  *  Assertion: a and b not empty.
515  *
516  *  @param[in]  a  polynomial dividend
517  *  @param[in]  b  polynomial divisor
518  *  @param[out] r  polynomial remainder
519  *  @param[out] q  polynomial quotient
520  */
521 static void remdiv(const umodpoly& a, const umodpoly& b, umodpoly& r, umodpoly& q)
522 {
523         int k, n;
524         n = degree(b);
525         k = degree(a) - n;
526         q.clear();
527         r = a;
528         if ( k < 0 ) return;
529
530         q.resize(k+1, a[0].ring()->zero());
531         do {
532                 cl_MI qk = div(r[n+k], b[n]);
533                 if ( !zerop(qk) ) {
534                         q[k] = qk;
535                         for ( int i=0; i<n; ++i ) {
536                                 unsigned int j = n + k - 1 - i;
537                                 r[j] = r[j] - qk * b[j-k];
538                         }
539                 }
540         } while ( k-- );
541
542         fill(r.begin()+n, r.end(), a[0].ring()->zero());
543         canonicalize(r);
544         canonicalize(q);
545 }
546
547 /** Calculates the GCD of polynomial a and b.
548  *
549  *  @param[in]  a  polynomial
550  *  @param[in]  b  polynomial
551  *  @param[out] c  GCD
552  */
553 static void gcd(const umodpoly& a, const umodpoly& b, umodpoly& c)
554 {
555         if ( degree(a) < degree(b) ) return gcd(b, a, c);
556
557         c = a;
558         normalize_in_field(c);
559         umodpoly d = b;
560         normalize_in_field(d);
561         umodpoly r;
562         while ( !d.empty() ) {
563                 rem(c, d, r);
564                 c = d;
565                 d = r;
566         }
567         normalize_in_field(c);
568 }
569
570 /** Calculates the derivative of the polynomial a.
571  *  
572  *  @param[in]  a  polynomial of which to take the derivative
573  *  @param[out] d  result/derivative
574  */
575 static void deriv(const umodpoly& a, umodpoly& d)
576 {
577         d.clear();
578         if ( a.size() <= 1 ) return;
579
580         d.insert(d.begin(), a.begin()+1, a.end());
581         int max = d.size();
582         for ( int i=1; i<max; ++i ) {
583                 d[i] = d[i] * (i+1);
584         }
585         canonicalize(d);
586 }
587
588 static bool unequal_one(const umodpoly& a)
589 {
590         if ( a.empty() ) return true;
591         return ( a.size() != 1 || a[0] != a[0].ring()->one() );
592 }
593
594 static bool equal_one(const umodpoly& a)
595 {
596         return ( a.size() == 1 && a[0] == a[0].ring()->one() );
597 }
598
599 /** Returns true if polynomial a is square free.
600  *
601  *  @param[in] a  polynomial to check
602  *  @return       true if polynomial is square free, false otherwise
603  */
604 static bool squarefree(const umodpoly& a)
605 {
606         umodpoly b;
607         deriv(a, b);
608         if ( b.empty() ) {
609                 return false;
610         }
611         umodpoly c;
612         gcd(a, b, c);
613         return equal_one(c);
614 }
615
616 // END modular univariate polynomial code
617 ////////////////////////////////////////////////////////////////////////////////
618
619 ////////////////////////////////////////////////////////////////////////////////
620 // modular matrix
621
622 class modular_matrix
623 {
624         friend ostream& operator<<(ostream& o, const modular_matrix& m);
625 public:
626         modular_matrix(size_t r_, size_t c_, const cl_MI& init) : r(r_), c(c_)
627         {
628                 m.resize(c*r, init);
629         }
630         size_t rowsize() const { return r; }
631         size_t colsize() const { return c; }
632         cl_MI& operator()(size_t row, size_t col) { return m[row*c + col]; }
633         cl_MI operator()(size_t row, size_t col) const { return m[row*c + col]; }
634         void mul_col(size_t col, const cl_MI x)
635         {
636                 mvec::iterator i = m.begin() + col;
637                 for ( size_t rc=0; rc<r; ++rc ) {
638                         *i = *i * x;
639                         i += c;
640                 }
641         }
642         void sub_col(size_t col1, size_t col2, const cl_MI fac)
643         {
644                 mvec::iterator i1 = m.begin() + col1;
645                 mvec::iterator i2 = m.begin() + col2;
646                 for ( size_t rc=0; rc<r; ++rc ) {
647                         *i1 = *i1 - *i2 * fac;
648                         i1 += c;
649                         i2 += c;
650                 }
651         }
652         void switch_col(size_t col1, size_t col2)
653         {
654                 cl_MI buf;
655                 mvec::iterator i1 = m.begin() + col1;
656                 mvec::iterator i2 = m.begin() + col2;
657                 for ( size_t rc=0; rc<r; ++rc ) {
658                         buf = *i1; *i1 = *i2; *i2 = buf;
659                         i1 += c;
660                         i2 += c;
661                 }
662         }
663         void mul_row(size_t row, const cl_MI x)
664         {
665                 vector<cl_MI>::iterator i = m.begin() + row*c;
666                 for ( size_t cc=0; cc<c; ++cc ) {
667                         *i = *i * x;
668                         ++i;
669                 }
670         }
671         void sub_row(size_t row1, size_t row2, const cl_MI fac)
672         {
673                 vector<cl_MI>::iterator i1 = m.begin() + row1*c;
674                 vector<cl_MI>::iterator i2 = m.begin() + row2*c;
675                 for ( size_t cc=0; cc<c; ++cc ) {
676                         *i1 = *i1 - *i2 * fac;
677                         ++i1;
678                         ++i2;
679                 }
680         }
681         void switch_row(size_t row1, size_t row2)
682         {
683                 cl_MI buf;
684                 vector<cl_MI>::iterator i1 = m.begin() + row1*c;
685                 vector<cl_MI>::iterator i2 = m.begin() + row2*c;
686                 for ( size_t cc=0; cc<c; ++cc ) {
687                         buf = *i1; *i1 = *i2; *i2 = buf;
688                         ++i1;
689                         ++i2;
690                 }
691         }
692         bool is_col_zero(size_t col) const
693         {
694                 mvec::const_iterator i = m.begin() + col;
695                 for ( size_t rr=0; rr<r; ++rr ) {
696                         if ( !zerop(*i) ) {
697                                 return false;
698                         }
699                         i += c;
700                 }
701                 return true;
702         }
703         bool is_row_zero(size_t row) const
704         {
705                 mvec::const_iterator i = m.begin() + row*c;
706                 for ( size_t cc=0; cc<c; ++cc ) {
707                         if ( !zerop(*i) ) {
708                                 return false;
709                         }
710                         ++i;
711                 }
712                 return true;
713         }
714         void set_row(size_t row, const vector<cl_MI>& newrow)
715         {
716                 mvec::iterator i1 = m.begin() + row*c;
717                 mvec::const_iterator i2 = newrow.begin(), end = newrow.end();
718                 for ( ; i2 != end; ++i1, ++i2 ) {
719                         *i1 = *i2;
720                 }
721         }
722         mvec::const_iterator row_begin(size_t row) const { return m.begin()+row*c; }
723         mvec::const_iterator row_end(size_t row) const { return m.begin()+row*c+r; }
724 private:
725         size_t r, c;
726         mvec m;
727 };
728
729 #ifdef DEBUGFACTOR
730 modular_matrix operator*(const modular_matrix& m1, const modular_matrix& m2)
731 {
732         const unsigned int r = m1.rowsize();
733         const unsigned int c = m2.colsize();
734         modular_matrix o(r,c,m1(0,0));
735
736         for ( size_t i=0; i<r; ++i ) {
737                 for ( size_t j=0; j<c; ++j ) {
738                         cl_MI buf;
739                         buf = m1(i,0) * m2(0,j);
740                         for ( size_t k=1; k<c; ++k ) {
741                                 buf = buf + m1(i,k)*m2(k,j);
742                         }
743                         o(i,j) = buf;
744                 }
745         }
746         return o;
747 }
748
749 ostream& operator<<(ostream& o, const modular_matrix& m)
750 {
751         cl_modint_ring R = m(0,0).ring();
752         o << "{";
753         for ( size_t i=0; i<m.rowsize(); ++i ) {
754                 o << "{";
755                 for ( size_t j=0; j<m.colsize()-1; ++j ) {
756                         o << R->retract(m(i,j)) << ",";
757                 }
758                 o << R->retract(m(i,m.colsize()-1)) << "}";
759                 if ( i != m.rowsize()-1 ) {
760                         o << ",";
761                 }
762         }
763         o << "}";
764         return o;
765 }
766 #endif // def DEBUGFACTOR
767
768 // END modular matrix
769 ////////////////////////////////////////////////////////////////////////////////
770
771 static void q_matrix(const umodpoly& a_, modular_matrix& Q)
772 {
773         umodpoly a = a_;
774         normalize_in_field(a);
775
776         int n = degree(a);
777         unsigned int q = cl_I_to_uint(a[0].ring()->modulus);
778         umodpoly r(n, a[0].ring()->zero());
779         r[0] = a[0].ring()->one();
780         Q.set_row(0, r);
781         unsigned int max = (n-1) * q;
782         for ( size_t m=1; m<=max; ++m ) {
783                 cl_MI rn_1 = r.back();
784                 for ( size_t i=n-1; i>0; --i ) {
785                         r[i] = r[i-1] - (rn_1 * a[i]);
786                 }
787                 r[0] = -rn_1 * a[0];
788                 if ( (m % q) == 0 ) {
789                         Q.set_row(m/q, r);
790                 }
791         }
792 }
793
794 static void nullspace(modular_matrix& M, vector<mvec>& basis)
795 {
796         const size_t n = M.rowsize();
797         const cl_MI one = M(0,0).ring()->one();
798         for ( size_t i=0; i<n; ++i ) {
799                 M(i,i) = M(i,i) - one;
800         }
801         for ( size_t r=0; r<n; ++r ) {
802                 size_t cc = 0;
803                 for ( ; cc<n; ++cc ) {
804                         if ( !zerop(M(r,cc)) ) {
805                                 if ( cc < r ) {
806                                         if ( !zerop(M(cc,cc)) ) {
807                                                 continue;
808                                         }
809                                         M.switch_col(cc, r);
810                                 }
811                                 else if ( cc > r ) {
812                                         M.switch_col(cc, r);
813                                 }
814                                 break;
815                         }
816                 }
817                 if ( cc < n ) {
818                         M.mul_col(r, recip(M(r,r)));
819                         for ( cc=0; cc<n; ++cc ) {
820                                 if ( cc != r ) {
821                                         M.sub_col(cc, r, M(r,cc));
822                                 }
823                         }
824                 }
825         }
826
827         for ( size_t i=0; i<n; ++i ) {
828                 M(i,i) = M(i,i) - one;
829         }
830         for ( size_t i=0; i<n; ++i ) {
831                 if ( !M.is_row_zero(i) ) {
832                         mvec nu(M.row_begin(i), M.row_end(i));
833                         basis.push_back(nu);
834                 }
835         }
836 }
837
838 static void berlekamp(const umodpoly& a, upvec& upv)
839 {
840         cl_modint_ring R = a[0].ring();
841         umodpoly one(1, R->one());
842
843         modular_matrix Q(degree(a), degree(a), R->zero());
844         q_matrix(a, Q);
845         vector<mvec> nu;
846         nullspace(Q, nu);
847
848         const unsigned int k = nu.size();
849         if ( k == 1 ) {
850                 return;
851         }
852
853         list<umodpoly> factors;
854         factors.push_back(a);
855         unsigned int size = 1;
856         unsigned int r = 1;
857         unsigned int q = cl_I_to_uint(R->modulus);
858
859         list<umodpoly>::iterator u = factors.begin();
860
861         while ( true ) {
862                 for ( unsigned int s=0; s<q; ++s ) {
863                         umodpoly nur = nu[r];
864                         nur[0] = nur[0] - cl_MI(R, s);
865                         canonicalize(nur);
866                         umodpoly g;
867                         gcd(nur, *u, g);
868                         if ( unequal_one(g) && g != *u ) {
869                                 umodpoly uo;
870                                 div(*u, g, uo);
871                                 if ( equal_one(uo) ) {
872                                         throw logic_error("berlekamp: unexpected divisor.");
873                                 }
874                                 else {
875                                         *u = uo;
876                                 }
877                                 factors.push_back(g);
878                                 size = 0;
879                                 list<umodpoly>::const_iterator i = factors.begin(), end = factors.end();
880                                 while ( i != end ) {
881                                         if ( degree(*i) ) ++size; 
882                                         ++i;
883                                 }
884                                 if ( size == k ) {
885                                         list<umodpoly>::const_iterator i = factors.begin(), end = factors.end();
886                                         while ( i != end ) {
887                                                 upv.push_back(*i++);
888                                         }
889                                         return;
890                                 }
891                         }
892                 }
893                 if ( ++r == k ) {
894                         r = 1;
895                         ++u;
896                 }
897         }
898 }
899
900 static void expt_1_over_p(const umodpoly& a, unsigned int prime, umodpoly& ap)
901 {
902         size_t newdeg = degree(a)/prime;
903         ap.resize(newdeg+1);
904         ap[0] = a[0];
905         for ( size_t i=1; i<=newdeg; ++i ) {
906                 ap[i] = a[i*prime];
907         }
908 }
909
910 static void modsqrfree(const umodpoly& a, upvec& factors, vector<int>& mult)
911 {
912         const unsigned int prime = cl_I_to_uint(a[0].ring()->modulus);
913         int i = 1;
914         umodpoly b;
915         deriv(a, b);
916         if ( b.size() ) {
917                 umodpoly c;
918                 gcd(a, b, c);
919                 umodpoly w;
920                 div(a, c, w);
921                 while ( unequal_one(w) ) {
922                         umodpoly y;
923                         gcd(w, c, y);
924                         umodpoly z;
925                         div(w, y, z);
926                         factors.push_back(z);
927                         mult.push_back(i);
928                         ++i;
929                         w = y;
930                         umodpoly buf;
931                         div(c, y, buf);
932                         c = buf;
933                 }
934                 if ( unequal_one(c) ) {
935                         umodpoly cp;
936                         expt_1_over_p(c, prime, cp);
937                         size_t previ = mult.size();
938                         modsqrfree(cp, factors, mult);
939                         for ( size_t i=previ; i<mult.size(); ++i ) {
940                                 mult[i] *= prime;
941                         }
942                 }
943         }
944         else {
945                 umodpoly ap;
946                 expt_1_over_p(a, prime, ap);
947                 size_t previ = mult.size();
948                 modsqrfree(ap, factors, mult);
949                 for ( size_t i=previ; i<mult.size(); ++i ) {
950                         mult[i] *= prime;
951                 }
952         }
953 }
954
955 static void distinct_degree_factor(const umodpoly& a_, vector<int>& degrees, upvec& ddfactors)
956 {
957         umodpoly a = a_;
958
959         cl_modint_ring R = a[0].ring();
960         int q = cl_I_to_int(R->modulus);
961         int nhalf = degree(a)/2;
962
963         int i = 1;
964         umodpoly w(2);
965         w[0] = R->zero();
966         w[1] = R->one();
967         umodpoly x = w;
968
969         while ( i <= nhalf ) {
970                 expt_pos(w, q);
971                 umodpoly buf;
972                 rem(w, a, buf);
973                 w = buf;
974                 umodpoly wx = w - x;
975                 gcd(a, wx, buf);
976                 if ( unequal_one(buf) ) {
977                         degrees.push_back(i);
978                         ddfactors.push_back(buf);
979                 }
980                 if ( unequal_one(buf) ) {
981                         umodpoly buf2;
982                         div(a, buf, buf2);
983                         a = buf2;
984                         nhalf = degree(a)/2;
985                         rem(w, a, buf);
986                         w = buf;
987                 }
988                 ++i;
989         }
990         if ( unequal_one(a) ) {
991                 degrees.push_back(degree(a));
992                 ddfactors.push_back(a);
993         }
994 }
995
996 static void same_degree_factor(const umodpoly& a, upvec& upv)
997 {
998         cl_modint_ring R = a[0].ring();
999
1000         vector<int> degrees;
1001         upvec ddfactors;
1002         distinct_degree_factor(a, degrees, ddfactors);
1003
1004         for ( size_t i=0; i<degrees.size(); ++i ) {
1005                 if ( degrees[i] == degree(ddfactors[i]) ) {
1006                         upv.push_back(ddfactors[i]);
1007                 }
1008                 else {
1009                         berlekamp(ddfactors[i], upv);
1010                 }
1011         }
1012 }
1013
1014 #define USE_SAME_DEGREE_FACTOR
1015
1016 static void factor_modular(const umodpoly& p, upvec& upv)
1017 {
1018 #ifdef USE_SAME_DEGREE_FACTOR
1019         same_degree_factor(p, upv);
1020 #else
1021         berlekamp(p, upv);
1022 #endif
1023 }
1024
1025 /** Calculates polynomials s and t such that a*s+b*t==1.
1026  *  Assertion: a and b are relatively prime and not zero.
1027  *
1028  *  @param[in]  a  polynomial
1029  *  @param[in]  b  polynomial
1030  *  @param[out] s  polynomial
1031  *  @param[out] t  polynomial
1032  */
1033 static void exteuclid(const umodpoly& a, const umodpoly& b, umodpoly& s, umodpoly& t)
1034 {
1035         if ( degree(a) < degree(b) ) {
1036                 exteuclid(b, a, t, s);
1037                 return;
1038         }
1039
1040         umodpoly one(1, a[0].ring()->one());
1041         umodpoly c = a; normalize_in_field(c);
1042         umodpoly d = b; normalize_in_field(d);
1043         s = one;
1044         t.clear();
1045         umodpoly d1;
1046         umodpoly d2 = one;
1047         umodpoly q;
1048         while ( true ) {
1049                 div(c, d, q);
1050                 umodpoly r = c - q * d;
1051                 umodpoly r1 = s - q * d1;
1052                 umodpoly r2 = t - q * d2;
1053                 c = d;
1054                 s = d1;
1055                 t = d2;
1056                 if ( r.empty() ) break;
1057                 d = r;
1058                 d1 = r1;
1059                 d2 = r2;
1060         }
1061         cl_MI fac = recip(lcoeff(a) * lcoeff(c));
1062         umodpoly::iterator i = s.begin(), end = s.end();
1063         for ( ; i!=end; ++i ) {
1064                 *i = *i * fac;
1065         }
1066         canonicalize(s);
1067         fac = recip(lcoeff(b) * lcoeff(c));
1068         i = t.begin(), end = t.end();
1069         for ( ; i!=end; ++i ) {
1070                 *i = *i * fac;
1071         }
1072         canonicalize(t);
1073 }
1074
1075 static upoly replace_lc(const upoly& poly, const cl_I& lc)
1076 {
1077         if ( poly.empty() ) return poly;
1078         upoly r = poly;
1079         r.back() = lc;
1080         return r;
1081 }
1082
1083 static inline cl_I calc_bound(const ex& a, const ex& x, int maxdeg)
1084 {
1085         cl_I maxcoeff = 0;
1086         cl_R coeff = 0;
1087         for ( int i=a.degree(x); i>=a.ldegree(x); --i ) {
1088                 cl_I aa = abs(the<cl_I>(ex_to<numeric>(a.coeff(x, i)).to_cl_N()));
1089                 if ( aa > maxcoeff ) maxcoeff = aa;
1090                 coeff = coeff + square(aa);
1091         }
1092         cl_I coeffnorm = ceiling1(the<cl_R>(cln::sqrt(coeff)));
1093         cl_I B = coeffnorm * expt_pos(cl_I(2), cl_I(maxdeg));
1094         return ( B > maxcoeff ) ? B : maxcoeff;
1095 }
1096
1097 static inline cl_I calc_bound(const upoly& a, int maxdeg)
1098 {
1099         cl_I maxcoeff = 0;
1100         cl_R coeff = 0;
1101         for ( int i=degree(a); i>=0; --i ) {
1102                 cl_I aa = abs(a[i]);
1103                 if ( aa > maxcoeff ) maxcoeff = aa;
1104                 coeff = coeff + square(aa);
1105         }