]> www.ginac.de Git - ginac.git/blobdiff - ginac/indexed.cpp
improved dummy index symmetrization in sums [Chris Dams]
[ginac.git] / ginac / indexed.cpp
index d9b2f474fc8266e7e96a87ee483c1f5eace86352..a52d6d77dc52b58b3e44f19531faac6ee7cc749e 100644 (file)
@@ -811,6 +811,68 @@ contraction_done:
                return r;
 }
 
+/** This structure stores the original and symmetrized versions of terms
+ *  obtained during the simplification of sums. */
+class symminfo {
+public:
+       symminfo() {}
+       ~symminfo() {}
+
+       symminfo(const ex & symmterm_, const ex & orig_)
+       {
+               if (is_a<mul>(orig_)) {
+                       ex tmp = orig_.op(orig_.nops()-1);
+                       orig = orig_ / tmp;
+               } else 
+                       orig = orig_;
+
+               if (is_a<mul>(symmterm_)) {
+                       coeff = symmterm_.op(symmterm_.nops()-1);
+                       symmterm = symmterm_ / coeff;
+               } else {
+                       coeff = 1;
+                       symmterm = symmterm_;
+               }
+       }
+
+       symminfo(const symminfo & other)
+       {
+               symmterm = other.symmterm;
+               coeff = other.coeff;
+               orig = other.orig;
+       }
+
+       const symminfo & operator=(const symminfo & other)
+       {
+               if (this != &other) {
+                       symmterm = other.symmterm;
+                       coeff = other.coeff;
+                       orig = other.orig;
+               }
+               return *this;
+       }
+
+       ex symmterm;
+       ex coeff;
+       ex orig;
+};
+
+class symminfo_is_less {
+public:
+       bool operator() (const symminfo & si1, const symminfo & si2)
+       {
+               int comp = si1.symmterm.compare(si2.symmterm);
+               if (comp < 0) return true;
+               if (comp > 0) return false;
+               comp = si1.orig.compare(si2.orig);
+               if (comp < 0) return true;
+               if (comp > 0) return false;
+               comp = si1.coeff.compare(si2.coeff);
+               if (comp < 0) return true;
+               return false;
+       }
+};
+
 /** Simplify indexed expression, return list of free indices. */
 ex simplify_indexed(const ex & e, exvector & free_indices, exvector & dummy_indices, const scalar_products & sp)
 {
@@ -868,6 +930,7 @@ ex simplify_indexed(const ex & e, exvector & free_indices, exvector & dummy_indi
                        }
                }
 
+               // If the sum turns out to be zero, we are finished
                if (sum.is_zero()) {
                        free_indices.clear();
                        return sum;
@@ -876,17 +939,39 @@ ex simplify_indexed(const ex & e, exvector & free_indices, exvector & dummy_indi
                // Symmetrizing over the dummy indices may cancel terms
                int num_terms_orig = (is_a<add>(sum) ? sum.nops() : 1);
                if (num_terms_orig > 1 && dummy_indices.size() >= 2) {
+
+                       // Construct list of all dummy index symbols
                        lst dummy_syms;
                        for (int i=0; i<dummy_indices.size(); i++)
                                dummy_syms.append(dummy_indices[i].op(0));
-                       ex sum_symm = sum.symmetrize(dummy_syms);
-                       if (sum_symm.is_zero()) {
-                               free_indices.clear();
-                               return _ex0;
+
+                       // Symmetrize each term separately and store the resulting
+                       // terms in a list of symminfo structures
+                       std::vector<symminfo> v;
+                       for (int i=0; i<sum.nops(); i++) {
+                               ex sum_symm = sum.op(i).symmetrize(dummy_syms);
+                               if (is_a<add>(sum_symm))
+                                       for (int j=0; j<sum_symm.nops(); j++)
+                                               v.push_back(symminfo(sum_symm.op(j), sum.op(i)));
+                               else
+                                       v.push_back(symminfo(sum_symm, sum.op(i)));
                        }
-                       int num_terms = (is_a<add>(sum_symm) ? sum_symm.nops() : 1);
-                       if (num_terms < num_terms_orig)
-                               return sum_symm;
+
+                       // Now add up all the unsymmetrized versions of the terms that
+                       // did not cancel out in the symmetrization
+                       exvector result;
+                       std::sort(v.begin(), v.end(), symminfo_is_less());
+                       for (std::vector<symminfo>::iterator i=v.begin(); i!=v.end(); ) {
+                               std::vector<symminfo>::iterator j = i;
+                               for (j++; j!=v.end() && i->symmterm == j->symmterm; j++) ;
+                               for (std::vector<symminfo>::iterator k=i; k!=j; k++)
+                                       result.push_back((k->coeff)*(i->orig));
+                               i = j;
+                       }
+                       ex sum_symm = (new add(result))->setflag(status_flags::dynallocated);
+                       if (sum_symm.is_zero())
+                               free_indices.clear();
+                       return sum_symm;
                }
 
                return sum;