3 * Implementation of GiNaC's ABC. */
6 * GiNaC Copyright (C) 1999-2004 Johannes Gutenberg University Mainz, Germany
8 * This program is free software; you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation; either version 2 of the License, or
11 * (at your option) any later version.
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
25 #ifdef DO_GINAC_ASSERT
36 #include "relational.h"
37 #include "operators.h"
44 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(basic, void,
45 print_func<print_context>(&basic::do_print).
46 print_func<print_tree>(&basic::do_print_tree).
47 print_func<print_python_repr>(&basic::do_print_python_repr))
50 // default constructor, destructor, copy constructor and assignment operator
55 /** basic copy constructor: implicitly assumes that the other class is of
56 * the exact same type (as it's used by duplicate()), so it can copy the
57 * tinfo_key and the hash value. */
58 basic::basic(const basic & other) : tinfo_key(other.tinfo_key), flags(other.flags & ~status_flags::dynallocated), hashvalue(other.hashvalue)
60 GINAC_ASSERT(typeid(*this) == typeid(other));
63 /** basic assignment operator: the other object might be of a derived class. */
64 const basic & basic::operator=(const basic & other)
66 unsigned fl = other.flags & ~status_flags::dynallocated;
67 if (tinfo_key != other.tinfo_key) {
68 // The other object is of a derived class, so clear the flags as they
69 // might no longer apply (especially hash_calculated). Oh, and don't
70 // copy the tinfo_key: it is already set correctly for this object.
71 fl &= ~(status_flags::evaluated | status_flags::expanded | status_flags::hash_calculated);
73 // The objects are of the exact same class, so copy the hash value.
74 hashvalue = other.hashvalue;
95 /** Construct object from archive_node. */
96 basic::basic(const archive_node &n, lst &sym_lst) : flags(0)
98 // Reconstruct tinfo_key from class name
99 std::string class_name;
100 if (n.find_string("class", class_name))
101 tinfo_key = find_tinfo_key(class_name);
103 throw (std::runtime_error("archive node contains no class name"));
106 /** Unarchive the object. */
107 DEFAULT_UNARCHIVE(basic)
109 /** Archive the object. */
110 void basic::archive(archive_node &n) const
112 n.add_string("class", class_name());
116 // new virtual functions which can be overridden by derived classes
121 /** Output to stream. This performs double dispatch on the dynamic type of
122 * *this and the dynamic type of the supplied print context.
123 * @param c print context object that describes the output formatting
124 * @param level value that is used to identify the precedence or indentation
125 * level for placing parentheses and formatting */
126 void basic::print(const print_context & c, unsigned level) const
128 print_dispatch(get_class_info(), c, level);
131 /** Like print(), but dispatch to the specified class. Can be used by
132 * implementations of print methods to dispatch to the method of the
135 * @see basic::print */
136 void basic::print_dispatch(const registered_class_info & ri, const print_context & c, unsigned level) const
138 // Double dispatch on object type and print_context type
139 const registered_class_info * reg_info = &ri;
140 const print_context_class_info * pc_info = &c.get_class_info();
143 const std::vector<print_functor> & pdt = reg_info->options.get_print_dispatch_table();
146 unsigned id = pc_info->options.get_id();
147 if (id >= pdt.size() || !(pdt[id].is_valid())) {
149 // Method not found, try parent print_context class
150 const print_context_class_info * parent_pc_info = pc_info->get_parent();
151 if (parent_pc_info) {
152 pc_info = parent_pc_info;
156 // Method still not found, try parent class
157 const registered_class_info * parent_reg_info = reg_info->get_parent();
158 if (parent_reg_info) {
159 reg_info = parent_reg_info;
160 pc_info = &c.get_class_info();
164 // Method still not found. This shouldn't happen because basic (the
165 // base class of the algebraic hierarchy) registers a method for
166 // print_context (the base class of the print context hierarchy),
167 // so if we end up here, there's something wrong with the class
169 throw (std::runtime_error(std::string("basic::print(): method for ") + class_name() + "/" + c.class_name() + " not found"));
174 pdt[id](*this, c, level);
178 /** Default output to stream. */
179 void basic::do_print(const print_context & c, unsigned level) const
181 c.s << "[" << class_name() << " object]";
184 /** Tree output to stream. */
185 void basic::do_print_tree(const print_tree & c, unsigned level) const
187 c.s << std::string(level, ' ') << class_name() << " @" << this
188 << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec;
190 c.s << ", nops=" << nops();
192 for (size_t i=0; i<nops(); ++i)
193 op(i).print(c, level + c.delta_indent);
196 /** Python parsable output to stream. */
197 void basic::do_print_python_repr(const print_python_repr & c, unsigned level) const
199 c.s << class_name() << "()";
202 /** Little wrapper around print to be called within a debugger.
203 * This is needed because you cannot call foo.print(cout) from within the
204 * debugger because it might not know what cout is. This method can be
205 * invoked with no argument and it will simply print to stdout.
208 * @see basic::dbgprinttree */
209 void basic::dbgprint() const
211 this->print(std::cerr);
212 std::cerr << std::endl;
215 /** Little wrapper around printtree to be called within a debugger.
217 * @see basic::dbgprint */
218 void basic::dbgprinttree() const
220 this->print(print_tree(std::cerr));
223 /** Return relative operator precedence (for parenthezing output). */
224 unsigned basic::precedence() const
229 /** Information about the object.
231 * @see class info_flags */
232 bool basic::info(unsigned inf) const
234 // all possible properties are false for basic objects
238 /** Number of operands/members. */
239 size_t basic::nops() const
241 // iterating from 0 to nops() on atomic objects should be an empty loop,
242 // and accessing their elements is a range error. Container objects should
247 /** Return operand/member at position i. */
248 ex basic::op(size_t i) const
250 throw(std::range_error(std::string("basic::op(): ") + class_name() + std::string(" has no operands")));
253 /** Return modifyable operand/member at position i. */
254 ex & basic::let_op(size_t i)
256 ensure_if_modifiable();
257 throw(std::range_error(std::string("basic::let_op(): ") + class_name() + std::string(" has no operands")));
260 ex basic::operator[](const ex & index) const
262 if (is_exactly_a<numeric>(index))
263 return op(static_cast<size_t>(ex_to<numeric>(index).to_int()));
265 throw(std::invalid_argument(std::string("non-numeric indices not supported by ") + class_name()));
268 ex basic::operator[](size_t i) const
273 ex & basic::operator[](const ex & index)
275 if (is_exactly_a<numeric>(index))
276 return let_op(ex_to<numeric>(index).to_int());
278 throw(std::invalid_argument(std::string("non-numeric indices not supported by ") + class_name()));
281 ex & basic::operator[](size_t i)
286 /** Test for occurrence of a pattern. An object 'has' a pattern if it matches
287 * the pattern itself or one of the children 'has' it. As a consequence
288 * (according to the definition of children) given e=x+y+z, e.has(x) is true
289 * but e.has(x+y) is false. */
290 bool basic::has(const ex & pattern) const
293 if (match(pattern, repl_lst))
295 for (size_t i=0; i<nops(); i++)
296 if (op(i).has(pattern))
302 /** Construct new expression by applying the specified function to all
303 * sub-expressions (one level only, not recursively). */
304 ex basic::map(map_function & f) const
311 for (size_t i=0; i<num; i++) {
312 const ex & o = op(i);
314 if (!are_ex_trivially_equal(o, n)) {
322 copy->setflag(status_flags::dynallocated);
323 copy->clearflag(status_flags::hash_calculated | status_flags::expanded);
329 /** Return degree of highest power in object s. */
330 int basic::degree(const ex & s) const
332 return is_equal(ex_to<basic>(s)) ? 1 : 0;
335 /** Return degree of lowest power in object s. */
336 int basic::ldegree(const ex & s) const
338 return is_equal(ex_to<basic>(s)) ? 1 : 0;
341 /** Return coefficient of degree n in object s. */
342 ex basic::coeff(const ex & s, int n) const
344 if (is_equal(ex_to<basic>(s)))
345 return n==1 ? _ex1 : _ex0;
347 return n==0 ? *this : _ex0;
350 /** Sort expanded expression in terms of powers of some object(s).
351 * @param s object(s) to sort in
352 * @param distributed recursive or distributed form (only used when s is a list) */
353 ex basic::collect(const ex & s, bool distributed) const
358 // List of objects specified
362 return collect(s.op(0));
364 else if (distributed) {
366 // Get lower/upper degree of all symbols in list
367 size_t num = s.nops();
371 int cnt; // current degree, 'counter'
372 ex coeff; // coefficient for degree 'cnt'
374 sym_info *si = new sym_info[num];
376 for (size_t i=0; i<num; i++) {
378 si[i].ldeg = si[i].cnt = this->ldegree(si[i].sym);
379 si[i].deg = this->degree(si[i].sym);
380 c = si[i].coeff = c.coeff(si[i].sym, si[i].cnt);
385 // Calculate coeff*x1^c1*...*xn^cn
387 for (size_t i=0; i<num; i++) {
389 y *= power(si[i].sym, cnt);
391 x += y * si[num - 1].coeff;
393 // Increment counters
397 if (si[n].cnt <= si[n].deg) {
398 // Update coefficients
404 for (size_t i=n; i<num; i++)
405 c = si[i].coeff = c.coeff(si[i].sym, si[i].cnt);
410 si[n].cnt = si[n].ldeg;
421 size_t n = s.nops() - 1;
432 // Only one object specified
433 for (int n=this->ldegree(s); n<=this->degree(s); ++n)
434 x += this->coeff(s,n)*power(s,n);
437 // correct for lost fractional arguments and return
438 return x + (*this - x).expand();
441 /** Perform automatic non-interruptive term rewriting rules. */
442 ex basic::eval(int level) const
444 // There is nothing to do for basic objects:
448 /** Function object to be applied by basic::evalf(). */
449 struct evalf_map_function : public map_function {
451 evalf_map_function(int l) : level(l) {}
452 ex operator()(const ex & e) { return evalf(e, level); }
455 /** Evaluate object numerically. */
456 ex basic::evalf(int level) const
463 else if (level == -max_recursion_level)
464 throw(std::runtime_error("max recursion level reached"));
466 evalf_map_function map_evalf(level - 1);
467 return map(map_evalf);
472 /** Function object to be applied by basic::evalm(). */
473 struct evalm_map_function : public map_function {
474 ex operator()(const ex & e) { return evalm(e); }
477 /** Evaluate sums, products and integer powers of matrices. */
478 ex basic::evalm() const
483 return map(map_evalm);
486 /** Perform automatic symbolic evaluations on indexed expression that
487 * contains this object as the base expression. */
488 ex basic::eval_indexed(const basic & i) const
489 // this function can't take a "const ex & i" because that would result
490 // in an infinite eval() loop
492 // There is nothing to do for basic objects
496 /** Add two indexed expressions. They are guaranteed to be of class indexed
497 * (or a subclass) and their indices are compatible. This function is used
498 * internally by simplify_indexed().
500 * @param self First indexed expression; it's base object is *this
501 * @param other Second indexed expression
502 * @return sum of self and other
503 * @see ex::simplify_indexed() */
504 ex basic::add_indexed(const ex & self, const ex & other) const
509 /** Multiply an indexed expression with a scalar. This function is used
510 * internally by simplify_indexed().
512 * @param self Indexed expression; it's base object is *this
513 * @param other Numeric value
514 * @return product of self and other
515 * @see ex::simplify_indexed() */
516 ex basic::scalar_mul_indexed(const ex & self, const numeric & other) const
521 /** Try to contract two indexed expressions that appear in the same product.
522 * If a contraction exists, the function overwrites one or both of the
523 * expressions and returns true. Otherwise it returns false. It is
524 * guaranteed that both expressions are of class indexed (or a subclass)
525 * and that at least one dummy index has been found. This functions is
526 * used internally by simplify_indexed().
528 * @param self Pointer to first indexed expression; it's base object is *this
529 * @param other Pointer to second indexed expression
530 * @param v The complete vector of factors
531 * @return true if the contraction was successful, false otherwise
532 * @see ex::simplify_indexed() */
533 bool basic::contract_with(exvector::iterator self, exvector::iterator other, exvector & v) const
539 /** Check whether the expression matches a given pattern. For every wildcard
540 * object in the pattern, an expression of the form "wildcard == matching_expression"
541 * is added to repl_lst. */
542 bool basic::match(const ex & pattern, lst & repl_lst) const
545 Sweet sweet shapes, sweet sweet shapes,
546 That's the key thing, right right.
547 Feed feed face, feed feed shapes,
548 But who is the king tonight?
549 Who is the king tonight?
550 Pattern is the thing, the key thing-a-ling,
551 But who is the king of Pattern?
552 But who is the king, the king thing-a-ling,
553 Who is the king of Pattern?
554 Bog is the king, the king thing-a-ling,
555 Bog is the king of Pattern.
556 Ba bu-bu-bu-bu bu-bu-bu-bu-bu-bu bu-bu
557 Bog is the king of Pattern.
560 if (is_exactly_a<wildcard>(pattern)) {
562 // Wildcard matches anything, but check whether we already have found
563 // a match for that wildcard first (if so, the earlier match must be
564 // the same expression)
565 for (lst::const_iterator it = repl_lst.begin(); it != repl_lst.end(); ++it) {
566 if (it->op(0).is_equal(pattern))
567 return is_equal(ex_to<basic>(it->op(1)));
569 repl_lst.append(pattern == *this);
574 // Expression must be of the same type as the pattern
575 if (tinfo() != ex_to<basic>(pattern).tinfo())
578 // Number of subexpressions must match
579 if (nops() != pattern.nops())
582 // No subexpressions? Then just compare the objects (there can't be
583 // wildcards in the pattern)
585 return is_equal_same_type(ex_to<basic>(pattern));
587 // Check whether attributes that are not subexpressions match
588 if (!match_same_type(ex_to<basic>(pattern)))
591 // Otherwise the subexpressions must match one-to-one
592 for (size_t i=0; i<nops(); i++)
593 if (!op(i).match(pattern.op(i), repl_lst))
596 // Looks similar enough, match found
601 /** Helper function for subs(). Does not recurse into subexpressions. */
602 ex basic::subs_one_level(const exmap & m, unsigned options) const
604 exmap::const_iterator it;
606 if (options & subs_options::no_pattern) {
611 for (it = m.begin(); it != m.end(); ++it) {
613 if (match(ex_to<basic>(it->first), repl_lst))
614 return it->second.subs(repl_lst, options | subs_options::no_pattern); // avoid infinite recursion when re-substituting the wildcards
621 /** Substitute a set of objects by arbitrary expressions. The ex returned
622 * will already be evaluated. */
623 ex basic::subs(const exmap & m, unsigned options) const
628 // Substitute in subexpressions
629 for (size_t i=0; i<num; i++) {
630 const ex & orig_op = op(i);
631 const ex & subsed_op = orig_op.subs(m, options);
632 if (!are_ex_trivially_equal(orig_op, subsed_op)) {
634 // Something changed, clone the object
635 basic *copy = duplicate();
636 copy->setflag(status_flags::dynallocated);
637 copy->clearflag(status_flags::hash_calculated | status_flags::expanded);
639 // Substitute the changed operand
640 copy->let_op(i++) = subsed_op;
642 // Substitute the other operands
644 copy->let_op(i) = op(i).subs(m, options);
646 // Perform substitutions on the new object as a whole
647 return copy->subs_one_level(m, options);
652 // Nothing changed or no subexpressions
653 return subs_one_level(m, options);
656 /** Default interface of nth derivative ex::diff(s, n). It should be called
657 * instead of ::derivative(s) for first derivatives and for nth derivatives it
658 * just recurses down.
660 * @param s symbol to differentiate in
661 * @param nth order of differentiation
663 ex basic::diff(const symbol & s, unsigned nth) const
665 // trivial: zeroth derivative
669 // evaluate unevaluated *this before differentiating
670 if (!(flags & status_flags::evaluated))
671 return ex(*this).diff(s, nth);
673 ex ndiff = this->derivative(s);
674 while (!ndiff.is_zero() && // stop differentiating zeros
676 ndiff = ndiff.diff(s);
682 /** Return a vector containing the free indices of an expression. */
683 exvector basic::get_free_indices() const
685 return exvector(); // return an empty exvector
688 ex basic::conjugate() const
693 ex basic::eval_ncmul(const exvector & v) const
695 return hold_ncmul(v);
700 /** Function object to be applied by basic::derivative(). */
701 struct derivative_map_function : public map_function {
703 derivative_map_function(const symbol &sym) : s(sym) {}
704 ex operator()(const ex & e) { return diff(e, s); }
707 /** Default implementation of ex::diff(). It maps the operation on the
708 * operands (or returns 0 when the object has no operands).
711 ex basic::derivative(const symbol & s) const
716 derivative_map_function map_derivative(s);
717 return map(map_derivative);
721 /** Returns order relation between two objects of same type. This needs to be
722 * implemented by each class. It may never return anything else than 0,
723 * signalling equality, or +1 and -1 signalling inequality and determining
724 * the canonical ordering. (Perl hackers will wonder why C++ doesn't feature
725 * the spaceship operator <=> for denoting just this.) */
726 int basic::compare_same_type(const basic & other) const
728 return compare_pointers(this, &other);
731 /** Returns true if two objects of same type are equal. Normally needs
732 * not be reimplemented as long as it wasn't overwritten by some parent
733 * class, since it just calls compare_same_type(). The reason why this
734 * function exists is that sometimes it is easier to determine equality
735 * than an order relation and then it can be overridden. */
736 bool basic::is_equal_same_type(const basic & other) const
738 return compare_same_type(other)==0;
741 /** Returns true if the attributes of two objects are similar enough for
742 * a match. This function must not match subexpressions (this is already
743 * done by basic::match()). Only attributes not accessible by op() should
744 * be compared. This is also the reason why this function doesn't take the
745 * wildcard replacement list from match() as an argument: only subexpressions
746 * are subject to wildcard matches. Also, this function only needs to be
747 * implemented for container classes because is_equal_same_type() is
748 * automatically used instead of match_same_type() if nops() == 0.
750 * @see basic::match */
751 bool basic::match_same_type(const basic & other) const
753 // The default is to only consider subexpressions, but not any other
758 unsigned basic::return_type() const
760 return return_types::commutative;
763 unsigned basic::return_type_tinfo() const
768 /** Compute the hash value of an object and if it makes sense to store it in
769 * the objects status_flags, do so. The method inherited from class basic
770 * computes a hash value based on the type and hash values of possible
771 * members. For this reason it is well suited for container classes but
772 * atomic classes should override this implementation because otherwise they
773 * would all end up with the same hashvalue. */
774 unsigned basic::calchash() const
776 unsigned v = golden_ratio_hash(tinfo());
777 for (size_t i=0; i<nops(); i++) {
779 v ^= this->op(i).gethash();
782 // store calculated hash value only if object is already evaluated
783 if (flags & status_flags::evaluated) {
784 setflag(status_flags::hash_calculated);
791 /** Function object to be applied by basic::expand(). */
792 struct expand_map_function : public map_function {
794 expand_map_function(unsigned o) : options(o) {}
795 ex operator()(const ex & e) { return e.expand(options); }
798 /** Expand expression, i.e. multiply it out and return the result as a new
800 ex basic::expand(unsigned options) const
803 return (options == 0) ? setflag(status_flags::expanded) : *this;
805 expand_map_function map_expand(options);
806 return ex_to<basic>(map(map_expand)).setflag(options == 0 ? status_flags::expanded : 0);
812 // non-virtual functions in this class
817 /** Compare objects syntactically to establish canonical ordering.
818 * All compare functions return: -1 for *this less than other, 0 equal,
820 int basic::compare(const basic & other) const
822 #ifdef GINAC_COMPARE_STATISTICS
823 compare_statistics.total_basic_compares++;
825 const unsigned hash_this = gethash();
826 const unsigned hash_other = other.gethash();
827 if (hash_this<hash_other) return -1;
828 if (hash_this>hash_other) return 1;
829 #ifdef GINAC_COMPARE_STATISTICS
830 compare_statistics.compare_same_hashvalue++;
833 const unsigned typeid_this = tinfo();
834 const unsigned typeid_other = other.tinfo();
835 if (typeid_this==typeid_other) {
836 GINAC_ASSERT(typeid(*this)==typeid(other));
837 // int cmpval = compare_same_type(other);
839 // std::cout << "hash collision, same type: "
840 // << *this << " and " << other << std::endl;
841 // this->print(print_tree(std::cout));
842 // std::cout << " and ";
843 // other.print(print_tree(std::cout));
844 // std::cout << std::endl;
847 #ifdef GINAC_COMPARE_STATISTICS
848 compare_statistics.compare_same_type++;
850 return compare_same_type(other);
852 // std::cout << "hash collision, different types: "
853 // << *this << " and " << other << std::endl;
854 // this->print(print_tree(std::cout));
855 // std::cout << " and ";
856 // other.print(print_tree(std::cout));
857 // std::cout << std::endl;
858 return (typeid_this<typeid_other ? -1 : 1);
862 /** Test for syntactic equality.
863 * This is only a quick test, meaning objects should be in the same domain.
864 * You might have to .expand(), .normal() objects first, depending on the
865 * domain of your computation, to get a more reliable answer.
867 * @see is_equal_same_type */
868 bool basic::is_equal(const basic & other) const
870 #ifdef GINAC_COMPARE_STATISTICS
871 compare_statistics.total_basic_is_equals++;
873 if (this->gethash()!=other.gethash())
875 #ifdef GINAC_COMPARE_STATISTICS
876 compare_statistics.is_equal_same_hashvalue++;
878 if (this->tinfo()!=other.tinfo())
881 GINAC_ASSERT(typeid(*this)==typeid(other));
883 #ifdef GINAC_COMPARE_STATISTICS
884 compare_statistics.is_equal_same_type++;
886 return is_equal_same_type(other);
891 /** Stop further evaluation.
893 * @see basic::eval */
894 const basic & basic::hold() const
896 return setflag(status_flags::evaluated);
899 /** Ensure the object may be modified without hurting others, throws if this
900 * is not the case. */
901 void basic::ensure_if_modifiable() const
903 if (get_refcount() > 1)
904 throw(std::runtime_error("cannot modify multiply referenced object"));
905 clearflag(status_flags::hash_calculated | status_flags::evaluated);
912 int max_recursion_level = 1024;
915 #ifdef GINAC_COMPARE_STATISTICS
916 compare_statistics_t::~compare_statistics_t()
918 std::clog << "ex::compare() called " << total_compares << " times" << std::endl;
919 std::clog << "nontrivial compares: " << nontrivial_compares << " times" << std::endl;
920 std::clog << "basic::compare() called " << total_basic_compares << " times" << std::endl;
921 std::clog << "same hashvalue in compare(): " << compare_same_hashvalue << " times" << std::endl;
922 std::clog << "compare_same_type() called " << compare_same_type << " times" << std::endl;
923 std::clog << std::endl;
924 std::clog << "ex::is_equal() called " << total_is_equals << " times" << std::endl;
925 std::clog << "nontrivial is_equals: " << nontrivial_is_equals << " times" << std::endl;
926 std::clog << "basic::is_equal() called " << total_basic_is_equals << " times" << std::endl;
927 std::clog << "same hashvalue in is_equal(): " << is_equal_same_hashvalue << " times" << std::endl;
928 std::clog << "is_equal_same_type() called " << is_equal_same_type << " times" << std::endl;
929 std::clog << std::endl;
930 std::clog << "basic::gethash() called " << total_gethash << " times" << std::endl;
931 std::clog << "used cached hashvalue " << gethash_cached << " times" << std::endl;
934 compare_statistics_t compare_statistics;