]> www.ginac.de Git - ginac.git/blobdiff - ginac/normal.cpp
Speed up special cases of square-free factorization.
[ginac.git] / ginac / normal.cpp
index b4b5b694225f0b557c0cb96d3edb43966ff8b94a..9f8b7b4a406fdb5fe5ee39454e86eb6340f0d86a 100644 (file)
@@ -6,7 +6,7 @@
  *  computation, square-free factorization and rational function normalization. */
 
 /*
- *  GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2018 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -146,7 +146,7 @@ struct sym_desc {
        /** Maximum number of terms of leading coefficient of symbol in both polynomials */
        size_t max_lcnops;
 
-       /** Commparison operator for sorting */
+       /** Comparison operator for sorting */
        bool operator<(const sym_desc &x) const
        {
                if (max_deg == x.max_deg)
@@ -270,9 +270,15 @@ static numeric lcm_of_coefficients_denominators(const ex &e)
  *  @param lcm  LCM to multiply in */
 static ex multiply_lcm(const ex &e, const numeric &lcm)
 {
+       if (lcm.is_equal(*_num1_p))
+               // e * 1 -> e;
+               return e;
+
        if (is_exactly_a<mul>(e)) {
+               // (a*b*...)*lcm -> (a*lcma)*(b*lcmb)*...*(lcm/(lcma*lcmb*...))
                size_t num = e.nops();
-               exvector v; v.reserve(num + 1);
+               exvector v;
+               v.reserve(num + 1);
                numeric lcm_accum = *_num1_p;
                for (size_t i=0; i<num; i++) {
                        numeric op_lcm = lcmcoeff(e.op(i), *_num1_p);
@@ -282,18 +288,24 @@ static ex multiply_lcm(const ex &e, const numeric &lcm)
                v.push_back(lcm / lcm_accum);
                return dynallocate<mul>(v);
        } else if (is_exactly_a<add>(e)) {
+               // (a+b+...)*lcm -> a*lcm+b*lcm+...
                size_t num = e.nops();
-               exvector v; v.reserve(num);
+               exvector v;
+               v.reserve(num);
                for (size_t i=0; i<num; i++)
                        v.push_back(multiply_lcm(e.op(i), lcm));
                return dynallocate<add>(v);
        } else if (is_exactly_a<power>(e)) {
-               if (is_a<symbol>(e.op(0)))
-                       return e * lcm;
-               else
-                       return pow(multiply_lcm(e.op(0), lcm.power(ex_to<numeric>(e.op(1)).inverse())), e.op(1));
-       } else
-               return e * lcm;
+               if (!is_a<symbol>(e.op(0))) {
+                       // (b^e)*lcm -> (b*lcm^(1/e))^e if lcm^(1/e) ∈ ℚ (i.e. not a float)
+                       // but not for symbolic b, as evaluation would undo this again
+                       numeric root_of_lcm = lcm.power(ex_to<numeric>(e.op(1)).inverse());
+                       if (root_of_lcm.is_rational())
+                               return pow(multiply_lcm(e.op(0), root_of_lcm), e.op(1));
+               }
+       }
+       // can't recurse down into e
+       return dynallocate<mul>(e, lcm);
 }
 
 
@@ -1536,7 +1548,7 @@ ex gcd(const ex &a, const ex &b, ex *ca, ex *cb, bool check_args, unsigned optio
 
        // The symbol with least degree which is contained in both polynomials
        // is our main variable
-       sym_desc_vec::iterator vari = sym_stats.begin();
+       auto vari = sym_stats.begin();
        while ((vari != sym_stats.end()) && 
               (((vari->ldeg_b == 0) && (vari->deg_b == 0)) ||
                ((vari->ldeg_a == 0) && (vari->deg_a == 0))))
@@ -1771,34 +1783,47 @@ ex lcm(const ex &a, const ex &b, bool check_args)
  *  Yun's algorithm.  Used internally by sqrfree().
  *
  *  @param a  multivariate polynomial over Z[X], treated here as univariate
- *            polynomial in x.
+ *            polynomial in x (needs not be expanded).
  *  @param x  variable to factor in
- *  @return   vector of factors sorted in ascending degree */
-static exvector sqrfree_yun(const ex &a, const symbol &x)
+ *  @return   vector of expairs (factor, exponent), sorted by exponents */
+static epvector sqrfree_yun(const ex &a, const symbol &x)
 {
-       exvector res;
        ex w = a;
        ex z = w.diff(x);
        ex g = gcd(w, z);
+       if (g.is_zero()) {
+               return epvector{};
+       }
        if (g.is_equal(_ex1)) {
-               res.push_back(a);
-               return res;
+               return epvector{expair(a, _ex1)};
        }
-       ex y;
+       epvector results;
+       ex exponent = _ex0;
        do {
                w = quo(w, g, x);
-               y = quo(z, g, x);
-               z = y - w.diff(x);
+               if (w.is_zero()) {
+                       return res;
+               }
+               z = quo(z, g, x) - w.diff(x);
+               exponent = exponent + 1;
+               if (w.is_equal(x)) {
+                       // shortcut for x^n with n ∈ ℕ
+                       exponent += quo(z, w.diff(x), x);
+                       results.push_back(expair(w, exponent));
+                       break;
+               }
                g = gcd(w, z);
-               res.push_back(g);
+               if (!g.is_equal(_ex1)) {
+                       results.push_back(expair(g, exponent));
+               }
        } while (!z.is_zero());
-       return res;
+       return results;
 }
 
 
 /** Compute a square-free factorization of a multivariate polynomial in Q[X].
  *
- *  @param a  multivariate polynomial over Q[X]
+ *  @param a  multivariate polynomial over Q[X] (needs not be expanded)
  *  @param l  lst of variables to factor in, may be left empty for autodetection
  *  @return   a square-free factorization of \p a.
  *
@@ -1833,8 +1858,8 @@ static exvector sqrfree_yun(const ex &a, const symbol &x)
  */
 ex sqrfree(const ex &a, const lst &l)
 {
-       if (is_exactly_a<numeric>(a) ||     // algorithm does not trap a==0
-           is_a<symbol>(a))        // shortcut
+       if (is_exactly_a<numeric>(a) ||
+           is_a<symbol>(a))        // shortcuts
                return a;
 
        // If no lst of variables to factorize in was specified we have to
@@ -1860,30 +1885,28 @@ ex sqrfree(const ex &a, const lst &l)
        const ex tmp = multiply_lcm(a,lcm);
 
        // find the factors
-       exvector factors = sqrfree_yun(tmp, x);
+       epvector factors = sqrfree_yun(tmp, x);
 
-       // construct the next list of symbols with the first element popped
-       lst newargs = args;
-       newargs.remove_first();
+       // remove symbol x and proceed recursively with the remaining symbols
+       args.remove_first();
 
        // recurse down the factors in remaining variables
-       if (newargs.nops()>0) {
+       if (args.nops()>0) {
                for (auto & it : factors)
-                       it = sqrfree(it, newargs);
+                       it.rest = sqrfree(it.rest, args);
        }
 
        // Done with recursion, now construct the final result
        ex result = _ex1;
-       int p = 1;
        for (auto & it : factors)
-               result *= pow(it, p++);
+               result *= pow(it.rest, it.coeff);
 
        // Yun's algorithm does not account for constant factors.  (For univariate
        // polynomials it works only in the monic case.)  We can correct this by
        // inserting what has been lost back into the result.  For completeness
        // we'll also have to recurse down that factor in the remaining variables.
-       if (newargs.nops()>0)
-               result *= sqrfree(quo(tmp, result, x), newargs);
+       if (args.nops()>0)
+               result *= sqrfree(quo(tmp, result, x), args);
        else
                result *= quo(tmp, result, x);
 
@@ -1911,24 +1934,21 @@ ex sqrfree_parfrac(const ex & a, const symbol & x)
 //clog << "red_poly = " << red_poly << ", red_numer = " << red_numer << endl;
 
        // Factorize denominator and compute cofactors
-       exvector yun = sqrfree_yun(denom, x);
-//clog << "yun factors: " << exprseq(yun) << endl;
-       size_t num_yun = yun.size();
-       exvector factor; factor.reserve(num_yun);
-       exvector cofac; cofac.reserve(num_yun);
-       for (size_t i=0; i<num_yun; i++) {
-               if (!yun[i].is_equal(_ex1)) {
-                       for (size_t j=0; j<=i; j++) {
-                               factor.push_back(pow(yun[i], j+1));
-                               ex prod = _ex1;
-                               for (size_t k=0; k<num_yun; k++) {
-                                       if (k == i)
-                                               prod *= pow(yun[k], i-j);
-                                       else
-                                               prod *= pow(yun[k], k+1);
-                               }
-                               cofac.push_back(prod.expand());
+       epvector yun = sqrfree_yun(denom, x);
+       size_t yun_max_exponent = yun.empty() ? 0 : ex_to<numeric>(yun.back().coeff).to_int();
+       exvector factor, cofac;
+       for (size_t i=0; i<yun.size(); i++) {
+               numeric i_exponent = ex_to<numeric>(yun[i].coeff);
+               for (size_t j=0; j<i_exponent; j++) {
+                       factor.push_back(pow(yun[i].rest, j+1));
+                       ex prod = _ex1;
+                       for (size_t k=0; k<yun.size(); k++) {
+                               if (yun[k].coeff == i_exponent)
+                                       prod *= pow(yun[k].rest, i_exponent-1-j);
+                               else
+                                       prod *= pow(yun[k].rest, yun[k].coeff);
                        }
+                       cofac.push_back(prod.expand());
                }
        }
        size_t num_factors = factor.size();
@@ -2024,34 +2044,25 @@ static ex replace_with_symbol(const ex & e, exmap & repl)
 
 /** Function object to be applied by basic::normal(). */
 struct normal_map_function : public map_function {
-       int level;
-       normal_map_function(int l) : level(l) {}
-       ex operator()(const ex & e) override { return normal(e, level); }
+       ex operator()(const ex & e) override { return normal(e); }
 };
 
 /** Default implementation of ex::normal(). It normalizes the children and
  *  replaces the object with a temporary symbol.
  *  @see ex::normal */
-ex basic::normal(exmap & repl, exmap & rev_lookup, int level) const
+ex basic::normal(exmap & repl, exmap & rev_lookup) const
 {
        if (nops() == 0)
                return dynallocate<lst>({replace_with_symbol(*this, repl, rev_lookup), _ex1});
-       else {
-               if (level == 1)
-                       return dynallocate<lst>({replace_with_symbol(*this, repl, rev_lookup), _ex1});
-               else if (level == -max_recursion_level)
-                       throw(std::runtime_error("max recursion level reached"));
-               else {
-                       normal_map_function map_normal(level - 1);
-                       return dynallocate<lst>({replace_with_symbol(map(map_normal), repl, rev_lookup), _ex1});
-               }
-       }
+
+       normal_map_function map_normal;
+       return dynallocate<lst>({replace_with_symbol(map(map_normal), repl, rev_lookup), _ex1});
 }
 
 
 /** Implementation of ex::normal() for symbols. This returns the unmodified symbol.
  *  @see ex::normal */
-ex symbol::normal(exmap & repl, exmap & rev_lookup, int level) const
+ex symbol::normal(exmap & repl, exmap & rev_lookup) const
 {
        return dynallocate<lst>({*this, _ex1});
 }
@@ -2061,7 +2072,7 @@ ex symbol::normal(exmap & repl, exmap & rev_lookup, int level) const
  *  into re+I*im and replaces I and non-rational real numbers with a temporary
  *  symbol.
  *  @see ex::normal */
-ex numeric::normal(exmap & repl, exmap & rev_lookup, int level) const
+ex numeric::normal(exmap & repl, exmap & rev_lookup) const
 {
        numeric num = numer();
        ex numex = num;
@@ -2145,23 +2156,18 @@ static ex frac_cancel(const ex &n, const ex &d)
 /** Implementation of ex::normal() for a sum. It expands terms and performs
  *  fractional addition.
  *  @see ex::normal */
-ex add::normal(exmap & repl, exmap & rev_lookup, int level) const
+ex add::normal(exmap & repl, exmap & rev_lookup) const
 {
-       if (level == 1)
-               return dynallocate<lst>({replace_with_symbol(*this, repl, rev_lookup), _ex1});
-       else if (level == -max_recursion_level)
-               throw(std::runtime_error("max recursion level reached"));
-
        // Normalize children and split each one into numerator and denominator
        exvector nums, dens;
        nums.reserve(seq.size()+1);
        dens.reserve(seq.size()+1);
        for (auto & it : seq) {
-               ex n = ex_to<basic>(recombine_pair_to_ex(it)).normal(repl, rev_lookup, level-1);
+               ex n = ex_to<basic>(recombine_pair_to_ex(it)).normal(repl, rev_lookup);
                nums.push_back(n.op(0));
                dens.push_back(n.op(1));
        }
-       ex n = ex_to<numeric>(overall_coeff).normal(repl, rev_lookup, level-1);
+       ex n = ex_to<numeric>(overall_coeff).normal(repl, rev_lookup);
        nums.push_back(n.op(0));
        dens.push_back(n.op(1));
        GINAC_ASSERT(nums.size() == dens.size());
@@ -2202,23 +2208,18 @@ ex add::normal(exmap & repl, exmap & rev_lookup, int level) const
 /** Implementation of ex::normal() for a product. It cancels common factors
  *  from fractions.
  *  @see ex::normal() */
-ex mul::normal(exmap & repl, exmap & rev_lookup, int level) const
+ex mul::normal(exmap & repl, exmap & rev_lookup) const
 {
-       if (level == 1)
-               return dynallocate<lst>({replace_with_symbol(*this, repl, rev_lookup), _ex1});
-       else if (level == -max_recursion_level)
-               throw(std::runtime_error("max recursion level reached"));
-
        // Normalize children, separate into numerator and denominator
        exvector num; num.reserve(seq.size());
        exvector den; den.reserve(seq.size());
        ex n;
        for (auto & it : seq) {
-               n = ex_to<basic>(recombine_pair_to_ex(it)).normal(repl, rev_lookup, level-1);
+               n = ex_to<basic>(recombine_pair_to_ex(it)).normal(repl, rev_lookup);
                num.push_back(n.op(0));
                den.push_back(n.op(1));
        }
-       n = ex_to<numeric>(overall_coeff).normal(repl, rev_lookup, level-1);
+       n = ex_to<numeric>(overall_coeff).normal(repl, rev_lookup);
        num.push_back(n.op(0));
        den.push_back(n.op(1));
 
@@ -2231,16 +2232,11 @@ ex mul::normal(exmap & repl, exmap & rev_lookup, int level) const
  *  distributes integer exponents to numerator and denominator, and replaces
  *  non-integer powers by temporary symbols.
  *  @see ex::normal */
-ex power::normal(exmap & repl, exmap & rev_lookup, int level) const
+ex power::normal(exmap & repl, exmap & rev_lookup) const
 {
-       if (level == 1)
-               return dynallocate<lst>({replace_with_symbol(*this, repl, rev_lookup), _ex1});
-       else if (level == -max_recursion_level)
-               throw(std::runtime_error("max recursion level reached"));
-
        // Normalize basis and exponent (exponent gets reassembled)
-       ex n_basis = ex_to<basic>(basis).normal(repl, rev_lookup, level-1);
-       ex n_exponent = ex_to<basic>(exponent).normal(repl, rev_lookup, level-1);
+       ex n_basis = ex_to<basic>(basis).normal(repl, rev_lookup);
+       ex n_exponent = ex_to<basic>(exponent).normal(repl, rev_lookup);
        n_exponent = n_exponent.op(0) / n_exponent.op(1);
 
        if (n_exponent.info(info_flags::integer)) {
@@ -2286,7 +2282,7 @@ ex power::normal(exmap & repl, exmap & rev_lookup, int level) const
 /** Implementation of ex::normal() for pseries. It normalizes each coefficient
  *  and replaces the series by a temporary symbol.
  *  @see ex::normal */
-ex pseries::normal(exmap & repl, exmap & rev_lookup, int level) const
+ex pseries::normal(exmap & repl, exmap & rev_lookup) const
 {
        epvector newseq;
        for (auto & it : seq) {
@@ -2309,13 +2305,12 @@ ex pseries::normal(exmap & repl, exmap & rev_lookup, int level) const
  *  expression can be treated as a rational function). normal() is applied
  *  recursively to arguments of functions etc.
  *
- *  @param level maximum depth of recursion
  *  @return normalized expression */
-ex ex::normal(int level) const
+ex ex::normal() const
 {
        exmap repl, rev_lookup;
 
-       ex e = bp->normal(repl, rev_lookup, level);
+       ex e = bp->normal(repl, rev_lookup);
        GINAC_ASSERT(is_a<lst>(e));
 
        // Re-insert replaced symbols
@@ -2336,7 +2331,7 @@ ex ex::numer() const
 {
        exmap repl, rev_lookup;
 
-       ex e = bp->normal(repl, rev_lookup, 0);
+       ex e = bp->normal(repl, rev_lookup);
        GINAC_ASSERT(is_a<lst>(e));
 
        // Re-insert replaced symbols
@@ -2356,7 +2351,7 @@ ex ex::denom() const
 {
        exmap repl, rev_lookup;
 
-       ex e = bp->normal(repl, rev_lookup, 0);
+       ex e = bp->normal(repl, rev_lookup);
        GINAC_ASSERT(is_a<lst>(e));
 
        // Re-insert replaced symbols
@@ -2376,7 +2371,7 @@ ex ex::numer_denom() const
 {
        exmap repl, rev_lookup;
 
-       ex e = bp->normal(repl, rev_lookup, 0);
+       ex e = bp->normal(repl, rev_lookup);
        GINAC_ASSERT(is_a<lst>(e));
 
        // Re-insert replaced symbols
@@ -2405,47 +2400,11 @@ ex ex::to_rational(exmap & repl) const
        return bp->to_rational(repl);
 }
 
-// GiNaC 1.1 compatibility function
-ex ex::to_rational(lst & repl_lst) const
-{
-       // Convert lst to exmap
-       exmap m;
-       for (auto & it : repl_lst)
-               m.insert(std::make_pair(it.op(0), it.op(1)));
-
-       ex ret = bp->to_rational(m);
-
-       // Convert exmap back to lst
-       repl_lst.remove_all();
-       for (auto & it : m)
-               repl_lst.append(it.first == it.second);
-
-       return ret;
-}
-
 ex ex::to_polynomial(exmap & repl) const
 {
        return bp->to_polynomial(repl);
 }
 
-// GiNaC 1.1 compatibility function
-ex ex::to_polynomial(lst & repl_lst) const
-{
-       // Convert lst to exmap
-       exmap m;
-       for (auto & it : repl_lst)
-               m.insert(std::make_pair(it.op(0), it.op(1)));
-
-       ex ret = bp->to_polynomial(m);
-
-       // Convert exmap back to lst
-       repl_lst.remove_all();
-       for (auto & it : m)
-               repl_lst.append(it.first == it.second);
-
-       return ret;
-}
-
 /** Default implementation of ex::to_rational(). This replaces the object with
  *  a temporary symbol. */
 ex basic::to_rational(exmap & repl) const