obvious patch from Chris Dams
[ginac.git] / ginac / ncmul.cpp
1 /** @file ncmul.cpp
2  *
3  *  Implementation of GiNaC's non-commutative products of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2004 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <algorithm>
24 #include <iostream>
25 #include <stdexcept>
26
27 #include "ncmul.h"
28 #include "ex.h"
29 #include "add.h"
30 #include "mul.h"
31 #include "matrix.h"
32 #include "archive.h"
33 #include "utils.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(ncmul, exprseq,
38   print_func<print_context>(&ncmul::do_print).
39   print_func<print_tree>(&ncmul::do_print_tree).
40   print_func<print_csrc>(&ncmul::do_print_csrc).
41   print_func<print_python_repr>(&ncmul::do_print_csrc))
42
43
44 //////////
45 // default constructor
46 //////////
47
48 ncmul::ncmul()
49 {
50         tinfo_key = TINFO_ncmul;
51 }
52
53 //////////
54 // other constructors
55 //////////
56
57 // public
58
59 ncmul::ncmul(const ex & lh, const ex & rh) : inherited(lh,rh)
60 {
61         tinfo_key = TINFO_ncmul;
62 }
63
64 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3) : inherited(f1,f2,f3)
65 {
66         tinfo_key = TINFO_ncmul;
67 }
68
69 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
70              const ex & f4) : inherited(f1,f2,f3,f4)
71 {
72         tinfo_key = TINFO_ncmul;
73 }
74
75 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
76              const ex & f4, const ex & f5) : inherited(f1,f2,f3,f4,f5)
77 {
78         tinfo_key = TINFO_ncmul;
79 }
80
81 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
82              const ex & f4, const ex & f5, const ex & f6) : inherited(f1,f2,f3,f4,f5,f6)
83 {
84         tinfo_key = TINFO_ncmul;
85 }
86
87 ncmul::ncmul(const exvector & v, bool discardable) : inherited(v,discardable)
88 {
89         tinfo_key = TINFO_ncmul;
90 }
91
92 ncmul::ncmul(std::auto_ptr<exvector> vp) : inherited(vp)
93 {
94         tinfo_key = TINFO_ncmul;
95 }
96
97 //////////
98 // archiving
99 //////////
100
101 DEFAULT_ARCHIVING(ncmul)
102         
103 //////////
104 // functions overriding virtual functions from base classes
105 //////////
106
107 // public
108
109 void ncmul::do_print(const print_context & c, unsigned level) const
110 {
111         printseq(c, '(', '*', ')', precedence(), level);
112 }
113
114 void ncmul::do_print_csrc(const print_context & c, unsigned level) const
115 {
116         c.s << class_name();
117         printseq(c, '(', ',', ')', precedence(), precedence());
118 }
119
120 bool ncmul::info(unsigned inf) const
121 {
122         return inherited::info(inf);
123 }
124
125 typedef std::vector<int> intvector;
126
127 ex ncmul::expand(unsigned options) const
128 {
129         // First, expand the children
130         std::auto_ptr<exvector> vp = expandchildren(options);
131         const exvector &expanded_seq = vp.get() ? *vp : this->seq;
132         
133         // Now, look for all the factors that are sums and remember their
134         // position and number of terms.
135         intvector positions_of_adds(expanded_seq.size());
136         intvector number_of_add_operands(expanded_seq.size());
137
138         size_t number_of_adds = 0;
139         size_t number_of_expanded_terms = 1;
140
141         size_t current_position = 0;
142         exvector::const_iterator last = expanded_seq.end();
143         for (exvector::const_iterator cit=expanded_seq.begin(); cit!=last; ++cit) {
144                 if (is_exactly_a<add>(*cit)) {
145                         positions_of_adds[number_of_adds] = current_position;
146                         size_t num_ops = cit->nops();
147                         number_of_add_operands[number_of_adds] = num_ops;
148                         number_of_expanded_terms *= num_ops;
149                         number_of_adds++;
150                 }
151                 ++current_position;
152         }
153
154         // If there are no sums, we are done
155         if (number_of_adds == 0) {
156                 if (vp.get())
157                         return (new ncmul(vp))->
158                                 setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
159                 else
160                         return *this;
161         }
162
163         // Now, form all possible products of the terms of the sums with the
164         // remaining factors, and add them together
165         exvector distrseq;
166         distrseq.reserve(number_of_expanded_terms);
167
168         intvector k(number_of_adds);
169
170         while (true) {
171                 exvector term = expanded_seq;
172                 for (size_t i=0; i<number_of_adds; i++)
173                         term[positions_of_adds[i]] = expanded_seq[positions_of_adds[i]].op(k[i]);
174                 distrseq.push_back((new ncmul(term, true))->
175                                     setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));
176
177                 // increment k[]
178                 int l = number_of_adds-1;
179                 while ((l>=0) && ((++k[l]) >= number_of_add_operands[l])) {
180                         k[l] = 0;
181                         l--;
182                 }
183                 if (l<0)
184                         break;
185         }
186
187         return (new add(distrseq))->
188                 setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
189 }
190
191 int ncmul::degree(const ex & s) const
192 {
193         // Sum up degrees of factors
194         int deg_sum = 0;
195         exvector::const_iterator i = seq.begin(), end = seq.end();
196         while (i != end) {
197                 deg_sum += i->degree(s);
198                 ++i;
199         }
200         return deg_sum;
201 }
202
203 int ncmul::ldegree(const ex & s) const
204 {
205         // Sum up degrees of factors
206         int deg_sum = 0;
207         exvector::const_iterator i = seq.begin(), end = seq.end();
208         while (i != end) {
209                 deg_sum += i->degree(s);
210                 ++i;
211         }
212         return deg_sum;
213 }
214
215 ex ncmul::coeff(const ex & s, int n) const
216 {
217         exvector coeffseq;
218         coeffseq.reserve(seq.size());
219
220         if (n == 0) {
221                 // product of individual coeffs
222                 // if a non-zero power of s is found, the resulting product will be 0
223                 exvector::const_iterator it=seq.begin();
224                 while (it!=seq.end()) {
225                         coeffseq.push_back((*it).coeff(s,n));
226                         ++it;
227                 }
228                 return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
229         }
230                  
231         exvector::const_iterator i = seq.begin(), end = seq.end();
232         bool coeff_found = false;
233         while (i != end) {
234                 ex c = i->coeff(s,n);
235                 if (c.is_zero()) {
236                         coeffseq.push_back(*i);
237                 } else {
238                         coeffseq.push_back(c);
239                         coeff_found = true;
240                 }
241                 ++i;
242         }
243
244         if (coeff_found) return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
245         
246         return _ex0;
247 }
248
249 size_t ncmul::count_factors(const ex & e) const
250 {
251         if ((is_exactly_a<mul>(e)&&(e.return_type()!=return_types::commutative))||
252                 (is_exactly_a<ncmul>(e))) {
253                 size_t factors=0;
254                 for (size_t i=0; i<e.nops(); i++)
255                         factors += count_factors(e.op(i));
256                 
257                 return factors;
258         }
259         return 1;
260 }
261                 
262 void ncmul::append_factors(exvector & v, const ex & e) const
263 {
264         if ((is_exactly_a<mul>(e)&&(e.return_type()!=return_types::commutative))||
265                 (is_exactly_a<ncmul>(e))) {
266                 for (size_t i=0; i<e.nops(); i++)
267                         append_factors(v, e.op(i));
268         } else 
269                 v.push_back(e);
270 }
271
272 typedef std::vector<unsigned> unsignedvector;
273 typedef std::vector<exvector> exvectorvector;
274
275 /** Perform automatic term rewriting rules in this class.  In the following
276  *  x, x1, x2,... stand for a symbolic variables of type ex and c, c1, c2...
277  *  stand for such expressions that contain a plain number.
278  *  - ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) -> ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
279  *  - ncmul(x) -> x
280  *  - ncmul() -> 1
281  *  - ncmul(...,c1,...,c2,...) -> *(c1,c2,ncmul(...))  (pull out commutative elements)
282  *  - ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))  (collect elements of same type)
283  *  - ncmul(x1,x2,x3,...) -> x::eval_ncmul(x1,x2,x3,...)
284  *
285  *  @param level cut-off in recursive evaluation */
286 ex ncmul::eval(int level) const
287 {
288         // The following additional rule would be nice, but produces a recursion,
289         // which must be trapped by introducing a flag that the sub-ncmuls()
290         // are already evaluated (maybe later...)
291         //                  ncmul(x1,x2,...,X,y1,y2,...) ->
292         //                      ncmul(ncmul(x1,x2,...),X,ncmul(y1,y2,...)
293         //                      (X noncommutative_composite)
294
295         if ((level==1) && (flags & status_flags::evaluated)) {
296                 return *this;
297         }
298
299         exvector evaledseq=evalchildren(level);
300
301         // ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
302         //     ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
303         size_t factors = 0;
304         exvector::const_iterator cit = evaledseq.begin(), citend = evaledseq.end();
305         while (cit != citend)
306                 factors += count_factors(*cit++);
307         
308         exvector assocseq;
309         assocseq.reserve(factors);
310         cit = evaledseq.begin();
311         while (cit != citend)
312                 append_factors(assocseq, *cit++);
313         
314         // ncmul(x) -> x
315         if (assocseq.size()==1) return *(seq.begin());
316
317         // ncmul() -> 1
318         if (assocseq.empty()) return _ex1;
319
320         // determine return types
321         unsignedvector rettypes;
322         rettypes.reserve(assocseq.size());
323         size_t i = 0;
324         size_t count_commutative=0;
325         size_t count_noncommutative=0;
326         size_t count_noncommutative_composite=0;
327         cit = assocseq.begin(); citend = assocseq.end();
328         while (cit != citend) {
329                 switch (rettypes[i] = cit->return_type()) {
330                 case return_types::commutative:
331                         count_commutative++;
332                         break;
333                 case return_types::noncommutative:
334                         count_noncommutative++;
335                         break;
336                 case return_types::noncommutative_composite:
337                         count_noncommutative_composite++;
338                         break;
339                 default:
340                         throw(std::logic_error("ncmul::eval(): invalid return type"));
341                 }
342                 ++i; ++cit;
343         }
344         GINAC_ASSERT(count_commutative+count_noncommutative+count_noncommutative_composite==assocseq.size());
345
346         // ncmul(...,c1,...,c2,...) ->
347         //     *(c1,c2,ncmul(...)) (pull out commutative elements)
348         if (count_commutative!=0) {
349                 exvector commutativeseq;
350                 commutativeseq.reserve(count_commutative+1);
351                 exvector noncommutativeseq;
352                 noncommutativeseq.reserve(assocseq.size()-count_commutative);
353                 size_t num = assocseq.size();
354                 for (size_t i=0; i<num; ++i) {
355                         if (rettypes[i]==return_types::commutative)
356                                 commutativeseq.push_back(assocseq[i]);
357                         else
358                                 noncommutativeseq.push_back(assocseq[i]);
359                 }
360                 commutativeseq.push_back((new ncmul(noncommutativeseq,1))->setflag(status_flags::dynallocated));
361                 return (new mul(commutativeseq))->setflag(status_flags::dynallocated);
362         }
363                 
364         // ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
365         //     (collect elements of same type)
366
367         if (count_noncommutative_composite==0) {
368                 // there are neither commutative nor noncommutative_composite
369                 // elements in assocseq
370                 GINAC_ASSERT(count_commutative==0);
371
372                 size_t assoc_num = assocseq.size();
373                 exvectorvector evv;
374                 unsignedvector rttinfos;
375                 evv.reserve(assoc_num);
376                 rttinfos.reserve(assoc_num);
377
378                 cit = assocseq.begin(), citend = assocseq.end();
379                 while (cit != citend) {
380                         unsigned ti = cit->return_type_tinfo();
381                         size_t rtt_num = rttinfos.size();
382                         // search type in vector of known types
383                         for (i=0; i<rtt_num; ++i) {
384                                 if (ti == rttinfos[i]) {
385                                         evv[i].push_back(*cit);
386                                         break;
387                                 }
388                         }
389                         if (i >= rtt_num) {
390                                 // new type
391                                 rttinfos.push_back(ti);
392                                 evv.push_back(exvector());
393                                 (evv.end()-1)->reserve(assoc_num);
394                                 (evv.end()-1)->push_back(*cit);
395                         }
396                         ++cit;
397                 }
398
399                 size_t evv_num = evv.size();
400 #ifdef DO_GINAC_ASSERT
401                 GINAC_ASSERT(evv_num == rttinfos.size());
402                 GINAC_ASSERT(evv_num > 0);
403                 size_t s=0;
404                 for (i=0; i<evv_num; ++i)
405                         s += evv[i].size();
406                 GINAC_ASSERT(s == assoc_num);
407 #endif // def DO_GINAC_ASSERT
408                 
409                 // if all elements are of same type, simplify the string
410                 if (evv_num == 1)
411                         return evv[0][0].eval_ncmul(evv[0]);
412                 
413                 exvector splitseq;
414                 splitseq.reserve(evv_num);
415                 for (i=0; i<evv_num; ++i)
416                         splitseq.push_back((new ncmul(evv[i]))->setflag(status_flags::dynallocated));
417                 
418                 return (new mul(splitseq))->setflag(status_flags::dynallocated);
419         }
420         
421         return (new ncmul(assocseq))->setflag(status_flags::dynallocated |
422                                                                                   status_flags::evaluated);
423 }
424
425 ex ncmul::evalm() const
426 {
427         // Evaluate children first
428         std::auto_ptr<exvector> s(new exvector);
429         s->reserve(seq.size());
430         exvector::const_iterator it = seq.begin(), itend = seq.end();
431         while (it != itend) {
432                 s->push_back(it->evalm());
433                 it++;
434         }
435
436         // If there are only matrices, simply multiply them
437         it = s->begin(); itend = s->end();
438         if (is_a<matrix>(*it)) {
439                 matrix prod(ex_to<matrix>(*it));
440                 it++;
441                 while (it != itend) {
442                         if (!is_a<matrix>(*it))
443                                 goto no_matrix;
444                         prod = prod.mul(ex_to<matrix>(*it));
445                         it++;
446                 }
447                 return prod;
448         }
449
450 no_matrix:
451         return (new ncmul(s))->setflag(status_flags::dynallocated);
452 }
453
454 ex ncmul::thiscontainer(const exvector & v) const
455 {
456         return (new ncmul(v))->setflag(status_flags::dynallocated);
457 }
458
459 ex ncmul::thiscontainer(std::auto_ptr<exvector> vp) const
460 {
461         return (new ncmul(vp))->setflag(status_flags::dynallocated);
462 }
463
464 ex ncmul::conjugate() const
465 {
466         if (return_type() != return_types::noncommutative) {
467                 return exprseq::conjugate();
468         }
469
470         if (return_type_tinfo() & 0xffffff00U != TINFO_clifford) {
471                 return exprseq::conjugate();
472         }
473
474         exvector ev;
475         ev.reserve(nops());
476         for (const_iterator i=end(); i!=begin();) {
477                 --i;
478                 ev.push_back(i->conjugate());
479         }
480         return (new ncmul(ev, true))->setflag(status_flags::dynallocated).eval();
481 }
482
483 // protected
484
485 /** Implementation of ex::diff() for a non-commutative product. It applies
486  *  the product rule.
487  *  @see ex::diff */
488 ex ncmul::derivative(const symbol & s) const
489 {
490         size_t num = seq.size();
491         exvector addseq;
492         addseq.reserve(num);
493         
494         // D(a*b*c) = D(a)*b*c + a*D(b)*c + a*b*D(c)
495         exvector ncmulseq = seq;
496         for (size_t i=0; i<num; ++i) {
497                 ex e = seq[i].diff(s);
498                 e.swap(ncmulseq[i]);
499                 addseq.push_back((new ncmul(ncmulseq))->setflag(status_flags::dynallocated));
500                 e.swap(ncmulseq[i]);
501         }
502         return (new add(addseq))->setflag(status_flags::dynallocated);
503 }
504
505 int ncmul::compare_same_type(const basic & other) const
506 {
507         return inherited::compare_same_type(other);
508 }
509
510 unsigned ncmul::return_type() const
511 {
512         if (seq.empty())
513                 return return_types::commutative;
514
515         bool all_commutative = true;
516         exvector::const_iterator noncommutative_element; // point to first found nc element
517
518         exvector::const_iterator i = seq.begin(), end = seq.end();
519         while (i != end) {
520                 unsigned rt = i->return_type();
521                 if (rt == return_types::noncommutative_composite)
522                         return rt; // one ncc -> mul also ncc
523                 if ((rt == return_types::noncommutative) && (all_commutative)) {
524                         // first nc element found, remember position
525                         noncommutative_element = i;
526                         all_commutative = false;
527                 }
528                 if ((rt == return_types::noncommutative) && (!all_commutative)) {
529                         // another nc element found, compare type_infos
530                         if (noncommutative_element->return_type_tinfo() != i->return_type_tinfo()) {
531                                 // diffent types -> mul is ncc
532                                 return return_types::noncommutative_composite;
533                         }
534                 }
535                 ++i;
536         }
537         // all factors checked
538         GINAC_ASSERT(!all_commutative); // not all factors should commute, because this is a ncmul();
539         return all_commutative ? return_types::commutative : return_types::noncommutative;
540 }
541    
542 unsigned ncmul::return_type_tinfo() const
543 {
544         if (seq.empty())
545                 return tinfo_key;
546
547         // return type_info of first noncommutative element
548         exvector::const_iterator i = seq.begin(), end = seq.end();
549         while (i != end) {
550                 if (i->return_type() == return_types::noncommutative)
551                         return i->return_type_tinfo();
552                 ++i;
553         }
554
555         // no noncommutative element found, should not happen
556         return tinfo_key;
557 }
558
559 //////////
560 // new virtual functions which can be overridden by derived classes
561 //////////
562
563 // none
564
565 //////////
566 // non-virtual functions in this class
567 //////////
568
569 std::auto_ptr<exvector> ncmul::expandchildren(unsigned options) const
570 {
571         const_iterator cit = this->seq.begin(), end = this->seq.end();
572         while (cit != end) {
573                 const ex & expanded_ex = cit->expand(options);
574                 if (!are_ex_trivially_equal(*cit, expanded_ex)) {
575
576                         // copy first part of seq which hasn't changed
577                         std::auto_ptr<exvector> s(new exvector(this->seq.begin(), cit));
578                         reserve(*s, this->seq.size());
579
580                         // insert changed element
581                         s->push_back(expanded_ex);
582                         ++cit;
583
584                         // copy rest
585                         while (cit != end) {
586                                 s->push_back(cit->expand(options));
587                                 ++cit;
588                         }
589
590                         return s;
591                 }
592
593                 ++cit;
594         }
595
596         return std::auto_ptr<exvector>(0); // nothing has changed
597 }
598
599 const exvector & ncmul::get_factors() const
600 {
601         return seq;
602 }
603
604 //////////
605 // friend functions
606 //////////
607
608 ex reeval_ncmul(const exvector & v)
609 {
610         return (new ncmul(v))->setflag(status_flags::dynallocated);
611 }
612
613 ex hold_ncmul(const exvector & v)
614 {
615         if (v.empty())
616                 return _ex1;
617         else if (v.size() == 1)
618                 return v[0];
619         else
620                 return (new ncmul(v))->setflag(status_flags::dynallocated |
621                                                status_flags::evaluated);
622 }
623
624 } // namespace GiNaC