0d88e329316c64b93eb5640a15983cb3ea56c43e
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 #include <iostream>
6 #include <stdexcept>
7
8 #include "ginac.h"
9
10 //////////
11 // default constructor, destructor, copy constructor assignment operator and helpers
12 //////////
13
14 // public
15
16 add::add()
17 {
18     debugmsg("add default constructor",LOGLEVEL_CONSTRUCT);
19     tinfo_key = TINFO_ADD;
20 }
21
22 add::~add()
23 {
24     debugmsg("add destructor",LOGLEVEL_DESTRUCT);
25     destroy(0);
26 }
27
28 add::add(add const & other)
29 {
30     debugmsg("add copy constructor",LOGLEVEL_CONSTRUCT);
31     copy(other);
32 }
33
34 add const & add::operator=(add const & other)
35 {
36     debugmsg("add operator=",LOGLEVEL_ASSIGNMENT);
37     if (this != &other) {
38         destroy(1);
39         copy(other);
40     }
41     return *this;
42 }
43
44 // protected
45
46 void add::copy(add const & other)
47 {
48     expairseq::copy(other);
49 }
50
51 void add::destroy(bool call_parent)
52 {
53     if (call_parent) expairseq::destroy(call_parent);
54 }
55
56 //////////
57 // other constructors
58 //////////
59
60 // public
61
62 add::add(ex const & lh, ex const & rh)
63 {
64     debugmsg("add constructor from ex,ex",LOGLEVEL_CONSTRUCT);
65     tinfo_key = TINFO_ADD;
66     overall_coeff=exZERO();
67     construct_from_2_ex(lh,rh);
68     ASSERT(is_canonical());
69 }
70
71 add::add(exvector const & v)
72 {
73     debugmsg("add constructor from exvector",LOGLEVEL_CONSTRUCT);
74     tinfo_key = TINFO_ADD;
75     overall_coeff=exZERO();
76     construct_from_exvector(v);
77     ASSERT(is_canonical());
78 }
79
80 /*
81 add::add(epvector const & v, bool do_not_canonicalize)
82 {
83     debugmsg("add constructor from epvector,bool",LOGLEVEL_CONSTRUCT);
84     tinfo_key = TINFO_ADD;
85     if (do_not_canonicalize) {
86         seq=v;
87 #ifdef EXPAIRSEQ_USE_HASHTAB
88         combine_same_terms(); // to build hashtab
89 #endif // def EXPAIRSEQ_USE_HASHTAB
90     } else {
91         construct_from_epvector(v);
92     }
93     ASSERT(is_canonical());
94 }
95 */
96
97 add::add(epvector const & v)
98 {
99     debugmsg("add constructor from epvector",LOGLEVEL_CONSTRUCT);
100     tinfo_key = TINFO_ADD;
101     overall_coeff=exZERO();
102     construct_from_epvector(v);
103     ASSERT(is_canonical());
104 }
105
106 add::add(epvector const & v, ex const & oc)
107 {
108     debugmsg("add constructor from epvector,ex",LOGLEVEL_CONSTRUCT);
109     tinfo_key = TINFO_ADD;
110     overall_coeff=oc;
111     construct_from_epvector(v);
112     ASSERT(is_canonical());
113 }
114
115 add::add(epvector * vp, ex const & oc)
116 {
117     debugmsg("add constructor from epvector *,ex",LOGLEVEL_CONSTRUCT);
118     tinfo_key = TINFO_ADD;
119     ASSERT(vp!=0);
120     overall_coeff=oc;
121     construct_from_epvector(*vp);
122     delete vp;
123     ASSERT(is_canonical());
124 }
125
126 //////////
127 // functions overriding virtual functions from bases classes
128 //////////
129
130 // public
131
132 basic * add::duplicate() const
133 {
134     debugmsg("add duplicate",LOGLEVEL_DUPLICATE);
135     return new add(*this);
136 }
137
138 bool add::info(unsigned inf) const
139 {
140     // TODO: optimize
141     if (inf==info_flags::polynomial || inf==info_flags::integer_polynomial || inf==info_flags::rational_polynomial || inf==info_flags::rational_function) {
142         for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
143             if (!(recombine_pair_to_ex(*it).info(inf)))
144                 return false;
145         }
146         return true;
147     } else {
148         return expairseq::info(inf);
149     }
150 }
151
152 int add::degree(symbol const & s) const
153 {
154     int deg=INT_MIN;
155     if (!overall_coeff.is_equal(exZERO())) {
156         deg=0;
157     }
158     int cur_deg;
159     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
160         cur_deg=(*cit).rest.degree(s);
161         if (cur_deg>deg) deg=cur_deg;
162     }
163     return deg;
164 }
165
166 int add::ldegree(symbol const & s) const
167 {
168     int deg=INT_MAX;
169     if (!overall_coeff.is_equal(exZERO())) {
170         deg=0;
171     }
172     int cur_deg;
173     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
174         cur_deg=(*cit).rest.ldegree(s);
175         if (cur_deg<deg) deg=cur_deg;
176     }
177     return deg;
178 }
179
180 ex add::coeff(symbol const & s, int const n) const
181 {
182     epvector coeffseq;
183     coeffseq.reserve(seq.size());
184
185     epvector::const_iterator it=seq.begin();
186     while (it!=seq.end()) {
187         coeffseq.push_back(combine_ex_with_coeff_to_pair((*it).rest.coeff(s,n),
188                                                          (*it).coeff));
189         ++it;
190     }
191     if (n==0) {
192         return (new add(coeffseq,overall_coeff))->setflag(status_flags::dynallocated);
193     }
194     return (new add(coeffseq))->setflag(status_flags::dynallocated);
195 }
196
197 /*
198 ex add::eval(int level) const
199 {
200     // simplifications: +(...,x,c1,c2) -> +(...,x,c1+c2) (c1, c2 numeric())
201     //                  +(...,(c1,c2)) -> (...,(c1*c2,1)) (normalize)
202     //                  +(...,x,0) -> +(...,x)
203     //                  +(x) -> x
204     //                  +() -> 0
205
206     debugmsg("add eval",LOGLEVEL_MEMBER_FUNCTION);
207
208     epvector newseq=seq;
209     epvector::iterator it1,it2;
210     
211     // +(...,x,c1,c2) -> +(...,x,c1+c2) (c1, c2 numeric())
212     it2=newseq.end()-1;
213     it1=it2-1;
214     while ((newseq.size()>=2)&&is_exactly_of_type(*(*it1).rest.bp,numeric)&&
215                                is_exactly_of_type(*(*it2).rest.bp,numeric)) {
216         *it1=expair(ex_to_numeric((*it1).rest).mul(ex_to_numeric((*it1).coeff))
217                     .add(ex_to_numeric((*it2).rest).mul(ex_to_numeric((*it2).coeff))),exONE());
218         newseq.pop_back();
219         it2=newseq.end()-1;
220         it1=it2-1;
221     }
222
223     if ((newseq.size()>=1)&&is_exactly_of_type(*(*it2).rest.bp,numeric)) {
224         // +(...,(c1,c2)) -> (...,(c1*c2,1)) (normalize)
225         *it2=expair(ex_to_numeric((*it2).rest).mul(ex_to_numeric((*it2).coeff)),exONE());
226         // +(...,x,0) -> +(...,x)
227         if (ex_to_numeric((*it2).rest).compare(0)==0) {
228             newseq.pop_back();
229         }
230     }
231
232     if (newseq.size()==0) {
233         // +() -> 0
234         return exZERO();
235     } else if (newseq.size()==1) {
236         // +(x) -> x
237         return recombine_pair_to_ex(*(newseq.begin()));
238     }
239
240     return (new add(newseq,1))->setflag(status_flags::dynallocated  |
241                                         status_flags::evaluated );
242 }
243 */
244
245 /*
246 ex add::eval(int level) const
247 {
248     // simplifications: +(...,x,c1,c2) -> +(...,x,c1+c2) (c1, c2 numeric())
249     //                  +(...,(c1,c2)) -> (...,(c1*c2,1)) (normalize)
250     //                  +(...,x,0) -> +(...,x)
251     //                  +(x) -> x
252     //                  +() -> 0
253
254     debugmsg("add eval",LOGLEVEL_MEMBER_FUNCTION);
255
256     if ((level==1)&&(flags & status_flags::evaluated)) {
257 #ifdef DOASSERT
258         for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
259             ASSERT(!is_ex_exactly_of_type((*cit).rest,add));
260             ASSERT(!(is_ex_exactly_of_type((*cit).rest,numeric)&&
261                      (ex_to_numeric((*cit).coeff).compare(numONE())!=0)));
262         }
263 #endif // def DOASSERT
264         return *this;
265     }
266
267     epvector newseq;
268     epvector::iterator it1,it2;
269     bool seq_copied=false;
270
271     epvector * evaled_seqp=evalchildren(level);
272     if (evaled_seqp!=0) {
273         // do more evaluation later
274         return (new add(evaled_seqp))->setflag(status_flags::dynallocated);
275     }
276
277 #ifdef DOASSERT
278     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
279         ASSERT(!is_ex_exactly_of_type((*cit).rest,add));
280         ASSERT(!(is_ex_exactly_of_type((*cit).rest,numeric)&&
281                  (ex_to_numeric((*cit).coeff).compare(numONE())!=0)));
282     }
283 #endif // def DOASSERT
284
285     if (flags & status_flags::evaluated) {
286         return *this;
287     }
288     
289     expair const & last_expair=*(seq.end()-1);
290     expair const & next_to_last_expair=*(seq.end()-2);
291     int seq_size = seq.size();
292
293     // +(...,x,c1,c2) -> +(...,x,c1+c2) (c1, c2 numeric())
294     if ((!seq_copied)&&(seq_size>=2)&&
295         is_ex_exactly_of_type(last_expair.rest,numeric)&&
296         is_ex_exactly_of_type(next_to_last_expair.rest,numeric)) {
297         newseq=seq;
298         seq_copied=true;
299         it2=newseq.end()-1;
300         it1=it2-1;
301     }
302     while (seq_copied&&(newseq.size()>=2)&&
303            is_ex_exactly_of_type((*it1).rest,numeric)&&
304            is_ex_exactly_of_type((*it2).rest,numeric)) {
305         *it1=expair(ex_to_numeric((*it1).rest).mul(ex_to_numeric((*it1).coeff))
306                     .add_dyn(ex_to_numeric((*it2).rest).mul(ex_to_numeric((*it2).coeff))),exONE());
307         newseq.pop_back();
308         it2=newseq.end()-1;
309         it1=it2-1;
310     }
311
312     // +(...,(c1,c2)) -> (...,(c1*c2,1)) (normalize)
313     if ((!seq_copied)&&(seq_size>=1)&&
314         (is_ex_exactly_of_type(last_expair.rest,numeric))&&
315         (ex_to_numeric(last_expair.coeff).compare(numONE())!=0)) {
316         newseq=seq;
317         seq_copied=true;
318         it2=newseq.end()-1;
319     }
320     if (seq_copied&&(newseq.size()>=1)&&
321         (is_ex_exactly_of_type((*it2).rest,numeric))&&
322         (ex_to_numeric((*it2).coeff).compare(numONE())!=0)) {
323         *it2=expair(ex_to_numeric((*it2).rest).mul_dyn(ex_to_numeric((*it2).coeff)),exONE());
324     }
325         
326     // +(...,x,0) -> +(...,x)
327     if ((!seq_copied)&&(seq_size>=1)&&
328         (is_ex_exactly_of_type(last_expair.rest,numeric))&&
329         (ex_to_numeric(last_expair.rest).is_zero())) {
330         newseq=seq;
331         seq_copied=true;
332         it2=newseq.end()-1;
333     }
334     if (seq_copied&&(newseq.size()>=1)&&
335         (is_ex_exactly_of_type((*it2).rest,numeric))&&
336         (ex_to_numeric((*it2).rest).is_zero())) {
337         newseq.pop_back();
338     }
339
340     // +() -> 0
341     if ((!seq_copied)&&(seq_size==0)) {
342         return exZERO();
343     } else if (seq_copied&&(newseq.size()==0)) {
344         return exZERO();
345     }
346
347     // +(x) -> x
348     if ((!seq_copied)&&(seq_size==1)) {
349         return recombine_pair_to_ex(*(seq.begin()));
350     } else if (seq_copied&&(newseq.size()==1)) {
351         return recombine_pair_to_ex(*(newseq.begin()));
352     }
353
354     if (!seq_copied) return this->hold();
355
356     return (new add(newseq,1))->setflag(status_flags::dynallocated  |
357                                         status_flags::evaluated );
358 }
359 */
360
361 ex add::eval(int level) const
362 {
363     // simplifications: +(;c) -> c
364     //                  +(x;1) -> x
365
366     debugmsg("add eval",LOGLEVEL_MEMBER_FUNCTION);
367
368     epvector * evaled_seqp=evalchildren(level);
369     if (evaled_seqp!=0) {
370         // do more evaluation later
371         return (new add(evaled_seqp,overall_coeff))->
372                    setflag(status_flags::dynallocated);
373     }
374
375 #ifdef DOASSERT
376     for (epvector::const_iterator cit=seq.begin(); cit!=seq.end(); ++cit) {
377         ASSERT(!is_ex_exactly_of_type((*cit).rest,add));
378         if (is_ex_exactly_of_type((*cit).rest,numeric)) {
379             dbgprint();
380         }
381         ASSERT(!is_ex_exactly_of_type((*cit).rest,numeric));
382     }
383 #endif // def DOASSERT
384
385     if (flags & status_flags::evaluated) {
386         ASSERT(seq.size()>0);
387         ASSERT((seq.size()>1)||!overall_coeff.is_equal(exZERO()));
388         return *this;
389     }
390
391     int seq_size=seq.size();
392     if (seq_size==0) {
393         // +(;c) -> c
394         return overall_coeff;
395     } else if ((seq_size==1)&&overall_coeff.is_equal(exZERO())) {
396         // +(x;0) -> x
397         return recombine_pair_to_ex(*(seq.begin()));
398     }
399     return this->hold();
400 }
401
402 exvector add::get_indices(void) const
403 {
404     // all terms in the sum should have the same indices (compatible tensors)
405     // however this is not checked, since there is no function yet which
406     // compares indices (idxvector can be unsorted) !!!!!!!!!!!
407     if (seq.size()==0) {
408         return exvector();
409     }
410     return (seq.begin())->rest.get_indices();
411 }    
412
413 ex add::simplify_ncmul(exvector const & v) const
414 {
415     if (seq.size()==0) {
416         return expairseq::simplify_ncmul(v);
417     }
418     return (*seq.begin()).rest.simplify_ncmul(v);
419 }    
420
421 // protected
422
423 int add::compare_same_type(basic const & other) const
424 {
425     return expairseq::compare_same_type(other);
426 }
427
428 bool add::is_equal_same_type(basic const & other) const
429 {
430     return expairseq::is_equal_same_type(other);
431 }
432
433 unsigned add::return_type(void) const
434 {
435     if (seq.size()==0) {
436         return return_types::commutative;
437     }
438     return (*seq.begin()).rest.return_type();
439 }
440    
441 unsigned add::return_type_tinfo(void) const
442 {
443     if (seq.size()==0) {
444         return tinfo_key;
445     }
446     return (*seq.begin()).rest.return_type_tinfo();
447 }
448
449 ex add::thisexpairseq(epvector const & v, ex const & oc) const
450 {
451     return (new add(v,oc))->setflag(status_flags::dynallocated);
452 }
453
454 ex add::thisexpairseq(epvector * vp, ex const & oc) const
455 {
456     return (new add(vp,oc))->setflag(status_flags::dynallocated);
457 }
458
459 /*
460 expair add::split_ex_to_pair(ex const & e) const
461 {
462     if (is_ex_exactly_of_type(e,mul)) {
463         mul const & mulref=ex_to_mul(e);
464         ASSERT(mulref.seq.size()>1);
465         ex const & lastfactor_rest=(*(mulref.seq.end()-1)).rest;
466         ex const & lastfactor_coeff=(*(mulref.seq.end()-1)).coeff;
467         if (is_ex_exactly_of_type(lastfactor_rest,numeric) &&
468             ex_to_numeric(lastfactor_coeff).is_equal(numONE())) {
469             epvector s=mulref.seq;
470             //s.pop_back();
471             //return expair((new mul(s,1))->setflag(status_flags::dynallocated),
472             //              lastfactor);
473             mul * mulp=static_cast<mul *>(mulref.duplicate());
474 #ifdef EXPAIRSEQ_USE_HASHTAB
475             mulp->remove_hashtab_entry(mulp->seq.end()-1);
476 #endif // def EXPAIRSEQ_USE_HASHTAB
477             mulp->seq.pop_back();
478 #ifdef EXPAIRSEQ_USE_HASHTAB
479             mulp->shrink_hashtab();
480 #endif // def EXPAIRSEQ_USE_HASHTAB
481             mulp->clearflag(status_flags::evaluated);
482             mulp->clearflag(status_flags::hash_calculated);
483             return expair(mulp->setflag(status_flags::dynallocated),lastfactor_rest);
484         }
485     }
486     return expair(e,exONE());
487 }
488 */
489
490 expair add::split_ex_to_pair(ex const & e) const
491 {
492     if (is_ex_exactly_of_type(e,mul)) {
493         mul const & mulref=ex_to_mul(e);
494         ex numfactor=mulref.overall_coeff;
495         // mul * mulcopyp=static_cast<mul *>(mulref.duplicate());
496         mul * mulcopyp=new mul(mulref);
497         mulcopyp->overall_coeff=exONE();
498         mulcopyp->clearflag(status_flags::evaluated);
499         mulcopyp->clearflag(status_flags::hash_calculated);
500         return expair(mulcopyp->setflag(status_flags::dynallocated),numfactor);
501     }
502     return expair(e,exONE());
503 }
504
505 /*
506 expair add::combine_ex_with_coeff_to_pair(ex const & e,
507                                           ex const & c) const
508 {
509     ASSERT(is_ex_exactly_of_type(c,numeric));
510     if (is_ex_exactly_of_type(e,mul)) {
511         mul const & mulref=ex_to_mul(e);
512         ASSERT(mulref.seq.size()>1);
513         ex const & lastfactor_rest=(*(mulref.seq.end()-1)).rest;
514         ex const & lastfactor_coeff=(*(mulref.seq.end()-1)).coeff;
515         if (is_ex_exactly_of_type(lastfactor_rest,numeric) &&
516             ex_to_numeric(lastfactor_coeff).is_equal(numONE())) {
517             //epvector s=mulref.seq;
518             //s.pop_back();
519             //return expair((new mul(s,1))->setflag(status_flags::dynallocated),
520             //              ex_to_numeric(lastfactor).mul_dyn(ex_to_numeric(c)));
521             mul * mulp=static_cast<mul *>(mulref.duplicate());
522 #ifdef EXPAIRSEQ_USE_HASHTAB
523             mulp->remove_hashtab_entry(mulp->seq.end()-1);
524 #endif // def EXPAIRSEQ_USE_HASHTAB
525             mulp->seq.pop_back();
526 #ifdef EXPAIRSEQ_USE_HASHTAB
527             mulp->shrink_hashtab();
528 #endif // def EXPAIRSEQ_USE_HASHTAB
529             mulp->clearflag(status_flags::evaluated);
530             mulp->clearflag(status_flags::hash_calculated);
531             if (are_ex_trivially_equal(c,exONE())) {
532                 return expair(mulp->setflag(status_flags::dynallocated),lastfactor_rest);
533             } else if (are_ex_trivially_equal(lastfactor_rest,exONE())) {
534                 return expair(mulp->setflag(status_flags::dynallocated),c);
535             }                
536             return expair(mulp->setflag(status_flags::dynallocated),
537                           ex_to_numeric(lastfactor_rest).mul_dyn(ex_to_numeric(c)));
538         }
539     }
540     return expair(e,c);
541 }
542 */
543
544 expair add::combine_ex_with_coeff_to_pair(ex const & e,
545                                           ex const & c) const
546 {
547     ASSERT(is_ex_exactly_of_type(c,numeric));
548     if (is_ex_exactly_of_type(e,mul)) {
549         mul const & mulref=ex_to_mul(e);
550         ex numfactor=mulref.overall_coeff;
551         //mul * mulcopyp=static_cast<mul *>(mulref.duplicate());
552         mul * mulcopyp=new mul(mulref);
553         mulcopyp->overall_coeff=exONE();
554         mulcopyp->clearflag(status_flags::evaluated);
555         mulcopyp->clearflag(status_flags::hash_calculated);
556         if (are_ex_trivially_equal(c,exONE())) {
557             return expair(mulcopyp->setflag(status_flags::dynallocated),numfactor);
558         } else if (are_ex_trivially_equal(numfactor,exONE())) {
559             return expair(mulcopyp->setflag(status_flags::dynallocated),c);
560         }
561         return expair(mulcopyp->setflag(status_flags::dynallocated),
562                           ex_to_numeric(numfactor).mul_dyn(ex_to_numeric(c)));
563     } else if (is_ex_exactly_of_type(e,numeric)) {
564         if (are_ex_trivially_equal(c,exONE())) {
565             return expair(e,exONE());
566         }
567         return expair(ex_to_numeric(e).mul_dyn(ex_to_numeric(c)),exONE());
568     }
569     return expair(e,c);
570 }
571     
572 expair add::combine_pair_with_coeff_to_pair(expair const & p,
573                                             ex const & c) const
574 {
575     ASSERT(is_ex_exactly_of_type(p.coeff,numeric));
576     ASSERT(is_ex_exactly_of_type(c,numeric));
577
578     if (is_ex_exactly_of_type(p.rest,numeric)) {
579         ASSERT(ex_to_numeric(p.coeff).is_equal(numONE())); // should be normalized
580         return expair(ex_to_numeric(p.rest).mul_dyn(ex_to_numeric(c)),exONE());
581     }
582
583     return expair(p.rest,ex_to_numeric(p.coeff).mul_dyn(ex_to_numeric(c)));
584 }
585     
586 ex add::recombine_pair_to_ex(expair const & p) const
587 {
588     //if (p.coeff.compare(exONE())==0) {
589     //if (are_ex_trivially_equal(p.coeff,exONE())) {
590     if (ex_to_numeric(p.coeff).is_equal(numONE())) {
591         return p.rest;
592     } else {
593         return p.rest*p.coeff;
594     }
595 }
596
597 ex add::expand(unsigned options) const
598 {
599     epvector * vp=expandchildren(options);
600     if (vp==0) {
601         return *this;
602     }
603     return (new add(vp,overall_coeff))->setflag(status_flags::expanded    |
604                                                 status_flags::dynallocated );
605 }
606
607 //////////
608 // new virtual functions which can be overridden by derived classes
609 //////////
610
611 // none
612
613 //////////
614 // non-virtual functions in this class
615 //////////
616
617 // none
618
619 //////////
620 // static member variables
621 //////////
622
623 // protected
624
625 unsigned add::precedence=40;
626
627 //////////
628 // global constants
629 //////////
630
631 const add some_add;
632 type_info const & typeid_add=typeid(some_add);
633
634
635