some cleanups
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <stdexcept>
24 #include <functional>
25 #include <algorithm>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "print.h"
31 #include "archive.h"
32 #include "utils.h"
33 #include "debugmsg.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS(symmetry, basic)
38
39 //////////
40 // default constructor, destructor, copy constructor assignment operator and helpers
41 //////////
42
43 symmetry::symmetry() : type(none)
44 {
45         debugmsg("symmetry default constructor", LOGLEVEL_CONSTRUCT);
46         tinfo_key = TINFO_symmetry;
47 }
48
49 void symmetry::copy(const symmetry & other)
50 {
51         inherited::copy(other);
52         type = other.type;
53         indices = other.indices;
54         children = other.children;
55 }
56
57 DEFAULT_DESTROY(symmetry)
58
59 //////////
60 // other constructors
61 //////////
62
63 symmetry::symmetry(unsigned i) : type(none)
64 {
65         debugmsg("symmetry constructor from unsigned", LOGLEVEL_CONSTRUCT);
66         indices.insert(i);
67         tinfo_key = TINFO_symmetry;
68 }
69
70 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : type(t)
71 {
72         debugmsg("symmetry constructor from symmetry_type,symmetry &,symmetry &", LOGLEVEL_CONSTRUCT);
73         add(c1); add(c2);
74         tinfo_key = TINFO_symmetry;
75 }
76
77 //////////
78 // archiving
79 //////////
80
81 /** Construct object from archive_node. */
82 symmetry::symmetry(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
83 {
84         debugmsg("symmetry ctor from archive_node", LOGLEVEL_CONSTRUCT);
85
86         unsigned t;
87         if (!(n.find_unsigned("type", t)))
88                 throw (std::runtime_error("unknown symmetry type in archive"));
89         type = (symmetry_type)t;
90
91         unsigned i = 0;
92         while (true) {
93                 ex e;
94                 if (n.find_ex("child", e, sym_lst, i))
95                         add(ex_to<symmetry>(e));
96                 else
97                         break;
98                 i++;
99         }
100
101         if (i == 0) {
102                 while (true) {
103                         unsigned u;
104                         if (n.find_unsigned("index", u, i))
105                                 indices.insert(u);
106                         else
107                                 break;
108                         i++;
109                 }
110         }
111 }
112
113 /** Archive the object. */
114 void symmetry::archive(archive_node &n) const
115 {
116         inherited::archive(n);
117
118         n.add_unsigned("type", type);
119
120         if (children.empty()) {
121                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
122                 while (i != iend) {
123                         n.add_unsigned("index", *i);
124                         i++;
125                 }
126         } else {
127                 exvector::const_iterator i = children.begin(), iend = children.end();
128                 while (i != iend) {
129                         n.add_ex("child", *i);
130                         i++;
131                 }
132         }
133 }
134
135 DEFAULT_UNARCHIVE(symmetry)
136
137 //////////
138 // functions overriding virtual functions from bases classes
139 //////////
140
141 int symmetry::compare_same_type(const basic & other) const
142 {
143         GINAC_ASSERT(is_of_type(other, symmetry));
144         const symmetry &o = static_cast<const symmetry &>(other);
145
146         // All symmetry trees are equal. They are not supposed to appear in
147         // ordinary expressions anyway...
148         return 0;
149 }
150
151 void symmetry::print(const print_context & c, unsigned level = 0) const
152 {
153         debugmsg("symmetry print", LOGLEVEL_PRINT);
154
155         if (children.empty()) {
156                 if (indices.size() > 0)
157                         c.s << *(indices.begin());
158         } else {
159                 switch (type) {
160                         case none: c.s << '!'; break;
161                         case symmetric: c.s << '+'; break;
162                         case antisymmetric: c.s << '-'; break;
163                         case cyclic: c.s << '@'; break;
164                         default: c.s << '?'; break;
165                 }
166                 c.s << '(';
167                 unsigned num = children.size();
168                 for (unsigned i=0; i<num; i++) {
169                         children[i].print(c);
170                         if (i != num - 1)
171                                 c.s << ",";
172                 }
173                 c.s << ')';
174         }
175 }
176
177 //////////
178 // non-virtual functions in this class
179 //////////
180
181 symmetry &symmetry::add(const symmetry &c)
182 {
183         // All children must have the same number of indices
184         if (type != none && !children.empty()) {
185                 GINAC_ASSERT(is_ex_exactly_of_type(children[0], symmetry));
186                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
187                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
188         }
189
190         // Compute union of indices and check whether the two sets are disjoint
191         std::set<unsigned> un;
192         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
193         if (un.size() != indices.size() + c.indices.size())
194                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
195
196         // Set new index set
197         indices.swap(un);
198
199         // Add child node
200         children.push_back(c);
201         return *this;
202 }
203
204 void symmetry::validate(unsigned n)
205 {
206         if (indices.upper_bound(n - 1) != indices.end())
207                 throw (std::range_error("symmetry::verify(): index values are out of range"));
208         if (type != none && indices.empty()) {
209                 for (unsigned i=0; i<n; i++)
210                         add(i);
211         }
212 }
213
214 //////////
215 // global functions
216 //////////
217
218 class sy_is_less : public std::binary_function<ex, ex, bool> {
219         exvector::iterator v;
220
221 public:
222         sy_is_less(exvector::iterator v_) : v(v_) {}
223
224         bool operator() (const ex &lh, const ex &rh) const
225         {
226                 GINAC_ASSERT(is_ex_exactly_of_type(lh, symmetry));
227                 GINAC_ASSERT(is_ex_exactly_of_type(rh, symmetry));
228                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
229                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
230                 while (ait != aitend) {
231                         int cmpval = v[*ait].compare(v[*bit]);
232                         if (cmpval < 0)
233                                 return true;
234                         else if (cmpval > 0)
235                                 return false;
236                         ++ait; ++bit;
237                 }
238                 return false;
239         }
240 };
241
242 class sy_swap : public std::binary_function<ex, ex, void> {
243         exvector::iterator v;
244
245 public:
246         bool &swapped;
247
248         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
249
250         void operator() (const ex &lh, const ex &rh)
251         {
252                 GINAC_ASSERT(is_ex_exactly_of_type(lh, symmetry));
253                 GINAC_ASSERT(is_ex_exactly_of_type(rh, symmetry));
254                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
255                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
256                 while (ait != aitend) {
257                         v[*ait].swap(v[*bit]);
258                         ++ait; ++bit;
259                 }
260                 swapped = true;
261         }
262 };
263
264 int canonicalize(exvector::iterator v, const symmetry &symm)
265 {
266         // No children? Then do nothing
267         if (symm.children.empty())
268                 return INT_MAX;
269
270         // Canonicalize children first
271         bool something_changed = false;
272         int sign = 1;
273         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
274         while (first != last) {
275                 GINAC_ASSERT(is_ex_exactly_of_type(*first, symmetry));
276                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
277                 if (child_sign == 0)
278                         return 0;
279                 if (child_sign != INT_MAX) {
280                         something_changed = true;
281                         sign *= child_sign;
282                 }
283                 first++;
284         }
285
286         // Now reorder the children
287         first = symm.children.begin();
288         switch (symm.type) {
289                 case symmetry::symmetric:
290                         // Sort the children in ascending order
291                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
292                         break;
293                 case symmetry::antisymmetric:
294                         // Sort the children in ascending order, keeping track of the signum
295                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
296                         break;
297                 case symmetry::cyclic:
298                         // Permute the smallest child to the front
299                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
300                         break;
301                 default:
302                         break;
303         }
304         return something_changed ? sign : INT_MAX;
305 }
306
307
308 // Symmetrize/antisymmetrize over a vector of objects
309 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
310 {
311         // Need at least 2 objects for this operation
312         int num = last - first;
313         if (num < 2)
314                 return e;
315
316         // Transform object vector to a list
317         exlist iv_lst;
318         iv_lst.insert(iv_lst.begin(), first, last);
319         lst orig_lst(iv_lst, true);
320
321         // Create index vectors for permutation
322         unsigned *iv = new unsigned[num], *iv2;
323         for (unsigned i=0; i<num; i++)
324                 iv[i] = i;
325         iv2 = (asymmetric ? new unsigned[num] : NULL);
326
327         // Loop over all permutations (the first permutation, which is the
328         // identity, is unrolled)
329         ex sum = e;
330         while (std::next_permutation(iv, iv + num)) {
331                 lst new_lst;
332                 for (unsigned i=0; i<num; i++)
333                         new_lst.append(orig_lst.op(iv[i]));
334                 ex term = e.subs(orig_lst, new_lst);
335                 if (asymmetric) {
336                         memcpy(iv2, iv, num * sizeof(unsigned));
337                         term *= permutation_sign(iv2, iv2 + num);
338                 }
339                 sum += term;
340         }
341
342         delete[] iv;
343         delete[] iv2;
344
345         return sum / factorial(numeric(num));
346 }
347
348 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
349 {
350         return symm(e, first, last, false);
351 }
352
353 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
354 {
355         return symm(e, first, last, true);
356 }
357
358 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
359 {
360         // Need at least 2 objects for this operation
361         int num = last - first;
362         if (num < 2)
363                 return e;
364
365         // Transform object vector to a list
366         exlist iv_lst;
367         iv_lst.insert(iv_lst.begin(), first, last);
368         lst orig_lst(iv_lst, true);
369         lst new_lst = orig_lst;
370
371         // Loop over all cyclic permutations (the first permutation, which is
372         // the identity, is unrolled)
373         ex sum = e;
374         for (unsigned i=0; i<num-1; i++) {
375                 ex perm = new_lst.op(0);
376                 new_lst.remove_first().append(perm);
377                 sum += e.subs(orig_lst, new_lst);
378         }
379         return sum / num;
380 }
381
382 /** Symmetrize expression over a list of objects (symbols, indices). */
383 ex ex::symmetrize(const lst & l) const
384 {
385         exvector v;
386         v.reserve(l.nops());
387         for (unsigned i=0; i<l.nops(); i++)
388                 v.push_back(l.op(i));
389         return symm(*this, v.begin(), v.end(), false);
390 }
391
392 /** Antisymmetrize expression over a list of objects (symbols, indices). */
393 ex ex::antisymmetrize(const lst & l) const
394 {
395         exvector v;
396         v.reserve(l.nops());
397         for (unsigned i=0; i<l.nops(); i++)
398                 v.push_back(l.op(i));
399         return symm(*this, v.begin(), v.end(), true);
400 }
401
402 /** Symmetrize expression by cyclic permutation over a list of objects
403  *  (symbols, indices). */
404 ex ex::symmetrize_cyclic(const lst & l) const
405 {
406         exvector v;
407         v.reserve(l.nops());
408         for (unsigned i=0; i<l.nops(); i++)
409                 v.push_back(l.op(i));
410         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
411 }
412
413 } // namespace GiNaC