]> www.ginac.de Git - ginac.git/blob - ginac/symmetry.cpp
[bugfix] remainder_in_ring: using exact division is plain wrong.
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26 #include <limits>
27
28 #include "symmetry.h"
29 #include "lst.h"
30 #include "numeric.h" // for factorial()
31 #include "operators.h"
32 #include "archive.h"
33 #include "utils.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(symmetry, basic,
38   print_func<print_context>(&symmetry::do_print).
39   print_func<print_tree>(&symmetry::do_print_tree))
40
41 /*
42    Some notes about the structure of a symmetry tree:
43     - The leaf nodes of the tree are of type "none", have one index, and no
44       children (of course). They are constructed by the symmetry(unsigned)
45       constructor.
46     - Leaf nodes are the only nodes that only have one index.
47     - Container nodes contain two or more children. The "indices" set member
48       is the set union of the index sets of all children, and the "children"
49       vector stores the children themselves.
50     - The index set of each child of a "symm", "anti" or "cycl" node must
51       have the same size. It follows that the children of such a node are
52       either all leaf nodes, or all container nodes with two or more indices.
53 */
54
55 //////////
56 // default constructor
57 //////////
58
59 symmetry::symmetry() : inherited(&symmetry::tinfo_static), type(none)
60 {
61         setflag(status_flags::evaluated | status_flags::expanded);
62 }
63
64 //////////
65 // other constructors
66 //////////
67
68 symmetry::symmetry(unsigned i) : inherited(&symmetry::tinfo_static), type(none)
69 {
70         indices.insert(i);
71         setflag(status_flags::evaluated | status_flags::expanded);
72 }
73
74 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) : inherited(&symmetry::tinfo_static), type(t)
75 {
76         add(c1); add(c2);
77         setflag(status_flags::evaluated | status_flags::expanded);
78 }
79
80 //////////
81 // archiving
82 //////////
83
84 /** Construct object from archive_node. */
85 symmetry::symmetry(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
86 {
87         unsigned t;
88         if (!(n.find_unsigned("type", t)))
89                 throw (std::runtime_error("unknown symmetry type in archive"));
90         type = (symmetry_type)t;
91
92         unsigned i = 0;
93         while (true) {
94                 ex e;
95                 if (n.find_ex("child", e, sym_lst, i))
96                         add(ex_to<symmetry>(e));
97                 else
98                         break;
99                 i++;
100         }
101
102         if (i == 0) {
103                 while (true) {
104                         unsigned u;
105                         if (n.find_unsigned("index", u, i))
106                                 indices.insert(u);
107                         else
108                                 break;
109                         i++;
110                 }
111         }
112 }
113
114 /** Archive the object. */
115 void symmetry::archive(archive_node &n) const
116 {
117         inherited::archive(n);
118
119         n.add_unsigned("type", type);
120
121         if (children.empty()) {
122                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
123                 while (i != iend) {
124                         n.add_unsigned("index", *i);
125                         i++;
126                 }
127         } else {
128                 exvector::const_iterator i = children.begin(), iend = children.end();
129                 while (i != iend) {
130                         n.add_ex("child", *i);
131                         i++;
132                 }
133         }
134 }
135
136 DEFAULT_UNARCHIVE(symmetry)
137
138 //////////
139 // functions overriding virtual functions from base classes
140 //////////
141
142 int symmetry::compare_same_type(const basic & other) const
143 {
144         GINAC_ASSERT(is_a<symmetry>(other));
145
146         // For archiving purposes we need to have an ordering of symmetries.
147         const symmetry &othersymm = ex_to<symmetry>(other);
148
149         // Compare type.
150         if (type > othersymm.type)
151                 return 1;
152         if (type < othersymm.type)
153                 return -1;
154
155         // Compare the index set.
156         size_t this_size = indices.size();
157         size_t that_size = othersymm.indices.size();
158         if (this_size > that_size)
159                 return 1;
160         if (this_size < that_size)
161                 return -1;
162         typedef std::set<unsigned>::iterator set_it;
163         set_it end = indices.end();
164         for (set_it i=indices.begin(),j=othersymm.indices.begin(); i!=end; ++i,++j) {
165                 if(*i < *j)
166                         return 1;
167                 if(*i > *j)
168                         return -1;
169         }
170
171         // Compare the children.
172         if (children.size() > othersymm.children.size())
173                 return 1;
174         if (children.size() < othersymm.children.size())
175                 return -1;
176         for (size_t i=0; i<children.size(); ++i) {
177                 int cmpval = ex_to<symmetry>(children[i])
178                         .compare_same_type(ex_to<symmetry>(othersymm.children[i]));
179                 if (cmpval)
180                         return cmpval;
181         }
182
183         return 0;
184 }
185
186 unsigned symmetry::calchash() const
187 {
188         unsigned v = golden_ratio_hash((p_int)tinfo());
189
190         if (type == none) {
191                 v = rotate_left(v);
192                 v ^= *(indices.begin());
193         } else {
194                 for (exvector::const_iterator i=children.begin(); i!=children.end(); ++i)
195                 {
196                         v = rotate_left(v);
197                         v ^= i->gethash();
198                 }
199         }
200
201         if (flags & status_flags::evaluated) {
202                 setflag(status_flags::hash_calculated);
203                 hashvalue = v;
204         }
205
206         return v;
207 }
208
209 void symmetry::do_print(const print_context & c, unsigned level) const
210 {
211         if (children.empty()) {
212                 if (indices.size() > 0)
213                         c.s << *(indices.begin());
214                 else
215                         c.s << "none";
216         } else {
217                 switch (type) {
218                         case none: c.s << '!'; break;
219                         case symmetric: c.s << '+'; break;
220                         case antisymmetric: c.s << '-'; break;
221                         case cyclic: c.s << '@'; break;
222                         default: c.s << '?'; break;
223                 }
224                 c.s << '(';
225                 size_t num = children.size();
226                 for (size_t i=0; i<num; i++) {
227                         children[i].print(c);
228                         if (i != num - 1)
229                                 c.s << ",";
230                 }
231                 c.s << ')';
232         }
233 }
234
235 void symmetry::do_print_tree(const print_tree & c, unsigned level) const
236 {
237         c.s << std::string(level, ' ') << class_name() << " @" << this
238             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
239             << ", type=";
240
241         switch (type) {
242                 case none: c.s << "none"; break;
243                 case symmetric: c.s << "symm"; break;
244                 case antisymmetric: c.s << "anti"; break;
245                 case cyclic: c.s << "cycl"; break;
246                 default: c.s << "<unknown>"; break;
247         }
248
249         c.s << ", indices=(";
250         if (!indices.empty()) {
251                 std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
252                 --end;
253                 while (i != end)
254                         c.s << *i++ << ",";
255                 c.s << *i;
256         }
257         c.s << ")\n";
258
259         exvector::const_iterator i = children.begin(), end = children.end();
260         while (i != end) {
261                 i->print(c, level + c.delta_indent);
262                 ++i;
263         }
264 }
265
266 //////////
267 // non-virtual functions in this class
268 //////////
269
270 bool symmetry::has_cyclic() const
271 {
272         if (type == cyclic)
273                 return true;
274
275         for (exvector::const_iterator i=children.begin(); i!=children.end(); ++i)
276                 if (ex_to<symmetry>(*i).has_cyclic())
277                         return true;
278
279         return false;
280 }
281
282 symmetry &symmetry::add(const symmetry &c)
283 {
284         // All children must have the same number of indices
285         if (type != none && !children.empty()) {
286                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
287                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
288                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
289         }
290
291         // Compute union of indices and check whether the two sets are disjoint
292         std::set<unsigned> un;
293         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
294         if (un.size() != indices.size() + c.indices.size())
295                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
296
297         // Set new index set
298         indices.swap(un);
299
300         // Add child node
301         children.push_back(c);
302         return *this;
303 }
304
305 void symmetry::validate(unsigned n)
306 {
307         if (indices.upper_bound(n - 1) != indices.end())
308                 throw (std::range_error("symmetry::verify(): index values are out of range"));
309         if (type != none && indices.empty()) {
310                 for (unsigned i=0; i<n; i++)
311                         add(i);
312         }
313 }
314
315 //////////
316 // global functions
317 //////////
318
319 static const symmetry & index0()
320 {
321         static ex s = (new symmetry(0))->setflag(status_flags::dynallocated);
322         return ex_to<symmetry>(s);
323 }
324
325 static const symmetry & index1()
326 {
327         static ex s = (new symmetry(1))->setflag(status_flags::dynallocated);
328         return ex_to<symmetry>(s);
329 }
330
331 static const symmetry & index2()
332 {
333         static ex s = (new symmetry(2))->setflag(status_flags::dynallocated);
334         return ex_to<symmetry>(s);
335 }
336
337 static const symmetry & index3()
338 {
339         static ex s = (new symmetry(3))->setflag(status_flags::dynallocated);
340         return ex_to<symmetry>(s);
341 }
342
343 const symmetry & not_symmetric()
344 {
345         static ex s = (new symmetry)->setflag(status_flags::dynallocated);
346         return ex_to<symmetry>(s);
347 }
348
349 const symmetry & symmetric2()
350 {
351         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->setflag(status_flags::dynallocated);
352         return ex_to<symmetry>(s);
353 }
354
355 const symmetry & symmetric3()
356 {
357         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
358         return ex_to<symmetry>(s);
359 }
360
361 const symmetry & symmetric4()
362 {
363         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
364         return ex_to<symmetry>(s);
365 }
366
367 const symmetry & antisymmetric2()
368 {
369         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->setflag(status_flags::dynallocated);
370         return ex_to<symmetry>(s);
371 }
372
373 const symmetry & antisymmetric3()
374 {
375         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
376         return ex_to<symmetry>(s);
377 }
378
379 const symmetry & antisymmetric4()
380 {
381         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
382         return ex_to<symmetry>(s);
383 }
384
385 class sy_is_less : public std::binary_function<ex, ex, bool> {
386         exvector::iterator v;
387
388 public:
389         sy_is_less(exvector::iterator v_) : v(v_) {}
390
391         bool operator() (const ex &lh, const ex &rh) const
392         {
393                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
394                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
395                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
396                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
397                 while (ait != aitend) {
398                         int cmpval = v[*ait].compare(v[*bit]);
399                         if (cmpval < 0)
400                                 return true;
401                         else if (cmpval > 0)
402                                 return false;
403                         ++ait; ++bit;
404                 }
405                 return false;
406         }
407 };
408
409 class sy_swap : public std::binary_function<ex, ex, void> {
410         exvector::iterator v;
411
412 public:
413         bool &swapped;
414
415         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
416
417         void operator() (const ex &lh, const ex &rh)
418         {
419                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
420                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
421                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
422                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
423                 while (ait != aitend) {
424                         v[*ait].swap(v[*bit]);
425                         ++ait; ++bit;
426                 }
427                 swapped = true;
428         }
429 };
430
431 int canonicalize(exvector::iterator v, const symmetry &symm)
432 {
433         // Less than two elements? Then do nothing
434         if (symm.indices.size() < 2)
435                 return std::numeric_limits<int>::max();
436
437         // Canonicalize children first
438         bool something_changed = false;
439         int sign = 1;
440         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
441         while (first != last) {
442                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
443                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
444                 if (child_sign == 0)
445                         return 0;
446                 if (child_sign != std::numeric_limits<int>::max()) {
447                         something_changed = true;
448                         sign *= child_sign;
449                 }
450                 first++;
451         }
452
453         // Now reorder the children
454         first = symm.children.begin();
455         switch (symm.type) {
456                 case symmetry::symmetric:
457                         // Sort the children in ascending order
458                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
459                         break;
460                 case symmetry::antisymmetric:
461                         // Sort the children in ascending order, keeping track of the signum
462                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
463                         if (sign == 0)
464                                 return 0;
465                         break;
466                 case symmetry::cyclic:
467                         // Permute the smallest child to the front
468                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
469                         break;
470                 default:
471                         break;
472         }
473         return something_changed ? sign : std::numeric_limits<int>::max();
474 }
475
476
477 // Symmetrize/antisymmetrize over a vector of objects
478 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
479 {
480         // Need at least 2 objects for this operation
481         unsigned num = last - first;
482         if (num < 2)
483                 return e;
484
485         // Transform object vector to a lst (for subs())
486         lst orig_lst(first, last);
487
488         // Create index vectors for permutation
489         unsigned *iv = new unsigned[num], *iv2;
490         for (unsigned i=0; i<num; i++)
491                 iv[i] = i;
492         iv2 = (asymmetric ? new unsigned[num] : NULL);
493
494         // Loop over all permutations (the first permutation, which is the
495         // identity, is unrolled)
496         ex sum = e;
497         while (std::next_permutation(iv, iv + num)) {
498                 lst new_lst;
499                 for (unsigned i=0; i<num; i++)
500                         new_lst.append(orig_lst.op(iv[i]));
501                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
502                 if (asymmetric) {
503                         memcpy(iv2, iv, num * sizeof(unsigned));
504                         term *= permutation_sign(iv2, iv2 + num);
505                 }
506                 sum += term;
507         }
508
509         delete[] iv;
510         delete[] iv2;
511
512         return sum / factorial(numeric(num));
513 }
514
515 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
516 {
517         return symm(e, first, last, false);
518 }
519
520 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
521 {
522         return symm(e, first, last, true);
523 }
524
525 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
526 {
527         // Need at least 2 objects for this operation
528         unsigned num = last - first;
529         if (num < 2)
530                 return e;
531
532         // Transform object vector to a lst (for subs())
533         lst orig_lst(first, last);
534         lst new_lst = orig_lst;
535
536         // Loop over all cyclic permutations (the first permutation, which is
537         // the identity, is unrolled)
538         ex sum = e;
539         for (unsigned i=0; i<num-1; i++) {
540                 ex perm = new_lst.op(0);
541                 new_lst.remove_first().append(perm);
542                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
543         }
544         return sum / num;
545 }
546
547 /** Symmetrize expression over a list of objects (symbols, indices). */
548 ex ex::symmetrize(const lst & l) const
549 {
550         exvector v(l.begin(), l.end());
551         return symm(*this, v.begin(), v.end(), false);
552 }
553
554 /** Antisymmetrize expression over a list of objects (symbols, indices). */
555 ex ex::antisymmetrize(const lst & l) const
556 {
557         exvector v(l.begin(), l.end());
558         return symm(*this, v.begin(), v.end(), true);
559 }
560
561 /** Symmetrize expression by cyclic permutation over a list of objects
562  *  (symbols, indices). */
563 ex ex::symmetrize_cyclic(const lst & l) const
564 {
565         exvector v(l.begin(), l.end());
566         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
567 }
568
569 } // namespace GiNaC