]> www.ginac.de Git - ginac.git/blob - ginac/ncmul.cpp
symbol: make get_domain() a virtual method, remove symbol::domain.
[ginac.git] / ginac / ncmul.cpp
1 /** @file ncmul.cpp
2  *
3  *  Implementation of GiNaC's non-commutative products of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <algorithm>
24 #include <iostream>
25 #include <stdexcept>
26
27 #include "ncmul.h"
28 #include "ex.h"
29 #include "add.h"
30 #include "mul.h"
31 #include "clifford.h"
32 #include "matrix.h"
33 #include "archive.h"
34 #include "indexed.h"
35 #include "utils.h"
36
37 namespace GiNaC {
38
39 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(ncmul, exprseq,
40   print_func<print_context>(&ncmul::do_print).
41   print_func<print_tree>(&ncmul::do_print_tree).
42   print_func<print_csrc>(&ncmul::do_print_csrc).
43   print_func<print_python_repr>(&ncmul::do_print_csrc))
44
45
46 //////////
47 // default constructor
48 //////////
49
50 ncmul::ncmul()
51 {
52 }
53
54 //////////
55 // other constructors
56 //////////
57
58 // public
59
60 ncmul::ncmul(const ex & lh, const ex & rh) : inherited(lh,rh)
61 {
62 }
63
64 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3) : inherited(f1,f2,f3)
65 {
66 }
67
68 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
69              const ex & f4) : inherited(f1,f2,f3,f4)
70 {
71 }
72
73 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
74              const ex & f4, const ex & f5) : inherited(f1,f2,f3,f4,f5)
75 {
76 }
77
78 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
79              const ex & f4, const ex & f5, const ex & f6) : inherited(f1,f2,f3,f4,f5,f6)
80 {
81 }
82
83 ncmul::ncmul(const exvector & v, bool discardable) : inherited(v,discardable)
84 {
85 }
86
87 ncmul::ncmul(std::auto_ptr<exvector> vp) : inherited(vp)
88 {
89 }
90
91 //////////
92 // archiving
93 //////////
94
95 DEFAULT_ARCHIVING(ncmul)
96         
97 //////////
98 // functions overriding virtual functions from base classes
99 //////////
100
101 // public
102
103 void ncmul::do_print(const print_context & c, unsigned level) const
104 {
105         printseq(c, '(', '*', ')', precedence(), level);
106 }
107
108 void ncmul::do_print_csrc(const print_context & c, unsigned level) const
109 {
110         c.s << class_name();
111         printseq(c, '(', ',', ')', precedence(), precedence());
112 }
113
114 bool ncmul::info(unsigned inf) const
115 {
116         return inherited::info(inf);
117 }
118
119 typedef std::vector<std::size_t> uintvector;
120
121 ex ncmul::expand(unsigned options) const
122 {
123         // First, expand the children
124         std::auto_ptr<exvector> vp = expandchildren(options);
125         const exvector &expanded_seq = vp.get() ? *vp : this->seq;
126         
127         // Now, look for all the factors that are sums and remember their
128         // position and number of terms.
129         uintvector positions_of_adds(expanded_seq.size());
130         uintvector number_of_add_operands(expanded_seq.size());
131
132         size_t number_of_adds = 0;
133         size_t number_of_expanded_terms = 1;
134
135         size_t current_position = 0;
136         exvector::const_iterator last = expanded_seq.end();
137         for (exvector::const_iterator cit=expanded_seq.begin(); cit!=last; ++cit) {
138                 if (is_exactly_a<add>(*cit)) {
139                         positions_of_adds[number_of_adds] = current_position;
140                         size_t num_ops = cit->nops();
141                         number_of_add_operands[number_of_adds] = num_ops;
142                         number_of_expanded_terms *= num_ops;
143                         number_of_adds++;
144                 }
145                 ++current_position;
146         }
147
148         // If there are no sums, we are done
149         if (number_of_adds == 0) {
150                 if (vp.get())
151                         return (new ncmul(vp))->
152                                 setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
153                 else
154                         return *this;
155         }
156
157         // Now, form all possible products of the terms of the sums with the
158         // remaining factors, and add them together
159         exvector distrseq;
160         distrseq.reserve(number_of_expanded_terms);
161
162         uintvector k(number_of_adds);
163
164         /* Rename indices in the static members of the product */
165         exvector expanded_seq_mod;
166         size_t j = 0;
167         exvector va;
168
169         for (size_t i=0; i<expanded_seq.size(); i++) {
170                 if (i == positions_of_adds[j]) {
171                         expanded_seq_mod.push_back(_ex1);
172                         j++;
173                 } else {
174                         expanded_seq_mod.push_back(rename_dummy_indices_uniquely(va, expanded_seq[i], true));
175                 }
176         }
177
178         while (true) {
179                 exvector term = expanded_seq_mod;
180                 for (size_t i=0; i<number_of_adds; i++) {
181                         term[positions_of_adds[i]] = rename_dummy_indices_uniquely(va, expanded_seq[positions_of_adds[i]].op(k[i]), true);
182                 }
183
184                 distrseq.push_back((new ncmul(term, true))->
185                                     setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));
186
187                 // increment k[]
188                 int l = number_of_adds-1;
189                 while ((l>=0) && ((++k[l]) >= number_of_add_operands[l])) {
190                         k[l] = 0;
191                         l--;
192                 }
193                 if (l<0)
194                         break;
195         }
196
197         return (new add(distrseq))->
198                 setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
199 }
200
201 int ncmul::degree(const ex & s) const
202 {
203         if (is_equal(ex_to<basic>(s)))
204                 return 1;
205
206         // Sum up degrees of factors
207         int deg_sum = 0;
208         exvector::const_iterator i = seq.begin(), end = seq.end();
209         while (i != end) {
210                 deg_sum += i->degree(s);
211                 ++i;
212         }
213         return deg_sum;
214 }
215
216 int ncmul::ldegree(const ex & s) const
217 {
218         if (is_equal(ex_to<basic>(s)))
219                 return 1;
220
221         // Sum up degrees of factors
222         int deg_sum = 0;
223         exvector::const_iterator i = seq.begin(), end = seq.end();
224         while (i != end) {
225                 deg_sum += i->degree(s);
226                 ++i;
227         }
228         return deg_sum;
229 }
230
231 ex ncmul::coeff(const ex & s, int n) const
232 {
233         if (is_equal(ex_to<basic>(s)))
234                 return n==1 ? _ex1 : _ex0;
235
236         exvector coeffseq;
237         coeffseq.reserve(seq.size());
238
239         if (n == 0) {
240                 // product of individual coeffs
241                 // if a non-zero power of s is found, the resulting product will be 0
242                 exvector::const_iterator it=seq.begin();
243                 while (it!=seq.end()) {
244                         coeffseq.push_back((*it).coeff(s,n));
245                         ++it;
246                 }
247                 return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
248         }
249                  
250         exvector::const_iterator i = seq.begin(), end = seq.end();
251         bool coeff_found = false;
252         while (i != end) {
253                 ex c = i->coeff(s,n);
254                 if (c.is_zero()) {
255                         coeffseq.push_back(*i);
256                 } else {
257                         coeffseq.push_back(c);
258                         coeff_found = true;
259                 }
260                 ++i;
261         }
262
263         if (coeff_found) return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
264         
265         return _ex0;
266 }
267
268 size_t ncmul::count_factors(const ex & e) const
269 {
270         if ((is_exactly_a<mul>(e)&&(e.return_type()!=return_types::commutative))||
271                 (is_exactly_a<ncmul>(e))) {
272                 size_t factors=0;
273                 for (size_t i=0; i<e.nops(); i++)
274                         factors += count_factors(e.op(i));
275                 
276                 return factors;
277         }
278         return 1;
279 }
280                 
281 void ncmul::append_factors(exvector & v, const ex & e) const
282 {
283         if ((is_exactly_a<mul>(e)&&(e.return_type()!=return_types::commutative))||
284                 (is_exactly_a<ncmul>(e))) {
285                 for (size_t i=0; i<e.nops(); i++)
286                         append_factors(v, e.op(i));
287         } else 
288                 v.push_back(e);
289 }
290
291 typedef std::vector<unsigned> unsignedvector;
292 typedef std::vector<exvector> exvectorvector;
293
294 /** Perform automatic term rewriting rules in this class.  In the following
295  *  x, x1, x2,... stand for a symbolic variables of type ex and c, c1, c2...
296  *  stand for such expressions that contain a plain number.
297  *  - ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) -> ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
298  *  - ncmul(x) -> x
299  *  - ncmul() -> 1
300  *  - ncmul(...,c1,...,c2,...) -> *(c1,c2,ncmul(...))  (pull out commutative elements)
301  *  - ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))  (collect elements of same type)
302  *  - ncmul(x1,x2,x3,...) -> x::eval_ncmul(x1,x2,x3,...)
303  *
304  *  @param level cut-off in recursive evaluation */
305 ex ncmul::eval(int level) const
306 {
307         // The following additional rule would be nice, but produces a recursion,
308         // which must be trapped by introducing a flag that the sub-ncmuls()
309         // are already evaluated (maybe later...)
310         //                  ncmul(x1,x2,...,X,y1,y2,...) ->
311         //                      ncmul(ncmul(x1,x2,...),X,ncmul(y1,y2,...)
312         //                      (X noncommutative_composite)
313
314         if ((level==1) && (flags & status_flags::evaluated)) {
315                 return *this;
316         }
317
318         exvector evaledseq=evalchildren(level);
319
320         // ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
321         //     ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
322         size_t factors = 0;
323         exvector::const_iterator cit = evaledseq.begin(), citend = evaledseq.end();
324         while (cit != citend)
325                 factors += count_factors(*cit++);
326         
327         exvector assocseq;
328         assocseq.reserve(factors);
329         cit = evaledseq.begin();
330         make_flat_inserter mf(evaledseq, true);
331         while (cit != citend)
332         {       ex factor = mf.handle_factor(*(cit++), 1);
333                 append_factors(assocseq, factor);
334         }
335         
336         // ncmul(x) -> x
337         if (assocseq.size()==1) return *(seq.begin());
338  
339         // ncmul() -> 1
340         if (assocseq.empty()) return _ex1;
341
342         // determine return types
343         unsignedvector rettypes;
344         rettypes.reserve(assocseq.size());
345         size_t i = 0;
346         size_t count_commutative=0;
347         size_t count_noncommutative=0;
348         size_t count_noncommutative_composite=0;
349         cit = assocseq.begin(); citend = assocseq.end();
350         while (cit != citend) {
351                 switch (rettypes[i] = cit->return_type()) {
352                 case return_types::commutative:
353                         count_commutative++;
354                         break;
355                 case return_types::noncommutative:
356                         count_noncommutative++;
357                         break;
358                 case return_types::noncommutative_composite:
359                         count_noncommutative_composite++;
360                         break;
361                 default:
362                         throw(std::logic_error("ncmul::eval(): invalid return type"));
363                 }
364                 ++i; ++cit;
365         }
366         GINAC_ASSERT(count_commutative+count_noncommutative+count_noncommutative_composite==assocseq.size());
367
368         // ncmul(...,c1,...,c2,...) ->
369         //     *(c1,c2,ncmul(...)) (pull out commutative elements)
370         if (count_commutative!=0) {
371                 exvector commutativeseq;
372                 commutativeseq.reserve(count_commutative+1);
373                 exvector noncommutativeseq;
374                 noncommutativeseq.reserve(assocseq.size()-count_commutative);
375                 size_t num = assocseq.size();
376                 for (size_t i=0; i<num; ++i) {
377                         if (rettypes[i]==return_types::commutative)
378                                 commutativeseq.push_back(assocseq[i]);
379                         else
380                                 noncommutativeseq.push_back(assocseq[i]);
381                 }
382                 commutativeseq.push_back((new ncmul(noncommutativeseq,1))->setflag(status_flags::dynallocated));
383                 return (new mul(commutativeseq))->setflag(status_flags::dynallocated);
384         }
385                 
386         // ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
387         //     (collect elements of same type)
388
389         if (count_noncommutative_composite==0) {
390                 // there are neither commutative nor noncommutative_composite
391                 // elements in assocseq
392                 GINAC_ASSERT(count_commutative==0);
393
394                 size_t assoc_num = assocseq.size();
395                 exvectorvector evv;
396                 std::vector<return_type_t> rttinfos;
397                 evv.reserve(assoc_num);
398                 rttinfos.reserve(assoc_num);
399
400                 cit = assocseq.begin(), citend = assocseq.end();
401                 while (cit != citend) {
402                         return_type_t ti = cit->return_type_tinfo();
403                         size_t rtt_num = rttinfos.size();
404                         // search type in vector of known types
405                         for (i=0; i<rtt_num; ++i) {
406                                 if(ti == rttinfos[i]) {
407                                         evv[i].push_back(*cit);
408                                         break;
409                                 }
410                         }
411                         if (i >= rtt_num) {
412                                 // new type
413                                 rttinfos.push_back(ti);
414                                 evv.push_back(exvector());
415                                 (evv.end()-1)->reserve(assoc_num);
416                                 (evv.end()-1)->push_back(*cit);
417                         }
418                         ++cit;
419                 }
420
421                 size_t evv_num = evv.size();
422 #ifdef DO_GINAC_ASSERT
423                 GINAC_ASSERT(evv_num == rttinfos.size());
424                 GINAC_ASSERT(evv_num > 0);
425                 size_t s=0;
426                 for (i=0; i<evv_num; ++i)
427                         s += evv[i].size();
428                 GINAC_ASSERT(s == assoc_num);
429 #endif // def DO_GINAC_ASSERT
430                 
431                 // if all elements are of same type, simplify the string
432                 if (evv_num == 1) {
433                         return evv[0][0].eval_ncmul(evv[0]);
434                 }
435                 
436                 exvector splitseq;
437                 splitseq.reserve(evv_num);
438                 for (i=0; i<evv_num; ++i)
439                         splitseq.push_back((new ncmul(evv[i]))->setflag(status_flags::dynallocated));
440                 
441                 return (new mul(splitseq))->setflag(status_flags::dynallocated);
442         }
443         
444         return (new ncmul(assocseq))->setflag(status_flags::dynallocated |
445                                                                                   status_flags::evaluated);
446 }
447
448 ex ncmul::evalm() const
449 {
450         // Evaluate children first
451         std::auto_ptr<exvector> s(new exvector);
452         s->reserve(seq.size());
453         exvector::const_iterator it = seq.begin(), itend = seq.end();
454         while (it != itend) {
455                 s->push_back(it->evalm());
456                 it++;
457         }
458
459         // If there are only matrices, simply multiply them
460         it = s->begin(); itend = s->end();
461         if (is_a<matrix>(*it)) {
462                 matrix prod(ex_to<matrix>(*it));
463                 it++;
464                 while (it != itend) {
465                         if (!is_a<matrix>(*it))
466                                 goto no_matrix;
467                         prod = prod.mul(ex_to<matrix>(*it));
468                         it++;
469                 }
470                 return prod;
471         }
472
473 no_matrix:
474         return (new ncmul(s))->setflag(status_flags::dynallocated);
475 }
476
477 ex ncmul::thiscontainer(const exvector & v) const
478 {
479         return (new ncmul(v))->setflag(status_flags::dynallocated);
480 }
481
482 ex ncmul::thiscontainer(std::auto_ptr<exvector> vp) const
483 {
484         return (new ncmul(vp))->setflag(status_flags::dynallocated);
485 }
486
487 ex ncmul::conjugate() const
488 {
489         if (return_type() != return_types::noncommutative) {
490                 return exprseq::conjugate();
491         }
492
493         if (!is_clifford_tinfo(return_type_tinfo())) {
494                 return exprseq::conjugate();
495         }
496
497         exvector ev;
498         ev.reserve(nops());
499         for (const_iterator i=end(); i!=begin();) {
500                 --i;
501                 ev.push_back(i->conjugate());
502         }
503         return (new ncmul(ev, true))->setflag(status_flags::dynallocated).eval();
504 }
505
506 ex ncmul::real_part() const
507 {
508         return basic::real_part();
509 }
510
511 ex ncmul::imag_part() const
512 {
513         return basic::imag_part();
514 }
515
516 // protected
517
518 /** Implementation of ex::diff() for a non-commutative product. It applies
519  *  the product rule.
520  *  @see ex::diff */
521 ex ncmul::derivative(const symbol & s) const
522 {
523         size_t num = seq.size();
524         exvector addseq;
525         addseq.reserve(num);
526         
527         // D(a*b*c) = D(a)*b*c + a*D(b)*c + a*b*D(c)
528         exvector ncmulseq = seq;
529         for (size_t i=0; i<num; ++i) {
530                 ex e = seq[i].diff(s);
531                 e.swap(ncmulseq[i]);
532                 addseq.push_back((new ncmul(ncmulseq))->setflag(status_flags::dynallocated));
533                 e.swap(ncmulseq[i]);
534         }
535         return (new add(addseq))->setflag(status_flags::dynallocated);
536 }
537
538 int ncmul::compare_same_type(const basic & other) const
539 {
540         return inherited::compare_same_type(other);
541 }
542
543 unsigned ncmul::return_type() const
544 {
545         if (seq.empty())
546                 return return_types::commutative;
547
548         bool all_commutative = true;
549         exvector::const_iterator noncommutative_element; // point to first found nc element
550
551         exvector::const_iterator i = seq.begin(), end = seq.end();
552         while (i != end) {
553                 unsigned rt = i->return_type();
554                 if (rt == return_types::noncommutative_composite)
555                         return rt; // one ncc -> mul also ncc
556                 if ((rt == return_types::noncommutative) && (all_commutative)) {
557                         // first nc element found, remember position
558                         noncommutative_element = i;
559                         all_commutative = false;
560                 }
561                 if ((rt == return_types::noncommutative) && (!all_commutative)) {
562                         // another nc element found, compare type_infos
563                         if(noncommutative_element->return_type_tinfo() != i->return_type_tinfo())
564                                         return return_types::noncommutative_composite;
565                 }
566                 ++i;
567         }
568         // all factors checked
569         GINAC_ASSERT(!all_commutative); // not all factors should commutate, because this is a ncmul();
570         return all_commutative ? return_types::commutative : return_types::noncommutative;
571 }
572    
573 return_type_t ncmul::return_type_tinfo() const
574 {
575         if (seq.empty())
576                 return make_return_type_t<ncmul>();
577
578         // return type_info of first noncommutative element
579         exvector::const_iterator i = seq.begin(), end = seq.end();
580         while (i != end) {
581                 if (i->return_type() == return_types::noncommutative)
582                         return i->return_type_tinfo();
583                 ++i;
584         }
585
586         // no noncommutative element found, should not happen
587         return make_return_type_t<ncmul>();
588 }
589
590 //////////
591 // new virtual functions which can be overridden by derived classes
592 //////////
593
594 // none
595
596 //////////
597 // non-virtual functions in this class
598 //////////
599
600 std::auto_ptr<exvector> ncmul::expandchildren(unsigned options) const
601 {
602         const_iterator cit = this->seq.begin(), end = this->seq.end();
603         while (cit != end) {
604                 const ex & expanded_ex = cit->expand(options);
605                 if (!are_ex_trivially_equal(*cit, expanded_ex)) {
606
607                         // copy first part of seq which hasn't changed
608                         std::auto_ptr<exvector> s(new exvector(this->seq.begin(), cit));
609                         reserve(*s, this->seq.size());
610
611                         // insert changed element
612                         s->push_back(expanded_ex);
613                         ++cit;
614
615                         // copy rest
616                         while (cit != end) {
617                                 s->push_back(cit->expand(options));
618                                 ++cit;
619                         }
620
621                         return s;
622                 }
623
624                 ++cit;
625         }
626
627         return std::auto_ptr<exvector>(0); // nothing has changed
628 }
629
630 const exvector & ncmul::get_factors() const
631 {
632         return seq;
633 }
634
635 //////////
636 // friend functions
637 //////////
638
639 ex reeval_ncmul(const exvector & v)
640 {
641         return (new ncmul(v))->setflag(status_flags::dynallocated);
642 }
643
644 ex hold_ncmul(const exvector & v)
645 {
646         if (v.empty())
647                 return _ex1;
648         else if (v.size() == 1)
649                 return v[0];
650         else
651                 return (new ncmul(v))->setflag(status_flags::dynallocated |
652                                                status_flags::evaluated);
653 }
654
655 } // namespace GiNaC