GiNaC  1.6.2
indexed.cpp
Go to the documentation of this file.
00001 
00005 /*
00006  *  GiNaC Copyright (C) 1999-2011 Johannes Gutenberg University Mainz, Germany
00007  *
00008  *  This program is free software; you can redistribute it and/or modify
00009  *  it under the terms of the GNU General Public License as published by
00010  *  the Free Software Foundation; either version 2 of the License, or
00011  *  (at your option) any later version.
00012  *
00013  *  This program is distributed in the hope that it will be useful,
00014  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
00015  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00016  *  GNU General Public License for more details.
00017  *
00018  *  You should have received a copy of the GNU General Public License
00019  *  along with this program; if not, write to the Free Software
00020  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
00021  */
00022 
00023 #include "indexed.h"
00024 #include "idx.h"
00025 #include "add.h"
00026 #include "mul.h"
00027 #include "ncmul.h"
00028 #include "power.h"
00029 #include "relational.h"
00030 #include "symmetry.h"
00031 #include "operators.h"
00032 #include "lst.h"
00033 #include "archive.h"
00034 #include "symbol.h"
00035 #include "utils.h"
00036 #include "integral.h"
00037 #include "matrix.h"
00038 #include "inifcns.h"
00039 
00040 #include <iostream>
00041 #include <limits>
00042 #include <sstream>
00043 #include <stdexcept>
00044 
00045 namespace GiNaC {
00046 
00047 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(indexed, exprseq,
00048   print_func<print_context>(&indexed::do_print).
00049   print_func<print_latex>(&indexed::do_print_latex).
00050   print_func<print_tree>(&indexed::do_print_tree))
00051 
00052 
00053 // default constructor
00055 
00056 indexed::indexed() : symtree(not_symmetric())
00057 {
00058 }
00059 
00061 // other constructors
00063 
00064 indexed::indexed(const ex & b) : inherited(b), symtree(not_symmetric())
00065 {
00066     validate();
00067 }
00068 
00069 indexed::indexed(const ex & b, const ex & i1) : inherited(b, i1), symtree(not_symmetric())
00070 {
00071     validate();
00072 }
00073 
00074 indexed::indexed(const ex & b, const ex & i1, const ex & i2) : inherited(b, i1, i2), symtree(not_symmetric())
00075 {
00076     validate();
00077 }
00078 
00079 indexed::indexed(const ex & b, const ex & i1, const ex & i2, const ex & i3) : inherited(b, i1, i2, i3), symtree(not_symmetric())
00080 {
00081     validate();
00082 }
00083 
00084 indexed::indexed(const ex & b, const ex & i1, const ex & i2, const ex & i3, const ex & i4) : inherited(b, i1, i2, i3, i4), symtree(not_symmetric())
00085 {
00086     validate();
00087 }
00088 
00089 indexed::indexed(const ex & b, const symmetry & symm, const ex & i1, const ex & i2) : inherited(b, i1, i2), symtree(symm)
00090 {
00091     validate();
00092 }
00093 
00094 indexed::indexed(const ex & b, const symmetry & symm, const ex & i1, const ex & i2, const ex & i3) : inherited(b, i1, i2, i3), symtree(symm)
00095 {
00096     validate();
00097 }
00098 
00099 indexed::indexed(const ex & b, const symmetry & symm, const ex & i1, const ex & i2, const ex & i3, const ex & i4) : inherited(b, i1, i2, i3, i4), symtree(symm)
00100 {
00101     validate();
00102 }
00103 
00104 indexed::indexed(const ex & b, const exvector & v) : inherited(b), symtree(not_symmetric())
00105 {
00106     seq.insert(seq.end(), v.begin(), v.end());
00107     validate();
00108 }
00109 
00110 indexed::indexed(const ex & b, const symmetry & symm, const exvector & v) : inherited(b), symtree(symm)
00111 {
00112     seq.insert(seq.end(), v.begin(), v.end());
00113     validate();
00114 }
00115 
00116 indexed::indexed(const symmetry & symm, const exprseq & es) : inherited(es), symtree(symm)
00117 {
00118 }
00119 
00120 indexed::indexed(const symmetry & symm, const exvector & v, bool discardable) : inherited(v, discardable), symtree(symm)
00121 {
00122 }
00123 
00124 indexed::indexed(const symmetry & symm, std::auto_ptr<exvector> vp) : inherited(vp), symtree(symm)
00125 {
00126 }
00127 
00129 // archiving
00131 
00132 void indexed::read_archive(const archive_node &n, lst &sym_lst)
00133 {
00134     inherited::read_archive(n, sym_lst);
00135     if (!n.find_ex("symmetry", symtree, sym_lst)) {
00136         // GiNaC versions <= 0.9.0 had an unsigned "symmetry" property
00137         unsigned symm = 0;
00138         n.find_unsigned("symmetry", symm);
00139         switch (symm) {
00140             case 1:
00141                 symtree = sy_symm();
00142                 break;
00143             case 2:
00144                 symtree = sy_anti();
00145                 break;
00146             default:
00147                 symtree = not_symmetric();
00148                 break;
00149         }
00150         const_cast<symmetry &>(ex_to<symmetry>(symtree)).validate(seq.size() - 1);
00151     }
00152 }
00153 GINAC_BIND_UNARCHIVER(indexed);
00154 
00155 void indexed::archive(archive_node &n) const
00156 {
00157     inherited::archive(n);
00158     n.add_ex("symmetry", symtree);
00159 }
00160 
00162 // functions overriding virtual functions from base classes
00164 
00165 void indexed::printindices(const print_context & c, unsigned level) const
00166 {
00167     if (seq.size() > 1) {
00168 
00169         exvector::const_iterator it=seq.begin() + 1, itend = seq.end();
00170 
00171         if (is_a<print_latex>(c)) {
00172 
00173             // TeX output: group by variance
00174             bool first = true;
00175             bool covariant = true;
00176 
00177             while (it != itend) {
00178                 bool cur_covariant = (is_a<varidx>(*it) ? ex_to<varidx>(*it).is_covariant() : true);
00179                 if (first || cur_covariant != covariant) { // Variance changed
00180                     // The empty {} prevents indices from ending up on top of each other
00181                     if (!first)
00182                         c.s << "}{}";
00183                     covariant = cur_covariant;
00184                     if (covariant)
00185                         c.s << "_{";
00186                     else
00187                         c.s << "^{";
00188                 }
00189                 it->print(c, level);
00190                 c.s << " ";
00191                 first = false;
00192                 it++;
00193             }
00194             c.s << "}";
00195 
00196         } else {
00197 
00198             // Ordinary output
00199             while (it != itend) {
00200                 it->print(c, level);
00201                 it++;
00202             }
00203         }
00204     }
00205 }
00206 
00207 void indexed::print_indexed(const print_context & c, const char *openbrace, const char *closebrace, unsigned level) const
00208 {
00209     if (precedence() <= level)
00210         c.s << openbrace << '(';
00211     c.s << openbrace;
00212     seq[0].print(c, precedence());
00213     c.s << closebrace;
00214     printindices(c, level);
00215     if (precedence() <= level)
00216         c.s << ')' << closebrace;
00217 }
00218 
00219 void indexed::do_print(const print_context & c, unsigned level) const
00220 {
00221     print_indexed(c, "", "", level);
00222 }
00223 
00224 void indexed::do_print_latex(const print_latex & c, unsigned level) const
00225 {
00226     print_indexed(c, "{", "}", level);
00227 }
00228 
00229 void indexed::do_print_tree(const print_tree & c, unsigned level) const
00230 {
00231     c.s << std::string(level, ' ') << class_name() << " @" << this
00232         << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
00233         << ", " << seq.size()-1 << " indices"
00234         << ", symmetry=" << symtree << std::endl;
00235     seq[0].print(c, level + c.delta_indent);
00236     printindices(c, level + c.delta_indent);
00237 }
00238 
00239 bool indexed::info(unsigned inf) const
00240 {
00241     if (inf == info_flags::indexed) return true;
00242     if (inf == info_flags::has_indices) return seq.size() > 1;
00243     return inherited::info(inf);
00244 }
00245 
00246 struct idx_is_not : public std::binary_function<ex, unsigned, bool> {
00247     bool operator() (const ex & e, unsigned inf) const {
00248         return !(ex_to<idx>(e).get_value().info(inf));
00249     }
00250 };
00251 
00252 bool indexed::all_index_values_are(unsigned inf) const
00253 {
00254     // No indices? Then no property can be fulfilled
00255     if (seq.size() < 2)
00256         return false;
00257 
00258     // Check all indices
00259     return find_if(seq.begin() + 1, seq.end(), bind2nd(idx_is_not(), inf)) == seq.end();
00260 }
00261 
00262 int indexed::compare_same_type(const basic & other) const
00263 {
00264     GINAC_ASSERT(is_a<indexed>(other));
00265     return inherited::compare_same_type(other);
00266 }
00267 
00268 ex indexed::eval(int level) const
00269 {
00270     // First evaluate children, then we will end up here again
00271     if (level > 1)
00272         return indexed(ex_to<symmetry>(symtree), evalchildren(level));
00273 
00274     const ex &base = seq[0];
00275 
00276     // If the base object is 0, the whole object is 0
00277     if (base.is_zero())
00278         return _ex0;
00279 
00280     // If the base object is a product, pull out the numeric factor
00281     if (is_exactly_a<mul>(base) && is_exactly_a<numeric>(base.op(base.nops() - 1))) {
00282         exvector v(seq);
00283         ex f = ex_to<numeric>(base.op(base.nops() - 1));
00284         v[0] = seq[0] / f;
00285         return f * thiscontainer(v);
00286     }
00287 
00288     if((typeid(*this) == typeid(indexed)) && seq.size()==1)
00289         return base;
00290 
00291     // Canonicalize indices according to the symmetry properties
00292     if (seq.size() > 2) {
00293         exvector v = seq;
00294         GINAC_ASSERT(is_exactly_a<symmetry>(symtree));
00295         int sig = canonicalize(v.begin() + 1, ex_to<symmetry>(symtree));
00296         if (sig != std::numeric_limits<int>::max()) {
00297             // Something has changed while sorting indices, more evaluations later
00298             if (sig == 0)
00299                 return _ex0;
00300             return ex(sig) * thiscontainer(v);
00301         }
00302     }
00303 
00304     // Let the class of the base object perform additional evaluations
00305     return ex_to<basic>(base).eval_indexed(*this);
00306 }
00307 
00308 ex indexed::real_part() const
00309 {
00310     if(op(0).info(info_flags::real))
00311         return *this;
00312     return real_part_function(*this).hold();
00313 }
00314 
00315 ex indexed::imag_part() const
00316 {
00317     if(op(0).info(info_flags::real))
00318         return 0;
00319     return imag_part_function(*this).hold();
00320 }
00321 
00322 ex indexed::thiscontainer(const exvector & v) const
00323 {
00324     return indexed(ex_to<symmetry>(symtree), v);
00325 }
00326 
00327 ex indexed::thiscontainer(std::auto_ptr<exvector> vp) const
00328 {
00329     return indexed(ex_to<symmetry>(symtree), vp);
00330 }
00331 
00332 unsigned indexed::return_type() const
00333 {
00334     if(is_a<matrix>(op(0)))
00335         return return_types::commutative;
00336     else
00337         return op(0).return_type();
00338 }
00339 
00340 ex indexed::expand(unsigned options) const
00341 {
00342     GINAC_ASSERT(seq.size() > 0);
00343 
00344     if (options & expand_options::expand_indexed) {
00345         ex newbase = seq[0].expand(options);
00346         if (is_exactly_a<add>(newbase)) {
00347             ex sum = _ex0;
00348             for (size_t i=0; i<newbase.nops(); i++) {
00349                 exvector s = seq;
00350                 s[0] = newbase.op(i);
00351                 sum += thiscontainer(s).expand(options);
00352             }
00353             return sum;
00354         }
00355         if (!are_ex_trivially_equal(newbase, seq[0])) {
00356             exvector s = seq;
00357             s[0] = newbase;
00358             return ex_to<indexed>(thiscontainer(s)).inherited::expand(options);
00359         }
00360     }
00361     return inherited::expand(options);
00362 }
00363 
00365 // virtual functions which can be overridden by derived classes
00367 
00368 // none
00369 
00371 // non-virtual functions in this class
00373 
00377 void indexed::validate() const
00378 {
00379     GINAC_ASSERT(seq.size() > 0);
00380     exvector::const_iterator it = seq.begin() + 1, itend = seq.end();
00381     while (it != itend) {
00382         if (!is_a<idx>(*it))
00383             throw(std::invalid_argument("indices of indexed object must be of type idx"));
00384         it++;
00385     }
00386 
00387     if (!symtree.is_zero()) {
00388         if (!is_exactly_a<symmetry>(symtree))
00389             throw(std::invalid_argument("symmetry of indexed object must be of type symmetry"));
00390         const_cast<symmetry &>(ex_to<symmetry>(symtree)).validate(seq.size() - 1);
00391     }
00392 }
00393 
00397 ex indexed::derivative(const symbol & s) const
00398 {
00399     return _ex0;
00400 }
00401 
00403 // global functions
00405 
00406 struct idx_is_equal_ignore_dim : public std::binary_function<ex, ex, bool> {
00407     bool operator() (const ex &lh, const ex &rh) const
00408     {
00409         if (lh.is_equal(rh))
00410             return true;
00411         else
00412             try {
00413                 // Replacing the dimension might cause an error (e.g. with
00414                 // index classes that only work in a fixed number of dimensions)
00415                 return lh.is_equal(ex_to<idx>(rh).replace_dim(ex_to<idx>(lh).get_dim()));
00416             } catch (...) {
00417                 return false;
00418             }
00419     }
00420 };
00421 
00423 static bool indices_consistent(const exvector & v1, const exvector & v2)
00424 {
00425     // Number of indices must be the same
00426     if (v1.size() != v2.size())
00427         return false;
00428 
00429     return equal(v1.begin(), v1.end(), v2.begin(), idx_is_equal_ignore_dim());
00430 }
00431 
00432 exvector indexed::get_indices() const
00433 {
00434     GINAC_ASSERT(seq.size() >= 1);
00435     return exvector(seq.begin() + 1, seq.end());
00436 }
00437 
00438 exvector indexed::get_dummy_indices() const
00439 {
00440     exvector free_indices, dummy_indices;
00441     find_free_and_dummy(seq.begin() + 1, seq.end(), free_indices, dummy_indices);
00442     return dummy_indices;
00443 }
00444 
00445 exvector indexed::get_dummy_indices(const indexed & other) const
00446 {
00447     exvector indices = get_free_indices();
00448     exvector other_indices = other.get_free_indices();
00449     indices.insert(indices.end(), other_indices.begin(), other_indices.end());
00450     exvector dummy_indices;
00451     find_dummy_indices(indices, dummy_indices);
00452     return dummy_indices;
00453 }
00454 
00455 bool indexed::has_dummy_index_for(const ex & i) const
00456 {
00457     exvector::const_iterator it = seq.begin() + 1, itend = seq.end();
00458     while (it != itend) {
00459         if (is_dummy_pair(*it, i))
00460             return true;
00461         it++;
00462     }
00463     return false;
00464 }
00465 
00466 exvector indexed::get_free_indices() const
00467 {
00468     exvector free_indices, dummy_indices;
00469     find_free_and_dummy(seq.begin() + 1, seq.end(), free_indices, dummy_indices);
00470     return free_indices;
00471 }
00472 
00473 exvector add::get_free_indices() const
00474 {
00475     exvector free_indices;
00476     for (size_t i=0; i<nops(); i++) {
00477         if (i == 0)
00478             free_indices = op(i).get_free_indices();
00479         else {
00480             exvector free_indices_of_term = op(i).get_free_indices();
00481             if (!indices_consistent(free_indices, free_indices_of_term))
00482                 throw (std::runtime_error("add::get_free_indices: inconsistent indices in sum"));
00483         }
00484     }
00485     return free_indices;
00486 }
00487 
00488 exvector mul::get_free_indices() const
00489 {
00490     // Concatenate free indices of all factors
00491     exvector un;
00492     for (size_t i=0; i<nops(); i++) {
00493         exvector free_indices_of_factor = op(i).get_free_indices();
00494         un.insert(un.end(), free_indices_of_factor.begin(), free_indices_of_factor.end());
00495     }
00496 
00497     // And remove the dummy indices
00498     exvector free_indices, dummy_indices;
00499     find_free_and_dummy(un, free_indices, dummy_indices);
00500     return free_indices;
00501 }
00502 
00503 exvector ncmul::get_free_indices() const
00504 {
00505     // Concatenate free indices of all factors
00506     exvector un;
00507     for (size_t i=0; i<nops(); i++) {
00508         exvector free_indices_of_factor = op(i).get_free_indices();
00509         un.insert(un.end(), free_indices_of_factor.begin(), free_indices_of_factor.end());
00510     }
00511 
00512     // And remove the dummy indices
00513     exvector free_indices, dummy_indices;
00514     find_free_and_dummy(un, free_indices, dummy_indices);
00515     return free_indices;
00516 }
00517 
00518 struct is_summation_idx : public std::unary_function<ex, bool> {
00519     bool operator()(const ex & e)
00520     {
00521         return is_dummy_pair(e, e);
00522     }
00523 };
00524 
00525 exvector integral::get_free_indices() const
00526 {
00527     if (a.get_free_indices().size() || b.get_free_indices().size())
00528         throw (std::runtime_error("integral::get_free_indices: boundary values should not have free indices"));
00529     return f.get_free_indices();
00530 }
00531 
00532 template<class T> size_t number_of_type(const exvector&v)
00533 {
00534     size_t number = 0;
00535     for(exvector::const_iterator i=v.begin(); i!=v.end(); ++i)
00536         if(is_exactly_a<T>(*i))
00537             ++number;
00538     return number;
00539 }
00540 
00549 template<class T> static ex rename_dummy_indices(const ex & e, exvector & global_dummy_indices, exvector & local_dummy_indices)
00550 {
00551     size_t global_size = number_of_type<T>(global_dummy_indices),
00552            local_size = number_of_type<T>(local_dummy_indices);
00553 
00554     // Any local dummy indices at all?
00555     if (local_size == 0)
00556         return e;
00557 
00558     if (global_size < local_size) {
00559 
00560         // More local indices than we encountered before, add the new ones
00561         // to the global set
00562         size_t old_global_size = global_size;
00563         int remaining = local_size - global_size;
00564         exvector::const_iterator it = local_dummy_indices.begin(), itend = local_dummy_indices.end();
00565         while (it != itend && remaining > 0) {
00566             if (is_exactly_a<T>(*it) && find_if(global_dummy_indices.begin(), global_dummy_indices.end(), bind2nd(idx_is_equal_ignore_dim(), *it)) == global_dummy_indices.end()) {
00567                 global_dummy_indices.push_back(*it);
00568                 global_size++;
00569                 remaining--;
00570             }
00571             it++;
00572         }
00573 
00574         // If this is the first set of local indices, do nothing
00575         if (old_global_size == 0)
00576             return e;
00577     }
00578     GINAC_ASSERT(local_size <= global_size);
00579 
00580     // Construct vectors of index symbols
00581     exvector local_syms, global_syms;
00582     local_syms.reserve(local_size);
00583     global_syms.reserve(local_size);
00584     for (size_t i=0; local_syms.size()!=local_size; i++)
00585         if(is_exactly_a<T>(local_dummy_indices[i]))
00586             local_syms.push_back(local_dummy_indices[i].op(0));
00587     shaker_sort(local_syms.begin(), local_syms.end(), ex_is_less(), ex_swap());
00588     for (size_t i=0; global_syms.size()!=local_size; i++) // don't use more global symbols than necessary
00589         if(is_exactly_a<T>(global_dummy_indices[i]))
00590             global_syms.push_back(global_dummy_indices[i].op(0));
00591     shaker_sort(global_syms.begin(), global_syms.end(), ex_is_less(), ex_swap());
00592 
00593     // Remove common indices
00594     exvector local_uniq, global_uniq;
00595     set_difference(local_syms.begin(), local_syms.end(), global_syms.begin(), global_syms.end(), std::back_insert_iterator<exvector>(local_uniq), ex_is_less());
00596     set_difference(global_syms.begin(), global_syms.end(), local_syms.begin(), local_syms.end(), std::back_insert_iterator<exvector>(global_uniq), ex_is_less());
00597 
00598     // Replace remaining non-common local index symbols by global ones
00599     if (local_uniq.empty())
00600         return e;
00601     else {
00602         while (global_uniq.size() > local_uniq.size())
00603             global_uniq.pop_back();
00604         return e.subs(lst(local_uniq.begin(), local_uniq.end()), lst(global_uniq.begin(), global_uniq.end()), subs_options::no_pattern);
00605     }
00606 }
00607 
00609 static void find_variant_indices(const exvector & v, exvector & variant_indices)
00610 {
00611     exvector::const_iterator it1, itend;
00612     for (it1 = v.begin(), itend = v.end(); it1 != itend; ++it1) {
00613         if (is_exactly_a<varidx>(*it1))
00614             variant_indices.push_back(*it1);
00615     }
00616 }
00617 
00625 bool reposition_dummy_indices(ex & e, exvector & variant_dummy_indices, exvector & moved_indices)
00626 {
00627     bool something_changed = false;
00628 
00629     // Find dummy symbols that occur twice in the same indexed object.
00630     exvector local_var_dummies;
00631     local_var_dummies.reserve(e.nops()/2);
00632     for (size_t i=1; i<e.nops(); ++i) {
00633         if (!is_a<varidx>(e.op(i)))
00634             continue;
00635         for (size_t j=i+1; j<e.nops(); ++j) {
00636             if (is_dummy_pair(e.op(i), e.op(j))) {
00637                 local_var_dummies.push_back(e.op(i));
00638                 for (exvector::iterator k = variant_dummy_indices.begin();
00639                         k!=variant_dummy_indices.end(); ++k) {
00640                     if (e.op(i).op(0) == k->op(0)) {
00641                         variant_dummy_indices.erase(k);
00642                         break;
00643                     }
00644                 }
00645                 break;
00646             }
00647         }
00648     }
00649 
00650     // In the case where a dummy symbol occurs twice in the same indexed object
00651     // we try all posibilities of raising/lowering and keep the least one in
00652     // the sense of ex_is_less.
00653     ex optimal_e = e;
00654     size_t numpossibs = 1 << local_var_dummies.size();
00655     for (size_t i=0; i<numpossibs; ++i) {
00656         ex try_e = e;
00657         for (size_t j=0; j<local_var_dummies.size(); ++j) {
00658             exmap m;
00659             if (1<<j & i) {
00660                 ex curr_idx = local_var_dummies[j];
00661                 ex curr_toggle = ex_to<varidx>(curr_idx).toggle_variance();
00662                 m[curr_idx] = curr_toggle;
00663                 m[curr_toggle] = curr_idx;
00664             }
00665             try_e = e.subs(m, subs_options::no_pattern);
00666         }
00667         if(ex_is_less()(try_e, optimal_e))
00668         {   optimal_e = try_e;
00669             something_changed = true;
00670         }
00671     }
00672     e = optimal_e;
00673 
00674     if (!is_a<indexed>(e))
00675         return true;
00676 
00677     exvector seq = ex_to<indexed>(e).seq;
00678 
00679     // If a dummy index is encountered for the first time in the
00680     // product, pull it up, otherwise, pull it down
00681     for (exvector::iterator it2 = seq.begin()+1, it2end = seq.end();
00682             it2 != it2end; ++it2) {
00683         if (!is_exactly_a<varidx>(*it2))
00684             continue;
00685 
00686         exvector::iterator vit, vitend;
00687         for (vit = variant_dummy_indices.begin(), vitend = variant_dummy_indices.end(); vit != vitend; ++vit) {
00688             if (it2->op(0).is_equal(vit->op(0))) {
00689                 if (ex_to<varidx>(*it2).is_covariant()) {
00690                     /*
00691                      * N.B. we don't want to use
00692                      *
00693                      *  e = e.subs(lst(
00694                      *  *it2 == ex_to<varidx>(*it2).toggle_variance(),
00695                      *  ex_to<varidx>(*it2).toggle_variance() == *it2
00696                      *  ), subs_options::no_pattern);
00697                      *
00698                      * since this can trigger non-trivial repositioning of indices,
00699                      * e.g. due to non-trivial symmetry properties of e, thus
00700                      * invalidating iterators
00701                      */
00702                     *it2 = ex_to<varidx>(*it2).toggle_variance();
00703                     something_changed = true;
00704                 }
00705                 moved_indices.push_back(*vit);
00706                 variant_dummy_indices.erase(vit);
00707                 goto next_index;
00708             }
00709         }
00710 
00711         for (vit = moved_indices.begin(), vitend = moved_indices.end(); vit != vitend; ++vit) {
00712             if (it2->op(0).is_equal(vit->op(0))) {
00713                 if (ex_to<varidx>(*it2).is_contravariant()) {
00714                     *it2 = ex_to<varidx>(*it2).toggle_variance();
00715                     something_changed = true;
00716                 }
00717                 goto next_index;
00718             }
00719         }
00720 
00721 next_index: ;
00722     }
00723 
00724     if (something_changed)
00725         e = ex_to<indexed>(e).thiscontainer(seq);
00726 
00727     return something_changed;
00728 }
00729 
00730 /* Ordering that only compares the base expressions of indexed objects. */
00731 struct ex_base_is_less : public std::binary_function<ex, ex, bool> {
00732     bool operator() (const ex &lh, const ex &rh) const
00733     {
00734         return (is_a<indexed>(lh) ? lh.op(0) : lh).compare(is_a<indexed>(rh) ? rh.op(0) : rh) < 0;
00735     }
00736 };
00737 
00738 /* An auxiliary function used by simplify_indexed() and expand_dummy_sum() 
00739  * It returns an exvector of factors from the supplied product */
00740 static void product_to_exvector(const ex & e, exvector & v, bool & non_commutative)
00741 {
00742     // Remember whether the product was commutative or noncommutative
00743     // (because we chop it into factors and need to reassemble later)
00744     non_commutative = is_exactly_a<ncmul>(e);
00745 
00746     // Collect factors in an exvector, store squares twice
00747     v.reserve(e.nops() * 2);
00748 
00749     if (is_exactly_a<power>(e)) {
00750         // We only get called for simple squares, split a^2 -> a*a
00751         GINAC_ASSERT(e.op(1).is_equal(_ex2));
00752         v.push_back(e.op(0));
00753         v.push_back(e.op(0));
00754     } else {
00755         for (size_t i=0; i<e.nops(); i++) {
00756             ex f = e.op(i);
00757             if (is_exactly_a<power>(f) && f.op(1).is_equal(_ex2)) {
00758                 v.push_back(f.op(0));
00759                 v.push_back(f.op(0));
00760             } else if (is_exactly_a<ncmul>(f)) {
00761                 // Noncommutative factor found, split it as well
00762                 non_commutative = true; // everything becomes noncommutative, ncmul will sort out the commutative factors later
00763                 for (size_t j=0; j<f.nops(); j++)
00764                     v.push_back(f.op(j));
00765             } else
00766                 v.push_back(f);
00767         }
00768     }
00769 }
00770 
00771 template<class T> ex idx_symmetrization(const ex& r,const exvector& local_dummy_indices)
00772 {   exvector dummy_syms;
00773     dummy_syms.reserve(r.nops());
00774     for (exvector::const_iterator it = local_dummy_indices.begin(); it != local_dummy_indices.end(); ++it)
00775             if(is_exactly_a<T>(*it))
00776                 dummy_syms.push_back(it->op(0));
00777     if(dummy_syms.size() < 2)
00778         return r;
00779     ex q=symmetrize(r, dummy_syms);
00780     return q;
00781 }
00782 
00783 // Forward declaration needed in absence of friend injection, C.f. [namespace.memdef]:
00784 ex simplify_indexed(const ex & e, exvector & free_indices, exvector & dummy_indices, const scalar_products & sp);
00785 
00788 ex simplify_indexed_product(const ex & e, exvector & free_indices, exvector & dummy_indices, const scalar_products & sp)
00789 {
00790     // Collect factors in an exvector
00791     exvector v;
00792 
00793     // Remember whether the product was commutative or noncommutative
00794     // (because we chop it into factors and need to reassemble later)
00795     bool non_commutative;
00796     product_to_exvector(e, v, non_commutative);
00797 
00798     // Perform contractions
00799     bool something_changed = false;
00800     bool has_nonsymmetric = false;
00801     GINAC_ASSERT(v.size() > 1);
00802     exvector::iterator it1, itend = v.end(), next_to_last = itend - 1;
00803     for (it1 = v.begin(); it1 != next_to_last; it1++) {
00804 
00805 try_again:
00806         if (!is_a<indexed>(*it1))
00807             continue;
00808 
00809         bool first_noncommutative = (it1->return_type() != return_types::commutative);
00810         bool first_nonsymmetric = ex_to<symmetry>(ex_to<indexed>(*it1).get_symmetry()).has_nonsymmetric();
00811 
00812         // Indexed factor found, get free indices and look for contraction
00813         // candidates
00814         exvector free1, dummy1;
00815         find_free_and_dummy(ex_to<indexed>(*it1).seq.begin() + 1, ex_to<indexed>(*it1).seq.end(), free1, dummy1);
00816 
00817         exvector::iterator it2;
00818         for (it2 = it1 + 1; it2 != itend; it2++) {
00819 
00820             if (!is_a<indexed>(*it2))
00821                 continue;
00822 
00823             bool second_noncommutative = (it2->return_type() != return_types::commutative);
00824 
00825             // Find free indices of second factor and merge them with free
00826             // indices of first factor
00827             exvector un;
00828             find_free_and_dummy(ex_to<indexed>(*it2).seq.begin() + 1, ex_to<indexed>(*it2).seq.end(), un, dummy1);
00829             un.insert(un.end(), free1.begin(), free1.end());
00830 
00831             // Check whether the two factors share dummy indices
00832             exvector free, dummy;
00833             find_free_and_dummy(un, free, dummy);
00834             size_t num_dummies = dummy.size();
00835             if (num_dummies == 0)
00836                 continue;
00837 
00838             // At least one dummy index, is it a defined scalar product?
00839             bool contracted = false;
00840             if (free.empty() && it1->nops()==2 && it2->nops()==2) {
00841 
00842                 ex dim = minimal_dim(
00843                     ex_to<idx>(it1->op(1)).get_dim(),
00844                     ex_to<idx>(it2->op(1)).get_dim()
00845                 );
00846 
00847                 // User-defined scalar product?
00848                 if (sp.is_defined(*it1, *it2, dim)) {
00849 
00850                     // Yes, substitute it
00851                     *it1 = sp.evaluate(*it1, *it2, dim);
00852                     *it2 = _ex1;
00853                     goto contraction_done;
00854                 }
00855             }
00856 
00857             // Try to contract the first one with the second one
00858             contracted = ex_to<basic>(it1->op(0)).contract_with(it1, it2, v);
00859             if (!contracted) {
00860 
00861                 // That didn't work; maybe the second object knows how to
00862                 // contract itself with the first one
00863                 contracted = ex_to<basic>(it2->op(0)).contract_with(it2, it1, v);
00864             }
00865             if (contracted) {
00866 contraction_done:
00867                 if (first_noncommutative || second_noncommutative
00868                  || is_exactly_a<add>(*it1) || is_exactly_a<add>(*it2)
00869                  || is_exactly_a<mul>(*it1) || is_exactly_a<mul>(*it2)
00870                  || is_exactly_a<ncmul>(*it1) || is_exactly_a<ncmul>(*it2)) {
00871 
00872                     // One of the factors became a sum or product:
00873                     // re-expand expression and run again
00874                     // Non-commutative products are always re-expanded to give
00875                     // eval_ncmul() the chance to re-order and canonicalize
00876                     // the product
00877                     ex r = (non_commutative ? ex(ncmul(v, true)) : ex(mul(v)));
00878                     return simplify_indexed(r, free_indices, dummy_indices, sp);
00879                 }
00880 
00881                 // Both objects may have new indices now or they might
00882                 // even not be indexed objects any more, so we have to
00883                 // start over
00884                 something_changed = true;
00885                 goto try_again;
00886             }
00887             else if (!has_nonsymmetric &&
00888                     (first_nonsymmetric ||
00889                      ex_to<symmetry>(ex_to<indexed>(*it2).get_symmetry()).has_nonsymmetric())) {
00890                 has_nonsymmetric = true;
00891             }
00892         }
00893     }
00894 
00895     // Find free indices (concatenate them all and call find_free_and_dummy())
00896     // and all dummy indices that appear
00897     exvector un, individual_dummy_indices;
00898     for (it1 = v.begin(), itend = v.end(); it1 != itend; ++it1) {
00899         exvector free_indices_of_factor;
00900         if (is_a<indexed>(*it1)) {
00901             exvector dummy_indices_of_factor;
00902             find_free_and_dummy(ex_to<indexed>(*it1).seq.begin() + 1, ex_to<indexed>(*it1).seq.end(), free_indices_of_factor, dummy_indices_of_factor);
00903             individual_dummy_indices.insert(individual_dummy_indices.end(), dummy_indices_of_factor.begin(), dummy_indices_of_factor.end());
00904         } else
00905             free_indices_of_factor = it1->get_free_indices();
00906         un.insert(un.end(), free_indices_of_factor.begin(), free_indices_of_factor.end());
00907     }
00908     exvector local_dummy_indices;
00909     find_free_and_dummy(un, free_indices, local_dummy_indices);
00910     local_dummy_indices.insert(local_dummy_indices.end(), individual_dummy_indices.begin(), individual_dummy_indices.end());
00911 
00912     // Filter out the dummy indices with variance
00913     exvector variant_dummy_indices;
00914     find_variant_indices(local_dummy_indices, variant_dummy_indices);
00915 
00916     // Any indices with variance present at all?
00917     if (!variant_dummy_indices.empty()) {
00918 
00919         // Yes, bring the product into a canonical order that only depends on
00920         // the base expressions of indexed objects
00921         if (!non_commutative)
00922             std::sort(v.begin(), v.end(), ex_base_is_less());
00923 
00924         exvector moved_indices;
00925 
00926         // Iterate over all indexed objects in the product
00927         for (it1 = v.begin(), itend = v.end(); it1 != itend; ++it1) {
00928             if (!is_a<indexed>(*it1))
00929                 continue;
00930 
00931             if (reposition_dummy_indices(*it1, variant_dummy_indices, moved_indices))
00932                 something_changed = true;
00933         }
00934     }
00935 
00936     ex r;
00937     if (something_changed)
00938         r = non_commutative ? ex(ncmul(v, true)) : ex(mul(v));
00939     else
00940         r = e;
00941 
00942     // The result should be symmetric with respect to exchange of dummy
00943     // indices, so if the symmetrization vanishes, the whole expression is
00944     // zero. This detects things like eps.i.j.k * p.j * p.k = 0.
00945     if (has_nonsymmetric) {
00946         ex q = idx_symmetrization<idx>(r, local_dummy_indices);
00947         if (q.is_zero()) {
00948             free_indices.clear();
00949             return _ex0;
00950         }
00951         q = idx_symmetrization<varidx>(q, local_dummy_indices);
00952         if (q.is_zero()) {
00953             free_indices.clear();
00954             return _ex0;
00955         }
00956         q = idx_symmetrization<spinidx>(q, local_dummy_indices);
00957         if (q.is_zero()) {
00958             free_indices.clear();
00959             return _ex0;
00960         }
00961     }
00962 
00963     // Dummy index renaming
00964     r = rename_dummy_indices<idx>(r, dummy_indices, local_dummy_indices);
00965     r = rename_dummy_indices<varidx>(r, dummy_indices, local_dummy_indices);
00966     r = rename_dummy_indices<spinidx>(r, dummy_indices, local_dummy_indices);
00967 
00968     // Product of indexed object with a scalar?
00969     if (is_exactly_a<mul>(r) && r.nops() == 2
00970      && is_exactly_a<numeric>(r.op(1)) && is_a<indexed>(r.op(0)))
00971         return ex_to<basic>(r.op(0).op(0)).scalar_mul_indexed(r.op(0), ex_to<numeric>(r.op(1)));
00972     else
00973         return r;
00974 }
00975 
00978 class terminfo {
00979 public:
00980     terminfo(const ex & orig_, const ex & symm_) : orig(orig_), symm(symm_) {}
00981 
00982     ex orig; 
00983     ex symm; 
00984 };
00985 
00986 class terminfo_is_less {
00987 public:
00988     bool operator() (const terminfo & ti1, const terminfo & ti2) const
00989     {
00990         return (ti1.symm.compare(ti2.symm) < 0);
00991     }
00992 };
00993 
00996 class symminfo {
00997 public:
00998     symminfo() : num(0) {}
00999 
01000     symminfo(const ex & symmterm_, const ex & orig_, size_t num_) : orig(orig_), num(num_)
01001     {
01002         if (is_exactly_a<mul>(symmterm_) && is_exactly_a<numeric>(symmterm_.op(symmterm_.nops()-1))) {
01003             coeff = symmterm_.op(symmterm_.nops()-1);
01004             symmterm = symmterm_ / coeff;
01005         } else {
01006             coeff = 1;
01007             symmterm = symmterm_;
01008         }
01009     }
01010 
01011     ex symmterm;  
01012     ex coeff;     
01013     ex orig;      
01014     size_t num; 
01015 };
01016 
01017 class symminfo_is_less_by_symmterm {
01018 public:
01019     bool operator() (const symminfo & si1, const symminfo & si2) const
01020     {
01021         return (si1.symmterm.compare(si2.symmterm) < 0);
01022     }
01023 };
01024 
01025 class symminfo_is_less_by_orig {
01026 public:
01027     bool operator() (const symminfo & si1, const symminfo & si2) const
01028     {
01029         return (si1.orig.compare(si2.orig) < 0);
01030     }
01031 };
01032 
01033 bool hasindex(const ex &x, const ex &sym)
01034 {   
01035     if(is_a<idx>(x) && x.op(0)==sym)
01036         return true;
01037     else
01038         for(size_t i=0; i<x.nops(); ++i)
01039             if(hasindex(x.op(i), sym))
01040                 return true;
01041     return false;
01042 }
01043 
01045 ex simplify_indexed(const ex & e, exvector & free_indices, exvector & dummy_indices, const scalar_products & sp)
01046 {
01047     // Expand the expression
01048     ex e_expanded = e.expand();
01049 
01050     // Simplification of single indexed object: just find the free indices
01051     // and perform dummy index renaming/repositioning
01052     if (is_a<indexed>(e_expanded)) {
01053 
01054         // Find the dummy indices
01055         const indexed &i = ex_to<indexed>(e_expanded);
01056         exvector local_dummy_indices;
01057         find_free_and_dummy(i.seq.begin() + 1, i.seq.end(), free_indices, local_dummy_indices);
01058 
01059         // Filter out the dummy indices with variance
01060         exvector variant_dummy_indices;
01061         find_variant_indices(local_dummy_indices, variant_dummy_indices);
01062 
01063         // Any indices with variance present at all?
01064         if (!variant_dummy_indices.empty()) {
01065 
01066             // Yes, reposition them
01067             exvector moved_indices;
01068             reposition_dummy_indices(e_expanded, variant_dummy_indices, moved_indices);
01069         }
01070 
01071         // Rename the dummy indices
01072         e_expanded = rename_dummy_indices<idx>(e_expanded, dummy_indices, local_dummy_indices);
01073         e_expanded = rename_dummy_indices<varidx>(e_expanded, dummy_indices, local_dummy_indices);
01074         e_expanded = rename_dummy_indices<spinidx>(e_expanded, dummy_indices, local_dummy_indices);
01075         return e_expanded;
01076     }
01077 
01078     // Simplification of sum = sum of simplifications, check consistency of
01079     // free indices in each term
01080     if (is_exactly_a<add>(e_expanded)) {
01081         bool first = true;
01082         ex sum;
01083         free_indices.clear();
01084 
01085         for (size_t i=0; i<e_expanded.nops(); i++) {
01086             exvector free_indices_of_term;
01087             ex term = simplify_indexed(e_expanded.op(i), free_indices_of_term, dummy_indices, sp);
01088             if (!term.is_zero()) {
01089                 if (first) {
01090                     free_indices = free_indices_of_term;
01091                     sum = term;
01092                     first = false;
01093                 } else {
01094                     if (!indices_consistent(free_indices, free_indices_of_term)) {
01095                         std::ostringstream s;
01096                         s << "simplify_indexed: inconsistent indices in sum: ";
01097                         s << exprseq(free_indices) << " vs. " << exprseq(free_indices_of_term);
01098                         throw (std::runtime_error(s.str()));
01099                     }
01100                     if (is_a<indexed>(sum) && is_a<indexed>(term))
01101                         sum = ex_to<basic>(sum.op(0)).add_indexed(sum, term);
01102                     else
01103                         sum += term;
01104                 }
01105             }
01106         }
01107 
01108         // If the sum turns out to be zero, we are finished
01109         if (sum.is_zero()) {
01110             free_indices.clear();
01111             return sum;
01112         }
01113 
01114         // More than one term and more than one dummy index?
01115         size_t num_terms_orig = (is_exactly_a<add>(sum) ? sum.nops() : 1);
01116         if (num_terms_orig < 2 || dummy_indices.size() < 2)
01117             return sum;
01118 
01119         // Chop the sum into terms and symmetrize each one over the dummy
01120         // indices
01121         std::vector<terminfo> terms;
01122         for (size_t i=0; i<sum.nops(); i++) {
01123             const ex & term = sum.op(i);
01124             exvector dummy_indices_of_term;
01125             dummy_indices_of_term.reserve(dummy_indices.size());
01126             for(exvector::iterator i=dummy_indices.begin(); i!=dummy_indices.end(); ++i)
01127                 if(hasindex(term,i->op(0)))
01128                     dummy_indices_of_term.push_back(*i);
01129             ex term_symm = idx_symmetrization<idx>(term, dummy_indices_of_term);
01130             term_symm = idx_symmetrization<varidx>(term_symm, dummy_indices_of_term);
01131             term_symm = idx_symmetrization<spinidx>(term_symm, dummy_indices_of_term);
01132             if (term_symm.is_zero())
01133                 continue;
01134             terms.push_back(terminfo(term, term_symm));
01135         }
01136 
01137         // Sort by symmetrized terms
01138         std::sort(terms.begin(), terms.end(), terminfo_is_less());
01139 
01140         // Combine equal symmetrized terms
01141         std::vector<terminfo> terms_pass2;
01142         for (std::vector<terminfo>::const_iterator i=terms.begin(); i!=terms.end(); ) {
01143             size_t num = 1;
01144             std::vector<terminfo>::const_iterator j = i + 1;
01145             while (j != terms.end() && j->symm == i->symm) {
01146                 num++;
01147                 j++;
01148             }
01149             terms_pass2.push_back(terminfo(i->orig * num, i->symm * num));
01150             i = j;
01151         }
01152 
01153         // If there is only one term left, we are finished
01154         if (terms_pass2.size() == 1)
01155             return terms_pass2[0].orig;
01156 
01157         // Chop the symmetrized terms into subterms
01158         std::vector<symminfo> sy;
01159         for (std::vector<terminfo>::const_iterator i=terms_pass2.begin(); i!=terms_pass2.end(); ++i) {
01160             if (is_exactly_a<add>(i->symm)) {
01161                 size_t num = i->symm.nops();
01162                 for (size_t j=0; j<num; j++)
01163                     sy.push_back(symminfo(i->symm.op(j), i->orig, num));
01164             } else
01165                 sy.push_back(symminfo(i->symm, i->orig, 1));
01166         }
01167 
01168         // Sort by symmetrized subterms
01169         std::sort(sy.begin(), sy.end(), symminfo_is_less_by_symmterm());
01170 
01171         // Combine equal symmetrized subterms
01172         std::vector<symminfo> sy_pass2;
01173         exvector result;
01174         for (std::vector<symminfo>::const_iterator i=sy.begin(); i!=sy.end(); ) {
01175 
01176             // Combine equal terms
01177             std::vector<symminfo>::const_iterator j = i + 1;
01178             if (j != sy.end() && j->symmterm == i->symmterm) {
01179 
01180                 // More than one term, collect the coefficients
01181                 ex coeff = i->coeff;
01182                 while (j != sy.end() && j->symmterm == i->symmterm) {
01183                     coeff += j->coeff;
01184                     j++;
01185                 }
01186 
01187                 // Add combined term to result
01188                 if (!coeff.is_zero())
01189                     result.push_back(coeff * i->symmterm);
01190 
01191             } else {
01192 
01193                 // Single term, store for second pass
01194                 sy_pass2.push_back(*i);
01195             }
01196 
01197             i = j;
01198         }
01199 
01200         // Were there any remaining terms that didn't get combined?
01201         if (sy_pass2.size() > 0) {
01202 
01203             // Yes, sort by their original terms
01204             std::sort(sy_pass2.begin(), sy_pass2.end(), symminfo_is_less_by_orig());
01205 
01206             for (std::vector<symminfo>::const_iterator i=sy_pass2.begin(); i!=sy_pass2.end(); ) {
01207 
01208                 // How many symmetrized terms of this original term are left?
01209                 size_t num = 1;
01210                 std::vector<symminfo>::const_iterator j = i + 1;
01211                 while (j != sy_pass2.end() && j->orig == i->orig) {
01212                     num++;
01213                     j++;
01214                 }
01215 
01216                 if (num == i->num) {
01217 
01218                     // All terms left, then add the original term to the result
01219                     result.push_back(i->orig);
01220 
01221                 } else {
01222 
01223                     // Some terms were combined with others, add up the remaining symmetrized terms
01224                     std::vector<symminfo>::const_iterator k;
01225                     for (k=i; k!=j; k++)
01226                         result.push_back(k->coeff * k->symmterm);
01227                 }
01228 
01229                 i = j;
01230             }
01231         }
01232 
01233         // Add all resulting terms
01234         ex sum_symm = (new add(result))->setflag(status_flags::dynallocated);
01235         if (sum_symm.is_zero())
01236             free_indices.clear();
01237         return sum_symm;
01238     }
01239 
01240     // Simplification of products
01241     if (is_exactly_a<mul>(e_expanded)
01242      || is_exactly_a<ncmul>(e_expanded)
01243      || (is_exactly_a<power>(e_expanded) && is_a<indexed>(e_expanded.op(0)) && e_expanded.op(1).is_equal(_ex2)))
01244         return simplify_indexed_product(e_expanded, free_indices, dummy_indices, sp);
01245 
01246     // Cannot do anything
01247     free_indices.clear();
01248     return e_expanded;
01249 }
01250 
01257 ex ex::simplify_indexed(unsigned options) const
01258 {
01259     exvector free_indices, dummy_indices;
01260     scalar_products sp;
01261     return GiNaC::simplify_indexed(*this, free_indices, dummy_indices, sp);
01262 }
01263 
01272 ex ex::simplify_indexed(const scalar_products & sp, unsigned options) const
01273 {
01274     exvector free_indices, dummy_indices;
01275     return GiNaC::simplify_indexed(*this, free_indices, dummy_indices, sp);
01276 }
01277 
01279 ex ex::symmetrize() const
01280 {
01281     return GiNaC::symmetrize(*this, get_free_indices());
01282 }
01283 
01285 ex ex::antisymmetrize() const
01286 {
01287     return GiNaC::antisymmetrize(*this, get_free_indices());
01288 }
01289 
01291 ex ex::symmetrize_cyclic() const
01292 {
01293     return GiNaC::symmetrize_cyclic(*this, get_free_indices());
01294 }
01295 
01297 // helper classes
01299 
01300 spmapkey::spmapkey(const ex & v1_, const ex & v2_, const ex & dim_) : dim(dim_)
01301 {
01302     // If indexed, extract base objects
01303     ex s1 = is_a<indexed>(v1_) ? v1_.op(0) : v1_;
01304     ex s2 = is_a<indexed>(v2_) ? v2_.op(0) : v2_;
01305 
01306     // Enforce canonical order in pair
01307     if (s1.compare(s2) > 0) {
01308         v1 = s2;
01309         v2 = s1;
01310     } else {
01311         v1 = s1;
01312         v2 = s2;
01313     }
01314 }
01315 
01316 bool spmapkey::operator==(const spmapkey &other) const
01317 {
01318     if (!v1.is_equal(other.v1))
01319         return false;
01320     if (!v2.is_equal(other.v2))
01321         return false;
01322     if (is_a<wildcard>(dim) || is_a<wildcard>(other.dim))
01323         return true;
01324     else
01325         return dim.is_equal(other.dim);
01326 }
01327 
01328 bool spmapkey::operator<(const spmapkey &other) const
01329 {
01330     int cmp = v1.compare(other.v1);
01331     if (cmp)
01332         return cmp < 0;
01333     cmp = v2.compare(other.v2);
01334     if (cmp)
01335         return cmp < 0;
01336 
01337     // Objects are equal, now check dimensions
01338     if (is_a<wildcard>(dim) || is_a<wildcard>(other.dim))
01339         return false;
01340     else
01341         return dim.compare(other.dim) < 0;
01342 }
01343 
01344 void spmapkey::debugprint() const
01345 {
01346     std::cerr << "(" << v1 << "," << v2 << "," << dim << ")";
01347 }
01348 
01349 void scalar_products::add(const ex & v1, const ex & v2, const ex & sp)
01350 {
01351     spm[spmapkey(v1, v2)] = sp;
01352 }
01353 
01354 void scalar_products::add(const ex & v1, const ex & v2, const ex & dim, const ex & sp)
01355 {
01356     spm[spmapkey(v1, v2, dim)] = sp;
01357 }
01358 
01359 void scalar_products::add_vectors(const lst & l, const ex & dim)
01360 {
01361     // Add all possible pairs of products
01362     for (lst::const_iterator it1 = l.begin(); it1 != l.end(); ++it1)
01363         for (lst::const_iterator it2 = l.begin(); it2 != l.end(); ++it2)
01364             add(*it1, *it2, *it1 * *it2);
01365 }
01366 
01367 void scalar_products::clear()
01368 {
01369     spm.clear();
01370 }
01371 
01373 bool scalar_products::is_defined(const ex & v1, const ex & v2, const ex & dim) const
01374 {
01375     return spm.find(spmapkey(v1, v2, dim)) != spm.end();
01376 }
01377 
01379 ex scalar_products::evaluate(const ex & v1, const ex & v2, const ex & dim) const
01380 {
01381     return spm.find(spmapkey(v1, v2, dim))->second;
01382 }
01383 
01384 void scalar_products::debugprint() const
01385 {
01386     std::cerr << "map size=" << spm.size() << std::endl;
01387     spmap::const_iterator i = spm.begin(), end = spm.end();
01388     while (i != end) {
01389         const spmapkey & k = i->first;
01390         std::cerr << "item key=";
01391         k.debugprint();
01392         std::cerr << ", value=" << i->second << std::endl;
01393         ++i;
01394     }
01395 }
01396 
01397 exvector get_all_dummy_indices_safely(const ex & e)
01398 {
01399     if (is_a<indexed>(e))
01400         return ex_to<indexed>(e).get_dummy_indices();
01401     else if (is_a<power>(e) && e.op(1)==2) {
01402         return e.op(0).get_free_indices();
01403     }   
01404     else if (is_a<mul>(e) || is_a<ncmul>(e)) {
01405         exvector dummies;
01406         exvector free_indices;
01407         for (std::size_t i = 0; i < e.nops(); ++i) {
01408             exvector dummies_of_factor = get_all_dummy_indices_safely(e.op(i));
01409             dummies.insert(dummies.end(), dummies_of_factor.begin(),
01410                 dummies_of_factor.end());
01411             exvector free_of_factor = e.op(i).get_free_indices();
01412             free_indices.insert(free_indices.begin(), free_of_factor.begin(),
01413                 free_of_factor.end());
01414         }
01415         exvector free_out, dummy_out;
01416         find_free_and_dummy(free_indices.begin(), free_indices.end(), free_out,
01417             dummy_out);
01418         dummies.insert(dummies.end(), dummy_out.begin(), dummy_out.end());
01419         return dummies;
01420     }
01421     else if(is_a<add>(e)) {
01422         exvector result;
01423         for(std::size_t i = 0; i < e.nops(); ++i) {
01424             exvector dummies_of_term = get_all_dummy_indices_safely(e.op(i));
01425             sort(dummies_of_term.begin(), dummies_of_term.end());
01426             exvector new_vec;
01427             set_union(result.begin(), result.end(), dummies_of_term.begin(),
01428                 dummies_of_term.end(), std::back_inserter<exvector>(new_vec),
01429                 ex_is_less());
01430             result.swap(new_vec);
01431         }
01432         return result;
01433     }
01434     return exvector();
01435 }
01436 
01438 exvector get_all_dummy_indices(const ex & e)
01439 {
01440     exvector p;
01441     bool nc;
01442     product_to_exvector(e, p, nc);
01443     exvector::const_iterator ip = p.begin(), ipend = p.end();
01444     exvector v, v1;
01445     while (ip != ipend) {
01446         if (is_a<indexed>(*ip)) {
01447             v1 = ex_to<indexed>(*ip).get_dummy_indices();
01448             v.insert(v.end(), v1.begin(), v1.end());
01449             exvector::const_iterator ip1 = ip+1;
01450             while (ip1 != ipend) {
01451                 if (is_a<indexed>(*ip1)) {
01452                     v1 = ex_to<indexed>(*ip).get_dummy_indices(ex_to<indexed>(*ip1));
01453                     v.insert(v.end(), v1.begin(), v1.end());
01454                 }
01455                 ++ip1;
01456             }
01457         }
01458         ++ip;
01459     }
01460     return v;
01461 }
01462 
01463 lst rename_dummy_indices_uniquely(const exvector & va, const exvector & vb)
01464 {
01465     exvector common_indices;
01466     set_intersection(va.begin(), va.end(), vb.begin(), vb.end(), std::back_insert_iterator<exvector>(common_indices), ex_is_less());
01467     if (common_indices.empty()) {
01468         return lst(lst(), lst());
01469     } else {
01470         exvector new_indices, old_indices;
01471         old_indices.reserve(2*common_indices.size());
01472         new_indices.reserve(2*common_indices.size());
01473         exvector::const_iterator ip = common_indices.begin(), ipend = common_indices.end();
01474         while (ip != ipend) {
01475             ex newsym=(new symbol)->setflag(status_flags::dynallocated);
01476             ex newidx;
01477             if(is_exactly_a<spinidx>(*ip))
01478                 newidx = (new spinidx(newsym, ex_to<spinidx>(*ip).get_dim(),
01479                         ex_to<spinidx>(*ip).is_covariant(),
01480                         ex_to<spinidx>(*ip).is_dotted()))
01481                     -> setflag(status_flags::dynallocated);
01482             else if (is_exactly_a<varidx>(*ip))
01483                 newidx = (new varidx(newsym, ex_to<varidx>(*ip).get_dim(),
01484                         ex_to<varidx>(*ip).is_covariant()))
01485                     -> setflag(status_flags::dynallocated);
01486             else
01487                 newidx = (new idx(newsym, ex_to<idx>(*ip).get_dim()))
01488                     -> setflag(status_flags::dynallocated);
01489             old_indices.push_back(*ip);
01490             new_indices.push_back(newidx);
01491             if(is_a<varidx>(*ip)) {
01492                 old_indices.push_back(ex_to<varidx>(*ip).toggle_variance());
01493                 new_indices.push_back(ex_to<varidx>(newidx).toggle_variance());
01494             }
01495             ++ip;
01496         }
01497         return lst(lst(old_indices.begin(), old_indices.end()), lst(new_indices.begin(), new_indices.end()));
01498     }
01499 }
01500 
01501 ex rename_dummy_indices_uniquely(const exvector & va, const exvector & vb, const ex & b)
01502 {
01503     lst indices_subs = rename_dummy_indices_uniquely(va, vb);
01504     return (indices_subs.op(0).nops()>0 ? b.subs(ex_to<lst>(indices_subs.op(0)), ex_to<lst>(indices_subs.op(1)), subs_options::no_pattern|subs_options::no_index_renaming) : b);
01505 }
01506 
01507 ex rename_dummy_indices_uniquely(const ex & a, const ex & b)
01508 {
01509     exvector va = get_all_dummy_indices_safely(a);
01510     if (va.size() > 0) {
01511         exvector vb = get_all_dummy_indices_safely(b);
01512         if (vb.size() > 0) {
01513             sort(va.begin(), va.end(), ex_is_less());
01514             sort(vb.begin(), vb.end(), ex_is_less());
01515             lst indices_subs = rename_dummy_indices_uniquely(va, vb);
01516             if (indices_subs.op(0).nops() > 0)
01517                 return b.subs(ex_to<lst>(indices_subs.op(0)), ex_to<lst>(indices_subs.op(1)), subs_options::no_pattern|subs_options::no_index_renaming);
01518         }
01519     }
01520     return b;
01521 }
01522 
01523 ex rename_dummy_indices_uniquely(exvector & va, const ex & b, bool modify_va)
01524 {
01525     if (va.size() > 0) {
01526         exvector vb = get_all_dummy_indices_safely(b);
01527         if (vb.size() > 0) {
01528             sort(vb.begin(), vb.end(), ex_is_less());
01529             lst indices_subs = rename_dummy_indices_uniquely(va, vb);
01530             if (indices_subs.op(0).nops() > 0) {
01531                 if (modify_va) {
01532                     for (lst::const_iterator i = ex_to<lst>(indices_subs.op(1)).begin(); i != ex_to<lst>(indices_subs.op(1)).end(); ++i)
01533                         va.push_back(*i);
01534                     exvector uncommon_indices;
01535                     set_difference(vb.begin(), vb.end(), indices_subs.op(0).begin(), indices_subs.op(0).end(), std::back_insert_iterator<exvector>(uncommon_indices), ex_is_less());
01536                     exvector::const_iterator ip = uncommon_indices.begin(), ipend = uncommon_indices.end();
01537                     while (ip != ipend) {
01538                         va.push_back(*ip);
01539                         ++ip;
01540                     }
01541                     sort(va.begin(), va.end(), ex_is_less());
01542                 }
01543                 return b.subs(ex_to<lst>(indices_subs.op(0)), ex_to<lst>(indices_subs.op(1)), subs_options::no_pattern|subs_options::no_index_renaming);
01544             }
01545         }
01546     }
01547     return b;
01548 }
01549 
01550 ex expand_dummy_sum(const ex & e, bool subs_idx)
01551 {
01552     ex e_expanded = e.expand();
01553     pointer_to_map_function_1arg<bool> fcn(expand_dummy_sum, subs_idx);
01554     if (is_a<add>(e_expanded) || is_a<lst>(e_expanded) || is_a<matrix>(e_expanded)) {
01555         return e_expanded.map(fcn);
01556     } else if (is_a<ncmul>(e_expanded) || is_a<mul>(e_expanded) || is_a<power>(e_expanded) || is_a<indexed>(e_expanded)) {
01557         exvector v;
01558         if (is_a<indexed>(e_expanded))
01559             v = ex_to<indexed>(e_expanded).get_dummy_indices();
01560         else
01561             v = get_all_dummy_indices(e_expanded);
01562         ex result = e_expanded;
01563         for(exvector::const_iterator it=v.begin(); it!=v.end(); ++it) {
01564             ex nu = *it;
01565             if (ex_to<idx>(nu).get_dim().info(info_flags::nonnegint)) {
01566                 int idim = ex_to<numeric>(ex_to<idx>(nu).get_dim()).to_int();
01567                 ex en = 0;
01568                 for (int i=0; i < idim; i++) {
01569                     if (subs_idx && is_a<varidx>(nu)) {
01570                         ex other = ex_to<varidx>(nu).toggle_variance();
01571                         en += result.subs(lst(
01572                             nu == idx(i, idim),
01573                             other == idx(i, idim)
01574                         ));
01575                     } else {
01576                         en += result.subs( nu.op(0) == i );
01577                     }
01578                 }
01579                 result = en;
01580             }
01581         }
01582         return result;
01583     } else {
01584         return e;
01585     }
01586 }
01587 
01588 } // namespace GiNaC

This page is part of the GiNaC developer's reference. It was generated automatically by doxygen. For an introduction, see the tutorial.