* Some minor optimization glitches.
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "matrix.h"
29 #include "archive.h"
30 #include "utils.h"
31
32 namespace GiNaC {
33
34 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
35
36 //////////
37 // default ctor, dtor, copy ctor, assignment operator and helpers
38 //////////
39
40 add::add()
41 {
42         tinfo_key = TINFO_add;
43 }
44
45 DEFAULT_COPY(add)
46 DEFAULT_DESTROY(add)
47
48 //////////
49 // other constructors
50 //////////
51
52 // public
53
54 add::add(const ex & lh, const ex & rh)
55 {
56         tinfo_key = TINFO_add;
57         overall_coeff = _ex0;
58         construct_from_2_ex(lh,rh);
59         GINAC_ASSERT(is_canonical());
60 }
61
62 add::add(const exvector & v)
63 {
64         tinfo_key = TINFO_add;
65         overall_coeff = _ex0;
66         construct_from_exvector(v);
67         GINAC_ASSERT(is_canonical());
68 }
69
70 add::add(const epvector & v)
71 {
72         tinfo_key = TINFO_add;
73         overall_coeff = _ex0;
74         construct_from_epvector(v);
75         GINAC_ASSERT(is_canonical());
76 }
77
78 add::add(const epvector & v, const ex & oc)
79 {
80         tinfo_key = TINFO_add;
81         overall_coeff = oc;
82         construct_from_epvector(v);
83         GINAC_ASSERT(is_canonical());
84 }
85
86 add::add(epvector * vp, const ex & oc)
87 {
88         tinfo_key = TINFO_add;
89         GINAC_ASSERT(vp!=0);
90         overall_coeff = oc;
91         construct_from_epvector(*vp);
92         delete vp;
93         GINAC_ASSERT(is_canonical());
94 }
95
96 //////////
97 // archiving
98 //////////
99
100 DEFAULT_ARCHIVING(add)
101
102 //////////
103 // functions overriding virtual functions from base classes
104 //////////
105
106 // public
107
108 void add::print(const print_context & c, unsigned level) const
109 {
110         if (is_a<print_tree>(c)) {
111
112                 inherited::print(c, level);
113
114         } else if (is_a<print_csrc>(c)) {
115
116                 if (precedence() <= level)
117                         c.s << "(";
118         
119                 // Print arguments, separated by "+"
120                 epvector::const_iterator it = seq.begin(), itend = seq.end();
121                 while (it != itend) {
122                 
123                         // If the coefficient is -1, it is replaced by a single minus sign
124                         if (it->coeff.compare(_num1) == 0) {
125                                 it->rest.print(c, precedence());
126                         } else if (it->coeff.compare(_num_1) == 0) {
127                                 c.s << "-";
128                                 it->rest.print(c, precedence());
129                         } else if (ex_to<numeric>(it->coeff).numer().compare(_num1) == 0) {
130                                 it->rest.print(c, precedence());
131                                 c.s << "/";
132                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
133                         } else if (ex_to<numeric>(it->coeff).numer().compare(_num_1) == 0) {
134                                 c.s << "-";
135                                 it->rest.print(c, precedence());
136                                 c.s << "/";
137                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
138                         } else {
139                                 it->coeff.print(c, precedence());
140                                 c.s << "*";
141                                 it->rest.print(c, precedence());
142                         }
143                 
144                         // Separator is "+", except if the following expression would have a leading minus sign
145                         ++it;
146                         if (it != itend && !(it->coeff.compare(_num0) < 0 || (it->coeff.compare(_num1) == 0 && is_exactly_a<numeric>(it->rest) && it->rest.compare(_num0) < 0)))
147                                 c.s << "+";
148                 }
149         
150                 if (!overall_coeff.is_zero()) {
151                         if (overall_coeff.info(info_flags::positive))
152                                 c.s << '+';
153                         overall_coeff.print(c, precedence());
154                 }
155         
156                 if (precedence() <= level)
157                         c.s << ")";
158
159         } else {
160
161                 if (precedence() <= level) {
162                         if (is_a<print_latex>(c))
163                                 c.s << "{(";
164                         else
165                                 c.s << "(";
166                 }
167
168                 numeric coeff;
169                 bool first = true;
170
171                 // First print the overall numeric coefficient, if present
172                 if (!overall_coeff.is_zero()) {
173                         if (!is_a<print_tree>(c))
174                                 overall_coeff.print(c, 0);
175                         else
176                                 overall_coeff.print(c, precedence());
177                         first = false;
178                 }
179
180                 // Then proceed with the remaining factors
181                 epvector::const_iterator it = seq.begin(), itend = seq.end();
182                 while (it != itend) {
183                         coeff = ex_to<numeric>(it->coeff);
184                         if (!first) {
185                                 if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
186                         } else {
187                                 if (coeff.csgn() == -1) c.s << '-';
188                                 first = false;
189                         }
190                         if (!coeff.is_equal(_num1) &&
191                             !coeff.is_equal(_num_1)) {
192                                 if (coeff.is_rational()) {
193                                         if (coeff.is_negative())
194                                                 (-coeff).print(c);
195                                         else
196                                                 coeff.print(c);
197                                 } else {
198                                         if (coeff.csgn() == -1)
199                                                 (-coeff).print(c, precedence());
200                                         else
201                                                 coeff.print(c, precedence());
202                                 }
203                                 if (is_a<print_latex>(c))
204                                         c.s << ' ';
205                                 else
206                                         c.s << '*';
207                         }
208                         it->rest.print(c, precedence());
209                         ++it;
210                 }
211
212                 if (precedence() <= level) {
213                         if (is_a<print_latex>(c))
214                                 c.s << ")}";
215                         else
216                                 c.s << ")";
217                 }
218         }
219 }
220
221 bool add::info(unsigned inf) const
222 {
223         switch (inf) {
224                 case info_flags::polynomial:
225                 case info_flags::integer_polynomial:
226                 case info_flags::cinteger_polynomial:
227                 case info_flags::rational_polynomial:
228                 case info_flags::crational_polynomial:
229                 case info_flags::rational_function: {
230                         epvector::const_iterator i = seq.begin(), end = seq.end();
231                         while (i != end) {
232                                 if (!(recombine_pair_to_ex(*i).info(inf)))
233                                         return false;
234                                 ++i;
235                         }
236                         return overall_coeff.info(inf);
237                 }
238                 case info_flags::algebraic: {
239                         epvector::const_iterator i = seq.begin(), end = seq.end();
240                         while (i != end) {
241                                 if ((recombine_pair_to_ex(*i).info(inf)))
242                                         return true;
243                                 ++i;
244                         }
245                         return false;
246                 }
247         }
248         return inherited::info(inf);
249 }
250
251 int add::degree(const ex & s) const
252 {
253         int deg = INT_MIN;
254         if (!overall_coeff.is_zero())
255                 deg = 0;
256         
257         // Find maximum of degrees of individual terms
258         epvector::const_iterator i = seq.begin(), end = seq.end();
259         while (i != end) {
260                 int cur_deg = i->rest.degree(s);
261                 if (cur_deg > deg)
262                         deg = cur_deg;
263                 ++i;
264         }
265         return deg;
266 }
267
268 int add::ldegree(const ex & s) const
269 {
270         int deg = INT_MAX;
271         if (!overall_coeff.is_zero())
272                 deg = 0;
273         
274         // Find minimum of degrees of individual terms
275         epvector::const_iterator i = seq.begin(), end = seq.end();
276         while (i != end) {
277                 int cur_deg = i->rest.ldegree(s);
278                 if (cur_deg < deg)
279                         deg = cur_deg;
280                 ++i;
281         }
282         return deg;
283 }
284
285 ex add::coeff(const ex & s, int n) const
286 {
287         epvector *coeffseq = new epvector();
288
289         // Calculate sum of coefficients in each term
290         epvector::const_iterator i = seq.begin(), end = seq.end();
291         while (i != end) {
292                 ex restcoeff = i->rest.coeff(s, n);
293                 if (!restcoeff.is_zero())
294                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
295                 ++i;
296         }
297
298         return (new add(coeffseq, n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
299 }
300
301 /** Perform automatic term rewriting rules in this class.  In the following
302  *  x stands for a symbolic variables of type ex and c stands for such
303  *  an expression that contain a plain number.
304  *  - +(;c) -> c
305  *  - +(x;1) -> x
306  *
307  *  @param level cut-off in recursive evaluation */
308 ex add::eval(int level) const
309 {
310         epvector *evaled_seqp = evalchildren(level);
311         if (evaled_seqp) {
312                 // do more evaluation later
313                 return (new add(evaled_seqp, overall_coeff))->
314                        setflag(status_flags::dynallocated);
315         }
316         
317 #ifdef DO_GINAC_ASSERT
318         epvector::const_iterator i = seq.begin(), end = seq.end();
319         while (i != end) {
320                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
321                 if (is_ex_exactly_of_type(i->rest,numeric))
322                         dbgprint();
323                 GINAC_ASSERT(!is_exactly_a<numeric>(i->rest));
324                 ++i;
325         }
326 #endif // def DO_GINAC_ASSERT
327         
328         if (flags & status_flags::evaluated) {
329                 GINAC_ASSERT(seq.size()>0);
330                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
331                 return *this;
332         }
333         
334         int seq_size = seq.size();
335         if (seq_size == 0) {
336                 // +(;c) -> c
337                 return overall_coeff;
338         } else if (seq_size == 1 && overall_coeff.is_zero()) {
339                 // +(x;0) -> x
340                 return recombine_pair_to_ex(*(seq.begin()));
341         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
342                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
343         }
344         return this->hold();
345 }
346
347 ex add::evalm(void) const
348 {
349         // Evaluate children first and add up all matrices. Stop if there's one
350         // term that is not a matrix.
351         epvector *s = new epvector;
352         s->reserve(seq.size());
353
354         bool all_matrices = true;
355         bool first_term = true;
356         matrix sum;
357
358         epvector::const_iterator it = seq.begin(), itend = seq.end();
359         while (it != itend) {
360                 const ex &m = recombine_pair_to_ex(*it).evalm();
361                 s->push_back(split_ex_to_pair(m));
362                 if (is_ex_of_type(m, matrix)) {
363                         if (first_term) {
364                                 sum = ex_to<matrix>(m);
365                                 first_term = false;
366                         } else
367                                 sum = sum.add(ex_to<matrix>(m));
368                 } else
369                         all_matrices = false;
370                 ++it;
371         }
372
373         if (all_matrices) {
374                 delete s;
375                 return sum + overall_coeff;
376         } else
377                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
378 }
379
380 ex add::simplify_ncmul(const exvector & v) const
381 {
382         if (seq.empty())
383                 return inherited::simplify_ncmul(v);
384         else
385                 return seq.begin()->rest.simplify_ncmul(v);
386 }    
387
388 // protected
389
390 /** Implementation of ex::diff() for a sum. It differentiates each term.
391  *  @see ex::diff */
392 ex add::derivative(const symbol & y) const
393 {
394         epvector *s = new epvector();
395         s->reserve(seq.size());
396         
397         // Only differentiate the "rest" parts of the expairs. This is faster
398         // than the default implementation in basic::derivative() although
399         // if performs the same function (differentiate each term).
400         epvector::const_iterator i = seq.begin(), end = seq.end();
401         while (i != end) {
402                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
403                 ++i;
404         }
405         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
406 }
407
408 int add::compare_same_type(const basic & other) const
409 {
410         return inherited::compare_same_type(other);
411 }
412
413 bool add::is_equal_same_type(const basic & other) const
414 {
415         return inherited::is_equal_same_type(other);
416 }
417
418 unsigned add::return_type(void) const
419 {
420         if (seq.empty())
421                 return return_types::commutative;
422         else
423                 return seq.begin()->rest.return_type();
424 }
425    
426 unsigned add::return_type_tinfo(void) const
427 {
428         if (seq.empty())
429                 return tinfo_key;
430         else
431                 return seq.begin()->rest.return_type_tinfo();
432 }
433
434 ex add::thisexpairseq(const epvector & v, const ex & oc) const
435 {
436         return (new add(v,oc))->setflag(status_flags::dynallocated);
437 }
438
439 ex add::thisexpairseq(epvector * vp, const ex & oc) const
440 {
441         return (new add(vp,oc))->setflag(status_flags::dynallocated);
442 }
443
444 expair add::split_ex_to_pair(const ex & e) const
445 {
446         if (is_ex_exactly_of_type(e,mul)) {
447                 const mul &mulref(ex_to<mul>(e));
448                 const ex &numfactor = mulref.overall_coeff;
449                 mul *mulcopyp = new mul(mulref);
450                 mulcopyp->overall_coeff = _ex1;
451                 mulcopyp->clearflag(status_flags::evaluated);
452                 mulcopyp->clearflag(status_flags::hash_calculated);
453                 mulcopyp->setflag(status_flags::dynallocated);
454                 return expair(*mulcopyp,numfactor);
455         }
456         return expair(e,_ex1);
457 }
458
459 expair add::combine_ex_with_coeff_to_pair(const ex & e,
460                                                                                   const ex & c) const
461 {
462         GINAC_ASSERT(is_exactly_a<numeric>(c));
463         if (is_ex_exactly_of_type(e, mul)) {
464                 const mul &mulref(ex_to<mul>(e));
465                 const ex &numfactor = mulref.overall_coeff;
466                 mul *mulcopyp = new mul(mulref);
467                 mulcopyp->overall_coeff = _ex1;
468                 mulcopyp->clearflag(status_flags::evaluated);
469                 mulcopyp->clearflag(status_flags::hash_calculated);
470                 mulcopyp->setflag(status_flags::dynallocated);
471                 if (are_ex_trivially_equal(c, _ex1))
472                         return expair(*mulcopyp, numfactor);
473                 else if (are_ex_trivially_equal(numfactor, _ex1))
474                         return expair(*mulcopyp, c);
475                 else
476                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
477         } else if (is_ex_exactly_of_type(e, numeric)) {
478                 if (are_ex_trivially_equal(c, _ex1))
479                         return expair(e, _ex1);
480                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
481         }
482         return expair(e, c);
483 }
484
485 expair add::combine_pair_with_coeff_to_pair(const expair & p,
486                                                                                         const ex & c) const
487 {
488         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
489         GINAC_ASSERT(is_exactly_a<numeric>(c));
490
491         if (is_ex_exactly_of_type(p.rest,numeric)) {
492                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(_num1)); // should be normalized
493                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
494         }
495
496         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
497 }
498         
499 ex add::recombine_pair_to_ex(const expair & p) const
500 {
501         if (ex_to<numeric>(p.coeff).is_equal(_num1))
502                 return p.rest;
503         else
504                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
505 }
506
507 ex add::expand(unsigned options) const
508 {
509         epvector *vp = expandchildren(options);
510         if (vp == NULL) {
511                 // the terms have not changed, so it is safe to declare this expanded
512                 return (options == 0) ? setflag(status_flags::expanded) : *this;
513         }
514         
515         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
516 }
517
518 } // namespace GiNaC