]> www.ginac.de Git - ginac.git/blobdiff - ginac/mul.cpp
some more comments and cleanups to mul::expand() and ncmul::expand()
[ginac.git] / ginac / mul.cpp
index 3dfe48ebd387bcf8a62671a34a82b1d63c4da310..92e21b3efa4c427f37e2d949a92185eef1d8aaf4 100644 (file)
@@ -131,11 +131,11 @@ void mul::print(const print_context & c, unsigned level) const
 {
        debugmsg("mul print", LOGLEVEL_PRINT);
 
-       if (is_of_type(c, print_tree)) {
+       if (is_a<print_tree>(c)) {
 
                inherited::print(c, level);
 
-       } else if (is_of_type(c, print_csrc)) {
+       } else if (is_a<print_csrc>(c)) {
 
                if (precedence() <= level)
                        c.s << "(";
@@ -150,8 +150,8 @@ void mul::print(const print_context & c, unsigned level) const
                while (it != itend) {
 
                        // If the first argument is a negative integer power, it gets printed as "1.0/<expr>"
-                       if (it == seq.begin() && ex_to_numeric(it->coeff).is_integer() && it->coeff.compare(_num0()) < 0) {
-                               if (is_of_type(c, print_csrc_cl_N))
+                       if (it == seq.begin() && ex_to<numeric>(it->coeff).is_integer() && it->coeff.compare(_num0()) < 0) {
+                               if (is_a<print_csrc_cl_N>(c))
                                        c.s << "recip(";
                                else
                                        c.s << "1.0/";
@@ -162,13 +162,13 @@ void mul::print(const print_context & c, unsigned level) const
                                it->rest.print(c, precedence());
                        else {
                                // Outer parens around ex needed for broken gcc-2.95 parser:
-                               (ex(power(it->rest, abs(ex_to_numeric(it->coeff))))).print(c, level);
+                               (ex(power(it->rest, abs(ex_to<numeric>(it->coeff))))).print(c, level);
                        }
 
                        // Separator is "/" for negative integer powers, "*" otherwise
                        ++it;
                        if (it != itend) {
-                               if (ex_to_numeric(it->coeff).is_integer() && it->coeff.compare(_num0()) < 0)
+                               if (ex_to<numeric>(it->coeff).is_integer() && it->coeff.compare(_num0()) < 0)
                                        c.s << "/";
                                else
                                        c.s << "*";
@@ -181,7 +181,7 @@ void mul::print(const print_context & c, unsigned level) const
        } else {
 
                if (precedence() <= level) {
-                       if (is_of_type(c, print_latex))
+                       if (is_a<print_latex>(c))
                                c.s << "{(";
                        else
                                c.s << "(";
@@ -190,7 +190,7 @@ void mul::print(const print_context & c, unsigned level) const
                bool first = true;
 
                // First print the overall numeric coefficient
-               numeric coeff = ex_to_numeric(overall_coeff);
+               numeric coeff = ex_to<numeric>(overall_coeff);
                if (coeff.csgn() == -1)
                        c.s << '-';
                if (!coeff.is_equal(_num1()) &&
@@ -206,7 +206,7 @@ void mul::print(const print_context & c, unsigned level) const
                                else
                                        coeff.print(c, precedence());
                        }
-                       if (is_of_type(c, print_latex))
+                       if (is_a<print_latex>(c))
                                c.s << ' ';
                        else
                                c.s << '*';
@@ -216,7 +216,7 @@ void mul::print(const print_context & c, unsigned level) const
                epvector::const_iterator it = seq.begin(), itend = seq.end();
                while (it != itend) {
                        if (!first) {
-                               if (is_of_type(c, print_latex))
+                               if (is_a<print_latex>(c))
                                        c.s << ' ';
                                else
                                        c.s << '*';
@@ -224,11 +224,11 @@ void mul::print(const print_context & c, unsigned level) const
                                first = false;
                        }
                        recombine_pair_to_ex(*it).print(c, precedence());
-                       it++;
+                       ++it;
                }
 
                if (precedence() <= level) {
-                       if (is_of_type(c, print_latex))
+                       if (is_a<print_latex>(c))
                                c.s << ")}";
                        else
                                c.s << ")";
@@ -245,16 +245,20 @@ bool mul::info(unsigned inf) const
                case info_flags::rational_polynomial:
                case info_flags::crational_polynomial:
                case info_flags::rational_function: {
-                       for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
+                       epvector::const_iterator i = seq.begin(), end = seq.end();
+                       while (i != end) {
                                if (!(recombine_pair_to_ex(*i).info(inf)))
                                        return false;
+                               ++i;
                        }
                        return overall_coeff.info(inf);
                }
                case info_flags::algebraic: {
-                       for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
+                       epvector::const_iterator i = seq.begin(), end = seq.end();
+                       while (i != end) {
                                if ((recombine_pair_to_ex(*i).info(inf)))
                                        return true;
+                               ++i;
                        }
                        return false;
                }
@@ -264,20 +268,26 @@ bool mul::info(unsigned inf) const
 
 int mul::degree(const ex & s) const
 {
+       // Sum up degrees of factors
        int deg_sum = 0;
-       for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
-               if (ex_to_numeric(cit->coeff).is_integer())
-                       deg_sum+=cit->rest.degree(s) * ex_to_numeric(cit->coeff).to_int();
+       epvector::const_iterator i = seq.begin(), end = seq.end();
+       while (i != end) {
+               if (ex_to<numeric>(i->coeff).is_integer())
+                       deg_sum += i->rest.degree(s) * ex_to<numeric>(i->coeff).to_int();
+               ++i;
        }
        return deg_sum;
 }
 
 int mul::ldegree(const ex & s) const
 {
+       // Sum up degrees of factors
        int deg_sum = 0;
-       for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
-               if (ex_to_numeric(cit->coeff).is_integer())
-                       deg_sum+=cit->rest.ldegree(s) * ex_to_numeric(cit->coeff).to_int();
+       epvector::const_iterator i = seq.begin(), end = seq.end();
+       while (i != end) {
+               if (ex_to<numeric>(i->coeff).is_integer())
+                       deg_sum += i->rest.ldegree(s) * ex_to<numeric>(i->coeff).to_int();
+               ++i;
        }
        return deg_sum;
 }
@@ -290,27 +300,27 @@ ex mul::coeff(const ex & s, int n) const
        if (n==0) {
                // product of individual coeffs
                // if a non-zero power of s is found, the resulting product will be 0
-               epvector::const_iterator it = seq.begin();
-               while (it!=seq.end()) {
-                       coeffseq.push_back(recombine_pair_to_ex(*it).coeff(s,n));
-                       ++it;
+               epvector::const_iterator i = seq.begin(), end = seq.end();
+               while (i != end) {
+                       coeffseq.push_back(recombine_pair_to_ex(*i).coeff(s,n));
+                       ++i;
                }
                coeffseq.push_back(overall_coeff);
                return (new mul(coeffseq))->setflag(status_flags::dynallocated);
        }
        
-       epvector::const_iterator it=seq.begin();
-       bool coeff_found = 0;
-       while (it!=seq.end()) {
-               ex t = recombine_pair_to_ex(*it);
-               ex c = t.coeff(s,n);
+       epvector::const_iterator i = seq.begin(), end = seq.end();
+       bool coeff_found = false;
+       while (i != end) {
+               ex t = recombine_pair_to_ex(*i);
+               ex c = t.coeff(s, n);
                if (!c.is_zero()) {
                        coeffseq.push_back(c);
                        coeff_found = 1;
                } else {
                        coeffseq.push_back(t);
                }
-               ++it;
+               ++i;
        }
        if (coeff_found) {
                coeffseq.push_back(overall_coeff);
@@ -329,26 +339,28 @@ ex mul::eval(int level) const
        
        debugmsg("mul eval",LOGLEVEL_MEMBER_FUNCTION);
        
-       epvector * evaled_seqp = evalchildren(level);
-       if (evaled_seqp!=0) {
+       epvector *evaled_seqp = evalchildren(level);
+       if (evaled_seqp) {
                // do more evaluation later
                return (new mul(evaled_seqp,overall_coeff))->
                           setflag(status_flags::dynallocated);
        }
        
 #ifdef DO_GINAC_ASSERT
-       for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
-               GINAC_ASSERT((!is_ex_exactly_of_type((*cit).rest,mul)) ||
-                            (!(ex_to_numeric((*cit).coeff).is_integer())));
-               GINAC_ASSERT(!(cit->is_canonical_numeric()));
-               if (is_ex_exactly_of_type(recombine_pair_to_ex(*cit),numeric))
+       epvector::const_iterator i = seq.begin(), end = seq.end();
+       while (i != end) {
+               GINAC_ASSERT((!is_ex_exactly_of_type(i->rest, mul)) ||
+                            (!(ex_to<numeric>(i->coeff).is_integer())));
+               GINAC_ASSERT(!(i->is_canonical_numeric()));
+               if (is_ex_exactly_of_type(recombine_pair_to_ex(*i), numeric))
                    print(print_tree(std::cerr));
-               GINAC_ASSERT(!is_ex_exactly_of_type(recombine_pair_to_ex(*cit),numeric));
+               GINAC_ASSERT(!is_ex_exactly_of_type(recombine_pair_to_ex(*i), numeric));
                /* for paranoia */
-               expair p = split_ex_to_pair(recombine_pair_to_ex(*cit));
-               GINAC_ASSERT(p.rest.is_equal((*cit).rest));
-               GINAC_ASSERT(p.coeff.is_equal((*cit).coeff));
+               expair p = split_ex_to_pair(recombine_pair_to_ex(*i));
+               GINAC_ASSERT(p.rest.is_equal(i->rest));
+               GINAC_ASSERT(p.coeff.is_equal(i->coeff));
                /* end paranoia */
+               ++i;
        }
 #endif // def DO_GINAC_ASSERT
        
@@ -359,7 +371,7 @@ ex mul::eval(int level) const
        }
        
        int seq_size = seq.size();
-       if (overall_coeff.is_equal(_ex0())) {
+       if (overall_coeff.is_zero()) {
                // *(...,x;0) -> 0
                return _ex0();
        } else if (seq_size==0) {
@@ -370,17 +382,19 @@ ex mul::eval(int level) const
                return recombine_pair_to_ex(*(seq.begin()));
        } else if ((seq_size==1) &&
                   is_ex_exactly_of_type((*seq.begin()).rest,add) &&
-                  ex_to_numeric((*seq.begin()).coeff).is_equal(_num1())) {
+                  ex_to<numeric>((*seq.begin()).coeff).is_equal(_num1())) {
                // *(+(x,y,...);c) -> +(*(x,c),*(y,c),...) (c numeric(), no powers of +())
-               const add & addref = ex_to_add((*seq.begin()).rest);
-               epvector distrseq;
-               distrseq.reserve(addref.seq.size());
-               for (epvector::const_iterator cit=addref.seq.begin(); cit!=addref.seq.end(); ++cit) {
-                       distrseq.push_back(addref.combine_pair_with_coeff_to_pair(*cit, overall_coeff));
+               const add & addref = ex_to<add>((*seq.begin()).rest);
+               epvector *distrseq = new epvector();
+               distrseq->reserve(addref.seq.size());
+               epvector::const_iterator i = addref.seq.begin(), end = addref.seq.end();
+               while (i != end) {
+                       distrseq->push_back(addref.combine_pair_with_coeff_to_pair(*i, overall_coeff));
+                       ++i;
                }
                return (new add(distrseq,
-                               ex_to_numeric(addref.overall_coeff).
-                               mul_dyn(ex_to_numeric(overall_coeff))))
+                               ex_to<numeric>(addref.overall_coeff).
+                               mul_dyn(ex_to<numeric>(overall_coeff))))
                      ->setflag(status_flags::dynallocated | status_flags::evaluated);
        }
        return this->hold();
@@ -394,15 +408,17 @@ ex mul::evalf(int level) const
        if (level==-max_recursion_level)
                throw(std::runtime_error("max recursion level reached"));
        
-       epvector s;
-       s.reserve(seq.size());
-       
+       epvector *s = new epvector();
+       s->reserve(seq.size());
+
        --level;
-       for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
-               s.push_back(combine_ex_with_coeff_to_pair((*it).rest.evalf(level),
-                                                         (*it).coeff));
+       epvector::const_iterator i = seq.begin(), end = seq.end();
+       while (i != end) {
+               s->push_back(combine_ex_with_coeff_to_pair(i->rest.evalf(level),
+                                                          i->coeff));
+               ++i;
        }
-       return mul(s,overall_coeff.evalf(level));
+       return mul(s, overall_coeff.evalf(level));
 }
 
 ex mul::evalm(void) const
@@ -410,7 +426,7 @@ ex mul::evalm(void) const
        // numeric*matrix
        if (seq.size() == 1 && seq[0].coeff.is_equal(_ex1())
         && is_ex_of_type(seq[0].rest, matrix))
-               return ex_to_matrix(seq[0].rest).mul(ex_to_numeric(overall_coeff));
+               return ex_to<matrix>(seq[0].rest).mul(ex_to<numeric>(overall_coeff));
 
        // Evaluate children first, look whether there are any matrices at all
        // (there can be either no matrices or one matrix; if there were more
@@ -421,22 +437,22 @@ ex mul::evalm(void) const
        bool have_matrix = false;
        epvector::iterator the_matrix;
 
-       epvector::const_iterator it = seq.begin(), itend = seq.end();
-       while (it != itend) {
-               const ex &m = recombine_pair_to_ex(*it).evalm();
+       epvector::const_iterator i = seq.begin(), end = seq.end();
+       while (i != end) {
+               const ex &m = recombine_pair_to_ex(*i).evalm();
                s->push_back(split_ex_to_pair(m));
                if (is_ex_of_type(m, matrix)) {
                        have_matrix = true;
                        the_matrix = s->end() - 1;
                }
-               it++;
+               ++i;
        }
 
        if (have_matrix) {
 
                // The product contained a matrix. We will multiply all other factors
                // into that matrix.
-               matrix m = ex_to_matrix(the_matrix->rest);
+               matrix m = ex_to<matrix>(the_matrix->rest);
                s->erase(the_matrix);
                ex scalar = (new mul(s, overall_coeff))->setflag(status_flags::dynallocated);
                return m.mul_scalar(scalar);
@@ -447,14 +463,15 @@ ex mul::evalm(void) const
 
 ex mul::simplify_ncmul(const exvector & v) const
 {
-       if (seq.size()==0) {
+       if (seq.empty())
                return inherited::simplify_ncmul(v);
-       }
 
        // Find first noncommutative element and call its simplify_ncmul()
-       for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
-               if (cit->rest.return_type() == return_types::noncommutative)
-                       return cit->rest.simplify_ncmul(v);
+       epvector::const_iterator i = seq.begin(), end = seq.end();
+       while (i != end) {
+               if (i->rest.return_type() == return_types::noncommutative)
+                       return i->rest.simplify_ncmul(v);
+               ++i;
        }
        return inherited::simplify_ncmul(v);
 }
@@ -465,15 +482,21 @@ ex mul::simplify_ncmul(const exvector & v) const
  *  @see ex::diff */
 ex mul::derivative(const symbol & s) const
 {
+       unsigned num = seq.size();
        exvector addseq;
-       addseq.reserve(seq.size());
+       addseq.reserve(num);
        
        // D(a*b*c) = D(a)*b*c + a*D(b)*c + a*b*D(c)
-       for (unsigned i=0; i!=seq.size(); ++i) {
-               epvector mulseq = seq;
-               mulseq[i] = split_ex_to_pair(power(seq[i].rest,seq[i].coeff - _ex1()) *
-                                            seq[i].rest.diff(s));
-               addseq.push_back((new mul(mulseq,overall_coeff*seq[i].coeff))->setflag(status_flags::dynallocated));
+       epvector mulseq = seq;
+       epvector::const_iterator i = seq.begin(), end = seq.end();
+       epvector::iterator i2 = mulseq.begin();
+       while (i != end) {
+               expair ep = split_ex_to_pair(power(i->rest, i->coeff - _ex1()) *
+                                            i->rest.diff(s));
+               ep.swap(*i2);
+               addseq.push_back((new mul(mulseq, overall_coeff * i->coeff))->setflag(status_flags::dynallocated));
+               ep.swap(*i2);
+               ++i; ++i2;
        }
        return (new add(addseq))->setflag(status_flags::dynallocated);
 }
@@ -490,30 +513,32 @@ bool mul::is_equal_same_type(const basic & other) const
 
 unsigned mul::return_type(void) const
 {
-       if (seq.size()==0) {
+       if (seq.empty()) {
                // mul without factors: should not happen, but commutes
                return return_types::commutative;
        }
        
-       bool all_commutative = 1;
-       unsigned rt;
-       epvector::const_iterator cit_noncommutative_element; // point to first found nc element
+       bool all_commutative = true;
+       epvector::const_iterator noncommutative_element; // point to first found nc element
        
-       for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
-               rt=(*cit).rest.return_type();
-               if (rt==return_types::noncommutative_composite) return rt; // one ncc -> mul also ncc
-               if ((rt==return_types::noncommutative)&&(all_commutative)) {
+       epvector::const_iterator i = seq.begin(), end = seq.end();
+       while (i != end) {
+               unsigned rt = i->rest.return_type();
+               if (rt == return_types::noncommutative_composite)
+                       return rt; // one ncc -> mul also ncc
+               if ((rt == return_types::noncommutative) && (all_commutative)) {
                        // first nc element found, remember position
-                       cit_noncommutative_element = cit;
-                       all_commutative = 0;
+                       noncommutative_element = i;
+                       all_commutative = false;
                }
-               if ((rt==return_types::noncommutative)&&(!all_commutative)) {
+               if ((rt == return_types::noncommutative) && (!all_commutative)) {
                        // another nc element found, compare type_infos
-                       if ((*cit_noncommutative_element).rest.return_type_tinfo()!=(*cit).rest.return_type_tinfo()) {
+                       if (noncommutative_element->rest.return_type_tinfo() != i->rest.return_type_tinfo()) {
                                // diffent types -> mul is ncc
                                return return_types::noncommutative_composite;
                        }
                }
+               ++i;
        }
        // all factors checked
        return all_commutative ? return_types::commutative : return_types::noncommutative;
@@ -521,13 +546,15 @@ unsigned mul::return_type(void) const
    
 unsigned mul::return_type_tinfo(void) const
 {
-       if (seq.size()==0)
+       if (seq.empty())
                return tinfo_key;  // mul without factors: should not happen
        
        // return type_info of first noncommutative element
-       for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
-               if ((*cit).rest.return_type()==return_types::noncommutative)
-                       return (*cit).rest.return_type_tinfo();
+       epvector::const_iterator i = seq.begin(), end = seq.end();
+       while (i != end) {
+               if (i->rest.return_type() == return_types::noncommutative)
+                       return i->rest.return_type_tinfo();
+               ++i;
        }
        // no noncommutative element found, should not happen
        return tinfo_key;
@@ -535,18 +562,18 @@ unsigned mul::return_type_tinfo(void) const
 
 ex mul::thisexpairseq(const epvector & v, const ex & oc) const
 {
-       return (new mul(v,oc))->setflag(status_flags::dynallocated);
+       return (new mul(v, oc))->setflag(status_flags::dynallocated);
 }
 
 ex mul::thisexpairseq(epvector * vp, const ex & oc) const
 {
-       return (new mul(vp,oc))->setflag(status_flags::dynallocated);
+       return (new mul(vp, oc))->setflag(status_flags::dynallocated);
 }
 
 expair mul::split_ex_to_pair(const ex & e) const
 {
        if (is_ex_exactly_of_type(e,power)) {
-               const power & powerref = ex_to_power(e);
+               const power & powerref = ex_to<power>(e);
                if (is_ex_exactly_of_type(powerref.exponent,numeric))
                        return expair(powerref.basis,powerref.exponent);
        }
@@ -581,7 +608,7 @@ expair mul::combine_pair_with_coeff_to_pair(const expair & p,
        
 ex mul::recombine_pair_to_ex(const expair & p) const
 {
-       if (ex_to_numeric(p.coeff).is_equal(_num1())) 
+       if (ex_to<numeric>(p.coeff).is_equal(_num1())) 
                return p.rest;
        else
                return power(p.rest,p.coeff);
@@ -590,7 +617,7 @@ ex mul::recombine_pair_to_ex(const expair & p) const
 bool mul::expair_needs_further_processing(epp it)
 {
        if (is_ex_exactly_of_type((*it).rest,mul) &&
-               ex_to_numeric((*it).coeff).is_integer()) {
+               ex_to<numeric>((*it).coeff).is_integer()) {
                // combined pair is product with integer power -> expand it
                *it = split_ex_to_pair(recombine_pair_to_ex(*it));
                return true;
@@ -602,7 +629,7 @@ bool mul::expair_needs_further_processing(epp it)
                        *it = ep;
                        return true;
                }
-               if (ex_to_numeric((*it).coeff).is_equal(_num1())) {
+               if (ex_to<numeric>((*it).coeff).is_equal(_num1())) {
                        // combined pair has coeff 1 and must be moved to the end
                        return true;
                }
@@ -619,7 +646,7 @@ void mul::combine_overall_coeff(const ex & c)
 {
        GINAC_ASSERT(is_ex_exactly_of_type(overall_coeff,numeric));
        GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
-       overall_coeff = ex_to_numeric(overall_coeff).mul_dyn(ex_to_numeric(c));
+       overall_coeff = ex_to<numeric>(overall_coeff).mul_dyn(ex_to<numeric>(c));
 }
 
 void mul::combine_overall_coeff(const ex & c1, const ex & c2)
@@ -627,7 +654,7 @@ void mul::combine_overall_coeff(const ex & c1, const ex & c2)
        GINAC_ASSERT(is_ex_exactly_of_type(overall_coeff,numeric));
        GINAC_ASSERT(is_ex_exactly_of_type(c1,numeric));
        GINAC_ASSERT(is_ex_exactly_of_type(c2,numeric));
-       overall_coeff = ex_to_numeric(overall_coeff).mul_dyn(ex_to_numeric(c1).power(ex_to_numeric(c2)));
+       overall_coeff = ex_to<numeric>(overall_coeff).mul_dyn(ex_to<numeric>(c1).power(ex_to<numeric>(c2)));
 }
 
 bool mul::can_make_flat(const expair & p) const
@@ -636,47 +663,44 @@ bool mul::can_make_flat(const expair & p) const
        // this assertion will probably fail somewhere
        // it would require a more careful make_flat, obeying the power laws
        // probably should return true only if p.coeff is integer
-       return ex_to_numeric(p.coeff).is_equal(_num1());
+       return ex_to<numeric>(p.coeff).is_equal(_num1());
 }
 
 ex mul::expand(unsigned options) const
 {
-       if (flags & status_flags::expanded)
-               return *this;
-       
-       exvector sub_expanded_seq;
-       
+       // First, expand the children
        epvector * expanded_seqp = expandchildren(options);
-       
-       const epvector & expanded_seq = expanded_seqp==0 ? seq : *expanded_seqp;
-       
+       const epvector & expanded_seq = (expanded_seqp == NULL) ? seq : *expanded_seqp;
+
+       // Now, look for all the factors that are sums and multiply each one out
+       // with the next one that is found while collecting the factors which are
+       // not sums
        int number_of_adds = 0;
+       ex last_expanded = _ex1();
        epvector non_adds;
        non_adds.reserve(expanded_seq.size());
-       epvector::const_iterator cit = expanded_seq.begin();
-       epvector::const_iterator last = expanded_seq.end();
-       ex last_expanded = _ex1();
-       while (cit!=last) {
-               if (is_ex_exactly_of_type((*cit).rest,add) &&
-                       ((*cit).coeff.is_equal(_ex1()))) {
+       epvector::const_iterator cit = expanded_seq.begin(), last = expanded_seq.end();
+       while (cit != last) {
+               if (is_ex_exactly_of_type(cit->rest, add) &&
+                       (cit->coeff.is_equal(_ex1()))) {
                        ++number_of_adds;
-                       if (is_ex_exactly_of_type(last_expanded,add)) {
-                               // expand adds
-                               const add & add1 = ex_to_add(last_expanded);
-                               const add & add2 = ex_to_add((*cit).rest);
+                       if (is_ex_exactly_of_type(last_expanded, add)) {
+                               const add & add1 = ex_to<add>(last_expanded);
+                               const add & add2 = ex_to<add>(cit->rest);
                                int n1 = add1.nops();
                                int n2 = add2.nops();
                                exvector distrseq;
                                distrseq.reserve(n1*n2);
                                for (int i1=0; i1<n1; ++i1) {
                                        for (int i2=0; i2<n2; ++i2) {
-                                               distrseq.push_back(add1.op(i1)*add2.op(i2));
+                                               distrseq.push_back(add1.op(i1) * add2.op(i2));
                                        }
                                }
-                               last_expanded = (new add(distrseq))->setflag(status_flags::dynallocated | status_flags::expanded);
+                               last_expanded = (new add(distrseq))->
+                                                setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
                        } else {
                                non_adds.push_back(split_ex_to_pair(last_expanded));
-                               last_expanded = (*cit).rest;
+                               last_expanded = cit->rest;
                        }
                } else {
                        non_adds.push_back(*cit);
@@ -686,22 +710,25 @@ ex mul::expand(unsigned options) const
        if (expanded_seqp)
                delete expanded_seqp;
 
-       if (is_ex_exactly_of_type(last_expanded,add)) {
-               add const & finaladd = ex_to_add(last_expanded);
+       // Now the only remaining thing to do is to multiply the factors which
+       // were not sums into the "last_expanded" sum
+       if (is_ex_exactly_of_type(last_expanded, add)) {
+               add const & finaladd = ex_to<add>(last_expanded);
                exvector distrseq;
                int n = finaladd.nops();
                distrseq.reserve(n);
                for (int i=0; i<n; ++i) {
                        epvector factors = non_adds;
                        factors.push_back(split_ex_to_pair(finaladd.op(i)));
-                       distrseq.push_back((new mul(factors,overall_coeff))->setflag(status_flags::dynallocated | status_flags::expanded));
+                       distrseq.push_back((new mul(factors, overall_coeff))->
+                                           setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));
                }
                return ((new add(distrseq))->
-                       setflag(status_flags::dynallocated | status_flags::expanded));
+                       setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));
        }
        non_adds.push_back(split_ex_to_pair(last_expanded));
-       return (new mul(non_adds,overall_coeff))->
-               setflag(status_flags::dynallocated | status_flags::expanded);
+       return (new mul(non_adds, overall_coeff))->
+               setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
 }