Added polynomial factorization (univariate case).
[ginac.git] / ginac / factor.cpp
1 /** @file factor.cpp
2  *
3  *  Polynomial factorization routines.
4  *  Only univariate at the moment and completely non-optimized!
5  */
6
7 /*
8  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
9  *
10  *  This program is free software; you can redistribute it and/or modify
11  *  it under the terms of the GNU General Public License as published by
12  *  the Free Software Foundation; either version 2 of the License, or
13  *  (at your option) any later version.
14  *
15  *  This program is distributed in the hope that it will be useful,
16  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18  *  GNU General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with this program; if not, write to the Free Software
22  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
23  */
24
25 #include "factor.h"
26
27 #include "ex.h"
28 #include "numeric.h"
29 #include "operators.h"
30 #include "inifcns.h"
31 #include "symbol.h"
32 #include "relational.h"
33 #include "power.h"
34 #include "mul.h"
35 #include "normal.h"
36 #include "add.h"
37
38 #include <algorithm>
39 #include <list>
40 #include <vector>
41 using namespace std;
42
43 #include <cln/cln.h>
44 using namespace cln;
45
46 //#define DEBUGFACTOR
47
48 #ifdef DEBUGFACTOR
49 #include <ostream>
50 #endif // def DEBUGFACTOR
51
52 namespace GiNaC {
53
54 namespace {
55
56 typedef vector<cl_MI> Vec;
57 typedef vector<Vec> VecVec;
58
59 #ifdef DEBUGFACTOR
60 ostream& operator<<(ostream& o, const Vec& v)
61 {
62         Vec::const_iterator i = v.begin(), end = v.end();
63         while ( i != end ) {
64                 o << *i++ << " ";
65         }
66         return o;
67 }
68 #endif // def DEBUGFACTOR
69
70 #ifdef DEBUGFACTOR
71 ostream& operator<<(ostream& o, const VecVec& v)
72 {
73         VecVec::const_iterator i = v.begin(), end = v.end();
74         while ( i != end ) {
75                 o << *i++ << endl;
76         }
77         return o;
78 }
79 #endif // def DEBUGFACTOR
80
81 struct Term
82 {
83         cl_MI c;          // coefficient
84         unsigned int exp; // exponent >=0
85 };
86
87 #ifdef DEBUGFACTOR
88 ostream& operator<<(ostream& o, const Term& t)
89 {
90         if ( t.exp ) {
91                 o << "(" << t.c << ")x^" << t.exp;
92         }
93         else {
94                 o << "(" << t.c << ")";
95         }
96         return o;
97 }
98 #endif // def DEBUGFACTOR
99
100 struct UniPoly
101 {
102         cl_modint_ring R;
103         list<Term> terms;  // highest exponent first
104
105         UniPoly(const cl_modint_ring& ring) : R(ring) { }
106         UniPoly(const cl_modint_ring& ring, const ex& poly, const ex& x) : R(ring)
107         { 
108                 // assert: poly is in Z[x]
109                 Term t;
110                 for ( int i=poly.degree(x); i>=poly.ldegree(x); --i ) {
111                         int coeff = ex_to<numeric>(poly.coeff(x,i)).to_int();
112                         if ( coeff ) {
113                                 t.c = R->canonhom(coeff);
114                                 if ( !zerop(t.c) ) {
115                                         t.exp = i;
116                                         terms.push_back(t);
117                                 }
118                         }
119                 }
120         }
121         UniPoly(const cl_modint_ring& ring, const Vec& v) : R(ring)
122         {
123                 Term t;
124                 for ( unsigned int i=0; i<v.size(); ++i ) {
125                         if ( !zerop(v[i]) ) {
126                                 t.c = v[i];
127                                 t.exp = i;
128                                 terms.push_front(t);
129                         }
130                 }
131         }
132         unsigned int degree() const
133         {
134                 if ( terms.size() ) {
135                         return terms.front().exp;
136                 }
137                 else {
138                         return 0;
139                 }
140         }
141         bool zero() const { return (terms.size() == 0); }
142         const cl_MI operator[](unsigned int deg) const
143         {
144                 list<Term>::const_iterator i = terms.begin(), end = terms.end();
145                 for ( ; i != end; ++i ) {
146                         if ( i->exp == deg ) {
147                                 return i->c;
148                         }
149                         if ( i->exp < deg ) {
150                                 break;
151                         }
152                 }
153                 return R->zero();
154         }
155         void set(unsigned int deg, const cl_MI& c)
156         {
157                 list<Term>::iterator i = terms.begin(), end = terms.end();
158                 while ( i != end ) {
159                         if ( i->exp == deg ) {
160                                 if ( !zerop(c) ) {
161                                         i->c = c;
162                                 }
163                                 else {
164                                         terms.erase(i);
165                                 }
166                                 return;
167                         }
168                         if ( i->exp < deg ) {
169                                 break;
170                         }
171                         ++i;
172                 }
173                 if ( !zerop(c) ) {
174                         Term t;
175                         t.c = c;
176                         t.exp = deg;
177                         terms.insert(i, t);
178                 }
179         }
180         ex to_ex(const ex& x, bool symmetric = true) const
181         {
182                 ex r;
183                 list<Term>::const_iterator i = terms.begin(), end = terms.end();
184                 if ( symmetric ) {
185                         numeric mod(R->modulus);
186                         numeric halfmod = (mod-1)/2;
187                         for ( ; i != end; ++i ) {
188                                 numeric n(R->retract(i->c));
189                                 if ( n > halfmod ) {
190                                         r += pow(x, i->exp) * (n-mod);
191                                 }
192                                 else {
193                                         r += pow(x, i->exp) * n;
194                                 }
195                         }
196                 }
197                 else {
198                         for ( ; i != end; ++i ) {
199                                 r += pow(x, i->exp) * numeric(R->retract(i->c));
200                         }
201                 }
202                 return r;
203         }
204         void unit_normal()
205         {
206                 if ( terms.size() ) {
207                         if ( terms.front().c != R->one() ) {
208                                 list<Term>::iterator i = terms.begin(), end = terms.end();
209                                 cl_MI cont = i->c;
210                                 i->c = R->one();
211                                 while ( ++i != end ) {
212                                         i->c = div(i->c, cont);
213                                         if ( zerop(i->c) ) {
214                                                 terms.erase(i);
215                                         }
216                                 }
217                         }
218                 }
219         }
220         cl_MI unit() const
221         {
222                 return terms.front().c;
223         }
224         void divide(const cl_MI& x)
225         {
226                 list<Term>::iterator i = terms.begin(), end = terms.end();
227                 for ( ; i != end; ++i ) {
228                         i->c = div(i->c, x);
229                         if ( zerop(i->c) ) {
230                                 terms.erase(i);
231                         }
232                 }
233         }
234         void reduce_exponents(unsigned int prime)
235         {
236                 list<Term>::iterator i = terms.begin(), end = terms.end();
237                 while ( i != end ) {
238                         if ( i->exp > 0 ) {
239                                 // assert: i->exp is multiple of prime
240                                 i->exp /= prime;
241                         }
242                         ++i;
243                 }
244         }
245         void deriv(UniPoly& d) const
246         {
247                 list<Term>::const_iterator i = terms.begin(), end = terms.end();
248                 while ( i != end ) {
249                         if ( i->exp ) {
250                                 cl_MI newc = i->c * i->exp;
251                                 if ( !zerop(newc) ) {
252                                         Term t;
253                                         t.c = newc;
254                                         t.exp = i->exp-1;
255                                         d.terms.push_back(t);
256                                 }
257                         }
258                         ++i;
259                 }
260         }
261         bool operator<(const UniPoly& o) const
262         {
263                 if ( terms.size() != o.terms.size() ) {
264                         return terms.size() < o.terms.size();
265                 }
266                 list<Term>::const_iterator i1 = terms.begin(), end = terms.end();
267                 list<Term>::const_iterator i2 = o.terms.begin();
268                 while ( i1 != end ) {
269                         if ( i1->exp != i2->exp ) {
270                                 return i1->exp < i2->exp;
271                         }
272                         if ( i1->c != i2->c ) {
273                                 return R->retract(i1->c) < R->retract(i2->c);
274                         }
275                         ++i1; ++i2;
276                 }
277                 return true;
278         }
279         bool operator==(const UniPoly& o) const
280         {
281                 if ( terms.size() != o.terms.size() ) {
282                         return false;
283                 }
284                 list<Term>::const_iterator i1 = terms.begin(), end = terms.end();
285                 list<Term>::const_iterator i2 = o.terms.begin();
286                 while ( i1 != end ) {
287                         if ( i1->exp != i2->exp ) {
288                                 return false;
289                         }
290                         if ( i1->c != i2->c ) {
291                                 return false;
292                         }
293                         ++i1; ++i2;
294                 }
295                 return true;
296         }
297         bool operator!=(const UniPoly& o) const
298         {
299                 bool res = !(*this == o);
300                 return res;
301         }
302 };
303
304 static UniPoly operator*(const UniPoly& a, const UniPoly& b)
305 {
306         unsigned int n = a.degree()+b.degree();
307         UniPoly c(a.R);
308         Term t;
309         for ( unsigned int i=0 ; i<=n; ++i ) {
310                 t.c = a.R->zero();
311                 for ( unsigned int j=0 ; j<=i; ++j ) {
312                         t.c = t.c + a[j] * b[i-j];
313                 }
314                 if ( !zerop(t.c) ) {
315                         t.exp = i;
316                         c.terms.push_front(t);
317                 }
318         }
319         return c;
320 }
321
322 static UniPoly operator-(const UniPoly& a, const UniPoly& b)
323 {
324         list<Term>::const_iterator ia = a.terms.begin(), aend = a.terms.end();
325         list<Term>::const_iterator ib = b.terms.begin(), bend = b.terms.end();
326         UniPoly c(a.R);
327         while ( ia != aend && ib != bend ) {
328                 if ( ia->exp > ib->exp ) {
329                         c.terms.push_back(*ia);
330                         ++ia;
331                 }
332                 else if ( ia->exp < ib->exp ) {
333                         c.terms.push_back(*ib);
334                         c.terms.back().c = -c.terms.back().c;
335                         ++ib;
336                 }
337                 else {
338                         Term t;
339                         t.exp = ia->exp;
340                         t.c = ia->c - ib->c;
341                         if ( !zerop(t.c) ) {
342                                 c.terms.push_back(t);
343                         }
344                         ++ia; ++ib;
345                 }
346         }
347         while ( ia != aend ) {
348                 c.terms.push_back(*ia);
349                 ++ia;
350         }
351         while ( ib != bend ) {
352                 c.terms.push_back(*ib);
353                 c.terms.back().c = -c.terms.back().c;
354                 ++ib;
355         }
356         return c;
357 }
358
359 static UniPoly operator-(const UniPoly& a)
360 {
361         list<Term>::const_iterator ia = a.terms.begin(), aend = a.terms.end();
362         UniPoly c(a.R);
363         while ( ia != aend ) {
364                 c.terms.push_back(*ia);
365                 c.terms.back().c = -c.terms.back().c;
366                 ++ia;
367         }
368         return c;
369 }
370
371 #ifdef DEBUGFACTOR
372 ostream& operator<<(ostream& o, const UniPoly& t)
373 {
374         list<Term>::const_iterator i = t.terms.begin(), end = t.terms.end();
375         if ( i == end ) {
376                 o << "0";
377                 return o;
378         }
379         for ( ; i != end; ) {
380                 o << *i++;
381                 if ( i != end ) {
382                         o << " + ";
383                 }
384         }
385         return o;
386 }
387 #endif // def DEBUGFACTOR
388
389 #ifdef DEBUGFACTOR
390 ostream& operator<<(ostream& o, const list<UniPoly>& t)
391 {
392         list<UniPoly>::const_iterator i = t.begin(), end = t.end();
393         o << "{" << endl;
394         for ( ; i != end; ) {
395                 o << *i++ << endl;
396         }
397         o << "}" << endl;
398         return o;
399 }
400 #endif // def DEBUGFACTOR
401
402 typedef vector<UniPoly> UniPolyVec;
403
404 struct UniFactor
405 {
406         UniPoly p;
407         unsigned int exp;
408
409         UniFactor(const cl_modint_ring& ring) : p(ring) { }
410         UniFactor(const UniPoly& p_, unsigned int exp_) : p(p_), exp(exp_) { }
411         bool operator<(const UniFactor& o) const
412         {
413                 return p < o.p;
414         }
415 };
416
417 struct UniFactorVec
418 {
419         vector<UniFactor> factors;
420
421         void unique()
422         {
423                 sort(factors.begin(), factors.end());
424                 if ( factors.size() > 1 ) {
425                         vector<UniFactor>::iterator i = factors.begin();
426                         vector<UniFactor>::const_iterator cmp = factors.begin()+1;
427                         vector<UniFactor>::iterator end = factors.end();
428                         while ( cmp != end ) {
429                                 if ( i->p != cmp->p ) {
430                                         ++i;
431                                         ++cmp;
432                                 }
433                                 else {
434                                         i->exp += cmp->exp;
435                                         ++cmp;
436                                 }
437                         }
438                         if ( i != end-1 ) {
439                                 factors.erase(i+1, end);
440                         }
441                 }
442         }
443 };
444
445 #ifdef DEBUGFACTOR
446 ostream& operator<<(ostream& o, const UniFactorVec& ufv)
447 {
448         for ( size_t i=0; i<ufv.factors.size(); ++i ) {
449                 if ( i != ufv.factors.size()-1 ) {
450                         o << "*";
451                 }
452                 else {
453                         o << " ";
454                 }
455                 o << "[ " << ufv.factors[i].p << " ]^" << ufv.factors[i].exp << endl;
456         }
457         return o;
458 }
459 #endif // def DEBUGFACTOR
460
461 static void rem(const UniPoly& a_, const UniPoly& b, UniPoly& c)
462 {
463         if ( a_.degree() < b.degree() ) {
464                 c = a_;
465                 return;
466         }
467
468         unsigned int k, n;
469         n = b.degree();
470         k = a_.degree() - n;
471
472         if ( n == 0 ) {
473                 c.terms.clear();
474                 return;
475         }
476
477         c = a_;
478         Term termbuf;
479
480         while ( true ) {
481                 cl_MI qk = div(c[n+k], b[n]);
482                 if ( !zerop(qk) ) {
483                         unsigned int j;
484                         for ( unsigned int i=0; i<n; ++i ) {
485                                 j = n + k - 1 - i;
486                                 c.set(j, c[j] - qk*b[j-k]);
487                         }
488                 }
489                 if ( k == 0 ) break;
490                 --k;
491         }
492         list<Term>::iterator i = c.terms.begin(), end = c.terms.end();
493         while ( i != end ) {
494                 if ( i->exp <= n-1 ) {
495                         break;
496                 }
497                 ++i;
498         }
499         c.terms.erase(c.terms.begin(), i);
500 }
501
502 static void div(const UniPoly& a_, const UniPoly& b, UniPoly& q)
503 {
504         if ( a_.degree() < b.degree() ) {
505                 q.terms.clear();
506                 return;
507         }
508
509         unsigned int k, n;
510         n = b.degree();
511         k = a_.degree() - n;
512
513         UniPoly c = a_;
514         Term termbuf;
515
516         while ( true ) {
517                 cl_MI qk = div(c[n+k], b[n]);
518                 if ( !zerop(qk) ) {
519                         Term t;
520                         t.c = qk;
521                         t.exp = k;
522                         q.terms.push_back(t);
523                         unsigned int j;
524                         for ( unsigned int i=0; i<n; ++i ) {
525                                 j = n + k - 1 - i;
526                                 c.set(j, c[j] - qk*b[j-k]);
527                         }
528                 }
529                 if ( k == 0 ) break;
530                 --k;
531         }
532 }
533
534 static void gcd(const UniPoly& a, const UniPoly& b, UniPoly& c)
535 {
536         c = a;
537         c.unit_normal();
538         UniPoly d = b;
539         d.unit_normal();
540
541         if ( c.degree() < d.degree() ) {
542                 gcd(b, a, c);
543                 return;
544         }
545
546         while ( !d.zero() ) {
547                 UniPoly r(a.R);
548                 rem(c, d, r);
549                 c = d;
550                 d = r;
551         }
552         c.unit_normal();
553 }
554
555 static bool is_one(const UniPoly& w)
556 {
557         if ( w.terms.size() == 1 && w[0] == w.R->one() ) {
558                 return true;
559         }
560         return false;
561 }
562
563 static void sqrfree_main(const UniPoly& a, UniFactorVec& fvec)
564 {
565         unsigned int i = 1;
566         UniPoly b(a.R);
567         a.deriv(b);
568         if ( !b.zero() ) {
569                 UniPoly c(a.R), w(a.R);
570                 gcd(a, b, c);
571                 div(a, c, w);
572                 while ( !is_one(w) ) {
573                         UniPoly y(a.R), z(a.R);
574                         gcd(w, c, y);
575                         div(w, y, z);
576                         if ( !is_one(z) ) {
577                                 UniFactor uf(z, i++);
578                                 fvec.factors.push_back(uf);
579                         }
580                         w = y;
581                         UniPoly cbuf(a.R);
582                         div(c, y, cbuf);
583                         c = cbuf;
584                 }
585                 if ( !is_one(c) ) {
586                         unsigned int prime = cl_I_to_uint(c.R->modulus);
587                         c.reduce_exponents(prime);
588                         unsigned int pos = fvec.factors.size();
589                         sqrfree_main(c, fvec);
590                         for ( unsigned int p=pos; p<fvec.factors.size(); ++p ) {
591                                 fvec.factors[p].exp *= prime;
592                         }
593                         return;
594                 }
595         }
596         else {
597                 unsigned int prime = cl_I_to_uint(a.R->modulus);
598                 UniPoly amod = a;
599                 amod.reduce_exponents(prime);
600                 unsigned int pos = fvec.factors.size();
601                 sqrfree_main(amod, fvec);
602                 for ( unsigned int p=pos; p<fvec.factors.size(); ++p ) {
603                         fvec.factors[p].exp *= prime;
604                 }
605                 return;
606         }
607 }
608
609 static void squarefree(const UniPoly& a, UniFactorVec& fvec)
610 {
611         sqrfree_main(a, fvec);
612         fvec.unique();
613 }
614
615 class Matrix
616 {
617         friend ostream& operator<<(ostream& o, const Matrix& m);
618 public:
619         Matrix(size_t r_, size_t c_, const cl_MI& init) : r(r_), c(c_)
620         {
621                 m.resize(c*r, init);
622         }
623         size_t rowsize() const { return r; }
624         size_t colsize() const { return c; }
625         cl_MI& operator()(size_t row, size_t col) { return m[row*c + col]; }
626         cl_MI operator()(size_t row, size_t col) const { return m[row*c + col]; }
627         void mul_col(size_t col, const cl_MI x)
628         {
629                 Vec::iterator i = m.begin() + col;
630                 for ( size_t rc=0; rc<r; ++rc ) {
631                         *i = *i * x;
632                         i += c;
633                 }
634         }
635         void sub_col(size_t col1, size_t col2, const cl_MI fac)
636         {
637                 Vec::iterator i1 = m.begin() + col1;
638                 Vec::iterator i2 = m.begin() + col2;
639                 for ( size_t rc=0; rc<r; ++rc ) {
640                         *i1 = *i1 - *i2 * fac;
641                         i1 += c;
642                         i2 += c;
643                 }
644         }
645         void switch_col(size_t col1, size_t col2)
646         {
647                 cl_MI buf;
648                 Vec::iterator i1 = m.begin() + col1;
649                 Vec::iterator i2 = m.begin() + col2;
650                 for ( size_t rc=0; rc<r; ++rc ) {
651                         buf = *i1; *i1 = *i2; *i2 = buf;
652                         i1 += c;
653                         i2 += c;
654                 }
655         }
656         bool is_row_zero(size_t row) const
657         {
658                 Vec::const_iterator i = m.begin() + row*c;
659                 for ( size_t cc=0; cc<c; ++cc ) {
660                         if ( !zerop(*i) ) {
661                                 return false;
662                         }
663                         ++i;
664                 }
665                 return true;
666         }
667         void set_row(size_t row, const vector<cl_MI>& newrow)
668         {
669                 Vec::iterator i1 = m.begin() + row*c;
670                 Vec::const_iterator i2 = newrow.begin(), end = newrow.end();
671                 for ( ; i2 != end; ++i1, ++i2 ) {
672                         *i1 = *i2;
673                 }
674         }
675         Vec::const_iterator row_begin(size_t row) const { return m.begin()+row*c; }
676         Vec::const_iterator row_end(size_t row) const { return m.begin()+row*c+r; }
677 private:
678         size_t r, c;
679         Vec m;
680 };
681
682 #ifdef DEBUGFACTOR
683 ostream& operator<<(ostream& o, const Matrix& m)
684 {
685         vector<cl_MI>::const_iterator i = m.m.begin(), end = m.m.end();
686         size_t wrap = 1;
687         for ( ; i != end; ++i ) {
688                 o << *i << " ";
689                 if ( !(wrap++ % m.c) ) {
690                         o << endl;
691                 }
692         }
693         o << endl;
694         return o;
695 }
696 #endif // def DEBUGFACTOR
697
698 static void q_matrix(const UniPoly& a, Matrix& Q)
699 {
700         unsigned int n = a.degree();
701         unsigned int q = cl_I_to_uint(a.R->modulus);
702         vector<cl_MI> r(n, a.R->zero());
703         r[0] = a.R->one();
704         Q.set_row(0, r);
705         unsigned int max = (n-1) * q;
706         for ( size_t m=1; m<=max; ++m ) {
707                 cl_MI rn_1 = r.back();
708                 for ( size_t i=n-1; i>0; --i ) {
709                         r[i] = r[i-1] - rn_1 * a[i];
710                 }
711                 r[0] = -rn_1 * a[0];
712                 if ( (m % q) == 0 ) {
713                         Q.set_row(m/q, r);
714                 }
715         }
716 }
717
718 static void nullspace(Matrix& M, vector<Vec>& basis)
719 {
720         const size_t n = M.rowsize();
721         const cl_MI one = M(0,0).ring()->one();
722         for ( size_t i=0; i<n; ++i ) {
723                 M(i,i) = M(i,i) - one;
724         }
725         for ( size_t r=0; r<n; ++r ) {
726                 size_t cc = 0;
727                 for ( ; cc<n; ++cc ) {
728                         if ( !zerop(M(r,cc)) ) {
729                                 if ( cc < r ) {
730                                         if ( !zerop(M(cc,cc)) ) {
731                                                 continue;
732                                         }
733                                         M.switch_col(cc, r);
734                                 }
735                                 else if ( cc > r ) {
736                                         M.switch_col(cc, r);
737                                 }
738                                 break;
739                         }
740                 }
741                 if ( cc < n ) {
742                         M.mul_col(r, recip(M(r,r)));
743                         for ( cc=0; cc<n; ++cc ) {
744                                 if ( cc != r ) {
745                                         M.sub_col(cc, r, M(r,cc));
746                                 }
747                         }
748                 }
749         }
750
751         for ( size_t i=0; i<n; ++i ) {
752                 M(i,i) = M(i,i) - one;
753         }
754         for ( size_t i=0; i<n; ++i ) {
755                 if ( !M.is_row_zero(i) ) {
756                         Vec nu(M.row_begin(i), M.row_end(i));
757                         basis.push_back(nu);
758                 }
759         }
760 }
761
762 static void berlekamp(const UniPoly& a, UniPolyVec& upv)
763 {
764         Matrix Q(a.degree(), a.degree(), a.R->zero());
765         q_matrix(a, Q);
766         VecVec nu;
767         nullspace(Q, nu);
768         const unsigned int k = nu.size();
769         if ( k == 1 ) {
770                 return;
771         }
772
773         list<UniPoly> factors;
774         factors.push_back(a);
775         unsigned int size = 1;
776         unsigned int r = 1;
777         unsigned int q = cl_I_to_uint(a.R->modulus);
778
779         list<UniPoly>::iterator u = factors.begin();
780
781         while ( true ) {
782                 for ( unsigned int s=0; s<q; ++s ) {
783                         UniPoly g(a.R);
784                         UniPoly nur(a.R, nu[r]);
785                         nur.set(0, nur[0] - cl_MI(a.R, s));
786                         gcd(nur, *u, g);
787                         if ( !is_one(g) && g != *u ) {
788                                 UniPoly uo(a.R);
789                                 div(*u, g, uo);
790                                 if ( is_one(uo) ) {
791                                         throw logic_error("berlekamp: unexpected divisor.");
792                                 }
793                                 else {
794                                         *u = uo;
795                                 }
796                                 factors.push_back(g);
797                                 ++size;
798                                 if ( size == k ) {
799                                         list<UniPoly>::const_iterator i = factors.begin(), end = factors.end();
800                                         while ( i != end ) {
801                                                 upv.push_back(*i++);
802                                         }
803                                         return;
804                                 }
805                                 if ( u->degree() < nur.degree() ) {
806                                         break;
807                                 }
808                         }
809                 }
810                 if ( ++r == k ) {
811                         r = 1;
812                         ++u;
813                 }
814         }
815 }
816
817 static void factor_modular(const UniPoly& p, UniPolyVec& upv)
818 {
819         berlekamp(p, upv);
820         return;
821 }
822
823 static void exteuclid(const UniPoly& a, const UniPoly& b, UniPoly& g, UniPoly& s, UniPoly& t)
824 {
825         if ( a.degree() < b.degree() ) {
826                 exteuclid(b, a, g, t, s);
827                 return;
828         }
829         UniPoly c1(a.R), c2(a.R), d1(a.R), d2(a.R), q(a.R), r(a.R), r1(a.R), r2(a.R);
830         UniPoly c = a; c.unit_normal();
831         UniPoly d = b; d.unit_normal();
832         c1.set(0, a.R->one());
833         d2.set(0, a.R->one());
834         while ( !d.zero() ) {
835                 q.terms.clear();
836                 div(c, d, q);
837                 r = c - q * d;
838                 r1 = c1 - q * d1;
839                 r2 = c2 - q * d2;
840                 c = d;
841                 c1 = d1;
842                 c2 = d2;
843                 d = r;
844                 d1 = r1;
845                 d2 = r2;
846         }
847         g = c; g.unit_normal();
848         s = c1;
849         s.divide(a.unit());
850         s.divide(c.unit());
851         t = c2;
852         t.divide(b.unit());
853         t.divide(c.unit());
854 }
855
856 static ex replace_lc(const ex& poly, const ex& x, const ex& lc)
857 {
858         ex r = expand(poly + (lc - poly.lcoeff(x)) * pow(x, poly.degree(x)));
859         return r;
860 }
861
862 static ex hensel_univar(const ex& a_, const ex& x, unsigned int p, const UniPoly& u1_, const UniPoly& w1_, const ex& gamma_ = 0)
863 {
864         ex a = a_;
865         const cl_modint_ring& R = u1_.R;
866
867         // calc bound B
868         ex maxcoeff;
869         for ( int i=a.degree(x); i>=a.ldegree(x); --i ) {
870                 maxcoeff += pow(abs(a.coeff(x, i)),2);
871         }
872         cl_I normmc = ceiling1(the<cl_F>(cln::sqrt(ex_to<numeric>(maxcoeff).to_cl_N())));
873         unsigned int maxdegree = (u1_.degree() > w1_.degree()) ? u1_.degree() : w1_.degree();
874         unsigned int B = cl_I_to_uint(normmc * expt_pos(cl_I(2), maxdegree));
875
876         // step 1
877         ex alpha = a.lcoeff(x);
878         ex gamma = gamma_;
879         if ( gamma == 0 ) {
880                 gamma = alpha;
881         }
882         unsigned int gamma_ui = ex_to<numeric>(abs(gamma)).to_int();
883         a = a * gamma;
884         UniPoly nu1 = u1_;
885         nu1.unit_normal();
886         UniPoly nw1 = w1_;
887         nw1.unit_normal();
888         ex phi;
889         phi = expand(gamma * nu1.to_ex(x));
890         UniPoly u1(R, phi, x);
891         phi = expand(alpha * nw1.to_ex(x));
892         UniPoly w1(R, phi, x);
893
894         // step 2
895         UniPoly s(R), t(R), g(R);
896         exteuclid(u1, w1, g, s, t);
897
898         // step 3
899         ex u = replace_lc(u1.to_ex(x), x, gamma);
900         ex w = replace_lc(w1.to_ex(x), x, alpha);
901         ex e = expand(a - u * w);
902         unsigned int modulus = p;
903
904         // step 4
905         while ( !e.is_zero() && modulus < 2*B*gamma_ui ) {
906                 ex c = e / modulus;
907                 phi = expand(s.to_ex(x)*c);
908                 UniPoly sigmatilde(R, phi, x);
909                 phi = expand(t.to_ex(x)*c);
910                 UniPoly tautilde(R, phi, x);
911                 UniPoly q(R), r(R);
912                 div(sigmatilde, w1, q);
913                 rem(sigmatilde, w1, r);
914                 UniPoly sigma = r;
915                 phi = expand(tautilde.to_ex(x) + q.to_ex(x) * u1.to_ex(x));
916                 UniPoly tau(R, phi, x);
917                 u = expand(u + tau.to_ex(x) * modulus);
918                 w = expand(w + sigma.to_ex(x) * modulus);
919                 e = expand(a - u * w);
920                 modulus = modulus * p;
921         }
922
923         // step 5
924         if ( e.is_zero() ) {
925                 ex delta = u.content(x);
926                 u = u / delta;
927                 w = w / gamma * delta;
928                 return lst(u, w);
929         }
930         else {
931                 return lst();
932         }
933 }
934
935 static unsigned int next_prime(unsigned int p)
936 {
937         static vector<unsigned int> primes;
938         if ( primes.size() == 0 ) {
939                 primes.push_back(3); primes.push_back(5); primes.push_back(7);
940         }
941         vector<unsigned int>::const_iterator it = primes.begin();
942         if ( p >= primes.back() ) {
943                 unsigned int candidate = primes.back() + 2;
944                 while ( true ) {
945                         size_t n = primes.size()/2;
946                         for ( size_t i=0; i<n; ++i ) {
947                                 if ( candidate % primes[i] ) continue;
948                                 candidate += 2;
949                                 i=-1;
950                         }
951                         primes.push_back(candidate);
952                         if ( candidate > p ) break;
953                 }
954                 return candidate;
955         }
956         vector<unsigned int>::const_iterator end = primes.end();
957         for ( ; it!=end; ++it ) {
958                 if ( *it > p ) {
959                         return *it;
960                 }
961         }
962         throw logic_error("next_prime: should not reach this point!");
963 }
964
965 class Partition
966 {
967 public:
968         Partition(size_t n_) : n(n_)
969         {
970                 k.resize(n, 1);
971                 k[0] = 0;
972                 sum = n-1;
973         }
974         int operator[](size_t i) const { return k[i]; }
975         size_t size() const { return n; }
976         size_t size_first() const { return n-sum; }
977         size_t size_second() const { return sum; }
978         bool next()
979         {
980                 for ( size_t i=n-1; i>=1; --i ) {
981                         if ( k[i] ) {
982                                 --k[i];
983                                 --sum;
984                                 return sum > 0;
985                         }
986                         ++k[i];
987                         ++sum;
988                 }
989                 return false;
990         }
991 private:
992         size_t n, sum;
993         vector<int> k;
994 };
995
996 static void split(const UniPolyVec& factors, const Partition& part, UniPoly& a, UniPoly& b)
997 {
998         a.set(0, a.R->one());
999         b.set(0, a.R->one());
1000         for ( size_t i=0; i<part.size(); ++i ) {
1001                 if ( part[i] ) {
1002                         b = b * factors[i];
1003                 }
1004                 else {
1005                         a = a * factors[i];
1006                 }
1007         }
1008 }
1009
1010 struct ModFactors
1011 {
1012         ex poly;
1013         UniPolyVec factors;
1014 };
1015
1016 static ex factor_univariate(const ex& poly, const ex& x)
1017 {
1018         ex unit, cont, prim;
1019         poly.unitcontprim(x, unit, cont, prim);
1020
1021         // determine proper prime
1022         unsigned int p = 3;
1023         cl_modint_ring R = find_modint_ring(p);
1024         while ( true ) {
1025                 if ( irem(ex_to<numeric>(prim.lcoeff(x)), p) != 0 ) {
1026                         UniPoly modpoly(R, prim, x);
1027                         UniFactorVec sqrfree_ufv;
1028                         squarefree(modpoly, sqrfree_ufv);
1029                         if ( sqrfree_ufv.factors.size() == 1 ) break;
1030                 }
1031                 p = next_prime(p);
1032                 R = find_modint_ring(p);
1033         }
1034
1035         // do modular factorization
1036         UniPoly modpoly(R, prim, x);
1037         UniPolyVec factors;
1038         factor_modular(modpoly, factors);
1039         if ( factors.size() <= 1 ) {
1040                 // irreducible for sure
1041                 return poly;
1042         }
1043
1044         // lift all factor combinations
1045         stack<ModFactors> tocheck;
1046         ModFactors mf;
1047         mf.poly = prim;
1048         mf.factors = factors;
1049         tocheck.push(mf);
1050         ex result = 1;
1051         while ( tocheck.size() ) {
1052                 const size_t n = tocheck.top().factors.size();
1053                 Partition part(n);
1054                 while ( true ) {
1055                         UniPoly a(R), b(R);
1056                         split(tocheck.top().factors, part, a, b);
1057
1058                         ex answer = hensel_univar(tocheck.top().poly, x, p, a, b);
1059                         if ( answer != lst() ) {
1060                                 if ( part.size_first() == 1 ) {
1061                                         if ( part.size_second() == 1 ) {
1062                                                 result *= answer.op(0) * answer.op(1);
1063                                                 tocheck.pop();
1064                                                 break;
1065                                         }
1066                                         result *= answer.op(0);
1067                                         tocheck.top().poly = answer.op(1);
1068                                         for ( size_t i=0; i<n; ++i ) {
1069                                                 if ( part[i] == 0 ) {
1070                                                         tocheck.top().factors.erase(tocheck.top().factors.begin()+i);
1071                                                         break;
1072                                                 }
1073                                         }
1074                                         break;
1075                                 }
1076                                 else if ( part.size_second() == 1 ) {
1077                                         if ( part.size_first() == 1 ) {
1078                                                 result *= answer.op(0) * answer.op(1);
1079                                                 tocheck.pop();
1080                                                 break;
1081                                         }
1082                                         result *= answer.op(1);
1083                                         tocheck.top().poly = answer.op(0);
1084                                         for ( size_t i=0; i<n; ++i ) {
1085                                                 if ( part[i] == 1 ) {
1086                                                         tocheck.top().factors.erase(tocheck.top().factors.begin()+i);
1087                                                         break;
1088                                                 }
1089                                         }
1090                                         break;
1091                                 }
1092                                 else {
1093                                         UniPolyVec newfactors1(part.size_first(), R), newfactors2(part.size_second(), R);
1094                                         UniPolyVec::iterator i1 = newfactors1.begin(), i2 = newfactors2.begin();
1095                                         for ( size_t i=0; i<n; ++i ) {
1096                                                 if ( part[i] ) {
1097                                                         *i2++ = tocheck.top().factors[i];
1098                                                 }
1099                                                 else {
1100                                                         *i1++ = tocheck.top().factors[i];
1101                                                 }
1102                                         }
1103                                         tocheck.top().factors = newfactors1;
1104                                         tocheck.top().poly = answer.op(0);
1105                                         ModFactors mf;
1106                                         mf.factors = newfactors2;
1107                                         mf.poly = answer.op(1);
1108                                         tocheck.push(mf);
1109                                 }
1110                         }
1111                         else {
1112                                 if ( !part.next() ) {
1113                                         result *= tocheck.top().poly;
1114                                         tocheck.pop();
1115                                         break;
1116                                 }
1117                         }
1118                 }
1119         }
1120
1121         return unit * cont * result;
1122 }
1123
1124 struct FindSymbolsMap : public map_function {
1125         exset syms;
1126         ex operator()(const ex& e)
1127         {
1128                 if ( is_a<symbol>(e) ) {
1129                         syms.insert(e);
1130                         return e;
1131                 }
1132                 return e.map(*this);
1133         }
1134 };
1135
1136 static ex factor_sqrfree(const ex& poly)
1137 {
1138         // determine all symbols in poly
1139         FindSymbolsMap findsymbols;
1140         findsymbols(poly);
1141         if ( findsymbols.syms.size() == 0 ) {
1142                 return poly;
1143         }
1144
1145         if ( findsymbols.syms.size() == 1 ) {
1146                 const ex& x = *(findsymbols.syms.begin());
1147                 if ( poly.ldegree(x) > 0 ) {
1148                         int ld = poly.ldegree(x);
1149                         ex res = factor_univariate(expand(poly/pow(x, ld)), x);
1150                         return res * pow(x,ld);
1151                 }
1152                 else {
1153                         ex res = factor_univariate(poly, x);
1154                         return res;
1155                 }
1156         }
1157
1158         // multivariate case not yet implemented!
1159         throw runtime_error("multivariate case not yet implemented!");
1160 }
1161
1162 } // anonymous namespace
1163
1164 ex factor(const ex& poly)
1165 {
1166         // determine all symbols in poly
1167         FindSymbolsMap findsymbols;
1168         findsymbols(poly);
1169         if ( findsymbols.syms.size() == 0 ) {
1170                 return poly;
1171         }
1172         lst syms;
1173         exset::const_iterator i=findsymbols.syms.begin(), end=findsymbols.syms.end();
1174         for ( ; i!=end; ++i ) {
1175                 syms.append(*i);
1176         }
1177
1178         // make poly square free
1179         ex sfpoly = sqrfree(poly, syms);
1180
1181         // factorize the square free components
1182         if ( is_a<power>(sfpoly) ) {
1183                 // case: (polynomial)^exponent
1184                 const ex& base = sfpoly.op(0);
1185                 if ( !is_a<add>(base) ) {
1186                         // simple case: (monomial)^exponent
1187                         return sfpoly;
1188                 }
1189                 ex f = factor_sqrfree(base);
1190                 return pow(f, sfpoly.op(1));
1191         }
1192         if ( is_a<mul>(sfpoly) ) {
1193                 ex res = 1;
1194                 for ( size_t i=0; i<sfpoly.nops(); ++i ) {
1195                         const ex& t = sfpoly.op(i);
1196                         if ( is_a<power>(t) ) {
1197                                 const ex& base = t.op(0);
1198                                 if ( !is_a<add>(base) ) {
1199                                         res *= t;
1200                                 }
1201                                 else {
1202                                         ex f = factor_sqrfree(base);
1203                                         res *= pow(f, t.op(1));
1204                                 }
1205                         }
1206                         else if ( is_a<add>(t) ) {
1207                                 ex f = factor_sqrfree(t);
1208                                 res *= f;
1209                         }
1210                         else {
1211                                 res *= t;
1212                         }
1213                 }
1214                 return res;
1215         }
1216         // case: (polynomial)
1217         ex f = factor_sqrfree(sfpoly);
1218         return f;
1219 }
1220
1221 } // namespace GiNaC