fixed some -Wall stuff
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <stdexcept>
24 #include <functional>
25 #include <algorithm>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "print.h"
31 #include "archive.h"
32 #include "utils.h"
33 #include "debugmsg.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS(symmetry, basic)
38
39 //////////
40 // default constructor, destructor, copy constructor assignment operator and helpers
41 //////////
42
43 symmetry::symmetry() : type(none)
44 {
45         debugmsg("symmetry default constructor", LOGLEVEL_CONSTRUCT);
46         tinfo_key = TINFO_symmetry;
47 }
48
49 void symmetry::copy(const symmetry & other)
50 {
51         inherited::copy(other);
52         type = other.type;
53         indices = other.indices;
54         children = other.children;
55 }
56
57 DEFAULT_DESTROY(symmetry)
58
59 //////////
60 // other constructors
61 //////////
62
63 symmetry::symmetry(unsigned i) : type(none)
64 {
65         debugmsg("symmetry constructor from unsigned", LOGLEVEL_CONSTRUCT);
66         indices.insert(i);
67         tinfo_key = TINFO_symmetry;
68 }
69
70 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : type(t)
71 {
72         debugmsg("symmetry constructor from symmetry_type,symmetry &,symmetry &", LOGLEVEL_CONSTRUCT);
73         add(c1); add(c2);
74         tinfo_key = TINFO_symmetry;
75 }
76
77 //////////
78 // archiving
79 //////////
80
81 /** Construct object from archive_node. */
82 symmetry::symmetry(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
83 {
84         debugmsg("symmetry ctor from archive_node", LOGLEVEL_CONSTRUCT);
85
86         unsigned t;
87         if (!(n.find_unsigned("type", t)))
88                 throw (std::runtime_error("unknown symmetry type in archive"));
89         type = (symmetry_type)t;
90
91         unsigned i = 0;
92         while (true) {
93                 ex e;
94                 if (n.find_ex("child", e, sym_lst, i))
95                         add(ex_to<symmetry>(e));
96                 else
97                         break;
98                 i++;
99         }
100
101         if (i == 0) {
102                 while (true) {
103                         unsigned u;
104                         if (n.find_unsigned("index", u, i))
105                                 indices.insert(u);
106                         else
107                                 break;
108                         i++;
109                 }
110         }
111 }
112
113 /** Archive the object. */
114 void symmetry::archive(archive_node &n) const
115 {
116         inherited::archive(n);
117
118         n.add_unsigned("type", type);
119
120         if (children.empty()) {
121                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
122                 while (i != iend) {
123                         n.add_unsigned("index", *i);
124                         i++;
125                 }
126         } else {
127                 exvector::const_iterator i = children.begin(), iend = children.end();
128                 while (i != iend) {
129                         n.add_ex("child", *i);
130                         i++;
131                 }
132         }
133 }
134
135 DEFAULT_UNARCHIVE(symmetry)
136
137 //////////
138 // functions overriding virtual functions from bases classes
139 //////////
140
141 int symmetry::compare_same_type(const basic & other) const
142 {
143         GINAC_ASSERT(is_of_type(other, symmetry));
144
145         // All symmetry trees are equal. They are not supposed to appear in
146         // ordinary expressions anyway...
147         return 0;
148 }
149
150 void symmetry::print(const print_context & c, unsigned level = 0) const
151 {
152         debugmsg("symmetry print", LOGLEVEL_PRINT);
153
154         if (children.empty()) {
155                 if (indices.size() > 0)
156                         c.s << *(indices.begin());
157         } else {
158                 switch (type) {
159                         case none: c.s << '!'; break;
160                         case symmetric: c.s << '+'; break;
161                         case antisymmetric: c.s << '-'; break;
162                         case cyclic: c.s << '@'; break;
163                         default: c.s << '?'; break;
164                 }
165                 c.s << '(';
166                 unsigned num = children.size();
167                 for (unsigned i=0; i<num; i++) {
168                         children[i].print(c);
169                         if (i != num - 1)
170                                 c.s << ",";
171                 }
172                 c.s << ')';
173         }
174 }
175
176 //////////
177 // non-virtual functions in this class
178 //////////
179
180 symmetry &symmetry::add(const symmetry &c)
181 {
182         // All children must have the same number of indices
183         if (type != none && !children.empty()) {
184                 GINAC_ASSERT(is_ex_exactly_of_type(children[0], symmetry));
185                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
186                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
187         }
188
189         // Compute union of indices and check whether the two sets are disjoint
190         std::set<unsigned> un;
191         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
192         if (un.size() != indices.size() + c.indices.size())
193                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
194
195         // Set new index set
196         indices.swap(un);
197
198         // Add child node
199         children.push_back(c);
200         return *this;
201 }
202
203 void symmetry::validate(unsigned n)
204 {
205         if (indices.upper_bound(n - 1) != indices.end())
206                 throw (std::range_error("symmetry::verify(): index values are out of range"));
207         if (type != none && indices.empty()) {
208                 for (unsigned i=0; i<n; i++)
209                         add(i);
210         }
211 }
212
213 //////////
214 // global functions
215 //////////
216
217 class sy_is_less : public std::binary_function<ex, ex, bool> {
218         exvector::iterator v;
219
220 public:
221         sy_is_less(exvector::iterator v_) : v(v_) {}
222
223         bool operator() (const ex &lh, const ex &rh) const
224         {
225                 GINAC_ASSERT(is_ex_exactly_of_type(lh, symmetry));
226                 GINAC_ASSERT(is_ex_exactly_of_type(rh, symmetry));
227                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
228                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
229                 while (ait != aitend) {
230                         int cmpval = v[*ait].compare(v[*bit]);
231                         if (cmpval < 0)
232                                 return true;
233                         else if (cmpval > 0)
234                                 return false;
235                         ++ait; ++bit;
236                 }
237                 return false;
238         }
239 };
240
241 class sy_swap : public std::binary_function<ex, ex, void> {
242         exvector::iterator v;
243
244 public:
245         bool &swapped;
246
247         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
248
249         void operator() (const ex &lh, const ex &rh)
250         {
251                 GINAC_ASSERT(is_ex_exactly_of_type(lh, symmetry));
252                 GINAC_ASSERT(is_ex_exactly_of_type(rh, symmetry));
253                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
254                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
255                 while (ait != aitend) {
256                         v[*ait].swap(v[*bit]);
257                         ++ait; ++bit;
258                 }
259                 swapped = true;
260         }
261 };
262
263 int canonicalize(exvector::iterator v, const symmetry &symm)
264 {
265         // No children? Then do nothing
266         if (symm.children.empty())
267                 return INT_MAX;
268
269         // Canonicalize children first
270         bool something_changed = false;
271         int sign = 1;
272         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
273         while (first != last) {
274                 GINAC_ASSERT(is_ex_exactly_of_type(*first, symmetry));
275                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
276                 if (child_sign == 0)
277                         return 0;
278                 if (child_sign != INT_MAX) {
279                         something_changed = true;
280                         sign *= child_sign;
281                 }
282                 first++;
283         }
284
285         // Now reorder the children
286         first = symm.children.begin();
287         switch (symm.type) {
288                 case symmetry::symmetric:
289                         // Sort the children in ascending order
290                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
291                         break;
292                 case symmetry::antisymmetric:
293                         // Sort the children in ascending order, keeping track of the signum
294                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
295                         break;
296                 case symmetry::cyclic:
297                         // Permute the smallest child to the front
298                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
299                         break;
300                 default:
301                         break;
302         }
303         return something_changed ? sign : INT_MAX;
304 }
305
306
307 // Symmetrize/antisymmetrize over a vector of objects
308 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
309 {
310         // Need at least 2 objects for this operation
311         unsigned num = last - first;
312         if (num < 2)
313                 return e;
314
315         // Transform object vector to a list
316         exlist iv_lst;
317         iv_lst.insert(iv_lst.begin(), first, last);
318         lst orig_lst(iv_lst, true);
319
320         // Create index vectors for permutation
321         unsigned *iv = new unsigned[num], *iv2;
322         for (unsigned i=0; i<num; i++)
323                 iv[i] = i;
324         iv2 = (asymmetric ? new unsigned[num] : NULL);
325
326         // Loop over all permutations (the first permutation, which is the
327         // identity, is unrolled)
328         ex sum = e;
329         while (std::next_permutation(iv, iv + num)) {
330                 lst new_lst;
331                 for (unsigned i=0; i<num; i++)
332                         new_lst.append(orig_lst.op(iv[i]));
333                 ex term = e.subs(orig_lst, new_lst);
334                 if (asymmetric) {
335                         memcpy(iv2, iv, num * sizeof(unsigned));
336                         term *= permutation_sign(iv2, iv2 + num);
337                 }
338                 sum += term;
339         }
340
341         delete[] iv;
342         delete[] iv2;
343
344         return sum / factorial(numeric(num));
345 }
346
347 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
348 {
349         return symm(e, first, last, false);
350 }
351
352 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
353 {
354         return symm(e, first, last, true);
355 }
356
357 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
358 {
359         // Need at least 2 objects for this operation
360         unsigned num = last - first;
361         if (num < 2)
362                 return e;
363
364         // Transform object vector to a list
365         exlist iv_lst;
366         iv_lst.insert(iv_lst.begin(), first, last);
367         lst orig_lst(iv_lst, true);
368         lst new_lst = orig_lst;
369
370         // Loop over all cyclic permutations (the first permutation, which is
371         // the identity, is unrolled)
372         ex sum = e;
373         for (unsigned i=0; i<num-1; i++) {
374                 ex perm = new_lst.op(0);
375                 new_lst.remove_first().append(perm);
376                 sum += e.subs(orig_lst, new_lst);
377         }
378         return sum / num;
379 }
380
381 /** Symmetrize expression over a list of objects (symbols, indices). */
382 ex ex::symmetrize(const lst & l) const
383 {
384         exvector v;
385         v.reserve(l.nops());
386         for (unsigned i=0; i<l.nops(); i++)
387                 v.push_back(l.op(i));
388         return symm(*this, v.begin(), v.end(), false);
389 }
390
391 /** Antisymmetrize expression over a list of objects (symbols, indices). */
392 ex ex::antisymmetrize(const lst & l) const
393 {
394         exvector v;
395         v.reserve(l.nops());
396         for (unsigned i=0; i<l.nops(); i++)
397                 v.push_back(l.op(i));
398         return symm(*this, v.begin(), v.end(), true);
399 }
400
401 /** Symmetrize expression by cyclic permutation over a list of objects
402  *  (symbols, indices). */
403 ex ex::symmetrize_cyclic(const lst & l) const
404 {
405         exvector v;
406         v.reserve(l.nops());
407         for (unsigned i=0; i<l.nops(); i++)
408                 v.push_back(l.op(i));
409         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
410 }
411
412 } // namespace GiNaC