6548366ee16f1facb3dab2c1751444730171f236
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2010 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "add.h"
24 #include "mul.h"
25 #include "archive.h"
26 #include "operators.h"
27 #include "matrix.h"
28 #include "utils.h"
29 #include "clifford.h"
30 #include "ncmul.h"
31
32 #include <iostream>
33 #include <limits>
34 #include <stdexcept>
35 #include <string>
36
37 namespace GiNaC {
38
39 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(add, expairseq,
40   print_func<print_context>(&add::do_print).
41   print_func<print_latex>(&add::do_print_latex).
42   print_func<print_csrc>(&add::do_print_csrc).
43   print_func<print_tree>(&add::do_print_tree).
44   print_func<print_python_repr>(&add::do_print_python_repr))
45
46 //////////
47 // default constructor
48 //////////
49
50 add::add()
51 {
52 }
53
54 //////////
55 // other constructors
56 //////////
57
58 // public
59
60 add::add(const ex & lh, const ex & rh)
61 {
62         overall_coeff = _ex0;
63         construct_from_2_ex(lh,rh);
64         GINAC_ASSERT(is_canonical());
65 }
66
67 add::add(const exvector & v)
68 {
69         overall_coeff = _ex0;
70         construct_from_exvector(v);
71         GINAC_ASSERT(is_canonical());
72 }
73
74 add::add(const epvector & v)
75 {
76         overall_coeff = _ex0;
77         construct_from_epvector(v);
78         GINAC_ASSERT(is_canonical());
79 }
80
81 add::add(const epvector & v, const ex & oc)
82 {
83         overall_coeff = oc;
84         construct_from_epvector(v);
85         GINAC_ASSERT(is_canonical());
86 }
87
88 add::add(std::auto_ptr<epvector> vp, const ex & oc)
89 {
90         GINAC_ASSERT(vp.get()!=0);
91         overall_coeff = oc;
92         construct_from_epvector(*vp);
93         GINAC_ASSERT(is_canonical());
94 }
95
96 //////////
97 // archiving
98 //////////
99
100 GINAC_BIND_UNARCHIVER(add);
101
102 //////////
103 // functions overriding virtual functions from base classes
104 //////////
105
106 // public
107
108 void add::print_add(const print_context & c, const char *openbrace, const char *closebrace, const char *mul_sym, unsigned level) const
109 {
110         if (precedence() <= level)
111                 c.s << openbrace << '(';
112
113         numeric coeff;
114         bool first = true;
115
116         // First print the overall numeric coefficient, if present
117         if (!overall_coeff.is_zero()) {
118                 overall_coeff.print(c, 0);
119                 first = false;
120         }
121
122         // Then proceed with the remaining factors
123         epvector::const_iterator it = seq.begin(), itend = seq.end();
124         while (it != itend) {
125                 coeff = ex_to<numeric>(it->coeff);
126                 if (!first) {
127                         if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
128                 } else {
129                         if (coeff.csgn() == -1) c.s << '-';
130                         first = false;
131                 }
132                 if (!coeff.is_equal(*_num1_p) &&
133                     !coeff.is_equal(*_num_1_p)) {
134                         if (coeff.is_rational()) {
135                                 if (coeff.is_negative())
136                                         (-coeff).print(c);
137                                 else
138                                         coeff.print(c);
139                         } else {
140                                 if (coeff.csgn() == -1)
141                                         (-coeff).print(c, precedence());
142                                 else
143                                         coeff.print(c, precedence());
144                         }
145                         c.s << mul_sym;
146                 }
147                 it->rest.print(c, precedence());
148                 ++it;
149         }
150
151         if (precedence() <= level)
152                 c.s << ')' << closebrace;
153 }
154
155 void add::do_print(const print_context & c, unsigned level) const
156 {
157         print_add(c, "", "", "*", level);
158 }
159
160 void add::do_print_latex(const print_latex & c, unsigned level) const
161 {
162         print_add(c, "{", "}", " ", level);
163 }
164
165 void add::do_print_csrc(const print_csrc & c, unsigned level) const
166 {
167         if (precedence() <= level)
168                 c.s << "(";
169         
170         // Print arguments, separated by "+" or "-"
171         epvector::const_iterator it = seq.begin(), itend = seq.end();
172         char separator = ' ';
173         while (it != itend) {
174                 
175                 // If the coefficient is negative, separator is "-"
176                 if (it->coeff.is_equal(_ex_1) || 
177                         ex_to<numeric>(it->coeff).numer().is_equal(*_num_1_p))
178                         separator = '-';
179                 c.s << separator;
180                 if (it->coeff.is_equal(_ex1) || it->coeff.is_equal(_ex_1)) {
181                         it->rest.print(c, precedence());
182                 } else if (ex_to<numeric>(it->coeff).numer().is_equal(*_num1_p) ||
183                                  ex_to<numeric>(it->coeff).numer().is_equal(*_num_1_p))
184                 {
185                         it->rest.print(c, precedence());
186                         c.s << '/';
187                         ex_to<numeric>(it->coeff).denom().print(c, precedence());
188                 } else {
189                         it->coeff.print(c, precedence());
190                         c.s << '*';
191                         it->rest.print(c, precedence());
192                 }
193                 
194                 ++it;
195                 separator = '+';
196         }
197         
198         if (!overall_coeff.is_zero()) {
199                 if (overall_coeff.info(info_flags::positive)
200                  || is_a<print_csrc_cl_N>(c) || !overall_coeff.info(info_flags::real))  // sign inside ctor argument
201                         c.s << '+';
202                 overall_coeff.print(c, precedence());
203         }
204                 
205         if (precedence() <= level)
206                 c.s << ")";
207 }
208
209 void add::do_print_python_repr(const print_python_repr & c, unsigned level) const
210 {
211         c.s << class_name() << '(';
212         op(0).print(c);
213         for (size_t i=1; i<nops(); ++i) {
214                 c.s << ',';
215                 op(i).print(c);
216         }
217         c.s << ')';
218 }
219
220 bool add::info(unsigned inf) const
221 {
222         switch (inf) {
223                 case info_flags::polynomial:
224                 case info_flags::integer_polynomial:
225                 case info_flags::cinteger_polynomial:
226                 case info_flags::rational_polynomial:
227                 case info_flags::real:
228                 case info_flags::rational:
229                 case info_flags::integer:
230                 case info_flags::crational:
231                 case info_flags::cinteger:
232                 case info_flags::positive:
233                 case info_flags::nonnegative:
234                 case info_flags::posint:
235                 case info_flags::nonnegint:
236                 case info_flags::even:
237                 case info_flags::crational_polynomial:
238                 case info_flags::rational_function: {
239                         epvector::const_iterator i = seq.begin(), end = seq.end();
240                         while (i != end) {
241                                 if (!(recombine_pair_to_ex(*i).info(inf)))
242                                         return false;
243                                 ++i;
244                         }
245                         if (overall_coeff.is_zero() && (inf == info_flags::positive || inf == info_flags::posint))
246                                 return true;
247                         return overall_coeff.info(inf);
248                 }
249                 case info_flags::algebraic: {
250                         epvector::const_iterator i = seq.begin(), end = seq.end();
251                         while (i != end) {
252                                 if ((recombine_pair_to_ex(*i).info(inf)))
253                                         return true;
254                                 ++i;
255                         }
256                         return false;
257                 }
258         }
259         return inherited::info(inf);
260 }
261
262 int add::degree(const ex & s) const
263 {
264         int deg = std::numeric_limits<int>::min();
265         if (!overall_coeff.is_zero())
266                 deg = 0;
267         
268         // Find maximum of degrees of individual terms
269         epvector::const_iterator i = seq.begin(), end = seq.end();
270         while (i != end) {
271                 int cur_deg = i->rest.degree(s);
272                 if (cur_deg > deg)
273                         deg = cur_deg;
274                 ++i;
275         }
276         return deg;
277 }
278
279 int add::ldegree(const ex & s) const
280 {
281         int deg = std::numeric_limits<int>::max();
282         if (!overall_coeff.is_zero())
283                 deg = 0;
284         
285         // Find minimum of degrees of individual terms
286         epvector::const_iterator i = seq.begin(), end = seq.end();
287         while (i != end) {
288                 int cur_deg = i->rest.ldegree(s);
289                 if (cur_deg < deg)
290                         deg = cur_deg;
291                 ++i;
292         }
293         return deg;
294 }
295
296 ex add::coeff(const ex & s, int n) const
297 {
298         std::auto_ptr<epvector> coeffseq(new epvector);
299         std::auto_ptr<epvector> coeffseq_cliff(new epvector);
300         char rl = clifford_max_label(s);
301         bool do_clifford = (rl != -1);
302         bool nonscalar = false;
303
304         // Calculate sum of coefficients in each term
305         epvector::const_iterator i = seq.begin(), end = seq.end();
306         while (i != end) {
307                 ex restcoeff = i->rest.coeff(s, n);
308                 if (!restcoeff.is_zero()) {
309                         if (do_clifford) {
310                                 if (clifford_max_label(restcoeff) == -1) {
311                                         coeffseq_cliff->push_back(combine_ex_with_coeff_to_pair(ncmul(restcoeff, dirac_ONE(rl)), i->coeff));
312                                 } else {
313                                         coeffseq_cliff->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
314                                         nonscalar = true;
315                                 }
316                         }
317                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
318                 }
319                 ++i;
320         }
321
322         return (new add(nonscalar ? coeffseq_cliff : coeffseq,
323                         n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
324 }
325
326 /** Perform automatic term rewriting rules in this class.  In the following
327  *  x stands for a symbolic variables of type ex and c stands for such
328  *  an expression that contain a plain number.
329  *  - +(;c) -> c
330  *  - +(x;0) -> x
331  *
332  *  @param level cut-off in recursive evaluation */
333 ex add::eval(int level) const
334 {
335         std::auto_ptr<epvector> evaled_seqp = evalchildren(level);
336         if (evaled_seqp.get()) {
337                 // do more evaluation later
338                 return (new add(evaled_seqp, overall_coeff))->
339                        setflag(status_flags::dynallocated);
340         }
341         
342 #ifdef DO_GINAC_ASSERT
343         epvector::const_iterator i = seq.begin(), end = seq.end();
344         while (i != end) {
345                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
346                 if (is_exactly_a<numeric>(i->rest))
347                         dbgprint();
348                 GINAC_ASSERT(!is_exactly_a<numeric>(i->rest));
349                 ++i;
350         }
351 #endif // def DO_GINAC_ASSERT
352         
353         if (flags & status_flags::evaluated) {
354                 GINAC_ASSERT(seq.size()>0);
355                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
356                 return *this;
357         }
358         
359         int seq_size = seq.size();
360         if (seq_size == 0) {
361                 // +(;c) -> c
362                 return overall_coeff;
363         } else if (seq_size == 1 && overall_coeff.is_zero()) {
364                 // +(x;0) -> x
365                 return recombine_pair_to_ex(*(seq.begin()));
366         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
367                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
368         }
369         return this->hold();
370 }
371
372 ex add::evalm() const
373 {
374         // Evaluate children first and add up all matrices. Stop if there's one
375         // term that is not a matrix.
376         std::auto_ptr<epvector> s(new epvector);
377         s->reserve(seq.size());
378
379         bool all_matrices = true;
380         bool first_term = true;
381         matrix sum;
382
383         epvector::const_iterator it = seq.begin(), itend = seq.end();
384         while (it != itend) {
385                 const ex &m = recombine_pair_to_ex(*it).evalm();
386                 s->push_back(split_ex_to_pair(m));
387                 if (is_a<matrix>(m)) {
388                         if (first_term) {
389                                 sum = ex_to<matrix>(m);
390                                 first_term = false;
391                         } else
392                                 sum = sum.add(ex_to<matrix>(m));
393                 } else
394                         all_matrices = false;
395                 ++it;
396         }
397
398         if (all_matrices)
399                 return sum + overall_coeff;
400         else
401                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
402 }
403
404 ex add::conjugate() const
405 {
406         exvector *v = 0;
407         for (size_t i=0; i<nops(); ++i) {
408                 if (v) {
409                         v->push_back(op(i).conjugate());
410                         continue;
411                 }
412                 ex term = op(i);
413                 ex ccterm = term.conjugate();
414                 if (are_ex_trivially_equal(term, ccterm))
415                         continue;
416                 v = new exvector;
417                 v->reserve(nops());
418                 for (size_t j=0; j<i; ++j)
419                         v->push_back(op(j));
420                 v->push_back(ccterm);
421         }
422         if (v) {
423                 ex result = add(*v);
424                 delete v;
425                 return result;
426         }
427         return *this;
428 }
429
430 ex add::real_part() const
431 {
432         epvector v;
433         v.reserve(seq.size());
434         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i)
435                 if ((i->coeff).info(info_flags::real)) {
436                         ex rp = (i->rest).real_part();
437                         if (!rp.is_zero())
438                                 v.push_back(expair(rp, i->coeff));
439                 } else {
440                         ex rp=recombine_pair_to_ex(*i).real_part();
441                         if (!rp.is_zero())
442                                 v.push_back(split_ex_to_pair(rp));
443                 }
444         return (new add(v, overall_coeff.real_part()))
445                 -> setflag(status_flags::dynallocated);
446 }
447
448 ex add::imag_part() const
449 {
450         epvector v;
451         v.reserve(seq.size());
452         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i)
453                 if ((i->coeff).info(info_flags::real)) {
454                         ex ip = (i->rest).imag_part();
455                         if (!ip.is_zero())
456                                 v.push_back(expair(ip, i->coeff));
457                 } else {
458                         ex ip=recombine_pair_to_ex(*i).imag_part();
459                         if (!ip.is_zero())
460                                 v.push_back(split_ex_to_pair(ip));
461                 }
462         return (new add(v, overall_coeff.imag_part()))
463                 -> setflag(status_flags::dynallocated);
464 }
465
466 ex add::eval_ncmul(const exvector & v) const
467 {
468         if (seq.empty())
469                 return inherited::eval_ncmul(v);
470         else
471                 return seq.begin()->rest.eval_ncmul(v);
472 }    
473
474 // protected
475
476 /** Implementation of ex::diff() for a sum. It differentiates each term.
477  *  @see ex::diff */
478 ex add::derivative(const symbol & y) const
479 {
480         std::auto_ptr<epvector> s(new epvector);
481         s->reserve(seq.size());
482         
483         // Only differentiate the "rest" parts of the expairs. This is faster
484         // than the default implementation in basic::derivative() although
485         // if performs the same function (differentiate each term).
486         epvector::const_iterator i = seq.begin(), end = seq.end();
487         while (i != end) {
488                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
489                 ++i;
490         }
491         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
492 }
493
494 int add::compare_same_type(const basic & other) const
495 {
496         return inherited::compare_same_type(other);
497 }
498
499 unsigned add::return_type() const
500 {
501         if (seq.empty())
502                 return return_types::commutative;
503         else
504                 return seq.begin()->rest.return_type();
505 }
506
507 return_type_t add::return_type_tinfo() const
508 {
509         if (seq.empty())
510                 return make_return_type_t<add>();
511         else
512                 return seq.begin()->rest.return_type_tinfo();
513 }
514
515 // Note: do_index_renaming is ignored because it makes no sense for an add.
516 ex add::thisexpairseq(const epvector & v, const ex & oc, bool do_index_renaming) const
517 {
518         return (new add(v,oc))->setflag(status_flags::dynallocated);
519 }
520
521 // Note: do_index_renaming is ignored because it makes no sense for an add.
522 ex add::thisexpairseq(std::auto_ptr<epvector> vp, const ex & oc, bool do_index_renaming) const
523 {
524         return (new add(vp,oc))->setflag(status_flags::dynallocated);
525 }
526
527 expair add::split_ex_to_pair(const ex & e) const
528 {
529         if (is_exactly_a<mul>(e)) {
530                 const mul &mulref(ex_to<mul>(e));
531                 const ex &numfactor = mulref.overall_coeff;
532                 mul *mulcopyp = new mul(mulref);
533                 mulcopyp->overall_coeff = _ex1;
534                 mulcopyp->clearflag(status_flags::evaluated);
535                 mulcopyp->clearflag(status_flags::hash_calculated);
536                 mulcopyp->setflag(status_flags::dynallocated);
537                 return expair(*mulcopyp,numfactor);
538         }
539         return expair(e,_ex1);
540 }
541
542 expair add::combine_ex_with_coeff_to_pair(const ex & e,
543                                                                                   const ex & c) const
544 {
545         GINAC_ASSERT(is_exactly_a<numeric>(c));
546         if (is_exactly_a<mul>(e)) {
547                 const mul &mulref(ex_to<mul>(e));
548                 const ex &numfactor = mulref.overall_coeff;
549                 mul *mulcopyp = new mul(mulref);
550                 mulcopyp->overall_coeff = _ex1;
551                 mulcopyp->clearflag(status_flags::evaluated);
552                 mulcopyp->clearflag(status_flags::hash_calculated);
553                 mulcopyp->setflag(status_flags::dynallocated);
554                 if (c.is_equal(_ex1))
555                         return expair(*mulcopyp, numfactor);
556                 else if (numfactor.is_equal(_ex1))
557                         return expair(*mulcopyp, c);
558                 else
559                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
560         } else if (is_exactly_a<numeric>(e)) {
561                 if (c.is_equal(_ex1))
562                         return expair(e, _ex1);
563                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
564         }
565         return expair(e, c);
566 }
567
568 expair add::combine_pair_with_coeff_to_pair(const expair & p,
569                                                                                         const ex & c) const
570 {
571         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
572         GINAC_ASSERT(is_exactly_a<numeric>(c));
573
574         if (is_exactly_a<numeric>(p.rest)) {
575                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(*_num1_p)); // should be normalized
576                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
577         }
578
579         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
580 }
581         
582 ex add::recombine_pair_to_ex(const expair & p) const
583 {
584         if (ex_to<numeric>(p.coeff).is_equal(*_num1_p))
585                 return p.rest;
586         else
587                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
588 }
589
590 ex add::expand(unsigned options) const
591 {
592         std::auto_ptr<epvector> vp = expandchildren(options);
593         if (vp.get() == 0) {
594                 // the terms have not changed, so it is safe to declare this expanded
595                 return (options == 0) ? setflag(status_flags::expanded) : *this;
596         }
597
598         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
599 }
600
601 } // namespace GiNaC