]> www.ginac.de Git - ginac.git/blob - ginac/symmetry.cpp
Happy new year!
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26 #include <limits>
27
28 #include "symmetry.h"
29 #include "lst.h"
30 #include "numeric.h" // for factorial()
31 #include "operators.h"
32 #include "archive.h"
33 #include "utils.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(symmetry, basic,
38   print_func<print_context>(&symmetry::do_print).
39   print_func<print_tree>(&symmetry::do_print_tree))
40
41 /*
42    Some notes about the structure of a symmetry tree:
43     - The leaf nodes of the tree are of type "none", have one index, and no
44       children (of course). They are constructed by the symmetry(unsigned)
45       constructor.
46     - Leaf nodes are the only nodes that only have one index.
47     - Container nodes contain two or more children. The "indices" set member
48       is the set union of the index sets of all children, and the "children"
49       vector stores the children themselves.
50     - The index set of each child of a "symm", "anti" or "cycl" node must
51       have the same size. It follows that the children of such a node are
52       either all leaf nodes, or all container nodes with two or more indices.
53 */
54
55 //////////
56 // default constructor
57 //////////
58
59 symmetry::symmetry() : inherited(TINFO_symmetry), type(none)
60 {
61         setflag(status_flags::evaluated | status_flags::expanded);
62 }
63
64 //////////
65 // other constructors
66 //////////
67
68 symmetry::symmetry(unsigned i) : inherited(TINFO_symmetry), type(none)
69 {
70         indices.insert(i);
71         setflag(status_flags::evaluated | status_flags::expanded);
72 }
73
74 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : inherited(TINFO_symmetry), type(t)
75 {
76         add(c1); add(c2);
77         setflag(status_flags::evaluated | status_flags::expanded);
78 }
79
80 //////////
81 // archiving
82 //////////
83
84 /** Construct object from archive_node. */
85 symmetry::symmetry(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
86 {
87         unsigned t;
88         if (!(n.find_unsigned("type", t)))
89                 throw (std::runtime_error("unknown symmetry type in archive"));
90         type = (symmetry_type)t;
91
92         unsigned i = 0;
93         while (true) {
94                 ex e;
95                 if (n.find_ex("child", e, sym_lst, i))
96                         add(ex_to<symmetry>(e));
97                 else
98                         break;
99                 i++;
100         }
101
102         if (i == 0) {
103                 while (true) {
104                         unsigned u;
105                         if (n.find_unsigned("index", u, i))
106                                 indices.insert(u);
107                         else
108                                 break;
109                         i++;
110                 }
111         }
112 }
113
114 /** Archive the object. */
115 void symmetry::archive(archive_node &n) const
116 {
117         inherited::archive(n);
118
119         n.add_unsigned("type", type);
120
121         if (children.empty()) {
122                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
123                 while (i != iend) {
124                         n.add_unsigned("index", *i);
125                         i++;
126                 }
127         } else {
128                 exvector::const_iterator i = children.begin(), iend = children.end();
129                 while (i != iend) {
130                         n.add_ex("child", *i);
131                         i++;
132                 }
133         }
134 }
135
136 DEFAULT_UNARCHIVE(symmetry)
137
138 //////////
139 // functions overriding virtual functions from base classes
140 //////////
141
142 int symmetry::compare_same_type(const basic & other) const
143 {
144         GINAC_ASSERT(is_a<symmetry>(other));
145
146         // All symmetry trees are equal. They are not supposed to appear in
147         // ordinary expressions anyway...
148         return 0;
149 }
150
151 void symmetry::do_print(const print_context & c, unsigned level) const
152 {
153         if (children.empty()) {
154                 if (indices.size() > 0)
155                         c.s << *(indices.begin());
156                 else
157                         c.s << "none";
158         } else {
159                 switch (type) {
160                         case none: c.s << '!'; break;
161                         case symmetric: c.s << '+'; break;
162                         case antisymmetric: c.s << '-'; break;
163                         case cyclic: c.s << '@'; break;
164                         default: c.s << '?'; break;
165                 }
166                 c.s << '(';
167                 size_t num = children.size();
168                 for (size_t i=0; i<num; i++) {
169                         children[i].print(c);
170                         if (i != num - 1)
171                                 c.s << ",";
172                 }
173                 c.s << ')';
174         }
175 }
176
177 void symmetry::do_print_tree(const print_tree & c, unsigned level) const
178 {
179         c.s << std::string(level, ' ') << class_name() << " @" << this
180             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
181             << ", type=";
182
183         switch (type) {
184                 case none: c.s << "none"; break;
185                 case symmetric: c.s << "symm"; break;
186                 case antisymmetric: c.s << "anti"; break;
187                 case cyclic: c.s << "cycl"; break;
188                 default: c.s << "<unknown>"; break;
189         }
190
191         c.s << ", indices=(";
192         if (!indices.empty()) {
193                 std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
194                 --end;
195                 while (i != end)
196                         c.s << *i++ << ",";
197                 c.s << *i;
198         }
199         c.s << ")\n";
200
201         exvector::const_iterator i = children.begin(), end = children.end();
202         while (i != end) {
203                 i->print(c, level + c.delta_indent);
204                 ++i;
205         }
206 }
207
208 //////////
209 // non-virtual functions in this class
210 //////////
211
212 symmetry &symmetry::add(const symmetry &c)
213 {
214         // All children must have the same number of indices
215         if (type != none && !children.empty()) {
216                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
217                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
218                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
219         }
220
221         // Compute union of indices and check whether the two sets are disjoint
222         std::set<unsigned> un;
223         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
224         if (un.size() != indices.size() + c.indices.size())
225                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
226
227         // Set new index set
228         indices.swap(un);
229
230         // Add child node
231         children.push_back(c);
232         return *this;
233 }
234
235 void symmetry::validate(unsigned n)
236 {
237         if (indices.upper_bound(n - 1) != indices.end())
238                 throw (std::range_error("symmetry::verify(): index values are out of range"));
239         if (type != none && indices.empty()) {
240                 for (unsigned i=0; i<n; i++)
241                         add(i);
242         }
243 }
244
245 //////////
246 // global functions
247 //////////
248
249 static const symmetry & index0()
250 {
251         static ex s = (new symmetry(0))->setflag(status_flags::dynallocated);
252         return ex_to<symmetry>(s);
253 }
254
255 static const symmetry & index1()
256 {
257         static ex s = (new symmetry(1))->setflag(status_flags::dynallocated);
258         return ex_to<symmetry>(s);
259 }
260
261 static const symmetry & index2()
262 {
263         static ex s = (new symmetry(2))->setflag(status_flags::dynallocated);
264         return ex_to<symmetry>(s);
265 }
266
267 static const symmetry & index3()
268 {
269         static ex s = (new symmetry(3))->setflag(status_flags::dynallocated);
270         return ex_to<symmetry>(s);
271 }
272
273 const symmetry & not_symmetric()
274 {
275         static ex s = (new symmetry)->setflag(status_flags::dynallocated);
276         return ex_to<symmetry>(s);
277 }
278
279 const symmetry & symmetric2()
280 {
281         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->setflag(status_flags::dynallocated);
282         return ex_to<symmetry>(s);
283 }
284
285 const symmetry & symmetric3()
286 {
287         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
288         return ex_to<symmetry>(s);
289 }
290
291 const symmetry & symmetric4()
292 {
293         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
294         return ex_to<symmetry>(s);
295 }
296
297 const symmetry & antisymmetric2()
298 {
299         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->setflag(status_flags::dynallocated);
300         return ex_to<symmetry>(s);
301 }
302
303 const symmetry & antisymmetric3()
304 {
305         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
306         return ex_to<symmetry>(s);
307 }
308
309 const symmetry & antisymmetric4()
310 {
311         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
312         return ex_to<symmetry>(s);
313 }
314
315 class sy_is_less : public std::binary_function<ex, ex, bool> {
316         exvector::iterator v;
317
318 public:
319         sy_is_less(exvector::iterator v_) : v(v_) {}
320
321         bool operator() (const ex &lh, const ex &rh) const
322         {
323                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
324                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
325                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
326                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
327                 while (ait != aitend) {
328                         int cmpval = v[*ait].compare(v[*bit]);
329                         if (cmpval < 0)
330                                 return true;
331                         else if (cmpval > 0)
332                                 return false;
333                         ++ait; ++bit;
334                 }
335                 return false;
336         }
337 };
338
339 class sy_swap : public std::binary_function<ex, ex, void> {
340         exvector::iterator v;
341
342 public:
343         bool &swapped;
344
345         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
346
347         void operator() (const ex &lh, const ex &rh)
348         {
349                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
350                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
351                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
352                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
353                 while (ait != aitend) {
354                         v[*ait].swap(v[*bit]);
355                         ++ait; ++bit;
356                 }
357                 swapped = true;
358         }
359 };
360
361 int canonicalize(exvector::iterator v, const symmetry &symm)
362 {
363         // Less than two elements? Then do nothing
364         if (symm.indices.size() < 2)
365                 return std::numeric_limits<int>::max();
366
367         // Canonicalize children first
368         bool something_changed = false;
369         int sign = 1;
370         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
371         while (first != last) {
372                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
373                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
374                 if (child_sign == 0)
375                         return 0;
376                 if (child_sign != std::numeric_limits<int>::max()) {
377                         something_changed = true;
378                         sign *= child_sign;
379                 }
380                 first++;
381         }
382
383         // Now reorder the children
384         first = symm.children.begin();
385         switch (symm.type) {
386                 case symmetry::symmetric:
387                         // Sort the children in ascending order
388                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
389                         break;
390                 case symmetry::antisymmetric:
391                         // Sort the children in ascending order, keeping track of the signum
392                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
393                         if (sign == 0)
394                                 return 0;
395                         break;
396                 case symmetry::cyclic:
397                         // Permute the smallest child to the front
398                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
399                         break;
400                 default:
401                         break;
402         }
403         return something_changed ? sign : std::numeric_limits<int>::max();
404 }
405
406
407 // Symmetrize/antisymmetrize over a vector of objects
408 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
409 {
410         // Need at least 2 objects for this operation
411         unsigned num = last - first;
412         if (num < 2)
413                 return e;
414
415         // Transform object vector to a lst (for subs())
416         lst orig_lst(first, last);
417
418         // Create index vectors for permutation
419         unsigned *iv = new unsigned[num], *iv2;
420         for (unsigned i=0; i<num; i++)
421                 iv[i] = i;
422         iv2 = (asymmetric ? new unsigned[num] : NULL);
423
424         // Loop over all permutations (the first permutation, which is the
425         // identity, is unrolled)
426         ex sum = e;
427         while (std::next_permutation(iv, iv + num)) {
428                 lst new_lst;
429                 for (unsigned i=0; i<num; i++)
430                         new_lst.append(orig_lst.op(iv[i]));
431                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern);
432                 if (asymmetric) {
433                         memcpy(iv2, iv, num * sizeof(unsigned));
434                         term *= permutation_sign(iv2, iv2 + num);
435                 }
436                 sum += term;
437         }
438
439         delete[] iv;
440         delete[] iv2;
441
442         return sum / factorial(numeric(num));
443 }
444
445 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
446 {
447         return symm(e, first, last, false);
448 }
449
450 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
451 {
452         return symm(e, first, last, true);
453 }
454
455 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
456 {
457         // Need at least 2 objects for this operation
458         unsigned num = last - first;
459         if (num < 2)
460                 return e;
461
462         // Transform object vector to a lst (for subs())
463         lst orig_lst(first, last);
464         lst new_lst = orig_lst;
465
466         // Loop over all cyclic permutations (the first permutation, which is
467         // the identity, is unrolled)
468         ex sum = e;
469         for (unsigned i=0; i<num-1; i++) {
470                 ex perm = new_lst.op(0);
471                 new_lst.remove_first().append(perm);
472                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern);
473         }
474         return sum / num;
475 }
476
477 /** Symmetrize expression over a list of objects (symbols, indices). */
478 ex ex::symmetrize(const lst & l) const
479 {
480         exvector v(l.begin(), l.end());
481         return symm(*this, v.begin(), v.end(), false);
482 }
483
484 /** Antisymmetrize expression over a list of objects (symbols, indices). */
485 ex ex::antisymmetrize(const lst & l) const
486 {
487         exvector v(l.begin(), l.end());
488         return symm(*this, v.begin(), v.end(), true);
489 }
490
491 /** Symmetrize expression by cyclic permutation over a list of objects
492  *  (symbols, indices). */
493 ex ex::symmetrize_cyclic(const lst & l) const
494 {
495         exvector v(l.begin(), l.end());
496         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
497 }
498
499 } // namespace GiNaC