Univariate Hensel lifting now uses upoly.
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <functional>
26 #include <limits>
27
28 #include "symmetry.h"
29 #include "lst.h"
30 #include "numeric.h" // for factorial()
31 #include "operators.h"
32 #include "archive.h"
33 #include "utils.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(symmetry, basic,
38   print_func<print_context>(&symmetry::do_print).
39   print_func<print_tree>(&symmetry::do_print_tree))
40
41 /*
42    Some notes about the structure of a symmetry tree:
43     - The leaf nodes of the tree are of type "none", have one index, and no
44       children (of course). They are constructed by the symmetry(unsigned)
45       constructor.
46     - Leaf nodes are the only nodes that only have one index.
47     - Container nodes contain two or more children. The "indices" set member
48       is the set union of the index sets of all children, and the "children"
49       vector stores the children themselves.
50     - The index set of each child of a "symm", "anti" or "cycl" node must
51       have the same size. It follows that the children of such a node are
52       either all leaf nodes, or all container nodes with two or more indices.
53 */
54
55 //////////
56 // default constructor
57 //////////
58
59 symmetry::symmetry() :  type(none)
60 {
61         setflag(status_flags::evaluated | status_flags::expanded);
62 }
63
64 //////////
65 // other constructors
66 //////////
67
68 symmetry::symmetry(unsigned i) :  type(none)
69 {
70         indices.insert(i);
71         setflag(status_flags::evaluated | status_flags::expanded);
72 }
73
74 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) :  type(t)
75 {
76         add(c1); add(c2);
77         setflag(status_flags::evaluated | status_flags::expanded);
78 }
79
80 //////////
81 // archiving
82 //////////
83
84 /** Construct object from archive_node. */
85 void symmetry::read_archive(const archive_node &n, lst &sym_lst)
86 {
87         inherited::read_archive(n, sym_lst);
88         unsigned t;
89         if (!(n.find_unsigned("type", t)))
90                 throw (std::runtime_error("unknown symmetry type in archive"));
91         type = (symmetry_type)t;
92
93         unsigned i = 0;
94         while (true) {
95                 ex e;
96                 if (n.find_ex("child", e, sym_lst, i))
97                         add(ex_to<symmetry>(e));
98                 else
99                         break;
100                 i++;
101         }
102
103         if (i == 0) {
104                 while (true) {
105                         unsigned u;
106                         if (n.find_unsigned("index", u, i))
107                                 indices.insert(u);
108                         else
109                                 break;
110                         i++;
111                 }
112         }
113 }
114 GINAC_BIND_UNARCHIVER(symmetry);
115
116 /** Archive the object. */
117 void symmetry::archive(archive_node &n) const
118 {
119         inherited::archive(n);
120
121         n.add_unsigned("type", type);
122
123         if (children.empty()) {
124                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
125                 while (i != iend) {
126                         n.add_unsigned("index", *i);
127                         i++;
128                 }
129         } else {
130                 exvector::const_iterator i = children.begin(), iend = children.end();
131                 while (i != iend) {
132                         n.add_ex("child", *i);
133                         i++;
134                 }
135         }
136 }
137
138 //////////
139 // functions overriding virtual functions from base classes
140 //////////
141
142 int symmetry::compare_same_type(const basic & other) const
143 {
144         GINAC_ASSERT(is_a<symmetry>(other));
145
146         // For archiving purposes we need to have an ordering of symmetries.
147         const symmetry &othersymm = ex_to<symmetry>(other);
148
149         // Compare type.
150         if (type > othersymm.type)
151                 return 1;
152         if (type < othersymm.type)
153                 return -1;
154
155         // Compare the index set.
156         size_t this_size = indices.size();
157         size_t that_size = othersymm.indices.size();
158         if (this_size > that_size)
159                 return 1;
160         if (this_size < that_size)
161                 return -1;
162         typedef std::set<unsigned>::iterator set_it;
163         set_it end = indices.end();
164         for (set_it i=indices.begin(),j=othersymm.indices.begin(); i!=end; ++i,++j) {
165                 if(*i < *j)
166                         return 1;
167                 if(*i > *j)
168                         return -1;
169         }
170
171         // Compare the children.
172         if (children.size() > othersymm.children.size())
173                 return 1;
174         if (children.size() < othersymm.children.size())
175                 return -1;
176         for (size_t i=0; i<children.size(); ++i) {
177                 int cmpval = ex_to<symmetry>(children[i])
178                         .compare_same_type(ex_to<symmetry>(othersymm.children[i]));
179                 if (cmpval)
180                         return cmpval;
181         }
182
183         return 0;
184 }
185
186 unsigned symmetry::calchash() const
187 {
188         const void* this_tinfo = (const void*)typeid(*this).name();
189         unsigned v = golden_ratio_hash((p_int)this_tinfo);
190
191         if (type == none) {
192                 v = rotate_left(v);
193                 v ^= *(indices.begin());
194         } else {
195                 for (exvector::const_iterator i=children.begin(); i!=children.end(); ++i)
196                 {
197                         v = rotate_left(v);
198                         v ^= i->gethash();
199                 }
200         }
201
202         if (flags & status_flags::evaluated) {
203                 setflag(status_flags::hash_calculated);
204                 hashvalue = v;
205         }
206
207         return v;
208 }
209
210 void symmetry::do_print(const print_context & c, unsigned level) const
211 {
212         if (children.empty()) {
213                 if (indices.size() > 0)
214                         c.s << *(indices.begin());
215                 else
216                         c.s << "none";
217         } else {
218                 switch (type) {
219                         case none: c.s << '!'; break;
220                         case symmetric: c.s << '+'; break;
221                         case antisymmetric: c.s << '-'; break;
222                         case cyclic: c.s << '@'; break;
223                         default: c.s << '?'; break;
224                 }
225                 c.s << '(';
226                 size_t num = children.size();
227                 for (size_t i=0; i<num; i++) {
228                         children[i].print(c);
229                         if (i != num - 1)
230                                 c.s << ",";
231                 }
232                 c.s << ')';
233         }
234 }
235
236 void symmetry::do_print_tree(const print_tree & c, unsigned level) const
237 {
238         c.s << std::string(level, ' ') << class_name() << " @" << this
239             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
240             << ", type=";
241
242         switch (type) {
243                 case none: c.s << "none"; break;
244                 case symmetric: c.s << "symm"; break;
245                 case antisymmetric: c.s << "anti"; break;
246                 case cyclic: c.s << "cycl"; break;
247                 default: c.s << "<unknown>"; break;
248         }
249
250         c.s << ", indices=(";
251         if (!indices.empty()) {
252                 std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
253                 --end;
254                 while (i != end)
255                         c.s << *i++ << ",";
256                 c.s << *i;
257         }
258         c.s << ")\n";
259
260         exvector::const_iterator i = children.begin(), end = children.end();
261         while (i != end) {
262                 i->print(c, level + c.delta_indent);
263                 ++i;
264         }
265 }
266
267 //////////
268 // non-virtual functions in this class
269 //////////
270
271 bool symmetry::has_cyclic() const
272 {
273         if (type == cyclic)
274                 return true;
275
276         for (exvector::const_iterator i=children.begin(); i!=children.end(); ++i)
277                 if (ex_to<symmetry>(*i).has_cyclic())
278                         return true;
279
280         return false;
281 }
282
283 symmetry &symmetry::add(const symmetry &c)
284 {
285         // All children must have the same number of indices
286         if (type != none && !children.empty()) {
287                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
288                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
289                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
290         }
291
292         // Compute union of indices and check whether the two sets are disjoint
293         std::set<unsigned> un;
294         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
295         if (un.size() != indices.size() + c.indices.size())
296                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
297
298         // Set new index set
299         indices.swap(un);
300
301         // Add child node
302         children.push_back(c);
303         return *this;
304 }
305
306 void symmetry::validate(unsigned n)
307 {
308         if (indices.upper_bound(n - 1) != indices.end())
309                 throw (std::range_error("symmetry::verify(): index values are out of range"));
310         if (type != none && indices.empty()) {
311                 for (unsigned i=0; i<n; i++)
312                         add(i);
313         }
314 }
315
316 //////////
317 // global functions
318 //////////
319
320 static const symmetry & index0()
321 {
322         static ex s = (new symmetry(0))->setflag(status_flags::dynallocated);
323         return ex_to<symmetry>(s);
324 }
325
326 static const symmetry & index1()
327 {
328         static ex s = (new symmetry(1))->setflag(status_flags::dynallocated);
329         return ex_to<symmetry>(s);
330 }
331
332 static const symmetry & index2()
333 {
334         static ex s = (new symmetry(2))->setflag(status_flags::dynallocated);
335         return ex_to<symmetry>(s);
336 }
337
338 static const symmetry & index3()
339 {
340         static ex s = (new symmetry(3))->setflag(status_flags::dynallocated);
341         return ex_to<symmetry>(s);
342 }
343
344 const symmetry & not_symmetric()
345 {
346         static ex s = (new symmetry)->setflag(status_flags::dynallocated);
347         return ex_to<symmetry>(s);
348 }
349
350 const symmetry & symmetric2()
351 {
352         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->setflag(status_flags::dynallocated);
353         return ex_to<symmetry>(s);
354 }
355
356 const symmetry & symmetric3()
357 {
358         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
359         return ex_to<symmetry>(s);
360 }
361
362 const symmetry & symmetric4()
363 {
364         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
365         return ex_to<symmetry>(s);
366 }
367
368 const symmetry & antisymmetric2()
369 {
370         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->setflag(status_flags::dynallocated);
371         return ex_to<symmetry>(s);
372 }
373
374 const symmetry & antisymmetric3()
375 {
376         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
377         return ex_to<symmetry>(s);
378 }
379
380 const symmetry & antisymmetric4()
381 {
382         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
383         return ex_to<symmetry>(s);
384 }
385
386 class sy_is_less : public std::binary_function<ex, ex, bool> {
387         exvector::iterator v;
388
389 public:
390         sy_is_less(exvector::iterator v_) : v(v_) {}
391
392         bool operator() (const ex &lh, const ex &rh) const
393         {
394                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
395                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
396                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
397                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
398                 while (ait != aitend) {
399                         int cmpval = v[*ait].compare(v[*bit]);
400                         if (cmpval < 0)
401                                 return true;
402                         else if (cmpval > 0)
403                                 return false;
404                         ++ait; ++bit;
405                 }
406                 return false;
407         }
408 };
409
410 class sy_swap : public std::binary_function<ex, ex, void> {
411         exvector::iterator v;
412
413 public:
414         bool &swapped;
415
416         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
417
418         void operator() (const ex &lh, const ex &rh)
419         {
420                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
421                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
422                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
423                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
424                 while (ait != aitend) {
425                         v[*ait].swap(v[*bit]);
426                         ++ait; ++bit;
427                 }
428                 swapped = true;
429         }
430 };
431
432 int canonicalize(exvector::iterator v, const symmetry &symm)
433 {
434         // Less than two elements? Then do nothing
435         if (symm.indices.size() < 2)
436                 return std::numeric_limits<int>::max();
437
438         // Canonicalize children first
439         bool something_changed = false;
440         int sign = 1;
441         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
442         while (first != last) {
443                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
444                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
445                 if (child_sign == 0)
446                         return 0;
447                 if (child_sign != std::numeric_limits<int>::max()) {
448                         something_changed = true;
449                         sign *= child_sign;
450                 }
451                 first++;
452         }
453
454         // Now reorder the children
455         first = symm.children.begin();
456         switch (symm.type) {
457                 case symmetry::symmetric:
458                         // Sort the children in ascending order
459                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
460                         break;
461                 case symmetry::antisymmetric:
462                         // Sort the children in ascending order, keeping track of the signum
463                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
464                         if (sign == 0)
465                                 return 0;
466                         break;
467                 case symmetry::cyclic:
468                         // Permute the smallest child to the front
469                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
470                         break;
471                 default:
472                         break;
473         }
474         return something_changed ? sign : std::numeric_limits<int>::max();
475 }
476
477
478 // Symmetrize/antisymmetrize over a vector of objects
479 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
480 {
481         // Need at least 2 objects for this operation
482         unsigned num = last - first;
483         if (num < 2)
484                 return e;
485
486         // Transform object vector to a lst (for subs())
487         lst orig_lst(first, last);
488
489         // Create index vectors for permutation
490         unsigned *iv = new unsigned[num], *iv2;
491         for (unsigned i=0; i<num; i++)
492                 iv[i] = i;
493         iv2 = (asymmetric ? new unsigned[num] : NULL);
494
495         // Loop over all permutations (the first permutation, which is the
496         // identity, is unrolled)
497         ex sum = e;
498         while (std::next_permutation(iv, iv + num)) {
499                 lst new_lst;
500                 for (unsigned i=0; i<num; i++)
501                         new_lst.append(orig_lst.op(iv[i]));
502                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
503                 if (asymmetric) {
504                         memcpy(iv2, iv, num * sizeof(unsigned));
505                         term *= permutation_sign(iv2, iv2 + num);
506                 }
507                 sum += term;
508         }
509
510         delete[] iv;
511         delete[] iv2;
512
513         return sum / factorial(numeric(num));
514 }
515
516 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
517 {
518         return symm(e, first, last, false);
519 }
520
521 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
522 {
523         return symm(e, first, last, true);
524 }
525
526 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
527 {
528         // Need at least 2 objects for this operation
529         unsigned num = last - first;
530         if (num < 2)
531                 return e;
532
533         // Transform object vector to a lst (for subs())
534         lst orig_lst(first, last);
535         lst new_lst = orig_lst;
536
537         // Loop over all cyclic permutations (the first permutation, which is
538         // the identity, is unrolled)
539         ex sum = e;
540         for (unsigned i=0; i<num-1; i++) {
541                 ex perm = new_lst.op(0);
542                 new_lst.remove_first().append(perm);
543                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
544         }
545         return sum / num;
546 }
547
548 /** Symmetrize expression over a list of objects (symbols, indices). */
549 ex ex::symmetrize(const lst & l) const
550 {
551         exvector v(l.begin(), l.end());
552         return symm(*this, v.begin(), v.end(), false);
553 }
554
555 /** Antisymmetrize expression over a list of objects (symbols, indices). */
556 ex ex::antisymmetrize(const lst & l) const
557 {
558         exvector v(l.begin(), l.end());
559         return symm(*this, v.begin(), v.end(), true);
560 }
561
562 /** Symmetrize expression by cyclic permutation over a list of objects
563  *  (symbols, indices). */
564 ex ex::symmetrize_cyclic(const lst & l) const
565 {
566         exvector v(l.begin(), l.end());
567         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
568 }
569
570 } // namespace GiNaC