introduced new class for constructing symmetry tree definitions
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <stdexcept>
24 #include <functional>
25 #include <algorithm>
26
27 #define DO_GINAC_ASSERT
28 #include "assertion.h"
29
30 #include "symmetry.h"
31 #include "lst.h"
32 #include "numeric.h" // for factorial()
33 #include "print.h"
34 #include "archive.h"
35 #include "utils.h"
36 #include "debugmsg.h"
37
38 namespace GiNaC {
39
40 GINAC_IMPLEMENT_REGISTERED_CLASS(symmetry, basic)
41
42 //////////
43 // default constructor, destructor, copy constructor assignment operator and helpers
44 //////////
45
46 symmetry::symmetry() : type(none)
47 {
48         debugmsg("symmetry default constructor", LOGLEVEL_CONSTRUCT);
49         tinfo_key = TINFO_symmetry;
50 }
51
52 void symmetry::copy(const symmetry & other)
53 {
54         inherited::copy(other);
55         type = other.type;
56         indices = other.indices;
57         children = other.children;
58 }
59
60 DEFAULT_DESTROY(symmetry)
61
62 //////////
63 // other constructors
64 //////////
65
66 symmetry::symmetry(unsigned i) : type(none)
67 {
68         debugmsg("symmetry constructor from unsigned", LOGLEVEL_CONSTRUCT);
69         indices.insert(i);
70         tinfo_key = TINFO_symmetry;
71 }
72
73 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : type(t)
74 {
75         debugmsg("symmetry constructor from symmetry_type,symmetry &,symmetry &", LOGLEVEL_CONSTRUCT);
76         add(c1); add(c2);
77         tinfo_key = TINFO_symmetry;
78 }
79
80 //////////
81 // archiving
82 //////////
83
84 /** Construct object from archive_node. */
85 symmetry::symmetry(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
86 {
87         debugmsg("symmetry ctor from archive_node", LOGLEVEL_CONSTRUCT);
88
89         unsigned t;
90         if (!(n.find_unsigned("type", t)))
91                 throw (std::runtime_error("unknown symmetry type in archive"));
92         type = (symmetry_type)t;
93
94         unsigned i = 0;
95         while (true) {
96                 ex e;
97                 if (n.find_ex("child", e, sym_lst, i))
98                         add(ex_to_symmetry(e));
99                 else
100                         break;
101                 i++;
102         }
103
104         if (i == 0) {
105                 while (true) {
106                         unsigned u;
107                         if (n.find_unsigned("index", u, i))
108                                 indices.insert(u);
109                         else
110                                 break;
111                         i++;
112                 }
113         }
114 }
115
116 /** Archive the object. */
117 void symmetry::archive(archive_node &n) const
118 {
119         inherited::archive(n);
120
121         n.add_unsigned("type", type);
122
123         if (children.empty()) {
124                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
125                 while (i != iend) {
126                         n.add_unsigned("index", *i);
127                         i++;
128                 }
129         } else {
130                 exvector::const_iterator i = children.begin(), iend = children.end();
131                 while (i != iend) {
132                         n.add_ex("child", *i);
133                         i++;
134                 }
135         }
136 }
137
138 DEFAULT_UNARCHIVE(symmetry)
139
140 //////////
141 // functions overriding virtual functions from bases classes
142 //////////
143
144 int symmetry::compare_same_type(const basic & other) const
145 {
146         GINAC_ASSERT(is_of_type(other, symmetry));
147         const symmetry &o = static_cast<const symmetry &>(other);
148
149         // All symmetry trees are equal. They are not supposed to appear in
150         // ordinary expressions anyway...
151         return 0;
152 }
153
154 void symmetry::print(const print_context & c, unsigned level = 0) const
155 {
156         debugmsg("symmetry print", LOGLEVEL_PRINT);
157
158         if (children.empty()) {
159                 if (indices.size() > 0)
160                         c.s << *(indices.begin());
161         } else {
162                 switch (type) {
163                         case none: c.s << '!'; break;
164                         case symmetric: c.s << '+'; break;
165                         case antisymmetric: c.s << '-'; break;
166                         case cyclic: c.s << '@'; break;
167                         default: c.s << '?'; break;
168                 }
169                 c.s << '(';
170                 for (unsigned i=0; i<children.size(); i++) {
171                         children[i].print(c);
172                         if (i != children.size() - 1)
173                                 c.s << ",";
174                 }
175                 c.s << ')';
176         }
177 }
178
179 //////////
180 // non-virtual functions in this class
181 //////////
182
183 symmetry &symmetry::add(const symmetry &c)
184 {
185         // All children must have the same number of indices
186         if (type != none && !children.empty()) {
187                 GINAC_ASSERT(is_ex_exactly_of_type(children[0], symmetry));
188                 if (ex_to_symmetry(children[0]).indices.size() != c.indices.size())
189                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
190         }
191
192         // Compute union of indices and check whether the two sets are disjoint
193         std::set<unsigned> un;
194         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
195         if (un.size() != indices.size() + c.indices.size())
196                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
197
198         // Set new index set
199         indices.swap(un);
200
201         // Add child node
202         children.push_back(c);
203         return *this;
204 }
205
206 void symmetry::validate(unsigned n)
207 {
208         if (indices.upper_bound(n - 1) != indices.end())
209                 throw (std::range_error("symmetry::verify(): index values are out of range"));
210         if (type != none && indices.empty()) {
211                 for (unsigned i=0; i<n; i++)
212                         add(i);
213         }
214 }
215
216 //////////
217 // global functions
218 //////////
219
220 class sy_is_less : public std::binary_function<ex, ex, bool> {
221         exvector::iterator v;
222
223 public:
224         sy_is_less(exvector::iterator v_) : v(v_) {}
225
226         bool operator() (const ex &lh, const ex &rh) const
227         {
228                 GINAC_ASSERT(is_ex_exactly_of_type(lh, symmetry));
229                 GINAC_ASSERT(is_ex_exactly_of_type(rh, symmetry));
230                 GINAC_ASSERT(ex_to_symmetry(lh).indices.size() == ex_to_symmetry(rh).indices.size());
231                 std::set<unsigned>::const_iterator ait = ex_to_symmetry(lh).indices.begin(), aitend = ex_to_symmetry(lh).indices.end(), bit = ex_to_symmetry(rh).indices.begin();
232                 while (ait != aitend) {
233                         int cmpval = v[*ait].compare(v[*bit]);
234                         if (cmpval < 0)
235                                 return true;
236                         else if (cmpval > 0)
237                                 return false;
238                         ++ait; ++bit;
239                 }
240                 return false;
241         }
242 };
243
244 class sy_swap : public std::binary_function<ex, ex, void> {
245         exvector::iterator v;
246
247 public:
248         bool &swapped;
249
250         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
251
252         void operator() (const ex &lh, const ex &rh)
253         {
254                 GINAC_ASSERT(is_ex_exactly_of_type(lh, symmetry));
255                 GINAC_ASSERT(is_ex_exactly_of_type(rh, symmetry));
256                 GINAC_ASSERT(ex_to_symmetry(lh).indices.size() == ex_to_symmetry(rh).indices.size());
257                 std::set<unsigned>::const_iterator ait = ex_to_symmetry(lh).indices.begin(), aitend = ex_to_symmetry(lh).indices.end(), bit = ex_to_symmetry(rh).indices.begin();
258                 while (ait != aitend) {
259                         v[*ait].swap(v[*bit]);
260                         ++ait; ++bit;
261                 }
262                 swapped = true;
263         }
264 };
265
266 int canonicalize(exvector::iterator v, const symmetry &symm)
267 {
268         // No children? Then do nothing
269         if (symm.children.empty())
270                 return INT_MAX;
271
272         // Canonicalize children first
273         bool something_changed = false;
274         int sign = 1;
275         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
276         while (first != last) {
277                 GINAC_ASSERT(is_ex_exactly_of_type(*first, symmetry));
278                 int child_sign = canonicalize(v, ex_to_symmetry(*first));
279                 if (child_sign == 0)
280                         return 0;
281                 if (child_sign != INT_MAX) {
282                         something_changed = true;
283                         sign *= child_sign;
284                 }
285                 first++;
286         }
287
288         // Now reorder the children
289         first = symm.children.begin();
290         switch (symm.type) {
291                 case symmetry::symmetric:
292                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
293                         break;
294                 case symmetry::antisymmetric:
295                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
296                         break;
297                 case symmetry::cyclic:
298                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
299                         break;
300                 default:
301                         break;
302         }
303         return something_changed ? sign : INT_MAX;
304 }
305
306
307 // Symmetrize/antisymmetrize over a vector of objects
308 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
309 {
310         // Need at least 2 objects for this operation
311         int num = last - first;
312         if (num < 2)
313                 return e;
314
315         // Transform object vector to a list
316         exlist iv_lst;
317         iv_lst.insert(iv_lst.begin(), first, last);
318         lst orig_lst(iv_lst, true);
319
320         // Create index vectors for permutation
321         unsigned *iv = new unsigned[num], *iv2;
322         for (unsigned i=0; i<num; i++)
323                 iv[i] = i;
324         iv2 = (asymmetric ? new unsigned[num] : NULL);
325
326         // Loop over all permutations (the first permutation, which is the
327         // identity, is unrolled)
328         ex sum = e;
329         while (std::next_permutation(iv, iv + num)) {
330                 lst new_lst;
331                 for (unsigned i=0; i<num; i++)
332                         new_lst.append(orig_lst.op(iv[i]));
333                 ex term = e.subs(orig_lst, new_lst);
334                 if (asymmetric) {
335                         memcpy(iv2, iv, num * sizeof(unsigned));
336                         term *= permutation_sign(iv2, iv2 + num);
337                 }
338                 sum += term;
339         }
340
341         delete[] iv;
342         delete[] iv2;
343
344         return sum / factorial(numeric(num));
345 }
346
347 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
348 {
349         return symm(e, first, last, false);
350 }
351
352 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
353 {
354         return symm(e, first, last, true);
355 }
356
357 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
358 {
359         // Need at least 2 objects for this operation
360         int num = last - first;
361         if (num < 2)
362                 return e;
363
364         // Transform object vector to a list
365         exlist iv_lst;
366         iv_lst.insert(iv_lst.begin(), first, last);
367         lst orig_lst(iv_lst, true);
368         lst new_lst = orig_lst;
369
370         // Loop over all cyclic permutations (the first permutation, which is
371         // the identity, is unrolled)
372         ex sum = e;
373         for (unsigned i=0; i<num-1; i++) {
374                 ex perm = new_lst.op(0);
375                 new_lst.remove_first().append(perm);
376                 sum += e.subs(orig_lst, new_lst);
377         }
378         return sum / num;
379 }
380
381 /** Symmetrize expression over a list of objects (symbols, indices). */
382 ex ex::symmetrize(const lst & l) const
383 {
384         exvector v;
385         v.reserve(l.nops());
386         for (unsigned i=0; i<l.nops(); i++)
387                 v.push_back(l.op(i));
388         return symm(*this, v.begin(), v.end(), false);
389 }
390
391 /** Antisymmetrize expression over a list of objects (symbols, indices). */
392 ex ex::antisymmetrize(const lst & l) const
393 {
394         exvector v;
395         v.reserve(l.nops());
396         for (unsigned i=0; i<l.nops(); i++)
397                 v.push_back(l.op(i));
398         return symm(*this, v.begin(), v.end(), true);
399 }
400
401 /** Symmetrize expression by cyclic permutation over a list of objects
402  *  (symbols, indices). */
403 ex ex::symmetrize_cyclic(const lst & l) const
404 {
405         exvector v;
406         v.reserve(l.nops());
407         for (unsigned i=0; i<l.nops(); i++)
408                 v.push_back(l.op(i));
409         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
410 }
411
412 } // namespace GiNaC