4424253ce4369c9a1ebb5c271d12e500ca08638b
[ginac.git] / ginac / ncmul.cpp
1 /** @file ncmul.cpp
2  *
3  *  Implementation of GiNaC's non-commutative products of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2005 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <algorithm>
24 #include <iostream>
25 #include <stdexcept>
26
27 #include "ncmul.h"
28 #include "ex.h"
29 #include "add.h"
30 #include "mul.h"
31 #include "matrix.h"
32 #include "archive.h"
33 #include "utils.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(ncmul, exprseq,
38   print_func<print_context>(&ncmul::do_print).
39   print_func<print_tree>(&ncmul::do_print_tree).
40   print_func<print_csrc>(&ncmul::do_print_csrc).
41   print_func<print_python_repr>(&ncmul::do_print_csrc))
42
43
44 //////////
45 // default constructor
46 //////////
47
48 ncmul::ncmul()
49 {
50         tinfo_key = TINFO_ncmul;
51 }
52
53 //////////
54 // other constructors
55 //////////
56
57 // public
58
59 ncmul::ncmul(const ex & lh, const ex & rh) : inherited(lh,rh)
60 {
61         tinfo_key = TINFO_ncmul;
62 }
63
64 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3) : inherited(f1,f2,f3)
65 {
66         tinfo_key = TINFO_ncmul;
67 }
68
69 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
70              const ex & f4) : inherited(f1,f2,f3,f4)
71 {
72         tinfo_key = TINFO_ncmul;
73 }
74
75 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
76              const ex & f4, const ex & f5) : inherited(f1,f2,f3,f4,f5)
77 {
78         tinfo_key = TINFO_ncmul;
79 }
80
81 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
82              const ex & f4, const ex & f5, const ex & f6) : inherited(f1,f2,f3,f4,f5,f6)
83 {
84         tinfo_key = TINFO_ncmul;
85 }
86
87 ncmul::ncmul(const exvector & v, bool discardable) : inherited(v,discardable)
88 {
89         tinfo_key = TINFO_ncmul;
90 }
91
92 ncmul::ncmul(std::auto_ptr<exvector> vp) : inherited(vp)
93 {
94         tinfo_key = TINFO_ncmul;
95 }
96
97 //////////
98 // archiving
99 //////////
100
101 DEFAULT_ARCHIVING(ncmul)
102         
103 //////////
104 // functions overriding virtual functions from base classes
105 //////////
106
107 // public
108
109 void ncmul::do_print(const print_context & c, unsigned level) const
110 {
111         printseq(c, '(', '*', ')', precedence(), level);
112 }
113
114 void ncmul::do_print_csrc(const print_context & c, unsigned level) const
115 {
116         c.s << class_name();
117         printseq(c, '(', ',', ')', precedence(), precedence());
118 }
119
120 bool ncmul::info(unsigned inf) const
121 {
122         return inherited::info(inf);
123 }
124
125 typedef std::vector<int> intvector;
126
127 ex ncmul::expand(unsigned options) const
128 {
129         // First, expand the children
130         std::auto_ptr<exvector> vp = expandchildren(options);
131         const exvector &expanded_seq = vp.get() ? *vp : this->seq;
132         
133         // Now, look for all the factors that are sums and remember their
134         // position and number of terms.
135         intvector positions_of_adds(expanded_seq.size());
136         intvector number_of_add_operands(expanded_seq.size());
137
138         size_t number_of_adds = 0;
139         size_t number_of_expanded_terms = 1;
140
141         size_t current_position = 0;
142         exvector::const_iterator last = expanded_seq.end();
143         for (exvector::const_iterator cit=expanded_seq.begin(); cit!=last; ++cit) {
144                 if (is_exactly_a<add>(*cit)) {
145                         positions_of_adds[number_of_adds] = current_position;
146                         size_t num_ops = cit->nops();
147                         number_of_add_operands[number_of_adds] = num_ops;
148                         number_of_expanded_terms *= num_ops;
149                         number_of_adds++;
150                 }
151                 ++current_position;
152         }
153
154         // If there are no sums, we are done
155         if (number_of_adds == 0) {
156                 if (vp.get())
157                         return (new ncmul(vp))->
158                                 setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
159                 else
160                         return *this;
161         }
162
163         // Now, form all possible products of the terms of the sums with the
164         // remaining factors, and add them together
165         exvector distrseq;
166         distrseq.reserve(number_of_expanded_terms);
167
168         intvector k(number_of_adds);
169
170         while (true) {
171                 exvector term = expanded_seq;
172                 for (size_t i=0; i<number_of_adds; i++)
173                         term[positions_of_adds[i]] = expanded_seq[positions_of_adds[i]].op(k[i]);
174                 distrseq.push_back((new ncmul(term, true))->
175                                     setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));
176
177                 // increment k[]
178                 int l = number_of_adds-1;
179                 while ((l>=0) && ((++k[l]) >= number_of_add_operands[l])) {
180                         k[l] = 0;
181                         l--;
182                 }
183                 if (l<0)
184                         break;
185         }
186
187         return (new add(distrseq))->
188                 setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
189 }
190
191 int ncmul::degree(const ex & s) const
192 {
193         if (is_equal(ex_to<basic>(s)))
194                 return 1;
195
196         // Sum up degrees of factors
197         int deg_sum = 0;
198         exvector::const_iterator i = seq.begin(), end = seq.end();
199         while (i != end) {
200                 deg_sum += i->degree(s);
201                 ++i;
202         }
203         return deg_sum;
204 }
205
206 int ncmul::ldegree(const ex & s) const
207 {
208         if (is_equal(ex_to<basic>(s)))
209                 return 1;
210
211         // Sum up degrees of factors
212         int deg_sum = 0;
213         exvector::const_iterator i = seq.begin(), end = seq.end();
214         while (i != end) {
215                 deg_sum += i->degree(s);
216                 ++i;
217         }
218         return deg_sum;
219 }
220
221 ex ncmul::coeff(const ex & s, int n) const
222 {
223         if (is_equal(ex_to<basic>(s)))
224                 return n==1 ? _ex1 : _ex0;
225
226         exvector coeffseq;
227         coeffseq.reserve(seq.size());
228
229         if (n == 0) {
230                 // product of individual coeffs
231                 // if a non-zero power of s is found, the resulting product will be 0
232                 exvector::const_iterator it=seq.begin();
233                 while (it!=seq.end()) {
234                         coeffseq.push_back((*it).coeff(s,n));
235                         ++it;
236                 }
237                 return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
238         }
239                  
240         exvector::const_iterator i = seq.begin(), end = seq.end();
241         bool coeff_found = false;
242         while (i != end) {
243                 ex c = i->coeff(s,n);
244                 if (c.is_zero()) {
245                         coeffseq.push_back(*i);
246                 } else {
247                         coeffseq.push_back(c);
248                         coeff_found = true;
249                 }
250                 ++i;
251         }
252
253         if (coeff_found) return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
254         
255         return _ex0;
256 }
257
258 size_t ncmul::count_factors(const ex & e) const
259 {
260         if ((is_exactly_a<mul>(e)&&(e.return_type()!=return_types::commutative))||
261                 (is_exactly_a<ncmul>(e))) {
262                 size_t factors=0;
263                 for (size_t i=0; i<e.nops(); i++)
264                         factors += count_factors(e.op(i));
265                 
266                 return factors;
267         }
268         return 1;
269 }
270                 
271 void ncmul::append_factors(exvector & v, const ex & e) const
272 {
273         if ((is_exactly_a<mul>(e)&&(e.return_type()!=return_types::commutative))||
274                 (is_exactly_a<ncmul>(e))) {
275                 for (size_t i=0; i<e.nops(); i++)
276                         append_factors(v, e.op(i));
277         } else 
278                 v.push_back(e);
279 }
280
281 typedef std::vector<unsigned> unsignedvector;
282 typedef std::vector<exvector> exvectorvector;
283
284 /** Perform automatic term rewriting rules in this class.  In the following
285  *  x, x1, x2,... stand for a symbolic variables of type ex and c, c1, c2...
286  *  stand for such expressions that contain a plain number.
287  *  - ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) -> ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
288  *  - ncmul(x) -> x
289  *  - ncmul() -> 1
290  *  - ncmul(...,c1,...,c2,...) -> *(c1,c2,ncmul(...))  (pull out commutative elements)
291  *  - ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))  (collect elements of same type)
292  *  - ncmul(x1,x2,x3,...) -> x::eval_ncmul(x1,x2,x3,...)
293  *
294  *  @param level cut-off in recursive evaluation */
295 ex ncmul::eval(int level) const
296 {
297         // The following additional rule would be nice, but produces a recursion,
298         // which must be trapped by introducing a flag that the sub-ncmuls()
299         // are already evaluated (maybe later...)
300         //                  ncmul(x1,x2,...,X,y1,y2,...) ->
301         //                      ncmul(ncmul(x1,x2,...),X,ncmul(y1,y2,...)
302         //                      (X noncommutative_composite)
303
304         if ((level==1) && (flags & status_flags::evaluated)) {
305                 return *this;
306         }
307
308         exvector evaledseq=evalchildren(level);
309
310         // ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
311         //     ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
312         size_t factors = 0;
313         exvector::const_iterator cit = evaledseq.begin(), citend = evaledseq.end();
314         while (cit != citend)
315                 factors += count_factors(*cit++);
316         
317         exvector assocseq;
318         assocseq.reserve(factors);
319         cit = evaledseq.begin();
320         while (cit != citend)
321                 append_factors(assocseq, *cit++);
322         
323         // ncmul(x) -> x
324         if (assocseq.size()==1) return *(seq.begin());
325
326         // ncmul() -> 1
327         if (assocseq.empty()) return _ex1;
328
329         // determine return types
330         unsignedvector rettypes;
331         rettypes.reserve(assocseq.size());
332         size_t i = 0;
333         size_t count_commutative=0;
334         size_t count_noncommutative=0;
335         size_t count_noncommutative_composite=0;
336         cit = assocseq.begin(); citend = assocseq.end();
337         while (cit != citend) {
338                 switch (rettypes[i] = cit->return_type()) {
339                 case return_types::commutative:
340                         count_commutative++;
341                         break;
342                 case return_types::noncommutative:
343                         count_noncommutative++;
344                         break;
345                 case return_types::noncommutative_composite:
346                         count_noncommutative_composite++;
347                         break;
348                 default:
349                         throw(std::logic_error("ncmul::eval(): invalid return type"));
350                 }
351                 ++i; ++cit;
352         }
353         GINAC_ASSERT(count_commutative+count_noncommutative+count_noncommutative_composite==assocseq.size());
354
355         // ncmul(...,c1,...,c2,...) ->
356         //     *(c1,c2,ncmul(...)) (pull out commutative elements)
357         if (count_commutative!=0) {
358                 exvector commutativeseq;
359                 commutativeseq.reserve(count_commutative+1);
360                 exvector noncommutativeseq;
361                 noncommutativeseq.reserve(assocseq.size()-count_commutative);
362                 size_t num = assocseq.size();
363                 for (size_t i=0; i<num; ++i) {
364                         if (rettypes[i]==return_types::commutative)
365                                 commutativeseq.push_back(assocseq[i]);
366                         else
367                                 noncommutativeseq.push_back(assocseq[i]);
368                 }
369                 commutativeseq.push_back((new ncmul(noncommutativeseq,1))->setflag(status_flags::dynallocated));
370                 return (new mul(commutativeseq))->setflag(status_flags::dynallocated);
371         }
372                 
373         // ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
374         //     (collect elements of same type)
375
376         if (count_noncommutative_composite==0) {
377                 // there are neither commutative nor noncommutative_composite
378                 // elements in assocseq
379                 GINAC_ASSERT(count_commutative==0);
380
381                 size_t assoc_num = assocseq.size();
382                 exvectorvector evv;
383                 unsignedvector rttinfos;
384                 evv.reserve(assoc_num);
385                 rttinfos.reserve(assoc_num);
386
387                 cit = assocseq.begin(), citend = assocseq.end();
388                 while (cit != citend) {
389                         unsigned ti = cit->return_type_tinfo();
390                         size_t rtt_num = rttinfos.size();
391                         // search type in vector of known types
392                         for (i=0; i<rtt_num; ++i) {
393                                 if (ti == rttinfos[i]) {
394                                         evv[i].push_back(*cit);
395                                         break;
396                                 }
397                         }
398                         if (i >= rtt_num) {
399                                 // new type
400                                 rttinfos.push_back(ti);
401                                 evv.push_back(exvector());
402                                 (evv.end()-1)->reserve(assoc_num);
403                                 (evv.end()-1)->push_back(*cit);
404                         }
405                         ++cit;
406                 }
407
408                 size_t evv_num = evv.size();
409 #ifdef DO_GINAC_ASSERT
410                 GINAC_ASSERT(evv_num == rttinfos.size());
411                 GINAC_ASSERT(evv_num > 0);
412                 size_t s=0;
413                 for (i=0; i<evv_num; ++i)
414                         s += evv[i].size();
415                 GINAC_ASSERT(s == assoc_num);
416 #endif // def DO_GINAC_ASSERT
417                 
418                 // if all elements are of same type, simplify the string
419                 if (evv_num == 1)
420                         return evv[0][0].eval_ncmul(evv[0]);
421                 
422                 exvector splitseq;
423                 splitseq.reserve(evv_num);
424                 for (i=0; i<evv_num; ++i)
425                         splitseq.push_back((new ncmul(evv[i]))->setflag(status_flags::dynallocated));
426                 
427                 return (new mul(splitseq))->setflag(status_flags::dynallocated);
428         }
429         
430         return (new ncmul(assocseq))->setflag(status_flags::dynallocated |
431                                                                                   status_flags::evaluated);
432 }
433
434 ex ncmul::evalm() const
435 {
436         // Evaluate children first
437         std::auto_ptr<exvector> s(new exvector);
438         s->reserve(seq.size());
439         exvector::const_iterator it = seq.begin(), itend = seq.end();
440         while (it != itend) {
441                 s->push_back(it->evalm());
442                 it++;
443         }
444
445         // If there are only matrices, simply multiply them
446         it = s->begin(); itend = s->end();
447         if (is_a<matrix>(*it)) {
448                 matrix prod(ex_to<matrix>(*it));
449                 it++;
450                 while (it != itend) {
451                         if (!is_a<matrix>(*it))
452                                 goto no_matrix;
453                         prod = prod.mul(ex_to<matrix>(*it));
454                         it++;
455                 }
456                 return prod;
457         }
458
459 no_matrix:
460         return (new ncmul(s))->setflag(status_flags::dynallocated);
461 }
462
463 ex ncmul::thiscontainer(const exvector & v) const
464 {
465         return (new ncmul(v))->setflag(status_flags::dynallocated);
466 }
467
468 ex ncmul::thiscontainer(std::auto_ptr<exvector> vp) const
469 {
470         return (new ncmul(vp))->setflag(status_flags::dynallocated);
471 }
472
473 ex ncmul::conjugate() const
474 {
475         if (return_type() != return_types::noncommutative) {
476                 return exprseq::conjugate();
477         }
478
479         if ((return_type_tinfo() & 0xffffff00U) != TINFO_clifford) {
480                 return exprseq::conjugate();
481         }
482
483         exvector ev;
484         ev.reserve(nops());
485         for (const_iterator i=end(); i!=begin();) {
486                 --i;
487                 ev.push_back(i->conjugate());
488         }
489         return (new ncmul(ev, true))->setflag(status_flags::dynallocated).eval();
490 }
491
492 // protected
493
494 /** Implementation of ex::diff() for a non-commutative product. It applies
495  *  the product rule.
496  *  @see ex::diff */
497 ex ncmul::derivative(const symbol & s) const
498 {
499         size_t num = seq.size();
500         exvector addseq;
501         addseq.reserve(num);
502         
503         // D(a*b*c) = D(a)*b*c + a*D(b)*c + a*b*D(c)
504         exvector ncmulseq = seq;
505         for (size_t i=0; i<num; ++i) {
506                 ex e = seq[i].diff(s);
507                 e.swap(ncmulseq[i]);
508                 addseq.push_back((new ncmul(ncmulseq))->setflag(status_flags::dynallocated));
509                 e.swap(ncmulseq[i]);
510         }
511         return (new add(addseq))->setflag(status_flags::dynallocated);
512 }
513
514 int ncmul::compare_same_type(const basic & other) const
515 {
516         return inherited::compare_same_type(other);
517 }
518
519 unsigned ncmul::return_type() const
520 {
521         if (seq.empty())
522                 return return_types::commutative;
523
524         bool all_commutative = true;
525         exvector::const_iterator noncommutative_element; // point to first found nc element
526
527         exvector::const_iterator i = seq.begin(), end = seq.end();
528         while (i != end) {
529                 unsigned rt = i->return_type();
530                 if (rt == return_types::noncommutative_composite)
531                         return rt; // one ncc -> mul also ncc
532                 if ((rt == return_types::noncommutative) && (all_commutative)) {
533                         // first nc element found, remember position
534                         noncommutative_element = i;
535                         all_commutative = false;
536                 }
537                 if ((rt == return_types::noncommutative) && (!all_commutative)) {
538                         // another nc element found, compare type_infos
539                         if (noncommutative_element->return_type_tinfo() != i->return_type_tinfo()) {
540                                 // diffent types -> mul is ncc
541                                 return return_types::noncommutative_composite;
542                         }
543                 }
544                 ++i;
545         }
546         // all factors checked
547         GINAC_ASSERT(!all_commutative); // not all factors should commutate, because this is a ncmul();
548         return all_commutative ? return_types::commutative : return_types::noncommutative;
549 }
550    
551 unsigned ncmul::return_type_tinfo() const
552 {
553         if (seq.empty())
554                 return tinfo_key;
555
556         // return type_info of first noncommutative element
557         exvector::const_iterator i = seq.begin(), end = seq.end();
558         while (i != end) {
559                 if (i->return_type() == return_types::noncommutative)
560                         return i->return_type_tinfo();
561                 ++i;
562         }
563
564         // no noncommutative element found, should not happen
565         return tinfo_key;
566 }
567
568 //////////
569 // new virtual functions which can be overridden by derived classes
570 //////////
571
572 // none
573
574 //////////
575 // non-virtual functions in this class
576 //////////
577
578 std::auto_ptr<exvector> ncmul::expandchildren(unsigned options) const
579 {
580         const_iterator cit = this->seq.begin(), end = this->seq.end();
581         while (cit != end) {
582                 const ex & expanded_ex = cit->expand(options);
583                 if (!are_ex_trivially_equal(*cit, expanded_ex)) {
584
585                         // copy first part of seq which hasn't changed
586                         std::auto_ptr<exvector> s(new exvector(this->seq.begin(), cit));
587                         reserve(*s, this->seq.size());
588
589                         // insert changed element
590                         s->push_back(expanded_ex);
591                         ++cit;
592
593                         // copy rest
594                         while (cit != end) {
595                                 s->push_back(cit->expand(options));
596                                 ++cit;
597                         }
598
599                         return s;
600                 }
601
602                 ++cit;
603         }
604
605         return std::auto_ptr<exvector>(0); // nothing has changed
606 }
607
608 const exvector & ncmul::get_factors() const
609 {
610         return seq;
611 }
612
613 //////////
614 // friend functions
615 //////////
616
617 ex reeval_ncmul(const exvector & v)
618 {
619         return (new ncmul(v))->setflag(status_flags::dynallocated);
620 }
621
622 ex hold_ncmul(const exvector & v)
623 {
624         if (v.empty())
625                 return _ex1;
626         else if (v.size() == 1)
627                 return v[0];
628         else
629                 return (new ncmul(v))->setflag(status_flags::dynallocated |
630                                                status_flags::evaluated);
631 }
632
633 } // namespace GiNaC