]> www.ginac.de Git - ginac.git/blob - ginac/add.cpp
re-enabled the assignment operator of class symbol
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2002 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "matrix.h"
29 #include "archive.h"
30 #include "utils.h"
31
32 namespace GiNaC {
33
34 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
35
36 //////////
37 // default ctor, dtor, copy ctor, assignment operator and helpers
38 //////////
39
40 add::add()
41 {
42         tinfo_key = TINFO_add;
43 }
44
45 DEFAULT_COPY(add)
46 DEFAULT_DESTROY(add)
47
48 //////////
49 // other constructors
50 //////////
51
52 // public
53
54 add::add(const ex & lh, const ex & rh)
55 {
56         tinfo_key = TINFO_add;
57         overall_coeff = _ex0;
58         construct_from_2_ex(lh,rh);
59         GINAC_ASSERT(is_canonical());
60 }
61
62 add::add(const exvector & v)
63 {
64         tinfo_key = TINFO_add;
65         overall_coeff = _ex0;
66         construct_from_exvector(v);
67         GINAC_ASSERT(is_canonical());
68 }
69
70 add::add(const epvector & v)
71 {
72         tinfo_key = TINFO_add;
73         overall_coeff = _ex0;
74         construct_from_epvector(v);
75         GINAC_ASSERT(is_canonical());
76 }
77
78 add::add(const epvector & v, const ex & oc)
79 {
80         tinfo_key = TINFO_add;
81         overall_coeff = oc;
82         construct_from_epvector(v);
83         GINAC_ASSERT(is_canonical());
84 }
85
86 add::add(epvector * vp, const ex & oc)
87 {
88         tinfo_key = TINFO_add;
89         GINAC_ASSERT(vp!=0);
90         overall_coeff = oc;
91         construct_from_epvector(*vp);
92         delete vp;
93         GINAC_ASSERT(is_canonical());
94 }
95
96 //////////
97 // archiving
98 //////////
99
100 DEFAULT_ARCHIVING(add)
101
102 //////////
103 // functions overriding virtual functions from base classes
104 //////////
105
106 // public
107
108 void add::print(const print_context & c, unsigned level) const
109 {
110         if (is_a<print_tree>(c)) {
111
112                 inherited::print(c, level);
113
114         } else if (is_a<print_csrc>(c)) {
115
116                 if (precedence() <= level)
117                         c.s << "(";
118         
119                 // Print arguments, separated by "+"
120                 epvector::const_iterator it = seq.begin(), itend = seq.end();
121                 while (it != itend) {
122                 
123                         // If the coefficient is -1, it is replaced by a single minus sign
124                         if (it->coeff.is_equal(_ex1)) {
125                                 it->rest.print(c, precedence());
126                         } else if (it->coeff.is_equal(_ex_1)) {
127                                 c.s << "-";
128                                 it->rest.print(c, precedence());
129                         } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num1)) {
130                                 it->rest.print(c, precedence());
131                                 c.s << "/";
132                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
133                         } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num_1)) {
134                                 c.s << "-";
135                                 it->rest.print(c, precedence());
136                                 c.s << "/";
137                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
138                         } else {
139                                 it->coeff.print(c, precedence());
140                                 c.s << "*";
141                                 it->rest.print(c, precedence());
142                         }
143                 
144                         // Separator is "+", except if the following expression would have a leading minus sign or the sign is sitting in parenthesis (as in a ctor)
145                         ++it;
146                         if (it != itend
147                                 && (is_a<print_csrc_cl_N>(c)  // sign inside ctor arguments
148                                         || !(it->coeff.info(info_flags::negative) || (it->coeff.is_equal(_num1) && is_exactly_a<numeric>(it->rest) && it->rest.info(info_flags::negative)))))
149                                 c.s << "+";
150                 }
151         
152                 if (!overall_coeff.is_zero()) {
153                         if (overall_coeff.info(info_flags::positive))
154                                 c.s << '+';
155                         overall_coeff.print(c, precedence());
156                 }
157                 
158                 if (precedence() <= level)
159                         c.s << ")";
160
161         } else if (is_a<print_python_repr>(c)) {
162
163                 c.s << class_name() << '(';
164                 op(0).print(c);
165                 for (unsigned i=1; i<nops(); ++i) {
166                         c.s << ',';
167                         op(i).print(c);
168                 }
169                 c.s << ')';
170
171         } else {
172
173                 if (precedence() <= level) {
174                         if (is_a<print_latex>(c))
175                                 c.s << "{(";
176                         else
177                                 c.s << "(";
178                 }
179
180                 numeric coeff;
181                 bool first = true;
182
183                 // First print the overall numeric coefficient, if present
184                 if (!overall_coeff.is_zero()) {
185                         if (!is_a<print_tree>(c))
186                                 overall_coeff.print(c, 0);
187                         else
188                                 overall_coeff.print(c, precedence());
189                         first = false;
190                 }
191
192                 // Then proceed with the remaining factors
193                 epvector::const_iterator it = seq.begin(), itend = seq.end();
194                 while (it != itend) {
195                         coeff = ex_to<numeric>(it->coeff);
196                         if (!first) {
197                                 if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
198                         } else {
199                                 if (coeff.csgn() == -1) c.s << '-';
200                                 first = false;
201                         }
202                         if (!coeff.is_equal(_num1) &&
203                             !coeff.is_equal(_num_1)) {
204                                 if (coeff.is_rational()) {
205                                         if (coeff.is_negative())
206                                                 (-coeff).print(c);
207                                         else
208                                                 coeff.print(c);
209                                 } else {
210                                         if (coeff.csgn() == -1)
211                                                 (-coeff).print(c, precedence());
212                                         else
213                                                 coeff.print(c, precedence());
214                                 }
215                                 if (is_a<print_latex>(c))
216                                         c.s << ' ';
217                                 else
218                                         c.s << '*';
219                         }
220                         it->rest.print(c, precedence());
221                         ++it;
222                 }
223
224                 if (precedence() <= level) {
225                         if (is_a<print_latex>(c))
226                                 c.s << ")}";
227                         else
228                                 c.s << ")";
229                 }
230         }
231 }
232
233 bool add::info(unsigned inf) const
234 {
235         switch (inf) {
236                 case info_flags::polynomial:
237                 case info_flags::integer_polynomial:
238                 case info_flags::cinteger_polynomial:
239                 case info_flags::rational_polynomial:
240                 case info_flags::crational_polynomial:
241                 case info_flags::rational_function: {
242                         epvector::const_iterator i = seq.begin(), end = seq.end();
243                         while (i != end) {
244                                 if (!(recombine_pair_to_ex(*i).info(inf)))
245                                         return false;
246                                 ++i;
247                         }
248                         return overall_coeff.info(inf);
249                 }
250                 case info_flags::algebraic: {
251                         epvector::const_iterator i = seq.begin(), end = seq.end();
252                         while (i != end) {
253                                 if ((recombine_pair_to_ex(*i).info(inf)))
254                                         return true;
255                                 ++i;
256                         }
257                         return false;
258                 }
259         }
260         return inherited::info(inf);
261 }
262
263 int add::degree(const ex & s) const
264 {
265         int deg = INT_MIN;
266         if (!overall_coeff.is_zero())
267                 deg = 0;
268         
269         // Find maximum of degrees of individual terms
270         epvector::const_iterator i = seq.begin(), end = seq.end();
271         while (i != end) {
272                 int cur_deg = i->rest.degree(s);
273                 if (cur_deg > deg)
274                         deg = cur_deg;
275                 ++i;
276         }
277         return deg;
278 }
279
280 int add::ldegree(const ex & s) const
281 {
282         int deg = INT_MAX;
283         if (!overall_coeff.is_zero())
284                 deg = 0;
285         
286         // Find minimum of degrees of individual terms
287         epvector::const_iterator i = seq.begin(), end = seq.end();
288         while (i != end) {
289                 int cur_deg = i->rest.ldegree(s);
290                 if (cur_deg < deg)
291                         deg = cur_deg;
292                 ++i;
293         }
294         return deg;
295 }
296
297 ex add::coeff(const ex & s, int n) const
298 {
299         epvector *coeffseq = new epvector();
300
301         // Calculate sum of coefficients in each term
302         epvector::const_iterator i = seq.begin(), end = seq.end();
303         while (i != end) {
304                 ex restcoeff = i->rest.coeff(s, n);
305                 if (!restcoeff.is_zero())
306                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
307                 ++i;
308         }
309
310         return (new add(coeffseq, n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
311 }
312
313 /** Perform automatic term rewriting rules in this class.  In the following
314  *  x stands for a symbolic variables of type ex and c stands for such
315  *  an expression that contain a plain number.
316  *  - +(;c) -> c
317  *  - +(x;1) -> x
318  *
319  *  @param level cut-off in recursive evaluation */
320 ex add::eval(int level) const
321 {
322         epvector *evaled_seqp = evalchildren(level);
323         if (evaled_seqp) {
324                 // do more evaluation later
325                 return (new add(evaled_seqp, overall_coeff))->
326                        setflag(status_flags::dynallocated);
327         }
328         
329 #ifdef DO_GINAC_ASSERT
330         epvector::const_iterator i = seq.begin(), end = seq.end();
331         while (i != end) {
332                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
333                 if (is_ex_exactly_of_type(i->rest,numeric))
334                         dbgprint();
335                 GINAC_ASSERT(!is_exactly_a<numeric>(i->rest));
336                 ++i;
337         }
338 #endif // def DO_GINAC_ASSERT
339         
340         if (flags & status_flags::evaluated) {
341                 GINAC_ASSERT(seq.size()>0);
342                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
343                 return *this;
344         }
345         
346         int seq_size = seq.size();
347         if (seq_size == 0) {
348                 // +(;c) -> c
349                 return overall_coeff;
350         } else if (seq_size == 1 && overall_coeff.is_zero()) {
351                 // +(x;0) -> x
352                 return recombine_pair_to_ex(*(seq.begin()));
353         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
354                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
355         }
356         return this->hold();
357 }
358
359 ex add::evalm(void) const
360 {
361         // Evaluate children first and add up all matrices. Stop if there's one
362         // term that is not a matrix.
363         epvector *s = new epvector;
364         s->reserve(seq.size());
365
366         bool all_matrices = true;
367         bool first_term = true;
368         matrix sum;
369
370         epvector::const_iterator it = seq.begin(), itend = seq.end();
371         while (it != itend) {
372                 const ex &m = recombine_pair_to_ex(*it).evalm();
373                 s->push_back(split_ex_to_pair(m));
374                 if (is_ex_of_type(m, matrix)) {
375                         if (first_term) {
376                                 sum = ex_to<matrix>(m);
377                                 first_term = false;
378                         } else
379                                 sum = sum.add(ex_to<matrix>(m));
380                 } else
381                         all_matrices = false;
382                 ++it;
383         }
384
385         if (all_matrices) {
386                 delete s;
387                 return sum + overall_coeff;
388         } else
389                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
390 }
391
392 ex add::simplify_ncmul(const exvector & v) const
393 {
394         if (seq.empty())
395                 return inherited::simplify_ncmul(v);
396         else
397                 return seq.begin()->rest.simplify_ncmul(v);
398 }    
399
400 // protected
401
402 /** Implementation of ex::diff() for a sum. It differentiates each term.
403  *  @see ex::diff */
404 ex add::derivative(const symbol & y) const
405 {
406         epvector *s = new epvector();
407         s->reserve(seq.size());
408         
409         // Only differentiate the "rest" parts of the expairs. This is faster
410         // than the default implementation in basic::derivative() although
411         // if performs the same function (differentiate each term).
412         epvector::const_iterator i = seq.begin(), end = seq.end();
413         while (i != end) {
414                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
415                 ++i;
416         }
417         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
418 }
419
420 int add::compare_same_type(const basic & other) const
421 {
422         return inherited::compare_same_type(other);
423 }
424
425 bool add::is_equal_same_type(const basic & other) const
426 {
427         return inherited::is_equal_same_type(other);
428 }
429
430 unsigned add::return_type(void) const
431 {
432         if (seq.empty())
433                 return return_types::commutative;
434         else
435                 return seq.begin()->rest.return_type();
436 }
437    
438 unsigned add::return_type_tinfo(void) const
439 {
440         if (seq.empty())
441                 return tinfo_key;
442         else
443                 return seq.begin()->rest.return_type_tinfo();
444 }
445
446 ex add::thisexpairseq(const epvector & v, const ex & oc) const
447 {
448         return (new add(v,oc))->setflag(status_flags::dynallocated);
449 }
450
451 ex add::thisexpairseq(epvector * vp, const ex & oc) const
452 {
453         return (new add(vp,oc))->setflag(status_flags::dynallocated);
454 }
455
456 expair add::split_ex_to_pair(const ex & e) const
457 {
458         if (is_ex_exactly_of_type(e,mul)) {
459                 const mul &mulref(ex_to<mul>(e));
460                 const ex &numfactor = mulref.overall_coeff;
461                 mul *mulcopyp = new mul(mulref);
462                 mulcopyp->overall_coeff = _ex1;
463                 mulcopyp->clearflag(status_flags::evaluated);
464                 mulcopyp->clearflag(status_flags::hash_calculated);
465                 mulcopyp->setflag(status_flags::dynallocated);
466                 return expair(*mulcopyp,numfactor);
467         }
468         return expair(e,_ex1);
469 }
470
471 expair add::combine_ex_with_coeff_to_pair(const ex & e,
472                                                                                   const ex & c) const
473 {
474         GINAC_ASSERT(is_exactly_a<numeric>(c));
475         if (is_ex_exactly_of_type(e, mul)) {
476                 const mul &mulref(ex_to<mul>(e));
477                 const ex &numfactor = mulref.overall_coeff;
478                 mul *mulcopyp = new mul(mulref);
479                 mulcopyp->overall_coeff = _ex1;
480                 mulcopyp->clearflag(status_flags::evaluated);
481                 mulcopyp->clearflag(status_flags::hash_calculated);
482                 mulcopyp->setflag(status_flags::dynallocated);
483                 if (are_ex_trivially_equal(c, _ex1))
484                         return expair(*mulcopyp, numfactor);
485                 else if (are_ex_trivially_equal(numfactor, _ex1))
486                         return expair(*mulcopyp, c);
487                 else
488                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
489         } else if (is_ex_exactly_of_type(e, numeric)) {
490                 if (are_ex_trivially_equal(c, _ex1))
491                         return expair(e, _ex1);
492                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
493         }
494         return expair(e, c);
495 }
496
497 expair add::combine_pair_with_coeff_to_pair(const expair & p,
498                                                                                         const ex & c) const
499 {
500         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
501         GINAC_ASSERT(is_exactly_a<numeric>(c));
502
503         if (is_ex_exactly_of_type(p.rest,numeric)) {
504                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(_num1)); // should be normalized
505                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
506         }
507
508         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
509 }
510         
511 ex add::recombine_pair_to_ex(const expair & p) const
512 {
513         if (ex_to<numeric>(p.coeff).is_equal(_num1))
514                 return p.rest;
515         else
516                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
517 }
518
519 ex add::expand(unsigned options) const
520 {
521         epvector *vp = expandchildren(options);
522         if (vp == NULL) {
523                 // the terms have not changed, so it is safe to declare this expanded
524                 return (options == 0) ? setflag(status_flags::expanded) : *this;
525         }
526         
527         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
528 }
529
530 } // namespace GiNaC