1b818191a64115849189bb8eda2de59f1de0406b
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <stdexcept>
24 #include <functional>
25 #include <algorithm>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "print.h"
31 #include "archive.h"
32 #include "utils.h"
33 #include "debugmsg.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS(symmetry, basic)
38
39 //////////
40 // default constructor, destructor, copy constructor assignment operator and helpers
41 //////////
42
43 symmetry::symmetry() : type(none)
44 {
45         debugmsg("symmetry default constructor", LOGLEVEL_CONSTRUCT);
46         tinfo_key = TINFO_symmetry;
47 }
48
49 void symmetry::copy(const symmetry & other)
50 {
51         inherited::copy(other);
52         type = other.type;
53         indices = other.indices;
54         children = other.children;
55 }
56
57 DEFAULT_DESTROY(symmetry)
58
59 //////////
60 // other constructors
61 //////////
62
63 symmetry::symmetry(unsigned i) : type(none)
64 {
65         debugmsg("symmetry constructor from unsigned", LOGLEVEL_CONSTRUCT);
66         indices.insert(i);
67         tinfo_key = TINFO_symmetry;
68 }
69
70 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : type(t)
71 {
72         debugmsg("symmetry constructor from symmetry_type,symmetry &,symmetry &", LOGLEVEL_CONSTRUCT);
73         add(c1); add(c2);
74         tinfo_key = TINFO_symmetry;
75 }
76
77 //////////
78 // archiving
79 //////////
80
81 /** Construct object from archive_node. */
82 symmetry::symmetry(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
83 {
84         debugmsg("symmetry ctor from archive_node", LOGLEVEL_CONSTRUCT);
85
86         unsigned t;
87         if (!(n.find_unsigned("type", t)))
88                 throw (std::runtime_error("unknown symmetry type in archive"));
89         type = (symmetry_type)t;
90
91         unsigned i = 0;
92         while (true) {
93                 ex e;
94                 if (n.find_ex("child", e, sym_lst, i))
95                         add(ex_to<symmetry>(e));
96                 else
97                         break;
98                 i++;
99         }
100
101         if (i == 0) {
102                 while (true) {
103                         unsigned u;
104                         if (n.find_unsigned("index", u, i))
105                                 indices.insert(u);
106                         else
107                                 break;
108                         i++;
109                 }
110         }
111 }
112
113 /** Archive the object. */
114 void symmetry::archive(archive_node &n) const
115 {
116         inherited::archive(n);
117
118         n.add_unsigned("type", type);
119
120         if (children.empty()) {
121                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
122                 while (i != iend) {
123                         n.add_unsigned("index", *i);
124                         i++;
125                 }
126         } else {
127                 exvector::const_iterator i = children.begin(), iend = children.end();
128                 while (i != iend) {
129                         n.add_ex("child", *i);
130                         i++;
131                 }
132         }
133 }
134
135 DEFAULT_UNARCHIVE(symmetry)
136
137 //////////
138 // functions overriding virtual functions from base classes
139 //////////
140
141 int symmetry::compare_same_type(const basic & other) const
142 {
143         GINAC_ASSERT(is_of_type(other, symmetry));
144
145         // All symmetry trees are equal. They are not supposed to appear in
146         // ordinary expressions anyway...
147         return 0;
148 }
149
150 void symmetry::print(const print_context & c, unsigned level = 0) const
151 {
152         debugmsg("symmetry print", LOGLEVEL_PRINT);
153
154         if (children.empty()) {
155                 if (indices.size() > 0)
156                         c.s << *(indices.begin());
157                 else
158                         c.s << "none";
159         } else {
160                 switch (type) {
161                         case none: c.s << '!'; break;
162                         case symmetric: c.s << '+'; break;
163                         case antisymmetric: c.s << '-'; break;
164                         case cyclic: c.s << '@'; break;
165                         default: c.s << '?'; break;
166                 }
167                 c.s << '(';
168                 unsigned num = children.size();
169                 for (unsigned i=0; i<num; i++) {
170                         children[i].print(c);
171                         if (i != num - 1)
172                                 c.s << ",";
173                 }
174                 c.s << ')';
175         }
176 }
177
178 //////////
179 // non-virtual functions in this class
180 //////////
181
182 symmetry &symmetry::add(const symmetry &c)
183 {
184         // All children must have the same number of indices
185         if (type != none && !children.empty()) {
186                 GINAC_ASSERT(is_ex_exactly_of_type(children[0], symmetry));
187                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
188                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
189         }
190
191         // Compute union of indices and check whether the two sets are disjoint
192         std::set<unsigned> un;
193         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
194         if (un.size() != indices.size() + c.indices.size())
195                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
196
197         // Set new index set
198         indices.swap(un);
199
200         // Add child node
201         children.push_back(c);
202         return *this;
203 }
204
205 void symmetry::validate(unsigned n)
206 {
207         if (indices.upper_bound(n - 1) != indices.end())
208                 throw (std::range_error("symmetry::verify(): index values are out of range"));
209         if (type != none && indices.empty()) {
210                 for (unsigned i=0; i<n; i++)
211                         add(i);
212         }
213 }
214
215 //////////
216 // global functions
217 //////////
218
219 class sy_is_less : public std::binary_function<ex, ex, bool> {
220         exvector::iterator v;
221
222 public:
223         sy_is_less(exvector::iterator v_) : v(v_) {}
224
225         bool operator() (const ex &lh, const ex &rh) const
226         {
227                 GINAC_ASSERT(is_ex_exactly_of_type(lh, symmetry));
228                 GINAC_ASSERT(is_ex_exactly_of_type(rh, symmetry));
229                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
230                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
231                 while (ait != aitend) {
232                         int cmpval = v[*ait].compare(v[*bit]);
233                         if (cmpval < 0)
234                                 return true;
235                         else if (cmpval > 0)
236                                 return false;
237                         ++ait; ++bit;
238                 }
239                 return false;
240         }
241 };
242
243 class sy_swap : public std::binary_function<ex, ex, void> {
244         exvector::iterator v;
245
246 public:
247         bool &swapped;
248
249         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
250
251         void operator() (const ex &lh, const ex &rh)
252         {
253                 GINAC_ASSERT(is_ex_exactly_of_type(lh, symmetry));
254                 GINAC_ASSERT(is_ex_exactly_of_type(rh, symmetry));
255                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
256                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
257                 while (ait != aitend) {
258                         v[*ait].swap(v[*bit]);
259                         ++ait; ++bit;
260                 }
261                 swapped = true;
262         }
263 };
264
265 int canonicalize(exvector::iterator v, const symmetry &symm)
266 {
267         // No children? Then do nothing
268         if (symm.children.empty())
269                 return INT_MAX;
270
271         // Canonicalize children first
272         bool something_changed = false;
273         int sign = 1;
274         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
275         while (first != last) {
276                 GINAC_ASSERT(is_ex_exactly_of_type(*first, symmetry));
277                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
278                 if (child_sign == 0)
279                         return 0;
280                 if (child_sign != INT_MAX) {
281                         something_changed = true;
282                         sign *= child_sign;
283                 }
284                 first++;
285         }
286
287         // Now reorder the children
288         first = symm.children.begin();
289         switch (symm.type) {
290                 case symmetry::symmetric:
291                         // Sort the children in ascending order
292                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
293                         break;
294                 case symmetry::antisymmetric:
295                         // Sort the children in ascending order, keeping track of the signum
296                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
297                         break;
298                 case symmetry::cyclic:
299                         // Permute the smallest child to the front
300                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
301                         break;
302                 default:
303                         break;
304         }
305         return something_changed ? sign : INT_MAX;
306 }
307
308
309 // Symmetrize/antisymmetrize over a vector of objects
310 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
311 {
312         // Need at least 2 objects for this operation
313         unsigned num = last - first;
314         if (num < 2)
315                 return e;
316
317         // Transform object vector to a list
318         exlist iv_lst;
319         iv_lst.insert(iv_lst.begin(), first, last);
320         lst orig_lst(iv_lst, true);
321
322         // Create index vectors for permutation
323         unsigned *iv = new unsigned[num], *iv2;
324         for (unsigned i=0; i<num; i++)
325                 iv[i] = i;
326         iv2 = (asymmetric ? new unsigned[num] : NULL);
327
328         // Loop over all permutations (the first permutation, which is the
329         // identity, is unrolled)
330         ex sum = e;
331         while (std::next_permutation(iv, iv + num)) {
332                 lst new_lst;
333                 for (unsigned i=0; i<num; i++)
334                         new_lst.append(orig_lst.op(iv[i]));
335                 ex term = e.subs(orig_lst, new_lst);
336                 if (asymmetric) {
337                         memcpy(iv2, iv, num * sizeof(unsigned));
338                         term *= permutation_sign(iv2, iv2 + num);
339                 }
340                 sum += term;
341         }
342
343         delete[] iv;
344         delete[] iv2;
345
346         return sum / factorial(numeric(num));
347 }
348
349 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
350 {
351         return symm(e, first, last, false);
352 }
353
354 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
355 {
356         return symm(e, first, last, true);
357 }
358
359 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
360 {
361         // Need at least 2 objects for this operation
362         unsigned num = last - first;
363         if (num < 2)
364                 return e;
365
366         // Transform object vector to a list
367         exlist iv_lst;
368         iv_lst.insert(iv_lst.begin(), first, last);
369         lst orig_lst(iv_lst, true);
370         lst new_lst = orig_lst;
371
372         // Loop over all cyclic permutations (the first permutation, which is
373         // the identity, is unrolled)
374         ex sum = e;
375         for (unsigned i=0; i<num-1; i++) {
376                 ex perm = new_lst.op(0);
377                 new_lst.remove_first().append(perm);
378                 sum += e.subs(orig_lst, new_lst);
379         }
380         return sum / num;
381 }
382
383 /** Symmetrize expression over a list of objects (symbols, indices). */
384 ex ex::symmetrize(const lst & l) const
385 {
386         exvector v;
387         v.reserve(l.nops());
388         for (unsigned i=0; i<l.nops(); i++)
389                 v.push_back(l.op(i));
390         return symm(*this, v.begin(), v.end(), false);
391 }
392
393 /** Antisymmetrize expression over a list of objects (symbols, indices). */
394 ex ex::antisymmetrize(const lst & l) const
395 {
396         exvector v;
397         v.reserve(l.nops());
398         for (unsigned i=0; i<l.nops(); i++)
399                 v.push_back(l.op(i));
400         return symm(*this, v.begin(), v.end(), true);
401 }
402
403 /** Symmetrize expression by cyclic permutation over a list of objects
404  *  (symbols, indices). */
405 ex ex::symmetrize_cyclic(const lst & l) const
406 {
407         exvector v;
408         v.reserve(l.nops());
409         for (unsigned i=0; i<l.nops(); i++)
410                 v.push_back(l.op(i));
411         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
412 }
413
414 } // namespace GiNaC