412aad52d9184206e34523beb75d494aa96fd6da
[ginac.git] / ginac / indexed.cpp
1 /** @file indexed.cpp
2  *
3  *  Implementation of GiNaC's indexed expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <sstream>
25 #include <stdexcept>
26
27 #include "indexed.h"
28 #include "idx.h"
29 #include "add.h"
30 #include "mul.h"
31 #include "ncmul.h"
32 #include "power.h"
33 #include "relational.h"
34 #include "symmetry.h"
35 #include "operators.h"
36 #include "lst.h"
37 #include "archive.h"
38 #include "utils.h"
39
40 namespace GiNaC {
41
42 GINAC_IMPLEMENT_REGISTERED_CLASS(indexed, exprseq)
43
44 //////////
45 // default constructor
46 //////////
47
48 indexed::indexed() : symtree(sy_none())
49 {
50         tinfo_key = TINFO_indexed;
51 }
52
53 //////////
54 // other constructors
55 //////////
56
57 indexed::indexed(const ex & b) : inherited(b), symtree(sy_none())
58 {
59         tinfo_key = TINFO_indexed;
60         validate();
61 }
62
63 indexed::indexed(const ex & b, const ex & i1) : inherited(b, i1), symtree(sy_none())
64 {
65         tinfo_key = TINFO_indexed;
66         validate();
67 }
68
69 indexed::indexed(const ex & b, const ex & i1, const ex & i2) : inherited(b, i1, i2), symtree(sy_none())
70 {
71         tinfo_key = TINFO_indexed;
72         validate();
73 }
74
75 indexed::indexed(const ex & b, const ex & i1, const ex & i2, const ex & i3) : inherited(b, i1, i2, i3), symtree(sy_none())
76 {
77         tinfo_key = TINFO_indexed;
78         validate();
79 }
80
81 indexed::indexed(const ex & b, const ex & i1, const ex & i2, const ex & i3, const ex & i4) : inherited(b, i1, i2, i3, i4), symtree(sy_none())
82 {
83         tinfo_key = TINFO_indexed;
84         validate();
85 }
86
87 indexed::indexed(const ex & b, const symmetry & symm, const ex & i1, const ex & i2) : inherited(b, i1, i2), symtree(symm)
88 {
89         tinfo_key = TINFO_indexed;
90         validate();
91 }
92
93 indexed::indexed(const ex & b, const symmetry & symm, const ex & i1, const ex & i2, const ex & i3) : inherited(b, i1, i2, i3), symtree(symm)
94 {
95         tinfo_key = TINFO_indexed;
96         validate();
97 }
98
99 indexed::indexed(const ex & b, const symmetry & symm, const ex & i1, const ex & i2, const ex & i3, const ex & i4) : inherited(b, i1, i2, i3, i4), symtree(symm)
100 {
101         tinfo_key = TINFO_indexed;
102         validate();
103 }
104
105 indexed::indexed(const ex & b, const exvector & v) : inherited(b), symtree(sy_none())
106 {
107         seq.insert(seq.end(), v.begin(), v.end());
108         tinfo_key = TINFO_indexed;
109         validate();
110 }
111
112 indexed::indexed(const ex & b, const symmetry & symm, const exvector & v) : inherited(b), symtree(symm)
113 {
114         seq.insert(seq.end(), v.begin(), v.end());
115         tinfo_key = TINFO_indexed;
116         validate();
117 }
118
119 indexed::indexed(const symmetry & symm, const exprseq & es) : inherited(es), symtree(symm)
120 {
121         tinfo_key = TINFO_indexed;
122 }
123
124 indexed::indexed(const symmetry & symm, const exvector & v, bool discardable) : inherited(v, discardable), symtree(symm)
125 {
126         tinfo_key = TINFO_indexed;
127 }
128
129 indexed::indexed(const symmetry & symm, exvector * vp) : inherited(vp), symtree(symm)
130 {
131         tinfo_key = TINFO_indexed;
132 }
133
134 //////////
135 // archiving
136 //////////
137
138 indexed::indexed(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
139 {
140         if (!n.find_ex("symmetry", symtree, sym_lst)) {
141                 // GiNaC versions <= 0.9.0 had an unsigned "symmetry" property
142                 unsigned symm = 0;
143                 n.find_unsigned("symmetry", symm);
144                 switch (symm) {
145                         case 1:
146                                 symtree = sy_symm();
147                                 break;
148                         case 2:
149                                 symtree = sy_anti();
150                                 break;
151                         default:
152                                 symtree = sy_none();
153                                 break;
154                 }
155                 const_cast<symmetry &>(ex_to<symmetry>(symtree)).validate(seq.size() - 1);
156         }
157 }
158
159 void indexed::archive(archive_node &n) const
160 {
161         inherited::archive(n);
162         n.add_ex("symmetry", symtree);
163 }
164
165 DEFAULT_UNARCHIVE(indexed)
166
167 //////////
168 // functions overriding virtual functions from base classes
169 //////////
170
171 void indexed::print(const print_context & c, unsigned level) const
172 {
173         GINAC_ASSERT(seq.size() > 0);
174
175         if (is_a<print_tree>(c)) {
176
177                 c.s << std::string(level, ' ') << class_name()
178                     << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
179                     << ", " << seq.size()-1 << " indices"
180                     << ", symmetry=" << symtree << std::endl;
181                 unsigned delta_indent = static_cast<const print_tree &>(c).delta_indent;
182                 seq[0].print(c, level + delta_indent);
183                 printindices(c, level + delta_indent);
184
185         } else {
186
187                 bool is_tex = is_a<print_latex>(c);
188                 const ex & base = seq[0];
189
190                 if (precedence() <= level)
191                         c.s << (is_tex ? "{(" : "(");
192                 if (is_tex)
193                         c.s << "{";
194                 base.print(c, precedence());
195                 if (is_tex)
196                         c.s << "}";
197                 printindices(c, level);
198                 if (precedence() <= level)
199                         c.s << (is_tex ? ")}" : ")");
200         }
201 }
202
203 bool indexed::info(unsigned inf) const
204 {
205         if (inf == info_flags::indexed) return true;
206         if (inf == info_flags::has_indices) return seq.size() > 1;
207         return inherited::info(inf);
208 }
209
210 struct idx_is_not : public std::binary_function<ex, unsigned, bool> {
211         bool operator() (const ex & e, unsigned inf) const {
212                 return !(ex_to<idx>(e).get_value().info(inf));
213         }
214 };
215
216 bool indexed::all_index_values_are(unsigned inf) const
217 {
218         // No indices? Then no property can be fulfilled
219         if (seq.size() < 2)
220                 return false;
221
222         // Check all indices
223         return find_if(seq.begin() + 1, seq.end(), bind2nd(idx_is_not(), inf)) == seq.end();
224 }
225
226 int indexed::compare_same_type(const basic & other) const
227 {
228         GINAC_ASSERT(is_a<indexed>(other));
229         return inherited::compare_same_type(other);
230 }
231
232 ex indexed::eval(int level) const
233 {
234         // First evaluate children, then we will end up here again
235         if (level > 1)
236                 return indexed(ex_to<symmetry>(symtree), evalchildren(level));
237
238         const ex &base = seq[0];
239
240         // If the base object is 0, the whole object is 0
241         if (base.is_zero())
242                 return _ex0;
243
244         // If the base object is a product, pull out the numeric factor
245         if (is_exactly_a<mul>(base) && is_exactly_a<numeric>(base.op(base.nops() - 1))) {
246                 exvector v(seq);
247                 ex f = ex_to<numeric>(base.op(base.nops() - 1));
248                 v[0] = seq[0] / f;
249                 return f * thiscontainer(v);
250         }
251
252         // Canonicalize indices according to the symmetry properties
253         if (seq.size() > 2) {
254                 exvector v = seq;
255                 GINAC_ASSERT(is_exactly_a<symmetry>(symtree));
256                 int sig = canonicalize(v.begin() + 1, ex_to<symmetry>(symtree));
257                 if (sig != INT_MAX) {
258                         // Something has changed while sorting indices, more evaluations later
259                         if (sig == 0)
260                                 return _ex0;
261                         return ex(sig) * thiscontainer(v);
262                 }
263         }
264
265         // Let the class of the base object perform additional evaluations
266         return ex_to<basic>(base).eval_indexed(*this);
267 }
268
269 ex indexed::thiscontainer(const exvector & v) const
270 {
271         return indexed(ex_to<symmetry>(symtree), v);
272 }
273
274 ex indexed::thiscontainer(exvector * vp) const
275 {
276         return indexed(ex_to<symmetry>(symtree), vp);
277 }
278
279 ex indexed::expand(unsigned options) const
280 {
281         GINAC_ASSERT(seq.size() > 0);
282
283         if ((options & expand_options::expand_indexed) && is_exactly_a<add>(seq[0])) {
284
285                 // expand_indexed expands (a+b).i -> a.i + b.i
286                 const ex & base = seq[0];
287                 ex sum = _ex0;
288                 for (size_t i=0; i<base.nops(); i++) {
289                         exvector s = seq;
290                         s[0] = base.op(i);
291                         sum += thiscontainer(s).expand();
292                 }
293                 return sum;
294
295         } else
296                 return inherited::expand(options);
297 }
298
299 //////////
300 // virtual functions which can be overridden by derived classes
301 //////////
302
303 // none
304
305 //////////
306 // non-virtual functions in this class
307 //////////
308
309 void indexed::printindices(const print_context & c, unsigned level) const
310 {
311         if (seq.size() > 1) {
312
313                 exvector::const_iterator it=seq.begin() + 1, itend = seq.end();
314
315                 if (is_a<print_latex>(c)) {
316
317                         // TeX output: group by variance
318                         bool first = true;
319                         bool covariant = true;
320
321                         while (it != itend) {
322                                 bool cur_covariant = (is_a<varidx>(*it) ? ex_to<varidx>(*it).is_covariant() : true);
323                                 if (first || cur_covariant != covariant) { // Variance changed
324                                         // The empty {} prevents indices from ending up on top of each other
325                                         if (!first)
326                                                 c.s << "}{}";
327                                         covariant = cur_covariant;
328                                         if (covariant)
329                                                 c.s << "_{";
330                                         else
331                                                 c.s << "^{";
332                                 }
333                                 it->print(c, level);
334                                 c.s << " ";
335                                 first = false;
336                                 it++;
337                         }
338                         c.s << "}";
339
340                 } else {
341
342                         // Ordinary output
343                         while (it != itend) {
344                                 it->print(c, level);
345                                 it++;
346                         }
347                 }
348         }
349 }
350
351 /** Check whether all indices are of class idx and validate the symmetry
352  *  tree. This function is used internally to make sure that all constructed
353  *  indexed objects really carry indices and not some other classes. */
354 void indexed::validate() const
355 {
356         GINAC_ASSERT(seq.size() > 0);
357         exvector::const_iterator it = seq.begin() + 1, itend = seq.end();
358         while (it != itend) {
359                 if (!is_a<idx>(*it))
360                         throw(std::invalid_argument("indices of indexed object must be of type idx"));
361                 it++;
362         }
363
364         if (!symtree.is_zero()) {
365                 if (!is_exactly_a<symmetry>(symtree))
366                         throw(std::invalid_argument("symmetry of indexed object must be of type symmetry"));
367                 const_cast<symmetry &>(ex_to<symmetry>(symtree)).validate(seq.size() - 1);
368         }
369 }
370
371 /** Implementation of ex::diff() for an indexed object always returns 0.
372  *
373  *  @see ex::diff */
374 ex indexed::derivative(const symbol & s) const
375 {
376         return _ex0;
377 }
378
379 //////////
380 // global functions
381 //////////
382
383 struct idx_is_equal_ignore_dim : public std::binary_function<ex, ex, bool> {
384         bool operator() (const ex &lh, const ex &rh) const
385         {
386                 if (lh.is_equal(rh))
387                         return true;
388                 else
389                         try {
390                                 // Replacing the dimension might cause an error (e.g. with
391                                 // index classes that only work in a fixed number of dimensions)
392                                 return lh.is_equal(ex_to<idx>(rh).replace_dim(ex_to<idx>(lh).get_dim()));
393                         } catch (...) {
394                                 return false;
395                         }
396         }
397 };
398
399 /** Check whether two sorted index vectors are consistent (i.e. equal). */
400 static bool indices_consistent(const exvector & v1, const exvector & v2)
401 {
402         // Number of indices must be the same
403         if (v1.size() != v2.size())
404                 return false;
405
406         return equal(v1.begin(), v1.end(), v2.begin(), idx_is_equal_ignore_dim());
407 }
408
409 exvector indexed::get_indices() const
410 {
411         GINAC_ASSERT(seq.size() >= 1);
412         return exvector(seq.begin() + 1, seq.end());
413 }
414
415 exvector indexed::get_dummy_indices() const
416 {
417         exvector free_indices, dummy_indices;
418         find_free_and_dummy(seq.begin() + 1, seq.end(), free_indices, dummy_indices);
419         return dummy_indices;
420 }
421
422 exvector indexed::get_dummy_indices(const indexed & other) const
423 {
424         exvector indices = get_free_indices();
425         exvector other_indices = other.get_free_indices();
426         indices.insert(indices.end(), other_indices.begin(), other_indices.end());
427         exvector dummy_indices;
428         find_dummy_indices(indices, dummy_indices);
429         return dummy_indices;
430 }
431
432 bool indexed::has_dummy_index_for(const ex & i) const
433 {
434         exvector::const_iterator it = seq.begin() + 1, itend = seq.end();
435         while (it != itend) {
436                 if (is_dummy_pair(*it, i))
437                         return true;
438                 it++;
439         }
440         return false;
441 }
442
443 exvector indexed::get_free_indices() const
444 {
445         exvector free_indices, dummy_indices;
446         find_free_and_dummy(seq.begin() + 1, seq.end(), free_indices, dummy_indices);
447         return free_indices;
448 }
449
450 exvector add::get_free_indices() const
451 {
452         exvector free_indices;
453         for (size_t i=0; i<nops(); i++) {
454                 if (i == 0)
455                         free_indices = op(i).get_free_indices();
456                 else {
457                         exvector free_indices_of_term = op(i).get_free_indices();
458                         if (!indices_consistent(free_indices, free_indices_of_term))
459                                 throw (std::runtime_error("add::get_free_indices: inconsistent indices in sum"));
460                 }
461         }
462         return free_indices;
463 }
464
465 exvector mul::get_free_indices() const
466 {
467         // Concatenate free indices of all factors
468         exvector un;
469         for (size_t i=0; i<nops(); i++) {
470                 exvector free_indices_of_factor = op(i).get_free_indices();
471                 un.insert(un.end(), free_indices_of_factor.begin(), free_indices_of_factor.end());
472         }
473
474         // And remove the dummy indices
475         exvector free_indices, dummy_indices;
476         find_free_and_dummy(un, free_indices, dummy_indices);
477         return free_indices;
478 }
479
480 exvector ncmul::get_free_indices() const
481 {
482         // Concatenate free indices of all factors
483         exvector un;
484         for (size_t i=0; i<nops(); i++) {
485                 exvector free_indices_of_factor = op(i).get_free_indices();
486                 un.insert(un.end(), free_indices_of_factor.begin(), free_indices_of_factor.end());
487         }
488
489         // And remove the dummy indices
490         exvector free_indices, dummy_indices;
491         find_free_and_dummy(un, free_indices, dummy_indices);
492         return free_indices;
493 }
494
495 exvector power::get_free_indices() const
496 {
497         // Return free indices of basis
498         return basis.get_free_indices();
499 }
500
501 /** Rename dummy indices in an expression.
502  *
503  *  @param e Expression to work on
504  *  @param local_dummy_indices The set of dummy indices that appear in the
505  *    expression "e"
506  *  @param global_dummy_indices The set of dummy indices that have appeared
507  *    before and which we would like to use in "e", too. This gets updated
508  *    by the function */
509 static ex rename_dummy_indices(const ex & e, exvector & global_dummy_indices, exvector & local_dummy_indices)
510 {
511         size_t global_size = global_dummy_indices.size(),
512                local_size = local_dummy_indices.size();
513
514         // Any local dummy indices at all?
515         if (local_size == 0)
516                 return e;
517
518         if (global_size < local_size) {
519
520                 // More local indices than we encountered before, add the new ones
521                 // to the global set
522                 size_t old_global_size = global_size;
523                 int remaining = local_size - global_size;
524                 exvector::const_iterator it = local_dummy_indices.begin(), itend = local_dummy_indices.end();
525                 while (it != itend && remaining > 0) {
526                         if (find_if(global_dummy_indices.begin(), global_dummy_indices.end(), bind2nd(op0_is_equal(), *it)) == global_dummy_indices.end()) {
527                                 global_dummy_indices.push_back(*it);
528                                 global_size++;
529                                 remaining--;
530                         }
531                         it++;
532                 }
533
534                 // If this is the first set of local indices, do nothing
535                 if (old_global_size == 0)
536                         return e;
537         }
538         GINAC_ASSERT(local_size <= global_size);
539
540         // Construct vectors of index symbols
541         exvector local_syms, global_syms;
542         local_syms.reserve(local_size);
543         global_syms.reserve(local_size);
544         for (size_t i=0; i<local_size; i++)
545                 local_syms.push_back(local_dummy_indices[i].op(0));
546         shaker_sort(local_syms.begin(), local_syms.end(), ex_is_less(), ex_swap());
547         for (size_t i=0; i<local_size; i++) // don't use more global symbols than necessary
548                 global_syms.push_back(global_dummy_indices[i].op(0));
549         shaker_sort(global_syms.begin(), global_syms.end(), ex_is_less(), ex_swap());
550
551         // Remove common indices
552         exvector local_uniq, global_uniq;
553         set_difference(local_syms.begin(), local_syms.end(), global_syms.begin(), global_syms.end(), std::back_insert_iterator<exvector>(local_uniq), ex_is_less());
554         set_difference(global_syms.begin(), global_syms.end(), local_syms.begin(), local_syms.end(), std::back_insert_iterator<exvector>(global_uniq), ex_is_less());
555
556         // Replace remaining non-common local index symbols by global ones
557         if (local_uniq.empty())
558                 return e;
559         else {
560                 while (global_uniq.size() > local_uniq.size())
561                         global_uniq.pop_back();
562                 return e.subs(lst(local_uniq.begin(), local_uniq.end()), lst(global_uniq.begin(), global_uniq.end()), subs_options::no_pattern);
563         }
564 }
565
566 /** Given a set of indices, extract those of class varidx. */
567 static void find_variant_indices(const exvector & v, exvector & variant_indices)
568 {
569         exvector::const_iterator it1, itend;
570         for (it1 = v.begin(), itend = v.end(); it1 != itend; ++it1) {
571                 if (is_exactly_a<varidx>(*it1))
572                         variant_indices.push_back(*it1);
573         }
574 }
575
576 /** Raise/lower dummy indices in a single indexed objects to canonicalize their
577  *  variance.
578  *
579  *  @param e Object to work on
580  *  @param variant_dummy_indices The set of indices that might need repositioning (will be changed by this function)
581  *  @param moved_indices The set of indices that have been repositioned (will be changed by this function)
582  *  @return true if 'e' was changed */
583 bool reposition_dummy_indices(ex & e, exvector & variant_dummy_indices, exvector & moved_indices)
584 {
585         bool something_changed = false;
586
587         // If a dummy index is encountered for the first time in the
588         // product, pull it up, otherwise, pull it down
589         exvector::const_iterator it2, it2start, it2end;
590         for (it2start = ex_to<indexed>(e).seq.begin(), it2end = ex_to<indexed>(e).seq.end(), it2 = it2start + 1; it2 != it2end; ++it2) {
591                 if (!is_exactly_a<varidx>(*it2))
592                         continue;
593
594                 exvector::iterator vit, vitend;
595                 for (vit = variant_dummy_indices.begin(), vitend = variant_dummy_indices.end(); vit != vitend; ++vit) {
596                         if (it2->op(0).is_equal(vit->op(0))) {
597                                 if (ex_to<varidx>(*it2).is_covariant()) {
598                                         e = e.subs(lst(
599                                                 *it2 == ex_to<varidx>(*it2).toggle_variance(),
600                                                 ex_to<varidx>(*it2).toggle_variance() == *it2
601                                         ), subs_options::no_pattern);
602                                         something_changed = true;
603                                         it2 = ex_to<indexed>(e).seq.begin() + (it2 - it2start);
604                                         it2start = ex_to<indexed>(e).seq.begin();
605                                         it2end = ex_to<indexed>(e).seq.end();
606                                 }
607                                 moved_indices.push_back(*vit);
608                                 variant_dummy_indices.erase(vit);
609                                 goto next_index;
610                         }
611                 }
612
613                 for (vit = moved_indices.begin(), vitend = moved_indices.end(); vit != vitend; ++vit) {
614                         if (it2->op(0).is_equal(vit->op(0))) {
615                                 if (ex_to<varidx>(*it2).is_contravariant()) {
616                                         e = e.subs(*it2 == ex_to<varidx>(*it2).toggle_variance(), subs_options::no_pattern);
617                                         something_changed = true;
618                                         it2 = ex_to<indexed>(e).seq.begin() + (it2 - it2start);
619                                         it2start = ex_to<indexed>(e).seq.begin();
620                                         it2end = ex_to<indexed>(e).seq.end();
621                                 }
622                                 goto next_index;
623                         }
624                 }
625
626 next_index: ;
627         }
628
629         return something_changed;
630 }
631
632 /* Ordering that only compares the base expressions of indexed objects. */
633 struct ex_base_is_less : public std::binary_function<ex, ex, bool> {
634         bool operator() (const ex &lh, const ex &rh) const
635         {
636                 return (is_a<indexed>(lh) ? lh.op(0) : lh).compare(is_a<indexed>(rh) ? rh.op(0) : rh) < 0;
637         }
638 };
639
640 /** Simplify product of indexed expressions (commutative, noncommutative and
641  *  simple squares), return list of free indices. */
642 ex simplify_indexed_product(const ex & e, exvector & free_indices, exvector & dummy_indices, const scalar_products & sp)
643 {
644         // Remember whether the product was commutative or noncommutative
645         // (because we chop it into factors and need to reassemble later)
646         bool non_commutative = is_exactly_a<ncmul>(e);
647
648         // Collect factors in an exvector, store squares twice
649         exvector v;
650         v.reserve(e.nops() * 2);
651
652         if (is_exactly_a<power>(e)) {
653                 // We only get called for simple squares, split a^2 -> a*a
654                 GINAC_ASSERT(e.op(1).is_equal(_ex2));
655                 v.push_back(e.op(0));
656                 v.push_back(e.op(0));
657         } else {
658                 for (size_t i=0; i<e.nops(); i++) {
659                         ex f = e.op(i);
660                         if (is_exactly_a<power>(f) && f.op(1).is_equal(_ex2)) {
661                                 v.push_back(f.op(0));
662                     v.push_back(f.op(0));
663                         } else if (is_exactly_a<ncmul>(f)) {
664                                 // Noncommutative factor found, split it as well
665                                 non_commutative = true; // everything becomes noncommutative, ncmul will sort out the commutative factors later
666                                 for (size_t j=0; j<f.nops(); j++)
667                                         v.push_back(f.op(j));
668                         } else
669                                 v.push_back(f);
670                 }
671         }
672
673         // Perform contractions
674         bool something_changed = false;
675         GINAC_ASSERT(v.size() > 1);
676         exvector::iterator it1, itend = v.end(), next_to_last = itend - 1;
677         for (it1 = v.begin(); it1 != next_to_last; it1++) {
678
679 try_again:
680                 if (!is_a<indexed>(*it1))
681                         continue;
682
683                 bool first_noncommutative = (it1->return_type() != return_types::commutative);
684
685                 // Indexed factor found, get free indices and look for contraction
686                 // candidates
687                 exvector free1, dummy1;
688                 find_free_and_dummy(ex_to<indexed>(*it1).seq.begin() + 1, ex_to<indexed>(*it1).seq.end(), free1, dummy1);
689
690                 exvector::iterator it2;
691                 for (it2 = it1 + 1; it2 != itend; it2++) {
692
693                         if (!is_a<indexed>(*it2))
694                                 continue;
695
696                         bool second_noncommutative = (it2->return_type() != return_types::commutative);
697
698                         // Find free indices of second factor and merge them with free
699                         // indices of first factor
700                         exvector un;
701                         find_free_and_dummy(ex_to<indexed>(*it2).seq.begin() + 1, ex_to<indexed>(*it2).seq.end(), un, dummy1);
702                         un.insert(un.end(), free1.begin(), free1.end());
703
704                         // Check whether the two factors share dummy indices
705                         exvector free, dummy;
706                         find_free_and_dummy(un, free, dummy);
707                         size_t num_dummies = dummy.size();
708                         if (num_dummies == 0)
709                                 continue;
710
711                         // At least one dummy index, is it a defined scalar product?
712                         bool contracted = false;
713                         if (free.empty()) {
714
715                                 // Find minimal dimension of all indices of both factors
716                                 exvector::const_iterator dit = ex_to<indexed>(*it1).seq.begin() + 1, ditend = ex_to<indexed>(*it1).seq.end();
717                                 ex dim = ex_to<idx>(*dit).get_dim();
718                                 ++dit;
719                                 for (; dit != ditend; ++dit) {
720                                         dim = minimal_dim(dim, ex_to<idx>(*dit).get_dim());
721                                 }
722                                 dit = ex_to<indexed>(*it2).seq.begin() + 1;
723                                 ditend = ex_to<indexed>(*it2).seq.end();
724                                 for (; dit != ditend; ++dit) {
725                                         dim = minimal_dim(dim, ex_to<idx>(*dit).get_dim());
726                                 }
727
728                                 // User-defined scalar product?
729                                 if (sp.is_defined(*it1, *it2, dim)) {
730
731                                         // Yes, substitute it
732                                         *it1 = sp.evaluate(*it1, *it2, dim);
733                                         *it2 = _ex1;
734                                         goto contraction_done;
735                                 }
736                         }
737
738                         // Try to contract the first one with the second one
739                         contracted = ex_to<basic>(it1->op(0)).contract_with(it1, it2, v);
740                         if (!contracted) {
741
742                                 // That didn't work; maybe the second object knows how to
743                                 // contract itself with the first one
744                                 contracted = ex_to<basic>(it2->op(0)).contract_with(it2, it1, v);
745                         }
746                         if (contracted) {
747 contraction_done:
748                                 if (first_noncommutative || second_noncommutative
749                                  || is_exactly_a<add>(*it1) || is_exactly_a<add>(*it2)
750                                  || is_exactly_a<mul>(*it1) || is_exactly_a<mul>(*it2)
751                                  || is_exactly_a<ncmul>(*it1) || is_exactly_a<ncmul>(*it2)) {
752
753                                         // One of the factors became a sum or product:
754                                         // re-expand expression and run again
755                                         // Non-commutative products are always re-expanded to give
756                                         // eval_ncmul() the chance to re-order and canonicalize
757                                         // the product
758                                         ex r = (non_commutative ? ex(ncmul(v, true)) : ex(mul(v)));
759                                         return simplify_indexed(r, free_indices, dummy_indices, sp);
760                                 }
761
762                                 // Both objects may have new indices now or they might
763                                 // even not be indexed objects any more, so we have to
764                                 // start over
765                                 something_changed = true;
766                                 goto try_again;
767                         }
768                 }
769         }
770
771         // Find free indices (concatenate them all and call find_free_and_dummy())
772         // and all dummy indices that appear
773         exvector un, individual_dummy_indices;
774         for (it1 = v.begin(), itend = v.end(); it1 != itend; ++it1) {
775                 exvector free_indices_of_factor;
776                 if (is_a<indexed>(*it1)) {
777                         exvector dummy_indices_of_factor;
778                         find_free_and_dummy(ex_to<indexed>(*it1).seq.begin() + 1, ex_to<indexed>(*it1).seq.end(), free_indices_of_factor, dummy_indices_of_factor);
779                         individual_dummy_indices.insert(individual_dummy_indices.end(), dummy_indices_of_factor.begin(), dummy_indices_of_factor.end());
780                 } else
781                         free_indices_of_factor = it1->get_free_indices();
782                 un.insert(un.end(), free_indices_of_factor.begin(), free_indices_of_factor.end());
783         }
784         exvector local_dummy_indices;
785         find_free_and_dummy(un, free_indices, local_dummy_indices);
786         local_dummy_indices.insert(local_dummy_indices.end(), individual_dummy_indices.begin(), individual_dummy_indices.end());
787
788         // Filter out the dummy indices with variance
789         exvector variant_dummy_indices;
790         find_variant_indices(local_dummy_indices, variant_dummy_indices);
791
792         // Any indices with variance present at all?
793         if (!variant_dummy_indices.empty()) {
794
795                 // Yes, bring the product into a canonical order that only depends on
796                 // the base expressions of indexed objects
797                 if (!non_commutative)
798                         std::sort(v.begin(), v.end(), ex_base_is_less());
799
800                 exvector moved_indices;
801
802                 // Iterate over all indexed objects in the product
803                 for (it1 = v.begin(), itend = v.end(); it1 != itend; ++it1) {
804                         if (!is_a<indexed>(*it1))
805                                 continue;
806
807                         if (reposition_dummy_indices(*it1, variant_dummy_indices, moved_indices))
808                                 something_changed = true;
809                 }
810         }
811
812         ex r;
813         if (something_changed)
814                 r = non_commutative ? ex(ncmul(v, true)) : ex(mul(v));
815         else
816                 r = e;
817
818         // The result should be symmetric with respect to exchange of dummy
819         // indices, so if the symmetrization vanishes, the whole expression is
820         // zero. This detects things like eps.i.j.k * p.j * p.k = 0.
821         if (local_dummy_indices.size() >= 2) {
822                 exvector dummy_syms;
823                 dummy_syms.reserve(local_dummy_indices.size());
824                 for (exvector::const_iterator it = local_dummy_indices.begin(); it != local_dummy_indices.end(); ++it)
825                         dummy_syms.push_back(it->op(0));
826                 if (symmetrize(r, dummy_syms).is_zero()) {
827                         free_indices.clear();
828                         return _ex0;
829                 }
830         }
831
832         // Dummy index renaming
833         r = rename_dummy_indices(r, dummy_indices, local_dummy_indices);
834
835         // Product of indexed object with a scalar?
836         if (is_exactly_a<mul>(r) && r.nops() == 2
837          && is_exactly_a<numeric>(r.op(1)) && is_a<indexed>(r.op(0)))
838                 return ex_to<basic>(r.op(0).op(0)).scalar_mul_indexed(r.op(0), ex_to<numeric>(r.op(1)));
839         else
840                 return r;
841 }
842
843 /** This structure stores the original and symmetrized versions of terms
844  *  obtained during the simplification of sums. */
845 class terminfo {
846 public:
847         terminfo(const ex & orig_, const ex & symm_) : orig(orig_), symm(symm_) {}
848
849         ex orig; /**< original term */
850         ex symm; /**< symmtrized term */
851 };
852
853 class terminfo_is_less {
854 public:
855         bool operator() (const terminfo & ti1, const terminfo & ti2) const
856         {
857                 return (ti1.symm.compare(ti2.symm) < 0);
858         }
859 };
860
861 /** This structure stores the individual symmetrized terms obtained during
862  *  the simplification of sums. */
863 class symminfo {
864 public:
865         symminfo() : num(0) {}
866
867         symminfo(const ex & symmterm_, const ex & orig_, size_t num_) : orig(orig_), num(num_)
868         {
869                 if (is_exactly_a<mul>(symmterm_) && is_exactly_a<numeric>(symmterm_.op(symmterm_.nops()-1))) {
870                         coeff = symmterm_.op(symmterm_.nops()-1);
871                         symmterm = symmterm_ / coeff;
872                 } else {
873                         coeff = 1;
874                         symmterm = symmterm_;
875                 }
876         }
877
878         ex symmterm;  /**< symmetrized term */
879         ex coeff;     /**< coefficient of symmetrized term */
880         ex orig;      /**< original term */
881         size_t num; /**< how many symmetrized terms resulted from the original term */
882 };
883
884 class symminfo_is_less_by_symmterm {
885 public:
886         bool operator() (const symminfo & si1, const symminfo & si2) const
887         {
888                 return (si1.symmterm.compare(si2.symmterm) < 0);
889         }
890 };
891
892 class symminfo_is_less_by_orig {
893 public:
894         bool operator() (const symminfo & si1, const symminfo & si2) const
895         {
896                 return (si1.orig.compare(si2.orig) < 0);
897         }
898 };
899
900 /** Simplify indexed expression, return list of free indices. */
901 ex simplify_indexed(const ex & e, exvector & free_indices, exvector & dummy_indices, const scalar_products & sp)
902 {
903         // Expand the expression
904         ex e_expanded = e.expand();
905
906         // Simplification of single indexed object: just find the free indices
907         // and perform dummy index renaming/repositioning
908         if (is_a<indexed>(e_expanded)) {
909
910                 // Find the dummy indices
911                 const indexed &i = ex_to<indexed>(e_expanded);
912                 exvector local_dummy_indices;
913                 find_free_and_dummy(i.seq.begin() + 1, i.seq.end(), free_indices, local_dummy_indices);
914
915                 // Filter out the dummy indices with variance
916                 exvector variant_dummy_indices;
917                 find_variant_indices(local_dummy_indices, variant_dummy_indices);
918
919                 // Any indices with variance present at all?
920                 if (!variant_dummy_indices.empty()) {
921
922                         // Yes, reposition them
923                         exvector moved_indices;
924                         reposition_dummy_indices(e_expanded, variant_dummy_indices, moved_indices);
925                 }
926
927                 // Rename the dummy indices
928                 return rename_dummy_indices(e_expanded, dummy_indices, local_dummy_indices);
929         }
930
931         // Simplification of sum = sum of simplifications, check consistency of
932         // free indices in each term
933         if (is_exactly_a<add>(e_expanded)) {
934                 bool first = true;
935                 ex sum;
936                 free_indices.clear();
937
938                 for (size_t i=0; i<e_expanded.nops(); i++) {
939                         exvector free_indices_of_term;
940                         ex term = simplify_indexed(e_expanded.op(i), free_indices_of_term, dummy_indices, sp);
941                         if (!term.is_zero()) {
942                                 if (first) {
943                                         free_indices = free_indices_of_term;
944                                         sum = term;
945                                         first = false;
946                                 } else {
947                                         if (!indices_consistent(free_indices, free_indices_of_term)) {
948                                                 std::ostringstream s;
949                                                 s << "simplify_indexed: inconsistent indices in sum: ";
950                                                 s << exprseq(free_indices) << " vs. " << exprseq(free_indices_of_term);
951                                                 throw (std::runtime_error(s.str()));
952                                         }
953                                         if (is_a<indexed>(sum) && is_a<indexed>(term))
954                                                 sum = ex_to<basic>(sum.op(0)).add_indexed(sum, term);
955                                         else
956                                                 sum += term;
957                                 }
958                         }
959                 }
960
961                 // If the sum turns out to be zero, we are finished
962                 if (sum.is_zero()) {
963                         free_indices.clear();
964                         return sum;
965                 }
966
967                 // More than one term and more than one dummy index?
968                 size_t num_terms_orig = (is_exactly_a<add>(sum) ? sum.nops() : 1);
969                 if (num_terms_orig < 2 || dummy_indices.size() < 2)
970                         return sum;
971
972                 // Yes, construct vector of all dummy index symbols
973                 exvector dummy_syms;
974                 dummy_syms.reserve(dummy_indices.size());
975                 for (exvector::const_iterator it = dummy_indices.begin(); it != dummy_indices.end(); ++it)
976                         dummy_syms.push_back(it->op(0));
977
978                 // Chop the sum into terms and symmetrize each one over the dummy
979                 // indices
980                 std::vector<terminfo> terms;
981                 for (size_t i=0; i<sum.nops(); i++) {
982                         const ex & term = sum.op(i);
983                         ex term_symm = symmetrize(term, dummy_syms);
984                         if (term_symm.is_zero())
985                                 continue;
986                         terms.push_back(terminfo(term, term_symm));
987                 }
988
989                 // Sort by symmetrized terms
990                 std::sort(terms.begin(), terms.end(), terminfo_is_less());
991
992                 // Combine equal symmetrized terms
993                 std::vector<terminfo> terms_pass2;
994                 for (std::vector<terminfo>::const_iterator i=terms.begin(); i!=terms.end(); ) {
995                         size_t num = 1;
996                         std::vector<terminfo>::const_iterator j = i + 1;
997                         while (j != terms.end() && j->symm == i->symm) {
998                                 num++;
999                                 j++;
1000                         }
1001                         terms_pass2.push_back(terminfo(i->orig * num, i->symm * num));
1002                         i = j;
1003                 }
1004
1005                 // If there is only one term left, we are finished
1006                 if (terms_pass2.size() == 1)
1007                         return terms_pass2[0].orig;
1008
1009                 // Chop the symmetrized terms into subterms
1010                 std::vector<symminfo> sy;
1011                 for (std::vector<terminfo>::const_iterator i=terms_pass2.begin(); i!=terms_pass2.end(); ++i) {
1012                         if (is_exactly_a<add>(i->symm)) {
1013                                 size_t num = i->symm.nops();
1014                                 for (size_t j=0; j<num; j++)
1015                                         sy.push_back(symminfo(i->symm.op(j), i->orig, num));
1016                         } else
1017                                 sy.push_back(symminfo(i->symm, i->orig, 1));
1018                 }
1019
1020                 // Sort by symmetrized subterms
1021                 std::sort(sy.begin(), sy.end(), symminfo_is_less_by_symmterm());
1022
1023                 // Combine equal symmetrized subterms
1024                 std::vector<symminfo> sy_pass2;
1025                 exvector result;
1026                 for (std::vector<symminfo>::const_iterator i=sy.begin(); i!=sy.end(); ) {
1027
1028                         // Combine equal terms
1029                         std::vector<symminfo>::const_iterator j = i + 1;
1030                         if (j != sy.end() && j->symmterm == i->symmterm) {
1031
1032                                 // More than one term, collect the coefficients
1033                                 ex coeff = i->coeff;
1034                                 while (j != sy.end() && j->symmterm == i->symmterm) {
1035                                         coeff += j->coeff;
1036                                         j++;
1037                                 }
1038
1039                                 // Add combined term to result
1040                                 if (!coeff.is_zero())
1041                                         result.push_back(coeff * i->symmterm);
1042
1043                         } else {
1044
1045                                 // Single term, store for second pass
1046                                 sy_pass2.push_back(*i);
1047                         }
1048
1049                         i = j;
1050                 }
1051
1052                 // Were there any remaining terms that didn't get combined?
1053                 if (sy_pass2.size() > 0) {
1054
1055                         // Yes, sort by their original terms
1056                         std::sort(sy_pass2.begin(), sy_pass2.end(), symminfo_is_less_by_orig());
1057
1058                         for (std::vector<symminfo>::const_iterator i=sy_pass2.begin(); i!=sy_pass2.end(); ) {
1059
1060                                 // How many symmetrized terms of this original term are left?
1061                                 size_t num = 1;
1062                                 std::vector<symminfo>::const_iterator j = i + 1;
1063                                 while (j != sy_pass2.end() && j->orig == i->orig) {
1064                                         num++;
1065                                         j++;
1066                                 }
1067
1068                                 if (num == i->num) {
1069
1070                                         // All terms left, then add the original term to the result
1071                                         result.push_back(i->orig);
1072
1073                                 } else {
1074
1075                                         // Some terms were combined with others, add up the remaining symmetrized terms
1076                                         std::vector<symminfo>::const_iterator k;
1077                                         for (k=i; k!=j; k++)
1078                                                 result.push_back(k->coeff * k->symmterm);
1079                                 }
1080
1081                                 i = j;
1082                         }
1083                 }
1084
1085                 // Add all resulting terms
1086                 ex sum_symm = (new add(result))->setflag(status_flags::dynallocated);
1087                 if (sum_symm.is_zero())
1088                         free_indices.clear();
1089                 return sum_symm;
1090         }
1091
1092         // Simplification of products
1093         if (is_exactly_a<mul>(e_expanded)
1094          || is_exactly_a<ncmul>(e_expanded)
1095          || (is_exactly_a<power>(e_expanded) && is_a<indexed>(e_expanded.op(0)) && e_expanded.op(1).is_equal(_ex2)))
1096                 return simplify_indexed_product(e_expanded, free_indices, dummy_indices, sp);
1097
1098         // Cannot do anything
1099         free_indices.clear();
1100         return e_expanded;
1101 }
1102
1103 /** Simplify/canonicalize expression containing indexed objects. This
1104  *  performs contraction of dummy indices where possible and checks whether
1105  *  the free indices in sums are consistent.
1106  *
1107  *  @return simplified expression */
1108 ex ex::simplify_indexed() const
1109 {
1110         exvector free_indices, dummy_indices;
1111         scalar_products sp;
1112         return GiNaC::simplify_indexed(*this, free_indices, dummy_indices, sp);
1113 }
1114
1115 /** Simplify/canonicalize expression containing indexed objects. This
1116  *  performs contraction of dummy indices where possible, checks whether
1117  *  the free indices in sums are consistent, and automatically replaces
1118  *  scalar products by known values if desired.
1119  *
1120  *  @param sp Scalar products to be replaced automatically
1121  *  @return simplified expression */
1122 ex ex::simplify_indexed(const scalar_products & sp) const
1123 {
1124         exvector free_indices, dummy_indices;
1125         return GiNaC::simplify_indexed(*this, free_indices, dummy_indices, sp);
1126 }
1127
1128 /** Symmetrize expression over its free indices. */
1129 ex ex::symmetrize() const
1130 {
1131         return GiNaC::symmetrize(*this, get_free_indices());
1132 }
1133
1134 /** Antisymmetrize expression over its free indices. */
1135 ex ex::antisymmetrize() const
1136 {
1137         return GiNaC::antisymmetrize(*this, get_free_indices());
1138 }
1139
1140 /** Symmetrize expression by cyclic permutation over its free indices. */
1141 ex ex::symmetrize_cyclic() const
1142 {
1143         return GiNaC::symmetrize_cyclic(*this, get_free_indices());
1144 }
1145
1146 //////////
1147 // helper classes
1148 //////////
1149
1150 spmapkey::spmapkey(const ex & v1_, const ex & v2_, const ex & dim_) : dim(dim_)
1151 {
1152         // If indexed, extract base objects
1153         ex s1 = is_a<indexed>(v1_) ? v1_.op(0) : v1_;
1154         ex s2 = is_a<indexed>(v2_) ? v2_.op(0) : v2_;
1155
1156         // Enforce canonical order in pair
1157         if (s1.compare(s2) > 0) {
1158                 v1 = s2;
1159                 v2 = s1;
1160         } else {
1161                 v1 = s1;
1162                 v2 = s2;
1163         }
1164 }
1165
1166 bool spmapkey::operator==(const spmapkey &other) const
1167 {
1168         if (!v1.is_equal(other.v1))
1169                 return false;
1170         if (!v2.is_equal(other.v2))
1171                 return false;
1172         if (is_a<wildcard>(dim) || is_a<wildcard>(other.dim))
1173                 return true;
1174         else
1175                 return dim.is_equal(other.dim);
1176 }
1177
1178 bool spmapkey::operator<(const spmapkey &other) const
1179 {
1180         int cmp = v1.compare(other.v1);
1181         if (cmp)
1182                 return cmp < 0;
1183         cmp = v2.compare(other.v2);
1184         if (cmp)
1185                 return cmp < 0;
1186
1187         // Objects are equal, now check dimensions
1188         if (is_a<wildcard>(dim) || is_a<wildcard>(other.dim))
1189                 return false;
1190         else
1191                 return dim.compare(other.dim) < 0;
1192 }
1193
1194 void spmapkey::debugprint() const
1195 {
1196         std::cerr << "(" << v1 << "," << v2 << "," << dim << ")";
1197 }
1198
1199 void scalar_products::add(const ex & v1, const ex & v2, const ex & sp)
1200 {
1201         spm[spmapkey(v1, v2)] = sp;
1202 }
1203
1204 void scalar_products::add(const ex & v1, const ex & v2, const ex & dim, const ex & sp)
1205 {
1206         spm[spmapkey(v1, v2, dim)] = sp;
1207 }
1208
1209 void scalar_products::add_vectors(const lst & l, const ex & dim)
1210 {
1211         // Add all possible pairs of products
1212         for (lst::const_iterator it1 = l.begin(); it1 != l.end(); ++it1)
1213                 for (lst::const_iterator it2 = l.begin(); it2 != l.end(); ++it2)
1214                         add(*it1, *it2, *it1 * *it2);
1215 }
1216
1217 void scalar_products::clear()
1218 {
1219         spm.clear();
1220 }
1221
1222 /** Check whether scalar product pair is defined. */
1223 bool scalar_products::is_defined(const ex & v1, const ex & v2, const ex & dim) const
1224 {
1225         return spm.find(spmapkey(v1, v2, dim)) != spm.end();
1226 }
1227
1228 /** Return value of defined scalar product pair. */
1229 ex scalar_products::evaluate(const ex & v1, const ex & v2, const ex & dim) const
1230 {
1231         return spm.find(spmapkey(v1, v2, dim))->second;
1232 }
1233
1234 void scalar_products::debugprint() const
1235 {
1236         std::cerr << "map size=" << spm.size() << std::endl;
1237         spmap::const_iterator i = spm.begin(), end = spm.end();
1238         while (i != end) {
1239                 const spmapkey & k = i->first;
1240                 std::cerr << "item key=";
1241                 k.debugprint();
1242                 std::cerr << ", value=" << i->second << std::endl;
1243                 ++i;
1244         }
1245 }
1246
1247 } // namespace GiNaC