X-Git-Url: https://www.ginac.de/ginac.git//ginac.git?p=ginac.git;a=blobdiff_plain;f=ginac%2Fnormal.cpp;h=3a8a82fc358512ea89614af13046a421ca76473b;hp=cd2ad8eca5a0b29024bc142e81c19866f031bea7;hb=5ef801553eb39aed7bd2df9dd1aff9d752c3ea9d;hpb=7bc96470ee0dd5c59a8ea1a29b74a781668606a1 diff --git a/ginac/normal.cpp b/ginac/normal.cpp index cd2ad8ec..3a8a82fc 100644 --- a/ginac/normal.cpp +++ b/ginac/normal.cpp @@ -1681,25 +1681,22 @@ ex sqrfree_parfrac(const ex & a, const symbol & x) /** Create a symbol for replacing the expression "e" (or return a previously - * assigned symbol). The symbol is appended to sym_lst and returned, the - * expression is appended to repl_lst. + * assigned symbol). The symbol and expression are appended to repl, for + * a later application of subs(). * @see ex::normal */ -static ex replace_with_symbol(const ex &e, lst &sym_lst, lst &repl_lst) +static ex replace_with_symbol(const ex & e, exmap & repl) { - // Expression already in repl_lst? Then return the assigned symbol - lst::const_iterator its, itr; - for (its = sym_lst.begin(), itr = repl_lst.begin(); itr != repl_lst.end(); ++its, ++itr) - if (itr->is_equal(e)) - return *its; + // Expression already in repl? Then return the assigned symbol + for (exmap::const_iterator it = repl.begin(); it != repl.end(); ++it) + if (it->second.is_equal(e)) + return it->first; // Otherwise create new symbol and add to list, taking care that the - // replacement expression doesn't contain symbols from the sym_lst + // replacement expression doesn't itself contain symbols from repl, // because subs() is not recursive - symbol s; - ex es(s); - ex e_replaced = e.subs(sym_lst, repl_lst); - sym_lst.append(es); - repl_lst.append(e_replaced); + ex es = (new symbol)->setflag(status_flags::dynallocated); + ex e_replaced = e.subs(repl); + repl[es] = e_replaced; return es; } @@ -1708,7 +1705,7 @@ static ex replace_with_symbol(const ex &e, lst &sym_lst, lst &repl_lst) * to repl_lst and the symbol is returned. * @see basic::to_rational * @see basic::to_polynomial */ -static ex replace_with_symbol(const ex &e, lst &repl_lst) +static ex replace_with_symbol(const ex & e, lst & repl_lst) { // Expression already in repl_lst? Then return the assigned symbol for (lst::const_iterator it = repl_lst.begin(); it != repl_lst.end(); ++it) @@ -1716,10 +1713,9 @@ static ex replace_with_symbol(const ex &e, lst &repl_lst) return it->op(0); // Otherwise create new symbol and add to list, taking care that the - // replacement expression doesn't contain symbols from the sym_lst + // replacement expression doesn't itself contain symbols from the repl_lst, // because subs() is not recursive - symbol s; - ex es(s); + ex es = (new symbol)->setflag(status_flags::dynallocated); ex e_replaced = e.subs(repl_lst); repl_lst.append(es == e_replaced); return es; @@ -1736,18 +1732,18 @@ struct normal_map_function : public map_function { /** Default implementation of ex::normal(). It normalizes the children and * replaces the object with a temporary symbol. * @see ex::normal */ -ex basic::normal(lst &sym_lst, lst &repl_lst, int level) const +ex basic::normal(exmap & repl, int level) const { if (nops() == 0) - return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(*this, repl), _ex1))->setflag(status_flags::dynallocated); else { if (level == 1) - return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(*this, repl), _ex1))->setflag(status_flags::dynallocated); else if (level == -max_recursion_level) throw(std::runtime_error("max recursion level reached")); else { normal_map_function map_normal(level - 1); - return (new lst(replace_with_symbol(map(map_normal), sym_lst, repl_lst), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(map(map_normal), repl), _ex1))->setflag(status_flags::dynallocated); } } } @@ -1755,7 +1751,7 @@ ex basic::normal(lst &sym_lst, lst &repl_lst, int level) const /** Implementation of ex::normal() for symbols. This returns the unmodified symbol. * @see ex::normal */ -ex symbol::normal(lst &sym_lst, lst &repl_lst, int level) const +ex symbol::normal(exmap & repl, int level) const { return (new lst(*this, _ex1))->setflag(status_flags::dynallocated); } @@ -1765,19 +1761,19 @@ ex symbol::normal(lst &sym_lst, lst &repl_lst, int level) const * into re+I*im and replaces I and non-rational real numbers with a temporary * symbol. * @see ex::normal */ -ex numeric::normal(lst &sym_lst, lst &repl_lst, int level) const +ex numeric::normal(exmap & repl, int level) const { numeric num = numer(); ex numex = num; if (num.is_real()) { if (!num.is_integer()) - numex = replace_with_symbol(numex, sym_lst, repl_lst); + numex = replace_with_symbol(numex, repl); } else { // complex numeric re = num.real(), im = num.imag(); - ex re_ex = re.is_rational() ? re : replace_with_symbol(re, sym_lst, repl_lst); - ex im_ex = im.is_rational() ? im : replace_with_symbol(im, sym_lst, repl_lst); - numex = re_ex + im_ex * replace_with_symbol(I, sym_lst, repl_lst); + ex re_ex = re.is_rational() ? re : replace_with_symbol(re, repl); + ex im_ex = im.is_rational() ? im : replace_with_symbol(im, repl); + numex = re_ex + im_ex * replace_with_symbol(I, repl); } // Denominator is always a real integer (see numeric::denom()) @@ -1849,10 +1845,10 @@ static ex frac_cancel(const ex &n, const ex &d) /** Implementation of ex::normal() for a sum. It expands terms and performs * fractional addition. * @see ex::normal */ -ex add::normal(lst &sym_lst, lst &repl_lst, int level) const +ex add::normal(exmap & repl, int level) const { if (level == 1) - return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(*this, repl), _ex1))->setflag(status_flags::dynallocated); else if (level == -max_recursion_level) throw(std::runtime_error("max recursion level reached")); @@ -1862,12 +1858,12 @@ ex add::normal(lst &sym_lst, lst &repl_lst, int level) const dens.reserve(seq.size()+1); epvector::const_iterator it = seq.begin(), itend = seq.end(); while (it != itend) { - ex n = ex_to(recombine_pair_to_ex(*it)).normal(sym_lst, repl_lst, level-1); + ex n = ex_to(recombine_pair_to_ex(*it)).normal(repl, level-1); nums.push_back(n.op(0)); dens.push_back(n.op(1)); it++; } - ex n = ex_to(overall_coeff).normal(sym_lst, repl_lst, level-1); + ex n = ex_to(overall_coeff).normal(repl, level-1); nums.push_back(n.op(0)); dens.push_back(n.op(1)); GINAC_ASSERT(nums.size() == dens.size()); @@ -1908,10 +1904,10 @@ ex add::normal(lst &sym_lst, lst &repl_lst, int level) const /** Implementation of ex::normal() for a product. It cancels common factors * from fractions. * @see ex::normal() */ -ex mul::normal(lst &sym_lst, lst &repl_lst, int level) const +ex mul::normal(exmap & repl, int level) const { if (level == 1) - return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(*this, repl), _ex1))->setflag(status_flags::dynallocated); else if (level == -max_recursion_level) throw(std::runtime_error("max recursion level reached")); @@ -1921,12 +1917,12 @@ ex mul::normal(lst &sym_lst, lst &repl_lst, int level) const ex n; epvector::const_iterator it = seq.begin(), itend = seq.end(); while (it != itend) { - n = ex_to(recombine_pair_to_ex(*it)).normal(sym_lst, repl_lst, level-1); + n = ex_to(recombine_pair_to_ex(*it)).normal(repl, level-1); num.push_back(n.op(0)); den.push_back(n.op(1)); it++; } - n = ex_to(overall_coeff).normal(sym_lst, repl_lst, level-1); + n = ex_to(overall_coeff).normal(repl, level-1); num.push_back(n.op(0)); den.push_back(n.op(1)); @@ -1940,16 +1936,16 @@ ex mul::normal(lst &sym_lst, lst &repl_lst, int level) const * distributes integer exponents to numerator and denominator, and replaces * non-integer powers by temporary symbols. * @see ex::normal */ -ex power::normal(lst &sym_lst, lst &repl_lst, int level) const +ex power::normal(exmap & repl, int level) const { if (level == 1) - return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(*this, repl), _ex1))->setflag(status_flags::dynallocated); else if (level == -max_recursion_level) throw(std::runtime_error("max recursion level reached")); // Normalize basis and exponent (exponent gets reassembled) - ex n_basis = ex_to(basis).normal(sym_lst, repl_lst, level-1); - ex n_exponent = ex_to(exponent).normal(sym_lst, repl_lst, level-1); + ex n_basis = ex_to(basis).normal(repl, level-1); + ex n_exponent = ex_to(exponent).normal(repl, level-1); n_exponent = n_exponent.op(0) / n_exponent.op(1); if (n_exponent.info(info_flags::integer)) { @@ -1970,32 +1966,32 @@ ex power::normal(lst &sym_lst, lst &repl_lst, int level) const if (n_exponent.info(info_flags::positive)) { // (a/b)^x -> {sym((a/b)^x), 1} - return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1), n_exponent), sym_lst, repl_lst), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1), n_exponent), repl), _ex1))->setflag(status_flags::dynallocated); } else if (n_exponent.info(info_flags::negative)) { if (n_basis.op(1).is_equal(_ex1)) { // a^-x -> {1, sym(a^x)} - return (new lst(_ex1, replace_with_symbol(power(n_basis.op(0), -n_exponent), sym_lst, repl_lst)))->setflag(status_flags::dynallocated); + return (new lst(_ex1, replace_with_symbol(power(n_basis.op(0), -n_exponent), repl)))->setflag(status_flags::dynallocated); } else { // (a/b)^-x -> {sym((b/a)^x), 1} - return (new lst(replace_with_symbol(power(n_basis.op(1) / n_basis.op(0), -n_exponent), sym_lst, repl_lst), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(power(n_basis.op(1) / n_basis.op(0), -n_exponent), repl), _ex1))->setflag(status_flags::dynallocated); } } } // (a/b)^x -> {sym((a/b)^x, 1} - return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1), n_exponent), sym_lst, repl_lst), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1), n_exponent), repl), _ex1))->setflag(status_flags::dynallocated); } /** Implementation of ex::normal() for pseries. It normalizes each coefficient * and replaces the series by a temporary symbol. * @see ex::normal */ -ex pseries::normal(lst &sym_lst, lst &repl_lst, int level) const +ex pseries::normal(exmap & repl, int level) const { epvector newseq; epvector::const_iterator i = seq.begin(), end = seq.end(); @@ -2006,7 +2002,7 @@ ex pseries::normal(lst &sym_lst, lst &repl_lst, int level) const ++i; } ex n = pseries(relational(var,point), newseq); - return (new lst(replace_with_symbol(n, sym_lst, repl_lst), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(n, repl), _ex1))->setflag(status_flags::dynallocated); } @@ -2024,14 +2020,14 @@ ex pseries::normal(lst &sym_lst, lst &repl_lst, int level) const * @return normalized expression */ ex ex::normal(int level) const { - lst sym_lst, repl_lst; + exmap repl; - ex e = bp->normal(sym_lst, repl_lst, level); + ex e = bp->normal(repl, level); GINAC_ASSERT(is_a(e)); // Re-insert replaced symbols - if (sym_lst.nops() > 0) - e = e.subs(sym_lst, repl_lst); + if (!repl.empty()) + e = e.subs(repl); // Convert {numerator, denominator} form back to fraction return e.op(0) / e.op(1); @@ -2045,16 +2041,16 @@ ex ex::normal(int level) const * @return numerator */ ex ex::numer() const { - lst sym_lst, repl_lst; + exmap repl; - ex e = bp->normal(sym_lst, repl_lst, 0); + ex e = bp->normal(repl, 0); GINAC_ASSERT(is_a(e)); // Re-insert replaced symbols - if (sym_lst.nops() > 0) - return e.op(0).subs(sym_lst, repl_lst); - else + if (repl.empty()) return e.op(0); + else + return e.op(0).subs(repl); } /** Get denominator of an expression. If the expression is not of the normal @@ -2065,16 +2061,16 @@ ex ex::numer() const * @return denominator */ ex ex::denom() const { - lst sym_lst, repl_lst; + exmap repl; - ex e = bp->normal(sym_lst, repl_lst, 0); + ex e = bp->normal(repl, 0); GINAC_ASSERT(is_a(e)); // Re-insert replaced symbols - if (sym_lst.nops() > 0) - return e.op(1).subs(sym_lst, repl_lst); - else + if (repl.empty()) return e.op(1); + else + return e.op(1).subs(repl); } /** Get numerator and denominator of an expression. If the expresison is not @@ -2085,16 +2081,16 @@ ex ex::denom() const * @return a list [numerator, denominator] */ ex ex::numer_denom() const { - lst sym_lst, repl_lst; + exmap repl; - ex e = bp->normal(sym_lst, repl_lst, 0); + ex e = bp->normal(repl, 0); GINAC_ASSERT(is_a(e)); // Re-insert replaced symbols - if (sym_lst.nops() > 0) - return e.subs(sym_lst, repl_lst); - else + if (repl.empty()) return e; + else + return e.subs(repl); }