* Some internal reorganization WRT flyweight handling and initialization,
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "matrix.h"
29 #include "archive.h"
30 #include "debugmsg.h"
31 #include "utils.h"
32
33 namespace GiNaC {
34
35 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
36
37 //////////
38 // default constructor, destructor, copy constructor assignment operator and helpers
39 //////////
40
41 add::add()
42 {
43         debugmsg("add default constructor",LOGLEVEL_CONSTRUCT);
44         tinfo_key = TINFO_add;
45 }
46
47 DEFAULT_COPY(add)
48 DEFAULT_DESTROY(add)
49
50 //////////
51 // other constructors
52 //////////
53
54 // public
55
56 add::add(const ex & lh, const ex & rh)
57 {
58         debugmsg("add constructor from ex,ex",LOGLEVEL_CONSTRUCT);
59         tinfo_key = TINFO_add;
60         overall_coeff = _ex0;
61         construct_from_2_ex(lh,rh);
62         GINAC_ASSERT(is_canonical());
63 }
64
65 add::add(const exvector & v)
66 {
67         debugmsg("add constructor from exvector",LOGLEVEL_CONSTRUCT);
68         tinfo_key = TINFO_add;
69         overall_coeff = _ex0;
70         construct_from_exvector(v);
71         GINAC_ASSERT(is_canonical());
72 }
73
74 add::add(const epvector & v)
75 {
76         debugmsg("add constructor from epvector",LOGLEVEL_CONSTRUCT);
77         tinfo_key = TINFO_add;
78         overall_coeff = _ex0;
79         construct_from_epvector(v);
80         GINAC_ASSERT(is_canonical());
81 }
82
83 add::add(const epvector & v, const ex & oc)
84 {
85         debugmsg("add constructor from epvector,ex",LOGLEVEL_CONSTRUCT);
86         tinfo_key = TINFO_add;
87         overall_coeff = oc;
88         construct_from_epvector(v);
89         GINAC_ASSERT(is_canonical());
90 }
91
92 add::add(epvector * vp, const ex & oc)
93 {
94         debugmsg("add constructor from epvector *,ex",LOGLEVEL_CONSTRUCT);
95         tinfo_key = TINFO_add;
96         GINAC_ASSERT(vp!=0);
97         overall_coeff = oc;
98         construct_from_epvector(*vp);
99         delete vp;
100         GINAC_ASSERT(is_canonical());
101 }
102
103 //////////
104 // archiving
105 //////////
106
107 DEFAULT_ARCHIVING(add)
108
109 //////////
110 // functions overriding virtual functions from base classes
111 //////////
112
113 // public
114
115 void add::print(const print_context & c, unsigned level) const
116 {
117         debugmsg("add print", LOGLEVEL_PRINT);
118
119         if (is_a<print_tree>(c)) {
120
121                 inherited::print(c, level);
122
123         } else if (is_a<print_csrc>(c)) {
124
125                 if (precedence() <= level)
126                         c.s << "(";
127         
128                 // Print arguments, separated by "+"
129                 epvector::const_iterator it = seq.begin(), itend = seq.end();
130                 while (it != itend) {
131                 
132                         // If the coefficient is -1, it is replaced by a single minus sign
133                         if (it->coeff.compare(_num1) == 0) {
134                                 it->rest.print(c, precedence());
135                         } else if (it->coeff.compare(_num_1) == 0) {
136                                 c.s << "-";
137                                 it->rest.print(c, precedence());
138                         } else if (ex_to<numeric>(it->coeff).numer().compare(_num1) == 0) {
139                                 it->rest.print(c, precedence());
140                                 c.s << "/";
141                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
142                         } else if (ex_to<numeric>(it->coeff).numer().compare(_num_1) == 0) {
143                                 c.s << "-";
144                                 it->rest.print(c, precedence());
145                                 c.s << "/";
146                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
147                         } else {
148                                 it->coeff.print(c, precedence());
149                                 c.s << "*";
150                                 it->rest.print(c, precedence());
151                         }
152                 
153                         // Separator is "+", except if the following expression would have a leading minus sign
154                         ++it;
155                         if (it != itend && !(it->coeff.compare(_num0) < 0 || (it->coeff.compare(_num1) == 0 && is_exactly_a<numeric>(it->rest) && it->rest.compare(_num0) < 0)))
156                                 c.s << "+";
157                 }
158         
159                 if (!overall_coeff.is_zero()) {
160                         if (overall_coeff.info(info_flags::positive))
161                                 c.s << '+';
162                         overall_coeff.print(c, precedence());
163                 }
164         
165                 if (precedence() <= level)
166                         c.s << ")";
167
168         } else {
169
170                 if (precedence() <= level) {
171                         if (is_a<print_latex>(c))
172                                 c.s << "{(";
173                         else
174                                 c.s << "(";
175                 }
176
177                 numeric coeff;
178                 bool first = true;
179
180                 // First print the overall numeric coefficient, if present
181                 if (!overall_coeff.is_zero()) {
182                         if (!is_a<print_tree>(c))
183                                 overall_coeff.print(c, 0);
184                         else
185                                 overall_coeff.print(c, precedence());
186                         first = false;
187                 }
188
189                 // Then proceed with the remaining factors
190                 epvector::const_iterator it = seq.begin(), itend = seq.end();
191                 while (it != itend) {
192                         coeff = ex_to<numeric>(it->coeff);
193                         if (!first) {
194                                 if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
195                         } else {
196                                 if (coeff.csgn() == -1) c.s << '-';
197                                 first = false;
198                         }
199                         if (!coeff.is_equal(_num1) &&
200                             !coeff.is_equal(_num_1)) {
201                                 if (coeff.is_rational()) {
202                                         if (coeff.is_negative())
203                                                 (-coeff).print(c);
204                                         else
205                                                 coeff.print(c);
206                                 } else {
207                                         if (coeff.csgn() == -1)
208                                                 (-coeff).print(c, precedence());
209                                         else
210                                                 coeff.print(c, precedence());
211                                 }
212                                 if (is_a<print_latex>(c))
213                                         c.s << ' ';
214                                 else
215                                         c.s << '*';
216                         }
217                         it->rest.print(c, precedence());
218                         ++it;
219                 }
220
221                 if (precedence() <= level) {
222                         if (is_a<print_latex>(c))
223                                 c.s << ")}";
224                         else
225                                 c.s << ")";
226                 }
227         }
228 }
229
230 bool add::info(unsigned inf) const
231 {
232         switch (inf) {
233                 case info_flags::polynomial:
234                 case info_flags::integer_polynomial:
235                 case info_flags::cinteger_polynomial:
236                 case info_flags::rational_polynomial:
237                 case info_flags::crational_polynomial:
238                 case info_flags::rational_function: {
239                         epvector::const_iterator i = seq.begin(), end = seq.end();
240                         while (i != end) {
241                                 if (!(recombine_pair_to_ex(*i).info(inf)))
242                                         return false;
243                                 ++i;
244                         }
245                         return overall_coeff.info(inf);
246                 }
247                 case info_flags::algebraic: {
248                         epvector::const_iterator i = seq.begin(), end = seq.end();
249                         while (i != end) {
250                                 if ((recombine_pair_to_ex(*i).info(inf)))
251                                         return true;
252                                 ++i;
253                         }
254                         return false;
255                 }
256         }
257         return inherited::info(inf);
258 }
259
260 int add::degree(const ex & s) const
261 {
262         int deg = INT_MIN;
263         if (!overall_coeff.is_zero())
264                 deg = 0;
265         
266         // Find maximum of degrees of individual terms
267         epvector::const_iterator i = seq.begin(), end = seq.end();
268         while (i != end) {
269                 int cur_deg = i->rest.degree(s);
270                 if (cur_deg > deg)
271                         deg = cur_deg;
272                 ++i;
273         }
274         return deg;
275 }
276
277 int add::ldegree(const ex & s) const
278 {
279         int deg = INT_MAX;
280         if (!overall_coeff.is_zero())
281                 deg = 0;
282         
283         // Find minimum of degrees of individual terms
284         epvector::const_iterator i = seq.begin(), end = seq.end();
285         while (i != end) {
286                 int cur_deg = i->rest.ldegree(s);
287                 if (cur_deg < deg)
288                         deg = cur_deg;
289                 ++i;
290         }
291         return deg;
292 }
293
294 ex add::coeff(const ex & s, int n) const
295 {
296         epvector *coeffseq = new epvector();
297
298         // Calculate sum of coefficients in each term
299         epvector::const_iterator i = seq.begin(), end = seq.end();
300         while (i != end) {
301                 ex restcoeff = i->rest.coeff(s, n);
302                 if (!restcoeff.is_zero())
303                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
304                 ++i;
305         }
306
307         return (new add(coeffseq, n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
308 }
309
310 /** Perform automatic term rewriting rules in this class.  In the following
311  *  x stands for a symbolic variables of type ex and c stands for such
312  *  an expression that contain a plain number.
313  *  - +(;c) -> c
314  *  - +(x;1) -> x
315  *
316  *  @param level cut-off in recursive evaluation */
317 ex add::eval(int level) const
318 {
319         debugmsg("add eval",LOGLEVEL_MEMBER_FUNCTION);
320         
321         epvector *evaled_seqp = evalchildren(level);
322         if (evaled_seqp) {
323                 // do more evaluation later
324                 return (new add(evaled_seqp, overall_coeff))->
325                        setflag(status_flags::dynallocated);
326         }
327         
328 #ifdef DO_GINAC_ASSERT
329         epvector::const_iterator i = seq.begin(), end = seq.end();
330         while (i != end) {
331                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
332                 if (is_ex_exactly_of_type(i->rest,numeric))
333                         dbgprint();
334                 GINAC_ASSERT(!is_exactly_a<numeric>(i->rest));
335                 ++i;
336         }
337 #endif // def DO_GINAC_ASSERT
338         
339         if (flags & status_flags::evaluated) {
340                 GINAC_ASSERT(seq.size()>0);
341                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
342                 return *this;
343         }
344         
345         int seq_size = seq.size();
346         if (seq_size == 0) {
347                 // +(;c) -> c
348                 return overall_coeff;
349         } else if (seq_size == 1 && overall_coeff.is_zero()) {
350                 // +(x;0) -> x
351                 return recombine_pair_to_ex(*(seq.begin()));
352         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
353                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
354         }
355         return this->hold();
356 }
357
358 ex add::evalm(void) const
359 {
360         // Evaluate children first and add up all matrices. Stop if there's one
361         // term that is not a matrix.
362         epvector *s = new epvector;
363         s->reserve(seq.size());
364
365         bool all_matrices = true;
366         bool first_term = true;
367         matrix sum;
368
369         epvector::const_iterator it = seq.begin(), itend = seq.end();
370         while (it != itend) {
371                 const ex &m = recombine_pair_to_ex(*it).evalm();
372                 s->push_back(split_ex_to_pair(m));
373                 if (is_ex_of_type(m, matrix)) {
374                         if (first_term) {
375                                 sum = ex_to<matrix>(m);
376                                 first_term = false;
377                         } else
378                                 sum = sum.add(ex_to<matrix>(m));
379                 } else
380                         all_matrices = false;
381                 it++;
382         }
383
384         if (all_matrices) {
385                 delete s;
386                 return sum + overall_coeff;
387         } else
388                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
389 }
390
391 ex add::simplify_ncmul(const exvector & v) const
392 {
393         if (seq.empty())
394                 return inherited::simplify_ncmul(v);
395         else
396                 return seq.begin()->rest.simplify_ncmul(v);
397 }    
398
399 // protected
400
401 /** Implementation of ex::diff() for a sum. It differentiates each term.
402  *  @see ex::diff */
403 ex add::derivative(const symbol & y) const
404 {
405         epvector *s = new epvector();
406         s->reserve(seq.size());
407         
408         // Only differentiate the "rest" parts of the expairs. This is faster
409         // than the default implementation in basic::derivative() although
410         // if performs the same function (differentiate each term).
411         epvector::const_iterator i = seq.begin(), end = seq.end();
412         while (i != end) {
413                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
414                 ++i;
415         }
416         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
417 }
418
419 int add::compare_same_type(const basic & other) const
420 {
421         return inherited::compare_same_type(other);
422 }
423
424 bool add::is_equal_same_type(const basic & other) const
425 {
426         return inherited::is_equal_same_type(other);
427 }
428
429 unsigned add::return_type(void) const
430 {
431         if (seq.empty())
432                 return return_types::commutative;
433         else
434                 return seq.begin()->rest.return_type();
435 }
436    
437 unsigned add::return_type_tinfo(void) const
438 {
439         if (seq.empty())
440                 return tinfo_key;
441         else
442                 return seq.begin()->rest.return_type_tinfo();
443 }
444
445 ex add::thisexpairseq(const epvector & v, const ex & oc) const
446 {
447         return (new add(v,oc))->setflag(status_flags::dynallocated);
448 }
449
450 ex add::thisexpairseq(epvector * vp, const ex & oc) const
451 {
452         return (new add(vp,oc))->setflag(status_flags::dynallocated);
453 }
454
455 expair add::split_ex_to_pair(const ex & e) const
456 {
457         if (is_ex_exactly_of_type(e,mul)) {
458                 const mul &mulref(ex_to<mul>(e));
459                 ex numfactor = mulref.overall_coeff;
460                 mul *mulcopyp = new mul(mulref);
461                 mulcopyp->overall_coeff = _ex1;
462                 mulcopyp->clearflag(status_flags::evaluated);
463                 mulcopyp->clearflag(status_flags::hash_calculated);
464                 mulcopyp->setflag(status_flags::dynallocated);
465                 return expair(*mulcopyp,numfactor);
466         }
467         return expair(e,_ex1);
468 }
469
470 expair add::combine_ex_with_coeff_to_pair(const ex & e,
471                                                                                   const ex & c) const
472 {
473         GINAC_ASSERT(is_exactly_a<numeric>(c));
474         if (is_ex_exactly_of_type(e, mul)) {
475                 const mul &mulref(ex_to<mul>(e));
476                 ex numfactor = mulref.overall_coeff;
477                 mul *mulcopyp = new mul(mulref);
478                 mulcopyp->overall_coeff = _ex1;
479                 mulcopyp->clearflag(status_flags::evaluated);
480                 mulcopyp->clearflag(status_flags::hash_calculated);
481                 mulcopyp->setflag(status_flags::dynallocated);
482                 if (are_ex_trivially_equal(c, _ex1))
483                         return expair(*mulcopyp, numfactor);
484                 else if (are_ex_trivially_equal(numfactor, _ex1))
485                         return expair(*mulcopyp, c);
486                 else
487                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
488         } else if (is_ex_exactly_of_type(e, numeric)) {
489                 if (are_ex_trivially_equal(c, _ex1))
490                         return expair(e, _ex1);
491                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
492         }
493         return expair(e, c);
494 }
495
496 expair add::combine_pair_with_coeff_to_pair(const expair & p,
497                                                                                         const ex & c) const
498 {
499         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
500         GINAC_ASSERT(is_exactly_a<numeric>(c));
501
502         if (is_ex_exactly_of_type(p.rest,numeric)) {
503                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(_num1)); // should be normalized
504                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
505         }
506
507         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
508 }
509         
510 ex add::recombine_pair_to_ex(const expair & p) const
511 {
512         if (ex_to<numeric>(p.coeff).is_equal(_num1))
513                 return p.rest;
514         else
515                 return p.rest*p.coeff;
516 }
517
518 ex add::expand(unsigned options) const
519 {
520         epvector *vp = expandchildren(options);
521         if (vp == NULL) {
522                 // the terms have not changed, so it is safe to declare this expanded
523                 return (options == 0) ? setflag(status_flags::expanded) : *this;
524         }
525         
526         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
527 }
528
529 } // namespace GiNaC