even better dummy index symmetrization
[ginac.git] / ginac / indexed.cpp
index 9b4092e..18f4457 100644 (file)
@@ -830,10 +830,27 @@ contraction_done:
 
 /** This structure stores the original and symmetrized versions of terms
  *  obtained during the simplification of sums. */
+class terminfo {
+public:
+       terminfo(const ex & orig_, const ex & symm_) : orig(orig_), symm(symm_) {}
+
+       ex orig; /**< original term */
+       ex symm; /**< symmtrized term */
+};
+
+class terminfo_is_less {
+public:
+       bool operator() (const terminfo & ti1, const terminfo & ti2) const
+       {
+               return (ti1.symm.compare(ti2.symm) < 0);
+       }
+};
+
+/** This structure stores the individual symmetrized terms obtained during
+ *  the simplification of sums. */
 class symminfo {
 public:
        symminfo() : num(0) {}
-       ~symminfo() {}
 
        symminfo(const ex & symmterm_, const ex & orig_, unsigned num_) : orig(orig_), num(num_)
        {
@@ -856,17 +873,7 @@ class symminfo_is_less_by_symmterm {
 public:
        bool operator() (const symminfo & si1, const symminfo & si2) const
        {
-               int comp = si1.symmterm.compare(si2.symmterm);
-               if (comp < 0) return true;
-#if 0
-               if (comp > 0) return false;
-               comp = si1.orig.compare(si2.orig);
-               if (comp < 0) return true;
-               if (comp > 0) return false;
-               comp = si1.coeff.compare(si2.coeff);
-               if (comp < 0) return true;
-#endif
-               return false;
+               return (si1.symmterm.compare(si2.symmterm) < 0);
        }
 };
 
@@ -874,17 +881,7 @@ class symminfo_is_less_by_orig {
 public:
        bool operator() (const symminfo & si1, const symminfo & si2) const
        {
-               int comp = si1.orig.compare(si2.orig);
-               if (comp < 0) return true;
-#if 0
-               if (comp > 0) return false;
-               comp = si1.symmterm.compare(si2.symmterm);
-               if (comp < 0) return true;
-               if (comp > 0) return false;
-               comp = si1.coeff.compare(si2.coeff);
-               if (comp < 0) return true;
-#endif
-               return false;
+               return (si1.orig.compare(si2.orig) < 0);
        }
 };
 
@@ -923,7 +920,7 @@ ex simplify_indexed(const ex & e, exvector & free_indices, exvector & dummy_indi
        // free indices in each term
        if (is_ex_exactly_of_type(e_expanded, add)) {
                bool first = true;
-               ex sum = _ex0;
+               ex sum;
                free_indices.clear();
 
                for (unsigned i=0; i<e_expanded.nops(); i++) {
@@ -955,102 +952,128 @@ ex simplify_indexed(const ex & e, exvector & free_indices, exvector & dummy_indi
                        return sum;
                }
 
-               // Symmetrizing over the dummy indices may cancel terms
+               // More than one term and more than one dummy index?
                int num_terms_orig = (is_exactly_a<add>(sum) ? sum.nops() : 1);
-               if (num_terms_orig > 1 && dummy_indices.size() >= 2) {
-
-                       // Construct list of all dummy index symbols
-                       lst dummy_syms;
-                       for (int i=0; i<dummy_indices.size(); i++)
-                               dummy_syms.append(dummy_indices[i].op(0));
-
-                       // Symmetrize each term separately and store the resulting
-                       // terms in a list of symminfo structures
-                       std::vector<symminfo> v;
-                       for (int i=0; i<sum.nops(); i++) {
-                               ex sum_symm = sum.op(i).symmetrize(dummy_syms);
-                               if (sum_symm.is_zero())
-                                       continue;
-                               if (is_exactly_a<add>(sum_symm))
-                                       for (int j=0; j<sum_symm.nops(); j++)
-                                               v.push_back(symminfo(sum_symm.op(j), sum.op(i), sum_symm.nops()));
-                               else
-                                       v.push_back(symminfo(sum_symm, sum.op(i), 1));
-                       }
+               if (num_terms_orig < 2 || dummy_indices.size() < 2)
+                       return sum;
 
-                       // Now add up all the unsymmetrized versions of the terms that
-                       // did not cancel out in the symmetrization
-                       exvector result;
-                       std::sort(v.begin(), v.end(), symminfo_is_less_by_symmterm());
+               // Yes, construct list of all dummy index symbols
+               lst dummy_syms;
+               for (int i=0; i<dummy_indices.size(); i++)
+                       dummy_syms.append(dummy_indices[i].op(0));
+
+               // Chop the sum into terms and symmetrize each one over the dummy
+               // indices
+               std::vector<terminfo> terms;
+               for (unsigned i=0; i<sum.nops(); i++) {
+                       const ex & term = sum.op(i);
+                       ex term_symm = term.symmetrize(dummy_syms);
+                       if (term_symm.is_zero())
+                               continue;
+                       terms.push_back(terminfo(term, term_symm));
+               }
 
-                       std::vector<symminfo> v_pass2;
-                       for (std::vector<symminfo>::const_iterator i=v.begin(); i!=v.end(); ) {
+               // Sort by symmetrized terms
+               std::sort(terms.begin(), terms.end(), terminfo_is_less());
+
+               // Combine equal symmetrized terms
+               std::vector<terminfo> terms_pass2;
+               for (std::vector<terminfo>::const_iterator i=terms.begin(); i!=terms.end(); ) {
+                       unsigned num = 1;
+                       std::vector<terminfo>::const_iterator j = i + 1;
+                       while (j != terms.end() && j->symm == i->symm) {
+                               num++;
+                               j++;
+                       }
+                       terms_pass2.push_back(terminfo(i->orig * num, i->symm * num));
+                       i = j;
+               }
 
-                               // Combine equal terms
-                               std::vector<symminfo>::const_iterator j = i + 1;
-                               if (j != v.end() && j->symmterm == i->symmterm) {
+               // If there is only one term left, we are finished
+               if (terms_pass2.size() == 1)
+                       return terms_pass2[0].orig;
+
+               // Chop the symmetrized terms into subterms
+               std::vector<symminfo> sy;
+               for (std::vector<terminfo>::const_iterator i=terms_pass2.begin(); i!=terms_pass2.end(); ++i) {
+                       if (is_exactly_a<add>(i->symm)) {
+                               unsigned num = i->symm.nops();
+                               for (unsigned j=0; j<num; j++)
+                                       sy.push_back(symminfo(i->symm.op(j), i->orig, num));
+                       } else
+                               sy.push_back(symminfo(i->symm, i->orig, 1));
+               }
 
-                                       // More than one term, collect the coefficients
-                                       ex coeff = i->coeff;
-                                       while (j != v.end() && j->symmterm == i->symmterm) {
-                                               coeff += j->coeff;
-                                               j++;
-                                       }
+               // Sort by symmetrized subterms
+               std::sort(sy.begin(), sy.end(), symminfo_is_less_by_symmterm());
 
-                                       // Add combined term to result
-                                       if (!coeff.is_zero())
-                                               result.push_back(coeff * i->symmterm);
+               // Combine equal symmetrized subterms
+               std::vector<symminfo> sy_pass2;
+               exvector result;
+               for (std::vector<symminfo>::const_iterator i=sy.begin(); i!=sy.end(); ) {
 
-                               } else {
+                       // Combine equal terms
+                       std::vector<symminfo>::const_iterator j = i + 1;
+                       if (j != sy.end() && j->symmterm == i->symmterm) {
 
-                                       // Single term, store for second pass
-                                       v_pass2.push_back(*i);
+                               // More than one term, collect the coefficients
+                               ex coeff = i->coeff;
+                               while (j != sy.end() && j->symmterm == i->symmterm) {
+                                       coeff += j->coeff;
+                                       j++;
                                }
 
-                               i = j;
+                               // Add combined term to result
+                               if (!coeff.is_zero())
+                                       result.push_back(coeff * i->symmterm);
+
+                       } else {
+
+                               // Single term, store for second pass
+                               sy_pass2.push_back(*i);
                        }
 
-                       // Were there any remaining terms that didn't get combined?
-                       if (v_pass2.size() > 0) {
+                       i = j;
+               }
 
-                               // Yes, sort them by their original term
-                               std::sort(v_pass2.begin(), v_pass2.end(), symminfo_is_less_by_orig());
+               // Were there any remaining terms that didn't get combined?
+               if (sy_pass2.size() > 0) {
 
-                               for (std::vector<symminfo>::const_iterator i=v_pass2.begin(); i!=v_pass2.end(); ) {
+                       // Yes, sort by their original terms
+                       std::sort(sy_pass2.begin(), sy_pass2.end(), symminfo_is_less_by_orig());
 
-                                       // How many symmetrized terms of this original term are left?
-                                       unsigned num = 1;
-                                       std::vector<symminfo>::const_iterator j = i + 1;
-                                       while (j != v_pass2.end() && j->orig == i->orig) {
-                                               num++;
-                                               j++;
-                                       }
+                       for (std::vector<symminfo>::const_iterator i=sy_pass2.begin(); i!=sy_pass2.end(); ) {
 
-                                       if (num == i->num) {
+                               // How many symmetrized terms of this original term are left?
+                               unsigned num = 1;
+                               std::vector<symminfo>::const_iterator j = i + 1;
+                               while (j != sy_pass2.end() && j->orig == i->orig) {
+                                       num++;
+                                       j++;
+                               }
 
-                                               // All terms left, then add the original term to the result
-                                               result.push_back(i->orig);
+                               if (num == i->num) {
 
-                                       } else {
+                                       // All terms left, then add the original term to the result
+                                       result.push_back(i->orig);
 
-                                               // Some terms were combined with others, add up the remaining symmetrized terms
-                                               std::vector<symminfo>::const_iterator k;
-                                               for (k=i; k!=j; k++)
-                                                       result.push_back(k->coeff * k->symmterm);
-                                       }
+                               } else {
 
-                                       i = j;
+                                       // Some terms were combined with others, add up the remaining symmetrized terms
+                                       std::vector<symminfo>::const_iterator k;
+                                       for (k=i; k!=j; k++)
+                                               result.push_back(k->coeff * k->symmterm);
                                }
-                       }
 
-                       // Add all resulting terms
-                       ex sum_symm = (new add(result))->setflag(status_flags::dynallocated);
-                       if (sum_symm.is_zero())
-                               free_indices.clear();
-                       return sum_symm;
+                               i = j;
+                       }
                }
 
-               return sum;
+               // Add all resulting terms
+               ex sum_symm = (new add(result))->setflag(status_flags::dynallocated);
+               if (sum_symm.is_zero())
+                       free_indices.clear();
+               return sum_symm;
        }
 
        // Simplification of products