X-Git-Url: https://www.ginac.de/ginac.git//ginac.git?p=ginac.git;a=blobdiff_plain;f=ginac%2Fpower.cpp;h=c3f348da095c4d4f0f824de7dd3ae99332a1338b;hp=a459150dbe4d5a692bf0a47e05ee13e8285f97a8;hb=55b0f861ce3676061b8f531c97fd34046875581d;hpb=487e5659efe401683eee0381b0d23f967ffffc3c diff --git a/ginac/power.cpp b/ginac/power.cpp index a459150d..b9090f81 100644 --- a/ginac/power.cpp +++ b/ginac/power.cpp @@ -3,7 +3,7 @@ * Implementation of GiNaC's symbolic exponentiation (basis^exponent). */ /* - * GiNaC Copyright (C) 1999 Johannes Gutenberg University Mainz, Germany + * GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -17,419 +17,899 @@ * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software - * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -#include -#include -#include - #include "power.h" #include "expairseq.h" #include "add.h" #include "mul.h" +#include "ncmul.h" #include "numeric.h" -#include "relational.h" +#include "constant.h" +#include "operators.h" +#include "inifcns.h" // for log() in power::derivative() +#include "matrix.h" +#include "indexed.h" #include "symbol.h" +#include "lst.h" +#include "archive.h" +#include "utils.h" +#include "relational.h" +#include "compiler.h" -typedef vector intvector; +#include +#include +#include +#include + +namespace GiNaC { + +GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(power, basic, + print_func(&power::do_print_dflt). + print_func(&power::do_print_latex). + print_func(&power::do_print_csrc). + print_func(&power::do_print_python). + print_func(&power::do_print_python_repr). + print_func(&power::do_print_csrc_cl_N)) ////////// -// default constructor, destructor, copy constructor assignment operator and helpers +// default constructor ////////// -// public +power::power() { } -power::power() : basic(TINFO_power) +////////// +// other constructors +////////// + +// all inlined + +////////// +// archiving +////////// + +void power::read_archive(const archive_node &n, lst &sym_lst) { - debugmsg("power default constructor",LOGLEVEL_CONSTRUCT); + inherited::read_archive(n, sym_lst); + n.find_ex("basis", basis, sym_lst); + n.find_ex("exponent", exponent, sym_lst); } -power::~power() +void power::archive(archive_node &n) const { - debugmsg("power destructor",LOGLEVEL_DESTRUCT); - destroy(0); + inherited::archive(n); + n.add_ex("basis", basis); + n.add_ex("exponent", exponent); } -power::power(power const & other) +////////// +// functions overriding virtual functions from base classes +////////// + +// public + +void power::print_power(const print_context & c, const char *powersymbol, const char *openbrace, const char *closebrace, unsigned level) const { - debugmsg("power copy constructor",LOGLEVEL_CONSTRUCT); - copy(other); + // Ordinary output of powers using '^' or '**' + if (precedence() <= level) + c.s << openbrace << '('; + basis.print(c, precedence()); + c.s << powersymbol; + c.s << openbrace; + exponent.print(c, precedence()); + c.s << closebrace; + if (precedence() <= level) + c.s << ')' << closebrace; } -power const & power::operator=(power const & other) +void power::do_print_dflt(const print_dflt & c, unsigned level) const { - debugmsg("power operator=",LOGLEVEL_ASSIGNMENT); - if (this != &other) { - destroy(1); - copy(other); - } - return *this; -} + if (exponent.is_equal(_ex1_2)) { -// protected + // Square roots are printed in a special way + c.s << "sqrt("; + basis.print(c); + c.s << ')'; -void power::copy(power const & other) -{ - basic::copy(other); - basis=other.basis; - exponent=other.exponent; + } else + print_power(c, "^", "", "", level); } -void power::destroy(bool call_parent) +void power::do_print_latex(const print_latex & c, unsigned level) const { - if (call_parent) basic::destroy(call_parent); -} + if (is_exactly_a(exponent) && ex_to(exponent).is_negative()) { -////////// -// other constructors -////////// + // Powers with negative numeric exponents are printed as fractions + c.s << "\\frac{1}{"; + power(basis, -exponent).eval().print(c); + c.s << '}'; -// public + } else if (exponent.is_equal(_ex1_2)) { + + // Square roots are printed in a special way + c.s << "\\sqrt{"; + basis.print(c); + c.s << '}'; -power::power(ex const & lh, ex const & rh) : basic(TINFO_power), basis(lh), exponent(rh) + } else + print_power(c, "^", "{", "}", level); +} + +static void print_sym_pow(const print_context & c, const symbol &x, int exp) { - debugmsg("power constructor from ex,ex",LOGLEVEL_CONSTRUCT); - ASSERT(basis.return_type()==return_types::commutative); + // Optimal output of integer powers of symbols to aid compiler CSE. + // C.f. ISO/IEC 14882:2011, section 1.9 [intro execution], paragraph 15 + // to learn why such a parenthesation is really necessary. + if (exp == 1) { + x.print(c); + } else if (exp == 2) { + x.print(c); + c.s << "*"; + x.print(c); + } else if (exp & 1) { + x.print(c); + c.s << "*"; + print_sym_pow(c, x, exp-1); + } else { + c.s << "("; + print_sym_pow(c, x, exp >> 1); + c.s << ")*("; + print_sym_pow(c, x, exp >> 1); + c.s << ")"; + } } -power::power(ex const & lh, numeric const & rh) : basic(TINFO_power), basis(lh), exponent(rh) +void power::do_print_csrc_cl_N(const print_csrc_cl_N& c, unsigned level) const { - debugmsg("power constructor from ex,numeric",LOGLEVEL_CONSTRUCT); - ASSERT(basis.return_type()==return_types::commutative); + if (exponent.is_equal(_ex_1)) { + c.s << "recip("; + basis.print(c); + c.s << ')'; + return; + } + c.s << "expt("; + basis.print(c); + c.s << ", "; + exponent.print(c); + c.s << ')'; } -////////// -// functions overriding virtual functions from bases classes -////////// +void power::do_print_csrc(const print_csrc & c, unsigned level) const +{ + // Integer powers of symbols are printed in a special, optimized way + if (exponent.info(info_flags::integer) && + (is_a(basis) || is_a(basis))) { + int exp = ex_to(exponent).to_int(); + if (exp > 0) + c.s << '('; + else { + exp = -exp; + c.s << "1.0/("; + } + print_sym_pow(c, ex_to(basis), exp); + c.s << ')'; + + // ^-1 is printed as "1.0/" or with the recip() function of CLN + } else if (exponent.is_equal(_ex_1)) { + c.s << "1.0/("; + basis.print(c); + c.s << ')'; + + // Otherwise, use the pow() function + } else { + c.s << "pow("; + basis.print(c); + c.s << ','; + exponent.print(c); + c.s << ')'; + } +} -// public +void power::do_print_python(const print_python & c, unsigned level) const +{ + print_power(c, "**", "", "", level); +} -basic * power::duplicate() const +void power::do_print_python_repr(const print_python_repr & c, unsigned level) const { - debugmsg("power duplicate",LOGLEVEL_DUPLICATE); - return new power(*this); + c.s << class_name() << '('; + basis.print(c); + c.s << ','; + exponent.print(c); + c.s << ')'; } bool power::info(unsigned inf) const { - if (inf==info_flags::polynomial || inf==info_flags::integer_polynomial || inf==info_flags::rational_polynomial) { - return exponent.info(info_flags::nonnegint); - } else if (inf==info_flags::rational_function) { - return exponent.info(info_flags::integer); - } else { - return basic::info(inf); - } + switch (inf) { + case info_flags::polynomial: + case info_flags::integer_polynomial: + case info_flags::cinteger_polynomial: + case info_flags::rational_polynomial: + case info_flags::crational_polynomial: + return exponent.info(info_flags::nonnegint) && + basis.info(inf); + case info_flags::rational_function: + return exponent.info(info_flags::integer) && + basis.info(inf); + case info_flags::algebraic: + return !exponent.info(info_flags::integer) || + basis.info(inf); + case info_flags::expanded: + return (flags & status_flags::expanded); + case info_flags::positive: + return basis.info(info_flags::positive) && exponent.info(info_flags::real); + case info_flags::nonnegative: + return (basis.info(info_flags::positive) && exponent.info(info_flags::real)) || + (basis.info(info_flags::real) && exponent.info(info_flags::integer) && exponent.info(info_flags::even)); + case info_flags::has_indices: { + if (flags & status_flags::has_indices) + return true; + else if (flags & status_flags::has_no_indices) + return false; + else if (basis.info(info_flags::has_indices)) { + setflag(status_flags::has_indices); + clearflag(status_flags::has_no_indices); + return true; + } else { + clearflag(status_flags::has_indices); + setflag(status_flags::has_no_indices); + return false; + } + } + } + return inherited::info(inf); } -int power::nops() const +size_t power::nops() const { - return 2; + return 2; } -ex & power::let_op(int const i) +ex power::op(size_t i) const { - ASSERT(i>=0); - ASSERT(i<2); + GINAC_ASSERT(i<2); - return i==0 ? basis : exponent; + return i==0 ? basis : exponent; } -int power::degree(symbol const & s) const +ex power::map(map_function & f) const { - if (is_exactly_of_type(*exponent.bp,numeric)) { - if ((*basis.bp).compare(s)==0) - return ex_to_numeric(exponent).to_int(); - else - return basis.degree(s) * ex_to_numeric(exponent).to_int(); - } - return 0; + const ex &mapped_basis = f(basis); + const ex &mapped_exponent = f(exponent); + + if (!are_ex_trivially_equal(basis, mapped_basis) + || !are_ex_trivially_equal(exponent, mapped_exponent)) + return (new power(mapped_basis, mapped_exponent))->setflag(status_flags::dynallocated); + else + return *this; } -int power::ldegree(symbol const & s) const +bool power::is_polynomial(const ex & var) const { - if (is_exactly_of_type(*exponent.bp,numeric)) { - if ((*basis.bp).compare(s)==0) - return ex_to_numeric(exponent).to_int(); - else - return basis.ldegree(s) * ex_to_numeric(exponent).to_int(); - } - return 0; + if (basis.is_polynomial(var)) { + if (basis.has(var)) + // basis is non-constant polynomial in var + return exponent.info(info_flags::nonnegint); + else + // basis is constant in var + return !exponent.has(var); + } + // basis is a non-polynomial function of var + return false; } -ex power::coeff(symbol const & s, int const n) const +int power::degree(const ex & s) const { - if ((*basis.bp).compare(s)!=0) { - // basis not equal to s - if (n==0) { - return *this; - } else { - return exZERO(); - } - } else if (is_exactly_of_type(*exponent.bp,numeric)&& - (static_cast(*exponent.bp).compare(numeric(n))==0)) { - return exONE(); - } + if (is_equal(ex_to(s))) + return 1; + else if (is_exactly_a(exponent) && ex_to(exponent).is_integer()) { + if (basis.is_equal(s)) + return ex_to(exponent).to_int(); + else + return basis.degree(s) * ex_to(exponent).to_int(); + } else if (basis.has(s)) + throw(std::runtime_error("power::degree(): undefined degree because of non-integer exponent")); + else + return 0; +} - return exZERO(); +int power::ldegree(const ex & s) const +{ + if (is_equal(ex_to(s))) + return 1; + else if (is_exactly_a(exponent) && ex_to(exponent).is_integer()) { + if (basis.is_equal(s)) + return ex_to(exponent).to_int(); + else + return basis.ldegree(s) * ex_to(exponent).to_int(); + } else if (basis.has(s)) + throw(std::runtime_error("power::ldegree(): undefined degree because of non-integer exponent")); + else + return 0; } +ex power::coeff(const ex & s, int n) const +{ + if (is_equal(ex_to(s))) + return n==1 ? _ex1 : _ex0; + else if (!basis.is_equal(s)) { + // basis not equal to s + if (n == 0) + return *this; + else + return _ex0; + } else { + // basis equal to s + if (is_exactly_a(exponent) && ex_to(exponent).is_integer()) { + // integer exponent + int int_exp = ex_to(exponent).to_int(); + if (n == int_exp) + return _ex1; + else + return _ex0; + } else { + // non-integer exponents are treated as zero + if (n == 0) + return *this; + else + return _ex0; + } + } +} + +/** Perform automatic term rewriting rules in this class. In the following + * x, x1, x2,... stand for a symbolic variables of type ex and c, c1, c2... + * stand for such expressions that contain a plain number. + * - ^(x,0) -> 1 (also handles ^(0,0)) + * - ^(x,1) -> x + * - ^(0,c) -> 0 or exception (depending on the real part of c) + * - ^(1,x) -> 1 + * - ^(c1,c2) -> *(c1^n,c1^(c2-n)) (so that 0<(c2-n)<1, try to evaluate roots, possibly in numerator and denominator of c1) + * - ^(^(x,c1),c2) -> ^(x,c1*c2) if x is positive and c1 is real. + * - ^(^(x,c1),c2) -> ^(x,c1*c2) (c2 integer or -1 < c1 <= 1 or (c1=-1 and c2>0), case c1=1 should not happen, see below!) + * - ^(*(x,y,z),c) -> *(x^c,y^c,z^c) (if c integer) + * - ^(*(x,c1),c2) -> ^(x,c2)*c1^c2 (c1>0) + * - ^(*(x,c1),c2) -> ^(-x,c2)*c1^c2 (c1<0) + * + * @param level cut-off in recursive evaluation */ ex power::eval(int level) const { - // simplifications: ^(x,0) -> 1 (0^0 handled here) - // ^(x,1) -> x - // ^(0,x) -> 0 (except if x is real and negative, in which case an exception is thrown) - // ^(1,x) -> 1 - // ^(c1,c2) -> *(c1^n,c1^(c2-n)) (c1, c2 numeric(), 0<(c2-n)<1 except if c1,c2 are rational, but c1^c2 is not) - // ^(^(x,c1),c2) -> ^(x,c1*c2) (c1, c2 numeric(), c2 integer or -1 < c1 <= 1, case c1=1 should not happen, see below!) - // ^(*(x,y,z),c1) -> *(x^c1,y^c1,z^c1) (c1 integer) - // ^(*(x,c1),c2) -> ^(x,c2)*c1^c2 (c1, c2 numeric(), c1>0) - // ^(*(x,c1),c2) -> ^(-x,c2)*c1^c2 (c1, c2 numeric(), c1<0) - - debugmsg("power eval",LOGLEVEL_MEMBER_FUNCTION); - - if ((level==1)&&(flags & status_flags::evaluated)) { - return *this; - } else if (level == -max_recursion_level) { - throw(std::runtime_error("max recursion level reached")); - } - - ex const & ebasis = level==1 ? basis : basis.eval(level-1); - ex const & eexponent = level==1 ? exponent : exponent.eval(level-1); - - bool basis_is_numerical=0; - bool exponent_is_numerical=0; - numeric * num_basis; - numeric * num_exponent; - - if (is_exactly_of_type(*ebasis.bp,numeric)) { - basis_is_numerical=1; - num_basis=static_cast(ebasis.bp); - } - if (is_exactly_of_type(*eexponent.bp,numeric)) { - exponent_is_numerical=1; - num_exponent=static_cast(eexponent.bp); - } - - // ^(x,0) -> 1 (0^0 also handled here) - if (eexponent.is_zero()) - return exONE(); - - // ^(x,1) -> x - if (eexponent.is_equal(exONE())) - return ebasis; - - // ^(0,x) -> 0 (except if x is real and negative) - if (ebasis.is_zero()) { - if (exponent_is_numerical && num_exponent->is_negative()) { - throw(std::overflow_error("power::eval(): division by zero")); - } else - return exZERO(); - } - - // ^(1,x) -> 1 - if (ebasis.is_equal(exONE())) - return exONE(); - - if (basis_is_numerical && exponent_is_numerical) { - // ^(c1,c2) -> c1^c2 (c1, c2 numeric(), - // except if c1,c2 are rational, but c1^c2 is not) - bool basis_is_rational = num_basis->is_rational(); - bool exponent_is_rational = num_exponent->is_rational(); - numeric res = (*num_basis).power(*num_exponent); - - if ((!basis_is_rational || !exponent_is_rational) - || res.is_rational()) { - return res; - } - ASSERT(!num_exponent->is_integer()); // has been handled by now - // ^(c1,n/m) -> *(c1^q,c1^(n/m-q)), 0<(n/m-h)<1, q integer - if (basis_is_rational && exponent_is_rational - && num_exponent->is_real() - && !num_exponent->is_integer()) { - numeric r, q, n, m; - n = num_exponent->numer(); - m = num_exponent->denom(); - q = iquo(n, m, r); - if (r.is_negative()) { - r = r.add(m); - q = q.sub(numONE()); - } - if (q.is_zero()) // the exponent was in the allowed range 0<(n/m)<1 - return this->hold(); - else { - epvector res(2); - res.push_back(expair(ebasis,r.div(m))); - res.push_back(expair(ex(num_basis->power(q)),exONE())); - return (new mul(res))->setflag(status_flags::dynallocated | status_flags::evaluated); - /*return mul(num_basis->power(q), - power(ex(*num_basis),ex(r.div(m)))).hold(); - */ - /* return (new mul(num_basis->power(q), - power(*num_basis,r.div(m)).hold()))->setflag(status_flags::dynallocated | status_flags::evaluated); - */ - } - } - } - - // ^(^(x,c1),c2) -> ^(x,c1*c2) - // (c1, c2 numeric(), c2 integer or -1 < c1 <= 1, - // case c1=1 should not happen, see below!) - if (exponent_is_numerical && is_ex_exactly_of_type(ebasis,power)) { - power const & sub_power=ex_to_power(ebasis); - ex const & sub_basis=sub_power.basis; - ex const & sub_exponent=sub_power.exponent; - if (is_ex_exactly_of_type(sub_exponent,numeric)) { - numeric const & num_sub_exponent=ex_to_numeric(sub_exponent); - ASSERT(num_sub_exponent!=numeric(1)); - if (num_exponent->is_integer() || abs(num_sub_exponent)<1) { - return power(sub_basis,num_sub_exponent.mul(*num_exponent)); - } - } - } - - // ^(*(x,y,z),c1) -> *(x^c1,y^c1,z^c1) (c1 integer) - if (exponent_is_numerical && num_exponent->is_integer() && - is_ex_exactly_of_type(ebasis,mul)) { - return expand_mul(ex_to_mul(ebasis), *num_exponent); - } - - // ^(*(...,x;c1),c2) -> ^(*(...,x;1),c2)*c1^c2 (c1, c2 numeric(), c1>0) - // ^(*(...,x,c1),c2) -> ^(*(...,x;-1),c2)*(-c1)^c2 (c1, c2 numeric(), c1<0) - if (exponent_is_numerical && is_ex_exactly_of_type(ebasis,mul)) { - ASSERT(!num_exponent->is_integer()); // should have been handled above - mul const & mulref=ex_to_mul(ebasis); - if (!mulref.overall_coeff.is_equal(exONE())) { - numeric const & num_coeff=ex_to_numeric(mulref.overall_coeff); - if (num_coeff.is_real()) { - if (num_coeff.is_positive()>0) { - mul * mulp=new mul(mulref); - mulp->overall_coeff=exONE(); - mulp->clearflag(status_flags::evaluated); - mulp->clearflag(status_flags::hash_calculated); - return (new mul(power(*mulp,exponent), - power(num_coeff,*num_exponent)))-> - setflag(status_flags::dynallocated); - } else { - ASSERT(num_coeff.compare(numZERO())<0); - if (num_coeff.compare(numMINUSONE())!=0) { - mul * mulp=new mul(mulref); - mulp->overall_coeff=exMINUSONE(); - mulp->clearflag(status_flags::evaluated); - mulp->clearflag(status_flags::hash_calculated); - return (new mul(power(*mulp,exponent), - power(abs(num_coeff),*num_exponent)))-> - setflag(status_flags::dynallocated); - } - } - } - } - } - - if (are_ex_trivially_equal(ebasis,basis) && - are_ex_trivially_equal(eexponent,exponent)) { - return this->hold(); - } - return (new power(ebasis, eexponent))->setflag(status_flags::dynallocated | - status_flags::evaluated); + if ((level==1) && (flags & status_flags::evaluated)) + return *this; + else if (level == -max_recursion_level) + throw(std::runtime_error("max recursion level reached")); + + const ex & ebasis = level==1 ? basis : basis.eval(level-1); + const ex & eexponent = level==1 ? exponent : exponent.eval(level-1); + + const numeric *num_basis = nullptr; + const numeric *num_exponent = nullptr; + + if (is_exactly_a(ebasis)) { + num_basis = &ex_to(ebasis); + } + if (is_exactly_a(eexponent)) { + num_exponent = &ex_to(eexponent); + } + + // ^(x,0) -> 1 (0^0 also handled here) + if (eexponent.is_zero()) { + if (ebasis.is_zero()) + throw (std::domain_error("power::eval(): pow(0,0) is undefined")); + else + return _ex1; + } + + // ^(x,1) -> x + if (eexponent.is_equal(_ex1)) + return ebasis; + + // ^(0,c1) -> 0 or exception (depending on real value of c1) + if ( ebasis.is_zero() && num_exponent ) { + if ((num_exponent->real()).is_zero()) + throw (std::domain_error("power::eval(): pow(0,I) is undefined")); + else if ((num_exponent->real()).is_negative()) + throw (pole_error("power::eval(): division by zero",1)); + else + return _ex0; + } + + // ^(1,x) -> 1 + if (ebasis.is_equal(_ex1)) + return _ex1; + + // power of a function calculated by separate rules defined for this function + if (is_exactly_a(ebasis)) + return ex_to(ebasis).power(eexponent); + + // Turn (x^c)^d into x^(c*d) in the case that x is positive and c is real. + if (is_exactly_a(ebasis) && ebasis.op(0).info(info_flags::positive) && ebasis.op(1).info(info_flags::real)) + return power(ebasis.op(0), ebasis.op(1) * eexponent); + + if ( num_exponent ) { + + // ^(c1,c2) -> c1^c2 (c1, c2 numeric(), + // except if c1,c2 are rational, but c1^c2 is not) + if ( num_basis ) { + const bool basis_is_crational = num_basis->is_crational(); + const bool exponent_is_crational = num_exponent->is_crational(); + if (!basis_is_crational || !exponent_is_crational) { + // return a plain float + return (new numeric(num_basis->power(*num_exponent)))->setflag(status_flags::dynallocated | + status_flags::evaluated | + status_flags::expanded); + } + + const numeric res = num_basis->power(*num_exponent); + if (res.is_crational()) { + return res; + } + GINAC_ASSERT(!num_exponent->is_integer()); // has been handled by now + + // ^(c1,n/m) -> *(c1^q,c1^(n/m-q)), 0<(n/m-q)<1, q integer + if (basis_is_crational && exponent_is_crational + && num_exponent->is_real() + && !num_exponent->is_integer()) { + const numeric n = num_exponent->numer(); + const numeric m = num_exponent->denom(); + numeric r; + numeric q = iquo(n, m, r); + if (r.is_negative()) { + r += m; + --q; + } + if (q.is_zero()) { // the exponent was in the allowed range 0<(n/m)<1 + if (num_basis->is_rational() && !num_basis->is_integer()) { + // try it for numerator and denominator separately, in order to + // partially simplify things like (5/8)^(1/3) -> 1/2*5^(1/3) + const numeric bnum = num_basis->numer(); + const numeric bden = num_basis->denom(); + const numeric res_bnum = bnum.power(*num_exponent); + const numeric res_bden = bden.power(*num_exponent); + if (res_bnum.is_integer()) + return (new mul(power(bden,-*num_exponent),res_bnum))->setflag(status_flags::dynallocated | status_flags::evaluated); + if (res_bden.is_integer()) + return (new mul(power(bnum,*num_exponent),res_bden.inverse()))->setflag(status_flags::dynallocated | status_flags::evaluated); + } + return this->hold(); + } else { + // assemble resulting product, but allowing for a re-evaluation, + // because otherwise we'll end up with something like + // (7/8)^(4/3) -> 7/8*(1/2*7^(1/3)) + // instead of 7/16*7^(1/3). + ex prod = power(*num_basis,r.div(m)); + return prod*power(*num_basis,q); + } + } + } + + // ^(^(x,c1),c2) -> ^(x,c1*c2) + // (c1, c2 numeric(), c2 integer or -1 < c1 <= 1 or (c1=-1 and c2>0), + // case c1==1 should not happen, see below!) + if (is_exactly_a(ebasis)) { + const power & sub_power = ex_to(ebasis); + const ex & sub_basis = sub_power.basis; + const ex & sub_exponent = sub_power.exponent; + if (is_exactly_a(sub_exponent)) { + const numeric & num_sub_exponent = ex_to(sub_exponent); + GINAC_ASSERT(num_sub_exponent!=numeric(1)); + if (num_exponent->is_integer() || (abs(num_sub_exponent) - (*_num1_p)).is_negative() || + (num_sub_exponent == *_num_1_p && num_exponent->is_positive())) { + return power(sub_basis,num_sub_exponent.mul(*num_exponent)); + } + } + } + + // ^(*(x,y,z),c1) -> *(x^c1,y^c1,z^c1) (c1 integer) + if (num_exponent->is_integer() && is_exactly_a(ebasis)) { + return expand_mul(ex_to(ebasis), *num_exponent, 0); + } + + // (2*x + 6*y)^(-4) -> 1/16*(x + 3*y)^(-4) + if (num_exponent->is_integer() && is_exactly_a(ebasis)) { + numeric icont = ebasis.integer_content(); + const numeric lead_coeff = + ex_to(ex_to(ebasis).seq.begin()->coeff).div(icont); + + const bool canonicalizable = lead_coeff.is_integer(); + const bool unit_normal = lead_coeff.is_pos_integer(); + if (canonicalizable && (! unit_normal)) + icont = icont.mul(*_num_1_p); + + if (canonicalizable && (icont != *_num1_p)) { + const add& addref = ex_to(ebasis); + add* addp = new add(addref); + addp->setflag(status_flags::dynallocated); + addp->clearflag(status_flags::hash_calculated); + addp->overall_coeff = ex_to(addp->overall_coeff).div_dyn(icont); + for (auto & i : addp->seq) + i.coeff = ex_to(i.coeff).div_dyn(icont); + + const numeric c = icont.power(*num_exponent); + if (likely(c != *_num1_p)) + return (new mul(power(*addp, *num_exponent), c))->setflag(status_flags::dynallocated); + else + return power(*addp, *num_exponent); + } + } + + // ^(*(...,x;c1),c2) -> *(^(*(...,x;1),c2),c1^c2) (c1, c2 numeric(), c1>0) + // ^(*(...,x;c1),c2) -> *(^(*(...,x;-1),c2),(-c1)^c2) (c1, c2 numeric(), c1<0) + if (is_exactly_a(ebasis)) { + GINAC_ASSERT(!num_exponent->is_integer()); // should have been handled above + const mul & mulref = ex_to(ebasis); + if (!mulref.overall_coeff.is_equal(_ex1)) { + const numeric & num_coeff = ex_to(mulref.overall_coeff); + if (num_coeff.is_real()) { + if (num_coeff.is_positive()) { + mul *mulp = new mul(mulref); + mulp->overall_coeff = _ex1; + mulp->setflag(status_flags::dynallocated); + mulp->clearflag(status_flags::evaluated); + mulp->clearflag(status_flags::hash_calculated); + return (new mul(power(*mulp,exponent), + power(num_coeff,*num_exponent)))->setflag(status_flags::dynallocated); + } else { + GINAC_ASSERT(num_coeff.compare(*_num0_p)<0); + if (!num_coeff.is_equal(*_num_1_p)) { + mul *mulp = new mul(mulref); + mulp->overall_coeff = _ex_1; + mulp->setflag(status_flags::dynallocated); + mulp->clearflag(status_flags::evaluated); + mulp->clearflag(status_flags::hash_calculated); + return (new mul(power(*mulp,exponent), + power(abs(num_coeff),*num_exponent)))->setflag(status_flags::dynallocated); + } + } + } + } + } + + // ^(nc,c1) -> ncmul(nc,nc,...) (c1 positive integer, unless nc is a matrix) + if (num_exponent->is_pos_integer() && + ebasis.return_type() != return_types::commutative && + !is_a(ebasis)) { + return ncmul(exvector(num_exponent->to_int(), ebasis)); + } + } + + if (are_ex_trivially_equal(ebasis,basis) && + are_ex_trivially_equal(eexponent,exponent)) { + return this->hold(); + } + return (new power(ebasis, eexponent))->setflag(status_flags::dynallocated | + status_flags::evaluated); } ex power::evalf(int level) const { - debugmsg("power evalf",LOGLEVEL_MEMBER_FUNCTION); + ex ebasis; + ex eexponent; + + if (level==1) { + ebasis = basis; + eexponent = exponent; + } else if (level == -max_recursion_level) { + throw(std::runtime_error("max recursion level reached")); + } else { + ebasis = basis.evalf(level-1); + if (!is_exactly_a(exponent)) + eexponent = exponent.evalf(level-1); + else + eexponent = exponent; + } + + return power(ebasis,eexponent); +} + +ex power::evalm() const +{ + const ex ebasis = basis.evalm(); + const ex eexponent = exponent.evalm(); + if (is_a(ebasis)) { + if (is_exactly_a(eexponent)) { + return (new matrix(ex_to(ebasis).pow(eexponent)))->setflag(status_flags::dynallocated); + } + } + return (new power(ebasis, eexponent))->setflag(status_flags::dynallocated); +} + +bool power::has(const ex & other, unsigned options) const +{ + if (!(options & has_options::algebraic)) + return basic::has(other, options); + if (!is_a(other)) + return basic::has(other, options); + if (!exponent.info(info_flags::integer) || + !other.op(1).info(info_flags::integer)) + return basic::has(other, options); + if (exponent.info(info_flags::posint) && + other.op(1).info(info_flags::posint) && + ex_to(exponent) > ex_to(other.op(1)) && + basis.match(other.op(0))) + return true; + if (exponent.info(info_flags::negint) && + other.op(1).info(info_flags::negint) && + ex_to(exponent) < ex_to(other.op(1)) && + basis.match(other.op(0))) + return true; + return basic::has(other, options); +} - ex ebasis; - ex eexponent; - - if (level==1) { - ebasis=basis; - eexponent=exponent; - } else if (level == -max_recursion_level) { - throw(std::runtime_error("max recursion level reached")); - } else { - ebasis=basis.evalf(level-1); - eexponent=exponent.evalf(level-1); - } +// from mul.cpp +extern bool tryfactsubs(const ex &, const ex &, int &, exmap&); + +ex power::subs(const exmap & m, unsigned options) const +{ + const ex &subsed_basis = basis.subs(m, options); + const ex &subsed_exponent = exponent.subs(m, options); + + if (!are_ex_trivially_equal(basis, subsed_basis) + || !are_ex_trivially_equal(exponent, subsed_exponent)) + return power(subsed_basis, subsed_exponent).subs_one_level(m, options); + + if (!(options & subs_options::algebraic)) + return subs_one_level(m, options); + + for (auto & it : m) { + int nummatches = std::numeric_limits::max(); + exmap repls; + if (tryfactsubs(*this, it.first, nummatches, repls)) { + ex anum = it.second.subs(repls, subs_options::no_pattern); + ex aden = it.first.subs(repls, subs_options::no_pattern); + ex result = (*this)*power(anum/aden, nummatches); + return (ex_to(result)).subs_one_level(m, options); + } + } + + return subs_one_level(m, options); +} - return power(ebasis,eexponent); +ex power::eval_ncmul(const exvector & v) const +{ + return inherited::eval_ncmul(v); } -ex power::subs(lst const & ls, lst const & lr) const +ex power::conjugate() const { - ex const & subsed_basis=basis.subs(ls,lr); - ex const & subsed_exponent=exponent.subs(ls,lr); + // conjugate(pow(x,y))==pow(conjugate(x),conjugate(y)) unless on the + // branch cut which runs along the negative real axis. + if (basis.info(info_flags::positive)) { + ex newexponent = exponent.conjugate(); + if (are_ex_trivially_equal(exponent, newexponent)) { + return *this; + } + return (new power(basis, newexponent))->setflag(status_flags::dynallocated); + } + if (exponent.info(info_flags::integer)) { + ex newbasis = basis.conjugate(); + if (are_ex_trivially_equal(basis, newbasis)) { + return *this; + } + return (new power(newbasis, exponent))->setflag(status_flags::dynallocated); + } + return conjugate_function(*this).hold(); +} - if (are_ex_trivially_equal(basis,subsed_basis)&& - are_ex_trivially_equal(exponent,subsed_exponent)) { - return *this; - } - - return power(subsed_basis, subsed_exponent); +ex power::real_part() const +{ + // basis == a+I*b, exponent == c+I*d + const ex a = basis.real_part(); + const ex c = exponent.real_part(); + if (basis.is_equal(a) && exponent.is_equal(c)) { + // Re(a^c) + return *this; + } + + const ex b = basis.imag_part(); + if (exponent.info(info_flags::integer)) { + // Re((a+I*b)^c) w/ c ∈ ℤ + long N = ex_to(c).to_long(); + // Use real terms in Binomial expansion to construct + // Re(expand(power(a+I*b, N))). + long NN = N > 0 ? N : -N; + ex numer = N > 0 ? _ex1 : power(power(a,2) + power(b,2), NN); + ex result = 0; + for (long n = 0; n <= NN; n += 2) { + ex term = binomial(NN, n) * power(a, NN-n) * power(b, n) / numer; + if (n % 4 == 0) { + result += term; // sign: I^n w/ n == 4*m + } else { + result -= term; // sign: I^n w/ n == 4*m+2 + } + } + return result; + } + + // Re((a+I*b)^(c+I*d)) + const ex d = exponent.imag_part(); + return power(abs(basis),c)*exp(-d*atan2(b,a))*cos(c*atan2(b,a)+d*log(abs(basis))); } -ex power::simplify_ncmul(exvector const & v) const +ex power::imag_part() const { - return basic::simplify_ncmul(v); + const ex a = basis.real_part(); + const ex c = exponent.real_part(); + if (basis.is_equal(a) && exponent.is_equal(c)) { + // Im(a^c) + return 0; + } + + const ex b = basis.imag_part(); + if (exponent.info(info_flags::integer)) { + // Im((a+I*b)^c) w/ c ∈ ℤ + long N = ex_to(c).to_long(); + // Use imaginary terms in Binomial expansion to construct + // Im(expand(power(a+I*b, N))). + long p = N > 0 ? 1 : 3; // modulus for positive sign + long NN = N > 0 ? N : -N; + ex numer = N > 0 ? _ex1 : power(power(a,2) + power(b,2), NN); + ex result = 0; + for (long n = 1; n <= NN; n += 2) { + ex term = binomial(NN, n) * power(a, NN-n) * power(b, n) / numer; + if (n % 4 == p) { + result += term; // sign: I^n w/ n == 4*m+p + } else { + result -= term; // sign: I^n w/ n == 4*m+2+p + } + } + return result; + } + + // Im((a+I*b)^(c+I*d)) + const ex d = exponent.imag_part(); + return power(abs(basis),c)*exp(-d*atan2(b,a))*sin(c*atan2(b,a)+d*log(abs(basis))); } // protected -int power::compare_same_type(basic const & other) const +/** Implementation of ex::diff() for a power. + * @see ex::diff */ +ex power::derivative(const symbol & s) const { - ASSERT(is_exactly_of_type(other, power)); - power const & o=static_cast(const_cast(other)); + if (is_a(exponent)) { + // D(b^r) = r * b^(r-1) * D(b) (faster than the formula below) + epvector newseq; + newseq.reserve(2); + newseq.push_back(expair(basis, exponent - _ex1)); + newseq.push_back(expair(basis.diff(s), _ex1)); + return mul(std::move(newseq), exponent); + } else { + // D(b^e) = b^e * (D(e)*ln(b) + e*D(b)/b) + return mul(*this, + add(mul(exponent.diff(s), log(basis)), + mul(mul(exponent, basis.diff(s)), power(basis, _ex_1)))); + } +} - int cmpval; - cmpval=basis.compare(o.basis); - if (cmpval==0) { - return exponent.compare(o.exponent); - } - return cmpval; +int power::compare_same_type(const basic & other) const +{ + GINAC_ASSERT(is_exactly_a(other)); + const power &o = static_cast(other); + + int cmpval = basis.compare(o.basis); + if (cmpval) + return cmpval; + else + return exponent.compare(o.exponent); } -unsigned power::return_type(void) const +unsigned power::return_type() const { - return basis.return_type(); + return basis.return_type(); } - -unsigned power::return_type_tinfo(void) const + +return_type_t power::return_type_tinfo() const { - return basis.return_type_tinfo(); + return basis.return_type_tinfo(); } ex power::expand(unsigned options) const { - ex expanded_basis=basis.expand(options); - - if (!is_ex_exactly_of_type(exponent,numeric)|| - !ex_to_numeric(exponent).is_integer()) { - if (are_ex_trivially_equal(basis,expanded_basis)) { - return this->hold(); - } else { - return (new power(expanded_basis,exponent))-> - setflag(status_flags::dynallocated); - } - } - - // integer numeric exponent - numeric const & num_exponent=ex_to_numeric(exponent); - int int_exponent = num_exponent.to_int(); - - if (int_exponent > 0 && is_ex_exactly_of_type(expanded_basis,add)) { - return expand_add(ex_to_add(expanded_basis), int_exponent); - } - - if (is_ex_exactly_of_type(expanded_basis,mul)) { - return expand_mul(ex_to_mul(expanded_basis), num_exponent); - } - - // cannot expand further - if (are_ex_trivially_equal(basis,expanded_basis)) { - return this->hold(); - } else { - return (new power(expanded_basis,exponent))-> - setflag(status_flags::dynallocated); - } + if (is_a(basis) && exponent.info(info_flags::integer)) { + // A special case worth optimizing. + setflag(status_flags::expanded); + return *this; + } + + // (x*p)^c -> x^c * p^c, if p>0 + // makes sense before expanding the basis + if (is_exactly_a(basis) && !basis.info(info_flags::indefinite)) { + const mul &m = ex_to(basis); + exvector prodseq; + epvector powseq; + prodseq.reserve(m.seq.size() + 1); + powseq.reserve(m.seq.size() + 1); + bool possign = true; + + // search for positive/negative factors + for (auto & cit : m.seq) { + ex e=m.recombine_pair_to_ex(cit); + if (e.info(info_flags::positive)) + prodseq.push_back(pow(e, exponent).expand(options)); + else if (e.info(info_flags::negative)) { + prodseq.push_back(pow(-e, exponent).expand(options)); + possign = !possign; + } else + powseq.push_back(cit); + } + + // take care on the numeric coefficient + ex coeff=(possign? _ex1 : _ex_1); + if (m.overall_coeff.info(info_flags::positive) && m.overall_coeff != _ex1) + prodseq.push_back(power(m.overall_coeff, exponent)); + else if (m.overall_coeff.info(info_flags::negative) && m.overall_coeff != _ex_1) + prodseq.push_back(power(-m.overall_coeff, exponent)); + else + coeff *= m.overall_coeff; + + // If positive/negative factors are found, then extract them. + // In either case we set a flag to avoid the second run on a part + // which does not have positive/negative terms. + if (prodseq.size() > 0) { + ex newbasis = coeff*mul(std::move(powseq)); + ex_to(newbasis).setflag(status_flags::purely_indefinite); + return ((new mul(std::move(prodseq)))->setflag(status_flags::dynallocated)*(new power(newbasis, exponent))->setflag(status_flags::dynallocated).expand(options)).expand(options); + } else + ex_to(basis).setflag(status_flags::purely_indefinite); + } + + const ex expanded_basis = basis.expand(options); + const ex expanded_exponent = exponent.expand(options); + + // x^(a+b) -> x^a * x^b + if (is_exactly_a(expanded_exponent)) { + const add &a = ex_to(expanded_exponent); + exvector distrseq; + distrseq.reserve(a.seq.size() + 1); + for (auto & cit : a.seq) { + distrseq.push_back(power(expanded_basis, a.recombine_pair_to_ex(cit))); + } + + // Make sure that e.g. (x+y)^(2+a) expands the (x+y)^2 factor + if (ex_to(a.overall_coeff).is_integer()) { + const numeric &num_exponent = ex_to(a.overall_coeff); + long int_exponent = num_exponent.to_int(); + if (int_exponent > 0 && is_exactly_a(expanded_basis)) + distrseq.push_back(expand_add(ex_to(expanded_basis), int_exponent, options)); + else + distrseq.push_back(power(expanded_basis, a.overall_coeff)); + } else + distrseq.push_back(power(expanded_basis, a.overall_coeff)); + + // Make sure that e.g. (x+y)^(1+a) -> x*(x+y)^a + y*(x+y)^a + ex r = (new mul(distrseq))->setflag(status_flags::dynallocated); + return r.expand(options); + } + + if (!is_exactly_a(expanded_exponent) || + !ex_to(expanded_exponent).is_integer()) { + if (are_ex_trivially_equal(basis,expanded_basis) && are_ex_trivially_equal(exponent,expanded_exponent)) { + return this->hold(); + } else { + return (new power(expanded_basis,expanded_exponent))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)); + } + } + + // integer numeric exponent + const numeric & num_exponent = ex_to(expanded_exponent); + long int_exponent = num_exponent.to_long(); + + // (x+y)^n, n>0 + if (int_exponent > 0 && is_exactly_a(expanded_basis)) + return expand_add(ex_to(expanded_basis), int_exponent, options); + + // (x*y)^n -> x^n * y^n + if (is_exactly_a(expanded_basis)) + return expand_mul(ex_to(expanded_basis), num_exponent, options, true); + + // cannot expand further + if (are_ex_trivially_equal(basis,expanded_basis) && are_ex_trivially_equal(exponent,expanded_exponent)) + return this->hold(); + else + return (new power(expanded_basis,expanded_exponent))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)); } ////////// @@ -442,278 +922,449 @@ ex power::expand(unsigned options) const // non-virtual functions in this class ////////// -ex power::expand_add(add const & a, int const n) const -{ - // expand a^n where a is an add and n is an integer - - if (n==2) { - return expand_add_2(a); - } - - int m=a.nops(); - exvector sum; - sum.reserve((n+1)*(m-1)); - intvector k(m-1); - intvector k_cum(m-1); // k_cum[l]:=sum(i=0,l,k[l]); - intvector upper_limit(m-1); - int l; - - for (int l=0; lsetflag(status_flags::dynallocated)); - - // increment k[] - l=m-2; - while ((l>=0)&&((++k[l])>upper_limit[l])) { - k[l]=0; - l--; - } - if (l<0) break; - - // recalc k_cum[] and upper_limit[] - if (l==0) { - k_cum[0]=k[0]; - } else { - k_cum[l]=k_cum[l-1]+k[l]; - } - for (int i=l+1; isetflag(status_flags::dynallocated); -} +namespace { // anonymous namespace for power::expand_add() helpers -/* -ex power::expand_add_2(add const & a) const -{ - // special case: expand a^2 where a is an add - - epvector sum; - sum.reserve((a.seq.size()*(a.seq.size()+1))/2); - epvector::const_iterator last=a.seq.end(); - - for (epvector::const_iterator cit0=a.seq.begin(); cit0!=last; ++cit0) { - ex const & b=a.recombine_pair_to_ex(*cit0); - ASSERT(!is_ex_exactly_of_type(b,add)); - ASSERT(!is_ex_exactly_of_type(b,power)|| - !is_ex_exactly_of_type(ex_to_power(b).exponent,numeric)|| - !ex_to_numeric(ex_to_power(b).exponent).is_pos_integer()); - if (is_ex_exactly_of_type(b,mul)) { - sum.push_back(a.split_ex_to_pair(expand_mul(ex_to_mul(b),numTWO()))); - } else { - sum.push_back(a.split_ex_to_pair((new power(b,exTWO()))-> - setflag(status_flags::dynallocated))); - } - for (epvector::const_iterator cit1=cit0+1; cit1!=last; ++cit1) { - sum.push_back(a.split_ex_to_pair((new mul(a.recombine_pair_to_ex(*cit0), - a.recombine_pair_to_ex(*cit1)))-> - setflag(status_flags::dynallocated), - exTWO())); - } - } - - ASSERT(sum.size()==(a.seq.size()*(a.seq.size()+1))/2); - - return (new add(sum))->setflag(status_flags::dynallocated); -} -*/ - -ex power::expand_add_2(add const & a) const -{ - // special case: expand a^2 where a is an add - - epvector sum; - unsigned a_nops=a.nops(); - sum.reserve((a_nops*(a_nops+1))/2); - epvector::const_iterator last=a.seq.end(); - - // power(+(x,...,z;c),2)=power(+(x,...,z;0),2)+2*c*+(x,...,z;0)+c*c - // first part: ignore overall_coeff and expand other terms - for (epvector::const_iterator cit0=a.seq.begin(); cit0!=last; ++cit0) { - ex const & r=(*cit0).rest; - ex const & c=(*cit0).coeff; - - ASSERT(!is_ex_exactly_of_type(r,add)); - ASSERT(!is_ex_exactly_of_type(r,power)|| - !is_ex_exactly_of_type(ex_to_power(r).exponent,numeric)|| - !ex_to_numeric(ex_to_power(r).exponent).is_pos_integer()|| - !is_ex_exactly_of_type(ex_to_power(r).basis,add)|| - !is_ex_exactly_of_type(ex_to_power(r).basis,mul)|| - !is_ex_exactly_of_type(ex_to_power(r).basis,power)); - - if (are_ex_trivially_equal(c,exONE())) { - if (is_ex_exactly_of_type(r,mul)) { - sum.push_back(expair(expand_mul(ex_to_mul(r),numTWO()),exONE())); - } else { - sum.push_back(expair((new power(r,exTWO()))->setflag(status_flags::dynallocated), - exONE())); - } - } else { - if (is_ex_exactly_of_type(r,mul)) { - sum.push_back(expair(expand_mul(ex_to_mul(r),numTWO()), - ex_to_numeric(c).power_dyn(numTWO()))); - } else { - sum.push_back(expair((new power(r,exTWO()))->setflag(status_flags::dynallocated), - ex_to_numeric(c).power_dyn(numTWO()))); - } - } - - for (epvector::const_iterator cit1=cit0+1; cit1!=last; ++cit1) { - ex const & r1=(*cit1).rest; - ex const & c1=(*cit1).coeff; - sum.push_back(a.combine_ex_with_coeff_to_pair((new mul(r,r1))->setflag(status_flags::dynallocated), - numTWO().mul(ex_to_numeric(c)).mul_dyn(ex_to_numeric(c1)))); - } - } - - ASSERT(sum.size()==(a.seq.size()*(a.seq.size()+1))/2); - - // second part: add terms coming from overall_factor (if != 0) - if (!a.overall_coeff.is_equal(exZERO())) { - for (epvector::const_iterator cit=a.seq.begin(); cit!=a.seq.end(); ++cit) { - sum.push_back(a.combine_pair_with_coeff_to_pair(*cit,ex_to_numeric(a.overall_coeff).mul_dyn(numTWO()))); - } - sum.push_back(expair(ex_to_numeric(a.overall_coeff).power_dyn(numTWO()),exONE())); - } - - ASSERT(sum.size()==(a_nops*(a_nops+1))/2); - - return (new add(sum))->setflag(status_flags::dynallocated); -} - -ex power::expand_mul(mul const & m, numeric const & n) const -{ - // expand m^n where m is a mul and n is and integer - - if (n.is_equal(numZERO())) { - return exONE(); - } - - epvector distrseq; - distrseq.reserve(m.seq.size()); - epvector::const_iterator last=m.seq.end(); - epvector::const_iterator cit=m.seq.begin(); - while (cit!=last) { - if (is_ex_exactly_of_type((*cit).rest,numeric)) { - distrseq.push_back(m.combine_pair_with_coeff_to_pair(*cit,n)); - } else { - // it is safe not to call mul::combine_pair_with_coeff_to_pair() - // since n is an integer - distrseq.push_back(expair((*cit).rest, - ex_to_numeric((*cit).coeff).mul(n))); - } - ++cit; - } - return (new mul(distrseq,ex_to_numeric(m.overall_coeff).power_dyn(n))) - ->setflag(status_flags::dynallocated); +/** Helper class to generate all bounded combinatorial partitions of an integer + * n with exactly m parts (including zero parts) in non-decreasing order. + */ +class partition_generator { +private: + // Partitions n into m parts, not including zero parts. + // (Cf. OEIS sequence A008284; implementation adapted from Jörg Arndt's + // FXT library) + struct mpartition2 + { + // partition: x[1] + x[2] + ... + x[m] = n and sentinel x[0] == 0 + std::vector x; + int n; // n>0 + int m; // 0 partition; // current partition +public: + partition_generator(unsigned n_, unsigned m_) + : mpgen(n_, 1), m(m_), partition(m_) + { } + // returns current partition in non-decreasing order, padded with zeros + const std::vector& current() const + { + for (int i = 0; i < m - mpgen.m; ++i) + partition[i] = 0; // pad with zeros + + for (int i = m - mpgen.m; i < m; ++i) + partition[i] = mpgen.x[i - m + mpgen.m + 1]; + + return partition; + } + bool next() + { + if (!mpgen.next_partition()) { + if (mpgen.m == m || mpgen.m == mpgen.n) + return false; // current is last + // increment number of parts + mpgen = mpartition2(mpgen.n, mpgen.m + 1); + } + return true; + } +}; + +/** Helper class to generate all compositions of a partition of an integer n, + * starting with the compositions which has non-decreasing order. + */ +class composition_generator { +private: + // Generates all distinct permutations of a multiset. + // (Based on Aaron Williams' algorithm 1 from "Loopless Generation of + // Multiset Permutations using a Constant Number of Variables by Prefix + // Shifts." ) + struct coolmulti { + // element of singly linked list + struct element { + int value; + element* next; + element(int val, element* n) + : value(val), next(n) {} + ~element() + { // recurses down to the end of the singly linked list + delete next; + } + }; + element *head, *i, *after_i; + // NB: Partition must be sorted in non-decreasing order. + explicit coolmulti(const std::vector& partition) + : head(nullptr), i(nullptr), after_i(nullptr) + { + for (unsigned n = 0; n < partition.size(); ++n) { + head = new element(partition[n], head); + if (n <= 1) + i = head; + } + after_i = i->next; + } + ~coolmulti() + { // deletes singly linked list + delete head; + } + void next_permutation() + { + element *before_k; + if (after_i->next != nullptr && i->value >= after_i->next->value) + before_k = after_i; + else + before_k = i; + element *k = before_k->next; + before_k->next = k->next; + k->next = head; + if (k->value < head->value) + i = k; + after_i = i->next; + head = k; + } + bool finished() const + { + return after_i->next == nullptr && after_i->value >= head->value; + } + } cmgen; + bool atend; // needed for simplifying iteration over permutations + bool trivial; // likewise, true if all elements are equal + mutable std::vector composition; // current compositions +public: + explicit composition_generator(const std::vector& partition) + : cmgen(partition), atend(false), trivial(true), composition(partition.size()) + { + for (unsigned i=1; i& current() const + { + coolmulti::element* it = cmgen.head; + size_t i = 0; + while (it != nullptr) { + composition[i] = it->value; + it = it->next; + ++i; + } + return composition; + } + bool next() + { + // This ugly contortion is needed because the original coolmulti + // algorithm requires code duplication of the payload procedure, + // one before the loop and one inside it. + if (trivial || atend) + return false; + cmgen.next_permutation(); + atend = cmgen.finished(); + return true; + } +}; + +/** Helper function to compute the multinomial coefficient n!/(p1!*p2!*...*pk!) + * where n = p1+p2+...+pk, i.e. p is a partition of n. + */ +const numeric +multinomial_coefficient(const std::vector & p) +{ + numeric n = 0, d = 1; + for (auto & it : p) { + n += numeric(it); + d *= factorial(numeric(it)); + } + return factorial(numeric(n)) / d; } -/* -ex power::expand_commutative_3(ex const & basis, numeric const & exponent, - unsigned options) const -{ - // obsolete +} // anonymous namespace - exvector distrseq; - epvector splitseq; +/** expand a^n where a is an add and n is a positive integer. + * @see power::expand */ +ex power::expand_add(const add & a, long n, unsigned options) const +{ + // The special case power(+(x,...y;x),2) can be optimized better. + if (n==2) + return expand_add_2(a, options); + + // method: + // + // Consider base as the sum of all symbolic terms and the overall numeric + // coefficient and apply the binomial theorem: + // S = power(+(x,...,z;c),n) + // = power(+(+(x,...,z;0);c),n) + // = sum(binomial(n,k)*power(+(x,...,z;0),k)*c^(n-k), k=1..n) + c^n + // Then, apply the multinomial theorem to expand all power(+(x,...,z;0),k): + // The multinomial theorem is computed by an outer loop over all + // partitions of the exponent and an inner loop over all compositions of + // that partition. This method makes the expansion a combinatorial + // problem and allows us to directly construct the expanded sum and also + // to re-use the multinomial coefficients (since they depend only on the + // partition, not on the composition). + // + // multinomial power(+(x,y,z;0),3) example: + // partition : compositions : multinomial coefficient + // [0,0,3] : [3,0,0],[0,3,0],[0,0,3] : 3!/(3!*0!*0!) = 1 + // [0,1,2] : [2,1,0],[1,2,0],[2,0,1],... : 3!/(2!*1!*0!) = 3 + // [1,1,1] : [1,1,1] : 3!/(1!*1!*1!) = 6 + // => (x + y + z)^3 = + // x^3 + y^3 + z^3 + // + 3*x^2*y + 3*x*y^2 + 3*y^2*z + 3*y*z^2 + 3*x*z^2 + 3*x^2*z + // + 6*x*y*z + // + // multinomial power(+(x,y,z;0),4) example: + // partition : compositions : multinomial coefficient + // [0,0,4] : [4,0,0],[0,4,0],[0,0,4] : 4!/(4!*0!*0!) = 1 + // [0,1,3] : [3,1,0],[1,3,0],[3,0,1],... : 4!/(3!*1!*0!) = 4 + // [0,2,2] : [2,2,0],[2,0,2],[0,2,2] : 4!/(2!*2!*0!) = 6 + // [1,1,2] : [2,1,1],[1,2,1],[1,1,2] : 4!/(2!*1!*1!) = 12 + // (no [1,1,1,1] partition since it has too many parts) + // => (x + y + z)^4 = + // x^4 + y^4 + z^4 + // + 4*x^3*y + 4*x*y^3 + 4*y^3*z + 4*y*z^3 + 4*x*z^3 + 4*x^3*z + // + 6*x^2*y^2 + 6*y^2*z^2 + 6*x^2*z^2 + // + 12*x^2*y*z + 12*x*y^2*z + 12*x*y*z^2 + // + // Summary: + // r = 0 + // for k from 0 to n: + // f = c^(n-k)*binomial(n,k) + // for p in all partitions of n with m parts (including zero parts): + // h = f * multinomial coefficient of p + // for c in all compositions of p: + // t = 1 + // for e in all elements of c: + // t = t * a[e]^e + // r = r + h*t + // return r + + epvector result; + // The number of terms will be the number of combinatorial compositions, + // i.e. the number of unordered arrangements of m nonnegative integers + // which sum up to n. It is frequently written as C_n(m) and directly + // related with binomial coefficients: binomial(n+m-1,m-1). + size_t result_size = binomial(numeric(n+a.nops()-1), numeric(a.nops()-1)).to_long(); + if (!a.overall_coeff.is_zero()) { + // the result's overall_coeff is one of the terms + --result_size; + } + result.reserve(result_size); + + // Iterate over all terms in binomial expansion of + // S = power(+(x,...,z;c),n) + // = sum(binomial(n,k)*power(+(x,...,z;0),k)*c^(n-k), k=1..n) + c^n + for (int k = 1; k <= n; ++k) { + numeric binomial_coefficient; // binomial(n,k)*c^(n-k) + if (a.overall_coeff.is_zero()) { + // degenerate case with zero overall_coeff: + // apply multinomial theorem directly to power(+(x,...z;0),n) + binomial_coefficient = 1; + if (k < n) { + continue; + } + } else { + binomial_coefficient = binomial(numeric(n), numeric(k)) * pow(ex_to(a.overall_coeff), numeric(n-k)); + } + + // Multinomial expansion of power(+(x,...,z;0),k)*c^(n-k): + // Iterate over all partitions of k with exactly as many parts as + // there are symbolic terms in the basis (including zero parts). + partition_generator partitions(k, a.seq.size()); + do { + const std::vector& partition = partitions.current(); + const numeric coeff = multinomial_coefficient(partition) * binomial_coefficient; + + // Iterate over all compositions of the current partition. + composition_generator compositions(partition); + do { + const std::vector& exponent = compositions.current(); + exvector term; + term.reserve(n); + numeric factor = coeff; + for (unsigned i = 0; i < exponent.size(); ++i) { + const ex & r = a.seq[i].rest; + GINAC_ASSERT(!is_exactly_a(r)); + GINAC_ASSERT(!is_exactly_a(r) || + !is_exactly_a(ex_to(r).exponent) || + !ex_to(ex_to(r).exponent).is_pos_integer() || + !is_exactly_a(ex_to(r).basis) || + !is_exactly_a(ex_to(r).basis) || + !is_exactly_a(ex_to(r).basis)); + GINAC_ASSERT(is_exactly_a(a.seq[i].coeff)); + const numeric & c = ex_to(a.seq[i].coeff); + if (exponent[i] == 0) { + // optimize away + } else if (exponent[i] == 1) { + // optimized + term.push_back(r); + if (c != *_num1_p) + factor = factor.mul(c); + } else { // general case exponent[i] > 1 + term.push_back((new power(r, exponent[i]))->setflag(status_flags::dynallocated)); + if (c != *_num1_p) + factor = factor.mul(c.power(exponent[i])); + } + } + result.push_back(a.combine_ex_with_coeff_to_pair(mul(term).expand(options), factor)); + } while (compositions.next()); + } while (partitions.next()); + } + + GINAC_ASSERT(result.size() == result_size); + + if (a.overall_coeff.is_zero()) { + return (new add(std::move(result)))->setflag(status_flags::dynallocated | + status_flags::expanded); + } else { + return (new add(std::move(result), ex_to(a.overall_coeff).power(n)))->setflag(status_flags::dynallocated | + status_flags::expanded); + } +} - add const & addref=static_cast(*basis.bp); - splitseq=addref.seq; - splitseq.pop_back(); - ex first_operands=add(splitseq); - ex last_operand=addref.recombine_pair_to_ex(*(addref.seq.end()-1)); - - int n=exponent.to_int(); - for (int k=0; k<=n; k++) { - distrseq.push_back(binomial(n,k)*power(first_operands,numeric(k))* - power(last_operand,numeric(n-k))); - } - return ex((new add(distrseq))->setflag(status_flags::sub_expanded | - status_flags::expanded | - status_flags::dynallocated )). - expand(options); +/** Special case of power::expand_add. Expands a^2 where a is an add. + * @see power::expand_add */ +ex power::expand_add_2(const add & a, unsigned options) const +{ + epvector result; + size_t result_size = (a.nops() * (a.nops()+1)) / 2; + if (!a.overall_coeff.is_zero()) { + // the result's overall_coeff is one of the terms + --result_size; + } + result.reserve(result_size); + + epvector::const_iterator last = a.seq.end(); + + // power(+(x,...,z;c),2)=power(+(x,...,z;0),2)+2*c*+(x,...,z;0)+c*c + // first part: ignore overall_coeff and expand other terms + for (epvector::const_iterator cit0=a.seq.begin(); cit0!=last; ++cit0) { + const ex & r = cit0->rest; + const ex & c = cit0->coeff; + + GINAC_ASSERT(!is_exactly_a(r)); + GINAC_ASSERT(!is_exactly_a(r) || + !is_exactly_a(ex_to(r).exponent) || + !ex_to(ex_to(r).exponent).is_pos_integer() || + !is_exactly_a(ex_to(r).basis) || + !is_exactly_a(ex_to(r).basis) || + !is_exactly_a(ex_to(r).basis)); + + if (c.is_equal(_ex1)) { + if (is_exactly_a(r)) { + result.push_back(a.combine_ex_with_coeff_to_pair(expand_mul(ex_to(r), *_num2_p, options, true), + _ex1)); + } else { + result.push_back(a.combine_ex_with_coeff_to_pair((new power(r,_ex2))->setflag(status_flags::dynallocated), + _ex1)); + } + } else { + if (is_exactly_a(r)) { + result.push_back(a.combine_ex_with_coeff_to_pair(expand_mul(ex_to(r), *_num2_p, options, true), + ex_to(c).power_dyn(*_num2_p))); + } else { + result.push_back(a.combine_ex_with_coeff_to_pair((new power(r,_ex2))->setflag(status_flags::dynallocated), + ex_to(c).power_dyn(*_num2_p))); + } + } + + for (epvector::const_iterator cit1=cit0+1; cit1!=last; ++cit1) { + const ex & r1 = cit1->rest; + const ex & c1 = cit1->coeff; + result.push_back(a.combine_ex_with_coeff_to_pair(mul(r,r1).expand(options), + _num2_p->mul(ex_to(c)).mul_dyn(ex_to(c1)))); + } + } + + // second part: add terms coming from overall_coeff (if != 0) + if (!a.overall_coeff.is_zero()) { + for (auto & i : a.seq) + result.push_back(a.combine_pair_with_coeff_to_pair(i, ex_to(a.overall_coeff).mul_dyn(*_num2_p))); + } + + GINAC_ASSERT(result.size() == result_size); + + if (a.overall_coeff.is_zero()) { + return (new add(std::move(result)))->setflag(status_flags::dynallocated | + status_flags::expanded); + } else { + return (new add(std::move(result), ex_to(a.overall_coeff).power(2)))->setflag(status_flags::dynallocated | + status_flags::expanded); + } } -*/ -/* -ex power::expand_noncommutative(ex const & basis, numeric const & exponent, - unsigned options) const +/** Expand factors of m in m^n where m is a mul and n is an integer. + * @see power::expand */ +ex power::expand_mul(const mul & m, const numeric & n, unsigned options, bool from_expand) const { - ex rest_power=ex(power(basis,exponent.add(numMINUSONE()))). - expand(options | expand_options::internal_do_not_expand_power_operands); - - return ex(mul(rest_power,basis),0). - expand(options | expand_options::internal_do_not_expand_mul_operands); + GINAC_ASSERT(n.is_integer()); + + if (n.is_zero()) { + return _ex1; + } + + // do not bother to rename indices if there are no any. + if (!(options & expand_options::expand_rename_idx) && + m.info(info_flags::has_indices)) + options |= expand_options::expand_rename_idx; + // Leave it to multiplication since dummy indices have to be renamed + if ((options & expand_options::expand_rename_idx) && + (get_all_dummy_indices(m).size() > 0) && n.is_positive()) { + ex result = m; + exvector va = get_all_dummy_indices(m); + sort(va.begin(), va.end(), ex_is_less()); + + for (int i=1; i < n.to_int(); i++) + result *= rename_dummy_indices_uniquely(va, m); + return result; + } + + epvector distrseq; + distrseq.reserve(m.seq.size()); + bool need_reexpand = false; + + for (auto & cit : m.seq) { + expair p = m.combine_pair_with_coeff_to_pair(cit, n); + if (from_expand && is_exactly_a(cit.rest) && ex_to(p.coeff).is_pos_integer()) { + // this happens when e.g. (a+b)^(1/2) gets squared and + // the resulting product needs to be reexpanded + need_reexpand = true; + } + distrseq.push_back(p); + } + + const mul & result = static_cast((new mul(distrseq, ex_to(m.overall_coeff).power_dyn(n)))->setflag(status_flags::dynallocated)); + if (need_reexpand) + return ex(result).expand(options); + if (from_expand) + return result.setflag(status_flags::expanded); + return result; } -*/ - -////////// -// static member variables -////////// -// protected - -unsigned power::precedence=60; - -////////// -// global constants -////////// +GINAC_BIND_UNARCHIVER(power); -const power some_power; -type_info const & typeid_power=typeid(some_power); +} // namespace GiNaC