- replaced the Derivative() function by a more resonable fderivative class;
[ginac.git] / ginac / ncmul.cpp
1 /** @file ncmul.cpp
2  *
3  *  Implementation of GiNaC's non-commutative products of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <algorithm>
24 #include <iostream>
25 #include <stdexcept>
26
27 #include "ncmul.h"
28 #include "ex.h"
29 #include "add.h"
30 #include "mul.h"
31 #include "matrix.h"
32 #include "print.h"
33 #include "archive.h"
34 #include "debugmsg.h"
35 #include "utils.h"
36
37 namespace GiNaC {
38
39 GINAC_IMPLEMENT_REGISTERED_CLASS(ncmul, exprseq)
40
41 //////////
42 // default constructor, destructor, copy constructor assignment operator and helpers
43 //////////
44
45 ncmul::ncmul()
46 {
47         debugmsg("ncmul default constructor",LOGLEVEL_CONSTRUCT);
48         tinfo_key = TINFO_ncmul;
49 }
50
51 DEFAULT_COPY(ncmul)
52 DEFAULT_DESTROY(ncmul)
53
54 //////////
55 // other constructors
56 //////////
57
58 // public
59
60 ncmul::ncmul(const ex & lh, const ex & rh) : inherited(lh,rh)
61 {
62         debugmsg("ncmul constructor from ex,ex",LOGLEVEL_CONSTRUCT);
63         tinfo_key = TINFO_ncmul;
64 }
65
66 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3) : inherited(f1,f2,f3)
67 {
68         debugmsg("ncmul constructor from 3 ex",LOGLEVEL_CONSTRUCT);
69         tinfo_key = TINFO_ncmul;
70 }
71
72 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
73              const ex & f4) : inherited(f1,f2,f3,f4)
74 {
75         debugmsg("ncmul constructor from 4 ex",LOGLEVEL_CONSTRUCT);
76         tinfo_key = TINFO_ncmul;
77 }
78
79 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
80              const ex & f4, const ex & f5) : inherited(f1,f2,f3,f4,f5)
81 {
82         debugmsg("ncmul constructor from 5 ex",LOGLEVEL_CONSTRUCT);
83         tinfo_key = TINFO_ncmul;
84 }
85
86 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
87              const ex & f4, const ex & f5, const ex & f6) : inherited(f1,f2,f3,f4,f5,f6)
88 {
89         debugmsg("ncmul constructor from 6 ex",LOGLEVEL_CONSTRUCT);
90         tinfo_key = TINFO_ncmul;
91 }
92
93 ncmul::ncmul(const exvector & v, bool discardable) : inherited(v,discardable)
94 {
95         debugmsg("ncmul constructor from exvector,bool",LOGLEVEL_CONSTRUCT);
96         tinfo_key = TINFO_ncmul;
97 }
98
99 ncmul::ncmul(exvector * vp) : inherited(vp)
100 {
101         debugmsg("ncmul constructor from exvector *",LOGLEVEL_CONSTRUCT);
102         tinfo_key = TINFO_ncmul;
103 }
104
105 //////////
106 // archiving
107 //////////
108
109 DEFAULT_ARCHIVING(ncmul)
110         
111 //////////
112 // functions overriding virtual functions from base classes
113 //////////
114
115 // public
116
117 void ncmul::print(const print_context & c, unsigned level) const
118 {
119         debugmsg("ncmul print", LOGLEVEL_PRINT);
120
121         if (is_of_type(c, print_tree)) {
122
123                 inherited::print(c, level);
124
125         } else if (is_of_type(c, print_csrc)) {
126
127                 c.s << "ncmul(";
128                 exvector::const_iterator it = seq.begin(), itend = seq.end()-1;
129                 while (it != itend) {
130                         it->print(c, precedence());
131                         c.s << ",";
132                         it++;
133                 }
134                 it->print(c, precedence());
135                 c.s << ")";
136
137         } else
138                 printseq(c, '(', '*', ')', precedence(), level);
139 }
140
141 bool ncmul::info(unsigned inf) const
142 {
143         return inherited::info(inf);
144 }
145
146 typedef std::vector<int> intvector;
147
148 ex ncmul::expand(unsigned options) const
149 {
150         // First, expand the children
151         exvector expanded_seq = expandchildren(options);
152         
153         // Now, look for all the factors that are sums and remember their
154         // position and number of terms.
155         intvector positions_of_adds(expanded_seq.size());
156         intvector number_of_add_operands(expanded_seq.size());
157
158         int number_of_adds = 0;
159         int number_of_expanded_terms = 1;
160
161         unsigned current_position = 0;
162         exvector::const_iterator last = expanded_seq.end();
163         for (exvector::const_iterator cit=expanded_seq.begin(); cit!=last; ++cit) {
164                 if (is_exactly_a<add>(*cit)) {
165                         positions_of_adds[number_of_adds] = current_position;
166                         unsigned num_ops = cit->nops();
167                         number_of_add_operands[number_of_adds] = num_ops;
168                         number_of_expanded_terms *= num_ops;
169                         number_of_adds++;
170                 }
171                 ++current_position;
172         }
173
174         // If there are no sums, we are done
175         if (number_of_adds == 0)
176                 return (new ncmul(expanded_seq, true))->
177                         setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
178
179         // Now, form all possible products of the terms of the sums with the
180         // remaining factors, and add them together
181         exvector distrseq;
182         distrseq.reserve(number_of_expanded_terms);
183
184         intvector k(number_of_adds);
185
186         while (true) {
187                 exvector term = expanded_seq;
188                 for (int i=0; i<number_of_adds; i++)
189                         term[positions_of_adds[i]] = expanded_seq[positions_of_adds[i]].op(k[i]);
190                 distrseq.push_back((new ncmul(term, true))->
191                                     setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));
192
193                 // increment k[]
194                 int l = number_of_adds-1;
195                 while ((l>=0) && ((++k[l]) >= number_of_add_operands[l])) {
196                         k[l] = 0;
197                         l--;
198                 }
199                 if (l<0)
200                         break;
201         }
202
203         return (new add(distrseq))->
204                 setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
205 }
206
207 int ncmul::degree(const ex & s) const
208 {
209         // Sum up degrees of factors
210         int deg_sum = 0;
211         exvector::const_iterator i = seq.begin(), end = seq.end();
212         while (i != end) {
213                 deg_sum += i->degree(s);
214                 ++i;
215         }
216         return deg_sum;
217 }
218
219 int ncmul::ldegree(const ex & s) const
220 {
221         // Sum up degrees of factors
222         int deg_sum = 0;
223         exvector::const_iterator i = seq.begin(), end = seq.end();
224         while (i != end) {
225                 deg_sum += i->degree(s);
226                 ++i;
227         }
228         return deg_sum;
229 }
230
231 ex ncmul::coeff(const ex & s, int n) const
232 {
233         exvector coeffseq;
234         coeffseq.reserve(seq.size());
235
236         if (n == 0) {
237                 // product of individual coeffs
238                 // if a non-zero power of s is found, the resulting product will be 0
239                 exvector::const_iterator it=seq.begin();
240                 while (it!=seq.end()) {
241                         coeffseq.push_back((*it).coeff(s,n));
242                         ++it;
243                 }
244                 return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
245         }
246                  
247         exvector::const_iterator i = seq.begin(), end = seq.end();
248         bool coeff_found = false;
249         while (i != end) {
250                 ex c = i->coeff(s,n);
251                 if (c.is_zero()) {
252                         coeffseq.push_back(*i);
253                 } else {
254                         coeffseq.push_back(c);
255                         coeff_found = true;
256                 }
257                 ++i;
258         }
259
260         if (coeff_found) return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
261         
262         return _ex0();
263 }
264
265 unsigned ncmul::count_factors(const ex & e) const
266 {
267         if ((is_ex_exactly_of_type(e,mul)&&(e.return_type()!=return_types::commutative))||
268                 (is_ex_exactly_of_type(e,ncmul))) {
269                 unsigned factors=0;
270                 for (unsigned i=0; i<e.nops(); i++)
271                         factors += count_factors(e.op(i));
272                 
273                 return factors;
274         }
275         return 1;
276 }
277                 
278 void ncmul::append_factors(exvector & v, const ex & e) const
279 {
280         if ((is_ex_exactly_of_type(e,mul)&&(e.return_type()!=return_types::commutative))||
281                 (is_ex_exactly_of_type(e,ncmul))) {
282                 for (unsigned i=0; i<e.nops(); i++)
283                         append_factors(v,e.op(i));
284         } else 
285                 v.push_back(e);
286 }
287
288 typedef std::vector<unsigned> unsignedvector;
289 typedef std::vector<exvector> exvectorvector;
290
291 ex ncmul::eval(int level) const
292 {
293         // simplifications: ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
294         //                      ncmul(...,x1,x2,...,x3,x4,...) (associativity)
295         //                  ncmul(x) -> x
296         //                  ncmul() -> 1
297         //                  ncmul(...,c1,...,c2,...)
298         //                      *(c1,c2,ncmul(...)) (pull out commutative elements)
299         //                  ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
300         //                      (collect elements of same type)
301         //                  ncmul(x1,x2,x3,...) -> x::simplify_ncmul(x1,x2,x3,...)
302         // the following rule would be nice, but produces a recursion,
303         // which must be trapped by introducing a flag that the sub-ncmuls()
304         // are already evaluated (maybe later...)
305         //                  ncmul(x1,x2,...,X,y1,y2,...) ->
306         //                      ncmul(ncmul(x1,x2,...),X,ncmul(y1,y2,...)
307         //                      (X noncommutative_composite)
308
309         if ((level==1) && (flags & status_flags::evaluated)) {
310                 return *this;
311         }
312
313         exvector evaledseq=evalchildren(level);
314
315         // ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
316         //     ncmul(...,x1,x2,...,x3,x4,...) (associativity)
317         unsigned factors = 0;
318         exvector::const_iterator cit = evaledseq.begin(), citend = evaledseq.end();
319         while (cit != citend)
320                 factors += count_factors(*cit++);
321         
322         exvector assocseq;
323         assocseq.reserve(factors);
324         cit = evaledseq.begin();
325         while (cit != citend)
326                 append_factors(assocseq, *cit++);
327         
328         // ncmul(x) -> x
329         if (assocseq.size()==1) return *(seq.begin());
330
331         // ncmul() -> 1
332         if (assocseq.empty()) return _ex1();
333
334         // determine return types
335         unsignedvector rettypes;
336         rettypes.reserve(assocseq.size());
337         unsigned i = 0;
338         unsigned count_commutative=0;
339         unsigned count_noncommutative=0;
340         unsigned count_noncommutative_composite=0;
341         cit = assocseq.begin(); citend = assocseq.end();
342         while (cit != citend) {
343                 switch (rettypes[i] = cit->return_type()) {
344                 case return_types::commutative:
345                         count_commutative++;
346                         break;
347                 case return_types::noncommutative:
348                         count_noncommutative++;
349                         break;
350                 case return_types::noncommutative_composite:
351                         count_noncommutative_composite++;
352                         break;
353                 default:
354                         throw(std::logic_error("ncmul::eval(): invalid return type"));
355                 }
356                 ++i; ++cit;
357         }
358         GINAC_ASSERT(count_commutative+count_noncommutative+count_noncommutative_composite==assocseq.size());
359
360         // ncmul(...,c1,...,c2,...) ->
361         //     *(c1,c2,ncmul(...)) (pull out commutative elements)
362         if (count_commutative!=0) {
363                 exvector commutativeseq;
364                 commutativeseq.reserve(count_commutative+1);
365                 exvector noncommutativeseq;
366                 noncommutativeseq.reserve(assocseq.size()-count_commutative);
367                 unsigned num = assocseq.size();
368                 for (unsigned i=0; i<num; ++i) {
369                         if (rettypes[i]==return_types::commutative)
370                                 commutativeseq.push_back(assocseq[i]);
371                         else
372                                 noncommutativeseq.push_back(assocseq[i]);
373                 }
374                 commutativeseq.push_back((new ncmul(noncommutativeseq,1))->setflag(status_flags::dynallocated));
375                 return (new mul(commutativeseq))->setflag(status_flags::dynallocated);
376         }
377                 
378         // ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
379         //     (collect elements of same type)
380
381         if (count_noncommutative_composite==0) {
382                 // there are neither commutative nor noncommutative_composite
383                 // elements in assocseq
384                 GINAC_ASSERT(count_commutative==0);
385
386                 unsigned assoc_num = assocseq.size();
387                 exvectorvector evv;
388                 unsignedvector rttinfos;
389                 evv.reserve(assoc_num);
390                 rttinfos.reserve(assoc_num);
391
392                 cit = assocseq.begin(), citend = assocseq.end();
393                 while (cit != citend) {
394                         unsigned ti = cit->return_type_tinfo();
395                         unsigned rtt_num = rttinfos.size();
396                         // search type in vector of known types
397                         for (i=0; i<rtt_num; ++i) {
398                                 if (ti == rttinfos[i]) {
399                                         evv[i].push_back(*cit);
400                                         break;
401                                 }
402                         }
403                         if (i >= rtt_num) {
404                                 // new type
405                                 rttinfos.push_back(ti);
406                                 evv.push_back(exvector());
407                                 (evv.end()-1)->reserve(assoc_num);
408                                 (evv.end()-1)->push_back(*cit);
409                         }
410                         ++cit;
411                 }
412
413                 unsigned evv_num = evv.size();
414 #ifdef DO_GINAC_ASSERT
415                 GINAC_ASSERT(evv_num == rttinfos.size());
416                 GINAC_ASSERT(evv_num > 0);
417                 unsigned s=0;
418                 for (i=0; i<evv_num; ++i)
419                         s += evv[i].size();
420                 GINAC_ASSERT(s == assoc_num);
421 #endif // def DO_GINAC_ASSERT
422                 
423                 // if all elements are of same type, simplify the string
424                 if (evv_num == 1)
425                         return evv[0][0].simplify_ncmul(evv[0]);
426                 
427                 exvector splitseq;
428                 splitseq.reserve(evv_num);
429                 for (i=0; i<evv_num; ++i)
430                         splitseq.push_back((new ncmul(evv[i]))->setflag(status_flags::dynallocated));
431                 
432                 return (new mul(splitseq))->setflag(status_flags::dynallocated);
433         }
434         
435         return (new ncmul(assocseq))->setflag(status_flags::dynallocated |
436                                                                                   status_flags::evaluated);
437 }
438
439 ex ncmul::evalm(void) const
440 {
441         // Evaluate children first
442         exvector *s = new exvector;
443         s->reserve(seq.size());
444         exvector::const_iterator it = seq.begin(), itend = seq.end();
445         while (it != itend) {
446                 s->push_back(it->evalm());
447                 it++;
448         }
449
450         // If there are only matrices, simply multiply them
451         it = s->begin(); itend = s->end();
452         if (is_ex_of_type(*it, matrix)) {
453                 matrix prod(ex_to<matrix>(*it));
454                 it++;
455                 while (it != itend) {
456                         if (!is_ex_of_type(*it, matrix))
457                                 goto no_matrix;
458                         prod = prod.mul(ex_to<matrix>(*it));
459                         it++;
460                 }
461                 delete s;
462                 return prod;
463         }
464
465 no_matrix:
466         return (new ncmul(s))->setflag(status_flags::dynallocated);
467 }
468
469 ex ncmul::thisexprseq(const exvector & v) const
470 {
471         return (new ncmul(v))->setflag(status_flags::dynallocated);
472 }
473
474 ex ncmul::thisexprseq(exvector * vp) const
475 {
476         return (new ncmul(vp))->setflag(status_flags::dynallocated);
477 }
478
479 // protected
480
481 /** Implementation of ex::diff() for a non-commutative product. It applies
482  *  the product rule.
483  *  @see ex::diff */
484 ex ncmul::derivative(const symbol & s) const
485 {
486         unsigned num = seq.size();
487         exvector addseq;
488         addseq.reserve(num);
489         
490         // D(a*b*c) = D(a)*b*c + a*D(b)*c + a*b*D(c)
491         exvector ncmulseq = seq;
492         for (unsigned i=0; i<num; ++i) {
493                 ex e = seq[i].diff(s);
494                 e.swap(ncmulseq[i]);
495                 addseq.push_back((new ncmul(ncmulseq))->setflag(status_flags::dynallocated));
496                 e.swap(ncmulseq[i]);
497         }
498         return (new add(addseq))->setflag(status_flags::dynallocated);
499 }
500
501 int ncmul::compare_same_type(const basic & other) const
502 {
503         return inherited::compare_same_type(other);
504 }
505
506 unsigned ncmul::return_type(void) const
507 {
508         if (seq.empty())
509                 return return_types::commutative;
510
511         bool all_commutative = true;
512         exvector::const_iterator noncommutative_element; // point to first found nc element
513
514         exvector::const_iterator i = seq.begin(), end = seq.end();
515         while (i != end) {
516                 unsigned rt = i->return_type();
517                 if (rt == return_types::noncommutative_composite)
518                         return rt; // one ncc -> mul also ncc
519                 if ((rt == return_types::noncommutative) && (all_commutative)) {
520                         // first nc element found, remember position
521                         noncommutative_element = i;
522                         all_commutative = false;
523                 }
524                 if ((rt == return_types::noncommutative) && (!all_commutative)) {
525                         // another nc element found, compare type_infos
526                         if (noncommutative_element->return_type_tinfo() != i->return_type_tinfo()) {
527                                 // diffent types -> mul is ncc
528                                 return return_types::noncommutative_composite;
529                         }
530                 }
531                 ++i;
532         }
533         // all factors checked
534         GINAC_ASSERT(!all_commutative); // not all factors should commute, because this is a ncmul();
535         return all_commutative ? return_types::commutative : return_types::noncommutative;
536 }
537    
538 unsigned ncmul::return_type_tinfo(void) const
539 {
540         if (seq.empty())
541                 return tinfo_key;
542
543         // return type_info of first noncommutative element
544         exvector::const_iterator i = seq.begin(), end = seq.end();
545         while (i != end) {
546                 if (i->return_type() == return_types::noncommutative)
547                         return i->return_type_tinfo();
548                 ++i;
549         }
550
551         // no noncommutative element found, should not happen
552         return tinfo_key;
553 }
554
555 //////////
556 // new virtual functions which can be overridden by derived classes
557 //////////
558
559 // none
560
561 //////////
562 // non-virtual functions in this class
563 //////////
564
565 exvector ncmul::expandchildren(unsigned options) const
566 {
567         exvector s;
568         s.reserve(seq.size());
569         exvector::const_iterator it = seq.begin(), itend = seq.end();
570         while (it != itend) {
571                 s.push_back(it->expand(options));
572                 it++;
573         }
574         return s;
575 }
576
577 const exvector & ncmul::get_factors(void) const
578 {
579         return seq;
580 }
581
582 //////////
583 // friend functions
584 //////////
585
586 ex nonsimplified_ncmul(const exvector & v)
587 {
588         return (new ncmul(v))->setflag(status_flags::dynallocated);
589 }
590
591 ex simplified_ncmul(const exvector & v)
592 {
593         if (v.empty())
594                 return _ex1();
595         else if (v.size() == 1)
596                 return v[0];
597         else
598                 return (new ncmul(v))->setflag(status_flags::dynallocated |
599                                                status_flags::evaluated);
600 }
601
602 } // namespace GiNaC