]> www.ginac.de Git - ginac.git/blobdiff - ginac/indexed.cpp
(p.i*p.i).get_free_indices() returned (.i) instead of ()
[ginac.git] / ginac / indexed.cpp
index 48bb14217dd06a6bf7b18792ea87d31c7fc54452..ad17a497a9dc703e52cd435d6901312071beab42 100644 (file)
@@ -3,7 +3,7 @@
  *  Implementation of GiNaC's indexed expressions. */
 
 /*
- *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2004 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
 #include "symmetry.h"
 #include "operators.h"
 #include "lst.h"
-#include "print.h"
 #include "archive.h"
 #include "utils.h"
 
 namespace GiNaC {
 
-GINAC_IMPLEMENT_REGISTERED_CLASS(indexed, exprseq)
+GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(indexed, exprseq,
+  print_func<print_context>(&indexed::do_print).
+  print_func<print_latex>(&indexed::do_print_latex).
+  print_func<print_tree>(&indexed::do_print_tree))
 
 //////////
 // default constructor
 //////////
 
-indexed::indexed() : symtree(sy_none())
+indexed::indexed() : symtree(not_symmetric())
 {
        tinfo_key = TINFO_indexed;
 }
@@ -55,31 +57,31 @@ indexed::indexed() : symtree(sy_none())
 // other constructors
 //////////
 
-indexed::indexed(const ex & b) : inherited(b), symtree(sy_none())
+indexed::indexed(const ex & b) : inherited(b), symtree(not_symmetric())
 {
        tinfo_key = TINFO_indexed;
        validate();
 }
 
-indexed::indexed(const ex & b, const ex & i1) : inherited(b, i1), symtree(sy_none())
+indexed::indexed(const ex & b, const ex & i1) : inherited(b, i1), symtree(not_symmetric())
 {
        tinfo_key = TINFO_indexed;
        validate();
 }
 
-indexed::indexed(const ex & b, const ex & i1, const ex & i2) : inherited(b, i1, i2), symtree(sy_none())
+indexed::indexed(const ex & b, const ex & i1, const ex & i2) : inherited(b, i1, i2), symtree(not_symmetric())
 {
        tinfo_key = TINFO_indexed;
        validate();
 }
 
-indexed::indexed(const ex & b, const ex & i1, const ex & i2, const ex & i3) : inherited(b, i1, i2, i3), symtree(sy_none())
+indexed::indexed(const ex & b, const ex & i1, const ex & i2, const ex & i3) : inherited(b, i1, i2, i3), symtree(not_symmetric())
 {
        tinfo_key = TINFO_indexed;
        validate();
 }
 
-indexed::indexed(const ex & b, const ex & i1, const ex & i2, const ex & i3, const ex & i4) : inherited(b, i1, i2, i3, i4), symtree(sy_none())
+indexed::indexed(const ex & b, const ex & i1, const ex & i2, const ex & i3, const ex & i4) : inherited(b, i1, i2, i3, i4), symtree(not_symmetric())
 {
        tinfo_key = TINFO_indexed;
        validate();
@@ -103,7 +105,7 @@ indexed::indexed(const ex & b, const symmetry & symm, const ex & i1, const ex &
        validate();
 }
 
-indexed::indexed(const ex & b, const exvector & v) : inherited(b), symtree(sy_none())
+indexed::indexed(const ex & b, const exvector & v) : inherited(b), symtree(not_symmetric())
 {
        seq.insert(seq.end(), v.begin(), v.end());
        tinfo_key = TINFO_indexed;
@@ -127,7 +129,7 @@ indexed::indexed(const symmetry & symm, const exvector & v, bool discardable) :
        tinfo_key = TINFO_indexed;
 }
 
-indexed::indexed(const symmetry & symm, exvector * vp) : inherited(vp), symtree(symm)
+indexed::indexed(const symmetry & symm, std::auto_ptr<exvector> vp) : inherited(vp), symtree(symm)
 {
        tinfo_key = TINFO_indexed;
 }
@@ -150,7 +152,7 @@ indexed::indexed(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
                                symtree = sy_anti();
                                break;
                        default:
-                               symtree = sy_none();
+                               symtree = not_symmetric();
                                break;
                }
                const_cast<symmetry &>(ex_to<symmetry>(symtree)).validate(seq.size() - 1);
@@ -169,38 +171,80 @@ DEFAULT_UNARCHIVE(indexed)
 // functions overriding virtual functions from base classes
 //////////
 
-void indexed::print(const print_context & c, unsigned level) const
+void indexed::printindices(const print_context & c, unsigned level) const
 {
-       GINAC_ASSERT(seq.size() > 0);
-
-       if (is_a<print_tree>(c)) {
+       if (seq.size() > 1) {
 
-               c.s << std::string(level, ' ') << class_name()
-                   << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
-                   << ", " << seq.size()-1 << " indices"
-                   << ", symmetry=" << symtree << std::endl;
-               unsigned delta_indent = static_cast<const print_tree &>(c).delta_indent;
-               seq[0].print(c, level + delta_indent);
-               printindices(c, level + delta_indent);
+               exvector::const_iterator it=seq.begin() + 1, itend = seq.end();
 
-       } else {
+               if (is_a<print_latex>(c)) {
 
-               bool is_tex = is_a<print_latex>(c);
-               const ex & base = seq[0];
+                       // TeX output: group by variance
+                       bool first = true;
+                       bool covariant = true;
 
-               if (precedence() <= level)
-                       c.s << (is_tex ? "{(" : "(");
-               if (is_tex)
-                       c.s << "{";
-               base.print(c, precedence());
-               if (is_tex)
+                       while (it != itend) {
+                               bool cur_covariant = (is_a<varidx>(*it) ? ex_to<varidx>(*it).is_covariant() : true);
+                               if (first || cur_covariant != covariant) { // Variance changed
+                                       // The empty {} prevents indices from ending up on top of each other
+                                       if (!first)
+                                               c.s << "}{}";
+                                       covariant = cur_covariant;
+                                       if (covariant)
+                                               c.s << "_{";
+                                       else
+                                               c.s << "^{";
+                               }
+                               it->print(c, level);
+                               c.s << " ";
+                               first = false;
+                               it++;
+                       }
                        c.s << "}";
-               printindices(c, level);
-               if (precedence() <= level)
-                       c.s << (is_tex ? ")}" : ")");
+
+               } else {
+
+                       // Ordinary output
+                       while (it != itend) {
+                               it->print(c, level);
+                               it++;
+                       }
+               }
        }
 }
 
+void indexed::print_indexed(const print_context & c, const char *openbrace, const char *closebrace, unsigned level) const
+{
+       if (precedence() <= level)
+               c.s << openbrace << '(';
+       c.s << openbrace;
+       seq[0].print(c, precedence());
+       c.s << closebrace;
+       printindices(c, level);
+       if (precedence() <= level)
+               c.s << ')' << closebrace;
+}
+
+void indexed::do_print(const print_context & c, unsigned level) const
+{
+       print_indexed(c, "", "", level);
+}
+
+void indexed::do_print_latex(const print_latex & c, unsigned level) const
+{
+       print_indexed(c, "{", "}", level);
+}
+
+void indexed::do_print_tree(const print_tree & c, unsigned level) const
+{
+       c.s << std::string(level, ' ') << class_name() << " @" << this
+           << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
+           << ", " << seq.size()-1 << " indices"
+           << ", symmetry=" << symtree << std::endl;
+       seq[0].print(c, level + c.delta_indent);
+       printindices(c, level + c.delta_indent);
+}
+
 bool indexed::info(unsigned inf) const
 {
        if (inf == info_flags::indexed) return true;
@@ -272,7 +316,7 @@ ex indexed::thiscontainer(const exvector & v) const
        return indexed(ex_to<symmetry>(symtree), v);
 }
 
-ex indexed::thiscontainer(exvector * vp) const
+ex indexed::thiscontainer(std::auto_ptr<exvector> vp) const
 {
        return indexed(ex_to<symmetry>(symtree), vp);
 }
@@ -281,20 +325,24 @@ ex indexed::expand(unsigned options) const
 {
        GINAC_ASSERT(seq.size() > 0);
 
-       if ((options & expand_options::expand_indexed) && is_exactly_a<add>(seq[0])) {
-
-               // expand_indexed expands (a+b).i -> a.i + b.i
-               const ex & base = seq[0];
-               ex sum = _ex0;
-               for (size_t i=0; i<base.nops(); i++) {
+       if (options & expand_options::expand_indexed) {
+               ex newbase = seq[0].expand(options);
+               if (is_exactly_a<add>(newbase)) {
+                       ex sum = _ex0;
+                       for (size_t i=0; i<newbase.nops(); i++) {
+                               exvector s = seq;
+                               s[0] = newbase.op(i);
+                               sum += thiscontainer(s).expand(options);
+                       }
+                       return sum;
+               }
+               if (!are_ex_trivially_equal(newbase, seq[0])) {
                        exvector s = seq;
-                       s[0] = base.op(i);
-                       sum += thiscontainer(s).expand();
+                       s[0] = newbase;
+                       return ex_to<indexed>(thiscontainer(s)).inherited::expand(options);
                }
-               return sum;
-
-       } else
-               return inherited::expand(options);
+       }
+       return inherited::expand(options);
 }
 
 //////////
@@ -307,48 +355,6 @@ ex indexed::expand(unsigned options) const
 // non-virtual functions in this class
 //////////
 
-void indexed::printindices(const print_context & c, unsigned level) const
-{
-       if (seq.size() > 1) {
-
-               exvector::const_iterator it=seq.begin() + 1, itend = seq.end();
-
-               if (is_a<print_latex>(c)) {
-
-                       // TeX output: group by variance
-                       bool first = true;
-                       bool covariant = true;
-
-                       while (it != itend) {
-                               bool cur_covariant = (is_a<varidx>(*it) ? ex_to<varidx>(*it).is_covariant() : true);
-                               if (first || cur_covariant != covariant) { // Variance changed
-                                       // The empty {} prevents indices from ending up on top of each other
-                                       if (!first)
-                                               c.s << "}{}";
-                                       covariant = cur_covariant;
-                                       if (covariant)
-                                               c.s << "_{";
-                                       else
-                                               c.s << "^{";
-                               }
-                               it->print(c, level);
-                               c.s << " ";
-                               first = false;
-                               it++;
-                       }
-                       c.s << "}";
-
-               } else {
-
-                       // Ordinary output
-                       while (it != itend) {
-                               it->print(c, level);
-                               it++;
-                       }
-               }
-       }
-}
-
 /** Check whether all indices are of class idx and validate the symmetry
  *  tree. This function is used internally to make sure that all constructed
  *  indexed objects really carry indices and not some other classes. */
@@ -493,10 +499,27 @@ exvector ncmul::get_free_indices() const
        return free_indices;
 }
 
+struct is_summation_idx : public std::unary_function<ex, bool> {
+       bool operator()(const ex & e)
+       {
+               return is_dummy_pair(e, e);
+       }
+};
+
 exvector power::get_free_indices() const
 {
-       // Return free indices of basis
-       return basis.get_free_indices();
+       // Get free indices of basis
+       exvector basis_indices = basis.get_free_indices();
+
+       if (exponent.info(info_flags::even)) {
+               // If the exponent is an even number, then any "free" index that
+               // forms a dummy pair with itself is actually a summation index
+               exvector really_free;
+               std::remove_copy_if(basis_indices.begin(), basis_indices.end(),
+                                   std::back_inserter(really_free), is_summation_idx());
+               return really_free;
+       } else
+               return basis_indices;
 }
 
 /** Rename dummy indices in an expression.
@@ -1106,7 +1129,7 @@ ex simplify_indexed(const ex & e, exvector & free_indices, exvector & dummy_indi
  *  the free indices in sums are consistent.
  *
  *  @return simplified expression */
-ex ex::simplify_indexed() const
+ex ex::simplify_indexed(unsigned options) const
 {
        exvector free_indices, dummy_indices;
        scalar_products sp;
@@ -1120,7 +1143,7 @@ ex ex::simplify_indexed() const
  *
  *  @param sp Scalar products to be replaced automatically
  *  @return simplified expression */
-ex ex::simplify_indexed(const scalar_products & sp) const
+ex ex::simplify_indexed(const scalar_products & sp, unsigned options) const
 {
        exvector free_indices, dummy_indices;
        return GiNaC::simplify_indexed(*this, free_indices, dummy_indices, sp);