]> www.ginac.de Git - ginac.git/blob - ginac/add.cpp
Tutorial: how to create noncommutative symbols?
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "add.h"
24 #include "mul.h"
25 #include "archive.h"
26 #include "operators.h"
27 #include "matrix.h"
28 #include "utils.h"
29 #include "clifford.h"
30 #include "ncmul.h"
31 #include "compiler.h"
32
33 #include <iostream>
34 #include <limits>
35 #include <stdexcept>
36 #include <string>
37
38 namespace GiNaC {
39
40 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(add, expairseq,
41   print_func<print_context>(&add::do_print).
42   print_func<print_latex>(&add::do_print_latex).
43   print_func<print_csrc>(&add::do_print_csrc).
44   print_func<print_tree>(&add::do_print_tree).
45   print_func<print_python_repr>(&add::do_print_python_repr))
46
47 //////////
48 // default constructor
49 //////////
50
51 add::add()
52 {
53 }
54
55 //////////
56 // other constructors
57 //////////
58
59 // public
60
61 add::add(const ex & lh, const ex & rh)
62 {
63         overall_coeff = _ex0;
64         construct_from_2_ex(lh,rh);
65         GINAC_ASSERT(is_canonical());
66 }
67
68 add::add(const exvector & v)
69 {
70         overall_coeff = _ex0;
71         construct_from_exvector(v);
72         GINAC_ASSERT(is_canonical());
73 }
74
75 add::add(const epvector & v)
76 {
77         overall_coeff = _ex0;
78         construct_from_epvector(v);
79         GINAC_ASSERT(is_canonical());
80 }
81
82 add::add(const epvector & v, const ex & oc)
83 {
84         overall_coeff = oc;
85         construct_from_epvector(v);
86         GINAC_ASSERT(is_canonical());
87 }
88
89 add::add(epvector && vp, const ex & oc)
90 {
91         overall_coeff = oc;
92         construct_from_epvector(std::move(vp));
93         GINAC_ASSERT(is_canonical());
94 }
95
96 //////////
97 // archiving
98 //////////
99
100 GINAC_BIND_UNARCHIVER(add);
101
102 //////////
103 // functions overriding virtual functions from base classes
104 //////////
105
106 // public
107
108 void add::print_add(const print_context & c, const char *openbrace, const char *closebrace, const char *mul_sym, unsigned level) const
109 {
110         if (precedence() <= level)
111                 c.s << openbrace << '(';
112
113         numeric coeff;
114         bool first = true;
115
116         // First print the overall numeric coefficient, if present
117         if (!overall_coeff.is_zero()) {
118                 overall_coeff.print(c, 0);
119                 first = false;
120         }
121
122         // Then proceed with the remaining factors
123         for (auto & it : seq) {
124                 coeff = ex_to<numeric>(it.coeff);
125                 if (!first) {
126                         if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
127                 } else {
128                         if (coeff.csgn() == -1) c.s << '-';
129                         first = false;
130                 }
131                 if (!coeff.is_equal(*_num1_p) &&
132                     !coeff.is_equal(*_num_1_p)) {
133                         if (coeff.is_rational()) {
134                                 if (coeff.is_negative())
135                                         (-coeff).print(c);
136                                 else
137                                         coeff.print(c);
138                         } else {
139                                 if (coeff.csgn() == -1)
140                                         (-coeff).print(c, precedence());
141                                 else
142                                         coeff.print(c, precedence());
143                         }
144                         c.s << mul_sym;
145                 }
146                 it.rest.print(c, precedence());
147         }
148
149         if (precedence() <= level)
150                 c.s << ')' << closebrace;
151 }
152
153 void add::do_print(const print_context & c, unsigned level) const
154 {
155         print_add(c, "", "", "*", level);
156 }
157
158 void add::do_print_latex(const print_latex & c, unsigned level) const
159 {
160         print_add(c, "{", "}", " ", level);
161 }
162
163 void add::do_print_csrc(const print_csrc & c, unsigned level) const
164 {
165         if (precedence() <= level)
166                 c.s << "(";
167         
168         // Print arguments, separated by "+" or "-"
169         char separator = ' ';
170         for (auto & it : seq) {
171                 
172                 // If the coefficient is negative, separator is "-"
173                 if (it.coeff.is_equal(_ex_1) ||
174                         ex_to<numeric>(it.coeff).numer().is_equal(*_num_1_p))
175                         separator = '-';
176                 c.s << separator;
177                 if (it.coeff.is_equal(_ex1) || it.coeff.is_equal(_ex_1)) {
178                         it.rest.print(c, precedence());
179                 } else if (ex_to<numeric>(it.coeff).numer().is_equal(*_num1_p) ||
180                                  ex_to<numeric>(it.coeff).numer().is_equal(*_num_1_p))
181                 {
182                         it.rest.print(c, precedence());
183                         c.s << '/';
184                         ex_to<numeric>(it.coeff).denom().print(c, precedence());
185                 } else {
186                         it.coeff.print(c, precedence());
187                         c.s << '*';
188                         it.rest.print(c, precedence());
189                 }
190                 
191                 separator = '+';
192         }
193         
194         if (!overall_coeff.is_zero()) {
195                 if (overall_coeff.info(info_flags::positive)
196                  || is_a<print_csrc_cl_N>(c) || !overall_coeff.info(info_flags::real))  // sign inside ctor argument
197                         c.s << '+';
198                 overall_coeff.print(c, precedence());
199         }
200                 
201         if (precedence() <= level)
202                 c.s << ")";
203 }
204
205 void add::do_print_python_repr(const print_python_repr & c, unsigned level) const
206 {
207         c.s << class_name() << '(';
208         op(0).print(c);
209         for (size_t i=1; i<nops(); ++i) {
210                 c.s << ',';
211                 op(i).print(c);
212         }
213         c.s << ')';
214 }
215
216 bool add::info(unsigned inf) const
217 {
218         switch (inf) {
219                 case info_flags::polynomial:
220                 case info_flags::integer_polynomial:
221                 case info_flags::cinteger_polynomial:
222                 case info_flags::rational_polynomial:
223                 case info_flags::real:
224                 case info_flags::rational:
225                 case info_flags::integer:
226                 case info_flags::crational:
227                 case info_flags::cinteger:
228                 case info_flags::positive:
229                 case info_flags::nonnegative:
230                 case info_flags::posint:
231                 case info_flags::nonnegint:
232                 case info_flags::even:
233                 case info_flags::crational_polynomial:
234                 case info_flags::rational_function: {
235                         for (auto & i : seq) {
236                                 if (!(recombine_pair_to_ex(i).info(inf)))
237                                         return false;
238                         }
239                         if (overall_coeff.is_zero() && (inf == info_flags::positive || inf == info_flags::posint))
240                                 return true;
241                         return overall_coeff.info(inf);
242                 }
243                 case info_flags::algebraic: {
244                         epvector::const_iterator i = seq.begin(), end = seq.end();
245                         while (i != end) {
246                                 if ((recombine_pair_to_ex(*i).info(inf)))
247                                         return true;
248                                 ++i;
249                         }
250                         return false;
251                 }
252         }
253         return inherited::info(inf);
254 }
255
256 bool add::is_polynomial(const ex & var) const
257 {
258         for (auto & i : seq) {
259                 if (!i.rest.is_polynomial(var)) {
260                         return false;
261                 }
262         }
263         return true;
264 }
265
266 int add::degree(const ex & s) const
267 {
268         int deg = std::numeric_limits<int>::min();
269         if (!overall_coeff.is_zero())
270                 deg = 0;
271         
272         // Find maximum of degrees of individual terms
273         for (auto & i : seq) {
274                 int cur_deg = i.rest.degree(s);
275                 if (cur_deg > deg)
276                         deg = cur_deg;
277         }
278         return deg;
279 }
280
281 int add::ldegree(const ex & s) const
282 {
283         int deg = std::numeric_limits<int>::max();
284         if (!overall_coeff.is_zero())
285                 deg = 0;
286         
287         // Find minimum of degrees of individual terms
288         for (auto & i : seq) {
289                 int cur_deg = i.rest.ldegree(s);
290                 if (cur_deg < deg)
291                         deg = cur_deg;
292         }
293         return deg;
294 }
295
296 ex add::coeff(const ex & s, int n) const
297 {
298         epvector coeffseq;
299         epvector coeffseq_cliff;
300         int rl = clifford_max_label(s);
301         bool do_clifford = (rl != -1);
302         bool nonscalar = false;
303
304         // Calculate sum of coefficients in each term
305         for (auto & i : seq) {
306                 ex restcoeff = i.rest.coeff(s, n);
307                 if (!restcoeff.is_zero()) {
308                         if (do_clifford) {
309                                 if (clifford_max_label(restcoeff) == -1) {
310                                         coeffseq_cliff.push_back(combine_ex_with_coeff_to_pair(ncmul(restcoeff, dirac_ONE(rl)), i.coeff));
311                                 } else {
312                                         coeffseq_cliff.push_back(combine_ex_with_coeff_to_pair(restcoeff, i.coeff));
313                                         nonscalar = true;
314                                 }
315                         }
316                         coeffseq.push_back(combine_ex_with_coeff_to_pair(restcoeff, i.coeff));
317                 }
318         }
319
320         return (new add(nonscalar ? std::move(coeffseq_cliff) : std::move(coeffseq),
321                         n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
322 }
323
324 /** Perform automatic term rewriting rules in this class.  In the following
325  *  x stands for a symbolic variables of type ex and c stands for such
326  *  an expression that contain a plain number.
327  *  - +(;c) -> c
328  *  - +(x;0) -> x
329  *
330  *  @param level cut-off in recursive evaluation */
331 ex add::eval(int level) const
332 {
333         epvector evaled = evalchildren(level);
334         if (unlikely(!evaled.empty())) {
335                 // do more evaluation later
336                 return (new add(std::move(evaled), overall_coeff))->
337                         setflag(status_flags::dynallocated);
338         }
339
340 #ifdef DO_GINAC_ASSERT
341         for (auto & i : seq) {
342                 GINAC_ASSERT(!is_exactly_a<add>(i.rest));
343         }
344 #endif // def DO_GINAC_ASSERT
345         
346         if (flags & status_flags::evaluated) {
347                 GINAC_ASSERT(seq.size()>0);
348                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
349                 return *this;
350         }
351         
352         int seq_size = seq.size();
353         if (seq_size == 0) {
354                 // +(;c) -> c
355                 return overall_coeff;
356         } else if (seq_size == 1 && overall_coeff.is_zero()) {
357                 // +(x;0) -> x
358                 return recombine_pair_to_ex(*(seq.begin()));
359         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
360                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
361         }
362         
363         // if any terms in the sum still are purely numeric, then they are more
364         // appropriately collected into the overall coefficient
365         int terms_to_collect = 0;
366         for (auto & it : seq) {
367                 if (unlikely(is_a<numeric>(it.rest)))
368                         ++terms_to_collect;
369         }
370         if (terms_to_collect) {
371                 epvector s;
372                 s.reserve(seq_size - terms_to_collect);
373                 numeric oc = *_num1_p;
374                 for (auto & it : seq) {
375                         if (unlikely(is_a<numeric>(it.rest)))
376                                 oc = oc.mul(ex_to<numeric>(it.rest)).mul(ex_to<numeric>(it.coeff));
377                         else
378                                 s.push_back(it);
379                 }
380                 return (new add(std::move(s), ex_to<numeric>(overall_coeff).add_dyn(oc)))
381                         ->setflag(status_flags::dynallocated);
382         }
383         
384         return this->hold();
385 }
386
387 ex add::evalm() const
388 {
389         // Evaluate children first and add up all matrices. Stop if there's one
390         // term that is not a matrix.
391         epvector s;
392         s.reserve(seq.size());
393
394         bool all_matrices = true;
395         bool first_term = true;
396         matrix sum;
397
398         for (auto & it : seq) {
399                 const ex &m = recombine_pair_to_ex(it).evalm();
400                 s.push_back(split_ex_to_pair(m));
401                 if (is_a<matrix>(m)) {
402                         if (first_term) {
403                                 sum = ex_to<matrix>(m);
404                                 first_term = false;
405                         } else
406                                 sum = sum.add(ex_to<matrix>(m));
407                 } else
408                         all_matrices = false;
409         }
410
411         if (all_matrices)
412                 return sum + overall_coeff;
413         else
414                 return (new add(std::move(s), overall_coeff))->setflag(status_flags::dynallocated);
415 }
416
417 ex add::conjugate() const
418 {
419         std::unique_ptr<exvector> v(nullptr);
420         for (size_t i=0; i<nops(); ++i) {
421                 if (v) {
422                         v->push_back(op(i).conjugate());
423                         continue;
424                 }
425                 ex term = op(i);
426                 ex ccterm = term.conjugate();
427                 if (are_ex_trivially_equal(term, ccterm))
428                         continue;
429                 v.reset(new exvector);
430                 v->reserve(nops());
431                 for (size_t j=0; j<i; ++j)
432                         v->push_back(op(j));
433                 v->push_back(ccterm);
434         }
435         if (v) {
436                 return add(std::move(*v));
437         }
438         return *this;
439 }
440
441 ex add::real_part() const
442 {
443         epvector v;
444         v.reserve(seq.size());
445         for (auto & it : seq)
446                 if (it.coeff.info(info_flags::real)) {
447                         ex rp = it.rest.real_part();
448                         if (!rp.is_zero())
449                                 v.push_back(expair(rp, it.coeff));
450                 } else {
451                         ex rp = recombine_pair_to_ex(it).real_part();
452                         if (!rp.is_zero())
453                                 v.push_back(split_ex_to_pair(rp));
454                 }
455         return (new add(std::move(v), overall_coeff.real_part()))
456                 -> setflag(status_flags::dynallocated);
457 }
458
459 ex add::imag_part() const
460 {
461         epvector v;
462         v.reserve(seq.size());
463         for (auto & it : seq)
464                 if (it.coeff.info(info_flags::real)) {
465                         ex ip = it.rest.imag_part();
466                         if (!ip.is_zero())
467                                 v.push_back(expair(ip, it.coeff));
468                 } else {
469                         ex ip = recombine_pair_to_ex(it).imag_part();
470                         if (!ip.is_zero())
471                                 v.push_back(split_ex_to_pair(ip));
472                 }
473         return (new add(std::move(v), overall_coeff.imag_part()))
474                 -> setflag(status_flags::dynallocated);
475 }
476
477 ex add::eval_ncmul(const exvector & v) const
478 {
479         if (seq.empty())
480                 return inherited::eval_ncmul(v);
481         else
482                 return seq.begin()->rest.eval_ncmul(v);
483 }    
484
485 // protected
486
487 /** Implementation of ex::diff() for a sum. It differentiates each term.
488  *  @see ex::diff */
489 ex add::derivative(const symbol & y) const
490 {
491         epvector s;
492         s.reserve(seq.size());
493         
494         // Only differentiate the "rest" parts of the expairs. This is faster
495         // than the default implementation in basic::derivative() although
496         // if performs the same function (differentiate each term).
497         for (auto & it : seq)
498                 s.push_back(combine_ex_with_coeff_to_pair(it.rest.diff(y), it.coeff));
499
500         return (new add(std::move(s), _ex0))->setflag(status_flags::dynallocated);
501 }
502
503 int add::compare_same_type(const basic & other) const
504 {
505         return inherited::compare_same_type(other);
506 }
507
508 unsigned add::return_type() const
509 {
510         if (seq.empty())
511                 return return_types::commutative;
512         else
513                 return seq.begin()->rest.return_type();
514 }
515
516 return_type_t add::return_type_tinfo() const
517 {
518         if (seq.empty())
519                 return make_return_type_t<add>();
520         else
521                 return seq.begin()->rest.return_type_tinfo();
522 }
523
524 // Note: do_index_renaming is ignored because it makes no sense for an add.
525 ex add::thisexpairseq(const epvector & v, const ex & oc, bool do_index_renaming) const
526 {
527         return (new add(v,oc))->setflag(status_flags::dynallocated);
528 }
529
530 // Note: do_index_renaming is ignored because it makes no sense for an add.
531 ex add::thisexpairseq(epvector && vp, const ex & oc, bool do_index_renaming) const
532 {
533         return (new add(std::move(vp), oc))->setflag(status_flags::dynallocated);
534 }
535
536 expair add::split_ex_to_pair(const ex & e) const
537 {
538         if (is_exactly_a<mul>(e)) {
539                 const mul &mulref(ex_to<mul>(e));
540                 const ex &numfactor = mulref.overall_coeff;
541                 if (numfactor.is_equal(_ex1))
542                         return expair(e, _ex1);
543                 mul *mulcopyp = new mul(mulref);
544                 mulcopyp->overall_coeff = _ex1;
545                 mulcopyp->clearflag(status_flags::evaluated);
546                 mulcopyp->clearflag(status_flags::hash_calculated);
547                 mulcopyp->setflag(status_flags::dynallocated);
548                 return expair(*mulcopyp,numfactor);
549         }
550         return expair(e,_ex1);
551 }
552
553 expair add::combine_ex_with_coeff_to_pair(const ex & e,
554                                           const ex & c) const
555 {
556         GINAC_ASSERT(is_exactly_a<numeric>(c));
557         if (is_exactly_a<mul>(e)) {
558                 const mul &mulref(ex_to<mul>(e));
559                 const ex &numfactor = mulref.overall_coeff;
560                 if (numfactor.is_equal(_ex1))
561                         return expair(e, c);
562                 mul *mulcopyp = new mul(mulref);
563                 mulcopyp->overall_coeff = _ex1;
564                 mulcopyp->clearflag(status_flags::evaluated);
565                 mulcopyp->clearflag(status_flags::hash_calculated);
566                 mulcopyp->setflag(status_flags::dynallocated);
567                 if (c.is_equal(_ex1))
568                         return expair(*mulcopyp, numfactor);
569                 else
570                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
571         } else if (is_exactly_a<numeric>(e)) {
572                 if (c.is_equal(_ex1))
573                         return expair(e, _ex1);
574                 if (e.is_equal(_ex1))
575                         return expair(c, _ex1);
576                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
577         }
578         return expair(e, c);
579 }
580
581 expair add::combine_pair_with_coeff_to_pair(const expair & p,
582                                             const ex & c) const
583 {
584         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
585         GINAC_ASSERT(is_exactly_a<numeric>(c));
586
587         if (is_exactly_a<numeric>(p.rest)) {
588                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(*_num1_p)); // should be normalized
589                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
590         }
591
592         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
593 }
594
595 ex add::recombine_pair_to_ex(const expair & p) const
596 {
597         if (ex_to<numeric>(p.coeff).is_equal(*_num1_p))
598                 return p.rest;
599         else
600                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
601 }
602
603 ex add::expand(unsigned options) const
604 {
605         epvector expanded = expandchildren(options);
606         if (expanded.empty())
607                 return (options == 0) ? setflag(status_flags::expanded) : *this;
608
609         return (new add(std::move(expanded), overall_coeff))->setflag(status_flags::dynallocated |
610                                                                       (options == 0 ? status_flags::expanded : 0));
611 }
612
613 } // namespace GiNaC