Daily bugfix in the polynomial factorization (code didn't catch polynomial "x"
[ginac.git] / ginac / factor.cpp
1 /** @file factor.cpp
2  *
3  *  Polynomial factorization routines.
4  *  Only univariate at the moment and completely non-optimized!
5  */
6
7 /*
8  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
9  *
10  *  This program is free software; you can redistribute it and/or modify
11  *  it under the terms of the GNU General Public License as published by
12  *  the Free Software Foundation; either version 2 of the License, or
13  *  (at your option) any later version.
14  *
15  *  This program is distributed in the hope that it will be useful,
16  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18  *  GNU General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with this program; if not, write to the Free Software
22  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
23  */
24
25 #include "factor.h"
26
27 #include "ex.h"
28 #include "numeric.h"
29 #include "operators.h"
30 #include "inifcns.h"
31 #include "symbol.h"
32 #include "relational.h"
33 #include "power.h"
34 #include "mul.h"
35 #include "normal.h"
36 #include "add.h"
37
38 #include <algorithm>
39 #include <list>
40 #include <vector>
41 using namespace std;
42
43 #include <cln/cln.h>
44 using namespace cln;
45
46 //#define DEBUGFACTOR
47
48 #ifdef DEBUGFACTOR
49 #include <ostream>
50 #endif // def DEBUGFACTOR
51
52 namespace GiNaC {
53
54 namespace {
55
56 typedef vector<cl_MI> Vec;
57 typedef vector<Vec> VecVec;
58
59 #ifdef DEBUGFACTOR
60 ostream& operator<<(ostream& o, const Vec& v)
61 {
62         Vec::const_iterator i = v.begin(), end = v.end();
63         while ( i != end ) {
64                 o << *i++ << " ";
65         }
66         return o;
67 }
68 #endif // def DEBUGFACTOR
69
70 #ifdef DEBUGFACTOR
71 ostream& operator<<(ostream& o, const VecVec& v)
72 {
73         VecVec::const_iterator i = v.begin(), end = v.end();
74         while ( i != end ) {
75                 o << *i++ << endl;
76         }
77         return o;
78 }
79 #endif // def DEBUGFACTOR
80
81 struct Term
82 {
83         cl_MI c;          // coefficient
84         unsigned int exp; // exponent >=0
85 };
86
87 #ifdef DEBUGFACTOR
88 ostream& operator<<(ostream& o, const Term& t)
89 {
90         if ( t.exp ) {
91                 o << "(" << t.c << ")x^" << t.exp;
92         }
93         else {
94                 o << "(" << t.c << ")";
95         }
96         return o;
97 }
98 #endif // def DEBUGFACTOR
99
100 struct UniPoly
101 {
102         cl_modint_ring R;
103         list<Term> terms;  // highest exponent first
104
105         UniPoly(const cl_modint_ring& ring) : R(ring) { }
106         UniPoly(const cl_modint_ring& ring, const ex& poly, const ex& x) : R(ring)
107         { 
108                 // assert: poly is in Z[x]
109                 Term t;
110                 for ( int i=poly.degree(x); i>=poly.ldegree(x); --i ) {
111                         int coeff = ex_to<numeric>(poly.coeff(x,i)).to_int();
112                         if ( coeff ) {
113                                 t.c = R->canonhom(coeff);
114                                 if ( !zerop(t.c) ) {
115                                         t.exp = i;
116                                         terms.push_back(t);
117                                 }
118                         }
119                 }
120         }
121         UniPoly(const cl_modint_ring& ring, const Vec& v) : R(ring)
122         {
123                 Term t;
124                 for ( unsigned int i=0; i<v.size(); ++i ) {
125                         if ( !zerop(v[i]) ) {
126                                 t.c = v[i];
127                                 t.exp = i;
128                                 terms.push_front(t);
129                         }
130                 }
131         }
132         unsigned int degree() const
133         {
134                 if ( terms.size() ) {
135                         return terms.front().exp;
136                 }
137                 else {
138                         return 0;
139                 }
140         }
141         bool zero() const { return (terms.size() == 0); }
142         const cl_MI operator[](unsigned int deg) const
143         {
144                 list<Term>::const_iterator i = terms.begin(), end = terms.end();
145                 for ( ; i != end; ++i ) {
146                         if ( i->exp == deg ) {
147                                 return i->c;
148                         }
149                         if ( i->exp < deg ) {
150                                 break;
151                         }
152                 }
153                 return R->zero();
154         }
155         void set(unsigned int deg, const cl_MI& c)
156         {
157                 list<Term>::iterator i = terms.begin(), end = terms.end();
158                 while ( i != end ) {
159                         if ( i->exp == deg ) {
160                                 if ( !zerop(c) ) {
161                                         i->c = c;
162                                 }
163                                 else {
164                                         terms.erase(i);
165                                 }
166                                 return;
167                         }
168                         if ( i->exp < deg ) {
169                                 break;
170                         }
171                         ++i;
172                 }
173                 if ( !zerop(c) ) {
174                         Term t;
175                         t.c = c;
176                         t.exp = deg;
177                         terms.insert(i, t);
178                 }
179         }
180         ex to_ex(const ex& x, bool symmetric = true) const
181         {
182                 ex r;
183                 list<Term>::const_iterator i = terms.begin(), end = terms.end();
184                 if ( symmetric ) {
185                         numeric mod(R->modulus);
186                         numeric halfmod = (mod-1)/2;
187                         for ( ; i != end; ++i ) {
188                                 numeric n(R->retract(i->c));
189                                 if ( n > halfmod ) {
190                                         r += pow(x, i->exp) * (n-mod);
191                                 }
192                                 else {
193                                         r += pow(x, i->exp) * n;
194                                 }
195                         }
196                 }
197                 else {
198                         for ( ; i != end; ++i ) {
199                                 r += pow(x, i->exp) * numeric(R->retract(i->c));
200                         }
201                 }
202                 return r;
203         }
204         void unit_normal()
205         {
206                 if ( terms.size() ) {
207                         if ( terms.front().c != R->one() ) {
208                                 list<Term>::iterator i = terms.begin(), end = terms.end();
209                                 cl_MI cont = i->c;
210                                 i->c = R->one();
211                                 while ( ++i != end ) {
212                                         i->c = div(i->c, cont);
213                                         if ( zerop(i->c) ) {
214                                                 terms.erase(i);
215                                         }
216                                 }
217                         }
218                 }
219         }
220         cl_MI unit() const
221         {
222                 return terms.front().c;
223         }
224         void divide(const cl_MI& x)
225         {
226                 list<Term>::iterator i = terms.begin(), end = terms.end();
227                 for ( ; i != end; ++i ) {
228                         i->c = div(i->c, x);
229                         if ( zerop(i->c) ) {
230                                 terms.erase(i);
231                         }
232                 }
233         }
234         void reduce_exponents(unsigned int prime)
235         {
236                 list<Term>::iterator i = terms.begin(), end = terms.end();
237                 while ( i != end ) {
238                         if ( i->exp > 0 ) {
239                                 // assert: i->exp is multiple of prime
240                                 i->exp /= prime;
241                         }
242                         ++i;
243                 }
244         }
245         void deriv(UniPoly& d) const
246         {
247                 list<Term>::const_iterator i = terms.begin(), end = terms.end();
248                 while ( i != end ) {
249                         if ( i->exp ) {
250                                 cl_MI newc = i->c * i->exp;
251                                 if ( !zerop(newc) ) {
252                                         Term t;
253                                         t.c = newc;
254                                         t.exp = i->exp-1;
255                                         d.terms.push_back(t);
256                                 }
257                         }
258                         ++i;
259                 }
260         }
261         bool operator<(const UniPoly& o) const
262         {
263                 if ( terms.size() != o.terms.size() ) {
264                         return terms.size() < o.terms.size();
265                 }
266                 list<Term>::const_iterator i1 = terms.begin(), end = terms.end();
267                 list<Term>::const_iterator i2 = o.terms.begin();
268                 while ( i1 != end ) {
269                         if ( i1->exp != i2->exp ) {
270                                 return i1->exp < i2->exp;
271                         }
272                         if ( i1->c != i2->c ) {
273                                 return R->retract(i1->c) < R->retract(i2->c);
274                         }
275                         ++i1; ++i2;
276                 }
277                 return true;
278         }
279         bool operator==(const UniPoly& o) const
280         {
281                 if ( terms.size() != o.terms.size() ) {
282                         return false;
283                 }
284                 list<Term>::const_iterator i1 = terms.begin(), end = terms.end();
285                 list<Term>::const_iterator i2 = o.terms.begin();
286                 while ( i1 != end ) {
287                         if ( i1->exp != i2->exp ) {
288                                 return false;
289                         }
290                         if ( i1->c != i2->c ) {
291                                 return false;
292                         }
293                         ++i1; ++i2;
294                 }
295                 return true;
296         }
297         bool operator!=(const UniPoly& o) const
298         {
299                 bool res = !(*this == o);
300                 return res;
301         }
302 };
303
304 static UniPoly operator*(const UniPoly& a, const UniPoly& b)
305 {
306         unsigned int n = a.degree()+b.degree();
307         UniPoly c(a.R);
308         Term t;
309         for ( unsigned int i=0 ; i<=n; ++i ) {
310                 t.c = a.R->zero();
311                 for ( unsigned int j=0 ; j<=i; ++j ) {
312                         t.c = t.c + a[j] * b[i-j];
313                 }
314                 if ( !zerop(t.c) ) {
315                         t.exp = i;
316                         c.terms.push_front(t);
317                 }
318         }
319         return c;
320 }
321
322 static UniPoly operator-(const UniPoly& a, const UniPoly& b)
323 {
324         list<Term>::const_iterator ia = a.terms.begin(), aend = a.terms.end();
325         list<Term>::const_iterator ib = b.terms.begin(), bend = b.terms.end();
326         UniPoly c(a.R);
327         while ( ia != aend && ib != bend ) {
328                 if ( ia->exp > ib->exp ) {
329                         c.terms.push_back(*ia);
330                         ++ia;
331                 }
332                 else if ( ia->exp < ib->exp ) {
333                         c.terms.push_back(*ib);
334                         c.terms.back().c = -c.terms.back().c;
335                         ++ib;
336                 }
337                 else {
338                         Term t;
339                         t.exp = ia->exp;
340                         t.c = ia->c - ib->c;
341                         if ( !zerop(t.c) ) {
342                                 c.terms.push_back(t);
343                         }
344                         ++ia; ++ib;
345                 }
346         }
347         while ( ia != aend ) {
348                 c.terms.push_back(*ia);
349                 ++ia;
350         }
351         while ( ib != bend ) {
352                 c.terms.push_back(*ib);
353                 c.terms.back().c = -c.terms.back().c;
354                 ++ib;
355         }
356         return c;
357 }
358
359 static UniPoly operator-(const UniPoly& a)
360 {
361         list<Term>::const_iterator ia = a.terms.begin(), aend = a.terms.end();
362         UniPoly c(a.R);
363         while ( ia != aend ) {
364                 c.terms.push_back(*ia);
365                 c.terms.back().c = -c.terms.back().c;
366                 ++ia;
367         }
368         return c;
369 }
370
371 #ifdef DEBUGFACTOR
372 ostream& operator<<(ostream& o, const UniPoly& t)
373 {
374         list<Term>::const_iterator i = t.terms.begin(), end = t.terms.end();
375         if ( i == end ) {
376                 o << "0";
377                 return o;
378         }
379         for ( ; i != end; ) {
380                 o << *i++;
381                 if ( i != end ) {
382                         o << " + ";
383                 }
384         }
385         return o;
386 }
387 #endif // def DEBUGFACTOR
388
389 #ifdef DEBUGFACTOR
390 ostream& operator<<(ostream& o, const list<UniPoly>& t)
391 {
392         list<UniPoly>::const_iterator i = t.begin(), end = t.end();
393         o << "{" << endl;
394         for ( ; i != end; ) {
395                 o << *i++ << endl;
396         }
397         o << "}" << endl;
398         return o;
399 }
400 #endif // def DEBUGFACTOR
401
402 typedef vector<UniPoly> UniPolyVec;
403
404 struct UniFactor
405 {
406         UniPoly p;
407         unsigned int exp;
408
409         UniFactor(const cl_modint_ring& ring) : p(ring) { }
410         UniFactor(const UniPoly& p_, unsigned int exp_) : p(p_), exp(exp_) { }
411         bool operator<(const UniFactor& o) const
412         {
413                 return p < o.p;
414         }
415 };
416
417 struct UniFactorVec
418 {
419         vector<UniFactor> factors;
420
421         void unique()
422         {
423                 sort(factors.begin(), factors.end());
424                 if ( factors.size() > 1 ) {
425                         vector<UniFactor>::iterator i = factors.begin();
426                         vector<UniFactor>::const_iterator cmp = factors.begin()+1;
427                         vector<UniFactor>::iterator end = factors.end();
428                         while ( cmp != end ) {
429                                 if ( i->p != cmp->p ) {
430                                         ++i;
431                                         ++cmp;
432                                 }
433                                 else {
434                                         i->exp += cmp->exp;
435                                         ++cmp;
436                                 }
437                         }
438                         if ( i != end-1 ) {
439                                 factors.erase(i+1, end);
440                         }
441                 }
442         }
443 };
444
445 #ifdef DEBUGFACTOR
446 ostream& operator<<(ostream& o, const UniFactorVec& ufv)
447 {
448         for ( size_t i=0; i<ufv.factors.size(); ++i ) {
449                 if ( i != ufv.factors.size()-1 ) {
450                         o << "*";
451                 }
452                 else {
453                         o << " ";
454                 }
455                 o << "[ " << ufv.factors[i].p << " ]^" << ufv.factors[i].exp << endl;
456         }
457         return o;
458 }
459 #endif // def DEBUGFACTOR
460
461 static void rem(const UniPoly& a_, const UniPoly& b, UniPoly& c)
462 {
463         if ( a_.degree() < b.degree() ) {
464                 c = a_;
465                 return;
466         }
467
468         unsigned int k, n;
469         n = b.degree();
470         k = a_.degree() - n;
471
472         if ( n == 0 ) {
473                 c.terms.clear();
474                 return;
475         }
476
477         c = a_;
478         Term termbuf;
479
480         while ( true ) {
481                 cl_MI qk = div(c[n+k], b[n]);
482                 if ( !zerop(qk) ) {
483                         unsigned int j;
484                         for ( unsigned int i=0; i<n; ++i ) {
485                                 j = n + k - 1 - i;
486                                 c.set(j, c[j] - qk*b[j-k]);
487                         }
488                 }
489                 if ( k == 0 ) break;
490                 --k;
491         }
492         list<Term>::iterator i = c.terms.begin(), end = c.terms.end();
493         while ( i != end ) {
494                 if ( i->exp <= n-1 ) {
495                         break;
496                 }
497                 ++i;
498         }
499         c.terms.erase(c.terms.begin(), i);
500 }
501
502 static void div(const UniPoly& a_, const UniPoly& b, UniPoly& q)
503 {
504         if ( a_.degree() < b.degree() ) {
505                 q.terms.clear();
506                 return;
507         }
508
509         unsigned int k, n;
510         n = b.degree();
511         k = a_.degree() - n;
512
513         UniPoly c = a_;
514         Term termbuf;
515
516         while ( true ) {
517                 cl_MI qk = div(c[n+k], b[n]);
518                 if ( !zerop(qk) ) {
519                         Term t;
520                         t.c = qk;
521                         t.exp = k;
522                         q.terms.push_back(t);
523                         unsigned int j;
524                         for ( unsigned int i=0; i<n; ++i ) {
525                                 j = n + k - 1 - i;
526                                 c.set(j, c[j] - qk*b[j-k]);
527                         }
528                 }
529                 if ( k == 0 ) break;
530                 --k;
531         }
532 }
533
534 static void gcd(const UniPoly& a, const UniPoly& b, UniPoly& c)
535 {
536         c = a;
537         c.unit_normal();
538         UniPoly d = b;
539         d.unit_normal();
540
541         if ( c.degree() < d.degree() ) {
542                 gcd(b, a, c);
543                 return;
544         }
545
546         while ( !d.zero() ) {
547                 UniPoly r(a.R);
548                 rem(c, d, r);
549                 c = d;
550                 d = r;
551         }
552         c.unit_normal();
553 }
554
555 static bool is_one(const UniPoly& w)
556 {
557         if ( w.terms.size() == 1 && w[0] == w.R->one() ) {
558                 return true;
559         }
560         return false;
561 }
562
563 static void sqrfree_main(const UniPoly& a, UniFactorVec& fvec)
564 {
565         unsigned int i = 1;
566         UniPoly b(a.R);
567         a.deriv(b);
568         if ( !b.zero() ) {
569                 UniPoly c(a.R), w(a.R);
570                 gcd(a, b, c);
571                 div(a, c, w);
572                 while ( !is_one(w) ) {
573                         UniPoly y(a.R), z(a.R);
574                         gcd(w, c, y);
575                         div(w, y, z);
576                         if ( !is_one(z) ) {
577                                 UniFactor uf(z, i);
578                                 fvec.factors.push_back(uf);
579                         }
580                         ++i;
581                         w = y;
582                         UniPoly cbuf(a.R);
583                         div(c, y, cbuf);
584                         c = cbuf;
585                 }
586                 if ( !is_one(c) ) {
587                         unsigned int prime = cl_I_to_uint(c.R->modulus);
588                         c.reduce_exponents(prime);
589                         unsigned int pos = fvec.factors.size();
590                         sqrfree_main(c, fvec);
591                         for ( unsigned int p=pos; p<fvec.factors.size(); ++p ) {
592                                 fvec.factors[p].exp *= prime;
593                         }
594                         return;
595                 }
596         }
597         else {
598                 unsigned int prime = cl_I_to_uint(a.R->modulus);
599                 UniPoly amod = a;
600                 amod.reduce_exponents(prime);
601                 unsigned int pos = fvec.factors.size();
602                 sqrfree_main(amod, fvec);
603                 for ( unsigned int p=pos; p<fvec.factors.size(); ++p ) {
604                         fvec.factors[p].exp *= prime;
605                 }
606                 return;
607         }
608 }
609
610 static void squarefree(const UniPoly& a, UniFactorVec& fvec)
611 {
612         sqrfree_main(a, fvec);
613         fvec.unique();
614 }
615
616 class Matrix
617 {
618         friend ostream& operator<<(ostream& o, const Matrix& m);
619 public:
620         Matrix(size_t r_, size_t c_, const cl_MI& init) : r(r_), c(c_)
621         {
622                 m.resize(c*r, init);
623         }
624         size_t rowsize() const { return r; }
625         size_t colsize() const { return c; }
626         cl_MI& operator()(size_t row, size_t col) { return m[row*c + col]; }
627         cl_MI operator()(size_t row, size_t col) const { return m[row*c + col]; }
628         void mul_col(size_t col, const cl_MI x)
629         {
630                 Vec::iterator i = m.begin() + col;
631                 for ( size_t rc=0; rc<r; ++rc ) {
632                         *i = *i * x;
633                         i += c;
634                 }
635         }
636         void sub_col(size_t col1, size_t col2, const cl_MI fac)
637         {
638                 Vec::iterator i1 = m.begin() + col1;
639                 Vec::iterator i2 = m.begin() + col2;
640                 for ( size_t rc=0; rc<r; ++rc ) {
641                         *i1 = *i1 - *i2 * fac;
642                         i1 += c;
643                         i2 += c;
644                 }
645         }
646         void switch_col(size_t col1, size_t col2)
647         {
648                 cl_MI buf;
649                 Vec::iterator i1 = m.begin() + col1;
650                 Vec::iterator i2 = m.begin() + col2;
651                 for ( size_t rc=0; rc<r; ++rc ) {
652                         buf = *i1; *i1 = *i2; *i2 = buf;
653                         i1 += c;
654                         i2 += c;
655                 }
656         }
657         bool is_row_zero(size_t row) const
658         {
659                 Vec::const_iterator i = m.begin() + row*c;
660                 for ( size_t cc=0; cc<c; ++cc ) {
661                         if ( !zerop(*i) ) {
662                                 return false;
663                         }
664                         ++i;
665                 }
666                 return true;
667         }
668         void set_row(size_t row, const vector<cl_MI>& newrow)
669         {
670                 Vec::iterator i1 = m.begin() + row*c;
671                 Vec::const_iterator i2 = newrow.begin(), end = newrow.end();
672                 for ( ; i2 != end; ++i1, ++i2 ) {
673                         *i1 = *i2;
674                 }
675         }
676         Vec::const_iterator row_begin(size_t row) const { return m.begin()+row*c; }
677         Vec::const_iterator row_end(size_t row) const { return m.begin()+row*c+r; }
678 private:
679         size_t r, c;
680         Vec m;
681 };
682
683 #ifdef DEBUGFACTOR
684 ostream& operator<<(ostream& o, const Matrix& m)
685 {
686         vector<cl_MI>::const_iterator i = m.m.begin(), end = m.m.end();
687         size_t wrap = 1;
688         for ( ; i != end; ++i ) {
689                 o << *i << " ";
690                 if ( !(wrap++ % m.c) ) {
691                         o << endl;
692                 }
693         }
694         o << endl;
695         return o;
696 }
697 #endif // def DEBUGFACTOR
698
699 static void q_matrix(const UniPoly& a, Matrix& Q)
700 {
701         unsigned int n = a.degree();
702         unsigned int q = cl_I_to_uint(a.R->modulus);
703 // fast and buggy
704 //      vector<cl_MI> r(n, a.R->zero());
705 //      r[0] = a.R->one();
706 //      Q.set_row(0, r);
707 //      unsigned int max = (n-1) * q;
708 //      for ( size_t m=1; m<=max; ++m ) {
709 //              cl_MI rn_1 = r.back();
710 //              for ( size_t i=n-1; i>0; --i ) {
711 //                      r[i] = r[i-1] - rn_1 * a[i];
712 //              }
713 //              r[0] = -rn_1 * a[0];
714 //              if ( (m % q) == 0 ) {
715 //                      Q.set_row(m/q, r);
716 //              }
717 //      }
718 // slow and (hopefully) correct
719         for ( size_t i=0; i<n; ++i ) {
720                 UniPoly qk(a.R);
721                 qk.set(i*q, a.R->one());
722                 UniPoly r(a.R);
723                 rem(qk, a, r);
724                 Vec rvec;
725                 for ( size_t j=0; j<n; ++j ) {
726                         rvec.push_back(r[j]);
727                 }
728                 Q.set_row(i, rvec);
729         }
730 }
731
732 static void nullspace(Matrix& M, vector<Vec>& basis)
733 {
734         const size_t n = M.rowsize();
735         const cl_MI one = M(0,0).ring()->one();
736         for ( size_t i=0; i<n; ++i ) {
737                 M(i,i) = M(i,i) - one;
738         }
739         for ( size_t r=0; r<n; ++r ) {
740                 size_t cc = 0;
741                 for ( ; cc<n; ++cc ) {
742                         if ( !zerop(M(r,cc)) ) {
743                                 if ( cc < r ) {
744                                         if ( !zerop(M(cc,cc)) ) {
745                                                 continue;
746                                         }
747                                         M.switch_col(cc, r);
748                                 }
749                                 else if ( cc > r ) {
750                                         M.switch_col(cc, r);
751                                 }
752                                 break;
753                         }
754                 }
755                 if ( cc < n ) {
756                         M.mul_col(r, recip(M(r,r)));
757                         for ( cc=0; cc<n; ++cc ) {
758                                 if ( cc != r ) {
759                                         M.sub_col(cc, r, M(r,cc));
760                                 }
761                         }
762                 }
763         }
764
765         for ( size_t i=0; i<n; ++i ) {
766                 M(i,i) = M(i,i) - one;
767         }
768         for ( size_t i=0; i<n; ++i ) {
769                 if ( !M.is_row_zero(i) ) {
770                         Vec nu(M.row_begin(i), M.row_end(i));
771                         basis.push_back(nu);
772                 }
773         }
774 }
775
776 static void berlekamp(const UniPoly& a, UniPolyVec& upv)
777 {
778         Matrix Q(a.degree(), a.degree(), a.R->zero());
779         q_matrix(a, Q);
780         VecVec nu;
781         nullspace(Q, nu);
782         const unsigned int k = nu.size();
783         if ( k == 1 ) {
784                 return;
785         }
786
787         list<UniPoly> factors;
788         factors.push_back(a);
789         unsigned int size = 1;
790         unsigned int r = 1;
791         unsigned int q = cl_I_to_uint(a.R->modulus);
792
793         list<UniPoly>::iterator u = factors.begin();
794
795         while ( true ) {
796                 for ( unsigned int s=0; s<q; ++s ) {
797                         UniPoly g(a.R);
798                         UniPoly nur(a.R, nu[r]);
799                         nur.set(0, nur[0] - cl_MI(a.R, s));
800                         gcd(nur, *u, g);
801                         if ( !is_one(g) && g != *u ) {
802                                 UniPoly uo(a.R);
803                                 div(*u, g, uo);
804                                 if ( is_one(uo) ) {
805                                         throw logic_error("berlekamp: unexpected divisor.");
806                                 }
807                                 else {
808                                         *u = uo;
809                                 }
810                                 factors.push_back(g);
811                                 size = 0;
812                                 list<UniPoly>::const_iterator i = factors.begin(), end = factors.end();
813                                 while ( i != end ) {
814                                         if ( i->degree() ) ++size; 
815                                         ++i;
816                                 }
817                                 if ( size == k ) {
818                                         list<UniPoly>::const_iterator i = factors.begin(), end = factors.end();
819                                         while ( i != end ) {
820                                                 upv.push_back(*i++);
821                                         }
822                                         return;
823                                 }
824 //                              if ( u->degree() < nur.degree() ) {
825 //                                      break;
826 //                              }
827                         }
828                 }
829                 if ( ++r == k ) {
830                         r = 1;
831                         ++u;
832                 }
833         }
834 }
835
836 static void factor_modular(const UniPoly& p, UniPolyVec& upv)
837 {
838         berlekamp(p, upv);
839         return;
840 }
841
842 static void exteuclid(const UniPoly& a, const UniPoly& b, UniPoly& g, UniPoly& s, UniPoly& t)
843 {
844         if ( a.degree() < b.degree() ) {
845                 exteuclid(b, a, g, t, s);
846                 return;
847         }
848         UniPoly c1(a.R), c2(a.R), d1(a.R), d2(a.R), q(a.R), r(a.R), r1(a.R), r2(a.R);
849         UniPoly c = a; c.unit_normal();
850         UniPoly d = b; d.unit_normal();
851         c1.set(0, a.R->one());
852         d2.set(0, a.R->one());
853         while ( !d.zero() ) {
854                 q.terms.clear();
855                 div(c, d, q);
856                 r = c - q * d;
857                 r1 = c1 - q * d1;
858                 r2 = c2 - q * d2;
859                 c = d;
860                 c1 = d1;
861                 c2 = d2;
862                 d = r;
863                 d1 = r1;
864                 d2 = r2;
865         }
866         g = c; g.unit_normal();
867         s = c1;
868         s.divide(a.unit());
869         s.divide(c.unit());
870         t = c2;
871         t.divide(b.unit());
872         t.divide(c.unit());
873 }
874
875 static ex replace_lc(const ex& poly, const ex& x, const ex& lc)
876 {
877         ex r = expand(poly + (lc - poly.lcoeff(x)) * pow(x, poly.degree(x)));
878         return r;
879 }
880
881 static ex hensel_univar(const ex& a_, const ex& x, unsigned int p, const UniPoly& u1_, const UniPoly& w1_, const ex& gamma_ = 0)
882 {
883         ex a = a_;
884         const cl_modint_ring& R = u1_.R;
885
886         // calc bound B
887         ex maxcoeff;
888         for ( int i=a.degree(x); i>=a.ldegree(x); --i ) {
889                 maxcoeff += pow(abs(a.coeff(x, i)),2);
890         }
891         cl_I normmc = ceiling1(the<cl_R>(cln::sqrt(ex_to<numeric>(maxcoeff).to_cl_N())));
892         unsigned int maxdegree = (u1_.degree() > w1_.degree()) ? u1_.degree() : w1_.degree();
893         unsigned int B = cl_I_to_uint(normmc * expt_pos(cl_I(2), maxdegree));
894
895         // step 1
896         ex alpha = a.lcoeff(x);
897         ex gamma = gamma_;
898         if ( gamma == 0 ) {
899                 gamma = alpha;
900         }
901         unsigned int gamma_ui = ex_to<numeric>(abs(gamma)).to_int();
902         a = a * gamma;
903         UniPoly nu1 = u1_;
904         nu1.unit_normal();
905         UniPoly nw1 = w1_;
906         nw1.unit_normal();
907         ex phi;
908         phi = expand(gamma * nu1.to_ex(x));
909         UniPoly u1(R, phi, x);
910         phi = expand(alpha * nw1.to_ex(x));
911         UniPoly w1(R, phi, x);
912
913         // step 2
914         UniPoly s(R), t(R), g(R);
915         exteuclid(u1, w1, g, s, t);
916
917         // step 3
918         ex u = replace_lc(u1.to_ex(x), x, gamma);
919         ex w = replace_lc(w1.to_ex(x), x, alpha);
920         ex e = expand(a - u * w);
921         unsigned int modulus = p;
922
923         // step 4
924         while ( !e.is_zero() && modulus < 2*B*gamma_ui ) {
925                 ex c = e / modulus;
926                 phi = expand(s.to_ex(x)*c);
927                 UniPoly sigmatilde(R, phi, x);
928                 phi = expand(t.to_ex(x)*c);
929                 UniPoly tautilde(R, phi, x);
930                 UniPoly q(R), r(R);
931                 div(sigmatilde, w1, q);
932                 rem(sigmatilde, w1, r);
933                 UniPoly sigma = r;
934                 phi = expand(tautilde.to_ex(x) + q.to_ex(x) * u1.to_ex(x));
935                 UniPoly tau(R, phi, x);
936                 u = expand(u + tau.to_ex(x) * modulus);
937                 w = expand(w + sigma.to_ex(x) * modulus);
938                 e = expand(a - u * w);
939                 modulus = modulus * p;
940         }
941
942         // step 5
943         if ( e.is_zero() ) {
944                 ex delta = u.content(x);
945                 u = u / delta;
946                 w = w / gamma * delta;
947                 return lst(u, w);
948         }
949         else {
950                 return lst();
951         }
952 }
953
954 static unsigned int next_prime(unsigned int p)
955 {
956         static vector<unsigned int> primes;
957         if ( primes.size() == 0 ) {
958                 primes.push_back(3); primes.push_back(5); primes.push_back(7);
959         }
960         vector<unsigned int>::const_iterator it = primes.begin();
961         if ( p >= primes.back() ) {
962                 unsigned int candidate = primes.back() + 2;
963                 while ( true ) {
964                         size_t n = primes.size()/2;
965                         for ( size_t i=0; i<n; ++i ) {
966                                 if ( candidate % primes[i] ) continue;
967                                 candidate += 2;
968                                 i=-1;
969                         }
970                         primes.push_back(candidate);
971                         if ( candidate > p ) break;
972                 }
973                 return candidate;
974         }
975         vector<unsigned int>::const_iterator end = primes.end();
976         for ( ; it!=end; ++it ) {
977                 if ( *it > p ) {
978                         return *it;
979                 }
980         }
981         throw logic_error("next_prime: should not reach this point!");
982 }
983
984 class Partition
985 {
986 public:
987         Partition(size_t n_) : n(n_)
988         {
989                 k.resize(n, 1);
990                 k[0] = 0;
991                 sum = n-1;
992         }
993         int operator[](size_t i) const { return k[i]; }
994         size_t size() const { return n; }
995         size_t size_first() const { return n-sum; }
996         size_t size_second() const { return sum; }
997         bool next()
998         {
999                 for ( size_t i=n-1; i>=1; --i ) {
1000                         if ( k[i] ) {
1001                                 --k[i];
1002                                 --sum;
1003                                 return sum > 0;
1004                         }
1005                         ++k[i];
1006                         ++sum;
1007                 }
1008                 return false;
1009         }
1010 private:
1011         size_t n, sum;
1012         vector<int> k;
1013 };
1014
1015 static void split(const UniPolyVec& factors, const Partition& part, UniPoly& a, UniPoly& b)
1016 {
1017         a.set(0, a.R->one());
1018         b.set(0, a.R->one());
1019         for ( size_t i=0; i<part.size(); ++i ) {
1020                 if ( part[i] ) {
1021                         b = b * factors[i];
1022                 }
1023                 else {
1024                         a = a * factors[i];
1025                 }
1026         }
1027 }
1028
1029 struct ModFactors
1030 {
1031         ex poly;
1032         UniPolyVec factors;
1033 };
1034
1035 static ex factor_univariate(const ex& poly, const ex& x)
1036 {
1037         ex unit, cont, prim;
1038         poly.unitcontprim(x, unit, cont, prim);
1039
1040         // determine proper prime
1041         unsigned int p = 3;
1042         cl_modint_ring R = find_modint_ring(p);
1043         while ( true ) {
1044                 if ( irem(ex_to<numeric>(prim.lcoeff(x)), p) != 0 ) {
1045                         UniPoly modpoly(R, prim, x);
1046                         UniFactorVec sqrfree_ufv;
1047                         squarefree(modpoly, sqrfree_ufv);
1048                         if ( sqrfree_ufv.factors.size() == 1 && sqrfree_ufv.factors.front().exp == 1 ) break;
1049                 }
1050                 p = next_prime(p);
1051                 R = find_modint_ring(p);
1052         }
1053
1054         // do modular factorization
1055         UniPoly modpoly(R, prim, x);
1056         UniPolyVec factors;
1057         factor_modular(modpoly, factors);
1058         if ( factors.size() <= 1 ) {
1059                 // irreducible for sure
1060                 return poly;
1061         }
1062
1063         // lift all factor combinations
1064         stack<ModFactors> tocheck;
1065         ModFactors mf;
1066         mf.poly = prim;
1067         mf.factors = factors;
1068         tocheck.push(mf);
1069         ex result = 1;
1070         while ( tocheck.size() ) {
1071                 const size_t n = tocheck.top().factors.size();
1072                 Partition part(n);
1073                 while ( true ) {
1074                         UniPoly a(R), b(R);
1075                         split(tocheck.top().factors, part, a, b);
1076
1077                         ex answer = hensel_univar(tocheck.top().poly, x, p, a, b);
1078                         if ( answer != lst() ) {
1079                                 if ( part.size_first() == 1 ) {
1080                                         if ( part.size_second() == 1 ) {
1081                                                 result *= answer.op(0) * answer.op(1);
1082                                                 tocheck.pop();
1083                                                 break;
1084                                         }
1085                                         result *= answer.op(0);
1086                                         tocheck.top().poly = answer.op(1);
1087                                         for ( size_t i=0; i<n; ++i ) {
1088                                                 if ( part[i] == 0 ) {
1089                                                         tocheck.top().factors.erase(tocheck.top().factors.begin()+i);
1090                                                         break;
1091                                                 }
1092                                         }
1093                                         break;
1094                                 }
1095                                 else if ( part.size_second() == 1 ) {
1096                                         if ( part.size_first() == 1 ) {
1097                                                 result *= answer.op(0) * answer.op(1);
1098                                                 tocheck.pop();
1099                                                 break;
1100                                         }
1101                                         result *= answer.op(1);
1102                                         tocheck.top().poly = answer.op(0);
1103                                         for ( size_t i=0; i<n; ++i ) {
1104                                                 if ( part[i] == 1 ) {
1105                                                         tocheck.top().factors.erase(tocheck.top().factors.begin()+i);
1106                                                         break;
1107                                                 }
1108                                         }
1109                                         break;
1110                                 }
1111                                 else {
1112                                         UniPolyVec newfactors1(part.size_first(), R), newfactors2(part.size_second(), R);
1113                                         UniPolyVec::iterator i1 = newfactors1.begin(), i2 = newfactors2.begin();
1114                                         for ( size_t i=0; i<n; ++i ) {
1115                                                 if ( part[i] ) {
1116                                                         *i2++ = tocheck.top().factors[i];
1117                                                 }
1118                                                 else {
1119                                                         *i1++ = tocheck.top().factors[i];
1120                                                 }
1121                                         }
1122                                         tocheck.top().factors = newfactors1;
1123                                         tocheck.top().poly = answer.op(0);
1124                                         ModFactors mf;
1125                                         mf.factors = newfactors2;
1126                                         mf.poly = answer.op(1);
1127                                         tocheck.push(mf);
1128                                         break;
1129                                 }
1130                         }
1131                         else {
1132                                 if ( !part.next() ) {
1133                                         result *= tocheck.top().poly;
1134                                         tocheck.pop();
1135                                         break;
1136                                 }
1137                         }
1138                 }
1139         }
1140
1141         return unit * cont * result;
1142 }
1143
1144 struct FindSymbolsMap : public map_function {
1145         exset syms;
1146         ex operator()(const ex& e)
1147         {
1148                 if ( is_a<symbol>(e) ) {
1149                         syms.insert(e);
1150                         return e;
1151                 }
1152                 return e.map(*this);
1153         }
1154 };
1155
1156 static ex factor_sqrfree(const ex& poly)
1157 {
1158         // determine all symbols in poly
1159         FindSymbolsMap findsymbols;
1160         findsymbols(poly);
1161         if ( findsymbols.syms.size() == 0 ) {
1162                 return poly;
1163         }
1164
1165         if ( findsymbols.syms.size() == 1 ) {
1166                 const ex& x = *(findsymbols.syms.begin());
1167                 if ( poly.ldegree(x) > 0 ) {
1168                         int ld = poly.ldegree(x);
1169                         ex res = factor_univariate(expand(poly/pow(x, ld)), x);
1170                         return res * pow(x,ld);
1171                 }
1172                 else {
1173                         ex res = factor_univariate(poly, x);
1174                         return res;
1175                 }
1176         }
1177
1178         // multivariate case not yet implemented!
1179         throw runtime_error("multivariate case not yet implemented!");
1180 }
1181
1182 } // anonymous namespace
1183
1184 ex factor(const ex& poly)
1185 {
1186         // determine all symbols in poly
1187         FindSymbolsMap findsymbols;
1188         findsymbols(poly);
1189         if ( findsymbols.syms.size() == 0 ) {
1190                 return poly;
1191         }
1192         lst syms;
1193         exset::const_iterator i=findsymbols.syms.begin(), end=findsymbols.syms.end();
1194         for ( ; i!=end; ++i ) {
1195                 syms.append(*i);
1196         }
1197
1198         // make poly square free
1199         ex sfpoly = sqrfree(poly, syms);
1200
1201         // factorize the square free components
1202         if ( is_a<power>(sfpoly) ) {
1203                 // case: (polynomial)^exponent
1204                 const ex& base = sfpoly.op(0);
1205                 if ( !is_a<add>(base) ) {
1206                         // simple case: (monomial)^exponent
1207                         return sfpoly;
1208                 }
1209                 ex f = factor_sqrfree(base);
1210                 return pow(f, sfpoly.op(1));
1211         }
1212         if ( is_a<mul>(sfpoly) ) {
1213                 ex res = 1;
1214                 for ( size_t i=0; i<sfpoly.nops(); ++i ) {
1215                         const ex& t = sfpoly.op(i);
1216                         if ( is_a<power>(t) ) {
1217                                 const ex& base = t.op(0);
1218                                 if ( !is_a<add>(base) ) {
1219                                         res *= t;
1220                                 }
1221                                 else {
1222                                         ex f = factor_sqrfree(base);
1223                                         res *= pow(f, t.op(1));
1224                                 }
1225                         }
1226                         else if ( is_a<add>(t) ) {
1227                                 ex f = factor_sqrfree(t);
1228                                 res *= f;
1229                         }
1230                         else {
1231                                 res *= t;
1232                         }
1233                 }
1234                 return res;
1235         }
1236         if ( is_a<symbol>(sfpoly) ) {
1237                 return poly;
1238         }
1239         // case: (polynomial)
1240         ex f = factor_sqrfree(sfpoly);
1241         return f;
1242 }
1243
1244 } // namespace GiNaC