Make sure add::eval() collects all numeric terms.
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2010 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "add.h"
24 #include "mul.h"
25 #include "archive.h"
26 #include "operators.h"
27 #include "matrix.h"
28 #include "utils.h"
29 #include "clifford.h"
30 #include "ncmul.h"
31 #include "compiler.h"
32
33 #include <iostream>
34 #include <limits>
35 #include <stdexcept>
36 #include <string>
37
38 namespace GiNaC {
39
40 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(add, expairseq,
41   print_func<print_context>(&add::do_print).
42   print_func<print_latex>(&add::do_print_latex).
43   print_func<print_csrc>(&add::do_print_csrc).
44   print_func<print_tree>(&add::do_print_tree).
45   print_func<print_python_repr>(&add::do_print_python_repr))
46
47 //////////
48 // default constructor
49 //////////
50
51 add::add()
52 {
53 }
54
55 //////////
56 // other constructors
57 //////////
58
59 // public
60
61 add::add(const ex & lh, const ex & rh)
62 {
63         overall_coeff = _ex0;
64         construct_from_2_ex(lh,rh);
65         GINAC_ASSERT(is_canonical());
66 }
67
68 add::add(const exvector & v)
69 {
70         overall_coeff = _ex0;
71         construct_from_exvector(v);
72         GINAC_ASSERT(is_canonical());
73 }
74
75 add::add(const epvector & v)
76 {
77         overall_coeff = _ex0;
78         construct_from_epvector(v);
79         GINAC_ASSERT(is_canonical());
80 }
81
82 add::add(const epvector & v, const ex & oc)
83 {
84         overall_coeff = oc;
85         construct_from_epvector(v);
86         GINAC_ASSERT(is_canonical());
87 }
88
89 add::add(std::auto_ptr<epvector> vp, const ex & oc)
90 {
91         GINAC_ASSERT(vp.get()!=0);
92         overall_coeff = oc;
93         construct_from_epvector(*vp);
94         GINAC_ASSERT(is_canonical());
95 }
96
97 //////////
98 // archiving
99 //////////
100
101 GINAC_BIND_UNARCHIVER(add);
102
103 //////////
104 // functions overriding virtual functions from base classes
105 //////////
106
107 // public
108
109 void add::print_add(const print_context & c, const char *openbrace, const char *closebrace, const char *mul_sym, unsigned level) const
110 {
111         if (precedence() <= level)
112                 c.s << openbrace << '(';
113
114         numeric coeff;
115         bool first = true;
116
117         // First print the overall numeric coefficient, if present
118         if (!overall_coeff.is_zero()) {
119                 overall_coeff.print(c, 0);
120                 first = false;
121         }
122
123         // Then proceed with the remaining factors
124         epvector::const_iterator it = seq.begin(), itend = seq.end();
125         while (it != itend) {
126                 coeff = ex_to<numeric>(it->coeff);
127                 if (!first) {
128                         if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
129                 } else {
130                         if (coeff.csgn() == -1) c.s << '-';
131                         first = false;
132                 }
133                 if (!coeff.is_equal(*_num1_p) &&
134                     !coeff.is_equal(*_num_1_p)) {
135                         if (coeff.is_rational()) {
136                                 if (coeff.is_negative())
137                                         (-coeff).print(c);
138                                 else
139                                         coeff.print(c);
140                         } else {
141                                 if (coeff.csgn() == -1)
142                                         (-coeff).print(c, precedence());
143                                 else
144                                         coeff.print(c, precedence());
145                         }
146                         c.s << mul_sym;
147                 }
148                 it->rest.print(c, precedence());
149                 ++it;
150         }
151
152         if (precedence() <= level)
153                 c.s << ')' << closebrace;
154 }
155
156 void add::do_print(const print_context & c, unsigned level) const
157 {
158         print_add(c, "", "", "*", level);
159 }
160
161 void add::do_print_latex(const print_latex & c, unsigned level) const
162 {
163         print_add(c, "{", "}", " ", level);
164 }
165
166 void add::do_print_csrc(const print_csrc & c, unsigned level) const
167 {
168         if (precedence() <= level)
169                 c.s << "(";
170         
171         // Print arguments, separated by "+" or "-"
172         epvector::const_iterator it = seq.begin(), itend = seq.end();
173         char separator = ' ';
174         while (it != itend) {
175                 
176                 // If the coefficient is negative, separator is "-"
177                 if (it->coeff.is_equal(_ex_1) || 
178                         ex_to<numeric>(it->coeff).numer().is_equal(*_num_1_p))
179                         separator = '-';
180                 c.s << separator;
181                 if (it->coeff.is_equal(_ex1) || it->coeff.is_equal(_ex_1)) {
182                         it->rest.print(c, precedence());
183                 } else if (ex_to<numeric>(it->coeff).numer().is_equal(*_num1_p) ||
184                                  ex_to<numeric>(it->coeff).numer().is_equal(*_num_1_p))
185                 {
186                         it->rest.print(c, precedence());
187                         c.s << '/';
188                         ex_to<numeric>(it->coeff).denom().print(c, precedence());
189                 } else {
190                         it->coeff.print(c, precedence());
191                         c.s << '*';
192                         it->rest.print(c, precedence());
193                 }
194                 
195                 ++it;
196                 separator = '+';
197         }
198         
199         if (!overall_coeff.is_zero()) {
200                 if (overall_coeff.info(info_flags::positive)
201                  || is_a<print_csrc_cl_N>(c) || !overall_coeff.info(info_flags::real))  // sign inside ctor argument
202                         c.s << '+';
203                 overall_coeff.print(c, precedence());
204         }
205                 
206         if (precedence() <= level)
207                 c.s << ")";
208 }
209
210 void add::do_print_python_repr(const print_python_repr & c, unsigned level) const
211 {
212         c.s << class_name() << '(';
213         op(0).print(c);
214         for (size_t i=1; i<nops(); ++i) {
215                 c.s << ',';
216                 op(i).print(c);
217         }
218         c.s << ')';
219 }
220
221 bool add::info(unsigned inf) const
222 {
223         switch (inf) {
224                 case info_flags::polynomial:
225                 case info_flags::integer_polynomial:
226                 case info_flags::cinteger_polynomial:
227                 case info_flags::rational_polynomial:
228                 case info_flags::real:
229                 case info_flags::rational:
230                 case info_flags::integer:
231                 case info_flags::crational:
232                 case info_flags::cinteger:
233                 case info_flags::positive:
234                 case info_flags::nonnegative:
235                 case info_flags::posint:
236                 case info_flags::nonnegint:
237                 case info_flags::even:
238                 case info_flags::crational_polynomial:
239                 case info_flags::rational_function: {
240                         epvector::const_iterator i = seq.begin(), end = seq.end();
241                         while (i != end) {
242                                 if (!(recombine_pair_to_ex(*i).info(inf)))
243                                         return false;
244                                 ++i;
245                         }
246                         if (overall_coeff.is_zero() && (inf == info_flags::positive || inf == info_flags::posint))
247                                 return true;
248                         return overall_coeff.info(inf);
249                 }
250                 case info_flags::algebraic: {
251                         epvector::const_iterator i = seq.begin(), end = seq.end();
252                         while (i != end) {
253                                 if ((recombine_pair_to_ex(*i).info(inf)))
254                                         return true;
255                                 ++i;
256                         }
257                         return false;
258                 }
259         }
260         return inherited::info(inf);
261 }
262
263 int add::degree(const ex & s) const
264 {
265         int deg = std::numeric_limits<int>::min();
266         if (!overall_coeff.is_zero())
267                 deg = 0;
268         
269         // Find maximum of degrees of individual terms
270         epvector::const_iterator i = seq.begin(), end = seq.end();
271         while (i != end) {
272                 int cur_deg = i->rest.degree(s);
273                 if (cur_deg > deg)
274                         deg = cur_deg;
275                 ++i;
276         }
277         return deg;
278 }
279
280 int add::ldegree(const ex & s) const
281 {
282         int deg = std::numeric_limits<int>::max();
283         if (!overall_coeff.is_zero())
284                 deg = 0;
285         
286         // Find minimum of degrees of individual terms
287         epvector::const_iterator i = seq.begin(), end = seq.end();
288         while (i != end) {
289                 int cur_deg = i->rest.ldegree(s);
290                 if (cur_deg < deg)
291                         deg = cur_deg;
292                 ++i;
293         }
294         return deg;
295 }
296
297 ex add::coeff(const ex & s, int n) const
298 {
299         std::auto_ptr<epvector> coeffseq(new epvector);
300         std::auto_ptr<epvector> coeffseq_cliff(new epvector);
301         char rl = clifford_max_label(s);
302         bool do_clifford = (rl != -1);
303         bool nonscalar = false;
304
305         // Calculate sum of coefficients in each term
306         epvector::const_iterator i = seq.begin(), end = seq.end();
307         while (i != end) {
308                 ex restcoeff = i->rest.coeff(s, n);
309                 if (!restcoeff.is_zero()) {
310                         if (do_clifford) {
311                                 if (clifford_max_label(restcoeff) == -1) {
312                                         coeffseq_cliff->push_back(combine_ex_with_coeff_to_pair(ncmul(restcoeff, dirac_ONE(rl)), i->coeff));
313                                 } else {
314                                         coeffseq_cliff->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
315                                         nonscalar = true;
316                                 }
317                         }
318                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
319                 }
320                 ++i;
321         }
322
323         return (new add(nonscalar ? coeffseq_cliff : coeffseq,
324                         n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
325 }
326
327 /** Perform automatic term rewriting rules in this class.  In the following
328  *  x stands for a symbolic variables of type ex and c stands for such
329  *  an expression that contain a plain number.
330  *  - +(;c) -> c
331  *  - +(x;0) -> x
332  *
333  *  @param level cut-off in recursive evaluation */
334 ex add::eval(int level) const
335 {
336         std::auto_ptr<epvector> evaled_seqp = evalchildren(level);
337         if (evaled_seqp.get()) {
338                 // do more evaluation later
339                 return (new add(evaled_seqp, overall_coeff))->
340                        setflag(status_flags::dynallocated);
341         }
342         
343 #ifdef DO_GINAC_ASSERT
344         epvector::const_iterator i = seq.begin(), end = seq.end();
345         while (i != end) {
346                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
347                 ++i;
348         }
349 #endif // def DO_GINAC_ASSERT
350         
351         if (flags & status_flags::evaluated) {
352                 GINAC_ASSERT(seq.size()>0);
353                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
354                 return *this;
355         }
356         
357         int seq_size = seq.size();
358         if (seq_size == 0) {
359                 // +(;c) -> c
360                 return overall_coeff;
361         } else if (seq_size == 1 && overall_coeff.is_zero()) {
362                 // +(x;0) -> x
363                 return recombine_pair_to_ex(*(seq.begin()));
364         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
365                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
366         }
367         
368         // if any terms in the sum still are purely numeric, then they are more
369         // appropriately collected into the overall coefficient
370         epvector::const_iterator last = seq.end();
371         epvector::const_iterator j = seq.begin();
372         int terms_to_collect = 0;
373         while (j != last) {
374                 if (unlikely(is_a<numeric>(j->rest)))
375                         ++terms_to_collect;
376                 ++j;
377         }
378         if (terms_to_collect) {
379                 std::auto_ptr<epvector> s(new epvector);
380                 s->reserve(seq_size - terms_to_collect);
381                 numeric oc = *_num1_p;
382                 j = seq.begin();
383                 while (j != last) {
384                         if (unlikely(is_a<numeric>(j->rest)))
385                                 oc = oc.mul(ex_to<numeric>(j->rest)).mul(ex_to<numeric>(j->coeff));
386                         else
387                                 s->push_back(*j);
388                         ++j;
389                 }
390                 return (new add(s, ex_to<numeric>(overall_coeff).add_dyn(oc)))
391                         ->setflag(status_flags::dynallocated);
392         }
393         
394         return this->hold();
395 }
396
397 ex add::evalm() const
398 {
399         // Evaluate children first and add up all matrices. Stop if there's one
400         // term that is not a matrix.
401         std::auto_ptr<epvector> s(new epvector);
402         s->reserve(seq.size());
403
404         bool all_matrices = true;
405         bool first_term = true;
406         matrix sum;
407
408         epvector::const_iterator it = seq.begin(), itend = seq.end();
409         while (it != itend) {
410                 const ex &m = recombine_pair_to_ex(*it).evalm();
411                 s->push_back(split_ex_to_pair(m));
412                 if (is_a<matrix>(m)) {
413                         if (first_term) {
414                                 sum = ex_to<matrix>(m);
415                                 first_term = false;
416                         } else
417                                 sum = sum.add(ex_to<matrix>(m));
418                 } else
419                         all_matrices = false;
420                 ++it;
421         }
422
423         if (all_matrices)
424                 return sum + overall_coeff;
425         else
426                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
427 }
428
429 ex add::conjugate() const
430 {
431         exvector *v = 0;
432         for (size_t i=0; i<nops(); ++i) {
433                 if (v) {
434                         v->push_back(op(i).conjugate());
435                         continue;
436                 }
437                 ex term = op(i);
438                 ex ccterm = term.conjugate();
439                 if (are_ex_trivially_equal(term, ccterm))
440                         continue;
441                 v = new exvector;
442                 v->reserve(nops());
443                 for (size_t j=0; j<i; ++j)
444                         v->push_back(op(j));
445                 v->push_back(ccterm);
446         }
447         if (v) {
448                 ex result = add(*v);
449                 delete v;
450                 return result;
451         }
452         return *this;
453 }
454
455 ex add::real_part() const
456 {
457         epvector v;
458         v.reserve(seq.size());
459         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i)
460                 if ((i->coeff).info(info_flags::real)) {
461                         ex rp = (i->rest).real_part();
462                         if (!rp.is_zero())
463                                 v.push_back(expair(rp, i->coeff));
464                 } else {
465                         ex rp=recombine_pair_to_ex(*i).real_part();
466                         if (!rp.is_zero())
467                                 v.push_back(split_ex_to_pair(rp));
468                 }
469         return (new add(v, overall_coeff.real_part()))
470                 -> setflag(status_flags::dynallocated);
471 }
472
473 ex add::imag_part() const
474 {
475         epvector v;
476         v.reserve(seq.size());
477         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i)
478                 if ((i->coeff).info(info_flags::real)) {
479                         ex ip = (i->rest).imag_part();
480                         if (!ip.is_zero())
481                                 v.push_back(expair(ip, i->coeff));
482                 } else {
483                         ex ip=recombine_pair_to_ex(*i).imag_part();
484                         if (!ip.is_zero())
485                                 v.push_back(split_ex_to_pair(ip));
486                 }
487         return (new add(v, overall_coeff.imag_part()))
488                 -> setflag(status_flags::dynallocated);
489 }
490
491 ex add::eval_ncmul(const exvector & v) const
492 {
493         if (seq.empty())
494                 return inherited::eval_ncmul(v);
495         else
496                 return seq.begin()->rest.eval_ncmul(v);
497 }    
498
499 // protected
500
501 /** Implementation of ex::diff() for a sum. It differentiates each term.
502  *  @see ex::diff */
503 ex add::derivative(const symbol & y) const
504 {
505         std::auto_ptr<epvector> s(new epvector);
506         s->reserve(seq.size());
507         
508         // Only differentiate the "rest" parts of the expairs. This is faster
509         // than the default implementation in basic::derivative() although
510         // if performs the same function (differentiate each term).
511         epvector::const_iterator i = seq.begin(), end = seq.end();
512         while (i != end) {
513                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
514                 ++i;
515         }
516         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
517 }
518
519 int add::compare_same_type(const basic & other) const
520 {
521         return inherited::compare_same_type(other);
522 }
523
524 unsigned add::return_type() const
525 {
526         if (seq.empty())
527                 return return_types::commutative;
528         else
529                 return seq.begin()->rest.return_type();
530 }
531
532 return_type_t add::return_type_tinfo() const
533 {
534         if (seq.empty())
535                 return make_return_type_t<add>();
536         else
537                 return seq.begin()->rest.return_type_tinfo();
538 }
539
540 // Note: do_index_renaming is ignored because it makes no sense for an add.
541 ex add::thisexpairseq(const epvector & v, const ex & oc, bool do_index_renaming) const
542 {
543         return (new add(v,oc))->setflag(status_flags::dynallocated);
544 }
545
546 // Note: do_index_renaming is ignored because it makes no sense for an add.
547 ex add::thisexpairseq(std::auto_ptr<epvector> vp, const ex & oc, bool do_index_renaming) const
548 {
549         return (new add(vp,oc))->setflag(status_flags::dynallocated);
550 }
551
552 expair add::split_ex_to_pair(const ex & e) const
553 {
554         if (is_exactly_a<mul>(e)) {
555                 const mul &mulref(ex_to<mul>(e));
556                 const ex &numfactor = mulref.overall_coeff;
557                 mul *mulcopyp = new mul(mulref);
558                 mulcopyp->overall_coeff = _ex1;
559                 mulcopyp->clearflag(status_flags::evaluated);
560                 mulcopyp->clearflag(status_flags::hash_calculated);
561                 mulcopyp->setflag(status_flags::dynallocated);
562                 return expair(*mulcopyp,numfactor);
563         }
564         return expair(e,_ex1);
565 }
566
567 expair add::combine_ex_with_coeff_to_pair(const ex & e,
568                                                                                   const ex & c) const
569 {
570         GINAC_ASSERT(is_exactly_a<numeric>(c));
571         if (is_exactly_a<mul>(e)) {
572                 const mul &mulref(ex_to<mul>(e));
573                 const ex &numfactor = mulref.overall_coeff;
574                 mul *mulcopyp = new mul(mulref);
575                 mulcopyp->overall_coeff = _ex1;
576                 mulcopyp->clearflag(status_flags::evaluated);
577                 mulcopyp->clearflag(status_flags::hash_calculated);
578                 mulcopyp->setflag(status_flags::dynallocated);
579                 if (c.is_equal(_ex1))
580                         return expair(*mulcopyp, numfactor);
581                 else if (numfactor.is_equal(_ex1))
582                         return expair(*mulcopyp, c);
583                 else
584                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
585         } else if (is_exactly_a<numeric>(e)) {
586                 if (c.is_equal(_ex1))
587                         return expair(e, _ex1);
588                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
589         }
590         return expair(e, c);
591 }
592
593 expair add::combine_pair_with_coeff_to_pair(const expair & p,
594                                                                                         const ex & c) const
595 {
596         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
597         GINAC_ASSERT(is_exactly_a<numeric>(c));
598
599         if (is_exactly_a<numeric>(p.rest)) {
600                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(*_num1_p)); // should be normalized
601                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
602         }
603
604         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
605 }
606         
607 ex add::recombine_pair_to_ex(const expair & p) const
608 {
609         if (ex_to<numeric>(p.coeff).is_equal(*_num1_p))
610                 return p.rest;
611         else
612                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
613 }
614
615 ex add::expand(unsigned options) const
616 {
617         std::auto_ptr<epvector> vp = expandchildren(options);
618         if (vp.get() == 0) {
619                 // the terms have not changed, so it is safe to declare this expanded
620                 return (options == 0) ? setflag(status_flags::expanded) : *this;
621         }
622
623         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
624 }
625
626 } // namespace GiNaC