Added calchash and compare_same_type for symmetries. Necessary for archiving.
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2006 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "operators.h"
31 #include "archive.h"
32 #include "utils.h"
33
34 namespace GiNaC {
35
36 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(symmetry, basic,
37   print_func<print_context>(&symmetry::do_print).
38   print_func<print_tree>(&symmetry::do_print_tree))
39
40 /*
41    Some notes about the structure of a symmetry tree:
42     - The leaf nodes of the tree are of type "none", have one index, and no
43       children (of course). They are constructed by the symmetry(unsigned)
44       constructor.
45     - Leaf nodes are the only nodes that only have one index.
46     - Container nodes contain two or more children. The "indices" set member
47       is the set union of the index sets of all children, and the "children"
48       vector stores the children themselves.
49     - The index set of each child of a "symm", "anti" or "cycl" node must
50       have the same size. It follows that the children of such a node are
51       either all leaf nodes, or all container nodes with two or more indices.
52 */
53
54 //////////
55 // default constructor
56 //////////
57
58 symmetry::symmetry() : inherited(&symmetry::tinfo_static), type(none)
59 {
60         setflag(status_flags::evaluated | status_flags::expanded);
61 }
62
63 //////////
64 // other constructors
65 //////////
66
67 symmetry::symmetry(unsigned i) : inherited(&symmetry::tinfo_static), type(none)
68 {
69         indices.insert(i);
70         setflag(status_flags::evaluated | status_flags::expanded);
71 }
72
73 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : inherited(&symmetry::tinfo_static), type(t)
74 {
75         add(c1); add(c2);
76         setflag(status_flags::evaluated | status_flags::expanded);
77 }
78
79 //////////
80 // archiving
81 //////////
82
83 /** Construct object from archive_node. */
84 symmetry::symmetry(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
85 {
86         unsigned t;
87         if (!(n.find_unsigned("type", t)))
88                 throw (std::runtime_error("unknown symmetry type in archive"));
89         type = (symmetry_type)t;
90
91         unsigned i = 0;
92         while (true) {
93                 ex e;
94                 if (n.find_ex("child", e, sym_lst, i))
95                         add(ex_to<symmetry>(e));
96                 else
97                         break;
98                 i++;
99         }
100
101         if (i == 0) {
102                 while (true) {
103                         unsigned u;
104                         if (n.find_unsigned("index", u, i))
105                                 indices.insert(u);
106                         else
107                                 break;
108                         i++;
109                 }
110         }
111 }
112
113 /** Archive the object. */
114 void symmetry::archive(archive_node &n) const
115 {
116         inherited::archive(n);
117
118         n.add_unsigned("type", type);
119
120         if (children.empty()) {
121                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
122                 while (i != iend) {
123                         n.add_unsigned("index", *i);
124                         i++;
125                 }
126         } else {
127                 exvector::const_iterator i = children.begin(), iend = children.end();
128                 while (i != iend) {
129                         n.add_ex("child", *i);
130                         i++;
131                 }
132         }
133 }
134
135 DEFAULT_UNARCHIVE(symmetry)
136
137 //////////
138 // functions overriding virtual functions from base classes
139 //////////
140
141 int symmetry::compare_same_type(const basic & other) const
142 {
143         GINAC_ASSERT(is_a<symmetry>(other));
144
145         // For archiving purposes we need to have an ordering of symmetries.
146         const symmetry &othersymm = ex_to<symmetry>(other);
147
148         // Compare type.
149         if (type > othersymm.type)
150                 return 1;
151         if (type < othersymm.type)
152                 return -1;
153
154         // Compare the index set.
155         size_t this_size = indices.size();
156         size_t that_size = othersymm.indices.size();
157         if (this_size > that_size)
158                 return 1;
159         if (this_size < that_size)
160                 return -1;
161         typedef std::set<unsigned>::iterator set_it;
162         set_it end = indices.end();
163         for (set_it i=indices.begin(),j=othersymm.indices.begin(); i!=end; ++i,++j) {
164                 if(*i < *j)
165                         return 1;
166                 if(*i > *j)
167                         return -1;
168         }
169
170         // Compare the children.
171         if (children.size() > othersymm.children.size())
172                 return 1;
173         if (children.size() < othersymm.children.size())
174                 return -1;
175         for (size_t i=0; i<children.size(); ++i) {
176                 int cmpval = ex_to<symmetry>(children[i])
177                         .compare_same_type(ex_to<symmetry>(othersymm.children[i]));
178                 if (cmpval)
179                         return cmpval;
180         }
181
182         return 0;
183 }
184
185 unsigned symmetry::calchash() const
186 {
187         unsigned v = golden_ratio_hash((p_int)tinfo());
188
189         if (type == none) {
190                 v = rotate_left(v);
191                 v ^= *(indices.begin());
192         } else {
193                 for (exvector::const_iterator i=children.begin(); i!=children.end(); ++i)
194                 {
195                         v = rotate_left(v);
196                         v ^= i->gethash();
197                 }
198         }
199
200         if (flags & status_flags::evaluated) {
201                 setflag(status_flags::hash_calculated);
202                 hashvalue = v;
203         }
204
205         return v;
206 }
207
208 void symmetry::do_print(const print_context & c, unsigned level) const
209 {
210         if (children.empty()) {
211                 if (indices.size() > 0)
212                         c.s << *(indices.begin());
213                 else
214                         c.s << "none";
215         } else {
216                 switch (type) {
217                         case none: c.s << '!'; break;
218                         case symmetric: c.s << '+'; break;
219                         case antisymmetric: c.s << '-'; break;
220                         case cyclic: c.s << '@'; break;
221                         default: c.s << '?'; break;
222                 }
223                 c.s << '(';
224                 size_t num = children.size();
225                 for (size_t i=0; i<num; i++) {
226                         children[i].print(c);
227                         if (i != num - 1)
228                                 c.s << ",";
229                 }
230                 c.s << ')';
231         }
232 }
233
234 void symmetry::do_print_tree(const print_tree & c, unsigned level) const
235 {
236         c.s << std::string(level, ' ') << class_name() << " @" << this
237             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
238             << ", type=";
239
240         switch (type) {
241                 case none: c.s << "none"; break;
242                 case symmetric: c.s << "symm"; break;
243                 case antisymmetric: c.s << "anti"; break;
244                 case cyclic: c.s << "cycl"; break;
245                 default: c.s << "<unknown>"; break;
246         }
247
248         c.s << ", indices=(";
249         if (!indices.empty()) {
250                 std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
251                 --end;
252                 while (i != end)
253                         c.s << *i++ << ",";
254                 c.s << *i;
255         }
256         c.s << ")\n";
257
258         exvector::const_iterator i = children.begin(), end = children.end();
259         while (i != end) {
260                 i->print(c, level + c.delta_indent);
261                 ++i;
262         }
263 }
264
265 //////////
266 // non-virtual functions in this class
267 //////////
268
269 symmetry &symmetry::add(const symmetry &c)
270 {
271         // All children must have the same number of indices
272         if (type != none && !children.empty()) {
273                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
274                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
275                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
276         }
277
278         // Compute union of indices and check whether the two sets are disjoint
279         std::set<unsigned> un;
280         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
281         if (un.size() != indices.size() + c.indices.size())
282                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
283
284         // Set new index set
285         indices.swap(un);
286
287         // Add child node
288         children.push_back(c);
289         return *this;
290 }
291
292 void symmetry::validate(unsigned n)
293 {
294         if (indices.upper_bound(n - 1) != indices.end())
295                 throw (std::range_error("symmetry::verify(): index values are out of range"));
296         if (type != none && indices.empty()) {
297                 for (unsigned i=0; i<n; i++)
298                         add(i);
299         }
300 }
301
302 //////////
303 // global functions
304 //////////
305
306 static const symmetry & index0()
307 {
308         static ex s = (new symmetry(0))->setflag(status_flags::dynallocated);
309         return ex_to<symmetry>(s);
310 }
311
312 static const symmetry & index1()
313 {
314         static ex s = (new symmetry(1))->setflag(status_flags::dynallocated);
315         return ex_to<symmetry>(s);
316 }
317
318 static const symmetry & index2()
319 {
320         static ex s = (new symmetry(2))->setflag(status_flags::dynallocated);
321         return ex_to<symmetry>(s);
322 }
323
324 static const symmetry & index3()
325 {
326         static ex s = (new symmetry(3))->setflag(status_flags::dynallocated);
327         return ex_to<symmetry>(s);
328 }
329
330 const symmetry & not_symmetric()
331 {
332         static ex s = (new symmetry)->setflag(status_flags::dynallocated);
333         return ex_to<symmetry>(s);
334 }
335
336 const symmetry & symmetric2()
337 {
338         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->setflag(status_flags::dynallocated);
339         return ex_to<symmetry>(s);
340 }
341
342 const symmetry & symmetric3()
343 {
344         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
345         return ex_to<symmetry>(s);
346 }
347
348 const symmetry & symmetric4()
349 {
350         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
351         return ex_to<symmetry>(s);
352 }
353
354 const symmetry & antisymmetric2()
355 {
356         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->setflag(status_flags::dynallocated);
357         return ex_to<symmetry>(s);
358 }
359
360 const symmetry & antisymmetric3()
361 {
362         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
363         return ex_to<symmetry>(s);
364 }
365
366 const symmetry & antisymmetric4()
367 {
368         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
369         return ex_to<symmetry>(s);
370 }
371
372 class sy_is_less : public std::binary_function<ex, ex, bool> {
373         exvector::iterator v;
374
375 public:
376         sy_is_less(exvector::iterator v_) : v(v_) {}
377
378         bool operator() (const ex &lh, const ex &rh) const
379         {
380                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
381                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
382                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
383                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
384                 while (ait != aitend) {
385                         int cmpval = v[*ait].compare(v[*bit]);
386                         if (cmpval < 0)
387                                 return true;
388                         else if (cmpval > 0)
389                                 return false;
390                         ++ait; ++bit;
391                 }
392                 return false;
393         }
394 };
395
396 class sy_swap : public std::binary_function<ex, ex, void> {
397         exvector::iterator v;
398
399 public:
400         bool &swapped;
401
402         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
403
404         void operator() (const ex &lh, const ex &rh)
405         {
406                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
407                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
408                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
409                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
410                 while (ait != aitend) {
411                         v[*ait].swap(v[*bit]);
412                         ++ait; ++bit;
413                 }
414                 swapped = true;
415         }
416 };
417
418 int canonicalize(exvector::iterator v, const symmetry &symm)
419 {
420         // Less than two elements? Then do nothing
421         if (symm.indices.size() < 2)
422                 return INT_MAX;
423
424         // Canonicalize children first
425         bool something_changed = false;
426         int sign = 1;
427         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
428         while (first != last) {
429                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
430                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
431                 if (child_sign == 0)
432                         return 0;
433                 if (child_sign != INT_MAX) {
434                         something_changed = true;
435                         sign *= child_sign;
436                 }
437                 first++;
438         }
439
440         // Now reorder the children
441         first = symm.children.begin();
442         switch (symm.type) {
443                 case symmetry::symmetric:
444                         // Sort the children in ascending order
445                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
446                         break;
447                 case symmetry::antisymmetric:
448                         // Sort the children in ascending order, keeping track of the signum
449                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
450                         if (sign == 0)
451                                 return 0;
452                         break;
453                 case symmetry::cyclic:
454                         // Permute the smallest child to the front
455                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
456                         break;
457                 default:
458                         break;
459         }
460         return something_changed ? sign : INT_MAX;
461 }
462
463
464 // Symmetrize/antisymmetrize over a vector of objects
465 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
466 {
467         // Need at least 2 objects for this operation
468         unsigned num = last - first;
469         if (num < 2)
470                 return e;
471
472         // Transform object vector to a lst (for subs())
473         lst orig_lst(first, last);
474
475         // Create index vectors for permutation
476         unsigned *iv = new unsigned[num], *iv2;
477         for (unsigned i=0; i<num; i++)
478                 iv[i] = i;
479         iv2 = (asymmetric ? new unsigned[num] : NULL);
480
481         // Loop over all permutations (the first permutation, which is the
482         // identity, is unrolled)
483         ex sum = e;
484         while (std::next_permutation(iv, iv + num)) {
485                 lst new_lst;
486                 for (unsigned i=0; i<num; i++)
487                         new_lst.append(orig_lst.op(iv[i]));
488                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
489                 if (asymmetric) {
490                         memcpy(iv2, iv, num * sizeof(unsigned));
491                         term *= permutation_sign(iv2, iv2 + num);
492                 }
493                 sum += term;
494         }
495
496         delete[] iv;
497         delete[] iv2;
498
499         return sum / factorial(numeric(num));
500 }
501
502 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
503 {
504         return symm(e, first, last, false);
505 }
506
507 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
508 {
509         return symm(e, first, last, true);
510 }
511
512 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
513 {
514         // Need at least 2 objects for this operation
515         unsigned num = last - first;
516         if (num < 2)
517                 return e;
518
519         // Transform object vector to a lst (for subs())
520         lst orig_lst(first, last);
521         lst new_lst = orig_lst;
522
523         // Loop over all cyclic permutations (the first permutation, which is
524         // the identity, is unrolled)
525         ex sum = e;
526         for (unsigned i=0; i<num-1; i++) {
527                 ex perm = new_lst.op(0);
528                 new_lst.remove_first().append(perm);
529                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
530         }
531         return sum / num;
532 }
533
534 /** Symmetrize expression over a list of objects (symbols, indices). */
535 ex ex::symmetrize(const lst & l) const
536 {
537         exvector v(l.begin(), l.end());
538         return symm(*this, v.begin(), v.end(), false);
539 }
540
541 /** Antisymmetrize expression over a list of objects (symbols, indices). */
542 ex ex::antisymmetrize(const lst & l) const
543 {
544         exvector v(l.begin(), l.end());
545         return symm(*this, v.begin(), v.end(), true);
546 }
547
548 /** Symmetrize expression by cyclic permutation over a list of objects
549  *  (symbols, indices). */
550 ex ex::symmetrize_cyclic(const lst & l) const
551 {
552         exvector v(l.begin(), l.end());
553         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
554 }
555
556 } // namespace GiNaC