]> www.ginac.de Git - ginac.git/blob - ginac/expairseq.cpp
Remove dependence on depreacted std::auto_ptr<T>.
[ginac.git] / ginac / expairseq.cpp
1 /** @file expairseq.cpp
2  *
3  *  Implementation of sequences of expression pairs. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "expairseq.h"
24 #include "lst.h"
25 #include "add.h"
26 #include "mul.h"
27 #include "power.h"
28 #include "relational.h"
29 #include "wildcard.h"
30 #include "archive.h"
31 #include "operators.h"
32 #include "utils.h"
33 #include "hash_seed.h"
34 #include "indexed.h"
35
36 #include <algorithm>
37 #include <iostream>
38 #include <iterator>
39 #include <stdexcept>
40 #include <string>
41
42 namespace GiNaC {
43
44         
45 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(expairseq, basic,
46   print_func<print_context>(&expairseq::do_print).
47   print_func<print_tree>(&expairseq::do_print_tree))
48
49
50 //////////
51 // helper classes
52 //////////
53
54 class epp_is_less
55 {
56 public:
57         bool operator()(const epp &lh, const epp &rh) const
58         {
59                 return (*lh).is_less(*rh);
60         }
61 };
62
63 //////////
64 // default constructor
65 //////////
66
67 // public
68
69 expairseq::expairseq() 
70 {}
71
72 // protected
73
74 //////////
75 // other constructors
76 //////////
77
78 expairseq::expairseq(const ex &lh, const ex &rh)
79 {
80         construct_from_2_ex(lh,rh);
81         GINAC_ASSERT(is_canonical());
82 }
83
84 expairseq::expairseq(const exvector &v)
85 {
86         construct_from_exvector(v);
87         GINAC_ASSERT(is_canonical());
88 }
89
90 expairseq::expairseq(const epvector &v, const ex &oc, bool do_index_renaming)
91   :  overall_coeff(oc)
92 {
93         GINAC_ASSERT(is_a<numeric>(oc));
94         construct_from_epvector(v, do_index_renaming);
95         GINAC_ASSERT(is_canonical());
96 }
97
98 expairseq::expairseq(epvector && vp, const ex &oc, bool do_index_renaming)
99   :  overall_coeff(oc)
100 {
101         GINAC_ASSERT(is_a<numeric>(oc));
102         construct_from_epvector(std::move(vp), do_index_renaming);
103         GINAC_ASSERT(is_canonical());
104 }
105
106 //////////
107 // archiving
108 //////////
109
110 void expairseq::read_archive(const archive_node &n, lst &sym_lst) 
111 {
112         inherited::read_archive(n, sym_lst);
113         archive_node::archive_node_cit first = n.find_first("rest");
114         archive_node::archive_node_cit last = n.find_last("coeff");
115         ++last;
116         seq.reserve((last-first)/2);
117
118         for (archive_node::archive_node_cit loc = first; loc < last;) {
119                 ex rest;
120                 ex coeff;
121                 n.find_ex_by_loc(loc++, rest, sym_lst);
122                 n.find_ex_by_loc(loc++, coeff, sym_lst);
123                 seq.push_back(expair(rest, coeff));
124         }
125
126         n.find_ex("overall_coeff", overall_coeff, sym_lst);
127
128         canonicalize();
129         GINAC_ASSERT(is_canonical());
130 }
131
132 void expairseq::archive(archive_node &n) const
133 {
134         inherited::archive(n);
135         epvector::const_iterator i = seq.begin(), iend = seq.end();
136         while (i != iend) {
137                 n.add_ex("rest", i->rest);
138                 n.add_ex("coeff", i->coeff);
139                 ++i;
140         }
141         n.add_ex("overall_coeff", overall_coeff);
142 }
143
144
145 //////////
146 // functions overriding virtual functions from base classes
147 //////////
148
149 // public
150
151 void expairseq::do_print(const print_context & c, unsigned level) const
152 {
153         c.s << "[[";
154         printseq(c, ',', precedence(), level);
155         c.s << "]]";
156 }
157
158 void expairseq::do_print_tree(const print_tree & c, unsigned level) const
159 {
160         c.s << std::string(level, ' ') << class_name() << " @" << this
161             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
162             << ", nops=" << nops()
163             << std::endl;
164         size_t num = seq.size();
165         for (size_t i=0; i<num; ++i) {
166                 seq[i].rest.print(c, level + c.delta_indent);
167                 seq[i].coeff.print(c, level + c.delta_indent);
168                 if (i != num - 1)
169                         c.s << std::string(level + c.delta_indent, ' ') << "-----" << std::endl;
170         }
171         if (!overall_coeff.is_equal(default_overall_coeff())) {
172                 c.s << std::string(level + c.delta_indent, ' ') << "-----" << std::endl
173                     << std::string(level + c.delta_indent, ' ') << "overall_coeff" << std::endl;
174                 overall_coeff.print(c, level + c.delta_indent);
175         }
176         c.s << std::string(level + c.delta_indent,' ') << "=====" << std::endl;
177 }
178
179 bool expairseq::info(unsigned inf) const
180 {
181         switch(inf) {
182                 case info_flags::expanded:
183                         return (flags & status_flags::expanded);
184                 case info_flags::has_indices: {
185                         if (flags & status_flags::has_indices)
186                                 return true;
187                         else if (flags & status_flags::has_no_indices)
188                                 return false;
189                         for (epvector::const_iterator i = seq.begin(); i != seq.end(); ++i) {
190                                 if (i->rest.info(info_flags::has_indices)) {
191                                         this->setflag(status_flags::has_indices);
192                                         this->clearflag(status_flags::has_no_indices);
193                                         return true;
194                                 }
195                         }
196                         this->clearflag(status_flags::has_indices);
197                         this->setflag(status_flags::has_no_indices);
198                         return false;
199                 }
200         }
201         return inherited::info(inf);
202 }
203
204 size_t expairseq::nops() const
205 {
206         if (overall_coeff.is_equal(default_overall_coeff()))
207                 return seq.size();
208         else
209                 return seq.size()+1;
210 }
211
212 ex expairseq::op(size_t i) const
213 {
214         if (i < seq.size())
215                 return recombine_pair_to_ex(seq[i]);
216         GINAC_ASSERT(!overall_coeff.is_equal(default_overall_coeff()));
217         return overall_coeff;
218 }
219
220 ex expairseq::map(map_function &f) const
221 {
222         epvector v;
223         v.reserve(seq.size()+1);
224
225         epvector::const_iterator cit = seq.begin(), last = seq.end();
226         while (cit != last) {
227                 v.push_back(split_ex_to_pair(f(recombine_pair_to_ex(*cit))));
228                 ++cit;
229         }
230
231         if (overall_coeff.is_equal(default_overall_coeff()))
232                 return thisexpairseq(std::move(v), default_overall_coeff(), true);
233         else {
234                 ex newcoeff = f(overall_coeff);
235                 if(is_a<numeric>(newcoeff))
236                         return thisexpairseq(std::move(v), newcoeff, true);
237                 else {
238                         v.push_back(split_ex_to_pair(newcoeff));
239                         return thisexpairseq(std::move(v), default_overall_coeff(), true);
240                 }
241         }
242 }
243
244 /** Perform coefficient-wise automatic term rewriting rules in this class. */
245 ex expairseq::eval(int level) const
246 {
247         if ((level==1) && (flags &status_flags::evaluated))
248                 return *this;
249
250         epvector evaled = evalchildren(level);
251         if (!evaled.empty())
252                 return (new expairseq(std::move(evaled), overall_coeff))->setflag(status_flags::dynallocated | status_flags::evaluated);
253         else
254                 return this->hold();
255 }
256
257 epvector* conjugateepvector(const epvector&epv)
258 {
259         epvector *newepv = 0;
260         for (epvector::const_iterator i=epv.begin(); i!=epv.end(); ++i) {
261                 if(newepv) {
262                         newepv->push_back(i->conjugate());
263                         continue;
264                 }
265                 expair x = i->conjugate();
266                 if (x.is_equal(*i)) {
267                         continue;
268                 }
269                 newepv = new epvector;
270                 newepv->reserve(epv.size());
271                 for (epvector::const_iterator j=epv.begin(); j!=i; ++j) {
272                         newepv->push_back(*j);
273                 }
274                 newepv->push_back(x);
275         }
276         return newepv;
277 }
278
279 ex expairseq::conjugate() const
280 {
281         epvector* newepv = conjugateepvector(seq);
282         ex x = overall_coeff.conjugate();
283         if (newepv) {
284                 ex result = thisexpairseq(std::move(*newepv), x);
285                 delete newepv;
286                 return result;
287         }
288         if (are_ex_trivially_equal(x, overall_coeff)) {
289                 return *this;
290         }
291         return thisexpairseq(seq, x);
292 }
293
294 bool expairseq::match(const ex & pattern, exmap & repl_lst) const
295 {
296         // This differs from basic::match() because we want "a+b+c+d" to
297         // match "d+*+b" with "*" being "a+c", and we want to honor commutativity
298
299         if (typeid(*this) == typeid(ex_to<basic>(pattern))) {
300
301                 // Check whether global wildcard (one that matches the "rest of the
302                 // expression", like "*" above) is present
303                 bool has_global_wildcard = false;
304                 ex global_wildcard;
305                 for (size_t i=0; i<pattern.nops(); i++) {
306                         if (is_exactly_a<wildcard>(pattern.op(i))) {
307                                 has_global_wildcard = true;
308                                 global_wildcard = pattern.op(i);
309                                 break;
310                         }
311                 }
312
313                 // Even if the expression does not match the pattern, some of
314                 // its subexpressions could match it. For example, x^5*y^(-1)
315                 // does not match the pattern $0^5, but its subexpression x^5
316                 // does. So, save repl_lst in order to not add bogus entries.
317                 exmap tmp_repl = repl_lst;
318
319                 // Unfortunately, this is an O(N^2) operation because we can't
320                 // sort the pattern in a useful way...
321
322                 // Chop into terms
323                 exvector ops;
324                 ops.reserve(nops());
325                 for (size_t i=0; i<nops(); i++)
326                         ops.push_back(op(i));
327
328                 // Now, for every term of the pattern, look for a matching term in
329                 // the expression and remove the match
330                 for (size_t i=0; i<pattern.nops(); i++) {
331                         ex p = pattern.op(i);
332                         if (has_global_wildcard && p.is_equal(global_wildcard))
333                                 continue;
334                         exvector::iterator it = ops.begin(), itend = ops.end();
335                         while (it != itend) {
336                                 if (it->match(p, tmp_repl)) {
337                                         ops.erase(it);
338                                         goto found;
339                                 }
340                                 ++it;
341                         }
342                         return false; // no match found
343 found:          ;
344                 }
345
346                 if (has_global_wildcard) {
347
348                         // Assign all the remaining terms to the global wildcard (unless
349                         // it has already been matched before, in which case the matches
350                         // must be equal)
351                         size_t num = ops.size();
352                         epvector vp;
353                         vp.reserve(num);
354                         for (size_t i=0; i<num; i++)
355                                 vp.push_back(split_ex_to_pair(ops[i]));
356                         ex rest = thisexpairseq(std::move(vp), default_overall_coeff());
357                         for (exmap::const_iterator it = tmp_repl.begin(); it != tmp_repl.end(); ++it) {
358                                 if (it->first.is_equal(global_wildcard)) {
359                                         if (rest.is_equal(it->second)) {
360                                                 repl_lst = tmp_repl;
361                                                 return true;
362                                         }
363                                         return false;
364                                 }
365                         }
366                         repl_lst = tmp_repl;
367                         repl_lst[global_wildcard] = rest;
368                         return true;
369
370                 } else {
371
372                         // No global wildcard, then the match fails if there are any
373                         // unmatched terms left
374                         if (ops.empty()) {
375                                 repl_lst = tmp_repl;
376                                 return true;
377                         }
378                         return false;
379                 }
380         }
381         return inherited::match(pattern, repl_lst);
382 }
383
384 ex expairseq::subs(const exmap & m, unsigned options) const
385 {
386         epvector subsed = subschildren(m, options);
387         if (!subsed.empty())
388                 return ex_to<basic>(thisexpairseq(std::move(subsed), overall_coeff, (options & subs_options::no_index_renaming) == 0));
389         else if ((options & subs_options::algebraic) && is_exactly_a<mul>(*this))
390                 return static_cast<const mul *>(this)->algebraic_subs_mul(m, options);
391         else
392                 return subs_one_level(m, options);
393 }
394
395 // protected
396
397 int expairseq::compare_same_type(const basic &other) const
398 {
399         GINAC_ASSERT(is_a<expairseq>(other));
400         const expairseq &o = static_cast<const expairseq &>(other);
401         
402         int cmpval;
403         
404         // compare number of elements
405         if (seq.size() != o.seq.size())
406                 return (seq.size()<o.seq.size()) ? -1 : 1;
407         
408         // compare overall_coeff
409         cmpval = overall_coeff.compare(o.overall_coeff);
410         if (cmpval!=0)
411                 return cmpval;
412         
413         epvector::const_iterator cit1 = seq.begin();
414         epvector::const_iterator cit2 = o.seq.begin();
415         epvector::const_iterator last1 = seq.end();
416         epvector::const_iterator last2 = o.seq.end();
417                 
418         for (; (cit1!=last1)&&(cit2!=last2); ++cit1, ++cit2) {
419                 cmpval = (*cit1).compare(*cit2);
420                 if (cmpval!=0) return cmpval;
421         }
422                 
423         GINAC_ASSERT(cit1==last1);
424         GINAC_ASSERT(cit2==last2);
425                 
426         return 0;
427 }
428
429 bool expairseq::is_equal_same_type(const basic &other) const
430 {
431         const expairseq &o = static_cast<const expairseq &>(other);
432         
433         // compare number of elements
434         if (seq.size()!=o.seq.size())
435                 return false;
436         
437         // compare overall_coeff
438         if (!overall_coeff.is_equal(o.overall_coeff))
439                 return false;
440         
441         epvector::const_iterator cit1 = seq.begin();
442         epvector::const_iterator cit2 = o.seq.begin();
443         epvector::const_iterator last1 = seq.end();
444                 
445         while (cit1!=last1) {
446                 if (!(*cit1).is_equal(*cit2)) return false;
447                 ++cit1;
448                 ++cit2;
449         }
450
451         return true;
452 }
453
454 unsigned expairseq::return_type() const
455 {
456         return return_types::noncommutative_composite;
457 }
458
459 unsigned expairseq::calchash() const
460 {
461         unsigned v = make_hash_seed(typeid(*this));
462         epvector::const_iterator i = seq.begin();
463         const epvector::const_iterator end = seq.end();
464         while (i != end) {
465                 v ^= i->rest.gethash();
466                 v = rotate_left(v);
467                 v ^= i->coeff.gethash();
468                 ++i;
469         }
470
471         v ^= overall_coeff.gethash();
472
473         // store calculated hash value only if object is already evaluated
474         if (flags &status_flags::evaluated) {
475                 setflag(status_flags::hash_calculated);
476                 hashvalue = v;
477         }
478         
479         return v;
480 }
481
482 ex expairseq::expand(unsigned options) const
483 {
484         epvector expanded = expandchildren(options);
485         if (!expanded.empty()) {
486                 return thisexpairseq(std::move(expanded), overall_coeff);
487         }
488         return (options == 0) ? setflag(status_flags::expanded) : *this;
489 }
490
491 //////////
492 // new virtual functions which can be overridden by derived classes
493 //////////
494
495 // protected
496
497 /** Create an object of this type.
498  *  This method works similar to a constructor.  It is useful because expairseq
499  *  has (at least) two possible different semantics but we want to inherit
500  *  methods thus avoiding code duplication.  Sometimes a method in expairseq
501  *  has to create a new one of the same semantics, which cannot be done by a
502  *  ctor because the name (add, mul,...) is unknown on the expairseq level.  In
503  *  order for this trick to work a derived class must of course override this
504  *  definition. */
505 ex expairseq::thisexpairseq(const epvector &v, const ex &oc, bool do_index_renaming) const
506 {
507         return expairseq(v, oc, do_index_renaming);
508 }
509
510 ex expairseq::thisexpairseq(epvector && vp, const ex &oc, bool do_index_renaming) const
511 {
512         return expairseq(std::move(vp), oc, do_index_renaming);
513 }
514
515 void expairseq::printpair(const print_context & c, const expair & p, unsigned upper_precedence) const
516 {
517         c.s << "[[";
518         p.rest.print(c, precedence());
519         c.s << ",";
520         p.coeff.print(c, precedence());
521         c.s << "]]";
522 }
523
524 void expairseq::printseq(const print_context & c, char delim,
525                          unsigned this_precedence,
526                          unsigned upper_precedence) const
527 {
528         if (this_precedence <= upper_precedence)
529                 c.s << "(";
530         epvector::const_iterator it, it_last = seq.end() - 1;
531         for (it=seq.begin(); it!=it_last; ++it) {
532                 printpair(c, *it, this_precedence);
533                 c.s << delim;
534         }
535         printpair(c, *it, this_precedence);
536         if (!overall_coeff.is_equal(default_overall_coeff())) {
537                 c.s << delim;
538                 overall_coeff.print(c, this_precedence);
539         }
540         
541         if (this_precedence <= upper_precedence)
542                 c.s << ")";
543 }
544
545
546 /** Form an expair from an ex, using the corresponding semantics.
547  *  @see expairseq::recombine_pair_to_ex() */
548 expair expairseq::split_ex_to_pair(const ex &e) const
549 {
550         return expair(e,_ex1);
551 }
552
553
554 expair expairseq::combine_ex_with_coeff_to_pair(const ex &e,
555                                                 const ex &c) const
556 {
557         GINAC_ASSERT(is_exactly_a<numeric>(c));
558         
559         return expair(e,c);
560 }
561
562
563 expair expairseq::combine_pair_with_coeff_to_pair(const expair &p,
564                                                   const ex &c) const
565 {
566         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
567         GINAC_ASSERT(is_exactly_a<numeric>(c));
568         
569         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
570 }
571
572
573 /** Form an ex out of an expair, using the corresponding semantics.
574  *  @see expairseq::split_ex_to_pair() */
575 ex expairseq::recombine_pair_to_ex(const expair &p) const
576 {
577         return lst(p.rest,p.coeff);
578 }
579
580 bool expairseq::expair_needs_further_processing(epp it)
581 {
582         return false;
583 }
584
585 ex expairseq::default_overall_coeff() const
586 {
587         return _ex0;
588 }
589
590 void expairseq::combine_overall_coeff(const ex &c)
591 {
592         GINAC_ASSERT(is_exactly_a<numeric>(overall_coeff));
593         GINAC_ASSERT(is_exactly_a<numeric>(c));
594         overall_coeff = ex_to<numeric>(overall_coeff).add_dyn(ex_to<numeric>(c));
595 }
596
597 void expairseq::combine_overall_coeff(const ex &c1, const ex &c2)
598 {
599         GINAC_ASSERT(is_exactly_a<numeric>(overall_coeff));
600         GINAC_ASSERT(is_exactly_a<numeric>(c1));
601         GINAC_ASSERT(is_exactly_a<numeric>(c2));
602         overall_coeff = ex_to<numeric>(overall_coeff).
603                         add_dyn(ex_to<numeric>(c1).mul(ex_to<numeric>(c2)));
604 }
605
606 bool expairseq::can_make_flat(const expair &p) const
607 {
608         return true;
609 }
610
611
612 //////////
613 // non-virtual functions in this class
614 //////////
615
616 void expairseq::construct_from_2_ex_via_exvector(const ex &lh, const ex &rh)
617 {
618         exvector v;
619         v.reserve(2);
620         v.push_back(lh);
621         v.push_back(rh);
622         construct_from_exvector(v);
623 }
624
625 void expairseq::construct_from_2_ex(const ex &lh, const ex &rh)
626 {
627         if (typeid(ex_to<basic>(lh)) == typeid(*this)) {
628                 if (typeid(ex_to<basic>(rh)) == typeid(*this)) {
629                         if (is_a<mul>(lh) && lh.info(info_flags::has_indices) && 
630                                 rh.info(info_flags::has_indices)) {
631                                 ex newrh=rename_dummy_indices_uniquely(lh, rh);
632                                 construct_from_2_expairseq(ex_to<expairseq>(lh),
633                                                            ex_to<expairseq>(newrh));
634                         }
635                         else
636                                 construct_from_2_expairseq(ex_to<expairseq>(lh),
637                                                            ex_to<expairseq>(rh));
638                         return;
639                 } else {
640                         construct_from_expairseq_ex(ex_to<expairseq>(lh), rh);
641                         return;
642                 }
643         } else if (typeid(ex_to<basic>(rh)) == typeid(*this)) {
644                 construct_from_expairseq_ex(ex_to<expairseq>(rh),lh);
645                 return;
646         }
647         
648         if (is_exactly_a<numeric>(lh)) {
649                 if (is_exactly_a<numeric>(rh)) {
650                         combine_overall_coeff(lh);
651                         combine_overall_coeff(rh);
652                 } else {
653                         combine_overall_coeff(lh);
654                         seq.push_back(split_ex_to_pair(rh));
655                 }
656         } else {
657                 if (is_exactly_a<numeric>(rh)) {
658                         combine_overall_coeff(rh);
659                         seq.push_back(split_ex_to_pair(lh));
660                 } else {
661                         expair p1 = split_ex_to_pair(lh);
662                         expair p2 = split_ex_to_pair(rh);
663                         
664                         int cmpval = p1.rest.compare(p2.rest);
665                         if (cmpval==0) {
666                                 p1.coeff = ex_to<numeric>(p1.coeff).add_dyn(ex_to<numeric>(p2.coeff));
667                                 if (!ex_to<numeric>(p1.coeff).is_zero()) {
668                                         // no further processing is necessary, since this
669                                         // one element will usually be recombined in eval()
670                                         seq.push_back(p1);
671                                 }
672                         } else {
673                                 seq.reserve(2);
674                                 if (cmpval<0) {
675                                         seq.push_back(p1);
676                                         seq.push_back(p2);
677                                 } else {
678                                         seq.push_back(p2);
679                                         seq.push_back(p1);
680                                 }
681                         }
682                 }
683         }
684 }
685
686 void expairseq::construct_from_2_expairseq(const expairseq &s1,
687                                            const expairseq &s2)
688 {
689         combine_overall_coeff(s1.overall_coeff);
690         combine_overall_coeff(s2.overall_coeff);
691
692         epvector::const_iterator first1 = s1.seq.begin();
693         epvector::const_iterator last1 = s1.seq.end();
694         epvector::const_iterator first2 = s2.seq.begin();
695         epvector::const_iterator last2 = s2.seq.end();
696
697         seq.reserve(s1.seq.size()+s2.seq.size());
698
699         bool needs_further_processing=false;
700         
701         while (first1!=last1 && first2!=last2) {
702                 int cmpval = (*first1).rest.compare((*first2).rest);
703
704                 if (cmpval==0) {
705                         // combine terms
706                         const numeric &newcoeff = ex_to<numeric>(first1->coeff).
707                                                    add(ex_to<numeric>(first2->coeff));
708                         if (!newcoeff.is_zero()) {
709                                 seq.push_back(expair(first1->rest,newcoeff));
710                                 if (expair_needs_further_processing(seq.end()-1)) {
711                                         needs_further_processing = true;
712                                 }
713                         }
714                         ++first1;
715                         ++first2;
716                 } else if (cmpval<0) {
717                         seq.push_back(*first1);
718                         ++first1;
719                 } else {
720                         seq.push_back(*first2);
721                         ++first2;
722                 }
723         }
724         
725         while (first1!=last1) {
726                 seq.push_back(*first1);
727                 ++first1;
728         }
729         while (first2!=last2) {
730                 seq.push_back(*first2);
731                 ++first2;
732         }
733         
734         if (needs_further_processing) {
735                 epvector v = seq;
736                 seq.clear();
737                 construct_from_epvector(std::move(v));
738         }
739 }
740
741 void expairseq::construct_from_expairseq_ex(const expairseq &s,
742                                             const ex &e)
743 {
744         combine_overall_coeff(s.overall_coeff);
745         if (is_exactly_a<numeric>(e)) {
746                 combine_overall_coeff(e);
747                 seq = s.seq;
748                 return;
749         }
750         
751         epvector::const_iterator first = s.seq.begin();
752         epvector::const_iterator last = s.seq.end();
753         expair p = split_ex_to_pair(e);
754         
755         seq.reserve(s.seq.size()+1);
756         bool p_pushed = false;
757         
758         bool needs_further_processing=false;
759         
760         // merge p into s.seq
761         while (first!=last) {
762                 int cmpval = (*first).rest.compare(p.rest);
763                 if (cmpval==0) {
764                         // combine terms
765                         const numeric &newcoeff = ex_to<numeric>(first->coeff).
766                                                    add(ex_to<numeric>(p.coeff));
767                         if (!newcoeff.is_zero()) {
768                                 seq.push_back(expair(first->rest,newcoeff));
769                                 if (expair_needs_further_processing(seq.end()-1))
770                                         needs_further_processing = true;
771                         }
772                         ++first;
773                         p_pushed = true;
774                         break;
775                 } else if (cmpval<0) {
776                         seq.push_back(*first);
777                         ++first;
778                 } else {
779                         seq.push_back(p);
780                         p_pushed = true;
781                         break;
782                 }
783         }
784         
785         if (p_pushed) {
786                 // while loop exited because p was pushed, now push rest of s.seq
787                 while (first!=last) {
788                         seq.push_back(*first);
789                         ++first;
790                 }
791         } else {
792                 // while loop exited because s.seq was pushed, now push p
793                 seq.push_back(p);
794         }
795
796         if (needs_further_processing) {
797                 epvector v = seq;
798                 seq.clear();
799                 construct_from_epvector(std::move(v));
800         }
801 }
802
803 void expairseq::construct_from_exvector(const exvector &v)
804 {
805         // simplifications: +(a,+(b,c),d) -> +(a,b,c,d) (associativity)
806         //                  +(d,b,c,a) -> +(a,b,c,d) (canonicalization)
807         //                  +(...,x,*(x,c1),*(x,c2)) -> +(...,*(x,1+c1+c2)) (c1, c2 numeric)
808         //                  (same for (+,*) -> (*,^)
809
810         make_flat(v);
811         canonicalize();
812         combine_same_terms_sorted_seq();
813 }
814
815 void expairseq::construct_from_epvector(const epvector &v, bool do_index_renaming)
816 {
817         // simplifications: +(a,+(b,c),d) -> +(a,b,c,d) (associativity)
818         //                  +(d,b,c,a) -> +(a,b,c,d) (canonicalization)
819         //                  +(...,x,*(x,c1),*(x,c2)) -> +(...,*(x,1+c1+c2)) (c1, c2 numeric)
820         //                  same for (+,*) -> (*,^)
821
822         make_flat(v, do_index_renaming);
823         canonicalize();
824         combine_same_terms_sorted_seq();
825 }
826
827 void expairseq::construct_from_epvector(epvector &&v, bool do_index_renaming)
828 {
829         // simplifications: +(a,+(b,c),d) -> +(a,b,c,d) (associativity)
830         //                  +(d,b,c,a) -> +(a,b,c,d) (canonicalization)
831         //                  +(...,x,*(x,c1),*(x,c2)) -> +(...,*(x,1+c1+c2)) (c1, c2 numeric)
832         //                  same for (+,*) -> (*,^)
833
834         make_flat(std::move(v), do_index_renaming);
835         canonicalize();
836         combine_same_terms_sorted_seq();
837 }
838
839 /** Combine this expairseq with argument exvector.
840  *  It cares for associativity as well as for special handling of numerics. */
841 void expairseq::make_flat(const exvector &v)
842 {
843         exvector::const_iterator cit;
844         
845         // count number of operands which are of same expairseq derived type
846         // and their cumulative number of operands
847         int nexpairseqs = 0;
848         int noperands = 0;
849         bool do_idx_rename = false;
850         
851         cit = v.begin();
852         while (cit!=v.end()) {
853                 if (typeid(ex_to<basic>(*cit)) == typeid(*this)) {
854                         ++nexpairseqs;
855                         noperands += ex_to<expairseq>(*cit).seq.size();
856                 }
857                 if (is_a<mul>(*this) && (!do_idx_rename) &&
858                                 cit->info(info_flags::has_indices))
859                         do_idx_rename = true;
860                 ++cit;
861         }
862         
863         // reserve seq and coeffseq which will hold all operands
864         seq.reserve(v.size()+noperands-nexpairseqs);
865         
866         // copy elements and split off numerical part
867         make_flat_inserter mf(v, do_idx_rename);
868         cit = v.begin();
869         while (cit!=v.end()) {
870                 if (typeid(ex_to<basic>(*cit)) == typeid(*this)) {
871                         ex newfactor = mf.handle_factor(*cit, _ex1);
872                         const expairseq &subseqref = ex_to<expairseq>(newfactor);
873                         combine_overall_coeff(subseqref.overall_coeff);
874                         epvector::const_iterator cit_s = subseqref.seq.begin();
875                         while (cit_s!=subseqref.seq.end()) {
876                                 seq.push_back(*cit_s);
877                                 ++cit_s;
878                         }
879                 } else {
880                         if (is_exactly_a<numeric>(*cit))
881                                 combine_overall_coeff(*cit);
882                         else {
883                                 ex newfactor = mf.handle_factor(*cit, _ex1);
884                                 seq.push_back(split_ex_to_pair(newfactor));
885                         }
886                 }
887                 ++cit;
888         }
889 }
890
891 /** Combine this expairseq with argument epvector.
892  *  It cares for associativity as well as for special handling of numerics. */
893 void expairseq::make_flat(const epvector &v, bool do_index_renaming)
894 {
895         epvector::const_iterator cit;
896         
897         // count number of operands which are of same expairseq derived type
898         // and their cumulative number of operands
899         int nexpairseqs = 0;
900         int noperands = 0;
901         bool really_need_rename_inds = false;
902         
903         cit = v.begin();
904         while (cit!=v.end()) {
905                 if (typeid(ex_to<basic>(cit->rest)) == typeid(*this)) {
906                         ++nexpairseqs;
907                         noperands += ex_to<expairseq>(cit->rest).seq.size();
908                 }
909                 if ((!really_need_rename_inds) && is_a<mul>(*this) &&
910                                 cit->rest.info(info_flags::has_indices))
911                         really_need_rename_inds = true;
912                 ++cit;
913         }
914         do_index_renaming = do_index_renaming && really_need_rename_inds;
915         
916         // reserve seq and coeffseq which will hold all operands
917         seq.reserve(v.size()+noperands-nexpairseqs);
918         make_flat_inserter mf(v, do_index_renaming);
919         
920         // copy elements and split off numerical part
921         cit = v.begin();
922         while (cit!=v.end()) {
923                 if ((typeid(ex_to<basic>(cit->rest)) == typeid(*this)) &&
924                     this->can_make_flat(*cit)) {
925                         ex newrest = mf.handle_factor(cit->rest, cit->coeff);
926                         const expairseq &subseqref = ex_to<expairseq>(newrest);
927                         combine_overall_coeff(ex_to<numeric>(subseqref.overall_coeff),
928                                               ex_to<numeric>(cit->coeff));
929                         epvector::const_iterator cit_s = subseqref.seq.begin();
930                         while (cit_s!=subseqref.seq.end()) {
931                                 seq.push_back(expair(cit_s->rest,
932                                                      ex_to<numeric>(cit_s->coeff).mul_dyn(ex_to<numeric>(cit->coeff))));
933                                 ++cit_s;
934                         }
935                 } else {
936                         if (cit->is_canonical_numeric())
937                                 combine_overall_coeff(mf.handle_factor(cit->rest, _ex1));
938                         else {
939                                 ex rest = cit->rest;
940                                 ex newrest = mf.handle_factor(rest, cit->coeff);
941                                 if (are_ex_trivially_equal(newrest, rest))
942                                         seq.push_back(*cit);
943                                 else
944                                         seq.push_back(expair(newrest, cit->coeff));
945                         }
946                 }
947                 ++cit;
948         }
949 }
950
951 /** Brings this expairseq into a sorted (canonical) form. */
952 void expairseq::canonicalize()
953 {
954         std::sort(seq.begin(), seq.end(), expair_rest_is_less());
955 }
956
957
958 /** Compact a presorted expairseq by combining all matching expairs to one
959  *  each.  On an add object, this is responsible for 2*x+3*x+y -> 5*x+y, for
960  *  instance. */
961 void expairseq::combine_same_terms_sorted_seq()
962 {
963         if (seq.size()<2)
964                 return;
965
966         bool needs_further_processing = false;
967
968         epvector::iterator itin1 = seq.begin();
969         epvector::iterator itin2 = itin1+1;
970         epvector::iterator itout = itin1;
971         epvector::iterator last = seq.end();
972         // must_copy will be set to true the first time some combination is 
973         // possible from then on the sequence has changed and must be compacted
974         bool must_copy = false;
975         while (itin2!=last) {
976                 if (itin1->rest.compare(itin2->rest)==0) {
977                         itin1->coeff = ex_to<numeric>(itin1->coeff).
978                                        add_dyn(ex_to<numeric>(itin2->coeff));
979                         if (expair_needs_further_processing(itin1))
980                                 needs_further_processing = true;
981                         must_copy = true;
982                 } else {
983                         if (!ex_to<numeric>(itin1->coeff).is_zero()) {
984                                 if (must_copy)
985                                         *itout = *itin1;
986                                 ++itout;
987                         }
988                         itin1 = itin2;
989                 }
990                 ++itin2;
991         }
992         if (!ex_to<numeric>(itin1->coeff).is_zero()) {
993                 if (must_copy)
994                         *itout = *itin1;
995                 ++itout;
996         }
997         if (itout!=last)
998                 seq.erase(itout,last);
999
1000         if (needs_further_processing) {
1001                 epvector v = seq;
1002                 seq.clear();
1003                 construct_from_epvector(std::move(v));
1004         }
1005 }
1006
1007 /** Check if this expairseq is in sorted (canonical) form.  Useful mainly for
1008  *  debugging or in assertions since being sorted is an invariance. */
1009 bool expairseq::is_canonical() const
1010 {
1011         if (seq.size() <= 1)
1012                 return 1;
1013         
1014         epvector::const_iterator it = seq.begin(), itend = seq.end();
1015         epvector::const_iterator it_last = it;
1016         for (++it; it!=itend; it_last=it, ++it) {
1017                 if (!(it_last->is_less(*it) || it_last->is_equal(*it))) {
1018                         if (!is_exactly_a<numeric>(it_last->rest) ||
1019                                 !is_exactly_a<numeric>(it->rest)) {
1020                                 // double test makes it easier to set a breakpoint...
1021                                 if (!is_exactly_a<numeric>(it_last->rest) ||
1022                                         !is_exactly_a<numeric>(it->rest)) {
1023                                         printpair(std::clog, *it_last, 0);
1024                                         std::clog << ">";
1025                                         printpair(std::clog, *it, 0);
1026                                         std::clog << "\n";
1027                                         std::clog << "pair1:" << std::endl;
1028                                         it_last->rest.print(print_tree(std::clog));
1029                                         it_last->coeff.print(print_tree(std::clog));
1030                                         std::clog << "pair2:" << std::endl;
1031                                         it->rest.print(print_tree(std::clog));
1032                                         it->coeff.print(print_tree(std::clog));
1033                                         return 0;
1034                                 }
1035                         }
1036                 }
1037         }
1038         return 1;
1039 }
1040
1041 /** Member-wise expand the expairs in this sequence.
1042  *
1043  *  @see expairseq::expand()
1044  *  @return epvector containing expanded pairs, empty if no members
1045  *    had to be changed. */
1046 epvector expairseq::expandchildren(unsigned options) const
1047 {
1048         const epvector::const_iterator last = seq.end();
1049         epvector::const_iterator cit = seq.begin();
1050         while (cit!=last) {
1051                 const ex &expanded_ex = cit->rest.expand(options);
1052                 if (!are_ex_trivially_equal(cit->rest,expanded_ex)) {
1053                         
1054                         // something changed, copy seq, eval and return it
1055                         epvector s;
1056                         s.reserve(seq.size());
1057                         
1058                         // copy parts of seq which are known not to have changed
1059                         epvector::const_iterator cit2 = seq.begin();
1060                         while (cit2!=cit) {
1061                                 s.push_back(*cit2);
1062                                 ++cit2;
1063                         }
1064
1065                         // copy first changed element
1066                         s.push_back(combine_ex_with_coeff_to_pair(expanded_ex,
1067                                                                   cit2->coeff));
1068                         ++cit2;
1069
1070                         // copy rest
1071                         while (cit2!=last) {
1072                                 s.push_back(combine_ex_with_coeff_to_pair(cit2->rest.expand(options),
1073                                                                           cit2->coeff));
1074                                 ++cit2;
1075                         }
1076                         return s;
1077                 }
1078                 ++cit;
1079         }
1080         
1081         return epvector(); // empty signalling nothing has changed
1082 }
1083
1084
1085 /** Member-wise evaluate the expairs in this sequence.
1086  *
1087  *  @see expairseq::eval()
1088  *  @return epvector containing evaluated pairs, empty if no members
1089  *    had to be changed. */
1090 epvector expairseq::evalchildren(int level) const
1091 {
1092         if (level==1)
1093                 return epvector();  // nothing had to be evaluated
1094         
1095         if (level == -max_recursion_level)
1096                 throw(std::runtime_error("max recursion level reached"));
1097         
1098         --level;
1099         epvector::const_iterator last = seq.end();
1100         epvector::const_iterator cit = seq.begin();
1101         while (cit!=last) {
1102                 const ex evaled_ex = cit->rest.eval(level);
1103                 if (!are_ex_trivially_equal(cit->rest,evaled_ex)) {
1104                         
1105                         // something changed, copy seq, eval and return it
1106                         epvector s;
1107                         s.reserve(seq.size());
1108                         
1109                         // copy parts of seq which are known not to have changed
1110                         epvector::const_iterator cit2=seq.begin();
1111                         while (cit2!=cit) {
1112                                 s.push_back(*cit2);
1113                                 ++cit2;
1114                         }
1115
1116                         // copy first changed element
1117                         s.push_back(combine_ex_with_coeff_to_pair(evaled_ex,
1118                                                                   cit2->coeff));
1119                         ++cit2;
1120
1121                         // copy rest
1122                         while (cit2!=last) {
1123                                 s.push_back(combine_ex_with_coeff_to_pair(cit2->rest.eval(level),
1124                                                                           cit2->coeff));
1125                                 ++cit2;
1126                         }
1127                         return std::move(s);
1128                 }
1129                 ++cit;
1130         }
1131
1132         return epvector(); // signalling nothing has changed
1133 }
1134
1135 /** Member-wise substitute in this sequence.
1136  *
1137  *  @see expairseq::subs()
1138  *  @return epvector containing expanded pairs, empty if no members
1139  *    had to be changed. */
1140 epvector expairseq::subschildren(const exmap & m, unsigned options) const
1141 {
1142         // When any of the objects to be substituted is a product or power
1143         // we have to recombine the pairs because the numeric coefficients may
1144         // be part of the search pattern.
1145         if (!(options & (subs_options::pattern_is_product | subs_options::pattern_is_not_product))) {
1146
1147                 // Search the list of substitutions and cache our findings
1148                 for (exmap::const_iterator it = m.begin(); it != m.end(); ++it) {
1149                         if (is_exactly_a<mul>(it->first) || is_exactly_a<power>(it->first)) {
1150                                 options |= subs_options::pattern_is_product;
1151                                 break;
1152                         }
1153                 }
1154                 if (!(options & subs_options::pattern_is_product))
1155                         options |= subs_options::pattern_is_not_product;
1156         }
1157
1158         if (options & subs_options::pattern_is_product) {
1159
1160                 // Substitute in the recombined pairs
1161                 epvector::const_iterator cit = seq.begin(), last = seq.end();
1162                 while (cit != last) {
1163
1164                         const ex &orig_ex = recombine_pair_to_ex(*cit);
1165                         const ex &subsed_ex = orig_ex.subs(m, options);
1166                         if (!are_ex_trivially_equal(orig_ex, subsed_ex)) {
1167
1168                                 // Something changed, copy seq, subs and return it
1169                                 epvector s;
1170                                 s.reserve(seq.size());
1171
1172                                 // Copy parts of seq which are known not to have changed
1173                                 s.insert(s.begin(), seq.begin(), cit);
1174
1175                                 // Copy first changed element
1176                                 s.push_back(split_ex_to_pair(subsed_ex));
1177                                 ++cit;
1178
1179                                 // Copy rest
1180                                 while (cit != last) {
1181                                         s.push_back(split_ex_to_pair(recombine_pair_to_ex(*cit).subs(m, options)));
1182                                         ++cit;
1183                                 }
1184                                 return s;
1185                         }
1186
1187                         ++cit;
1188                 }
1189
1190         } else {
1191
1192                 // Substitute only in the "rest" part of the pairs
1193                 epvector::const_iterator cit = seq.begin(), last = seq.end();
1194                 while (cit != last) {
1195
1196                         const ex &subsed_ex = cit->rest.subs(m, options);
1197                         if (!are_ex_trivially_equal(cit->rest, subsed_ex)) {
1198                         
1199                                 // Something changed, copy seq, subs and return it
1200                                 epvector s;
1201                                 s.reserve(seq.size());
1202
1203                                 // Copy parts of seq which are known not to have changed
1204                                 s.insert(s.begin(), seq.begin(), cit);
1205                         
1206                                 // Copy first changed element
1207                                 s.push_back(combine_ex_with_coeff_to_pair(subsed_ex, cit->coeff));
1208                                 ++cit;
1209
1210                                 // Copy rest
1211                                 while (cit != last) {
1212                                         s.push_back(combine_ex_with_coeff_to_pair(cit->rest.subs(m, options), cit->coeff));
1213                                         ++cit;
1214                                 }
1215                                 return s;
1216                         }
1217
1218                         ++cit;
1219                 }
1220         }
1221         
1222         // Nothing has changed
1223         return epvector();
1224 }
1225
1226 //////////
1227 // static member variables
1228 //////////
1229
1230 } // namespace GiNaC