* Implementation of GiNaC's indexed expressions. */
/*
- * GiNaC Copyright (C) 1999-2002 Johannes Gutenberg University Mainz, Germany
+ * GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
*/
#include <iostream>
+#include <sstream>
#include <stdexcept>
#include "indexed.h"
// global functions
//////////
+struct idx_is_equal_ignore_dim : public std::binary_function<ex, ex, bool> {
+ bool operator() (const ex &lh, const ex &rh) const
+ {
+ if (lh.is_equal(rh))
+ return true;
+ else
+ try {
+ // Replacing the dimension might cause an error (e.g. with
+ // index classes that only work in a fixed number of dimensions)
+ return lh.is_equal(ex_to<idx>(rh).replace_dim(ex_to<idx>(lh).get_dim()));
+ } catch (...) {
+ return false;
+ }
+ }
+};
+
/** Check whether two sorted index vectors are consistent (i.e. equal). */
static bool indices_consistent(const exvector & v1, const exvector & v2)
{
if (v1.size() != v2.size())
return false;
- return equal(v1.begin(), v1.end(), v2.begin(), ex_is_equal());
+ return equal(v1.begin(), v1.end(), v2.begin(), idx_is_equal_ignore_dim());
}
exvector indexed::get_indices(void) const
* obtained during the simplification of sums. */
class symminfo {
public:
- symminfo() {}
+ symminfo() : num(0) {}
~symminfo() {}
- symminfo(const ex & symmterm_, const ex & orig_)
+ symminfo(const ex & symmterm_, const ex & orig_, unsigned num_) : orig(orig_), num(num_)
{
- if (is_exactly_a<mul>(orig_) && is_exactly_a<numeric>(orig_.op(orig_.nops()-1))) {
- ex tmp = orig_.op(orig_.nops()-1);
- orig = orig_ / tmp;
- } else
- orig = orig_;
-
if (is_exactly_a<mul>(symmterm_) && is_exactly_a<numeric>(symmterm_.op(symmterm_.nops()-1))) {
coeff = symmterm_.op(symmterm_.nops()-1);
symmterm = symmterm_ / coeff;
}
}
- symminfo(const symminfo & other)
- {
- symmterm = other.symmterm;
- coeff = other.coeff;
- orig = other.orig;
- }
+ ex symmterm; /**< symmetrized term */
+ ex coeff; /**< coefficient of symmetrized term */
+ ex orig; /**< original term */
+ unsigned num; /**< how many symmetrized terms resulted from the original term */
+};
- const symminfo & operator=(const symminfo & other)
+class symminfo_is_less_by_symmterm {
+public:
+ bool operator() (const symminfo & si1, const symminfo & si2) const
{
- if (this != &other) {
- symmterm = other.symmterm;
- coeff = other.coeff;
- orig = other.orig;
- }
- return *this;
+ int comp = si1.symmterm.compare(si2.symmterm);
+ if (comp < 0) return true;
+#if 0
+ if (comp > 0) return false;
+ comp = si1.orig.compare(si2.orig);
+ if (comp < 0) return true;
+ if (comp > 0) return false;
+ comp = si1.coeff.compare(si2.coeff);
+ if (comp < 0) return true;
+#endif
+ return false;
}
-
- ex symmterm;
- ex coeff;
- ex orig;
};
-class symminfo_is_less {
+class symminfo_is_less_by_orig {
public:
- bool operator() (const symminfo & si1, const symminfo & si2)
+ bool operator() (const symminfo & si1, const symminfo & si2) const
{
- int comp = si1.symmterm.compare(si2.symmterm);
+ int comp = si1.orig.compare(si2.orig);
if (comp < 0) return true;
+#if 0
if (comp > 0) return false;
- comp = si1.orig.compare(si2.orig);
+ comp = si1.symmterm.compare(si2.symmterm);
if (comp < 0) return true;
if (comp > 0) return false;
comp = si1.coeff.compare(si2.coeff);
if (comp < 0) return true;
+#endif
return false;
}
};
sum = term;
first = false;
} else {
- if (!indices_consistent(free_indices, free_indices_of_term))
- throw (std::runtime_error("simplify_indexed: inconsistent indices in sum"));
+ if (!indices_consistent(free_indices, free_indices_of_term)) {
+ std::ostringstream s;
+ s << "simplify_indexed: inconsistent indices in sum: ";
+ s << exprseq(free_indices) << " vs. " << exprseq(free_indices_of_term);
+ throw (std::runtime_error(s.str()));
+ }
if (is_ex_of_type(sum, indexed) && is_ex_of_type(term, indexed))
sum = ex_to<basic>(sum.op(0)).add_indexed(sum, term);
else
std::vector<symminfo> v;
for (int i=0; i<sum.nops(); i++) {
ex sum_symm = sum.op(i).symmetrize(dummy_syms);
+ if (sum_symm.is_zero())
+ continue;
if (is_exactly_a<add>(sum_symm))
for (int j=0; j<sum_symm.nops(); j++)
- v.push_back(symminfo(sum_symm.op(j), sum.op(i)));
+ v.push_back(symminfo(sum_symm.op(j), sum.op(i), sum_symm.nops()));
else
- v.push_back(symminfo(sum_symm, sum.op(i)));
+ v.push_back(symminfo(sum_symm, sum.op(i), 1));
}
// Now add up all the unsymmetrized versions of the terms that
// did not cancel out in the symmetrization
exvector result;
- std::sort(v.begin(), v.end(), symminfo_is_less());
- for (std::vector<symminfo>::iterator i=v.begin(); i!=v.end(); ) {
- std::vector<symminfo>::iterator j = i;
- for (j++; j!=v.end() && i->symmterm == j->symmterm; j++) ;
- for (std::vector<symminfo>::iterator k=i; k!=j; k++)
- result.push_back((k->coeff)*(i->orig));
+ std::sort(v.begin(), v.end(), symminfo_is_less_by_symmterm());
+
+ std::vector<symminfo> v_pass2;
+ for (std::vector<symminfo>::const_iterator i=v.begin(); i!=v.end(); ) {
+
+ // Combine equal terms
+ std::vector<symminfo>::const_iterator j = i + 1;
+ if (j != v.end() && j->symmterm == i->symmterm) {
+
+ // More than one term, collect the coefficients
+ ex coeff = i->coeff;
+ while (j != v.end() && j->symmterm == i->symmterm) {
+ coeff += j->coeff;
+ j++;
+ }
+
+ // Add combined term to result
+ if (!coeff.is_zero())
+ result.push_back(coeff * i->symmterm);
+
+ } else {
+
+ // Single term, store for second pass
+ v_pass2.push_back(*i);
+ }
+
i = j;
}
+
+ // Were there any remaining terms that didn't get combined?
+ if (v_pass2.size() > 0) {
+
+ // Yes, sort them by their original term
+ std::sort(v_pass2.begin(), v_pass2.end(), symminfo_is_less_by_orig());
+
+ for (std::vector<symminfo>::const_iterator i=v_pass2.begin(); i!=v_pass2.end(); ) {
+
+ // How many symmetrized terms of this original term are left?
+ unsigned num = 1;
+ std::vector<symminfo>::const_iterator j = i + 1;
+ while (j != v_pass2.end() && j->orig == i->orig) {
+ num++;
+ j++;
+ }
+
+ if (num == i->num) {
+
+ // All terms left, then add the original term to the result
+ result.push_back(i->orig);
+
+ } else {
+
+ // Some terms were combined with others, add up the remaining symmetrized terms
+ std::vector<symminfo>::const_iterator k;
+ for (k=i; k!=j; k++)
+ result.push_back(k->coeff * k->symmterm);
+ }
+
+ i = j;
+ }
+ }
+
+ // Add all resulting terms
ex sum_symm = (new add(result))->setflag(status_flags::dynallocated);
if (sum_symm.is_zero())
free_indices.clear();