]> www.ginac.de Git - ginac.git/blobdiff - ginac/expairseq.cpp
Index renaming issues, sped up simplify_indexed, used defined NC-objects
[ginac.git] / ginac / expairseq.cpp
index 7555bd185d422c011beb9a86f5de8e6ea5b61666..9f580d9e768ff4ac58c5fcaaa0865124ad7069fd 100644 (file)
@@ -34,6 +34,7 @@
 #include "archive.h"
 #include "operators.h"
 #include "utils.h"
+#include "indexed.h"
 
 #if EXPAIRSEQ_USE_HASHTAB
 #include <cmath>
@@ -757,8 +758,15 @@ void expairseq::construct_from_2_ex(const ex &lh, const ex &rh)
                                construct_from_2_ex_via_exvector(lh,rh);
                        } else {
 #endif // EXPAIRSEQ_USE_HASHTAB
-                               construct_from_2_expairseq(ex_to<expairseq>(lh),
-                                                          ex_to<expairseq>(rh));
+                               if(is_a<mul>(lh))
+                               {       
+                                       ex newrh=rename_dummy_indices_uniquely(lh, rh);
+                                       construct_from_2_expairseq(ex_to<expairseq>(lh),
+                                                                  ex_to<expairseq>(newrh));
+                               }
+                               else
+                                       construct_from_2_expairseq(ex_to<expairseq>(lh),
+                                                                  ex_to<expairseq>(rh));
 #if EXPAIRSEQ_USE_HASHTAB
                        }
 #endif // EXPAIRSEQ_USE_HASHTAB
@@ -1008,13 +1016,27 @@ void expairseq::make_flat(const exvector &v)
        seq.reserve(v.size()+noperands-nexpairseqs);
        
        // copy elements and split off numerical part
+       exvector dummy_indices;
        cit = v.begin();
        while (cit!=v.end()) {
                if (ex_to<basic>(*cit).tinfo()==this->tinfo()) {
-                       const expairseq &subseqref = ex_to<expairseq>(*cit);
-                       combine_overall_coeff(subseqref.overall_coeff);
-                       epvector::const_iterator cit_s = subseqref.seq.begin();
-                       while (cit_s!=subseqref.seq.end()) {
+                       const expairseq *subseqref;
+                       ex newfactor;
+                       if(is_a<mul>(*cit))
+                       {
+                               exvector dummies_of_factor = get_all_dummy_indices(*cit);
+                               sort(dummies_of_factor.begin(), dummies_of_factor.end(), ex_is_less());
+                               newfactor = rename_dummy_indices_uniquely(dummy_indices, dummies_of_factor, *cit);
+                               subseqref = &(ex_to<expairseq>(newfactor));
+                               exvector new_dummy_indices;
+                               set_union(dummy_indices.begin(), dummy_indices.end(), dummies_of_factor.begin(), dummies_of_factor.end(), std::back_insert_iterator<exvector>(new_dummy_indices), ex_is_less());
+                               dummy_indices.swap(new_dummy_indices);
+                       }
+                       else
+                               subseqref = &ex_to<expairseq>(*cit);
+                       combine_overall_coeff(subseqref->overall_coeff);
+                       epvector::const_iterator cit_s = subseqref->seq.begin();
+                       while (cit_s!=subseqref->seq.end()) {
                                seq.push_back(*cit_s);
                                ++cit_s;
                        }
@@ -1579,6 +1601,67 @@ std::auto_ptr<epvector> expairseq::evalchildren(int level) const
        return std::auto_ptr<epvector>(0); // signalling nothing has changed
 }
 
+class safe_inserter
+{
+       public:
+               safe_inserter(const ex&, const bool disable_renaming=false);
+               std::auto_ptr<epvector> getseq(){return epv;}
+               void insert_old_pair(const expair &p)
+               {
+                       epv->push_back(p);
+               }
+               void insert_new_pair(const expair &p, const ex &orig_ex);
+       private:
+               std::auto_ptr<epvector> epv;
+               bool dodummies;
+               exvector dummy_indices;
+               void update_dummy_indices(const exvector&);
+};
+
+safe_inserter::safe_inserter(const ex&e, const bool disable_renaming)
+               :epv(new epvector)
+{
+       epv->reserve(e.nops());
+       dodummies=is_a<mul>(e);
+       if(disable_renaming)
+               dodummies=false;
+       if(dodummies) {
+               dummy_indices = get_all_dummy_indices(e);
+               sort(dummy_indices.begin(), dummy_indices.end(), ex_is_less());
+       }
+}
+
+void safe_inserter::update_dummy_indices(const exvector &v)
+{
+       exvector new_dummy_indices;
+       set_union(dummy_indices.begin(), dummy_indices.end(), v.begin(), v.end(),
+               std::back_insert_iterator<exvector>(new_dummy_indices), ex_is_less());
+       dummy_indices.swap(new_dummy_indices);
+}
+
+void safe_inserter::insert_new_pair(const expair &p, const ex &orig_ex)
+{
+       if(!dodummies) {
+               epv->push_back(p);
+               return;
+       }
+       exvector dummies_of_factor = get_all_dummy_indices(p.rest);
+       if(dummies_of_factor.size() == 0) {
+               epv->push_back(p);
+               return;
+       }
+       sort(dummies_of_factor.begin(), dummies_of_factor.end(), ex_is_less());
+       exvector dummies_of_orig_ex = get_all_dummy_indices(orig_ex);
+       sort(dummies_of_orig_ex.begin(), dummies_of_orig_ex.end(), ex_is_less());
+       exvector new_dummy_indices;
+       new_dummy_indices.reserve(dummy_indices.size());
+       set_difference(dummy_indices.begin(), dummy_indices.end(), dummies_of_orig_ex.begin(), dummies_of_orig_ex.end(),
+               std::back_insert_iterator<exvector>(new_dummy_indices), ex_is_less());
+       dummy_indices.swap(new_dummy_indices);
+       ex newfactor = rename_dummy_indices_uniquely(dummy_indices, dummies_of_factor, p.rest);
+       update_dummy_indices(dummies_of_factor);
+       epv -> push_back(expair(newfactor, p.coeff));
+}
 
 /** Member-wise substitute in this sequence.
  *
@@ -1614,22 +1697,27 @@ std::auto_ptr<epvector> expairseq::subschildren(const exmap & m, unsigned option
                        if (!are_ex_trivially_equal(orig_ex, subsed_ex)) {
 
                                // Something changed, copy seq, subs and return it
-                               std::auto_ptr<epvector> s(new epvector);
-                               s->reserve(seq.size());
+                               safe_inserter s(*this, options & subs_options::no_index_renaming);
 
                                // Copy parts of seq which are known not to have changed
-                               s->insert(s->begin(), seq.begin(), cit);
+                               for(epvector::const_iterator i=seq.begin(); i!=cit; ++i)
+                                       s.insert_old_pair(*i);
 
                                // Copy first changed element
-                               s->push_back(split_ex_to_pair(subsed_ex));
+                               s.insert_new_pair(split_ex_to_pair(subsed_ex), orig_ex);
                                ++cit;
 
                                // Copy rest
                                while (cit != last) {
-                                       s->push_back(split_ex_to_pair(recombine_pair_to_ex(*cit).subs(m, options)));
+                                       ex orig_ex = recombine_pair_to_ex(*cit);
+                                       ex subsed_ex = orig_ex.subs(m, options);
+                                       if(are_ex_trivially_equal(orig_ex, subsed_ex))
+                                               s.insert_old_pair(*cit);
+                                       else
+                                               s.insert_new_pair(split_ex_to_pair(subsed_ex), orig_ex);
                                        ++cit;
                                }
-                               return s;
+                               return s.getseq();
                        }
 
                        ++cit;
@@ -1645,23 +1733,27 @@ std::auto_ptr<epvector> expairseq::subschildren(const exmap & m, unsigned option
                        if (!are_ex_trivially_equal(cit->rest, subsed_ex)) {
                        
                                // Something changed, copy seq, subs and return it
-                               std::auto_ptr<epvector> s(new epvector);
-                               s->reserve(seq.size());
+                               safe_inserter s(*this, options & subs_options::no_index_renaming);
 
                                // Copy parts of seq which are known not to have changed
-                               s->insert(s->begin(), seq.begin(), cit);
+                               for(epvector::const_iterator i=seq.begin(); i!=cit; ++i)
+                                       s.insert_old_pair(*i);
                        
                                // Copy first changed element
-                               s->push_back(combine_ex_with_coeff_to_pair(subsed_ex, cit->coeff));
+                               s.insert_new_pair(combine_ex_with_coeff_to_pair(subsed_ex, cit->coeff), cit->rest);
                                ++cit;
 
                                // Copy rest
                                while (cit != last) {
-                                       s->push_back(combine_ex_with_coeff_to_pair(cit->rest.subs(m, options),
-                                                                                  cit->coeff));
+                                       const ex &orig_ex = cit->rest;
+                                       const ex &subsed_ex = cit->rest.subs(m, options);
+                                       if(are_ex_trivially_equal(orig_ex, subsed_ex))
+                                               s.insert_old_pair(*cit);
+                                       else
+                                               s.insert_new_pair(combine_ex_with_coeff_to_pair(subsed_ex, cit->coeff), orig_ex);
                                        ++cit;
                                }
-                               return s;
+                               return s.getseq();
                        }
 
                        ++cit;