]> www.ginac.de Git - ginac.git/commitdiff
Explicit derivation of functions.
authorVladimir V. Kisil <kisilv@maths.leeds.ac.uk>
Sun, 8 Feb 2015 19:50:51 +0000 (20:50 +0100)
committerRichard Kreckel <kreckel@ginac.de>
Sun, 8 Feb 2015 19:52:57 +0000 (20:52 +0100)
Some function cannot be cleanly differentiated through the chain rule.
For example, it is natural to define derivative of the absolute value as

(abs(f))'=(f'*f.conjugate()+f*f'.conjugate())/2/abs(f)

This patch adds a possibility to define derivatives of functions in this way.
In particular the derivative of abs(), Order(), real_part(), imag_part() and
conjugate() are defined.

For example, conjugate of a derivative with respect of a real symbol
If x is real then U.diff(x)-I*V.diff(x) represents both
conjugate(U+I*V).diff(x) and conjugate((U+I*V).diff(x))
Thus in this patch we use the rule

conjugate(f)'=conjugate(f')

for a derivative with respect to the real symbol.

Signed-off-by: Vladimir V. Kisil <kisilv@maths.leeds.ac.uk>
check/exam_inifcns.cpp
doc/tutorial/ginac.texi
ginac/function.cppy
ginac/function.hppy
ginac/function.py
ginac/inifcns.cpp

index 19ad1b98def75b8b334b6a4d79fb29da166158c9..a0acb2de8af183d1f1094b06c7d9f506e5c37b6b 100644 (file)
@@ -343,6 +343,71 @@ static unsigned inifcns_consist_various()
        return result;
 }
 
        return result;
 }
 
+/* Several tests for derivetives */
+static unsigned inifcns_consist_derivatives()
+{
+       unsigned result = 0;
+       symbol z, w;
+       realsymbol x;
+       ex e, e1;
+
+       e=pow(x,z).conjugate().diff(x);
+       e1=pow(x,z).conjugate()*z.conjugate()/x;
+       if (! (e-e1).normal().is_zero() ) {
+               clog << "ERROR: pow(x,z).conjugate().diff(x) " << e << " != " << e1 << endl;
+               ++result;
+       }
+
+       e=pow(w,z).conjugate().diff(w);
+       e1=pow(w,z).conjugate()*z.conjugate()/w;
+       if ( (e-e1).normal().is_zero() ) {
+               clog << "ERROR: pow(w,z).conjugate().diff(w) " << e << " = " << e1 << endl;
+               ++result;
+       }
+
+       e=atanh(x).imag_part().diff(x);
+       if (! e.is_zero() ) {
+               clog << "ERROR: atanh(x).imag_part().diff(x) " << e << " != 0" << endl;
+               ++result;
+       }
+
+       e=atanh(w).imag_part().diff(w);
+       if ( e.is_zero() ) {
+               clog << "ERROR: atanh(w).imag_part().diff(w) " << e << " = 0" << endl;
+               ++result;
+       }
+
+       e=atanh(x).real_part().diff(x);
+       e1=pow(1-x*x,-1);
+       if (! (e-e1).normal().is_zero() ) {
+               clog << "ERROR: atanh(x).real_part().diff(x) " << e << " != " << e1 << endl;
+               ++result;
+       }
+
+       e=atanh(w).real_part().diff(w);
+       e1=pow(1-w*w,-1);
+       if ( (e-e1).normal().is_zero() ) {
+               clog << "ERROR: atanh(w).real_part().diff(w) " << e << " = " << e1 << endl;
+               ++result;
+       }
+
+       e=abs(log(z)).diff(z);
+       e1=(conjugate(log(z))/z+log(z)/conjugate(z))/abs(log(z))/2;
+       if (! (e-e1).normal().is_zero() ) {
+               clog << "ERROR: abs(log(z)).diff(z) " << e << " != " << e1 << endl;
+               ++result;
+       }
+
+       e=Order(pow(x,4)).diff(x);
+       e1=Order(pow(x,3));
+       if (! (e-e1).normal().is_zero() ) {
+               clog << "ERROR: Order(pow(x,4)).diff(x) " << e << " != " << e1 << endl;
+               ++result;
+       }
+
+       return result;
+}
+
 unsigned exam_inifcns()
 {
        unsigned result = 0;
 unsigned exam_inifcns()
 {
        unsigned result = 0;
@@ -357,6 +422,7 @@ unsigned exam_inifcns()
        result += inifcns_consist_exp();  cout << '.' << flush;
        result += inifcns_consist_log();  cout << '.' << flush;
        result += inifcns_consist_various();  cout << '.' << flush;
        result += inifcns_consist_exp();  cout << '.' << flush;
        result += inifcns_consist_log();  cout << '.' << flush;
        result += inifcns_consist_various();  cout << '.' << flush;
+       result += inifcns_consist_derivatives();  cout << '.' << flush;
        
        return result;
 }
        
        return result;
 }
index 21e31b2c98501fc7581ceed6c3013e978dbbc3f1..3ac53981fe9e5602d1ddb1b11b923b66fedaa78d 100644 (file)
@@ -7103,6 +7103,25 @@ specifies which parameter to differentiate in a partial derivative in
 case the function has more than one parameter, and its main application
 is for correct handling of the chain rule.
 
 case the function has more than one parameter, and its main application
 is for correct handling of the chain rule.
 
+Derivatives of some functions, for example @code{abs()} and
+@code{Order()}, could not be evaluated through the chain rule. In such
+cases the full derivative may be specified as shown for @code{Order()}:
+
+@example
+static ex Order_expl_derivative(const ex & arg, const symbol & s)
+@{
+       return Order(arg.diff(s));
+@}
+@end example
+
+That is, we need to supply a procedure, which returns the expression of
+derivative with respect to the variable @code{s} for the argument
+@code{arg}. This procedure need to be registered with the function
+through the option @code{expl_derivative_func} (see the next
+Subsection). In contrast, a partial derivative, e.g. as was defined for
+@code{cos()} above, needs to be registered through the option
+@code{derivative_func}. 
+
 An implementation of the series expansion is not needed for @code{cos()} as
 it doesn't have any poles and GiNaC can do Taylor expansion by itself (as
 long as it knows what the derivative of @code{cos()} is). @code{tan()}, on
 An implementation of the series expansion is not needed for @code{cos()} as
 it doesn't have any poles and GiNaC can do Taylor expansion by itself (as
 long as it knows what the derivative of @code{cos()} is). @code{tan()}, on
@@ -7138,14 +7157,15 @@ functions without any special options.
 eval_func(<C++ function>)
 evalf_func(<C++ function>)
 derivative_func(<C++ function>)
 eval_func(<C++ function>)
 evalf_func(<C++ function>)
 derivative_func(<C++ function>)
+expl_derivative_func(<C++ function>)
 series_func(<C++ function>)
 conjugate_func(<C++ function>)
 @end example
 
 These specify the C++ functions that implement symbolic evaluation,
 series_func(<C++ function>)
 conjugate_func(<C++ function>)
 @end example
 
 These specify the C++ functions that implement symbolic evaluation,
-numeric evaluation, partial derivatives, and series expansion, respectively.
-They correspond to the GiNaC methods @code{eval()}, @code{evalf()},
-@code{diff()} and @code{series()}.
+numeric evaluation, partial derivatives, explicit derivative, and series
+expansion, respectively.  They correspond to the GiNaC methods
+@code{eval()}, @code{evalf()}, @code{diff()} and @code{series()}.
 
 The @code{eval_func()} function needs to use @code{.hold()} if no further
 automatic evaluation is desired or possible.
 
 The @code{eval_func()} function needs to use @code{.hold()} if no further
 automatic evaluation is desired or possible.
index d8a261f6ca3321e2d8f20ee7b586c3b0bf0372fe..dba9f4e0f096b8234234e540c109484ebdf260f8 100644 (file)
@@ -79,7 +79,7 @@ void function_options::initialize()
        set_name("unnamed_function", "\\\\mbox{unnamed}");
        nparams = 0;
        eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = expand_f
        set_name("unnamed_function", "\\\\mbox{unnamed}");
        nparams = 0;
        eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = expand_f
-               = derivative_f = power_f = series_f = 0;
+               = derivative_f = expl_derivative_f = power_f = series_f = 0;
        info_f = 0;
        evalf_params_first = true;
        use_return_type = false;
        info_f = 0;
        evalf_params_first = true;
        use_return_type = false;
@@ -90,6 +90,7 @@ void function_options::initialize()
        imag_part_use_exvector_args = false;
        expand_use_exvector_args = false;
        derivative_use_exvector_args = false;
        imag_part_use_exvector_args = false;
        expand_use_exvector_args = false;
        derivative_use_exvector_args = false;
+       expl_derivative_use_exvector_args = false;
        power_use_exvector_args = false;
        series_use_exvector_args = false;
        print_use_exvector_args = false;
        power_use_exvector_args = false;
        series_use_exvector_args = false;
        print_use_exvector_args = false;
@@ -630,10 +631,10 @@ ex function::derivative(const symbol & s) const
 {
        ex result;
 
 {
        ex result;
 
-       if (serial == Order_SERIAL::serial) {
-               // Order Term function only differentiates the argument
-               return Order(seq[0].diff(s));
-       } else {
+       try {
+               // Explicit derivation
+               result = expl_derivative(s);
+       } catch (...) {
                // Chain rule
                ex arg_diff;
                size_t num = seq.size();
                // Chain rule
                ex arg_diff;
                size_t num = seq.size();
@@ -752,6 +753,28 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation
        throw(std::logic_error("function::pderivative(): no diff function defined"));
 }
 
        throw(std::logic_error("function::pderivative(): no diff function defined"));
 }
 
+ex function::expl_derivative(const symbol & s) const // explicit differentiation
+{
+       GINAC_ASSERT(serial<registered_functions().size());
+       const function_options &opt = registered_functions()[serial];
+
+       // No explicit derivative defined? Then this function shall not be called!
+       if (opt.expl_derivative_f == NULL)
+               throw(std::logic_error("function::expl_derivative(): explicit derivation is called, but no such function defined"));
+
+       current_serial = serial;
+       if (opt.expl_derivative_use_exvector_args)
+               return ((expl_derivative_funcp_exvector)(opt.expl_derivative_f))(seq, s);
+       switch (opt.nparams) {
+               // the following lines have been generated for max. @maxargs@ parameters
++++ for N in range(1, maxargs + 1):
+               case @N@:
+                       return ((expl_derivative_funcp_@N@)(opt.expl_derivative_f))(@seq('seq[%(n)d]', N, 0)@, s);
+---
+               // end of generated lines
+       }
+}
+
 ex function::power(const ex & power_param) const // power of function
 {
        GINAC_ASSERT(serial<registered_functions().size());
 ex function::power(const ex & power_param) const // power of function
 {
        GINAC_ASSERT(serial<registered_functions().size());
index 6259d7a6f40fa25903484cec92c7fbfea0594641..971786d23a57efde2701899fb0de2523e21feabf 100644 (file)
@@ -59,6 +59,7 @@ typedef ex (* real_part_funcp)();
 typedef ex (* imag_part_funcp)();
 typedef ex (* expand_funcp)();
 typedef ex (* derivative_funcp)();
 typedef ex (* imag_part_funcp)();
 typedef ex (* expand_funcp)();
 typedef ex (* derivative_funcp)();
+typedef ex (* expl_derivative_funcp)();
 typedef ex (* power_funcp)();
 typedef ex (* series_funcp)();
 typedef void (* print_funcp)();
 typedef ex (* power_funcp)();
 typedef ex (* series_funcp)();
 typedef void (* print_funcp)();
@@ -73,6 +74,7 @@ typedef ex (* real_part_funcp_@N@)( @args@ );
 typedef ex (* imag_part_funcp_@N@)( @args@ );
 typedef ex (* expand_funcp_@N@)( @args@, unsigned );
 typedef ex (* derivative_funcp_@N@)( @args@, unsigned );
 typedef ex (* imag_part_funcp_@N@)( @args@ );
 typedef ex (* expand_funcp_@N@)( @args@, unsigned );
 typedef ex (* derivative_funcp_@N@)( @args@, unsigned );
+typedef ex (* expl_derivative_funcp_@N@)( @args@, const symbol & );
 typedef ex (* power_funcp_@N@)( @args@, const ex & );
 typedef ex (* series_funcp_@N@)( @args@, const relational &, int, unsigned );
 typedef void (* print_funcp_@N@)( @args@, const print_context & );
 typedef ex (* power_funcp_@N@)( @args@, const ex & );
 typedef ex (* series_funcp_@N@)( @args@, const relational &, int, unsigned );
 typedef void (* print_funcp_@N@)( @args@, const print_context & );
@@ -87,6 +89,7 @@ typedef ex (* @fp@_funcp_exvector)(const exvector &);
 ---
 typedef ex (* expand_funcp_exvector)(const exvector &, unsigned);
 typedef ex (* derivative_funcp_exvector)(const exvector &, unsigned);
 ---
 typedef ex (* expand_funcp_exvector)(const exvector &, unsigned);
 typedef ex (* derivative_funcp_exvector)(const exvector &, unsigned);
+typedef ex (* expl_derivative_funcp_exvector)(const exvector &, const symbol &);
 typedef ex (* power_funcp_exvector)(const exvector &, const ex &);
 typedef ex (* series_funcp_exvector)(const exvector &, const relational &, int, unsigned);
 typedef void (* print_funcp_exvector)(const exvector &, const print_context &);
 typedef ex (* power_funcp_exvector)(const exvector &, const ex &);
 typedef ex (* series_funcp_exvector)(const exvector &, const relational &, int, unsigned);
 typedef void (* print_funcp_exvector)(const exvector &, const print_context &);
@@ -159,6 +162,7 @@ protected:
        imag_part_funcp imag_part_f;
        expand_funcp expand_f;
        derivative_funcp derivative_f;
        imag_part_funcp imag_part_f;
        expand_funcp expand_f;
        derivative_funcp derivative_f;
+       expl_derivative_funcp expl_derivative_f;
        power_funcp power_f;
        series_funcp series_f;
        std::vector<print_funcp> print_dispatch_table;
        power_funcp power_f;
        series_funcp series_f;
        std::vector<print_funcp> print_dispatch_table;
@@ -182,6 +186,7 @@ protected:
        bool imag_part_use_exvector_args;
        bool expand_use_exvector_args;
        bool derivative_use_exvector_args;
        bool imag_part_use_exvector_args;
        bool expand_use_exvector_args;
        bool derivative_use_exvector_args;
+       bool expl_derivative_use_exvector_args;
        bool power_use_exvector_args;
        bool series_use_exvector_args;
        bool print_use_exvector_args;
        bool power_use_exvector_args;
        bool series_use_exvector_args;
        bool print_use_exvector_args;
@@ -251,6 +256,7 @@ protected:
        // non-virtual functions in this class
 protected:
        ex pderivative(unsigned diff_param) const; // partial differentiation
        // non-virtual functions in this class
 protected:
        ex pderivative(unsigned diff_param) const; // partial differentiation
+       ex expl_derivative(const symbol & s) const; // partial differentiation
        static std::vector<function_options> & registered_functions();
        bool lookup_remember_table(ex & result) const;
        void store_remember_table(ex const & result) const;
        static std::vector<function_options> & registered_functions();
        bool lookup_remember_table(ex & result) const;
        void store_remember_table(ex const & result) const;
index 3f5e54eb6028d95fef63e72c880c1cc6d17f2ab7..465976b349c28b650afc34e698d92234ff70f2c5 100755 (executable)
@@ -2,7 +2,7 @@
 # encoding: utf-8
 
 maxargs = 14
 # encoding: utf-8
 
 maxargs = 14
-methods = "eval evalf conjugate real_part imag_part expand derivative power series info print".split()
+methods = "eval evalf conjugate real_part imag_part expand derivative expl_derivative power series info print".split()
 
 import sys, os, optparse
 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'scripts'))
 
 import sys, os, optparse
 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'scripts'))
index 84cd2852ddd6bdb23160bf6d9ccf11e4352d0b72..28d54fed3e79924dffbd95c457add5ff2bbd662a 100644 (file)
@@ -24,6 +24,7 @@
 #include "ex.h"
 #include "constant.h"
 #include "lst.h"
 #include "ex.h"
 #include "constant.h"
 #include "lst.h"
+#include "fderivative.h"
 #include "matrix.h"
 #include "mul.h"
 #include "power.h"
 #include "matrix.h"
 #include "mul.h"
 #include "power.h"
@@ -66,6 +67,19 @@ static ex conjugate_conjugate(const ex & arg)
        return arg;
 }
 
        return arg;
 }
 
+// If x is real then U.diff(x)-I*V.diff(x) represents both conjugate(U+I*V).diff(x) 
+// and conjugate((U+I*V).diff(x))
+static ex conjugate_expl_derivative(const ex & arg, const symbol & s)
+{
+       if (s.info(info_flags::real))
+               return conjugate(arg.diff(s));
+       else {
+               exvector vec_arg;
+               vec_arg.push_back(arg);
+               return fderivative(ex_to<function>(conjugate(arg)).get_serial(),0,vec_arg).hold()*arg.diff(s);
+       }
+}
+
 static ex conjugate_real_part(const ex & arg)
 {
        return arg.real_part();
 static ex conjugate_real_part(const ex & arg)
 {
        return arg.real_part();
@@ -115,6 +129,7 @@ static bool conjugate_info(const ex & arg, unsigned inf)
 
 REGISTER_FUNCTION(conjugate_function, eval_func(conjugate_eval).
                                       evalf_func(conjugate_evalf).
 
 REGISTER_FUNCTION(conjugate_function, eval_func(conjugate_eval).
                                       evalf_func(conjugate_evalf).
+                                      expl_derivative_func(conjugate_expl_derivative).
                                       info_func(conjugate_info).
                                       print_func<print_latex>(conjugate_print_latex).
                                       conjugate_func(conjugate_conjugate).
                                       info_func(conjugate_info).
                                       print_func<print_latex>(conjugate_print_latex).
                                       conjugate_func(conjugate_conjugate).
@@ -159,8 +174,21 @@ static ex real_part_imag_part(const ex & arg)
        return 0;
 }
 
        return 0;
 }
 
+// If x is real then Re(e).diff(x) is equal to Re(e.diff(x)) 
+static ex real_part_expl_derivative(const ex & arg, const symbol & s)
+{
+       if (s.info(info_flags::real))
+               return real_part_function(arg.diff(s));
+       else {
+               exvector vec_arg;
+               vec_arg.push_back(arg);
+               return fderivative(ex_to<function>(real_part(arg)).get_serial(),0,vec_arg).hold()*arg.diff(s);
+       }
+}
+
 REGISTER_FUNCTION(real_part_function, eval_func(real_part_eval).
                                       evalf_func(real_part_evalf).
 REGISTER_FUNCTION(real_part_function, eval_func(real_part_eval).
                                       evalf_func(real_part_evalf).
+                                      expl_derivative_func(real_part_expl_derivative).
                                       print_func<print_latex>(real_part_print_latex).
                                       conjugate_func(real_part_conjugate).
                                       real_part_func(real_part_real_part).
                                       print_func<print_latex>(real_part_print_latex).
                                       conjugate_func(real_part_conjugate).
                                       real_part_func(real_part_real_part).
@@ -204,8 +232,21 @@ static ex imag_part_imag_part(const ex & arg)
        return 0;
 }
 
        return 0;
 }
 
+// If x is real then Im(e).diff(x) is equal to Im(e.diff(x)) 
+static ex imag_part_expl_derivative(const ex & arg, const symbol & s)
+{
+       if (s.info(info_flags::real))
+               return imag_part_function(arg.diff(s));
+       else {
+               exvector vec_arg;
+               vec_arg.push_back(arg);
+               return fderivative(ex_to<function>(imag_part(arg)).get_serial(),0,vec_arg).hold()*arg.diff(s);
+       }
+}
+
 REGISTER_FUNCTION(imag_part_function, eval_func(imag_part_eval).
                                       evalf_func(imag_part_evalf).
 REGISTER_FUNCTION(imag_part_function, eval_func(imag_part_eval).
                                       evalf_func(imag_part_evalf).
+                                      expl_derivative_func(imag_part_expl_derivative).
                                       print_func<print_latex>(imag_part_print_latex).
                                       conjugate_func(imag_part_conjugate).
                                       real_part_func(imag_part_real_part).
                                       print_func<print_latex>(imag_part_print_latex).
                                       conjugate_func(imag_part_conjugate).
                                       real_part_func(imag_part_real_part).
@@ -275,6 +316,12 @@ static ex abs_expand(const ex & arg, unsigned options)
                return abs(arg).hold();
 }
 
                return abs(arg).hold();
 }
 
+static ex abs_expl_derivative(const ex & arg, const symbol & s)
+{
+       ex diff_arg = arg.diff(s);
+       return (diff_arg*arg.conjugate()+arg*diff_arg.conjugate())/2/abs(arg);
+}
+
 static void abs_print_latex(const ex & arg, const print_context & c)
 {
        c.s << "{|"; arg.print(c); c.s << "|}";
 static void abs_print_latex(const ex & arg, const print_context & c)
 {
        c.s << "{|"; arg.print(c); c.s << "|}";
@@ -341,6 +388,7 @@ bool abs_info(const ex & arg, unsigned inf)
 REGISTER_FUNCTION(abs, eval_func(abs_eval).
                        evalf_func(abs_evalf).
                        expand_func(abs_expand).
 REGISTER_FUNCTION(abs, eval_func(abs_eval).
                        evalf_func(abs_evalf).
                        expand_func(abs_expand).
+                       expl_derivative_func(abs_expl_derivative).
                        info_func(abs_info).
                        print_func<print_latex>(abs_print_latex).
                        print_func<print_csrc_float>(abs_print_csrc_float).
                        info_func(abs_info).
                        print_func<print_latex>(abs_print_latex).
                        print_func<print_csrc_float>(abs_print_csrc_float).
@@ -977,11 +1025,15 @@ static ex Order_imag_part(const ex & x)
        return Order(x).hold();
 }
 
        return Order(x).hold();
 }
 
-// Differentiation is handled in function::derivative because of its special requirements
+static ex Order_expl_derivative(const ex & arg, const symbol & s)
+{
+       return Order(arg.diff(s));
+}
 
 REGISTER_FUNCTION(Order, eval_func(Order_eval).
                          series_func(Order_series).
                          latex_name("\\mathcal{O}").
 
 REGISTER_FUNCTION(Order, eval_func(Order_eval).
                          series_func(Order_series).
                          latex_name("\\mathcal{O}").
+                         expl_derivative_func(Order_expl_derivative).
                          conjugate_func(Order_conjugate).
                          real_part_func(Order_real_part).
                          imag_part_func(Order_imag_part));
                          conjugate_func(Order_conjugate).
                          real_part_func(Order_real_part).
                          imag_part_func(Order_imag_part));