synced to 1.0
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "matrix.h"
29 #include "archive.h"
30 #include "utils.h"
31
32 namespace GiNaC {
33
34 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
35
36 //////////
37 // default ctor, dtor, copy ctor, assignment operator and helpers
38 //////////
39
40 add::add()
41 {
42         tinfo_key = TINFO_add;
43 }
44
45 DEFAULT_COPY(add)
46 DEFAULT_DESTROY(add)
47
48 //////////
49 // other constructors
50 //////////
51
52 // public
53
54 add::add(const ex & lh, const ex & rh)
55 {
56         tinfo_key = TINFO_add;
57         overall_coeff = _ex0;
58         construct_from_2_ex(lh,rh);
59         GINAC_ASSERT(is_canonical());
60 }
61
62 add::add(const exvector & v)
63 {
64         tinfo_key = TINFO_add;
65         overall_coeff = _ex0;
66         construct_from_exvector(v);
67         GINAC_ASSERT(is_canonical());
68 }
69
70 add::add(const epvector & v)
71 {
72         tinfo_key = TINFO_add;
73         overall_coeff = _ex0;
74         construct_from_epvector(v);
75         GINAC_ASSERT(is_canonical());
76 }
77
78 add::add(const epvector & v, const ex & oc)
79 {
80         tinfo_key = TINFO_add;
81         overall_coeff = oc;
82         construct_from_epvector(v);
83         GINAC_ASSERT(is_canonical());
84 }
85
86 add::add(epvector * vp, const ex & oc)
87 {
88         tinfo_key = TINFO_add;
89         GINAC_ASSERT(vp!=0);
90         overall_coeff = oc;
91         construct_from_epvector(*vp);
92         delete vp;
93         GINAC_ASSERT(is_canonical());
94 }
95
96 //////////
97 // archiving
98 //////////
99
100 DEFAULT_ARCHIVING(add)
101
102 //////////
103 // functions overriding virtual functions from base classes
104 //////////
105
106 // public
107
108 void add::print(const print_context & c, unsigned level) const
109 {
110         if (is_a<print_tree>(c)) {
111
112                 inherited::print(c, level);
113
114         } else if (is_a<print_csrc>(c)) {
115
116                 if (precedence() <= level)
117                         c.s << "(";
118         
119                 // Print arguments, separated by "+"
120                 epvector::const_iterator it = seq.begin(), itend = seq.end();
121                 while (it != itend) {
122                 
123                         // If the coefficient is -1, it is replaced by a single minus sign
124                         if (it->coeff.is_equal(_ex1)) {
125                                 it->rest.print(c, precedence());
126                         } else if (it->coeff.is_equal(_ex_1)) {
127                                 c.s << "-";
128                                 it->rest.print(c, precedence());
129                         } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num1)) {
130                                 it->rest.print(c, precedence());
131                                 c.s << "/";
132                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
133                         } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num_1)) {
134                                 c.s << "-";
135                                 it->rest.print(c, precedence());
136                                 c.s << "/";
137                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
138                         } else {
139                                 it->coeff.print(c, precedence());
140                                 c.s << "*";
141                                 it->rest.print(c, precedence());
142                         }
143                 
144                         // Separator is "+", except if the following expression would have a leading minus sign or the sign is sitting in parenthesis (as in a ctor)
145                         ++it;
146                         if (it != itend
147                          && (is_a<print_csrc_cl_N>(c) || !it->coeff.info(info_flags::real)  // sign inside ctor arguments
148                           || !(it->coeff.info(info_flags::negative) || (it->coeff.is_equal(_num1) && is_exactly_a<numeric>(it->rest) && it->rest.info(info_flags::negative)))))
149                                 c.s << "+";
150                 }
151         
152                 if (!overall_coeff.is_zero()) {
153                         if (overall_coeff.info(info_flags::positive)
154                          || is_a<print_csrc_cl_N>(c) || !overall_coeff.info(info_flags::real))  // sign inside ctor argument
155                                 c.s << '+';
156                         overall_coeff.print(c, precedence());
157                 }
158                 
159                 if (precedence() <= level)
160                         c.s << ")";
161
162         } else if (is_a<print_python_repr>(c)) {
163
164                 c.s << class_name() << '(';
165                 op(0).print(c);
166                 for (unsigned i=1; i<nops(); ++i) {
167                         c.s << ',';
168                         op(i).print(c);
169                 }
170                 c.s << ')';
171
172         } else {
173
174                 if (precedence() <= level) {
175                         if (is_a<print_latex>(c))
176                                 c.s << "{(";
177                         else
178                                 c.s << "(";
179                 }
180
181                 numeric coeff;
182                 bool first = true;
183
184                 // First print the overall numeric coefficient, if present
185                 if (!overall_coeff.is_zero()) {
186                         if (!is_a<print_tree>(c))
187                                 overall_coeff.print(c, 0);
188                         else
189                                 overall_coeff.print(c, precedence());
190                         first = false;
191                 }
192
193                 // Then proceed with the remaining factors
194                 epvector::const_iterator it = seq.begin(), itend = seq.end();
195                 while (it != itend) {
196                         coeff = ex_to<numeric>(it->coeff);
197                         if (!first) {
198                                 if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
199                         } else {
200                                 if (coeff.csgn() == -1) c.s << '-';
201                                 first = false;
202                         }
203                         if (!coeff.is_equal(_num1) &&
204                             !coeff.is_equal(_num_1)) {
205                                 if (coeff.is_rational()) {
206                                         if (coeff.is_negative())
207                                                 (-coeff).print(c);
208                                         else
209                                                 coeff.print(c);
210                                 } else {
211                                         if (coeff.csgn() == -1)
212                                                 (-coeff).print(c, precedence());
213                                         else
214                                                 coeff.print(c, precedence());
215                                 }
216                                 if (is_a<print_latex>(c))
217                                         c.s << ' ';
218                                 else
219                                         c.s << '*';
220                         }
221                         it->rest.print(c, precedence());
222                         ++it;
223                 }
224
225                 if (precedence() <= level) {
226                         if (is_a<print_latex>(c))
227                                 c.s << ")}";
228                         else
229                                 c.s << ")";
230                 }
231         }
232 }
233
234 bool add::info(unsigned inf) const
235 {
236         switch (inf) {
237                 case info_flags::polynomial:
238                 case info_flags::integer_polynomial:
239                 case info_flags::cinteger_polynomial:
240                 case info_flags::rational_polynomial:
241                 case info_flags::crational_polynomial:
242                 case info_flags::rational_function: {
243                         epvector::const_iterator i = seq.begin(), end = seq.end();
244                         while (i != end) {
245                                 if (!(recombine_pair_to_ex(*i).info(inf)))
246                                         return false;
247                                 ++i;
248                         }
249                         return overall_coeff.info(inf);
250                 }
251                 case info_flags::algebraic: {
252                         epvector::const_iterator i = seq.begin(), end = seq.end();
253                         while (i != end) {
254                                 if ((recombine_pair_to_ex(*i).info(inf)))
255                                         return true;
256                                 ++i;
257                         }
258                         return false;
259                 }
260         }
261         return inherited::info(inf);
262 }
263
264 int add::degree(const ex & s) const
265 {
266         int deg = INT_MIN;
267         if (!overall_coeff.is_zero())
268                 deg = 0;
269         
270         // Find maximum of degrees of individual terms
271         epvector::const_iterator i = seq.begin(), end = seq.end();
272         while (i != end) {
273                 int cur_deg = i->rest.degree(s);
274                 if (cur_deg > deg)
275                         deg = cur_deg;
276                 ++i;
277         }
278         return deg;
279 }
280
281 int add::ldegree(const ex & s) const
282 {
283         int deg = INT_MAX;
284         if (!overall_coeff.is_zero())
285                 deg = 0;
286         
287         // Find minimum of degrees of individual terms
288         epvector::const_iterator i = seq.begin(), end = seq.end();
289         while (i != end) {
290                 int cur_deg = i->rest.ldegree(s);
291                 if (cur_deg < deg)
292                         deg = cur_deg;
293                 ++i;
294         }
295         return deg;
296 }
297
298 ex add::coeff(const ex & s, int n) const
299 {
300         epvector *coeffseq = new epvector();
301
302         // Calculate sum of coefficients in each term
303         epvector::const_iterator i = seq.begin(), end = seq.end();
304         while (i != end) {
305                 ex restcoeff = i->rest.coeff(s, n);
306                 if (!restcoeff.is_zero())
307                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
308                 ++i;
309         }
310
311         return (new add(coeffseq, n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
312 }
313
314 /** Perform automatic term rewriting rules in this class.  In the following
315  *  x stands for a symbolic variables of type ex and c stands for such
316  *  an expression that contain a plain number.
317  *  - +(;c) -> c
318  *  - +(x;0) -> x
319  *
320  *  @param level cut-off in recursive evaluation */
321 ex add::eval(int level) const
322 {
323         epvector *evaled_seqp = evalchildren(level);
324         if (evaled_seqp) {
325                 // do more evaluation later
326                 return (new add(evaled_seqp, overall_coeff))->
327                        setflag(status_flags::dynallocated);
328         }
329         
330 #ifdef DO_GINAC_ASSERT
331         epvector::const_iterator i = seq.begin(), end = seq.end();
332         while (i != end) {
333                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
334                 if (is_ex_exactly_of_type(i->rest,numeric))
335                         dbgprint();
336                 GINAC_ASSERT(!is_exactly_a<numeric>(i->rest));
337                 ++i;
338         }
339 #endif // def DO_GINAC_ASSERT
340         
341         if (flags & status_flags::evaluated) {
342                 GINAC_ASSERT(seq.size()>0);
343                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
344                 return *this;
345         }
346         
347         int seq_size = seq.size();
348         if (seq_size == 0) {
349                 // +(;c) -> c
350                 return overall_coeff;
351         } else if (seq_size == 1 && overall_coeff.is_zero()) {
352                 // +(x;0) -> x
353                 return recombine_pair_to_ex(*(seq.begin()));
354         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
355                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
356         }
357         return this->hold();
358 }
359
360 ex add::evalm(void) const
361 {
362         // Evaluate children first and add up all matrices. Stop if there's one
363         // term that is not a matrix.
364         epvector *s = new epvector;
365         s->reserve(seq.size());
366
367         bool all_matrices = true;
368         bool first_term = true;
369         matrix sum;
370
371         epvector::const_iterator it = seq.begin(), itend = seq.end();
372         while (it != itend) {
373                 const ex &m = recombine_pair_to_ex(*it).evalm();
374                 s->push_back(split_ex_to_pair(m));
375                 if (is_ex_of_type(m, matrix)) {
376                         if (first_term) {
377                                 sum = ex_to<matrix>(m);
378                                 first_term = false;
379                         } else
380                                 sum = sum.add(ex_to<matrix>(m));
381                 } else
382                         all_matrices = false;
383                 ++it;
384         }
385
386         if (all_matrices) {
387                 delete s;
388                 return sum + overall_coeff;
389         } else
390                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
391 }
392
393 ex add::simplify_ncmul(const exvector & v) const
394 {
395         if (seq.empty())
396                 return inherited::simplify_ncmul(v);
397         else
398                 return seq.begin()->rest.simplify_ncmul(v);
399 }    
400
401 // protected
402
403 /** Implementation of ex::diff() for a sum. It differentiates each term.
404  *  @see ex::diff */
405 ex add::derivative(const symbol & y) const
406 {
407         epvector *s = new epvector();
408         s->reserve(seq.size());
409         
410         // Only differentiate the "rest" parts of the expairs. This is faster
411         // than the default implementation in basic::derivative() although
412         // if performs the same function (differentiate each term).
413         epvector::const_iterator i = seq.begin(), end = seq.end();
414         while (i != end) {
415                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
416                 ++i;
417         }
418         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
419 }
420
421 int add::compare_same_type(const basic & other) const
422 {
423         return inherited::compare_same_type(other);
424 }
425
426 bool add::is_equal_same_type(const basic & other) const
427 {
428         return inherited::is_equal_same_type(other);
429 }
430
431 unsigned add::return_type(void) const
432 {
433         if (seq.empty())
434                 return return_types::commutative;
435         else
436                 return seq.begin()->rest.return_type();
437 }
438    
439 unsigned add::return_type_tinfo(void) const
440 {
441         if (seq.empty())
442                 return tinfo_key;
443         else
444                 return seq.begin()->rest.return_type_tinfo();
445 }
446
447 ex add::thisexpairseq(const epvector & v, const ex & oc) const
448 {
449         return (new add(v,oc))->setflag(status_flags::dynallocated);
450 }
451
452 ex add::thisexpairseq(epvector * vp, const ex & oc) const
453 {
454         return (new add(vp,oc))->setflag(status_flags::dynallocated);
455 }
456
457 expair add::split_ex_to_pair(const ex & e) const
458 {
459         if (is_ex_exactly_of_type(e,mul)) {
460                 const mul &mulref(ex_to<mul>(e));
461                 const ex &numfactor = mulref.overall_coeff;
462                 mul *mulcopyp = new mul(mulref);
463                 mulcopyp->overall_coeff = _ex1;
464                 mulcopyp->clearflag(status_flags::evaluated);
465                 mulcopyp->clearflag(status_flags::hash_calculated);
466                 mulcopyp->setflag(status_flags::dynallocated);
467                 return expair(*mulcopyp,numfactor);
468         }
469         return expair(e,_ex1);
470 }
471
472 expair add::combine_ex_with_coeff_to_pair(const ex & e,
473                                                                                   const ex & c) const
474 {
475         GINAC_ASSERT(is_exactly_a<numeric>(c));
476         if (is_ex_exactly_of_type(e, mul)) {
477                 const mul &mulref(ex_to<mul>(e));
478                 const ex &numfactor = mulref.overall_coeff;
479                 mul *mulcopyp = new mul(mulref);
480                 mulcopyp->overall_coeff = _ex1;
481                 mulcopyp->clearflag(status_flags::evaluated);
482                 mulcopyp->clearflag(status_flags::hash_calculated);
483                 mulcopyp->setflag(status_flags::dynallocated);
484                 if (are_ex_trivially_equal(c, _ex1))
485                         return expair(*mulcopyp, numfactor);
486                 else if (are_ex_trivially_equal(numfactor, _ex1))
487                         return expair(*mulcopyp, c);
488                 else
489                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
490         } else if (is_ex_exactly_of_type(e, numeric)) {
491                 if (are_ex_trivially_equal(c, _ex1))
492                         return expair(e, _ex1);
493                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
494         }
495         return expair(e, c);
496 }
497
498 expair add::combine_pair_with_coeff_to_pair(const expair & p,
499                                                                                         const ex & c) const
500 {
501         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
502         GINAC_ASSERT(is_exactly_a<numeric>(c));
503
504         if (is_ex_exactly_of_type(p.rest,numeric)) {
505                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(_num1)); // should be normalized
506                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
507         }
508
509         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
510 }
511         
512 ex add::recombine_pair_to_ex(const expair & p) const
513 {
514         if (ex_to<numeric>(p.coeff).is_equal(_num1))
515                 return p.rest;
516         else
517                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
518 }
519
520 ex add::expand(unsigned options) const
521 {
522         epvector *vp = expandchildren(options);
523         if (vp == NULL) {
524                 // the terms have not changed, so it is safe to declare this expanded
525                 return (options == 0) ? setflag(status_flags::expanded) : *this;
526         }
527
528         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
529 }
530
531 } // namespace GiNaC