804f07ae8ddebc12aaf83be08a8e4c8018fd415e
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "debugmsg.h"
29
30 #ifndef NO_GINAC_NAMESPACE
31 namespace GiNaC {
32 #endif // ndef NO_GINAC_NAMESPACE
33
34 //////////
35 // default constructor, destructor, copy constructor assignment operator and helpers
36 //////////
37
38 // public
39
40 add::add()
41 {
42     debugmsg("add default constructor",LOGLEVEL_CONSTRUCT);
43     tinfo_key = TINFO_add;
44 }
45
46 add::~add()
47 {
48     debugmsg("add destructor",LOGLEVEL_DESTRUCT);
49     destroy(0);
50 }
51
52 add::add(add const & other)
53 {
54     debugmsg("add copy constructor",LOGLEVEL_CONSTRUCT);
55     copy(other);
56 }
57
58 add const & add::operator=(add const & other)
59 {
60     debugmsg("add operator=",LOGLEVEL_ASSIGNMENT);
61     if (this != &other) {
62         destroy(1);
63         copy(other);
64     }
65     return *this;
66 }
67
68 // protected
69
70 void add::copy(add const & other)
71 {
72     expairseq::copy(other);
73 }
74
75 void add::destroy(bool call_parent)
76 {
77     if (call_parent) expairseq::destroy(call_parent);
78 }
79
80 //////////
81 // other constructors
82 //////////
83
84 // public
85
86 add::add(ex const & lh, ex const & rh)
87 {
88     debugmsg("add constructor from ex,ex",LOGLEVEL_CONSTRUCT);
89     tinfo_key = TINFO_add;
90     overall_coeff=exZERO();
91     construct_from_2_ex(lh,rh);
92     GINAC_ASSERT(is_canonical());
93 }
94
95 add::add(exvector const & v)
96 {
97     debugmsg("add constructor from exvector",LOGLEVEL_CONSTRUCT);
98     tinfo_key = TINFO_add;
99     overall_coeff=exZERO();
100     construct_from_exvector(v);
101     GINAC_ASSERT(is_canonical());
102 }
103
104 /*
105 add::add(epvector const & v, bool do_not_canonicalize)
106 {
107     debugmsg("add constructor from epvector,bool",LOGLEVEL_CONSTRUCT);
108     tinfo_key = TINFO_add;
109     if (do_not_canonicalize) {
110         seq=v;
111 #ifdef EXPAIRSEQ_USE_HASHTAB
112         combine_same_terms(); // to build hashtab
113 #endif // def EXPAIRSEQ_USE_HASHTAB
114     } else {
115         construct_from_epvector(v);
116     }
117     GINAC_ASSERT(is_canonical());
118 }
119 */
120
121 add::add(epvector const & v)
122 {
123     debugmsg("add constructor from epvector",LOGLEVEL_CONSTRUCT);
124     tinfo_key = TINFO_add;
125     overall_coeff=exZERO();
126     construct_from_epvector(v);
127     GINAC_ASSERT(is_canonical());
128 }
129
130 add::add(epvector const & v, ex const & oc)
131 {
132     debugmsg("add constructor from epvector,ex",LOGLEVEL_CONSTRUCT);
133     tinfo_key = TINFO_add;
134     overall_coeff=oc;
135     construct_from_epvector(v);
136     GINAC_ASSERT(is_canonical());
137 }
138
139 add::add(epvector * vp, ex const & oc)
140 {
141     debugmsg("add constructor from epvector *,ex",LOGLEVEL_CONSTRUCT);
142     tinfo_key = TINFO_add;
143     GINAC_ASSERT(vp!=0);
144     overall_coeff=oc;
145     construct_from_epvector(*vp);
146     delete vp;
147     GINAC_ASSERT(is_canonical());
148 }
149
150 //////////
151 // functions overriding virtual functions from bases classes
152 //////////
153
154 // public
155
156 basic * add::duplicate() const
157 {
158     debugmsg("add duplicate",LOGLEVEL_DUPLICATE);
159     return new add(*this);
160 }
161
162 void add::print(ostream & os, unsigned upper_precedence) const
163 {
164     debugmsg("add print",LOGLEVEL_PRINT);
165     if (precedence<=upper_precedence) os << "(";
166     numeric coeff;
167     bool first=true;
168     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
169         coeff = ex_to_numeric(cit->coeff);
170         if (!first) {
171             if (coeff.csgn()==-1) os << '-'; else os << '+';
172         } else {
173             if (coeff.csgn()==-1) os << '-';
174             first=false;
175         }
176         if (!coeff.is_equal(numONE()) &&
177             !coeff.is_equal(numMINUSONE())) {
178             if (coeff.csgn()==-1)
179                 (numMINUSONE()*coeff).print(os, precedence);
180             else
181                 coeff.print(os, precedence);
182             os << '*';
183         }
184         os << cit->rest;
185     }
186     // print the overall numeric coefficient, if present:
187     if (!overall_coeff.is_zero()) {
188         if (overall_coeff.info(info_flags::positive)) os << '+';
189         os << overall_coeff;
190     }
191     if (precedence<=upper_precedence) os << ")";
192 }
193
194 void add::printraw(ostream & os) const
195 {
196     debugmsg("add printraw",LOGLEVEL_PRINT);
197
198     os << "+(";
199     for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
200         os << "(";
201         (*it).rest.bp->printraw(os);
202         os << ",";
203         (*it).coeff.bp->printraw(os);        
204         os << "),";
205     }
206     os << ",hash=" << hashvalue << ",flags=" << flags;
207     os << ")";
208 }
209
210 void add::printcsrc(ostream & os, unsigned type, unsigned upper_precedence) const
211 {
212     debugmsg("add print csrc", LOGLEVEL_PRINT);
213     if (precedence <= upper_precedence)
214         os << "(";
215
216     // Print arguments, separated by "+"
217     epvector::const_iterator it = seq.begin();
218     epvector::const_iterator itend = seq.end();
219     while (it != itend) {
220
221         // If the coefficient is -1, it is replaced by a single minus sign
222         if (it->coeff.compare(numONE()) == 0) {
223             it->rest.bp->printcsrc(os, type, precedence);
224         } else if (it->coeff.compare(numMINUSONE()) == 0) {
225             os << "-";
226             it->rest.bp->printcsrc(os, type, precedence);
227         } else if (ex_to_numeric(it->coeff).numer().compare(numONE()) == 0) {
228             it->rest.bp->printcsrc(os, type, precedence);
229             os << "/";
230             ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
231         } else if (ex_to_numeric(it->coeff).numer().compare(numMINUSONE()) == 0) {
232             os << "-";
233             it->rest.bp->printcsrc(os, type, precedence);
234             os << "/";
235             ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
236         } else {
237             it->coeff.bp->printcsrc(os, type, precedence);
238             os << "*";
239             it->rest.bp->printcsrc(os, type, precedence);
240         }
241
242         // Separator is "+", except if the following expression would have a leading minus sign
243         it++;
244         if (it != itend && !(it->coeff.compare(numZERO()) < 0 || (it->coeff.compare(numONE()) == 0 && is_ex_exactly_of_type(it->rest, numeric) && it->rest.compare(numZERO()) < 0)))
245             os << "+";
246     }
247     
248     if (!overall_coeff.is_equal(exZERO())) {
249         if (overall_coeff.info(info_flags::positive)) os << '+';
250         overall_coeff.bp->printcsrc(os,type,precedence);
251     }
252     
253     if (precedence <= upper_precedence)
254         os << ")";
255 }
256
257 bool add::info(unsigned inf) const
258 {
259     // TODO: optimize
260     if (inf==info_flags::polynomial ||
261         inf==info_flags::integer_polynomial ||
262         inf==info_flags::cinteger_polynomial ||
263         inf==info_flags::rational_polynomial ||
264         inf==info_flags::crational_polynomial ||
265         inf==info_flags::rational_function) {
266         for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
267             if (!(recombine_pair_to_ex(*it).info(inf)))
268                 return false;
269         }
270         return overall_coeff.info(inf);
271     } else {
272         return expairseq::info(inf);
273     }
274 }
275
276 int add::degree(symbol const & s) const
277 {
278     int deg=INT_MIN;
279     if (!overall_coeff.is_equal(exZERO())) {
280         deg=0;
281     }
282     int cur_deg;
283     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
284         cur_deg=(*cit).rest.degree(s);
285         if (cur_deg>deg) deg=cur_deg;
286     }
287     return deg;
288 }
289
290 int add::ldegree(symbol const & s) const
291 {
292     int deg=INT_MAX;
293     if (!overall_coeff.is_equal(exZERO())) {
294         deg=0;
295     }
296     int cur_deg;
297     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
298         cur_deg=(*cit).rest.ldegree(s);
299         if (cur_deg<deg) deg=cur_deg;
300     }
301     return deg;
302 }
303
304 ex add::coeff(symbol const & s, int const n) const
305 {
306     epvector coeffseq;
307     coeffseq.reserve(seq.size());
308
309     epvector::const_iterator it=seq.begin();
310     while (it!=seq.end()) {
311         coeffseq.push_back(combine_ex_with_coeff_to_pair((*it).rest.coeff(s,n),
312                                                          (*it).coeff));
313         ++it;
314     }
315     if (n==0) {
316         return (new add(coeffseq,overall_coeff))->setflag(status_flags::dynallocated);
317     }
318     return (new add(coeffseq))->setflag(status_flags::dynallocated);
319 }
320
321 ex add::eval(int level) const
322 {
323     // simplifications: +(;c) -> c
324     //                  +(x;1) -> x
325
326     debugmsg("add eval",LOGLEVEL_MEMBER_FUNCTION);
327
328     epvector * evaled_seqp=evalchildren(level);
329     if (evaled_seqp!=0) {
330         // do more evaluation later
331         return (new add(evaled_seqp,overall_coeff))->
332                    setflag(status_flags::dynallocated);
333     }
334
335 #ifdef DO_GINAC_ASSERT
336     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
337         GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,add));
338         if (is_ex_exactly_of_type((*cit).rest,numeric)) {
339             dbgprint();
340         }
341         GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,numeric));
342     }
343 #endif // def DO_GINAC_ASSERT
344
345     if (flags & status_flags::evaluated) {
346         GINAC_ASSERT(seq.size()>0);
347         GINAC_ASSERT((seq.size()>1)||!overall_coeff.is_equal(exZERO()));
348         return *this;
349     }
350
351     int seq_size=seq.size();
352     if (seq_size==0) {
353         // +(;c) -> c
354         return overall_coeff;
355     } else if ((seq_size==1)&&overall_coeff.is_equal(exZERO())) {
356         // +(x;0) -> x
357         return recombine_pair_to_ex(*(seq.begin()));
358     }
359     return this->hold();
360 }
361
362 exvector add::get_indices(void) const
363 {
364     // FIXME: all terms in the sum should have the same indices (compatible
365     // tensors) however this is not checked, since there is no function yet
366     // which compares indices (idxvector can be unsorted)
367     if (seq.size()==0) {
368         return exvector();
369     }
370     return (seq.begin())->rest.get_indices();
371 }    
372
373 ex add::simplify_ncmul(exvector const & v) const
374 {
375     if (seq.size()==0) {
376         return expairseq::simplify_ncmul(v);
377     }
378     return (*seq.begin()).rest.simplify_ncmul(v);
379 }    
380
381 // protected
382
383 int add::compare_same_type(basic const & other) const
384 {
385     return expairseq::compare_same_type(other);
386 }
387
388 bool add::is_equal_same_type(basic const & other) const
389 {
390     return expairseq::is_equal_same_type(other);
391 }
392
393 unsigned add::return_type(void) const
394 {
395     if (seq.size()==0) {
396         return return_types::commutative;
397     }
398     return (*seq.begin()).rest.return_type();
399 }
400    
401 unsigned add::return_type_tinfo(void) const
402 {
403     if (seq.size()==0) {
404         return tinfo_key;
405     }
406     return (*seq.begin()).rest.return_type_tinfo();
407 }
408
409 ex add::thisexpairseq(epvector const & v, ex const & oc) const
410 {
411     return (new add(v,oc))->setflag(status_flags::dynallocated);
412 }
413
414 ex add::thisexpairseq(epvector * vp, ex const & oc) const
415 {
416     return (new add(vp,oc))->setflag(status_flags::dynallocated);
417 }
418
419 expair add::split_ex_to_pair(ex const & e) const
420 {
421     if (is_ex_exactly_of_type(e,mul)) {
422         mul const & mulref=ex_to_mul(e);
423         ex numfactor=mulref.overall_coeff;
424         // mul * mulcopyp=static_cast<mul *>(mulref.duplicate());
425         mul * mulcopyp=new mul(mulref);
426         mulcopyp->overall_coeff=exONE();
427         mulcopyp->clearflag(status_flags::evaluated);
428         mulcopyp->clearflag(status_flags::hash_calculated);
429         return expair(mulcopyp->setflag(status_flags::dynallocated),numfactor);
430     }
431     return expair(e,exONE());
432 }
433
434 expair add::combine_ex_with_coeff_to_pair(ex const & e,
435                                           ex const & c) const
436 {
437     GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
438     if (is_ex_exactly_of_type(e,mul)) {
439         mul const & mulref=ex_to_mul(e);
440         ex numfactor=mulref.overall_coeff;
441         //mul * mulcopyp=static_cast<mul *>(mulref.duplicate());
442         mul * mulcopyp=new mul(mulref);
443         mulcopyp->overall_coeff=exONE();
444         mulcopyp->clearflag(status_flags::evaluated);
445         mulcopyp->clearflag(status_flags::hash_calculated);
446         if (are_ex_trivially_equal(c,exONE())) {
447             return expair(mulcopyp->setflag(status_flags::dynallocated),numfactor);
448         } else if (are_ex_trivially_equal(numfactor,exONE())) {
449             return expair(mulcopyp->setflag(status_flags::dynallocated),c);
450         }
451         return expair(mulcopyp->setflag(status_flags::dynallocated),
452                           ex_to_numeric(numfactor).mul_dyn(ex_to_numeric(c)));
453     } else if (is_ex_exactly_of_type(e,numeric)) {
454         if (are_ex_trivially_equal(c,exONE())) {
455             return expair(e,exONE());
456         }
457         return expair(ex_to_numeric(e).mul_dyn(ex_to_numeric(c)),exONE());
458     }
459     return expair(e,c);
460 }
461     
462 expair add::combine_pair_with_coeff_to_pair(expair const & p,
463                                             ex const & c) const
464 {
465     GINAC_ASSERT(is_ex_exactly_of_type(p.coeff,numeric));
466     GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
467
468     if (is_ex_exactly_of_type(p.rest,numeric)) {
469         GINAC_ASSERT(ex_to_numeric(p.coeff).is_equal(numONE())); // should be normalized
470         return expair(ex_to_numeric(p.rest).mul_dyn(ex_to_numeric(c)),exONE());
471     }
472
473     return expair(p.rest,ex_to_numeric(p.coeff).mul_dyn(ex_to_numeric(c)));
474 }
475     
476 ex add::recombine_pair_to_ex(expair const & p) const
477 {
478     //if (p.coeff.compare(exONE())==0) {
479     //if (are_ex_trivially_equal(p.coeff,exONE())) {
480     if (ex_to_numeric(p.coeff).is_equal(numONE())) {
481         return p.rest;
482     } else {
483         return p.rest*p.coeff;
484     }
485 }
486
487 ex add::expand(unsigned options) const
488 {
489     epvector * vp=expandchildren(options);
490     if (vp==0) {
491         return *this;
492     }
493     return (new add(vp,overall_coeff))->setflag(status_flags::expanded    |
494                                                 status_flags::dynallocated );
495 }
496
497 //////////
498 // new virtual functions which can be overridden by derived classes
499 //////////
500
501 // none
502
503 //////////
504 // non-virtual functions in this class
505 //////////
506
507 // none
508
509 //////////
510 // static member variables
511 //////////
512
513 // protected
514
515 unsigned add::precedence=40;
516
517 //////////
518 // global constants
519 //////////
520
521 const add some_add;
522 type_info const & typeid_add=typeid(some_add);
523
524 #ifndef NO_GINAC_NAMESPACE
525 } // namespace GiNaC
526 #endif // ndef NO_GINAC_NAMESPACE