* Remove check for empty seq in Python output.
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "matrix.h"
29 #include "archive.h"
30 #include "utils.h"
31
32 namespace GiNaC {
33
34 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
35
36 //////////
37 // default ctor, dtor, copy ctor, assignment operator and helpers
38 //////////
39
40 add::add()
41 {
42         tinfo_key = TINFO_add;
43 }
44
45 DEFAULT_COPY(add)
46 DEFAULT_DESTROY(add)
47
48 //////////
49 // other constructors
50 //////////
51
52 // public
53
54 add::add(const ex & lh, const ex & rh)
55 {
56         tinfo_key = TINFO_add;
57         overall_coeff = _ex0;
58         construct_from_2_ex(lh,rh);
59         GINAC_ASSERT(is_canonical());
60 }
61
62 add::add(const exvector & v)
63 {
64         tinfo_key = TINFO_add;
65         overall_coeff = _ex0;
66         construct_from_exvector(v);
67         GINAC_ASSERT(is_canonical());
68 }
69
70 add::add(const epvector & v)
71 {
72         tinfo_key = TINFO_add;
73         overall_coeff = _ex0;
74         construct_from_epvector(v);
75         GINAC_ASSERT(is_canonical());
76 }
77
78 add::add(const epvector & v, const ex & oc)
79 {
80         tinfo_key = TINFO_add;
81         overall_coeff = oc;
82         construct_from_epvector(v);
83         GINAC_ASSERT(is_canonical());
84 }
85
86 add::add(epvector * vp, const ex & oc)
87 {
88         tinfo_key = TINFO_add;
89         GINAC_ASSERT(vp!=0);
90         overall_coeff = oc;
91         construct_from_epvector(*vp);
92         delete vp;
93         GINAC_ASSERT(is_canonical());
94 }
95
96 //////////
97 // archiving
98 //////////
99
100 DEFAULT_ARCHIVING(add)
101
102 //////////
103 // functions overriding virtual functions from base classes
104 //////////
105
106 // public
107
108 void add::print(const print_context & c, unsigned level) const
109 {
110         if (is_a<print_tree>(c)) {
111
112                 inherited::print(c, level);
113
114         } else if (is_a<print_csrc>(c)) {
115
116                 if (precedence() <= level)
117                         c.s << "(";
118         
119                 // Print arguments, separated by "+"
120                 epvector::const_iterator it = seq.begin(), itend = seq.end();
121                 while (it != itend) {
122                 
123                         // If the coefficient is -1, it is replaced by a single minus sign
124                         if (it->coeff.compare(_num1) == 0) {
125                                 it->rest.print(c, precedence());
126                         } else if (it->coeff.compare(_num_1) == 0) {
127                                 c.s << "-";
128                                 it->rest.print(c, precedence());
129                         } else if (ex_to<numeric>(it->coeff).numer().compare(_num1) == 0) {
130                                 it->rest.print(c, precedence());
131                                 c.s << "/";
132                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
133                         } else if (ex_to<numeric>(it->coeff).numer().compare(_num_1) == 0) {
134                                 c.s << "-";
135                                 it->rest.print(c, precedence());
136                                 c.s << "/";
137                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
138                         } else {
139                                 it->coeff.print(c, precedence());
140                                 c.s << "*";
141                                 it->rest.print(c, precedence());
142                         }
143                 
144                         // Separator is "+", except if the following expression would have a leading minus sign
145                         ++it;
146                         if (it != itend && !(it->coeff.compare(_num0) < 0 || (it->coeff.compare(_num1) == 0 && is_exactly_a<numeric>(it->rest) && it->rest.compare(_num0) < 0)))
147                                 c.s << "+";
148                 }
149         
150                 if (!overall_coeff.is_zero()) {
151                         if (overall_coeff.info(info_flags::positive))
152                                 c.s << '+';
153                         overall_coeff.print(c, precedence());
154                 }
155                 
156                 if (precedence() <= level)
157                         c.s << ")";
158
159         } else if (is_a<print_python_repr>(c)) {
160
161                 c.s << class_name() << '(';
162                 op(0).print(c);
163                 for (unsigned i=1; i<nops(); ++i) {
164                         c.s << ',';
165                         op(i).print(c);
166                 }
167                 c.s << ')';
168
169         } else {
170
171                 if (precedence() <= level) {
172                         if (is_a<print_latex>(c))
173                                 c.s << "{(";
174                         else
175                                 c.s << "(";
176                 }
177
178                 numeric coeff;
179                 bool first = true;
180
181                 // First print the overall numeric coefficient, if present
182                 if (!overall_coeff.is_zero()) {
183                         if (!is_a<print_tree>(c))
184                                 overall_coeff.print(c, 0);
185                         else
186                                 overall_coeff.print(c, precedence());
187                         first = false;
188                 }
189
190                 // Then proceed with the remaining factors
191                 epvector::const_iterator it = seq.begin(), itend = seq.end();
192                 while (it != itend) {
193                         coeff = ex_to<numeric>(it->coeff);
194                         if (!first) {
195                                 if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
196                         } else {
197                                 if (coeff.csgn() == -1) c.s << '-';
198                                 first = false;
199                         }
200                         if (!coeff.is_equal(_num1) &&
201                             !coeff.is_equal(_num_1)) {
202                                 if (coeff.is_rational()) {
203                                         if (coeff.is_negative())
204                                                 (-coeff).print(c);
205                                         else
206                                                 coeff.print(c);
207                                 } else {
208                                         if (coeff.csgn() == -1)
209                                                 (-coeff).print(c, precedence());
210                                         else
211                                                 coeff.print(c, precedence());
212                                 }
213                                 if (is_a<print_latex>(c))
214                                         c.s << ' ';
215                                 else
216                                         c.s << '*';
217                         }
218                         it->rest.print(c, precedence());
219                         ++it;
220                 }
221
222                 if (precedence() <= level) {
223                         if (is_a<print_latex>(c))
224                                 c.s << ")}";
225                         else
226                                 c.s << ")";
227                 }
228         }
229 }
230
231 bool add::info(unsigned inf) const
232 {
233         switch (inf) {
234                 case info_flags::polynomial:
235                 case info_flags::integer_polynomial:
236                 case info_flags::cinteger_polynomial:
237                 case info_flags::rational_polynomial:
238                 case info_flags::crational_polynomial:
239                 case info_flags::rational_function: {
240                         epvector::const_iterator i = seq.begin(), end = seq.end();
241                         while (i != end) {
242                                 if (!(recombine_pair_to_ex(*i).info(inf)))
243                                         return false;
244                                 ++i;
245                         }
246                         return overall_coeff.info(inf);
247                 }
248                 case info_flags::algebraic: {
249                         epvector::const_iterator i = seq.begin(), end = seq.end();
250                         while (i != end) {
251                                 if ((recombine_pair_to_ex(*i).info(inf)))
252                                         return true;
253                                 ++i;
254                         }
255                         return false;
256                 }
257         }
258         return inherited::info(inf);
259 }
260
261 int add::degree(const ex & s) const
262 {
263         int deg = INT_MIN;
264         if (!overall_coeff.is_zero())
265                 deg = 0;
266         
267         // Find maximum of degrees of individual terms
268         epvector::const_iterator i = seq.begin(), end = seq.end();
269         while (i != end) {
270                 int cur_deg = i->rest.degree(s);
271                 if (cur_deg > deg)
272                         deg = cur_deg;
273                 ++i;
274         }
275         return deg;
276 }
277
278 int add::ldegree(const ex & s) const
279 {
280         int deg = INT_MAX;
281         if (!overall_coeff.is_zero())
282                 deg = 0;
283         
284         // Find minimum of degrees of individual terms
285         epvector::const_iterator i = seq.begin(), end = seq.end();
286         while (i != end) {
287                 int cur_deg = i->rest.ldegree(s);
288                 if (cur_deg < deg)
289                         deg = cur_deg;
290                 ++i;
291         }
292         return deg;
293 }
294
295 ex add::coeff(const ex & s, int n) const
296 {
297         epvector *coeffseq = new epvector();
298
299         // Calculate sum of coefficients in each term
300         epvector::const_iterator i = seq.begin(), end = seq.end();
301         while (i != end) {
302                 ex restcoeff = i->rest.coeff(s, n);
303                 if (!restcoeff.is_zero())
304                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
305                 ++i;
306         }
307
308         return (new add(coeffseq, n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
309 }
310
311 /** Perform automatic term rewriting rules in this class.  In the following
312  *  x stands for a symbolic variables of type ex and c stands for such
313  *  an expression that contain a plain number.
314  *  - +(;c) -> c
315  *  - +(x;1) -> x
316  *
317  *  @param level cut-off in recursive evaluation */
318 ex add::eval(int level) const
319 {
320         epvector *evaled_seqp = evalchildren(level);
321         if (evaled_seqp) {
322                 // do more evaluation later
323                 return (new add(evaled_seqp, overall_coeff))->
324                        setflag(status_flags::dynallocated);
325         }
326         
327 #ifdef DO_GINAC_ASSERT
328         epvector::const_iterator i = seq.begin(), end = seq.end();
329         while (i != end) {
330                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
331                 if (is_ex_exactly_of_type(i->rest,numeric))
332                         dbgprint();
333                 GINAC_ASSERT(!is_exactly_a<numeric>(i->rest));
334                 ++i;
335         }
336 #endif // def DO_GINAC_ASSERT
337         
338         if (flags & status_flags::evaluated) {
339                 GINAC_ASSERT(seq.size()>0);
340                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
341                 return *this;
342         }
343         
344         int seq_size = seq.size();
345         if (seq_size == 0) {
346                 // +(;c) -> c
347                 return overall_coeff;
348         } else if (seq_size == 1 && overall_coeff.is_zero()) {
349                 // +(x;0) -> x
350                 return recombine_pair_to_ex(*(seq.begin()));
351         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
352                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
353         }
354         return this->hold();
355 }
356
357 ex add::evalm(void) const
358 {
359         // Evaluate children first and add up all matrices. Stop if there's one
360         // term that is not a matrix.
361         epvector *s = new epvector;
362         s->reserve(seq.size());
363
364         bool all_matrices = true;
365         bool first_term = true;
366         matrix sum;
367
368         epvector::const_iterator it = seq.begin(), itend = seq.end();
369         while (it != itend) {
370                 const ex &m = recombine_pair_to_ex(*it).evalm();
371                 s->push_back(split_ex_to_pair(m));
372                 if (is_ex_of_type(m, matrix)) {
373                         if (first_term) {
374                                 sum = ex_to<matrix>(m);
375                                 first_term = false;
376                         } else
377                                 sum = sum.add(ex_to<matrix>(m));
378                 } else
379                         all_matrices = false;
380                 ++it;
381         }
382
383         if (all_matrices) {
384                 delete s;
385                 return sum + overall_coeff;
386         } else
387                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
388 }
389
390 ex add::simplify_ncmul(const exvector & v) const
391 {
392         if (seq.empty())
393                 return inherited::simplify_ncmul(v);
394         else
395                 return seq.begin()->rest.simplify_ncmul(v);
396 }    
397
398 // protected
399
400 /** Implementation of ex::diff() for a sum. It differentiates each term.
401  *  @see ex::diff */
402 ex add::derivative(const symbol & y) const
403 {
404         epvector *s = new epvector();
405         s->reserve(seq.size());
406         
407         // Only differentiate the "rest" parts of the expairs. This is faster
408         // than the default implementation in basic::derivative() although
409         // if performs the same function (differentiate each term).
410         epvector::const_iterator i = seq.begin(), end = seq.end();
411         while (i != end) {
412                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
413                 ++i;
414         }
415         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
416 }
417
418 int add::compare_same_type(const basic & other) const
419 {
420         return inherited::compare_same_type(other);
421 }
422
423 bool add::is_equal_same_type(const basic & other) const
424 {
425         return inherited::is_equal_same_type(other);
426 }
427
428 unsigned add::return_type(void) const
429 {
430         if (seq.empty())
431                 return return_types::commutative;
432         else
433                 return seq.begin()->rest.return_type();
434 }
435    
436 unsigned add::return_type_tinfo(void) const
437 {
438         if (seq.empty())
439                 return tinfo_key;
440         else
441                 return seq.begin()->rest.return_type_tinfo();
442 }
443
444 ex add::thisexpairseq(const epvector & v, const ex & oc) const
445 {
446         return (new add(v,oc))->setflag(status_flags::dynallocated);
447 }
448
449 ex add::thisexpairseq(epvector * vp, const ex & oc) const
450 {
451         return (new add(vp,oc))->setflag(status_flags::dynallocated);
452 }
453
454 expair add::split_ex_to_pair(const ex & e) const
455 {
456         if (is_ex_exactly_of_type(e,mul)) {
457                 const mul &mulref(ex_to<mul>(e));
458                 const ex &numfactor = mulref.overall_coeff;
459                 mul *mulcopyp = new mul(mulref);
460                 mulcopyp->overall_coeff = _ex1;
461                 mulcopyp->clearflag(status_flags::evaluated);
462                 mulcopyp->clearflag(status_flags::hash_calculated);
463                 mulcopyp->setflag(status_flags::dynallocated);
464                 return expair(*mulcopyp,numfactor);
465         }
466         return expair(e,_ex1);
467 }
468
469 expair add::combine_ex_with_coeff_to_pair(const ex & e,
470                                                                                   const ex & c) const
471 {
472         GINAC_ASSERT(is_exactly_a<numeric>(c));
473         if (is_ex_exactly_of_type(e, mul)) {
474                 const mul &mulref(ex_to<mul>(e));
475                 const ex &numfactor = mulref.overall_coeff;
476                 mul *mulcopyp = new mul(mulref);
477                 mulcopyp->overall_coeff = _ex1;
478                 mulcopyp->clearflag(status_flags::evaluated);
479                 mulcopyp->clearflag(status_flags::hash_calculated);
480                 mulcopyp->setflag(status_flags::dynallocated);
481                 if (are_ex_trivially_equal(c, _ex1))
482                         return expair(*mulcopyp, numfactor);
483                 else if (are_ex_trivially_equal(numfactor, _ex1))
484                         return expair(*mulcopyp, c);
485                 else
486                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
487         } else if (is_ex_exactly_of_type(e, numeric)) {
488                 if (are_ex_trivially_equal(c, _ex1))
489                         return expair(e, _ex1);
490                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
491         }
492         return expair(e, c);
493 }
494
495 expair add::combine_pair_with_coeff_to_pair(const expair & p,
496                                                                                         const ex & c) const
497 {
498         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
499         GINAC_ASSERT(is_exactly_a<numeric>(c));
500
501         if (is_ex_exactly_of_type(p.rest,numeric)) {
502                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(_num1)); // should be normalized
503                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
504         }
505
506         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
507 }
508         
509 ex add::recombine_pair_to_ex(const expair & p) const
510 {
511         if (ex_to<numeric>(p.coeff).is_equal(_num1))
512                 return p.rest;
513         else
514                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
515 }
516
517 ex add::expand(unsigned options) const
518 {
519         epvector *vp = expandchildren(options);
520         if (vp == NULL) {
521                 // the terms have not changed, so it is safe to declare this expanded
522                 return (options == 0) ? setflag(status_flags::expanded) : *this;
523         }
524         
525         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
526 }
527
528 } // namespace GiNaC