* zeta(n,x) is now zetaderiv(n,s)
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "operators.h"
31 #include "archive.h"
32 #include "utils.h"
33
34 namespace GiNaC {
35
36 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(symmetry, basic,
37   print_func<print_context>(&symmetry::do_print).
38   print_func<print_tree>(&symmetry::do_print_tree))
39
40 /*
41    Some notes about the structure of a symmetry tree:
42     - The leaf nodes of the tree are of type "none", have one index, and no
43       children (of course). They are constructed by the symmetry(unsigned)
44       constructor.
45     - Leaf nodes are the only nodes that only have one index.
46     - Container nodes contain two or more children. The "indices" set member
47       is the set union of the index sets of all children, and the "children"
48       vector stores the children themselves.
49     - The index set of each child of a "symm", "anti" or "cycl" node must
50       have the same size. It follows that the children of such a node are
51       either all leaf nodes, or all container nodes with two or more indices.
52 */
53
54 //////////
55 // default constructor
56 //////////
57
58 symmetry::symmetry() : type(none)
59 {
60         tinfo_key = TINFO_symmetry;
61 }
62
63 //////////
64 // other constructors
65 //////////
66
67 symmetry::symmetry(unsigned i) : type(none)
68 {
69         indices.insert(i);
70         tinfo_key = TINFO_symmetry;
71 }
72
73 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : type(t)
74 {
75         add(c1); add(c2);
76         tinfo_key = TINFO_symmetry;
77 }
78
79 //////////
80 // archiving
81 //////////
82
83 /** Construct object from archive_node. */
84 symmetry::symmetry(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
85 {
86         unsigned t;
87         if (!(n.find_unsigned("type", t)))
88                 throw (std::runtime_error("unknown symmetry type in archive"));
89         type = (symmetry_type)t;
90
91         unsigned i = 0;
92         while (true) {
93                 ex e;
94                 if (n.find_ex("child", e, sym_lst, i))
95                         add(ex_to<symmetry>(e));
96                 else
97                         break;
98                 i++;
99         }
100
101         if (i == 0) {
102                 while (true) {
103                         unsigned u;
104                         if (n.find_unsigned("index", u, i))
105                                 indices.insert(u);
106                         else
107                                 break;
108                         i++;
109                 }
110         }
111 }
112
113 /** Archive the object. */
114 void symmetry::archive(archive_node &n) const
115 {
116         inherited::archive(n);
117
118         n.add_unsigned("type", type);
119
120         if (children.empty()) {
121                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
122                 while (i != iend) {
123                         n.add_unsigned("index", *i);
124                         i++;
125                 }
126         } else {
127                 exvector::const_iterator i = children.begin(), iend = children.end();
128                 while (i != iend) {
129                         n.add_ex("child", *i);
130                         i++;
131                 }
132         }
133 }
134
135 DEFAULT_UNARCHIVE(symmetry)
136
137 //////////
138 // functions overriding virtual functions from base classes
139 //////////
140
141 int symmetry::compare_same_type(const basic & other) const
142 {
143         GINAC_ASSERT(is_a<symmetry>(other));
144
145         // All symmetry trees are equal. They are not supposed to appear in
146         // ordinary expressions anyway...
147         return 0;
148 }
149
150 void symmetry::do_print(const print_context & c, unsigned level) const
151 {
152         if (children.empty()) {
153                 if (indices.size() > 0)
154                         c.s << *(indices.begin());
155                 else
156                         c.s << "none";
157         } else {
158                 switch (type) {
159                         case none: c.s << '!'; break;
160                         case symmetric: c.s << '+'; break;
161                         case antisymmetric: c.s << '-'; break;
162                         case cyclic: c.s << '@'; break;
163                         default: c.s << '?'; break;
164                 }
165                 c.s << '(';
166                 size_t num = children.size();
167                 for (size_t i=0; i<num; i++) {
168                         children[i].print(c);
169                         if (i != num - 1)
170                                 c.s << ",";
171                 }
172                 c.s << ')';
173         }
174 }
175
176 void symmetry::do_print_tree(const print_tree & c, unsigned level) const
177 {
178         c.s << std::string(level, ' ') << class_name() << " @" << this
179             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
180             << ", type=";
181
182         switch (type) {
183                 case none: c.s << "none"; break;
184                 case symmetric: c.s << "symm"; break;
185                 case antisymmetric: c.s << "anti"; break;
186                 case cyclic: c.s << "cycl"; break;
187                 default: c.s << "<unknown>"; break;
188         }
189
190         c.s << ", indices=(";
191         if (!indices.empty()) {
192                 std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
193                 --end;
194                 while (i != end)
195                         c.s << *i++ << ",";
196                 c.s << *i;
197         }
198         c.s << ")\n";
199
200         exvector::const_iterator i = children.begin(), end = children.end();
201         while (i != end) {
202                 i->print(c, level + c.delta_indent);
203                 ++i;
204         }
205 }
206
207 //////////
208 // non-virtual functions in this class
209 //////////
210
211 symmetry &symmetry::add(const symmetry &c)
212 {
213         // All children must have the same number of indices
214         if (type != none && !children.empty()) {
215                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
216                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
217                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
218         }
219
220         // Compute union of indices and check whether the two sets are disjoint
221         std::set<unsigned> un;
222         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
223         if (un.size() != indices.size() + c.indices.size())
224                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
225
226         // Set new index set
227         indices.swap(un);
228
229         // Add child node
230         children.push_back(c);
231         return *this;
232 }
233
234 void symmetry::validate(unsigned n)
235 {
236         if (indices.upper_bound(n - 1) != indices.end())
237                 throw (std::range_error("symmetry::verify(): index values are out of range"));
238         if (type != none && indices.empty()) {
239                 for (unsigned i=0; i<n; i++)
240                         add(i);
241         }
242 }
243
244 //////////
245 // global functions
246 //////////
247
248 class sy_is_less : public std::binary_function<ex, ex, bool> {
249         exvector::iterator v;
250
251 public:
252         sy_is_less(exvector::iterator v_) : v(v_) {}
253
254         bool operator() (const ex &lh, const ex &rh) const
255         {
256                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
257                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
258                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
259                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
260                 while (ait != aitend) {
261                         int cmpval = v[*ait].compare(v[*bit]);
262                         if (cmpval < 0)
263                                 return true;
264                         else if (cmpval > 0)
265                                 return false;
266                         ++ait; ++bit;
267                 }
268                 return false;
269         }
270 };
271
272 class sy_swap : public std::binary_function<ex, ex, void> {
273         exvector::iterator v;
274
275 public:
276         bool &swapped;
277
278         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
279
280         void operator() (const ex &lh, const ex &rh)
281         {
282                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
283                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
284                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
285                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
286                 while (ait != aitend) {
287                         v[*ait].swap(v[*bit]);
288                         ++ait; ++bit;
289                 }
290                 swapped = true;
291         }
292 };
293
294 int canonicalize(exvector::iterator v, const symmetry &symm)
295 {
296         // Less than two elements? Then do nothing
297         if (symm.indices.size() < 2)
298                 return INT_MAX;
299
300         // Canonicalize children first
301         bool something_changed = false;
302         int sign = 1;
303         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
304         while (first != last) {
305                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
306                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
307                 if (child_sign == 0)
308                         return 0;
309                 if (child_sign != INT_MAX) {
310                         something_changed = true;
311                         sign *= child_sign;
312                 }
313                 first++;
314         }
315
316         // Now reorder the children
317         first = symm.children.begin();
318         switch (symm.type) {
319                 case symmetry::symmetric:
320                         // Sort the children in ascending order
321                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
322                         break;
323                 case symmetry::antisymmetric:
324                         // Sort the children in ascending order, keeping track of the signum
325                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
326                         if (sign == 0)
327                                 return 0;
328                         break;
329                 case symmetry::cyclic:
330                         // Permute the smallest child to the front
331                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
332                         break;
333                 default:
334                         break;
335         }
336         return something_changed ? sign : INT_MAX;
337 }
338
339
340 // Symmetrize/antisymmetrize over a vector of objects
341 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
342 {
343         // Need at least 2 objects for this operation
344         unsigned num = last - first;
345         if (num < 2)
346                 return e;
347
348         // Transform object vector to a lst (for subs())
349         lst orig_lst(first, last);
350
351         // Create index vectors for permutation
352         unsigned *iv = new unsigned[num], *iv2;
353         for (unsigned i=0; i<num; i++)
354                 iv[i] = i;
355         iv2 = (asymmetric ? new unsigned[num] : NULL);
356
357         // Loop over all permutations (the first permutation, which is the
358         // identity, is unrolled)
359         ex sum = e;
360         while (std::next_permutation(iv, iv + num)) {
361                 lst new_lst;
362                 for (unsigned i=0; i<num; i++)
363                         new_lst.append(orig_lst.op(iv[i]));
364                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern);
365                 if (asymmetric) {
366                         memcpy(iv2, iv, num * sizeof(unsigned));
367                         term *= permutation_sign(iv2, iv2 + num);
368                 }
369                 sum += term;
370         }
371
372         delete[] iv;
373         delete[] iv2;
374
375         return sum / factorial(numeric(num));
376 }
377
378 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
379 {
380         return symm(e, first, last, false);
381 }
382
383 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
384 {
385         return symm(e, first, last, true);
386 }
387
388 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
389 {
390         // Need at least 2 objects for this operation
391         unsigned num = last - first;
392         if (num < 2)
393                 return e;
394
395         // Transform object vector to a lst (for subs())
396         lst orig_lst(first, last);
397         lst new_lst = orig_lst;
398
399         // Loop over all cyclic permutations (the first permutation, which is
400         // the identity, is unrolled)
401         ex sum = e;
402         for (unsigned i=0; i<num-1; i++) {
403                 ex perm = new_lst.op(0);
404                 new_lst.remove_first().append(perm);
405                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern);
406         }
407         return sum / num;
408 }
409
410 /** Symmetrize expression over a list of objects (symbols, indices). */
411 ex ex::symmetrize(const lst & l) const
412 {
413         exvector v(l.begin(), l.end());
414         return symm(*this, v.begin(), v.end(), false);
415 }
416
417 /** Antisymmetrize expression over a list of objects (symbols, indices). */
418 ex ex::antisymmetrize(const lst & l) const
419 {
420         exvector v(l.begin(), l.end());
421         return symm(*this, v.begin(), v.end(), true);
422 }
423
424 /** Symmetrize expression by cyclic permutation over a list of objects
425  *  (symbols, indices). */
426 ex ex::symmetrize_cyclic(const lst & l) const
427 {
428         exvector v(l.begin(), l.end());
429         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
430 }
431
432 } // namespace GiNaC