#include "numeric.h"
#include "power.h"
#include "relational.h"
+#include "matrix.h"
#include "pseries.h"
#include "symbol.h"
#include "utils.h"
if (is_ex_exactly_of_type(b, numeric))
return _ex0();
else
- return b;
+ return a;
}
#if FAST_COMPARE
if (a.is_equal(b))
}
+/** Decompose rational function a(x)=N(x)/D(x) into P(x)+n(x)/D(x)
+ * with degree(n, x) < degree(D, x).
+ *
+ * @param a rational function in x
+ * @param x a is a function of x
+ * @return decomposed function. */
+ex decomp_rational(const ex &a, const symbol &x)
+{
+ ex nd = numer_denom(a);
+ ex numer = nd.op(0), denom = nd.op(1);
+ ex q = quo(numer, denom, x);
+ if (is_ex_exactly_of_type(q, fail))
+ return a;
+ else
+ return q + rem(numer, denom, x) / denom;
+}
+
+
/** Pseudo-remainder of polynomials a(x) and b(x) in Z[x].
*
* @param a first polynomial in x (dividend)
} while (!z.is_zero());
return res;
}
+
/** Compute square-free factorization of multivariate polynomial in Q[X].
*
* @param a multivariate polynomial over Q[X]
return result * lcm.inverse();
}
+/** Compute square-free partial fraction decomposition of rational function
+ * a(x).
+ *
+ * @param a rational function over Z[x], treated as univariate polynomial
+ * in x
+ * @param x variable to factor in
+ * @return decomposed rational function */
+ex sqrfree_parfrac(const ex & a, const symbol & x)
+{
+ // Find numerator and denominator
+ ex nd = numer_denom(a);
+ ex numer = nd.op(0), denom = nd.op(1);
+//clog << "numer = " << numer << ", denom = " << denom << endl;
+
+ // Convert N(x)/D(x) -> Q(x) + R(x)/D(x), so degree(R) < degree(D)
+ ex red_poly = quo(numer, denom, x), red_numer = rem(numer, denom, x).expand();
+//clog << "red_poly = " << red_poly << ", red_numer = " << red_numer << endl;
+
+ // Factorize denominator and compute cofactors
+ exvector yun = sqrfree_yun(denom, x);
+//clog << "yun factors: " << exprseq(yun) << endl;
+ int num_yun = yun.size();
+ exvector factor; factor.reserve(num_yun);
+ exvector cofac; cofac.reserve(num_yun);
+ for (unsigned i=0; i<num_yun; i++) {
+ if (!yun[i].is_equal(_ex1())) {
+ for (unsigned j=0; j<=i; j++) {
+ factor.push_back(pow(yun[i], j+1));
+ ex prod = 1;
+ for (unsigned k=0; k<num_yun; k++) {
+ if (k == i)
+ prod *= pow(yun[k], i-j);
+ else
+ prod *= pow(yun[k], k+1);
+ }
+ cofac.push_back(prod.expand());
+ }
+ }
+ }
+ int num_factors = factor.size();
+//clog << "factors : " << exprseq(factor) << endl;
+//clog << "cofactors: " << exprseq(cofac) << endl;
+
+ // Construct coefficient matrix for decomposition
+ int max_denom_deg = denom.degree(x);
+ matrix sys(max_denom_deg + 1, num_factors);
+ matrix rhs(max_denom_deg + 1, 1);
+ for (unsigned i=0; i<=max_denom_deg; i++) {
+ for (unsigned j=0; j<num_factors; j++)
+ sys(i, j) = cofac[j].coeff(x, i);
+ rhs(i, 0) = red_numer.coeff(x, i);
+ }
+//clog << "coeffs: " << sys << endl;
+//clog << "rhs : " << rhs << endl;
+
+ // Solve resulting linear system
+ matrix vars(num_factors, 1);
+ for (unsigned i=0; i<num_factors; i++)
+ vars(i, 0) = symbol();
+ matrix sol = sys.solve(vars, rhs);
+
+ // Sum up decomposed fractions
+ ex sum = 0;
+ for (unsigned i=0; i<num_factors; i++)
+ sum += sol(i, 0) / factor[i];
+
+ return red_poly + sum;
+}
+
/*
* Normal form of rational functions
* the information that (a+b) is the numerator and 3 is the denominator.
*/
+
/** Create a symbol for replacing the expression "e" (or return a previously
* assigned symbol). The symbol is appended to sym_lst and returned, the
* expression is appended to repl_lst.
return es;
}
-/** Default implementation of ex::normal(). It replaces the object with a
- * temporary symbol.
+
+/** Function object to be applied by basic::normal(). */
+struct normal_map_function : public map_function {
+ int level;
+ normal_map_function(int l) : level(l) {}
+ ex operator()(const ex & e) { return normal(e, level); }
+};
+
+/** Default implementation of ex::normal(). It normalizes the children and
+ * replaces the object with a temporary symbol.
* @see ex::normal */
ex basic::normal(lst &sym_lst, lst &repl_lst, int level) const
{
- return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
+ if (nops() == 0)
+ return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
+ else {
+ if (level == 1)
+ return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
+ else if (level == -max_recursion_level)
+ throw(std::runtime_error("max recursion level reached"));
+ else {
+ normal_map_function map_normal(level - 1);
+ return (new lst(replace_with_symbol(map(map_normal), sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
+ }
+ }
}
}
-/** Implementation of ex::normal() for relationals. It normalizes both sides.
- * @see ex::normal */
-ex relational::normal(lst &sym_lst, lst &repl_lst, int level) const
-{
- return (new lst(relational(lh.normal(), rh.normal(), o), _ex1()))->setflag(status_flags::dynallocated);
-}
-
-
/** Normalization of rational functions.
* This function converts an expression to its normal form
* "numerator/denominator", where numerator and denominator are (relatively