- prepared for 1.0.13 release
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "print.h"
31 #include "archive.h"
32 #include "utils.h"
33
34 namespace GiNaC {
35
36 GINAC_IMPLEMENT_REGISTERED_CLASS(symmetry, basic)
37
38 /*
39    Some notes about the structure of a symmetry tree:
40     - The leaf nodes of the tree are of type "none", have one index, and no
41       children (of course). They are constructed by the symmetry(unsigned)
42       constructor.
43     - Leaf nodes are the only nodes that only have one index.
44     - Container nodes contain two or more children. The "indices" set member
45       is the set union of the index sets of all children, and the "children"
46       vector stores the children themselves.
47     - The index set of each child of a "symm", "anti" or "cycl" node must
48       have the same size. It follows that the children of such a node are
49       either all leaf nodes, or all container nodes with two or more indices.
50 */
51
52 //////////
53 // default ctor, dtor, copy ctor, assignment operator and helpers
54 //////////
55
56 symmetry::symmetry() : type(none)
57 {
58         tinfo_key = TINFO_symmetry;
59 }
60
61 void symmetry::copy(const symmetry & other)
62 {
63         inherited::copy(other);
64         type = other.type;
65         indices = other.indices;
66         children = other.children;
67 }
68
69 DEFAULT_DESTROY(symmetry)
70
71 //////////
72 // other constructors
73 //////////
74
75 symmetry::symmetry(unsigned i) : type(none)
76 {
77         indices.insert(i);
78         tinfo_key = TINFO_symmetry;
79 }
80
81 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : type(t)
82 {
83         add(c1); add(c2);
84         tinfo_key = TINFO_symmetry;
85 }
86
87 //////////
88 // archiving
89 //////////
90
91 /** Construct object from archive_node. */
92 symmetry::symmetry(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
93 {
94         unsigned t;
95         if (!(n.find_unsigned("type", t)))
96                 throw (std::runtime_error("unknown symmetry type in archive"));
97         type = (symmetry_type)t;
98
99         unsigned i = 0;
100         while (true) {
101                 ex e;
102                 if (n.find_ex("child", e, sym_lst, i))
103                         add(ex_to<symmetry>(e));
104                 else
105                         break;
106                 i++;
107         }
108
109         if (i == 0) {
110                 while (true) {
111                         unsigned u;
112                         if (n.find_unsigned("index", u, i))
113                                 indices.insert(u);
114                         else
115                                 break;
116                         i++;
117                 }
118         }
119 }
120
121 /** Archive the object. */
122 void symmetry::archive(archive_node &n) const
123 {
124         inherited::archive(n);
125
126         n.add_unsigned("type", type);
127
128         if (children.empty()) {
129                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
130                 while (i != iend) {
131                         n.add_unsigned("index", *i);
132                         i++;
133                 }
134         } else {
135                 exvector::const_iterator i = children.begin(), iend = children.end();
136                 while (i != iend) {
137                         n.add_ex("child", *i);
138                         i++;
139                 }
140         }
141 }
142
143 DEFAULT_UNARCHIVE(symmetry)
144
145 //////////
146 // functions overriding virtual functions from base classes
147 //////////
148
149 int symmetry::compare_same_type(const basic & other) const
150 {
151         GINAC_ASSERT(is_a<symmetry>(other));
152
153         // All symmetry trees are equal. They are not supposed to appear in
154         // ordinary expressions anyway...
155         return 0;
156 }
157
158 void symmetry::print(const print_context & c, unsigned level) const
159 {
160         if (is_a<print_tree>(c)) {
161
162                 c.s << std::string(level, ' ') << class_name()
163                     << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
164                     << ", type=";
165
166                 switch (type) {
167                         case none: c.s << "none"; break;
168                         case symmetric: c.s << "symm"; break;
169                         case antisymmetric: c.s << "anti"; break;
170                         case cyclic: c.s << "cycl"; break;
171                         default: c.s << "<unknown>"; break;
172                 }
173
174                 c.s << ", indices=(";
175                 if (!indices.empty()) {
176                         std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
177                         --end;
178                         while (i != end)
179                                 c.s << *i++ << ",";
180                         c.s << *i;
181                 }
182                 c.s << ")\n";
183
184                 unsigned delta_indent = static_cast<const print_tree &>(c).delta_indent;
185                 exvector::const_iterator i = children.begin(), end = children.end();
186                 while (i != end) {
187                         i->print(c, level + delta_indent);
188                         ++i;
189                 }
190
191         } else {
192
193                 if (children.empty()) {
194                         if (indices.size() > 0)
195                                 c.s << *(indices.begin());
196                         else
197                                 c.s << "none";
198                 } else {
199                         switch (type) {
200                                 case none: c.s << '!'; break;
201                                 case symmetric: c.s << '+'; break;
202                                 case antisymmetric: c.s << '-'; break;
203                                 case cyclic: c.s << '@'; break;
204                                 default: c.s << '?'; break;
205                         }
206                         c.s << '(';
207                         unsigned num = children.size();
208                         for (unsigned i=0; i<num; i++) {
209                                 children[i].print(c);
210                                 if (i != num - 1)
211                                         c.s << ",";
212                         }
213                         c.s << ')';
214                 }
215         }
216 }
217
218 //////////
219 // non-virtual functions in this class
220 //////////
221
222 symmetry &symmetry::add(const symmetry &c)
223 {
224         // All children must have the same number of indices
225         if (type != none && !children.empty()) {
226                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
227                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
228                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
229         }
230
231         // Compute union of indices and check whether the two sets are disjoint
232         std::set<unsigned> un;
233         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
234         if (un.size() != indices.size() + c.indices.size())
235                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
236
237         // Set new index set
238         indices.swap(un);
239
240         // Add child node
241         children.push_back(c);
242         return *this;
243 }
244
245 void symmetry::validate(unsigned n)
246 {
247         if (indices.upper_bound(n - 1) != indices.end())
248                 throw (std::range_error("symmetry::verify(): index values are out of range"));
249         if (type != none && indices.empty()) {
250                 for (unsigned i=0; i<n; i++)
251                         add(i);
252         }
253 }
254
255 //////////
256 // global functions
257 //////////
258
259 class sy_is_less : public std::binary_function<ex, ex, bool> {
260         exvector::iterator v;
261
262 public:
263         sy_is_less(exvector::iterator v_) : v(v_) {}
264
265         bool operator() (const ex &lh, const ex &rh) const
266         {
267                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
268                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
269                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
270                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
271                 while (ait != aitend) {
272                         int cmpval = v[*ait].compare(v[*bit]);
273                         if (cmpval < 0)
274                                 return true;
275                         else if (cmpval > 0)
276                                 return false;
277                         ++ait; ++bit;
278                 }
279                 return false;
280         }
281 };
282
283 class sy_swap : public std::binary_function<ex, ex, void> {
284         exvector::iterator v;
285
286 public:
287         bool &swapped;
288
289         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
290
291         void operator() (const ex &lh, const ex &rh)
292         {
293                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
294                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
295                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
296                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
297                 while (ait != aitend) {
298                         v[*ait].swap(v[*bit]);
299                         ++ait; ++bit;
300                 }
301                 swapped = true;
302         }
303 };
304
305 int canonicalize(exvector::iterator v, const symmetry &symm)
306 {
307         // Less than two elements? Then do nothing
308         if (symm.indices.size() < 2)
309                 return INT_MAX;
310
311         // Canonicalize children first
312         bool something_changed = false;
313         int sign = 1;
314         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
315         while (first != last) {
316                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
317                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
318                 if (child_sign == 0)
319                         return 0;
320                 if (child_sign != INT_MAX) {
321                         something_changed = true;
322                         sign *= child_sign;
323                 }
324                 first++;
325         }
326
327         // Now reorder the children
328         first = symm.children.begin();
329         switch (symm.type) {
330                 case symmetry::symmetric:
331                         // Sort the children in ascending order
332                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
333                         break;
334                 case symmetry::antisymmetric:
335                         // Sort the children in ascending order, keeping track of the signum
336                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
337                         if (sign == 0)
338                                 return 0;
339                         break;
340                 case symmetry::cyclic:
341                         // Permute the smallest child to the front
342                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
343                         break;
344                 default:
345                         break;
346         }
347         return something_changed ? sign : INT_MAX;
348 }
349
350
351 // Symmetrize/antisymmetrize over a vector of objects
352 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
353 {
354         // Need at least 2 objects for this operation
355         unsigned num = last - first;
356         if (num < 2)
357                 return e;
358
359         // Transform object vector to a list
360         exlist iv_lst;
361         iv_lst.insert(iv_lst.begin(), first, last);
362         lst orig_lst(iv_lst, true);
363
364         // Create index vectors for permutation
365         unsigned *iv = new unsigned[num], *iv2;
366         for (unsigned i=0; i<num; i++)
367                 iv[i] = i;
368         iv2 = (asymmetric ? new unsigned[num] : NULL);
369
370         // Loop over all permutations (the first permutation, which is the
371         // identity, is unrolled)
372         ex sum = e;
373         while (std::next_permutation(iv, iv + num)) {
374                 lst new_lst;
375                 for (unsigned i=0; i<num; i++)
376                         new_lst.append(orig_lst.op(iv[i]));
377                 ex term = e.subs(orig_lst, new_lst);
378                 if (asymmetric) {
379                         memcpy(iv2, iv, num * sizeof(unsigned));
380                         term *= permutation_sign(iv2, iv2 + num);
381                 }
382                 sum += term;
383         }
384
385         delete[] iv;
386         delete[] iv2;
387
388         return sum / factorial(numeric(num));
389 }
390
391 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
392 {
393         return symm(e, first, last, false);
394 }
395
396 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
397 {
398         return symm(e, first, last, true);
399 }
400
401 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
402 {
403         // Need at least 2 objects for this operation
404         unsigned num = last - first;
405         if (num < 2)
406                 return e;
407
408         // Transform object vector to a list
409         exlist iv_lst;
410         iv_lst.insert(iv_lst.begin(), first, last);
411         lst orig_lst(iv_lst, true);
412         lst new_lst = orig_lst;
413
414         // Loop over all cyclic permutations (the first permutation, which is
415         // the identity, is unrolled)
416         ex sum = e;
417         for (unsigned i=0; i<num-1; i++) {
418                 ex perm = new_lst.op(0);
419                 new_lst.remove_first().append(perm);
420                 sum += e.subs(orig_lst, new_lst);
421         }
422         return sum / num;
423 }
424
425 /** Symmetrize expression over a list of objects (symbols, indices). */
426 ex ex::symmetrize(const lst & l) const
427 {
428         exvector v;
429         v.reserve(l.nops());
430         for (unsigned i=0; i<l.nops(); i++)
431                 v.push_back(l.op(i));
432         return symm(*this, v.begin(), v.end(), false);
433 }
434
435 /** Antisymmetrize expression over a list of objects (symbols, indices). */
436 ex ex::antisymmetrize(const lst & l) const
437 {
438         exvector v;
439         v.reserve(l.nops());
440         for (unsigned i=0; i<l.nops(); i++)
441                 v.push_back(l.op(i));
442         return symm(*this, v.begin(), v.end(), true);
443 }
444
445 /** Symmetrize expression by cyclic permutation over a list of objects
446  *  (symbols, indices). */
447 ex ex::symmetrize_cyclic(const lst & l) const
448 {
449         exvector v;
450         v.reserve(l.nops());
451         for (unsigned i=0; i<l.nops(); i++)
452                 v.push_back(l.op(i));
453         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
454 }
455
456 } // namespace GiNaC