]> www.ginac.de Git - ginac.git/blob - ginac/symmetry.cpp
series expansion behaviour fixed.
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "operators.h"
31 #include "print.h"
32 #include "archive.h"
33 #include "utils.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS(symmetry, basic)
38
39 /*
40    Some notes about the structure of a symmetry tree:
41     - The leaf nodes of the tree are of type "none", have one index, and no
42       children (of course). They are constructed by the symmetry(unsigned)
43       constructor.
44     - Leaf nodes are the only nodes that only have one index.
45     - Container nodes contain two or more children. The "indices" set member
46       is the set union of the index sets of all children, and the "children"
47       vector stores the children themselves.
48     - The index set of each child of a "symm", "anti" or "cycl" node must
49       have the same size. It follows that the children of such a node are
50       either all leaf nodes, or all container nodes with two or more indices.
51 */
52
53 //////////
54 // default ctor, dtor, copy ctor, assignment operator and helpers
55 //////////
56
57 symmetry::symmetry() : type(none)
58 {
59         tinfo_key = TINFO_symmetry;
60 }
61
62 void symmetry::copy(const symmetry & other)
63 {
64         inherited::copy(other);
65         type = other.type;
66         indices = other.indices;
67         children = other.children;
68 }
69
70 DEFAULT_DESTROY(symmetry)
71
72 //////////
73 // other constructors
74 //////////
75
76 symmetry::symmetry(unsigned i) : type(none)
77 {
78         indices.insert(i);
79         tinfo_key = TINFO_symmetry;
80 }
81
82 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : type(t)
83 {
84         add(c1); add(c2);
85         tinfo_key = TINFO_symmetry;
86 }
87
88 //////////
89 // archiving
90 //////////
91
92 /** Construct object from archive_node. */
93 symmetry::symmetry(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
94 {
95         unsigned t;
96         if (!(n.find_unsigned("type", t)))
97                 throw (std::runtime_error("unknown symmetry type in archive"));
98         type = (symmetry_type)t;
99
100         unsigned i = 0;
101         while (true) {
102                 ex e;
103                 if (n.find_ex("child", e, sym_lst, i))
104                         add(ex_to<symmetry>(e));
105                 else
106                         break;
107                 i++;
108         }
109
110         if (i == 0) {
111                 while (true) {
112                         unsigned u;
113                         if (n.find_unsigned("index", u, i))
114                                 indices.insert(u);
115                         else
116                                 break;
117                         i++;
118                 }
119         }
120 }
121
122 /** Archive the object. */
123 void symmetry::archive(archive_node &n) const
124 {
125         inherited::archive(n);
126
127         n.add_unsigned("type", type);
128
129         if (children.empty()) {
130                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
131                 while (i != iend) {
132                         n.add_unsigned("index", *i);
133                         i++;
134                 }
135         } else {
136                 exvector::const_iterator i = children.begin(), iend = children.end();
137                 while (i != iend) {
138                         n.add_ex("child", *i);
139                         i++;
140                 }
141         }
142 }
143
144 DEFAULT_UNARCHIVE(symmetry)
145
146 //////////
147 // functions overriding virtual functions from base classes
148 //////////
149
150 int symmetry::compare_same_type(const basic & other) const
151 {
152         GINAC_ASSERT(is_a<symmetry>(other));
153
154         // All symmetry trees are equal. They are not supposed to appear in
155         // ordinary expressions anyway...
156         return 0;
157 }
158
159 void symmetry::print(const print_context & c, unsigned level) const
160 {
161         if (is_a<print_tree>(c)) {
162
163                 c.s << std::string(level, ' ') << class_name()
164                     << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
165                     << ", type=";
166
167                 switch (type) {
168                         case none: c.s << "none"; break;
169                         case symmetric: c.s << "symm"; break;
170                         case antisymmetric: c.s << "anti"; break;
171                         case cyclic: c.s << "cycl"; break;
172                         default: c.s << "<unknown>"; break;
173                 }
174
175                 c.s << ", indices=(";
176                 if (!indices.empty()) {
177                         std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
178                         --end;
179                         while (i != end)
180                                 c.s << *i++ << ",";
181                         c.s << *i;
182                 }
183                 c.s << ")\n";
184
185                 unsigned delta_indent = static_cast<const print_tree &>(c).delta_indent;
186                 exvector::const_iterator i = children.begin(), end = children.end();
187                 while (i != end) {
188                         i->print(c, level + delta_indent);
189                         ++i;
190                 }
191
192         } else {
193
194                 if (children.empty()) {
195                         if (indices.size() > 0)
196                                 c.s << *(indices.begin());
197                         else
198                                 c.s << "none";
199                 } else {
200                         switch (type) {
201                                 case none: c.s << '!'; break;
202                                 case symmetric: c.s << '+'; break;
203                                 case antisymmetric: c.s << '-'; break;
204                                 case cyclic: c.s << '@'; break;
205                                 default: c.s << '?'; break;
206                         }
207                         c.s << '(';
208                         size_t num = children.size();
209                         for (size_t i=0; i<num; i++) {
210                                 children[i].print(c);
211                                 if (i != num - 1)
212                                         c.s << ",";
213                         }
214                         c.s << ')';
215                 }
216         }
217 }
218
219 //////////
220 // non-virtual functions in this class
221 //////////
222
223 symmetry &symmetry::add(const symmetry &c)
224 {
225         // All children must have the same number of indices
226         if (type != none && !children.empty()) {
227                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
228                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
229                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
230         }
231
232         // Compute union of indices and check whether the two sets are disjoint
233         std::set<unsigned> un;
234         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
235         if (un.size() != indices.size() + c.indices.size())
236                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
237
238         // Set new index set
239         indices.swap(un);
240
241         // Add child node
242         children.push_back(c);
243         return *this;
244 }
245
246 void symmetry::validate(unsigned n)
247 {
248         if (indices.upper_bound(n - 1) != indices.end())
249                 throw (std::range_error("symmetry::verify(): index values are out of range"));
250         if (type != none && indices.empty()) {
251                 for (unsigned i=0; i<n; i++)
252                         add(i);
253         }
254 }
255
256 //////////
257 // global functions
258 //////////
259
260 class sy_is_less : public std::binary_function<ex, ex, bool> {
261         exvector::iterator v;
262
263 public:
264         sy_is_less(exvector::iterator v_) : v(v_) {}
265
266         bool operator() (const ex &lh, const ex &rh) const
267         {
268                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
269                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
270                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
271                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
272                 while (ait != aitend) {
273                         int cmpval = v[*ait].compare(v[*bit]);
274                         if (cmpval < 0)
275                                 return true;
276                         else if (cmpval > 0)
277                                 return false;
278                         ++ait; ++bit;
279                 }
280                 return false;
281         }
282 };
283
284 class sy_swap : public std::binary_function<ex, ex, void> {
285         exvector::iterator v;
286
287 public:
288         bool &swapped;
289
290         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
291
292         void operator() (const ex &lh, const ex &rh)
293         {
294                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
295                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
296                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
297                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
298                 while (ait != aitend) {
299                         v[*ait].swap(v[*bit]);
300                         ++ait; ++bit;
301                 }
302                 swapped = true;
303         }
304 };
305
306 int canonicalize(exvector::iterator v, const symmetry &symm)
307 {
308         // Less than two elements? Then do nothing
309         if (symm.indices.size() < 2)
310                 return INT_MAX;
311
312         // Canonicalize children first
313         bool something_changed = false;
314         int sign = 1;
315         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
316         while (first != last) {
317                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
318                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
319                 if (child_sign == 0)
320                         return 0;
321                 if (child_sign != INT_MAX) {
322                         something_changed = true;
323                         sign *= child_sign;
324                 }
325                 first++;
326         }
327
328         // Now reorder the children
329         first = symm.children.begin();
330         switch (symm.type) {
331                 case symmetry::symmetric:
332                         // Sort the children in ascending order
333                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
334                         break;
335                 case symmetry::antisymmetric:
336                         // Sort the children in ascending order, keeping track of the signum
337                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
338                         if (sign == 0)
339                                 return 0;
340                         break;
341                 case symmetry::cyclic:
342                         // Permute the smallest child to the front
343                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
344                         break;
345                 default:
346                         break;
347         }
348         return something_changed ? sign : INT_MAX;
349 }
350
351
352 // Symmetrize/antisymmetrize over a vector of objects
353 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
354 {
355         // Need at least 2 objects for this operation
356         unsigned num = last - first;
357         if (num < 2)
358                 return e;
359
360         // Transform object vector to a list
361         exlist iv_lst;
362         iv_lst.insert(iv_lst.begin(), first, last);
363         lst orig_lst(iv_lst, true);
364
365         // Create index vectors for permutation
366         unsigned *iv = new unsigned[num], *iv2;
367         for (unsigned i=0; i<num; i++)
368                 iv[i] = i;
369         iv2 = (asymmetric ? new unsigned[num] : NULL);
370
371         // Loop over all permutations (the first permutation, which is the
372         // identity, is unrolled)
373         ex sum = e;
374         while (std::next_permutation(iv, iv + num)) {
375                 lst new_lst;
376                 for (unsigned i=0; i<num; i++)
377                         new_lst.append(orig_lst.op(iv[i]));
378                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern);
379                 if (asymmetric) {
380                         memcpy(iv2, iv, num * sizeof(unsigned));
381                         term *= permutation_sign(iv2, iv2 + num);
382                 }
383                 sum += term;
384         }
385
386         delete[] iv;
387         delete[] iv2;
388
389         return sum / factorial(numeric(num));
390 }
391
392 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
393 {
394         return symm(e, first, last, false);
395 }
396
397 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
398 {
399         return symm(e, first, last, true);
400 }
401
402 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
403 {
404         // Need at least 2 objects for this operation
405         unsigned num = last - first;
406         if (num < 2)
407                 return e;
408
409         // Transform object vector to a list
410         exlist iv_lst;
411         iv_lst.insert(iv_lst.begin(), first, last);
412         lst orig_lst(iv_lst, true);
413         lst new_lst = orig_lst;
414
415         // Loop over all cyclic permutations (the first permutation, which is
416         // the identity, is unrolled)
417         ex sum = e;
418         for (unsigned i=0; i<num-1; i++) {
419                 ex perm = new_lst.op(0);
420                 new_lst.remove_first().append(perm);
421                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern);
422         }
423         return sum / num;
424 }
425
426 /** Symmetrize expression over a list of objects (symbols, indices). */
427 ex ex::symmetrize(const lst & l) const
428 {
429         exvector v(l.begin(), l.end());
430         return symm(*this, v.begin(), v.end(), false);
431 }
432
433 /** Antisymmetrize expression over a list of objects (symbols, indices). */
434 ex ex::antisymmetrize(const lst & l) const
435 {
436         exvector v(l.begin(), l.end());
437         return symm(*this, v.begin(), v.end(), true);
438 }
439
440 /** Symmetrize expression by cyclic permutation over a list of objects
441  *  (symbols, indices). */
442 ex ex::symmetrize_cyclic(const lst & l) const
443 {
444         exvector v(l.begin(), l.end());
445         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
446 }
447
448 } // namespace GiNaC