]> www.ginac.de Git - ginac.git/blob - ginac/add.cpp
0c22465a10021e54828d149bdda3502ee1e3e25b
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "archive.h"
29 #include "debugmsg.h"
30 #include "utils.h"
31
32 namespace GiNaC {
33
34 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
35
36 //////////
37 // default constructor, destructor, copy constructor assignment operator and helpers
38 //////////
39
40 // public
41
42 add::add()
43 {
44         debugmsg("add default constructor",LOGLEVEL_CONSTRUCT);
45         tinfo_key = TINFO_add;
46 }
47
48 // protected
49
50 /** For use by copy ctor and assignment operator. */
51 void add::copy(const add & other)
52 {
53         inherited::copy(other);
54 }
55
56 void add::destroy(bool call_parent)
57 {
58         if (call_parent) inherited::destroy(call_parent);
59 }
60
61 //////////
62 // other constructors
63 //////////
64
65 // public
66
67 add::add(const ex & lh, const ex & rh)
68 {
69         debugmsg("add constructor from ex,ex",LOGLEVEL_CONSTRUCT);
70         tinfo_key = TINFO_add;
71         overall_coeff = _ex0();
72         construct_from_2_ex(lh,rh);
73         GINAC_ASSERT(is_canonical());
74 }
75
76 add::add(const exvector & v)
77 {
78         debugmsg("add constructor from exvector",LOGLEVEL_CONSTRUCT);
79         tinfo_key = TINFO_add;
80         overall_coeff = _ex0();
81         construct_from_exvector(v);
82         GINAC_ASSERT(is_canonical());
83 }
84
85 add::add(const epvector & v)
86 {
87         debugmsg("add constructor from epvector",LOGLEVEL_CONSTRUCT);
88         tinfo_key = TINFO_add;
89         overall_coeff = _ex0();
90         construct_from_epvector(v);
91         GINAC_ASSERT(is_canonical());
92 }
93
94 add::add(const epvector & v, const ex & oc)
95 {
96         debugmsg("add constructor from epvector,ex",LOGLEVEL_CONSTRUCT);
97         tinfo_key = TINFO_add;
98         overall_coeff = oc;
99         construct_from_epvector(v);
100         GINAC_ASSERT(is_canonical());
101 }
102
103 add::add(epvector * vp, const ex & oc)
104 {
105         debugmsg("add constructor from epvector *,ex",LOGLEVEL_CONSTRUCT);
106         tinfo_key = TINFO_add;
107         GINAC_ASSERT(vp!=0);
108         overall_coeff = oc;
109         construct_from_epvector(*vp);
110         delete vp;
111         GINAC_ASSERT(is_canonical());
112 }
113
114 //////////
115 // archiving
116 //////////
117
118 /** Construct object from archive_node. */
119 add::add(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
120 {
121         debugmsg("add constructor from archive_node", LOGLEVEL_CONSTRUCT);
122 }
123
124 /** Unarchive the object. */
125 ex add::unarchive(const archive_node &n, const lst &sym_lst)
126 {
127         return (new add(n, sym_lst))->setflag(status_flags::dynallocated);
128 }
129
130 /** Archive the object. */
131 void add::archive(archive_node &n) const
132 {
133         inherited::archive(n);
134 }
135
136 //////////
137 // functions overriding virtual functions from bases classes
138 //////////
139
140 // public
141
142 void add::print(std::ostream & os, unsigned upper_precedence) const
143 {
144         debugmsg("add print",LOGLEVEL_PRINT);
145         if (precedence<=upper_precedence) os << "(";
146         numeric coeff;
147         bool first = true;
148         // first print the overall numeric coefficient, if present:
149         if (!overall_coeff.is_zero()) {
150                 os << overall_coeff;
151                 first = false;
152         }
153         // then proceed with the remaining factors:
154         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
155                 coeff = ex_to_numeric(cit->coeff);
156                 if (!first) {
157                         if (coeff.csgn()==-1) os << '-'; else os << '+';
158                 } else {
159                         if (coeff.csgn()==-1) os << '-';
160                         first = false;
161                 }
162                 if (!coeff.is_equal(_num1()) &&
163                     !coeff.is_equal(_num_1())) {
164                         if (coeff.is_rational()) {
165                                 if (coeff.is_negative())
166                                         os << -coeff;
167                                 else
168                                         os << coeff;
169                         } else {
170                                 if (coeff.csgn()==-1)
171                                         (-coeff).print(os, precedence);
172                                 else
173                                         coeff.print(os, precedence);
174                         }
175                         os << '*';
176                 }
177                 cit->rest.print(os, precedence);
178         }
179         if (precedence<=upper_precedence) os << ")";
180 }
181
182 void add::printraw(std::ostream & os) const
183 {
184         debugmsg("add printraw",LOGLEVEL_PRINT);
185
186         os << "+(";
187         for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
188                 os << "(";
189                 (*it).rest.bp->printraw(os);
190                 os << ",";
191                 (*it).coeff.bp->printraw(os);        
192                 os << "),";
193         }
194         os << ",hash=" << hashvalue << ",flags=" << flags;
195         os << ")";
196 }
197
198 void add::printcsrc(std::ostream & os, unsigned type, unsigned upper_precedence) const
199 {
200         debugmsg("add print csrc", LOGLEVEL_PRINT);
201         if (precedence <= upper_precedence)
202                 os << "(";
203         
204         // Print arguments, separated by "+"
205         epvector::const_iterator it = seq.begin();
206         epvector::const_iterator itend = seq.end();
207         while (it != itend) {
208                 
209                 // If the coefficient is -1, it is replaced by a single minus sign
210                 if (it->coeff.compare(_num1()) == 0) {
211                         it->rest.bp->printcsrc(os, type, precedence);
212                 } else if (it->coeff.compare(_num_1()) == 0) {
213                         os << "-";
214                         it->rest.bp->printcsrc(os, type, precedence);
215                 } else if (ex_to_numeric(it->coeff).numer().compare(_num1()) == 0) {
216                         it->rest.bp->printcsrc(os, type, precedence);
217                         os << "/";
218                         ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
219                 } else if (ex_to_numeric(it->coeff).numer().compare(_num_1()) == 0) {
220                         os << "-";
221                         it->rest.bp->printcsrc(os, type, precedence);
222                         os << "/";
223                         ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
224                 } else {
225                         it->coeff.bp->printcsrc(os, type, precedence);
226                         os << "*";
227                         it->rest.bp->printcsrc(os, type, precedence);
228                 }
229                 
230                 // Separator is "+", except if the following expression would have a leading minus sign
231                 it++;
232                 if (it != itend && !(it->coeff.compare(_num0()) < 0 || (it->coeff.compare(_num1()) == 0 && is_ex_exactly_of_type(it->rest, numeric) && it->rest.compare(_num0()) < 0)))
233                         os << "+";
234         }
235         
236         if (!overall_coeff.is_zero()) {
237                 if (overall_coeff.info(info_flags::positive)) os << '+';
238                 overall_coeff.bp->printcsrc(os,type,precedence);
239         }
240         
241         if (precedence <= upper_precedence)
242                 os << ")";
243 }
244
245 bool add::info(unsigned inf) const
246 {
247         switch (inf) {
248                 case info_flags::polynomial:
249                 case info_flags::integer_polynomial:
250                 case info_flags::cinteger_polynomial:
251                 case info_flags::rational_polynomial:
252                 case info_flags::crational_polynomial:
253                 case info_flags::rational_function: {
254                         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
255                                 if (!(recombine_pair_to_ex(*i).info(inf)))
256                                         return false;
257                         }
258                         return overall_coeff.info(inf);
259                 }
260                 case info_flags::algebraic: {
261                         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
262                                 if ((recombine_pair_to_ex(*i).info(inf)))
263                                         return true;
264                         }
265                         return false;
266                 }
267         }
268         return inherited::info(inf);
269 }
270
271 int add::degree(const symbol & s) const
272 {
273         int deg = INT_MIN;
274         if (!overall_coeff.is_equal(_ex0()))
275                 deg = 0;
276         
277         int cur_deg;
278         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
279                 cur_deg = (*cit).rest.degree(s);
280                 if (cur_deg>deg)
281                         deg = cur_deg;
282         }
283         return deg;
284 }
285
286 int add::ldegree(const symbol & s) const
287 {
288         int deg = INT_MAX;
289         if (!overall_coeff.is_equal(_ex0()))
290                 deg = 0;
291         
292         int cur_deg;
293         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
294                 cur_deg = (*cit).rest.ldegree(s);
295                 if (cur_deg<deg) deg=cur_deg;
296         }
297         return deg;
298 }
299
300 ex add::coeff(const symbol & s, int n) const
301 {
302         epvector coeffseq;
303         coeffseq.reserve(seq.size());
304
305         epvector::const_iterator it=seq.begin();
306         while (it!=seq.end()) {
307                 coeffseq.push_back(combine_ex_with_coeff_to_pair((*it).rest.coeff(s,n),
308                                                                  (*it).coeff));
309                 ++it;
310         }
311         if (n==0) {
312                 return (new add(coeffseq,overall_coeff))->setflag(status_flags::dynallocated);
313         }
314         return (new add(coeffseq))->setflag(status_flags::dynallocated);
315 }
316
317 ex add::eval(int level) const
318 {
319         // simplifications: +(;c) -> c
320         //                  +(x;1) -> x
321         
322         debugmsg("add eval",LOGLEVEL_MEMBER_FUNCTION);
323         
324         epvector * evaled_seqp = evalchildren(level);
325         if (evaled_seqp!=0) {
326                 // do more evaluation later
327                 return (new add(evaled_seqp,overall_coeff))->
328                        setflag(status_flags::dynallocated);
329         }
330         
331 #ifdef DO_GINAC_ASSERT
332         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
333                 GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,add));
334                 if (is_ex_exactly_of_type((*cit).rest,numeric))
335                         dbgprint();
336                 GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,numeric));
337         }
338 #endif // def DO_GINAC_ASSERT
339         
340         if (flags & status_flags::evaluated) {
341                 GINAC_ASSERT(seq.size()>0);
342                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
343                 return *this;
344         }
345         
346         int seq_size = seq.size();
347         if (seq_size==0) {
348                 // +(;c) -> c
349                 return overall_coeff;
350         } else if ((seq_size==1) && overall_coeff.is_equal(_ex0())) {
351                 // +(x;0) -> x
352                 return recombine_pair_to_ex(*(seq.begin()));
353         }
354         return this->hold();
355 }
356
357 exvector add::get_indices(void) const
358 {
359         // FIXME: all terms in the sum should have the same indices (compatible
360         // tensors) however this is not checked, since there is no function yet
361         // which compares indices (idxvector can be unsorted)
362         if (seq.size()==0) {
363                 return exvector();
364         }
365         return (seq.begin())->rest.get_indices();
366 }    
367
368 ex add::simplify_ncmul(const exvector & v) const
369 {
370         if (seq.size()==0) {
371                 return inherited::simplify_ncmul(v);
372         }
373         return (*seq.begin()).rest.simplify_ncmul(v);
374 }    
375
376 // protected
377
378 /** Implementation of ex::diff() for a sum. It differentiates each term.
379  *  @see ex::diff */
380 ex add::derivative(const symbol & s) const
381 {
382         // D(a+b+c)=D(a)+D(b)+D(c)
383         return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
384 }
385
386 int add::compare_same_type(const basic & other) const
387 {
388         return inherited::compare_same_type(other);
389 }
390
391 bool add::is_equal_same_type(const basic & other) const
392 {
393         return inherited::is_equal_same_type(other);
394 }
395
396 unsigned add::return_type(void) const
397 {
398         if (seq.size()==0) {
399                 return return_types::commutative;
400         }
401         return (*seq.begin()).rest.return_type();
402 }
403    
404 unsigned add::return_type_tinfo(void) const
405 {
406         if (seq.size()==0) {
407                 return tinfo_key;
408         }
409         return (*seq.begin()).rest.return_type_tinfo();
410 }
411
412 ex add::thisexpairseq(const epvector & v, const ex & oc) const
413 {
414         return (new add(v,oc))->setflag(status_flags::dynallocated);
415 }
416
417 ex add::thisexpairseq(epvector * vp, const ex & oc) const
418 {
419         return (new add(vp,oc))->setflag(status_flags::dynallocated);
420 }
421
422 expair add::split_ex_to_pair(const ex & e) const
423 {
424         if (is_ex_exactly_of_type(e,mul)) {
425                 const mul &mulref = ex_to_mul(e);
426                 ex numfactor = mulref.overall_coeff;
427                 mul *mulcopyp = new mul(mulref);
428                 mulcopyp->overall_coeff = _ex1();
429                 mulcopyp->clearflag(status_flags::evaluated);
430                 mulcopyp->clearflag(status_flags::hash_calculated);
431                 mulcopyp->setflag(status_flags::dynallocated);
432                 return expair(*mulcopyp,numfactor);
433         }
434         return expair(e,_ex1());
435 }
436
437 expair add::combine_ex_with_coeff_to_pair(const ex & e,
438                                                                                   const ex & c) const
439 {
440         GINAC_ASSERT(is_ex_exactly_of_type(c, numeric));
441         if (is_ex_exactly_of_type(e, mul)) {
442                 const mul &mulref = ex_to_mul(e);
443                 ex numfactor = mulref.overall_coeff;
444                 mul *mulcopyp = new mul(mulref);
445                 mulcopyp->overall_coeff = _ex1();
446                 mulcopyp->clearflag(status_flags::evaluated);
447                 mulcopyp->clearflag(status_flags::hash_calculated);
448                 mulcopyp->setflag(status_flags::dynallocated);
449                 if (are_ex_trivially_equal(c, _ex1()))
450                         return expair(*mulcopyp, numfactor);
451                 else if (are_ex_trivially_equal(numfactor, _ex1()))
452                         return expair(*mulcopyp, c);
453                 else
454                         return expair(*mulcopyp, ex_to_numeric(numfactor).mul_dyn(ex_to_numeric(c)));
455         } else if (is_ex_exactly_of_type(e, numeric)) {
456                 if (are_ex_trivially_equal(c, _ex1()))
457                         return expair(e, _ex1());
458                 return expair(ex_to_numeric(e).mul_dyn(ex_to_numeric(c)), _ex1());
459         }
460         return expair(e, c);
461 }
462
463 expair add::combine_pair_with_coeff_to_pair(const expair & p,
464                                                                                         const ex & c) const
465 {
466         GINAC_ASSERT(is_ex_exactly_of_type(p.coeff,numeric));
467         GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
468
469         if (is_ex_exactly_of_type(p.rest,numeric)) {
470                 GINAC_ASSERT(ex_to_numeric(p.coeff).is_equal(_num1())); // should be normalized
471                 return expair(ex_to_numeric(p.rest).mul_dyn(ex_to_numeric(c)),_ex1());
472         }
473
474         return expair(p.rest,ex_to_numeric(p.coeff).mul_dyn(ex_to_numeric(c)));
475 }
476         
477 ex add::recombine_pair_to_ex(const expair & p) const
478 {
479         if (ex_to_numeric(p.coeff).is_equal(_num1()))
480                 return p.rest;
481         else
482                 return p.rest*p.coeff;
483 }
484
485 ex add::expand(unsigned options) const
486 {
487         if (flags & status_flags::expanded)
488                 return *this;
489         
490         epvector * vp = expandchildren(options);
491         if (vp==0) {
492                 // the terms have not changed, so it is safe to declare this expanded
493                 setflag(status_flags::expanded);
494                 return *this;
495         }
496         
497         return (new add(vp,overall_coeff))->setflag(status_flags::expanded | status_flags::dynallocated);
498 }
499
500 //////////
501 // static member variables
502 //////////
503
504 // protected
505
506 unsigned add::precedence = 40;
507
508 } // namespace GiNaC