]> www.ginac.de Git - ginac.git/blobdiff - ginac/indexed.cpp
- New tinfo mechanism
[ginac.git] / ginac / indexed.cpp
index 64eba6d08cf1cccf6dbf455cbcce4c01356dd3e7..b8b70f7e9d5ae853649c85c709fb895182aaf91d 100644 (file)
@@ -3,7 +3,7 @@
  *  Implementation of GiNaC's indexed expressions. */
 
 /*
- *  GiNaC Copyright (C) 1999-2005 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2006 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -53,7 +53,7 @@ GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(indexed, exprseq,
 
 indexed::indexed() : symtree(not_symmetric())
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
 }
 
 //////////
@@ -62,79 +62,79 @@ indexed::indexed() : symtree(not_symmetric())
 
 indexed::indexed(const ex & b) : inherited(b), symtree(not_symmetric())
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
        validate();
 }
 
 indexed::indexed(const ex & b, const ex & i1) : inherited(b, i1), symtree(not_symmetric())
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
        validate();
 }
 
 indexed::indexed(const ex & b, const ex & i1, const ex & i2) : inherited(b, i1, i2), symtree(not_symmetric())
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
        validate();
 }
 
 indexed::indexed(const ex & b, const ex & i1, const ex & i2, const ex & i3) : inherited(b, i1, i2, i3), symtree(not_symmetric())
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
        validate();
 }
 
 indexed::indexed(const ex & b, const ex & i1, const ex & i2, const ex & i3, const ex & i4) : inherited(b, i1, i2, i3, i4), symtree(not_symmetric())
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
        validate();
 }
 
 indexed::indexed(const ex & b, const symmetry & symm, const ex & i1, const ex & i2) : inherited(b, i1, i2), symtree(symm)
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
        validate();
 }
 
 indexed::indexed(const ex & b, const symmetry & symm, const ex & i1, const ex & i2, const ex & i3) : inherited(b, i1, i2, i3), symtree(symm)
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
        validate();
 }
 
 indexed::indexed(const ex & b, const symmetry & symm, const ex & i1, const ex & i2, const ex & i3, const ex & i4) : inherited(b, i1, i2, i3, i4), symtree(symm)
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
        validate();
 }
 
 indexed::indexed(const ex & b, const exvector & v) : inherited(b), symtree(not_symmetric())
 {
        seq.insert(seq.end(), v.begin(), v.end());
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
        validate();
 }
 
 indexed::indexed(const ex & b, const symmetry & symm, const exvector & v) : inherited(b), symtree(symm)
 {
        seq.insert(seq.end(), v.begin(), v.end());
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
        validate();
 }
 
 indexed::indexed(const symmetry & symm, const exprseq & es) : inherited(es), symtree(symm)
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
 }
 
 indexed::indexed(const symmetry & symm, const exvector & v, bool discardable) : inherited(v, discardable), symtree(symm)
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
 }
 
 indexed::indexed(const symmetry & symm, std::auto_ptr<exvector> vp) : inherited(vp), symtree(symm)
 {
-       tinfo_key = TINFO_indexed;
+       tinfo_key = &indexed::tinfo_static;
 }
 
 //////////
@@ -297,7 +297,7 @@ ex indexed::eval(int level) const
                return f * thiscontainer(v);
        }
 
-       if(this->tinfo()==TINFO_indexed && seq.size()==1)
+       if(this->tinfo()==&indexed::tinfo_static && seq.size()==1)
                return base;
 
        // Canonicalize indices according to the symmetry properties
@@ -520,22 +520,6 @@ struct is_summation_idx : public std::unary_function<ex, bool> {
        }
 };
 
-exvector power::get_free_indices() const
-{
-       // Get free indices of basis
-       exvector basis_indices = basis.get_free_indices();
-
-       if (exponent.info(info_flags::even)) {
-               // If the exponent is an even number, then any "free" index that
-               // forms a dummy pair with itself is actually a summation index
-               exvector really_free;
-               std::remove_copy_if(basis_indices.begin(), basis_indices.end(),
-                                   std::back_inserter(really_free), is_summation_idx());
-               return really_free;
-       } else
-               return basis_indices;
-}
-
 exvector integral::get_free_indices() const
 {
        if (a.get_free_indices().size() || b.get_free_indices().size())
@@ -738,6 +722,9 @@ template<class T> ex idx_symmetrization(const ex& r,const exvector& local_dummy_
        return q;
 }
 
+// Forward declaration needed in absence of friend injection, C.f. [namespace.memdef]:
+ex simplify_indexed(const ex & e, exvector & free_indices, exvector & dummy_indices, const scalar_products & sp);
+
 /** Simplify product of indexed expressions (commutative, noncommutative and
  *  simple squares), return list of free indices. */
 ex simplify_indexed_product(const ex & e, exvector & free_indices, exvector & dummy_indices, const scalar_products & sp)
@@ -790,20 +777,12 @@ try_again:
 
                        // At least one dummy index, is it a defined scalar product?
                        bool contracted = false;
-                       if (free.empty()) {
-
-                               // Find minimal dimension of all indices of both factors
-                               exvector::const_iterator dit = ex_to<indexed>(*it1).seq.begin() + 1, ditend = ex_to<indexed>(*it1).seq.end();
-                               ex dim = ex_to<idx>(*dit).get_dim();
-                               ++dit;
-                               for (; dit != ditend; ++dit) {
-                                       dim = minimal_dim(dim, ex_to<idx>(*dit).get_dim());
-                               }
-                               dit = ex_to<indexed>(*it2).seq.begin() + 1;
-                               ditend = ex_to<indexed>(*it2).seq.end();
-                               for (; dit != ditend; ++dit) {
-                                       dim = minimal_dim(dim, ex_to<idx>(*dit).get_dim());
-                               }
+                       if (free.empty() && it1->nops()==2 && it2->nops()==2) {
+
+                               ex dim = minimal_dim(
+                                       ex_to<idx>(it1->op(1)).get_dim(),
+                                       ex_to<idx>(it2->op(1)).get_dim()
+                               );
 
                                // User-defined scalar product?
                                if (sp.is_defined(*it1, *it2, dim)) {
@@ -1348,6 +1327,46 @@ void scalar_products::debugprint() const
        }
 }
 
+exvector get_all_dummy_indices_safely(const ex & e)
+{
+       if (is_a<indexed>(e))
+               return ex_to<indexed>(e).get_dummy_indices();
+       else if (is_a<power>(e) && e.op(1)==2) {
+               return e.op(0).get_free_indices();
+       }       
+       else if (is_a<mul>(e) || is_a<ncmul>(e)) {
+               exvector dummies;
+               exvector free_indices;
+               for (int i=0; i<e.nops(); ++i) {
+                       exvector dummies_of_factor = get_all_dummy_indices_safely(e.op(i));
+                       dummies.insert(dummies.end(), dummies_of_factor.begin(),
+                               dummies_of_factor.end());
+                       exvector free_of_factor = e.op(i).get_free_indices();
+                       free_indices.insert(free_indices.begin(), free_of_factor.begin(),
+                               free_of_factor.end());
+               }
+               exvector free_out, dummy_out;
+               find_free_and_dummy(free_indices.begin(), free_indices.end(), free_out,
+                       dummy_out);
+               dummies.insert(dummies.end(), dummy_out.begin(), dummy_out.end());
+               return dummies;
+       }
+       else if(is_a<add>(e)) {
+               exvector result;
+               for(int i=0; i<e.nops(); ++i) {
+                       exvector dummies_of_term = get_all_dummy_indices_safely(e.op(i));
+                       sort(dummies_of_term.begin(), dummies_of_term.end());
+                       exvector new_vec;
+                       set_union(result.begin(), result.end(), dummies_of_term.begin(),
+                               dummies_of_term.end(), std::back_inserter<exvector>(new_vec),
+                               ex_is_less());
+                       result.swap(new_vec);
+               }
+               return result;
+       }
+       return exvector();
+}
+
 /** Returns all dummy indices from the exvector */
 exvector get_all_dummy_indices(const ex & e)
 {
@@ -1374,12 +1393,12 @@ exvector get_all_dummy_indices(const ex & e)
        return v;
 }
 
-ex rename_dummy_indices_uniquely(const exvector & va, const exvector & vb, const ex & b)
+lst rename_dummy_indices_uniquely(const exvector & va, const exvector & vb)
 {
        exvector common_indices;
        set_intersection(va.begin(), va.end(), vb.begin(), vb.end(), std::back_insert_iterator<exvector>(common_indices), ex_is_less());
        if (common_indices.empty()) {
-               return b;
+               return lst(lst(), lst());
        } else {
                exvector new_indices, old_indices;
                old_indices.reserve(2*common_indices.size());
@@ -1408,17 +1427,57 @@ ex rename_dummy_indices_uniquely(const exvector & va, const exvector & vb, const
                        }
                        ++ip;
                }
-               return b.subs(lst(old_indices.begin(), old_indices.end()), lst(new_indices.begin(), new_indices.end()), subs_options::no_pattern);
+               return lst(lst(old_indices.begin(), old_indices.end()), lst(new_indices.begin(), new_indices.end()));
        }
 }
 
+ex rename_dummy_indices_uniquely(const exvector & va, const exvector & vb, const ex & b)
+{
+       lst indices_subs = rename_dummy_indices_uniquely(va, vb);
+       return (indices_subs.op(0).nops()>0 ? b.subs(ex_to<lst>(indices_subs.op(0)), ex_to<lst>(indices_subs.op(1)), subs_options::no_pattern|subs_options::no_index_renaming) : b);
+}
+
 ex rename_dummy_indices_uniquely(const ex & a, const ex & b)
 {
-       exvector va = get_all_dummy_indices(a);
-       exvector vb = get_all_dummy_indices(b);
-       sort(va.begin(), va.end(), ex_is_less());
-       sort(vb.begin(), vb.end(), ex_is_less());
-       return rename_dummy_indices_uniquely(va, vb, b);
+       exvector va = get_all_dummy_indices_safely(a);
+       if (va.size() > 0) {
+               exvector vb = get_all_dummy_indices_safely(b);
+               if (vb.size() > 0) {
+                       sort(va.begin(), va.end(), ex_is_less());
+                       sort(vb.begin(), vb.end(), ex_is_less());
+                       lst indices_subs = rename_dummy_indices_uniquely(va, vb);
+                       if (indices_subs.op(0).nops() > 0)
+                               return b.subs(ex_to<lst>(indices_subs.op(0)), ex_to<lst>(indices_subs.op(1)), subs_options::no_pattern|subs_options::no_index_renaming);
+               }
+       }
+       return b;
+}
+
+ex rename_dummy_indices_uniquely(exvector & va, const ex & b, bool modify_va)
+{
+       if (va.size() > 0) {
+               exvector vb = get_all_dummy_indices_safely(b);
+               if (vb.size() > 0) {
+                       sort(vb.begin(), vb.end(), ex_is_less());
+                       lst indices_subs = rename_dummy_indices_uniquely(va, vb);
+                       if (indices_subs.op(0).nops() > 0) {
+                               if (modify_va) {
+                                       for (lst::const_iterator i = ex_to<lst>(indices_subs.op(1)).begin(); i != ex_to<lst>(indices_subs.op(1)).end(); ++i)
+                                               va.push_back(*i);
+                                       exvector uncommon_indices;
+                                       set_difference(vb.begin(), vb.end(), indices_subs.op(0).begin(), indices_subs.op(0).end(), std::back_insert_iterator<exvector>(uncommon_indices), ex_is_less());
+                                       exvector::const_iterator ip = uncommon_indices.begin(), ipend = uncommon_indices.end();
+                                       while (ip != ipend) {
+                                               va.push_back(*ip);
+                                               ++ip;
+                                       }
+                                       sort(va.begin(), va.end(), ex_is_less());
+                               }
+                               return b.subs(ex_to<lst>(indices_subs.op(0)), ex_to<lst>(indices_subs.op(1)), subs_options::no_pattern|subs_options::no_index_renaming);
+                       }
+               }
+       }
+       return b;
 }
 
 ex expand_dummy_sum(const ex & e, bool subs_idx)