]> www.ginac.de Git - ginac.git/blob - ginac/add.cpp
36bb2011013efee1abbe28fbc3ea38433dc1791d
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "add.h"
24 #include "mul.h"
25 #include "archive.h"
26 #include "operators.h"
27 #include "matrix.h"
28 #include "utils.h"
29 #include "clifford.h"
30 #include "ncmul.h"
31 #include "compiler.h"
32
33 #include <iostream>
34 #include <limits>
35 #include <stdexcept>
36 #include <string>
37
38 namespace GiNaC {
39
40 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(add, expairseq,
41   print_func<print_context>(&add::do_print).
42   print_func<print_latex>(&add::do_print_latex).
43   print_func<print_csrc>(&add::do_print_csrc).
44   print_func<print_tree>(&add::do_print_tree).
45   print_func<print_python_repr>(&add::do_print_python_repr))
46
47 //////////
48 // default constructor
49 //////////
50
51 add::add()
52 {
53 }
54
55 //////////
56 // other constructors
57 //////////
58
59 // public
60
61 add::add(const ex & lh, const ex & rh)
62 {
63         overall_coeff = _ex0;
64         construct_from_2_ex(lh,rh);
65         GINAC_ASSERT(is_canonical());
66 }
67
68 add::add(const exvector & v)
69 {
70         overall_coeff = _ex0;
71         construct_from_exvector(v);
72         GINAC_ASSERT(is_canonical());
73 }
74
75 add::add(const epvector & v)
76 {
77         overall_coeff = _ex0;
78         construct_from_epvector(v);
79         GINAC_ASSERT(is_canonical());
80 }
81
82 add::add(const epvector & v, const ex & oc)
83 {
84         overall_coeff = oc;
85         construct_from_epvector(v);
86         GINAC_ASSERT(is_canonical());
87 }
88
89 add::add(epvector && vp, const ex & oc)
90 {
91         overall_coeff = oc;
92         construct_from_epvector(std::move(vp));
93         GINAC_ASSERT(is_canonical());
94 }
95
96 //////////
97 // archiving
98 //////////
99
100 GINAC_BIND_UNARCHIVER(add);
101
102 //////////
103 // functions overriding virtual functions from base classes
104 //////////
105
106 // public
107
108 void add::print_add(const print_context & c, const char *openbrace, const char *closebrace, const char *mul_sym, unsigned level) const
109 {
110         if (precedence() <= level)
111                 c.s << openbrace << '(';
112
113         numeric coeff;
114         bool first = true;
115
116         // First print the overall numeric coefficient, if present
117         if (!overall_coeff.is_zero()) {
118                 overall_coeff.print(c, 0);
119                 first = false;
120         }
121
122         // Then proceed with the remaining factors
123         epvector::const_iterator it = seq.begin(), itend = seq.end();
124         while (it != itend) {
125                 coeff = ex_to<numeric>(it->coeff);
126                 if (!first) {
127                         if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
128                 } else {
129                         if (coeff.csgn() == -1) c.s << '-';
130                         first = false;
131                 }
132                 if (!coeff.is_equal(*_num1_p) &&
133                     !coeff.is_equal(*_num_1_p)) {
134                         if (coeff.is_rational()) {
135                                 if (coeff.is_negative())
136                                         (-coeff).print(c);
137                                 else
138                                         coeff.print(c);
139                         } else {
140                                 if (coeff.csgn() == -1)
141                                         (-coeff).print(c, precedence());
142                                 else
143                                         coeff.print(c, precedence());
144                         }
145                         c.s << mul_sym;
146                 }
147                 it->rest.print(c, precedence());
148                 ++it;
149         }
150
151         if (precedence() <= level)
152                 c.s << ')' << closebrace;
153 }
154
155 void add::do_print(const print_context & c, unsigned level) const
156 {
157         print_add(c, "", "", "*", level);
158 }
159
160 void add::do_print_latex(const print_latex & c, unsigned level) const
161 {
162         print_add(c, "{", "}", " ", level);
163 }
164
165 void add::do_print_csrc(const print_csrc & c, unsigned level) const
166 {
167         if (precedence() <= level)
168                 c.s << "(";
169         
170         // Print arguments, separated by "+" or "-"
171         epvector::const_iterator it = seq.begin(), itend = seq.end();
172         char separator = ' ';
173         while (it != itend) {
174                 
175                 // If the coefficient is negative, separator is "-"
176                 if (it->coeff.is_equal(_ex_1) || 
177                         ex_to<numeric>(it->coeff).numer().is_equal(*_num_1_p))
178                         separator = '-';
179                 c.s << separator;
180                 if (it->coeff.is_equal(_ex1) || it->coeff.is_equal(_ex_1)) {
181                         it->rest.print(c, precedence());
182                 } else if (ex_to<numeric>(it->coeff).numer().is_equal(*_num1_p) ||
183                                  ex_to<numeric>(it->coeff).numer().is_equal(*_num_1_p))
184                 {
185                         it->rest.print(c, precedence());
186                         c.s << '/';
187                         ex_to<numeric>(it->coeff).denom().print(c, precedence());
188                 } else {
189                         it->coeff.print(c, precedence());
190                         c.s << '*';
191                         it->rest.print(c, precedence());
192                 }
193                 
194                 ++it;
195                 separator = '+';
196         }
197         
198         if (!overall_coeff.is_zero()) {
199                 if (overall_coeff.info(info_flags::positive)
200                  || is_a<print_csrc_cl_N>(c) || !overall_coeff.info(info_flags::real))  // sign inside ctor argument
201                         c.s << '+';
202                 overall_coeff.print(c, precedence());
203         }
204                 
205         if (precedence() <= level)
206                 c.s << ")";
207 }
208
209 void add::do_print_python_repr(const print_python_repr & c, unsigned level) const
210 {
211         c.s << class_name() << '(';
212         op(0).print(c);
213         for (size_t i=1; i<nops(); ++i) {
214                 c.s << ',';
215                 op(i).print(c);
216         }
217         c.s << ')';
218 }
219
220 bool add::info(unsigned inf) const
221 {
222         switch (inf) {
223                 case info_flags::polynomial:
224                 case info_flags::integer_polynomial:
225                 case info_flags::cinteger_polynomial:
226                 case info_flags::rational_polynomial:
227                 case info_flags::real:
228                 case info_flags::rational:
229                 case info_flags::integer:
230                 case info_flags::crational:
231                 case info_flags::cinteger:
232                 case info_flags::positive:
233                 case info_flags::nonnegative:
234                 case info_flags::posint:
235                 case info_flags::nonnegint:
236                 case info_flags::even:
237                 case info_flags::crational_polynomial:
238                 case info_flags::rational_function: {
239                         epvector::const_iterator i = seq.begin(), end = seq.end();
240                         while (i != end) {
241                                 if (!(recombine_pair_to_ex(*i).info(inf)))
242                                         return false;
243                                 ++i;
244                         }
245                         if (overall_coeff.is_zero() && (inf == info_flags::positive || inf == info_flags::posint))
246                                 return true;
247                         return overall_coeff.info(inf);
248                 }
249                 case info_flags::algebraic: {
250                         epvector::const_iterator i = seq.begin(), end = seq.end();
251                         while (i != end) {
252                                 if ((recombine_pair_to_ex(*i).info(inf)))
253                                         return true;
254                                 ++i;
255                         }
256                         return false;
257                 }
258         }
259         return inherited::info(inf);
260 }
261
262 bool add::is_polynomial(const ex & var) const
263 {
264         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
265                 if (!(i->rest).is_polynomial(var)) {
266                         return false;
267                 }
268         }
269         return true;
270 }
271
272 int add::degree(const ex & s) const
273 {
274         int deg = std::numeric_limits<int>::min();
275         if (!overall_coeff.is_zero())
276                 deg = 0;
277         
278         // Find maximum of degrees of individual terms
279         epvector::const_iterator i = seq.begin(), end = seq.end();
280         while (i != end) {
281                 int cur_deg = i->rest.degree(s);
282                 if (cur_deg > deg)
283                         deg = cur_deg;
284                 ++i;
285         }
286         return deg;
287 }
288
289 int add::ldegree(const ex & s) const
290 {
291         int deg = std::numeric_limits<int>::max();
292         if (!overall_coeff.is_zero())
293                 deg = 0;
294         
295         // Find minimum of degrees of individual terms
296         epvector::const_iterator i = seq.begin(), end = seq.end();
297         while (i != end) {
298                 int cur_deg = i->rest.ldegree(s);
299                 if (cur_deg < deg)
300                         deg = cur_deg;
301                 ++i;
302         }
303         return deg;
304 }
305
306 ex add::coeff(const ex & s, int n) const
307 {
308         epvector coeffseq;
309         epvector coeffseq_cliff;
310         int rl = clifford_max_label(s);
311         bool do_clifford = (rl != -1);
312         bool nonscalar = false;
313
314         // Calculate sum of coefficients in each term
315         epvector::const_iterator i = seq.begin(), end = seq.end();
316         while (i != end) {
317                 ex restcoeff = i->rest.coeff(s, n);
318                 if (!restcoeff.is_zero()) {
319                         if (do_clifford) {
320                                 if (clifford_max_label(restcoeff) == -1) {
321                                         coeffseq_cliff.push_back(combine_ex_with_coeff_to_pair(ncmul(restcoeff, dirac_ONE(rl)), i->coeff));
322                                 } else {
323                                         coeffseq_cliff.push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
324                                         nonscalar = true;
325                                 }
326                         }
327                         coeffseq.push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
328                 }
329                 ++i;
330         }
331
332         return (new add(nonscalar ? std::move(coeffseq_cliff) : std::move(coeffseq),
333                         n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
334 }
335
336 /** Perform automatic term rewriting rules in this class.  In the following
337  *  x stands for a symbolic variables of type ex and c stands for such
338  *  an expression that contain a plain number.
339  *  - +(;c) -> c
340  *  - +(x;0) -> x
341  *
342  *  @param level cut-off in recursive evaluation */
343 ex add::eval(int level) const
344 {
345         epvector evaled = evalchildren(level);
346         if (!evaled.empty()) {
347                 // do more evaluation later
348                 return (new add(std::move(evaled), overall_coeff))->
349                         setflag(status_flags::dynallocated);
350         }
351
352 #ifdef DO_GINAC_ASSERT
353         epvector::const_iterator i = seq.begin(), end = seq.end();
354         while (i != end) {
355                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
356                 ++i;
357         }
358 #endif // def DO_GINAC_ASSERT
359         
360         if (flags & status_flags::evaluated) {
361                 GINAC_ASSERT(seq.size()>0);
362                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
363                 return *this;
364         }
365         
366         int seq_size = seq.size();
367         if (seq_size == 0) {
368                 // +(;c) -> c
369                 return overall_coeff;
370         } else if (seq_size == 1 && overall_coeff.is_zero()) {
371                 // +(x;0) -> x
372                 return recombine_pair_to_ex(*(seq.begin()));
373         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
374                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
375         }
376         
377         // if any terms in the sum still are purely numeric, then they are more
378         // appropriately collected into the overall coefficient
379         epvector::const_iterator last = seq.end();
380         epvector::const_iterator j = seq.begin();
381         int terms_to_collect = 0;
382         while (j != last) {
383                 if (unlikely(is_a<numeric>(j->rest)))
384                         ++terms_to_collect;
385                 ++j;
386         }
387         if (terms_to_collect) {
388                 epvector s;
389                 s.reserve(seq_size - terms_to_collect);
390                 numeric oc = *_num1_p;
391                 j = seq.begin();
392                 while (j != last) {
393                         if (unlikely(is_a<numeric>(j->rest)))
394                                 oc = oc.mul(ex_to<numeric>(j->rest)).mul(ex_to<numeric>(j->coeff));
395                         else
396                                 s.push_back(*j);
397                         ++j;
398                 }
399                 return (new add(std::move(s), ex_to<numeric>(overall_coeff).add_dyn(oc)))
400                         ->setflag(status_flags::dynallocated);
401         }
402         
403         return this->hold();
404 }
405
406 ex add::evalm() const
407 {
408         // Evaluate children first and add up all matrices. Stop if there's one
409         // term that is not a matrix.
410         epvector s;
411         s.reserve(seq.size());
412
413         bool all_matrices = true;
414         bool first_term = true;
415         matrix sum;
416
417         epvector::const_iterator it = seq.begin(), itend = seq.end();
418         while (it != itend) {
419                 const ex &m = recombine_pair_to_ex(*it).evalm();
420                 s.push_back(split_ex_to_pair(m));
421                 if (is_a<matrix>(m)) {
422                         if (first_term) {
423                                 sum = ex_to<matrix>(m);
424                                 first_term = false;
425                         } else
426                                 sum = sum.add(ex_to<matrix>(m));
427                 } else
428                         all_matrices = false;
429                 ++it;
430         }
431
432         if (all_matrices)
433                 return sum + overall_coeff;
434         else
435                 return (new add(std::move(s), overall_coeff))->setflag(status_flags::dynallocated);
436 }
437
438 ex add::conjugate() const
439 {
440         exvector *v = 0;
441         for (size_t i=0; i<nops(); ++i) {
442                 if (v) {
443                         v->push_back(op(i).conjugate());
444                         continue;
445                 }
446                 ex term = op(i);
447                 ex ccterm = term.conjugate();
448                 if (are_ex_trivially_equal(term, ccterm))
449                         continue;
450                 v = new exvector;
451                 v->reserve(nops());
452                 for (size_t j=0; j<i; ++j)
453                         v->push_back(op(j));
454                 v->push_back(ccterm);
455         }
456         if (v) {
457                 ex result = add(*v);
458                 delete v;
459                 return result;
460         }
461         return *this;
462 }
463
464 ex add::real_part() const
465 {
466         epvector v;
467         v.reserve(seq.size());
468         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i)
469                 if ((i->coeff).info(info_flags::real)) {
470                         ex rp = (i->rest).real_part();
471                         if (!rp.is_zero())
472                                 v.push_back(expair(rp, i->coeff));
473                 } else {
474                         ex rp=recombine_pair_to_ex(*i).real_part();
475                         if (!rp.is_zero())
476                                 v.push_back(split_ex_to_pair(rp));
477                 }
478         return (new add(v, overall_coeff.real_part()))
479                 -> setflag(status_flags::dynallocated);
480 }
481
482 ex add::imag_part() const
483 {
484         epvector v;
485         v.reserve(seq.size());
486         for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i)
487                 if ((i->coeff).info(info_flags::real)) {
488                         ex ip = (i->rest).imag_part();
489                         if (!ip.is_zero())
490                                 v.push_back(expair(ip, i->coeff));
491                 } else {
492                         ex ip=recombine_pair_to_ex(*i).imag_part();
493                         if (!ip.is_zero())
494                                 v.push_back(split_ex_to_pair(ip));
495                 }
496         return (new add(v, overall_coeff.imag_part()))
497                 -> setflag(status_flags::dynallocated);
498 }
499
500 ex add::eval_ncmul(const exvector & v) const
501 {
502         if (seq.empty())
503                 return inherited::eval_ncmul(v);
504         else
505                 return seq.begin()->rest.eval_ncmul(v);
506 }    
507
508 // protected
509
510 /** Implementation of ex::diff() for a sum. It differentiates each term.
511  *  @see ex::diff */
512 ex add::derivative(const symbol & y) const
513 {
514         epvector s;
515         s.reserve(seq.size());
516         
517         // Only differentiate the "rest" parts of the expairs. This is faster
518         // than the default implementation in basic::derivative() although
519         // if performs the same function (differentiate each term).
520         epvector::const_iterator i = seq.begin(), end = seq.end();
521         while (i != end) {
522                 s.push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
523                 ++i;
524         }
525         return (new add(std::move(s), _ex0))->setflag(status_flags::dynallocated);
526 }
527
528 int add::compare_same_type(const basic & other) const
529 {
530         return inherited::compare_same_type(other);
531 }
532
533 unsigned add::return_type() const
534 {
535         if (seq.empty())
536                 return return_types::commutative;
537         else
538                 return seq.begin()->rest.return_type();
539 }
540
541 return_type_t add::return_type_tinfo() const
542 {
543         if (seq.empty())
544                 return make_return_type_t<add>();
545         else
546                 return seq.begin()->rest.return_type_tinfo();
547 }
548
549 // Note: do_index_renaming is ignored because it makes no sense for an add.
550 ex add::thisexpairseq(const epvector & v, const ex & oc, bool do_index_renaming) const
551 {
552         return (new add(v,oc))->setflag(status_flags::dynallocated);
553 }
554
555 // Note: do_index_renaming is ignored because it makes no sense for an add.
556 ex add::thisexpairseq(epvector && vp, const ex & oc, bool do_index_renaming) const
557 {
558         return (new add(std::move(vp), oc))->setflag(status_flags::dynallocated);
559 }
560
561 expair add::split_ex_to_pair(const ex & e) const
562 {
563         if (is_exactly_a<mul>(e)) {
564                 const mul &mulref(ex_to<mul>(e));
565                 const ex &numfactor = mulref.overall_coeff;
566                 mul *mulcopyp = new mul(mulref);
567                 mulcopyp->overall_coeff = _ex1;
568                 mulcopyp->clearflag(status_flags::evaluated);
569                 mulcopyp->clearflag(status_flags::hash_calculated);
570                 mulcopyp->setflag(status_flags::dynallocated);
571                 return expair(*mulcopyp,numfactor);
572         }
573         return expair(e,_ex1);
574 }
575
576 expair add::combine_ex_with_coeff_to_pair(const ex & e,
577                                           const ex & c) const
578 {
579         GINAC_ASSERT(is_exactly_a<numeric>(c));
580         if (is_exactly_a<mul>(e)) {
581                 const mul &mulref(ex_to<mul>(e));
582                 const ex &numfactor = mulref.overall_coeff;
583                 mul *mulcopyp = new mul(mulref);
584                 mulcopyp->overall_coeff = _ex1;
585                 mulcopyp->clearflag(status_flags::evaluated);
586                 mulcopyp->clearflag(status_flags::hash_calculated);
587                 mulcopyp->setflag(status_flags::dynallocated);
588                 if (c.is_equal(_ex1))
589                         return expair(*mulcopyp, numfactor);
590                 else if (numfactor.is_equal(_ex1))
591                         return expair(*mulcopyp, c);
592                 else
593                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
594         } else if (is_exactly_a<numeric>(e)) {
595                 if (c.is_equal(_ex1))
596                         return expair(e, _ex1);
597                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
598         }
599         return expair(e, c);
600 }
601
602 expair add::combine_pair_with_coeff_to_pair(const expair & p,
603                                             const ex & c) const
604 {
605         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
606         GINAC_ASSERT(is_exactly_a<numeric>(c));
607
608         if (is_exactly_a<numeric>(p.rest)) {
609                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(*_num1_p)); // should be normalized
610                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
611         }
612
613         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
614 }
615
616 ex add::recombine_pair_to_ex(const expair & p) const
617 {
618         if (ex_to<numeric>(p.coeff).is_equal(*_num1_p))
619                 return p.rest;
620         else
621                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
622 }
623
624 ex add::expand(unsigned options) const
625 {
626         epvector expanded = expandchildren(options);
627         if (expanded.empty())
628                 return (options == 0) ? setflag(status_flags::expanded) : *this;
629
630         return (new add(std::move(expanded), overall_coeff))->setflag(status_flags::dynallocated |
631                                                                       (options == 0 ? status_flags::expanded : 0));
632 }
633
634 } // namespace GiNaC