X-Git-Url: https://www.ginac.de/ginac.git//ginac.git?p=ginac.git;a=blobdiff_plain;f=ginac%2Fnormal.cpp;h=2603086a743b495b84033855ef59c822aca7af97;hp=8e2ec00e7c453265ca75eef852f7e681f234b830;hb=0f7b8280ad89fa88a0cbaab7785a4b7cb06e6a63;hpb=b301f03a61bc9f72b27940ca7fe1f8d0b343a4e2;ds=sidebyside diff --git a/ginac/normal.cpp b/ginac/normal.cpp index 8e2ec00e..2603086a 100644 --- a/ginac/normal.cpp +++ b/ginac/normal.cpp @@ -1684,12 +1684,12 @@ ex sqrfree_parfrac(const ex & a, const symbol & x) * assigned symbol). The symbol and expression are appended to repl, for * a later application of subs(). * @see ex::normal */ -static ex replace_with_symbol(const ex & e, exmap & repl) +static ex replace_with_symbol(const ex & e, exmap & repl, exmap & rev_lookup) { - // Expression already in repl? Then return the assigned symbol - for (exmap::const_iterator it = repl.begin(); it != repl.end(); ++it) - if (it->second.is_equal(e)) - return it->first; + // Expression already replaced? Then return the assigned symbol + exmap::const_iterator it = rev_lookup.find(e); + if (it != rev_lookup.end()) + return it->second; // Otherwise create new symbol and add to list, taking care that the // replacement expression doesn't itself contain symbols from repl, @@ -1697,6 +1697,7 @@ static ex replace_with_symbol(const ex & e, exmap & repl) ex es = (new symbol)->setflag(status_flags::dynallocated); ex e_replaced = e.subs(repl, subs_options::no_pattern); repl[es] = e_replaced; + rev_lookup[e_replaced] = es; return es; } @@ -1732,18 +1733,18 @@ struct normal_map_function : public map_function { /** Default implementation of ex::normal(). It normalizes the children and * replaces the object with a temporary symbol. * @see ex::normal */ -ex basic::normal(exmap & repl, int level) const +ex basic::normal(exmap & repl, exmap & rev_lookup, int level) const { if (nops() == 0) - return (new lst(replace_with_symbol(*this, repl), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(*this, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated); else { if (level == 1) - return (new lst(replace_with_symbol(*this, repl), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(*this, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated); else if (level == -max_recursion_level) throw(std::runtime_error("max recursion level reached")); else { normal_map_function map_normal(level - 1); - return (new lst(replace_with_symbol(map(map_normal), repl), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(map(map_normal), repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated); } } } @@ -1751,7 +1752,7 @@ ex basic::normal(exmap & repl, int level) const /** Implementation of ex::normal() for symbols. This returns the unmodified symbol. * @see ex::normal */ -ex symbol::normal(exmap & repl, int level) const +ex symbol::normal(exmap & repl, exmap & rev_lookup, int level) const { return (new lst(*this, _ex1))->setflag(status_flags::dynallocated); } @@ -1761,19 +1762,19 @@ ex symbol::normal(exmap & repl, int level) const * into re+I*im and replaces I and non-rational real numbers with a temporary * symbol. * @see ex::normal */ -ex numeric::normal(exmap & repl, int level) const +ex numeric::normal(exmap & repl, exmap & rev_lookup, int level) const { numeric num = numer(); ex numex = num; if (num.is_real()) { if (!num.is_integer()) - numex = replace_with_symbol(numex, repl); + numex = replace_with_symbol(numex, repl, rev_lookup); } else { // complex numeric re = num.real(), im = num.imag(); - ex re_ex = re.is_rational() ? re : replace_with_symbol(re, repl); - ex im_ex = im.is_rational() ? im : replace_with_symbol(im, repl); - numex = re_ex + im_ex * replace_with_symbol(I, repl); + ex re_ex = re.is_rational() ? re : replace_with_symbol(re, repl, rev_lookup); + ex im_ex = im.is_rational() ? im : replace_with_symbol(im, repl, rev_lookup); + numex = re_ex + im_ex * replace_with_symbol(I, repl, rev_lookup); } // Denominator is always a real integer (see numeric::denom()) @@ -1845,10 +1846,10 @@ static ex frac_cancel(const ex &n, const ex &d) /** Implementation of ex::normal() for a sum. It expands terms and performs * fractional addition. * @see ex::normal */ -ex add::normal(exmap & repl, int level) const +ex add::normal(exmap & repl, exmap & rev_lookup, int level) const { if (level == 1) - return (new lst(replace_with_symbol(*this, repl), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(*this, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated); else if (level == -max_recursion_level) throw(std::runtime_error("max recursion level reached")); @@ -1858,12 +1859,12 @@ ex add::normal(exmap & repl, int level) const dens.reserve(seq.size()+1); epvector::const_iterator it = seq.begin(), itend = seq.end(); while (it != itend) { - ex n = ex_to(recombine_pair_to_ex(*it)).normal(repl, level-1); + ex n = ex_to(recombine_pair_to_ex(*it)).normal(repl, rev_lookup, level-1); nums.push_back(n.op(0)); dens.push_back(n.op(1)); it++; } - ex n = ex_to(overall_coeff).normal(repl, level-1); + ex n = ex_to(overall_coeff).normal(repl, rev_lookup, level-1); nums.push_back(n.op(0)); dens.push_back(n.op(1)); GINAC_ASSERT(nums.size() == dens.size()); @@ -1904,10 +1905,10 @@ ex add::normal(exmap & repl, int level) const /** Implementation of ex::normal() for a product. It cancels common factors * from fractions. * @see ex::normal() */ -ex mul::normal(exmap & repl, int level) const +ex mul::normal(exmap & repl, exmap & rev_lookup, int level) const { if (level == 1) - return (new lst(replace_with_symbol(*this, repl), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(*this, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated); else if (level == -max_recursion_level) throw(std::runtime_error("max recursion level reached")); @@ -1917,12 +1918,12 @@ ex mul::normal(exmap & repl, int level) const ex n; epvector::const_iterator it = seq.begin(), itend = seq.end(); while (it != itend) { - n = ex_to(recombine_pair_to_ex(*it)).normal(repl, level-1); + n = ex_to(recombine_pair_to_ex(*it)).normal(repl, rev_lookup, level-1); num.push_back(n.op(0)); den.push_back(n.op(1)); it++; } - n = ex_to(overall_coeff).normal(repl, level-1); + n = ex_to(overall_coeff).normal(repl, rev_lookup, level-1); num.push_back(n.op(0)); den.push_back(n.op(1)); @@ -1932,20 +1933,20 @@ ex mul::normal(exmap & repl, int level) const } -/** Implementation of ex::normal() for powers. It normalizes the basis, +/** Implementation of ex::normal([B) for powers. It normalizes the basis, * distributes integer exponents to numerator and denominator, and replaces * non-integer powers by temporary symbols. * @see ex::normal */ -ex power::normal(exmap & repl, int level) const +ex power::normal(exmap & repl, exmap & rev_lookup, int level) const { if (level == 1) - return (new lst(replace_with_symbol(*this, repl), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(*this, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated); else if (level == -max_recursion_level) throw(std::runtime_error("max recursion level reached")); // Normalize basis and exponent (exponent gets reassembled) - ex n_basis = ex_to(basis).normal(repl, level-1); - ex n_exponent = ex_to(exponent).normal(repl, level-1); + ex n_basis = ex_to(basis).normal(repl, rev_lookup, level-1); + ex n_exponent = ex_to(exponent).normal(repl, rev_lookup, level-1); n_exponent = n_exponent.op(0) / n_exponent.op(1); if (n_exponent.info(info_flags::integer)) { @@ -1966,32 +1967,32 @@ ex power::normal(exmap & repl, int level) const if (n_exponent.info(info_flags::positive)) { // (a/b)^x -> {sym((a/b)^x), 1} - return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1), n_exponent), repl), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1), n_exponent), repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated); } else if (n_exponent.info(info_flags::negative)) { if (n_basis.op(1).is_equal(_ex1)) { // a^-x -> {1, sym(a^x)} - return (new lst(_ex1, replace_with_symbol(power(n_basis.op(0), -n_exponent), repl)))->setflag(status_flags::dynallocated); + return (new lst(_ex1, replace_with_symbol(power(n_basis.op(0), -n_exponent), repl, rev_lookup)))->setflag(status_flags::dynallocated); } else { // (a/b)^-x -> {sym((b/a)^x), 1} - return (new lst(replace_with_symbol(power(n_basis.op(1) / n_basis.op(0), -n_exponent), repl), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(power(n_basis.op(1) / n_basis.op(0), -n_exponent), repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated); } } } // (a/b)^x -> {sym((a/b)^x, 1} - return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1), n_exponent), repl), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1), n_exponent), repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated); } /** Implementation of ex::normal() for pseries. It normalizes each coefficient * and replaces the series by a temporary symbol. * @see ex::normal */ -ex pseries::normal(exmap & repl, int level) const +ex pseries::normal(exmap & repl, exmap & rev_lookup, int level) const { epvector newseq; epvector::const_iterator i = seq.begin(), end = seq.end(); @@ -2002,7 +2003,7 @@ ex pseries::normal(exmap & repl, int level) const ++i; } ex n = pseries(relational(var,point), newseq); - return (new lst(replace_with_symbol(n, repl), _ex1))->setflag(status_flags::dynallocated); + return (new lst(replace_with_symbol(n, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated); } @@ -2020,9 +2021,9 @@ ex pseries::normal(exmap & repl, int level) const * @return normalized expression */ ex ex::normal(int level) const { - exmap repl; + exmap repl, rev_lookup; - ex e = bp->normal(repl, level); + ex e = bp->normal(repl, rev_lookup, level); GINAC_ASSERT(is_a(e)); // Re-insert replaced symbols @@ -2041,9 +2042,9 @@ ex ex::normal(int level) const * @return numerator */ ex ex::numer() const { - exmap repl; + exmap repl, rev_lookup; - ex e = bp->normal(repl, 0); + ex e = bp->normal(repl, rev_lookup, 0); GINAC_ASSERT(is_a(e)); // Re-insert replaced symbols @@ -2061,9 +2062,9 @@ ex ex::numer() const * @return denominator */ ex ex::denom() const { - exmap repl; + exmap repl, rev_lookup; - ex e = bp->normal(repl, 0); + ex e = bp->normal(repl, rev_lookup, 0); GINAC_ASSERT(is_a(e)); // Re-insert replaced symbols @@ -2081,9 +2082,9 @@ ex ex::denom() const * @return a list [numerator, denominator] */ ex ex::numer_denom() const { - exmap repl; + exmap repl, rev_lookup; - ex e = bp->normal(repl, 0); + ex e = bp->normal(repl, rev_lookup, 0); GINAC_ASSERT(is_a(e)); // Re-insert replaced symbols