3dcdf2cc4a7bc3733d49cef210230bbb0e2086b9
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "archive.h"
29 #include "operators.h"
30 #include "matrix.h"
31 #include "utils.h"
32
33 namespace GiNaC {
34
35 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(add, expairseq,
36   print_func<print_context>(&add::do_print).
37   print_func<print_latex>(&add::do_print_latex).
38   print_func<print_csrc>(&add::do_print_csrc).
39   print_func<print_tree>(&inherited::do_print_tree).
40   print_func<print_python_repr>(&add::do_print_python_repr))
41
42 //////////
43 // default constructor
44 //////////
45
46 add::add()
47 {
48         tinfo_key = TINFO_add;
49 }
50
51 //////////
52 // other constructors
53 //////////
54
55 // public
56
57 add::add(const ex & lh, const ex & rh)
58 {
59         tinfo_key = TINFO_add;
60         overall_coeff = _ex0;
61         construct_from_2_ex(lh,rh);
62         GINAC_ASSERT(is_canonical());
63 }
64
65 add::add(const exvector & v)
66 {
67         tinfo_key = TINFO_add;
68         overall_coeff = _ex0;
69         construct_from_exvector(v);
70         GINAC_ASSERT(is_canonical());
71 }
72
73 add::add(const epvector & v)
74 {
75         tinfo_key = TINFO_add;
76         overall_coeff = _ex0;
77         construct_from_epvector(v);
78         GINAC_ASSERT(is_canonical());
79 }
80
81 add::add(const epvector & v, const ex & oc)
82 {
83         tinfo_key = TINFO_add;
84         overall_coeff = oc;
85         construct_from_epvector(v);
86         GINAC_ASSERT(is_canonical());
87 }
88
89 add::add(epvector * vp, const ex & oc)
90 {
91         tinfo_key = TINFO_add;
92         GINAC_ASSERT(vp!=0);
93         overall_coeff = oc;
94         construct_from_epvector(*vp);
95         delete vp;
96         GINAC_ASSERT(is_canonical());
97 }
98
99 //////////
100 // archiving
101 //////////
102
103 DEFAULT_ARCHIVING(add)
104
105 //////////
106 // functions overriding virtual functions from base classes
107 //////////
108
109 // public
110
111 void add::print_add(const print_context & c, const char *openbrace, const char *closebrace, const char *mul_sym, unsigned level) const
112 {
113         if (precedence() <= level)
114                 c.s << openbrace << '(';
115
116         numeric coeff;
117         bool first = true;
118
119         // First print the overall numeric coefficient, if present
120         if (!overall_coeff.is_zero()) {
121                 overall_coeff.print(c, 0);
122                 first = false;
123         }
124
125         // Then proceed with the remaining factors
126         epvector::const_iterator it = seq.begin(), itend = seq.end();
127         while (it != itend) {
128                 coeff = ex_to<numeric>(it->coeff);
129                 if (!first) {
130                         if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
131                 } else {
132                         if (coeff.csgn() == -1) c.s << '-';
133                         first = false;
134                 }
135                 if (!coeff.is_equal(_num1) &&
136                     !coeff.is_equal(_num_1)) {
137                         if (coeff.is_rational()) {
138                                 if (coeff.is_negative())
139                                         (-coeff).print(c);
140                                 else
141                                         coeff.print(c);
142                         } else {
143                                 if (coeff.csgn() == -1)
144                                         (-coeff).print(c, precedence());
145                                 else
146                                         coeff.print(c, precedence());
147                         }
148                         c.s << mul_sym;
149                 }
150                 it->rest.print(c, precedence());
151                 ++it;
152         }
153
154         if (precedence() <= level)
155                 c.s << ')' << closebrace;
156 }
157
158 void add::do_print(const print_context & c, unsigned level) const
159 {
160         print_add(c, "", "", "*", level);
161 }
162
163 void add::do_print_latex(const print_latex & c, unsigned level) const
164 {
165         print_add(c, "{", "}", " ", level);
166 }
167
168 void add::do_print_csrc(const print_csrc & c, unsigned level) const
169 {
170         if (precedence() <= level)
171                 c.s << "(";
172         
173         // Print arguments, separated by "+"
174         epvector::const_iterator it = seq.begin(), itend = seq.end();
175         while (it != itend) {
176                 
177                 // If the coefficient is -1, it is replaced by a single minus sign
178                 if (it->coeff.is_equal(_ex1)) {
179                         it->rest.print(c, precedence());
180                 } else if (it->coeff.is_equal(_ex_1)) {
181                         c.s << "-";
182                         it->rest.print(c, precedence());
183                 } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num1)) {
184                         it->rest.print(c, precedence());
185                         c.s << "/";
186                         ex_to<numeric>(it->coeff).denom().print(c, precedence());
187                 } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num_1)) {
188                         c.s << "-";
189                         it->rest.print(c, precedence());
190                         c.s << "/";
191                         ex_to<numeric>(it->coeff).denom().print(c, precedence());
192                 } else {
193                         it->coeff.print(c, precedence());
194                         c.s << "*";
195                         it->rest.print(c, precedence());
196                 }
197                 
198                 // Separator is "+", except if the following expression would have a leading minus sign or the sign is sitting in parenthesis (as in a ctor)
199                 ++it;
200                 if (it != itend
201                  && (is_a<print_csrc_cl_N>(c) || !it->coeff.info(info_flags::real)  // sign inside ctor arguments
202                   || !(it->coeff.info(info_flags::negative) || (it->coeff.is_equal(_num1) && is_exactly_a<numeric>(it->rest) && it->rest.info(info_flags::negative)))))
203                         c.s << "+";
204         }
205         
206         if (!overall_coeff.is_zero()) {
207                 if (overall_coeff.info(info_flags::positive)
208                  || is_a<print_csrc_cl_N>(c) || !overall_coeff.info(info_flags::real))  // sign inside ctor argument
209                         c.s << '+';
210                 overall_coeff.print(c, precedence());
211         }
212                 
213         if (precedence() <= level)
214                 c.s << ")";
215 }
216
217 void add::do_print_python_repr(const print_python_repr & c, unsigned level) const
218 {
219         c.s << class_name() << '(';
220         op(0).print(c);
221         for (size_t i=1; i<nops(); ++i) {
222                 c.s << ',';
223                 op(i).print(c);
224         }
225         c.s << ')';
226 }
227
228 bool add::info(unsigned inf) const
229 {
230         switch (inf) {
231                 case info_flags::polynomial:
232                 case info_flags::integer_polynomial:
233                 case info_flags::cinteger_polynomial:
234                 case info_flags::rational_polynomial:
235                 case info_flags::crational_polynomial:
236                 case info_flags::rational_function: {
237                         epvector::const_iterator i = seq.begin(), end = seq.end();
238                         while (i != end) {
239                                 if (!(recombine_pair_to_ex(*i).info(inf)))
240                                         return false;
241                                 ++i;
242                         }
243                         return overall_coeff.info(inf);
244                 }
245                 case info_flags::algebraic: {
246                         epvector::const_iterator i = seq.begin(), end = seq.end();
247                         while (i != end) {
248                                 if ((recombine_pair_to_ex(*i).info(inf)))
249                                         return true;
250                                 ++i;
251                         }
252                         return false;
253                 }
254         }
255         return inherited::info(inf);
256 }
257
258 int add::degree(const ex & s) const
259 {
260         int deg = INT_MIN;
261         if (!overall_coeff.is_zero())
262                 deg = 0;
263         
264         // Find maximum of degrees of individual terms
265         epvector::const_iterator i = seq.begin(), end = seq.end();
266         while (i != end) {
267                 int cur_deg = i->rest.degree(s);
268                 if (cur_deg > deg)
269                         deg = cur_deg;
270                 ++i;
271         }
272         return deg;
273 }
274
275 int add::ldegree(const ex & s) const
276 {
277         int deg = INT_MAX;
278         if (!overall_coeff.is_zero())
279                 deg = 0;
280         
281         // Find minimum of degrees of individual terms
282         epvector::const_iterator i = seq.begin(), end = seq.end();
283         while (i != end) {
284                 int cur_deg = i->rest.ldegree(s);
285                 if (cur_deg < deg)
286                         deg = cur_deg;
287                 ++i;
288         }
289         return deg;
290 }
291
292 ex add::coeff(const ex & s, int n) const
293 {
294         epvector *coeffseq = new epvector();
295
296         // Calculate sum of coefficients in each term
297         epvector::const_iterator i = seq.begin(), end = seq.end();
298         while (i != end) {
299                 ex restcoeff = i->rest.coeff(s, n);
300                 if (!restcoeff.is_zero())
301                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
302                 ++i;
303         }
304
305         return (new add(coeffseq, n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
306 }
307
308 /** Perform automatic term rewriting rules in this class.  In the following
309  *  x stands for a symbolic variables of type ex and c stands for such
310  *  an expression that contain a plain number.
311  *  - +(;c) -> c
312  *  - +(x;0) -> x
313  *
314  *  @param level cut-off in recursive evaluation */
315 ex add::eval(int level) const
316 {
317         epvector *evaled_seqp = evalchildren(level);
318         if (evaled_seqp) {
319                 // do more evaluation later
320                 return (new add(evaled_seqp, overall_coeff))->
321                        setflag(status_flags::dynallocated);
322         }
323         
324 #ifdef DO_GINAC_ASSERT
325         epvector::const_iterator i = seq.begin(), end = seq.end();
326         while (i != end) {
327                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
328                 if (is_exactly_a<numeric>(i->rest))
329                         dbgprint();
330                 GINAC_ASSERT(!is_exactly_a<numeric>(i->rest));
331                 ++i;
332         }
333 #endif // def DO_GINAC_ASSERT
334         
335         if (flags & status_flags::evaluated) {
336                 GINAC_ASSERT(seq.size()>0);
337                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
338                 return *this;
339         }
340         
341         int seq_size = seq.size();
342         if (seq_size == 0) {
343                 // +(;c) -> c
344                 return overall_coeff;
345         } else if (seq_size == 1 && overall_coeff.is_zero()) {
346                 // +(x;0) -> x
347                 return recombine_pair_to_ex(*(seq.begin()));
348         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
349                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
350         }
351         return this->hold();
352 }
353
354 ex add::evalm() const
355 {
356         // Evaluate children first and add up all matrices. Stop if there's one
357         // term that is not a matrix.
358         epvector *s = new epvector;
359         s->reserve(seq.size());
360
361         bool all_matrices = true;
362         bool first_term = true;
363         matrix sum;
364
365         epvector::const_iterator it = seq.begin(), itend = seq.end();
366         while (it != itend) {
367                 const ex &m = recombine_pair_to_ex(*it).evalm();
368                 s->push_back(split_ex_to_pair(m));
369                 if (is_a<matrix>(m)) {
370                         if (first_term) {
371                                 sum = ex_to<matrix>(m);
372                                 first_term = false;
373                         } else
374                                 sum = sum.add(ex_to<matrix>(m));
375                 } else
376                         all_matrices = false;
377                 ++it;
378         }
379
380         if (all_matrices) {
381                 delete s;
382                 return sum + overall_coeff;
383         } else
384                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
385 }
386
387 ex add::eval_ncmul(const exvector & v) const
388 {
389         if (seq.empty())
390                 return inherited::eval_ncmul(v);
391         else
392                 return seq.begin()->rest.eval_ncmul(v);
393 }    
394
395 // protected
396
397 /** Implementation of ex::diff() for a sum. It differentiates each term.
398  *  @see ex::diff */
399 ex add::derivative(const symbol & y) const
400 {
401         epvector *s = new epvector();
402         s->reserve(seq.size());
403         
404         // Only differentiate the "rest" parts of the expairs. This is faster
405         // than the default implementation in basic::derivative() although
406         // if performs the same function (differentiate each term).
407         epvector::const_iterator i = seq.begin(), end = seq.end();
408         while (i != end) {
409                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
410                 ++i;
411         }
412         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
413 }
414
415 int add::compare_same_type(const basic & other) const
416 {
417         return inherited::compare_same_type(other);
418 }
419
420 unsigned add::return_type() const
421 {
422         if (seq.empty())
423                 return return_types::commutative;
424         else
425                 return seq.begin()->rest.return_type();
426 }
427    
428 unsigned add::return_type_tinfo() const
429 {
430         if (seq.empty())
431                 return tinfo_key;
432         else
433                 return seq.begin()->rest.return_type_tinfo();
434 }
435
436 ex add::thisexpairseq(const epvector & v, const ex & oc) const
437 {
438         return (new add(v,oc))->setflag(status_flags::dynallocated);
439 }
440
441 ex add::thisexpairseq(epvector * vp, const ex & oc) const
442 {
443         return (new add(vp,oc))->setflag(status_flags::dynallocated);
444 }
445
446 expair add::split_ex_to_pair(const ex & e) const
447 {
448         if (is_exactly_a<mul>(e)) {
449                 const mul &mulref(ex_to<mul>(e));
450                 const ex &numfactor = mulref.overall_coeff;
451                 mul *mulcopyp = new mul(mulref);
452                 mulcopyp->overall_coeff = _ex1;
453                 mulcopyp->clearflag(status_flags::evaluated);
454                 mulcopyp->clearflag(status_flags::hash_calculated);
455                 mulcopyp->setflag(status_flags::dynallocated);
456                 return expair(*mulcopyp,numfactor);
457         }
458         return expair(e,_ex1);
459 }
460
461 expair add::combine_ex_with_coeff_to_pair(const ex & e,
462                                                                                   const ex & c) const
463 {
464         GINAC_ASSERT(is_exactly_a<numeric>(c));
465         if (is_exactly_a<mul>(e)) {
466                 const mul &mulref(ex_to<mul>(e));
467                 const ex &numfactor = mulref.overall_coeff;
468                 mul *mulcopyp = new mul(mulref);
469                 mulcopyp->overall_coeff = _ex1;
470                 mulcopyp->clearflag(status_flags::evaluated);
471                 mulcopyp->clearflag(status_flags::hash_calculated);
472                 mulcopyp->setflag(status_flags::dynallocated);
473                 if (c.is_equal(_ex1))
474                         return expair(*mulcopyp, numfactor);
475                 else if (numfactor.is_equal(_ex1))
476                         return expair(*mulcopyp, c);
477                 else
478                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
479         } else if (is_exactly_a<numeric>(e)) {
480                 if (c.is_equal(_ex1))
481                         return expair(e, _ex1);
482                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
483         }
484         return expair(e, c);
485 }
486
487 expair add::combine_pair_with_coeff_to_pair(const expair & p,
488                                                                                         const ex & c) const
489 {
490         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
491         GINAC_ASSERT(is_exactly_a<numeric>(c));
492
493         if (is_exactly_a<numeric>(p.rest)) {
494                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(_num1)); // should be normalized
495                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
496         }
497
498         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
499 }
500         
501 ex add::recombine_pair_to_ex(const expair & p) const
502 {
503         if (ex_to<numeric>(p.coeff).is_equal(_num1))
504                 return p.rest;
505         else
506                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
507 }
508
509 ex add::expand(unsigned options) const
510 {
511         epvector *vp = expandchildren(options);
512         if (vp == NULL) {
513                 // the terms have not changed, so it is safe to declare this expanded
514                 return (options == 0) ? setflag(status_flags::expanded) : *this;
515         }
516
517         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
518 }
519
520 } // namespace GiNaC