]> www.ginac.de Git - ginac.git/blob - ginac/basic.cpp
Added .is_polynomial() method.
[ginac.git] / ginac / basic.cpp
1 /** @file basic.cpp
2  *
3  *  Implementation of GiNaC's ABC. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2006 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25 #ifdef DO_GINAC_ASSERT
26 #  include <typeinfo>
27 #endif
28
29 #include "basic.h"
30 #include "ex.h"
31 #include "numeric.h"
32 #include "power.h"
33 #include "symbol.h"
34 #include "lst.h"
35 #include "ncmul.h"
36 #include "relational.h"
37 #include "operators.h"
38 #include "wildcard.h"
39 #include "archive.h"
40 #include "utils.h"
41
42 namespace GiNaC {
43
44 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(basic, void,
45   print_func<print_context>(&basic::do_print).
46   print_func<print_tree>(&basic::do_print_tree).
47   print_func<print_python_repr>(&basic::do_print_python_repr))
48
49 //////////
50 // default constructor, destructor, copy constructor and assignment operator
51 //////////
52
53 // public
54
55 /** basic copy constructor: implicitly assumes that the other class is of
56  *  the exact same type (as it's used by duplicate()), so it can copy the
57  *  tinfo_key and the hash value. */
58 basic::basic(const basic & other) : tinfo_key(other.tinfo_key), flags(other.flags & ~status_flags::dynallocated), hashvalue(other.hashvalue)
59 {
60 }
61
62 /** basic assignment operator: the other object might be of a derived class. */
63 const basic & basic::operator=(const basic & other)
64 {
65         unsigned fl = other.flags & ~status_flags::dynallocated;
66         if (tinfo_key != other.tinfo_key) {
67                 // The other object is of a derived class, so clear the flags as they
68                 // might no longer apply (especially hash_calculated). Oh, and don't
69                 // copy the tinfo_key: it is already set correctly for this object.
70                 fl &= ~(status_flags::evaluated | status_flags::expanded | status_flags::hash_calculated);
71         } else {
72                 // The objects are of the exact same class, so copy the hash value.
73                 hashvalue = other.hashvalue;
74         }
75         flags = fl;
76         set_refcount(0);
77         return *this;
78 }
79
80 // protected
81
82 // none (all inlined)
83
84 //////////
85 // other constructors
86 //////////
87
88 // none (all inlined)
89
90 //////////
91 // archiving
92 //////////
93
94 /** Construct object from archive_node. */
95 basic::basic(const archive_node &n, lst &sym_lst) : flags(0)
96 {
97         // Reconstruct tinfo_key from class name
98         std::string class_name;
99         if (n.find_string("class", class_name))
100                 tinfo_key = find_tinfo_key(class_name);
101         else
102                 throw (std::runtime_error("archive node contains no class name"));
103 }
104
105 /** Unarchive the object. */
106 DEFAULT_UNARCHIVE(basic)
107
108 /** Archive the object. */
109 void basic::archive(archive_node &n) const
110 {
111         n.add_string("class", class_name());
112 }
113
114 //////////
115 // new virtual functions which can be overridden by derived classes
116 //////////
117
118 // public
119
120 /** Output to stream. This performs double dispatch on the dynamic type of
121  *  *this and the dynamic type of the supplied print context.
122  *  @param c print context object that describes the output formatting
123  *  @param level value that is used to identify the precedence or indentation
124  *               level for placing parentheses and formatting */
125 void basic::print(const print_context & c, unsigned level) const
126 {
127         print_dispatch(get_class_info(), c, level);
128 }
129
130 /** Like print(), but dispatch to the specified class. Can be used by
131  *  implementations of print methods to dispatch to the method of the
132  *  superclass.
133  *
134  *  @see basic::print */
135 void basic::print_dispatch(const registered_class_info & ri, const print_context & c, unsigned level) const
136 {
137         // Double dispatch on object type and print_context type
138         const registered_class_info * reg_info = &ri;
139         const print_context_class_info * pc_info = &c.get_class_info();
140
141 next_class:
142         const std::vector<print_functor> & pdt = reg_info->options.get_print_dispatch_table();
143
144 next_context:
145         unsigned id = pc_info->options.get_id();
146         if (id >= pdt.size() || !(pdt[id].is_valid())) {
147
148                 // Method not found, try parent print_context class
149                 const print_context_class_info * parent_pc_info = pc_info->get_parent();
150                 if (parent_pc_info) {
151                         pc_info = parent_pc_info;
152                         goto next_context;
153                 }
154
155                 // Method still not found, try parent class
156                 const registered_class_info * parent_reg_info = reg_info->get_parent();
157                 if (parent_reg_info) {
158                         reg_info = parent_reg_info;
159                         pc_info = &c.get_class_info();
160                         goto next_class;
161                 }
162
163                 // Method still not found. This shouldn't happen because basic (the
164                 // base class of the algebraic hierarchy) registers a method for
165                 // print_context (the base class of the print context hierarchy),
166                 // so if we end up here, there's something wrong with the class
167                 // registry.
168                 throw (std::runtime_error(std::string("basic::print(): method for ") + class_name() + "/" + c.class_name() + " not found"));
169
170         } else {
171
172                 // Call method
173                 pdt[id](*this, c, level);
174         }
175 }
176
177 /** Default output to stream. */
178 void basic::do_print(const print_context & c, unsigned level) const
179 {
180         c.s << "[" << class_name() << " object]";
181 }
182
183 /** Tree output to stream. */
184 void basic::do_print_tree(const print_tree & c, unsigned level) const
185 {
186         c.s << std::string(level, ' ') << class_name() << " @" << this
187             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec;
188         if (nops())
189                 c.s << ", nops=" << nops();
190         c.s << std::endl;
191         for (size_t i=0; i<nops(); ++i)
192                 op(i).print(c, level + c.delta_indent);
193 }
194
195 /** Python parsable output to stream. */
196 void basic::do_print_python_repr(const print_python_repr & c, unsigned level) const
197 {
198         c.s << class_name() << "()";
199 }
200
201 /** Little wrapper around print to be called within a debugger.
202  *  This is needed because you cannot call foo.print(cout) from within the
203  *  debugger because it might not know what cout is.  This method can be
204  *  invoked with no argument and it will simply print to stdout.
205  *
206  *  @see basic::print
207  *  @see basic::dbgprinttree */
208 void basic::dbgprint() const
209 {
210         this->print(print_dflt(std::cerr));
211         std::cerr << std::endl;
212 }
213
214 /** Little wrapper around printtree to be called within a debugger.
215  *
216  *  @see basic::dbgprint */
217 void basic::dbgprinttree() const
218 {
219         this->print(print_tree(std::cerr));
220 }
221
222 /** Return relative operator precedence (for parenthezing output). */
223 unsigned basic::precedence() const
224 {
225         return 70;
226 }
227
228 /** Information about the object.
229  *
230  *  @see class info_flags */
231 bool basic::info(unsigned inf) const
232 {
233         // all possible properties are false for basic objects
234         return false;
235 }
236
237 /** Number of operands/members. */
238 size_t basic::nops() const
239 {
240         // iterating from 0 to nops() on atomic objects should be an empty loop,
241         // and accessing their elements is a range error.  Container objects should
242         // override this.
243         return 0;
244 }
245
246 /** Return operand/member at position i. */
247 ex basic::op(size_t i) const
248 {
249         throw(std::range_error(std::string("basic::op(): ") + class_name() + std::string(" has no operands")));
250 }
251
252 /** Return modifyable operand/member at position i. */
253 ex & basic::let_op(size_t i)
254 {
255         ensure_if_modifiable();
256         throw(std::range_error(std::string("basic::let_op(): ") + class_name() + std::string(" has no operands")));
257 }
258
259 ex basic::operator[](const ex & index) const
260 {
261         if (is_exactly_a<numeric>(index))
262                 return op(static_cast<size_t>(ex_to<numeric>(index).to_int()));
263
264         throw(std::invalid_argument(std::string("non-numeric indices not supported by ") + class_name()));
265 }
266
267 ex basic::operator[](size_t i) const
268 {
269         return op(i);
270 }
271
272 ex & basic::operator[](const ex & index)
273 {
274         if (is_exactly_a<numeric>(index))
275                 return let_op(ex_to<numeric>(index).to_int());
276
277         throw(std::invalid_argument(std::string("non-numeric indices not supported by ") + class_name()));
278 }
279
280 ex & basic::operator[](size_t i)
281 {
282         return let_op(i);
283 }
284
285 /** Test for occurrence of a pattern.  An object 'has' a pattern if it matches
286  *  the pattern itself or one of the children 'has' it.  As a consequence
287  *  (according to the definition of children) given e=x+y+z, e.has(x) is true
288  *  but e.has(x+y) is false. */
289 bool basic::has(const ex & pattern, unsigned options) const
290 {
291         lst repl_lst;
292         if (match(pattern, repl_lst))
293                 return true;
294         for (size_t i=0; i<nops(); i++)
295                 if (op(i).has(pattern, options))
296                         return true;
297         
298         return false;
299 }
300
301 /** Construct new expression by applying the specified function to all
302  *  sub-expressions (one level only, not recursively). */
303 ex basic::map(map_function & f) const
304 {
305         size_t num = nops();
306         if (num == 0)
307                 return *this;
308
309         basic *copy = NULL;
310         for (size_t i=0; i<num; i++) {
311                 const ex & o = op(i);
312                 const ex & n = f(o);
313                 if (!are_ex_trivially_equal(o, n)) {
314                         if (copy == NULL)
315                                 copy = duplicate();
316                         copy->let_op(i) = n;
317                 }
318         }
319
320         if (copy) {
321                 copy->setflag(status_flags::dynallocated);
322                 copy->clearflag(status_flags::hash_calculated | status_flags::expanded);
323                 return *copy;
324         } else
325                 return *this;
326 }
327
328 /** Check whether this is a polynomial in the given variables. */
329 bool basic::is_polynomial(const ex & var) const
330 {
331         return !has(var) || is_equal(ex_to<basic>(var));
332 }
333
334 /** Return degree of highest power in object s. */
335 int basic::degree(const ex & s) const
336 {
337         return is_equal(ex_to<basic>(s)) ? 1 : 0;
338 }
339
340 /** Return degree of lowest power in object s. */
341 int basic::ldegree(const ex & s) const
342 {
343         return is_equal(ex_to<basic>(s)) ? 1 : 0;
344 }
345
346 /** Return coefficient of degree n in object s. */
347 ex basic::coeff(const ex & s, int n) const
348 {
349         if (is_equal(ex_to<basic>(s)))
350                 return n==1 ? _ex1 : _ex0;
351         else
352                 return n==0 ? *this : _ex0;
353 }
354
355 /** Sort expanded expression in terms of powers of some object(s).
356  *  @param s object(s) to sort in
357  *  @param distributed recursive or distributed form (only used when s is a list) */
358 ex basic::collect(const ex & s, bool distributed) const
359 {
360         ex x;
361         if (is_a<lst>(s)) {
362
363                 // List of objects specified
364                 if (s.nops() == 0)
365                         return *this;
366                 if (s.nops() == 1)
367                         return collect(s.op(0));
368
369                 else if (distributed) {
370
371                         // Get lower/upper degree of all symbols in list
372                         size_t num = s.nops();
373                         struct sym_info {
374                                 ex sym;
375                                 int ldeg, deg;
376                                 int cnt;  // current degree, 'counter'
377                                 ex coeff; // coefficient for degree 'cnt'
378                         };
379                         sym_info *si = new sym_info[num];
380                         ex c = *this;
381                         for (size_t i=0; i<num; i++) {
382                                 si[i].sym = s.op(i);
383                                 si[i].ldeg = si[i].cnt = this->ldegree(si[i].sym);
384                                 si[i].deg = this->degree(si[i].sym);
385                                 c = si[i].coeff = c.coeff(si[i].sym, si[i].cnt);
386                         }
387
388                         while (true) {
389
390                                 // Calculate coeff*x1^c1*...*xn^cn
391                                 ex y = _ex1;
392                                 for (size_t i=0; i<num; i++) {
393                                         int cnt = si[i].cnt;
394                                         y *= power(si[i].sym, cnt);
395                                 }
396                                 x += y * si[num - 1].coeff;
397
398                                 // Increment counters
399                                 size_t n = num - 1;
400                                 while (true) {
401                                         ++si[n].cnt;
402                                         if (si[n].cnt <= si[n].deg) {
403                                                 // Update coefficients
404                                                 ex c;
405                                                 if (n == 0)
406                                                         c = *this;
407                                                 else
408                                                         c = si[n - 1].coeff;
409                                                 for (size_t i=n; i<num; i++)
410                                                         c = si[i].coeff = c.coeff(si[i].sym, si[i].cnt);
411                                                 break;
412                                         }
413                                         if (n == 0)
414                                                 goto done;
415                                         si[n].cnt = si[n].ldeg;
416                                         n--;
417                                 }
418                         }
419
420 done:           delete[] si;
421
422                 } else {
423
424                         // Recursive form
425                         x = *this;
426                         size_t n = s.nops() - 1;
427                         while (true) {
428                                 x = x.collect(s[n]);
429                                 if (n == 0)
430                                         break;
431                                 n--;
432                         }
433                 }
434
435         } else {
436
437                 // Only one object specified
438                 for (int n=this->ldegree(s); n<=this->degree(s); ++n)
439                         x += this->coeff(s,n)*power(s,n);
440         }
441         
442         // correct for lost fractional arguments and return
443         return x + (*this - x).expand();
444 }
445
446 /** Perform automatic non-interruptive term rewriting rules. */
447 ex basic::eval(int level) const
448 {
449         // There is nothing to do for basic objects:
450         return hold();
451 }
452
453 /** Function object to be applied by basic::evalf(). */
454 struct evalf_map_function : public map_function {
455         int level;
456         evalf_map_function(int l) : level(l) {}
457         ex operator()(const ex & e) { return evalf(e, level); }
458 };
459
460 /** Evaluate object numerically. */
461 ex basic::evalf(int level) const
462 {
463         if (nops() == 0)
464                 return *this;
465         else {
466                 if (level == 1)
467                         return *this;
468                 else if (level == -max_recursion_level)
469                         throw(std::runtime_error("max recursion level reached"));
470                 else {
471                         evalf_map_function map_evalf(level - 1);
472                         return map(map_evalf);
473                 }
474         }
475 }
476
477 /** Function object to be applied by basic::evalm(). */
478 struct evalm_map_function : public map_function {
479         ex operator()(const ex & e) { return evalm(e); }
480 } map_evalm;
481
482 /** Evaluate sums, products and integer powers of matrices. */
483 ex basic::evalm() const
484 {
485         if (nops() == 0)
486                 return *this;
487         else
488                 return map(map_evalm);
489 }
490
491 /** Function object to be applied by basic::eval_integ(). */
492 struct eval_integ_map_function : public map_function {
493         ex operator()(const ex & e) { return eval_integ(e); }
494 } map_eval_integ;
495
496 /** Evaluate integrals, if result is known. */
497 ex basic::eval_integ() const
498 {
499         if (nops() == 0)
500                 return *this;
501         else
502                 return map(map_eval_integ);
503 }
504
505 /** Perform automatic symbolic evaluations on indexed expression that
506  *  contains this object as the base expression. */
507 ex basic::eval_indexed(const basic & i) const
508  // this function can't take a "const ex & i" because that would result
509  // in an infinite eval() loop
510 {
511         // There is nothing to do for basic objects
512         return i.hold();
513 }
514
515 /** Add two indexed expressions. They are guaranteed to be of class indexed
516  *  (or a subclass) and their indices are compatible. This function is used
517  *  internally by simplify_indexed().
518  *
519  *  @param self First indexed expression; its base object is *this
520  *  @param other Second indexed expression
521  *  @return sum of self and other 
522  *  @see ex::simplify_indexed() */
523 ex basic::add_indexed(const ex & self, const ex & other) const
524 {
525         return self + other;
526 }
527
528 /** Multiply an indexed expression with a scalar. This function is used
529  *  internally by simplify_indexed().
530  *
531  *  @param self Indexed expression; its base object is *this
532  *  @param other Numeric value
533  *  @return product of self and other
534  *  @see ex::simplify_indexed() */
535 ex basic::scalar_mul_indexed(const ex & self, const numeric & other) const
536 {
537         return self * other;
538 }
539
540 /** Try to contract two indexed expressions that appear in the same product. 
541  *  If a contraction exists, the function overwrites one or both of the
542  *  expressions and returns true. Otherwise it returns false. It is
543  *  guaranteed that both expressions are of class indexed (or a subclass)
544  *  and that at least one dummy index has been found. This functions is
545  *  used internally by simplify_indexed().
546  *
547  *  @param self Pointer to first indexed expression; its base object is *this
548  *  @param other Pointer to second indexed expression
549  *  @param v The complete vector of factors
550  *  @return true if the contraction was successful, false otherwise
551  *  @see ex::simplify_indexed() */
552 bool basic::contract_with(exvector::iterator self, exvector::iterator other, exvector & v) const
553 {
554         // Do nothing
555         return false;
556 }
557
558 /** Check whether the expression matches a given pattern. For every wildcard
559  *  object in the pattern, an expression of the form "wildcard == matching_expression"
560  *  is added to repl_lst. */
561 bool basic::match(const ex & pattern, lst & repl_lst) const
562 {
563 /*
564         Sweet sweet shapes, sweet sweet shapes,
565         That's the key thing, right right.
566         Feed feed face, feed feed shapes,
567         But who is the king tonight?
568         Who is the king tonight?
569         Pattern is the thing, the key thing-a-ling,
570         But who is the king of Pattern?
571         But who is the king, the king thing-a-ling,
572         Who is the king of Pattern?
573         Bog is the king, the king thing-a-ling,
574         Bog is the king of Pattern.
575         Ba bu-bu-bu-bu bu-bu-bu-bu-bu-bu bu-bu
576         Bog is the king of Pattern.
577 */
578
579         if (is_exactly_a<wildcard>(pattern)) {
580
581                 // Wildcard matches anything, but check whether we already have found
582                 // a match for that wildcard first (if so, the earlier match must be
583                 // the same expression)
584                 for (lst::const_iterator it = repl_lst.begin(); it != repl_lst.end(); ++it) {
585                         if (it->op(0).is_equal(pattern))
586                                 return is_equal(ex_to<basic>(it->op(1)));
587                 }
588                 repl_lst.append(pattern == *this);
589                 return true;
590
591         } else {
592
593                 // Expression must be of the same type as the pattern
594                 if (tinfo() != ex_to<basic>(pattern).tinfo())
595                         return false;
596
597                 // Number of subexpressions must match
598                 if (nops() != pattern.nops())
599                         return false;
600
601                 // No subexpressions? Then just compare the objects (there can't be
602                 // wildcards in the pattern)
603                 if (nops() == 0)
604                         return is_equal_same_type(ex_to<basic>(pattern));
605
606                 // Check whether attributes that are not subexpressions match
607                 if (!match_same_type(ex_to<basic>(pattern)))
608                         return false;
609
610                 // Otherwise the subexpressions must match one-to-one
611                 for (size_t i=0; i<nops(); i++)
612                         if (!op(i).match(pattern.op(i), repl_lst))
613                                 return false;
614
615                 // Looks similar enough, match found
616                 return true;
617         }
618 }
619
620 /** Helper function for subs(). Does not recurse into subexpressions. */
621 ex basic::subs_one_level(const exmap & m, unsigned options) const
622 {
623         exmap::const_iterator it;
624
625         if (options & subs_options::no_pattern) {
626                 ex thisex = *this;
627                 it = m.find(thisex);
628                 if (it != m.end())
629                         return it->second;
630                 return thisex;
631         } else {
632                 for (it = m.begin(); it != m.end(); ++it) {
633                         lst repl_lst;
634                         if (match(ex_to<basic>(it->first), repl_lst))
635                                 return it->second.subs(repl_lst, options | subs_options::no_pattern); // avoid infinite recursion when re-substituting the wildcards
636                 }
637         }
638
639         return *this;
640 }
641
642 /** Substitute a set of objects by arbitrary expressions. The ex returned
643  *  will already be evaluated. */
644 ex basic::subs(const exmap & m, unsigned options) const
645 {
646         size_t num = nops();
647         if (num) {
648
649                 // Substitute in subexpressions
650                 for (size_t i=0; i<num; i++) {
651                         const ex & orig_op = op(i);
652                         const ex & subsed_op = orig_op.subs(m, options);
653                         if (!are_ex_trivially_equal(orig_op, subsed_op)) {
654
655                                 // Something changed, clone the object
656                                 basic *copy = duplicate();
657                                 copy->setflag(status_flags::dynallocated);
658                                 copy->clearflag(status_flags::hash_calculated | status_flags::expanded);
659
660                                 // Substitute the changed operand
661                                 copy->let_op(i++) = subsed_op;
662
663                                 // Substitute the other operands
664                                 for (; i<num; i++)
665                                         copy->let_op(i) = op(i).subs(m, options);
666
667                                 // Perform substitutions on the new object as a whole
668                                 return copy->subs_one_level(m, options);
669                         }
670                 }
671         }
672
673         // Nothing changed or no subexpressions
674         return subs_one_level(m, options);
675 }
676
677 /** Default interface of nth derivative ex::diff(s, n).  It should be called
678  *  instead of ::derivative(s) for first derivatives and for nth derivatives it
679  *  just recurses down.
680  *
681  *  @param s symbol to differentiate in
682  *  @param nth order of differentiation
683  *  @see ex::diff */
684 ex basic::diff(const symbol & s, unsigned nth) const
685 {
686         // trivial: zeroth derivative
687         if (nth==0)
688                 return ex(*this);
689         
690         // evaluate unevaluated *this before differentiating
691         if (!(flags & status_flags::evaluated))
692                 return ex(*this).diff(s, nth);
693         
694         ex ndiff = this->derivative(s);
695         while (!ndiff.is_zero() &&    // stop differentiating zeros
696                nth>1) {
697                 ndiff = ndiff.diff(s);
698                 --nth;
699         }
700         return ndiff;
701 }
702
703 /** Return a vector containing the free indices of an expression. */
704 exvector basic::get_free_indices() const
705 {
706         return exvector(); // return an empty exvector
707 }
708
709 ex basic::conjugate() const
710 {
711         return *this;
712 }
713
714 ex basic::eval_ncmul(const exvector & v) const
715 {
716         return hold_ncmul(v);
717 }
718
719 // protected
720
721 /** Function object to be applied by basic::derivative(). */
722 struct derivative_map_function : public map_function {
723         const symbol &s;
724         derivative_map_function(const symbol &sym) : s(sym) {}
725         ex operator()(const ex & e) { return diff(e, s); }
726 };
727
728 /** Default implementation of ex::diff(). It maps the operation on the
729  *  operands (or returns 0 when the object has no operands).
730  *
731  *  @see ex::diff */
732 ex basic::derivative(const symbol & s) const
733 {
734         if (nops() == 0)
735                 return _ex0;
736         else {
737                 derivative_map_function map_derivative(s);
738                 return map(map_derivative);
739         }
740 }
741
742 /** Returns order relation between two objects of same type.  This needs to be
743  *  implemented by each class. It may never return anything else than 0,
744  *  signalling equality, or +1 and -1 signalling inequality and determining
745  *  the canonical ordering.  (Perl hackers will wonder why C++ doesn't feature
746  *  the spaceship operator <=> for denoting just this.) */
747 int basic::compare_same_type(const basic & other) const
748 {
749         return compare_pointers(this, &other);
750 }
751
752 /** Returns true if two objects of same type are equal.  Normally needs
753  *  not be reimplemented as long as it wasn't overwritten by some parent
754  *  class, since it just calls compare_same_type().  The reason why this
755  *  function exists is that sometimes it is easier to determine equality
756  *  than an order relation and then it can be overridden. */
757 bool basic::is_equal_same_type(const basic & other) const
758 {
759         return compare_same_type(other)==0;
760 }
761
762 /** Returns true if the attributes of two objects are similar enough for
763  *  a match. This function must not match subexpressions (this is already
764  *  done by basic::match()). Only attributes not accessible by op() should
765  *  be compared. This is also the reason why this function doesn't take the
766  *  wildcard replacement list from match() as an argument: only subexpressions
767  *  are subject to wildcard matches. Also, this function only needs to be
768  *  implemented for container classes because is_equal_same_type() is
769  *  automatically used instead of match_same_type() if nops() == 0.
770  *
771  *  @see basic::match */
772 bool basic::match_same_type(const basic & other) const
773 {
774         // The default is to only consider subexpressions, but not any other
775         // attributes
776         return true;
777 }
778
779 unsigned basic::return_type() const
780 {
781         return return_types::commutative;
782 }
783
784 tinfo_t basic::return_type_tinfo() const
785 {
786         return tinfo_key;
787 }
788
789 /** Compute the hash value of an object and if it makes sense to store it in
790  *  the objects status_flags, do so.  The method inherited from class basic
791  *  computes a hash value based on the type and hash values of possible
792  *  members.  For this reason it is well suited for container classes but
793  *  atomic classes should override this implementation because otherwise they
794  *  would all end up with the same hashvalue. */
795 unsigned basic::calchash() const
796 {
797         unsigned v = golden_ratio_hash((p_int)tinfo());
798         for (size_t i=0; i<nops(); i++) {
799                 v = rotate_left(v);
800                 v ^= this->op(i).gethash();
801         }
802
803         // store calculated hash value only if object is already evaluated
804         if (flags & status_flags::evaluated) {
805                 setflag(status_flags::hash_calculated);
806                 hashvalue = v;
807         }
808
809         return v;
810 }
811
812 /** Function object to be applied by basic::expand(). */
813 struct expand_map_function : public map_function {
814         unsigned options;
815         expand_map_function(unsigned o) : options(o) {}
816         ex operator()(const ex & e) { return e.expand(options); }
817 };
818
819 /** Expand expression, i.e. multiply it out and return the result as a new
820  *  expression. */
821 ex basic::expand(unsigned options) const
822 {
823         if (nops() == 0)
824                 return (options == 0) ? setflag(status_flags::expanded) : *this;
825         else {
826                 expand_map_function map_expand(options);
827                 return ex_to<basic>(map(map_expand)).setflag(options == 0 ? status_flags::expanded : 0);
828         }
829 }
830
831
832 //////////
833 // non-virtual functions in this class
834 //////////
835
836 // public
837
838 /** Compare objects syntactically to establish canonical ordering.
839  *  All compare functions return: -1 for *this less than other, 0 equal,
840  *  1 greater. */
841 int basic::compare(const basic & other) const
842 {
843 #ifdef GINAC_COMPARE_STATISTICS
844         compare_statistics.total_basic_compares++;
845 #endif
846         const unsigned hash_this = gethash();
847         const unsigned hash_other = other.gethash();
848         if (hash_this<hash_other) return -1;
849         if (hash_this>hash_other) return 1;
850 #ifdef GINAC_COMPARE_STATISTICS
851         compare_statistics.compare_same_hashvalue++;
852 #endif
853
854         const tinfo_t typeid_this = tinfo();
855         const tinfo_t typeid_other = other.tinfo();
856         if (typeid_this==typeid_other) {
857                 GINAC_ASSERT(typeid(*this)==typeid(other));
858 //              int cmpval = compare_same_type(other);
859 //              if (cmpval!=0) {
860 //                      std::cout << "hash collision, same type: " 
861 //                                << *this << " and " << other << std::endl;
862 //                      this->print(print_tree(std::cout));
863 //                      std::cout << " and ";
864 //                      other.print(print_tree(std::cout));
865 //                      std::cout << std::endl;
866 //              }
867 //              return cmpval;
868 #ifdef GINAC_COMPARE_STATISTICS
869                 compare_statistics.compare_same_type++;
870 #endif
871                 return compare_same_type(other);
872         } else {
873 //              std::cout << "hash collision, different types: " 
874 //                        << *this << " and " << other << std::endl;
875 //              this->print(print_tree(std::cout));
876 //              std::cout << " and ";
877 //              other.print(print_tree(std::cout));
878 //              std::cout << std::endl;
879                 return (typeid_this<typeid_other ? -1 : 1);
880         }
881 }
882
883 /** Test for syntactic equality.
884  *  This is only a quick test, meaning objects should be in the same domain.
885  *  You might have to .expand(), .normal() objects first, depending on the
886  *  domain of your computation, to get a more reliable answer.
887  *
888  *  @see is_equal_same_type */
889 bool basic::is_equal(const basic & other) const
890 {
891 #ifdef GINAC_COMPARE_STATISTICS
892         compare_statistics.total_basic_is_equals++;
893 #endif
894         if (this->gethash()!=other.gethash())
895                 return false;
896 #ifdef GINAC_COMPARE_STATISTICS
897         compare_statistics.is_equal_same_hashvalue++;
898 #endif
899         if (this->tinfo()!=other.tinfo())
900                 return false;
901         
902         GINAC_ASSERT(typeid(*this)==typeid(other));
903         
904 #ifdef GINAC_COMPARE_STATISTICS
905         compare_statistics.is_equal_same_type++;
906 #endif
907         return is_equal_same_type(other);
908 }
909
910 // protected
911
912 /** Stop further evaluation.
913  *
914  *  @see basic::eval */
915 const basic & basic::hold() const
916 {
917         return setflag(status_flags::evaluated);
918 }
919
920 /** Ensure the object may be modified without hurting others, throws if this
921  *  is not the case. */
922 void basic::ensure_if_modifiable() const
923 {
924         if (get_refcount() > 1)
925                 throw(std::runtime_error("cannot modify multiply referenced object"));
926         clearflag(status_flags::hash_calculated | status_flags::evaluated);
927 }
928
929 //////////
930 // global variables
931 //////////
932
933 int max_recursion_level = 1024;
934
935
936 #ifdef GINAC_COMPARE_STATISTICS
937 compare_statistics_t::~compare_statistics_t()
938 {
939         std::clog << "ex::compare() called " << total_compares << " times" << std::endl;
940         std::clog << "nontrivial compares: " << nontrivial_compares << " times" << std::endl;
941         std::clog << "basic::compare() called " << total_basic_compares << " times" << std::endl;
942         std::clog << "same hashvalue in compare(): " << compare_same_hashvalue << " times" << std::endl;
943         std::clog << "compare_same_type() called " << compare_same_type << " times" << std::endl;
944         std::clog << std::endl;
945         std::clog << "ex::is_equal() called " << total_is_equals << " times" << std::endl;
946         std::clog << "nontrivial is_equals: " << nontrivial_is_equals << " times" << std::endl;
947         std::clog << "basic::is_equal() called " << total_basic_is_equals << " times" << std::endl;
948         std::clog << "same hashvalue in is_equal(): " << is_equal_same_hashvalue << " times" << std::endl;
949         std::clog << "is_equal_same_type() called " << is_equal_same_type << " times" << std::endl;
950         std::clog << std::endl;
951         std::clog << "basic::gethash() called " << total_gethash << " times" << std::endl;
952         std::clog << "used cached hashvalue " << gethash_cached << " times" << std::endl;
953 }
954
955 compare_statistics_t compare_statistics;
956 #endif
957
958 } // namespace GiNaC