- added skeleton implementation of color and clifford classes (don't bother
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "archive.h"
29 #include "debugmsg.h"
30 #include "utils.h"
31
32 namespace GiNaC {
33
34 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
35
36 //////////
37 // default constructor, destructor, copy constructor assignment operator and helpers
38 //////////
39
40 add::add()
41 {
42         debugmsg("add default constructor",LOGLEVEL_CONSTRUCT);
43         tinfo_key = TINFO_add;
44 }
45
46 DEFAULT_COPY(add)
47 DEFAULT_DESTROY(add)
48
49 //////////
50 // other constructors
51 //////////
52
53 // public
54
55 add::add(const ex & lh, const ex & rh)
56 {
57         debugmsg("add constructor from ex,ex",LOGLEVEL_CONSTRUCT);
58         tinfo_key = TINFO_add;
59         overall_coeff = _ex0();
60         construct_from_2_ex(lh,rh);
61         GINAC_ASSERT(is_canonical());
62 }
63
64 add::add(const exvector & v)
65 {
66         debugmsg("add constructor from exvector",LOGLEVEL_CONSTRUCT);
67         tinfo_key = TINFO_add;
68         overall_coeff = _ex0();
69         construct_from_exvector(v);
70         GINAC_ASSERT(is_canonical());
71 }
72
73 add::add(const epvector & v)
74 {
75         debugmsg("add constructor from epvector",LOGLEVEL_CONSTRUCT);
76         tinfo_key = TINFO_add;
77         overall_coeff = _ex0();
78         construct_from_epvector(v);
79         GINAC_ASSERT(is_canonical());
80 }
81
82 add::add(const epvector & v, const ex & oc)
83 {
84         debugmsg("add constructor from epvector,ex",LOGLEVEL_CONSTRUCT);
85         tinfo_key = TINFO_add;
86         overall_coeff = oc;
87         construct_from_epvector(v);
88         GINAC_ASSERT(is_canonical());
89 }
90
91 add::add(epvector * vp, const ex & oc)
92 {
93         debugmsg("add constructor from epvector *,ex",LOGLEVEL_CONSTRUCT);
94         tinfo_key = TINFO_add;
95         GINAC_ASSERT(vp!=0);
96         overall_coeff = oc;
97         construct_from_epvector(*vp);
98         delete vp;
99         GINAC_ASSERT(is_canonical());
100 }
101
102 //////////
103 // archiving
104 //////////
105
106 DEFAULT_ARCHIVING(add)
107
108 //////////
109 // functions overriding virtual functions from bases classes
110 //////////
111
112 // public
113
114 void add::print(std::ostream & os, unsigned upper_precedence) const
115 {
116         debugmsg("add print",LOGLEVEL_PRINT);
117         if (precedence<=upper_precedence) os << "(";
118         numeric coeff;
119         bool first = true;
120         // first print the overall numeric coefficient, if present:
121         if (!overall_coeff.is_zero()) {
122                 os << overall_coeff;
123                 first = false;
124         }
125         // then proceed with the remaining factors:
126         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
127                 coeff = ex_to_numeric(cit->coeff);
128                 if (!first) {
129                         if (coeff.csgn()==-1) os << '-'; else os << '+';
130                 } else {
131                         if (coeff.csgn()==-1) os << '-';
132                         first = false;
133                 }
134                 if (!coeff.is_equal(_num1()) &&
135                     !coeff.is_equal(_num_1())) {
136                         if (coeff.is_rational()) {
137                                 if (coeff.is_negative())
138                                         os << -coeff;
139                                 else
140                                         os << coeff;
141                         } else {
142                                 if (coeff.csgn()==-1)
143                                         (-coeff).print(os, precedence);
144                                 else
145                                         coeff.print(os, precedence);
146                         }
147                         os << '*';
148                 }
149                 cit->rest.print(os, precedence);
150         }
151         if (precedence<=upper_precedence) os << ")";
152 }
153
154 void add::printraw(std::ostream & os) const
155 {
156         debugmsg("add printraw",LOGLEVEL_PRINT);
157
158         os << "+(";
159         for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
160                 os << "(";
161                 (*it).rest.bp->printraw(os);
162                 os << ",";
163                 (*it).coeff.bp->printraw(os);        
164                 os << "),";
165         }
166         os << ",hash=" << hashvalue << ",flags=" << flags;
167         os << ")";
168 }
169
170 void add::printcsrc(std::ostream & os, unsigned type, unsigned upper_precedence) const
171 {
172         debugmsg("add print csrc", LOGLEVEL_PRINT);
173         if (precedence <= upper_precedence)
174                 os << "(";
175         
176         // Print arguments, separated by "+"
177         epvector::const_iterator it = seq.begin();
178         epvector::const_iterator itend = seq.end();
179         while (it != itend) {
180                 
181                 // If the coefficient is -1, it is replaced by a single minus sign
182                 if (it->coeff.compare(_num1()) == 0) {
183                         it->rest.bp->printcsrc(os, type, precedence);
184                 } else if (it->coeff.compare(_num_1()) == 0) {
185                         os << "-";
186                         it->rest.bp->printcsrc(os, type, precedence);
187                 } else if (ex_to_numeric(it->coeff).numer().compare(_num1()) == 0) {
188                         it->rest.bp->printcsrc(os, type, precedence);
189                         os << "/";
190                         ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
191                 } else if (ex_to_numeric(it->coeff).numer().compare(_num_1()) == 0) {
192                         os << "-";
193                         it->rest.bp->printcsrc(os, type, precedence);
194                         os << "/";
195                         ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
196                 } else {
197                         it->coeff.bp->printcsrc(os, type, precedence);
198                         os << "*";
199                         it->rest.bp->printcsrc(os, type, precedence);
200                 }
201                 
202                 // Separator is "+", except if the following expression would have a leading minus sign
203                 it++;
204                 if (it != itend && !(it->coeff.compare(_num0()) < 0 || (it->coeff.compare(_num1()) == 0 && is_ex_exactly_of_type(it->rest, numeric) && it->rest.compare(_num0()) < 0)))
205                         os << "+";
206         }
207         
208         if (!overall_coeff.is_zero()) {
209                 if (overall_coeff.info(info_flags::positive)) os << '+';
210                 overall_coeff.bp->printcsrc(os,type,precedence);
211         }
212         
213         if (precedence <= upper_precedence)
214                 os << ")";
215 }
216
217 bool add::info(unsigned inf) const
218 {
219         switch (inf) {
220                 case info_flags::polynomial:
221                 case info_flags::integer_polynomial:
222                 case info_flags::cinteger_polynomial:
223                 case info_flags::rational_polynomial:
224                 case info_flags::crational_polynomial:
225                 case info_flags::rational_function: {
226                         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
227                                 if (!(recombine_pair_to_ex(*i).info(inf)))
228                                         return false;
229                         }
230                         return overall_coeff.info(inf);
231                 }
232                 case info_flags::algebraic: {
233                         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
234                                 if ((recombine_pair_to_ex(*i).info(inf)))
235                                         return true;
236                         }
237                         return false;
238                 }
239         }
240         return inherited::info(inf);
241 }
242
243 int add::degree(const symbol & s) const
244 {
245         int deg = INT_MIN;
246         if (!overall_coeff.is_equal(_ex0()))
247                 deg = 0;
248         
249         int cur_deg;
250         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
251                 cur_deg = (*cit).rest.degree(s);
252                 if (cur_deg>deg)
253                         deg = cur_deg;
254         }
255         return deg;
256 }
257
258 int add::ldegree(const symbol & s) const
259 {
260         int deg = INT_MAX;
261         if (!overall_coeff.is_equal(_ex0()))
262                 deg = 0;
263         
264         int cur_deg;
265         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
266                 cur_deg = (*cit).rest.ldegree(s);
267                 if (cur_deg<deg) deg=cur_deg;
268         }
269         return deg;
270 }
271
272 ex add::coeff(const symbol & s, int n) const
273 {
274         epvector coeffseq;
275         coeffseq.reserve(seq.size());
276
277         epvector::const_iterator it=seq.begin();
278         while (it!=seq.end()) {
279                 coeffseq.push_back(combine_ex_with_coeff_to_pair((*it).rest.coeff(s,n),
280                                                                  (*it).coeff));
281                 ++it;
282         }
283         if (n==0) {
284                 return (new add(coeffseq,overall_coeff))->setflag(status_flags::dynallocated);
285         }
286         return (new add(coeffseq))->setflag(status_flags::dynallocated);
287 }
288
289 ex add::eval(int level) const
290 {
291         // simplifications: +(;c) -> c
292         //                  +(x;1) -> x
293         
294         debugmsg("add eval",LOGLEVEL_MEMBER_FUNCTION);
295         
296         epvector * evaled_seqp = evalchildren(level);
297         if (evaled_seqp!=0) {
298                 // do more evaluation later
299                 return (new add(evaled_seqp,overall_coeff))->
300                        setflag(status_flags::dynallocated);
301         }
302         
303 #ifdef DO_GINAC_ASSERT
304         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
305                 GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,add));
306                 if (is_ex_exactly_of_type((*cit).rest,numeric))
307                         dbgprint();
308                 GINAC_ASSERT(!is_ex_exactly_of_type((*cit).rest,numeric));
309         }
310 #endif // def DO_GINAC_ASSERT
311         
312         if (flags & status_flags::evaluated) {
313                 GINAC_ASSERT(seq.size()>0);
314                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
315                 return *this;
316         }
317         
318         int seq_size = seq.size();
319         if (seq_size==0) {
320                 // +(;c) -> c
321                 return overall_coeff;
322         } else if ((seq_size==1) && overall_coeff.is_equal(_ex0())) {
323                 // +(x;0) -> x
324                 return recombine_pair_to_ex(*(seq.begin()));
325         }
326         return this->hold();
327 }
328
329 ex add::simplify_ncmul(const exvector & v) const
330 {
331         if (seq.size()==0) {
332                 return inherited::simplify_ncmul(v);
333         }
334         return (*seq.begin()).rest.simplify_ncmul(v);
335 }    
336
337 // protected
338
339 /** Implementation of ex::diff() for a sum. It differentiates each term.
340  *  @see ex::diff */
341 ex add::derivative(const symbol & s) const
342 {
343         // D(a+b+c)=D(a)+D(b)+D(c)
344         return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
345 }
346
347 int add::compare_same_type(const basic & other) const
348 {
349         return inherited::compare_same_type(other);
350 }
351
352 bool add::is_equal_same_type(const basic & other) const
353 {
354         return inherited::is_equal_same_type(other);
355 }
356
357 unsigned add::return_type(void) const
358 {
359         if (seq.size()==0) {
360                 return return_types::commutative;
361         }
362         return (*seq.begin()).rest.return_type();
363 }
364    
365 unsigned add::return_type_tinfo(void) const
366 {
367         if (seq.size()==0) {
368                 return tinfo_key;
369         }
370         return (*seq.begin()).rest.return_type_tinfo();
371 }
372
373 ex add::thisexpairseq(const epvector & v, const ex & oc) const
374 {
375         return (new add(v,oc))->setflag(status_flags::dynallocated);
376 }
377
378 ex add::thisexpairseq(epvector * vp, const ex & oc) const
379 {
380         return (new add(vp,oc))->setflag(status_flags::dynallocated);
381 }
382
383 expair add::split_ex_to_pair(const ex & e) const
384 {
385         if (is_ex_exactly_of_type(e,mul)) {
386                 const mul &mulref = ex_to_mul(e);
387                 ex numfactor = mulref.overall_coeff;
388                 mul *mulcopyp = new mul(mulref);
389                 mulcopyp->overall_coeff = _ex1();
390                 mulcopyp->clearflag(status_flags::evaluated);
391                 mulcopyp->clearflag(status_flags::hash_calculated);
392                 mulcopyp->setflag(status_flags::dynallocated);
393                 return expair(*mulcopyp,numfactor);
394         }
395         return expair(e,_ex1());
396 }
397
398 expair add::combine_ex_with_coeff_to_pair(const ex & e,
399                                                                                   const ex & c) const
400 {
401         GINAC_ASSERT(is_ex_exactly_of_type(c, numeric));
402         if (is_ex_exactly_of_type(e, mul)) {
403                 const mul &mulref = ex_to_mul(e);
404                 ex numfactor = mulref.overall_coeff;
405                 mul *mulcopyp = new mul(mulref);
406                 mulcopyp->overall_coeff = _ex1();
407                 mulcopyp->clearflag(status_flags::evaluated);
408                 mulcopyp->clearflag(status_flags::hash_calculated);
409                 mulcopyp->setflag(status_flags::dynallocated);
410                 if (are_ex_trivially_equal(c, _ex1()))
411                         return expair(*mulcopyp, numfactor);
412                 else if (are_ex_trivially_equal(numfactor, _ex1()))
413                         return expair(*mulcopyp, c);
414                 else
415                         return expair(*mulcopyp, ex_to_numeric(numfactor).mul_dyn(ex_to_numeric(c)));
416         } else if (is_ex_exactly_of_type(e, numeric)) {
417                 if (are_ex_trivially_equal(c, _ex1()))
418                         return expair(e, _ex1());
419                 return expair(ex_to_numeric(e).mul_dyn(ex_to_numeric(c)), _ex1());
420         }
421         return expair(e, c);
422 }
423
424 expair add::combine_pair_with_coeff_to_pair(const expair & p,
425                                                                                         const ex & c) const
426 {
427         GINAC_ASSERT(is_ex_exactly_of_type(p.coeff,numeric));
428         GINAC_ASSERT(is_ex_exactly_of_type(c,numeric));
429
430         if (is_ex_exactly_of_type(p.rest,numeric)) {
431                 GINAC_ASSERT(ex_to_numeric(p.coeff).is_equal(_num1())); // should be normalized
432                 return expair(ex_to_numeric(p.rest).mul_dyn(ex_to_numeric(c)),_ex1());
433         }
434
435         return expair(p.rest,ex_to_numeric(p.coeff).mul_dyn(ex_to_numeric(c)));
436 }
437         
438 ex add::recombine_pair_to_ex(const expair & p) const
439 {
440         if (ex_to_numeric(p.coeff).is_equal(_num1()))
441                 return p.rest;
442         else
443                 return p.rest*p.coeff;
444 }
445
446 ex add::expand(unsigned options) const
447 {
448         if (flags & status_flags::expanded)
449                 return *this;
450         
451         epvector * vp = expandchildren(options);
452         if (vp==0) {
453                 // the terms have not changed, so it is safe to declare this expanded
454                 setflag(status_flags::expanded);
455                 return *this;
456         }
457         
458         return (new add(vp,overall_coeff))->setflag(status_flags::expanded | status_flags::dynallocated);
459 }
460
461 //////////
462 // static member variables
463 //////////
464
465 // protected
466
467 unsigned add::precedence = 40;
468
469 } // namespace GiNaC