]> www.ginac.de Git - ginac.git/blob - ginac/add.cpp
79994e11c5ddf47e54b84daf883b014b1f1410a4
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "add.h"
24 #include "mul.h"
25 #include "archive.h"
26 #include "operators.h"
27 #include "matrix.h"
28 #include "utils.h"
29 #include "clifford.h"
30 #include "ncmul.h"
31 #include "compiler.h"
32
33 #include <iostream>
34 #include <limits>
35 #include <stdexcept>
36 #include <string>
37
38 namespace GiNaC {
39
40 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(add, expairseq,
41   print_func<print_context>(&add::do_print).
42   print_func<print_latex>(&add::do_print_latex).
43   print_func<print_csrc>(&add::do_print_csrc).
44   print_func<print_tree>(&add::do_print_tree).
45   print_func<print_python_repr>(&add::do_print_python_repr))
46
47 //////////
48 // default constructor
49 //////////
50
51 add::add()
52 {
53 }
54
55 //////////
56 // other constructors
57 //////////
58
59 // public
60
61 add::add(const ex & lh, const ex & rh)
62 {
63         overall_coeff = _ex0;
64         construct_from_2_ex(lh,rh);
65         GINAC_ASSERT(is_canonical());
66 }
67
68 add::add(const exvector & v)
69 {
70         overall_coeff = _ex0;
71         construct_from_exvector(v);
72         GINAC_ASSERT(is_canonical());
73 }
74
75 add::add(const epvector & v)
76 {
77         overall_coeff = _ex0;
78         construct_from_epvector(v);
79         GINAC_ASSERT(is_canonical());
80 }
81
82 add::add(const epvector & v, const ex & oc)
83 {
84         overall_coeff = oc;
85         construct_from_epvector(v);
86         GINAC_ASSERT(is_canonical());
87 }
88
89 add::add(epvector && vp)
90 {
91         overall_coeff = _ex0;
92         construct_from_epvector(std::move(vp));
93         GINAC_ASSERT(is_canonical());
94 }
95
96 add::add(epvector && vp, const ex & oc)
97 {
98         overall_coeff = oc;
99         construct_from_epvector(std::move(vp));
100         GINAC_ASSERT(is_canonical());
101 }
102
103 //////////
104 // archiving
105 //////////
106
107 GINAC_BIND_UNARCHIVER(add);
108
109 //////////
110 // functions overriding virtual functions from base classes
111 //////////
112
113 // public
114
115 void add::print_add(const print_context & c, const char *openbrace, const char *closebrace, const char *mul_sym, unsigned level) const
116 {
117         if (precedence() <= level)
118                 c.s << openbrace << '(';
119
120         numeric coeff;
121         bool first = true;
122
123         // First print the overall numeric coefficient, if present
124         if (!overall_coeff.is_zero()) {
125                 overall_coeff.print(c, 0);
126                 first = false;
127         }
128
129         // Then proceed with the remaining factors
130         for (auto & it : seq) {
131                 coeff = ex_to<numeric>(it.coeff);
132                 if (!first) {
133                         if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
134                 } else {
135                         if (coeff.csgn() == -1) c.s << '-';
136                         first = false;
137                 }
138                 if (!coeff.is_equal(*_num1_p) &&
139                     !coeff.is_equal(*_num_1_p)) {
140                         if (coeff.is_rational()) {
141                                 if (coeff.is_negative())
142                                         (-coeff).print(c);
143                                 else
144                                         coeff.print(c);
145                         } else {
146                                 if (coeff.csgn() == -1)
147                                         (-coeff).print(c, precedence());
148                                 else
149                                         coeff.print(c, precedence());
150                         }
151                         c.s << mul_sym;
152                 }
153                 it.rest.print(c, precedence());
154         }
155
156         if (precedence() <= level)
157                 c.s << ')' << closebrace;
158 }
159
160 void add::do_print(const print_context & c, unsigned level) const
161 {
162         print_add(c, "", "", "*", level);
163 }
164
165 void add::do_print_latex(const print_latex & c, unsigned level) const
166 {
167         print_add(c, "{", "}", " ", level);
168 }
169
170 void add::do_print_csrc(const print_csrc & c, unsigned level) const
171 {
172         if (precedence() <= level)
173                 c.s << "(";
174         
175         // Print arguments, separated by "+" or "-"
176         char separator = ' ';
177         for (auto & it : seq) {
178                 
179                 // If the coefficient is negative, separator is "-"
180                 if (it.coeff.is_equal(_ex_1) ||
181                         ex_to<numeric>(it.coeff).numer().is_equal(*_num_1_p))
182                         separator = '-';
183                 c.s << separator;
184                 if (it.coeff.is_equal(_ex1) || it.coeff.is_equal(_ex_1)) {
185                         it.rest.print(c, precedence());
186                 } else if (ex_to<numeric>(it.coeff).numer().is_equal(*_num1_p) ||
187                                  ex_to<numeric>(it.coeff).numer().is_equal(*_num_1_p))
188                 {
189                         it.rest.print(c, precedence());
190                         c.s << '/';
191                         ex_to<numeric>(it.coeff).denom().print(c, precedence());
192                 } else {
193                         it.coeff.print(c, precedence());
194                         c.s << '*';
195                         it.rest.print(c, precedence());
196                 }
197                 
198                 separator = '+';
199         }
200         
201         if (!overall_coeff.is_zero()) {
202                 if (overall_coeff.info(info_flags::positive)
203                  || is_a<print_csrc_cl_N>(c) || !overall_coeff.info(info_flags::real))  // sign inside ctor argument
204                         c.s << '+';
205                 overall_coeff.print(c, precedence());
206         }
207                 
208         if (precedence() <= level)
209                 c.s << ")";
210 }
211
212 void add::do_print_python_repr(const print_python_repr & c, unsigned level) const
213 {
214         c.s << class_name() << '(';
215         op(0).print(c);
216         for (size_t i=1; i<nops(); ++i) {
217                 c.s << ',';
218                 op(i).print(c);
219         }
220         c.s << ')';
221 }
222
223 bool add::info(unsigned inf) const
224 {
225         switch (inf) {
226                 case info_flags::polynomial:
227                 case info_flags::integer_polynomial:
228                 case info_flags::cinteger_polynomial:
229                 case info_flags::rational_polynomial:
230                 case info_flags::real:
231                 case info_flags::rational:
232                 case info_flags::integer:
233                 case info_flags::crational:
234                 case info_flags::cinteger:
235                 case info_flags::positive:
236                 case info_flags::nonnegative:
237                 case info_flags::posint:
238                 case info_flags::nonnegint:
239                 case info_flags::even:
240                 case info_flags::crational_polynomial:
241                 case info_flags::rational_function: {
242                         for (auto & i : seq) {
243                                 if (!(recombine_pair_to_ex(i).info(inf)))
244                                         return false;
245                         }
246                         if (overall_coeff.is_zero() && (inf == info_flags::positive || inf == info_flags::posint))
247                                 return true;
248                         return overall_coeff.info(inf);
249                 }
250                 case info_flags::algebraic: {
251                         epvector::const_iterator i = seq.begin(), end = seq.end();
252                         while (i != end) {
253                                 if ((recombine_pair_to_ex(*i).info(inf)))
254                                         return true;
255                                 ++i;
256                         }
257                         return false;
258                 }
259         }
260         return inherited::info(inf);
261 }
262
263 bool add::is_polynomial(const ex & var) const
264 {
265         for (auto & i : seq) {
266                 if (!i.rest.is_polynomial(var)) {
267                         return false;
268                 }
269         }
270         return true;
271 }
272
273 int add::degree(const ex & s) const
274 {
275         int deg = std::numeric_limits<int>::min();
276         if (!overall_coeff.is_zero())
277                 deg = 0;
278         
279         // Find maximum of degrees of individual terms
280         for (auto & i : seq) {
281                 int cur_deg = i.rest.degree(s);
282                 if (cur_deg > deg)
283                         deg = cur_deg;
284         }
285         return deg;
286 }
287
288 int add::ldegree(const ex & s) const
289 {
290         int deg = std::numeric_limits<int>::max();
291         if (!overall_coeff.is_zero())
292                 deg = 0;
293         
294         // Find minimum of degrees of individual terms
295         for (auto & i : seq) {
296                 int cur_deg = i.rest.ldegree(s);
297                 if (cur_deg < deg)
298                         deg = cur_deg;
299         }
300         return deg;
301 }
302
303 ex add::coeff(const ex & s, int n) const
304 {
305         epvector coeffseq;
306         epvector coeffseq_cliff;
307         int rl = clifford_max_label(s);
308         bool do_clifford = (rl != -1);
309         bool nonscalar = false;
310
311         // Calculate sum of coefficients in each term
312         for (auto & i : seq) {
313                 ex restcoeff = i.rest.coeff(s, n);
314                 if (!restcoeff.is_zero()) {
315                         if (do_clifford) {
316                                 if (clifford_max_label(restcoeff) == -1) {
317                                         coeffseq_cliff.push_back(expair(ncmul(restcoeff, dirac_ONE(rl)), i.coeff));
318                                 } else {
319                                         coeffseq_cliff.push_back(expair(restcoeff, i.coeff));
320                                         nonscalar = true;
321                                 }
322                         }
323                         coeffseq.push_back(expair(restcoeff, i.coeff));
324                 }
325         }
326
327         return dynallocate<add>(nonscalar ? std::move(coeffseq_cliff) : std::move(coeffseq),
328                                 n==0 ? overall_coeff : _ex0);
329 }
330
331 /** Perform automatic term rewriting rules in this class.  In the following
332  *  x stands for a symbolic variables of type ex and c stands for such
333  *  an expression that contain a plain number.
334  *  - +(;c) -> c
335  *  - +(x;0) -> x
336  */
337 ex add::eval() const
338 {
339         if (flags & status_flags::evaluated) {
340                 GINAC_ASSERT(seq.size()>0);
341                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
342                 return *this;
343         }
344
345         const epvector evaled = evalchildren();
346         if (unlikely(!evaled.empty())) {
347                 // start over evaluating a new object
348                 return dynallocate<add>(std::move(evaled), overall_coeff);
349         }
350
351 #ifdef DO_GINAC_ASSERT
352         for (auto & i : seq) {
353                 GINAC_ASSERT(!is_exactly_a<add>(i.rest));
354         }
355 #endif // def DO_GINAC_ASSERT
356
357         int seq_size = seq.size();
358         if (seq_size == 0) {
359                 // +(;c) -> c
360                 return overall_coeff;
361         } else if (seq_size == 1 && overall_coeff.is_zero()) {
362                 // +(x;0) -> x
363                 return recombine_pair_to_ex(*(seq.begin()));
364         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
365                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
366         }
367         
368         // if any terms in the sum still are purely numeric, then they are more
369         // appropriately collected into the overall coefficient
370         int terms_to_collect = 0;
371         for (auto & it : seq) {
372                 if (unlikely(is_a<numeric>(it.rest)))
373                         ++terms_to_collect;
374         }
375         if (terms_to_collect) {
376                 epvector s;
377                 s.reserve(seq_size - terms_to_collect);
378                 numeric oc = *_num1_p;
379                 for (auto & it : seq) {
380                         if (unlikely(is_a<numeric>(it.rest)))
381                                 oc = oc.mul(ex_to<numeric>(it.rest)).mul(ex_to<numeric>(it.coeff));
382                         else
383                                 s.push_back(it);
384                 }
385                 return dynallocate<add>(std::move(s), ex_to<numeric>(overall_coeff).add_dyn(oc));
386         }
387         
388         return this->hold();
389 }
390
391 ex add::evalm() const
392 {
393         // Evaluate children first and add up all matrices. Stop if there's one
394         // term that is not a matrix.
395         epvector s;
396         s.reserve(seq.size());
397
398         bool all_matrices = true;
399         bool first_term = true;
400         matrix sum;
401
402         for (auto & it : seq) {
403                 const ex &m = recombine_pair_to_ex(it).evalm();
404                 s.push_back(split_ex_to_pair(m));
405                 if (is_a<matrix>(m)) {
406                         if (first_term) {
407                                 sum = ex_to<matrix>(m);
408                                 first_term = false;
409                         } else
410                                 sum = sum.add(ex_to<matrix>(m));
411                 } else
412                         all_matrices = false;
413         }
414
415         if (all_matrices)
416                 return sum + overall_coeff;
417         else
418                 return dynallocate<add>(std::move(s), overall_coeff);
419 }
420
421 ex add::conjugate() const
422 {
423         std::unique_ptr<exvector> v(nullptr);
424         for (size_t i=0; i<nops(); ++i) {
425                 if (v) {
426                         v->push_back(op(i).conjugate());
427                         continue;
428                 }
429                 ex term = op(i);
430                 ex ccterm = term.conjugate();
431                 if (are_ex_trivially_equal(term, ccterm))
432                         continue;
433                 v.reset(new exvector);
434                 v->reserve(nops());
435                 for (size_t j=0; j<i; ++j)
436                         v->push_back(op(j));
437                 v->push_back(ccterm);
438         }
439         if (v) {
440                 return add(std::move(*v));
441         }
442         return *this;
443 }
444
445 ex add::real_part() const
446 {
447         epvector v;
448         v.reserve(seq.size());
449         for (auto & it : seq)
450                 if (it.coeff.info(info_flags::real)) {
451                         ex rp = it.rest.real_part();
452                         if (!rp.is_zero())
453                                 v.push_back(expair(rp, it.coeff));
454                 } else {
455                         ex rp = recombine_pair_to_ex(it).real_part();
456                         if (!rp.is_zero())
457                                 v.push_back(split_ex_to_pair(rp));
458                 }
459         return dynallocate<add>(std::move(v), overall_coeff.real_part());
460 }
461
462 ex add::imag_part() const
463 {
464         epvector v;
465         v.reserve(seq.size());
466         for (auto & it : seq)
467                 if (it.coeff.info(info_flags::real)) {
468                         ex ip = it.rest.imag_part();
469                         if (!ip.is_zero())
470                                 v.push_back(expair(ip, it.coeff));
471                 } else {
472                         ex ip = recombine_pair_to_ex(it).imag_part();
473                         if (!ip.is_zero())
474                                 v.push_back(split_ex_to_pair(ip));
475                 }
476         return dynallocate<add>(std::move(v), overall_coeff.imag_part());
477 }
478
479 ex add::eval_ncmul(const exvector & v) const
480 {
481         if (seq.empty())
482                 return inherited::eval_ncmul(v);
483         else
484                 return seq.begin()->rest.eval_ncmul(v);
485 }    
486
487 // protected
488
489 /** Implementation of ex::diff() for a sum. It differentiates each term.
490  *  @see ex::diff */
491 ex add::derivative(const symbol & y) const
492 {
493         epvector s;
494         s.reserve(seq.size());
495         
496         // Only differentiate the "rest" parts of the expairs. This is faster
497         // than the default implementation in basic::derivative() although
498         // if performs the same function (differentiate each term).
499         for (auto & it : seq)
500                 s.push_back(expair(it.rest.diff(y), it.coeff));
501
502         return dynallocate<add>(std::move(s));
503 }
504
505 int add::compare_same_type(const basic & other) const
506 {
507         return inherited::compare_same_type(other);
508 }
509
510 unsigned add::return_type() const
511 {
512         if (seq.empty())
513                 return return_types::commutative;
514         else
515                 return seq.begin()->rest.return_type();
516 }
517
518 return_type_t add::return_type_tinfo() const
519 {
520         if (seq.empty())
521                 return make_return_type_t<add>();
522         else
523                 return seq.begin()->rest.return_type_tinfo();
524 }
525
526 // Note: do_index_renaming is ignored because it makes no sense for an add.
527 ex add::thisexpairseq(const epvector & v, const ex & oc, bool do_index_renaming) const
528 {
529         return dynallocate<add>(v, oc);
530 }
531
532 // Note: do_index_renaming is ignored because it makes no sense for an add.
533 ex add::thisexpairseq(epvector && vp, const ex & oc, bool do_index_renaming) const
534 {
535         return dynallocate<add>(std::move(vp), oc);
536 }
537
538 expair add::split_ex_to_pair(const ex & e) const
539 {
540         if (is_exactly_a<mul>(e)) {
541                 const mul &mulref(ex_to<mul>(e));
542                 const ex &numfactor = mulref.overall_coeff;
543                 if (numfactor.is_equal(_ex1))
544                         return expair(e, _ex1);
545                 mul & mulcopy = dynallocate<mul>(mulref);
546                 mulcopy.overall_coeff = _ex1;
547                 mulcopy.clearflag(status_flags::evaluated | status_flags::hash_calculated);
548                 return expair(mulcopy, numfactor);
549         }
550         return expair(e,_ex1);
551 }
552
553 expair add::combine_ex_with_coeff_to_pair(const ex & e,
554                                           const ex & c) const
555 {
556         GINAC_ASSERT(is_exactly_a<numeric>(c));
557         if (is_exactly_a<mul>(e)) {
558                 const mul &mulref(ex_to<mul>(e));
559                 const ex &numfactor = mulref.overall_coeff;
560                 if (likely(numfactor.is_equal(_ex1)))
561                         return expair(e, c);
562                 mul & mulcopy = dynallocate<mul>(mulref);
563                 mulcopy.overall_coeff = _ex1;
564                 mulcopy.clearflag(status_flags::evaluated | status_flags::hash_calculated);
565                 if (c.is_equal(_ex1))
566                         return expair(mulcopy, numfactor);
567                 else
568                         return expair(mulcopy, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
569         } else if (is_exactly_a<numeric>(e)) {
570                 if (c.is_equal(_ex1))
571                         return expair(e, _ex1);
572                 if (e.is_equal(_ex1))
573                         return expair(c, _ex1);
574                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
575         }
576         return expair(e, c);
577 }
578
579 expair add::combine_pair_with_coeff_to_pair(const expair & p,
580                                             const ex & c) const
581 {
582         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
583         GINAC_ASSERT(is_exactly_a<numeric>(c));
584
585         if (is_exactly_a<numeric>(p.rest)) {
586                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(*_num1_p)); // should be normalized
587                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
588         }
589
590         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
591 }
592
593 ex add::recombine_pair_to_ex(const expair & p) const
594 {
595         if (ex_to<numeric>(p.coeff).is_equal(*_num1_p))
596                 return p.rest;
597         else
598                 return dynallocate<mul>(p.rest, p.coeff);
599 }
600
601 ex add::expand(unsigned options) const
602 {
603         epvector expanded = expandchildren(options);
604         if (expanded.empty())
605                 return (options == 0) ? setflag(status_flags::expanded) : *this;
606
607         return dynallocate<add>(std::move(expanded), overall_coeff).setflag(options == 0 ? status_flags::expanded : 0);
608 }
609
610 } // namespace GiNaC