* Implementation of GiNaC's indexed expressions. */
/*
- * GiNaC Copyright (C) 1999-2002 Johannes Gutenberg University Mainz, Germany
+ * GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
*/
#include <iostream>
+#include <sstream>
#include <stdexcept>
#include "indexed.h"
// global functions
//////////
+struct idx_is_equal_ignore_dim : public std::binary_function<ex, ex, bool> {
+ bool operator() (const ex &lh, const ex &rh) const
+ {
+ if (lh.is_equal(rh))
+ return true;
+ else
+ try {
+ // Replacing the dimension might cause an error (e.g. with
+ // index classes that only work in a fixed number of dimensions)
+ return lh.is_equal(ex_to<idx>(rh).replace_dim(ex_to<idx>(lh).get_dim()));
+ } catch (...) {
+ return false;
+ }
+ }
+};
+
/** Check whether two sorted index vectors are consistent (i.e. equal). */
static bool indices_consistent(const exvector & v1, const exvector & v2)
{
if (v1.size() != v2.size())
return false;
- return equal(v1.begin(), v1.end(), v2.begin(), ex_is_equal());
+ return equal(v1.begin(), v1.end(), v2.begin(), idx_is_equal_ignore_dim());
}
exvector indexed::get_indices(void) const
int remaining = local_size - global_size;
exvector::const_iterator it = local_dummy_indices.begin(), itend = local_dummy_indices.end();
while (it != itend && remaining > 0) {
- if (find_if(global_dummy_indices.begin(), global_dummy_indices.end(), bind2nd(ex_is_equal(), *it)) == global_dummy_indices.end()) {
+ if (find_if(global_dummy_indices.begin(), global_dummy_indices.end(), bind2nd(op0_is_equal(), *it)) == global_dummy_indices.end()) {
global_dummy_indices.push_back(*it);
global_size++;
remaining--;
for (unsigned i=0; i<local_size; i++)
local_syms.push_back(local_dummy_indices[i].op(0));
shaker_sort(local_syms.begin(), local_syms.end(), ex_is_less(), ex_swap());
- for (unsigned i=0; i<global_size; i++)
+ for (unsigned i=0; i<local_size; i++) // don't use more global symbols than necessary
global_syms.push_back(global_dummy_indices[i].op(0));
shaker_sort(global_syms.begin(), global_syms.end(), ex_is_less(), ex_swap());
/** This structure stores the original and symmetrized versions of terms
* obtained during the simplification of sums. */
-class symminfo {
+class terminfo {
public:
- symminfo() {}
- ~symminfo() {}
+ terminfo(const ex & orig_, const ex & symm_) : orig(orig_), symm(symm_) {}
- symminfo(const ex & symmterm_, const ex & orig_)
+ ex orig; /**< original term */
+ ex symm; /**< symmtrized term */
+};
+
+class terminfo_is_less {
+public:
+ bool operator() (const terminfo & ti1, const terminfo & ti2) const
{
- if (is_a<mul>(orig_) && is_a<numeric>(orig_.op(orig_.nops()-1))) {
- ex tmp = orig_.op(orig_.nops()-1);
- orig = orig_ / tmp;
- } else
- orig = orig_;
+ return (ti1.symm.compare(ti2.symm) < 0);
+ }
+};
+
+/** This structure stores the individual symmetrized terms obtained during
+ * the simplification of sums. */
+class symminfo {
+public:
+ symminfo() : num(0) {}
- if (is_a<mul>(symmterm_) && is_a<numeric>(symmterm_.op(symmterm_.nops()-1))) {
+ symminfo(const ex & symmterm_, const ex & orig_, unsigned num_) : orig(orig_), num(num_)
+ {
+ if (is_exactly_a<mul>(symmterm_) && is_exactly_a<numeric>(symmterm_.op(symmterm_.nops()-1))) {
coeff = symmterm_.op(symmterm_.nops()-1);
symmterm = symmterm_ / coeff;
} else {
}
}
- symminfo(const symminfo & other)
- {
- symmterm = other.symmterm;
- coeff = other.coeff;
- orig = other.orig;
- }
+ ex symmterm; /**< symmetrized term */
+ ex coeff; /**< coefficient of symmetrized term */
+ ex orig; /**< original term */
+ unsigned num; /**< how many symmetrized terms resulted from the original term */
+};
- const symminfo & operator=(const symminfo & other)
+class symminfo_is_less_by_symmterm {
+public:
+ bool operator() (const symminfo & si1, const symminfo & si2) const
{
- if (this != &other) {
- symmterm = other.symmterm;
- coeff = other.coeff;
- orig = other.orig;
- }
- return *this;
+ return (si1.symmterm.compare(si2.symmterm) < 0);
}
-
- ex symmterm;
- ex coeff;
- ex orig;
};
-class symminfo_is_less {
+class symminfo_is_less_by_orig {
public:
- bool operator() (const symminfo & si1, const symminfo & si2)
+ bool operator() (const symminfo & si1, const symminfo & si2) const
{
- int comp = si1.symmterm.compare(si2.symmterm);
- if (comp < 0) return true;
- if (comp > 0) return false;
- comp = si1.orig.compare(si2.orig);
- if (comp < 0) return true;
- if (comp > 0) return false;
- comp = si1.coeff.compare(si2.coeff);
- if (comp < 0) return true;
- return false;
+ return (si1.orig.compare(si2.orig) < 0);
}
};
// free indices in each term
if (is_ex_exactly_of_type(e_expanded, add)) {
bool first = true;
- ex sum = _ex0;
+ ex sum;
free_indices.clear();
for (unsigned i=0; i<e_expanded.nops(); i++) {
sum = term;
first = false;
} else {
- if (!indices_consistent(free_indices, free_indices_of_term))
- throw (std::runtime_error("simplify_indexed: inconsistent indices in sum"));
+ if (!indices_consistent(free_indices, free_indices_of_term)) {
+ std::ostringstream s;
+ s << "simplify_indexed: inconsistent indices in sum: ";
+ s << exprseq(free_indices) << " vs. " << exprseq(free_indices_of_term);
+ throw (std::runtime_error(s.str()));
+ }
if (is_ex_of_type(sum, indexed) && is_ex_of_type(term, indexed))
sum = ex_to<basic>(sum.op(0)).add_indexed(sum, term);
else
return sum;
}
- // Symmetrizing over the dummy indices may cancel terms
- int num_terms_orig = (is_a<add>(sum) ? sum.nops() : 1);
- if (num_terms_orig > 1 && dummy_indices.size() >= 2) {
-
- // Construct list of all dummy index symbols
- lst dummy_syms;
- for (int i=0; i<dummy_indices.size(); i++)
- dummy_syms.append(dummy_indices[i].op(0));
-
- // Symmetrize each term separately and store the resulting
- // terms in a list of symminfo structures
- std::vector<symminfo> v;
- for (int i=0; i<sum.nops(); i++) {
- ex sum_symm = sum.op(i).symmetrize(dummy_syms);
- if (is_a<add>(sum_symm))
- for (int j=0; j<sum_symm.nops(); j++)
- v.push_back(symminfo(sum_symm.op(j), sum.op(i)));
- else
- v.push_back(symminfo(sum_symm, sum.op(i)));
+ // More than one term and more than one dummy index?
+ int num_terms_orig = (is_exactly_a<add>(sum) ? sum.nops() : 1);
+ if (num_terms_orig < 2 || dummy_indices.size() < 2)
+ return sum;
+
+ // Yes, construct list of all dummy index symbols
+ lst dummy_syms;
+ for (int i=0; i<dummy_indices.size(); i++)
+ dummy_syms.append(dummy_indices[i].op(0));
+
+ // Chop the sum into terms and symmetrize each one over the dummy
+ // indices
+ std::vector<terminfo> terms;
+ for (unsigned i=0; i<sum.nops(); i++) {
+ const ex & term = sum.op(i);
+ ex term_symm = term.symmetrize(dummy_syms);
+ if (term_symm.is_zero())
+ continue;
+ terms.push_back(terminfo(term, term_symm));
+ }
+
+ // Sort by symmetrized terms
+ std::sort(terms.begin(), terms.end(), terminfo_is_less());
+
+ // Combine equal symmetrized terms
+ std::vector<terminfo> terms_pass2;
+ for (std::vector<terminfo>::const_iterator i=terms.begin(); i!=terms.end(); ) {
+ unsigned num = 1;
+ std::vector<terminfo>::const_iterator j = i + 1;
+ while (j != terms.end() && j->symm == i->symm) {
+ num++;
+ j++;
}
+ terms_pass2.push_back(terminfo(i->orig * num, i->symm * num));
+ i = j;
+ }
+
+ // If there is only one term left, we are finished
+ if (terms_pass2.size() == 1)
+ return terms_pass2[0].orig;
+
+ // Chop the symmetrized terms into subterms
+ std::vector<symminfo> sy;
+ for (std::vector<terminfo>::const_iterator i=terms_pass2.begin(); i!=terms_pass2.end(); ++i) {
+ if (is_exactly_a<add>(i->symm)) {
+ unsigned num = i->symm.nops();
+ for (unsigned j=0; j<num; j++)
+ sy.push_back(symminfo(i->symm.op(j), i->orig, num));
+ } else
+ sy.push_back(symminfo(i->symm, i->orig, 1));
+ }
+
+ // Sort by symmetrized subterms
+ std::sort(sy.begin(), sy.end(), symminfo_is_less_by_symmterm());
+
+ // Combine equal symmetrized subterms
+ std::vector<symminfo> sy_pass2;
+ exvector result;
+ for (std::vector<symminfo>::const_iterator i=sy.begin(); i!=sy.end(); ) {
+
+ // Combine equal terms
+ std::vector<symminfo>::const_iterator j = i + 1;
+ if (j != sy.end() && j->symmterm == i->symmterm) {
+
+ // More than one term, collect the coefficients
+ ex coeff = i->coeff;
+ while (j != sy.end() && j->symmterm == i->symmterm) {
+ coeff += j->coeff;
+ j++;
+ }
+
+ // Add combined term to result
+ if (!coeff.is_zero())
+ result.push_back(coeff * i->symmterm);
+
+ } else {
+
+ // Single term, store for second pass
+ sy_pass2.push_back(*i);
+ }
+
+ i = j;
+ }
+
+ // Were there any remaining terms that didn't get combined?
+ if (sy_pass2.size() > 0) {
+
+ // Yes, sort by their original terms
+ std::sort(sy_pass2.begin(), sy_pass2.end(), symminfo_is_less_by_orig());
+
+ for (std::vector<symminfo>::const_iterator i=sy_pass2.begin(); i!=sy_pass2.end(); ) {
+
+ // How many symmetrized terms of this original term are left?
+ unsigned num = 1;
+ std::vector<symminfo>::const_iterator j = i + 1;
+ while (j != sy_pass2.end() && j->orig == i->orig) {
+ num++;
+ j++;
+ }
+
+ if (num == i->num) {
+
+ // All terms left, then add the original term to the result
+ result.push_back(i->orig);
+
+ } else {
+
+ // Some terms were combined with others, add up the remaining symmetrized terms
+ std::vector<symminfo>::const_iterator k;
+ for (k=i; k!=j; k++)
+ result.push_back(k->coeff * k->symmterm);
+ }
- // Now add up all the unsymmetrized versions of the terms that
- // did not cancel out in the symmetrization
- exvector result;
- std::sort(v.begin(), v.end(), symminfo_is_less());
- for (std::vector<symminfo>::iterator i=v.begin(); i!=v.end(); ) {
- std::vector<symminfo>::iterator j = i;
- for (j++; j!=v.end() && i->symmterm == j->symmterm; j++) ;
- for (std::vector<symminfo>::iterator k=i; k!=j; k++)
- result.push_back((k->coeff)*(i->orig));
i = j;
}
- ex sum_symm = (new add(result))->setflag(status_flags::dynallocated);
- if (sum_symm.is_zero())
- free_indices.clear();
- return sum_symm;
}
- return sum;
+ // Add all resulting terms
+ ex sum_symm = (new add(result))->setflag(status_flags::dynallocated);
+ if (sum_symm.is_zero())
+ free_indices.clear();
+ return sum_symm;
}
// Simplification of products