865011119cf3ce27f9b1d87e6fc5846cfa714132
[ginac.git] / ginac / ncmul.cpp
1 /** @file ncmul.cpp
2  *
3  *  Implementation of GiNaC's non-commutative products of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <algorithm>
24 #include <iostream>
25 #include <stdexcept>
26
27 #include "ncmul.h"
28 #include "ex.h"
29 #include "add.h"
30 #include "mul.h"
31 #include "archive.h"
32 #include "debugmsg.h"
33 #include "utils.h"
34
35 #ifndef NO_NAMESPACE_GINAC
36 namespace GiNaC {
37 #endif // ndef NO_NAMESPACE_GINAC
38
39 GINAC_IMPLEMENT_REGISTERED_CLASS(ncmul, exprseq)
40
41 //////////
42 // default constructor, destructor, copy constructor assignment operator and helpers
43 //////////
44
45 // public
46
47 ncmul::ncmul()
48 {
49         debugmsg("ncmul default constructor",LOGLEVEL_CONSTRUCT);
50         tinfo_key = TINFO_ncmul;
51 }
52
53 // protected
54
55 void ncmul::copy(const ncmul & other)
56 {
57         inherited::copy(other);
58 }
59
60 void ncmul::destroy(bool call_parent)
61 {
62         if (call_parent) inherited::destroy(call_parent);
63 }
64
65 //////////
66 // other constructors
67 //////////
68
69 // public
70
71 ncmul::ncmul(const ex & lh, const ex & rh) : inherited(lh,rh)
72 {
73         debugmsg("ncmul constructor from ex,ex",LOGLEVEL_CONSTRUCT);
74         tinfo_key = TINFO_ncmul;
75 }
76
77 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3) : inherited(f1,f2,f3)
78 {
79         debugmsg("ncmul constructor from 3 ex",LOGLEVEL_CONSTRUCT);
80         tinfo_key = TINFO_ncmul;
81 }
82
83 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
84              const ex & f4) : inherited(f1,f2,f3,f4)
85 {
86         debugmsg("ncmul constructor from 4 ex",LOGLEVEL_CONSTRUCT);
87         tinfo_key = TINFO_ncmul;
88 }
89
90 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
91              const ex & f4, const ex & f5) : inherited(f1,f2,f3,f4,f5)
92 {
93         debugmsg("ncmul constructor from 5 ex",LOGLEVEL_CONSTRUCT);
94         tinfo_key = TINFO_ncmul;
95 }
96
97 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
98              const ex & f4, const ex & f5, const ex & f6) : inherited(f1,f2,f3,f4,f5,f6)
99 {
100         debugmsg("ncmul constructor from 6 ex",LOGLEVEL_CONSTRUCT);
101         tinfo_key = TINFO_ncmul;
102 }
103
104 ncmul::ncmul(const exvector & v, bool discardable) : inherited(v,discardable)
105 {
106         debugmsg("ncmul constructor from exvector,bool",LOGLEVEL_CONSTRUCT);
107         tinfo_key = TINFO_ncmul;
108 }
109
110 ncmul::ncmul(exvector * vp) : inherited(vp)
111 {
112         debugmsg("ncmul constructor from exvector *",LOGLEVEL_CONSTRUCT);
113         tinfo_key = TINFO_ncmul;
114 }
115
116 //////////
117 // archiving
118 //////////
119
120 /** Construct object from archive_node. */
121 ncmul::ncmul(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
122 {
123         debugmsg("ncmul constructor from archive_node", LOGLEVEL_CONSTRUCT);
124 }
125
126 /** Unarchive the object. */
127 ex ncmul::unarchive(const archive_node &n, const lst &sym_lst)
128 {
129         return (new ncmul(n, sym_lst))->setflag(status_flags::dynallocated);
130 }
131
132 /** Archive the object. */
133 void ncmul::archive(archive_node &n) const
134 {
135         inherited::archive(n);
136 }
137
138         
139 //////////
140 // functions overriding virtual functions from bases classes
141 //////////
142
143 // public
144
145 void ncmul::print(std::ostream & os, unsigned upper_precedence) const
146 {
147         debugmsg("ncmul print",LOGLEVEL_PRINT);
148         printseq(os,'(','%',')',precedence,upper_precedence);
149 }
150
151 void ncmul::printraw(std::ostream & os) const
152 {
153         debugmsg("ncmul printraw",LOGLEVEL_PRINT);
154         os << "%(";
155         for (exvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
156                 (*it).bp->printraw(os);
157                 os << ",";
158         }
159         os << ",hash=" << hashvalue << ",flags=" << flags;
160         os << ")";
161 }
162
163 void ncmul::printcsrc(std::ostream & os, unsigned type, unsigned upper_precedence) const
164 {
165         debugmsg("ncmul print csrc",LOGLEVEL_PRINT);
166         exvector::const_iterator it;
167         exvector::const_iterator itend = seq.end()-1;
168         os << "ncmul(";
169         for (it=seq.begin(); it!=itend; ++it) {
170                 (*it).bp->printcsrc(os,precedence);
171                 os << ",";
172         }
173         (*it).bp->printcsrc(os,precedence);
174         os << ")";
175 }
176
177 bool ncmul::info(unsigned inf) const
178 {
179         throw(std::logic_error("which flags have to be implemented in ncmul::info()?"));
180 }
181
182 typedef std::vector<int> intvector;
183
184 ex ncmul::expand(unsigned options) const
185 {
186         exvector sub_expanded_seq;
187         intvector positions_of_adds;
188         intvector number_of_add_operands;
189
190         exvector expanded_seq=expandchildren(options);
191
192         positions_of_adds.resize(expanded_seq.size());
193         number_of_add_operands.resize(expanded_seq.size());
194
195         int number_of_adds=0;
196         int number_of_expanded_terms=1;
197
198         unsigned current_position=0;
199         exvector::const_iterator last=expanded_seq.end();
200         for (exvector::const_iterator cit=expanded_seq.begin(); cit!=last; ++cit) {
201                 if (is_ex_exactly_of_type((*cit),add)) {
202                         positions_of_adds[number_of_adds]=current_position;
203                         const add & expanded_addref=ex_to_add(*cit);
204                         number_of_add_operands[number_of_adds]=expanded_addref.seq.size();
205                         number_of_expanded_terms *= expanded_addref.seq.size();
206                         number_of_adds++;
207                 }
208                 current_position++;
209         }
210
211         if (number_of_adds==0) {
212                 return (new ncmul(expanded_seq,1))->setflag(status_flags::dynallocated ||
213                                                                                                         status_flags::expanded);
214         }
215
216         exvector distrseq;
217         distrseq.reserve(number_of_expanded_terms);
218
219         intvector k;
220         k.resize(number_of_adds);
221         
222         int l;
223         for (l=0; l<number_of_adds; l++) {
224                 k[l]=0;
225         }
226
227         while (1) {
228                 exvector term;
229                 term=expanded_seq;
230                 for (l=0; l<number_of_adds; l++) {
231                         GINAC_ASSERT(is_ex_exactly_of_type(expanded_seq[positions_of_adds[l]],add));
232                         const add & addref=ex_to_add(expanded_seq[positions_of_adds[l]]);
233                         term[positions_of_adds[l]]=addref.recombine_pair_to_ex(addref.seq[k[l]]);
234                 }
235                 distrseq.push_back((new ncmul(term,1))->setflag(status_flags::dynallocated |
236                                                                                                                 status_flags::expanded));
237
238                 // increment k[]
239                 l=number_of_adds-1;
240                 while ((l>=0)&&((++k[l])>=number_of_add_operands[l])) {
241                         k[l]=0;    
242                         l--;
243                 }
244                 if (l<0) break;
245         }
246
247         return (new add(distrseq))->setflag(status_flags::dynallocated |
248                                                                                 status_flags::expanded);
249 }
250
251 int ncmul::degree(const symbol & s) const
252 {
253         int deg_sum=0;
254         for (exvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
255                 deg_sum+=(*cit).degree(s);
256         }
257         return deg_sum;
258 }
259
260 int ncmul::ldegree(const symbol & s) const
261 {
262         int deg_sum=0;
263         for (exvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
264                 deg_sum+=(*cit).ldegree(s);
265         }
266         return deg_sum;
267 }
268
269 ex ncmul::coeff(const symbol & s, int n) const
270 {
271         exvector coeffseq;
272         coeffseq.reserve(seq.size());
273
274         if (n==0) {
275                 // product of individual coeffs
276                 // if a non-zero power of s is found, the resulting product will be 0
277                 exvector::const_iterator it=seq.begin();
278                 while (it!=seq.end()) {
279                         coeffseq.push_back((*it).coeff(s,n));
280                         ++it;
281                 }
282                 return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
283         }
284                  
285         exvector::const_iterator it=seq.begin();
286         bool coeff_found=0;
287         while (it!=seq.end()) {
288                 ex c=(*it).coeff(s,n);
289                 if (!c.is_zero()) {
290                         coeffseq.push_back(c);
291                         coeff_found=1;
292                 } else {
293                         coeffseq.push_back(*it);
294                 }
295                 ++it;
296         }
297
298         if (coeff_found) return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
299         
300         return _ex0();
301 }
302
303 unsigned ncmul::count_factors(const ex & e) const
304 {
305         if ((is_ex_exactly_of_type(e,mul)&&(e.return_type()!=return_types::commutative))||
306                 (is_ex_exactly_of_type(e,ncmul))) {
307                 unsigned factors=0;
308                 for (unsigned i=0; i<e.nops(); i++)
309                         factors += count_factors(e.op(i));
310                 
311                 return factors;
312         }
313         return 1;
314 }
315                 
316 void ncmul::append_factors(exvector & v, const ex & e) const
317 {
318         if ((is_ex_exactly_of_type(e,mul)&&(e.return_type()!=return_types::commutative))||
319                 (is_ex_exactly_of_type(e,ncmul))) {
320                 for (unsigned i=0; i<e.nops(); i++)
321                         append_factors(v,e.op(i));
322                 
323                 return;
324         }
325         v.push_back(e);
326 }
327
328 typedef std::vector<unsigned> unsignedvector;
329 typedef std::vector<exvector> exvectorvector;
330
331 ex ncmul::eval(int level) const
332 {
333         // simplifications: ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
334         //                      ncmul(...,x1,x2,...,x3,x4,...) (associativity)
335         //                  ncmul(x) -> x
336         //                  ncmul() -> 1
337         //                  ncmul(...,c1,...,c2,...)
338         //                      *(c1,c2,ncmul(...)) (pull out commutative elements)
339         //                  ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
340         //                      (collect elements of same type)
341         //                  ncmul(x1,x2,x3,...) -> x::eval_ncmul(x1,x2,x3,...)
342         // the following rule would be nice, but produces a recursion,
343         // which must be trapped by introducing a flag that the sub-ncmuls()
344         // are already evaluated (maybe later...)
345         //                  ncmul(x1,x2,...,X,y1,y2,...) ->
346         //                      ncmul(ncmul(x1,x2,...),X,ncmul(y1,y2,...)
347         //                      (X noncommutative_composite)
348
349         if ((level==1)&&(flags & status_flags::evaluated)) {
350                 return *this;
351         }
352
353         exvector evaledseq=evalchildren(level);
354
355         // ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
356         //     ncmul(...,x1,x2,...,x3,x4,...) (associativity)
357         unsigned factors=0;
358         for (exvector::const_iterator cit=evaledseq.begin(); cit!=evaledseq.end(); ++cit) {
359                 factors += count_factors(*cit);
360         }
361
362         exvector assocseq;
363         assocseq.reserve(factors);
364         for (exvector::const_iterator cit=evaledseq.begin(); cit!=evaledseq.end(); ++cit) {
365                 append_factors(assocseq,*cit);
366         }
367
368         // ncmul(x) -> x
369         if (assocseq.size()==1) return *(seq.begin());
370
371         // ncmul() -> 1
372         if (assocseq.size()==0) return _ex1();
373
374         // determine return types
375         unsignedvector rettypes;
376         rettypes.reserve(assocseq.size());
377         unsigned i=0;
378         unsigned count_commutative=0;
379         unsigned count_noncommutative=0;
380         unsigned count_noncommutative_composite=0;
381         for (exvector::const_iterator cit=assocseq.begin(); cit!=assocseq.end(); ++cit) {
382                 switch (rettypes[i]=(*cit).return_type()) {
383                 case return_types::commutative:
384                         count_commutative++;
385                         break;
386                 case return_types::noncommutative:
387                         count_noncommutative++;
388                         break;
389                 case return_types::noncommutative_composite:
390                         count_noncommutative_composite++;
391                         break;
392                 default:
393                         throw(std::logic_error("ncmul::eval(): invalid return type"));
394                 }
395                 ++i;
396         }
397         GINAC_ASSERT(count_commutative+count_noncommutative+count_noncommutative_composite==assocseq.size());
398
399         // ncmul(...,c1,...,c2,...) ->
400         //     *(c1,c2,ncmul(...)) (pull out commutative elements)
401         if (count_commutative!=0) {
402                 exvector commutativeseq;
403                 commutativeseq.reserve(count_commutative+1);
404                 exvector noncommutativeseq;
405                 noncommutativeseq.reserve(assocseq.size()-count_commutative);
406                 for (i=0; i<assocseq.size(); ++i) {
407                         if (rettypes[i]==return_types::commutative) {
408                                 commutativeseq.push_back(assocseq[i]);
409                         } else {
410                                 noncommutativeseq.push_back(assocseq[i]);
411                         }
412                 }
413                 commutativeseq.push_back((new ncmul(noncommutativeseq,1))->setflag(status_flags::dynallocated));
414                 return (new mul(commutativeseq))->setflag(status_flags::dynallocated);
415         }
416                 
417         // ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
418         //     (collect elements of same type)
419
420         if (count_noncommutative_composite==0) {
421                 // there are neither commutative nor noncommutative_composite
422                 // elements in assocseq
423                 GINAC_ASSERT(count_commutative==0);
424
425                 exvectorvector evv;
426                 unsignedvector rttinfos;
427                 evv.reserve(assocseq.size());
428                 rttinfos.reserve(assocseq.size());
429
430                 for (exvector::const_iterator cit=assocseq.begin(); cit!=assocseq.end(); ++cit) {
431                         unsigned ti=(*cit).return_type_tinfo();
432                         // search type in vector of known types
433                         for (i=0; i<rttinfos.size(); ++i) {
434                                 if (ti==rttinfos[i]) {
435                                         evv[i].push_back(*cit);
436                                         break;
437                                 }
438                         }
439                         if (i>=rttinfos.size()) {
440                                 // new type
441                                 rttinfos.push_back(ti);
442                                 evv.push_back(exvector());
443                                 (*(evv.end()-1)).reserve(assocseq.size());
444                                 (*(evv.end()-1)).push_back(*cit);
445                         }
446                 }
447
448 #ifdef DO_GINAC_ASSERT
449                 GINAC_ASSERT(evv.size()==rttinfos.size());
450                 GINAC_ASSERT(evv.size()>0);
451                 unsigned s=0;
452                 for (i=0; i<evv.size(); ++i) {
453                         s += evv[i].size();
454                 }
455                 GINAC_ASSERT(s==assocseq.size());
456 #endif // def DO_GINAC_ASSERT
457                 
458                 // if all elements are of same type, simplify the string
459                 if (evv.size()==1) {
460                         return evv[0][0].simplify_ncmul(evv[0]);
461                 }
462                 
463                 exvector splitseq;
464                 splitseq.reserve(evv.size());
465                 for (i=0; i<evv.size(); ++i) {
466                         splitseq.push_back((new ncmul(evv[i]))->setflag(status_flags::dynallocated));
467                 }
468
469                 return (new mul(splitseq))->setflag(status_flags::dynallocated);
470         }
471         
472         return (new ncmul(assocseq))->setflag(status_flags::dynallocated |
473                                                                                   status_flags::evaluated);
474 }
475
476 exvector ncmul::get_indices(void) const
477 {
478         // return union of indices of factors
479         exvector iv;
480         for (exvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
481                 exvector subiv=(*cit).get_indices();
482                 iv.reserve(iv.size()+subiv.size());
483                 for (exvector::const_iterator cit2=subiv.begin(); cit2!=subiv.end(); ++cit2) {
484                         iv.push_back(*cit2);
485                 }
486         }
487         return iv;
488 }
489
490 ex ncmul::subs(const lst & ls, const lst & lr) const
491 {
492         return ncmul(subschildren(ls, lr));
493 }
494
495 ex ncmul::thisexprseq(const exvector & v) const
496 {
497         return (new ncmul(v))->setflag(status_flags::dynallocated);
498 }
499
500 ex ncmul::thisexprseq(exvector * vp) const
501 {
502         return (new ncmul(vp))->setflag(status_flags::dynallocated);
503 }
504
505 // protected
506
507 /** Implementation of ex::diff() for a non-commutative product. It always returns 0.
508  *  @see ex::diff */
509 ex ncmul::derivative(const symbol & s) const
510 {
511         return _ex0();
512 }
513
514 int ncmul::compare_same_type(const basic & other) const
515 {
516         return inherited::compare_same_type(other);
517 }
518
519 unsigned ncmul::return_type(void) const
520 {
521         if (seq.size()==0) {
522                 // ncmul without factors: should not happen, but commutes
523                 return return_types::commutative;
524         }
525
526         bool all_commutative=1;
527         unsigned rt;
528         exvector::const_iterator cit_noncommutative_element; // point to first found nc element
529
530         for (exvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
531                 rt=(*cit).return_type();
532                 if (rt==return_types::noncommutative_composite) return rt; // one ncc -> mul also ncc
533                 if ((rt==return_types::noncommutative)&&(all_commutative)) {
534                         // first nc element found, remember position
535                         cit_noncommutative_element=cit;
536                         all_commutative=0;
537                 }
538                 if ((rt==return_types::noncommutative)&&(!all_commutative)) {
539                         // another nc element found, compare type_infos
540                         if ((*cit_noncommutative_element).return_type_tinfo()!=(*cit).return_type_tinfo()) {
541                                 // diffent types -> mul is ncc
542                                 return return_types::noncommutative_composite;
543                         }
544                 }
545         }
546         // all factors checked
547         GINAC_ASSERT(!all_commutative); // not all factors should commute, because this is a ncmul();
548         return all_commutative ? return_types::commutative : return_types::noncommutative;
549 }
550    
551 unsigned ncmul::return_type_tinfo(void) const
552 {
553         if (seq.size()==0) {
554                 // mul without factors: should not happen
555                 return tinfo_key;
556         }
557         // return type_info of first noncommutative element
558         for (exvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
559                 if ((*cit).return_type()==return_types::noncommutative) {
560                         return (*cit).return_type_tinfo();
561                 }
562         }
563         // no noncommutative element found, should not happen
564         return tinfo_key;
565 }
566
567 //////////
568 // new virtual functions which can be overridden by derived classes
569 //////////
570
571 // none
572
573 //////////
574 // non-virtual functions in this class
575 //////////
576
577 exvector ncmul::expandchildren(unsigned options) const
578 {
579         exvector s;
580         s.reserve(seq.size());
581
582         for (exvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
583                 s.push_back((*it).expand(options));
584         }
585         return s;
586 }
587
588 const exvector & ncmul::get_factors(void) const
589 {
590         return seq;
591 }
592
593 //////////
594 // static member variables
595 //////////
596
597 // protected
598
599 unsigned ncmul::precedence=50;
600
601 //////////
602 // friend functions
603 //////////
604
605 ex nonsimplified_ncmul(const exvector & v)
606 {
607         return (new ncmul(v))->setflag(status_flags::dynallocated);
608 }
609
610 ex simplified_ncmul(const exvector & v)
611 {
612         if (v.size()==0) {
613                 return _ex1();
614         } else if (v.size()==1) {
615                 return v[0];
616         }
617         return (new ncmul(v))->setflag(status_flags::dynallocated |
618                                        status_flags::evaluated);
619 }
620
621 #ifndef NO_NAMESPACE_GINAC
622 } // namespace GiNaC
623 #endif // ndef NO_NAMESPACE_GINAC