]> www.ginac.de Git - ginac.git/blobdiff - ginac/normal.cpp
Avoid unnecessary expansion in sqrfree_yun().
[ginac.git] / ginac / normal.cpp
index a1e2de48c9f598fd09c915dd026e16211d2fde70..f8615499add62a1a8e447a7b8d5f2aaf725b62b0 100644 (file)
@@ -6,7 +6,7 @@
  *  computation, square-free factorization and rational function normalization. */
 
 /*
- *  GiNaC Copyright (C) 1999-2016 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2020 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -146,7 +146,7 @@ struct sym_desc {
        /** Maximum number of terms of leading coefficient of symbol in both polynomials */
        size_t max_lcnops;
 
-       /** Commparison operator for sorting */
+       /** Comparison operator for sorting */
        bool operator<(const sym_desc &x) const
        {
                if (max_deg == x.max_deg)
@@ -212,10 +212,10 @@ static void get_symbol_stats(const ex &a, const ex &b, sym_desc_vec &v)
 
 #if 0
        std::clog << "Symbols:\n";
-       it = v.begin(); itend = v.end();
+       auto it = v.begin(), itend = v.end();
        while (it != itend) {
-               std::clog << " " << it->sym << ": deg_a=" << it->deg_a << ", deg_b=" << it->deg_b << ", ldeg_a=" << it->ldeg_a << ", ldeg_b=" << it->ldeg_b << ", max_deg=" << it->max_deg << ", max_lcnops=" << it->max_lcnops << endl;
-               std::clog << "  lcoeff_a=" << a.lcoeff(it->sym) << ", lcoeff_b=" << b.lcoeff(it->sym) << endl;
+               std::clog << " " << it->sym << ": deg_a=" << it->deg_a << ", deg_b=" << it->deg_b << ", ldeg_a=" << it->ldeg_a << ", ldeg_b=" << it->ldeg_b << ", max_deg=" << it->max_deg << ", max_lcnops=" << it->max_lcnops << std::endl;
+               std::clog << "  lcoeff_a=" << a.lcoeff(it->sym) << ", lcoeff_b=" << b.lcoeff(it->sym) << std::endl;
                ++it;
        }
 #endif
@@ -270,9 +270,15 @@ static numeric lcm_of_coefficients_denominators(const ex &e)
  *  @param lcm  LCM to multiply in */
 static ex multiply_lcm(const ex &e, const numeric &lcm)
 {
+       if (lcm.is_equal(*_num1_p))
+               // e * 1 -> e;
+               return e;
+
        if (is_exactly_a<mul>(e)) {
+               // (a*b*...)*lcm -> (a*lcma)*(b*lcmb)*...*(lcm/(lcma*lcmb*...))
                size_t num = e.nops();
-               exvector v; v.reserve(num + 1);
+               exvector v;
+               v.reserve(num + 1);
                numeric lcm_accum = *_num1_p;
                for (size_t i=0; i<num; i++) {
                        numeric op_lcm = lcmcoeff(e.op(i), *_num1_p);
@@ -282,23 +288,24 @@ static ex multiply_lcm(const ex &e, const numeric &lcm)
                v.push_back(lcm / lcm_accum);
                return dynallocate<mul>(v);
        } else if (is_exactly_a<add>(e)) {
+               // (a+b+...)*lcm -> a*lcm+b*lcm+...
                size_t num = e.nops();
-               exvector v; v.reserve(num);
+               exvector v;
+               v.reserve(num);
                for (size_t i=0; i<num; i++)
                        v.push_back(multiply_lcm(e.op(i), lcm));
                return dynallocate<add>(v);
        } else if (is_exactly_a<power>(e)) {
-               if (is_a<symbol>(e.op(0)))
-                       return e * lcm;
-               else {
+               if (!is_a<symbol>(e.op(0))) {
+                       // (b^e)*lcm -> (b*lcm^(1/e))^e if lcm^(1/e) ∈ ℚ (i.e. not a float)
+                       // but not for symbolic b, as evaluation would undo this again
                        numeric root_of_lcm = lcm.power(ex_to<numeric>(e.op(1)).inverse());
                        if (root_of_lcm.is_rational())
                                return pow(multiply_lcm(e.op(0), root_of_lcm), e.op(1));
-                       else
-                               return e * lcm;
                }
-       } else
-               return e * lcm;
+       }
+       // can't recurse down into e
+       return dynallocate<mul>(e, lcm);
 }
 
 
@@ -1463,7 +1470,7 @@ ex gcd(const ex &a, const ex &b, ex *ca, ex *cb, bool check_args, unsigned optio
        }
 
        // Some trivial cases
-       ex aex = a.expand(), bex = b.expand();
+       ex aex = a.expand();
        if (aex.is_zero()) {
                if (ca)
                        *ca = _ex0;
@@ -1471,6 +1478,7 @@ ex gcd(const ex &a, const ex &b, ex *ca, ex *cb, bool check_args, unsigned optio
                        *cb = _ex1;
                return b;
        }
+       ex bex = b.expand();
        if (bex.is_zero()) {
                if (ca)
                        *ca = _ex1;
@@ -1541,7 +1549,7 @@ ex gcd(const ex &a, const ex &b, ex *ca, ex *cb, bool check_args, unsigned optio
 
        // The symbol with least degree which is contained in both polynomials
        // is our main variable
-       sym_desc_vec::iterator vari = sym_stats.begin();
+       auto vari = sym_stats.begin();
        while ((vari != sym_stats.end()) && 
               (((vari->ldeg_b == 0) && (vari->deg_b == 0)) ||
                ((vari->ldeg_a == 0) && (vari->deg_a == 0))))
@@ -1556,8 +1564,7 @@ ex gcd(const ex &a, const ex &b, ex *ca, ex *cb, bool check_args, unsigned optio
                        *cb = b;
                return _ex1;
        }
-       // move symbols which contained only in one of the polynomials
-       // to the end:
+       // move symbol contained only in one of the polynomials to the end:
        rotate(sym_stats.begin(), vari, sym_stats.end());
 
        sym_desc_vec::const_iterator var = sym_stats.begin();
@@ -1669,7 +1676,6 @@ static ex gcd_pf_pow_pow(const ex& a, const ex& b, ex* ca, ex* cb)
                        if (cb)
                                *cb = b;
                        return _ex1;
-                       // XXX: do I need to check for p_gcd = -1?
        }
 
        // there are common factors:
@@ -1699,17 +1705,27 @@ static ex gcd_pf_pow(const ex& a, const ex& b, ex* ca, ex* cb)
        if (p.is_equal(b)) {
                // a = p^n, b = p, gcd = p
                if (ca)
-                       *ca = pow(p, a.op(1) - 1);
+                       *ca = pow(p, exp_a - 1);
                if (cb)
                        *cb = _ex1;
                return p;
-       } 
+       }
+       if (is_a<symbol>(p)) {
+               // Cancel trivial common factor
+               int ldeg_a = ex_to<numeric>(exp_a).to_int();
+               int ldeg_b = b.ldegree(p);
+               int min_ldeg = std::min(ldeg_a, ldeg_b);
+               if (min_ldeg > 0) {
+                       ex common = pow(p, min_ldeg);
+                       return gcd(pow(p, ldeg_a - min_ldeg), (b / common).expand(), ca, cb, false) * common;
+               }
+       }
 
        ex p_co, bpart_co;
        ex p_gcd = gcd(p, b, &p_co, &bpart_co, false);
 
-       // a(x) = p(x)^n, gcd(p, b) = 1 ==> gcd(a, b) = 1
        if (p_gcd.is_equal(_ex1)) {
+               // a(x) = p(x)^n, gcd(p, b) = 1 ==> gcd(a, b) = 1
                if (ca)
                        *ca = a;
                if (cb)
@@ -1778,32 +1794,59 @@ ex lcm(const ex &a, const ex &b, bool check_args)
  *  @param a  multivariate polynomial over Z[X], treated here as univariate
  *            polynomial in x (needs not be expanded).
  *  @param x  variable to factor in
- *  @return   vector of factors sorted in ascending degree */
-static exvector sqrfree_yun(const ex &a, const symbol &x)
+ *  @return   vector of expairs (factor, exponent), sorted by exponents */
+static epvector sqrfree_yun(const ex &a, const symbol &x)
 {
-       exvector res;
        ex w = a;
        ex z = w.diff(x);
        ex g = gcd(w, z);
        if (g.is_zero()) {
-               return res;
+               // manifest zero or hidden zero
+               return {};
        }
        if (g.is_equal(_ex1)) {
-               res.push_back(a);
-               return res;
+               // w(x) and w'(x) share no factors: w(x) is square-free
+               return {expair(a, _ex1)};
        }
-       ex y;
+
+       epvector factors;
+       ex i = 0;  // exponent
        do {
                w = quo(w, g, x);
                if (w.is_zero()) {
-                       return res;
+                       // hidden zero
+                       break;
+               }
+               z = quo(z, g, x) - w.diff(x);
+               i += 1;
+               if (w.is_equal(x)) {
+                       // shortcut for x^n with n ∈ ℕ
+                       i += quo(z, w.diff(x), x);
+                       factors.push_back(expair(w, i));
+                       break;
                }
-               y = quo(z, g, x);
-               z = y - w.diff(x);
                g = gcd(w, z);
-               res.push_back(g);
+               if (!g.is_equal(_ex1)) {
+                       factors.push_back(expair(g, i));
+               }
        } while (!z.is_zero());
-       return res;
+
+       // correct for lost factor
+       // (being based on GCDs, Yun's algorithm only finds factors up to a unit)
+       const ex lost_factor = quo(a, mul{factors}, x);
+       if (lost_factor.is_equal(_ex1)) {
+               // trivial lost factor
+               return factors;
+       }
+       if (!factors.empty() && factors[0].coeff.is_equal(1)) {
+               // multiply factor^1 with lost_factor
+               factors[0].rest *= lost_factor;
+               return factors;
+       }
+       // no factor^1: prepend lost_factor^1 to the results
+       epvector results = {expair(lost_factor, 1)};
+       std::move(factors.begin(), factors.end(), std::back_inserter(results));
+       return results;
 }
 
 
@@ -1868,38 +1911,29 @@ ex sqrfree(const ex &a, const lst &l)
 
        // convert the argument from something in Q[X] to something in Z[X]
        const numeric lcm = lcm_of_coefficients_denominators(a);
-       const ex tmp = multiply_lcm(a,lcm);
+       const ex tmp = multiply_lcm(a, lcm);
 
        // find the factors
-       exvector factors = sqrfree_yun(tmp, x);
+       epvector factors = sqrfree_yun(tmp, x);
+       if (factors.empty()) {
+               // the polynomial was a hidden zero
+               return _ex0;
+       }
 
-       // construct the next list of symbols with the first element popped
-       lst newargs = args;
-       newargs.remove_first();
+       // remove symbol x and proceed recursively with the remaining symbols
+       args.remove_first();
 
        // recurse down the factors in remaining variables
-       if (newargs.nops()>0) {
+       if (args.nops()>0) {
                for (auto & it : factors)
-                       it = sqrfree(it, newargs);
+                       it.rest = sqrfree(it.rest, args);
        }
 
        // Done with recursion, now construct the final result
-       ex result = _ex1;
-       int p = 1;
-       for (auto & it : factors)
-               result *= pow(it, p++);
-
-       // Yun's algorithm does not account for constant factors.  (For univariate
-       // polynomials it works only in the monic case.)  We can correct this by
-       // inserting what has been lost back into the result.  For completeness
-       // we'll also have to recurse down that factor in the remaining variables.
-       if (newargs.nops()>0)
-               result *= sqrfree(quo(tmp, result, x), newargs);
-       else
-               result *= quo(tmp, result, x);
+       ex result = mul(factors);
 
        // Put in the rational overall factor again and return
-       return result * lcm.inverse();
+       return result * lcm.inverse();
 }
 
 
@@ -1915,36 +1949,33 @@ ex sqrfree_parfrac(const ex & a, const symbol & x)
        // Find numerator and denominator
        ex nd = numer_denom(a);
        ex numer = nd.op(0), denom = nd.op(1);
-//clog << "numer = " << numer << ", denom = " << denom << endl;
+//std::clog << "numer = " << numer << ", denom = " << denom << std::endl;
 
        // Convert N(x)/D(x) -> Q(x) + R(x)/D(x), so degree(R) < degree(D)
        ex red_poly = quo(numer, denom, x), red_numer = rem(numer, denom, x).expand();
-//clog << "red_poly = " << red_poly << ", red_numer = " << red_numer << endl;
+//std::clog << "red_poly = " << red_poly << ", red_numer = " << red_numer << std::endl;
 
        // Factorize denominator and compute cofactors
-       exvector yun = sqrfree_yun(denom, x);
-//clog << "yun factors: " << exprseq(yun) << endl;
-       size_t num_yun = yun.size();
-       exvector factor; factor.reserve(num_yun);
-       exvector cofac; cofac.reserve(num_yun);
-       for (size_t i=0; i<num_yun; i++) {
-               if (!yun[i].is_equal(_ex1)) {
-                       for (size_t j=0; j<=i; j++) {
-                               factor.push_back(pow(yun[i], j+1));
-                               ex prod = _ex1;
-                               for (size_t k=0; k<num_yun; k++) {
-                                       if (k == i)
-                                               prod *= pow(yun[k], i-j);
-                                       else
-                                               prod *= pow(yun[k], k+1);
-                               }
-                               cofac.push_back(prod.expand());
+       epvector yun = sqrfree_yun(denom, x);
+       size_t yun_max_exponent = yun.empty() ? 0 : ex_to<numeric>(yun.back().coeff).to_int();
+       exvector factor, cofac;
+       for (size_t i=0; i<yun.size(); i++) {
+               numeric i_exponent = ex_to<numeric>(yun[i].coeff);
+               for (size_t j=0; j<i_exponent; j++) {
+                       factor.push_back(pow(yun[i].rest, j+1));
+                       ex prod = _ex1;
+                       for (size_t k=0; k<yun.size(); k++) {
+                               if (yun[k].coeff == i_exponent)
+                                       prod *= pow(yun[k].rest, i_exponent-1-j);
+                               else
+                                       prod *= pow(yun[k].rest, yun[k].coeff);
                        }
+                       cofac.push_back(prod.expand());
                }
        }
        size_t num_factors = factor.size();
-//clog << "factors  : " << exprseq(factor) << endl;
-//clog << "cofactors: " << exprseq(cofac) << endl;
+//std::clog << "factors  : " << exprseq(factor) << std::endl;
+//std::clog << "cofactors: " << exprseq(cofac) << std::endl;
 
        // Construct coefficient matrix for decomposition
        int max_denom_deg = denom.degree(x);
@@ -1955,8 +1986,8 @@ ex sqrfree_parfrac(const ex & a, const symbol & x)
                        sys(i, j) = cofac[j].coeff(x, i);
                rhs(i, 0) = red_numer.coeff(x, i);
        }
-//clog << "coeffs: " << sys << endl;
-//clog << "rhs   : " << rhs << endl;
+//std::clog << "coeffs: " << sys << std::endl;
+//std::clog << "rhs   : " << rhs << std::endl;
 
        // Solve resulting linear system
        matrix vars(num_factors, 1);
@@ -2391,47 +2422,11 @@ ex ex::to_rational(exmap & repl) const
        return bp->to_rational(repl);
 }
 
-// GiNaC 1.1 compatibility function
-ex ex::to_rational(lst & repl_lst) const
-{
-       // Convert lst to exmap
-       exmap m;
-       for (auto & it : repl_lst)
-               m.insert(std::make_pair(it.op(0), it.op(1)));
-
-       ex ret = bp->to_rational(m);
-
-       // Convert exmap back to lst
-       repl_lst.remove_all();
-       for (auto & it : m)
-               repl_lst.append(it.first == it.second);
-
-       return ret;
-}
-
 ex ex::to_polynomial(exmap & repl) const
 {
        return bp->to_polynomial(repl);
 }
 
-// GiNaC 1.1 compatibility function
-ex ex::to_polynomial(lst & repl_lst) const
-{
-       // Convert lst to exmap
-       exmap m;
-       for (auto & it : repl_lst)
-               m.insert(std::make_pair(it.op(0), it.op(1)));
-
-       ex ret = bp->to_polynomial(m);
-
-       // Convert exmap back to lst
-       repl_lst.remove_all();
-       for (auto & it : m)
-               repl_lst.append(it.first == it.second);
-
-       return ret;
-}
-
 /** Default implementation of ex::to_rational(). This replaces the object with
  *  a temporary symbol. */
 ex basic::to_rational(exmap & repl) const
@@ -2583,7 +2578,7 @@ static ex find_common_factor(const ex & e, ex & factor, exmap & repl)
                                x *= f;
                        }
 
-                       if (i == 0)
+                       if (gc.is_zero())
                                gc = x;
                        else
                                gc = gcd(gc, x);
@@ -2594,6 +2589,9 @@ static ex find_common_factor(const ex & e, ex & factor, exmap & repl)
                if (gc.is_equal(_ex1))
                        return e;
 
+               if (gc.is_zero())
+                       return _ex0;
+
                // The GCD is the factor we pull out
                factor *= gc;