725506fa6fa5a8e4af653c9c647ff33ad020834e
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2009 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "symmetry.h"
24 #include "lst.h"
25 #include "numeric.h" // for factorial()
26 #include "operators.h"
27 #include "archive.h"
28 #include "utils.h"
29 #include "hash_seed.h"
30
31 #include <functional>
32 #include <iostream>
33 #include <limits>
34 #include <stdexcept>
35
36 namespace GiNaC {
37
38 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(symmetry, basic,
39   print_func<print_context>(&symmetry::do_print).
40   print_func<print_tree>(&symmetry::do_print_tree))
41
42 /*
43    Some notes about the structure of a symmetry tree:
44     - The leaf nodes of the tree are of type "none", have one index, and no
45       children (of course). They are constructed by the symmetry(unsigned)
46       constructor.
47     - Leaf nodes are the only nodes that only have one index.
48     - Container nodes contain two or more children. The "indices" set member
49       is the set union of the index sets of all children, and the "children"
50       vector stores the children themselves.
51     - The index set of each child of a "symm", "anti" or "cycl" node must
52       have the same size. It follows that the children of such a node are
53       either all leaf nodes, or all container nodes with two or more indices.
54 */
55
56 //////////
57 // default constructor
58 //////////
59
60 symmetry::symmetry() :  type(none)
61 {
62         setflag(status_flags::evaluated | status_flags::expanded);
63 }
64
65 //////////
66 // other constructors
67 //////////
68
69 symmetry::symmetry(unsigned i) :  type(none)
70 {
71         indices.insert(i);
72         setflag(status_flags::evaluated | status_flags::expanded);
73 }
74
75 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) :  type(t)
76 {
77         add(c1); add(c2);
78         setflag(status_flags::evaluated | status_flags::expanded);
79 }
80
81 //////////
82 // archiving
83 //////////
84
85 /** Construct object from archive_node. */
86 void symmetry::read_archive(const archive_node &n, lst &sym_lst)
87 {
88         inherited::read_archive(n, sym_lst);
89         unsigned t;
90         if (!(n.find_unsigned("type", t)))
91                 throw (std::runtime_error("unknown symmetry type in archive"));
92         type = (symmetry_type)t;
93
94         unsigned i = 0;
95         while (true) {
96                 ex e;
97                 if (n.find_ex("child", e, sym_lst, i))
98                         add(ex_to<symmetry>(e));
99                 else
100                         break;
101                 i++;
102         }
103
104         if (i == 0) {
105                 while (true) {
106                         unsigned u;
107                         if (n.find_unsigned("index", u, i))
108                                 indices.insert(u);
109                         else
110                                 break;
111                         i++;
112                 }
113         }
114 }
115 GINAC_BIND_UNARCHIVER(symmetry);
116
117 /** Archive the object. */
118 void symmetry::archive(archive_node &n) const
119 {
120         inherited::archive(n);
121
122         n.add_unsigned("type", type);
123
124         if (children.empty()) {
125                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
126                 while (i != iend) {
127                         n.add_unsigned("index", *i);
128                         i++;
129                 }
130         } else {
131                 exvector::const_iterator i = children.begin(), iend = children.end();
132                 while (i != iend) {
133                         n.add_ex("child", *i);
134                         i++;
135                 }
136         }
137 }
138
139 //////////
140 // functions overriding virtual functions from base classes
141 //////////
142
143 int symmetry::compare_same_type(const basic & other) const
144 {
145         GINAC_ASSERT(is_a<symmetry>(other));
146
147         // For archiving purposes we need to have an ordering of symmetries.
148         const symmetry &othersymm = ex_to<symmetry>(other);
149
150         // Compare type.
151         if (type > othersymm.type)
152                 return 1;
153         if (type < othersymm.type)
154                 return -1;
155
156         // Compare the index set.
157         size_t this_size = indices.size();
158         size_t that_size = othersymm.indices.size();
159         if (this_size > that_size)
160                 return 1;
161         if (this_size < that_size)
162                 return -1;
163         typedef std::set<unsigned>::iterator set_it;
164         set_it end = indices.end();
165         for (set_it i=indices.begin(),j=othersymm.indices.begin(); i!=end; ++i,++j) {
166                 if(*i < *j)
167                         return 1;
168                 if(*i > *j)
169                         return -1;
170         }
171
172         // Compare the children.
173         if (children.size() > othersymm.children.size())
174                 return 1;
175         if (children.size() < othersymm.children.size())
176                 return -1;
177         for (size_t i=0; i<children.size(); ++i) {
178                 int cmpval = ex_to<symmetry>(children[i])
179                         .compare_same_type(ex_to<symmetry>(othersymm.children[i]));
180                 if (cmpval)
181                         return cmpval;
182         }
183
184         return 0;
185 }
186
187 unsigned symmetry::calchash() const
188 {
189         unsigned v = make_hash_seed(typeid(*this));
190
191         if (type == none) {
192                 v = rotate_left(v);
193                 if (!indices.empty())
194                         v ^= *(indices.begin());
195         } else {
196                 for (exvector::const_iterator i=children.begin(); i!=children.end(); ++i)
197                 {
198                         v = rotate_left(v);
199                         v ^= i->gethash();
200                 }
201         }
202
203         if (flags & status_flags::evaluated) {
204                 setflag(status_flags::hash_calculated);
205                 hashvalue = v;
206         }
207
208         return v;
209 }
210
211 void symmetry::do_print(const print_context & c, unsigned level) const
212 {
213         if (children.empty()) {
214                 if (indices.size() > 0)
215                         c.s << *(indices.begin());
216                 else
217                         c.s << "none";
218         } else {
219                 switch (type) {
220                         case none: c.s << '!'; break;
221                         case symmetric: c.s << '+'; break;
222                         case antisymmetric: c.s << '-'; break;
223                         case cyclic: c.s << '@'; break;
224                         default: c.s << '?'; break;
225                 }
226                 c.s << '(';
227                 size_t num = children.size();
228                 for (size_t i=0; i<num; i++) {
229                         children[i].print(c);
230                         if (i != num - 1)
231                                 c.s << ",";
232                 }
233                 c.s << ')';
234         }
235 }
236
237 void symmetry::do_print_tree(const print_tree & c, unsigned level) const
238 {
239         c.s << std::string(level, ' ') << class_name() << " @" << this
240             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
241             << ", type=";
242
243         switch (type) {
244                 case none: c.s << "none"; break;
245                 case symmetric: c.s << "symm"; break;
246                 case antisymmetric: c.s << "anti"; break;
247                 case cyclic: c.s << "cycl"; break;
248                 default: c.s << "<unknown>"; break;
249         }
250
251         c.s << ", indices=(";
252         if (!indices.empty()) {
253                 std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
254                 --end;
255                 while (i != end)
256                         c.s << *i++ << ",";
257                 c.s << *i;
258         }
259         c.s << ")\n";
260
261         exvector::const_iterator i = children.begin(), end = children.end();
262         while (i != end) {
263                 i->print(c, level + c.delta_indent);
264                 ++i;
265         }
266 }
267
268 //////////
269 // non-virtual functions in this class
270 //////////
271
272 bool symmetry::has_cyclic() const
273 {
274         if (type == cyclic)
275                 return true;
276
277         for (exvector::const_iterator i=children.begin(); i!=children.end(); ++i)
278                 if (ex_to<symmetry>(*i).has_cyclic())
279                         return true;
280
281         return false;
282 }
283
284 symmetry &symmetry::add(const symmetry &c)
285 {
286         // All children must have the same number of indices
287         if (type != none && !children.empty()) {
288                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
289                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
290                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
291         }
292
293         // Compute union of indices and check whether the two sets are disjoint
294         std::set<unsigned> un;
295         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
296         if (un.size() != indices.size() + c.indices.size())
297                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
298
299         // Set new index set
300         indices.swap(un);
301
302         // Add child node
303         children.push_back(c);
304         return *this;
305 }
306
307 void symmetry::validate(unsigned n)
308 {
309         if (indices.upper_bound(n - 1) != indices.end())
310                 throw (std::range_error("symmetry::verify(): index values are out of range"));
311         if (type != none && indices.empty()) {
312                 for (unsigned i=0; i<n; i++)
313                         add(i);
314         }
315 }
316
317 //////////
318 // global functions
319 //////////
320
321 static const symmetry & index0()
322 {
323         static ex s = (new symmetry(0))->setflag(status_flags::dynallocated);
324         return ex_to<symmetry>(s);
325 }
326
327 static const symmetry & index1()
328 {
329         static ex s = (new symmetry(1))->setflag(status_flags::dynallocated);
330         return ex_to<symmetry>(s);
331 }
332
333 static const symmetry & index2()
334 {
335         static ex s = (new symmetry(2))->setflag(status_flags::dynallocated);
336         return ex_to<symmetry>(s);
337 }
338
339 static const symmetry & index3()
340 {
341         static ex s = (new symmetry(3))->setflag(status_flags::dynallocated);
342         return ex_to<symmetry>(s);
343 }
344
345 const symmetry & not_symmetric()
346 {
347         static ex s = (new symmetry)->setflag(status_flags::dynallocated);
348         return ex_to<symmetry>(s);
349 }
350
351 const symmetry & symmetric2()
352 {
353         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->setflag(status_flags::dynallocated);
354         return ex_to<symmetry>(s);
355 }
356
357 const symmetry & symmetric3()
358 {
359         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
360         return ex_to<symmetry>(s);
361 }
362
363 const symmetry & symmetric4()
364 {
365         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
366         return ex_to<symmetry>(s);
367 }
368
369 const symmetry & antisymmetric2()
370 {
371         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->setflag(status_flags::dynallocated);
372         return ex_to<symmetry>(s);
373 }
374
375 const symmetry & antisymmetric3()
376 {
377         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
378         return ex_to<symmetry>(s);
379 }
380
381 const symmetry & antisymmetric4()
382 {
383         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
384         return ex_to<symmetry>(s);
385 }
386
387 class sy_is_less : public std::binary_function<ex, ex, bool> {
388         exvector::iterator v;
389
390 public:
391         sy_is_less(exvector::iterator v_) : v(v_) {}
392
393         bool operator() (const ex &lh, const ex &rh) const
394         {
395                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
396                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
397                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
398                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
399                 while (ait != aitend) {
400                         int cmpval = v[*ait].compare(v[*bit]);
401                         if (cmpval < 0)
402                                 return true;
403                         else if (cmpval > 0)
404                                 return false;
405                         ++ait; ++bit;
406                 }
407                 return false;
408         }
409 };
410
411 class sy_swap : public std::binary_function<ex, ex, void> {
412         exvector::iterator v;
413
414 public:
415         bool &swapped;
416
417         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
418
419         void operator() (const ex &lh, const ex &rh)
420         {
421                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
422                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
423                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
424                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
425                 while (ait != aitend) {
426                         v[*ait].swap(v[*bit]);
427                         ++ait; ++bit;
428                 }
429                 swapped = true;
430         }
431 };
432
433 int canonicalize(exvector::iterator v, const symmetry &symm)
434 {
435         // Less than two elements? Then do nothing
436         if (symm.indices.size() < 2)
437                 return std::numeric_limits<int>::max();
438
439         // Canonicalize children first
440         bool something_changed = false;
441         int sign = 1;
442         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
443         while (first != last) {
444                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
445                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
446                 if (child_sign == 0)
447                         return 0;
448                 if (child_sign != std::numeric_limits<int>::max()) {
449                         something_changed = true;
450                         sign *= child_sign;
451                 }
452                 first++;
453         }
454
455         // Now reorder the children
456         first = symm.children.begin();
457         switch (symm.type) {
458                 case symmetry::symmetric:
459                         // Sort the children in ascending order
460                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
461                         break;
462                 case symmetry::antisymmetric:
463                         // Sort the children in ascending order, keeping track of the signum
464                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
465                         if (sign == 0)
466                                 return 0;
467                         break;
468                 case symmetry::cyclic:
469                         // Permute the smallest child to the front
470                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
471                         break;
472                 default:
473                         break;
474         }
475         return something_changed ? sign : std::numeric_limits<int>::max();
476 }
477
478
479 // Symmetrize/antisymmetrize over a vector of objects
480 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
481 {
482         // Need at least 2 objects for this operation
483         unsigned num = last - first;
484         if (num < 2)
485                 return e;
486
487         // Transform object vector to a lst (for subs())
488         lst orig_lst(first, last);
489
490         // Create index vectors for permutation
491         unsigned *iv = new unsigned[num], *iv2;
492         for (unsigned i=0; i<num; i++)
493                 iv[i] = i;
494         iv2 = (asymmetric ? new unsigned[num] : NULL);
495
496         // Loop over all permutations (the first permutation, which is the
497         // identity, is unrolled)
498         ex sum = e;
499         while (std::next_permutation(iv, iv + num)) {
500                 lst new_lst;
501                 for (unsigned i=0; i<num; i++)
502                         new_lst.append(orig_lst.op(iv[i]));
503                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
504                 if (asymmetric) {
505                         memcpy(iv2, iv, num * sizeof(unsigned));
506                         term *= permutation_sign(iv2, iv2 + num);
507                 }
508                 sum += term;
509         }
510
511         delete[] iv;
512         delete[] iv2;
513
514         return sum / factorial(numeric(num));
515 }
516
517 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
518 {
519         return symm(e, first, last, false);
520 }
521
522 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
523 {
524         return symm(e, first, last, true);
525 }
526
527 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
528 {
529         // Need at least 2 objects for this operation
530         unsigned num = last - first;
531         if (num < 2)
532                 return e;
533
534         // Transform object vector to a lst (for subs())
535         lst orig_lst(first, last);
536         lst new_lst = orig_lst;
537
538         // Loop over all cyclic permutations (the first permutation, which is
539         // the identity, is unrolled)
540         ex sum = e;
541         for (unsigned i=0; i<num-1; i++) {
542                 ex perm = new_lst.op(0);
543                 new_lst.remove_first().append(perm);
544                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
545         }
546         return sum / num;
547 }
548
549 /** Symmetrize expression over a list of objects (symbols, indices). */
550 ex ex::symmetrize(const lst & l) const
551 {
552         exvector v(l.begin(), l.end());
553         return symm(*this, v.begin(), v.end(), false);
554 }
555
556 /** Antisymmetrize expression over a list of objects (symbols, indices). */
557 ex ex::antisymmetrize(const lst & l) const
558 {
559         exvector v(l.begin(), l.end());
560         return symm(*this, v.begin(), v.end(), true);
561 }
562
563 /** Symmetrize expression by cyclic permutation over a list of objects
564  *  (symbols, indices). */
565 ex ex::symmetrize_cyclic(const lst & l) const
566 {
567         exvector v(l.begin(), l.end());
568         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
569 }
570
571 } // namespace GiNaC