]> www.ginac.de Git - ginac.git/blobdiff - ginac/mul.cpp
synced to 1.1 (expand() problem)
[ginac.git] / ginac / mul.cpp
index 6f99c515062ff86c64a253ae9ce80adf512680bb..db68b275ffbbc788a476bad5b9b9f814b30229ff 100644 (file)
@@ -835,8 +835,11 @@ ex mul::expand(unsigned options) const
        // not sums
        int number_of_adds = 0;
        ex last_expanded = _ex1;
+
        epvector non_adds;
        non_adds.reserve(expanded_seq.size());
+       bool non_adds_has_sums = false; // Look for sums or powers of sums in the non_adds (we need this later)
+
        epvector::const_iterator cit = expanded_seq.begin(), last = expanded_seq.end();
        while (cit != last) {
                if (is_exactly_a<add>(cit->rest) &&
@@ -886,7 +889,7 @@ ex mul::expand(unsigned options) const
                                        for (epvector::const_iterator i2=add2begin; i2!=add2end; ++i2) {
                                                // Don't push_back expairs which might have a rest that evaluates to a numeric,
                                                // since that would violate an invariant of expairseq:
-                                               const ex rest = (new mul(i1->rest, i2->rest))->setflag(status_flags::dynallocated);
+                                               const ex rest = ex((new mul(i1->rest, i2->rest))->setflag(status_flags::dynallocated)).expand();
                                                if (is_exactly_a<numeric>(rest))
                                                        oc += ex_to<numeric>(rest).mul(ex_to<numeric>(i1->coeff).mul(ex_to<numeric>(i2->coeff)));
                                                else
@@ -901,23 +904,36 @@ ex mul::expand(unsigned options) const
                                last_expanded = cit->rest;
                        }
                } else {
+                       if (is_exactly_a<add>(cit->rest))
+                               non_adds_has_sums = true;
                        non_adds.push_back(*cit);
                }
                ++cit;
        }
-       
+
        // Now the only remaining thing to do is to multiply the factors which
        // were not sums into the "last_expanded" sum
        if (is_exactly_a<add>(last_expanded)) {
                const add & finaladd = ex_to<add>(last_expanded);
-               exvector distrseq;
+
                size_t n = finaladd.nops();
+               exvector distrseq;
                distrseq.reserve(n);
+
                for (size_t i=0; i<n; ++i) {
                        epvector factors = non_adds;
-                       factors.push_back(split_ex_to_pair(finaladd.op(i)));
-                       distrseq.push_back((new mul(factors, overall_coeff))->
-                                           setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));
+                       expair new_factor = split_ex_to_pair(finaladd.op(i).expand());
+                       factors.push_back(new_factor);
+
+                       const mul & term = static_cast<const mul &>((new mul(factors, overall_coeff))->setflag(status_flags::dynallocated));
+
+                       // The new term may have sums in it if e.g. a sqrt() of a sum in
+                       // the non_adds meets a sqrt() of a sum in the factor from
+                       // last_expanded. In this case we should re-expand the term.
+                       if (non_adds_has_sums || is_exactly_a<add>(new_factor.rest))
+                               distrseq.push_back(ex(term).expand());
+                       else
+                               distrseq.push_back(term.setflag(options == 0 ? status_flags::expanded : 0));
                }
                return ((new add(distrseq))->
                        setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));