- removed debugging code in match()
[ginac.git] / ginac / basic.cpp
1 /** @file basic.cpp
2  *
3  *  Implementation of GiNaC's ABC. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #ifdef DO_GINAC_ASSERT
26 #  include <typeinfo>
27 #endif
28
29 #include "basic.h"
30 #include "ex.h"
31 #include "numeric.h"
32 #include "power.h"
33 #include "symbol.h"
34 #include "lst.h"
35 #include "ncmul.h"
36 #include "relational.h"
37 #include "print.h"
38 #include "archive.h"
39 #include "utils.h"
40 #include "debugmsg.h"
41
42 namespace GiNaC {
43
44 GINAC_IMPLEMENT_REGISTERED_CLASS_NO_CTORS(basic, void)
45
46 //////////
47 // default ctor, dtor, copy ctor assignment operator and helpers
48 //////////
49
50 // public
51
52 basic::basic(const basic & other) : tinfo_key(TINFO_basic), flags(0), refcount(0)
53 {
54         debugmsg("basic copy ctor", LOGLEVEL_CONSTRUCT);
55         copy(other);
56 }
57
58 const basic & basic::operator=(const basic & other)
59 {
60         debugmsg("basic operator=", LOGLEVEL_ASSIGNMENT);
61         if (this != &other) {
62                 destroy(true);
63                 copy(other);
64         }
65         return *this;
66 }
67
68 // protected
69
70 // none (all conditionally inlined)
71
72 //////////
73 // other ctors
74 //////////
75
76 // none (all conditionally inlined)
77
78 //////////
79 // archiving
80 //////////
81
82 /** Construct object from archive_node. */
83 basic::basic(const archive_node &n, const lst &sym_lst) : flags(0), refcount(0)
84 {
85         debugmsg("basic ctor from archive_node", LOGLEVEL_CONSTRUCT);
86
87         // Reconstruct tinfo_key from class name
88         std::string class_name;
89         if (n.find_string("class", class_name))
90                 tinfo_key = find_tinfo_key(class_name);
91         else
92                 throw (std::runtime_error("archive node contains no class name"));
93 }
94
95 /** Unarchive the object. */
96 DEFAULT_UNARCHIVE(basic)
97
98 /** Archive the object. */
99 void basic::archive(archive_node &n) const
100 {
101         n.add_string("class", class_name());
102 }
103
104 //////////
105 // functions overriding virtual functions from bases classes
106 //////////
107
108 // none
109
110 //////////
111 // new virtual functions which can be overridden by derived classes
112 //////////
113
114 // public
115
116 /** Output to stream.
117  *  @param c print context object that describes the output formatting
118  *  @param level value that is used to identify the precedence or indentation
119  *               level for placing parentheses and formatting */
120 void basic::print(const print_context & c, unsigned level) const
121 {
122         debugmsg("basic print", LOGLEVEL_PRINT);
123
124         if (is_of_type(c, print_tree)) {
125
126                 c.s << std::string(level, ' ') << class_name()
127                     << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
128                     << ", nops=" << nops()
129                     << std::endl;
130                 for (unsigned i=0; i<nops(); ++i)
131                         op(i).print(c, level + static_cast<const print_tree &>(c).delta_indent);
132
133         } else
134                 c.s << "[" << class_name() << " object]";
135 }
136
137 /** Little wrapper arount print to be called within a debugger.
138  *  This is needed because you cannot call foo.print(cout) from within the
139  *  debugger because it might not know what cout is.  This method can be
140  *  invoked with no argument and it will simply print to stdout.
141  *
142  *  @see basic::print */
143 void basic::dbgprint(void) const
144 {
145         this->print(std::cerr);
146         std::cerr << std::endl;
147 }
148
149 /** Little wrapper arount printtree to be called within a debugger.
150  *
151  *  @see basic::dbgprint
152  *  @see basic::printtree */
153 void basic::dbgprinttree(void) const
154 {
155         this->print(print_tree(std::cerr));
156 }
157
158 /** Return relative operator precedence (for parenthizing output). */
159 unsigned basic::precedence(void) const
160 {
161         return 70;
162 }
163
164 /** Create a new copy of this on the heap.  One can think of this as simulating
165  *  a virtual copy constructor which is needed for instance by the refcounted
166  *  construction of an ex from a basic. */
167 basic * basic::duplicate() const
168 {
169         debugmsg("basic duplicate",LOGLEVEL_DUPLICATE);
170         return new basic(*this);
171 }
172
173 /** Information about the object.
174  *
175  *  @see class info_flags */
176 bool basic::info(unsigned inf) const
177 {
178         // all possible properties are false for basic objects
179         return false;
180 }
181
182 /** Number of operands/members. */
183 unsigned basic::nops() const
184 {
185         // iterating from 0 to nops() on atomic objects should be an empty loop,
186         // and accessing their elements is a range error.  Container objects should
187         // override this.
188         return 0;
189 }
190
191 /** Return operand/member at position i. */
192 ex basic::op(int i) const
193 {
194         return (const_cast<basic *>(this))->let_op(i);
195 }
196
197 /** Return modifyable operand/member at position i. */
198 ex & basic::let_op(int i)
199 {
200         throw(std::out_of_range("op() out of range"));
201 }
202
203 ex basic::operator[](const ex & index) const
204 {
205         if (is_exactly_of_type(*index.bp,numeric))
206                 return op(static_cast<const numeric &>(*index.bp).to_int());
207         
208         throw(std::invalid_argument("non-numeric indices not supported by this type"));
209 }
210
211 ex basic::operator[](int i) const
212 {
213         return op(i);
214 }
215
216 /** Search ocurrences.  An object 'has' an expression if it is the expression
217  *  itself or one of the children 'has' it.  As a consequence (according to
218  *  the definition of children) given e=x+y+z, e.has(x) is true but e.has(x+y)
219  *  is false.  The expression can also contain wildcards. */
220 bool basic::has(const ex & other) const
221 {
222         GINAC_ASSERT(other.bp!=0);
223         lst repl_lst;
224         if (match(*other.bp, repl_lst)) return true;
225         if (nops()>0) {
226                 for (unsigned i=0; i<nops(); i++)
227                         if (op(i).has(other))
228                                 return true;
229         }
230         
231         return false;
232 }
233
234 /** Return degree of highest power in object s. */
235 int basic::degree(const ex & s) const
236 {
237         return 0;
238 }
239
240 /** Return degree of lowest power in object s. */
241 int basic::ldegree(const ex & s) const
242 {
243         return 0;
244 }
245
246 /** Return coefficient of degree n in object s. */
247 ex basic::coeff(const ex & s, int n) const
248 {
249         return n==0 ? *this : _ex0();
250 }
251
252 /** Sort expression in terms of powers of some object(s).
253  *  @param s object(s) to sort in
254  *  @param distributed recursive or distributed form (only used when s is a list) */
255 ex basic::collect(const ex & s, bool distributed) const
256 {
257         ex x;
258         if (is_ex_of_type(s, lst)) {
259
260                 // List of objects specified
261                 if (s.nops() == 1)
262                         return collect(s.op(0));
263
264                 else if (distributed) {
265
266                         // Get lower/upper degree of all symbols in list
267                         int num = s.nops();
268                         struct sym_info {
269                                 ex sym;
270                                 int ldeg, deg;
271                                 int cnt;  // current degree, 'counter'
272                                 ex coeff; // coefficient for degree 'cnt'
273                         };
274                         sym_info *si = new sym_info[num];
275                         ex c = *this;
276                         for (int i=0; i<num; i++) {
277                                 si[i].sym = s.op(i);
278                                 si[i].ldeg = si[i].cnt = this->ldegree(si[i].sym);
279                                 si[i].deg = this->degree(si[i].sym);
280                                 c = si[i].coeff = c.coeff(si[i].sym, si[i].cnt);
281                         }
282
283                         while (true) {
284
285                                 // Calculate coeff*x1^c1*...*xn^cn
286                                 ex y = _ex1();
287                                 for (int i=0; i<num; i++) {
288                                         int cnt = si[i].cnt;
289                                         y *= power(si[i].sym, cnt);
290                                 }
291                                 x += y * si[num - 1].coeff;
292
293                                 // Increment counters
294                                 int n = num - 1;
295                                 while (true) {
296                                         si[n].cnt++;
297                                         if (si[n].cnt <= si[n].deg) {
298                                                 // Update coefficients
299                                                 ex c;
300                                                 if (n == 0)
301                                                         c = *this;
302                                                 else
303                                                         c = si[n - 1].coeff;
304                                                 for (int i=n; i<num; i++)
305                                                         c = si[i].coeff = c.coeff(si[i].sym, si[i].cnt);
306                                                 break;
307                                         }
308                                         if (n == 0)
309                                                 goto done;
310                                         si[n].cnt = si[n].ldeg;
311                                         n--;
312                                 }
313                         }
314
315 done:           delete[] si;
316
317                 } else {
318
319                         // Recursive form
320                         x = *this;
321                         for (int n=s.nops()-1; n>=0; n--)
322                                 x = x.collect(s[n]);
323                 }
324
325         } else {
326
327                 // Only one object specified
328                 for (int n=this->ldegree(s); n<=this->degree(s); ++n)
329                         x += this->coeff(s,n)*power(s,n);
330         }
331         
332         // correct for lost fractional arguments and return
333         return x + (*this - x).expand();
334 }
335
336 /** Perform automatic non-interruptive symbolic evaluation on expression. */
337 ex basic::eval(int level) const
338 {
339         // There is nothing to do for basic objects:
340         return this->hold();
341 }
342
343 /** Evaluate object numerically. */
344 ex basic::evalf(int level) const
345 {
346         // There is nothing to do for basic objects:
347         return *this;
348 }
349
350 /** Perform automatic symbolic evaluations on indexed expression that
351  *  contains this object as the base expression. */
352 ex basic::eval_indexed(const basic & i) const
353  // this function can't take a "const ex & i" because that would result
354  // in an infinite eval() loop
355 {
356         // There is nothing to do for basic objects
357         return i.hold();
358 }
359
360 /** Add two indexed expressions. They are guaranteed to be of class indexed
361  *  (or a subclass) and their indices are compatible. This function is used
362  *  internally by simplify_indexed().
363  *
364  *  @param self First indexed expression; it's base object is *this
365  *  @param other Second indexed expression
366  *  @return sum of self and other 
367  *  @see ex::simplify_indexed() */
368 ex basic::add_indexed(const ex & self, const ex & other) const
369 {
370         return self + other;
371 }
372
373 /** Multiply an indexed expression with a scalar. This function is used
374  *  internally by simplify_indexed().
375  *
376  *  @param self Indexed expression; it's base object is *this
377  *  @param other Numeric value
378  *  @return product of self and other
379  *  @see ex::simplify_indexed() */
380 ex basic::scalar_mul_indexed(const ex & self, const numeric & other) const
381 {
382         return self * other;
383 }
384
385 /** Try to contract two indexed expressions that appear in the same product. 
386  *  If a contraction exists, the function overwrites one or both of the
387  *  expressions and returns true. Otherwise it returns false. It is
388  *  guaranteed that both expressions are of class indexed (or a subclass)
389  *  and that at least one dummy index has been found. This functions is
390  *  used internally by simplify_indexed().
391  *
392  *  @param self Pointer to first indexed expression; it's base object is *this
393  *  @param other Pointer to second indexed expression
394  *  @param v The complete vector of factors
395  *  @return true if the contraction was successful, false otherwise
396  *  @see ex::simplify_indexed() */
397 bool basic::contract_with(exvector::iterator self, exvector::iterator other, exvector & v) const
398 {
399         // Do nothing
400         return false;
401 }
402
403 /** Check whether the expression matches a given pattern. For every wildcard
404  *  object in the pattern, an expression of the form "wildcard == matching_expression"
405  *  is added to repl_lst. */
406 bool basic::match(const ex & pattern, lst & repl_lst) const
407 {
408 /*
409         Sweet sweet shapes, sweet sweet shapes,
410         Thats the key thing, right right.
411         Feed feed face, feed feed shapes,
412         But who is the king tonight?
413         Who is the king tonight?
414         Pattern is the thing, the key thing-a-ling,
415         But who is the king of pattern?
416         But who is the king, the king thing-a-ling,
417         Who is the king of Pattern?
418         Bog is the king, the king thing-a-ling,
419         Bog is the king of Pattern.
420         Ba bu-bu-bu-bu bu-bu-bu-bu-bu-bu bu-bu
421         Bog is the king of Pattern.
422 */
423
424         if (is_ex_exactly_of_type(pattern, wildcard)) {
425
426                 // Wildcard matches anything, but check whether we already have found
427                 // a match for that wildcard first (if so, it the earlier match must
428                 // be the same expression)
429                 for (unsigned i=0; i<repl_lst.nops(); i++) {
430                         if (repl_lst.op(i).op(0).is_equal(pattern))
431                                 return is_equal(*repl_lst.op(i).op(1).bp);
432                 }
433                 repl_lst.append(pattern == *this);
434                 return true;
435
436         } else {
437
438                 // Expression must be of the same type as the pattern
439                 if (tinfo() != pattern.bp->tinfo())
440                         return false;
441
442                 // Number of subexpressions must match
443                 if (nops() != pattern.nops())
444                         return false;
445
446                 // No subexpressions? Then just compare the objects (there can't be
447                 // wildcards in the pattern)
448                 if (nops() == 0)
449                         return is_equal(*pattern.bp);
450
451                 // Otherwise the subexpressions must match one-to-one
452                 for (unsigned i=0; i<nops(); i++)
453                         if (!op(i).match(pattern.op(i), repl_lst))
454                                 return false;
455
456                 // Looks similar enough, match found
457                 return true;
458         }
459 }
460
461 /** Substitute a set of objects by arbitrary expressions. The ex returned
462  *  will already be evaluated. */
463 ex basic::subs(const lst & ls, const lst & lr, bool no_pattern) const
464 {
465         GINAC_ASSERT(ls.nops() == lr.nops());
466
467         if (no_pattern) {
468                 for (unsigned i=0; i<ls.nops(); i++) {
469                         if (is_equal(*ls.op(i).bp))
470                                 return lr.op(i);
471                 }
472         } else {
473                 for (unsigned i=0; i<ls.nops(); i++) {
474                         lst repl_lst;
475                         if (match(*ls.op(i).bp, repl_lst))
476                                 return lr.op(i).bp->subs(repl_lst, true); // avoid infinite recursion when re-substituting the wildcards
477                 }
478         }
479
480         return *this;
481 }
482
483 /** Default interface of nth derivative ex::diff(s, n).  It should be called
484  *  instead of ::derivative(s) for first derivatives and for nth derivatives it
485  *  just recurses down.
486  *
487  *  @param s symbol to differentiate in
488  *  @param nth order of differentiation
489  *  @see ex::diff */
490 ex basic::diff(const symbol & s, unsigned nth) const
491 {
492         // trivial: zeroth derivative
493         if (nth==0)
494                 return ex(*this);
495         
496         // evaluate unevaluated *this before differentiating
497         if (!(flags & status_flags::evaluated))
498                 return ex(*this).diff(s, nth);
499         
500         ex ndiff = this->derivative(s);
501         while (!ndiff.is_zero() &&    // stop differentiating zeros
502                nth>1) {
503                 ndiff = ndiff.diff(s);
504                 --nth;
505         }
506         return ndiff;
507 }
508
509 /** Return a vector containing the free indices of an expression. */
510 exvector basic::get_free_indices(void) const
511 {
512         return exvector(); // return an empty exvector
513 }
514
515 ex basic::simplify_ncmul(const exvector & v) const
516 {
517         return simplified_ncmul(v);
518 }
519
520 // protected
521
522 /** Default implementation of ex::diff(). It simply throws an error message.
523  *
524  *  @exception logic_error (differentiation not supported by this type)
525  *  @see ex::diff */
526 ex basic::derivative(const symbol & s) const
527 {
528         throw(std::logic_error("differentiation not supported by this type"));
529 }
530
531 /** Returns order relation between two objects of same type.  This needs to be
532  *  implemented by each class. It may never return anything else than 0,
533  *  signalling equality, or +1 and -1 signalling inequality and determining
534  *  the canonical ordering.  (Perl hackers will wonder why C++ doesn't feature
535  *  the spaceship operator <=> for denoting just this.) */
536 int basic::compare_same_type(const basic & other) const
537 {
538         return compare_pointers(this, &other);
539 }
540
541 /** Returns true if two objects of same type are equal.  Normally needs
542  *  not be reimplemented as long as it wasn't overwritten by some parent
543  *  class, since it just calls compare_same_type().  The reason why this
544  *  function exists is that sometimes it is easier to determine equality
545  *  than an order relation and then it can be overridden. */
546 bool basic::is_equal_same_type(const basic & other) const
547 {
548         return this->compare_same_type(other)==0;
549 }
550
551 unsigned basic::return_type(void) const
552 {
553         return return_types::commutative;
554 }
555
556 unsigned basic::return_type_tinfo(void) const
557 {
558         return tinfo();
559 }
560
561 /** Compute the hash value of an object and if it makes sense to store it in
562  *  the objects status_flags, do so.  The method inherited from class basic
563  *  computes a hash value based on the type and hash values of possible
564  *  members.  For this reason it is well suited for container classes but
565  *  atomic classes should override this implementation because otherwise they
566  *  would all end up with the same hashvalue. */
567 unsigned basic::calchash(void) const
568 {
569         unsigned v = golden_ratio_hash(tinfo());
570         for (unsigned i=0; i<nops(); i++) {
571                 v = rotate_left_31(v);
572                 v ^= (const_cast<basic *>(this))->op(i).gethash();
573         }
574         
575         // mask out numeric hashes:
576         v &= 0x7FFFFFFFU;
577         
578         // store calculated hash value only if object is already evaluated
579         if (flags & status_flags::evaluated) {
580                 setflag(status_flags::hash_calculated);
581                 hashvalue = v;
582         }
583
584         return v;
585 }
586
587 /** Expand expression, i.e. multiply it out and return the result as a new
588  *  expression. */
589 ex basic::expand(unsigned options) const
590 {
591         return this->setflag(status_flags::expanded);
592 }
593
594
595 //////////
596 // non-virtual functions in this class
597 //////////
598
599 // public
600
601 /** Substitute objects in an expression (syntactic substitution) and return
602  *  the result as a new expression.  There are two valid types of
603  *  replacement arguments: 1) a relational like object==ex and 2) a list of
604  *  relationals lst(object1==ex1,object2==ex2,...), which is converted to
605  *  subs(lst(object1,object2,...),lst(ex1,ex2,...)). */
606 ex basic::subs(const ex & e, bool no_pattern) const
607 {
608         if (e.info(info_flags::relation_equal)) {
609                 return subs(lst(e), no_pattern);
610         }
611         if (!e.info(info_flags::list)) {
612                 throw(std::invalid_argument("basic::subs(ex): argument must be a list"));
613         }
614         lst ls;
615         lst lr;
616         for (unsigned i=0; i<e.nops(); i++) {
617                 ex r = e.op(i);
618                 if (!r.info(info_flags::relation_equal)) {
619                         throw(std::invalid_argument("basic::subs(ex): argument must be a list of equations"));
620                 }
621                 ls.append(r.op(0));
622                 lr.append(r.op(1));
623         }
624         return subs(ls, lr, no_pattern);
625 }
626
627 /** Compare objects to establish canonical ordering.
628  *  All compare functions return: -1 for *this less than other, 0 equal,
629  *  1 greater. */
630 int basic::compare(const basic & other) const
631 {
632         unsigned hash_this = gethash();
633         unsigned hash_other = other.gethash();
634         
635         if (hash_this<hash_other) return -1;
636         if (hash_this>hash_other) return 1;
637         
638         unsigned typeid_this = tinfo();
639         unsigned typeid_other = other.tinfo();
640         
641         if (typeid_this<typeid_other) {
642 //              std::cout << "hash collision, different types: " 
643 //                        << *this << " and " << other << std::endl;
644 //              this->print(print_tree(std::cout));
645 //              std::cout << " and ";
646 //              other.print(print_tree(std::cout));
647 //              std::cout << std::endl;
648                 return -1;
649         }
650         if (typeid_this>typeid_other) {
651 //              std::cout << "hash collision, different types: " 
652 //                        << *this << " and " << other << std::endl;
653 //              this->print(print_tree(std::cout));
654 //              std::cout << " and ";
655 //              other.print(print_tree(std::cout));
656 //              std::cout << std::endl;
657                 return 1;
658         }
659         
660         GINAC_ASSERT(typeid(*this)==typeid(other));
661         
662 //      int cmpval = compare_same_type(other);
663 //      if ((cmpval!=0) && (hash_this<0x80000000U)) {
664 //              std::cout << "hash collision, same type: " 
665 //                        << *this << " and " << other << std::endl;
666 //              this->print(print_tree(std::cout));
667 //              std::cout << " and ";
668 //              other.print(print_tree(std::cout));
669 //              std::cout << std::endl;
670 //      }
671 //      return cmpval;
672         
673         return compare_same_type(other);
674 }
675
676 /** Test for equality.
677  *  This is only a quick test, meaning objects should be in the same domain.
678  *  You might have to .expand(), .normal() objects first, depending on the
679  *  domain of your computation, to get a more reliable answer.
680  *
681  *  @see is_equal_same_type */
682 bool basic::is_equal(const basic & other) const
683 {
684         if (this->gethash()!=other.gethash())
685                 return false;
686         if (this->tinfo()!=other.tinfo())
687                 return false;
688         
689         GINAC_ASSERT(typeid(*this)==typeid(other));
690         
691         return this->is_equal_same_type(other);
692 }
693
694 // protected
695
696 /** Stop further evaluation.
697  *
698  *  @see basic::eval */
699 const basic & basic::hold(void) const
700 {
701         return this->setflag(status_flags::evaluated);
702 }
703
704 /** Ensure the object may be modified without hurting others, throws if this
705  *  is not the case. */
706 void basic::ensure_if_modifiable(void) const
707 {
708         if (this->refcount>1)
709                 throw(std::runtime_error("cannot modify multiply referenced object"));
710 }
711
712 //////////
713 // global variables
714 //////////
715
716 int max_recursion_level = 1024;
717
718 } // namespace GiNaC