]> www.ginac.de Git - ginac.git/blob - ginac/container.h
tiny optimization in subs()
[ginac.git] / ginac / container.h
1 /** @file container.h
2  *
3  *  Wrapper template for making GiNaC classes out of STL containers. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #ifndef __GINAC_CONTAINER_H__
24 #define __GINAC_CONTAINER_H__
25
26 #include <iterator>
27 #include <stdexcept>
28 #include <algorithm>
29 #include <vector>
30 #include <list>
31
32 #include "ex.h"
33 #include "print.h"
34 #include "archive.h"
35 #include "assertion.h"
36
37 namespace GiNaC {
38
39
40 /** Helper template for encapsulating the reserve() mechanics of STL containers. */
41 template <template <class> class C>
42 class container_storage {
43 protected:
44         typedef C<ex> STLT;
45
46         container_storage() {}
47         container_storage(size_t n, const ex & e) : seq(n, e) {}
48
49         template <class In>
50         container_storage(In b, In e) : seq(b, e) {}
51
52         void reserve(size_t) {}
53         static void reserve(STLT &, size_t) {}
54
55         STLT seq;
56
57         // disallow destruction of container through a container_storage*
58 protected:
59         ~container_storage() {}
60 };
61
62 template <>
63 inline void container_storage<std::vector>::reserve(size_t n) { seq.reserve(n); }
64
65 template <>
66 inline void container_storage<std::vector>::reserve(std::vector<ex> & v, size_t n) { v.reserve(n); }
67
68
69 /** Wrapper template for making GiNaC classes out of STL containers. */
70 template <template <class> class C>
71 class container : public basic, public container_storage<C> {
72         GINAC_DECLARE_REGISTERED_CLASS(container, basic)
73
74         typedef typename container_storage<C>::STLT STLT;
75
76 public:
77         typedef typename STLT::const_iterator const_iterator;
78         typedef typename STLT::const_reverse_iterator const_reverse_iterator;
79
80 protected:
81         // helpers
82         static unsigned get_tinfo() { return TINFO_fail; }
83         static char get_open_delim() { return '('; }
84         static char get_close_delim() { return ')'; }
85
86         // constructors
87 public:
88         container(STLT const & s, bool discardable = false)
89         {
90                 if (discardable)
91                         seq.swap(const_cast<STLT &>(s));
92                 else
93                         seq = s;
94         }
95
96         explicit container(STLT * vp)
97         {
98                 GINAC_ASSERT(vp);
99                 seq.swap(*vp);
100                 delete vp;
101         }
102
103         container(exvector::const_iterator b, exvector::const_iterator e)
104          : inherited(get_tinfo()), container_storage<C>(b, e) {}
105
106         explicit container(const ex & p1)
107          : inherited(get_tinfo()), container_storage<C>(1, p1) {}
108
109         container(const ex & p1, const ex & p2) : inherited(get_tinfo())
110         {
111                 reserve(seq, 2);
112                 seq.push_back(p1); seq.push_back(p2);
113         }
114
115         container(const ex & p1, const ex & p2, const ex & p3) : inherited(get_tinfo())
116         {
117                 reserve(seq, 3);
118                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
119         }
120
121         container(const ex & p1, const ex & p2, const ex & p3,
122                            const ex & p4) : inherited(get_tinfo())
123         {
124                 reserve(seq, 4);
125                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
126                 seq.push_back(p4);
127         }
128
129         container(const ex & p1, const ex & p2, const ex & p3,
130                   const ex & p4, const ex & p5) : inherited(get_tinfo())
131         {
132                 reserve(seq, 5);
133                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
134                 seq.push_back(p4); seq.push_back(p5);
135         }
136
137         container(const ex & p1, const ex & p2, const ex & p3,
138                   const ex & p4, const ex & p5, const ex & p6) : inherited(get_tinfo())
139         {
140                 reserve(seq, 6);
141                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
142                 seq.push_back(p4); seq.push_back(p5); seq.push_back(p6);
143         }
144
145         container(const ex & p1, const ex & p2, const ex & p3,
146                   const ex & p4, const ex & p5, const ex & p6,
147                   const ex & p7) : inherited(get_tinfo())
148         {
149                 reserve(seq, 7);
150                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
151                 seq.push_back(p4); seq.push_back(p5); seq.push_back(p6);
152                 seq.push_back(p7);
153         }
154
155         container(const ex & p1, const ex & p2, const ex & p3,
156                   const ex & p4, const ex & p5, const ex & p6,
157                   const ex & p7, const ex & p8) : inherited(get_tinfo())
158         {
159                 reserve(seq, 8);
160                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
161                 seq.push_back(p4); seq.push_back(p5); seq.push_back(p6);
162                 seq.push_back(p7); seq.push_back(p8);
163         }
164
165         container(const ex & p1, const ex & p2, const ex & p3,
166                   const ex & p4, const ex & p5, const ex & p6,
167                   const ex & p7, const ex & p8, const ex & p9) : inherited(get_tinfo())
168         {
169                 reserve(seq, 9);
170                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
171                 seq.push_back(p4); seq.push_back(p5); seq.push_back(p6);
172                 seq.push_back(p7); seq.push_back(p8); seq.push_back(p9);
173         }
174
175         container(const ex & p1, const ex & p2, const ex & p3,
176                   const ex & p4, const ex & p5, const ex & p6,
177                   const ex & p7, const ex & p8, const ex & p9,
178                   const ex & p10) : inherited(get_tinfo())
179         {
180                 reserve(seq, 10);
181                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
182                 seq.push_back(p4); seq.push_back(p5); seq.push_back(p6);
183                 seq.push_back(p7); seq.push_back(p8); seq.push_back(p9);
184                 seq.push_back(p10);
185         }
186
187         container(const ex & p1, const ex & p2, const ex & p3,
188                   const ex & p4, const ex & p5, const ex & p6,
189                   const ex & p7, const ex & p8, const ex & p9,
190                   const ex & p10, const ex & p11) : inherited(get_tinfo())
191         {
192                 reserve(seq, 11);
193                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
194                 seq.push_back(p4); seq.push_back(p5); seq.push_back(p6);
195                 seq.push_back(p7); seq.push_back(p8); seq.push_back(p9);
196                 seq.push_back(p10); seq.push_back(p11);
197         }
198
199         container(const ex & p1, const ex & p2, const ex & p3,
200                   const ex & p4, const ex & p5, const ex & p6,
201                   const ex & p7, const ex & p8, const ex & p9,
202                   const ex & p10, const ex & p11, const ex & p12) : inherited(get_tinfo())
203         {
204                 reserve(seq, 12);
205                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
206                 seq.push_back(p4); seq.push_back(p5); seq.push_back(p6);
207                 seq.push_back(p7); seq.push_back(p8); seq.push_back(p9);
208                 seq.push_back(p10); seq.push_back(p11); seq.push_back(p12);
209         }
210
211         container(const ex & p1, const ex & p2, const ex & p3,
212                   const ex & p4, const ex & p5, const ex & p6,
213                   const ex & p7, const ex & p8, const ex & p9,
214                   const ex & p10, const ex & p11, const ex & p12,
215                   const ex & p13) : inherited(get_tinfo())
216         {
217                 reserve(seq, 13);
218                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
219                 seq.push_back(p4); seq.push_back(p5); seq.push_back(p6);
220                 seq.push_back(p7); seq.push_back(p8); seq.push_back(p9);
221                 seq.push_back(p10); seq.push_back(p11); seq.push_back(p12);
222                 seq.push_back(p13);
223         }
224
225         container(const ex & p1, const ex & p2, const ex & p3,
226                   const ex & p4, const ex & p5, const ex & p6,
227                   const ex & p7, const ex & p8, const ex & p9,
228                   const ex & p10, const ex & p11, const ex & p12,
229                   const ex & p13, const ex & p14) : inherited(get_tinfo())
230         {
231                 reserve(seq, 14);
232                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
233                 seq.push_back(p4); seq.push_back(p5); seq.push_back(p6);
234                 seq.push_back(p7); seq.push_back(p8); seq.push_back(p9);
235                 seq.push_back(p10); seq.push_back(p11); seq.push_back(p12);
236                 seq.push_back(p13); seq.push_back(p14);
237         }
238
239         container(const ex & p1, const ex & p2, const ex & p3,
240                   const ex & p4, const ex & p5, const ex & p6,
241                   const ex & p7, const ex & p8, const ex & p9,
242                   const ex & p10, const ex & p11, const ex & p12,
243                   const ex & p13, const ex & p14, const ex & p15) : inherited(get_tinfo())
244         {
245                 reserve(seq, 15);
246                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
247                 seq.push_back(p4); seq.push_back(p5); seq.push_back(p6);
248                 seq.push_back(p7); seq.push_back(p8); seq.push_back(p9);
249                 seq.push_back(p10); seq.push_back(p11); seq.push_back(p12);
250                 seq.push_back(p13); seq.push_back(p14); seq.push_back(p15);
251         }
252
253         container(const ex & p1, const ex & p2, const ex & p3,
254                   const ex & p4, const ex & p5, const ex & p6,
255                   const ex & p7, const ex & p8, const ex & p9,
256                   const ex & p10, const ex & p11, const ex & p12,
257                   const ex & p13, const ex & p14, const ex & p15,
258                   const ex & p16) : inherited(get_tinfo())
259         {
260                 reserve(seq, 16);
261                 seq.push_back(p1); seq.push_back(p2); seq.push_back(p3);
262                 seq.push_back(p4); seq.push_back(p5); seq.push_back(p6);
263                 seq.push_back(p7); seq.push_back(p8); seq.push_back(p9);
264                 seq.push_back(p10); seq.push_back(p11); seq.push_back(p12);
265                 seq.push_back(p13); seq.push_back(p14); seq.push_back(p15);
266                 seq.push_back(p16);
267         }
268
269         // functions overriding virtual functions from base classes
270 public:
271         void print(const print_context & c, unsigned level = 0) const;
272         bool info(unsigned inf) const { return inherited::info(inf); }
273         unsigned precedence() const { return 10; }
274         size_t nops() const { return seq.size(); }
275         ex op(size_t i) const;
276         ex & let_op(size_t i);
277         ex eval(int level = 0) const;
278         ex subs(const lst & ls, const lst & lr, unsigned options = 0) const;
279
280 protected:
281         bool is_equal_same_type(const basic & other) const;
282
283         // new virtual functions which can be overridden by derived classes
284 protected:
285         /** Similar to duplicate(), but with a preset sequence. Must be
286          *  overridden by derived classes. */
287         virtual ex thiscontainer(const STLT & v) const { return container(v); }
288
289         /** Similar to duplicate(), but with a preset sequence (which gets
290          *  deleted). Must be overridden by derived classes. */
291         virtual ex thiscontainer(STLT * vp) const { return container(vp); }
292
293         virtual void printseq(const print_context & c, char openbracket, char delim,
294                               char closebracket, unsigned this_precedence,
295                               unsigned upper_precedence = 0) const;
296
297         // non-virtual functions in this class
298 private:
299         void sort_(std::random_access_iterator_tag)
300         {
301                 std::sort(seq.begin(), seq.end(), ex_is_less());
302         }
303
304         void sort_(std::input_iterator_tag)
305         {
306                 seq.sort(ex_is_less());
307         }
308
309         void unique_()
310         {
311                 typename STLT::iterator p = std::unique(seq.begin(), seq.end(), ex_is_equal());
312                 seq.erase(p, seq.end());
313         }
314
315 public:
316         container & prepend(const ex & b);
317         container & append(const ex & b);
318         container & remove_first();
319         container & remove_last();
320         container & remove_all();
321         container & sort();
322         container & unique();
323
324         const_iterator begin() const {return seq.begin();}
325         const_iterator end() const {return seq.end();}
326         const_reverse_iterator rbegin() const {return seq.rbegin();}
327         const_reverse_iterator rend() const {return seq.rend();}
328
329 protected:
330         STLT evalchildren(int level) const;
331         STLT *subschildren(const lst & ls, const lst & lr, unsigned options = 0) const;
332 };
333
334 /** Default constructor */
335 template <template <class> class C>
336 container<C>::container() : inherited(get_tinfo()) {}
337
338 /** Construct object from archive_node. */
339 template <template <class> class C>
340 container<C>::container(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
341 {
342         for (unsigned int i=0; true; i++) {
343                 ex e;
344                 if (n.find_ex("seq", e, sym_lst, i))
345                         seq.push_back(e);
346                 else
347                         break;
348         }
349 }
350
351 /** Unarchive the object. */
352 template <template <class> class C>
353 ex container<C>::unarchive(const archive_node &n, lst &sym_lst)
354 {
355         return (new container(n, sym_lst))->setflag(status_flags::dynallocated);
356 }
357
358 /** Archive the object. */
359 template <template <class> class C>
360 void container<C>::archive(archive_node &n) const
361 {
362         inherited::archive(n);
363         const_iterator i = seq.begin(), end = seq.end();
364         while (i != end) {
365                 n.add_ex("seq", *i);
366                 ++i;
367         }
368 }
369
370 template <template <class> class C>
371 void container<C>::print(const print_context & c, unsigned level) const
372 {
373         if (is_a<print_tree>(c)) {
374                 c.s << std::string(level, ' ') << class_name()
375                     << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
376                     << ", nops=" << nops()
377                     << std::endl;
378                 unsigned delta_indent = static_cast<const print_tree &>(c).delta_indent;
379                 const_iterator i = seq.begin(), end = seq.end();
380                 while (i != end) {
381                         i->print(c, level + delta_indent);
382                         ++i;
383                 }
384                 c.s << std::string(level + delta_indent,' ') << "=====" << std::endl;
385         } else if (is_a<print_python>(c)) {
386                 printseq(c, '[', ',', ']', precedence(), precedence()+1);
387         } else if (is_a<print_python_repr>(c)) {
388                 c.s << class_name ();
389                 printseq(c, '(', ',', ')', precedence(), precedence()+1);
390         } else {
391                 // always print brackets around seq, ignore upper_precedence
392                 printseq(c, get_open_delim(), ',', get_close_delim(), precedence(), precedence()+1);
393         }
394 }
395
396 template <template <class> class C>
397 ex container<C>::op(size_t i) const
398 {
399         GINAC_ASSERT(i < nops());
400
401         const_iterator it = seq.begin();
402         advance(it, i);
403         return *it;
404 }
405
406 template <template <class> class C>
407 ex & container<C>::let_op(size_t i)
408 {
409         GINAC_ASSERT(i < nops());
410
411         ensure_if_modifiable();
412         typename STLT::iterator it = seq.begin();
413         advance(it, i);
414         return *it;
415 }
416
417 template <template <class> class C>
418 ex container<C>::eval(int level) const
419 {
420         if (level == 1)
421                 return hold();
422         else
423                 return thiscontainer(evalchildren(level));
424 }
425
426 template <template <class> class C>
427 ex container<C>::subs(const lst & ls, const lst & lr, unsigned options) const
428 {
429         STLT *vp = subschildren(ls, lr, options);
430         if (vp)
431                 return ex_to<basic>(thiscontainer(vp)).subs_one_level(ls, lr, options);
432         else
433                 return subs_one_level(ls, lr, options);
434 }
435
436 /** Compare two containers of the same type. */
437 template <template <class> class C>
438 int container<C>::compare_same_type(const basic & other) const
439 {
440         GINAC_ASSERT(is_a<container>(other));
441         const container & o = static_cast<const container &>(other);
442
443         const_iterator it1 = seq.begin(), it1end = seq.end(),
444                        it2 = o.seq.begin(), it2end = o.seq.end();
445
446         while (it1 != it1end && it2 != it2end) {
447                 int cmpval = it1->compare(*it2);
448                 if (cmpval)
449                         return cmpval;
450                 ++it1; ++it2;
451         }
452
453         return (it1 == it1end) ? (it2 == it2end ? 0 : -1) : 1;
454 }
455
456 template <template <class> class C>
457 bool container<C>::is_equal_same_type(const basic & other) const
458 {
459         GINAC_ASSERT(is_a<container>(other));
460         const container & o = static_cast<const container &>(other);
461
462         if (seq.size() != o.seq.size())
463                 return false;
464
465         const_iterator it1 = seq.begin(), it1end = seq.end(), it2 = o.seq.begin();
466         while (it1 != it1end) {
467                 if (!it1->is_equal(*it2))
468                         return false;
469                 ++it1; ++it2;
470         }
471
472         return true;
473 }
474
475 /** Add element at front. */
476 template <template <class> class C>
477 container<C> & container<C>::prepend(const ex & b)
478 {
479         ensure_if_modifiable();
480         seq.push_front(b);
481         return *this;
482 }
483
484 /** Add element at back. */
485 template <template <class> class C>
486 container<C> & container<C>::append(const ex & b)
487 {
488         ensure_if_modifiable();
489         seq.push_back(b);
490         return *this;
491 }
492
493 /** Remove first element. */
494 template <template <class> class C>
495 container<C> & container<C>::remove_first()
496 {
497         ensure_if_modifiable();
498         seq.pop_front();
499         return *this;
500 }
501
502 /** Remove last element. */
503 template <template <class> class C>
504 container<C> & container<C>::remove_last()
505 {
506         ensure_if_modifiable();
507         seq.pop_back();
508         return *this;
509 }
510
511 /** Remove all elements. */
512 template <template <class> class C>
513 container<C> & container<C>::remove_all()
514 {
515         ensure_if_modifiable();
516         seq.clear();
517         return *this;
518 }
519
520 /** Sort elements. */
521 template <template <class> class C>
522 container<C> & container<C>::sort()
523 {
524         ensure_if_modifiable();
525         sort_(std::iterator_traits<typename STLT::iterator>::iterator_category());
526         return *this;
527 }
528
529 /** Specialization of container::unique_() for std::list. */
530 inline void container<std::list>::unique_()
531 {
532         seq.unique(ex_is_equal());
533 }
534
535 /** Remove adjacent duplicate elements. */
536 template <template <class> class C>
537 container<C> & container<C>::unique()
538 {
539         ensure_if_modifiable();
540         unique_();
541         return *this;
542 }
543
544 /** Print sequence of contained elements. */
545 template <template <class> class C>
546 void container<C>::printseq(const print_context & c, char openbracket, char delim,
547                             char closebracket, unsigned this_precedence,
548                             unsigned upper_precedence) const
549 {
550         if (this_precedence <= upper_precedence)
551                 c.s << openbracket;
552
553         if (!seq.empty()) {
554                 const_iterator it = seq.begin(), itend = seq.end();
555                 --itend;
556                 while (it != itend) {
557                         it->print(c, this_precedence);
558                         c.s << delim;
559                         ++it;
560                 }
561                 it->print(c, this_precedence);
562         }
563
564         if (this_precedence <= upper_precedence)
565                 c.s << closebracket;
566 }
567
568 template <template <class> class C>
569 typename container<C>::STLT container<C>::evalchildren(int level) const
570 {
571         if (level == 1)
572                 return seq;
573         else if (level == -max_recursion_level)
574                 throw std::runtime_error("max recursion level reached");
575
576         STLT s;
577         reserve(s, seq.size());
578
579         --level;
580         const_iterator it = seq.begin(), itend = seq.end();
581         while (it != itend) {
582                 s.push_back(it->eval(level));
583                 ++it;
584         }
585
586         return s;
587 }
588
589 template <template <class> class C>
590 typename container<C>::STLT *container<C>::subschildren(const lst & ls, const lst & lr, unsigned options) const
591 {
592         // returns a NULL pointer if nothing had to be substituted
593         // returns a pointer to a newly created epvector otherwise
594         // (which has to be deleted somewhere else)
595
596         const_iterator cit = seq.begin(), end = seq.end();
597         while (cit != end) {
598                 const ex & subsed_ex = cit->subs(ls, lr, options);
599                 if (!are_ex_trivially_equal(*cit, subsed_ex)) {
600
601                         // copy first part of seq which hasn't changed
602                         STLT *s = new STLT(seq.begin(), cit);
603                         reserve(*s, seq.size());
604
605                         // insert changed element
606                         s->push_back(subsed_ex);
607                         ++cit;
608
609                         // copy rest
610                         while (cit != end) {
611                                 s->push_back(cit->subs(ls, lr, options));
612                                 ++cit;
613                         }
614
615                         return s;
616                 }
617
618                 ++cit;
619         }
620         
621         return 0; // nothing has changed
622 }
623
624 } // namespace GiNaC
625
626 #endif // ndef __GINAC_CONTAINER_H__