6a50fe1dcebcc1563da2a74f14da58157b933655
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "operators.h"
31 #include "print.h"
32 #include "archive.h"
33 #include "utils.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(symmetry, basic,
38   print_func<print_context>(&symmetry::do_print).
39   print_func<print_tree>(&symmetry::do_print_tree))
40
41 /*
42    Some notes about the structure of a symmetry tree:
43     - The leaf nodes of the tree are of type "none", have one index, and no
44       children (of course). They are constructed by the symmetry(unsigned)
45       constructor.
46     - Leaf nodes are the only nodes that only have one index.
47     - Container nodes contain two or more children. The "indices" set member
48       is the set union of the index sets of all children, and the "children"
49       vector stores the children themselves.
50     - The index set of each child of a "symm", "anti" or "cycl" node must
51       have the same size. It follows that the children of such a node are
52       either all leaf nodes, or all container nodes with two or more indices.
53 */
54
55 //////////
56 // default constructor
57 //////////
58
59 symmetry::symmetry() : type(none)
60 {
61         tinfo_key = TINFO_symmetry;
62 }
63
64 //////////
65 // other constructors
66 //////////
67
68 symmetry::symmetry(unsigned i) : type(none)
69 {
70         indices.insert(i);
71         tinfo_key = TINFO_symmetry;
72 }
73
74 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : type(t)
75 {
76         add(c1); add(c2);
77         tinfo_key = TINFO_symmetry;
78 }
79
80 //////////
81 // archiving
82 //////////
83
84 /** Construct object from archive_node. */
85 symmetry::symmetry(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
86 {
87         unsigned t;
88         if (!(n.find_unsigned("type", t)))
89                 throw (std::runtime_error("unknown symmetry type in archive"));
90         type = (symmetry_type)t;
91
92         unsigned i = 0;
93         while (true) {
94                 ex e;
95                 if (n.find_ex("child", e, sym_lst, i))
96                         add(ex_to<symmetry>(e));
97                 else
98                         break;
99                 i++;
100         }
101
102         if (i == 0) {
103                 while (true) {
104                         unsigned u;
105                         if (n.find_unsigned("index", u, i))
106                                 indices.insert(u);
107                         else
108                                 break;
109                         i++;
110                 }
111         }
112 }
113
114 /** Archive the object. */
115 void symmetry::archive(archive_node &n) const
116 {
117         inherited::archive(n);
118
119         n.add_unsigned("type", type);
120
121         if (children.empty()) {
122                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
123                 while (i != iend) {
124                         n.add_unsigned("index", *i);
125                         i++;
126                 }
127         } else {
128                 exvector::const_iterator i = children.begin(), iend = children.end();
129                 while (i != iend) {
130                         n.add_ex("child", *i);
131                         i++;
132                 }
133         }
134 }
135
136 DEFAULT_UNARCHIVE(symmetry)
137
138 //////////
139 // functions overriding virtual functions from base classes
140 //////////
141
142 int symmetry::compare_same_type(const basic & other) const
143 {
144         GINAC_ASSERT(is_a<symmetry>(other));
145
146         // All symmetry trees are equal. They are not supposed to appear in
147         // ordinary expressions anyway...
148         return 0;
149 }
150
151 void symmetry::do_print(const print_context & c, unsigned level) const
152 {
153         if (children.empty()) {
154                 if (indices.size() > 0)
155                         c.s << *(indices.begin());
156                 else
157                         c.s << "none";
158         } else {
159                 switch (type) {
160                         case none: c.s << '!'; break;
161                         case symmetric: c.s << '+'; break;
162                         case antisymmetric: c.s << '-'; break;
163                         case cyclic: c.s << '@'; break;
164                         default: c.s << '?'; break;
165                 }
166                 c.s << '(';
167                 size_t num = children.size();
168                 for (size_t i=0; i<num; i++) {
169                         children[i].print(c);
170                         if (i != num - 1)
171                                 c.s << ",";
172                 }
173                 c.s << ')';
174         }
175 }
176
177 void symmetry::do_print_tree(const print_tree & c, unsigned level) const
178 {
179         c.s << std::string(level, ' ') << class_name()
180             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
181             << ", type=";
182
183         switch (type) {
184                 case none: c.s << "none"; break;
185                 case symmetric: c.s << "symm"; break;
186                 case antisymmetric: c.s << "anti"; break;
187                 case cyclic: c.s << "cycl"; break;
188                 default: c.s << "<unknown>"; break;
189         }
190
191         c.s << ", indices=(";
192         if (!indices.empty()) {
193                 std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
194                 --end;
195                 while (i != end)
196                         c.s << *i++ << ",";
197                 c.s << *i;
198         }
199         c.s << ")\n";
200
201         exvector::const_iterator i = children.begin(), end = children.end();
202         while (i != end) {
203                 i->print(c, level + c.delta_indent);
204                 ++i;
205         }
206 }
207
208 //////////
209 // non-virtual functions in this class
210 //////////
211
212 symmetry &symmetry::add(const symmetry &c)
213 {
214         // All children must have the same number of indices
215         if (type != none && !children.empty()) {
216                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
217                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
218                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
219         }
220
221         // Compute union of indices and check whether the two sets are disjoint
222         std::set<unsigned> un;
223         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
224         if (un.size() != indices.size() + c.indices.size())
225                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
226
227         // Set new index set
228         indices.swap(un);
229
230         // Add child node
231         children.push_back(c);
232         return *this;
233 }
234
235 void symmetry::validate(unsigned n)
236 {
237         if (indices.upper_bound(n - 1) != indices.end())
238                 throw (std::range_error("symmetry::verify(): index values are out of range"));
239         if (type != none && indices.empty()) {
240                 for (unsigned i=0; i<n; i++)
241                         add(i);
242         }
243 }
244
245 //////////
246 // global functions
247 //////////
248
249 class sy_is_less : public std::binary_function<ex, ex, bool> {
250         exvector::iterator v;
251
252 public:
253         sy_is_less(exvector::iterator v_) : v(v_) {}
254
255         bool operator() (const ex &lh, const ex &rh) const
256         {
257                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
258                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
259                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
260                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
261                 while (ait != aitend) {
262                         int cmpval = v[*ait].compare(v[*bit]);
263                         if (cmpval < 0)
264                                 return true;
265                         else if (cmpval > 0)
266                                 return false;
267                         ++ait; ++bit;
268                 }
269                 return false;
270         }
271 };
272
273 class sy_swap : public std::binary_function<ex, ex, void> {
274         exvector::iterator v;
275
276 public:
277         bool &swapped;
278
279         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
280
281         void operator() (const ex &lh, const ex &rh)
282         {
283                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
284                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
285                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
286                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
287                 while (ait != aitend) {
288                         v[*ait].swap(v[*bit]);
289                         ++ait; ++bit;
290                 }
291                 swapped = true;
292         }
293 };
294
295 int canonicalize(exvector::iterator v, const symmetry &symm)
296 {
297         // Less than two elements? Then do nothing
298         if (symm.indices.size() < 2)
299                 return INT_MAX;
300
301         // Canonicalize children first
302         bool something_changed = false;
303         int sign = 1;
304         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
305         while (first != last) {
306                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
307                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
308                 if (child_sign == 0)
309                         return 0;
310                 if (child_sign != INT_MAX) {
311                         something_changed = true;
312                         sign *= child_sign;
313                 }
314                 first++;
315         }
316
317         // Now reorder the children
318         first = symm.children.begin();
319         switch (symm.type) {
320                 case symmetry::symmetric:
321                         // Sort the children in ascending order
322                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
323                         break;
324                 case symmetry::antisymmetric:
325                         // Sort the children in ascending order, keeping track of the signum
326                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
327                         if (sign == 0)
328                                 return 0;
329                         break;
330                 case symmetry::cyclic:
331                         // Permute the smallest child to the front
332                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
333                         break;
334                 default:
335                         break;
336         }
337         return something_changed ? sign : INT_MAX;
338 }
339
340
341 // Symmetrize/antisymmetrize over a vector of objects
342 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
343 {
344         // Need at least 2 objects for this operation
345         unsigned num = last - first;
346         if (num < 2)
347                 return e;
348
349         // Transform object vector to a lst (for subs())
350         lst orig_lst(first, last);
351
352         // Create index vectors for permutation
353         unsigned *iv = new unsigned[num], *iv2;
354         for (unsigned i=0; i<num; i++)
355                 iv[i] = i;
356         iv2 = (asymmetric ? new unsigned[num] : NULL);
357
358         // Loop over all permutations (the first permutation, which is the
359         // identity, is unrolled)
360         ex sum = e;
361         while (std::next_permutation(iv, iv + num)) {
362                 lst new_lst;
363                 for (unsigned i=0; i<num; i++)
364                         new_lst.append(orig_lst.op(iv[i]));
365                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern);
366                 if (asymmetric) {
367                         memcpy(iv2, iv, num * sizeof(unsigned));
368                         term *= permutation_sign(iv2, iv2 + num);
369                 }
370                 sum += term;
371         }
372
373         delete[] iv;
374         delete[] iv2;
375
376         return sum / factorial(numeric(num));
377 }
378
379 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
380 {
381         return symm(e, first, last, false);
382 }
383
384 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
385 {
386         return symm(e, first, last, true);
387 }
388
389 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
390 {
391         // Need at least 2 objects for this operation
392         unsigned num = last - first;
393         if (num < 2)
394                 return e;
395
396         // Transform object vector to a lst (for subs())
397         lst orig_lst(first, last);
398         lst new_lst = orig_lst;
399
400         // Loop over all cyclic permutations (the first permutation, which is
401         // the identity, is unrolled)
402         ex sum = e;
403         for (unsigned i=0; i<num-1; i++) {
404                 ex perm = new_lst.op(0);
405                 new_lst.remove_first().append(perm);
406                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern);
407         }
408         return sum / num;
409 }
410
411 /** Symmetrize expression over a list of objects (symbols, indices). */
412 ex ex::symmetrize(const lst & l) const
413 {
414         exvector v(l.begin(), l.end());
415         return symm(*this, v.begin(), v.end(), false);
416 }
417
418 /** Antisymmetrize expression over a list of objects (symbols, indices). */
419 ex ex::antisymmetrize(const lst & l) const
420 {
421         exvector v(l.begin(), l.end());
422         return symm(*this, v.begin(), v.end(), true);
423 }
424
425 /** Symmetrize expression by cyclic permutation over a list of objects
426  *  (symbols, indices). */
427 ex ex::symmetrize_cyclic(const lst & l) const
428 {
429         exvector v(l.begin(), l.end());
430         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
431 }
432
433 } // namespace GiNaC