e48e269851e0846442bbbc1975928348e8d24271
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2006 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "operators.h"
31 #include "archive.h"
32 #include "utils.h"
33
34 namespace GiNaC {
35
36 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(symmetry, basic,
37   print_func<print_context>(&symmetry::do_print).
38   print_func<print_tree>(&symmetry::do_print_tree))
39
40 /*
41    Some notes about the structure of a symmetry tree:
42     - The leaf nodes of the tree are of type "none", have one index, and no
43       children (of course). They are constructed by the symmetry(unsigned)
44       constructor.
45     - Leaf nodes are the only nodes that only have one index.
46     - Container nodes contain two or more children. The "indices" set member
47       is the set union of the index sets of all children, and the "children"
48       vector stores the children themselves.
49     - The index set of each child of a "symm", "anti" or "cycl" node must
50       have the same size. It follows that the children of such a node are
51       either all leaf nodes, or all container nodes with two or more indices.
52 */
53
54 //////////
55 // default constructor
56 //////////
57
58 symmetry::symmetry() : inherited(&symmetry::tinfo_static), type(none)
59 {
60         setflag(status_flags::evaluated | status_flags::expanded);
61 }
62
63 //////////
64 // other constructors
65 //////////
66
67 symmetry::symmetry(unsigned i) : inherited(&symmetry::tinfo_static), type(none)
68 {
69         indices.insert(i);
70         setflag(status_flags::evaluated | status_flags::expanded);
71 }
72
73 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : inherited(&symmetry::tinfo_static), type(t)
74 {
75         add(c1); add(c2);
76         setflag(status_flags::evaluated | status_flags::expanded);
77 }
78
79 //////////
80 // archiving
81 //////////
82
83 /** Construct object from archive_node. */
84 symmetry::symmetry(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
85 {
86         unsigned t;
87         if (!(n.find_unsigned("type", t)))
88                 throw (std::runtime_error("unknown symmetry type in archive"));
89         type = (symmetry_type)t;
90
91         unsigned i = 0;
92         while (true) {
93                 ex e;
94                 if (n.find_ex("child", e, sym_lst, i))
95                         add(ex_to<symmetry>(e));
96                 else
97                         break;
98                 i++;
99         }
100
101         if (i == 0) {
102                 while (true) {
103                         unsigned u;
104                         if (n.find_unsigned("index", u, i))
105                                 indices.insert(u);
106                         else
107                                 break;
108                         i++;
109                 }
110         }
111 }
112
113 /** Archive the object. */
114 void symmetry::archive(archive_node &n) const
115 {
116         inherited::archive(n);
117
118         n.add_unsigned("type", type);
119
120         if (children.empty()) {
121                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
122                 while (i != iend) {
123                         n.add_unsigned("index", *i);
124                         i++;
125                 }
126         } else {
127                 exvector::const_iterator i = children.begin(), iend = children.end();
128                 while (i != iend) {
129                         n.add_ex("child", *i);
130                         i++;
131                 }
132         }
133 }
134
135 DEFAULT_UNARCHIVE(symmetry)
136
137 //////////
138 // functions overriding virtual functions from base classes
139 //////////
140
141 int symmetry::compare_same_type(const basic & other) const
142 {
143         GINAC_ASSERT(is_a<symmetry>(other));
144
145         // For archiving purposes we need to have an ordering of symmetries.
146         const symmetry &othersymm = ex_to<symmetry>(other);
147
148         // Compare type.
149         if (type > othersymm.type)
150                 return 1;
151         if (type < othersymm.type)
152                 return -1;
153
154         // Compare the index set.
155         size_t this_size = indices.size();
156         size_t that_size = othersymm.indices.size();
157         if (this_size > that_size)
158                 return 1;
159         if (this_size < that_size)
160                 return -1;
161         typedef std::set<unsigned>::iterator set_it;
162         set_it end = indices.end();
163         for (set_it i=indices.begin(),j=othersymm.indices.begin(); i!=end; ++i,++j) {
164                 if(*i < *j)
165                         return 1;
166                 if(*i > *j)
167                         return -1;
168         }
169
170         // Compare the children.
171         if (children.size() > othersymm.children.size())
172                 return 1;
173         if (children.size() < othersymm.children.size())
174                 return -1;
175         for (size_t i=0; i<children.size(); ++i) {
176                 int cmpval = ex_to<symmetry>(children[i])
177                         .compare_same_type(ex_to<symmetry>(othersymm.children[i]));
178                 if (cmpval)
179                         return cmpval;
180         }
181
182         return 0;
183 }
184
185 unsigned symmetry::calchash() const
186 {
187         unsigned v = golden_ratio_hash((p_int)tinfo());
188
189         if (type == none) {
190                 v = rotate_left(v);
191                 v ^= *(indices.begin());
192         } else {
193                 for (exvector::const_iterator i=children.begin(); i!=children.end(); ++i)
194                 {
195                         v = rotate_left(v);
196                         v ^= i->gethash();
197                 }
198         }
199
200         if (flags & status_flags::evaluated) {
201                 setflag(status_flags::hash_calculated);
202                 hashvalue = v;
203         }
204
205         return v;
206 }
207
208 void symmetry::do_print(const print_context & c, unsigned level) const
209 {
210         if (children.empty()) {
211                 if (indices.size() > 0)
212                         c.s << *(indices.begin());
213                 else
214                         c.s << "none";
215         } else {
216                 switch (type) {
217                         case none: c.s << '!'; break;
218                         case symmetric: c.s << '+'; break;
219                         case antisymmetric: c.s << '-'; break;
220                         case cyclic: c.s << '@'; break;
221                         default: c.s << '?'; break;
222                 }
223                 c.s << '(';
224                 size_t num = children.size();
225                 for (size_t i=0; i<num; i++) {
226                         children[i].print(c);
227                         if (i != num - 1)
228                                 c.s << ",";
229                 }
230                 c.s << ')';
231         }
232 }
233
234 void symmetry::do_print_tree(const print_tree & c, unsigned level) const
235 {
236         c.s << std::string(level, ' ') << class_name() << " @" << this
237             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
238             << ", type=";
239
240         switch (type) {
241                 case none: c.s << "none"; break;
242                 case symmetric: c.s << "symm"; break;
243                 case antisymmetric: c.s << "anti"; break;
244                 case cyclic: c.s << "cycl"; break;
245                 default: c.s << "<unknown>"; break;
246         }
247
248         c.s << ", indices=(";
249         if (!indices.empty()) {
250                 std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
251                 --end;
252                 while (i != end)
253                         c.s << *i++ << ",";
254                 c.s << *i;
255         }
256         c.s << ")\n";
257
258         exvector::const_iterator i = children.begin(), end = children.end();
259         while (i != end) {
260                 i->print(c, level + c.delta_indent);
261                 ++i;
262         }
263 }
264
265 //////////
266 // non-virtual functions in this class
267 //////////
268
269 bool symmetry::has_cyclic() const
270 {
271         if (type == cyclic)
272                 return true;
273
274         for (exvector::const_iterator i=children.begin(); i!=children.end(); ++i)
275                 if (ex_to<symmetry>(*i).has_cyclic())
276                         return true;
277
278         return false;
279 }
280
281 symmetry &symmetry::add(const symmetry &c)
282 {
283         // All children must have the same number of indices
284         if (type != none && !children.empty()) {
285                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
286                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
287                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
288         }
289
290         // Compute union of indices and check whether the two sets are disjoint
291         std::set<unsigned> un;
292         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
293         if (un.size() != indices.size() + c.indices.size())
294                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
295
296         // Set new index set
297         indices.swap(un);
298
299         // Add child node
300         children.push_back(c);
301         return *this;
302 }
303
304 void symmetry::validate(unsigned n)
305 {
306         if (indices.upper_bound(n - 1) != indices.end())
307                 throw (std::range_error("symmetry::verify(): index values are out of range"));
308         if (type != none && indices.empty()) {
309                 for (unsigned i=0; i<n; i++)
310                         add(i);
311         }
312 }
313
314 //////////
315 // global functions
316 //////////
317
318 static const symmetry & index0()
319 {
320         static ex s = (new symmetry(0))->setflag(status_flags::dynallocated);
321         return ex_to<symmetry>(s);
322 }
323
324 static const symmetry & index1()
325 {
326         static ex s = (new symmetry(1))->setflag(status_flags::dynallocated);
327         return ex_to<symmetry>(s);
328 }
329
330 static const symmetry & index2()
331 {
332         static ex s = (new symmetry(2))->setflag(status_flags::dynallocated);
333         return ex_to<symmetry>(s);
334 }
335
336 static const symmetry & index3()
337 {
338         static ex s = (new symmetry(3))->setflag(status_flags::dynallocated);
339         return ex_to<symmetry>(s);
340 }
341
342 const symmetry & not_symmetric()
343 {
344         static ex s = (new symmetry)->setflag(status_flags::dynallocated);
345         return ex_to<symmetry>(s);
346 }
347
348 const symmetry & symmetric2()
349 {
350         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->setflag(status_flags::dynallocated);
351         return ex_to<symmetry>(s);
352 }
353
354 const symmetry & symmetric3()
355 {
356         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
357         return ex_to<symmetry>(s);
358 }
359
360 const symmetry & symmetric4()
361 {
362         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
363         return ex_to<symmetry>(s);
364 }
365
366 const symmetry & antisymmetric2()
367 {
368         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->setflag(status_flags::dynallocated);
369         return ex_to<symmetry>(s);
370 }
371
372 const symmetry & antisymmetric3()
373 {
374         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
375         return ex_to<symmetry>(s);
376 }
377
378 const symmetry & antisymmetric4()
379 {
380         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
381         return ex_to<symmetry>(s);
382 }
383
384 class sy_is_less : public std::binary_function<ex, ex, bool> {
385         exvector::iterator v;
386
387 public:
388         sy_is_less(exvector::iterator v_) : v(v_) {}
389
390         bool operator() (const ex &lh, const ex &rh) const
391         {
392                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
393                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
394                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
395                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
396                 while (ait != aitend) {
397                         int cmpval = v[*ait].compare(v[*bit]);
398                         if (cmpval < 0)
399                                 return true;
400                         else if (cmpval > 0)
401                                 return false;
402                         ++ait; ++bit;
403                 }
404                 return false;
405         }
406 };
407
408 class sy_swap : public std::binary_function<ex, ex, void> {
409         exvector::iterator v;
410
411 public:
412         bool &swapped;
413
414         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
415
416         void operator() (const ex &lh, const ex &rh)
417         {
418                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
419                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
420                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
421                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
422                 while (ait != aitend) {
423                         v[*ait].swap(v[*bit]);
424                         ++ait; ++bit;
425                 }
426                 swapped = true;
427         }
428 };
429
430 int canonicalize(exvector::iterator v, const symmetry &symm)
431 {
432         // Less than two elements? Then do nothing
433         if (symm.indices.size() < 2)
434                 return INT_MAX;
435
436         // Canonicalize children first
437         bool something_changed = false;
438         int sign = 1;
439         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
440         while (first != last) {
441                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
442                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
443                 if (child_sign == 0)
444                         return 0;
445                 if (child_sign != INT_MAX) {
446                         something_changed = true;
447                         sign *= child_sign;
448                 }
449                 first++;
450         }
451
452         // Now reorder the children
453         first = symm.children.begin();
454         switch (symm.type) {
455                 case symmetry::symmetric:
456                         // Sort the children in ascending order
457                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
458                         break;
459                 case symmetry::antisymmetric:
460                         // Sort the children in ascending order, keeping track of the signum
461                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
462                         if (sign == 0)
463                                 return 0;
464                         break;
465                 case symmetry::cyclic:
466                         // Permute the smallest child to the front
467                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
468                         break;
469                 default:
470                         break;
471         }
472         return something_changed ? sign : INT_MAX;
473 }
474
475
476 // Symmetrize/antisymmetrize over a vector of objects
477 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
478 {
479         // Need at least 2 objects for this operation
480         unsigned num = last - first;
481         if (num < 2)
482                 return e;
483
484         // Transform object vector to a lst (for subs())
485         lst orig_lst(first, last);
486
487         // Create index vectors for permutation
488         unsigned *iv = new unsigned[num], *iv2;
489         for (unsigned i=0; i<num; i++)
490                 iv[i] = i;
491         iv2 = (asymmetric ? new unsigned[num] : NULL);
492
493         // Loop over all permutations (the first permutation, which is the
494         // identity, is unrolled)
495         ex sum = e;
496         while (std::next_permutation(iv, iv + num)) {
497                 lst new_lst;
498                 for (unsigned i=0; i<num; i++)
499                         new_lst.append(orig_lst.op(iv[i]));
500                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
501                 if (asymmetric) {
502                         memcpy(iv2, iv, num * sizeof(unsigned));
503                         term *= permutation_sign(iv2, iv2 + num);
504                 }
505                 sum += term;
506         }
507
508         delete[] iv;
509         delete[] iv2;
510
511         return sum / factorial(numeric(num));
512 }
513
514 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
515 {
516         return symm(e, first, last, false);
517 }
518
519 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
520 {
521         return symm(e, first, last, true);
522 }
523
524 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
525 {
526         // Need at least 2 objects for this operation
527         unsigned num = last - first;
528         if (num < 2)
529                 return e;
530
531         // Transform object vector to a lst (for subs())
532         lst orig_lst(first, last);
533         lst new_lst = orig_lst;
534
535         // Loop over all cyclic permutations (the first permutation, which is
536         // the identity, is unrolled)
537         ex sum = e;
538         for (unsigned i=0; i<num-1; i++) {
539                 ex perm = new_lst.op(0);
540                 new_lst.remove_first().append(perm);
541                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
542         }
543         return sum / num;
544 }
545
546 /** Symmetrize expression over a list of objects (symbols, indices). */
547 ex ex::symmetrize(const lst & l) const
548 {
549         exvector v(l.begin(), l.end());
550         return symm(*this, v.begin(), v.end(), false);
551 }
552
553 /** Antisymmetrize expression over a list of objects (symbols, indices). */
554 ex ex::antisymmetrize(const lst & l) const
555 {
556         exvector v(l.begin(), l.end());
557         return symm(*this, v.begin(), v.end(), true);
558 }
559
560 /** Symmetrize expression by cyclic permutation over a list of objects
561  *  (symbols, indices). */
562 ex ex::symmetrize_cyclic(const lst & l) const
563 {
564         exvector v(l.begin(), l.end());
565         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
566 }
567
568 } // namespace GiNaC