]> www.ginac.de Git - ginac.git/blob - ginac/add.cpp
series expansion behaviour fixed.
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "archive.h"
29 #include "operators.h"
30 #include "matrix.h"
31 #include "utils.h"
32
33 namespace GiNaC {
34
35 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
36
37 //////////
38 // default ctor, dtor, copy ctor, assignment operator and helpers
39 //////////
40
41 add::add()
42 {
43         tinfo_key = TINFO_add;
44 }
45
46 DEFAULT_COPY(add)
47 DEFAULT_DESTROY(add)
48
49 //////////
50 // other constructors
51 //////////
52
53 // public
54
55 add::add(const ex & lh, const ex & rh)
56 {
57         tinfo_key = TINFO_add;
58         overall_coeff = _ex0;
59         construct_from_2_ex(lh,rh);
60         GINAC_ASSERT(is_canonical());
61 }
62
63 add::add(const exvector & v)
64 {
65         tinfo_key = TINFO_add;
66         overall_coeff = _ex0;
67         construct_from_exvector(v);
68         GINAC_ASSERT(is_canonical());
69 }
70
71 add::add(const epvector & v)
72 {
73         tinfo_key = TINFO_add;
74         overall_coeff = _ex0;
75         construct_from_epvector(v);
76         GINAC_ASSERT(is_canonical());
77 }
78
79 add::add(const epvector & v, const ex & oc)
80 {
81         tinfo_key = TINFO_add;
82         overall_coeff = oc;
83         construct_from_epvector(v);
84         GINAC_ASSERT(is_canonical());
85 }
86
87 add::add(epvector * vp, const ex & oc)
88 {
89         tinfo_key = TINFO_add;
90         GINAC_ASSERT(vp!=0);
91         overall_coeff = oc;
92         construct_from_epvector(*vp);
93         delete vp;
94         GINAC_ASSERT(is_canonical());
95 }
96
97 //////////
98 // archiving
99 //////////
100
101 DEFAULT_ARCHIVING(add)
102
103 //////////
104 // functions overriding virtual functions from base classes
105 //////////
106
107 // public
108
109 void add::print(const print_context & c, unsigned level) const
110 {
111         if (is_a<print_tree>(c)) {
112
113                 inherited::print(c, level);
114
115         } else if (is_a<print_csrc>(c)) {
116
117                 if (precedence() <= level)
118                         c.s << "(";
119         
120                 // Print arguments, separated by "+"
121                 epvector::const_iterator it = seq.begin(), itend = seq.end();
122                 while (it != itend) {
123                 
124                         // If the coefficient is -1, it is replaced by a single minus sign
125                         if (it->coeff.is_equal(_ex1)) {
126                                 it->rest.print(c, precedence());
127                         } else if (it->coeff.is_equal(_ex_1)) {
128                                 c.s << "-";
129                                 it->rest.print(c, precedence());
130                         } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num1)) {
131                                 it->rest.print(c, precedence());
132                                 c.s << "/";
133                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
134                         } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num_1)) {
135                                 c.s << "-";
136                                 it->rest.print(c, precedence());
137                                 c.s << "/";
138                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
139                         } else {
140                                 it->coeff.print(c, precedence());
141                                 c.s << "*";
142                                 it->rest.print(c, precedence());
143                         }
144                 
145                         // Separator is "+", except if the following expression would have a leading minus sign or the sign is sitting in parenthesis (as in a ctor)
146                         ++it;
147                         if (it != itend
148                          && (is_a<print_csrc_cl_N>(c) || !it->coeff.info(info_flags::real)  // sign inside ctor arguments
149                           || !(it->coeff.info(info_flags::negative) || (it->coeff.is_equal(_num1) && is_exactly_a<numeric>(it->rest) && it->rest.info(info_flags::negative)))))
150                                 c.s << "+";
151                 }
152         
153                 if (!overall_coeff.is_zero()) {
154                         if (overall_coeff.info(info_flags::positive)
155                          || is_a<print_csrc_cl_N>(c) || !overall_coeff.info(info_flags::real))  // sign inside ctor argument
156                                 c.s << '+';
157                         overall_coeff.print(c, precedence());
158                 }
159                 
160                 if (precedence() <= level)
161                         c.s << ")";
162
163         } else if (is_a<print_python_repr>(c)) {
164
165                 c.s << class_name() << '(';
166                 op(0).print(c);
167                 for (size_t i=1; i<nops(); ++i) {
168                         c.s << ',';
169                         op(i).print(c);
170                 }
171                 c.s << ')';
172
173         } else {
174
175                 if (precedence() <= level) {
176                         if (is_a<print_latex>(c))
177                                 c.s << "{(";
178                         else
179                                 c.s << "(";
180                 }
181
182                 numeric coeff;
183                 bool first = true;
184
185                 // First print the overall numeric coefficient, if present
186                 if (!overall_coeff.is_zero()) {
187                         if (!is_a<print_tree>(c))
188                                 overall_coeff.print(c, 0);
189                         else
190                                 overall_coeff.print(c, precedence());
191                         first = false;
192                 }
193
194                 // Then proceed with the remaining factors
195                 epvector::const_iterator it = seq.begin(), itend = seq.end();
196                 while (it != itend) {
197                         coeff = ex_to<numeric>(it->coeff);
198                         if (!first) {
199                                 if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
200                         } else {
201                                 if (coeff.csgn() == -1) c.s << '-';
202                                 first = false;
203                         }
204                         if (!coeff.is_equal(_num1) &&
205                             !coeff.is_equal(_num_1)) {
206                                 if (coeff.is_rational()) {
207                                         if (coeff.is_negative())
208                                                 (-coeff).print(c);
209                                         else
210                                                 coeff.print(c);
211                                 } else {
212                                         if (coeff.csgn() == -1)
213                                                 (-coeff).print(c, precedence());
214                                         else
215                                                 coeff.print(c, precedence());
216                                 }
217                                 if (is_a<print_latex>(c))
218                                         c.s << ' ';
219                                 else
220                                         c.s << '*';
221                         }
222                         it->rest.print(c, precedence());
223                         ++it;
224                 }
225
226                 if (precedence() <= level) {
227                         if (is_a<print_latex>(c))
228                                 c.s << ")}";
229                         else
230                                 c.s << ")";
231                 }
232         }
233 }
234
235 bool add::info(unsigned inf) const
236 {
237         switch (inf) {
238                 case info_flags::polynomial:
239                 case info_flags::integer_polynomial:
240                 case info_flags::cinteger_polynomial:
241                 case info_flags::rational_polynomial:
242                 case info_flags::crational_polynomial:
243                 case info_flags::rational_function: {
244                         epvector::const_iterator i = seq.begin(), end = seq.end();
245                         while (i != end) {
246                                 if (!(recombine_pair_to_ex(*i).info(inf)))
247                                         return false;
248                                 ++i;
249                         }
250                         return overall_coeff.info(inf);
251                 }
252                 case info_flags::algebraic: {
253                         epvector::const_iterator i = seq.begin(), end = seq.end();
254                         while (i != end) {
255                                 if ((recombine_pair_to_ex(*i).info(inf)))
256                                         return true;
257                                 ++i;
258                         }
259                         return false;
260                 }
261         }
262         return inherited::info(inf);
263 }
264
265 int add::degree(const ex & s) const
266 {
267         int deg = INT_MIN;
268         if (!overall_coeff.is_zero())
269                 deg = 0;
270         
271         // Find maximum of degrees of individual terms
272         epvector::const_iterator i = seq.begin(), end = seq.end();
273         while (i != end) {
274                 int cur_deg = i->rest.degree(s);
275                 if (cur_deg > deg)
276                         deg = cur_deg;
277                 ++i;
278         }
279         return deg;
280 }
281
282 int add::ldegree(const ex & s) const
283 {
284         int deg = INT_MAX;
285         if (!overall_coeff.is_zero())
286                 deg = 0;
287         
288         // Find minimum of degrees of individual terms
289         epvector::const_iterator i = seq.begin(), end = seq.end();
290         while (i != end) {
291                 int cur_deg = i->rest.ldegree(s);
292                 if (cur_deg < deg)
293                         deg = cur_deg;
294                 ++i;
295         }
296         return deg;
297 }
298
299 ex add::coeff(const ex & s, int n) const
300 {
301         epvector *coeffseq = new epvector();
302
303         // Calculate sum of coefficients in each term
304         epvector::const_iterator i = seq.begin(), end = seq.end();
305         while (i != end) {
306                 ex restcoeff = i->rest.coeff(s, n);
307                 if (!restcoeff.is_zero())
308                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
309                 ++i;
310         }
311
312         return (new add(coeffseq, n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
313 }
314
315 /** Perform automatic term rewriting rules in this class.  In the following
316  *  x stands for a symbolic variables of type ex and c stands for such
317  *  an expression that contain a plain number.
318  *  - +(;c) -> c
319  *  - +(x;0) -> x
320  *
321  *  @param level cut-off in recursive evaluation */
322 ex add::eval(int level) const
323 {
324         epvector *evaled_seqp = evalchildren(level);
325         if (evaled_seqp) {
326                 // do more evaluation later
327                 return (new add(evaled_seqp, overall_coeff))->
328                        setflag(status_flags::dynallocated);
329         }
330         
331 #ifdef DO_GINAC_ASSERT
332         epvector::const_iterator i = seq.begin(), end = seq.end();
333         while (i != end) {
334                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
335                 if (is_exactly_a<numeric>(i->rest))
336                         dbgprint();
337                 GINAC_ASSERT(!is_exactly_a<numeric>(i->rest));
338                 ++i;
339         }
340 #endif // def DO_GINAC_ASSERT
341         
342         if (flags & status_flags::evaluated) {
343                 GINAC_ASSERT(seq.size()>0);
344                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
345                 return *this;
346         }
347         
348         int seq_size = seq.size();
349         if (seq_size == 0) {
350                 // +(;c) -> c
351                 return overall_coeff;
352         } else if (seq_size == 1 && overall_coeff.is_zero()) {
353                 // +(x;0) -> x
354                 return recombine_pair_to_ex(*(seq.begin()));
355         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
356                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
357         }
358         return this->hold();
359 }
360
361 ex add::evalm(void) const
362 {
363         // Evaluate children first and add up all matrices. Stop if there's one
364         // term that is not a matrix.
365         epvector *s = new epvector;
366         s->reserve(seq.size());
367
368         bool all_matrices = true;
369         bool first_term = true;
370         matrix sum;
371
372         epvector::const_iterator it = seq.begin(), itend = seq.end();
373         while (it != itend) {
374                 const ex &m = recombine_pair_to_ex(*it).evalm();
375                 s->push_back(split_ex_to_pair(m));
376                 if (is_a<matrix>(m)) {
377                         if (first_term) {
378                                 sum = ex_to<matrix>(m);
379                                 first_term = false;
380                         } else
381                                 sum = sum.add(ex_to<matrix>(m));
382                 } else
383                         all_matrices = false;
384                 ++it;
385         }
386
387         if (all_matrices) {
388                 delete s;
389                 return sum + overall_coeff;
390         } else
391                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
392 }
393
394 ex add::eval_ncmul(const exvector & v) const
395 {
396         if (seq.empty())
397                 return inherited::eval_ncmul(v);
398         else
399                 return seq.begin()->rest.eval_ncmul(v);
400 }    
401
402 // protected
403
404 /** Implementation of ex::diff() for a sum. It differentiates each term.
405  *  @see ex::diff */
406 ex add::derivative(const symbol & y) const
407 {
408         epvector *s = new epvector();
409         s->reserve(seq.size());
410         
411         // Only differentiate the "rest" parts of the expairs. This is faster
412         // than the default implementation in basic::derivative() although
413         // if performs the same function (differentiate each term).
414         epvector::const_iterator i = seq.begin(), end = seq.end();
415         while (i != end) {
416                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
417                 ++i;
418         }
419         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
420 }
421
422 int add::compare_same_type(const basic & other) const
423 {
424         return inherited::compare_same_type(other);
425 }
426
427 unsigned add::return_type(void) const
428 {
429         if (seq.empty())
430                 return return_types::commutative;
431         else
432                 return seq.begin()->rest.return_type();
433 }
434    
435 unsigned add::return_type_tinfo(void) const
436 {
437         if (seq.empty())
438                 return tinfo_key;
439         else
440                 return seq.begin()->rest.return_type_tinfo();
441 }
442
443 ex add::thisexpairseq(const epvector & v, const ex & oc) const
444 {
445         return (new add(v,oc))->setflag(status_flags::dynallocated);
446 }
447
448 ex add::thisexpairseq(epvector * vp, const ex & oc) const
449 {
450         return (new add(vp,oc))->setflag(status_flags::dynallocated);
451 }
452
453 expair add::split_ex_to_pair(const ex & e) const
454 {
455         if (is_exactly_a<mul>(e)) {
456                 const mul &mulref(ex_to<mul>(e));
457                 const ex &numfactor = mulref.overall_coeff;
458                 mul *mulcopyp = new mul(mulref);
459                 mulcopyp->overall_coeff = _ex1;
460                 mulcopyp->clearflag(status_flags::evaluated);
461                 mulcopyp->clearflag(status_flags::hash_calculated);
462                 mulcopyp->setflag(status_flags::dynallocated);
463                 return expair(*mulcopyp,numfactor);
464         }
465         return expair(e,_ex1);
466 }
467
468 expair add::combine_ex_with_coeff_to_pair(const ex & e,
469                                                                                   const ex & c) const
470 {
471         GINAC_ASSERT(is_exactly_a<numeric>(c));
472         if (is_exactly_a<mul>(e)) {
473                 const mul &mulref(ex_to<mul>(e));
474                 const ex &numfactor = mulref.overall_coeff;
475                 mul *mulcopyp = new mul(mulref);
476                 mulcopyp->overall_coeff = _ex1;
477                 mulcopyp->clearflag(status_flags::evaluated);
478                 mulcopyp->clearflag(status_flags::hash_calculated);
479                 mulcopyp->setflag(status_flags::dynallocated);
480                 if (c.is_equal(_ex1))
481                         return expair(*mulcopyp, numfactor);
482                 else if (numfactor.is_equal(_ex1))
483                         return expair(*mulcopyp, c);
484                 else
485                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
486         } else if (is_exactly_a<numeric>(e)) {
487                 if (c.is_equal(_ex1))
488                         return expair(e, _ex1);
489                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
490         }
491         return expair(e, c);
492 }
493
494 expair add::combine_pair_with_coeff_to_pair(const expair & p,
495                                                                                         const ex & c) const
496 {
497         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
498         GINAC_ASSERT(is_exactly_a<numeric>(c));
499
500         if (is_exactly_a<numeric>(p.rest)) {
501                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(_num1)); // should be normalized
502                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
503         }
504
505         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
506 }
507         
508 ex add::recombine_pair_to_ex(const expair & p) const
509 {
510         if (ex_to<numeric>(p.coeff).is_equal(_num1))
511                 return p.rest;
512         else
513                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
514 }
515
516 ex add::expand(unsigned options) const
517 {
518         epvector *vp = expandchildren(options);
519         if (vp == NULL) {
520                 // the terms have not changed, so it is safe to declare this expanded
521                 return (options == 0) ? setflag(status_flags::expanded) : *this;
522         }
523
524         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
525 }
526
527 } // namespace GiNaC