synced to 1.1 (expand() problem)
authorChristian Bauer <Christian.Bauer@uni-mainz.de>
Thu, 11 Dec 2003 20:34:04 +0000 (20:34 +0000)
committerChristian Bauer <Christian.Bauer@uni-mainz.de>
Thu, 11 Dec 2003 20:34:04 +0000 (20:34 +0000)
check/exam_paranoia.cpp
ginac/mul.cpp
ginac/power.cpp

index c1df502c9282282311ee536513d326219c23be96..1104ebe9ed3c9ffc47c4358e9e40ce0a823e9772 100644 (file)
@@ -370,6 +370,41 @@ static unsigned exam_paranoia15()
        return result;
 }
 
+// Expanding products containing powers of sums could return results that
+// were not fully expanded. Fixed on Dec 10, 2003.
+static unsigned exam_paranoia16()
+{
+       unsigned result = 0;
+       symbol a("a"), b("b"), c("c"), d("d");
+       ex e1, e2, e3;
+
+       e1 = pow(1+a*sqrt(b+c), 2);
+       e2 = e1.expand();
+
+       if (e2.has(pow(a, 2)*(b+c))) {
+               clog << "expand(" << e1 << ") didn't fully expand\n";
+               ++result;
+       }
+
+       e1 = (d*sqrt(a+b)+a*sqrt(c+d))*(b*sqrt(a+b)+a*sqrt(c+d));
+       e2 = e1.expand();
+
+       if (e2.has(pow(a, 2)*(c+d))) {
+               clog << "expand(" << e1 << ") didn't fully expand\n";
+               ++result;
+       }
+
+       e1 = (a+sqrt(b+c))*sqrt(b+c)*(d+sqrt(b+c));
+       e2 = e1.expand();
+
+       if (e2.has(a*(b+c))) {
+               clog << "expand(" << e1 << ") didn't fully expand\n";
+               ++result;
+       }
+
+       return result;
+}
+
 unsigned exam_paranoia()
 {
        unsigned result = 0;
@@ -392,6 +427,7 @@ unsigned exam_paranoia()
        result += exam_paranoia13();  cout << '.' << flush;
        result += exam_paranoia14();  cout << '.' << flush;
        result += exam_paranoia15();  cout << '.' << flush;
+       result += exam_paranoia16();  cout << '.' << flush;
        
        if (!result) {
                cout << " passed " << endl;
index 6f99c515062ff86c64a253ae9ce80adf512680bb..db68b275ffbbc788a476bad5b9b9f814b30229ff 100644 (file)
@@ -835,8 +835,11 @@ ex mul::expand(unsigned options) const
        // not sums
        int number_of_adds = 0;
        ex last_expanded = _ex1;
+
        epvector non_adds;
        non_adds.reserve(expanded_seq.size());
+       bool non_adds_has_sums = false; // Look for sums or powers of sums in the non_adds (we need this later)
+
        epvector::const_iterator cit = expanded_seq.begin(), last = expanded_seq.end();
        while (cit != last) {
                if (is_exactly_a<add>(cit->rest) &&
@@ -886,7 +889,7 @@ ex mul::expand(unsigned options) const
                                        for (epvector::const_iterator i2=add2begin; i2!=add2end; ++i2) {
                                                // Don't push_back expairs which might have a rest that evaluates to a numeric,
                                                // since that would violate an invariant of expairseq:
-                                               const ex rest = (new mul(i1->rest, i2->rest))->setflag(status_flags::dynallocated);
+                                               const ex rest = ex((new mul(i1->rest, i2->rest))->setflag(status_flags::dynallocated)).expand();
                                                if (is_exactly_a<numeric>(rest))
                                                        oc += ex_to<numeric>(rest).mul(ex_to<numeric>(i1->coeff).mul(ex_to<numeric>(i2->coeff)));
                                                else
@@ -901,23 +904,36 @@ ex mul::expand(unsigned options) const
                                last_expanded = cit->rest;
                        }
                } else {
+                       if (is_exactly_a<add>(cit->rest))
+                               non_adds_has_sums = true;
                        non_adds.push_back(*cit);
                }
                ++cit;
        }
-       
+
        // Now the only remaining thing to do is to multiply the factors which
        // were not sums into the "last_expanded" sum
        if (is_exactly_a<add>(last_expanded)) {
                const add & finaladd = ex_to<add>(last_expanded);
-               exvector distrseq;
+
                size_t n = finaladd.nops();
+               exvector distrseq;
                distrseq.reserve(n);
+
                for (size_t i=0; i<n; ++i) {
                        epvector factors = non_adds;
-                       factors.push_back(split_ex_to_pair(finaladd.op(i)));
-                       distrseq.push_back((new mul(factors, overall_coeff))->
-                                           setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));
+                       expair new_factor = split_ex_to_pair(finaladd.op(i).expand());
+                       factors.push_back(new_factor);
+
+                       const mul & term = static_cast<const mul &>((new mul(factors, overall_coeff))->setflag(status_flags::dynallocated));
+
+                       // The new term may have sums in it if e.g. a sqrt() of a sum in
+                       // the non_adds meets a sqrt() of a sum in the factor from
+                       // last_expanded. In this case we should re-expand the term.
+                       if (non_adds_has_sums || is_exactly_a<add>(new_factor.rest))
+                               distrseq.push_back(ex(term).expand());
+                       else
+                               distrseq.push_back(term.setflag(options == 0 ? status_flags::expanded : 0));
                }
                return ((new add(distrseq))->
                        setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));
index b38b733cb575667e0047fdad10383ccd3b53af0d..a45bf6afaf02aea954ebd6f2d6c0fe44476b3465 100644 (file)
@@ -839,6 +839,8 @@ ex power::expand_mul(const mul & m, const numeric & n) const
 
        epvector distrseq;
        distrseq.reserve(m.seq.size());
+       bool need_reexpand = false;
+
        epvector::const_iterator last = m.seq.end();
        epvector::const_iterator cit = m.seq.begin();
        while (cit!=last) {
@@ -847,11 +849,22 @@ ex power::expand_mul(const mul & m, const numeric & n) const
                } else {
                        // it is safe not to call mul::combine_pair_with_coeff_to_pair()
                        // since n is an integer
-                       distrseq.push_back(expair(cit->rest, ex_to<numeric>(cit->coeff).mul(n)));
+                       numeric new_coeff = ex_to<numeric>(cit->coeff).mul(n);
+                       if (is_exactly_a<add>(cit->rest) && new_coeff.is_pos_integer()) {
+                               // this happens when e.g. (a+b)^(1/2) gets squared and
+                               // the resulting product needs to be reexpanded
+                               need_reexpand = true;
+                       }
+                       distrseq.push_back(expair(cit->rest, new_coeff));
                }
                ++cit;
        }
-       return (new mul(distrseq, ex_to<numeric>(m.overall_coeff).power_dyn(n)))->setflag(status_flags::dynallocated);
+
+       const mul & result = static_cast<const mul &>((new mul(distrseq, ex_to<numeric>(m.overall_coeff).power_dyn(n)))->setflag(status_flags::dynallocated));
+       if (need_reexpand)
+               return ex(result).expand();
+       else
+               return result.setflag(status_flags::expanded);
 }
 
 } // namespace GiNaC