generous use of auto_ptr to provide better exception safety and make the code
[ginac.git] / ginac / ncmul.cpp
1 /** @file ncmul.cpp
2  *
3  *  Implementation of GiNaC's non-commutative products of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <algorithm>
24 #include <iostream>
25 #include <stdexcept>
26
27 #include "ncmul.h"
28 #include "ex.h"
29 #include "add.h"
30 #include "mul.h"
31 #include "matrix.h"
32 #include "archive.h"
33 #include "utils.h"
34
35 namespace GiNaC {
36
37 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(ncmul, exprseq,
38   print_func<print_context>(&ncmul::do_print).
39   print_func<print_tree>(&basic::do_print_tree).
40   print_func<print_csrc>(&ncmul::do_print_csrc).
41   print_func<print_python_repr>(&ncmul::do_print_csrc))
42
43
44 //////////
45 // default constructor
46 //////////
47
48 ncmul::ncmul()
49 {
50         tinfo_key = TINFO_ncmul;
51 }
52
53 //////////
54 // other constructors
55 //////////
56
57 // public
58
59 ncmul::ncmul(const ex & lh, const ex & rh) : inherited(lh,rh)
60 {
61         tinfo_key = TINFO_ncmul;
62 }
63
64 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3) : inherited(f1,f2,f3)
65 {
66         tinfo_key = TINFO_ncmul;
67 }
68
69 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
70              const ex & f4) : inherited(f1,f2,f3,f4)
71 {
72         tinfo_key = TINFO_ncmul;
73 }
74
75 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
76              const ex & f4, const ex & f5) : inherited(f1,f2,f3,f4,f5)
77 {
78         tinfo_key = TINFO_ncmul;
79 }
80
81 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
82              const ex & f4, const ex & f5, const ex & f6) : inherited(f1,f2,f3,f4,f5,f6)
83 {
84         tinfo_key = TINFO_ncmul;
85 }
86
87 ncmul::ncmul(const exvector & v, bool discardable) : inherited(v,discardable)
88 {
89         tinfo_key = TINFO_ncmul;
90 }
91
92 ncmul::ncmul(std::auto_ptr<exvector> vp) : inherited(vp)
93 {
94         tinfo_key = TINFO_ncmul;
95 }
96
97 //////////
98 // archiving
99 //////////
100
101 DEFAULT_ARCHIVING(ncmul)
102         
103 //////////
104 // functions overriding virtual functions from base classes
105 //////////
106
107 // public
108
109 void ncmul::do_print(const print_context & c, unsigned level) const
110 {
111         printseq(c, '(', '*', ')', precedence(), level);
112 }
113
114 void ncmul::do_print_csrc(const print_context & c, unsigned level) const
115 {
116         c.s << class_name();
117         printseq(c, '(', ',', ')', precedence(), precedence());
118 }
119
120 bool ncmul::info(unsigned inf) const
121 {
122         return inherited::info(inf);
123 }
124
125 typedef std::vector<int> intvector;
126
127 ex ncmul::expand(unsigned options) const
128 {
129         // First, expand the children
130         exvector expanded_seq = expandchildren(options);
131         
132         // Now, look for all the factors that are sums and remember their
133         // position and number of terms.
134         intvector positions_of_adds(expanded_seq.size());
135         intvector number_of_add_operands(expanded_seq.size());
136
137         size_t number_of_adds = 0;
138         size_t number_of_expanded_terms = 1;
139
140         size_t current_position = 0;
141         exvector::const_iterator last = expanded_seq.end();
142         for (exvector::const_iterator cit=expanded_seq.begin(); cit!=last; ++cit) {
143                 if (is_exactly_a<add>(*cit)) {
144                         positions_of_adds[number_of_adds] = current_position;
145                         size_t num_ops = cit->nops();
146                         number_of_add_operands[number_of_adds] = num_ops;
147                         number_of_expanded_terms *= num_ops;
148                         number_of_adds++;
149                 }
150                 ++current_position;
151         }
152
153         // If there are no sums, we are done
154         if (number_of_adds == 0)
155                 return (new ncmul(expanded_seq, true))->
156                         setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
157
158         // Now, form all possible products of the terms of the sums with the
159         // remaining factors, and add them together
160         exvector distrseq;
161         distrseq.reserve(number_of_expanded_terms);
162
163         intvector k(number_of_adds);
164
165         while (true) {
166                 exvector term = expanded_seq;
167                 for (size_t i=0; i<number_of_adds; i++)
168                         term[positions_of_adds[i]] = expanded_seq[positions_of_adds[i]].op(k[i]);
169                 distrseq.push_back((new ncmul(term, true))->
170                                     setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0)));
171
172                 // increment k[]
173                 int l = number_of_adds-1;
174                 while ((l>=0) && ((++k[l]) >= number_of_add_operands[l])) {
175                         k[l] = 0;
176                         l--;
177                 }
178                 if (l<0)
179                         break;
180         }
181
182         return (new add(distrseq))->
183                 setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
184 }
185
186 int ncmul::degree(const ex & s) const
187 {
188         // Sum up degrees of factors
189         int deg_sum = 0;
190         exvector::const_iterator i = seq.begin(), end = seq.end();
191         while (i != end) {
192                 deg_sum += i->degree(s);
193                 ++i;
194         }
195         return deg_sum;
196 }
197
198 int ncmul::ldegree(const ex & s) const
199 {
200         // Sum up degrees of factors
201         int deg_sum = 0;
202         exvector::const_iterator i = seq.begin(), end = seq.end();
203         while (i != end) {
204                 deg_sum += i->degree(s);
205                 ++i;
206         }
207         return deg_sum;
208 }
209
210 ex ncmul::coeff(const ex & s, int n) const
211 {
212         exvector coeffseq;
213         coeffseq.reserve(seq.size());
214
215         if (n == 0) {
216                 // product of individual coeffs
217                 // if a non-zero power of s is found, the resulting product will be 0
218                 exvector::const_iterator it=seq.begin();
219                 while (it!=seq.end()) {
220                         coeffseq.push_back((*it).coeff(s,n));
221                         ++it;
222                 }
223                 return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
224         }
225                  
226         exvector::const_iterator i = seq.begin(), end = seq.end();
227         bool coeff_found = false;
228         while (i != end) {
229                 ex c = i->coeff(s,n);
230                 if (c.is_zero()) {
231                         coeffseq.push_back(*i);
232                 } else {
233                         coeffseq.push_back(c);
234                         coeff_found = true;
235                 }
236                 ++i;
237         }
238
239         if (coeff_found) return (new ncmul(coeffseq,1))->setflag(status_flags::dynallocated);
240         
241         return _ex0;
242 }
243
244 size_t ncmul::count_factors(const ex & e) const
245 {
246         if ((is_exactly_a<mul>(e)&&(e.return_type()!=return_types::commutative))||
247                 (is_exactly_a<ncmul>(e))) {
248                 size_t factors=0;
249                 for (size_t i=0; i<e.nops(); i++)
250                         factors += count_factors(e.op(i));
251                 
252                 return factors;
253         }
254         return 1;
255 }
256                 
257 void ncmul::append_factors(exvector & v, const ex & e) const
258 {
259         if ((is_exactly_a<mul>(e)&&(e.return_type()!=return_types::commutative))||
260                 (is_exactly_a<ncmul>(e))) {
261                 for (size_t i=0; i<e.nops(); i++)
262                         append_factors(v, e.op(i));
263         } else 
264                 v.push_back(e);
265 }
266
267 typedef std::vector<unsigned> unsignedvector;
268 typedef std::vector<exvector> exvectorvector;
269
270 /** Perform automatic term rewriting rules in this class.  In the following
271  *  x, x1, x2,... stand for a symbolic variables of type ex and c, c1, c2...
272  *  stand for such expressions that contain a plain number.
273  *  - ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) -> ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
274  *  - ncmul(x) -> x
275  *  - ncmul() -> 1
276  *  - ncmul(...,c1,...,c2,...) -> *(c1,c2,ncmul(...))  (pull out commutative elements)
277  *  - ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))  (collect elements of same type)
278  *  - ncmul(x1,x2,x3,...) -> x::eval_ncmul(x1,x2,x3,...)
279  *
280  *  @param level cut-off in recursive evaluation */
281 ex ncmul::eval(int level) const
282 {
283         // The following additional rule would be nice, but produces a recursion,
284         // which must be trapped by introducing a flag that the sub-ncmuls()
285         // are already evaluated (maybe later...)
286         //                  ncmul(x1,x2,...,X,y1,y2,...) ->
287         //                      ncmul(ncmul(x1,x2,...),X,ncmul(y1,y2,...)
288         //                      (X noncommutative_composite)
289
290         if ((level==1) && (flags & status_flags::evaluated)) {
291                 return *this;
292         }
293
294         exvector evaledseq=evalchildren(level);
295
296         // ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
297         //     ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
298         size_t factors = 0;
299         exvector::const_iterator cit = evaledseq.begin(), citend = evaledseq.end();
300         while (cit != citend)
301                 factors += count_factors(*cit++);
302         
303         exvector assocseq;
304         assocseq.reserve(factors);
305         cit = evaledseq.begin();
306         while (cit != citend)
307                 append_factors(assocseq, *cit++);
308         
309         // ncmul(x) -> x
310         if (assocseq.size()==1) return *(seq.begin());
311
312         // ncmul() -> 1
313         if (assocseq.empty()) return _ex1;
314
315         // determine return types
316         unsignedvector rettypes;
317         rettypes.reserve(assocseq.size());
318         size_t i = 0;
319         size_t count_commutative=0;
320         size_t count_noncommutative=0;
321         size_t count_noncommutative_composite=0;
322         cit = assocseq.begin(); citend = assocseq.end();
323         while (cit != citend) {
324                 switch (rettypes[i] = cit->return_type()) {
325                 case return_types::commutative:
326                         count_commutative++;
327                         break;
328                 case return_types::noncommutative:
329                         count_noncommutative++;
330                         break;
331                 case return_types::noncommutative_composite:
332                         count_noncommutative_composite++;
333                         break;
334                 default:
335                         throw(std::logic_error("ncmul::eval(): invalid return type"));
336                 }
337                 ++i; ++cit;
338         }
339         GINAC_ASSERT(count_commutative+count_noncommutative+count_noncommutative_composite==assocseq.size());
340
341         // ncmul(...,c1,...,c2,...) ->
342         //     *(c1,c2,ncmul(...)) (pull out commutative elements)
343         if (count_commutative!=0) {
344                 exvector commutativeseq;
345                 commutativeseq.reserve(count_commutative+1);
346                 exvector noncommutativeseq;
347                 noncommutativeseq.reserve(assocseq.size()-count_commutative);
348                 size_t num = assocseq.size();
349                 for (size_t i=0; i<num; ++i) {
350                         if (rettypes[i]==return_types::commutative)
351                                 commutativeseq.push_back(assocseq[i]);
352                         else
353                                 noncommutativeseq.push_back(assocseq[i]);
354                 }
355                 commutativeseq.push_back((new ncmul(noncommutativeseq,1))->setflag(status_flags::dynallocated));
356                 return (new mul(commutativeseq))->setflag(status_flags::dynallocated);
357         }
358                 
359         // ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
360         //     (collect elements of same type)
361
362         if (count_noncommutative_composite==0) {
363                 // there are neither commutative nor noncommutative_composite
364                 // elements in assocseq
365                 GINAC_ASSERT(count_commutative==0);
366
367                 size_t assoc_num = assocseq.size();
368                 exvectorvector evv;
369                 unsignedvector rttinfos;
370                 evv.reserve(assoc_num);
371                 rttinfos.reserve(assoc_num);
372
373                 cit = assocseq.begin(), citend = assocseq.end();
374                 while (cit != citend) {
375                         unsigned ti = cit->return_type_tinfo();
376                         size_t rtt_num = rttinfos.size();
377                         // search type in vector of known types
378                         for (i=0; i<rtt_num; ++i) {
379                                 if (ti == rttinfos[i]) {
380                                         evv[i].push_back(*cit);
381                                         break;
382                                 }
383                         }
384                         if (i >= rtt_num) {
385                                 // new type
386                                 rttinfos.push_back(ti);
387                                 evv.push_back(exvector());
388                                 (evv.end()-1)->reserve(assoc_num);
389                                 (evv.end()-1)->push_back(*cit);
390                         }
391                         ++cit;
392                 }
393
394                 size_t evv_num = evv.size();
395 #ifdef DO_GINAC_ASSERT
396                 GINAC_ASSERT(evv_num == rttinfos.size());
397                 GINAC_ASSERT(evv_num > 0);
398                 size_t s=0;
399                 for (i=0; i<evv_num; ++i)
400                         s += evv[i].size();
401                 GINAC_ASSERT(s == assoc_num);
402 #endif // def DO_GINAC_ASSERT
403                 
404                 // if all elements are of same type, simplify the string
405                 if (evv_num == 1)
406                         return evv[0][0].eval_ncmul(evv[0]);
407                 
408                 exvector splitseq;
409                 splitseq.reserve(evv_num);
410                 for (i=0; i<evv_num; ++i)
411                         splitseq.push_back((new ncmul(evv[i]))->setflag(status_flags::dynallocated));
412                 
413                 return (new mul(splitseq))->setflag(status_flags::dynallocated);
414         }
415         
416         return (new ncmul(assocseq))->setflag(status_flags::dynallocated |
417                                                                                   status_flags::evaluated);
418 }
419
420 ex ncmul::evalm() const
421 {
422         // Evaluate children first
423         std::auto_ptr<exvector> s(new exvector);
424         s->reserve(seq.size());
425         exvector::const_iterator it = seq.begin(), itend = seq.end();
426         while (it != itend) {
427                 s->push_back(it->evalm());
428                 it++;
429         }
430
431         // If there are only matrices, simply multiply them
432         it = s->begin(); itend = s->end();
433         if (is_a<matrix>(*it)) {
434                 matrix prod(ex_to<matrix>(*it));
435                 it++;
436                 while (it != itend) {
437                         if (!is_a<matrix>(*it))
438                                 goto no_matrix;
439                         prod = prod.mul(ex_to<matrix>(*it));
440                         it++;
441                 }
442                 return prod;
443         }
444
445 no_matrix:
446         return (new ncmul(s))->setflag(status_flags::dynallocated);
447 }
448
449 ex ncmul::thiscontainer(const exvector & v) const
450 {
451         return (new ncmul(v))->setflag(status_flags::dynallocated);
452 }
453
454 ex ncmul::thiscontainer(std::auto_ptr<exvector> vp) const
455 {
456         return (new ncmul(vp))->setflag(status_flags::dynallocated);
457 }
458
459 // protected
460
461 /** Implementation of ex::diff() for a non-commutative product. It applies
462  *  the product rule.
463  *  @see ex::diff */
464 ex ncmul::derivative(const symbol & s) const
465 {
466         size_t num = seq.size();
467         exvector addseq;
468         addseq.reserve(num);
469         
470         // D(a*b*c) = D(a)*b*c + a*D(b)*c + a*b*D(c)
471         exvector ncmulseq = seq;
472         for (size_t i=0; i<num; ++i) {
473                 ex e = seq[i].diff(s);
474                 e.swap(ncmulseq[i]);
475                 addseq.push_back((new ncmul(ncmulseq))->setflag(status_flags::dynallocated));
476                 e.swap(ncmulseq[i]);
477         }
478         return (new add(addseq))->setflag(status_flags::dynallocated);
479 }
480
481 int ncmul::compare_same_type(const basic & other) const
482 {
483         return inherited::compare_same_type(other);
484 }
485
486 unsigned ncmul::return_type() const
487 {
488         if (seq.empty())
489                 return return_types::commutative;
490
491         bool all_commutative = true;
492         exvector::const_iterator noncommutative_element; // point to first found nc element
493
494         exvector::const_iterator i = seq.begin(), end = seq.end();
495         while (i != end) {
496                 unsigned rt = i->return_type();
497                 if (rt == return_types::noncommutative_composite)
498                         return rt; // one ncc -> mul also ncc
499                 if ((rt == return_types::noncommutative) && (all_commutative)) {
500                         // first nc element found, remember position
501                         noncommutative_element = i;
502                         all_commutative = false;
503                 }
504                 if ((rt == return_types::noncommutative) && (!all_commutative)) {
505                         // another nc element found, compare type_infos
506                         if (noncommutative_element->return_type_tinfo() != i->return_type_tinfo()) {
507                                 // diffent types -> mul is ncc
508                                 return return_types::noncommutative_composite;
509                         }
510                 }
511                 ++i;
512         }
513         // all factors checked
514         GINAC_ASSERT(!all_commutative); // not all factors should commute, because this is a ncmul();
515         return all_commutative ? return_types::commutative : return_types::noncommutative;
516 }
517    
518 unsigned ncmul::return_type_tinfo() const
519 {
520         if (seq.empty())
521                 return tinfo_key;
522
523         // return type_info of first noncommutative element
524         exvector::const_iterator i = seq.begin(), end = seq.end();
525         while (i != end) {
526                 if (i->return_type() == return_types::noncommutative)
527                         return i->return_type_tinfo();
528                 ++i;
529         }
530
531         // no noncommutative element found, should not happen
532         return tinfo_key;
533 }
534
535 //////////
536 // new virtual functions which can be overridden by derived classes
537 //////////
538
539 // none
540
541 //////////
542 // non-virtual functions in this class
543 //////////
544
545 exvector ncmul::expandchildren(unsigned options) const
546 {
547         exvector s;
548         s.reserve(seq.size());
549         exvector::const_iterator it = seq.begin(), itend = seq.end();
550         while (it != itend) {
551                 s.push_back(it->expand(options));
552                 it++;
553         }
554         return s;
555 }
556
557 const exvector & ncmul::get_factors() const
558 {
559         return seq;
560 }
561
562 //////////
563 // friend functions
564 //////////
565
566 ex reeval_ncmul(const exvector & v)
567 {
568         return (new ncmul(v))->setflag(status_flags::dynallocated);
569 }
570
571 ex hold_ncmul(const exvector & v)
572 {
573         if (v.empty())
574                 return _ex1;
575         else if (v.size() == 1)
576                 return v[0];
577         else
578                 return (new ncmul(v))->setflag(status_flags::dynallocated |
579                                                status_flags::evaluated);
580 }
581
582 } // namespace GiNaC