]> www.ginac.de Git - ginac.git/blobdiff - ginac/normal.cpp
- normal() respects the "level" parameter to limit the recursion depth
[ginac.git] / ginac / normal.cpp
index f93ecc8621b9e85e1082e9a100c366f9f8fa7ce8..074759df4db92413e902b5ec6182af426dde48b6 100644 (file)
@@ -489,6 +489,57 @@ ex prem(const ex &a, const ex &b, const symbol &x, bool check_args)
 }
 
 
 }
 
 
+/** Sparse pseudo-remainder of polynomials a(x) and b(x) in Z[x].
+ *
+ *  @param a  first polynomial in x (dividend)
+ *  @param b  second polynomial in x (divisor)
+ *  @param x  a and b are polynomials in x
+ *  @param check_args  check whether a and b are polynomials with rational
+ *         coefficients (defaults to "true")
+ *  @return sparse pseudo-remainder of a(x) and b(x) in Z[x] */
+
+ex sprem(const ex &a, const ex &b, const symbol &x, bool check_args)
+{
+    if (b.is_zero())
+        throw(std::overflow_error("prem: division by zero"));
+    if (is_ex_exactly_of_type(a, numeric)) {
+        if (is_ex_exactly_of_type(b, numeric))
+            return _ex0();
+        else
+            return b;
+    }
+    if (check_args && (!a.info(info_flags::rational_polynomial) || !b.info(info_flags::rational_polynomial)))
+        throw(std::invalid_argument("prem: arguments must be polynomials over the rationals"));
+
+    // Polynomial long division
+    ex r = a.expand();
+    ex eb = b.expand();
+    int rdeg = r.degree(x);
+    int bdeg = eb.degree(x);
+    ex blcoeff;
+    if (bdeg <= rdeg) {
+        blcoeff = eb.coeff(x, bdeg);
+        if (bdeg == 0)
+            eb = _ex0();
+        else
+            eb -= blcoeff * power(x, bdeg);
+    } else
+        blcoeff = _ex1();
+
+    while (rdeg >= bdeg && !r.is_zero()) {
+        ex rlcoeff = r.coeff(x, rdeg);
+        ex term = (power(x, rdeg - bdeg) * eb * rlcoeff).expand();
+        if (rdeg == 0)
+            r = _ex0();
+        else
+            r -= rlcoeff * power(x, rdeg);
+        r = (blcoeff * r).expand() - term;
+        rdeg = r.degree(x);
+    }
+    return r;
+}
+
+
 /** Exact polynomial division of a(X) by b(X) in Q[X].
  *  
  *  @param a  first multivariate polynomial (dividend)
 /** Exact polynomial division of a(X) by b(X) in Q[X].
  *  
  *  @param a  first multivariate polynomial (dividend)
@@ -1797,6 +1848,11 @@ static ex frac_cancel(const ex &n, const ex &d)
  *  @see ex::normal */
 ex add::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
  *  @see ex::normal */
 ex add::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
+       if (level == 1)
+               return (new lst(*this, _ex1()))->setflag(status_flags::dynallocated);
+       else if (level == -max_recursion_level)
+        throw(std::runtime_error("max recursion level reached"));
+
     // Normalize and expand children, chop into summands
     exvector o;
     o.reserve(seq.size()+1);
     // Normalize and expand children, chop into summands
     exvector o;
     o.reserve(seq.size()+1);
@@ -1873,6 +1929,11 @@ ex add::normal(lst &sym_lst, lst &repl_lst, int level) const
  *  @see ex::normal() */
 ex mul::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
  *  @see ex::normal() */
 ex mul::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
+       if (level == 1)
+               return (new lst(*this, _ex1()))->setflag(status_flags::dynallocated);
+       else if (level == -max_recursion_level)
+        throw(std::runtime_error("max recursion level reached"));
+
     // Normalize children, separate into numerator and denominator
        ex num = _ex1();
        ex den = _ex1(); 
     // Normalize children, separate into numerator and denominator
        ex num = _ex1();
        ex den = _ex1(); 
@@ -1899,6 +1960,11 @@ ex mul::normal(lst &sym_lst, lst &repl_lst, int level) const
  *  @see ex::normal */
 ex power::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
  *  @see ex::normal */
 ex power::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
+       if (level == 1)
+               return (new lst(*this, _ex1()))->setflag(status_flags::dynallocated);
+       else if (level == -max_recursion_level)
+        throw(std::runtime_error("max recursion level reached"));
+
        // Normalize basis
     ex n = basis.bp->normal(sym_lst, repl_lst, level-1);
 
        // Normalize basis
     ex n = basis.bp->normal(sym_lst, repl_lst, level-1);