epsilon tensor contractions evaluate to metric tensors instead of deltas,
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <stdexcept>
24 #include <functional>
25 #include <algorithm>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "print.h"
31 #include "archive.h"
32 #include "utils.h"
33 #include "debugmsg.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS(symmetry, basic)
38
39 /*
40    Some notes about the structure of a symmetry tree:
41     - The leaf nodes of the tree are of type "none", have one index, and no
42       children (of course). They are constructed by the symmetry(unsigned)
43       constructor.
44     - Leaf nodes are the only nodes that only have one index.
45     - Container nodes contain two or more children. The "indices" set member
46       is the set union of the index sets of all children, and the "children"
47       vector stores the children themselves.
48     - The index set of each child of a "symm", "anti" or "cycl" node must
49       have the same size. It follows that the children of such a node are
50       either all leaf nodes, or all container nodes with two or more indices.
51 */
52
53 //////////
54 // default constructor, destructor, copy constructor assignment operator and helpers
55 //////////
56
57 symmetry::symmetry() : type(none)
58 {
59         debugmsg("symmetry default constructor", LOGLEVEL_CONSTRUCT);
60         tinfo_key = TINFO_symmetry;
61 }
62
63 void symmetry::copy(const symmetry & other)
64 {
65         inherited::copy(other);
66         type = other.type;
67         indices = other.indices;
68         children = other.children;
69 }
70
71 DEFAULT_DESTROY(symmetry)
72
73 //////////
74 // other constructors
75 //////////
76
77 symmetry::symmetry(unsigned i) : type(none)
78 {
79         debugmsg("symmetry constructor from unsigned", LOGLEVEL_CONSTRUCT);
80         indices.insert(i);
81         tinfo_key = TINFO_symmetry;
82 }
83
84 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : type(t)
85 {
86         debugmsg("symmetry constructor from symmetry_type,symmetry &,symmetry &", LOGLEVEL_CONSTRUCT);
87         add(c1); add(c2);
88         tinfo_key = TINFO_symmetry;
89 }
90
91 //////////
92 // archiving
93 //////////
94
95 /** Construct object from archive_node. */
96 symmetry::symmetry(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
97 {
98         debugmsg("symmetry ctor from archive_node", LOGLEVEL_CONSTRUCT);
99
100         unsigned t;
101         if (!(n.find_unsigned("type", t)))
102                 throw (std::runtime_error("unknown symmetry type in archive"));
103         type = (symmetry_type)t;
104
105         unsigned i = 0;
106         while (true) {
107                 ex e;
108                 if (n.find_ex("child", e, sym_lst, i))
109                         add(ex_to<symmetry>(e));
110                 else
111                         break;
112                 i++;
113         }
114
115         if (i == 0) {
116                 while (true) {
117                         unsigned u;
118                         if (n.find_unsigned("index", u, i))
119                                 indices.insert(u);
120                         else
121                                 break;
122                         i++;
123                 }
124         }
125 }
126
127 /** Archive the object. */
128 void symmetry::archive(archive_node &n) const
129 {
130         inherited::archive(n);
131
132         n.add_unsigned("type", type);
133
134         if (children.empty()) {
135                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
136                 while (i != iend) {
137                         n.add_unsigned("index", *i);
138                         i++;
139                 }
140         } else {
141                 exvector::const_iterator i = children.begin(), iend = children.end();
142                 while (i != iend) {
143                         n.add_ex("child", *i);
144                         i++;
145                 }
146         }
147 }
148
149 DEFAULT_UNARCHIVE(symmetry)
150
151 //////////
152 // functions overriding virtual functions from base classes
153 //////////
154
155 int symmetry::compare_same_type(const basic & other) const
156 {
157         GINAC_ASSERT(is_of_type(other, symmetry));
158
159         // All symmetry trees are equal. They are not supposed to appear in
160         // ordinary expressions anyway...
161         return 0;
162 }
163
164 void symmetry::print(const print_context & c, unsigned level = 0) const
165 {
166         debugmsg("symmetry print", LOGLEVEL_PRINT);
167
168         if (is_of_type(c, print_tree)) {
169
170                 c.s << std::string(level, ' ') << class_name()
171                     << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
172                     << ", type=";
173
174                 switch (type) {
175                         case none: c.s << "none"; break;
176                         case symmetric: c.s << "symm"; break;
177                         case antisymmetric: c.s << "anti"; break;
178                         case cyclic: c.s << "cycl"; break;
179                         default: c.s << "<unknown>"; break;
180                 }
181
182                 c.s << ", indices=(";
183                 if (!indices.empty()) {
184                         std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
185                         --end;
186                         while (i != end)
187                                 c.s << *i++ << ",";
188                         c.s << *i;
189                 }
190                 c.s << ")\n";
191
192                 unsigned delta_indent = static_cast<const print_tree &>(c).delta_indent;
193                 exvector::const_iterator i = children.begin(), end = children.end();
194                 while (i != end) {
195                         i->print(c, level + delta_indent);
196                         ++i;
197                 }
198
199         } else {
200
201                 if (children.empty()) {
202                         if (indices.size() > 0)
203                                 c.s << *(indices.begin());
204                         else
205                                 c.s << "none";
206                 } else {
207                         switch (type) {
208                                 case none: c.s << '!'; break;
209                                 case symmetric: c.s << '+'; break;
210                                 case antisymmetric: c.s << '-'; break;
211                                 case cyclic: c.s << '@'; break;
212                                 default: c.s << '?'; break;
213                         }
214                         c.s << '(';
215                         unsigned num = children.size();
216                         for (unsigned i=0; i<num; i++) {
217                                 children[i].print(c);
218                                 if (i != num - 1)
219                                         c.s << ",";
220                         }
221                         c.s << ')';
222                 }
223         }
224 }
225
226 //////////
227 // non-virtual functions in this class
228 //////////
229
230 symmetry &symmetry::add(const symmetry &c)
231 {
232         // All children must have the same number of indices
233         if (type != none && !children.empty()) {
234                 GINAC_ASSERT(is_ex_exactly_of_type(children[0], symmetry));
235                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
236                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
237         }
238
239         // Compute union of indices and check whether the two sets are disjoint
240         std::set<unsigned> un;
241         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
242         if (un.size() != indices.size() + c.indices.size())
243                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
244
245         // Set new index set
246         indices.swap(un);
247
248         // Add child node
249         children.push_back(c);
250         return *this;
251 }
252
253 void symmetry::validate(unsigned n)
254 {
255         if (indices.upper_bound(n - 1) != indices.end())
256                 throw (std::range_error("symmetry::verify(): index values are out of range"));
257         if (type != none && indices.empty()) {
258                 for (unsigned i=0; i<n; i++)
259                         add(i);
260         }
261 }
262
263 //////////
264 // global functions
265 //////////
266
267 class sy_is_less : public std::binary_function<ex, ex, bool> {
268         exvector::iterator v;
269
270 public:
271         sy_is_less(exvector::iterator v_) : v(v_) {}
272
273         bool operator() (const ex &lh, const ex &rh) const
274         {
275                 GINAC_ASSERT(is_ex_exactly_of_type(lh, symmetry));
276                 GINAC_ASSERT(is_ex_exactly_of_type(rh, symmetry));
277                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
278                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
279                 while (ait != aitend) {
280                         int cmpval = v[*ait].compare(v[*bit]);
281                         if (cmpval < 0)
282                                 return true;
283                         else if (cmpval > 0)
284                                 return false;
285                         ++ait; ++bit;
286                 }
287                 return false;
288         }
289 };
290
291 class sy_swap : public std::binary_function<ex, ex, void> {
292         exvector::iterator v;
293
294 public:
295         bool &swapped;
296
297         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
298
299         void operator() (const ex &lh, const ex &rh)
300         {
301                 GINAC_ASSERT(is_ex_exactly_of_type(lh, symmetry));
302                 GINAC_ASSERT(is_ex_exactly_of_type(rh, symmetry));
303                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
304                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
305                 while (ait != aitend) {
306                         v[*ait].swap(v[*bit]);
307                         ++ait; ++bit;
308                 }
309                 swapped = true;
310         }
311 };
312
313 int canonicalize(exvector::iterator v, const symmetry &symm)
314 {
315         // Less than two indices? Then do nothing
316         if (symm.indices.size() < 2)
317                 return INT_MAX;
318
319         // Canonicalize children first
320         bool something_changed = false;
321         int sign = 1;
322         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
323         while (first != last) {
324                 GINAC_ASSERT(is_ex_exactly_of_type(*first, symmetry));
325                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
326                 if (child_sign == 0)
327                         return 0;
328                 if (child_sign != INT_MAX) {
329                         something_changed = true;
330                         sign *= child_sign;
331                 }
332                 first++;
333         }
334
335         // Now reorder the children
336         first = symm.children.begin();
337         switch (symm.type) {
338                 case symmetry::symmetric:
339                         // Sort the children in ascending order
340                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
341                         break;
342                 case symmetry::antisymmetric:
343                         // Sort the children in ascending order, keeping track of the signum
344                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
345                         break;
346                 case symmetry::cyclic:
347                         // Permute the smallest child to the front
348                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
349                         break;
350                 default:
351                         break;
352         }
353         return something_changed ? sign : INT_MAX;
354 }
355
356
357 // Symmetrize/antisymmetrize over a vector of objects
358 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
359 {
360         // Need at least 2 objects for this operation
361         unsigned num = last - first;
362         if (num < 2)
363                 return e;
364
365         // Transform object vector to a list
366         exlist iv_lst;
367         iv_lst.insert(iv_lst.begin(), first, last);
368         lst orig_lst(iv_lst, true);
369
370         // Create index vectors for permutation
371         unsigned *iv = new unsigned[num], *iv2;
372         for (unsigned i=0; i<num; i++)
373                 iv[i] = i;
374         iv2 = (asymmetric ? new unsigned[num] : NULL);
375
376         // Loop over all permutations (the first permutation, which is the
377         // identity, is unrolled)
378         ex sum = e;
379         while (std::next_permutation(iv, iv + num)) {
380                 lst new_lst;
381                 for (unsigned i=0; i<num; i++)
382                         new_lst.append(orig_lst.op(iv[i]));
383                 ex term = e.subs(orig_lst, new_lst);
384                 if (asymmetric) {
385                         memcpy(iv2, iv, num * sizeof(unsigned));
386                         term *= permutation_sign(iv2, iv2 + num);
387                 }
388                 sum += term;
389         }
390
391         delete[] iv;
392         delete[] iv2;
393
394         return sum / factorial(numeric(num));
395 }
396
397 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
398 {
399         return symm(e, first, last, false);
400 }
401
402 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
403 {
404         return symm(e, first, last, true);
405 }
406
407 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
408 {
409         // Need at least 2 objects for this operation
410         unsigned num = last - first;
411         if (num < 2)
412                 return e;
413
414         // Transform object vector to a list
415         exlist iv_lst;
416         iv_lst.insert(iv_lst.begin(), first, last);
417         lst orig_lst(iv_lst, true);
418         lst new_lst = orig_lst;
419
420         // Loop over all cyclic permutations (the first permutation, which is
421         // the identity, is unrolled)
422         ex sum = e;
423         for (unsigned i=0; i<num-1; i++) {
424                 ex perm = new_lst.op(0);
425                 new_lst.remove_first().append(perm);
426                 sum += e.subs(orig_lst, new_lst);
427         }
428         return sum / num;
429 }
430
431 /** Symmetrize expression over a list of objects (symbols, indices). */
432 ex ex::symmetrize(const lst & l) const
433 {
434         exvector v;
435         v.reserve(l.nops());
436         for (unsigned i=0; i<l.nops(); i++)
437                 v.push_back(l.op(i));
438         return symm(*this, v.begin(), v.end(), false);
439 }
440
441 /** Antisymmetrize expression over a list of objects (symbols, indices). */
442 ex ex::antisymmetrize(const lst & l) const
443 {
444         exvector v;
445         v.reserve(l.nops());
446         for (unsigned i=0; i<l.nops(); i++)
447                 v.push_back(l.op(i));
448         return symm(*this, v.begin(), v.end(), true);
449 }
450
451 /** Symmetrize expression by cyclic permutation over a list of objects
452  *  (symbols, indices). */
453 ex ex::symmetrize_cyclic(const lst & l) const
454 {
455         exvector v;
456         v.reserve(l.nops());
457         for (unsigned i=0; i<l.nops(); i++)
458                 v.push_back(l.op(i));
459         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
460 }
461
462 } // namespace GiNaC