18aba0dbdcb8900587ce5cc746efebf5b064db80
[ginac.git] / ginac / expairseq.cpp
1 /** @file expairseq.cpp
2  *
3  *  Implementation of sequences of expression pairs. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "expairseq.h"
24 #include "lst.h"
25 #include "add.h"
26 #include "mul.h"
27 #include "power.h"
28 #include "relational.h"
29 #include "wildcard.h"
30 #include "archive.h"
31 #include "operators.h"
32 #include "utils.h"
33 #include "hash_seed.h"
34 #include "indexed.h"
35
36 #include <algorithm>
37 #include <iostream>
38 #include <iterator>
39 #include <stdexcept>
40 #include <string>
41
42 namespace GiNaC {
43
44         
45 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(expairseq, basic,
46   print_func<print_context>(&expairseq::do_print).
47   print_func<print_tree>(&expairseq::do_print_tree))
48
49
50 //////////
51 // helper classes
52 //////////
53
54 class epp_is_less
55 {
56 public:
57         bool operator()(const epp &lh, const epp &rh) const
58         {
59                 return (*lh).is_less(*rh);
60         }
61 };
62
63 //////////
64 // default constructor
65 //////////
66
67 // public
68
69 expairseq::expairseq() 
70 {}
71
72 // protected
73
74 //////////
75 // other constructors
76 //////////
77
78 expairseq::expairseq(const ex &lh, const ex &rh)
79 {
80         construct_from_2_ex(lh,rh);
81         GINAC_ASSERT(is_canonical());
82 }
83
84 expairseq::expairseq(const exvector &v)
85 {
86         construct_from_exvector(v);
87         GINAC_ASSERT(is_canonical());
88 }
89
90 expairseq::expairseq(const epvector &v, const ex &oc, bool do_index_renaming)
91   :  overall_coeff(oc)
92 {
93         GINAC_ASSERT(is_a<numeric>(oc));
94         construct_from_epvector(v, do_index_renaming);
95         GINAC_ASSERT(is_canonical());
96 }
97
98 expairseq::expairseq(std::auto_ptr<epvector> vp, const ex &oc, bool do_index_renaming)
99   :  overall_coeff(oc)
100 {
101         GINAC_ASSERT(vp.get()!=0);
102         GINAC_ASSERT(is_a<numeric>(oc));
103         construct_from_epvector(*vp, do_index_renaming);
104         GINAC_ASSERT(is_canonical());
105 }
106
107 //////////
108 // archiving
109 //////////
110
111 void expairseq::read_archive(const archive_node &n, lst &sym_lst) 
112 {
113         inherited::read_archive(n, sym_lst);
114         archive_node::archive_node_cit first = n.find_first("rest");
115         archive_node::archive_node_cit last = n.find_last("coeff");
116         ++last;
117         seq.reserve((last-first)/2);
118
119         for (archive_node::archive_node_cit loc = first; loc < last;) {
120                 ex rest;
121                 ex coeff;
122                 n.find_ex_by_loc(loc++, rest, sym_lst);
123                 n.find_ex_by_loc(loc++, coeff, sym_lst);
124                 seq.push_back(expair(rest, coeff));
125         }
126
127         n.find_ex("overall_coeff", overall_coeff, sym_lst);
128
129         canonicalize();
130         GINAC_ASSERT(is_canonical());
131 }
132
133 void expairseq::archive(archive_node &n) const
134 {
135         inherited::archive(n);
136         epvector::const_iterator i = seq.begin(), iend = seq.end();
137         while (i != iend) {
138                 n.add_ex("rest", i->rest);
139                 n.add_ex("coeff", i->coeff);
140                 ++i;
141         }
142         n.add_ex("overall_coeff", overall_coeff);
143 }
144
145
146 //////////
147 // functions overriding virtual functions from base classes
148 //////////
149
150 // public
151
152 void expairseq::do_print(const print_context & c, unsigned level) const
153 {
154         c.s << "[[";
155         printseq(c, ',', precedence(), level);
156         c.s << "]]";
157 }
158
159 void expairseq::do_print_tree(const print_tree & c, unsigned level) const
160 {
161         c.s << std::string(level, ' ') << class_name() << " @" << this
162             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
163             << ", nops=" << nops()
164             << std::endl;
165         size_t num = seq.size();
166         for (size_t i=0; i<num; ++i) {
167                 seq[i].rest.print(c, level + c.delta_indent);
168                 seq[i].coeff.print(c, level + c.delta_indent);
169                 if (i != num - 1)
170                         c.s << std::string(level + c.delta_indent, ' ') << "-----" << std::endl;
171         }
172         if (!overall_coeff.is_equal(default_overall_coeff())) {
173                 c.s << std::string(level + c.delta_indent, ' ') << "-----" << std::endl
174                     << std::string(level + c.delta_indent, ' ') << "overall_coeff" << std::endl;
175                 overall_coeff.print(c, level + c.delta_indent);
176         }
177         c.s << std::string(level + c.delta_indent,' ') << "=====" << std::endl;
178 }
179
180 bool expairseq::info(unsigned inf) const
181 {
182         switch(inf) {
183                 case info_flags::expanded:
184                         return (flags & status_flags::expanded);
185                 case info_flags::has_indices: {
186                         if (flags & status_flags::has_indices)
187                                 return true;
188                         else if (flags & status_flags::has_no_indices)
189                                 return false;
190                         for (epvector::const_iterator i = seq.begin(); i != seq.end(); ++i) {
191                                 if (i->rest.info(info_flags::has_indices)) {
192                                         this->setflag(status_flags::has_indices);
193                                         this->clearflag(status_flags::has_no_indices);
194                                         return true;
195                                 }
196                         }
197                         this->clearflag(status_flags::has_indices);
198                         this->setflag(status_flags::has_no_indices);
199                         return false;
200                 }
201         }
202         return inherited::info(inf);
203 }
204
205 size_t expairseq::nops() const
206 {
207         if (overall_coeff.is_equal(default_overall_coeff()))
208                 return seq.size();
209         else
210                 return seq.size()+1;
211 }
212
213 ex expairseq::op(size_t i) const
214 {
215         if (i < seq.size())
216                 return recombine_pair_to_ex(seq[i]);
217         GINAC_ASSERT(!overall_coeff.is_equal(default_overall_coeff()));
218         return overall_coeff;
219 }
220
221 ex expairseq::map(map_function &f) const
222 {
223         std::auto_ptr<epvector> v(new epvector);
224         v->reserve(seq.size()+1);
225
226         epvector::const_iterator cit = seq.begin(), last = seq.end();
227         while (cit != last) {
228                 v->push_back(split_ex_to_pair(f(recombine_pair_to_ex(*cit))));
229                 ++cit;
230         }
231
232         if (overall_coeff.is_equal(default_overall_coeff()))
233                 return thisexpairseq(v, default_overall_coeff(), true);
234         else {
235                 ex newcoeff = f(overall_coeff);
236                 if(is_a<numeric>(newcoeff))
237                         return thisexpairseq(v, newcoeff, true);
238                 else {
239                         v->push_back(split_ex_to_pair(newcoeff));
240                         return thisexpairseq(v, default_overall_coeff(), true);
241                 }
242         }
243 }
244
245 /** Perform coefficient-wise automatic term rewriting rules in this class. */
246 ex expairseq::eval(int level) const
247 {
248         if ((level==1) && (flags &status_flags::evaluated))
249                 return *this;
250         
251         std::auto_ptr<epvector> vp = evalchildren(level);
252         if (vp.get() == 0)
253                 return this->hold();
254         
255         return (new expairseq(vp, overall_coeff))->setflag(status_flags::dynallocated | status_flags::evaluated);
256 }
257
258 epvector* conjugateepvector(const epvector&epv)
259 {
260         epvector *newepv = 0;
261         for (epvector::const_iterator i=epv.begin(); i!=epv.end(); ++i) {
262                 if(newepv) {
263                         newepv->push_back(i->conjugate());
264                         continue;
265                 }
266                 expair x = i->conjugate();
267                 if (x.is_equal(*i)) {
268                         continue;
269                 }
270                 newepv = new epvector;
271                 newepv->reserve(epv.size());
272                 for (epvector::const_iterator j=epv.begin(); j!=i; ++j) {
273                         newepv->push_back(*j);
274                 }
275                 newepv->push_back(x);
276         }
277         return newepv;
278 }
279
280 ex expairseq::conjugate() const
281 {
282         epvector* newepv = conjugateepvector(seq);
283         ex x = overall_coeff.conjugate();
284         if (!newepv && are_ex_trivially_equal(x, overall_coeff)) {
285                 return *this;
286         }
287         ex result = thisexpairseq(newepv ? *newepv : seq, x);
288         delete newepv;
289         return result;
290 }
291
292 bool expairseq::match(const ex & pattern, exmap & repl_lst) const
293 {
294         // This differs from basic::match() because we want "a+b+c+d" to
295         // match "d+*+b" with "*" being "a+c", and we want to honor commutativity
296
297         if (typeid(*this) == typeid(ex_to<basic>(pattern))) {
298
299                 // Check whether global wildcard (one that matches the "rest of the
300                 // expression", like "*" above) is present
301                 bool has_global_wildcard = false;
302                 ex global_wildcard;
303                 for (size_t i=0; i<pattern.nops(); i++) {
304                         if (is_exactly_a<wildcard>(pattern.op(i))) {
305                                 has_global_wildcard = true;
306                                 global_wildcard = pattern.op(i);
307                                 break;
308                         }
309                 }
310
311                 // Even if the expression does not match the pattern, some of
312                 // its subexpressions could match it. For example, x^5*y^(-1)
313                 // does not match the pattern $0^5, but its subexpression x^5
314                 // does. So, save repl_lst in order to not add bogus entries.
315                 exmap tmp_repl = repl_lst;
316
317                 // Unfortunately, this is an O(N^2) operation because we can't
318                 // sort the pattern in a useful way...
319
320                 // Chop into terms
321                 exvector ops;
322                 ops.reserve(nops());
323                 for (size_t i=0; i<nops(); i++)
324                         ops.push_back(op(i));
325
326                 // Now, for every term of the pattern, look for a matching term in
327                 // the expression and remove the match
328                 for (size_t i=0; i<pattern.nops(); i++) {
329                         ex p = pattern.op(i);
330                         if (has_global_wildcard && p.is_equal(global_wildcard))
331                                 continue;
332                         exvector::iterator it = ops.begin(), itend = ops.end();
333                         while (it != itend) {
334                                 if (it->match(p, tmp_repl)) {
335                                         ops.erase(it);
336                                         goto found;
337                                 }
338                                 ++it;
339                         }
340                         return false; // no match found
341 found:          ;
342                 }
343
344                 if (has_global_wildcard) {
345
346                         // Assign all the remaining terms to the global wildcard (unless
347                         // it has already been matched before, in which case the matches
348                         // must be equal)
349                         size_t num = ops.size();
350                         std::auto_ptr<epvector> vp(new epvector);
351                         vp->reserve(num);
352                         for (size_t i=0; i<num; i++)
353                                 vp->push_back(split_ex_to_pair(ops[i]));
354                         ex rest = thisexpairseq(vp, default_overall_coeff());
355                         for (exmap::const_iterator it = tmp_repl.begin(); it != tmp_repl.end(); ++it) {
356                                 if (it->first.is_equal(global_wildcard)) {
357                                         if (rest.is_equal(it->second)) {
358                                                 repl_lst = tmp_repl;
359                                                 return true;
360                                         }
361                                         return false;
362                                 }
363                         }
364                         repl_lst = tmp_repl;
365                         repl_lst[global_wildcard] = rest;
366                         return true;
367
368                 } else {
369
370                         // No global wildcard, then the match fails if there are any
371                         // unmatched terms left
372                         if (ops.empty()) {
373                                 repl_lst = tmp_repl;
374                                 return true;
375                         }
376                         return false;
377                 }
378         }
379         return inherited::match(pattern, repl_lst);
380 }
381
382 ex expairseq::subs(const exmap & m, unsigned options) const
383 {
384         std::auto_ptr<epvector> vp = subschildren(m, options);
385         if (vp.get())
386                 return ex_to<basic>(thisexpairseq(vp, overall_coeff, (options & subs_options::no_index_renaming) == 0));
387         else if ((options & subs_options::algebraic) && is_exactly_a<mul>(*this))
388                 return static_cast<const mul *>(this)->algebraic_subs_mul(m, options);
389         else
390                 return subs_one_level(m, options);
391 }
392
393 // protected
394
395 int expairseq::compare_same_type(const basic &other) const
396 {
397         GINAC_ASSERT(is_a<expairseq>(other));
398         const expairseq &o = static_cast<const expairseq &>(other);
399         
400         int cmpval;
401         
402         // compare number of elements
403         if (seq.size() != o.seq.size())
404                 return (seq.size()<o.seq.size()) ? -1 : 1;
405         
406         // compare overall_coeff
407         cmpval = overall_coeff.compare(o.overall_coeff);
408         if (cmpval!=0)
409                 return cmpval;
410         
411         epvector::const_iterator cit1 = seq.begin();
412         epvector::const_iterator cit2 = o.seq.begin();
413         epvector::const_iterator last1 = seq.end();
414         epvector::const_iterator last2 = o.seq.end();
415                 
416         for (; (cit1!=last1)&&(cit2!=last2); ++cit1, ++cit2) {
417                 cmpval = (*cit1).compare(*cit2);
418                 if (cmpval!=0) return cmpval;
419         }
420                 
421         GINAC_ASSERT(cit1==last1);
422         GINAC_ASSERT(cit2==last2);
423                 
424         return 0;
425 }
426
427 bool expairseq::is_equal_same_type(const basic &other) const
428 {
429         const expairseq &o = static_cast<const expairseq &>(other);
430         
431         // compare number of elements
432         if (seq.size()!=o.seq.size())
433                 return false;
434         
435         // compare overall_coeff
436         if (!overall_coeff.is_equal(o.overall_coeff))
437                 return false;
438         
439         epvector::const_iterator cit1 = seq.begin();
440         epvector::const_iterator cit2 = o.seq.begin();
441         epvector::const_iterator last1 = seq.end();
442                 
443         while (cit1!=last1) {
444                 if (!(*cit1).is_equal(*cit2)) return false;
445                 ++cit1;
446                 ++cit2;
447         }
448
449         return true;
450 }
451
452 unsigned expairseq::return_type() const
453 {
454         return return_types::noncommutative_composite;
455 }
456
457 unsigned expairseq::calchash() const
458 {
459         unsigned v = make_hash_seed(typeid(*this));
460         epvector::const_iterator i = seq.begin();
461         const epvector::const_iterator end = seq.end();
462         while (i != end) {
463                 v ^= i->rest.gethash();
464                 // rotation spoils commutativity!
465                 v = rotate_left(v);
466                 v ^= i->coeff.gethash();
467                 ++i;
468         }
469
470         v ^= overall_coeff.gethash();
471
472         // store calculated hash value only if object is already evaluated
473         if (flags &status_flags::evaluated) {
474                 setflag(status_flags::hash_calculated);
475                 hashvalue = v;
476         }
477         
478         return v;
479 }
480
481 ex expairseq::expand(unsigned options) const
482 {
483         std::auto_ptr<epvector> vp = expandchildren(options);
484         if (vp.get())
485                 return thisexpairseq(vp, overall_coeff);
486         else {
487                 // The terms have not changed, so it is safe to declare this expanded
488                 return (options == 0) ? setflag(status_flags::expanded) : *this;
489         }
490 }
491
492 //////////
493 // new virtual functions which can be overridden by derived classes
494 //////////
495
496 // protected
497
498 /** Create an object of this type.
499  *  This method works similar to a constructor.  It is useful because expairseq
500  *  has (at least) two possible different semantics but we want to inherit
501  *  methods thus avoiding code duplication.  Sometimes a method in expairseq
502  *  has to create a new one of the same semantics, which cannot be done by a
503  *  ctor because the name (add, mul,...) is unknown on the expairseq level.  In
504  *  order for this trick to work a derived class must of course override this
505  *  definition. */
506 ex expairseq::thisexpairseq(const epvector &v, const ex &oc, bool do_index_renaming) const
507 {
508         return expairseq(v, oc, do_index_renaming);
509 }
510
511 ex expairseq::thisexpairseq(std::auto_ptr<epvector> vp, const ex &oc, bool do_index_renaming) const
512 {
513         return expairseq(vp, oc, do_index_renaming);
514 }
515
516 void expairseq::printpair(const print_context & c, const expair & p, unsigned upper_precedence) const
517 {
518         c.s << "[[";
519         p.rest.print(c, precedence());
520         c.s << ",";
521         p.coeff.print(c, precedence());
522         c.s << "]]";
523 }
524
525 void expairseq::printseq(const print_context & c, char delim,
526                          unsigned this_precedence,
527                          unsigned upper_precedence) const
528 {
529         if (this_precedence <= upper_precedence)
530                 c.s << "(";
531         epvector::const_iterator it, it_last = seq.end() - 1;
532         for (it=seq.begin(); it!=it_last; ++it) {
533                 printpair(c, *it, this_precedence);
534                 c.s << delim;
535         }
536         printpair(c, *it, this_precedence);
537         if (!overall_coeff.is_equal(default_overall_coeff())) {
538                 c.s << delim;
539                 overall_coeff.print(c, this_precedence);
540         }
541         
542         if (this_precedence <= upper_precedence)
543                 c.s << ")";
544 }
545
546
547 /** Form an expair from an ex, using the corresponding semantics.
548  *  @see expairseq::recombine_pair_to_ex() */
549 expair expairseq::split_ex_to_pair(const ex &e) const
550 {
551         return expair(e,_ex1);
552 }
553
554
555 expair expairseq::combine_ex_with_coeff_to_pair(const ex &e,
556                                                 const ex &c) const
557 {
558         GINAC_ASSERT(is_exactly_a<numeric>(c));
559         
560         return expair(e,c);
561 }
562
563
564 expair expairseq::combine_pair_with_coeff_to_pair(const expair &p,
565                                                   const ex &c) const
566 {
567         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
568         GINAC_ASSERT(is_exactly_a<numeric>(c));
569         
570         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
571 }
572
573
574 /** Form an ex out of an expair, using the corresponding semantics.
575  *  @see expairseq::split_ex_to_pair() */
576 ex expairseq::recombine_pair_to_ex(const expair &p) const
577 {
578         return lst(p.rest,p.coeff);
579 }
580
581 bool expairseq::expair_needs_further_processing(epp it)
582 {
583         return false;
584 }
585
586 ex expairseq::default_overall_coeff() const
587 {
588         return _ex0;
589 }
590
591 void expairseq::combine_overall_coeff(const ex &c)
592 {
593         GINAC_ASSERT(is_exactly_a<numeric>(overall_coeff));
594         GINAC_ASSERT(is_exactly_a<numeric>(c));
595         overall_coeff = ex_to<numeric>(overall_coeff).add_dyn(ex_to<numeric>(c));
596 }
597
598 void expairseq::combine_overall_coeff(const ex &c1, const ex &c2)
599 {
600         GINAC_ASSERT(is_exactly_a<numeric>(overall_coeff));
601         GINAC_ASSERT(is_exactly_a<numeric>(c1));
602         GINAC_ASSERT(is_exactly_a<numeric>(c2));
603         overall_coeff = ex_to<numeric>(overall_coeff).
604                         add_dyn(ex_to<numeric>(c1).mul(ex_to<numeric>(c2)));
605 }
606
607 bool expairseq::can_make_flat(const expair &p) const
608 {
609         return true;
610 }
611
612
613 //////////
614 // non-virtual functions in this class
615 //////////
616
617 void expairseq::construct_from_2_ex_via_exvector(const ex &lh, const ex &rh)
618 {
619         exvector v;
620         v.reserve(2);
621         v.push_back(lh);
622         v.push_back(rh);
623         construct_from_exvector(v);
624 }
625
626 void expairseq::construct_from_2_ex(const ex &lh, const ex &rh)
627 {
628         if (typeid(ex_to<basic>(lh)) == typeid(*this)) {
629                 if (typeid(ex_to<basic>(rh)) == typeid(*this)) {
630                         if (is_a<mul>(lh) && lh.info(info_flags::has_indices) && 
631                                 rh.info(info_flags::has_indices)) {
632                                 ex newrh=rename_dummy_indices_uniquely(lh, rh);
633                                 construct_from_2_expairseq(ex_to<expairseq>(lh),
634                                                            ex_to<expairseq>(newrh));
635                         }
636                         else
637                                 construct_from_2_expairseq(ex_to<expairseq>(lh),
638                                                            ex_to<expairseq>(rh));
639                         return;
640                 } else {
641                         construct_from_expairseq_ex(ex_to<expairseq>(lh), rh);
642                         return;
643                 }
644         } else if (typeid(ex_to<basic>(rh)) == typeid(*this)) {
645                 construct_from_expairseq_ex(ex_to<expairseq>(rh),lh);
646                 return;
647         }
648         
649         if (is_exactly_a<numeric>(lh)) {
650                 if (is_exactly_a<numeric>(rh)) {
651                         combine_overall_coeff(lh);
652                         combine_overall_coeff(rh);
653                 } else {
654                         combine_overall_coeff(lh);
655                         seq.push_back(split_ex_to_pair(rh));
656                 }
657         } else {
658                 if (is_exactly_a<numeric>(rh)) {
659                         combine_overall_coeff(rh);
660                         seq.push_back(split_ex_to_pair(lh));
661                 } else {
662                         expair p1 = split_ex_to_pair(lh);
663                         expair p2 = split_ex_to_pair(rh);
664                         
665                         int cmpval = p1.rest.compare(p2.rest);
666                         if (cmpval==0) {
667                                 p1.coeff = ex_to<numeric>(p1.coeff).add_dyn(ex_to<numeric>(p2.coeff));
668                                 if (!ex_to<numeric>(p1.coeff).is_zero()) {
669                                         // no further processing is necessary, since this
670                                         // one element will usually be recombined in eval()
671                                         seq.push_back(p1);
672                                 }
673                         } else {
674                                 seq.reserve(2);
675                                 if (cmpval<0) {
676                                         seq.push_back(p1);
677                                         seq.push_back(p2);
678                                 } else {
679                                         seq.push_back(p2);
680                                         seq.push_back(p1);
681                                 }
682                         }
683                 }
684         }
685 }
686
687 void expairseq::construct_from_2_expairseq(const expairseq &s1,
688                                            const expairseq &s2)
689 {
690         combine_overall_coeff(s1.overall_coeff);
691         combine_overall_coeff(s2.overall_coeff);
692
693         epvector::const_iterator first1 = s1.seq.begin();
694         epvector::const_iterator last1 = s1.seq.end();
695         epvector::const_iterator first2 = s2.seq.begin();
696         epvector::const_iterator last2 = s2.seq.end();
697
698         seq.reserve(s1.seq.size()+s2.seq.size());
699
700         bool needs_further_processing=false;
701         
702         while (first1!=last1 && first2!=last2) {
703                 int cmpval = (*first1).rest.compare((*first2).rest);
704
705                 if (cmpval==0) {
706                         // combine terms
707                         const numeric &newcoeff = ex_to<numeric>(first1->coeff).
708                                                    add(ex_to<numeric>(first2->coeff));
709                         if (!newcoeff.is_zero()) {
710                                 seq.push_back(expair(first1->rest,newcoeff));
711                                 if (expair_needs_further_processing(seq.end()-1)) {
712                                         needs_further_processing = true;
713                                 }
714                         }
715                         ++first1;
716                         ++first2;
717                 } else if (cmpval<0) {
718                         seq.push_back(*first1);
719                         ++first1;
720                 } else {
721                         seq.push_back(*first2);
722                         ++first2;
723                 }
724         }
725         
726         while (first1!=last1) {
727                 seq.push_back(*first1);
728                 ++first1;
729         }
730         while (first2!=last2) {
731                 seq.push_back(*first2);
732                 ++first2;
733         }
734         
735         if (needs_further_processing) {
736                 epvector v = seq;
737                 seq.clear();
738                 construct_from_epvector(v);
739         }
740 }
741
742 void expairseq::construct_from_expairseq_ex(const expairseq &s,
743                                             const ex &e)
744 {
745         combine_overall_coeff(s.overall_coeff);
746         if (is_exactly_a<numeric>(e)) {
747                 combine_overall_coeff(e);
748                 seq = s.seq;
749                 return;
750         }
751         
752         epvector::const_iterator first = s.seq.begin();
753         epvector::const_iterator last = s.seq.end();
754         expair p = split_ex_to_pair(e);
755         
756         seq.reserve(s.seq.size()+1);
757         bool p_pushed = false;
758         
759         bool needs_further_processing=false;
760         
761         // merge p into s.seq
762         while (first!=last) {
763                 int cmpval = (*first).rest.compare(p.rest);
764                 if (cmpval==0) {
765                         // combine terms
766                         const numeric &newcoeff = ex_to<numeric>(first->coeff).
767                                                    add(ex_to<numeric>(p.coeff));
768                         if (!newcoeff.is_zero()) {
769                                 seq.push_back(expair(first->rest,newcoeff));
770                                 if (expair_needs_further_processing(seq.end()-1))
771                                         needs_further_processing = true;
772                         }
773                         ++first;
774                         p_pushed = true;
775                         break;
776                 } else if (cmpval<0) {
777                         seq.push_back(*first);
778                         ++first;
779                 } else {
780                         seq.push_back(p);
781                         p_pushed = true;
782                         break;
783                 }
784         }
785         
786         if (p_pushed) {
787                 // while loop exited because p was pushed, now push rest of s.seq
788                 while (first!=last) {
789                         seq.push_back(*first);
790                         ++first;
791                 }
792         } else {
793                 // while loop exited because s.seq was pushed, now push p
794                 seq.push_back(p);
795         }
796
797         if (needs_further_processing) {
798                 epvector v = seq;
799                 seq.clear();
800                 construct_from_epvector(v);
801         }
802 }
803
804 void expairseq::construct_from_exvector(const exvector &v)
805 {
806         // simplifications: +(a,+(b,c),d) -> +(a,b,c,d) (associativity)
807         //                  +(d,b,c,a) -> +(a,b,c,d) (canonicalization)
808         //                  +(...,x,*(x,c1),*(x,c2)) -> +(...,*(x,1+c1+c2)) (c1, c2 numeric)
809         //                  (same for (+,*) -> (*,^)
810
811         make_flat(v);
812         canonicalize();
813         combine_same_terms_sorted_seq();
814 }
815
816 void expairseq::construct_from_epvector(const epvector &v, bool do_index_renaming)
817 {
818         // simplifications: +(a,+(b,c),d) -> +(a,b,c,d) (associativity)
819         //                  +(d,b,c,a) -> +(a,b,c,d) (canonicalization)
820         //                  +(...,x,*(x,c1),*(x,c2)) -> +(...,*(x,1+c1+c2)) (c1, c2 numeric)
821         //                  same for (+,*) -> (*,^)
822
823         make_flat(v, do_index_renaming);
824         canonicalize();
825         combine_same_terms_sorted_seq();
826 }
827
828 /** Combine this expairseq with argument exvector.
829  *  It cares for associativity as well as for special handling of numerics. */
830 void expairseq::make_flat(const exvector &v)
831 {
832         exvector::const_iterator cit;
833         
834         // count number of operands which are of same expairseq derived type
835         // and their cumulative number of operands
836         int nexpairseqs = 0;
837         int noperands = 0;
838         bool do_idx_rename = false;
839         
840         cit = v.begin();
841         while (cit!=v.end()) {
842                 if (typeid(ex_to<basic>(*cit)) == typeid(*this)) {
843                         ++nexpairseqs;
844                         noperands += ex_to<expairseq>(*cit).seq.size();
845                 }
846                 if (is_a<mul>(*this) && (!do_idx_rename) &&
847                                 cit->info(info_flags::has_indices))
848                         do_idx_rename = true;
849                 ++cit;
850         }
851         
852         // reserve seq and coeffseq which will hold all operands
853         seq.reserve(v.size()+noperands-nexpairseqs);
854         
855         // copy elements and split off numerical part
856         make_flat_inserter mf(v, do_idx_rename);
857         cit = v.begin();
858         while (cit!=v.end()) {
859                 if (typeid(ex_to<basic>(*cit)) == typeid(*this)) {
860                         ex newfactor = mf.handle_factor(*cit, _ex1);
861                         const expairseq &subseqref = ex_to<expairseq>(newfactor);
862                         combine_overall_coeff(subseqref.overall_coeff);
863                         epvector::const_iterator cit_s = subseqref.seq.begin();
864                         while (cit_s!=subseqref.seq.end()) {
865                                 seq.push_back(*cit_s);
866                                 ++cit_s;
867                         }
868                 } else {
869                         if (is_exactly_a<numeric>(*cit))
870                                 combine_overall_coeff(*cit);
871                         else {
872                                 ex newfactor = mf.handle_factor(*cit, _ex1);
873                                 seq.push_back(split_ex_to_pair(newfactor));
874                         }
875                 }
876                 ++cit;
877         }
878 }
879
880 /** Combine this expairseq with argument epvector.
881  *  It cares for associativity as well as for special handling of numerics. */
882 void expairseq::make_flat(const epvector &v, bool do_index_renaming)
883 {
884         epvector::const_iterator cit;
885         
886         // count number of operands which are of same expairseq derived type
887         // and their cumulative number of operands
888         int nexpairseqs = 0;
889         int noperands = 0;
890         bool really_need_rename_inds = false;
891         
892         cit = v.begin();
893         while (cit!=v.end()) {
894                 if (typeid(ex_to<basic>(cit->rest)) == typeid(*this)) {
895                         ++nexpairseqs;
896                         noperands += ex_to<expairseq>(cit->rest).seq.size();
897                 }
898                 if ((!really_need_rename_inds) && is_a<mul>(*this) &&
899                                 cit->rest.info(info_flags::has_indices))
900                         really_need_rename_inds = true;
901                 ++cit;
902         }
903         do_index_renaming = do_index_renaming && really_need_rename_inds;
904         
905         // reserve seq and coeffseq which will hold all operands
906         seq.reserve(v.size()+noperands-nexpairseqs);
907         make_flat_inserter mf(v, do_index_renaming);
908         
909         // copy elements and split off numerical part
910         cit = v.begin();
911         while (cit!=v.end()) {
912                 if ((typeid(ex_to<basic>(cit->rest)) == typeid(*this)) &&
913                     this->can_make_flat(*cit)) {
914                         ex newrest = mf.handle_factor(cit->rest, cit->coeff);
915                         const expairseq &subseqref = ex_to<expairseq>(newrest);
916                         combine_overall_coeff(ex_to<numeric>(subseqref.overall_coeff),
917                                                             ex_to<numeric>(cit->coeff));
918                         epvector::const_iterator cit_s = subseqref.seq.begin();
919                         while (cit_s!=subseqref.seq.end()) {
920                                 seq.push_back(expair(cit_s->rest,
921                                                      ex_to<numeric>(cit_s->coeff).mul_dyn(ex_to<numeric>(cit->coeff))));
922                                 //seq.push_back(combine_pair_with_coeff_to_pair(*cit_s,
923                                 //                                              (*cit).coeff));
924                                 ++cit_s;
925                         }
926                 } else {
927                         if (cit->is_canonical_numeric())
928                                 combine_overall_coeff(mf.handle_factor(cit->rest, _ex1));
929                         else {
930                                 ex rest = cit->rest;
931                                 ex newrest = mf.handle_factor(rest, cit->coeff);
932                                 if (are_ex_trivially_equal(newrest, rest))
933                                         seq.push_back(*cit);
934                                 else
935                                         seq.push_back(expair(newrest, cit->coeff));
936                         }
937                 }
938                 ++cit;
939         }
940 }
941
942 /** Brings this expairseq into a sorted (canonical) form. */
943 void expairseq::canonicalize()
944 {
945         std::sort(seq.begin(), seq.end(), expair_rest_is_less());
946 }
947
948
949 /** Compact a presorted expairseq by combining all matching expairs to one
950  *  each.  On an add object, this is responsible for 2*x+3*x+y -> 5*x+y, for
951  *  instance. */
952 void expairseq::combine_same_terms_sorted_seq()
953 {
954         if (seq.size()<2)
955                 return;
956
957         bool needs_further_processing = false;
958
959         epvector::iterator itin1 = seq.begin();
960         epvector::iterator itin2 = itin1+1;
961         epvector::iterator itout = itin1;
962         epvector::iterator last = seq.end();
963         // must_copy will be set to true the first time some combination is 
964         // possible from then on the sequence has changed and must be compacted
965         bool must_copy = false;
966         while (itin2!=last) {
967                 if (itin1->rest.compare(itin2->rest)==0) {
968                         itin1->coeff = ex_to<numeric>(itin1->coeff).
969                                        add_dyn(ex_to<numeric>(itin2->coeff));
970                         if (expair_needs_further_processing(itin1))
971                                 needs_further_processing = true;
972                         must_copy = true;
973                 } else {
974                         if (!ex_to<numeric>(itin1->coeff).is_zero()) {
975                                 if (must_copy)
976                                         *itout = *itin1;
977                                 ++itout;
978                         }
979                         itin1 = itin2;
980                 }
981                 ++itin2;
982         }
983         if (!ex_to<numeric>(itin1->coeff).is_zero()) {
984                 if (must_copy)
985                         *itout = *itin1;
986                 ++itout;
987         }
988         if (itout!=last)
989                 seq.erase(itout,last);
990
991         if (needs_further_processing) {
992                 epvector v = seq;
993                 seq.clear();
994                 construct_from_epvector(v);
995         }
996 }
997
998 /** Check if this expairseq is in sorted (canonical) form.  Useful mainly for
999  *  debugging or in assertions since being sorted is an invariance. */
1000 bool expairseq::is_canonical() const
1001 {
1002         if (seq.size() <= 1)
1003                 return 1;
1004         
1005         epvector::const_iterator it = seq.begin(), itend = seq.end();
1006         epvector::const_iterator it_last = it;
1007         for (++it; it!=itend; it_last=it, ++it) {
1008                 if (!(it_last->is_less(*it) || it_last->is_equal(*it))) {
1009                         if (!is_exactly_a<numeric>(it_last->rest) ||
1010                                 !is_exactly_a<numeric>(it->rest)) {
1011                                 // double test makes it easier to set a breakpoint...
1012                                 if (!is_exactly_a<numeric>(it_last->rest) ||
1013                                         !is_exactly_a<numeric>(it->rest)) {
1014                                         printpair(std::clog, *it_last, 0);
1015                                         std::clog << ">";
1016                                         printpair(std::clog, *it, 0);
1017                                         std::clog << "\n";
1018                                         std::clog << "pair1:" << std::endl;
1019                                         it_last->rest.print(print_tree(std::clog));
1020                                         it_last->coeff.print(print_tree(std::clog));
1021                                         std::clog << "pair2:" << std::endl;
1022                                         it->rest.print(print_tree(std::clog));
1023                                         it->coeff.print(print_tree(std::clog));
1024                                         return 0;
1025                                 }
1026                         }
1027                 }
1028         }
1029         return 1;
1030 }
1031
1032
1033 /** Member-wise expand the expairs in this sequence.
1034  *
1035  *  @see expairseq::expand()
1036  *  @return pointer to epvector containing expanded pairs or zero pointer,
1037  *  if no members were changed. */
1038 std::auto_ptr<epvector> expairseq::expandchildren(unsigned options) const
1039 {
1040         const epvector::const_iterator last = seq.end();
1041         epvector::const_iterator cit = seq.begin();
1042         while (cit!=last) {
1043                 const ex &expanded_ex = cit->rest.expand(options);
1044                 if (!are_ex_trivially_equal(cit->rest,expanded_ex)) {
1045                         
1046                         // something changed, copy seq, eval and return it
1047                         std::auto_ptr<epvector> s(new epvector);
1048                         s->reserve(seq.size());
1049                         
1050                         // copy parts of seq which are known not to have changed
1051                         epvector::const_iterator cit2 = seq.begin();
1052                         while (cit2!=cit) {
1053                                 s->push_back(*cit2);
1054                                 ++cit2;
1055                         }
1056
1057                         // copy first changed element
1058                         s->push_back(combine_ex_with_coeff_to_pair(expanded_ex,
1059                                                                    cit2->coeff));
1060                         ++cit2;
1061
1062                         // copy rest
1063                         while (cit2!=last) {
1064                                 s->push_back(combine_ex_with_coeff_to_pair(cit2->rest.expand(options),
1065                                                                            cit2->coeff));
1066                                 ++cit2;
1067                         }
1068                         return s;
1069                 }
1070                 ++cit;
1071         }
1072         
1073         return std::auto_ptr<epvector>(0); // signalling nothing has changed
1074 }
1075
1076
1077 /** Member-wise evaluate the expairs in this sequence.
1078  *
1079  *  @see expairseq::eval()
1080  *  @return pointer to epvector containing evaluated pairs or zero pointer,
1081  *  if no members were changed. */
1082 std::auto_ptr<epvector> expairseq::evalchildren(int level) const
1083 {
1084         // returns a NULL pointer if nothing had to be evaluated
1085         // returns a pointer to a newly created epvector otherwise
1086         // (which has to be deleted somewhere else)
1087
1088         if (level==1)
1089                 return std::auto_ptr<epvector>(0);
1090         
1091         if (level == -max_recursion_level)
1092                 throw(std::runtime_error("max recursion level reached"));
1093         
1094         --level;
1095         epvector::const_iterator last = seq.end();
1096         epvector::const_iterator cit = seq.begin();
1097         while (cit!=last) {
1098                 const ex &evaled_ex = cit->rest.eval(level);
1099                 if (!are_ex_trivially_equal(cit->rest,evaled_ex)) {
1100                         
1101                         // something changed, copy seq, eval and return it
1102                         std::auto_ptr<epvector> s(new epvector);
1103                         s->reserve(seq.size());
1104                         
1105                         // copy parts of seq which are known not to have changed
1106                         epvector::const_iterator cit2=seq.begin();
1107                         while (cit2!=cit) {
1108                                 s->push_back(*cit2);
1109                                 ++cit2;
1110                         }
1111
1112                         // copy first changed element
1113                         s->push_back(combine_ex_with_coeff_to_pair(evaled_ex,
1114                                                                    cit2->coeff));
1115                         ++cit2;
1116
1117                         // copy rest
1118                         while (cit2!=last) {
1119                                 s->push_back(combine_ex_with_coeff_to_pair(cit2->rest.eval(level),
1120                                                                            cit2->coeff));
1121                                 ++cit2;
1122                         }
1123                         return s;
1124                 }
1125                 ++cit;
1126         }
1127         
1128         return std::auto_ptr<epvector>(0); // signalling nothing has changed
1129 }
1130
1131 /** Member-wise substitute in this sequence.
1132  *
1133  *  @see expairseq::subs()
1134  *  @return pointer to epvector containing pairs after application of subs,
1135  *    or NULL pointer if no members were changed. */
1136 std::auto_ptr<epvector> expairseq::subschildren(const exmap & m, unsigned options) const
1137 {
1138         // When any of the objects to be substituted is a product or power
1139         // we have to recombine the pairs because the numeric coefficients may
1140         // be part of the search pattern.
1141         if (!(options & (subs_options::pattern_is_product | subs_options::pattern_is_not_product))) {
1142
1143                 // Search the list of substitutions and cache our findings
1144                 for (exmap::const_iterator it = m.begin(); it != m.end(); ++it) {
1145                         if (is_exactly_a<mul>(it->first) || is_exactly_a<power>(it->first)) {
1146                                 options |= subs_options::pattern_is_product;
1147                                 break;
1148                         }
1149                 }
1150                 if (!(options & subs_options::pattern_is_product))
1151                         options |= subs_options::pattern_is_not_product;
1152         }
1153
1154         if (options & subs_options::pattern_is_product) {
1155
1156                 // Substitute in the recombined pairs
1157                 epvector::const_iterator cit = seq.begin(), last = seq.end();
1158                 while (cit != last) {
1159
1160                         const ex &orig_ex = recombine_pair_to_ex(*cit);
1161                         const ex &subsed_ex = orig_ex.subs(m, options);
1162                         if (!are_ex_trivially_equal(orig_ex, subsed_ex)) {
1163
1164                                 // Something changed, copy seq, subs and return it
1165                                 std::auto_ptr<epvector> s(new epvector);
1166                                 s->reserve(seq.size());
1167
1168                                 // Copy parts of seq which are known not to have changed
1169                                 s->insert(s->begin(), seq.begin(), cit);
1170
1171                                 // Copy first changed element
1172                                 s->push_back(split_ex_to_pair(subsed_ex));
1173                                 ++cit;
1174
1175                                 // Copy rest
1176                                 while (cit != last) {
1177                                         s->push_back(split_ex_to_pair(recombine_pair_to_ex(*cit).subs(m, options)));
1178                                         ++cit;
1179                                 }
1180                                 return s;
1181                         }
1182
1183                         ++cit;
1184                 }
1185
1186         } else {
1187
1188                 // Substitute only in the "rest" part of the pairs
1189                 epvector::const_iterator cit = seq.begin(), last = seq.end();
1190                 while (cit != last) {
1191
1192                         const ex &subsed_ex = cit->rest.subs(m, options);
1193                         if (!are_ex_trivially_equal(cit->rest, subsed_ex)) {
1194                         
1195                                 // Something changed, copy seq, subs and return it
1196                                 std::auto_ptr<epvector> s(new epvector);
1197                                 s->reserve(seq.size());
1198
1199                                 // Copy parts of seq which are known not to have changed
1200                                 s->insert(s->begin(), seq.begin(), cit);
1201                         
1202                                 // Copy first changed element
1203                                 s->push_back(combine_ex_with_coeff_to_pair(subsed_ex, cit->coeff));
1204                                 ++cit;
1205
1206                                 // Copy rest
1207                                 while (cit != last) {
1208                                         s->push_back(combine_ex_with_coeff_to_pair(cit->rest.subs(m, options), cit->coeff));
1209                                         ++cit;
1210                                 }
1211                                 return s;
1212                         }
1213
1214                         ++cit;
1215                 }
1216         }
1217         
1218         // Nothing has changed
1219         return std::auto_ptr<epvector>(0);
1220 }
1221
1222 //////////
1223 // static member variables
1224 //////////
1225
1226 } // namespace GiNaC