From 87d731b215909cc8ab8ecdb8c05fcd717bf63fd2 Mon Sep 17 00:00:00 2001 From: Christian Bauer Date: Thu, 11 Dec 2003 20:34:04 +0000 Subject: [PATCH] synced to 1.1 (expand() problem) --- check/exam_paranoia.cpp | 36 ++++++++++++++++++++++++++++++++++++ ginac/mul.cpp | 28 ++++++++++++++++++++++------ ginac/power.cpp | 17 +++++++++++++++-- 3 files changed, 73 insertions(+), 8 deletions(-) diff --git a/check/exam_paranoia.cpp b/check/exam_paranoia.cpp index c1df502c..1104ebe9 100644 --- a/check/exam_paranoia.cpp +++ b/check/exam_paranoia.cpp @@ -370,6 +370,41 @@ static unsigned exam_paranoia15() return result; } +// Expanding products containing powers of sums could return results that +// were not fully expanded. Fixed on Dec 10, 2003. +static unsigned exam_paranoia16() +{ + unsigned result = 0; + symbol a("a"), b("b"), c("c"), d("d"); + ex e1, e2, e3; + + e1 = pow(1+a*sqrt(b+c), 2); + e2 = e1.expand(); + + if (e2.has(pow(a, 2)*(b+c))) { + clog << "expand(" << e1 << ") didn't fully expand\n"; + ++result; + } + + e1 = (d*sqrt(a+b)+a*sqrt(c+d))*(b*sqrt(a+b)+a*sqrt(c+d)); + e2 = e1.expand(); + + if (e2.has(pow(a, 2)*(c+d))) { + clog << "expand(" << e1 << ") didn't fully expand\n"; + ++result; + } + + e1 = (a+sqrt(b+c))*sqrt(b+c)*(d+sqrt(b+c)); + e2 = e1.expand(); + + if (e2.has(a*(b+c))) { + clog << "expand(" << e1 << ") didn't fully expand\n"; + ++result; + } + + return result; +} + unsigned exam_paranoia() { unsigned result = 0; @@ -392,6 +427,7 @@ unsigned exam_paranoia() result += exam_paranoia13(); cout << '.' << flush; result += exam_paranoia14(); cout << '.' << flush; result += exam_paranoia15(); cout << '.' << flush; + result += exam_paranoia16(); cout << '.' << flush; if (!result) { cout << " passed " << endl; diff --git a/ginac/mul.cpp b/ginac/mul.cpp index 6f99c515..db68b275 100644 --- a/ginac/mul.cpp +++ b/ginac/mul.cpp @@ -835,8 +835,11 @@ ex mul::expand(unsigned options) const // not sums int number_of_adds = 0; ex last_expanded = _ex1; + epvector non_adds; non_adds.reserve(expanded_seq.size()); + bool non_adds_has_sums = false; // Look for sums or powers of sums in the non_adds (we need this later) + epvector::const_iterator cit = expanded_seq.begin(), last = expanded_seq.end(); while (cit != last) { if (is_exactly_a(cit->rest) && @@ -886,7 +889,7 @@ ex mul::expand(unsigned options) const for (epvector::const_iterator i2=add2begin; i2!=add2end; ++i2) { // Don't push_back expairs which might have a rest that evaluates to a numeric, // since that would violate an invariant of expairseq: - const ex rest = (new mul(i1->rest, i2->rest))->setflag(status_flags::dynallocated); + const ex rest = ex((new mul(i1->rest, i2->rest))->setflag(status_flags::dynallocated)).expand(); if (is_exactly_a(rest)) oc += ex_to(rest).mul(ex_to(i1->coeff).mul(ex_to(i2->coeff))); else @@ -901,23 +904,36 @@ ex mul::expand(unsigned options) const last_expanded = cit->rest; } } else { + if (is_exactly_a(cit->rest)) + non_adds_has_sums = true; non_adds.push_back(*cit); } ++cit; } - + // Now the only remaining thing to do is to multiply the factors which // were not sums into the "last_expanded" sum if (is_exactly_a(last_expanded)) { const add & finaladd = ex_to(last_expanded); - exvector distrseq; + size_t n = finaladd.nops(); + exvector distrseq; distrseq.reserve(n); + for (size_t i=0; i - setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0))); + expair new_factor = split_ex_to_pair(finaladd.op(i).expand()); + factors.push_back(new_factor); + + const mul & term = static_cast((new mul(factors, overall_coeff))->setflag(status_flags::dynallocated)); + + // The new term may have sums in it if e.g. a sqrt() of a sum in + // the non_adds meets a sqrt() of a sum in the factor from + // last_expanded. In this case we should re-expand the term. + if (non_adds_has_sums || is_exactly_a(new_factor.rest)) + distrseq.push_back(ex(term).expand()); + else + distrseq.push_back(term.setflag(options == 0 ? status_flags::expanded : 0)); } return ((new add(distrseq))-> setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0))); diff --git a/ginac/power.cpp b/ginac/power.cpp index b38b733c..a45bf6af 100644 --- a/ginac/power.cpp +++ b/ginac/power.cpp @@ -839,6 +839,8 @@ ex power::expand_mul(const mul & m, const numeric & n) const epvector distrseq; distrseq.reserve(m.seq.size()); + bool need_reexpand = false; + epvector::const_iterator last = m.seq.end(); epvector::const_iterator cit = m.seq.begin(); while (cit!=last) { @@ -847,11 +849,22 @@ ex power::expand_mul(const mul & m, const numeric & n) const } else { // it is safe not to call mul::combine_pair_with_coeff_to_pair() // since n is an integer - distrseq.push_back(expair(cit->rest, ex_to(cit->coeff).mul(n))); + numeric new_coeff = ex_to(cit->coeff).mul(n); + if (is_exactly_a(cit->rest) && new_coeff.is_pos_integer()) { + // this happens when e.g. (a+b)^(1/2) gets squared and + // the resulting product needs to be reexpanded + need_reexpand = true; + } + distrseq.push_back(expair(cit->rest, new_coeff)); } ++cit; } - return (new mul(distrseq, ex_to(m.overall_coeff).power_dyn(n)))->setflag(status_flags::dynallocated); + + const mul & result = static_cast((new mul(distrseq, ex_to(m.overall_coeff).power_dyn(n)))->setflag(status_flags::dynallocated)); + if (need_reexpand) + return ex(result).expand(); + else + return result.setflag(status_flags::expanded); } } // namespace GiNaC -- 2.44.0