]> www.ginac.de Git - ginac.git/blob - ginac/add.cpp
pseries::expand(): do not generate zero terms.
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "archive.h"
29 #include "debugmsg.h"
30 #include "utils.h"
31
32 #ifndef NO_NAMESPACE_GINAC
33 namespace GiNaC {
34 #endif // ndef NO_NAMESPACE_GINAC
35
36 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
37
38 //////////
39 // default constructor, destructor, copy constructor assignment operator and helpers
40 //////////
41
42 // public
43
44 add::add()
45 {
46         debugmsg("add default constructor",LOGLEVEL_CONSTRUCT);
47         tinfo_key = TINFO_add;
48 }
49
50 // protected
51
52 void add::copy(const add & other)
53 {
54         inherited::copy(other);
55 }
56
57 void add::destroy(bool call_parent)
58 {
59         if (call_parent) inherited::destroy(call_parent);
60 }
61
62 //////////
63 // other constructors
64 //////////
65
66 // public
67
68 add::add(const ex & lh, const ex & rh)
69 {
70         debugmsg("add constructor from ex,ex",LOGLEVEL_CONSTRUCT);
71         tinfo_key = TINFO_add;
72         overall_coeff = _ex0();
73         construct_from_2_ex(lh,rh);
74         GINAC_ASSERT(is_canonical());
75 }
76
77 add::add(const exvector & v)
78 {
79         debugmsg("add constructor from exvector",LOGLEVEL_CONSTRUCT);
80         tinfo_key = TINFO_add;
81         overall_coeff = _ex0();
82         construct_from_exvector(v);
83         GINAC_ASSERT(is_canonical());
84 }
85
86 /*
87 add::add(const epvector & v, bool do_not_canonicalize)
88 {
89         debugmsg("add constructor from epvector,bool",LOGLEVEL_CONSTRUCT);
90         tinfo_key = TINFO_add;
91         if (do_not_canonicalize) {
92                 seq=v;
93 #ifdef EXPAIRSEQ_USE_HASHTAB
94                 combine_same_terms(); // to build hashtab
95 #endif // def EXPAIRSEQ_USE_HASHTAB
96         } else {
97                 construct_from_epvector(v);
98         }
99         GINAC_ASSERT(is_canonical());
100 }
101 */
102
103 add::add(const epvector & v)
104 {
105         debugmsg("add constructor from epvector",LOGLEVEL_CONSTRUCT);
106         tinfo_key = TINFO_add;
107         overall_coeff = _ex0();
108         construct_from_epvector(v);
109         GINAC_ASSERT(is_canonical());
110 }
111
112 add::add(const epvector & v, const ex & oc)
113 {
114         debugmsg("add constructor from epvector,ex",LOGLEVEL_CONSTRUCT);
115         tinfo_key = TINFO_add;
116         overall_coeff = oc;
117         construct_from_epvector(v);
118         GINAC_ASSERT(is_canonical());
119 }
120
121 add::add(epvector * vp, const ex & oc)
122 {
123         debugmsg("add constructor from epvector *,ex",LOGLEVEL_CONSTRUCT);
124         tinfo_key = TINFO_add;
125         GINAC_ASSERT(vp!=0);
126         overall_coeff = oc;
127         construct_from_epvector(*vp);
128         delete vp;
129         GINAC_ASSERT(is_canonical());
130 }
131
132 //////////
133 // archiving
134 //////////
135
136 /** Construct object from archive_node. */
137 add::add(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
138 {
139         debugmsg("add constructor from archive_node", LOGLEVEL_CONSTRUCT);
140 }
141
142 /** Unarchive the object. */
143 ex add::unarchive(const archive_node &n, const lst &sym_lst)
144 {
145         return (new add(n, sym_lst))->setflag(status_flags::dynallocated);
146 }
147
148 /** Archive the object. */
149 void add::archive(archive_node &n) const
150 {
151         inherited::archive(n);
152 }
153
154 //////////
155 // functions overriding virtual functions from bases classes
156 //////////
157
158 // public
159
160 void add::print(std::ostream & os, unsigned upper_precedence) const
161 {
162         debugmsg("add print",LOGLEVEL_PRINT);
163         if (precedence<=upper_precedence) os << "(";
164         numeric coeff;
165         bool first = true;
166         // first print the overall numeric coefficient, if present:
167         if (!overall_coeff.is_zero()) {
168                 os << overall_coeff;
169                 first = false;
170         }
171         // then proceed with the remaining factors:
172         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
173                 coeff = ex_to_numeric(cit->coeff);
174                 if (!first) {
175                         if (coeff.csgn()==-1) os << '-'; else os << '+';
176                 } else {
177                         if (coeff.csgn()==-1) os << '-';
178                         first = false;
179                 }
180                 if (!coeff.is_equal(_num1()) &&
181                     !coeff.is_equal(_num_1())) {
182                         if (coeff.is_rational()) {
183                                 if (coeff.is_negative())
184                                         os << -coeff;
185                                 else
186                                         os << coeff;
187                         } else {
188                                 if (coeff.csgn()==-1)
189                                         (-coeff).print(os, precedence);
190                                 else
191                                         coeff.print(os, precedence);
192                         }
193                         os << '*';
194                 }
195                 cit->rest.print(os, precedence);
196         }
197         if (precedence<=upper_precedence) os << ")";
198 }
199
200 void add::printraw(std::ostream & os) const
201 {
202         debugmsg("add printraw",LOGLEVEL_PRINT);
203
204         os << "+(";
205         for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
206                 os << "(";
207                 (*it).rest.bp->printraw(os);
208                 os << ",";
209                 (*it).coeff.bp->printraw(os);        
210                 os << "),";
211         }
212         os << ",hash=" << hashvalue << ",flags=" << flags;
213         os << ")";
214 }
215
216 void add::printcsrc(std::ostream & os, unsigned type, unsigned upper_precedence) const
217 {
218         debugmsg("add print csrc", LOGLEVEL_PRINT);
219         if (precedence <= upper_precedence)
220                 os << "(";
221
222         // Print arguments, separated by "+"
223         epvector::const_iterator it = seq.begin();
224         epvector::const_iterator itend = seq.end();
225         while (it != itend) {
226
227                 // If the coefficient is -1, it is replaced by a single minus sign
228                 if (it->coeff.compare(_num1()) == 0) {
229                         it->rest.bp->printcsrc(os, type, precedence);
230                 } else if (it->coeff.compare(_num_1()) == 0) {
231                         os << "-";
232                         it->rest.bp->printcsrc(os, type, precedence);
233                 } else if (ex_to_numeric(it->coeff).numer().compare(_num1()) == 0) {
234                         it->rest.bp->printcsrc(os, type, precedence);
235                         os << "/";
236                         ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
237                 } else if (ex_to_numeric(it->coeff).numer().compare(_num_1()) == 0) {
238                         os << "-";
239                         it->rest.bp->printcsrc(os, type, precedence);
240                         os << "/";
241                         ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
242                 } else {
243                         it->coeff.bp->printcsrc(os, type, precedence);
244                         os << "*";
245                         it->rest.bp->printcsrc(os, type, precedence);
246                 }
247
248                 // Separator is "+", except if the following expression would have a leading minus sign
249                 it++;
250                 if (it != itend && !(it->coeff.compare(_num0()) < 0 || (it->coeff.compare(_num1()) == 0 && is_ex_exactly_of_type(it->rest, numeric) && it->rest.compare(_num0()) < 0)))
251                         os << "+";
252         }
253         
254         if (!overall_coeff.is_equal(_ex0())) {
255                 if (overall_coeff.info(info_flags::positive)) os << '+';
256                 overall_coeff.bp->printcsrc(os,type,precedence);
257         }
258         
259         if (precedence <= upper_precedence)
260                 os << ")";
261 }
262
263 bool add::info(unsigned inf) const
264 {
265         switch (inf) {
266                 case info_flags::polynomial:
267                 case info_flags::integer_polynomial:
268                 case info_flags::cinteger_polynomial:
269                 case info_flags::rational_polynomial:
270                 case info_flags::crational_polynomial:
271                 case info_flags::rational_function: {
272                         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
273                                 if (!(recombine_pair_to_ex(*i).info(inf)))
274                                         return false;
275                         }
276                         return overall_coeff.info(inf);
277                 }
278                 case info_flags::algebraic: {
279                         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
280                                 if ((recombine_pair_to_ex(*i).info(inf)))
281                                         return true;
282                         }
283                         return false;
284                 }
285         }
286         return inherited::info(inf);
287 }
288
289 int add::degree(const symbol & s) const
290 {
291         int deg = INT_MIN;
292         if (!overall_coeff.is_equal(_ex0())) {
293                 deg = 0;
294         }
295         int cur_deg;
296         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
297                 cur_deg=(*cit).rest.degree(s);
298                 if (cur_deg>deg) deg=cur_deg;
299         }
300         return deg;
301 }
302
303 int add::ldegree(const symbol & s) const
304 {
305         int deg = INT_MAX;
306         if (!overall_coeff.is_equal(_ex0())) {
307                 deg = 0;
308         }
309         int cur_deg;
310         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
311                 cur_deg = (*cit).rest.ldegree(s);
312                 if (cur_deg<deg) deg=cur_deg;
313         }
314         return deg;
315 }
316
317 ex add::coeff(const symbol & s, int n) const
318 {
319         epvector coeffseq;
320         coeffseq.reserve(seq.size());
321
322         epvector::const_iterator it=seq.begin();
323         while (it!=seq.end()) {
324                 coeffseq.push_back(combine_ex_with_coeff_to_pair((*it).rest.coeff(s,n),
325                                                                  (*it).coeff));
326                 ++it;
327         }
328         if (n==0) {
329                 return (new add(coeffseq,overall_coeff))->setflag(status_flags::dynallocated);
330         }
331         return (new add(coeffseq))->setflag(status_flags::dynallocated);
332 }
333
334 ex add::eval(int level) const
335 {
336         // simplifications: +(;c) -> c
337         //                  +(x;1) -> x
338
339         debugmsg("add eval",LOGLEVEL_MEMBER_FUNCTION);
340
341         epvector * evaled_seqp=evalchildren(level);
342         if (evaled_seqp!=0) {
343                 // do more evaluation later
344                 return (new add(evaled_seqp,overall_coeff))->
345                        setflag(status_flags::dynallocated);
346         }
347         
348 #ifdef DO_GINAC_ASSERT
349         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
350                 GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,add));
351                 if (is_ex_exactly_of_type((*cit).rest,numeric)) {
352                         dbgprint();
353                 }
354                 GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,numeric));
355         }
356 #endif // def DO_GINAC_ASSERT
357         
358         if (flags & status_flags::evaluated) {
359                 GINAC_ASSERT(seq.size()>0);
360                 GINAC_ASSERT((seq.size()>1)||!overall_coeff.is_equal(_ex0()));
361                 return *this;
362         }
363         
364         int seq_size=seq.size();
365         if (seq_size==0) {
366                 // +(;c) -> c
367                 return overall_coeff;
368         } else if ((seq_size==1)&&overall_coeff.is_equal(_ex0())) {
369                 // +(x;0) -> x
370                 return recombine_pair_to_ex(*(seq.begin()));
371         }
372         return this->hold();
373 }
374
375 exvector add::get_indices(void) const
376 {
377         // FIXME: all terms in the sum should have the same indices (compatible
378         // tensors) however this is not checked, since there is no function yet
379         // which compares indices (idxvector can be unsorted)
380         if (seq.size()==0) {
381                 return exvector();
382         }
383         return (seq.begin())->rest.get_indices();
384 }    
385
386 ex add::simplify_ncmul(const exvector & v) const
387 {
388         if (seq.size()==0) {
389                 return inherited::simplify_ncmul(v);
390         }
391         return (*seq.begin()).rest.simplify_ncmul(v);
392 }    
393
394 // protected
395
396 /** Implementation of ex::diff() for a sum. It differentiates each term.
397  *  @see ex::diff */
398 ex add::derivative(const symbol & s) const
399 {
400         // D(a+b+c)=D(a)+D(b)+D(c)
401         return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
402 }
403
404 int add::compare_same_type(const basic & other) const
405 {
406         return inherited::compare_same_type(other);
407 }
408
409 bool add::is_equal_same_type(const basic & other) const
410 {
411         return inherited::is_equal_same_type(other);
412 }
413
414 unsigned add::return_type(void) const
415 {
416         if (seq.size()==0) {
417                 return return_types::commutative;
418         }
419         return (*seq.begin()).rest.return_type();
420 }
421    
422 unsigned add::return_type_tinfo(void) const
423 {
424         if (seq.size()==0) {
425                 return tinfo_key;
426         }
427         return (*seq.begin()).rest.return_type_tinfo();
428 }
429
430 ex add::thisexpairseq(const epvector & v, const ex & oc) const
431 {
432         return (new add(v,oc))->setflag(status_flags::dynallocated);
433 }
434
435 ex add::thisexpairseq(epvector * vp, const ex & oc) const
436 {
437         return (new add(vp,oc))->setflag(status_flags::dynallocated);
438 }
439
440 expair add::split_ex_to_pair(const ex & e) const
441 {
442         if (is_ex_exactly_of_type(e,mul)) {
443                 const mul & mulref = ex_to_mul(e);
444                 ex numfactor = mulref.overall_coeff;
445                 // mul * mulcopyp = static_cast<mul *>(mulref.duplicate());
446                 mul * mulcopyp = new mul(mulref);
447                 mulcopyp->overall_coeff = _ex1();
448                 mulcopyp->clearflag(status_flags::evaluated);
449                 mulcopyp->clearflag(status_flags::hash_calculated);
450                 return expair(mulcopyp->setflag(status_flags::dynallocated),numfactor);
451         }
452         return expair(e,_ex1());
453 }
454
455 expair add::combine_ex_with_coeff_to_pair(const ex & e,
456                                                                                   const ex & c) const
457 {
458         GINAC_ASSERT(is_ex_exactly_of_type(c, numeric));
459         if (is_ex_exactly_of_type(e, mul)) {
460                 const mul &mulref = ex_to_mul(e);
461                 ex numfactor = mulref.overall_coeff;
462                 mul *mulcopyp = new mul(mulref);
463                 mulcopyp->overall_coeff = _ex1();
464                 mulcopyp->clearflag(status_flags::evaluated);
465                 mulcopyp->clearflag(status_flags::hash_calculated);
466                 mulcopyp->setflag(status_flags::dynallocated);
467                 if (are_ex_trivially_equal(c, _ex1())) {
468                         return expair(*mulcopyp, numfactor);
469                 } else if (are_ex_trivially_equal(numfactor, _ex1())) {
470                         return expair(*mulcopyp, c);
471                 }
472                 return expair(*mulcopyp, ex_to_numeric(numfactor).mul_dyn(ex_to_numeric(c)));
473         } else if (is_ex_exactly_of_type(e, numeric)) {
474                 if (are_ex_trivially_equal(c, _ex1()))
475                         return expair(e, _ex1());
476                 return expair(ex_to_numeric(e).mul_dyn(ex_to_numeric(c)), _ex1());
477         }
478         return expair(e, c);
479 }
480         
481 expair add::combine_pair_with_coeff_to_pair(const expair & p,
482                                                                                         const ex & c) const
483 {
484         GINAC_ASSERT(is_ex_exactly_of_type(p.coeff,numeric));
485         GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
486
487         if (is_ex_exactly_of_type(p.rest,numeric)) {
488                 GINAC_ASSERT(ex_to_numeric(p.coeff).is_equal(_num1())); // should be normalized
489                 return expair(ex_to_numeric(p.rest).mul_dyn(ex_to_numeric(c)),_ex1());
490         }
491
492         return expair(p.rest,ex_to_numeric(p.coeff).mul_dyn(ex_to_numeric(c)));
493 }
494         
495 ex add::recombine_pair_to_ex(const expair & p) const
496 {
497         if (ex_to_numeric(p.coeff).is_equal(_num1()))
498                 return p.rest;
499         else
500                 return p.rest*p.coeff;
501 }
502
503 ex add::expand(unsigned options) const
504 {
505         if (flags & status_flags::expanded)
506                 return *this;
507         
508         epvector * vp = expandchildren(options);
509         if (vp==0)
510                 return *this;
511         
512         return (new add(vp,overall_coeff))->setflag(status_flags::expanded | status_flags::dynallocated);
513 }
514
515 //////////
516 // static member variables
517 //////////
518
519 // protected
520
521 unsigned add::precedence = 40;
522
523 #ifndef NO_NAMESPACE_GINAC
524 } // namespace GiNaC
525 #endif // ndef NO_NAMESPACE_GINAC