generous use of subs_options::no_pattern
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "operators.h"
31 #include "print.h"
32 #include "archive.h"
33 #include "utils.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS(symmetry, basic)
38
39 /*
40    Some notes about the structure of a symmetry tree:
41     - The leaf nodes of the tree are of type "none", have one index, and no
42       children (of course). They are constructed by the symmetry(unsigned)
43       constructor.
44     - Leaf nodes are the only nodes that only have one index.
45     - Container nodes contain two or more children. The "indices" set member
46       is the set union of the index sets of all children, and the "children"
47       vector stores the children themselves.
48     - The index set of each child of a "symm", "anti" or "cycl" node must
49       have the same size. It follows that the children of such a node are
50       either all leaf nodes, or all container nodes with two or more indices.
51 */
52
53 //////////
54 // default constructor
55 //////////
56
57 symmetry::symmetry() : type(none)
58 {
59         tinfo_key = TINFO_symmetry;
60 }
61
62 //////////
63 // other constructors
64 //////////
65
66 symmetry::symmetry(unsigned i) : type(none)
67 {
68         indices.insert(i);
69         tinfo_key = TINFO_symmetry;
70 }
71
72 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : type(t)
73 {
74         add(c1); add(c2);
75         tinfo_key = TINFO_symmetry;
76 }
77
78 //////////
79 // archiving
80 //////////
81
82 /** Construct object from archive_node. */
83 symmetry::symmetry(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
84 {
85         unsigned t;
86         if (!(n.find_unsigned("type", t)))
87                 throw (std::runtime_error("unknown symmetry type in archive"));
88         type = (symmetry_type)t;
89
90         unsigned i = 0;
91         while (true) {
92                 ex e;
93                 if (n.find_ex("child", e, sym_lst, i))
94                         add(ex_to<symmetry>(e));
95                 else
96                         break;
97                 i++;
98         }
99
100         if (i == 0) {
101                 while (true) {
102                         unsigned u;
103                         if (n.find_unsigned("index", u, i))
104                                 indices.insert(u);
105                         else
106                                 break;
107                         i++;
108                 }
109         }
110 }
111
112 /** Archive the object. */
113 void symmetry::archive(archive_node &n) const
114 {
115         inherited::archive(n);
116
117         n.add_unsigned("type", type);
118
119         if (children.empty()) {
120                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
121                 while (i != iend) {
122                         n.add_unsigned("index", *i);
123                         i++;
124                 }
125         } else {
126                 exvector::const_iterator i = children.begin(), iend = children.end();
127                 while (i != iend) {
128                         n.add_ex("child", *i);
129                         i++;
130                 }
131         }
132 }
133
134 DEFAULT_UNARCHIVE(symmetry)
135
136 //////////
137 // functions overriding virtual functions from base classes
138 //////////
139
140 int symmetry::compare_same_type(const basic & other) const
141 {
142         GINAC_ASSERT(is_a<symmetry>(other));
143
144         // All symmetry trees are equal. They are not supposed to appear in
145         // ordinary expressions anyway...
146         return 0;
147 }
148
149 void symmetry::print(const print_context & c, unsigned level) const
150 {
151         if (is_a<print_tree>(c)) {
152
153                 c.s << std::string(level, ' ') << class_name()
154                     << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
155                     << ", type=";
156
157                 switch (type) {
158                         case none: c.s << "none"; break;
159                         case symmetric: c.s << "symm"; break;
160                         case antisymmetric: c.s << "anti"; break;
161                         case cyclic: c.s << "cycl"; break;
162                         default: c.s << "<unknown>"; break;
163                 }
164
165                 c.s << ", indices=(";
166                 if (!indices.empty()) {
167                         std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
168                         --end;
169                         while (i != end)
170                                 c.s << *i++ << ",";
171                         c.s << *i;
172                 }
173                 c.s << ")\n";
174
175                 unsigned delta_indent = static_cast<const print_tree &>(c).delta_indent;
176                 exvector::const_iterator i = children.begin(), end = children.end();
177                 while (i != end) {
178                         i->print(c, level + delta_indent);
179                         ++i;
180                 }
181
182         } else {
183
184                 if (children.empty()) {
185                         if (indices.size() > 0)
186                                 c.s << *(indices.begin());
187                         else
188                                 c.s << "none";
189                 } else {
190                         switch (type) {
191                                 case none: c.s << '!'; break;
192                                 case symmetric: c.s << '+'; break;
193                                 case antisymmetric: c.s << '-'; break;
194                                 case cyclic: c.s << '@'; break;
195                                 default: c.s << '?'; break;
196                         }
197                         c.s << '(';
198                         size_t num = children.size();
199                         for (size_t i=0; i<num; i++) {
200                                 children[i].print(c);
201                                 if (i != num - 1)
202                                         c.s << ",";
203                         }
204                         c.s << ')';
205                 }
206         }
207 }
208
209 //////////
210 // non-virtual functions in this class
211 //////////
212
213 symmetry &symmetry::add(const symmetry &c)
214 {
215         // All children must have the same number of indices
216         if (type != none && !children.empty()) {
217                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
218                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
219                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
220         }
221
222         // Compute union of indices and check whether the two sets are disjoint
223         std::set<unsigned> un;
224         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
225         if (un.size() != indices.size() + c.indices.size())
226                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
227
228         // Set new index set
229         indices.swap(un);
230
231         // Add child node
232         children.push_back(c);
233         return *this;
234 }
235
236 void symmetry::validate(unsigned n)
237 {
238         if (indices.upper_bound(n - 1) != indices.end())
239                 throw (std::range_error("symmetry::verify(): index values are out of range"));
240         if (type != none && indices.empty()) {
241                 for (unsigned i=0; i<n; i++)
242                         add(i);
243         }
244 }
245
246 //////////
247 // global functions
248 //////////
249
250 class sy_is_less : public std::binary_function<ex, ex, bool> {
251         exvector::iterator v;
252
253 public:
254         sy_is_less(exvector::iterator v_) : v(v_) {}
255
256         bool operator() (const ex &lh, const ex &rh) const
257         {
258                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
259                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
260                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
261                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
262                 while (ait != aitend) {
263                         int cmpval = v[*ait].compare(v[*bit]);
264                         if (cmpval < 0)
265                                 return true;
266                         else if (cmpval > 0)
267                                 return false;
268                         ++ait; ++bit;
269                 }
270                 return false;
271         }
272 };
273
274 class sy_swap : public std::binary_function<ex, ex, void> {
275         exvector::iterator v;
276
277 public:
278         bool &swapped;
279
280         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
281
282         void operator() (const ex &lh, const ex &rh)
283         {
284                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
285                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
286                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
287                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
288                 while (ait != aitend) {
289                         v[*ait].swap(v[*bit]);
290                         ++ait; ++bit;
291                 }
292                 swapped = true;
293         }
294 };
295
296 int canonicalize(exvector::iterator v, const symmetry &symm)
297 {
298         // Less than two elements? Then do nothing
299         if (symm.indices.size() < 2)
300                 return INT_MAX;
301
302         // Canonicalize children first
303         bool something_changed = false;
304         int sign = 1;
305         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
306         while (first != last) {
307                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
308                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
309                 if (child_sign == 0)
310                         return 0;
311                 if (child_sign != INT_MAX) {
312                         something_changed = true;
313                         sign *= child_sign;
314                 }
315                 first++;
316         }
317
318         // Now reorder the children
319         first = symm.children.begin();
320         switch (symm.type) {
321                 case symmetry::symmetric:
322                         // Sort the children in ascending order
323                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
324                         break;
325                 case symmetry::antisymmetric:
326                         // Sort the children in ascending order, keeping track of the signum
327                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
328                         if (sign == 0)
329                                 return 0;
330                         break;
331                 case symmetry::cyclic:
332                         // Permute the smallest child to the front
333                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
334                         break;
335                 default:
336                         break;
337         }
338         return something_changed ? sign : INT_MAX;
339 }
340
341
342 // Symmetrize/antisymmetrize over a vector of objects
343 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
344 {
345         // Need at least 2 objects for this operation
346         unsigned num = last - first;
347         if (num < 2)
348                 return e;
349
350         // Transform object vector to a lst (for subs())
351         lst orig_lst(first, last);
352
353         // Create index vectors for permutation
354         unsigned *iv = new unsigned[num], *iv2;
355         for (unsigned i=0; i<num; i++)
356                 iv[i] = i;
357         iv2 = (asymmetric ? new unsigned[num] : NULL);
358
359         // Loop over all permutations (the first permutation, which is the
360         // identity, is unrolled)
361         ex sum = e;
362         while (std::next_permutation(iv, iv + num)) {
363                 lst new_lst;
364                 for (unsigned i=0; i<num; i++)
365                         new_lst.append(orig_lst.op(iv[i]));
366                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern);
367                 if (asymmetric) {
368                         memcpy(iv2, iv, num * sizeof(unsigned));
369                         term *= permutation_sign(iv2, iv2 + num);
370                 }
371                 sum += term;
372         }
373
374         delete[] iv;
375         delete[] iv2;
376
377         return sum / factorial(numeric(num));
378 }
379
380 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
381 {
382         return symm(e, first, last, false);
383 }
384
385 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
386 {
387         return symm(e, first, last, true);
388 }
389
390 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
391 {
392         // Need at least 2 objects for this operation
393         unsigned num = last - first;
394         if (num < 2)
395                 return e;
396
397         // Transform object vector to a lst (for subs())
398         lst orig_lst(first, last);
399         lst new_lst = orig_lst;
400
401         // Loop over all cyclic permutations (the first permutation, which is
402         // the identity, is unrolled)
403         ex sum = e;
404         for (unsigned i=0; i<num-1; i++) {
405                 ex perm = new_lst.op(0);
406                 new_lst.remove_first().append(perm);
407                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern);
408         }
409         return sum / num;
410 }
411
412 /** Symmetrize expression over a list of objects (symbols, indices). */
413 ex ex::symmetrize(const lst & l) const
414 {
415         exvector v(l.begin(), l.end());
416         return symm(*this, v.begin(), v.end(), false);
417 }
418
419 /** Antisymmetrize expression over a list of objects (symbols, indices). */
420 ex ex::antisymmetrize(const lst & l) const
421 {
422         exvector v(l.begin(), l.end());
423         return symm(*this, v.begin(), v.end(), true);
424 }
425
426 /** Symmetrize expression by cyclic permutation over a list of objects
427  *  (symbols, indices). */
428 ex ex::symmetrize_cyclic(const lst & l) const
429 {
430         exvector v(l.begin(), l.end());
431         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
432 }
433
434 } // namespace GiNaC