]> www.ginac.de Git - ginac.git/blobdiff - ginac/function.cppy
Replace use of NULL by C++11 nullptr.
[ginac.git] / ginac / function.cppy
index fb83cdc4f05b1f2683589a87e2f4d8fb928504d9..45f33ec16b53ba32d6e020067bf708f0ba262ced 100644 (file)
@@ -7,7 +7,7 @@
  *  Please do not modify it directly, edit function.cppy instead!
  *  function.py options: maxargs=@maxargs@
  *
- *  GiNaC Copyright (C) 1999-2010 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -78,8 +78,9 @@ void function_options::initialize()
 {
        set_name("unnamed_function", "\\\\mbox{unnamed}");
        nparams = 0;
-       eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = derivative_f
-               = power_f = series_f = 0;
+       eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = expand_f
+               = derivative_f = expl_derivative_f = power_f = series_f = 0;
+       info_f = 0;
        evalf_params_first = true;
        use_return_type = false;
        eval_use_exvector_args = false;
@@ -87,10 +88,13 @@ void function_options::initialize()
        conjugate_use_exvector_args = false;
        real_part_use_exvector_args = false;
        imag_part_use_exvector_args = false;
+       expand_use_exvector_args = false;
        derivative_use_exvector_args = false;
+       expl_derivative_use_exvector_args = false;
        power_use_exvector_args = false;
        series_use_exvector_args = false;
        print_use_exvector_args = false;
+       info_use_exvector_args = false;
        use_remember = false;
        functions_with_same_name = 1;
        symtree = 0;
@@ -243,8 +247,8 @@ function::function(unsigned ser, const exvector & v, bool discardable)
 {
 }
 
-function::function(unsigned ser, std::auto_ptr<exvector> vp) 
-  : exprseq(vp), serial(ser)
+function::function(unsigned ser, exvector && v)
+  : exprseq(std::move(v)), serial(ser)
 {
 }
 
@@ -300,7 +304,7 @@ void function::print(const print_context & c, unsigned level) const
 
 next_context:
        unsigned id = pc_info->options.get_id();
-       if (id >= pdt.size() || pdt[id] == NULL) {
+       if (id >= pdt.size() || pdt[id] == nullptr) {
 
                // Method not found, try parent print_context class
                const print_context_class_info *parent_pc_info = pc_info->get_parent();
@@ -360,15 +364,6 @@ next_context:
        }
 }
 
-ex function::expand(unsigned options) const
-{
-       // Only expand arguments when asked to do so
-       if (options & expand_options::expand_function_args)
-               return inherited::expand(options);
-       else
-               return (options == 0) ? setflag(status_flags::expanded) : *this;
-}
-
 ex function::eval(int level) const
 {
        if (level>1) {
@@ -490,9 +485,9 @@ ex function::thiscontainer(const exvector & v) const
        return function(serial, v);
 }
 
-ex function::thiscontainer(std::auto_ptr<exvector> vp) const
+ex function::thiscontainer(exvector && v) const
 {
-       return function(serial, vp);
+       return function(serial, std::move(v));
 }
 
 /** Implementation of ex::series for functions.
@@ -602,6 +597,31 @@ ex function::imag_part() const
        throw(std::logic_error("function::imag_part(): invalid nparams"));
 }
 
+/** Implementation of ex::info for functions. */
+bool function::info(unsigned inf) const
+{
+       GINAC_ASSERT(serial<registered_functions().size());
+       const function_options & opt = registered_functions()[serial];
+
+       if (opt.info_f==0) {
+               return basic::info(inf);
+       }
+
+       if (opt.info_use_exvector_args) {
+               return ((info_funcp_exvector)(opt.info_f))(seq, inf);
+       }
+
+       switch (opt.nparams) {
+               // the following lines have been generated for max. @maxargs@ parameters
++++ for N in range(1, maxargs + 1):
+               case @N@:
+                       return ((info_funcp_@N@)(opt.info_f))(@seq('seq[%(n)d]', N, 0)@, inf);
+---
+               // end of generated lines
+       }
+       throw(std::logic_error("function::info(): invalid nparams"));
+}
+
 // protected
 
 /** Implementation of ex::diff() for functions. It applies the chain rule,
@@ -611,10 +631,10 @@ ex function::derivative(const symbol & s) const
 {
        ex result;
 
-       if (serial == Order_SERIAL::serial) {
-               // Order Term function only differentiates the argument
-               return Order(seq[0].diff(s));
-       } else {
+       try {
+               // Explicit derivation
+               result = expl_derivative(s);
+       } catch (...) {
                // Chain rule
                ex arg_diff;
                size_t num = seq.size();
@@ -716,7 +736,7 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation
        const function_options &opt = registered_functions()[serial];
        
        // No derivative defined? Then return abstract derivative object
-       if (opt.derivative_f == NULL)
+       if (opt.derivative_f == nullptr)
                return fderivative(serial, diff_param, seq);
 
        current_serial = serial;
@@ -733,12 +753,34 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation
        throw(std::logic_error("function::pderivative(): no diff function defined"));
 }
 
+ex function::expl_derivative(const symbol & s) const // explicit differentiation
+{
+       GINAC_ASSERT(serial<registered_functions().size());
+       const function_options &opt = registered_functions()[serial];
+
+       // No explicit derivative defined? Then this function shall not be called!
+       if (opt.expl_derivative_f == nullptr)
+               throw(std::logic_error("function::expl_derivative(): explicit derivation is called, but no such function defined"));
+
+       current_serial = serial;
+       if (opt.expl_derivative_use_exvector_args)
+               return ((expl_derivative_funcp_exvector)(opt.expl_derivative_f))(seq, s);
+       switch (opt.nparams) {
+               // the following lines have been generated for max. @maxargs@ parameters
++++ for N in range(1, maxargs + 1):
+               case @N@:
+                       return ((expl_derivative_funcp_@N@)(opt.expl_derivative_f))(@seq('seq[%(n)d]', N, 0)@, s);
+---
+               // end of generated lines
+       }
+}
+
 ex function::power(const ex & power_param) const // power of function
 {
        GINAC_ASSERT(serial<registered_functions().size());
        const function_options &opt = registered_functions()[serial];
        
-       if (opt.power_f == NULL)
+       if (opt.power_f == nullptr)
                return (new GiNaC::power(*this, power_param))->setflag(status_flags::dynallocated |
                                                                       status_flags::evaluated);
 
@@ -756,6 +798,34 @@ ex function::power(const ex & power_param) const // power of function
        throw(std::logic_error("function::power(): no power function defined"));
 }
 
+ex function::expand(unsigned options) const
+{
+       GINAC_ASSERT(serial<registered_functions().size());
+       const function_options &opt = registered_functions()[serial];
+
+       // No expand defined? Then return the same function with expanded arguments (if required)
+       if (opt.expand_f == nullptr) {
+               // Only expand arguments when asked to do so
+               if (options & expand_options::expand_function_args)
+                       return inherited::expand(options);
+               else
+                       return (options == 0) ? setflag(status_flags::expanded) : *this;
+       }
+
+       current_serial = serial;
+       if (opt.expand_use_exvector_args)
+               return ((expand_funcp_exvector)(opt.expand_f))(seq,  options);
+       switch (opt.nparams) {
+               // the following lines have been generated for max. @maxargs@ parameters
++++ for N in range(1, maxargs + 1):
+               case @N@:
+                       return ((expand_funcp_@N@)(opt.expand_f))(@seq('seq[%(n)d]', N, 0)@, options);
+---
+               // end of generated lines
+       }
+       throw(std::logic_error("function::expand(): no expand of function defined"));
+}
+
 std::vector<function_options> & function::registered_functions()
 {
        static std::vector<function_options> rf = std::vector<function_options>();