- Move several comments into doxygen comment blocks.
[ginac.git] / ginac / ncmul.cpp
1 /** @file ncmul.cpp
2  *
3  *  Implementation of GiNaC's non-commutative products of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <algorithm>
24 #include <iostream>
25 #include <stdexcept>
26
27 #include "ncmul.h"
28 #include "ex.h"
29 #include "add.h"
30 #include "mul.h"
31 #include "matrix.h"
32 #include "print.h"
33 #include "archive.h"
34 #include "debugmsg.h"
35 #include "utils.h"
36
37 namespace GiNaC {
38
39 GINAC_IMPLEMENT_REGISTERED_CLASS(ncmul, exprseq)
40
41 //////////
42 // default constructor, destructor, copy constructor assignment operator and helpers
43 //////////
44
45 ncmul::ncmul()
46 {
47         debugmsg("ncmul default constructor",LOGLEVEL_CONSTRUCT);
48         tinfo_key = TINFO_ncmul;
49 }
50
51 DEFAULT_COPY(ncmul)
52 DEFAULT_DESTROY(ncmul)
53
54 //////////
55 // other constructors
56 //////////
57
58 // public
59
60 ncmul::ncmul(const ex & lh, const ex & rh) : inherited(lh,rh)
61 {
62         debugmsg("ncmul constructor from ex,ex",LOGLEVEL_CONSTRUCT);
63         tinfo_key = TINFO_ncmul;
64 }
65
66 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3) : inherited(f1,f2,f3)
67 {
68         debugmsg("ncmul constructor from 3 ex",LOGLEVEL_CONSTRUCT);
69         tinfo_key = TINFO_ncmul;
70 }
71
72 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
73              const ex & f4) : inherited(f1,f2,f3,f4)
74 {
75         debugmsg("ncmul constructor from 4 ex",LOGLEVEL_CONSTRUCT);
76         tinfo_key = TINFO_ncmul;
77 }
78
79 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
80              const ex & f4, const ex & f5) : inherited(f1,f2,f3,f4,f5)
81 {
82         debugmsg("ncmul constructor from 5 ex",LOGLEVEL_CONSTRUCT);
83         tinfo_key = TINFO_ncmul;
84 }
85
86 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
87              const ex & f4, const ex & f5, const ex & f6) : inherited(f1,f2,f3,f4,f5,f6)
88 {
89         debugmsg("ncmul constructor from 6 ex",LOGLEVEL_CONSTRUCT);
90         tinfo_key = TINFO_ncmul;
91 }
92
93 ncmul::ncmul(const exvector & v, bool discardable) : inherited(v,discardable)
94 {
95         debugmsg("ncmul constructor from exvector,bool",LOGLEVEL_CONSTRUCT);
96         tinfo_key = TINFO_ncmul;
97 }
98
99 ncmul::ncmul(exvector * vp) : inherited(vp)
100 {
101         debugmsg("ncmul constructor from exvector *",LOGLEVEL_CONSTRUCT);
102         tinfo_key = TINFO_ncmul;
103 }
104
105 //////////
106 // archiving
107 //////////
108
109 DEFAULT_ARCHIVING(ncmul)
110         
111 //////////
112 // functions overriding virtual functions from base classes
113 //////////
114
115 // public
116
117 void ncmul::print(const print_context & c, unsigned level) const
118 {
119         debugmsg("ncmul print", LOGLEVEL_PRINT);
120
121         if (is_of_type(c, print_tree)) {
122
123                 inherited::print(c, level);
124
125         } else if (is_of_type(c, print_csrc)) {
126
127                 c.s << "ncmul(";
128                 exvector::const_iterator it = seq.begin(), itend = seq.end()-1;
129                 while (it != itend) {
130                         it->print(c, precedence());
131                         c.s << ",";
132                         it++;
133                 }
134                 it->print(c, precedence());
135                 c.s << ")";
136
137         } else
138                 printseq(c, '(', '*', ')', precedence(), level);
139 }
140
141 bool ncmul::info(unsigned inf) const
142 {
143         return inherited::info(inf);
144 }
145
146 typedef std::vector<int> intvector;
147
148 ex ncmul::expand(unsigned options) const
149 {
150         // First, expand the children
151         exvector expanded_seq = expandchildren(options);
152         
153         // Now, look for all the factors that are sums and remember their
154         // position and number of terms.
155         intvector positions_of_adds(expanded_seq.size());
156         intvector number_of_add_operands(expanded_seq.size());
157
158         int number_of_adds = 0;
159         int number_of_expanded_terms = 1;
160
161         unsigned current_position = 0;
162         exvector::const_iterator last = expanded_seq.end();
163         for (exvector::const_iterator cit=expanded_seq.begin(); cit!=last; ++cit) {
164                 if (is_exactly_a<add>(*cit)) {
165                         positions_of_adds[number_of_adds] = current_position;
166                         unsigned num_ops = cit->nops();
167                         number_of_add_operands[number_of_adds] = num_ops;
168                         number_of_expanded_terms *= num_ops;
169                         number_of_adds++;
170                 }
171                 ++current_position;
172         }
173
174         // If there are no sums, we are done
175         if (number_of_adds == 0)
176                 return (new ncmul(expanded_seq, true))->
177                         setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
178
179         // Now, form all possible products of the terms of the sums with the
180         // remaining factors, and add them together
181         exvector distrseq;
182         distrseq.reserve(number_of_expanded_terms);
183
184         intvector k(number_of_adds);
185
186         while (true) {
187                 exvector term = expanded_seq;
188                 for (int i=0; i<number_of_adds; i++)
189                         term[positions_of_adds[i]] = expanded_seq[positions_of_adds[i]].op(k[i]);
190                 distrseq.push_back((new ncmul(term, true))->
191                                     setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));
192
193                 // increment k[]
194                 int l = number_of_adds-1;
195                 while ((l>=0) && ((++k[l]) >= number_of_add_operands[l])) {
196                         k[l] = 0;
197                         l--;
198                 }
199                 if (l<0)
200                         break;
201         }
202
203         return (new add(distrseq))->
204                 setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
205 }
206
207 int ncmul::degree(const ex & s) const
208 {
209         // Sum up degrees of factors
210         int deg_sum = 0;
211         exvector::const_iterator i = seq.begin(), end = seq.end();
212         while (i != end) {
213                 deg_sum += i->degree(s);
214                 ++i;
215         }
216         return deg_sum;
217 }
218
219 int ncmul::ldegree(const ex & s) const
220 {
221         // Sum up degrees of factors
222         int deg_sum = 0;
223         exvector::const_iterator i = seq.begin(), end = seq.end();
224         while (i != end) {
225                 deg_sum += i->degree(s);
226                 ++i;
227         }
228         return deg_sum;
229 }
230
231 ex ncmul::coeff(const ex & s, int n) const
232 {
233         exvector coeffseq;
234         coeffseq.reserve(seq.size());
235
236         if (n == 0) {
237                 // product of individual coeffs
238                 // if a non-zero power of s is found, the resulting product will be 0
239                 exvector::const_iterator it=seq.begin();
240                 while (it!=seq.end()) {
241                         coeffseq.push_back((*it).coeff(s,n));
242                         ++it;
243                 }
244                 return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
245         }
246                  
247         exvector::const_iterator i = seq.begin(), end = seq.end();
248         bool coeff_found = false;
249         while (i != end) {
250                 ex c = i->coeff(s,n);
251                 if (c.is_zero()) {
252                         coeffseq.push_back(*i);
253                 } else {
254                         coeffseq.push_back(c);
255                         coeff_found = true;
256                 }
257                 ++i;
258         }
259
260         if (coeff_found) return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
261         
262         return _ex0();
263 }
264
265 unsigned ncmul::count_factors(const ex & e) const
266 {
267         if ((is_ex_exactly_of_type(e,mul)&&(e.return_type()!=return_types::commutative))||
268                 (is_ex_exactly_of_type(e,ncmul))) {
269                 unsigned factors=0;
270                 for (unsigned i=0; i<e.nops(); i++)
271                         factors += count_factors(e.op(i));
272                 
273                 return factors;
274         }
275         return 1;
276 }
277                 
278 void ncmul::append_factors(exvector & v, const ex & e) const
279 {
280         if ((is_ex_exactly_of_type(e,mul)&&(e.return_type()!=return_types::commutative))||
281                 (is_ex_exactly_of_type(e,ncmul))) {
282                 for (unsigned i=0; i<e.nops(); i++)
283                         append_factors(v,e.op(i));
284         } else 
285                 v.push_back(e);
286 }
287
288 typedef std::vector<unsigned> unsignedvector;
289 typedef std::vector<exvector> exvectorvector;
290
291 /** Perform automatic term rewriting rules in this class.  In the following
292  *  x, x1, x2,... stand for a symbolic variables of type ex and c, c1, c2...
293  *  stand for such expressions that contain a plain number.
294  *  - ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) -> ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
295  *  - ncmul(x) -> x
296  *  - ncmul() -> 1
297  *  - ncmul(...,c1,...,c2,...) -> *(c1,c2,ncmul(...))  (pull out commutative elements)
298  *  - ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))  (collect elements of same type)
299  *  - ncmul(x1,x2,x3,...) -> x::simplify_ncmul(x1,x2,x3,...)
300  *
301  *  @param level cut-off in recursive evaluation */
302 ex ncmul::eval(int level) const
303 {
304         // The following additional rule would be nice, but produces a recursion,
305         // which must be trapped by introducing a flag that the sub-ncmuls()
306         // are already evaluated (maybe later...)
307         //                  ncmul(x1,x2,...,X,y1,y2,...) ->
308         //                      ncmul(ncmul(x1,x2,...),X,ncmul(y1,y2,...)
309         //                      (X noncommutative_composite)
310
311         if ((level==1) && (flags & status_flags::evaluated)) {
312                 return *this;
313         }
314
315         exvector evaledseq=evalchildren(level);
316
317         // ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
318         //     ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
319         unsigned factors = 0;
320         exvector::const_iterator cit = evaledseq.begin(), citend = evaledseq.end();
321         while (cit != citend)
322                 factors += count_factors(*cit++);
323         
324         exvector assocseq;
325         assocseq.reserve(factors);
326         cit = evaledseq.begin();
327         while (cit != citend)
328                 append_factors(assocseq, *cit++);
329         
330         // ncmul(x) -> x
331         if (assocseq.size()==1) return *(seq.begin());
332
333         // ncmul() -> 1
334         if (assocseq.empty()) return _ex1();
335
336         // determine return types
337         unsignedvector rettypes;
338         rettypes.reserve(assocseq.size());
339         unsigned i = 0;
340         unsigned count_commutative=0;
341         unsigned count_noncommutative=0;
342         unsigned count_noncommutative_composite=0;
343         cit = assocseq.begin(); citend = assocseq.end();
344         while (cit != citend) {
345                 switch (rettypes[i] = cit->return_type()) {
346                 case return_types::commutative:
347                         count_commutative++;
348                         break;
349                 case return_types::noncommutative:
350                         count_noncommutative++;
351                         break;
352                 case return_types::noncommutative_composite:
353                         count_noncommutative_composite++;
354                         break;
355                 default:
356                         throw(std::logic_error("ncmul::eval(): invalid return type"));
357                 }
358                 ++i; ++cit;
359         }
360         GINAC_ASSERT(count_commutative+count_noncommutative+count_noncommutative_composite==assocseq.size());
361
362         // ncmul(...,c1,...,c2,...) ->
363         //     *(c1,c2,ncmul(...)) (pull out commutative elements)
364         if (count_commutative!=0) {
365                 exvector commutativeseq;
366                 commutativeseq.reserve(count_commutative+1);
367                 exvector noncommutativeseq;
368                 noncommutativeseq.reserve(assocseq.size()-count_commutative);
369                 unsigned num = assocseq.size();
370                 for (unsigned i=0; i<num; ++i) {
371                         if (rettypes[i]==return_types::commutative)
372                                 commutativeseq.push_back(assocseq[i]);
373                         else
374                                 noncommutativeseq.push_back(assocseq[i]);
375                 }
376                 commutativeseq.push_back((new ncmul(noncommutativeseq,1))->setflag(status_flags::dynallocated));
377                 return (new mul(commutativeseq))->setflag(status_flags::dynallocated);
378         }
379                 
380         // ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
381         //     (collect elements of same type)
382
383         if (count_noncommutative_composite==0) {
384                 // there are neither commutative nor noncommutative_composite
385                 // elements in assocseq
386                 GINAC_ASSERT(count_commutative==0);
387
388                 unsigned assoc_num = assocseq.size();
389                 exvectorvector evv;
390                 unsignedvector rttinfos;
391                 evv.reserve(assoc_num);
392                 rttinfos.reserve(assoc_num);
393
394                 cit = assocseq.begin(), citend = assocseq.end();
395                 while (cit != citend) {
396                         unsigned ti = cit->return_type_tinfo();
397                         unsigned rtt_num = rttinfos.size();
398                         // search type in vector of known types
399                         for (i=0; i<rtt_num; ++i) {
400                                 if (ti == rttinfos[i]) {
401                                         evv[i].push_back(*cit);
402                                         break;
403                                 }
404                         }
405                         if (i >= rtt_num) {
406                                 // new type
407                                 rttinfos.push_back(ti);
408                                 evv.push_back(exvector());
409                                 (evv.end()-1)->reserve(assoc_num);
410                                 (evv.end()-1)->push_back(*cit);
411                         }
412                         ++cit;
413                 }
414
415                 unsigned evv_num = evv.size();
416 #ifdef DO_GINAC_ASSERT
417                 GINAC_ASSERT(evv_num == rttinfos.size());
418                 GINAC_ASSERT(evv_num > 0);
419                 unsigned s=0;
420                 for (i=0; i<evv_num; ++i)
421                         s += evv[i].size();
422                 GINAC_ASSERT(s == assoc_num);
423 #endif // def DO_GINAC_ASSERT
424                 
425                 // if all elements are of same type, simplify the string
426                 if (evv_num == 1)
427                         return evv[0][0].simplify_ncmul(evv[0]);
428                 
429                 exvector splitseq;
430                 splitseq.reserve(evv_num);
431                 for (i=0; i<evv_num; ++i)
432                         splitseq.push_back((new ncmul(evv[i]))->setflag(status_flags::dynallocated));
433                 
434                 return (new mul(splitseq))->setflag(status_flags::dynallocated);
435         }
436         
437         return (new ncmul(assocseq))->setflag(status_flags::dynallocated |
438                                                                                   status_flags::evaluated);
439 }
440
441 ex ncmul::evalm(void) const
442 {
443         // Evaluate children first
444         exvector *s = new exvector;
445         s->reserve(seq.size());
446         exvector::const_iterator it = seq.begin(), itend = seq.end();
447         while (it != itend) {
448                 s->push_back(it->evalm());
449                 it++;
450         }
451
452         // If there are only matrices, simply multiply them
453         it = s->begin(); itend = s->end();
454         if (is_ex_of_type(*it, matrix)) {
455                 matrix prod(ex_to<matrix>(*it));
456                 it++;
457                 while (it != itend) {
458                         if (!is_ex_of_type(*it, matrix))
459                                 goto no_matrix;
460                         prod = prod.mul(ex_to<matrix>(*it));
461                         it++;
462                 }
463                 delete s;
464                 return prod;
465         }
466
467 no_matrix:
468         return (new ncmul(s))->setflag(status_flags::dynallocated);
469 }
470
471 ex ncmul::thisexprseq(const exvector & v) const
472 {
473         return (new ncmul(v))->setflag(status_flags::dynallocated);
474 }
475
476 ex ncmul::thisexprseq(exvector * vp) const
477 {
478         return (new ncmul(vp))->setflag(status_flags::dynallocated);
479 }
480
481 // protected
482
483 /** Implementation of ex::diff() for a non-commutative product. It applies
484  *  the product rule.
485  *  @see ex::diff */
486 ex ncmul::derivative(const symbol & s) const
487 {
488         unsigned num = seq.size();
489         exvector addseq;
490         addseq.reserve(num);
491         
492         // D(a*b*c) = D(a)*b*c + a*D(b)*c + a*b*D(c)
493         exvector ncmulseq = seq;
494         for (unsigned i=0; i<num; ++i) {
495                 ex e = seq[i].diff(s);
496                 e.swap(ncmulseq[i]);
497                 addseq.push_back((new ncmul(ncmulseq))->setflag(status_flags::dynallocated));
498                 e.swap(ncmulseq[i]);
499         }
500         return (new add(addseq))->setflag(status_flags::dynallocated);
501 }
502
503 int ncmul::compare_same_type(const basic & other) const
504 {
505         return inherited::compare_same_type(other);
506 }
507
508 unsigned ncmul::return_type(void) const
509 {
510         if (seq.empty())
511                 return return_types::commutative;
512
513         bool all_commutative = true;
514         exvector::const_iterator noncommutative_element; // point to first found nc element
515
516         exvector::const_iterator i = seq.begin(), end = seq.end();
517         while (i != end) {
518                 unsigned rt = i->return_type();
519                 if (rt == return_types::noncommutative_composite)
520                         return rt; // one ncc -> mul also ncc
521                 if ((rt == return_types::noncommutative) && (all_commutative)) {
522                         // first nc element found, remember position
523                         noncommutative_element = i;
524                         all_commutative = false;
525                 }
526                 if ((rt == return_types::noncommutative) && (!all_commutative)) {
527                         // another nc element found, compare type_infos
528                         if (noncommutative_element->return_type_tinfo() != i->return_type_tinfo()) {
529                                 // diffent types -> mul is ncc
530                                 return return_types::noncommutative_composite;
531                         }
532                 }
533                 ++i;
534         }
535         // all factors checked
536         GINAC_ASSERT(!all_commutative); // not all factors should commute, because this is a ncmul();
537         return all_commutative ? return_types::commutative : return_types::noncommutative;
538 }
539    
540 unsigned ncmul::return_type_tinfo(void) const
541 {
542         if (seq.empty())
543                 return tinfo_key;
544
545         // return type_info of first noncommutative element
546         exvector::const_iterator i = seq.begin(), end = seq.end();
547         while (i != end) {
548                 if (i->return_type() == return_types::noncommutative)
549                         return i->return_type_tinfo();
550                 ++i;
551         }
552
553         // no noncommutative element found, should not happen
554         return tinfo_key;
555 }
556
557 //////////
558 // new virtual functions which can be overridden by derived classes
559 //////////
560
561 // none
562
563 //////////
564 // non-virtual functions in this class
565 //////////
566
567 exvector ncmul::expandchildren(unsigned options) const
568 {
569         exvector s;
570         s.reserve(seq.size());
571         exvector::const_iterator it = seq.begin(), itend = seq.end();
572         while (it != itend) {
573                 s.push_back(it->expand(options));
574                 it++;
575         }
576         return s;
577 }
578
579 const exvector & ncmul::get_factors(void) const
580 {
581         return seq;
582 }
583
584 //////////
585 // friend functions
586 //////////
587
588 ex nonsimplified_ncmul(const exvector & v)
589 {
590         return (new ncmul(v))->setflag(status_flags::dynallocated);
591 }
592
593 ex simplified_ncmul(const exvector & v)
594 {
595         if (v.empty())
596                 return _ex1();
597         else if (v.size() == 1)
598                 return v[0];
599         else
600                 return (new ncmul(v))->setflag(status_flags::dynallocated |
601                                                status_flags::evaluated);
602 }
603
604 } // namespace GiNaC