115df0bfe4924505dfc507eadf08754b592afada
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2019 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "symmetry.h"
24 #include "lst.h"
25 #include "add.h"
26 #include "numeric.h" // for factorial()
27 #include "operators.h"
28 #include "archive.h"
29 #include "utils.h"
30 #include "hash_seed.h"
31
32 #include <functional>
33 #include <iostream>
34 #include <limits>
35 #include <stdexcept>
36
37 namespace GiNaC {
38
39 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(symmetry, basic,
40   print_func<print_context>(&symmetry::do_print).
41   print_func<print_tree>(&symmetry::do_print_tree))
42
43 /*
44    Some notes about the structure of a symmetry tree:
45     - The leaf nodes of the tree are of type "none", have one index, and no
46       children (of course). They are constructed by the symmetry(unsigned)
47       constructor.
48     - Leaf nodes are the only nodes that only have one index.
49     - Container nodes contain two or more children. The "indices" set member
50       is the set union of the index sets of all children, and the "children"
51       vector stores the children themselves.
52     - The index set of each child of a "symm", "anti" or "cycl" node must
53       have the same size. It follows that the children of such a node are
54       either all leaf nodes, or all container nodes with two or more indices.
55 */
56
57 //////////
58 // default constructor
59 //////////
60
61 symmetry::symmetry() :  type(none)
62 {
63         setflag(status_flags::evaluated | status_flags::expanded);
64 }
65
66 //////////
67 // other constructors
68 //////////
69
70 symmetry::symmetry(unsigned i) :  type(none)
71 {
72         indices.insert(i);
73         setflag(status_flags::evaluated | status_flags::expanded);
74 }
75
76 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) :  type(t)
77 {
78         add(c1); add(c2);
79         setflag(status_flags::evaluated | status_flags::expanded);
80 }
81
82 //////////
83 // archiving
84 //////////
85
86 /** Construct object from archive_node. */
87 void symmetry::read_archive(const archive_node &n, lst &sym_lst)
88 {
89         inherited::read_archive(n, sym_lst);
90         unsigned t;
91         if (!(n.find_unsigned("type", t)))
92                 throw (std::runtime_error("unknown symmetry type in archive"));
93         type = (symmetry_type)t;
94
95         unsigned i = 0;
96         while (true) {
97                 ex e;
98                 if (n.find_ex("child", e, sym_lst, i))
99                         add(ex_to<symmetry>(e));
100                 else
101                         break;
102                 i++;
103         }
104
105         if (i == 0) {
106                 while (true) {
107                         unsigned u;
108                         if (n.find_unsigned("index", u, i))
109                                 indices.insert(u);
110                         else
111                                 break;
112                         i++;
113                 }
114         }
115 }
116 GINAC_BIND_UNARCHIVER(symmetry);
117
118 /** Archive the object. */
119 void symmetry::archive(archive_node &n) const
120 {
121         inherited::archive(n);
122
123         n.add_unsigned("type", type);
124
125         if (children.empty()) {
126                 for (auto & i : indices) {
127                         n.add_unsigned("index", i);
128                 }
129         } else {
130                 for (auto & i : children) {
131                         n.add_ex("child", i);
132                 }
133         }
134 }
135
136 //////////
137 // functions overriding virtual functions from base classes
138 //////////
139
140 int symmetry::compare_same_type(const basic & other) const
141 {
142         GINAC_ASSERT(is_a<symmetry>(other));
143
144         // For archiving purposes we need to have an ordering of symmetries.
145         const symmetry &othersymm = ex_to<symmetry>(other);
146
147         // Compare type.
148         if (type > othersymm.type)
149                 return 1;
150         if (type < othersymm.type)
151                 return -1;
152
153         // Compare the index set.
154         size_t this_size = indices.size();
155         size_t that_size = othersymm.indices.size();
156         if (this_size > that_size)
157                 return 1;
158         if (this_size < that_size)
159                 return -1;
160         auto end = indices.end();
161         for (auto i=indices.begin(),j=othersymm.indices.begin(); i!=end; ++i,++j) {
162                 if(*i < *j)
163                         return 1;
164                 if(*i > *j)
165                         return -1;
166         }
167
168         // Compare the children.
169         if (children.size() > othersymm.children.size())
170                 return 1;
171         if (children.size() < othersymm.children.size())
172                 return -1;
173         for (size_t i=0; i<children.size(); ++i) {
174                 int cmpval = ex_to<symmetry>(children[i])
175                         .compare_same_type(ex_to<symmetry>(othersymm.children[i]));
176                 if (cmpval)
177                         return cmpval;
178         }
179
180         return 0;
181 }
182
183 unsigned symmetry::calchash() const
184 {
185         unsigned v = make_hash_seed(typeid(*this));
186
187         if (type == none) {
188                 v = rotate_left(v);
189                 if (!indices.empty())
190                         v ^= *(indices.begin());
191         } else {
192                 for (auto & i : children) {
193                         v = rotate_left(v);
194                         v ^= i.gethash();
195                 }
196         }
197
198         if (flags & status_flags::evaluated) {
199                 setflag(status_flags::hash_calculated);
200                 hashvalue = v;
201         }
202
203         return v;
204 }
205
206 void symmetry::do_print(const print_context & c, unsigned level) const
207 {
208         if (children.empty()) {
209                 if (indices.size() > 0)
210                         c.s << *(indices.begin());
211                 else
212                         c.s << "none";
213         } else {
214                 switch (type) {
215                         case none: c.s << '!'; break;
216                         case symmetric: c.s << '+'; break;
217                         case antisymmetric: c.s << '-'; break;
218                         case cyclic: c.s << '@'; break;
219                         default: c.s << '?'; break;
220                 }
221                 c.s << '(';
222                 size_t num = children.size();
223                 for (size_t i=0; i<num; i++) {
224                         children[i].print(c);
225                         if (i != num - 1)
226                                 c.s << ",";
227                 }
228                 c.s << ')';
229         }
230 }
231
232 void symmetry::do_print_tree(const print_tree & c, unsigned level) const
233 {
234         c.s << std::string(level, ' ') << class_name() << " @" << this
235             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
236             << ", type=";
237
238         switch (type) {
239                 case none: c.s << "none"; break;
240                 case symmetric: c.s << "symm"; break;
241                 case antisymmetric: c.s << "anti"; break;
242                 case cyclic: c.s << "cycl"; break;
243                 default: c.s << "<unknown>"; break;
244         }
245
246         c.s << ", indices=(";
247         if (!indices.empty()) {
248                 auto i = indices.begin(), end = indices.end();
249                 --end;
250                 while (i != end)
251                         c.s << *i++ << ",";
252                 c.s << *i;
253         }
254         c.s << ")\n";
255
256         for (auto & i : children) {
257                 i.print(c, level + c.delta_indent);
258         }
259 }
260
261 //////////
262 // non-virtual functions in this class
263 //////////
264
265 bool symmetry::has_nonsymmetric() const
266 {
267         if (type == antisymmetric || type == cyclic)
268                 return true;
269
270         for (auto & i : children)
271                 if (ex_to<symmetry>(i).has_nonsymmetric())
272                         return true;
273
274         return false;
275 }
276
277 bool symmetry::has_cyclic() const
278 {
279         if (type == cyclic)
280                 return true;
281
282         for (auto & i : children)
283                 if (ex_to<symmetry>(i).has_cyclic())
284                         return true;
285
286         return false;
287 }
288
289 symmetry &symmetry::add(const symmetry &c)
290 {
291         // All children must have the same number of indices
292         if (type != none && !children.empty()) {
293                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
294                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
295                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
296         }
297
298         // Compute union of indices and check whether the two sets are disjoint
299         std::set<unsigned> un;
300         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
301         if (un.size() != indices.size() + c.indices.size())
302                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
303
304         // Set new index set
305         indices.swap(un);
306
307         // Add child node
308         children.push_back(c);
309         return *this;
310 }
311
312 void symmetry::validate(unsigned n)
313 {
314         if (indices.upper_bound(n - 1) != indices.end())
315                 throw (std::range_error("symmetry::verify(): index values are out of range"));
316         if (type != none && indices.empty()) {
317                 for (unsigned i=0; i<n; i++)
318                         add(i);
319         }
320 }
321
322 //////////
323 // global functions
324 //////////
325
326 static const symmetry & index0()
327 {
328         static ex s = dynallocate<symmetry>(0);
329         return ex_to<symmetry>(s);
330 }
331
332 static const symmetry & index1()
333 {
334         static ex s = dynallocate<symmetry>(1);
335         return ex_to<symmetry>(s);
336 }
337
338 static const symmetry & index2()
339 {
340         static ex s = dynallocate<symmetry>(2);
341         return ex_to<symmetry>(s);
342 }
343
344 static const symmetry & index3()
345 {
346         static ex s = dynallocate<symmetry>(3);
347         return ex_to<symmetry>(s);
348 }
349
350 const symmetry & not_symmetric()
351 {
352         static ex s = dynallocate<symmetry>();
353         return ex_to<symmetry>(s);
354 }
355
356 const symmetry & symmetric2()
357 {
358         static ex s = dynallocate<symmetry>(symmetry::symmetric, index0(), index1());
359         return ex_to<symmetry>(s);
360 }
361
362 const symmetry & symmetric3()
363 {
364         static ex s = dynallocate<symmetry>(symmetry::symmetric, index0(), index1()).add(index2());
365         return ex_to<symmetry>(s);
366 }
367
368 const symmetry & symmetric4()
369 {
370         static ex s = dynallocate<symmetry>(symmetry::symmetric, index0(), index1()).add(index2()).add(index3());
371         return ex_to<symmetry>(s);
372 }
373
374 const symmetry & antisymmetric2()
375 {
376         static ex s = dynallocate<symmetry>(symmetry::antisymmetric, index0(), index1());
377         return ex_to<symmetry>(s);
378 }
379
380 const symmetry & antisymmetric3()
381 {
382         static ex s = dynallocate<symmetry>(symmetry::antisymmetric, index0(), index1()).add(index2());
383         return ex_to<symmetry>(s);
384 }
385
386 const symmetry & antisymmetric4()
387 {
388         static ex s = dynallocate<symmetry>(symmetry::antisymmetric, index0(), index1()).add(index2()).add(index3());
389         return ex_to<symmetry>(s);
390 }
391
392 class sy_is_less {
393         exvector::iterator v;
394
395 public:
396         sy_is_less(exvector::iterator v_) : v(v_) {}
397
398         bool operator() (const ex &lh, const ex &rh) const
399         {
400                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
401                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
402                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
403                 auto ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
404                 while (ait != aitend) {
405                         int cmpval = v[*ait].compare(v[*bit]);
406                         if (cmpval < 0)
407                                 return true;
408                         else if (cmpval > 0)
409                                 return false;
410                         ++ait; ++bit;
411                 }
412                 return false;
413         }
414 };
415
416 class sy_swap {
417         exvector::iterator v;
418
419 public:
420         bool &swapped;
421
422         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
423
424         void operator() (const ex &lh, const ex &rh)
425         {
426                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
427                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
428                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
429                 auto ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
430                 while (ait != aitend) {
431                         v[*ait].swap(v[*bit]);
432                         ++ait; ++bit;
433                 }
434                 swapped = true;
435         }
436 };
437
438 int canonicalize(exvector::iterator v, const symmetry &symm)
439 {
440         // Less than two elements? Then do nothing
441         if (symm.indices.size() < 2)
442                 return std::numeric_limits<int>::max();
443
444         // Canonicalize children first
445         bool something_changed = false;
446         int sign = 1;
447         auto first = symm.children.begin(), last = symm.children.end();
448         while (first != last) {
449                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
450                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
451                 if (child_sign == 0)
452                         return 0;
453                 if (child_sign != std::numeric_limits<int>::max()) {
454                         something_changed = true;
455                         sign *= child_sign;
456                 }
457                 first++;
458         }
459
460         // Now reorder the children
461         first = symm.children.begin();
462         switch (symm.type) {
463                 case symmetry::symmetric:
464                         // Sort the children in ascending order
465                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
466                         break;
467                 case symmetry::antisymmetric:
468                         // Sort the children in ascending order, keeping track of the signum
469                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
470                         if (sign == 0)
471                                 return 0;
472                         break;
473                 case symmetry::cyclic:
474                         // Permute the smallest child to the front
475                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
476                         break;
477                 default:
478                         break;
479         }
480         return something_changed ? sign : std::numeric_limits<int>::max();
481 }
482
483
484 // Symmetrize/antisymmetrize over a vector of objects
485 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
486 {
487         // Need at least 2 objects for this operation
488         unsigned num = last - first;
489         if (num < 2)
490                 return e;
491
492         // Transform object vector to a lst (for subs())
493         lst orig_lst(first, last);
494
495         // Create index vectors for permutation
496         unsigned *iv = new unsigned[num], *iv2;
497         for (unsigned i=0; i<num; i++)
498                 iv[i] = i;
499         iv2 = (asymmetric ? new unsigned[num] : nullptr);
500
501         // Loop over all permutations (the first permutation, which is the
502         // identity, is unrolled)
503         exvector sum_v;
504         sum_v.push_back(e);
505         while (std::next_permutation(iv, iv + num)) {
506                 lst new_lst;
507                 for (unsigned i=0; i<num; i++)
508                         new_lst.append(orig_lst.op(iv[i]));
509                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
510                 if (asymmetric) {
511                         memcpy(iv2, iv, num * sizeof(unsigned));
512                         term *= permutation_sign(iv2, iv2 + num);
513                 }
514                 sum_v.push_back(term);
515         }
516         ex sum = dynallocate<add>(sum_v);
517
518         delete[] iv;
519         delete[] iv2;
520
521         return sum / factorial(numeric(num));
522 }
523
524 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
525 {
526         return symm(e, first, last, false);
527 }
528
529 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
530 {
531         return symm(e, first, last, true);
532 }
533
534 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
535 {
536         // Need at least 2 objects for this operation
537         unsigned num = last - first;
538         if (num < 2)
539                 return e;
540
541         // Transform object vector to a lst (for subs())
542         lst orig_lst(first, last);
543         lst new_lst = orig_lst;
544
545         // Loop over all cyclic permutations (the first permutation, which is
546         // the identity, is unrolled)
547         ex sum = e;
548         for (unsigned i=0; i<num-1; i++) {
549                 ex perm = new_lst.op(0);
550                 new_lst.remove_first().append(perm);
551                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
552         }
553         return sum / num;
554 }
555
556 /** Symmetrize expression over a list of objects (symbols, indices). */
557 ex ex::symmetrize(const lst & l) const
558 {
559         exvector v(l.begin(), l.end());
560         return symm(*this, v.begin(), v.end(), false);
561 }
562
563 /** Antisymmetrize expression over a list of objects (symbols, indices). */
564 ex ex::antisymmetrize(const lst & l) const
565 {
566         exvector v(l.begin(), l.end());
567         return symm(*this, v.begin(), v.end(), true);
568 }
569
570 /** Symmetrize expression by cyclic permutation over a list of objects
571  *  (symbols, indices). */
572 ex ex::symmetrize_cyclic(const lst & l) const
573 {
574         exvector v(l.begin(), l.end());
575         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
576 }
577
578 } // namespace GiNaC