]> www.ginac.de Git - ginac.git/blobdiff - ginac/add.cpp
- do something about the mad cast disease.
[ginac.git] / ginac / add.cpp
index 92cb325fe04c00db04e0875c55774c9ed7856f38..647de1722c7065c2d735f7093b187880d42c4203 100644 (file)
@@ -25,6 +25,7 @@
 
 #include "add.h"
 #include "mul.h"
+#include "matrix.h"
 #include "archive.h"
 #include "debugmsg.h"
 #include "utils.h"
@@ -115,11 +116,11 @@ void add::print(const print_context & c, unsigned level) const
 {
        debugmsg("add print", LOGLEVEL_PRINT);
 
-       if (is_of_type(c, print_tree)) {
+       if (is_a<print_tree>(c)) {
 
                inherited::print(c, level);
 
-       } else if (is_of_type(c, print_csrc)) {
+       } else if (is_a<print_csrc>(c)) {
 
                if (precedence() <= level)
                        c.s << "(";
@@ -134,15 +135,15 @@ void add::print(const print_context & c, unsigned level) const
                        } else if (it->coeff.compare(_num_1()) == 0) {
                                c.s << "-";
                                it->rest.bp->print(c, precedence());
-                       } else if (ex_to_numeric(it->coeff).numer().compare(_num1()) == 0) {
+                       } else if (ex_to<numeric>(it->coeff).numer().compare(_num1()) == 0) {
                                it->rest.bp->print(c, precedence());
                                c.s << "/";
-                               ex_to_numeric(it->coeff).denom().print(c, precedence());
-                       } else if (ex_to_numeric(it->coeff).numer().compare(_num_1()) == 0) {
+                               ex_to<numeric>(it->coeff).denom().print(c, precedence());
+                       } else if (ex_to<numeric>(it->coeff).numer().compare(_num_1()) == 0) {
                                c.s << "-";
                                it->rest.bp->print(c, precedence());
                                c.s << "/";
-                               ex_to_numeric(it->coeff).denom().print(c, precedence());
+                               ex_to<numeric>(it->coeff).denom().print(c, precedence());
                        } else {
                                it->coeff.bp->print(c, precedence());
                                c.s << "*";
@@ -151,7 +152,7 @@ void add::print(const print_context & c, unsigned level) const
                
                        // Separator is "+", except if the following expression would have a leading minus sign
                        it++;
-                       if (it != itend && !(it->coeff.compare(_num0()) < 0 || (it->coeff.compare(_num1()) == 0 && is_ex_exactly_of_type(it->rest, numeric) && it->rest.compare(_num0()) < 0)))
+                       if (it != itend && !(it->coeff.compare(_num0()) < 0 || (it->coeff.compare(_num1()) == 0 && is_exactly_a<numeric>(it->rest) && it->rest.compare(_num0()) < 0)))
                                c.s << "+";
                }
        
@@ -167,7 +168,7 @@ void add::print(const print_context & c, unsigned level) const
        } else {
 
                if (precedence() <= level) {
-                       if (is_of_type(c, print_latex))
+                       if (is_a<print_latex>(c))
                                c.s << "{(";
                        else
                                c.s << "(";
@@ -178,7 +179,7 @@ void add::print(const print_context & c, unsigned level) const
 
                // First print the overall numeric coefficient, if present
                if (!overall_coeff.is_zero()) {
-                       if (!is_of_type(c, print_tree))
+                       if (!is_a<print_tree>(c))
                                overall_coeff.print(c, 0);
                        else
                                overall_coeff.print(c, precedence());
@@ -188,7 +189,7 @@ void add::print(const print_context & c, unsigned level) const
                // Then proceed with the remaining factors
                epvector::const_iterator it = seq.begin(), itend = seq.end();
                while (it != itend) {
-                       coeff = ex_to_numeric(it->coeff);
+                       coeff = ex_to<numeric>(it->coeff);
                        if (!first) {
                                if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
                        } else {
@@ -208,7 +209,7 @@ void add::print(const print_context & c, unsigned level) const
                                        else
                                                coeff.print(c, precedence());
                                }
-                               if (is_of_type(c, print_latex))
+                               if (is_a<print_latex>(c))
                                        c.s << ' ';
                                else
                                        c.s << '*';
@@ -218,7 +219,7 @@ void add::print(const print_context & c, unsigned level) const
                }
 
                if (precedence() <= level) {
-                       if (is_of_type(c, print_latex))
+                       if (is_a<print_latex>(c))
                                c.s << ")}";
                        else
                                c.s << ")";
@@ -336,6 +337,38 @@ ex add::eval(int level) const
        return this->hold();
 }
 
+ex add::evalm(void) const
+{
+       // Evaluate children first and add up all matrices. Stop if there's one
+       // term that is not a matrix.
+       epvector *s = new epvector;
+       s->reserve(seq.size());
+
+       bool all_matrices = true;
+       bool first_term = true;
+       matrix sum;
+
+       epvector::const_iterator it = seq.begin(), itend = seq.end();
+       while (it != itend) {
+               const ex &m = recombine_pair_to_ex(*it).evalm();
+               s->push_back(split_ex_to_pair(m));
+               if (is_ex_of_type(m, matrix)) {
+                       if (first_term) {
+                               sum = ex_to<matrix>(m);
+                               first_term = false;
+                       } else
+                               sum = sum.add(ex_to<matrix>(m));
+               } else
+                       all_matrices = false;
+               it++;
+       }
+
+       if (all_matrices)
+               return sum + overall_coeff;
+       else
+               return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
+}
+
 ex add::simplify_ncmul(const exvector & v) const
 {
        if (seq.size()==0) {
@@ -393,7 +426,7 @@ ex add::thisexpairseq(epvector * vp, const ex & oc) const
 expair add::split_ex_to_pair(const ex & e) const
 {
        if (is_ex_exactly_of_type(e,mul)) {
-               const mul &mulref = ex_to_mul(e);
+               const mul &mulref(ex_to<mul>(e));
                ex numfactor = mulref.overall_coeff;
                mul *mulcopyp = new mul(mulref);
                mulcopyp->overall_coeff = _ex1();
@@ -410,7 +443,7 @@ expair add::combine_ex_with_coeff_to_pair(const ex & e,
 {
        GINAC_ASSERT(is_ex_exactly_of_type(c, numeric));
        if (is_ex_exactly_of_type(e, mul)) {
-               const mul &mulref = ex_to_mul(e);
+               const mul &mulref(ex_to<mul>(e));
                ex numfactor = mulref.overall_coeff;
                mul *mulcopyp = new mul(mulref);
                mulcopyp->overall_coeff = _ex1();
@@ -422,11 +455,11 @@ expair add::combine_ex_with_coeff_to_pair(const ex & e,
                else if (are_ex_trivially_equal(numfactor, _ex1()))
                        return expair(*mulcopyp, c);
                else
-                       return expair(*mulcopyp, ex_to_numeric(numfactor).mul_dyn(ex_to_numeric(c)));
+                       return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
        } else if (is_ex_exactly_of_type(e, numeric)) {
                if (are_ex_trivially_equal(c, _ex1()))
                        return expair(e, _ex1());
-               return expair(ex_to_numeric(e).mul_dyn(ex_to_numeric(c)), _ex1());
+               return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1());
        }
        return expair(e, c);
 }
@@ -438,16 +471,16 @@ expair add::combine_pair_with_coeff_to_pair(const expair & p,
        GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
 
        if (is_ex_exactly_of_type(p.rest,numeric)) {
-               GINAC_ASSERT(ex_to_numeric(p.coeff).is_equal(_num1())); // should be normalized
-               return expair(ex_to_numeric(p.rest).mul_dyn(ex_to_numeric(c)),_ex1());
+               GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(_num1())); // should be normalized
+               return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1());
        }
 
-       return expair(p.rest,ex_to_numeric(p.coeff).mul_dyn(ex_to_numeric(c)));
+       return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
 }
        
 ex add::recombine_pair_to_ex(const expair & p) const
 {
-       if (ex_to_numeric(p.coeff).is_equal(_num1()))
+       if (ex_to<numeric>(p.coeff).is_equal(_num1()))
                return p.rest;
        else
                return p.rest*p.coeff;