- New tinfo mechanism
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2006 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "operators.h"
31 #include "archive.h"
32 #include "utils.h"
33
34 namespace GiNaC {
35
36 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(symmetry, basic,
37   print_func<print_context>(&symmetry::do_print).
38   print_func<print_tree>(&symmetry::do_print_tree))
39
40 /*
41    Some notes about the structure of a symmetry tree:
42     - The leaf nodes of the tree are of type "none", have one index, and no
43       children (of course). They are constructed by the symmetry(unsigned)
44       constructor.
45     - Leaf nodes are the only nodes that only have one index.
46     - Container nodes contain two or more children. The "indices" set member
47       is the set union of the index sets of all children, and the "children"
48       vector stores the children themselves.
49     - The index set of each child of a "symm", "anti" or "cycl" node must
50       have the same size. It follows that the children of such a node are
51       either all leaf nodes, or all container nodes with two or more indices.
52 */
53
54 //////////
55 // default constructor
56 //////////
57
58 symmetry::symmetry() : inherited(&symmetry::tinfo_static), type(none)
59 {
60         setflag(status_flags::evaluated | status_flags::expanded);
61 }
62
63 //////////
64 // other constructors
65 //////////
66
67 symmetry::symmetry(unsigned i) : inherited(&symmetry::tinfo_static), type(none)
68 {
69         indices.insert(i);
70         setflag(status_flags::evaluated | status_flags::expanded);
71 }
72
73 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : inherited(&symmetry::tinfo_static), type(t)
74 {
75         add(c1); add(c2);
76         setflag(status_flags::evaluated | status_flags::expanded);
77 }
78
79 //////////
80 // archiving
81 //////////
82
83 /** Construct object from archive_node. */
84 symmetry::symmetry(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
85 {
86         unsigned t;
87         if (!(n.find_unsigned("type", t)))
88                 throw (std::runtime_error("unknown symmetry type in archive"));
89         type = (symmetry_type)t;
90
91         unsigned i = 0;
92         while (true) {
93                 ex e;
94                 if (n.find_ex("child", e, sym_lst, i))
95                         add(ex_to<symmetry>(e));
96                 else
97                         break;
98                 i++;
99         }
100
101         if (i == 0) {
102                 while (true) {
103                         unsigned u;
104                         if (n.find_unsigned("index", u, i))
105                                 indices.insert(u);
106                         else
107                                 break;
108                         i++;
109                 }
110         }
111 }
112
113 /** Archive the object. */
114 void symmetry::archive(archive_node &n) const
115 {
116         inherited::archive(n);
117
118         n.add_unsigned("type", type);
119
120         if (children.empty()) {
121                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
122                 while (i != iend) {
123                         n.add_unsigned("index", *i);
124                         i++;
125                 }
126         } else {
127                 exvector::const_iterator i = children.begin(), iend = children.end();
128                 while (i != iend) {
129                         n.add_ex("child", *i);
130                         i++;
131                 }
132         }
133 }
134
135 DEFAULT_UNARCHIVE(symmetry)
136
137 //////////
138 // functions overriding virtual functions from base classes
139 //////////
140
141 int symmetry::compare_same_type(const basic & other) const
142 {
143         GINAC_ASSERT(is_a<symmetry>(other));
144
145         // All symmetry trees are equal. They are not supposed to appear in
146         // ordinary expressions anyway...
147         return 0;
148 }
149
150 void symmetry::do_print(const print_context & c, unsigned level) const
151 {
152         if (children.empty()) {
153                 if (indices.size() > 0)
154                         c.s << *(indices.begin());
155                 else
156                         c.s << "none";
157         } else {
158                 switch (type) {
159                         case none: c.s << '!'; break;
160                         case symmetric: c.s << '+'; break;
161                         case antisymmetric: c.s << '-'; break;
162                         case cyclic: c.s << '@'; break;
163                         default: c.s << '?'; break;
164                 }
165                 c.s << '(';
166                 size_t num = children.size();
167                 for (size_t i=0; i<num; i++) {
168                         children[i].print(c);
169                         if (i != num - 1)
170                                 c.s << ",";
171                 }
172                 c.s << ')';
173         }
174 }
175
176 void symmetry::do_print_tree(const print_tree & c, unsigned level) const
177 {
178         c.s << std::string(level, ' ') << class_name() << " @" << this
179             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
180             << ", type=";
181
182         switch (type) {
183                 case none: c.s << "none"; break;
184                 case symmetric: c.s << "symm"; break;
185                 case antisymmetric: c.s << "anti"; break;
186                 case cyclic: c.s << "cycl"; break;
187                 default: c.s << "<unknown>"; break;
188         }
189
190         c.s << ", indices=(";
191         if (!indices.empty()) {
192                 std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
193                 --end;
194                 while (i != end)
195                         c.s << *i++ << ",";
196                 c.s << *i;
197         }
198         c.s << ")\n";
199
200         exvector::const_iterator i = children.begin(), end = children.end();
201         while (i != end) {
202                 i->print(c, level + c.delta_indent);
203                 ++i;
204         }
205 }
206
207 //////////
208 // non-virtual functions in this class
209 //////////
210
211 symmetry &symmetry::add(const symmetry &c)
212 {
213         // All children must have the same number of indices
214         if (type != none && !children.empty()) {
215                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
216                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
217                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
218         }
219
220         // Compute union of indices and check whether the two sets are disjoint
221         std::set<unsigned> un;
222         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
223         if (un.size() != indices.size() + c.indices.size())
224                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
225
226         // Set new index set
227         indices.swap(un);
228
229         // Add child node
230         children.push_back(c);
231         return *this;
232 }
233
234 void symmetry::validate(unsigned n)
235 {
236         if (indices.upper_bound(n - 1) != indices.end())
237                 throw (std::range_error("symmetry::verify(): index values are out of range"));
238         if (type != none && indices.empty()) {
239                 for (unsigned i=0; i<n; i++)
240                         add(i);
241         }
242 }
243
244 //////////
245 // global functions
246 //////////
247
248 static const symmetry & index0()
249 {
250         static ex s = (new symmetry(0))->setflag(status_flags::dynallocated);
251         return ex_to<symmetry>(s);
252 }
253
254 static const symmetry & index1()
255 {
256         static ex s = (new symmetry(1))->setflag(status_flags::dynallocated);
257         return ex_to<symmetry>(s);
258 }
259
260 static const symmetry & index2()
261 {
262         static ex s = (new symmetry(2))->setflag(status_flags::dynallocated);
263         return ex_to<symmetry>(s);
264 }
265
266 static const symmetry & index3()
267 {
268         static ex s = (new symmetry(3))->setflag(status_flags::dynallocated);
269         return ex_to<symmetry>(s);
270 }
271
272 const symmetry & not_symmetric()
273 {
274         static ex s = (new symmetry)->setflag(status_flags::dynallocated);
275         return ex_to<symmetry>(s);
276 }
277
278 const symmetry & symmetric2()
279 {
280         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->setflag(status_flags::dynallocated);
281         return ex_to<symmetry>(s);
282 }
283
284 const symmetry & symmetric3()
285 {
286         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
287         return ex_to<symmetry>(s);
288 }
289
290 const symmetry & symmetric4()
291 {
292         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
293         return ex_to<symmetry>(s);
294 }
295
296 const symmetry & antisymmetric2()
297 {
298         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->setflag(status_flags::dynallocated);
299         return ex_to<symmetry>(s);
300 }
301
302 const symmetry & antisymmetric3()
303 {
304         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
305         return ex_to<symmetry>(s);
306 }
307
308 const symmetry & antisymmetric4()
309 {
310         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
311         return ex_to<symmetry>(s);
312 }
313
314 class sy_is_less : public std::binary_function<ex, ex, bool> {
315         exvector::iterator v;
316
317 public:
318         sy_is_less(exvector::iterator v_) : v(v_) {}
319
320         bool operator() (const ex &lh, const ex &rh) const
321         {
322                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
323                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
324                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
325                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
326                 while (ait != aitend) {
327                         int cmpval = v[*ait].compare(v[*bit]);
328                         if (cmpval < 0)
329                                 return true;
330                         else if (cmpval > 0)
331                                 return false;
332                         ++ait; ++bit;
333                 }
334                 return false;
335         }
336 };
337
338 class sy_swap : public std::binary_function<ex, ex, void> {
339         exvector::iterator v;
340
341 public:
342         bool &swapped;
343
344         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
345
346         void operator() (const ex &lh, const ex &rh)
347         {
348                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
349                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
350                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
351                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
352                 while (ait != aitend) {
353                         v[*ait].swap(v[*bit]);
354                         ++ait; ++bit;
355                 }
356                 swapped = true;
357         }
358 };
359
360 int canonicalize(exvector::iterator v, const symmetry &symm)
361 {
362         // Less than two elements? Then do nothing
363         if (symm.indices.size() < 2)
364                 return INT_MAX;
365
366         // Canonicalize children first
367         bool something_changed = false;
368         int sign = 1;
369         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
370         while (first != last) {
371                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
372                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
373                 if (child_sign == 0)
374                         return 0;
375                 if (child_sign != INT_MAX) {
376                         something_changed = true;
377                         sign *= child_sign;
378                 }
379                 first++;
380         }
381
382         // Now reorder the children
383         first = symm.children.begin();
384         switch (symm.type) {
385                 case symmetry::symmetric:
386                         // Sort the children in ascending order
387                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
388                         break;
389                 case symmetry::antisymmetric:
390                         // Sort the children in ascending order, keeping track of the signum
391                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
392                         if (sign == 0)
393                                 return 0;
394                         break;
395                 case symmetry::cyclic:
396                         // Permute the smallest child to the front
397                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
398                         break;
399                 default:
400                         break;
401         }
402         return something_changed ? sign : INT_MAX;
403 }
404
405
406 // Symmetrize/antisymmetrize over a vector of objects
407 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
408 {
409         // Need at least 2 objects for this operation
410         unsigned num = last - first;
411         if (num < 2)
412                 return e;
413
414         // Transform object vector to a lst (for subs())
415         lst orig_lst(first, last);
416
417         // Create index vectors for permutation
418         unsigned *iv = new unsigned[num], *iv2;
419         for (unsigned i=0; i<num; i++)
420                 iv[i] = i;
421         iv2 = (asymmetric ? new unsigned[num] : NULL);
422
423         // Loop over all permutations (the first permutation, which is the
424         // identity, is unrolled)
425         ex sum = e;
426         while (std::next_permutation(iv, iv + num)) {
427                 lst new_lst;
428                 for (unsigned i=0; i<num; i++)
429                         new_lst.append(orig_lst.op(iv[i]));
430                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
431                 if (asymmetric) {
432                         memcpy(iv2, iv, num * sizeof(unsigned));
433                         term *= permutation_sign(iv2, iv2 + num);
434                 }
435                 sum += term;
436         }
437
438         delete[] iv;
439         delete[] iv2;
440
441         return sum / factorial(numeric(num));
442 }
443
444 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
445 {
446         return symm(e, first, last, false);
447 }
448
449 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
450 {
451         return symm(e, first, last, true);
452 }
453
454 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
455 {
456         // Need at least 2 objects for this operation
457         unsigned num = last - first;
458         if (num < 2)
459                 return e;
460
461         // Transform object vector to a lst (for subs())
462         lst orig_lst(first, last);
463         lst new_lst = orig_lst;
464
465         // Loop over all cyclic permutations (the first permutation, which is
466         // the identity, is unrolled)
467         ex sum = e;
468         for (unsigned i=0; i<num-1; i++) {
469                 ex perm = new_lst.op(0);
470                 new_lst.remove_first().append(perm);
471                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
472         }
473         return sum / num;
474 }
475
476 /** Symmetrize expression over a list of objects (symbols, indices). */
477 ex ex::symmetrize(const lst & l) const
478 {
479         exvector v(l.begin(), l.end());
480         return symm(*this, v.begin(), v.end(), false);
481 }
482
483 /** Antisymmetrize expression over a list of objects (symbols, indices). */
484 ex ex::antisymmetrize(const lst & l) const
485 {
486         exvector v(l.begin(), l.end());
487         return symm(*this, v.begin(), v.end(), true);
488 }
489
490 /** Symmetrize expression by cyclic permutation over a list of objects
491  *  (symbols, indices). */
492 ex ex::symmetrize_cyclic(const lst & l) const
493 {
494         exvector v(l.begin(), l.end());
495         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
496 }
497
498 } // namespace GiNaC