]> www.ginac.de Git - ginac.git/blobdiff - ginac/function.cppy
Make add::eval(), mul::eval() work without compromise.
[ginac.git] / ginac / function.cppy
index b8259f9810f260d8937a12d781ab4d25d4484651..789ca056c2dfa9dcebd7e0d86801b7a59e518ee9 100644 (file)
@@ -7,7 +7,7 @@
  *  Please do not modify it directly, edit function.cppy instead!
  *  function.py options: maxargs=@maxargs@
  *
- *  GiNaC Copyright (C) 1999-2014 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -34,7 +34,6 @@
 #include "power.h"
 #include "archive.h"
 #include "inifcns.h"
-#include "tostring.h"
 #include "utils.h"
 #include "hash_seed.h"
 #include "remember.h"
@@ -79,7 +78,7 @@ void function_options::initialize()
        set_name("unnamed_function", "\\\\mbox{unnamed}");
        nparams = 0;
        eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = expand_f
-               = derivative_f = power_f = series_f = 0;
+               = derivative_f = expl_derivative_f = power_f = series_f = nullptr;
        info_f = 0;
        evalf_params_first = true;
        use_return_type = false;
@@ -90,6 +89,7 @@ void function_options::initialize()
        imag_part_use_exvector_args = false;
        expand_use_exvector_args = false;
        derivative_use_exvector_args = false;
+       expl_derivative_use_exvector_args = false;
        power_use_exvector_args = false;
        series_use_exvector_args = false;
        print_use_exvector_args = false;
@@ -228,7 +228,7 @@ function::function(unsigned ser) : serial(ser)
 // the following lines have been generated for max. @maxargs@ parameters
 +++ for N in range(1, maxargs + 1):
 function::function(unsigned ser, @seq('const ex & param%(n)d', N)@)
-       : exprseq(@seq('param%(n)d', N)@), serial(ser)
+       : exprseq{@seq('param%(n)d', N)@}, serial(ser)
 {
 }
 ---
@@ -241,13 +241,13 @@ function::function(unsigned ser, const exprseq & es) : exprseq(es), serial(ser)
        clearflag(status_flags::evaluated);
 }
 
-function::function(unsigned ser, const exvector & v, bool discardable) 
-  : exprseq(v,discardable), serial(ser)
+function::function(unsigned ser, const exvector & v)
+  : exprseq(v), serial(ser)
 {
 }
 
-function::function(unsigned ser, std::auto_ptr<exvector> vp) 
-  : exprseq(vp), serial(ser)
+function::function(unsigned ser, exvector && v)
+  : exprseq(std::move(v)), serial(ser)
 {
 }
 
@@ -263,13 +263,12 @@ void function::read_archive(const archive_node& n, lst& sym_lst)
        std::string s;
        if (n.find_string("name", s)) {
                unsigned int ser = 0;
-               std::vector<function_options>::const_iterator i = registered_functions().begin(), iend = registered_functions().end();
-               while (i != iend) {
-                       if (s == i->name) {
+               for (auto & it : registered_functions()) {
+                       if (s == it.name) {
                                serial = ser;
                                return;
                        }
-                       ++i; ++ser;
+                       ++ser;
                }
                throw (std::runtime_error("unknown function '" + s + "' in archive"));
        } else
@@ -303,7 +302,7 @@ void function::print(const print_context & c, unsigned level) const
 
 next_context:
        unsigned id = pc_info->options.get_id();
-       if (id >= pdt.size() || pdt[id] == NULL) {
+       if (id >= pdt.size() || pdt[id] == nullptr) {
 
                // Method not found, try parent print_context class
                const print_context_class_info *parent_pc_info = pc_info->get_parent();
@@ -365,6 +364,10 @@ next_context:
 
 ex function::eval(int level) const
 {
+       if ((level == 1) && (flags & status_flags::evaluated)) {
+               return *this;
+       }
+
        if (level>1) {
                // first evaluate children, then we will end up here again
                return function(serial,evalchildren(level));
@@ -386,7 +389,7 @@ ex function::eval(int level) const
                }
        }
 
-       if (opt.eval_f==0) {
+       if (opt.eval_f==nullptr) {
                return this->hold();
        }
 
@@ -430,10 +433,8 @@ ex function::evalf(int level) const
        else {
                eseq.reserve(seq.size());
                --level;
-               exvector::const_iterator it = seq.begin(), itend = seq.end();
-               while (it != itend) {
-                       eseq.push_back(it->evalf(level));
-                       ++it;
+               for (auto & it : seq) {
+                       eseq.push_back(it.evalf(level));
                }
        }
 
@@ -484,9 +485,9 @@ ex function::thiscontainer(const exvector & v) const
        return function(serial, v);
 }
 
-ex function::thiscontainer(std::auto_ptr<exvector> vp) const
+ex function::thiscontainer(exvector && v) const
 {
-       return function(serial, vp);
+       return function(serial, std::move(v));
 }
 
 /** Implementation of ex::series for functions.
@@ -630,10 +631,10 @@ ex function::derivative(const symbol & s) const
 {
        ex result;
 
-       if (serial == Order_SERIAL::serial) {
-               // Order Term function only differentiates the argument
-               return Order(seq[0].diff(s));
-       } else {
+       try {
+               // Explicit derivation
+               result = expl_derivative(s);
+       } catch (...) {
                // Chain rule
                ex arg_diff;
                size_t num = seq.size();
@@ -734,22 +735,45 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation
        GINAC_ASSERT(serial<registered_functions().size());
        const function_options &opt = registered_functions()[serial];
        
-       // No derivative defined? Then return abstract derivative object
-       if (opt.derivative_f == NULL)
-               return fderivative(serial, diff_param, seq);
+       if (opt.derivative_f) {
+               // Invoke the defined derivative function.
+               current_serial = serial;
+               if (opt.derivative_use_exvector_args)
+                       return ((derivative_funcp_exvector)(opt.derivative_f))(seq, diff_param);
+               switch (opt.nparams) {
+                       // the following lines have been generated for max. @maxargs@ parameters
++++ for N in range(1, maxargs + 1):
+                       case @N@:
+                               return ((derivative_funcp_@N@)(opt.derivative_f))(@seq('seq[%(n)d]', N, 0)@, diff_param);
+---
+                       // end of generated lines
+               }
+       }
+       // No derivative defined? Fall back to abstract derivative object.
+       return fderivative(serial, diff_param, seq);
+}
 
-       current_serial = serial;
-       if (opt.derivative_use_exvector_args)
-               return ((derivative_funcp_exvector)(opt.derivative_f))(seq, diff_param);
-       switch (opt.nparams) {
-               // the following lines have been generated for max. @maxargs@ parameters
+ex function::expl_derivative(const symbol & s) const // explicit differentiation
+{
+       GINAC_ASSERT(serial<registered_functions().size());
+       const function_options &opt = registered_functions()[serial];
+
+       if (opt.expl_derivative_f) {
+               // Invoke the defined explicit derivative function.
+               current_serial = serial;
+               if (opt.expl_derivative_use_exvector_args)
+                       return ((expl_derivative_funcp_exvector)(opt.expl_derivative_f))(seq, s);
+               switch (opt.nparams) {
+                       // the following lines have been generated for max. @maxargs@ parameters
 +++ for N in range(1, maxargs + 1):
-               case @N@:
-                       return ((derivative_funcp_@N@)(opt.derivative_f))(@seq('seq[%(n)d]', N, 0)@, diff_param);
+                       case @N@:
+                               return ((expl_derivative_funcp_@N@)(opt.expl_derivative_f))(@seq('seq[%(n)d]', N, 0)@, s);
 ---
-               // end of generated lines
+                       // end of generated lines
+               }
        }
-       throw(std::logic_error("function::pderivative(): no diff function defined"));
+       // There is no fallback for explicit deriviative.
+       throw(std::logic_error("function::expl_derivative(): explicit derivation is called, but no such function defined"));
 }
 
 ex function::power(const ex & power_param) const // power of function
@@ -757,22 +781,22 @@ ex function::power(const ex & power_param) const // power of function
        GINAC_ASSERT(serial<registered_functions().size());
        const function_options &opt = registered_functions()[serial];
        
-       if (opt.power_f == NULL)
-               return (new GiNaC::power(*this, power_param))->setflag(status_flags::dynallocated |
-                                                                      status_flags::evaluated);
-
-       current_serial = serial;
-       if (opt.power_use_exvector_args)
-               return ((power_funcp_exvector)(opt.power_f))(seq,  power_param);
-       switch (opt.nparams) {
-               // the following lines have been generated for max. @maxargs@ parameters
+       if (opt.power_f) {
+               // Invoke the defined power function.
+               current_serial = serial;
+               if (opt.power_use_exvector_args)
+                       return ((power_funcp_exvector)(opt.power_f))(seq,  power_param);
+               switch (opt.nparams) {
+                       // the following lines have been generated for max. @maxargs@ parameters
 +++ for N in range(1, maxargs + 1):
-               case @N@:
-                       return ((power_funcp_@N@)(opt.power_f))(@seq('seq[%(n)d]', N, 0)@, power_param);
+                       case @N@:
+                               return ((power_funcp_@N@)(opt.power_f))(@seq('seq[%(n)d]', N, 0)@, power_param);
 ---
-               // end of generated lines
+                       // end of generated lines
+               }
        }
-       throw(std::logic_error("function::power(): no power function defined"));
+       // No power function defined? Fall back to returning a power object.
+       return dynallocate<GiNaC::power>(*this, power_param).setflag(status_flags::evaluated);
 }
 
 ex function::expand(unsigned options) const
@@ -780,27 +804,25 @@ ex function::expand(unsigned options) const
        GINAC_ASSERT(serial<registered_functions().size());
        const function_options &opt = registered_functions()[serial];
 
-       // No expand defined? Then return the same function with expanded arguments (if required)
-       if (opt.expand_f == NULL) {
-               // Only expand arguments when asked to do so
-               if (options & expand_options::expand_function_args)
-                       return inherited::expand(options);
-               else
-                       return (options == 0) ? setflag(status_flags::expanded) : *this;
-       }
-
-       current_serial = serial;
-       if (opt.expand_use_exvector_args)
-               return ((expand_funcp_exvector)(opt.expand_f))(seq,  options);
-       switch (opt.nparams) {
-               // the following lines have been generated for max. @maxargs@ parameters
+       if (opt.expand_f) {
+               // Invoke the defined expand function.
+               current_serial = serial;
+               if (opt.expand_use_exvector_args)
+                       return ((expand_funcp_exvector)(opt.expand_f))(seq,  options);
+               switch (opt.nparams) {
+                       // the following lines have been generated for max. @maxargs@ parameters
 +++ for N in range(1, maxargs + 1):
-               case @N@:
-                       return ((expand_funcp_@N@)(opt.expand_f))(@seq('seq[%(n)d]', N, 0)@, options);
+                       case @N@:
+                               return ((expand_funcp_@N@)(opt.expand_f))(@seq('seq[%(n)d]', N, 0)@, options);
 ---
-               // end of generated lines
+                       // end of generated lines
+               }
        }
-       throw(std::logic_error("function::expand(): no expand of function defined"));
+       // No expand function defined? Return the same function with expanded arguments (if required)
+       if (options & expand_options::expand_function_args)
+               return inherited::expand(options);
+       else
+               return (options == 0) ? setflag(status_flags::expanded) : *this;
 }
 
 std::vector<function_options> & function::registered_functions()
@@ -852,15 +874,13 @@ unsigned function::register_new(function_options const & opt)
  *  Throws exception if function was not found. */
 unsigned function::find_function(const std::string &name, unsigned nparams)
 {
-       std::vector<function_options>::const_iterator i = function::registered_functions().begin(), end = function::registered_functions().end();
        unsigned serial = 0;
-       while (i != end) {
-               if (i->get_name() == name && i->get_nparams() == nparams)
+       for (auto & it : function::registered_functions()) {
+               if (it.get_name() == name && it.get_nparams() == nparams)
                        return serial;
-               ++i;
                ++serial;
        }
-       throw (std::runtime_error("no function '" + name + "' with " + ToString(nparams) + " parameters defined"));
+       throw (std::runtime_error("no function '" + name + "' with " + std::to_string(nparams) + " parameters defined"));
 }
 
 /** Return the print name of the function. */