]> www.ginac.de Git - ginac.git/blob - ginac/add.cpp
normal() uses an additional reverse lookup map
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "archive.h"
29 #include "operators.h"
30 #include "matrix.h"
31 #include "utils.h"
32
33 namespace GiNaC {
34
35 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
36
37 //////////
38 // default constructor
39 //////////
40
41 add::add()
42 {
43         tinfo_key = TINFO_add;
44 }
45
46 //////////
47 // other constructors
48 //////////
49
50 // public
51
52 add::add(const ex & lh, const ex & rh)
53 {
54         tinfo_key = TINFO_add;
55         overall_coeff = _ex0;
56         construct_from_2_ex(lh,rh);
57         GINAC_ASSERT(is_canonical());
58 }
59
60 add::add(const exvector & v)
61 {
62         tinfo_key = TINFO_add;
63         overall_coeff = _ex0;
64         construct_from_exvector(v);
65         GINAC_ASSERT(is_canonical());
66 }
67
68 add::add(const epvector & v)
69 {
70         tinfo_key = TINFO_add;
71         overall_coeff = _ex0;
72         construct_from_epvector(v);
73         GINAC_ASSERT(is_canonical());
74 }
75
76 add::add(const epvector & v, const ex & oc)
77 {
78         tinfo_key = TINFO_add;
79         overall_coeff = oc;
80         construct_from_epvector(v);
81         GINAC_ASSERT(is_canonical());
82 }
83
84 add::add(epvector * vp, const ex & oc)
85 {
86         tinfo_key = TINFO_add;
87         GINAC_ASSERT(vp!=0);
88         overall_coeff = oc;
89         construct_from_epvector(*vp);
90         delete vp;
91         GINAC_ASSERT(is_canonical());
92 }
93
94 //////////
95 // archiving
96 //////////
97
98 DEFAULT_ARCHIVING(add)
99
100 //////////
101 // functions overriding virtual functions from base classes
102 //////////
103
104 // public
105
106 void add::print(const print_context & c, unsigned level) const
107 {
108         if (is_a<print_tree>(c)) {
109
110                 inherited::print(c, level);
111
112         } else if (is_a<print_csrc>(c)) {
113
114                 if (precedence() <= level)
115                         c.s << "(";
116         
117                 // Print arguments, separated by "+"
118                 epvector::const_iterator it = seq.begin(), itend = seq.end();
119                 while (it != itend) {
120                 
121                         // If the coefficient is -1, it is replaced by a single minus sign
122                         if (it->coeff.is_equal(_ex1)) {
123                                 it->rest.print(c, precedence());
124                         } else if (it->coeff.is_equal(_ex_1)) {
125                                 c.s << "-";
126                                 it->rest.print(c, precedence());
127                         } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num1)) {
128                                 it->rest.print(c, precedence());
129                                 c.s << "/";
130                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
131                         } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num_1)) {
132                                 c.s << "-";
133                                 it->rest.print(c, precedence());
134                                 c.s << "/";
135                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
136                         } else {
137                                 it->coeff.print(c, precedence());
138                                 c.s << "*";
139                                 it->rest.print(c, precedence());
140                         }
141                 
142                         // Separator is "+", except if the following expression would have a leading minus sign or the sign is sitting in parenthesis (as in a ctor)
143                         ++it;
144                         if (it != itend
145                          && (is_a<print_csrc_cl_N>(c) || !it->coeff.info(info_flags::real)  // sign inside ctor arguments
146                           || !(it->coeff.info(info_flags::negative) || (it->coeff.is_equal(_num1) && is_exactly_a<numeric>(it->rest) && it->rest.info(info_flags::negative)))))
147                                 c.s << "+";
148                 }
149         
150                 if (!overall_coeff.is_zero()) {
151                         if (overall_coeff.info(info_flags::positive)
152                          || is_a<print_csrc_cl_N>(c) || !overall_coeff.info(info_flags::real))  // sign inside ctor argument
153                                 c.s << '+';
154                         overall_coeff.print(c, precedence());
155                 }
156                 
157                 if (precedence() <= level)
158                         c.s << ")";
159
160         } else if (is_a<print_python_repr>(c)) {
161
162                 c.s << class_name() << '(';
163                 op(0).print(c);
164                 for (size_t i=1; i<nops(); ++i) {
165                         c.s << ',';
166                         op(i).print(c);
167                 }
168                 c.s << ')';
169
170         } else {
171
172                 if (precedence() <= level) {
173                         if (is_a<print_latex>(c))
174                                 c.s << "{(";
175                         else
176                                 c.s << "(";
177                 }
178
179                 numeric coeff;
180                 bool first = true;
181
182                 // First print the overall numeric coefficient, if present
183                 if (!overall_coeff.is_zero()) {
184                         if (!is_a<print_tree>(c))
185                                 overall_coeff.print(c, 0);
186                         else
187                                 overall_coeff.print(c, precedence());
188                         first = false;
189                 }
190
191                 // Then proceed with the remaining factors
192                 epvector::const_iterator it = seq.begin(), itend = seq.end();
193                 while (it != itend) {
194                         coeff = ex_to<numeric>(it->coeff);
195                         if (!first) {
196                                 if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
197                         } else {
198                                 if (coeff.csgn() == -1) c.s << '-';
199                                 first = false;
200                         }
201                         if (!coeff.is_equal(_num1) &&
202                             !coeff.is_equal(_num_1)) {
203                                 if (coeff.is_rational()) {
204                                         if (coeff.is_negative())
205                                                 (-coeff).print(c);
206                                         else
207                                                 coeff.print(c);
208                                 } else {
209                                         if (coeff.csgn() == -1)
210                                                 (-coeff).print(c, precedence());
211                                         else
212                                                 coeff.print(c, precedence());
213                                 }
214                                 if (is_a<print_latex>(c))
215                                         c.s << ' ';
216                                 else
217                                         c.s << '*';
218                         }
219                         it->rest.print(c, precedence());
220                         ++it;
221                 }
222
223                 if (precedence() <= level) {
224                         if (is_a<print_latex>(c))
225                                 c.s << ")}";
226                         else
227                                 c.s << ")";
228                 }
229         }
230 }
231
232 bool add::info(unsigned inf) const
233 {
234         switch (inf) {
235                 case info_flags::polynomial:
236                 case info_flags::integer_polynomial:
237                 case info_flags::cinteger_polynomial:
238                 case info_flags::rational_polynomial:
239                 case info_flags::crational_polynomial:
240                 case info_flags::rational_function: {
241                         epvector::const_iterator i = seq.begin(), end = seq.end();
242                         while (i != end) {
243                                 if (!(recombine_pair_to_ex(*i).info(inf)))
244                                         return false;
245                                 ++i;
246                         }
247                         return overall_coeff.info(inf);
248                 }
249                 case info_flags::algebraic: {
250                         epvector::const_iterator i = seq.begin(), end = seq.end();
251                         while (i != end) {
252                                 if ((recombine_pair_to_ex(*i).info(inf)))
253                                         return true;
254                                 ++i;
255                         }
256                         return false;
257                 }
258         }
259         return inherited::info(inf);
260 }
261
262 int add::degree(const ex & s) const
263 {
264         int deg = INT_MIN;
265         if (!overall_coeff.is_zero())
266                 deg = 0;
267         
268         // Find maximum of degrees of individual terms
269         epvector::const_iterator i = seq.begin(), end = seq.end();
270         while (i != end) {
271                 int cur_deg = i->rest.degree(s);
272                 if (cur_deg > deg)
273                         deg = cur_deg;
274                 ++i;
275         }
276         return deg;
277 }
278
279 int add::ldegree(const ex & s) const
280 {
281         int deg = INT_MAX;
282         if (!overall_coeff.is_zero())
283                 deg = 0;
284         
285         // Find minimum of degrees of individual terms
286         epvector::const_iterator i = seq.begin(), end = seq.end();
287         while (i != end) {
288                 int cur_deg = i->rest.ldegree(s);
289                 if (cur_deg < deg)
290                         deg = cur_deg;
291                 ++i;
292         }
293         return deg;
294 }
295
296 ex add::coeff(const ex & s, int n) const
297 {
298         epvector *coeffseq = new epvector();
299
300         // Calculate sum of coefficients in each term
301         epvector::const_iterator i = seq.begin(), end = seq.end();
302         while (i != end) {
303                 ex restcoeff = i->rest.coeff(s, n);
304                 if (!restcoeff.is_zero())
305                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
306                 ++i;
307         }
308
309         return (new add(coeffseq, n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
310 }
311
312 /** Perform automatic term rewriting rules in this class.  In the following
313  *  x stands for a symbolic variables of type ex and c stands for such
314  *  an expression that contain a plain number.
315  *  - +(;c) -> c
316  *  - +(x;0) -> x
317  *
318  *  @param level cut-off in recursive evaluation */
319 ex add::eval(int level) const
320 {
321         epvector *evaled_seqp = evalchildren(level);
322         if (evaled_seqp) {
323                 // do more evaluation later
324                 return (new add(evaled_seqp, overall_coeff))->
325                        setflag(status_flags::dynallocated);
326         }
327         
328 #ifdef DO_GINAC_ASSERT
329         epvector::const_iterator i = seq.begin(), end = seq.end();
330         while (i != end) {
331                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
332                 if (is_exactly_a<numeric>(i->rest))
333                         dbgprint();
334                 GINAC_ASSERT(!is_exactly_a<numeric>(i->rest));
335                 ++i;
336         }
337 #endif // def DO_GINAC_ASSERT
338         
339         if (flags & status_flags::evaluated) {
340                 GINAC_ASSERT(seq.size()>0);
341                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
342                 return *this;
343         }
344         
345         int seq_size = seq.size();
346         if (seq_size == 0) {
347                 // +(;c) -> c
348                 return overall_coeff;
349         } else if (seq_size == 1 && overall_coeff.is_zero()) {
350                 // +(x;0) -> x
351                 return recombine_pair_to_ex(*(seq.begin()));
352         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
353                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
354         }
355         return this->hold();
356 }
357
358 ex add::evalm() const
359 {
360         // Evaluate children first and add up all matrices. Stop if there's one
361         // term that is not a matrix.
362         epvector *s = new epvector;
363         s->reserve(seq.size());
364
365         bool all_matrices = true;
366         bool first_term = true;
367         matrix sum;
368
369         epvector::const_iterator it = seq.begin(), itend = seq.end();
370         while (it != itend) {
371                 const ex &m = recombine_pair_to_ex(*it).evalm();
372                 s->push_back(split_ex_to_pair(m));
373                 if (is_a<matrix>(m)) {
374                         if (first_term) {
375                                 sum = ex_to<matrix>(m);
376                                 first_term = false;
377                         } else
378                                 sum = sum.add(ex_to<matrix>(m));
379                 } else
380                         all_matrices = false;
381                 ++it;
382         }
383
384         if (all_matrices) {
385                 delete s;
386                 return sum + overall_coeff;
387         } else
388                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
389 }
390
391 ex add::eval_ncmul(const exvector & v) const
392 {
393         if (seq.empty())
394                 return inherited::eval_ncmul(v);
395         else
396                 return seq.begin()->rest.eval_ncmul(v);
397 }    
398
399 // protected
400
401 /** Implementation of ex::diff() for a sum. It differentiates each term.
402  *  @see ex::diff */
403 ex add::derivative(const symbol & y) const
404 {
405         epvector *s = new epvector();
406         s->reserve(seq.size());
407         
408         // Only differentiate the "rest" parts of the expairs. This is faster
409         // than the default implementation in basic::derivative() although
410         // if performs the same function (differentiate each term).
411         epvector::const_iterator i = seq.begin(), end = seq.end();
412         while (i != end) {
413                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
414                 ++i;
415         }
416         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
417 }
418
419 int add::compare_same_type(const basic & other) const
420 {
421         return inherited::compare_same_type(other);
422 }
423
424 unsigned add::return_type() const
425 {
426         if (seq.empty())
427                 return return_types::commutative;
428         else
429                 return seq.begin()->rest.return_type();
430 }
431    
432 unsigned add::return_type_tinfo() const
433 {
434         if (seq.empty())
435                 return tinfo_key;
436         else
437                 return seq.begin()->rest.return_type_tinfo();
438 }
439
440 ex add::thisexpairseq(const epvector & v, const ex & oc) const
441 {
442         return (new add(v,oc))->setflag(status_flags::dynallocated);
443 }
444
445 ex add::thisexpairseq(epvector * vp, const ex & oc) const
446 {
447         return (new add(vp,oc))->setflag(status_flags::dynallocated);
448 }
449
450 expair add::split_ex_to_pair(const ex & e) const
451 {
452         if (is_exactly_a<mul>(e)) {
453                 const mul &mulref(ex_to<mul>(e));
454                 const ex &numfactor = mulref.overall_coeff;
455                 mul *mulcopyp = new mul(mulref);
456                 mulcopyp->overall_coeff = _ex1;
457                 mulcopyp->clearflag(status_flags::evaluated);
458                 mulcopyp->clearflag(status_flags::hash_calculated);
459                 mulcopyp->setflag(status_flags::dynallocated);
460                 return expair(*mulcopyp,numfactor);
461         }
462         return expair(e,_ex1);
463 }
464
465 expair add::combine_ex_with_coeff_to_pair(const ex & e,
466                                                                                   const ex & c) const
467 {
468         GINAC_ASSERT(is_exactly_a<numeric>(c));
469         if (is_exactly_a<mul>(e)) {
470                 const mul &mulref(ex_to<mul>(e));
471                 const ex &numfactor = mulref.overall_coeff;
472                 mul *mulcopyp = new mul(mulref);
473                 mulcopyp->overall_coeff = _ex1;
474                 mulcopyp->clearflag(status_flags::evaluated);
475                 mulcopyp->clearflag(status_flags::hash_calculated);
476                 mulcopyp->setflag(status_flags::dynallocated);
477                 if (c.is_equal(_ex1))
478                         return expair(*mulcopyp, numfactor);
479                 else if (numfactor.is_equal(_ex1))
480                         return expair(*mulcopyp, c);
481                 else
482                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
483         } else if (is_exactly_a<numeric>(e)) {
484                 if (c.is_equal(_ex1))
485                         return expair(e, _ex1);
486                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
487         }
488         return expair(e, c);
489 }
490
491 expair add::combine_pair_with_coeff_to_pair(const expair & p,
492                                                                                         const ex & c) const
493 {
494         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
495         GINAC_ASSERT(is_exactly_a<numeric>(c));
496
497         if (is_exactly_a<numeric>(p.rest)) {
498                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(_num1)); // should be normalized
499                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
500         }
501
502         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
503 }
504         
505 ex add::recombine_pair_to_ex(const expair & p) const
506 {
507         if (ex_to<numeric>(p.coeff).is_equal(_num1))
508                 return p.rest;
509         else
510                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
511 }
512
513 ex add::expand(unsigned options) const
514 {
515         epvector *vp = expandchildren(options);
516         if (vp == NULL) {
517                 // the terms have not changed, so it is safe to declare this expanded
518                 return (options == 0) ? setflag(status_flags::expanded) : *this;
519         }
520
521         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
522 }
523
524 } // namespace GiNaC