A better return_type_tinfo() mechanism.
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #include <limits>
26 #include <string>
27
28 #include "add.h"
29 #include "mul.h"
30 #include "archive.h"
31 #include "operators.h"
32 #include "matrix.h"
33 #include "utils.h"
34 #include "clifford.h"
35 #include "ncmul.h"
36
37 namespace GiNaC {
38
39 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(add, expairseq,
40   print_func<print_context>(&add::do_print).
41   print_func<print_latex>(&add::do_print_latex).
42   print_func<print_csrc>(&add::do_print_csrc).
43   print_func<print_tree>(&add::do_print_tree).
44   print_func<print_python_repr>(&add::do_print_python_repr))
45
46 //////////
47 // default constructor
48 //////////
49
50 add::add()
51 {
52         tinfo_key = &add::tinfo_static;
53 }
54
55 //////////
56 // other constructors
57 //////////
58
59 // public
60
61 add::add(const ex & lh, const ex & rh)
62 {
63         tinfo_key = &add::tinfo_static;
64         overall_coeff = _ex0;
65         construct_from_2_ex(lh,rh);
66         GINAC_ASSERT(is_canonical());
67 }
68
69 add::add(const exvector & v)
70 {
71         tinfo_key = &add::tinfo_static;
72         overall_coeff = _ex0;
73         construct_from_exvector(v);
74         GINAC_ASSERT(is_canonical());
75 }
76
77 add::add(const epvector & v)
78 {
79         tinfo_key = &add::tinfo_static;
80         overall_coeff = _ex0;
81         construct_from_epvector(v);
82         GINAC_ASSERT(is_canonical());
83 }
84
85 add::add(const epvector & v, const ex & oc)
86 {
87         tinfo_key = &add::tinfo_static;
88         overall_coeff = oc;
89         construct_from_epvector(v);
90         GINAC_ASSERT(is_canonical());
91 }
92
93 add::add(std::auto_ptr<epvector> vp, const ex & oc)
94 {
95         tinfo_key = &add::tinfo_static;
96         GINAC_ASSERT(vp.get()!=0);
97         overall_coeff = oc;
98         construct_from_epvector(*vp);
99         GINAC_ASSERT(is_canonical());
100 }
101
102 //////////
103 // archiving
104 //////////
105
106 DEFAULT_ARCHIVING(add)
107
108 //////////
109 // functions overriding virtual functions from base classes
110 //////////
111
112 // public
113
114 void add::print_add(const print_context & c, const char *openbrace, const char *closebrace, const char *mul_sym, unsigned level) const
115 {
116         if (precedence() <= level)
117                 c.s << openbrace << '(';
118
119         numeric coeff;
120         bool first = true;
121
122         // First print the overall numeric coefficient, if present
123         if (!overall_coeff.is_zero()) {
124                 overall_coeff.print(c, 0);
125                 first = false;
126         }
127
128         // Then proceed with the remaining factors
129         epvector::const_iterator it = seq.begin(), itend = seq.end();
130         while (it != itend) {
131                 coeff = ex_to<numeric>(it->coeff);
132                 if (!first) {
133                         if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
134                 } else {
135                         if (coeff.csgn() == -1) c.s << '-';
136                         first = false;
137                 }
138                 if (!coeff.is_equal(*_num1_p) &&
139                     !coeff.is_equal(*_num_1_p)) {
140                         if (coeff.is_rational()) {
141                                 if (coeff.is_negative())
142                                         (-coeff).print(c);
143                                 else
144                                         coeff.print(c);
145                         } else {
146                                 if (coeff.csgn() == -1)
147                                         (-coeff).print(c, precedence());
148                                 else
149                                         coeff.print(c, precedence());
150                         }
151                         c.s << mul_sym;
152                 }
153                 it->rest.print(c, precedence());
154                 ++it;
155         }
156
157         if (precedence() <= level)
158                 c.s << ')' << closebrace;
159 }
160
161 void add::do_print(const print_context & c, unsigned level) const
162 {
163         print_add(c, "", "", "*", level);
164 }
165
166 void add::do_print_latex(const print_latex & c, unsigned level) const
167 {
168         print_add(c, "{", "}", " ", level);
169 }
170
171 void add::do_print_csrc(const print_csrc & c, unsigned level) const
172 {
173         if (precedence() <= level)
174                 c.s << "(";
175         
176         // Print arguments, separated by "+" or "-"
177         epvector::const_iterator it = seq.begin(), itend = seq.end();
178         char separator = ' ';
179         while (it != itend) {
180                 
181                 // If the coefficient is negative, separator is "-"
182                 if (it->coeff.is_equal(_ex_1) || 
183                         ex_to<numeric>(it->coeff).numer().is_equal(*_num_1_p))
184                         separator = '-';
185                 c.s << separator;
186                 if (it->coeff.is_equal(_ex1) || it->coeff.is_equal(_ex_1)) {
187                         it->rest.print(c, precedence());
188                 } else if (ex_to<numeric>(it->coeff).numer().is_equal(*_num1_p) ||
189                                  ex_to<numeric>(it->coeff).numer().is_equal(*_num_1_p))
190                 {
191                         it->rest.print(c, precedence());
192                         c.s << '/';
193                         ex_to<numeric>(it->coeff).denom().print(c, precedence());
194                 } else {
195                         it->coeff.print(c, precedence());
196                         c.s << '*';
197                         it->rest.print(c, precedence());
198                 }
199                 
200                 ++it;
201                 separator = '+';
202         }
203         
204         if (!overall_coeff.is_zero()) {
205                 if (overall_coeff.info(info_flags::positive)
206                  || is_a<print_csrc_cl_N>(c) || !overall_coeff.info(info_flags::real))  // sign inside ctor argument
207                         c.s << '+';
208                 overall_coeff.print(c, precedence());
209         }
210                 
211         if (precedence() <= level)
212                 c.s << ")";
213 }
214
215 void add::do_print_python_repr(const print_python_repr & c, unsigned level) const
216 {
217         c.s << class_name() << '(';
218         op(0).print(c);
219         for (size_t i=1; i<nops(); ++i) {
220                 c.s << ',';
221                 op(i).print(c);
222         }
223         c.s << ')';
224 }
225
226 bool add::info(unsigned inf) const
227 {
228         switch (inf) {
229                 case info_flags::polynomial:
230                 case info_flags::integer_polynomial:
231                 case info_flags::cinteger_polynomial:
232                 case info_flags::rational_polynomial:
233                 case info_flags::crational_polynomial:
234                 case info_flags::rational_function: {
235                         epvector::const_iterator i = seq.begin(), end = seq.end();
236                         while (i != end) {
237                                 if (!(recombine_pair_to_ex(*i).info(inf)))
238                                         return false;
239                                 ++i;
240                         }
241                         return overall_coeff.info(inf);
242                 }
243                 case info_flags::algebraic: {
244                         epvector::const_iterator i = seq.begin(), end = seq.end();
245                         while (i != end) {
246                                 if ((recombine_pair_to_ex(*i).info(inf)))
247                                         return true;
248                                 ++i;
249                         }
250                         return false;
251                 }
252         }
253         return inherited::info(inf);
254 }
255
256 int add::degree(const ex & s) const
257 {
258         int deg = std::numeric_limits<int>::min();
259         if (!overall_coeff.is_zero())
260                 deg = 0;
261         
262         // Find maximum of degrees of individual terms
263         epvector::const_iterator i = seq.begin(), end = seq.end();
264         while (i != end) {
265                 int cur_deg = i->rest.degree(s);
266                 if (cur_deg > deg)
267                         deg = cur_deg;
268                 ++i;
269         }
270         return deg;
271 }
272
273 int add::ldegree(const ex & s) const
274 {
275         int deg = std::numeric_limits<int>::max();
276         if (!overall_coeff.is_zero())
277                 deg = 0;
278         
279         // Find minimum of degrees of individual terms
280         epvector::const_iterator i = seq.begin(), end = seq.end();
281         while (i != end) {
282                 int cur_deg = i->rest.ldegree(s);
283                 if (cur_deg < deg)
284                         deg = cur_deg;
285                 ++i;
286         }
287         return deg;
288 }
289
290 ex add::coeff(const ex & s, int n) const
291 {
292         std::auto_ptr<epvector> coeffseq(new epvector);
293         std::auto_ptr<epvector> coeffseq_cliff(new epvector);
294         char rl = clifford_max_label(s);
295         bool do_clifford = (rl != -1);
296         bool nonscalar = false;
297
298         // Calculate sum of coefficients in each term
299         epvector::const_iterator i = seq.begin(), end = seq.end();
300         while (i != end) {
301                 ex restcoeff = i->rest.coeff(s, n);
302                 if (!restcoeff.is_zero()) {
303                         if (do_clifford) {
304                                 if (clifford_max_label(restcoeff) == -1) {
305                                         coeffseq_cliff->push_back(combine_ex_with_coeff_to_pair(ncmul(restcoeff, dirac_ONE(rl)), i->coeff));
306                                 } else {
307                                         coeffseq_cliff->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
308                                         nonscalar = true;
309                                 }
310                         }
311                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
312                 }
313                 ++i;
314         }
315
316         return (new add(nonscalar ? coeffseq_cliff : coeffseq,
317                         n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
318 }
319
320 /** Perform automatic term rewriting rules in this class.  In the following
321  *  x stands for a symbolic variables of type ex and c stands for such
322  *  an expression that contain a plain number.
323  *  - +(;c) -> c
324  *  - +(x;0) -> x
325  *
326  *  @param level cut-off in recursive evaluation */
327 ex add::eval(int level) const
328 {
329         std::auto_ptr<epvector> evaled_seqp = evalchildren(level);
330         if (evaled_seqp.get()) {
331                 // do more evaluation later
332                 return (new add(evaled_seqp, overall_coeff))->
333                        setflag(status_flags::dynallocated);
334         }
335         
336 #ifdef DO_GINAC_ASSERT
337         epvector::const_iterator i = seq.begin(), end = seq.end();
338         while (i != end) {
339                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
340                 if (is_exactly_a<numeric>(i->rest))
341                         dbgprint();
342                 GINAC_ASSERT(!is_exactly_a<numeric>(i->rest));
343                 ++i;
344         }
345 #endif // def DO_GINAC_ASSERT
346         
347         if (flags & status_flags::evaluated) {
348                 GINAC_ASSERT(seq.size()>0);
349                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
350                 return *this;
351         }
352         
353         int seq_size = seq.size();
354         if (seq_size == 0) {
355                 // +(;c) -> c
356                 return overall_coeff;
357         } else if (seq_size == 1 && overall_coeff.is_zero()) {
358                 // +(x;0) -> x
359                 return recombine_pair_to_ex(*(seq.begin()));
360         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
361                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
362         }
363         return this->hold();
364 }
365
366 ex add::evalm() const
367 {
368         // Evaluate children first and add up all matrices. Stop if there's one
369         // term that is not a matrix.
370         std::auto_ptr<epvector> s(new epvector);
371         s->reserve(seq.size());
372
373         bool all_matrices = true;
374         bool first_term = true;
375         matrix sum;
376
377         epvector::const_iterator it = seq.begin(), itend = seq.end();
378         while (it != itend) {
379                 const ex &m = recombine_pair_to_ex(*it).evalm();
380                 s->push_back(split_ex_to_pair(m));
381                 if (is_a<matrix>(m)) {
382                         if (first_term) {
383                                 sum = ex_to<matrix>(m);
384                                 first_term = false;
385                         } else
386                                 sum = sum.add(ex_to<matrix>(m));
387                 } else
388                         all_matrices = false;
389                 ++it;
390         }
391
392         if (all_matrices)
393                 return sum + overall_coeff;
394         else
395                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
396 }
397
398 ex add::conjugate() const
399 {
400         exvector *v = 0;
401         for (size_t i=0; i<nops(); ++i) {
402                 if (v) {
403                         v->push_back(op(i).conjugate());
404                         continue;
405                 }
406                 ex term = op(i);
407                 ex ccterm = term.conjugate();
408                 if (are_ex_trivially_equal(term, ccterm))
409                         continue;
410                 v = new exvector;
411                 v->reserve(nops());
412                 for (size_t j=0; j<i; ++j)
413                         v->push_back(op(j));
414                 v->push_back(ccterm);
415         }
416         if (v) {
417                 ex result = add(*v);
418                 delete v;
419                 return result;
420         }
421         return *this;
422 }
423
424 ex add::real_part() const
425 {
426         epvector v;
427         v.reserve(seq.size());
428         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i)
429                 if ((i->coeff).info(info_flags::real)) {
430                         ex rp = (i->rest).real_part();
431                         if (!rp.is_zero())
432                                 v.push_back(expair(rp, i->coeff));
433                 } else {
434                         ex rp=recombine_pair_to_ex(*i).real_part();
435                         if (!rp.is_zero())
436                                 v.push_back(split_ex_to_pair(rp));
437                 }
438         return (new add(v, overall_coeff.real_part()))
439                 -> setflag(status_flags::dynallocated);
440 }
441
442 ex add::imag_part() const
443 {
444         epvector v;
445         v.reserve(seq.size());
446         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i)
447                 if ((i->coeff).info(info_flags::real)) {
448                         ex ip = (i->rest).imag_part();
449                         if (!ip.is_zero())
450                                 v.push_back(expair(ip, i->coeff));
451                 } else {
452                         ex ip=recombine_pair_to_ex(*i).imag_part();
453                         if (!ip.is_zero())
454                                 v.push_back(split_ex_to_pair(ip));
455                 }
456         return (new add(v, overall_coeff.imag_part()))
457                 -> setflag(status_flags::dynallocated);
458 }
459
460 ex add::eval_ncmul(const exvector & v) const
461 {
462         if (seq.empty())
463                 return inherited::eval_ncmul(v);
464         else
465                 return seq.begin()->rest.eval_ncmul(v);
466 }    
467
468 // protected
469
470 /** Implementation of ex::diff() for a sum. It differentiates each term.
471  *  @see ex::diff */
472 ex add::derivative(const symbol & y) const
473 {
474         std::auto_ptr<epvector> s(new epvector);
475         s->reserve(seq.size());
476         
477         // Only differentiate the "rest" parts of the expairs. This is faster
478         // than the default implementation in basic::derivative() although
479         // if performs the same function (differentiate each term).
480         epvector::const_iterator i = seq.begin(), end = seq.end();
481         while (i != end) {
482                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
483                 ++i;
484         }
485         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
486 }
487
488 int add::compare_same_type(const basic & other) const
489 {
490         return inherited::compare_same_type(other);
491 }
492
493 unsigned add::return_type() const
494 {
495         if (seq.empty())
496                 return return_types::commutative;
497         else
498                 return seq.begin()->rest.return_type();
499 }
500
501 return_type_t add::return_type_tinfo() const
502 {
503         if (seq.empty())
504                 return make_return_type_t<add>();
505         else
506                 return seq.begin()->rest.return_type_tinfo();
507 }
508
509 // Note: do_index_renaming is ignored because it makes no sense for an add.
510 ex add::thisexpairseq(const epvector & v, const ex & oc, bool do_index_renaming) const
511 {
512         return (new add(v,oc))->setflag(status_flags::dynallocated);
513 }
514
515 // Note: do_index_renaming is ignored because it makes no sense for an add.
516 ex add::thisexpairseq(std::auto_ptr<epvector> vp, const ex & oc, bool do_index_renaming) const
517 {
518         return (new add(vp,oc))->setflag(status_flags::dynallocated);
519 }
520
521 expair add::split_ex_to_pair(const ex & e) const
522 {
523         if (is_exactly_a<mul>(e)) {
524                 const mul &mulref(ex_to<mul>(e));
525                 const ex &numfactor = mulref.overall_coeff;
526                 mul *mulcopyp = new mul(mulref);
527                 mulcopyp->overall_coeff = _ex1;
528                 mulcopyp->clearflag(status_flags::evaluated);
529                 mulcopyp->clearflag(status_flags::hash_calculated);
530                 mulcopyp->setflag(status_flags::dynallocated);
531                 return expair(*mulcopyp,numfactor);
532         }
533         return expair(e,_ex1);
534 }
535
536 expair add::combine_ex_with_coeff_to_pair(const ex & e,
537                                                                                   const ex & c) const
538 {
539         GINAC_ASSERT(is_exactly_a<numeric>(c));
540         if (is_exactly_a<mul>(e)) {
541                 const mul &mulref(ex_to<mul>(e));
542                 const ex &numfactor = mulref.overall_coeff;
543                 mul *mulcopyp = new mul(mulref);
544                 mulcopyp->overall_coeff = _ex1;
545                 mulcopyp->clearflag(status_flags::evaluated);
546                 mulcopyp->clearflag(status_flags::hash_calculated);
547                 mulcopyp->setflag(status_flags::dynallocated);
548                 if (c.is_equal(_ex1))
549                         return expair(*mulcopyp, numfactor);
550                 else if (numfactor.is_equal(_ex1))
551                         return expair(*mulcopyp, c);
552                 else
553                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
554         } else if (is_exactly_a<numeric>(e)) {
555                 if (c.is_equal(_ex1))
556                         return expair(e, _ex1);
557                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
558         }
559         return expair(e, c);
560 }
561
562 expair add::combine_pair_with_coeff_to_pair(const expair & p,
563                                                                                         const ex & c) const
564 {
565         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
566         GINAC_ASSERT(is_exactly_a<numeric>(c));
567
568         if (is_exactly_a<numeric>(p.rest)) {
569                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(*_num1_p)); // should be normalized
570                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
571         }
572
573         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
574 }
575         
576 ex add::recombine_pair_to_ex(const expair & p) const
577 {
578         if (ex_to<numeric>(p.coeff).is_equal(*_num1_p))
579                 return p.rest;
580         else
581                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
582 }
583
584 ex add::expand(unsigned options) const
585 {
586         std::auto_ptr<epvector> vp = expandchildren(options);
587         if (vp.get() == 0) {
588                 // the terms have not changed, so it is safe to declare this expanded
589                 return (options == 0) ? setflag(status_flags::expanded) : *this;
590         }
591
592         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
593 }
594
595 } // namespace GiNaC