]> www.ginac.de Git - ginac.git/blob - ginac/add.cpp
This patch fixes a bug on machines where char is unsigned by default, by
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2011 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "add.h"
24 #include "mul.h"
25 #include "archive.h"
26 #include "operators.h"
27 #include "matrix.h"
28 #include "utils.h"
29 #include "clifford.h"
30 #include "ncmul.h"
31 #include "compiler.h"
32
33 #include <iostream>
34 #include <limits>
35 #include <stdexcept>
36 #include <string>
37
38 namespace GiNaC {
39
40 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(add, expairseq,
41   print_func<print_context>(&add::do_print).
42   print_func<print_latex>(&add::do_print_latex).
43   print_func<print_csrc>(&add::do_print_csrc).
44   print_func<print_tree>(&add::do_print_tree).
45   print_func<print_python_repr>(&add::do_print_python_repr))
46
47 //////////
48 // default constructor
49 //////////
50
51 add::add()
52 {
53 }
54
55 //////////
56 // other constructors
57 //////////
58
59 // public
60
61 add::add(const ex & lh, const ex & rh)
62 {
63         overall_coeff = _ex0;
64         construct_from_2_ex(lh,rh);
65         GINAC_ASSERT(is_canonical());
66 }
67
68 add::add(const exvector & v)
69 {
70         overall_coeff = _ex0;
71         construct_from_exvector(v);
72         GINAC_ASSERT(is_canonical());
73 }
74
75 add::add(const epvector & v)
76 {
77         overall_coeff = _ex0;
78         construct_from_epvector(v);
79         GINAC_ASSERT(is_canonical());
80 }
81
82 add::add(const epvector & v, const ex & oc)
83 {
84         overall_coeff = oc;
85         construct_from_epvector(v);
86         GINAC_ASSERT(is_canonical());
87 }
88
89 add::add(std::auto_ptr<epvector> vp, const ex & oc)
90 {
91         GINAC_ASSERT(vp.get()!=0);
92         overall_coeff = oc;
93         construct_from_epvector(*vp);
94         GINAC_ASSERT(is_canonical());
95 }
96
97 //////////
98 // archiving
99 //////////
100
101 GINAC_BIND_UNARCHIVER(add);
102
103 //////////
104 // functions overriding virtual functions from base classes
105 //////////
106
107 // public
108
109 void add::print_add(const print_context & c, const char *openbrace, const char *closebrace, const char *mul_sym, unsigned level) const
110 {
111         if (precedence() <= level)
112                 c.s << openbrace << '(';
113
114         numeric coeff;
115         bool first = true;
116
117         // First print the overall numeric coefficient, if present
118         if (!overall_coeff.is_zero()) {
119                 overall_coeff.print(c, 0);
120                 first = false;
121         }
122
123         // Then proceed with the remaining factors
124         epvector::const_iterator it = seq.begin(), itend = seq.end();
125         while (it != itend) {
126                 coeff = ex_to<numeric>(it->coeff);
127                 if (!first) {
128                         if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
129                 } else {
130                         if (coeff.csgn() == -1) c.s << '-';
131                         first = false;
132                 }
133                 if (!coeff.is_equal(*_num1_p) &&
134                     !coeff.is_equal(*_num_1_p)) {
135                         if (coeff.is_rational()) {
136                                 if (coeff.is_negative())
137                                         (-coeff).print(c);
138                                 else
139                                         coeff.print(c);
140                         } else {
141                                 if (coeff.csgn() == -1)
142                                         (-coeff).print(c, precedence());
143                                 else
144                                         coeff.print(c, precedence());
145                         }
146                         c.s << mul_sym;
147                 }
148                 it->rest.print(c, precedence());
149                 ++it;
150         }
151
152         if (precedence() <= level)
153                 c.s << ')' << closebrace;
154 }
155
156 void add::do_print(const print_context & c, unsigned level) const
157 {
158         print_add(c, "", "", "*", level);
159 }
160
161 void add::do_print_latex(const print_latex & c, unsigned level) const
162 {
163         print_add(c, "{", "}", " ", level);
164 }
165
166 void add::do_print_csrc(const print_csrc & c, unsigned level) const
167 {
168         if (precedence() <= level)
169                 c.s << "(";
170         
171         // Print arguments, separated by "+" or "-"
172         epvector::const_iterator it = seq.begin(), itend = seq.end();
173         char separator = ' ';
174         while (it != itend) {
175                 
176                 // If the coefficient is negative, separator is "-"
177                 if (it->coeff.is_equal(_ex_1) || 
178                         ex_to<numeric>(it->coeff).numer().is_equal(*_num_1_p))
179                         separator = '-';
180                 c.s << separator;
181                 if (it->coeff.is_equal(_ex1) || it->coeff.is_equal(_ex_1)) {
182                         it->rest.print(c, precedence());
183                 } else if (ex_to<numeric>(it->coeff).numer().is_equal(*_num1_p) ||
184                                  ex_to<numeric>(it->coeff).numer().is_equal(*_num_1_p))
185                 {
186                         it->rest.print(c, precedence());
187                         c.s << '/';
188                         ex_to<numeric>(it->coeff).denom().print(c, precedence());
189                 } else {
190                         it->coeff.print(c, precedence());
191                         c.s << '*';
192                         it->rest.print(c, precedence());
193                 }
194                 
195                 ++it;
196                 separator = '+';
197         }
198         
199         if (!overall_coeff.is_zero()) {
200                 if (overall_coeff.info(info_flags::positive)
201                  || is_a<print_csrc_cl_N>(c) || !overall_coeff.info(info_flags::real))  // sign inside ctor argument
202                         c.s << '+';
203                 overall_coeff.print(c, precedence());
204         }
205                 
206         if (precedence() <= level)
207                 c.s << ")";
208 }
209
210 void add::do_print_python_repr(const print_python_repr & c, unsigned level) const
211 {
212         c.s << class_name() << '(';
213         op(0).print(c);
214         for (size_t i=1; i<nops(); ++i) {
215                 c.s << ',';
216                 op(i).print(c);
217         }
218         c.s << ')';
219 }
220
221 bool add::info(unsigned inf) const
222 {
223         switch (inf) {
224                 case info_flags::polynomial:
225                 case info_flags::integer_polynomial:
226                 case info_flags::cinteger_polynomial:
227                 case info_flags::rational_polynomial:
228                 case info_flags::real:
229                 case info_flags::rational:
230                 case info_flags::integer:
231                 case info_flags::crational:
232                 case info_flags::cinteger:
233                 case info_flags::positive:
234                 case info_flags::nonnegative:
235                 case info_flags::posint:
236                 case info_flags::nonnegint:
237                 case info_flags::even:
238                 case info_flags::crational_polynomial:
239                 case info_flags::rational_function: {
240                         epvector::const_iterator i = seq.begin(), end = seq.end();
241                         while (i != end) {
242                                 if (!(recombine_pair_to_ex(*i).info(inf)))
243                                         return false;
244                                 ++i;
245                         }
246                         if (overall_coeff.is_zero() && (inf == info_flags::positive || inf == info_flags::posint))
247                                 return true;
248                         return overall_coeff.info(inf);
249                 }
250                 case info_flags::algebraic: {
251                         epvector::const_iterator i = seq.begin(), end = seq.end();
252                         while (i != end) {
253                                 if ((recombine_pair_to_ex(*i).info(inf)))
254                                         return true;
255                                 ++i;
256                         }
257                         return false;
258                 }
259         }
260         return inherited::info(inf);
261 }
262
263 bool add::is_polynomial(const ex & var) const
264 {
265         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
266                 if (!(i->rest).is_polynomial(var)) {
267                         return false;
268                 }
269         }
270         return true;
271 }
272
273 int add::degree(const ex & s) const
274 {
275         int deg = std::numeric_limits<int>::min();
276         if (!overall_coeff.is_zero())
277                 deg = 0;
278         
279         // Find maximum of degrees of individual terms
280         epvector::const_iterator i = seq.begin(), end = seq.end();
281         while (i != end) {
282                 int cur_deg = i->rest.degree(s);
283                 if (cur_deg > deg)
284                         deg = cur_deg;
285                 ++i;
286         }
287         return deg;
288 }
289
290 int add::ldegree(const ex & s) const
291 {
292         int deg = std::numeric_limits<int>::max();
293         if (!overall_coeff.is_zero())
294                 deg = 0;
295         
296         // Find minimum of degrees of individual terms
297         epvector::const_iterator i = seq.begin(), end = seq.end();
298         while (i != end) {
299                 int cur_deg = i->rest.ldegree(s);
300                 if (cur_deg < deg)
301                         deg = cur_deg;
302                 ++i;
303         }
304         return deg;
305 }
306
307 ex add::coeff(const ex & s, int n) const
308 {
309         std::auto_ptr<epvector> coeffseq(new epvector);
310         std::auto_ptr<epvector> coeffseq_cliff(new epvector);
311         int rl = clifford_max_label(s);
312         bool do_clifford = (rl != -1);
313         bool nonscalar = false;
314
315         // Calculate sum of coefficients in each term
316         epvector::const_iterator i = seq.begin(), end = seq.end();
317         while (i != end) {
318                 ex restcoeff = i->rest.coeff(s, n);
319                 if (!restcoeff.is_zero()) {
320                         if (do_clifford) {
321                                 if (clifford_max_label(restcoeff) == -1) {
322                                         coeffseq_cliff->push_back(combine_ex_with_coeff_to_pair(ncmul(restcoeff, dirac_ONE(rl)), i->coeff));
323                                 } else {
324                                         coeffseq_cliff->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
325                                         nonscalar = true;
326                                 }
327                         }
328                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
329                 }
330                 ++i;
331         }
332
333         return (new add(nonscalar ? coeffseq_cliff : coeffseq,
334                         n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
335 }
336
337 /** Perform automatic term rewriting rules in this class.  In the following
338  *  x stands for a symbolic variables of type ex and c stands for such
339  *  an expression that contain a plain number.
340  *  - +(;c) -> c
341  *  - +(x;0) -> x
342  *
343  *  @param level cut-off in recursive evaluation */
344 ex add::eval(int level) const
345 {
346         std::auto_ptr<epvector> evaled_seqp = evalchildren(level);
347         if (evaled_seqp.get()) {
348                 // do more evaluation later
349                 return (new add(evaled_seqp, overall_coeff))->
350                        setflag(status_flags::dynallocated);
351         }
352         
353 #ifdef DO_GINAC_ASSERT
354         epvector::const_iterator i = seq.begin(), end = seq.end();
355         while (i != end) {
356                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
357                 ++i;
358         }
359 #endif // def DO_GINAC_ASSERT
360         
361         if (flags & status_flags::evaluated) {
362                 GINAC_ASSERT(seq.size()>0);
363                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
364                 return *this;
365         }
366         
367         int seq_size = seq.size();
368         if (seq_size == 0) {
369                 // +(;c) -> c
370                 return overall_coeff;
371         } else if (seq_size == 1 && overall_coeff.is_zero()) {
372                 // +(x;0) -> x
373                 return recombine_pair_to_ex(*(seq.begin()));
374         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
375                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
376         }
377         
378         // if any terms in the sum still are purely numeric, then they are more
379         // appropriately collected into the overall coefficient
380         epvector::const_iterator last = seq.end();
381         epvector::const_iterator j = seq.begin();
382         int terms_to_collect = 0;
383         while (j != last) {
384                 if (unlikely(is_a<numeric>(j->rest)))
385                         ++terms_to_collect;
386                 ++j;
387         }
388         if (terms_to_collect) {
389                 std::auto_ptr<epvector> s(new epvector);
390                 s->reserve(seq_size - terms_to_collect);
391                 numeric oc = *_num1_p;
392                 j = seq.begin();
393                 while (j != last) {
394                         if (unlikely(is_a<numeric>(j->rest)))
395                                 oc = oc.mul(ex_to<numeric>(j->rest)).mul(ex_to<numeric>(j->coeff));
396                         else
397                                 s->push_back(*j);
398                         ++j;
399                 }
400                 return (new add(s, ex_to<numeric>(overall_coeff).add_dyn(oc)))
401                         ->setflag(status_flags::dynallocated);
402         }
403         
404         return this->hold();
405 }
406
407 ex add::evalm() const
408 {
409         // Evaluate children first and add up all matrices. Stop if there's one
410         // term that is not a matrix.
411         std::auto_ptr<epvector> s(new epvector);
412         s->reserve(seq.size());
413
414         bool all_matrices = true;
415         bool first_term = true;
416         matrix sum;
417
418         epvector::const_iterator it = seq.begin(), itend = seq.end();
419         while (it != itend) {
420                 const ex &m = recombine_pair_to_ex(*it).evalm();
421                 s->push_back(split_ex_to_pair(m));
422                 if (is_a<matrix>(m)) {
423                         if (first_term) {
424                                 sum = ex_to<matrix>(m);
425                                 first_term = false;
426                         } else
427                                 sum = sum.add(ex_to<matrix>(m));
428                 } else
429                         all_matrices = false;
430                 ++it;
431         }
432
433         if (all_matrices)
434                 return sum + overall_coeff;
435         else
436                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
437 }
438
439 ex add::conjugate() const
440 {
441         exvector *v = 0;
442         for (size_t i=0; i<nops(); ++i) {
443                 if (v) {
444                         v->push_back(op(i).conjugate());
445                         continue;
446                 }
447                 ex term = op(i);
448                 ex ccterm = term.conjugate();
449                 if (are_ex_trivially_equal(term, ccterm))
450                         continue;
451                 v = new exvector;
452                 v->reserve(nops());
453                 for (size_t j=0; j<i; ++j)
454                         v->push_back(op(j));
455                 v->push_back(ccterm);
456         }
457         if (v) {
458                 ex result = add(*v);
459                 delete v;
460                 return result;
461         }
462         return *this;
463 }
464
465 ex add::real_part() const
466 {
467         epvector v;
468         v.reserve(seq.size());
469         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i)
470                 if ((i->coeff).info(info_flags::real)) {
471                         ex rp = (i->rest).real_part();
472                         if (!rp.is_zero())
473                                 v.push_back(expair(rp, i->coeff));
474                 } else {
475                         ex rp=recombine_pair_to_ex(*i).real_part();
476                         if (!rp.is_zero())
477                                 v.push_back(split_ex_to_pair(rp));
478                 }
479         return (new add(v, overall_coeff.real_part()))
480                 -> setflag(status_flags::dynallocated);
481 }
482
483 ex add::imag_part() const
484 {
485         epvector v;
486         v.reserve(seq.size());
487         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i)
488                 if ((i->coeff).info(info_flags::real)) {
489                         ex ip = (i->rest).imag_part();
490                         if (!ip.is_zero())
491                                 v.push_back(expair(ip, i->coeff));
492                 } else {
493                         ex ip=recombine_pair_to_ex(*i).imag_part();
494                         if (!ip.is_zero())
495                                 v.push_back(split_ex_to_pair(ip));
496                 }
497         return (new add(v, overall_coeff.imag_part()))
498                 -> setflag(status_flags::dynallocated);
499 }
500
501 ex add::eval_ncmul(const exvector & v) const
502 {
503         if (seq.empty())
504                 return inherited::eval_ncmul(v);
505         else
506                 return seq.begin()->rest.eval_ncmul(v);
507 }    
508
509 // protected
510
511 /** Implementation of ex::diff() for a sum. It differentiates each term.
512  *  @see ex::diff */
513 ex add::derivative(const symbol & y) const
514 {
515         std::auto_ptr<epvector> s(new epvector);
516         s->reserve(seq.size());
517         
518         // Only differentiate the "rest" parts of the expairs. This is faster
519         // than the default implementation in basic::derivative() although
520         // if performs the same function (differentiate each term).
521         epvector::const_iterator i = seq.begin(), end = seq.end();
522         while (i != end) {
523                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
524                 ++i;
525         }
526         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
527 }
528
529 int add::compare_same_type(const basic & other) const
530 {
531         return inherited::compare_same_type(other);
532 }
533
534 unsigned add::return_type() const
535 {
536         if (seq.empty())
537                 return return_types::commutative;
538         else
539                 return seq.begin()->rest.return_type();
540 }
541
542 return_type_t add::return_type_tinfo() const
543 {
544         if (seq.empty())
545                 return make_return_type_t<add>();
546         else
547                 return seq.begin()->rest.return_type_tinfo();
548 }
549
550 // Note: do_index_renaming is ignored because it makes no sense for an add.
551 ex add::thisexpairseq(const epvector & v, const ex & oc, bool do_index_renaming) const
552 {
553         return (new add(v,oc))->setflag(status_flags::dynallocated);
554 }
555
556 // Note: do_index_renaming is ignored because it makes no sense for an add.
557 ex add::thisexpairseq(std::auto_ptr<epvector> vp, const ex & oc, bool do_index_renaming) const
558 {
559         return (new add(vp,oc))->setflag(status_flags::dynallocated);
560 }
561
562 expair add::split_ex_to_pair(const ex & e) const
563 {
564         if (is_exactly_a<mul>(e)) {
565                 const mul &mulref(ex_to<mul>(e));
566                 const ex &numfactor = mulref.overall_coeff;
567                 mul *mulcopyp = new mul(mulref);
568                 mulcopyp->overall_coeff = _ex1;
569                 mulcopyp->clearflag(status_flags::evaluated);
570                 mulcopyp->clearflag(status_flags::hash_calculated);
571                 mulcopyp->setflag(status_flags::dynallocated);
572                 return expair(*mulcopyp,numfactor);
573         }
574         return expair(e,_ex1);
575 }
576
577 expair add::combine_ex_with_coeff_to_pair(const ex & e,
578                                                                                   const ex & c) const
579 {
580         GINAC_ASSERT(is_exactly_a<numeric>(c));
581         if (is_exactly_a<mul>(e)) {
582                 const mul &mulref(ex_to<mul>(e));
583                 const ex &numfactor = mulref.overall_coeff;
584                 mul *mulcopyp = new mul(mulref);
585                 mulcopyp->overall_coeff = _ex1;
586                 mulcopyp->clearflag(status_flags::evaluated);
587                 mulcopyp->clearflag(status_flags::hash_calculated);
588                 mulcopyp->setflag(status_flags::dynallocated);
589                 if (c.is_equal(_ex1))
590                         return expair(*mulcopyp, numfactor);
591                 else if (numfactor.is_equal(_ex1))
592                         return expair(*mulcopyp, c);
593                 else
594                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
595         } else if (is_exactly_a<numeric>(e)) {
596                 if (c.is_equal(_ex1))
597                         return expair(e, _ex1);
598                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
599         }
600         return expair(e, c);
601 }
602
603 expair add::combine_pair_with_coeff_to_pair(const expair & p,
604                                                                                         const ex & c) const
605 {
606         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
607         GINAC_ASSERT(is_exactly_a<numeric>(c));
608
609         if (is_exactly_a<numeric>(p.rest)) {
610                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(*_num1_p)); // should be normalized
611                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
612         }
613
614         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
615 }
616         
617 ex add::recombine_pair_to_ex(const expair & p) const
618 {
619         if (ex_to<numeric>(p.coeff).is_equal(*_num1_p))
620                 return p.rest;
621         else
622                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
623 }
624
625 ex add::expand(unsigned options) const
626 {
627         std::auto_ptr<epvector> vp = expandchildren(options);
628         if (vp.get() == 0) {
629                 // the terms have not changed, so it is safe to declare this expanded
630                 return (options == 0) ? setflag(status_flags::expanded) : *this;
631         }
632
633         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
634 }
635
636 } // namespace GiNaC