]> www.ginac.de Git - ginac.git/blob - ginac/symmetry.cpp
G_eval: fix incorrect use of STL iterator.
[ginac.git] / ginac / symmetry.cpp
1 /** @file symmetry.cpp
2  *
3  *  Implementation of GiNaC's symmetry definitions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2009 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "symmetry.h"
24 #include "lst.h"
25 #include "numeric.h" // for factorial()
26 #include "operators.h"
27 #include "archive.h"
28 #include "utils.h"
29 #include "hash_seed.h"
30
31 #include <functional>
32 #include <iostream>
33 #include <limits>
34 #include <stdexcept>
35
36 namespace GiNaC {
37
38 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(symmetry, basic,
39   print_func<print_context>(&symmetry::do_print).
40   print_func<print_tree>(&symmetry::do_print_tree))
41
42 /*
43    Some notes about the structure of a symmetry tree:
44     - The leaf nodes of the tree are of type "none", have one index, and no
45       children (of course). They are constructed by the symmetry(unsigned)
46       constructor.
47     - Leaf nodes are the only nodes that only have one index.
48     - Container nodes contain two or more children. The "indices" set member
49       is the set union of the index sets of all children, and the "children"
50       vector stores the children themselves.
51     - The index set of each child of a "symm", "anti" or "cycl" node must
52       have the same size. It follows that the children of such a node are
53       either all leaf nodes, or all container nodes with two or more indices.
54 */
55
56 //////////
57 // default constructor
58 //////////
59
60 symmetry::symmetry() :  type(none)
61 {
62         setflag(status_flags::evaluated | status_flags::expanded);
63 }
64
65 //////////
66 // other constructors
67 //////////
68
69 symmetry::symmetry(unsigned i) :  type(none)
70 {
71         indices.insert(i);
72         setflag(status_flags::evaluated | status_flags::expanded);
73 }
74
75 symmetry::symmetry(symmetry_type t, const symmetry &c1, const symmetry &c2) :  type(t)
76 {
77         add(c1); add(c2);
78         setflag(status_flags::evaluated | status_flags::expanded);
79 }
80
81 //////////
82 // archiving
83 //////////
84
85 /** Construct object from archive_node. */
86 void symmetry::read_archive(const archive_node &n, lst &sym_lst)
87 {
88         inherited::read_archive(n, sym_lst);
89         unsigned t;
90         if (!(n.find_unsigned("type", t)))
91                 throw (std::runtime_error("unknown symmetry type in archive"));
92         type = (symmetry_type)t;
93
94         unsigned i = 0;
95         while (true) {
96                 ex e;
97                 if (n.find_ex("child", e, sym_lst, i))
98                         add(ex_to<symmetry>(e));
99                 else
100                         break;
101                 i++;
102         }
103
104         if (i == 0) {
105                 while (true) {
106                         unsigned u;
107                         if (n.find_unsigned("index", u, i))
108                                 indices.insert(u);
109                         else
110                                 break;
111                         i++;
112                 }
113         }
114 }
115 GINAC_BIND_UNARCHIVER(symmetry);
116
117 /** Archive the object. */
118 void symmetry::archive(archive_node &n) const
119 {
120         inherited::archive(n);
121
122         n.add_unsigned("type", type);
123
124         if (children.empty()) {
125                 std::set<unsigned>::const_iterator i = indices.begin(), iend = indices.end();
126                 while (i != iend) {
127                         n.add_unsigned("index", *i);
128                         i++;
129                 }
130         } else {
131                 exvector::const_iterator i = children.begin(), iend = children.end();
132                 while (i != iend) {
133                         n.add_ex("child", *i);
134                         i++;
135                 }
136         }
137 }
138
139 //////////
140 // functions overriding virtual functions from base classes
141 //////////
142
143 int symmetry::compare_same_type(const basic & other) const
144 {
145         GINAC_ASSERT(is_a<symmetry>(other));
146
147         // For archiving purposes we need to have an ordering of symmetries.
148         const symmetry &othersymm = ex_to<symmetry>(other);
149
150         // Compare type.
151         if (type > othersymm.type)
152                 return 1;
153         if (type < othersymm.type)
154                 return -1;
155
156         // Compare the index set.
157         size_t this_size = indices.size();
158         size_t that_size = othersymm.indices.size();
159         if (this_size > that_size)
160                 return 1;
161         if (this_size < that_size)
162                 return -1;
163         typedef std::set<unsigned>::iterator set_it;
164         set_it end = indices.end();
165         for (set_it i=indices.begin(),j=othersymm.indices.begin(); i!=end; ++i,++j) {
166                 if(*i < *j)
167                         return 1;
168                 if(*i > *j)
169                         return -1;
170         }
171
172         // Compare the children.
173         if (children.size() > othersymm.children.size())
174                 return 1;
175         if (children.size() < othersymm.children.size())
176                 return -1;
177         for (size_t i=0; i<children.size(); ++i) {
178                 int cmpval = ex_to<symmetry>(children[i])
179                         .compare_same_type(ex_to<symmetry>(othersymm.children[i]));
180                 if (cmpval)
181                         return cmpval;
182         }
183
184         return 0;
185 }
186
187 unsigned symmetry::calchash() const
188 {
189         unsigned v = make_hash_seed(typeid(*this));
190
191         if (type == none) {
192                 v = rotate_left(v);
193                 v ^= *(indices.begin());
194         } else {
195                 for (exvector::const_iterator i=children.begin(); i!=children.end(); ++i)
196                 {
197                         v = rotate_left(v);
198                         v ^= i->gethash();
199                 }
200         }
201
202         if (flags & status_flags::evaluated) {
203                 setflag(status_flags::hash_calculated);
204                 hashvalue = v;
205         }
206
207         return v;
208 }
209
210 void symmetry::do_print(const print_context & c, unsigned level) const
211 {
212         if (children.empty()) {
213                 if (indices.size() > 0)
214                         c.s << *(indices.begin());
215                 else
216                         c.s << "none";
217         } else {
218                 switch (type) {
219                         case none: c.s << '!'; break;
220                         case symmetric: c.s << '+'; break;
221                         case antisymmetric: c.s << '-'; break;
222                         case cyclic: c.s << '@'; break;
223                         default: c.s << '?'; break;
224                 }
225                 c.s << '(';
226                 size_t num = children.size();
227                 for (size_t i=0; i<num; i++) {
228                         children[i].print(c);
229                         if (i != num - 1)
230                                 c.s << ",";
231                 }
232                 c.s << ')';
233         }
234 }
235
236 void symmetry::do_print_tree(const print_tree & c, unsigned level) const
237 {
238         c.s << std::string(level, ' ') << class_name() << " @" << this
239             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
240             << ", type=";
241
242         switch (type) {
243                 case none: c.s << "none"; break;
244                 case symmetric: c.s << "symm"; break;
245                 case antisymmetric: c.s << "anti"; break;
246                 case cyclic: c.s << "cycl"; break;
247                 default: c.s << "<unknown>"; break;
248         }
249
250         c.s << ", indices=(";
251         if (!indices.empty()) {
252                 std::set<unsigned>::const_iterator i = indices.begin(), end = indices.end();
253                 --end;
254                 while (i != end)
255                         c.s << *i++ << ",";
256                 c.s << *i;
257         }
258         c.s << ")\n";
259
260         exvector::const_iterator i = children.begin(), end = children.end();
261         while (i != end) {
262                 i->print(c, level + c.delta_indent);
263                 ++i;
264         }
265 }
266
267 //////////
268 // non-virtual functions in this class
269 //////////
270
271 bool symmetry::has_cyclic() const
272 {
273         if (type == cyclic)
274                 return true;
275
276         for (exvector::const_iterator i=children.begin(); i!=children.end(); ++i)
277                 if (ex_to<symmetry>(*i).has_cyclic())
278                         return true;
279
280         return false;
281 }
282
283 symmetry &symmetry::add(const symmetry &c)
284 {
285         // All children must have the same number of indices
286         if (type != none && !children.empty()) {
287                 GINAC_ASSERT(is_exactly_a<symmetry>(children[0]));
288                 if (ex_to<symmetry>(children[0]).indices.size() != c.indices.size())
289                         throw (std::logic_error("symmetry:add(): children must have same number of indices"));
290         }
291
292         // Compute union of indices and check whether the two sets are disjoint
293         std::set<unsigned> un;
294         set_union(indices.begin(), indices.end(), c.indices.begin(), c.indices.end(), inserter(un, un.begin()));
295         if (un.size() != indices.size() + c.indices.size())
296                 throw (std::logic_error("symmetry::add(): the same index appears in more than one child"));
297
298         // Set new index set
299         indices.swap(un);
300
301         // Add child node
302         children.push_back(c);
303         return *this;
304 }
305
306 void symmetry::validate(unsigned n)
307 {
308         if (indices.upper_bound(n - 1) != indices.end())
309                 throw (std::range_error("symmetry::verify(): index values are out of range"));
310         if (type != none && indices.empty()) {
311                 for (unsigned i=0; i<n; i++)
312                         add(i);
313         }
314 }
315
316 //////////
317 // global functions
318 //////////
319
320 static const symmetry & index0()
321 {
322         static ex s = (new symmetry(0))->setflag(status_flags::dynallocated);
323         return ex_to<symmetry>(s);
324 }
325
326 static const symmetry & index1()
327 {
328         static ex s = (new symmetry(1))->setflag(status_flags::dynallocated);
329         return ex_to<symmetry>(s);
330 }
331
332 static const symmetry & index2()
333 {
334         static ex s = (new symmetry(2))->setflag(status_flags::dynallocated);
335         return ex_to<symmetry>(s);
336 }
337
338 static const symmetry & index3()
339 {
340         static ex s = (new symmetry(3))->setflag(status_flags::dynallocated);
341         return ex_to<symmetry>(s);
342 }
343
344 const symmetry & not_symmetric()
345 {
346         static ex s = (new symmetry)->setflag(status_flags::dynallocated);
347         return ex_to<symmetry>(s);
348 }
349
350 const symmetry & symmetric2()
351 {
352         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->setflag(status_flags::dynallocated);
353         return ex_to<symmetry>(s);
354 }
355
356 const symmetry & symmetric3()
357 {
358         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
359         return ex_to<symmetry>(s);
360 }
361
362 const symmetry & symmetric4()
363 {
364         static ex s = (new symmetry(symmetry::symmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
365         return ex_to<symmetry>(s);
366 }
367
368 const symmetry & antisymmetric2()
369 {
370         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->setflag(status_flags::dynallocated);
371         return ex_to<symmetry>(s);
372 }
373
374 const symmetry & antisymmetric3()
375 {
376         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).setflag(status_flags::dynallocated);
377         return ex_to<symmetry>(s);
378 }
379
380 const symmetry & antisymmetric4()
381 {
382         static ex s = (new symmetry(symmetry::antisymmetric, index0(), index1()))->add(index2()).add(index3()).setflag(status_flags::dynallocated);
383         return ex_to<symmetry>(s);
384 }
385
386 class sy_is_less : public std::binary_function<ex, ex, bool> {
387         exvector::iterator v;
388
389 public:
390         sy_is_less(exvector::iterator v_) : v(v_) {}
391
392         bool operator() (const ex &lh, const ex &rh) const
393         {
394                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
395                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
396                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
397                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
398                 while (ait != aitend) {
399                         int cmpval = v[*ait].compare(v[*bit]);
400                         if (cmpval < 0)
401                                 return true;
402                         else if (cmpval > 0)
403                                 return false;
404                         ++ait; ++bit;
405                 }
406                 return false;
407         }
408 };
409
410 class sy_swap : public std::binary_function<ex, ex, void> {
411         exvector::iterator v;
412
413 public:
414         bool &swapped;
415
416         sy_swap(exvector::iterator v_, bool &s) : v(v_), swapped(s) {}
417
418         void operator() (const ex &lh, const ex &rh)
419         {
420                 GINAC_ASSERT(is_exactly_a<symmetry>(lh));
421                 GINAC_ASSERT(is_exactly_a<symmetry>(rh));
422                 GINAC_ASSERT(ex_to<symmetry>(lh).indices.size() == ex_to<symmetry>(rh).indices.size());
423                 std::set<unsigned>::const_iterator ait = ex_to<symmetry>(lh).indices.begin(), aitend = ex_to<symmetry>(lh).indices.end(), bit = ex_to<symmetry>(rh).indices.begin();
424                 while (ait != aitend) {
425                         v[*ait].swap(v[*bit]);
426                         ++ait; ++bit;
427                 }
428                 swapped = true;
429         }
430 };
431
432 int canonicalize(exvector::iterator v, const symmetry &symm)
433 {
434         // Less than two elements? Then do nothing
435         if (symm.indices.size() < 2)
436                 return std::numeric_limits<int>::max();
437
438         // Canonicalize children first
439         bool something_changed = false;
440         int sign = 1;
441         exvector::const_iterator first = symm.children.begin(), last = symm.children.end();
442         while (first != last) {
443                 GINAC_ASSERT(is_exactly_a<symmetry>(*first));
444                 int child_sign = canonicalize(v, ex_to<symmetry>(*first));
445                 if (child_sign == 0)
446                         return 0;
447                 if (child_sign != std::numeric_limits<int>::max()) {
448                         something_changed = true;
449                         sign *= child_sign;
450                 }
451                 first++;
452         }
453
454         // Now reorder the children
455         first = symm.children.begin();
456         switch (symm.type) {
457                 case symmetry::symmetric:
458                         // Sort the children in ascending order
459                         shaker_sort(first, last, sy_is_less(v), sy_swap(v, something_changed));
460                         break;
461                 case symmetry::antisymmetric:
462                         // Sort the children in ascending order, keeping track of the signum
463                         sign *= permutation_sign(first, last, sy_is_less(v), sy_swap(v, something_changed));
464                         if (sign == 0)
465                                 return 0;
466                         break;
467                 case symmetry::cyclic:
468                         // Permute the smallest child to the front
469                         cyclic_permutation(first, last, min_element(first, last, sy_is_less(v)), sy_swap(v, something_changed));
470                         break;
471                 default:
472                         break;
473         }
474         return something_changed ? sign : std::numeric_limits<int>::max();
475 }
476
477
478 // Symmetrize/antisymmetrize over a vector of objects
479 static ex symm(const ex & e, exvector::const_iterator first, exvector::const_iterator last, bool asymmetric)
480 {
481         // Need at least 2 objects for this operation
482         unsigned num = last - first;
483         if (num < 2)
484                 return e;
485
486         // Transform object vector to a lst (for subs())
487         lst orig_lst(first, last);
488
489         // Create index vectors for permutation
490         unsigned *iv = new unsigned[num], *iv2;
491         for (unsigned i=0; i<num; i++)
492                 iv[i] = i;
493         iv2 = (asymmetric ? new unsigned[num] : NULL);
494
495         // Loop over all permutations (the first permutation, which is the
496         // identity, is unrolled)
497         ex sum = e;
498         while (std::next_permutation(iv, iv + num)) {
499                 lst new_lst;
500                 for (unsigned i=0; i<num; i++)
501                         new_lst.append(orig_lst.op(iv[i]));
502                 ex term = e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
503                 if (asymmetric) {
504                         memcpy(iv2, iv, num * sizeof(unsigned));
505                         term *= permutation_sign(iv2, iv2 + num);
506                 }
507                 sum += term;
508         }
509
510         delete[] iv;
511         delete[] iv2;
512
513         return sum / factorial(numeric(num));
514 }
515
516 ex symmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
517 {
518         return symm(e, first, last, false);
519 }
520
521 ex antisymmetrize(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
522 {
523         return symm(e, first, last, true);
524 }
525
526 ex symmetrize_cyclic(const ex & e, exvector::const_iterator first, exvector::const_iterator last)
527 {
528         // Need at least 2 objects for this operation
529         unsigned num = last - first;
530         if (num < 2)
531                 return e;
532
533         // Transform object vector to a lst (for subs())
534         lst orig_lst(first, last);
535         lst new_lst = orig_lst;
536
537         // Loop over all cyclic permutations (the first permutation, which is
538         // the identity, is unrolled)
539         ex sum = e;
540         for (unsigned i=0; i<num-1; i++) {
541                 ex perm = new_lst.op(0);
542                 new_lst.remove_first().append(perm);
543                 sum += e.subs(orig_lst, new_lst, subs_options::no_pattern|subs_options::no_index_renaming);
544         }
545         return sum / num;
546 }
547
548 /** Symmetrize expression over a list of objects (symbols, indices). */
549 ex ex::symmetrize(const lst & l) const
550 {
551         exvector v(l.begin(), l.end());
552         return symm(*this, v.begin(), v.end(), false);
553 }
554
555 /** Antisymmetrize expression over a list of objects (symbols, indices). */
556 ex ex::antisymmetrize(const lst & l) const
557 {
558         exvector v(l.begin(), l.end());
559         return symm(*this, v.begin(), v.end(), true);
560 }
561
562 /** Symmetrize expression by cyclic permutation over a list of objects
563  *  (symbols, indices). */
564 ex ex::symmetrize_cyclic(const lst & l) const
565 {
566         exvector v(l.begin(), l.end());
567         return GiNaC::symmetrize_cyclic(*this, v.begin(), v.end());
568 }
569
570 } // namespace GiNaC