- The default implementations of evalf(), diff(), normal() and expand() use
authorChristian Bauer <Christian.Bauer@uni-mainz.de>
Tue, 19 Jun 2001 19:50:02 +0000 (19:50 +0000)
committerChristian Bauer <Christian.Bauer@uni-mainz.de>
Tue, 19 Jun 2001 19:50:02 +0000 (19:50 +0000)
  map() where useful. This has the nice effect of having a more reasonable
  default behaviour for container functions (most of the evalfchildren() etc.
  stuff is gone).
- diff() works with non-commutative products (product rule) and no longer
  bails on indexed objects.
- added decomp_rational()
- added sqrfree_parfrac() which doesn't yet work in the general case and
  is unsupported

21 files changed:
ginac/add.cpp
ginac/basic.cpp
ginac/basic.h
ginac/constant.cpp
ginac/container.pl
ginac/expairseq.cpp
ginac/expairseq.h
ginac/function.pl
ginac/idx.cpp
ginac/idx.h
ginac/indexed.cpp
ginac/indexed.h
ginac/matrix.cpp
ginac/matrix.h
ginac/ncmul.cpp
ginac/normal.cpp
ginac/normal.h
ginac/relational.cpp
ginac/relational.h
ginac/symbol.cpp
ginac/symbol.h

index 647de17..31e2f8f 100644 (file)
@@ -381,10 +381,19 @@ ex add::simplify_ncmul(const exvector & v) const
 
 /** Implementation of ex::diff() for a sum. It differentiates each term.
  *  @see ex::diff */
-ex add::derivative(const symbol & s) const
+ex add::derivative(const symbol & y) const
 {
-       // D(a+b+c)=D(a)+D(b)+D(c)
-       return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
+       epvector *s = new epvector();
+       s->reserve(seq.size());
+       
+       // Only differentiate the "rest" parts of the expairs. This is faster
+       // than the default implementation in basic::derivative() although
+       // if performs the same function (differentiate each term).
+       for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
+               s->push_back(combine_ex_with_coeff_to_pair((*it).rest.diff(y),
+                                                          (*it).coeff));
+       }
+       return (new add(s, _ex0()))->setflag(status_flags::dynallocated);
 }
 
 int add::compare_same_type(const basic & other) const
@@ -492,10 +501,9 @@ ex add::expand(unsigned options) const
                return *this;
        
        epvector * vp = expandchildren(options);
-       if (vp==0) {
+       if (vp == NULL) {
                // the terms have not changed, so it is safe to declare this expanded
-               setflag(status_flags::expanded);
-               return *this;
+               return this->setflag(status_flags::expanded);
        }
        
        return (new add(vp,overall_coeff))->setflag(status_flags::expanded | status_flags::dynallocated);
index 92bd160..70f84ea 100644 (file)
@@ -240,7 +240,7 @@ ex basic::map(map_function & f) const
 
        basic *copy = duplicate();
        copy->setflag(status_flags::dynallocated);
-       copy->clearflag(status_flags::hash_calculated);
+       copy->clearflag(status_flags::hash_calculated | status_flags::expanded);
        ex e(*copy);
        for (unsigned i=0; i<num; i++)
                e.let_op(i) = f(e.op(i));
@@ -356,24 +356,42 @@ ex basic::eval(int level) const
        return this->hold();
 }
 
+/** Function object to be applied by basic::evalf(). */
+struct evalf_map_function : public map_function {
+       int level;
+       evalf_map_function(int l) : level(l) {}
+       ex operator()(const ex & e) { return evalf(e, level); }
+};
+
 /** Evaluate object numerically. */
 ex basic::evalf(int level) const
 {
-       // There is nothing to do for basic objects:
-       return *this;
+       if (nops() == 0)
+               return *this;
+       else {
+               if (level == 1)
+                       return *this;
+               else if (level == -max_recursion_level)
+                       throw(std::runtime_error("max recursion level reached"));
+               else {
+                       evalf_map_function map_evalf(level - 1);
+                       return map(map_evalf);
+               }
+       }
 }
 
 /** Function object to be applied by basic::evalm(). */
 struct evalm_map_function : public map_function {
-       ex operator()(const ex & e) { return GiNaC::evalm(e); }
-} fcn;
+       ex operator()(const ex & e) { return evalm(e); }
+} map_evalm;
+
 /** Evaluate sums, products and integer powers of matrices. */
 ex basic::evalm(void) const
 {
        if (nops() == 0)
                return *this;
        else
-               return map(fcn);
+               return map(map_evalm);
 }
 
 /** Perform automatic symbolic evaluations on indexed expression that
@@ -548,13 +566,25 @@ ex basic::simplify_ncmul(const exvector & v) const
 
 // protected
 
-/** Default implementation of ex::diff(). It simply throws an error message.
+/** Function object to be applied by basic::derivative(). */
+struct derivative_map_function : public map_function {
+       const symbol &s;
+       derivative_map_function(const symbol &sym) : s(sym) {}
+       ex operator()(const ex & e) { return diff(e, s); }
+};
+
+/** Default implementation of ex::diff(). It maps the operation on the
+ *  operands (or returns 0 when the object has no operands).
  *
- *  @exception logic_error (differentiation not supported by this type)
  *  @see ex::diff */
 ex basic::derivative(const symbol & s) const
 {
-       throw(std::logic_error("differentiation not supported by this type"));
+       if (nops() == 0)
+               return _ex0();
+       else {
+               derivative_map_function map_derivative(s);
+               return map(map_derivative);
+       }
 }
 
 /** Returns order relation between two objects of same type.  This needs to be
@@ -613,11 +643,23 @@ unsigned basic::calchash(void) const
        return v;
 }
 
+/** Function object to be applied by basic::expand(). */
+struct expand_map_function : public map_function {
+       unsigned options;
+       expand_map_function(unsigned o) : options(o) {}
+       ex operator()(const ex & e) { return expand(e, options); }
+};
+
 /** Expand expression, i.e. multiply it out and return the result as a new
  *  expression. */
 ex basic::expand(unsigned options) const
 {
-       return this->setflag(status_flags::expanded);
+       if (nops() == 0)
+               return this->setflag(status_flags::expanded);
+       else {
+               expand_map_function map_expand(options);
+               return map(map_expand).bp->setflag(status_flags::expanded);
+       }
 }
 
 
index f53a4e3..84ed992 100644 (file)
@@ -112,6 +112,7 @@ public: // only const functions please (may break reference counting)
        virtual ex & let_op(int i);
        virtual ex operator[](const ex & index) const;
        virtual ex operator[](int i) const;
+       virtual ex expand(unsigned options = 0) const;
        virtual bool has(const ex & other) const;
        virtual ex map(map_function & f) const;
        virtual int degree(const ex & s) const;
@@ -130,19 +131,18 @@ public: // only const functions please (may break reference counting)
        virtual ex smod(const numeric &xi) const;
        virtual numeric max_coefficient(void) const;
        virtual exvector get_free_indices(void) const;
-       virtual ex simplify_ncmul(const exvector & v) const;
        virtual ex eval_indexed(const basic & i) const;
        virtual ex add_indexed(const ex & self, const ex & other) const;
        virtual ex scalar_mul_indexed(const ex & self, const numeric & other) const;
        virtual bool contract_with(exvector::iterator self, exvector::iterator other, exvector & v) const;
-protected: // non-const functions should be called from class ex only
+       virtual unsigned return_type(void) const;
+       virtual unsigned return_type_tinfo(void) const;
+protected: // functions that should be called from class ex only
        virtual ex derivative(const symbol & s) const;
        virtual int compare_same_type(const basic & other) const;
        virtual bool is_equal_same_type(const basic & other) const;
-       virtual unsigned return_type(void) const;
-       virtual unsigned return_type_tinfo(void) const;
        virtual unsigned calchash(void) const;
-       virtual ex expand(unsigned options = 0) const;
+       virtual ex simplify_ncmul(const exvector & v) const;
        
        // non-virtual functions in this class
 public:
index ba8b151..2292385 100644 (file)
@@ -83,7 +83,7 @@ constant::constant(const std::string & initname, evalffunctype efun, const std::
                TeX_name = "\\mbox{" + name + "}";
        else
                TeX_name = texname;
-       setflag(status_flags::evaluated);
+       setflag(status_flags::evaluated | status_flags::expanded);
 }
 
 constant::constant(const std::string & initname, const numeric & initnumber, const std::string & texname)
@@ -94,7 +94,7 @@ constant::constant(const std::string & initname, const numeric & initnumber, con
                TeX_name = "\\mbox{" + name + "}";
        else
                TeX_name = texname;
-       setflag(status_flags::evaluated);
+       setflag(status_flags::evaluated | status_flags::expanded);
 }
 
 //////////
@@ -180,7 +180,7 @@ ex constant::evalf(int level) const
 
 // protected
 
-/** Implementation of ex::diff() for a constant. It always returns 0.
+/** Implementation of ex::diff() for a constant always returns 0.
  *
  *  @see ex::diff */
 ex constant::derivative(const symbol & s) const
index 6db886c..7b8304b 100755 (executable)
@@ -211,11 +211,7 @@ public:
        unsigned nops() const;
        ex & let_op(int i);
        ex map(map_function & f) const;
-       ex expand(unsigned options=0) const;
        ex eval(int level=0) const;
-       ex evalf(int level=0) const;
-       ex normal(lst &sym_lst, lst &repl_lst, int level=0) const;
-       ex derivative(const symbol & s) const;
        ex subs(const lst & ls, const lst & lr, bool no_pattern = false) const;
 protected:
        bool is_equal_same_type(const basic & other) const;
@@ -236,9 +232,6 @@ protected:
 protected:
        bool is_canonical() const;
        ${STLT} evalchildren(int level) const;
-       ${STLT} evalfchildren(int level) const;
-       ${STLT} normalchildren(int level) const;
-       ${STLT} diffchildren(const symbol & s) const;
        ${STLT} * subschildren(const lst & ls, const lst & lr, bool no_pattern = false) const;
 
 protected:
@@ -452,17 +445,6 @@ ex ${CONTAINER}::map(map_function & f) const
        return this${CONTAINER}(s);
 }
 
-ex ${CONTAINER}::expand(unsigned options) const
-{
-       ${STLT} s;
-       RESERVE(s,seq.size());
-       for (${STLT}::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
-               s.push_back((*it).expand(options));
-       }
-
-       return this${CONTAINER}(s);
-}
-
 ex ${CONTAINER}::eval(int level) const
 {
        if (level==1) {
@@ -471,25 +453,6 @@ ex ${CONTAINER}::eval(int level) const
        return this${CONTAINER}(evalchildren(level));
 }
 
-ex ${CONTAINER}::evalf(int level) const
-{
-       return this${CONTAINER}(evalfchildren(level));
-}
-
-/** Implementation of ex::normal() for ${CONTAINER}s. It normalizes the arguments
- *  and replaces the ${CONTAINER} by a temporary symbol.
- *  \@see ex::normal */
-ex ${CONTAINER}::normal(lst &sym_lst, lst &repl_lst, int level) const
-{
-       ex n=this${CONTAINER}(normalchildren(level));
-       return n.bp->basic::normal(sym_lst,repl_lst,level);
-}
-
-ex ${CONTAINER}::derivative(const symbol & s) const
-{
-       return this${CONTAINER}(diffchildren(s));
-}
-
 ex ${CONTAINER}::subs(const lst & ls, const lst & lr, bool no_pattern) const
 {
        ${STLT} *vp = subschildren(ls, lr, no_pattern);
@@ -645,52 +608,6 @@ ${STLT} ${CONTAINER}::evalchildren(int level) const
        return s;
 }
 
-${STLT} ${CONTAINER}::evalfchildren(int level) const
-{
-       ${STLT} s;
-       RESERVE(s,seq.size());
-
-       if (level==1) {
-               return seq;
-       }
-       if (level == -max_recursion_level) {
-               throw(std::runtime_error("max recursion level reached"));
-       }
-       --level;
-       for (${STLT}::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
-               s.push_back((*it).evalf(level));
-       }
-       return s;
-}
-
-${STLT} ${CONTAINER}::normalchildren(int level) const
-{
-       ${STLT} s;
-       RESERVE(s,seq.size());
-
-       if (level==1) {
-               return seq;
-       }
-       if (level == -max_recursion_level) {
-               throw(std::runtime_error("max recursion level reached"));
-       }
-       --level;
-       for (${STLT}::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
-               s.push_back((*it).normal(level));
-       }
-       return s;
-}
-
-${STLT} ${CONTAINER}::diffchildren(const symbol & y) const
-{
-       ${STLT} s;
-       RESERVE(s,seq.size());
-       for (${STLT}::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
-               s.push_back((*it).diff(y));
-       }
-       return s;
-}
-
 ${STLT} * ${CONTAINER}::subschildren(const lst & ls, const lst & lr, bool no_pattern) const
 {
        // returns a NULL pointer if nothing had to be substituted
index 86f687d..605494a 100644 (file)
@@ -325,17 +325,6 @@ ex expairseq::eval(int level) const
        return (new expairseq(vp,overall_coeff))->setflag(status_flags::dynallocated | status_flags::evaluated);
 }
 
-ex expairseq::evalf(int level) const
-{
-       return thisexpairseq(evalfchildren(level),overall_coeff.evalf(level-1));
-}
-
-ex expairseq::normal(lst &sym_lst, lst &repl_lst, int level) const
-{
-       ex n = thisexpairseq(normalchildren(level),overall_coeff);
-       return n.bp->basic::normal(sym_lst,repl_lst,level);
-}
-
 bool expairseq::match(const ex & pattern, lst & repl_lst) const
 {
        // This differs from basic::match() because we want "a+b+c+d" to
@@ -420,14 +409,6 @@ ex expairseq::subs(const lst &ls, const lst &lr, bool no_pattern) const
 
 // protected
 
-/** Implementation of ex::diff() for an expairseq.
- *  It differentiates all elements of the sequence.
- *  @see ex::diff */
-ex expairseq::derivative(const symbol &s) const
-{
-       return thisexpairseq(diffchildren(s),overall_coeff);
-}
-
 int expairseq::compare_same_type(const basic &other) const
 {
        GINAC_ASSERT(is_of_type(other, expairseq));
@@ -592,13 +573,11 @@ unsigned expairseq::calchash(void) const
 ex expairseq::expand(unsigned options) const
 {
        epvector *vp = expandchildren(options);
-       if (vp==0) {
-               // the terms have not changed, so it is safe to declare this expanded
-               setflag(status_flags::expanded);
-               return *this;
-       }
-       
-       return thisexpairseq(vp,overall_coeff);
+       if (vp == NULL) {
+               // The terms have not changed, so it is safe to declare this expanded
+               return this->setflag(status_flags::expanded);
+       } else
+               return thisexpairseq(vp, overall_coeff);
 }
 
 //////////
@@ -1604,71 +1583,6 @@ epvector * expairseq::evalchildren(int level) const
 }
 
 
-/** Member-wise evaluate numerically all expairs in this sequence.
- *
- *  @see expairseq::evalf()
- *  @return epvector with all entries evaluated numerically. */
-epvector expairseq::evalfchildren(int level) const
-{
-       if (level==1)
-               return seq;
-       
-       if (level==-max_recursion_level)
-               throw(std::runtime_error("max recursion level reached"));
-       
-       epvector s;
-       s.reserve(seq.size());
-       
-       --level;
-       for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
-               s.push_back(combine_ex_with_coeff_to_pair((*it).rest.evalf(level),
-                                                         (*it).coeff.evalf(level)));
-       }
-       return s;
-}
-
-
-/** Member-wise normalize all expairs in this sequence.
- *
- *  @see expairseq::normal()
- *  @return epvector with all entries normalized. */
-epvector expairseq::normalchildren(int level) const
-{
-       if (level==1)
-               return seq;
-       
-       if (level==-max_recursion_level)
-               throw(std::runtime_error("max recursion level reached"));
-       
-       epvector s;
-       s.reserve(seq.size());
-       
-       --level;
-       for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
-               s.push_back(combine_ex_with_coeff_to_pair((*it).rest.normal(level),
-                                                         (*it).coeff));
-       }
-       return s;
-}
-
-
-/** Member-wise differentiate all expairs in this sequence.
- *
- *  @see expairseq::diff()
- *  @return epvector with all entries differentiated. */
-epvector expairseq::diffchildren(const symbol &y) const
-{
-       epvector s;
-       s.reserve(seq.size());
-       
-       for (epvector::const_iterator it=seq.begin(); it!=seq.end(); ++it) {
-               s.push_back(combine_ex_with_coeff_to_pair((*it).rest.diff(y),
-                                                         (*it).coeff));
-       }
-       return s;
-}
-
-
 /** Member-wise substitute in this sequence.
  *
  *  @see expairseq::subs()
index 8851a27..cebc016 100644 (file)
@@ -95,13 +95,10 @@ public:
        ex & let_op(int i);
        ex map(map_function & f) const;
        ex eval(int level=0) const;
-       ex evalf(int level=0) const;
-       ex normal(lst &sym_lst, lst &repl_lst, int level=0) const;
        ex to_rational(lst &repl_lst) const;
        bool match(const ex & pattern, lst & repl_lst) const;
        ex subs(const lst & ls, const lst & lr, bool no_pattern = false) const;
 protected:
-       ex derivative(const symbol & s) const;
        int compare_same_type(const basic & other) const;
        bool is_equal_same_type(const basic & other) const;
        unsigned return_type(void) const;
@@ -167,9 +164,6 @@ protected:
        bool is_canonical() const;
        epvector * expandchildren(unsigned options) const;
        epvector * evalchildren(int level) const;
-       epvector evalfchildren(int level) const;
-       epvector normalchildren(int level) const;
-       epvector diffchildren(const symbol & s) const;
        epvector * subschildren(const lst & ls, const lst & lr, bool no_pattern = false) const;
        
 // member variables
index c3dbcd6..33fc505 100755 (executable)
@@ -773,7 +773,20 @@ ex function::evalf(int level) const
 {
        GINAC_ASSERT(serial<registered_functions().size());
 
-       exvector eseq=evalfchildren(level);
+       // Evaluate children first
+       exvector eseq;
+       if (level == 1)
+               eseq = seq;
+       else if (level == -max_recursion_level)
+               throw(std::runtime_error("max recursion level reached"));
+       else
+               eseq.reserve(seq.size());
+       level--;
+       exvector::const_iterator it = seq.begin(), itend = seq.end();
+       while (it != itend) {
+               eseq.push_back((*it).evalf(level));
+               it++;
+       }
        
        if (registered_functions()[serial].evalf_f==0) {
                return function(serial,eseq).hold();
index df79607..67249e5 100644 (file)
@@ -316,6 +316,13 @@ int spinidx::compare_same_type(const basic & other) const
        return 0;
 }
 
+/** By default, basic::evalf would evaluate the index value but we don't want
+ *  a.1 to become a.(1.0). */
+ex idx::evalf(int level) const
+{
+       return *this;
+}
+
 bool idx::match(const ex & pattern, lst & repl_lst) const
 {
        if (!is_ex_of_type(pattern, idx))
@@ -377,6 +384,14 @@ ex idx::subs(const lst & ls, const lst & lr, bool no_pattern) const
        return i_copy->setflag(status_flags::dynallocated);
 }
 
+/** Implementation of ex::diff() for an index always returns 0.
+ *
+ *  @see ex::diff */
+ex idx::derivative(const symbol & s) const
+{
+       return _ex0();
+}
+
 //////////
 // new virtual functions
 //////////
index 79b9209..1c7bb9c 100644 (file)
@@ -51,9 +51,13 @@ public:
        bool info(unsigned inf) const;
        unsigned nops() const;
        ex & let_op(int i);
+       ex evalf(int level = 0) const;
        bool match(const ex & pattern, lst & repl_lst) const;
        ex subs(const lst & ls, const lst & lr, bool no_pattern = false) const;
 
+protected:
+       ex derivative(const symbol & s) const;
+
        // new virtual functions in this class
 public:
        /** Check whether the index forms a dummy index pair with another index
index 459caa9..4bbeeee 100644 (file)
@@ -411,6 +411,14 @@ void indexed::validate(void) const
        }
 }
 
+/** Implementation of ex::diff() for an indexed object always returns 0.
+ *
+ *  @see ex::diff */
+ex indexed::derivative(const symbol & s) const
+{
+       return _ex0();
+}
+
 //////////
 // global functions
 //////////
@@ -557,7 +565,8 @@ static ex rename_dummy_indices(const ex & e, exvector & global_dummy_indices, ex
        for (unsigned i=0; i<local_size; i++) {
                ex loc_sym = local_dummy_indices[i].op(0);
                ex glob_sym = global_dummy_indices[i].op(0);
-               if (!loc_sym.is_equal(glob_sym)) {
+               if (!loc_sym.is_equal(glob_sym)
+                && ex_to<idx>(local_dummy_indices[i]).get_dim().is_equal(ex_to<idx>(global_dummy_indices[i]).get_dim())) {
                        all_equal = false;
                        local_syms.append(loc_sym);
                        global_syms.append(glob_sym);
index 363bbe8..b5a0dce 100644 (file)
@@ -151,6 +151,7 @@ public:
        exvector get_free_indices(void) const;
 
 protected:
+       ex derivative(const symbol & s) const;
        ex thisexprseq(const exvector & v) const;
        ex thisexprseq(exvector * vp) const;
        unsigned return_type(void) const { return return_types::commutative; }
index 965e1da..9f9f67a 100644 (file)
@@ -198,16 +198,6 @@ ex & matrix::let_op(int i)
        return m[i];
 }
 
-/** expands the elements of a matrix entry by entry. */
-ex matrix::expand(unsigned options) const
-{
-       exvector tmp(row*col);
-       for (unsigned i=0; i<row*col; ++i)
-               tmp[i] = m[i].expand(options);
-       
-       return matrix(row, col, tmp);
-}
-
 /** Evaluate matrix entry by entry. */
 ex matrix::eval(int level) const
 {
@@ -232,30 +222,6 @@ ex matrix::eval(int level) const
                                                                                           status_flags::evaluated );
 }
 
-/** Evaluate matrix numerically entry by entry. */
-ex matrix::evalf(int level) const
-{
-       debugmsg("matrix evalf",LOGLEVEL_MEMBER_FUNCTION);
-               
-       // check if we have to do anything at all
-       if (level==1)
-               return *this;
-       
-       // emergency break
-       if (level == -max_recursion_level) {
-               throw (std::runtime_error("matrix::evalf(): recursion limit exceeded"));
-       }
-       
-       // evalf() entry by entry
-       exvector m2(row*col);
-       --level;
-       for (unsigned r=0; r<row; ++r)
-               for (unsigned c=0; c<col; ++c)
-                       m2[r*col+c] = m[r*col+c].evalf(level);
-       
-       return matrix(row, col, m2);
-}
-
 ex matrix::subs(const lst & ls, const lst & lr, bool no_pattern) const
 {
        exvector m2(row * col);
index b7f3f1c..fcf2d63 100644 (file)
@@ -46,9 +46,7 @@ public:
        unsigned nops() const;
        ex op(int i) const;
        ex & let_op(int i);
-       ex expand(unsigned options=0) const;
        ex eval(int level=0) const;
-       ex evalf(int level=0) const;
        ex evalm(void) const {return *this;}
        ex subs(const lst & ls, const lst & lr, bool no_pattern = false) const;
        ex eval_indexed(const basic & i) const;
index e4a9186..dbd5732 100644 (file)
@@ -475,11 +475,21 @@ ex ncmul::thisexprseq(exvector * vp) const
 
 // protected
 
-/** Implementation of ex::diff() for a non-commutative product. It always returns 0.
+/** Implementation of ex::diff() for a non-commutative product. It applies
+ *  the product rule.
  *  @see ex::diff */
 ex ncmul::derivative(const symbol & s) const
 {
-       return _ex0();
+       exvector addseq;
+       addseq.reserve(seq.size());
+       
+       // D(a*b*c) = D(a)*b*c + a*D(b)*c + a*b*D(c)
+       for (unsigned i=0; i!=seq.size(); ++i) {
+               exvector ncmulseq = seq;
+               ncmulseq[i] = seq[i].diff(s);
+               addseq.push_back((new ncmul(ncmulseq))->setflag(status_flags::dynallocated));
+       }
+       return (new add(addseq))->setflag(status_flags::dynallocated);
 }
 
 int ncmul::compare_same_type(const basic & other) const
index 18e24e9..619a232 100644 (file)
@@ -39,6 +39,7 @@
 #include "numeric.h"
 #include "power.h"
 #include "relational.h"
+#include "matrix.h"
 #include "pseries.h"
 #include "symbol.h"
 #include "utils.h"
@@ -415,7 +416,7 @@ ex rem(const ex &a, const ex &b, const symbol &x, bool check_args)
                if  (is_ex_exactly_of_type(b, numeric))
                        return _ex0();
                else
-                       return b;
+                       return a;
        }
 #if FAST_COMPARE
        if (a.is_equal(b))
@@ -450,6 +451,24 @@ ex rem(const ex &a, const ex &b, const symbol &x, bool check_args)
 }
 
 
+/** Decompose rational function a(x)=N(x)/D(x) into P(x)+n(x)/D(x)
+ *  with degree(n, x) < degree(D, x).
+ *
+ *  @param a rational function in x
+ *  @param x a is a function of x
+ *  @return decomposed function. */
+ex decomp_rational(const ex &a, const symbol &x)
+{
+       ex nd = numer_denom(a);
+       ex numer = nd.op(0), denom = nd.op(1);
+       ex q = quo(numer, denom, x);
+       if (is_ex_exactly_of_type(q, fail))
+               return a;
+       else
+               return q + rem(numer, denom, x) / denom;
+}
+
+
 /** Pseudo-remainder of polynomials a(x) and b(x) in Z[x].
  *
  *  @param a  first polynomial in x (dividend)
@@ -1717,6 +1736,7 @@ static exvector sqrfree_yun(const ex &a, const symbol &x)
        } while (!z.is_zero());
        return res;
 }
+
 /** Compute square-free factorization of multivariate polynomial in Q[X].
  *
  *  @param a  multivariate polynomial over Q[X]
@@ -1769,6 +1789,75 @@ ex sqrfree(const ex &a, const lst &l)
        return result * lcm.inverse();
 }
 
+/** Compute square-free partial fraction decomposition of rational function
+ *  a(x).
+ *
+ *  @param a rational function over Z[x], treated as univariate polynomial
+ *           in x
+ *  @param x variable to factor in
+ *  @return decomposed rational function */
+ex sqrfree_parfrac(const ex & a, const symbol & x)
+{
+       // Find numerator and denominator
+       ex nd = numer_denom(a);
+       ex numer = nd.op(0), denom = nd.op(1);
+//clog << "numer = " << numer << ", denom = " << denom << endl;
+
+       // Convert N(x)/D(x) -> Q(x) + R(x)/D(x), so degree(R) < degree(D)
+       ex red_poly = quo(numer, denom, x), red_numer = rem(numer, denom, x).expand();
+//clog << "red_poly = " << red_poly << ", red_numer = " << red_numer << endl;
+
+       // Factorize denominator and compute cofactors
+       exvector yun = sqrfree_yun(denom, x);
+//clog << "yun factors: " << exprseq(yun) << endl;
+       int num_yun = yun.size();
+       exvector factor; factor.reserve(num_yun);
+       exvector cofac; cofac.reserve(num_yun);
+       for (unsigned i=0; i<num_yun; i++) {
+               if (!yun[i].is_equal(_ex1())) {
+                       for (unsigned j=0; j<=i; j++) {
+                               factor.push_back(pow(yun[i], j+1));
+                               ex prod = 1;
+                               for (unsigned k=0; k<num_yun; k++) {
+                                       if (k == i)
+                                               prod *= pow(yun[k], i-j);
+                                       else
+                                               prod *= pow(yun[k], k+1);
+                               }
+                               cofac.push_back(prod.expand());
+                       }
+               }
+       }
+       int num_factors = factor.size();
+//clog << "factors  : " << exprseq(factor) << endl;
+//clog << "cofactors: " << exprseq(cofac) << endl;
+
+       // Construct coefficient matrix for decomposition
+       int max_denom_deg = denom.degree(x);
+       matrix sys(max_denom_deg + 1, num_factors);
+       matrix rhs(max_denom_deg + 1, 1);
+       for (unsigned i=0; i<=max_denom_deg; i++) {
+               for (unsigned j=0; j<num_factors; j++)
+                       sys(i, j) = cofac[j].coeff(x, i);
+               rhs(i, 0) = red_numer.coeff(x, i);
+       }
+//clog << "coeffs: " << sys << endl;
+//clog << "rhs   : " << rhs << endl;
+
+       // Solve resulting linear system
+       matrix vars(num_factors, 1);
+       for (unsigned i=0; i<num_factors; i++)
+               vars(i, 0) = symbol();
+       matrix sol = sys.solve(vars, rhs);
+
+       // Sum up decomposed fractions
+       ex sum = 0;
+       for (unsigned i=0; i<num_factors; i++)
+               sum += sol(i, 0) / factor[i];
+
+       return red_poly + sum;
+}
+
 
 /*
  *  Normal form of rational functions
@@ -1782,6 +1871,7 @@ ex sqrfree(const ex &a, const lst &l)
  *  the information that (a+b) is the numerator and 3 is the denominator.
  */
 
+
 /** Create a symbol for replacing the expression "e" (or return a previously
  *  assigned symbol). The symbol is appended to sym_lst and returned, the
  *  expression is appended to repl_lst.
@@ -1825,12 +1915,31 @@ static ex replace_with_symbol(const ex &e, lst &repl_lst)
        return es;
 }
 
-/** Default implementation of ex::normal(). It replaces the object with a
- *  temporary symbol.
+
+/** Function object to be applied by basic::normal(). */
+struct normal_map_function : public map_function {
+       int level;
+       normal_map_function(int l) : level(l) {}
+       ex operator()(const ex & e) { return normal(e, level); }
+};
+
+/** Default implementation of ex::normal(). It normalizes the children and
+ *  replaces the object with a temporary symbol.
  *  @see ex::normal */
 ex basic::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
-       return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
+       if (nops() == 0)
+               return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
+       else {
+               if (level == 1)
+                       return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
+               else if (level == -max_recursion_level)
+                       throw(std::runtime_error("max recursion level reached"));
+               else {
+                       normal_map_function map_normal(level - 1);
+                       return (new lst(replace_with_symbol(map(map_normal), sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
+               }
+       }
 }
 
 
@@ -2083,14 +2192,6 @@ ex pseries::normal(lst &sym_lst, lst &repl_lst, int level) const
 }
 
 
-/** Implementation of ex::normal() for relationals. It normalizes both sides.
- *  @see ex::normal */
-ex relational::normal(lst &sym_lst, lst &repl_lst, int level) const
-{
-       return (new lst(relational(lh.normal(), rh.normal(), o), _ex1()))->setflag(status_flags::dynallocated);
-}
-
-
 /** Normalization of rational functions.
  *  This function converts an expression to its normal form
  *  "numerator/denominator", where numerator and denominator are (relatively
index fb6960f..133addf 100644 (file)
@@ -39,6 +39,9 @@ extern ex quo(const ex &a, const ex &b, const symbol &x, bool check_args = true)
 // Remainder r(x) of polynomials a(x) and b(x) in Q[x], so that a(x)=b(x)*q(x)+r(x)
 extern ex rem(const ex &a, const ex &b, const symbol &x, bool check_args = true);
 
+// Decompose rational function a(x)=N(x)/D(x) into Q(x)+R(x)/D(x) with degree(R, x) < degree(D, x)
+extern ex decomp_rational(const ex &a, const symbol &x);
+
 // Pseudo-remainder of polynomials a(x) and b(x) in Z[x]
 extern ex prem(const ex &a, const ex &b, const symbol &x, bool check_args = true);
 
@@ -54,6 +57,9 @@ extern ex lcm(const ex &a, const ex &b, bool check_args = true);
 // Square-free factorization of a polynomial a(x)
 extern ex sqrfree(const ex &a, const lst &l = lst());
 
+// Square-free partial fraction decomposition of a rational function a(x)
+extern ex sqrfree_parfrac(const ex & a, const symbol & x);
+
 } // namespace GiNaC
 
 #endif // ndef __GINAC_NORMAL_H__
index 82b4bb9..0bc6c70 100644 (file)
@@ -183,17 +183,6 @@ ex relational::eval(int level) const
        return (new relational(lh.eval(level-1),rh.eval(level-1),o))->setflag(status_flags::dynallocated | status_flags::evaluated);
 }
 
-ex relational::evalf(int level) const
-{
-       if (level==1)
-               return *this;
-       
-       if (level==-max_recursion_level)
-               throw(std::runtime_error("max recursion level reached"));
-       
-       return (new relational(lh.eval(level-1),rh.eval(level-1),o))->setflag(status_flags::dynallocated);
-}
-
 ex relational::simplify_ncmul(const exvector & v) const
 {
        return lh.simplify_ncmul(v);
index a5188b5..0447e94 100644 (file)
@@ -57,8 +57,6 @@ public:
        unsigned nops() const;
        ex & let_op(int i);
        ex eval(int level=0) const;
-       ex evalf(int level=0) const;
-       ex normal(lst &sym_lst, lst &repl_lst, int level=0) const;
        ex simplify_ncmul(const exvector & v) const;
 protected:
        unsigned return_type(void) const;
index ea453a9..b15f9b7 100644 (file)
@@ -179,11 +179,6 @@ bool symbol::info(unsigned inf) const
                return inherited::info(inf);
 }
 
-ex symbol::expand(unsigned options) const
-{
-       return this->hold();
-}
-
 bool symbol::has(const ex & other) const
 {
        if (this->is_equal(*other.bp))
index 69d9b9e..9bbf3f2 100644 (file)
@@ -75,12 +75,12 @@ public:
        basic * duplicate() const;
        void print(const print_context & c, unsigned level = 0) const;
        bool info(unsigned inf) const;
-       ex expand(unsigned options = 0) const;
        bool has(const ex & other) const;
        int degree(const ex & s) const;
        int ldegree(const ex & s) const;
        ex coeff(const ex & s, int n = 1) const;
        ex eval(int level = 0) const;
+       ex evalf(int level = 0) const { return *this; } // overwrites basic::evalf() for performance reasons
        ex series(const relational & s, int order, unsigned options = 0) const;
        ex normal(lst &sym_lst, lst &repl_lst, int level = 0) const;
        ex to_rational(lst &repl_lst) const;