8a766214aaa1024d93eaa1992039622a76115887
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2002 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26
27 #include "symmetry.h"
28 #include "lst.h"
29 #include "numeric.h" // for factorial()
30 #include "print.h"
31 #include "archive.h"
32 #include "utils.h"
33
34 namespace GiNaC {
35
36 GINAC_IMPLEMENT_REGISTERED_CLASS(symmetry, basic)
37
38 /*
39    Some notes about the structure of a symmetry tree:
40     - The leaf nodes of the tree are of type "none", have one index, and no
41       children (of course). They are constructed by the symmetry(unsigned)
42       constructor.
43     - Leaf nodes are the only nodes that only have one index.
44     - Container nodes contain two or more children. The "indices" set member
45       is the set union of the index sets of all children, and the "children"
46       vector stores the children themselves.
47     - The index set of each child of a "symm", "anti" or "cycl" node must
48       have the same size. It follows that the children of such a node are
49       either all leaf nodes, or all container nodes with two or more indices.
50 */
51
52 //////////
53 // default ctor, dtor, copy ctor, assignment operator and helpers
54 //////////
55
56 symmetry::symmetry() : type(none)
57 {
58         tinfo_key = TINFO_symmetry;
59 }
60
61 void symmetry::copy(const symmetry & other)
62 {
63         inherited::copy(other);
64         type = other.type;
65         indices = other.indices;
66         children = other.children;
67 }
68
69 DEFAULT_DESTROY(symmetry)
70
71 //////////
72 // other constructors
73 //////////
74
75 symmetry::symmetry(unsigned i) : type(none)
76 {
77         indices.insert(i);
78         tinfo_key = TINFO_symmetry;
79 }
80
81 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : type(t)
82 {
83         add(c1); add(c2);
84         tinfo_key = TINFO_symmetry;
85 }
86
87 //////////
88 // archiving
89 //////////
90
91 /** Construct object from archive_node. */
92 symmetry::symmetry(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
93 {
94         unsigned t;
95         if (!(n.find_unsigned("type", t)))
96                 throw (std::runtime_error("unknown symmetry type in archive"));
97         type = (symmetry_type)t;
98
99         unsigned i = 0;
100         while (true) {
101                 ex e;
102                 if (n.find_ex("child", e, sym_lst, i))
103                         add(ex_to<symmetry>(e));
104                 else
105                         break;
106                 i++;
107         }
108
109         if (i == 0) {
110                 while (true) {
111                         unsigned u;
112                         if (n.find_unsigned("index", u, i))
113                                 indices.insert(u);
114                         else
115                                 break;
116                         i++;
117                 }
118         }
119 }
120
121 /** Archive the object. */
122 void symmetry::archive(archive_node &n) const
123 {
124         inherited::archive(n);
125
126         n.add_unsigned("type", type);
127
128         if (children.empty()) {
129                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
130                 while (i != iend) {
131                         n.add_unsigned("index", *i);
132                         i++;
133                 }
134         } else {
135                 exvector::const_iterator i = children.begin(), iend = children.end();
136                 while (i != iend) {
137                         n.add_ex("child", *i);
138                         i++;
139                 }
140         }
141 }
142
143 DEFAULT_UNARCHIVE(symmetry)
144
145 //////////
146 // functions overriding virtual functions from base classes
147 //////////
148
149 int symmetry::compare_same_type(const basic & other) const
150 {
151         GINAC_ASSERT(is_a<symmetry>(other));
152
153         // All symmetry trees are equal. They are not supposed to appear in
154         // ordinary expressions anyway...
155         return 0;
156 }
157
158 void symmetry::print(const print_context & c, unsigned level) const
159 {
160         if (is_a<print_tree>(c)) {
161
162                 c.s << std::string(level, ' ') << class_name()
163                     << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
164                     << ", type=";
165
166                 switch (type) {
167                         case none: c.s << "none"; break;
168                         case symmetric: c.s << "symm"; break;
169                         case antisymmetric: c.s << "anti"; break;
170                         case cyclic: c.s << "cycl"; break;
171                         default: c.s << "<unknown>"; break;
172                 }
173
174                 c.s << ", indices=(";
175                 if (!indices.empty()) {
176                         std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
177                         --end;
178                         while (i != end)
179                                 c.s << *i++ << ",";
180                         c.s << *i;
181                 }
182                 c.s << ")\n";
183
184                 unsigned delta_indent = static_cast<const print_tree &>(c).delta_indent;
185                 exvector::const_iterator i = children.begin(), end = children.end();
186                 while (i != end) {
187                         i->print(c, level + delta_indent);
188                         ++i;
189                 }
190
191         } else {
192
193                 if (children.empty()) {
194                         if (indices.size() > 0)
195                                 c.s << *(indices.begin());
196                         else
197                                 c.s << "none";
198                 } else {
199                         switch (type) {
200                                 case none: c.s << '!'; break;
201                                 case symmetric: c.s << '+'; break;
202                                 case antisymmetric: c.s << '-'; break;
203                                 case cyclic: c.s << '@'; break;
204                                 default: c.s << '?'; break;
205                         }
206                         c.s << '(';
207                         unsigned num = children.size();
208                         for (unsigned i=0; i<num; i++) {
209                                 children[i].print(c);
210                                 if (i != num - 1)
211                                         c.s << ",";
212                         }
213                         c.s << ')';
214                 }
215         }
216 }
217
218 //////////
219 // non-virtual functions in this class
220 //////////
221
222 symmetry &symmetry::add(const symmetry &c)
223 {
224         // All children must have the same number of indices
225         if (type != none && !children.empty()) {
226                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
227                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
228                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
229         }
230
231         // Compute union of indices and check whether the two sets are disjoint
232         std::set<unsigned> un;
233         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
234         if (un.size() != indices.size() + c.indices.size())
235                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
236
237         // Set new index set
238         indices.swap(un);
239
240         // Add child node
241         children.push_back(c);
242         return *this;
243 }
244
245 void symmetry::validate(unsigned n)
246 {
247         if (indices.upper_bound(n - 1) != indices.end())
248                 throw (std::range_error("symmetry::verify(): index values are out of range"));
249         if (type != none && indices.empty()) {
250                 for (unsigned i=0; i<n; i++)
251                         add(i);
252         }
253 }
254
255 //////////
256 // global functions
257 //////////
258
259 class sy_is_less : public std::binary_function<ex, ex, bool> {
260         exvector::iterator v;
261
262 public:
263         sy_is_less(exvector::iterator v_) : v(v_) {}
264
265         bool operator() (const ex &lh, const ex &rh) const
266         {
267                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
268                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
269                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
270                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
271                 while (ait != aitend) {
272                         int cmpval = v[*ait].compare(v[*bit]);
273                         if (cmpval < 0)
274                                 return true;
275                         else if (cmpval > 0)
276                                 return false;
277                         ++ait; ++bit;
278                 }
279                 return false;
280         }
281 };
282
283 class sy_swap : public std::binary_function<ex, ex, void> {
284         exvector::iterator v;
285
286 public:
287         bool &swapped;
288
289         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
290
291         void operator() (const ex &lh, const ex &rh)
292         {
293                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
294                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
295                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
296                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
297                 while (ait != aitend) {
298                         v[*ait].swap(v[*bit]);
299                         ++ait; ++bit;
300                 }
301                 swapped = true;
302         }
303 };
304
305 int canonicalize(exvector::iterator v, const symmetry &symm)
306 {
307         // Less than two indices? Then do nothing
308         if (symm.indices.size() < 2)
309                 return INT_MAX;
310
311         // Canonicalize children first
312         bool something_changed = false;
313         int sign = 1;
314         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
315         while (first != last) {
316                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
317                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
318                 if (child_sign == 0)
319                         return 0;
320                 if (child_sign != INT_MAX) {
321                         something_changed = true;
322                         sign *= child_sign;
323                 }
324                 first++;
325         }
326
327         // Now reorder the children
328         first = symm.children.begin();
329         switch (symm.type) {
330                 case symmetry::symmetric:
331                         // Sort the children in ascending order
332                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
333                         break;
334                 case symmetry::antisymmetric:
335                         // Sort the children in ascending order, keeping track of the signum
336                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
337                         break;
338                 case symmetry::cyclic:
339                         // Permute the smallest child to the front
340                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
341                         break;
342                 default:
343                         break;
344         }
345         return something_changed ? sign : INT_MAX;
346 }
347
348
349 // Symmetrize/antisymmetrize over a vector of objects
350 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
351 {
352         // Need at least 2 objects for this operation
353         unsigned num = last - first;
354         if (num < 2)
355                 return e;
356
357         // Transform object vector to a list
358         exlist iv_lst;
359         iv_lst.insert(iv_lst.begin(), first, last);
360         lst orig_lst(iv_lst, true);
361
362         // Create index vectors for permutation
363         unsigned *iv = new unsigned[num], *iv2;
364         for (unsigned i=0; i<num; i++)
365                 iv[i] = i;
366         iv2 = (asymmetric ? new unsigned[num] : NULL);
367
368         // Loop over all permutations (the first permutation, which is the
369         // identity, is unrolled)
370         ex sum = e;
371         while (std::next_permutation(iv, iv + num)) {
372                 lst new_lst;
373                 for (unsigned i=0; i<num; i++)
374                         new_lst.append(orig_lst.op(iv[i]));
375                 ex term = e.subs(orig_lst, new_lst);
376                 if (asymmetric) {
377                         memcpy(iv2, iv, num * sizeof(unsigned));
378                         term *= permutation_sign(iv2, iv2 + num);
379                 }
380                 sum += term;
381         }
382
383         delete[] iv;
384         delete[] iv2;
385
386         return sum / factorial(numeric(num));
387 }
388
389 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
390 {
391         return symm(e, first, last, false);
392 }
393
394 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
395 {
396         return symm(e, first, last, true);
397 }
398
399 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
400 {
401         // Need at least 2 objects for this operation
402         unsigned num = last - first;
403         if (num < 2)
404                 return e;
405
406         // Transform object vector to a list
407         exlist iv_lst;
408         iv_lst.insert(iv_lst.begin(), first, last);
409         lst orig_lst(iv_lst, true);
410         lst new_lst = orig_lst;
411
412         // Loop over all cyclic permutations (the first permutation, which is
413         // the identity, is unrolled)
414         ex sum = e;
415         for (unsigned i=0; i<num-1; i++) {
416                 ex perm = new_lst.op(0);
417                 new_lst.remove_first().append(perm);
418                 sum += e.subs(orig_lst, new_lst);
419         }
420         return sum / num;
421 }
422
423 /** Symmetrize expression over a list of objects (symbols, indices). */
424 ex ex::symmetrize(const lst & l) const
425 {
426         exvector v;
427         v.reserve(l.nops());
428         for (unsigned i=0; i<l.nops(); i++)
429                 v.push_back(l.op(i));
430         return symm(*this, v.begin(), v.end(), false);
431 }
432
433 /** Antisymmetrize expression over a list of objects (symbols, indices). */
434 ex ex::antisymmetrize(const lst & l) const
435 {
436         exvector v;
437         v.reserve(l.nops());
438         for (unsigned i=0; i<l.nops(); i++)
439                 v.push_back(l.op(i));
440         return symm(*this, v.begin(), v.end(), true);
441 }
442
443 /** Symmetrize expression by cyclic permutation over a list of objects
444  *  (symbols, indices). */
445 ex ex::symmetrize_cyclic(const lst & l) const
446 {
447         exvector v;
448         v.reserve(l.nops());
449         for (unsigned i=0; i<l.nops(); i++)
450                 v.push_back(l.op(i));
451         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
452 }
453
454 } // namespace GiNaC