]> www.ginac.de Git - ginac.git/blob - ginac/ncmul.cpp
[DOC] Update the tutorial for GiNaC::function resolution.
[ginac.git] / ginac / ncmul.cpp
1 /** @file ncmul.cpp
2  *
3  *  Implementation of GiNaC's non-commutative products of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2025 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "ncmul.h"
24 #include "ex.h"
25 #include "add.h"
26 #include "mul.h"
27 #include "clifford.h"
28 #include "matrix.h"
29 #include "archive.h"
30 #include "indexed.h"
31 #include "utils.h"
32
33 #include <algorithm>
34 #include <stdexcept>
35
36 namespace GiNaC {
37
38 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(ncmul, exprseq,
39   print_func<print_context>(&ncmul::do_print).
40   print_func<print_tree>(&ncmul::do_print_tree).
41   print_func<print_csrc>(&ncmul::do_print_csrc).
42   print_func<print_python_repr>(&ncmul::do_print_csrc))
43
44
45 //////////
46 // default constructor
47 //////////
48
49 ncmul::ncmul()
50 {
51 }
52
53 //////////
54 // other constructors
55 //////////
56
57 // public
58
59 ncmul::ncmul(const ex & lh, const ex & rh) : inherited{lh,rh}
60 {
61 }
62
63 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3) : inherited{f1,f2,f3}
64 {
65 }
66
67 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
68              const ex & f4) : inherited{f1,f2,f3,f4}
69 {
70 }
71
72 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
73              const ex & f4, const ex & f5) : inherited{f1,f2,f3,f4,f5}
74 {
75 }
76
77 ncmul::ncmul(const ex & f1, const ex & f2, const ex & f3,
78              const ex & f4, const ex & f5, const ex & f6) : inherited{f1,f2,f3,f4,f5,f6}
79 {
80 }
81
82 ncmul::ncmul(const exvector & v) : inherited(v)
83 {
84 }
85
86 ncmul::ncmul(exvector && v) : inherited(std::move(v))
87 {
88 }
89
90 //////////
91 // archiving
92 //////////
93
94
95 //////////
96 // functions overriding virtual functions from base classes
97 //////////
98
99 // public
100
101 void ncmul::do_print(const print_context & c, unsigned level) const
102 {
103         printseq(c, '(', '*', ')', precedence(), level);
104 }
105
106 void ncmul::do_print_csrc(const print_context & c, unsigned level) const
107 {
108         c.s << class_name();
109         printseq(c, '(', ',', ')', precedence(), precedence());
110 }
111
112 bool ncmul::info(unsigned inf) const
113 {
114         return inherited::info(inf);
115 }
116
117 typedef std::vector<std::size_t> uintvector;
118
119 ex ncmul::expand(unsigned options) const
120 {
121         // First, expand the children
122         exvector v = expandchildren(options);
123         const exvector &expanded_seq = v.empty() ? this->seq : v;
124         
125         // Now, look for all the factors that are sums and remember their
126         // position and number of terms.
127         uintvector positions_of_adds(expanded_seq.size());
128         uintvector number_of_add_operands(expanded_seq.size());
129
130         size_t number_of_adds = 0;
131         size_t number_of_expanded_terms = 1;
132
133         size_t current_position = 0;
134         for (auto & it : expanded_seq) {
135                 if (is_exactly_a<add>(it)) {
136                         positions_of_adds[number_of_adds] = current_position;
137                         size_t num_ops = it.nops();
138                         number_of_add_operands[number_of_adds] = num_ops;
139                         number_of_expanded_terms *= num_ops;
140                         number_of_adds++;
141                 }
142                 ++current_position;
143         }
144
145         // If there are no sums, we are done
146         if (number_of_adds == 0) {
147                 if (!v.empty())
148                         return dynallocate<ncmul>(std::move(v)).setflag(options == 0 ? status_flags::expanded : 0);
149                 else
150                         return *this;
151         }
152
153         // Now, form all possible products of the terms of the sums with the
154         // remaining factors, and add them together
155         exvector distrseq;
156         distrseq.reserve(number_of_expanded_terms);
157
158         uintvector k(number_of_adds);
159
160         /* Rename indices in the static members of the product */
161         exvector expanded_seq_mod;
162         size_t j = 0;
163         exvector va;
164
165         for (size_t i=0; i<expanded_seq.size(); i++) {
166                 if (i == positions_of_adds[j]) {
167                         expanded_seq_mod.push_back(_ex1);
168                         j++;
169                 } else {
170                         expanded_seq_mod.push_back(rename_dummy_indices_uniquely(va, expanded_seq[i], true));
171                 }
172         }
173
174         while (true) {
175                 exvector term = expanded_seq_mod;
176                 for (size_t i=0; i<number_of_adds; i++) {
177                         term[positions_of_adds[i]] = rename_dummy_indices_uniquely(va, expanded_seq[positions_of_adds[i]].op(k[i]), true);
178                 }
179
180                 distrseq.push_back(dynallocate<ncmul>(std::move(term)).setflag(options == 0 ? status_flags::expanded : 0));
181
182                 // increment k[]
183                 int l = number_of_adds-1;
184                 while ((l>=0) && ((++k[l]) >= number_of_add_operands[l])) {
185                         k[l] = 0;
186                         l--;
187                 }
188                 if (l<0)
189                         break;
190         }
191
192         return dynallocate<add>(distrseq).setflag(options == 0 ? status_flags::expanded : 0);
193 }
194
195 int ncmul::degree(const ex & s) const
196 {
197         if (is_equal(ex_to<basic>(s)))
198                 return 1;
199
200         // Sum up degrees of factors
201         int deg_sum = 0;
202         for (auto & i : seq)
203                 deg_sum += i.degree(s);
204         return deg_sum;
205 }
206
207 int ncmul::ldegree(const ex & s) const
208 {
209         if (is_equal(ex_to<basic>(s)))
210                 return 1;
211
212         // Sum up degrees of factors
213         int deg_sum = 0;
214         for (auto & i : seq)
215                 deg_sum += i.degree(s);
216         return deg_sum;
217 }
218
219 ex ncmul::coeff(const ex & s, int n) const
220 {
221         if (is_equal(ex_to<basic>(s)))
222                 return n==1 ? _ex1 : _ex0;
223
224         exvector coeffseq;
225         coeffseq.reserve(seq.size());
226
227         if (n == 0) {
228                 // product of individual coeffs
229                 // if a non-zero power of s is found, the resulting product will be 0
230                 for (auto & it : seq)
231                         coeffseq.push_back(it.coeff(s,n));
232                 return dynallocate<ncmul>(std::move(coeffseq));
233         }
234                  
235         bool coeff_found = false;
236         for (auto & i : seq) {
237                 ex c = i.coeff(s,n);
238                 if (c.is_zero()) {
239                         coeffseq.push_back(i);
240                 } else {
241                         coeffseq.push_back(c);
242                         coeff_found = true;
243                 }
244         }
245
246         if (coeff_found)
247                 return dynallocate<ncmul>(std::move(coeffseq));
248         
249         return _ex0;
250 }
251
252 size_t ncmul::count_factors(const ex & e) const
253 {
254         if ((is_exactly_a<mul>(e)&&(e.return_type()!=return_types::commutative))||
255                 (is_exactly_a<ncmul>(e))) {
256                 size_t factors=0;
257                 for (size_t i=0; i<e.nops(); i++)
258                         factors += count_factors(e.op(i));
259                 
260                 return factors;
261         }
262         return 1;
263 }
264                 
265 void ncmul::append_factors(exvector & v, const ex & e) const
266 {
267         if ((is_exactly_a<mul>(e)&&(e.return_type()!=return_types::commutative))||
268                 (is_exactly_a<ncmul>(e))) {
269                 for (size_t i=0; i<e.nops(); i++)
270                         append_factors(v, e.op(i));
271         } else 
272                 v.push_back(e);
273 }
274
275 typedef std::vector<unsigned> unsignedvector;
276 typedef std::vector<exvector> exvectorvector;
277
278 /** Perform automatic term rewriting rules in this class.  In the following
279  *  x, x1, x2,... stand for a symbolic variables of type ex and c, c1, c2...
280  *  stand for such expressions that contain a plain number.
281  *  - ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) -> ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
282  *  - ncmul(x) -> x
283  *  - ncmul() -> 1
284  *  - ncmul(...,c1,...,c2,...) -> *(c1,c2,ncmul(...))  (pull out commutative elements)
285  *  - ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))  (collect elements of same type)
286  *  - ncmul(x1,x2,x3,...) -> x::eval_ncmul(x1,x2,x3,...)
287  */
288 ex ncmul::eval() const
289 {
290         // The following additional rule would be nice, but produces a recursion,
291         // which must be trapped by introducing a flag that the sub-ncmuls()
292         // are already evaluated (maybe later...)
293         //                  ncmul(x1,x2,...,X,y1,y2,...) ->
294         //                      ncmul(ncmul(x1,x2,...),X,ncmul(y1,y2,...)
295         //                      (X noncommutative_composite)
296
297         if (flags & status_flags::evaluated) {
298                 return *this;
299         }
300
301         // ncmul(...,*(x1,x2),...,ncmul(x3,x4),...) ->
302         //     ncmul(...,x1,x2,...,x3,x4,...)  (associativity)
303         size_t factors = 0;
304         for (auto & it : seq)
305                 factors += count_factors(it);
306         
307         exvector assocseq;
308         assocseq.reserve(factors);
309         make_flat_inserter mf(seq, true);
310         for (auto & it : seq) {
311                 ex factor = mf.handle_factor(it, 1);
312                 append_factors(assocseq, factor);
313         }
314         
315         // ncmul(x) -> x
316         if (assocseq.size()==1) return *(seq.begin());
317
318         // ncmul() -> 1
319         if (assocseq.empty()) return _ex1;
320
321         // determine return types
322         unsignedvector rettypes(assocseq.size());
323         size_t i = 0;
324         size_t count_commutative=0;
325         size_t count_noncommutative=0;
326         size_t count_noncommutative_composite=0;
327         for (auto & it : assocseq) {
328                 rettypes[i] = it.return_type();
329                 switch (rettypes[i]) {
330                 case return_types::commutative:
331                         count_commutative++;
332                         break;
333                 case return_types::noncommutative:
334                         count_noncommutative++;
335                         break;
336                 case return_types::noncommutative_composite:
337                         count_noncommutative_composite++;
338                         break;
339                 default:
340                         throw(std::logic_error("ncmul::eval(): invalid return type"));
341                 }
342                 ++i;
343         }
344         GINAC_ASSERT(count_commutative+count_noncommutative+count_noncommutative_composite==assocseq.size());
345
346         // ncmul(...,c1,...,c2,...) ->
347         //     *(c1,c2,ncmul(...)) (pull out commutative elements)
348         if (count_commutative!=0) {
349                 exvector commutativeseq;
350                 commutativeseq.reserve(count_commutative+1);
351                 exvector noncommutativeseq;
352                 noncommutativeseq.reserve(assocseq.size()-count_commutative);
353                 size_t num = assocseq.size();
354                 for (size_t i=0; i<num; ++i) {
355                         if (rettypes[i]==return_types::commutative)
356                                 commutativeseq.push_back(assocseq[i]);
357                         else
358                                 noncommutativeseq.push_back(assocseq[i]);
359                 }
360                 commutativeseq.push_back(dynallocate<ncmul>(std::move(noncommutativeseq)));
361                 return dynallocate<mul>(std::move(commutativeseq));
362         }
363                 
364         // ncmul(x1,y1,x2,y2) -> *(ncmul(x1,x2),ncmul(y1,y2))
365         //     (collect elements of same type)
366
367         if (count_noncommutative_composite==0) {
368                 // there are neither commutative nor noncommutative_composite
369                 // elements in assocseq
370                 GINAC_ASSERT(count_commutative==0);
371
372                 size_t assoc_num = assocseq.size();
373                 exvectorvector evv;
374                 std::vector<return_type_t> rttinfos;
375                 evv.reserve(assoc_num);
376                 rttinfos.reserve(assoc_num);
377
378                 for (auto & it : assocseq) {
379                         return_type_t ti = it.return_type_tinfo();
380                         size_t rtt_num = rttinfos.size();
381                         // search type in vector of known types
382                         for (i=0; i<rtt_num; ++i) {
383                                 if(ti == rttinfos[i]) {
384                                         evv[i].push_back(it);
385                                         break;
386                                 }
387                         }
388                         if (i >= rtt_num) {
389                                 // new type
390                                 rttinfos.push_back(ti);
391                                 evv.push_back(exvector());
392                                 (evv.end()-1)->reserve(assoc_num);
393                                 (evv.end()-1)->push_back(it);
394                         }
395                 }
396
397                 size_t evv_num = evv.size();
398 #ifdef DO_GINAC_ASSERT
399                 GINAC_ASSERT(evv_num == rttinfos.size());
400                 GINAC_ASSERT(evv_num > 0);
401                 size_t s=0;
402                 for (i=0; i<evv_num; ++i)
403                         s += evv[i].size();
404                 GINAC_ASSERT(s == assoc_num);
405 #endif // def DO_GINAC_ASSERT
406                 
407                 // if all elements are of same type, simplify the string
408                 if (evv_num == 1) {
409                         return evv[0][0].eval_ncmul(evv[0]);
410                 }
411                 
412                 exvector splitseq;
413                 splitseq.reserve(evv_num);
414                 for (i=0; i<evv_num; ++i)
415                         splitseq.push_back(dynallocate<ncmul>(evv[i]));
416                 
417                 return dynallocate<mul>(splitseq);
418         }
419         
420         return dynallocate<ncmul>(assocseq).setflag(status_flags::evaluated);
421 }
422
423 ex ncmul::evalm() const
424 {
425         // Evaluate children first
426         exvector s;
427         s.reserve(seq.size());
428         for (auto & it : seq)
429                 s.push_back(it.evalm());
430
431         // If there are only matrices, simply multiply them
432         auto it = s.begin(), itend = s.end();
433         if (is_a<matrix>(*it)) {
434                 matrix prod(ex_to<matrix>(*it));
435                 it++;
436                 while (it != itend) {
437                         if (!is_a<matrix>(*it))
438                                 goto no_matrix;
439                         prod = prod.mul(ex_to<matrix>(*it));
440                         it++;
441                 }
442                 return prod;
443         }
444
445 no_matrix:
446         return dynallocate<ncmul>(std::move(s));
447 }
448
449 ex ncmul::thiscontainer(const exvector & v) const
450 {
451         return dynallocate<ncmul>(v);
452 }
453
454 ex ncmul::thiscontainer(exvector && v) const
455 {
456         return dynallocate<ncmul>(std::move(v));
457 }
458
459 ex ncmul::conjugate() const
460 {
461         if (return_type() != return_types::noncommutative) {
462                 return exprseq::conjugate();
463         }
464
465         if (!is_clifford_tinfo(return_type_tinfo())) {
466                 return exprseq::conjugate();
467         }
468
469         exvector ev;
470         ev.reserve(nops());
471         for (auto i=end(); i!=begin();) {
472                 --i;
473                 ev.push_back(i->conjugate());
474         }
475         return dynallocate<ncmul>(std::move(ev));
476 }
477
478 ex ncmul::real_part() const
479 {
480         return basic::real_part();
481 }
482
483 ex ncmul::imag_part() const
484 {
485         return basic::imag_part();
486 }
487
488 // protected
489
490 /** Implementation of ex::diff() for a non-commutative product. It applies
491  *  the product rule.
492  *  @see ex::diff */
493 ex ncmul::derivative(const symbol & s) const
494 {
495         size_t num = seq.size();
496         exvector addseq;
497         addseq.reserve(num);
498         
499         // D(a*b*c) = D(a)*b*c + a*D(b)*c + a*b*D(c)
500         exvector ncmulseq = seq;
501         for (size_t i=0; i<num; ++i) {
502                 ex e = seq[i].diff(s);
503                 e.swap(ncmulseq[i]);
504                 addseq.push_back(dynallocate<ncmul>(ncmulseq));
505                 e.swap(ncmulseq[i]);
506         }
507         return dynallocate<add>(addseq);
508 }
509
510 int ncmul::compare_same_type(const basic & other) const
511 {
512         return inherited::compare_same_type(other);
513 }
514
515 unsigned ncmul::return_type() const
516 {
517         if (seq.empty())
518                 return return_types::commutative;
519
520         bool all_commutative = true;
521         exvector::const_iterator noncommutative_element; // point to first found nc element
522
523         auto i = seq.begin(), end = seq.end();
524         while (i != end) {
525                 unsigned rt = i->return_type();
526                 if (rt == return_types::noncommutative_composite)
527                         return rt; // one ncc -> mul also ncc
528                 if ((rt == return_types::noncommutative) && (all_commutative)) {
529                         // first nc element found, remember position
530                         noncommutative_element = i;
531                         all_commutative = false;
532                 }
533                 if ((rt == return_types::noncommutative) && (!all_commutative)) {
534                         // another nc element found, compare type_infos
535                         if(noncommutative_element->return_type_tinfo() != i->return_type_tinfo())
536                                         return return_types::noncommutative_composite;
537                 }
538                 ++i;
539         }
540         // all factors checked
541         GINAC_ASSERT(!all_commutative); // not all factors should commutate, because this is a ncmul();
542         return all_commutative ? return_types::commutative : return_types::noncommutative;
543 }
544
545 return_type_t ncmul::return_type_tinfo() const
546 {
547         if (seq.empty())
548                 return make_return_type_t<ncmul>();
549
550         // return type_info of first noncommutative element
551         for (auto & i : seq)
552                 if (i.return_type() == return_types::noncommutative)
553                         return i.return_type_tinfo();
554
555         // no noncommutative element found, should not happen
556         return make_return_type_t<ncmul>();
557 }
558
559 //////////
560 // new virtual functions which can be overridden by derived classes
561 //////////
562
563 // none
564
565 //////////
566 // non-virtual functions in this class
567 //////////
568
569 exvector ncmul::expandchildren(unsigned options) const
570 {
571         auto cit = this->seq.begin(), end = this->seq.end();
572         while (cit != end) {
573                 const ex & expanded_ex = cit->expand(options);
574                 if (!are_ex_trivially_equal(*cit, expanded_ex)) {
575
576                         // copy first part of seq which hasn't changed
577                         exvector s(this->seq.begin(), cit);
578                         s.reserve(this->seq.size());
579
580                         // insert changed element
581                         s.push_back(expanded_ex);
582                         ++cit;
583
584                         // copy rest
585                         while (cit != end) {
586                                 s.push_back(cit->expand(options));
587                                 ++cit;
588                         }
589
590                         return s;
591                 }
592
593                 ++cit;
594         }
595
596         return exvector(); // nothing has changed
597 }
598
599 const exvector & ncmul::get_factors() const
600 {
601         return seq;
602 }
603
604 //////////
605 // friend functions
606 //////////
607
608 ex reeval_ncmul(const exvector & v)
609 {
610         return dynallocate<ncmul>(v);
611 }
612
613 ex hold_ncmul(const exvector & v)
614 {
615         if (v.empty())
616                 return _ex1;
617         else if (v.size() == 1)
618                 return v[0];
619         else
620                 return dynallocate<ncmul>(v).setflag(status_flags::evaluated);
621 }
622
623 GINAC_BIND_UNARCHIVER(ncmul);
624
625 } // namespace GiNaC