Changed debugging facilities in factor.cpp.
[ginac.git] / ginac / factor.cpp
1 /** @file factor.cpp
2  *
3  *  Polynomial factorization code (implementation).
4  *
5  *  Algorithms used can be found in
6  *    [W1]  An Improved Multivariate Polynomial Factoring Algorithm,
7  *          P.S.Wang, Mathematics of Computation, Vol. 32, No. 144 (1978) 1215--1231.
8  *    [GCL] Algorithms for Computer Algebra,
9  *          K.O.Geddes, S.R.Czapor, G.Labahn, Springer Verlag, 1992.
10  */
11
12 /*
13  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
14  *
15  *  This program is free software; you can redistribute it and/or modify
16  *  it under the terms of the GNU General Public License as published by
17  *  the Free Software Foundation; either version 2 of the License, or
18  *  (at your option) any later version.
19  *
20  *  This program is distributed in the hope that it will be useful,
21  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
22  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
23  *  GNU General Public License for more details.
24  *
25  *  You should have received a copy of the GNU General Public License
26  *  along with this program; if not, write to the Free Software
27  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
28  */
29
30 //#define DEBUGFACTOR
31
32 #include "factor.h"
33
34 #include "ex.h"
35 #include "numeric.h"
36 #include "operators.h"
37 #include "inifcns.h"
38 #include "symbol.h"
39 #include "relational.h"
40 #include "power.h"
41 #include "mul.h"
42 #include "normal.h"
43 #include "add.h"
44
45 #include <algorithm>
46 #include <cmath>
47 #include <limits>
48 #include <list>
49 #include <vector>
50 #ifdef DEBUGFACTOR
51 #include <ostream>
52 #endif
53 using namespace std;
54
55 #include <cln/cln.h>
56 using namespace cln;
57
58 namespace GiNaC {
59
60 #ifdef DEBUGFACTOR
61 #define DCOUT(str) cout << #str << endl
62 #define DCOUTVAR(var) cout << #var << ": " << var << endl
63 #define DCOUT2(str,var) cout << #str << ": " << var << endl
64 #else
65 #define DCOUT(str)
66 #define DCOUTVAR(var)
67 #define DCOUT2(str,var)
68 #endif
69
70 // anonymous namespace to hide all utility functions
71 namespace {
72
73 typedef vector<cl_MI> mvec;
74 #ifdef DEBUGFACTOR
75 ostream& operator<<(ostream& o, const vector<cl_MI>& v)
76 {
77         vector<cl_MI>::const_iterator i = v.begin(), end = v.end();
78         while ( i != end ) {
79                 o << *i++ << " ";
80         }
81         return o;
82 }
83 ostream& operator<<(ostream& o, const vector< vector<cl_MI> >& v)
84 {
85         vector< vector<cl_MI> >::const_iterator i = v.begin(), end = v.end();
86         while ( i != end ) {
87                 o << *i++ << endl;
88         }
89         return o;
90 }
91 #endif
92
93 ////////////////////////////////////////////////////////////////////////////////
94 // modular univariate polynomial code
95
96 //typedef cl_UP_MI umod;
97 typedef std::vector<cln::cl_MI> umodpoly;
98 //typedef vector<umod> umodvec;
99 typedef vector<umodpoly> upvec;
100
101 // COPY FROM UPOLY.HPP
102
103 // CHANGED size_t -> int !!!
104 template<typename T> static int degree(const T& p)
105 {
106         return p.size() - 1;
107 }
108
109 template<typename T> static typename T::value_type lcoeff(const T& p)
110 {
111         return p[p.size() - 1];
112 }
113
114 static bool normalize_in_field(umodpoly& a)
115 {
116         if (a.size() == 0)
117                 return true;
118         if ( lcoeff(a) == a[0].ring()->one() ) {
119                 return true;
120         }
121
122         const cln::cl_MI lc_1 = recip(lcoeff(a));
123         for (std::size_t k = a.size(); k-- != 0; )
124                 a[k] = a[k]*lc_1;
125         return false;
126 }
127
128 template<typename T> static void
129 canonicalize(T& p, const typename T::size_type hint = std::numeric_limits<typename T::size_type>::max())
130 {
131         if (p.empty())
132                 return;
133
134         std::size_t i = p.size() - 1;
135         // Be fast if the polynomial is already canonicalized
136         if (!zerop(p[i]))
137                 return;
138
139         if (hint < p.size())
140                 i = hint;
141
142         bool is_zero = false;
143         do {
144                 if (!zerop(p[i])) {
145                         ++i;
146                         break;
147                 }
148                 if (i == 0) {
149                         is_zero = true;
150                         break;
151                 }
152                 --i;
153         } while (true);
154
155         if (is_zero) {
156                 p.clear();
157                 return;
158         }
159
160         p.erase(p.begin() + i, p.end());
161 }
162
163 // END COPY FROM UPOLY.HPP
164
165 static void expt_pos(const umodpoly& a, unsigned int q, umodpoly& b)
166 {
167         throw runtime_error("expt_pos: not implemented!");
168         // code below is not correct!
169 //      b.clear();
170 //      if ( a.empty() ) return;
171 //      b.resize(degree(a)*q+1, a[0].ring()->zero());
172 //      cl_MI norm = recip(a[0]);
173 //      umodpoly an = a;
174 //      for ( size_t i=0; i<an.size(); ++i ) {
175 //              an[i] = an[i] * norm;
176 //      }
177 //      b[0] = a[0].ring()->one();
178 //      for ( size_t i=1; i<b.size(); ++i ) {
179 //              for ( size_t j=1; j<i; ++j ) {
180 //                      b[i] = b[i] + ((i-j+1)*q-i-1) * a[i-j] * b[j-1];
181 //              }
182 //              b[i] = b[i] / i;
183 //      }
184 //      cl_MI corr = expt_pos(a[0], q);
185 //      for ( size_t i=0; i<b.size(); ++i ) {
186 //              b[i] = b[i] * corr;
187 //      }
188 }
189
190 static umodpoly operator+(const umodpoly& a, const umodpoly& b)
191 {
192         int sa = a.size();
193         int sb = b.size();
194         if ( sa >= sb ) {
195                 umodpoly r(sa);
196                 int i = 0;
197                 for ( ; i<sb; ++i ) {
198                         r[i] = a[i] + b[i];
199                 }
200                 for ( ; i<sa; ++i ) {
201                         r[i] = a[i];
202                 }
203                 canonicalize(r);
204                 return r;
205         }
206         else {
207                 umodpoly r(sb);
208                 int i = 0;
209                 for ( ; i<sa; ++i ) {
210                         r[i] = a[i] + b[i];
211                 }
212                 for ( ; i<sb; ++i ) {
213                         r[i] = b[i];
214                 }
215                 canonicalize(r);
216                 return r;
217         }
218 }
219
220 static umodpoly operator-(const umodpoly& a, const umodpoly& b)
221 {
222         int sa = a.size();
223         int sb = b.size();
224         if ( sa >= sb ) {
225                 umodpoly r(sa);
226                 int i = 0;
227                 for ( ; i<sb; ++i ) {
228                         r[i] = a[i] - b[i];
229                 }
230                 for ( ; i<sa; ++i ) {
231                         r[i] = a[i];
232                 }
233                 canonicalize(r);
234                 return r;
235         }
236         else {
237                 umodpoly r(sb);
238                 int i = 0;
239                 for ( ; i<sa; ++i ) {
240                         r[i] = a[i] - b[i];
241                 }
242                 for ( ; i<sb; ++i ) {
243                         r[i] = -b[i];
244                 }
245                 canonicalize(r);
246                 return r;
247         }
248 }
249
250 static umodpoly operator*(const umodpoly& a, const umodpoly& b)
251 {
252         umodpoly c;
253         if ( a.empty() || b.empty() ) return c;
254
255         int n = degree(a) + degree(b);
256         c.resize(n+1, a[0].ring()->zero());
257         for ( int i=0 ; i<=n; ++i ) {
258                 for ( int j=0 ; j<=i; ++j ) {
259                         if ( j > degree(a) || (i-j) > degree(b) ) continue; // TODO optimize!
260                         c[i] = c[i] + a[j] * b[i-j];
261                 }
262         }
263         canonicalize(c);
264         return c;
265 }
266
267 static umodpoly operator*(const umodpoly& a, const cl_MI& x)
268 {
269         umodpoly r(a.size());
270         for ( size_t i=0; i<a.size(); ++i ) {
271                 r[i] = a[i] * x;
272         }
273         canonicalize(r);
274         return r;
275 }
276
277 static void umodpoly_from_ex(umodpoly& ump, const ex& e, const ex& x, const cl_modint_ring& R)
278 {
279         // assert: e is in Z[x]
280         int deg = e.degree(x);
281         ump.resize(deg+1);
282         int ldeg = e.ldegree(x);
283         for ( ; deg>=ldeg; --deg ) {
284                 cl_I coeff = the<cl_I>(ex_to<numeric>(e.coeff(x, deg)).to_cl_N());
285                 ump[deg] = R->canonhom(coeff);
286         }
287         for ( ; deg>=0; --deg ) {
288                 ump[deg] = R->zero();
289         }
290         canonicalize(ump);
291 }
292
293 static void umodpoly_from_ex(umodpoly& ump, const ex& e, const ex& x, const cl_I& modulus)
294 {
295         umodpoly_from_ex(ump, e, x, find_modint_ring(modulus));
296 }
297
298 static ex umod_to_ex(const umodpoly& a, const ex& x)
299 {
300         if ( a.empty() ) return 0;
301         cl_modint_ring R = a[0].ring();
302         cl_I mod = R->modulus;
303         cl_I halfmod = (mod-1) >> 1;
304         ex e;
305         for ( int i=degree(a); i>=0; --i ) {
306                 cl_I n = R->retract(a[i]);
307                 if ( n > halfmod ) {
308                         e += numeric(n-mod) * pow(x, i);
309                 } else {
310                         e += numeric(n) * pow(x, i);
311                 }
312         }
313         return e;
314 }
315
316 /** Divides all coefficients of the polynomial a by the integer x.
317  *  All coefficients are supposed to be divisible by x. If they are not, the
318  *  the<cl_I> cast will raise an exception.
319  *
320  *  @param[in,out] a  polynomial of which the coefficients will be reduced by x
321  *  @param[in]     x  integer that divides the coefficients
322  */
323 static void reduce_coeff(umodpoly& a, const cl_I& x)
324 {
325         if ( a.empty() ) return;
326
327         cl_modint_ring R = a[0].ring();
328         umodpoly::iterator i = a.begin(), end = a.end();
329         for ( ; i!=end; ++i ) {
330                 // cln cannot perform this division in the modular field
331                 cl_I c = R->retract(*i);
332                 *i = cl_MI(R, the<cl_I>(c / x));
333         }
334 }
335
336 /** Calculates remainder of a/b.
337  *  Assertion: a and b not empty.
338  *
339  *  @param[in]  a  polynomial dividend
340  *  @param[in]  b  polynomial divisor
341  *  @param[out] r  polynomial remainder
342  */
343 static void rem(const umodpoly& a, const umodpoly& b, umodpoly& r)
344 {
345         int k, n;
346         n = degree(b);
347         k = degree(a) - n;
348         r = a;
349         if ( k < 0 ) return;
350
351         do {
352                 cl_MI qk = div(r[n+k], b[n]);
353                 if ( !zerop(qk) ) {
354                         for ( int i=0; i<n; ++i ) {
355                                 unsigned int j = n + k - 1 - i;
356                                 r[j] = r[j] - qk * b[j-k];
357                         }
358                 }
359         } while ( k-- );
360
361         fill(r.begin()+n, r.end(), a[0].ring()->zero());
362         canonicalize(r);
363 }
364
365 /** Calculates quotient of a/b.
366  *  Assertion: a and b not empty.
367  *
368  *  @param[in]  a  polynomial dividend
369  *  @param[in]  b  polynomial divisor
370  *  @param[out] q  polynomial quotient
371  */
372 static void div(const umodpoly& a, const umodpoly& b, umodpoly& q)
373 {
374         int k, n;
375         n = degree(b);
376         k = degree(a) - n;
377         q.clear();
378         if ( k < 0 ) return;
379
380         umodpoly r = a;
381         q.resize(k+1, a[0].ring()->zero());
382         do {
383                 cl_MI qk = div(r[n+k], b[n]);
384                 if ( !zerop(qk) ) {
385                         q[k] = qk;
386                         for ( int i=0; i<n; ++i ) {
387                                 unsigned int j = n + k - 1 - i;
388                                 r[j] = r[j] - qk * b[j-k];
389                         }
390                 }
391         } while ( k-- );
392
393         canonicalize(q);
394 }
395
396 /** Calculates quotient and remainder of a/b.
397  *  Assertion: a and b not empty.
398  *
399  *  @param[in]  a  polynomial dividend
400  *  @param[in]  b  polynomial divisor
401  *  @param[out] r  polynomial remainder
402  *  @param[out] q  polynomial quotient
403  */
404 static void remdiv(const umodpoly& a, const umodpoly& b, umodpoly& r, umodpoly& q)
405 {
406         int k, n;
407         n = degree(b);
408         k = degree(a) - n;
409         q.clear();
410         r = a;
411         if ( k < 0 ) return;
412
413         q.resize(k+1, a[0].ring()->zero());
414         do {
415                 cl_MI qk = div(r[n+k], b[n]);
416                 if ( !zerop(qk) ) {
417                         q[k] = qk;
418                         for ( int i=0; i<n; ++i ) {
419                                 unsigned int j = n + k - 1 - i;
420                                 r[j] = r[j] - qk * b[j-k];
421                         }
422                 }
423         } while ( k-- );
424
425         fill(r.begin()+n, r.end(), a[0].ring()->zero());
426         canonicalize(r);
427         canonicalize(q);
428 }
429
430 /** Calculates the GCD of polynomial a and b.
431  *
432  *  @param[in]  a  polynomial
433  *  @param[in]  b  polynomial
434  *  @param[out] c  GCD
435  */
436 static void gcd(const umodpoly& a, const umodpoly& b, umodpoly& c)
437 {
438         if ( degree(a) < degree(b) ) return gcd(b, a, c);
439
440         c = a;
441         normalize_in_field(c);
442         umodpoly d = b;
443         normalize_in_field(d);
444         umodpoly r;
445         while ( !d.empty() ) {
446                 rem(c, d, r);
447                 c = d;
448                 d = r;
449         }
450         normalize_in_field(c);
451 }
452
453 /** Calculates the derivative of the polynomial a.
454  *  
455  *  @param[in]  a  polynomial of which to take the derivative
456  *  @param[out] d  result/derivative
457  */
458 static void deriv(const umodpoly& a, umodpoly& d)
459 {
460         d.clear();
461         if ( a.size() <= 1 ) return;
462
463         d.insert(d.begin(), a.begin()+1, a.end());
464         int max = d.size();
465         for ( int i=1; i<max; ++i ) {
466                 d[i] = d[i] * (i+1);
467         }
468         canonicalize(d);
469 }
470
471 static bool unequal_one(const umodpoly& a)
472 {
473         if ( a.empty() ) return true;
474         return ( a.size() != 1 || a[0] != a[0].ring()->one() );
475 }
476
477 static bool equal_one(const umodpoly& a)
478 {
479         return ( a.size() == 1 && a[0] == a[0].ring()->one() );
480 }
481
482 /** Returns true if polynomial a is square free.
483  *
484  *  @param[in] a  polynomial to check
485  *  @return       true if polynomial is square free, false otherwise
486  */
487 static bool squarefree(const umodpoly& a)
488 {
489         umodpoly b;
490         deriv(a, b);
491         if ( b.empty() ) {
492                 return true;
493         }
494         umodpoly c;
495         gcd(a, b, c);
496         return equal_one(c);
497 }
498
499 // END modular univariate polynomial code
500 ////////////////////////////////////////////////////////////////////////////////
501
502 ////////////////////////////////////////////////////////////////////////////////
503 // modular matrix
504
505 class modular_matrix
506 {
507         friend ostream& operator<<(ostream& o, const modular_matrix& m);
508 public:
509         modular_matrix(size_t r_, size_t c_, const cl_MI& init) : r(r_), c(c_)
510         {
511                 m.resize(c*r, init);
512         }
513         size_t rowsize() const { return r; }
514         size_t colsize() const { return c; }
515         cl_MI& operator()(size_t row, size_t col) { return m[row*c + col]; }
516         cl_MI operator()(size_t row, size_t col) const { return m[row*c + col]; }
517         void mul_col(size_t col, const cl_MI x)
518         {
519                 mvec::iterator i = m.begin() + col;
520                 for ( size_t rc=0; rc<r; ++rc ) {
521                         *i = *i * x;
522                         i += c;
523                 }
524         }
525         void sub_col(size_t col1, size_t col2, const cl_MI fac)
526         {
527                 mvec::iterator i1 = m.begin() + col1;
528                 mvec::iterator i2 = m.begin() + col2;
529                 for ( size_t rc=0; rc<r; ++rc ) {
530                         *i1 = *i1 - *i2 * fac;
531                         i1 += c;
532                         i2 += c;
533                 }
534         }
535         void switch_col(size_t col1, size_t col2)
536         {
537                 cl_MI buf;
538                 mvec::iterator i1 = m.begin() + col1;
539                 mvec::iterator i2 = m.begin() + col2;
540                 for ( size_t rc=0; rc<r; ++rc ) {
541                         buf = *i1; *i1 = *i2; *i2 = buf;
542                         i1 += c;
543                         i2 += c;
544                 }
545         }
546         void mul_row(size_t row, const cl_MI x)
547         {
548                 vector<cl_MI>::iterator i = m.begin() + row*c;
549                 for ( size_t cc=0; cc<c; ++cc ) {
550                         *i = *i * x;
551                         ++i;
552                 }
553         }
554         void sub_row(size_t row1, size_t row2, const cl_MI fac)
555         {
556                 vector<cl_MI>::iterator i1 = m.begin() + row1*c;
557                 vector<cl_MI>::iterator i2 = m.begin() + row2*c;
558                 for ( size_t cc=0; cc<c; ++cc ) {
559                         *i1 = *i1 - *i2 * fac;
560                         ++i1;
561                         ++i2;
562                 }
563         }
564         void switch_row(size_t row1, size_t row2)
565         {
566                 cl_MI buf;
567                 vector<cl_MI>::iterator i1 = m.begin() + row1*c;
568                 vector<cl_MI>::iterator i2 = m.begin() + row2*c;
569                 for ( size_t cc=0; cc<c; ++cc ) {
570                         buf = *i1; *i1 = *i2; *i2 = buf;
571                         ++i1;
572                         ++i2;
573                 }
574         }
575         bool is_col_zero(size_t col) const
576         {
577                 mvec::const_iterator i = m.begin() + col;
578                 for ( size_t rr=0; rr<r; ++rr ) {
579                         if ( !zerop(*i) ) {
580                                 return false;
581                         }
582                         i += c;
583                 }
584                 return true;
585         }
586         bool is_row_zero(size_t row) const
587         {
588                 mvec::const_iterator i = m.begin() + row*c;
589                 for ( size_t cc=0; cc<c; ++cc ) {
590                         if ( !zerop(*i) ) {
591                                 return false;
592                         }
593                         ++i;
594                 }
595                 return true;
596         }
597         void set_row(size_t row, const vector<cl_MI>& newrow)
598         {
599                 mvec::iterator i1 = m.begin() + row*c;
600                 mvec::const_iterator i2 = newrow.begin(), end = newrow.end();
601                 for ( ; i2 != end; ++i1, ++i2 ) {
602                         *i1 = *i2;
603                 }
604         }
605         mvec::const_iterator row_begin(size_t row) const { return m.begin()+row*c; }
606         mvec::const_iterator row_end(size_t row) const { return m.begin()+row*c+r; }
607 private:
608         size_t r, c;
609         mvec m;
610 };
611
612 #ifdef DEBUGFACTOR
613 modular_matrix operator*(const modular_matrix& m1, const modular_matrix& m2)
614 {
615         const unsigned int r = m1.rowsize();
616         const unsigned int c = m2.colsize();
617         modular_matrix o(r,c,m1(0,0));
618
619         for ( size_t i=0; i<r; ++i ) {
620                 for ( size_t j=0; j<c; ++j ) {
621                         cl_MI buf;
622                         buf = m1(i,0) * m2(0,j);
623                         for ( size_t k=1; k<c; ++k ) {
624                                 buf = buf + m1(i,k)*m2(k,j);
625                         }
626                         o(i,j) = buf;
627                 }
628         }
629         return o;
630 }
631
632 ostream& operator<<(ostream& o, const modular_matrix& m)
633 {
634         vector<cl_MI>::const_iterator i = m.m.begin(), end = m.m.end();
635         size_t wrap = 1;
636         for ( ; i != end; ++i ) {
637                 o << *i << " ";
638                 if ( !(wrap++ % m.c) ) {
639                         o << endl;
640                 }
641         }
642         o << endl;
643         return o;
644 }
645 #endif // def DEBUGFACTOR
646
647 // END modular matrix
648 ////////////////////////////////////////////////////////////////////////////////
649
650 static void q_matrix(const umodpoly& a, modular_matrix& Q)
651 {
652         int n = degree(a);
653         unsigned int q = cl_I_to_uint(a[0].ring()->modulus);
654 // fast and buggy
655 //      vector<cl_MI> r(n, a.R->zero());
656 //      r[0] = a.R->one();
657 //      Q.set_row(0, r);
658 //      unsigned int max = (n-1) * q;
659 //      for ( size_t m=1; m<=max; ++m ) {
660 //              cl_MI rn_1 = r.back();
661 //              for ( size_t i=n-1; i>0; --i ) {
662 //                      r[i] = r[i-1] - rn_1 * a[i];
663 //              }
664 //              r[0] = -rn_1 * a[0];
665 //              if ( (m % q) == 0 ) {
666 //                      Q.set_row(m/q, r);
667 //              }
668 //      }
669 // slow and (hopefully) correct
670         cl_MI one = a[0].ring()->one();
671         cl_MI zero = a[0].ring()->zero();
672         for ( int i=0; i<n; ++i ) {
673                 umodpoly qk(i*q+1, zero);
674                 qk[i*q] = one;
675                 umodpoly r;
676                 rem(qk, a, r);
677                 mvec rvec(n, zero);
678                 for ( int j=0; j<=degree(r); ++j ) {
679                         rvec[j] = r[j];
680                 }
681                 Q.set_row(i, rvec);
682         }
683 }
684
685 static void nullspace(modular_matrix& M, vector<mvec>& basis)
686 {
687         const size_t n = M.rowsize();
688         const cl_MI one = M(0,0).ring()->one();
689         for ( size_t i=0; i<n; ++i ) {
690                 M(i,i) = M(i,i) - one;
691         }
692         for ( size_t r=0; r<n; ++r ) {
693                 size_t cc = 0;
694                 for ( ; cc<n; ++cc ) {
695                         if ( !zerop(M(r,cc)) ) {
696                                 if ( cc < r ) {
697                                         if ( !zerop(M(cc,cc)) ) {
698                                                 continue;
699                                         }
700                                         M.switch_col(cc, r);
701                                 }
702                                 else if ( cc > r ) {
703                                         M.switch_col(cc, r);
704                                 }
705                                 break;
706                         }
707                 }
708                 if ( cc < n ) {
709                         M.mul_col(r, recip(M(r,r)));
710                         for ( cc=0; cc<n; ++cc ) {
711                                 if ( cc != r ) {
712                                         M.sub_col(cc, r, M(r,cc));
713                                 }
714                         }
715                 }
716         }
717
718         for ( size_t i=0; i<n; ++i ) {
719                 M(i,i) = M(i,i) - one;
720         }
721         for ( size_t i=0; i<n; ++i ) {
722                 if ( !M.is_row_zero(i) ) {
723                         mvec nu(M.row_begin(i), M.row_end(i));
724                         basis.push_back(nu);
725                 }
726         }
727 }
728
729 static void berlekamp(const umodpoly& a, upvec& upv)
730 {
731         cl_modint_ring R = a[0].ring();
732         umodpoly one(1, R->one());
733
734         modular_matrix Q(degree(a), degree(a), R->zero());
735         q_matrix(a, Q);
736         vector<mvec> nu;
737         nullspace(Q, nu);
738         const unsigned int k = nu.size();
739         if ( k == 1 ) {
740                 return;
741         }
742
743         list<umodpoly> factors;
744         factors.push_back(a);
745         unsigned int size = 1;
746         unsigned int r = 1;
747         unsigned int q = cl_I_to_uint(R->modulus);
748
749         list<umodpoly>::iterator u = factors.begin();
750
751         while ( true ) {
752                 for ( unsigned int s=0; s<q; ++s ) {
753                         umodpoly nur = nu[r];
754                         nur[0] = nur[0] - cl_MI(R, s);
755                         canonicalize(nur);
756                         umodpoly g;
757                         gcd(nur, *u, g);
758                         if ( unequal_one(g) && g != *u ) {
759                                 umodpoly uo;
760                                 div(*u, g, uo);
761                                 if ( equal_one(uo) ) {
762                                         throw logic_error("berlekamp: unexpected divisor.");
763                                 }
764                                 else {
765                                         *u = uo;
766                                 }
767                                 factors.push_back(g);
768                                 size = 0;
769                                 list<umodpoly>::const_iterator i = factors.begin(), end = factors.end();
770                                 while ( i != end ) {
771                                         if ( degree(*i) ) ++size; 
772                                         ++i;
773                                 }
774                                 if ( size == k ) {
775                                         list<umodpoly>::const_iterator i = factors.begin(), end = factors.end();
776                                         while ( i != end ) {
777                                                 upv.push_back(*i++);
778                                         }
779                                         return;
780                                 }
781                         }
782                 }
783                 if ( ++r == k ) {
784                         r = 1;
785                         ++u;
786                 }
787         }
788 }
789
790 static void rem_xq(int q, const umodpoly& b, umodpoly& c)
791 {
792         cl_modint_ring R = b[0].ring();
793
794         int n = degree(b);
795         if ( n > q ) {
796                 c.resize(q+1, R->zero());
797                 c[q] = R->one();
798                 return;
799         }
800
801         c.clear();
802         c.resize(n+1, R->zero());
803         int k = q-n;
804         c[n] = R->one();
805
806         int ofs = 0;
807         do {
808                 cl_MI qk = div(c[n-ofs], b[n]);
809                 if ( !zerop(qk) ) {
810                         for ( int i=1; i<=n; ++i ) {
811                                 c[n-i+ofs] = c[n-i] - qk * b[n-i];
812                         }
813                         ofs = ofs ? 0 : 1;
814                 }
815         } while ( k-- );
816
817         if ( ofs ) {
818                 c.pop_back();
819         }
820         else {
821                 c.erase(c.begin());
822         }
823         canonicalize(c);
824 }
825
826 static void distinct_degree_factor(const umodpoly& a_, upvec& result)
827 {
828         umodpoly a = a_;
829
830         cl_modint_ring R = a[0].ring();
831         int q = cl_I_to_int(R->modulus);
832         int n = degree(a);
833         size_t nhalf = n/2;
834
835         size_t i = 1;
836         umodpoly w(1, R->one());
837         umodpoly x = w;
838
839         upvec ai;
840
841         while ( i <= nhalf ) {
842                 expt_pos(w, q, w);
843                 rem(w, a, w);
844
845                 umodpoly buf;
846                 gcd(a, w-x, buf);
847                 ai.push_back(buf);
848
849                 if ( unequal_one(ai.back()) ) {
850                         div(a, ai.back(), a);
851                         rem(w, a, w);
852                 }
853
854                 ++i;
855         }
856
857         result = ai;
858 }
859
860 static void same_degree_factor(const umodpoly& a, upvec& result)
861 {
862         cl_modint_ring R = a[0].ring();
863         int deg = degree(a);
864
865         upvec buf;
866         distinct_degree_factor(a, buf);
867         int degsum = 0;
868
869         for ( size_t i=0; i<buf.size(); ++i ) {
870                 if ( unequal_one(buf[i]) ) {
871                         degsum += degree(buf[i]);
872                         upvec upv;
873                         berlekamp(buf[i], upv);
874                         for ( size_t j=0; j<upv.size(); ++j ) {
875                                 result.push_back(upv[j]);
876                         }
877                 }
878         }
879
880         if ( degsum < deg ) {
881                 result.push_back(a);
882         }
883 }
884
885 static void distinct_degree_factor_BSGS(const umodpoly& a, upvec& result)
886 {
887         cl_modint_ring R = a[0].ring();
888         int q = cl_I_to_int(R->modulus);
889         int n = degree(a);
890
891         cl_N pm = 0.3;
892         int l = cl_I_to_int(ceiling1(the<cl_F>(expt(n, pm))));
893         upvec h(l+1);
894         umodpoly qk(1, R->one());
895         h[0] = qk;
896         for ( int i=1; i<=l; ++i ) {
897                 expt_pos(h[i-1], q, qk);
898                 rem(qk, a, h[i]);
899         }
900
901         int m = std::ceil(((double)n)/2/l);
902         upvec H(m);
903         int ql = std::pow(q, l);
904         H[0] = h[l];
905         for ( int i=1; i<m; ++i ) {
906                 expt_pos(H[i-1], ql, qk);
907                 rem(qk, a, H[i]);
908         }
909
910         upvec I(m);
911         umodpoly one(1, R->one());
912         for ( int i=0; i<m; ++i ) {
913                 I[i] = one;
914                 for ( int j=0; j<l; ++j ) {
915                         I[i] = I[i] * (H[i] - h[j]);
916                 }
917                 rem(I[i], a, I[i]);
918         }
919
920         upvec F(m, one);
921         umodpoly f = a;
922         for ( int i=0; i<m; ++i ) {
923                 umodpoly g;
924                 gcd(f, I[i], g); 
925                 if ( g == one ) continue;
926                 F[i] = g;
927                 div(f, g, f);
928         }
929
930         result.resize(n, one);
931         if ( unequal_one(f) ) {
932                 result[n] = f;
933         }
934         for ( int i=0; i<m; ++i ) {
935                 umodpoly f = F[i];
936                 for ( int j=l-1; j>=0; --j ) {
937                         umodpoly g;
938                         gcd(f, H[i]-h[j], g);
939                         result[l*(i+1)-j-1] = g;
940                         div(f, g, f);
941                 }
942         }
943 }
944
945 static void cantor_zassenhaus(const umodpoly& a, upvec& result)
946 {
947 }
948
949 static void factor_modular(const umodpoly& p, upvec& upv)
950 {
951         //same_degree_factor(p, upv);
952         berlekamp(p, upv);
953         return;
954 }
955
956 static void exteuclid(const umodpoly& a, const umodpoly& b, umodpoly& g, umodpoly& s, umodpoly& t)
957 {
958         if ( degree(a) < degree(b) ) {
959                 exteuclid(b, a, g, t, s);
960                 return;
961         }
962         umodpoly one(1, a[0].ring()->one());
963         umodpoly c = a; normalize_in_field(c);
964         umodpoly d = b; normalize_in_field(d);
965         umodpoly c1 = one;
966         umodpoly c2;
967         umodpoly d1;
968         umodpoly d2 = one;
969         while ( !d.empty() ) {
970                 umodpoly q;
971                 div(c, d, q);
972                 umodpoly r = c - q * d;
973                 umodpoly r1 = c1 - q * d1;
974                 umodpoly r2 = c2 - q * d2;
975                 c = d;
976                 c1 = d1;
977                 c2 = d2;
978                 d = r;
979                 d1 = r1;
980                 d2 = r2;
981         }
982         g = c; normalize_in_field(g);
983         s = c1;
984         for ( int i=0; i<=degree(s); ++i ) {
985                 s[i] = s[i] * recip(a[degree(a)] * c[degree(c)]);
986         }
987         canonicalize(s);
988         s = s * g;
989         t = c2;
990         for ( int i=0; i<=degree(t); ++i ) {
991                 t[i] = t[i] * recip(b[degree(b)] * c[degree(c)]);
992         }
993         canonicalize(t);
994         t = t * g;
995 }
996
997 static ex replace_lc(const ex& poly, const ex& x, const ex& lc)
998 {
999         ex r = expand(poly + (lc - poly.lcoeff(x)) * pow(x, poly.degree(x)));
1000         return r;
1001 }
1002
1003 static ex hensel_univar(const ex& a_, const ex& x, unsigned int p, const umodpoly& u1_, const umodpoly& w1_, const ex& gamma_ = 0)
1004 {
1005         ex a = a_;
1006         const cl_modint_ring& R = u1_[0].ring();
1007
1008         // calc bound B
1009         ex maxcoeff;
1010         for ( int i=a.degree(x); i>=a.ldegree(x); --i ) {
1011                 maxcoeff += pow(abs(a.coeff(x, i)),2);
1012         }
1013         cl_I normmc = ceiling1(the<cl_R>(cln::sqrt(ex_to<numeric>(maxcoeff).to_cl_N())));
1014         cl_I maxdegree = (degree(u1_) > degree(w1_)) ? degree(u1_) : degree(w1_);
1015         cl_I B = normmc * expt_pos(cl_I(2), maxdegree);
1016
1017         // step 1
1018         ex alpha = a.lcoeff(x);
1019         ex gamma = gamma_;
1020         if ( gamma == 0 ) {
1021                 gamma = alpha;
1022         }
1023         numeric gamma_ui = ex_to<numeric>(abs(gamma));
1024         a = a * gamma;
1025         umodpoly nu1 = u1_;
1026         normalize_in_field(nu1);
1027         umodpoly nw1 = w1_;
1028         normalize_in_field(nw1);
1029         ex phi;
1030         phi = gamma * umod_to_ex(nu1, x);
1031         umodpoly u1;
1032         umodpoly_from_ex(u1, phi, x, R);
1033         phi = alpha * umod_to_ex(nw1, x);
1034         umodpoly w1;
1035         umodpoly_from_ex(w1, phi, x, R);
1036
1037         // step 2
1038         umodpoly g;
1039         umodpoly s;
1040         umodpoly t;
1041         exteuclid(u1, w1, g, s, t);
1042         if ( unequal_one(g) ) {
1043                 throw logic_error("gcd(u1,w1) != 1");
1044         }
1045
1046         // step 3
1047         ex u = replace_lc(umod_to_ex(u1, x), x, gamma);
1048         ex w = replace_lc(umod_to_ex(w1, x), x, alpha);
1049         ex e = expand(a - u * w);
1050         numeric modulus = p;
1051         const numeric maxmodulus = 2*numeric(B)*gamma_ui;
1052
1053         // step 4
1054         while ( !e.is_zero() && modulus < maxmodulus ) {
1055                 ex c = e / modulus;
1056                 phi = expand(umod_to_ex(s, x) * c);
1057                 umodpoly sigmatilde;
1058                 umodpoly_from_ex(sigmatilde, phi, x, R);
1059                 phi = expand(umod_to_ex(t, x) * c);
1060                 umodpoly tautilde;
1061                 umodpoly_from_ex(tautilde, phi, x, R);
1062                 umodpoly r, q;
1063                 remdiv(sigmatilde, w1, r, q);
1064                 umodpoly sigma = r;
1065                 phi = expand(umod_to_ex(tautilde, x) + umod_to_ex(q, x) * umod_to_ex(u1, x));
1066                 umodpoly tau;
1067                 umodpoly_from_ex(tau, phi, x, R);
1068                 u = expand(u + umod_to_ex(tau, x) * modulus);
1069                 w = expand(w + umod_to_ex(sigma, x) * modulus);
1070                 e = expand(a - u * w);
1071                 modulus = modulus * p;
1072         }
1073
1074         // step 5
1075         if ( e.is_zero() ) {
1076                 ex delta = u.content(x);
1077                 u = u / delta;
1078                 w = w / gamma * delta;
1079                 return lst(u, w);
1080         }
1081         else {
1082                 return lst();
1083         }
1084 }
1085
1086 static unsigned int next_prime(unsigned int p)
1087 {
1088         static vector<unsigned int> primes;
1089         if ( primes.size() == 0 ) {
1090                 primes.push_back(3); primes.push_back(5); primes.push_back(7);
1091         }
1092         vector<unsigned int>::const_iterator it = primes.begin();
1093         if ( p >= primes.back() ) {
1094                 unsigned int candidate = primes.back() + 2;
1095                 while ( true ) {
1096                         size_t n = primes.size()/2;
1097                         for ( size_t i=0; i<n; ++i ) {
1098                                 if ( candidate % primes[i] ) continue;
1099                                 candidate += 2;
1100                                 i=-1;
1101                         }
1102                         primes.push_back(candidate);
1103                         if ( candidate > p ) break;
1104                 }
1105                 return candidate;
1106         }
1107         vector<unsigned int>::const_iterator end = primes.end();
1108         for ( ; it!=end; ++it ) {
1109                 if ( *it > p ) {
1110                         return *it;
1111                 }
1112         }
1113         throw logic_error("next_prime: should not reach this point!");
1114 }
1115
1116 class Partition
1117 {
1118 public:
1119         Partition(size_t n_) : n(n_)
1120         {
1121                 k.resize(n, 1);
1122                 k[0] = 0;
1123                 sum = n-1;
1124         }
1125         int operator[](size_t i) const { return k[i]; }
1126         size_t size() const { return n; }
1127         size_t size_first() const { return n-sum; }
1128         size_t size_second() const { return sum; }
1129 #ifdef DEBUGFACTOR
1130         void get() const
1131         {
1132                 for ( size_t i=0; i<k.size(); ++i ) {
1133                         cout << k[i] << " ";
1134                 }
1135                 cout << endl;
1136         }
1137 #endif
1138         bool next()
1139         {
1140                 for ( size_t i=n-1; i>=1; --i ) {
1141                         if ( k[i] ) {
1142                                 --k[i];
1143                                 --sum;
1144                                 return sum > 0;
1145                         }
1146                         ++k[i];
1147                         ++sum;
1148                 }
1149                 return false;
1150         }
1151 private:
1152         size_t n, sum;
1153         vector<int> k;
1154 };
1155
1156 static void split(const upvec& factors, const Partition& part, umodpoly& a, umodpoly& b)
1157 {
1158         umodpoly one(1, factors.front()[0].ring()->one());
1159         a = one;
1160         b = one;
1161         for ( size_t i=0; i<part.size(); ++i ) {
1162                 if ( part[i] ) {
1163                         b = b * factors[i];
1164                 }
1165                 else {
1166                         a = a * factors[i];
1167                 }
1168         }
1169 }
1170
1171 struct ModFactors
1172 {
1173         ex poly;
1174         upvec factors;
1175 };
1176
1177 static ex factor_univariate(const ex& poly, const ex& x)
1178 {
1179         ex unit, cont, prim;
1180         poly.unitcontprim(x, unit, cont, prim);
1181
1182         // determine proper prime and minimize number of modular factors
1183         unsigned int p = 3, lastp = 3;
1184         cl_modint_ring R;
1185         unsigned int trials = 0;
1186         unsigned int minfactors = 0;
1187         numeric lcoeff = ex_to<numeric>(prim.lcoeff(x));
1188         upvec factors;
1189         while ( trials < 2 ) {
1190                 while ( true ) {
1191                         p = next_prime(p);
1192                         if ( irem(lcoeff, p) != 0 ) {
1193                                 R = find_modint_ring(p);
1194                                 umodpoly modpoly;
1195                                 umodpoly_from_ex(modpoly, prim, x, R);
1196                                 if ( squarefree(modpoly) ) break;
1197                         }
1198                 }
1199
1200                 // do modular factorization
1201                 umodpoly modpoly;
1202                 umodpoly_from_ex(modpoly, prim, x, R);
1203                 upvec trialfactors;
1204                 factor_modular(modpoly, trialfactors);
1205                 if ( trialfactors.size() <= 1 ) {
1206                         // irreducible for sure
1207                         return poly;
1208                 }
1209
1210                 if ( minfactors == 0 || trialfactors.size() < minfactors ) {
1211                         factors = trialfactors;
1212                         minfactors = factors.size();
1213                         lastp = p;
1214                         trials = 1;
1215                 }
1216                 else {
1217                         ++trials;
1218                 }
1219         }
1220         p = lastp;
1221         R = find_modint_ring(p);
1222         cl_univpoly_modint_ring UPR = find_univpoly_ring(R);
1223
1224         // lift all factor combinations
1225         stack<ModFactors> tocheck;
1226         ModFactors mf;
1227         mf.poly = prim;
1228         mf.factors = factors;
1229         tocheck.push(mf);
1230         ex result = 1;
1231         while ( tocheck.size() ) {
1232                 const size_t n = tocheck.top().factors.size();
1233                 Partition part(n);
1234                 while ( true ) {
1235                         umodpoly a, b;
1236                         split(tocheck.top().factors, part, a, b);
1237
1238                         ex answer = hensel_univar(tocheck.top().poly, x, p, a, b);
1239                         if ( answer != lst() ) {
1240                                 if ( part.size_first() == 1 ) {
1241                                         if ( part.size_second() == 1 ) {
1242                                                 result *= answer.op(0) * answer.op(1);
1243                                                 tocheck.pop();
1244                                                 break;
1245                                         }
1246                                         result *= answer.op(0);
1247                                         tocheck.top().poly = answer.op(1);
1248                                         for ( size_t i=0; i<n; ++i ) {
1249                                                 if ( part[i] == 0 ) {
1250                                                         tocheck.top().factors.erase(tocheck.top().factors.begin()+i);
1251                                                         break;
1252                                                 }
1253                                         }
1254                                         break;
1255                                 }
1256                                 else if ( part.size_second() == 1 ) {
1257                                         if ( part.size_first() == 1 ) {
1258                                                 result *= answer.op(0) * answer.op(1);
1259                                                 tocheck.pop();
1260                                                 break;
1261                                         }
1262                                         result *= answer.op(1);
1263                                         tocheck.top().poly = answer.op(0);
1264                                         for ( size_t i=0; i<n; ++i ) {
1265                                                 if ( part[i] == 1 ) {
1266                                                         tocheck.top().factors.erase(tocheck.top().factors.begin()+i);
1267                                                         break;
1268                                                 }
1269                                         }
1270                                         break;
1271                                 }
1272                                 else {
1273                                         upvec newfactors1(part.size_first()), newfactors2(part.size_second());
1274                                         upvec::iterator i1 = newfactors1.begin(), i2 = newfactors2.begin();
1275                                         for ( size_t i=0; i<n; ++i ) {
1276                                                 if ( part[i] ) {
1277                                                         *i2++ = tocheck.top().factors[i];
1278                                                 }
1279                                                 else {
1280                                                         *i1++ = tocheck.top().factors[i];
1281                                                 }
1282                                         }
1283                                         tocheck.top().factors = newfactors1;
1284                                         tocheck.top().poly = answer.op(0);
1285                                         ModFactors mf;
1286                                         mf.factors = newfactors2;
1287                                         mf.poly = answer.op(1);
1288                                         tocheck.push(mf);
1289                                         break;
1290                                 }
1291                         }
1292                         else {
1293                                 if ( !part.next() ) {
1294                                         result *= tocheck.top().poly;
1295                                         tocheck.pop();
1296                                         break;
1297                                 }
1298                         }
1299                 }
1300         }
1301
1302         return unit * cont * result;
1303 }
1304
1305 struct EvalPoint
1306 {
1307         ex x;
1308         int evalpoint;
1309 };
1310
1311 // forward declaration
1312 vector<ex> multivar_diophant(const vector<ex>& a_, const ex& x, const ex& c, const vector<EvalPoint>& I, unsigned int d, unsigned int p, unsigned int k);
1313
1314 upvec multiterm_eea_lift(const upvec& a, const ex& x, unsigned int p, unsigned int k)
1315 {
1316         const size_t r = a.size();
1317         cl_modint_ring R = find_modint_ring(expt_pos(cl_I(p),k));
1318         upvec q(r-1);
1319         q[r-2] = a[r-1];
1320         for ( size_t j=r-2; j>=1; --j ) {
1321                 q[j-1] = a[j] * q[j];
1322         }
1323         umodpoly beta(1, R->one());
1324         upvec s;
1325         for ( size_t j=1; j<r; ++j ) {
1326                 vector<ex> mdarg(2);
1327                 mdarg[0] = umod_to_ex(q[j-1], x);
1328                 mdarg[1] = umod_to_ex(a[j-1], x);
1329                 vector<EvalPoint> empty;
1330                 vector<ex> exsigma = multivar_diophant(mdarg, x, umod_to_ex(beta, x), empty, 0, p, k);
1331                 umodpoly sigma1;
1332                 umodpoly_from_ex(sigma1, exsigma[0], x, R);
1333                 umodpoly sigma2;
1334                 umodpoly_from_ex(sigma2, exsigma[1], x, R);
1335                 beta = sigma1;
1336                 s.push_back(sigma2);
1337         }
1338         s.push_back(beta);
1339         return s;
1340 }
1341
1342 /**
1343  *  Assert: a not empty.
1344  */
1345 void change_modulus(const cl_modint_ring& R, umodpoly& a)
1346 {
1347         if ( a.empty() ) return;
1348         cl_modint_ring oldR = a[0].ring();
1349         umodpoly::iterator i = a.begin(), end = a.end();
1350         for ( ; i!=end; ++i ) {
1351                 *i = R->canonhom(oldR->retract(*i));
1352         }
1353         canonicalize(a);
1354 }
1355
1356 void eea_lift(const umodpoly& a, const umodpoly& b, const ex& x, unsigned int p, unsigned int k, umodpoly& s_, umodpoly& t_)
1357 {
1358         cl_modint_ring R = find_modint_ring(p);
1359         umodpoly amod = a;
1360         change_modulus(R, amod);
1361         umodpoly bmod = b;
1362         change_modulus(R, bmod);
1363
1364         umodpoly g;
1365         umodpoly smod;
1366         umodpoly tmod;
1367         exteuclid(amod, bmod, g, smod, tmod);
1368         if ( unequal_one(g) ) {
1369                 throw logic_error("gcd(amod,bmod) != 1");
1370         }
1371
1372         cl_modint_ring Rpk = find_modint_ring(expt_pos(cl_I(p),k));
1373         umodpoly s = smod;
1374         change_modulus(Rpk, s);
1375         umodpoly t = tmod;
1376         change_modulus(Rpk, t);
1377
1378         cl_I modulus(p);
1379         umodpoly one(1, Rpk->one());
1380         for ( size_t j=1; j<k; ++j ) {
1381                 umodpoly e = one - a * s - b * t;
1382                 reduce_coeff(e, modulus);
1383                 umodpoly c = e;
1384                 change_modulus(R, c);
1385                 umodpoly sigmabar = smod * c;
1386                 umodpoly taubar = tmod * c;
1387                 umodpoly sigma, q;
1388                 remdiv(sigmabar, bmod, sigma, q);
1389                 umodpoly tau = taubar + q * amod;
1390                 umodpoly sadd = sigma;
1391                 change_modulus(Rpk, sadd);
1392                 cl_MI modmodulus(Rpk, modulus);
1393                 s = s + sadd * modmodulus;
1394                 umodpoly tadd = tau;
1395                 change_modulus(Rpk, tadd);
1396                 t = t + tadd * modmodulus;
1397                 modulus = modulus * p;
1398         }
1399
1400         s_ = s; t_ = t;
1401 }
1402
1403 upvec univar_diophant(const upvec& a, const ex& x, unsigned int m, unsigned int p, unsigned int k)
1404 {
1405         cl_modint_ring R = find_modint_ring(expt_pos(cl_I(p),k));
1406
1407         const size_t r = a.size();
1408         upvec result;
1409         if ( r > 2 ) {
1410                 upvec s = multiterm_eea_lift(a, x, p, k);
1411                 for ( size_t j=0; j<r; ++j ) {
1412                         ex phi = expand(pow(x,m) * umod_to_ex(s[j], x));
1413                         umodpoly bmod;
1414                         umodpoly_from_ex(bmod, phi, x, R);
1415                         umodpoly buf;
1416                         rem(bmod, a[j], buf);
1417                         result.push_back(buf);
1418                 }
1419         }
1420         else {
1421                 umodpoly s;
1422                 umodpoly t;
1423                 eea_lift(a[1], a[0], x, p, k, s, t);
1424                 ex phi = expand(pow(x,m) * umod_to_ex(s, x));
1425                 umodpoly bmod;
1426                 umodpoly_from_ex(bmod, phi, x, R);
1427                 umodpoly buf, q;
1428                 remdiv(bmod, a[0], buf, q);
1429                 result.push_back(buf);
1430                 phi = expand(pow(x,m) * umod_to_ex(t, x));
1431                 umodpoly t1mod;
1432                 umodpoly_from_ex(t1mod, phi, x, R);
1433                 umodpoly buf2 = t1mod + q * a[1];
1434                 result.push_back(buf2);
1435         }
1436
1437         return result;
1438 }
1439
1440 struct make_modular_map : public map_function {
1441         cl_modint_ring R;
1442         make_modular_map(const cl_modint_ring& R_) : R(R_) { }
1443         ex operator()(const ex& e)
1444         {
1445                 if ( is_a<add>(e) || is_a<mul>(e) ) {
1446                         return e.map(*this);
1447                 }
1448                 else if ( is_a<numeric>(e) ) {
1449                         numeric mod(R->modulus);
1450                         numeric halfmod = (mod-1)/2;
1451                         cl_MI emod = R->canonhom(the<cl_I>(ex_to<numeric>(e).to_cl_N()));
1452                         numeric n(R->retract(emod));
1453                         if ( n > halfmod ) {
1454                                 return n-mod;
1455                         }
1456                         else {
1457                                 return n;
1458                         }
1459                 }
1460                 return e;
1461         }
1462 };
1463
1464 static ex make_modular(const ex& e, const cl_modint_ring& R)
1465 {
1466         make_modular_map map(R);
1467         return map(e.expand());
1468 }
1469
1470 vector<ex> multivar_diophant(const vector<ex>& a_, const ex& x, const ex& c, const vector<EvalPoint>& I, unsigned int d, unsigned int p, unsigned int k)
1471 {
1472         vector<ex> a = a_;
1473
1474         const cl_modint_ring R = find_modint_ring(expt_pos(cl_I(p),k));
1475         const size_t r = a.size();
1476         const size_t nu = I.size() + 1;
1477
1478         vector<ex> sigma;
1479         if ( nu > 1 ) {
1480                 ex xnu = I.back().x;
1481                 int alphanu = I.back().evalpoint;
1482
1483                 ex A = 1;
1484                 for ( size_t i=0; i<r; ++i ) {
1485                         A *= a[i];
1486                 }
1487                 vector<ex> b(r);
1488                 for ( size_t i=0; i<r; ++i ) {
1489                         b[i] = normal(A / a[i]);
1490                 }
1491
1492                 vector<ex> anew = a;
1493                 for ( size_t i=0; i<r; ++i ) {
1494                         anew[i] = anew[i].subs(xnu == alphanu);
1495                 }
1496                 ex cnew = c.subs(xnu == alphanu);
1497                 vector<EvalPoint> Inew = I;
1498                 Inew.pop_back();
1499                 sigma = multivar_diophant(anew, x, cnew, Inew, d, p, k);
1500
1501                 ex buf = c;
1502                 for ( size_t i=0; i<r; ++i ) {
1503                         buf -= sigma[i] * b[i];
1504                 }
1505                 ex e = make_modular(buf, R);
1506
1507                 ex monomial = 1;
1508                 for ( size_t m=1; m<=d; ++m ) {
1509                         while ( !e.is_zero() && e.has(xnu) ) {
1510                                 monomial *= (xnu - alphanu);
1511                                 monomial = expand(monomial);
1512                                 ex cm = e.diff(ex_to<symbol>(xnu), m).subs(xnu==alphanu) / factorial(m);
1513                                 cm = make_modular(cm, R);
1514                                 if ( !cm.is_zero() ) {
1515                                         vector<ex> delta_s = multivar_diophant(anew, x, cm, Inew, d, p, k);
1516                                         ex buf = e;
1517                                         for ( size_t j=0; j<delta_s.size(); ++j ) {
1518                                                 delta_s[j] *= monomial;
1519                                                 sigma[j] += delta_s[j];
1520                                                 buf -= delta_s[j] * b[j];
1521                                         }
1522                                         e = make_modular(buf, R);
1523                                 }
1524                         }
1525                 }
1526         }
1527         else {
1528                 upvec amod;
1529                 for ( size_t i=0; i<a.size(); ++i ) {
1530                         umodpoly up;
1531                         umodpoly_from_ex(up, a[i], x, R);
1532                         amod.push_back(up);
1533                 }
1534
1535                 sigma.insert(sigma.begin(), r, 0);
1536                 size_t nterms;
1537                 ex z;
1538                 if ( is_a<add>(c) ) {
1539                         nterms = c.nops();
1540                         z = c.op(0);
1541                 }
1542                 else {
1543                         nterms = 1;
1544                         z = c;
1545                 }
1546                 for ( size_t i=0; i<nterms; ++i ) {
1547                         int m = z.degree(x);
1548                         cl_I cm = the<cl_I>(ex_to<numeric>(z.lcoeff(x)).to_cl_N());
1549                         upvec delta_s = univar_diophant(amod, x, m, p, k);
1550                         cl_MI modcm;
1551                         cl_I poscm = cm;
1552                         while ( poscm < 0 ) {
1553                                 poscm = poscm + expt_pos(cl_I(p),k);
1554                         }
1555                         modcm = cl_MI(R, poscm);
1556                         for ( size_t j=0; j<delta_s.size(); ++j ) {
1557                                 delta_s[j] = delta_s[j] * modcm;
1558                                 sigma[j] = sigma[j] + umod_to_ex(delta_s[j], x);
1559                         }
1560                         if ( nterms > 1 ) {
1561                                 z = c.op(i+1);
1562                         }
1563                 }
1564         }
1565
1566         for ( size_t i=0; i<sigma.size(); ++i ) {
1567                 sigma[i] = make_modular(sigma[i], R);
1568         }
1569
1570         return sigma;
1571 }
1572
1573 #ifdef DEBUGFACTOR
1574 ostream& operator<<(ostream& o, const vector<EvalPoint>& v)
1575 {
1576         for ( size_t i=0; i<v.size(); ++i ) {
1577                 o << "(" << v[i].x << "==" << v[i].evalpoint << ") ";
1578         }
1579         return o;
1580 }
1581 #endif // def DEBUGFACTOR
1582
1583 ex hensel_multivar(const ex& a, const ex& x, const vector<EvalPoint>& I, unsigned int p, const cl_I& l, const upvec& u, const vector<ex>& lcU)
1584 {
1585         const size_t nu = I.size() + 1;
1586         const cl_modint_ring R = find_modint_ring(expt_pos(cl_I(p),l));
1587
1588         vector<ex> A(nu);
1589         A[nu-1] = a;
1590
1591         for ( size_t j=nu; j>=2; --j ) {
1592                 ex x = I[j-2].x;
1593                 int alpha = I[j-2].evalpoint;
1594                 A[j-2] = A[j-1].subs(x==alpha);
1595                 A[j-2] = make_modular(A[j-2], R);
1596         }
1597
1598         int maxdeg = a.degree(I.front().x);
1599         for ( size_t i=1; i<I.size(); ++i ) {
1600                 int maxdeg2 = a.degree(I[i].x);
1601                 if ( maxdeg2 > maxdeg ) maxdeg = maxdeg2;
1602         }
1603
1604         const size_t n = u.size();
1605         vector<ex> U(n);
1606         for ( size_t i=0; i<n; ++i ) {
1607                 U[i] = umod_to_ex(u[i], x);
1608         }
1609
1610         for ( size_t j=2; j<=nu; ++j ) {
1611                 vector<ex> U1 = U;
1612                 ex monomial = 1;
1613                 for ( size_t m=0; m<n; ++m) {
1614                         if ( lcU[m] != 1 ) {
1615                                 ex coef = lcU[m];
1616                                 for ( size_t i=j-1; i<nu-1; ++i ) {
1617                                         coef = coef.subs(I[i].x == I[i].evalpoint);
1618                                 }
1619                                 coef = make_modular(coef, R);
1620                                 int deg = U[m].degree(x);
1621                                 U[m] = U[m] - U[m].lcoeff(x) * pow(x,deg) + coef * pow(x,deg);
1622                         }
1623                 }
1624                 ex Uprod = 1;
1625                 for ( size_t i=0; i<n; ++i ) {
1626                         Uprod *= U[i];
1627                 }
1628                 ex e = expand(A[j-1] - Uprod);
1629
1630                 vector<EvalPoint> newI;
1631                 for ( size_t i=1; i<=j-2; ++i ) {
1632                         newI.push_back(I[i-1]);
1633                 }
1634
1635                 ex xj = I[j-2].x;
1636                 int alphaj = I[j-2].evalpoint;
1637                 size_t deg = A[j-1].degree(xj);
1638                 for ( size_t k=1; k<=deg; ++k ) {
1639                         if ( !e.is_zero() ) {
1640                                 monomial *= (xj - alphaj);
1641                                 monomial = expand(monomial);
1642                                 ex dif = e.diff(ex_to<symbol>(xj), k);
1643                                 ex c = dif.subs(xj==alphaj) / factorial(k);
1644                                 if ( !c.is_zero() ) {
1645                                         vector<ex> deltaU = multivar_diophant(U1, x, c, newI, maxdeg, p, cl_I_to_uint(l));
1646                                         for ( size_t i=0; i<n; ++i ) {
1647                                                 deltaU[i] *= monomial;
1648                                                 U[i] += deltaU[i];
1649                                                 U[i] = make_modular(U[i], R);
1650                                         }
1651                                         ex Uprod = 1;
1652                                         for ( size_t i=0; i<n; ++i ) {
1653                                                 Uprod *= U[i];
1654                                         }
1655                                         e = A[j-1] - Uprod;
1656                                         e = make_modular(e, R);
1657                                 }
1658                         }
1659                 }
1660         }
1661
1662         ex acand = 1;
1663         for ( size_t i=0; i<U.size(); ++i ) {
1664                 acand *= U[i];
1665         }
1666         if ( expand(a-acand).is_zero() ) {
1667                 lst res;
1668                 for ( size_t i=0; i<U.size(); ++i ) {
1669                         res.append(U[i]);
1670                 }
1671                 return res;
1672         }
1673         else {
1674                 lst res;
1675                 return lst();
1676         }
1677 }
1678
1679 static ex put_factors_into_lst(const ex& e)
1680 {
1681         lst result;
1682
1683         if ( is_a<numeric>(e) ) {
1684                 result.append(e);
1685                 return result;
1686         }
1687         if ( is_a<power>(e) ) {
1688                 result.append(1);
1689                 result.append(e.op(0));
1690                 result.append(e.op(1));
1691                 return result;
1692         }
1693         if ( is_a<symbol>(e) || is_a<add>(e) ) {
1694                 result.append(1);
1695                 result.append(e);
1696                 result.append(1);
1697                 return result;
1698         }
1699         if ( is_a<mul>(e) ) {
1700                 ex nfac = 1;
1701                 for ( size_t i=0; i<e.nops(); ++i ) {
1702                         ex op = e.op(i);
1703                         if ( is_a<numeric>(op) ) {
1704                                 nfac = op;
1705                         }
1706                         if ( is_a<power>(op) ) {
1707                                 result.append(op.op(0));
1708                                 result.append(op.op(1));
1709                         }
1710                         if ( is_a<symbol>(op) || is_a<add>(op) ) {
1711                                 result.append(op);
1712                                 result.append(1);
1713                         }
1714                 }
1715                 result.prepend(nfac);
1716                 return result;
1717         }
1718         throw runtime_error("put_factors_into_lst: bad term.");
1719 }
1720
1721 #ifdef DEBUGFACTOR
1722 ostream& operator<<(ostream& o, const vector<numeric>& v)
1723 {
1724         for ( size_t i=0; i<v.size(); ++i ) {
1725                 o << v[i] << " ";
1726         }
1727         return o;
1728 }
1729 #endif // def DEBUGFACTOR
1730
1731 static bool checkdivisors(const lst& f, vector<numeric>& d)
1732 {
1733         const int k = f.nops()-2;
1734         numeric q, r;
1735         d[0] = ex_to<numeric>(f.op(0) * f.op(f.nops()-1));
1736         if ( d[0] == 1 && k == 1 && abs(f.op(1)) != 1 ) {
1737                 return false;
1738         }
1739         for ( int i=1; i<=k; ++i ) {
1740                 q = ex_to<numeric>(abs(f.op(i)));
1741                 for ( int j=i-1; j>=0; --j ) {
1742                         r = d[j];
1743                         do {
1744                                 r = gcd(r, q);
1745                                 q = q/r;
1746                         } while ( r != 1 );
1747                         if ( q == 1 ) {
1748                                 return true;
1749                         }
1750                 }
1751                 d[i] = q;
1752         }
1753         return false;
1754 }
1755
1756 static bool generate_set(const ex& u, const ex& vn, const exset& syms, const ex& f, const numeric& modulus, vector<numeric>& a, vector<numeric>& d)
1757 {
1758         // computation of d is actually not necessary
1759         const ex& x = *syms.begin();
1760         bool trying = true;
1761         do {
1762                 ex u0 = u;
1763                 ex vna = vn;
1764                 ex vnatry;
1765                 exset::const_iterator s = syms.begin();
1766                 ++s;
1767                 for ( size_t i=0; i<a.size(); ++i ) {
1768                         do {
1769                                 a[i] = mod(numeric(rand()), 2*modulus) - modulus;
1770                                 vnatry = vna.subs(*s == a[i]);
1771                         } while ( vnatry == 0 );
1772                         vna = vnatry;
1773                         u0 = u0.subs(*s == a[i]);
1774                         ++s;
1775                 }
1776                 if ( gcd(u0,u0.diff(ex_to<symbol>(x))) != 1 ) {
1777                         continue;
1778                 }
1779                 if ( is_a<numeric>(vn) ) {
1780                         trying = false;
1781                 }
1782                 else {
1783                         lst fnum;
1784                         lst::const_iterator i = ex_to<lst>(f).begin();
1785                         fnum.append(*i++);
1786                         bool problem = false;
1787                         while ( i!=ex_to<lst>(f).end() ) {
1788                                 ex fs = *i;
1789                                 if ( !is_a<numeric>(fs) ) {
1790                                         s = syms.begin();
1791                                         ++s;
1792                                         for ( size_t j=0; j<a.size(); ++j ) {
1793                                                 fs = fs.subs(*s == a[j]);
1794                                                 ++s;
1795                                         }
1796                                         if ( abs(fs) == 1 ) {
1797                                                 problem = true;
1798                                                 break;
1799                                         }
1800                                 }
1801                                 fnum.append(fs);
1802                                 ++i; ++i;
1803                         }
1804                         if ( problem ) {
1805                                 return true;
1806                         }
1807                         ex con = u0.content(x);
1808                         fnum.append(con);
1809                         trying = checkdivisors(fnum, d);
1810                 }
1811         } while ( trying );
1812         return false;
1813 }
1814
1815 static ex factor_multivariate(const ex& poly, const exset& syms)
1816 {
1817         exset::const_iterator s;
1818         const ex& x = *syms.begin();
1819
1820         /* make polynomial primitive */
1821         ex p = poly.expand().collect(x);
1822         ex cont = p.lcoeff(x);
1823         for ( numeric i=p.degree(x)-1; i>=p.ldegree(x); --i ) {
1824                 cont = gcd(cont, p.coeff(x,ex_to<numeric>(i).to_int()));
1825                 if ( cont == 1 ) break;
1826         }
1827         ex pp = expand(normal(p / cont));
1828         if ( !is_a<numeric>(cont) ) {
1829                 return factor(cont) * factor(pp);
1830         }
1831
1832         /* factor leading coefficient */
1833         pp = pp.collect(x);
1834         ex vn = pp.lcoeff(x);
1835         pp = pp.expand();
1836         ex vnlst;
1837         if ( is_a<numeric>(vn) ) {
1838                 vnlst = lst(vn);
1839         }
1840         else {
1841                 ex vnfactors = factor(vn);
1842                 vnlst = put_factors_into_lst(vnfactors);
1843         }
1844
1845         const numeric maxtrials = 3;
1846         numeric modulus = (vnlst.nops()-1 > 3) ? vnlst.nops()-1 : 3;
1847         numeric minimalr = -1;
1848         vector<numeric> a(syms.size()-1, 0);
1849         vector<numeric> d((vnlst.nops()-1)/2+1, 0);
1850
1851         while ( true ) {
1852                 numeric trialcount = 0;
1853                 ex u, delta;
1854                 unsigned int prime = 3;
1855                 size_t factor_count = 0;
1856                 ex ufac;
1857                 ex ufaclst;
1858                 while ( trialcount < maxtrials ) {
1859                         bool problem = generate_set(pp, vn, syms, vnlst, modulus, a, d);
1860                         if ( problem ) {
1861                                 ++modulus;
1862                                 continue;
1863                         }
1864                         u = pp;
1865                         s = syms.begin();
1866                         ++s;
1867                         for ( size_t i=0; i<a.size(); ++i ) {
1868                                 u = u.subs(*s == a[i]);
1869                                 ++s;
1870                         }
1871                         delta = u.content(x);
1872
1873                         // determine proper prime
1874                         prime = 3;
1875                         cl_modint_ring R = find_modint_ring(prime);
1876                         while ( true ) {
1877                                 if ( irem(ex_to<numeric>(u.lcoeff(x)), prime) != 0 ) {
1878                                         umodpoly modpoly;
1879                                         umodpoly_from_ex(modpoly, u, x, R);
1880                                         if ( squarefree(modpoly) ) break;
1881                                 }
1882                                 prime = next_prime(prime);
1883                                 R = find_modint_ring(prime);
1884                         }
1885
1886                         ufac = factor(u);
1887                         ufaclst = put_factors_into_lst(ufac);
1888                         factor_count = (ufaclst.nops()-1)/2;
1889
1890                         // veto factorization for which gcd(u_i, u_j) != 1 for all i,j
1891                         upvec tryu;
1892                         for ( size_t i=0; i<(ufaclst.nops()-1)/2; ++i ) {
1893                                 umodpoly newu;
1894                                 umodpoly_from_ex(newu, ufaclst.op(i*2+1), x, R);
1895                                 tryu.push_back(newu);
1896                         }
1897                         bool veto = false;
1898                         for ( size_t i=0; i<tryu.size()-1; ++i ) {
1899                                 for ( size_t j=i+1; j<tryu.size(); ++j ) {
1900                                         umodpoly tryg;
1901                                         gcd(tryu[i], tryu[j], tryg);
1902                                         if ( unequal_one(tryg) ) {
1903                                                 veto = true;
1904                                                 goto escape_quickly;
1905                                         }
1906                                 }
1907                         }
1908                         escape_quickly: ;
1909                         if ( veto ) {
1910                                 continue;
1911                         }
1912
1913                         if ( factor_count <= 1 ) {
1914                                 return poly;
1915                         }
1916
1917                         if ( minimalr < 0 ) {
1918                                 minimalr = factor_count;
1919                         }
1920                         else if ( minimalr == factor_count ) {
1921                                 ++trialcount;
1922                                 ++modulus;
1923                         }
1924                         else if ( minimalr > factor_count ) {
1925                                 minimalr = factor_count;
1926                                 trialcount = 0;
1927                         }
1928                         if ( minimalr <= 1 ) {
1929                                 return poly;
1930                         }
1931                 }
1932
1933                 vector<numeric> ftilde((vnlst.nops()-1)/2+1);
1934                 ftilde[0] = ex_to<numeric>(vnlst.op(0));
1935                 for ( size_t i=1; i<ftilde.size(); ++i ) {
1936                         ex ft = vnlst.op((i-1)*2+1);
1937                         s = syms.begin();
1938                         ++s;
1939                         for ( size_t j=0; j<a.size(); ++j ) {
1940                                 ft = ft.subs(*s == a[j]);
1941                                 ++s;
1942                         }
1943                         ftilde[i] = ex_to<numeric>(ft);
1944                 }
1945
1946                 vector<bool> used_flag((vnlst.nops()-1)/2+1, false);
1947                 vector<ex> D(factor_count, 1);
1948                 for ( size_t i=0; i<=factor_count; ++i ) {
1949                         numeric prefac;
1950                         if ( i == 0 ) {
1951                                 prefac = ex_to<numeric>(ufaclst.op(0));
1952                                 ftilde[0] = ftilde[0] / prefac;
1953                                 vnlst.let_op(0) = vnlst.op(0) / prefac;
1954                                 continue;
1955                         }
1956                         else {
1957                                 prefac = ex_to<numeric>(ufaclst.op(2*(i-1)+1).lcoeff(x));
1958                         }
1959                         for ( size_t j=(vnlst.nops()-1)/2+1; j>0; --j ) {
1960                                 if ( abs(ftilde[j-1]) == 1 ) {
1961                                         used_flag[j-1] = true;
1962                                         continue;
1963                                 }
1964                                 numeric g = gcd(prefac, ftilde[j-1]);
1965                                 if ( g != 1 ) {
1966                                         prefac = prefac / g;
1967                                         numeric count = abs(iquo(g, ftilde[j-1]));
1968                                         used_flag[j-1] = true;
1969                                         if ( i > 0 ) {
1970                                                 if ( j == 1 ) {
1971                                                         D[i-1] = D[i-1] * pow(vnlst.op(0), count);
1972                                                 }
1973                                                 else {
1974                                                         D[i-1] = D[i-1] * pow(vnlst.op(2*(j-2)+1), count);
1975                                                 }
1976                                         }
1977                                         else {
1978                                                 ftilde[j-1] = ftilde[j-1] / prefac;
1979                                                 break;
1980                                         }
1981                                         ++j;
1982                                 }
1983                         }
1984                 }
1985
1986                 bool some_factor_unused = false;
1987                 for ( size_t i=0; i<used_flag.size(); ++i ) {
1988                         if ( !used_flag[i] ) {
1989                                 some_factor_unused = true;
1990                                 break;
1991                         }
1992                 }
1993                 if ( some_factor_unused ) {
1994                         continue;
1995                 }
1996
1997                 vector<ex> C(factor_count);
1998                 if ( delta == 1 ) {
1999                         for ( size_t i=0; i<D.size(); ++i ) {
2000                                 ex Dtilde = D[i];
2001                                 s = syms.begin();
2002                                 ++s;
2003                                 for ( size_t j=0; j<a.size(); ++j ) {
2004                                         Dtilde = Dtilde.subs(*s == a[j]);
2005                                         ++s;
2006                                 }
2007                                 C[i] = D[i] * (ufaclst.op(2*i+1).lcoeff(x) / Dtilde);
2008                         }
2009                 }
2010                 else {
2011                         for ( size_t i=0; i<D.size(); ++i ) {
2012                                 ex Dtilde = D[i];
2013                                 s = syms.begin();
2014                                 ++s;
2015                                 for ( size_t j=0; j<a.size(); ++j ) {
2016                                         Dtilde = Dtilde.subs(*s == a[j]);
2017                                         ++s;
2018                                 }
2019                                 ex ui;
2020                                 if ( i == 0 ) {
2021                                         ui = ufaclst.op(0);
2022                                 }
2023                                 else {
2024                                         ui = ufaclst.op(2*(i-1)+1);
2025                                 }
2026                                 while ( true ) {
2027                                         ex d = gcd(ui.lcoeff(x), Dtilde);
2028                                         C[i] = D[i] * ( ui.lcoeff(x) / d );
2029                                         ui = ui * ( Dtilde[i] / d );
2030                                         delta = delta / ( Dtilde[i] / d );
2031                                         if ( delta == 1 ) break;
2032                                         ui = delta * ui;
2033                                         C[i] = delta * C[i];
2034                                         pp = pp * pow(delta, D.size()-1);
2035                                 }
2036                         }
2037                 }
2038
2039                 EvalPoint ep;
2040                 vector<EvalPoint> epv;
2041                 s = syms.begin();
2042                 ++s;
2043                 for ( size_t i=0; i<a.size(); ++i ) {
2044                         ep.x = *s++;
2045                         ep.evalpoint = a[i].to_int();
2046                         epv.push_back(ep);
2047                 }
2048
2049                 // calc bound B
2050                 ex maxcoeff;
2051                 for ( int i=u.degree(x); i>=u.ldegree(x); --i ) {
2052                         maxcoeff += pow(abs(u.coeff(x, i)),2);
2053                 }
2054                 cl_I normmc = ceiling1(the<cl_R>(cln::sqrt(ex_to<numeric>(maxcoeff).to_cl_N())));
2055                 unsigned int maxdegree = 0;
2056                 for ( size_t i=0; i<factor_count; ++i ) {
2057                         if ( ufaclst[2*i+1].degree(x) > (int)maxdegree ) {
2058                                 maxdegree = ufaclst[2*i+1].degree(x);
2059                         }
2060                 }
2061                 cl_I B = normmc * expt_pos(cl_I(2), maxdegree);
2062                 cl_I l = 1;
2063                 cl_I pl = prime;
2064                 while ( pl < B ) {
2065                         l = l + 1;
2066                         pl = pl * prime;
2067                 }
2068
2069                 upvec uvec;
2070                 cl_modint_ring R = find_modint_ring(expt_pos(cl_I(prime),l));
2071                 for ( size_t i=0; i<(ufaclst.nops()-1)/2; ++i ) {
2072                         umodpoly newu;
2073                         umodpoly_from_ex(newu, ufaclst.op(i*2+1), x, R);
2074                         uvec.push_back(newu);
2075                 }
2076
2077                 ex res = hensel_multivar(ufaclst.op(0)*pp, x, epv, prime, l, uvec, C);
2078                 if ( res != lst() ) {
2079                         ex result = cont * ufaclst.op(0);
2080                         for ( size_t i=0; i<res.nops(); ++i ) {
2081                                 result *= res.op(i).content(x) * res.op(i).unit(x);
2082                                 result *= res.op(i).primpart(x);
2083                         }
2084                         return result;
2085                 }
2086         }
2087 }
2088
2089 struct find_symbols_map : public map_function {
2090         exset syms;
2091         ex operator()(const ex& e)
2092         {
2093                 if ( is_a<symbol>(e) ) {
2094                         syms.insert(e);
2095                         return e;
2096                 }
2097                 return e.map(*this);
2098         }
2099 };
2100
2101 static ex factor_sqrfree(const ex& poly)
2102 {
2103         // determine all symbols in poly
2104         find_symbols_map findsymbols;
2105         findsymbols(poly);
2106         if ( findsymbols.syms.size() == 0 ) {
2107                 return poly;
2108         }
2109
2110         if ( findsymbols.syms.size() == 1 ) {
2111                 // univariate case
2112                 const ex& x = *(findsymbols.syms.begin());
2113                 if ( poly.ldegree(x) > 0 ) {
2114                         int ld = poly.ldegree(x);
2115                         ex res = factor_univariate(expand(poly/pow(x, ld)), x);
2116                         return res * pow(x,ld);
2117                 }
2118                 else {
2119                         ex res = factor_univariate(poly, x);
2120                         return res;
2121                 }
2122         }
2123
2124         // multivariate case
2125         ex res = factor_multivariate(poly, findsymbols.syms);
2126         return res;
2127 }
2128
2129 struct apply_factor_map : public map_function {
2130         unsigned options;
2131         apply_factor_map(unsigned options_) : options(options_) { }
2132         ex operator()(const ex& e)
2133         {
2134                 if ( e.info(info_flags::polynomial) ) {
2135                         return factor(e, options);
2136                 }
2137                 if ( is_a<add>(e) ) {
2138                         ex s1, s2;
2139                         for ( size_t i=0; i<e.nops(); ++i ) {
2140                                 if ( e.op(i).info(info_flags::polynomial) ) {
2141                                         s1 += e.op(i);
2142                                 }
2143                                 else {
2144                                         s2 += e.op(i);
2145                                 }
2146                         }
2147                         s1 = s1.eval();
2148                         s2 = s2.eval();
2149                         return factor(s1, options) + s2.map(*this);
2150                 }
2151                 return e.map(*this);
2152         }
2153 };
2154
2155 } // anonymous namespace
2156
2157 ex factor(const ex& poly, unsigned options)
2158 {
2159         // check arguments
2160         if ( !poly.info(info_flags::polynomial) ) {
2161                 if ( options & factor_options::all ) {
2162                         options &= ~factor_options::all;
2163                         apply_factor_map factor_map(options);
2164                         return factor_map(poly);
2165                 }
2166                 return poly;
2167         }
2168
2169         // determine all symbols in poly
2170         find_symbols_map findsymbols;
2171         findsymbols(poly);
2172         if ( findsymbols.syms.size() == 0 ) {
2173                 return poly;
2174         }
2175         lst syms;
2176         exset::const_iterator i=findsymbols.syms.begin(), end=findsymbols.syms.end();
2177         for ( ; i!=end; ++i ) {
2178                 syms.append(*i);
2179         }
2180
2181         // make poly square free
2182         ex sfpoly = sqrfree(poly, syms);
2183
2184         // factorize the square free components
2185         if ( is_a<power>(sfpoly) ) {
2186                 // case: (polynomial)^exponent
2187                 const ex& base = sfpoly.op(0);
2188                 if ( !is_a<add>(base) ) {
2189                         // simple case: (monomial)^exponent
2190                         return sfpoly;
2191                 }
2192                 ex f = factor_sqrfree(base);
2193                 return pow(f, sfpoly.op(1));
2194         }
2195         if ( is_a<mul>(sfpoly) ) {
2196                 // case: multiple factors
2197                 ex res = 1;
2198                 for ( size_t i=0; i<sfpoly.nops(); ++i ) {
2199                         const ex& t = sfpoly.op(i);
2200                         if ( is_a<power>(t) ) {
2201                                 const ex& base = t.op(0);
2202                                 if ( !is_a<add>(base) ) {
2203                                         res *= t;
2204                                 }
2205                                 else {
2206                                         ex f = factor_sqrfree(base);
2207                                         res *= pow(f, t.op(1));
2208                                 }
2209                         }
2210                         else if ( is_a<add>(t) ) {
2211                                 ex f = factor_sqrfree(t);
2212                                 res *= f;
2213                         }
2214                         else {
2215                                 res *= t;
2216                         }
2217                 }
2218                 return res;
2219         }
2220         if ( is_a<symbol>(sfpoly) ) {
2221                 return poly;
2222         }
2223         // case: (polynomial)
2224         ex f = factor_sqrfree(sfpoly);
2225         return f;
2226 }
2227
2228 } // namespace GiNaC
2229
2230 #ifdef DEBUGFACTOR
2231 #include "test.h"
2232 #endif