]> www.ginac.de Git - ginac.git/blobdiff - ginac/expairseq.cpp
parser: add necessary checks to operator() to stop accepting nonsense.
[ginac.git] / ginac / expairseq.cpp
index 3fb2d991286a97ecbdb454ef4eaacbe85befffb8..67aa4f8ef681bf3453eb83fecc0181076edba70e 100644 (file)
@@ -3,7 +3,7 @@
  *  Implementation of sequences of expression pairs. */
 
 /*
- *  GiNaC Copyright (C) 1999-2006 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -24,6 +24,7 @@
 #include <algorithm>
 #include <string>
 #include <stdexcept>
+#include <iterator>
 
 #include "expairseq.h"
 #include "lst.h"
@@ -267,6 +268,26 @@ void expairseq::do_print_tree(const print_tree & c, unsigned level) const
 
 bool expairseq::info(unsigned inf) const
 {
+       switch(inf) {
+               case info_flags::expanded:
+                       return (flags & status_flags::expanded);
+               case info_flags::has_indices: {
+                       if (flags & status_flags::has_indices)
+                               return true;
+                       else if (flags & status_flags::has_no_indices)
+                               return false;
+                       for (epvector::const_iterator i = seq.begin(); i != seq.end(); ++i) {
+                               if (i->rest.info(info_flags::has_indices)) {
+                                       this->setflag(status_flags::has_indices);
+                                       this->clearflag(status_flags::has_no_indices);
+                                       return true;
+                               }
+                       }
+                       this->clearflag(status_flags::has_indices);
+                       this->setflag(status_flags::has_no_indices);
+                       return false;
+               }
+       }
        return inherited::info(inf);
 }
 
@@ -370,7 +391,7 @@ bool expairseq::is_polynomial(const ex & var) const
        return true;
 }
 
-bool expairseq::match(const ex & pattern, lst & repl_lst) const
+bool expairseq::match(const ex & pattern, exmap & repl_lst) const
 {
        // This differs from basic::match() because we want "a+b+c+d" to
        // match "d+*+b" with "*" being "a+c", and we want to honor commutativity
@@ -406,20 +427,10 @@ bool expairseq::match(const ex & pattern, lst & repl_lst) const
                                continue;
                        exvector::iterator it = ops.begin(), itend = ops.end();
                        while (it != itend) {
-                               lst::const_iterator last_el = repl_lst.end();
-                               --last_el;
                                if (it->match(p, repl_lst)) {
                                        ops.erase(it);
                                        goto found;
                                }
-                               while(true) {
-                                       lst::const_iterator next_el = last_el;
-                                       ++next_el;
-                                       if(next_el == repl_lst.end())
-                                               break;
-                                       else
-                                               repl_lst.remove_last();
-                               }
                                ++it;
                        }
                        return false; // no match found
@@ -437,11 +448,11 @@ found:            ;
                        for (size_t i=0; i<num; i++)
                                vp->push_back(split_ex_to_pair(ops[i]));
                        ex rest = thisexpairseq(vp, default_overall_coeff());
-                       for (lst::const_iterator it = repl_lst.begin(); it != repl_lst.end(); ++it) {
-                               if (it->op(0).is_equal(global_wildcard))
-                                       return rest.is_equal(it->op(1));
+                       for (exmap::const_iterator it = repl_lst.begin(); it != repl_lst.end(); ++it) {
+                               if (it->first.is_equal(global_wildcard))
+                                       return rest.is_equal(it->second);
                        }
-                       repl_lst.append(global_wildcard == rest);
+                       repl_lst[global_wildcard] = rest;
                        return true;
 
                } else {
@@ -791,8 +802,8 @@ void expairseq::construct_from_2_ex(const ex &lh, const ex &rh)
                                construct_from_2_ex_via_exvector(lh,rh);
                        } else {
 #endif // EXPAIRSEQ_USE_HASHTAB
-                               if(is_a<mul>(lh))
-                               {
+                               if (is_a<mul>(lh) && lh.info(info_flags::has_indices) && 
+                                       rh.info(info_flags::has_indices)) {
                                        ex newrh=rename_dummy_indices_uniquely(lh, rh);
                                        construct_from_2_expairseq(ex_to<expairseq>(lh),
                                                                   ex_to<expairseq>(newrh));
@@ -1036,6 +1047,7 @@ void expairseq::make_flat(const exvector &v)
        // and their cumulative number of operands
        int nexpairseqs = 0;
        int noperands = 0;
+       bool do_idx_rename = false;
        
        cit = v.begin();
        while (cit!=v.end()) {
@@ -1043,6 +1055,9 @@ void expairseq::make_flat(const exvector &v)
                        ++nexpairseqs;
                        noperands += ex_to<expairseq>(*cit).seq.size();
                }
+               if (is_a<mul>(*this) && (!do_idx_rename) &&
+                               cit->info(info_flags::has_indices))
+                       do_idx_rename = true;
                ++cit;
        }
        
@@ -1050,7 +1065,7 @@ void expairseq::make_flat(const exvector &v)
        seq.reserve(v.size()+noperands-nexpairseqs);
        
        // copy elements and split off numerical part
-       make_flat_inserter mf(v, this->tinfo() == &mul::tinfo_static);
+       make_flat_inserter mf(v, do_idx_rename);
        cit = v.begin();
        while (cit!=v.end()) {
                if (ex_to<basic>(*cit).tinfo()==this->tinfo()) {
@@ -1084,6 +1099,7 @@ void expairseq::make_flat(const epvector &v, bool do_index_renaming)
        // and their cumulative number of operands
        int nexpairseqs = 0;
        int noperands = 0;
+       bool really_need_rename_inds = false;
        
        cit = v.begin();
        while (cit!=v.end()) {
@@ -1091,8 +1107,12 @@ void expairseq::make_flat(const epvector &v, bool do_index_renaming)
                        ++nexpairseqs;
                        noperands += ex_to<expairseq>(cit->rest).seq.size();
                }
+               if ((!really_need_rename_inds) && is_a<mul>(*this) &&
+                               cit->rest.info(info_flags::has_indices))
+                       really_need_rename_inds = true;
                ++cit;
        }
+       do_index_renaming = do_index_renaming && really_need_rename_inds;
        
        // reserve seq and coeffseq which will hold all operands
        seq.reserve(v.size()+noperands-nexpairseqs);