]> www.ginac.de Git - ginac.git/blobdiff - ginac/function.cppy
Explicit derivation of functions.
[ginac.git] / ginac / function.cppy
index d8a261f6ca3321e2d8f20ee7b586c3b0bf0372fe..dba9f4e0f096b8234234e540c109484ebdf260f8 100644 (file)
@@ -79,7 +79,7 @@ void function_options::initialize()
        set_name("unnamed_function", "\\\\mbox{unnamed}");
        nparams = 0;
        eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = expand_f
-               = derivative_f = power_f = series_f = 0;
+               = derivative_f = expl_derivative_f = power_f = series_f = 0;
        info_f = 0;
        evalf_params_first = true;
        use_return_type = false;
@@ -90,6 +90,7 @@ void function_options::initialize()
        imag_part_use_exvector_args = false;
        expand_use_exvector_args = false;
        derivative_use_exvector_args = false;
+       expl_derivative_use_exvector_args = false;
        power_use_exvector_args = false;
        series_use_exvector_args = false;
        print_use_exvector_args = false;
@@ -630,10 +631,10 @@ ex function::derivative(const symbol & s) const
 {
        ex result;
 
-       if (serial == Order_SERIAL::serial) {
-               // Order Term function only differentiates the argument
-               return Order(seq[0].diff(s));
-       } else {
+       try {
+               // Explicit derivation
+               result = expl_derivative(s);
+       } catch (...) {
                // Chain rule
                ex arg_diff;
                size_t num = seq.size();
@@ -752,6 +753,28 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation
        throw(std::logic_error("function::pderivative(): no diff function defined"));
 }
 
+ex function::expl_derivative(const symbol & s) const // explicit differentiation
+{
+       GINAC_ASSERT(serial<registered_functions().size());
+       const function_options &opt = registered_functions()[serial];
+
+       // No explicit derivative defined? Then this function shall not be called!
+       if (opt.expl_derivative_f == NULL)
+               throw(std::logic_error("function::expl_derivative(): explicit derivation is called, but no such function defined"));
+
+       current_serial = serial;
+       if (opt.expl_derivative_use_exvector_args)
+               return ((expl_derivative_funcp_exvector)(opt.expl_derivative_f))(seq, s);
+       switch (opt.nparams) {
+               // the following lines have been generated for max. @maxargs@ parameters
++++ for N in range(1, maxargs + 1):
+               case @N@:
+                       return ((expl_derivative_funcp_@N@)(opt.expl_derivative_f))(@seq('seq[%(n)d]', N, 0)@, s);
+---
+               // end of generated lines
+       }
+}
+
 ex function::power(const ex & power_param) const // power of function
 {
        GINAC_ASSERT(serial<registered_functions().size());