2b1c5d77ab84a3a54f2d089e406a037b5d09b553
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2000 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "archive.h"
29 #include "debugmsg.h"
30 #include "utils.h"
31
32 #ifndef NO_NAMESPACE_GINAC
33 namespace GiNaC {
34 #endif // ndef NO_NAMESPACE_GINAC
35
36 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
37
38 //////////
39 // default constructor, destructor, copy constructor assignment operator and helpers
40 //////////
41
42 // public
43
44 add::add()
45 {
46         debugmsg("add default constructor",LOGLEVEL_CONSTRUCT);
47         tinfo_key = TINFO_add;
48 }
49
50 add::~add()
51 {
52         debugmsg("add destructor",LOGLEVEL_DESTRUCT);
53         destroy(false);
54 }
55
56 add::add(const add & other)
57 {
58         debugmsg("add copy constructor",LOGLEVEL_CONSTRUCT);
59         copy(other);
60 }
61
62 const add & add::operator=(const add & other)
63 {
64         debugmsg("add operator=",LOGLEVEL_ASSIGNMENT);
65         if (this != &other) {
66                 destroy(true);
67                 copy(other);
68         }
69         return *this;
70 }
71
72 // protected
73
74 void add::copy(const add & other)
75 {
76         inherited::copy(other);
77 }
78
79 void add::destroy(bool call_parent)
80 {
81         if (call_parent) inherited::destroy(call_parent);
82 }
83
84 //////////
85 // other constructors
86 //////////
87
88 // public
89
90 add::add(const ex & lh, const ex & rh)
91 {
92         debugmsg("add constructor from ex,ex",LOGLEVEL_CONSTRUCT);
93         tinfo_key = TINFO_add;
94         overall_coeff = _ex0();
95         construct_from_2_ex(lh,rh);
96         GINAC_ASSERT(is_canonical());
97 }
98
99 add::add(const exvector & v)
100 {
101         debugmsg("add constructor from exvector",LOGLEVEL_CONSTRUCT);
102         tinfo_key = TINFO_add;
103         overall_coeff = _ex0();
104         construct_from_exvector(v);
105         GINAC_ASSERT(is_canonical());
106 }
107
108 /*
109 add::add(const epvector & v, bool do_not_canonicalize)
110 {
111         debugmsg("add constructor from epvector,bool",LOGLEVEL_CONSTRUCT);
112         tinfo_key = TINFO_add;
113         if (do_not_canonicalize) {
114                 seq=v;
115 #ifdef EXPAIRSEQ_USE_HASHTAB
116                 combine_same_terms(); // to build hashtab
117 #endif // def EXPAIRSEQ_USE_HASHTAB
118         } else {
119                 construct_from_epvector(v);
120         }
121         GINAC_ASSERT(is_canonical());
122 }
123 */
124
125 add::add(const epvector & v)
126 {
127         debugmsg("add constructor from epvector",LOGLEVEL_CONSTRUCT);
128         tinfo_key = TINFO_add;
129         overall_coeff = _ex0();
130         construct_from_epvector(v);
131         GINAC_ASSERT(is_canonical());
132 }
133
134 add::add(const epvector & v, const ex & oc)
135 {
136         debugmsg("add constructor from epvector,ex",LOGLEVEL_CONSTRUCT);
137         tinfo_key = TINFO_add;
138         overall_coeff = oc;
139         construct_from_epvector(v);
140         GINAC_ASSERT(is_canonical());
141 }
142
143 add::add(epvector * vp, const ex & oc)
144 {
145         debugmsg("add constructor from epvector *,ex",LOGLEVEL_CONSTRUCT);
146         tinfo_key = TINFO_add;
147         GINAC_ASSERT(vp!=0);
148         overall_coeff = oc;
149         construct_from_epvector(*vp);
150         delete vp;
151         GINAC_ASSERT(is_canonical());
152 }
153
154 //////////
155 // archiving
156 //////////
157
158 /** Construct object from archive_node. */
159 add::add(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
160 {
161         debugmsg("add constructor from archive_node", LOGLEVEL_CONSTRUCT);
162 }
163
164 /** Unarchive the object. */
165 ex add::unarchive(const archive_node &n, const lst &sym_lst)
166 {
167         return (new add(n, sym_lst))->setflag(status_flags::dynallocated);
168 }
169
170 /** Archive the object. */
171 void add::archive(archive_node &n) const
172 {
173         inherited::archive(n);
174 }
175
176 //////////
177 // functions overriding virtual functions from bases classes
178 //////////
179
180 // public
181
182 basic * add::duplicate() const
183 {
184         debugmsg("add duplicate",LOGLEVEL_DUPLICATE);
185         return new add(*this);
186 }
187
188 void add::print(std::ostream & os, unsigned upper_precedence) const
189 {
190         debugmsg("add print",LOGLEVEL_PRINT);
191         if (precedence<=upper_precedence) os << "(";
192         numeric coeff;
193         bool first = true;
194         // first print the overall numeric coefficient, if present:
195         if (!overall_coeff.is_zero()) {
196                 os << overall_coeff;
197                 first = false;
198         }
199         // then proceed with the remaining factors:
200         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
201                 coeff = ex_to_numeric(cit->coeff);
202                 if (!first) {
203                         if (coeff.csgn()==-1) os << '-'; else os << '+';
204                 } else {
205                         if (coeff.csgn()==-1) os << '-';
206                         first = false;
207                 }
208                 if (!coeff.is_equal(_num1()) &&
209                         !coeff.is_equal(_num_1())) {
210                         if (coeff.is_rational()) {
211                                 if (coeff.is_negative())
212                                         os << -coeff;
213                                 else
214                                         os << coeff;
215                         } else {
216                                 if (coeff.csgn()==-1)
217                                         (-coeff).print(os, precedence);
218                                 else
219                                         coeff.print(os, precedence);
220                         }
221                         os << '*';
222                 }
223                 os << cit->rest;
224         }
225         if (precedence<=upper_precedence) os << ")";
226 }
227
228 void add::printraw(std::ostream & os) const
229 {
230         debugmsg("add printraw",LOGLEVEL_PRINT);
231
232         os << "+(";
233         for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
234                 os << "(";
235                 (*it).rest.bp->printraw(os);
236                 os << ",";
237                 (*it).coeff.bp->printraw(os);        
238                 os << "),";
239         }
240         os << ",hash=" << hashvalue << ",flags=" << flags;
241         os << ")";
242 }
243
244 void add::printcsrc(std::ostream & os, unsigned type, unsigned upper_precedence) const
245 {
246         debugmsg("add print csrc", LOGLEVEL_PRINT);
247         if (precedence <= upper_precedence)
248                 os << "(";
249
250         // Print arguments, separated by "+"
251         epvector::const_iterator it = seq.begin();
252         epvector::const_iterator itend = seq.end();
253         while (it != itend) {
254
255                 // If the coefficient is -1, it is replaced by a single minus sign
256                 if (it->coeff.compare(_num1()) == 0) {
257                         it->rest.bp->printcsrc(os, type, precedence);
258                 } else if (it->coeff.compare(_num_1()) == 0) {
259                         os << "-";
260                         it->rest.bp->printcsrc(os, type, precedence);
261                 } else if (ex_to_numeric(it->coeff).numer().compare(_num1()) == 0) {
262                         it->rest.bp->printcsrc(os, type, precedence);
263                         os << "/";
264                         ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
265                 } else if (ex_to_numeric(it->coeff).numer().compare(_num_1()) == 0) {
266                         os << "-";
267                         it->rest.bp->printcsrc(os, type, precedence);
268                         os << "/";
269                         ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
270                 } else {
271                         it->coeff.bp->printcsrc(os, type, precedence);
272                         os << "*";
273                         it->rest.bp->printcsrc(os, type, precedence);
274                 }
275
276                 // Separator is "+", except if the following expression would have a leading minus sign
277                 it++;
278                 if (it != itend && !(it->coeff.compare(_num0()) < 0 || (it->coeff.compare(_num1()) == 0 && is_ex_exactly_of_type(it->rest, numeric) && it->rest.compare(_num0()) < 0)))
279                         os << "+";
280         }
281         
282         if (!overall_coeff.is_equal(_ex0())) {
283                 if (overall_coeff.info(info_flags::positive)) os << '+';
284                 overall_coeff.bp->printcsrc(os,type,precedence);
285         }
286         
287         if (precedence <= upper_precedence)
288                 os << ")";
289 }
290
291 bool add::info(unsigned inf) const
292 {
293         switch (inf) {
294                 case info_flags::polynomial:
295                 case info_flags::integer_polynomial:
296                 case info_flags::cinteger_polynomial:
297                 case info_flags::rational_polynomial:
298                 case info_flags::crational_polynomial:
299                 case info_flags::rational_function: {
300                         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
301                                 if (!(recombine_pair_to_ex(*i).info(inf)))
302                                         return false;
303                         }
304                         return overall_coeff.info(inf);
305                 }
306                 case info_flags::algebraic: {
307                         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
308                                 if ((recombine_pair_to_ex(*i).info(inf)))
309                                         return true;
310                         }
311                         return false;
312                 }
313         }
314         return inherited::info(inf);
315 }
316
317 int add::degree(const symbol & s) const
318 {
319         int deg = INT_MIN;
320         if (!overall_coeff.is_equal(_ex0())) {
321                 deg = 0;
322         }
323         int cur_deg;
324         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
325                 cur_deg=(*cit).rest.degree(s);
326                 if (cur_deg>deg) deg=cur_deg;
327         }
328         return deg;
329 }
330
331 int add::ldegree(const symbol & s) const
332 {
333         int deg = INT_MAX;
334         if (!overall_coeff.is_equal(_ex0())) {
335                 deg = 0;
336         }
337         int cur_deg;
338         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
339                 cur_deg = (*cit).rest.ldegree(s);
340                 if (cur_deg<deg) deg=cur_deg;
341         }
342         return deg;
343 }
344
345 ex add::coeff(const symbol & s, int n) const
346 {
347         epvector coeffseq;
348         coeffseq.reserve(seq.size());
349
350         epvector::const_iterator it=seq.begin();
351         while (it!=seq.end()) {
352                 coeffseq.push_back(combine_ex_with_coeff_to_pair((*it).rest.coeff(s,n),
353                                                                                                                  (*it).coeff));
354                 ++it;
355         }
356         if (n==0) {
357                 return (new add(coeffseq,overall_coeff))->setflag(status_flags::dynallocated);
358         }
359         return (new add(coeffseq))->setflag(status_flags::dynallocated);
360 }
361
362 ex add::eval(int level) const
363 {
364         // simplifications: +(;c) -> c
365         //                  +(x;1) -> x
366
367         debugmsg("add eval",LOGLEVEL_MEMBER_FUNCTION);
368
369         epvector * evaled_seqp=evalchildren(level);
370         if (evaled_seqp!=0) {
371                 // do more evaluation later
372                 return (new add(evaled_seqp,overall_coeff))->
373                                    setflag(status_flags::dynallocated);
374         }
375         
376 #ifdef DO_GINAC_ASSERT
377         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
378                 GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,add));
379                 if (is_ex_exactly_of_type((*cit).rest,numeric)) {
380                         dbgprint();
381                 }
382                 GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,numeric));
383         }
384 #endif // def DO_GINAC_ASSERT
385         
386         if (flags & status_flags::evaluated) {
387                 GINAC_ASSERT(seq.size()>0);
388                 GINAC_ASSERT((seq.size()>1)||!overall_coeff.is_equal(_ex0()));
389                 return *this;
390         }
391         
392         int seq_size=seq.size();
393         if (seq_size==0) {
394                 // +(;c) -> c
395                 return overall_coeff;
396         } else if ((seq_size==1)&&overall_coeff.is_equal(_ex0())) {
397                 // +(x;0) -> x
398                 return recombine_pair_to_ex(*(seq.begin()));
399         }
400         return this->hold();
401 }
402
403 exvector add::get_indices(void) const
404 {
405         // FIXME: all terms in the sum should have the same indices (compatible
406         // tensors) however this is not checked, since there is no function yet
407         // which compares indices (idxvector can be unsorted)
408         if (seq.size()==0) {
409                 return exvector();
410         }
411         return (seq.begin())->rest.get_indices();
412 }    
413
414 ex add::simplify_ncmul(const exvector & v) const
415 {
416         if (seq.size()==0) {
417                 return inherited::simplify_ncmul(v);
418         }
419         return (*seq.begin()).rest.simplify_ncmul(v);
420 }    
421
422 // protected
423
424 /** Implementation of ex::diff() for a sum. It differentiates each term.
425  *  @see ex::diff */
426 ex add::derivative(const symbol & s) const
427 {
428         // D(a+b+c)=D(a)+D(b)+D(c)
429         return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
430 }
431
432 int add::compare_same_type(const basic & other) const
433 {
434         return inherited::compare_same_type(other);
435 }
436
437 bool add::is_equal_same_type(const basic & other) const
438 {
439         return inherited::is_equal_same_type(other);
440 }
441
442 unsigned add::return_type(void) const
443 {
444         if (seq.size()==0) {
445                 return return_types::commutative;
446         }
447         return (*seq.begin()).rest.return_type();
448 }
449    
450 unsigned add::return_type_tinfo(void) const
451 {
452         if (seq.size()==0) {
453                 return tinfo_key;
454         }
455         return (*seq.begin()).rest.return_type_tinfo();
456 }
457
458 ex add::thisexpairseq(const epvector & v, const ex & oc) const
459 {
460         return (new add(v,oc))->setflag(status_flags::dynallocated);
461 }
462
463 ex add::thisexpairseq(epvector * vp, const ex & oc) const
464 {
465         return (new add(vp,oc))->setflag(status_flags::dynallocated);
466 }
467
468 expair add::split_ex_to_pair(const ex & e) const
469 {
470         if (is_ex_exactly_of_type(e,mul)) {
471                 const mul & mulref=ex_to_mul(e);
472                 ex numfactor=mulref.overall_coeff;
473                 // mul * mulcopyp=static_cast<mul *>(mulref.duplicate());
474                 mul * mulcopyp=new mul(mulref);
475                 mulcopyp->overall_coeff=_ex1();
476                 mulcopyp->clearflag(status_flags::evaluated);
477                 mulcopyp->clearflag(status_flags::hash_calculated);
478                 return expair(mulcopyp->setflag(status_flags::dynallocated),numfactor);
479         }
480         return expair(e,_ex1());
481 }
482
483 expair add::combine_ex_with_coeff_to_pair(const ex & e,
484                                                                                   const ex & c) const
485 {
486         GINAC_ASSERT(is_ex_exactly_of_type(c, numeric));
487         ex one = _ex1();
488         if (is_ex_exactly_of_type(e, mul)) {
489                 const mul &mulref = ex_to_mul(e);
490                 ex numfactor = mulref.overall_coeff;
491                 mul *mulcopyp = new mul(mulref);
492                 mulcopyp->overall_coeff = one;
493                 mulcopyp->clearflag(status_flags::evaluated);
494                 mulcopyp->clearflag(status_flags::hash_calculated);
495                 mulcopyp->setflag(status_flags::dynallocated);
496                 if (are_ex_trivially_equal(c, one)) {
497                         return expair(*mulcopyp, numfactor);
498                 } else if (are_ex_trivially_equal(numfactor, one)) {
499                         return expair(*mulcopyp, c);
500                 }
501                 return expair(*mulcopyp, ex_to_numeric(numfactor).mul_dyn(ex_to_numeric(c)));
502         } else if (is_ex_exactly_of_type(e, numeric)) {
503                 if (are_ex_trivially_equal(c, one)) {
504                         return expair(e, one);
505                 }
506                 return expair(ex_to_numeric(e).mul_dyn(ex_to_numeric(c)), one);
507         }
508         return expair(e, c);
509 }
510         
511 expair add::combine_pair_with_coeff_to_pair(const expair & p,
512                                                                                         const ex & c) const
513 {
514         GINAC_ASSERT(is_ex_exactly_of_type(p.coeff,numeric));
515         GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
516
517         if (is_ex_exactly_of_type(p.rest,numeric)) {
518                 GINAC_ASSERT(ex_to_numeric(p.coeff).is_equal(_num1())); // should be normalized
519                 return expair(ex_to_numeric(p.rest).mul_dyn(ex_to_numeric(c)),_ex1());
520         }
521
522         return expair(p.rest,ex_to_numeric(p.coeff).mul_dyn(ex_to_numeric(c)));
523 }
524         
525 ex add::recombine_pair_to_ex(const expair & p) const
526 {
527         if (ex_to_numeric(p.coeff).is_equal(_num1()))
528                 return p.rest;
529         else
530                 return p.rest*p.coeff;
531 }
532
533 ex add::expand(unsigned options) const
534 {
535         if (flags & status_flags::expanded)
536                 return *this;
537         
538         epvector * vp = expandchildren(options);
539         if (vp==0)
540                 return *this;
541         
542         return (new add(vp,overall_coeff))->setflag(status_flags::expanded | status_flags::dynallocated);
543 }
544
545 //////////
546 // new virtual functions which can be overridden by derived classes
547 //////////
548
549 // none
550
551 //////////
552 // non-virtual functions in this class
553 //////////
554
555 // none
556
557 //////////
558 // static member variables
559 //////////
560
561 // protected
562
563 unsigned add::precedence = 40;
564
565 //////////
566 // global constants
567 //////////
568
569 const add some_add;
570 const std::type_info & typeid_add = typeid(some_add);
571
572 #ifndef NO_NAMESPACE_GINAC
573 } // namespace GiNaC
574 #endif // ndef NO_NAMESPACE_GINAC