- Derivatives are now assembled in a slightly different manner (i.e. they
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2000 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "archive.h"
29 #include "debugmsg.h"
30 #include "utils.h"
31
32 #ifndef NO_NAMESPACE_GINAC
33 namespace GiNaC {
34 #endif // ndef NO_NAMESPACE_GINAC
35
36 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
37
38 //////////
39 // default constructor, destructor, copy constructor assignment operator and helpers
40 //////////
41
42 // public
43
44 add::add()
45 {
46     debugmsg("add default constructor",LOGLEVEL_CONSTRUCT);
47     tinfo_key = TINFO_add;
48 }
49
50 add::~add()
51 {
52     debugmsg("add destructor",LOGLEVEL_DESTRUCT);
53     destroy(0);
54 }
55
56 add::add(const add & other)
57 {
58     debugmsg("add copy constructor",LOGLEVEL_CONSTRUCT);
59     copy(other);
60 }
61
62 const add & add::operator=(const add & other)
63 {
64     debugmsg("add operator=",LOGLEVEL_ASSIGNMENT);
65     if (this != &other) {
66         destroy(1);
67         copy(other);
68     }
69     return *this;
70 }
71
72 // protected
73
74 void add::copy(const add & other)
75 {
76     inherited::copy(other);
77 }
78
79 void add::destroy(bool call_parent)
80 {
81     if (call_parent) inherited::destroy(call_parent);
82 }
83
84 //////////
85 // other constructors
86 //////////
87
88 // public
89
90 add::add(const ex & lh, const ex & rh)
91 {
92     debugmsg("add constructor from ex,ex",LOGLEVEL_CONSTRUCT);
93     tinfo_key = TINFO_add;
94     overall_coeff = _ex0();
95     construct_from_2_ex(lh,rh);
96     GINAC_ASSERT(is_canonical());
97 }
98
99 add::add(const exvector & v)
100 {
101     debugmsg("add constructor from exvector",LOGLEVEL_CONSTRUCT);
102     tinfo_key = TINFO_add;
103     overall_coeff = _ex0();
104     construct_from_exvector(v);
105     GINAC_ASSERT(is_canonical());
106 }
107
108 /*
109 add::add(const epvector & v, bool do_not_canonicalize)
110 {
111     debugmsg("add constructor from epvector,bool",LOGLEVEL_CONSTRUCT);
112     tinfo_key = TINFO_add;
113     if (do_not_canonicalize) {
114         seq=v;
115 #ifdef EXPAIRSEQ_USE_HASHTAB
116         combine_same_terms(); // to build hashtab
117 #endif // def EXPAIRSEQ_USE_HASHTAB
118     } else {
119         construct_from_epvector(v);
120     }
121     GINAC_ASSERT(is_canonical());
122 }
123 */
124
125 add::add(const epvector & v)
126 {
127     debugmsg("add constructor from epvector",LOGLEVEL_CONSTRUCT);
128     tinfo_key = TINFO_add;
129     overall_coeff = _ex0();
130     construct_from_epvector(v);
131     GINAC_ASSERT(is_canonical());
132 }
133
134 add::add(const epvector & v, const ex & oc)
135 {
136     debugmsg("add constructor from epvector,ex",LOGLEVEL_CONSTRUCT);
137     tinfo_key = TINFO_add;
138     overall_coeff = oc;
139     construct_from_epvector(v);
140     GINAC_ASSERT(is_canonical());
141 }
142
143 add::add(epvector * vp, const ex & oc)
144 {
145     debugmsg("add constructor from epvector *,ex",LOGLEVEL_CONSTRUCT);
146     tinfo_key = TINFO_add;
147     GINAC_ASSERT(vp!=0);
148     overall_coeff = oc;
149     construct_from_epvector(*vp);
150     delete vp;
151     GINAC_ASSERT(is_canonical());
152 }
153
154 //////////
155 // archiving
156 //////////
157
158 /** Construct object from archive_node. */
159 add::add(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
160 {
161     debugmsg("add constructor from archive_node", LOGLEVEL_CONSTRUCT);
162 }
163
164 /** Unarchive the object. */
165 ex add::unarchive(const archive_node &n, const lst &sym_lst)
166 {
167     return (new add(n, sym_lst))->setflag(status_flags::dynallocated);
168 }
169
170 /** Archive the object. */
171 void add::archive(archive_node &n) const
172 {
173     inherited::archive(n);
174 }
175
176 //////////
177 // functions overriding virtual functions from bases classes
178 //////////
179
180 // public
181
182 basic * add::duplicate() const
183 {
184     debugmsg("add duplicate",LOGLEVEL_DUPLICATE);
185     return new add(*this);
186 }
187
188 void add::print(std::ostream & os, unsigned upper_precedence) const
189 {
190     debugmsg("add print",LOGLEVEL_PRINT);
191     if (precedence<=upper_precedence) os << "(";
192     numeric coeff;
193     bool first = true;
194     // first print the overall numeric coefficient, if present:
195     if (!overall_coeff.is_zero()) {
196         os << overall_coeff;
197         first = false;
198     }
199     // then proceed with the remaining factors:
200     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
201         coeff = ex_to_numeric(cit->coeff);
202         if (!first) {
203             if (coeff.csgn()==-1) os << '-'; else os << '+';
204         } else {
205             if (coeff.csgn()==-1) os << '-';
206             first = false;
207         }
208         if (!coeff.is_equal(_num1()) &&
209             !coeff.is_equal(_num_1())) {
210             if (coeff.is_rational()) {
211                 if (coeff.is_negative())
212                     os << -coeff;
213                 else
214                     os << coeff;
215             } else {
216                 if (coeff.csgn()==-1)
217                     (-coeff).print(os, precedence);
218                 else
219                     coeff.print(os, precedence);
220             }
221             os << '*';
222         }
223         os << cit->rest;
224     }
225     if (precedence<=upper_precedence) os << ")";
226 }
227
228 void add::printraw(std::ostream & os) const
229 {
230     debugmsg("add printraw",LOGLEVEL_PRINT);
231
232     os << "+(";
233     for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
234         os << "(";
235         (*it).rest.bp->printraw(os);
236         os << ",";
237         (*it).coeff.bp->printraw(os);        
238         os << "),";
239     }
240     os << ",hash=" << hashvalue << ",flags=" << flags;
241     os << ")";
242 }
243
244 void add::printcsrc(std::ostream & os, unsigned type, unsigned upper_precedence) const
245 {
246     debugmsg("add print csrc", LOGLEVEL_PRINT);
247     if (precedence <= upper_precedence)
248         os << "(";
249
250     // Print arguments, separated by "+"
251     epvector::const_iterator it = seq.begin();
252     epvector::const_iterator itend = seq.end();
253     while (it != itend) {
254
255         // If the coefficient is -1, it is replaced by a single minus sign
256         if (it->coeff.compare(_num1()) == 0) {
257             it->rest.bp->printcsrc(os, type, precedence);
258         } else if (it->coeff.compare(_num_1()) == 0) {
259             os << "-";
260             it->rest.bp->printcsrc(os, type, precedence);
261         } else if (ex_to_numeric(it->coeff).numer().compare(_num1()) == 0) {
262             it->rest.bp->printcsrc(os, type, precedence);
263             os << "/";
264             ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
265         } else if (ex_to_numeric(it->coeff).numer().compare(_num_1()) == 0) {
266             os << "-";
267             it->rest.bp->printcsrc(os, type, precedence);
268             os << "/";
269             ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
270         } else {
271             it->coeff.bp->printcsrc(os, type, precedence);
272             os << "*";
273             it->rest.bp->printcsrc(os, type, precedence);
274         }
275
276         // Separator is "+", except if the following expression would have a leading minus sign
277         it++;
278         if (it != itend && !(it->coeff.compare(_num0()) < 0 || (it->coeff.compare(_num1()) == 0 && is_ex_exactly_of_type(it->rest, numeric) && it->rest.compare(_num0()) < 0)))
279             os << "+";
280     }
281     
282     if (!overall_coeff.is_equal(_ex0())) {
283         if (overall_coeff.info(info_flags::positive)) os << '+';
284         overall_coeff.bp->printcsrc(os,type,precedence);
285     }
286     
287     if (precedence <= upper_precedence)
288         os << ")";
289 }
290
291 bool add::info(unsigned inf) const
292 {
293     switch (inf) {
294         case info_flags::polynomial:
295         case info_flags::integer_polynomial:
296         case info_flags::cinteger_polynomial:
297         case info_flags::rational_polynomial:
298         case info_flags::crational_polynomial:
299         case info_flags::rational_function: {
300             for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
301                 if (!(recombine_pair_to_ex(*i).info(inf)))
302                     return false;
303             }
304             return overall_coeff.info(inf);
305         }
306         case info_flags::algebraic: {
307             for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
308                 if ((recombine_pair_to_ex(*i).info(inf)))
309                     return true;
310             }
311             return false;
312         }
313     }
314     return inherited::info(inf);
315 }
316
317 int add::degree(const symbol & s) const
318 {
319     int deg = INT_MIN;
320     if (!overall_coeff.is_equal(_ex0())) {
321         deg = 0;
322     }
323     int cur_deg;
324     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
325         cur_deg=(*cit).rest.degree(s);
326         if (cur_deg>deg) deg=cur_deg;
327     }
328     return deg;
329 }
330
331 int add::ldegree(const symbol & s) const
332 {
333     int deg = INT_MAX;
334     if (!overall_coeff.is_equal(_ex0())) {
335         deg = 0;
336     }
337     int cur_deg;
338     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
339         cur_deg = (*cit).rest.ldegree(s);
340         if (cur_deg<deg) deg=cur_deg;
341     }
342     return deg;
343 }
344
345 ex add::coeff(const symbol & s, int n) const
346 {
347     epvector coeffseq;
348     coeffseq.reserve(seq.size());
349
350     epvector::const_iterator it=seq.begin();
351     while (it!=seq.end()) {
352         coeffseq.push_back(combine_ex_with_coeff_to_pair((*it).rest.coeff(s,n),
353                                                          (*it).coeff));
354         ++it;
355     }
356     if (n==0) {
357         return (new add(coeffseq,overall_coeff))->setflag(status_flags::dynallocated);
358     }
359     return (new add(coeffseq))->setflag(status_flags::dynallocated);
360 }
361
362 ex add::eval(int level) const
363 {
364     // simplifications: +(;c) -> c
365     //                  +(x;1) -> x
366
367     debugmsg("add eval",LOGLEVEL_MEMBER_FUNCTION);
368
369     epvector * evaled_seqp=evalchildren(level);
370     if (evaled_seqp!=0) {
371         // do more evaluation later
372         return (new add(evaled_seqp,overall_coeff))->
373                    setflag(status_flags::dynallocated);
374     }
375     
376 #ifdef DO_GINAC_ASSERT
377     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
378         GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,add));
379         if (is_ex_exactly_of_type((*cit).rest,numeric)) {
380             dbgprint();
381         }
382         GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,numeric));
383     }
384 #endif // def DO_GINAC_ASSERT
385     
386     if (flags & status_flags::evaluated) {
387         GINAC_ASSERT(seq.size()>0);
388         GINAC_ASSERT((seq.size()>1)||!overall_coeff.is_equal(_ex0()));
389         return *this;
390     }
391     
392     int seq_size=seq.size();
393     if (seq_size==0) {
394         // +(;c) -> c
395         return overall_coeff;
396     } else if ((seq_size==1)&&overall_coeff.is_equal(_ex0())) {
397         // +(x;0) -> x
398         return recombine_pair_to_ex(*(seq.begin()));
399     }
400     return this->hold();
401 }
402
403 exvector add::get_indices(void) const
404 {
405     // FIXME: all terms in the sum should have the same indices (compatible
406     // tensors) however this is not checked, since there is no function yet
407     // which compares indices (idxvector can be unsorted)
408     if (seq.size()==0) {
409         return exvector();
410     }
411     return (seq.begin())->rest.get_indices();
412 }    
413
414 ex add::simplify_ncmul(const exvector & v) const
415 {
416     if (seq.size()==0) {
417         return inherited::simplify_ncmul(v);
418     }
419     return (*seq.begin()).rest.simplify_ncmul(v);
420 }    
421
422 // protected
423
424 /** Implementation of ex::diff() for a sum. It differentiates each term.
425  *  @see ex::diff */
426 ex add::derivative(const symbol & s) const
427 {
428     // D(a+b+c)=D(a)+D(b)+D(c)
429     return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
430 }
431
432 int add::compare_same_type(const basic & other) const
433 {
434     return inherited::compare_same_type(other);
435 }
436
437 bool add::is_equal_same_type(const basic & other) const
438 {
439     return inherited::is_equal_same_type(other);
440 }
441
442 unsigned add::return_type(void) const
443 {
444     if (seq.size()==0) {
445         return return_types::commutative;
446     }
447     return (*seq.begin()).rest.return_type();
448 }
449    
450 unsigned add::return_type_tinfo(void) const
451 {
452     if (seq.size()==0) {
453         return tinfo_key;
454     }
455     return (*seq.begin()).rest.return_type_tinfo();
456 }
457
458 ex add::thisexpairseq(const epvector & v, const ex & oc) const
459 {
460     return (new add(v,oc))->setflag(status_flags::dynallocated);
461 }
462
463 ex add::thisexpairseq(epvector * vp, const ex & oc) const
464 {
465     return (new add(vp,oc))->setflag(status_flags::dynallocated);
466 }
467
468 expair add::split_ex_to_pair(const ex & e) const
469 {
470     if (is_ex_exactly_of_type(e,mul)) {
471         const mul & mulref=ex_to_mul(e);
472         ex numfactor=mulref.overall_coeff;
473         // mul * mulcopyp=static_cast<mul *>(mulref.duplicate());
474         mul * mulcopyp=new mul(mulref);
475         mulcopyp->overall_coeff=_ex1();
476         mulcopyp->clearflag(status_flags::evaluated);
477         mulcopyp->clearflag(status_flags::hash_calculated);
478         return expair(mulcopyp->setflag(status_flags::dynallocated),numfactor);
479     }
480     return expair(e,_ex1());
481 }
482
483 expair add::combine_ex_with_coeff_to_pair(const ex & e,
484                                           const ex & c) const
485 {
486     GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
487     if (is_ex_exactly_of_type(e,mul)) {
488         const mul & mulref=ex_to_mul(e);
489         ex numfactor=mulref.overall_coeff;
490         //mul * mulcopyp=static_cast<mul *>(mulref.duplicate());
491         mul * mulcopyp=new mul(mulref);
492         mulcopyp->overall_coeff=_ex1();
493         mulcopyp->clearflag(status_flags::evaluated);
494         mulcopyp->clearflag(status_flags::hash_calculated);
495         if (are_ex_trivially_equal(c,_ex1())) {
496             return expair(mulcopyp->setflag(status_flags::dynallocated),numfactor);
497         } else if (are_ex_trivially_equal(numfactor,_ex1())) {
498             return expair(mulcopyp->setflag(status_flags::dynallocated),c);
499         }
500         return expair(mulcopyp->setflag(status_flags::dynallocated),
501                           ex_to_numeric(numfactor).mul_dyn(ex_to_numeric(c)));
502     } else if (is_ex_exactly_of_type(e,numeric)) {
503         if (are_ex_trivially_equal(c,_ex1())) {
504             return expair(e,_ex1());
505         }
506         return expair(ex_to_numeric(e).mul_dyn(ex_to_numeric(c)),_ex1());
507     }
508     return expair(e,c);
509 }
510     
511 expair add::combine_pair_with_coeff_to_pair(const expair & p,
512                                             const ex & c) const
513 {
514     GINAC_ASSERT(is_ex_exactly_of_type(p.coeff,numeric));
515     GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
516
517     if (is_ex_exactly_of_type(p.rest,numeric)) {
518         GINAC_ASSERT(ex_to_numeric(p.coeff).is_equal(_num1())); // should be normalized
519         return expair(ex_to_numeric(p.rest).mul_dyn(ex_to_numeric(c)),_ex1());
520     }
521
522     return expair(p.rest,ex_to_numeric(p.coeff).mul_dyn(ex_to_numeric(c)));
523 }
524     
525 ex add::recombine_pair_to_ex(const expair & p) const
526 {
527     //if (p.coeff.compare(_ex1())==0) {
528     //if (are_ex_trivially_equal(p.coeff,_ex1())) {
529     if (ex_to_numeric(p.coeff).is_equal(_num1())) {
530         return p.rest;
531     } else {
532         return p.rest*p.coeff;
533     }
534 }
535
536 ex add::expand(unsigned options) const
537 {
538     if (flags & status_flags::expanded)
539         return *this;
540     
541     epvector * vp = expandchildren(options);
542     if (vp==0)
543         return *this;
544     
545     return (new add(vp,overall_coeff))->
546         setflag(status_flags::expanded |
547                 status_flags::dynallocated);
548 }
549
550 //////////
551 // new virtual functions which can be overridden by derived classes
552 //////////
553
554 // none
555
556 //////////
557 // non-virtual functions in this class
558 //////////
559
560 // none
561
562 //////////
563 // static member variables
564 //////////
565
566 // protected
567
568 unsigned add::precedence = 40;
569
570 //////////
571 // global constants
572 //////////
573
574 const add some_add;
575 const type_info & typeid_add = typeid(some_add);
576
577 #ifndef NO_NAMESPACE_GINAC
578 } // namespace GiNaC
579 #endif // ndef NO_NAMESPACE_GINAC