]> www.ginac.de Git - ginac.git/blobdiff - ginac/add.cpp
- do something about the mad cast disease.
[ginac.git] / ginac / add.cpp
index 1b7b2719d3a57c63f104e9f684f62d1983c60f09..647de1722c7065c2d735f7093b187880d42c4203 100644 (file)
@@ -25,6 +25,7 @@
 
 #include "add.h"
 #include "mul.h"
+#include "matrix.h"
 #include "archive.h"
 #include "debugmsg.h"
 #include "utils.h"
@@ -115,13 +116,13 @@ void add::print(const print_context & c, unsigned level) const
 {
        debugmsg("add print", LOGLEVEL_PRINT);
 
-       if (is_of_type(c, print_tree)) {
+       if (is_a<print_tree>(c)) {
 
                inherited::print(c, level);
 
-       } else if (is_of_type(c, print_csrc)) {
+       } else if (is_a<print_csrc>(c)) {
 
-               if (precedence <= level)
+               if (precedence() <= level)
                        c.s << "(";
        
                // Print arguments, separated by "+"
@@ -130,44 +131,44 @@ void add::print(const print_context & c, unsigned level) const
                
                        // If the coefficient is -1, it is replaced by a single minus sign
                        if (it->coeff.compare(_num1()) == 0) {
-                               it->rest.bp->print(c, precedence);
+                               it->rest.bp->print(c, precedence());
                        } else if (it->coeff.compare(_num_1()) == 0) {
                                c.s << "-";
-                               it->rest.bp->print(c, precedence);
-                       } else if (ex_to_numeric(it->coeff).numer().compare(_num1()) == 0) {
-                               it->rest.bp->print(c, precedence);
+                               it->rest.bp->print(c, precedence());
+                       } else if (ex_to<numeric>(it->coeff).numer().compare(_num1()) == 0) {
+                               it->rest.bp->print(c, precedence());
                                c.s << "/";
-                               ex_to_numeric(it->coeff).denom().print(c, precedence);
-                       } else if (ex_to_numeric(it->coeff).numer().compare(_num_1()) == 0) {
+                               ex_to<numeric>(it->coeff).denom().print(c, precedence());
+                       } else if (ex_to<numeric>(it->coeff).numer().compare(_num_1()) == 0) {
                                c.s << "-";
-                               it->rest.bp->print(c, precedence);
+                               it->rest.bp->print(c, precedence());
                                c.s << "/";
-                               ex_to_numeric(it->coeff).denom().print(c, precedence);
+                               ex_to<numeric>(it->coeff).denom().print(c, precedence());
                        } else {
-                               it->coeff.bp->print(c, precedence);
+                               it->coeff.bp->print(c, precedence());
                                c.s << "*";
-                               it->rest.bp->print(c, precedence);
+                               it->rest.bp->print(c, precedence());
                        }
                
                        // Separator is "+", except if the following expression would have a leading minus sign
                        it++;
-                       if (it != itend && !(it->coeff.compare(_num0()) < 0 || (it->coeff.compare(_num1()) == 0 && is_ex_exactly_of_type(it->rest, numeric) && it->rest.compare(_num0()) < 0)))
+                       if (it != itend && !(it->coeff.compare(_num0()) < 0 || (it->coeff.compare(_num1()) == 0 && is_exactly_a<numeric>(it->rest) && it->rest.compare(_num0()) < 0)))
                                c.s << "+";
                }
        
                if (!overall_coeff.is_zero()) {
                        if (overall_coeff.info(info_flags::positive))
                                c.s << '+';
-                       overall_coeff.bp->print(c, precedence);
+                       overall_coeff.bp->print(c, precedence());
                }
        
-               if (precedence <= level)
+               if (precedence() <= level)
                        c.s << ")";
 
        } else {
 
-               if (precedence <= level) {
-                       if (is_of_type(c, print_latex))
+               if (precedence() <= level) {
+                       if (is_a<print_latex>(c))
                                c.s << "{(";
                        else
                                c.s << "(";
@@ -178,17 +179,17 @@ void add::print(const print_context & c, unsigned level) const
 
                // First print the overall numeric coefficient, if present
                if (!overall_coeff.is_zero()) {
-                       if (!is_of_type(c, print_tree))
+                       if (!is_a<print_tree>(c))
                                overall_coeff.print(c, 0);
                        else
-                               overall_coeff.print(c, precedence);
+                               overall_coeff.print(c, precedence());
                        first = false;
                }
 
                // Then proceed with the remaining factors
                epvector::const_iterator it = seq.begin(), itend = seq.end();
                while (it != itend) {
-                       coeff = ex_to_numeric(it->coeff);
+                       coeff = ex_to<numeric>(it->coeff);
                        if (!first) {
                                if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
                        } else {
@@ -204,21 +205,21 @@ void add::print(const print_context & c, unsigned level) const
                                                coeff.print(c);
                                } else {
                                        if (coeff.csgn() == -1)
-                                               (-coeff).print(c, precedence);
+                                               (-coeff).print(c, precedence());
                                        else
-                                               coeff.print(c, precedence);
+                                               coeff.print(c, precedence());
                                }
-                               if (is_of_type(c, print_latex))
+                               if (is_a<print_latex>(c))
                                        c.s << ' ';
                                else
                                        c.s << '*';
                        }
-                       it->rest.print(c, precedence);
+                       it->rest.print(c, precedence());
                        it++;
                }
 
-               if (precedence <= level) {
-                       if (is_of_type(c, print_latex))
+               if (precedence() <= level) {
+                       if (is_a<print_latex>(c))
                                c.s << ")}";
                        else
                                c.s << ")";
@@ -284,18 +285,16 @@ int add::ldegree(const ex & s) const
 ex add::coeff(const ex & s, int n) const
 {
        epvector coeffseq;
-       coeffseq.reserve(seq.size());
-
+       
        epvector::const_iterator it=seq.begin();
        while (it!=seq.end()) {
-               coeffseq.push_back(combine_ex_with_coeff_to_pair((*it).rest.coeff(s,n),
-                                                                (*it).coeff));
+               ex restcoeff = it->rest.coeff(s,n);
+               if (!restcoeff.is_zero())
+                       coeffseq.push_back(combine_ex_with_coeff_to_pair(restcoeff,it->coeff));
                ++it;
        }
-       if (n==0) {
-               return (new add(coeffseq,overall_coeff))->setflag(status_flags::dynallocated);
-       }
-       return (new add(coeffseq))->setflag(status_flags::dynallocated);
+       
+       return (new add(coeffseq, n==0 ? overall_coeff : default_overall_coeff()))->setflag(status_flags::dynallocated);
 }
 
 ex add::eval(int level) const
@@ -338,6 +337,38 @@ ex add::eval(int level) const
        return this->hold();
 }
 
+ex add::evalm(void) const
+{
+       // Evaluate children first and add up all matrices. Stop if there's one
+       // term that is not a matrix.
+       epvector *s = new epvector;
+       s->reserve(seq.size());
+
+       bool all_matrices = true;
+       bool first_term = true;
+       matrix sum;
+
+       epvector::const_iterator it = seq.begin(), itend = seq.end();
+       while (it != itend) {
+               const ex &m = recombine_pair_to_ex(*it).evalm();
+               s->push_back(split_ex_to_pair(m));
+               if (is_ex_of_type(m, matrix)) {
+                       if (first_term) {
+                               sum = ex_to<matrix>(m);
+                               first_term = false;
+                       } else
+                               sum = sum.add(ex_to<matrix>(m));
+               } else
+                       all_matrices = false;
+               it++;
+       }
+
+       if (all_matrices)
+               return sum + overall_coeff;
+       else
+               return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
+}
+
 ex add::simplify_ncmul(const exvector & v) const
 {
        if (seq.size()==0) {
@@ -395,7 +426,7 @@ ex add::thisexpairseq(epvector * vp, const ex & oc) const
 expair add::split_ex_to_pair(const ex & e) const
 {
        if (is_ex_exactly_of_type(e,mul)) {
-               const mul &mulref = ex_to_mul(e);
+               const mul &mulref(ex_to<mul>(e));
                ex numfactor = mulref.overall_coeff;
                mul *mulcopyp = new mul(mulref);
                mulcopyp->overall_coeff = _ex1();
@@ -412,7 +443,7 @@ expair add::combine_ex_with_coeff_to_pair(const ex & e,
 {
        GINAC_ASSERT(is_ex_exactly_of_type(c, numeric));
        if (is_ex_exactly_of_type(e, mul)) {
-               const mul &mulref = ex_to_mul(e);
+               const mul &mulref(ex_to<mul>(e));
                ex numfactor = mulref.overall_coeff;
                mul *mulcopyp = new mul(mulref);
                mulcopyp->overall_coeff = _ex1();
@@ -424,11 +455,11 @@ expair add::combine_ex_with_coeff_to_pair(const ex & e,
                else if (are_ex_trivially_equal(numfactor, _ex1()))
                        return expair(*mulcopyp, c);
                else
-                       return expair(*mulcopyp, ex_to_numeric(numfactor).mul_dyn(ex_to_numeric(c)));
+                       return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
        } else if (is_ex_exactly_of_type(e, numeric)) {
                if (are_ex_trivially_equal(c, _ex1()))
                        return expair(e, _ex1());
-               return expair(ex_to_numeric(e).mul_dyn(ex_to_numeric(c)), _ex1());
+               return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1());
        }
        return expair(e, c);
 }
@@ -440,16 +471,16 @@ expair add::combine_pair_with_coeff_to_pair(const expair & p,
        GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
 
        if (is_ex_exactly_of_type(p.rest,numeric)) {
-               GINAC_ASSERT(ex_to_numeric(p.coeff).is_equal(_num1())); // should be normalized
-               return expair(ex_to_numeric(p.rest).mul_dyn(ex_to_numeric(c)),_ex1());
+               GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(_num1())); // should be normalized
+               return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1());
        }
 
-       return expair(p.rest,ex_to_numeric(p.coeff).mul_dyn(ex_to_numeric(c)));
+       return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
 }
        
 ex add::recombine_pair_to_ex(const expair & p) const
 {
-       if (ex_to_numeric(p.coeff).is_equal(_num1()))
+       if (ex_to<numeric>(p.coeff).is_equal(_num1()))
                return p.rest;
        else
                return p.rest*p.coeff;
@@ -470,12 +501,4 @@ ex add::expand(unsigned options) const
        return (new add(vp,overall_coeff))->setflag(status_flags::expanded | status_flags::dynallocated);
 }
 
-//////////
-// static member variables
-//////////
-
-// protected
-
-unsigned add::precedence = 40;
-
 } // namespace GiNaC