- modified GiNaC headers to Alexander's liking
[ginac.git] / ginac / ncmul.cpp
1 /** @file ncmul.cpp
2  *
3  *  Implementation of GiNaC's non-commutative products of expressions.
4  *
5  *  GiNaC Copyright (C) 1999 Johannes Gutenberg University Mainz, Germany
6  *
7  *  This program is free software; you can redistribute it and/or modify
8  *  it under the terms of the GNU General Public License as published by
9  *  the Free Software Foundation; either version 2 of the License, or
10  *  (at your option) any later version.
11  *
12  *  This program is distributed in the hope that it will be useful,
13  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
14  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  *  GNU General Public License for more details.
16  *
17  *  You should have received a copy of the GNU General Public License
18  *  along with this program; if not, write to the Free Software
19  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
20  */
21
22 #include <algorithm>
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "ncmul.h"
27 #include "ex.h"
28 #include "add.h"
29 #include "mul.h"
30
31 //////////
32 // default constructor, destructor, copy constructor assignment operator and helpers
33 //////////
34
35 // public
36
37 ncmul::ncmul()
38 {
39     debugmsg("ncmul default constructor",LOGLEVEL_CONSTRUCT);
40     tinfo_key = TINFO_ncmul;
41 }
42
43 ncmul::~ncmul()
44 {
45     debugmsg("ncmul destructor",LOGLEVEL_DESTRUCT);
46     destroy(0);
47 }
48
49 ncmul::ncmul(ncmul const & other)
50 {
51     debugmsg("ncmul copy constructor",LOGLEVEL_CONSTRUCT);
52     copy(other);
53 }
54
55 ncmul const & ncmul::operator=(ncmul const & other)
56 {
57     debugmsg("ncmul operator=",LOGLEVEL_ASSIGNMENT);
58     if (this != &other) {
59         destroy(1);
60         copy(other);
61     }
62     return *this;
63 }
64
65 // protected
66
67 void ncmul::copy(ncmul const & other)
68 {
69     exprseq::copy(other);
70 }
71
72 void ncmul::destroy(bool call_parent)
73 {
74     if (call_parent) exprseq::destroy(call_parent);
75 }
76
77 //////////
78 // other constructors
79 //////////
80
81 // public
82
83 ncmul::ncmul(ex const & lh, ex const & rh) :
84     exprseq(lh,rh)
85 {
86     debugmsg("ncmul constructor from ex,ex",LOGLEVEL_CONSTRUCT);
87     tinfo_key = TINFO_ncmul;
88 }
89
90 ncmul::ncmul(ex const & f1, ex const & f2, ex const & f3) :
91     exprseq(f1,f2,f3)
92 {
93     debugmsg("ncmul constructor from 3 ex",LOGLEVEL_CONSTRUCT);
94     tinfo_key = TINFO_ncmul;
95 }
96
97 ncmul::ncmul(ex const & f1, ex const & f2, ex const & f3,
98       ex const & f4) : exprseq(f1,f2,f3,f4)
99 {
100     debugmsg("ncmul constructor from 4 ex",LOGLEVEL_CONSTRUCT);
101     tinfo_key = TINFO_ncmul;
102 }
103
104 ncmul::ncmul(ex const & f1, ex const & f2, ex const & f3,
105       ex const & f4, ex const & f5) : exprseq(f1,f2,f3,f4,f5)
106 {
107     debugmsg("ncmul constructor from 5 ex",LOGLEVEL_CONSTRUCT);
108     tinfo_key = TINFO_ncmul;
109 }
110
111 ncmul::ncmul(ex const & f1, ex const & f2, ex const & f3,
112       ex const & f4, ex const & f5, ex const & f6) :
113     exprseq(f1,f2,f3,f4,f5,f6)
114 {
115     debugmsg("ncmul constructor from 6 ex",LOGLEVEL_CONSTRUCT);
116     tinfo_key = TINFO_ncmul;
117 }
118
119 ncmul::ncmul(exvector const & v, bool discardable) : exprseq(v,discardable)
120 {
121     debugmsg("ncmul constructor from exvector,bool",LOGLEVEL_CONSTRUCT);
122     tinfo_key = TINFO_ncmul;
123 }
124
125 ncmul::ncmul(exvector * vp) : exprseq(vp)
126 {
127     debugmsg("ncmul constructor from exvector *",LOGLEVEL_CONSTRUCT);
128     tinfo_key = TINFO_ncmul;
129 }
130     
131 //////////
132 // functions overriding virtual functions from bases classes
133 //////////
134
135 // public
136
137 basic * ncmul::duplicate() const
138 {
139     debugmsg("ncmul duplicate",LOGLEVEL_ASSIGNMENT);
140     return new ncmul(*this);
141 }
142
143 bool ncmul::info(unsigned inf) const
144 {
145     throw(std::logic_error("which flags have to be implemented in ncmul::info()?"));
146 }
147
148 typedef vector<int> intvector;
149
150 ex ncmul::expand(unsigned options) const
151 {
152     exvector sub_expanded_seq;
153     intvector positions_of_adds;
154     intvector number_of_add_operands;
155
156     exvector expanded_seq=expandchildren(options);
157
158     positions_of_adds.resize(expanded_seq.size());
159     number_of_add_operands.resize(expanded_seq.size());
160
161     int number_of_adds=0;
162     int number_of_expanded_terms=1;
163
164     unsigned current_position=0;
165     exvector::const_iterator last=expanded_seq.end();
166     for (exvector::const_iterator cit=expanded_seq.begin(); cit!=last; ++cit) {
167         if (is_ex_exactly_of_type((*cit),add)) {
168             positions_of_adds[number_of_adds]=current_position;
169             add const & expanded_addref=ex_to_add(*cit);
170             number_of_add_operands[number_of_adds]=expanded_addref.seq.size();
171             number_of_expanded_terms *= expanded_addref.seq.size();
172             number_of_adds++;
173         }
174         current_position++;
175     }
176
177     if (number_of_adds==0) {
178         return (new ncmul(expanded_seq,1))->setflag(status_flags::dynallocated ||
179                                                     status_flags::expanded);
180     }
181
182     exvector distrseq;
183     distrseq.reserve(number_of_expanded_terms);
184
185     intvector k;
186     k.resize(number_of_adds);
187     
188     int l;
189     for (l=0; l<number_of_adds; l++) {
190         k[l]=0;
191     }
192
193     while (1) {
194         exvector term;
195         term=expanded_seq;
196         for (l=0; l<number_of_adds; l++) {
197             ASSERT(is_ex_exactly_of_type(expanded_seq[positions_of_adds[l]],add));
198             add const & addref=ex_to_add(expanded_seq[positions_of_adds[l]]);
199             term[positions_of_adds[l]]=addref.recombine_pair_to_ex(addref.seq[k[l]]);
200         }
201         distrseq.push_back((new ncmul(term,1))->setflag(status_flags::dynallocated |
202                                                         status_flags::expanded));
203
204         // increment k[]
205         l=number_of_adds-1;
206         while ((l>=0)&&((++k[l])>=number_of_add_operands[l])) {
207             k[l]=0;    
208             l--;
209         }
210         if (l<0) break;
211     }
212
213     return (new add(distrseq))->setflag(status_flags::dynallocated |
214                                         status_flags::expanded);
215 }
216
217 int ncmul::degree(symbol const & s) const
218 {
219     int deg_sum=0;
220     for (exvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
221         deg_sum+=(*cit).degree(s);
222     }
223     return deg_sum;
224 }
225
226 int ncmul::ldegree(symbol const & s) const
227 {
228     int deg_sum=0;
229     for (exvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
230         deg_sum+=(*cit).ldegree(s);
231     }
232     return deg_sum;
233 }
234
235 ex ncmul::coeff(symbol const & s, int const n) const
236 {
237     exvector coeffseq;
238     coeffseq.reserve(seq.size());
239
240     if (n==0) {
241         // product of individual coeffs
242         // if a non-zero power of s is found, the resulting product will be 0
243         exvector::const_iterator it=seq.begin();
244         while (it!=seq.end()) {
245             coeffseq.push_back((*it).coeff(s,n));
246             ++it;
247         }
248         return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
249     }
250          
251     exvector::const_iterator it=seq.begin();
252     bool coeff_found=0;
253     while (it!=seq.end()) {
254         ex c=(*it).coeff(s,n);
255         if (!c.is_zero()) {
256             coeffseq.push_back(c);
257             coeff_found=1;
258         } else {
259             coeffseq.push_back(*it);
260         }
261         ++it;
262     }
263
264     if (coeff_found) return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
265     
266     return exZERO();
267 }
268
269 unsigned ncmul::count_factors(ex const & e) const
270 {
271     if ((is_ex_exactly_of_type(e,mul)&&(e.return_type()!=return_types::commutative))||
272         (is_ex_exactly_of_type(e,ncmul))) {
273         unsigned factors=0;
274         for (int i=0; i<e.nops(); i++) {
275             factors += count_factors(e.op(i));
276         }
277         return factors;
278     }
279     return 1;
280 }
281         
282 void ncmul::append_factors(exvector & v, ex const & e) const
283 {
284     if ((is_ex_exactly_of_type(e,mul)&&(e.return_type()!=return_types::commutative))||
285         (is_ex_exactly_of_type(e,ncmul))) {
286         for (int i=0; i<e.nops(); i++) {
287             append_factors(v,e.op(i));
288         }
289         return;
290     }
291     v.push_back(e);
292 }
293
294 typedef vector<unsigned> unsignedvector;
295 typedef vector<exvector> exvectorvector;
296
297 ex ncmul::eval(int level) const
298 {
299     // simplifications: ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
300     //                      ncmul(...,x1,x2,...,x3,x4,...) (associativity)
301     //                  ncmul(x) -> x
302     //                  ncmul() -> 1
303     //                  ncmul(...,c1,...,c2,...) ->
304     //                      *(c1,c2,ncmul(...)) (pull out commutative elements)
305     //                  ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
306     //                      (collect elements of same type)
307     //                  ncmul(x1,x2,x3,...) -> x::eval_ncmul(x1,x2,x3,...)
308     // the following rule would be nice, but produces a recursion,
309     // which must be trapped by introducing a flag that the sub-ncmuls()
310     // are already evaluated (maybe later...)
311     //                  ncmul(x1,x2,...,X,y1,y2,...) ->
312     //                      ncmul(ncmul(x1,x2,...),X,ncmul(y1,y2,...)
313     //                      (X noncommutative_composite)
314
315     if ((level==1)&&(flags & status_flags::evaluated)) {
316         return *this;
317     }
318
319     exvector evaledseq=evalchildren(level);
320
321     // ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
322     //     ncmul(...,x1,x2,...,x3,x4,...) (associativity)
323     unsigned factors=0;
324     for (exvector::const_iterator cit=evaledseq.begin(); cit!=evaledseq.end(); ++cit) {
325         factors += count_factors(*cit);
326     }
327
328     exvector assocseq;
329     assocseq.reserve(factors);
330     for (exvector::const_iterator cit=evaledseq.begin(); cit!=evaledseq.end(); ++cit) {
331         append_factors(assocseq,*cit);
332     }
333
334     // ncmul(x) -> x
335     if (assocseq.size()==1) return *(seq.begin());
336
337     // ncmul() -> 1
338     if (assocseq.size()==0) return exONE();
339
340     // determine return types
341     unsignedvector rettypes;
342     rettypes.reserve(assocseq.size());
343     unsigned i=0;
344     unsigned count_commutative=0;
345     unsigned count_noncommutative=0;
346     unsigned count_noncommutative_composite=0;
347     for (exvector::const_iterator cit=assocseq.begin(); cit!=assocseq.end(); ++cit) {
348         switch (rettypes[i]=(*cit).return_type()) {
349         case return_types::commutative:
350             count_commutative++;
351             break;
352         case return_types::noncommutative:
353             count_noncommutative++;
354             break;
355         case return_types::noncommutative_composite:
356             count_noncommutative_composite++;
357             break;
358         default:
359             throw(std::logic_error("ncmul::eval(): invalid return type"));
360         }
361         ++i;
362     }
363     ASSERT(count_commutative+count_noncommutative+count_noncommutative_composite==assocseq.size());
364
365     // ncmul(...,c1,...,c2,...) ->
366     //     *(c1,c2,ncmul(...)) (pull out commutative elements)
367     if (count_commutative!=0) {
368         exvector commutativeseq;
369         commutativeseq.reserve(count_commutative+1);
370         exvector noncommutativeseq;
371         noncommutativeseq.reserve(assocseq.size()-count_commutative);
372         for (i=0; i<assocseq.size(); ++i) {
373             if (rettypes[i]==return_types::commutative) {
374                 commutativeseq.push_back(assocseq[i]);
375             } else {
376                 noncommutativeseq.push_back(assocseq[i]);
377             }
378         }
379         commutativeseq.push_back((new ncmul(noncommutativeseq,1))->
380                                   setflag(status_flags::dynallocated));
381         return (new mul(commutativeseq))->setflag(status_flags::dynallocated);
382     }
383         
384     // ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
385     //     (collect elements of same type)
386
387     if (count_noncommutative_composite==0) {
388         // there are neither commutative nor noncommutative_composite
389         // elements in assocseq
390         ASSERT(count_commutative==0);
391
392         exvectorvector evv;
393         unsignedvector rttinfos;
394         evv.reserve(assocseq.size());
395         rttinfos.reserve(assocseq.size());
396
397         for (exvector::const_iterator cit=assocseq.begin(); cit!=assocseq.end(); ++cit) {
398             unsigned ti=(*cit).return_type_tinfo();
399             // search type in vector of known types
400             for (i=0; i<rttinfos.size(); ++i) {
401                 if (ti==rttinfos[i]) {
402                     evv[i].push_back(*cit);
403                     break;
404                 }
405             }
406             if (i>=rttinfos.size()) {
407                 // new type
408                 rttinfos.push_back(ti);
409                 evv.push_back(exvector());
410                 (*(evv.end()-1)).reserve(assocseq.size());
411                 (*(evv.end()-1)).push_back(*cit);
412             }
413         }
414
415 #ifdef DOASSERT
416         ASSERT(evv.size()==rttinfos.size());
417         ASSERT(evv.size()>0);
418         unsigned s=0;
419         for (i=0; i<evv.size(); ++i) {
420             s += evv[i].size();
421         }
422         ASSERT(s==assocseq.size());
423 #endif // def DOASSERT
424         
425         // if all elements are of same type, simplify the string
426         if (evv.size()==1) {
427             return evv[0][0].simplify_ncmul(evv[0]);
428         }
429         
430         exvector splitseq;
431         splitseq.reserve(evv.size());
432         for (i=0; i<evv.size(); ++i) {
433             splitseq.push_back((new ncmul(evv[i]))->
434                                setflag(status_flags::dynallocated));
435         }
436
437         return (new mul(splitseq))->setflag(status_flags::dynallocated);
438     }
439     
440     return (new ncmul(assocseq))->setflag(status_flags::dynallocated |
441                                           status_flags::evaluated);
442 }
443
444 exvector ncmul::get_indices(void) const
445 {
446     // return union of indices of factors
447     exvector iv;
448     for (exvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
449         exvector subiv=(*cit).get_indices();
450         iv.reserve(iv.size()+subiv.size());
451         for (exvector::const_iterator cit2=subiv.begin(); cit2!=subiv.end(); ++cit2) {
452             iv.push_back(*cit2);
453         }
454     }
455     return iv;
456 }
457
458 ex ncmul::subs(lst const & ls, lst const & lr) const
459 {
460     return ncmul(subschildren(ls, lr));
461 }
462
463 ex ncmul::thisexprseq(exvector const & v) const
464 {
465     return (new ncmul(v))->setflag(status_flags::dynallocated);
466 }
467
468 ex ncmul::thisexprseq(exvector * vp) const
469 {
470     return (new ncmul(vp))->setflag(status_flags::dynallocated);
471 }
472
473 // protected
474
475 int ncmul::compare_same_type(basic const & other) const
476 {
477     return exprseq::compare_same_type(other);
478 }
479
480 unsigned ncmul::return_type(void) const
481 {
482     if (seq.size()==0) {
483         // ncmul without factors: should not happen, but commutes
484         return return_types::commutative;
485     }
486
487     bool all_commutative=1;
488     unsigned rt;
489     exvector::const_iterator cit_noncommutative_element; // point to first found nc element
490
491     for (exvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
492         rt=(*cit).return_type();
493         if (rt==return_types::noncommutative_composite) return rt; // one ncc -> mul also ncc
494         if ((rt==return_types::noncommutative)&&(all_commutative)) {
495             // first nc element found, remember position
496             cit_noncommutative_element=cit;
497             all_commutative=0;
498         }
499         if ((rt==return_types::noncommutative)&&(!all_commutative)) {
500             // another nc element found, compare type_infos
501             if ((*cit_noncommutative_element).return_type_tinfo()!=(*cit).return_type_tinfo()) {
502                 // diffent types -> mul is ncc
503                 return return_types::noncommutative_composite;
504             }
505         }
506     }
507     // all factors checked
508     ASSERT(!all_commutative); // not all factors should commute, because this is a ncmul();
509     return all_commutative ? return_types::commutative : return_types::noncommutative;
510 }
511    
512 unsigned ncmul::return_type_tinfo(void) const
513 {
514     if (seq.size()==0) {
515         // mul without factors: should not happen
516         return tinfo_key;
517     }
518     // return type_info of first noncommutative element
519     for (exvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
520         if ((*cit).return_type()==return_types::noncommutative) {
521             return (*cit).return_type_tinfo();
522         }
523     }
524     // no noncommutative element found, should not happen
525     return tinfo_key;
526 }
527
528 //////////
529 // new virtual functions which can be overridden by derived classes
530 //////////
531
532 // none
533
534 //////////
535 // non-virtual functions in this class
536 //////////
537
538 exvector ncmul::expandchildren(unsigned options) const
539 {
540     exvector s;
541     s.reserve(seq.size());
542
543     for (exvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
544         s.push_back((*it).expand(options));
545     }
546     return s;
547 }
548
549 exvector const & ncmul::get_factors(void) const
550 {
551     return seq;
552 }
553
554 //////////
555 // static member variables
556 //////////
557
558 // protected
559
560 unsigned ncmul::precedence=50;
561
562
563 //////////
564 // global constants
565 //////////
566
567 const ncmul some_ncmul;
568 type_info const & typeid_ncmul=typeid(some_ncmul);
569
570 //////////
571 // friend functions
572 //////////
573
574 ex nonsimplified_ncmul(exvector const & v)
575 {
576     return (new ncmul(v))->setflag(status_flags::dynallocated);
577 }
578
579 ex simplified_ncmul(exvector const & v)
580 {
581     if (v.size()==0) {
582         return exONE();
583     } else if (v.size()==1) {
584         return v[0];
585     }
586     return (new ncmul(v))->setflag(status_flags::dynallocated |
587                                    status_flags::evaluated);
588 }
589
590