]> www.ginac.de Git - ginac.git/blobdiff - ginac/mul.cpp
bumped archive version because of new way of specifying symmetry for indexed
[ginac.git] / ginac / mul.cpp
index 8e641caf2ddec551ccfd6f86308649b5690b9d06..04941b3da270398c32db690cbcfb5812c4bfe63c 100644 (file)
@@ -26,6 +26,7 @@
 #include "mul.h"
 #include "add.h"
 #include "power.h"
+#include "matrix.h"
 #include "archive.h"
 #include "debugmsg.h"
 #include "utils.h"
@@ -130,11 +131,11 @@ void mul::print(const print_context & c, unsigned level) const
 {
        debugmsg("mul print", LOGLEVEL_PRINT);
 
-       if (is_of_type(c, print_tree)) {
+       if (is_a<print_tree>(c)) {
 
                inherited::print(c, level);
 
-       } else if (is_of_type(c, print_csrc)) {
+       } else if (is_a<print_csrc>(c)) {
 
                if (precedence() <= level)
                        c.s << "(";
@@ -143,14 +144,14 @@ void mul::print(const print_context & c, unsigned level) const
                        overall_coeff.bp->print(c, precedence());
                        c.s << "*";
                }
-       
+
                // Print arguments, separated by "*" or "/"
                epvector::const_iterator it = seq.begin(), itend = seq.end();
                while (it != itend) {
 
                        // If the first argument is a negative integer power, it gets printed as "1.0/<expr>"
-                       if (it == seq.begin() && ex_to_numeric(it->coeff).is_integer() && it->coeff.compare(_num0()) < 0) {
-                               if (is_of_type(c, print_csrc_cl_N))
+                       if (it == seq.begin() && ex_to<numeric>(it->coeff).is_integer() && it->coeff.compare(_num0()) < 0) {
+                               if (is_a<print_csrc_cl_N>(c))
                                        c.s << "recip(";
                                else
                                        c.s << "1.0/";
@@ -161,13 +162,13 @@ void mul::print(const print_context & c, unsigned level) const
                                it->rest.print(c, precedence());
                        else {
                                // Outer parens around ex needed for broken gcc-2.95 parser:
-                               (ex(power(it->rest, abs(ex_to_numeric(it->coeff))))).print(c, level);
+                               (ex(power(it->rest, abs(ex_to<numeric>(it->coeff))))).print(c, level);
                        }
 
                        // Separator is "/" for negative integer powers, "*" otherwise
                        ++it;
                        if (it != itend) {
-                               if (ex_to_numeric(it->coeff).is_integer() && it->coeff.compare(_num0()) < 0)
+                               if (ex_to<numeric>(it->coeff).is_integer() && it->coeff.compare(_num0()) < 0)
                                        c.s << "/";
                                else
                                        c.s << "*";
@@ -180,7 +181,7 @@ void mul::print(const print_context & c, unsigned level) const
        } else {
 
                if (precedence() <= level) {
-                       if (is_of_type(c, print_latex))
+                       if (is_a<print_latex>(c))
                                c.s << "{(";
                        else
                                c.s << "(";
@@ -189,7 +190,7 @@ void mul::print(const print_context & c, unsigned level) const
                bool first = true;
 
                // First print the overall numeric coefficient
-               numeric coeff = ex_to_numeric(overall_coeff);
+               numeric coeff = ex_to<numeric>(overall_coeff);
                if (coeff.csgn() == -1)
                        c.s << '-';
                if (!coeff.is_equal(_num1()) &&
@@ -205,7 +206,7 @@ void mul::print(const print_context & c, unsigned level) const
                                else
                                        coeff.print(c, precedence());
                        }
-                       if (is_of_type(c, print_latex))
+                       if (is_a<print_latex>(c))
                                c.s << ' ';
                        else
                                c.s << '*';
@@ -215,7 +216,7 @@ void mul::print(const print_context & c, unsigned level) const
                epvector::const_iterator it = seq.begin(), itend = seq.end();
                while (it != itend) {
                        if (!first) {
-                               if (is_of_type(c, print_latex))
+                               if (is_a<print_latex>(c))
                                        c.s << ' ';
                                else
                                        c.s << '*';
@@ -227,7 +228,7 @@ void mul::print(const print_context & c, unsigned level) const
                }
 
                if (precedence() <= level) {
-                       if (is_of_type(c, print_latex))
+                       if (is_a<print_latex>(c))
                                c.s << ")}";
                        else
                                c.s << ")";
@@ -265,8 +266,8 @@ int mul::degree(const ex & s) const
 {
        int deg_sum = 0;
        for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
-               if (ex_to_numeric(cit->coeff).is_integer())
-                       deg_sum+=cit->rest.degree(s) * ex_to_numeric(cit->coeff).to_int();
+               if (ex_to<numeric>(cit->coeff).is_integer())
+                       deg_sum+=cit->rest.degree(s) * ex_to<numeric>(cit->coeff).to_int();
        }
        return deg_sum;
 }
@@ -275,8 +276,8 @@ int mul::ldegree(const ex & s) const
 {
        int deg_sum = 0;
        for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
-               if (ex_to_numeric(cit->coeff).is_integer())
-                       deg_sum+=cit->rest.ldegree(s) * ex_to_numeric(cit->coeff).to_int();
+               if (ex_to<numeric>(cit->coeff).is_integer())
+                       deg_sum+=cit->rest.ldegree(s) * ex_to<numeric>(cit->coeff).to_int();
        }
        return deg_sum;
 }
@@ -338,7 +339,7 @@ ex mul::eval(int level) const
 #ifdef DO_GINAC_ASSERT
        for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
                GINAC_ASSERT((!is_ex_exactly_of_type((*cit).rest,mul)) ||
-                            (!(ex_to_numeric((*cit).coeff).is_integer())));
+                            (!(ex_to<numeric>((*cit).coeff).is_integer())));
                GINAC_ASSERT(!(cit->is_canonical_numeric()));
                if (is_ex_exactly_of_type(recombine_pair_to_ex(*cit),numeric))
                    print(print_tree(std::cerr));
@@ -369,17 +370,17 @@ ex mul::eval(int level) const
                return recombine_pair_to_ex(*(seq.begin()));
        } else if ((seq_size==1) &&
                   is_ex_exactly_of_type((*seq.begin()).rest,add) &&
-                  ex_to_numeric((*seq.begin()).coeff).is_equal(_num1())) {
+                  ex_to<numeric>((*seq.begin()).coeff).is_equal(_num1())) {
                // *(+(x,y,...);c) -> +(*(x,c),*(y,c),...) (c numeric(), no powers of +())
-               const add & addref = ex_to_add((*seq.begin()).rest);
+               const add & addref = ex_to<add>((*seq.begin()).rest);
                epvector distrseq;
                distrseq.reserve(addref.seq.size());
                for (epvector::const_iterator cit=addref.seq.begin(); cit!=addref.seq.end(); ++cit) {
                        distrseq.push_back(addref.combine_pair_with_coeff_to_pair(*cit, overall_coeff));
                }
                return (new add(distrseq,
-                               ex_to_numeric(addref.overall_coeff).
-                               mul_dyn(ex_to_numeric(overall_coeff))))
+                               ex_to<numeric>(addref.overall_coeff).
+                               mul_dyn(ex_to<numeric>(overall_coeff))))
                      ->setflag(status_flags::dynallocated | status_flags::evaluated);
        }
        return this->hold();
@@ -404,6 +405,46 @@ ex mul::evalf(int level) const
        return mul(s,overall_coeff.evalf(level));
 }
 
+ex mul::evalm(void) const
+{
+       // numeric*matrix
+       if (seq.size() == 1 && seq[0].coeff.is_equal(_ex1())
+        && is_ex_of_type(seq[0].rest, matrix))
+               return ex_to<matrix>(seq[0].rest).mul(ex_to<numeric>(overall_coeff));
+
+       // Evaluate children first, look whether there are any matrices at all
+       // (there can be either no matrices or one matrix; if there were more
+       // than one matrix, it would be a non-commutative product)
+       epvector *s = new epvector;
+       s->reserve(seq.size());
+
+       bool have_matrix = false;
+       epvector::iterator the_matrix;
+
+       epvector::const_iterator it = seq.begin(), itend = seq.end();
+       while (it != itend) {
+               const ex &m = recombine_pair_to_ex(*it).evalm();
+               s->push_back(split_ex_to_pair(m));
+               if (is_ex_of_type(m, matrix)) {
+                       have_matrix = true;
+                       the_matrix = s->end() - 1;
+               }
+               it++;
+       }
+
+       if (have_matrix) {
+
+               // The product contained a matrix. We will multiply all other factors
+               // into that matrix.
+               matrix m = ex_to<matrix>(the_matrix->rest);
+               s->erase(the_matrix);
+               ex scalar = (new mul(s, overall_coeff))->setflag(status_flags::dynallocated);
+               return m.mul_scalar(scalar);
+
+       } else
+               return (new mul(s, overall_coeff))->setflag(status_flags::dynallocated);
+}
+
 ex mul::simplify_ncmul(const exvector & v) const
 {
        if (seq.size()==0) {
@@ -505,7 +546,7 @@ ex mul::thisexpairseq(epvector * vp, const ex & oc) const
 expair mul::split_ex_to_pair(const ex & e) const
 {
        if (is_ex_exactly_of_type(e,power)) {
-               const power & powerref = ex_to_power(e);
+               const power & powerref = ex_to<power>(e);
                if (is_ex_exactly_of_type(powerref.exponent,numeric))
                        return expair(powerref.basis,powerref.exponent);
        }
@@ -540,7 +581,7 @@ expair mul::combine_pair_with_coeff_to_pair(const expair & p,
        
 ex mul::recombine_pair_to_ex(const expair & p) const
 {
-       if (ex_to_numeric(p.coeff).is_equal(_num1())) 
+       if (ex_to<numeric>(p.coeff).is_equal(_num1())) 
                return p.rest;
        else
                return power(p.rest,p.coeff);
@@ -549,7 +590,7 @@ ex mul::recombine_pair_to_ex(const expair & p) const
 bool mul::expair_needs_further_processing(epp it)
 {
        if (is_ex_exactly_of_type((*it).rest,mul) &&
-               ex_to_numeric((*it).coeff).is_integer()) {
+               ex_to<numeric>((*it).coeff).is_integer()) {
                // combined pair is product with integer power -> expand it
                *it = split_ex_to_pair(recombine_pair_to_ex(*it));
                return true;
@@ -561,7 +602,7 @@ bool mul::expair_needs_further_processing(epp it)
                        *it = ep;
                        return true;
                }
-               if (ex_to_numeric((*it).coeff).is_equal(_num1())) {
+               if (ex_to<numeric>((*it).coeff).is_equal(_num1())) {
                        // combined pair has coeff 1 and must be moved to the end
                        return true;
                }
@@ -578,7 +619,7 @@ void mul::combine_overall_coeff(const ex & c)
 {
        GINAC_ASSERT(is_ex_exactly_of_type(overall_coeff,numeric));
        GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
-       overall_coeff = ex_to_numeric(overall_coeff).mul_dyn(ex_to_numeric(c));
+       overall_coeff = ex_to<numeric>(overall_coeff).mul_dyn(ex_to<numeric>(c));
 }
 
 void mul::combine_overall_coeff(const ex & c1, const ex & c2)
@@ -586,7 +627,7 @@ void mul::combine_overall_coeff(const ex & c1, const ex & c2)
        GINAC_ASSERT(is_ex_exactly_of_type(overall_coeff,numeric));
        GINAC_ASSERT(is_ex_exactly_of_type(c1,numeric));
        GINAC_ASSERT(is_ex_exactly_of_type(c2,numeric));
-       overall_coeff = ex_to_numeric(overall_coeff).mul_dyn(ex_to_numeric(c1).power(ex_to_numeric(c2)));
+       overall_coeff = ex_to<numeric>(overall_coeff).mul_dyn(ex_to<numeric>(c1).power(ex_to<numeric>(c2)));
 }
 
 bool mul::can_make_flat(const expair & p) const
@@ -595,7 +636,7 @@ bool mul::can_make_flat(const expair & p) const
        // this assertion will probably fail somewhere
        // it would require a more careful make_flat, obeying the power laws
        // probably should return true only if p.coeff is integer
-       return ex_to_numeric(p.coeff).is_equal(_num1());
+       return ex_to<numeric>(p.coeff).is_equal(_num1());
 }
 
 ex mul::expand(unsigned options) const
@@ -621,8 +662,8 @@ ex mul::expand(unsigned options) const
                        ++number_of_adds;
                        if (is_ex_exactly_of_type(last_expanded,add)) {
                                // expand adds
-                               const add & add1 = ex_to_add(last_expanded);
-                               const add & add2 = ex_to_add((*cit).rest);
+                               const add & add1 = ex_to<add>(last_expanded);
+                               const add & add2 = ex_to<add>((*cit).rest);
                                int n1 = add1.nops();
                                int n2 = add2.nops();
                                exvector distrseq;
@@ -646,7 +687,7 @@ ex mul::expand(unsigned options) const
                delete expanded_seqp;
 
        if (is_ex_exactly_of_type(last_expanded,add)) {
-               add const & finaladd = ex_to_add(last_expanded);
+               add const & finaladd = ex_to<add>(last_expanded);
                exvector distrseq;
                int n = finaladd.nops();
                distrseq.reserve(n);