]> www.ginac.de Git - ginac.git/blobdiff - ginac/normal.cpp
Avoid unnecessary expansion in sqrfree_yun().
[ginac.git] / ginac / normal.cpp
index 5fda185981b13fdd82e800ff842db8f3ec2e1ce5..f8615499add62a1a8e447a7b8d5f2aaf725b62b0 100644 (file)
@@ -6,7 +6,7 @@
  *  computation, square-free factorization and rational function normalization. */
 
 /*
- *  GiNaC Copyright (C) 1999-2019 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2020 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -1801,31 +1801,51 @@ static epvector sqrfree_yun(const ex &a, const symbol &x)
        ex z = w.diff(x);
        ex g = gcd(w, z);
        if (g.is_zero()) {
-               return epvector{};
+               // manifest zero or hidden zero
+               return {};
        }
        if (g.is_equal(_ex1)) {
-               return epvector{expair(a, _ex1)};
+               // w(x) and w'(x) share no factors: w(x) is square-free
+               return {expair(a, _ex1)};
        }
-       epvector results;
-       ex exponent = _ex0;
+
+       epvector factors;
+       ex i = 0;  // exponent
        do {
                w = quo(w, g, x);
                if (w.is_zero()) {
-                       return results;
+                       // hidden zero
+                       break;
                }
                z = quo(z, g, x) - w.diff(x);
-               exponent = exponent + 1;
+               i += 1;
                if (w.is_equal(x)) {
                        // shortcut for x^n with n ∈ ℕ
-                       exponent += quo(z, w.diff(x), x);
-                       results.push_back(expair(w, exponent));
+                       i += quo(z, w.diff(x), x);
+                       factors.push_back(expair(w, i));
                        break;
                }
                g = gcd(w, z);
                if (!g.is_equal(_ex1)) {
-                       results.push_back(expair(g, exponent));
+                       factors.push_back(expair(g, i));
                }
        } while (!z.is_zero());
+
+       // correct for lost factor
+       // (being based on GCDs, Yun's algorithm only finds factors up to a unit)
+       const ex lost_factor = quo(a, mul{factors}, x);
+       if (lost_factor.is_equal(_ex1)) {
+               // trivial lost factor
+               return factors;
+       }
+       if (!factors.empty() && factors[0].coeff.is_equal(1)) {
+               // multiply factor^1 with lost_factor
+               factors[0].rest *= lost_factor;
+               return factors;
+       }
+       // no factor^1: prepend lost_factor^1 to the results
+       epvector results = {expair(lost_factor, 1)};
+       std::move(factors.begin(), factors.end(), std::back_inserter(results));
        return results;
 }
 
@@ -1891,10 +1911,14 @@ ex sqrfree(const ex &a, const lst &l)
 
        // convert the argument from something in Q[X] to something in Z[X]
        const numeric lcm = lcm_of_coefficients_denominators(a);
-       const ex tmp = multiply_lcm(a,lcm);
+       const ex tmp = multiply_lcm(a, lcm);
 
        // find the factors
        epvector factors = sqrfree_yun(tmp, x);
+       if (factors.empty()) {
+               // the polynomial was a hidden zero
+               return _ex0;
+       }
 
        // remove symbol x and proceed recursively with the remaining symbols
        args.remove_first();
@@ -1906,21 +1930,10 @@ ex sqrfree(const ex &a, const lst &l)
        }
 
        // Done with recursion, now construct the final result
-       ex result = _ex1;
-       for (auto & it : factors)
-               result *= pow(it.rest, it.coeff);
-
-       // Yun's algorithm does not account for constant factors.  (For univariate
-       // polynomials it works only in the monic case.)  We can correct this by
-       // inserting what has been lost back into the result.  For completeness
-       // we'll also have to recurse down that factor in the remaining variables.
-       if (args.nops()>0)
-               result *= sqrfree(quo(tmp, result, x), args);
-       else
-               result *= quo(tmp, result, x);
+       ex result = mul(factors);
 
        // Put in the rational overall factor again and return
-       return result * lcm.inverse();
+       return result * lcm.inverse();
 }
 
 
@@ -1936,11 +1949,11 @@ ex sqrfree_parfrac(const ex & a, const symbol & x)
        // Find numerator and denominator
        ex nd = numer_denom(a);
        ex numer = nd.op(0), denom = nd.op(1);
-//clog << "numer = " << numer << ", denom = " << denom << endl;
+//std::clog << "numer = " << numer << ", denom = " << denom << std::endl;
 
        // Convert N(x)/D(x) -> Q(x) + R(x)/D(x), so degree(R) < degree(D)
        ex red_poly = quo(numer, denom, x), red_numer = rem(numer, denom, x).expand();
-//clog << "red_poly = " << red_poly << ", red_numer = " << red_numer << endl;
+//std::clog << "red_poly = " << red_poly << ", red_numer = " << red_numer << std::endl;
 
        // Factorize denominator and compute cofactors
        epvector yun = sqrfree_yun(denom, x);
@@ -1961,8 +1974,8 @@ ex sqrfree_parfrac(const ex & a, const symbol & x)
                }
        }
        size_t num_factors = factor.size();
-//clog << "factors  : " << exprseq(factor) << endl;
-//clog << "cofactors: " << exprseq(cofac) << endl;
+//std::clog << "factors  : " << exprseq(factor) << std::endl;
+//std::clog << "cofactors: " << exprseq(cofac) << std::endl;
 
        // Construct coefficient matrix for decomposition
        int max_denom_deg = denom.degree(x);
@@ -1973,8 +1986,8 @@ ex sqrfree_parfrac(const ex & a, const symbol & x)
                        sys(i, j) = cofac[j].coeff(x, i);
                rhs(i, 0) = red_numer.coeff(x, i);
        }
-//clog << "coeffs: " << sys << endl;
-//clog << "rhs   : " << rhs << endl;
+//std::clog << "coeffs: " << sys << std::endl;
+//std::clog << "rhs   : " << rhs << std::endl;
 
        // Solve resulting linear system
        matrix vars(num_factors, 1);
@@ -2565,7 +2578,7 @@ static ex find_common_factor(const ex & e, ex & factor, exmap & repl)
                                x *= f;
                        }
 
-                       if (i == 0)
+                       if (gc.is_zero())
                                gc = x;
                        else
                                gc = gcd(gc, x);
@@ -2576,6 +2589,9 @@ static ex find_common_factor(const ex & e, ex & factor, exmap & repl)
                if (gc.is_equal(_ex1))
                        return e;
 
+               if (gc.is_zero())
+                       return _ex0;
+
                // The GCD is the factor we pull out
                factor *= gc;