]> www.ginac.de Git - ginac.git/blobdiff - ginac/inifcns.cpp
Improve abs(arg).
[ginac.git] / ginac / inifcns.cpp
index 84cd2852ddd6bdb23160bf6d9ccf11e4352d0b72..ecb6e0072fb8333067ab5c54ba8d7453c0dacc97 100644 (file)
@@ -24,6 +24,7 @@
 #include "ex.h"
 #include "constant.h"
 #include "lst.h"
+#include "fderivative.h"
 #include "matrix.h"
 #include "mul.h"
 #include "power.h"
@@ -66,6 +67,19 @@ static ex conjugate_conjugate(const ex & arg)
        return arg;
 }
 
+// If x is real then U.diff(x)-I*V.diff(x) represents both conjugate(U+I*V).diff(x) 
+// and conjugate((U+I*V).diff(x))
+static ex conjugate_expl_derivative(const ex & arg, const symbol & s)
+{
+       if (s.info(info_flags::real))
+               return conjugate(arg.diff(s));
+       else {
+               exvector vec_arg;
+               vec_arg.push_back(arg);
+               return fderivative(ex_to<function>(conjugate(arg)).get_serial(),0,vec_arg).hold()*arg.diff(s);
+       }
+}
+
 static ex conjugate_real_part(const ex & arg)
 {
        return arg.real_part();
@@ -115,6 +129,7 @@ static bool conjugate_info(const ex & arg, unsigned inf)
 
 REGISTER_FUNCTION(conjugate_function, eval_func(conjugate_eval).
                                       evalf_func(conjugate_evalf).
+                                      expl_derivative_func(conjugate_expl_derivative).
                                       info_func(conjugate_info).
                                       print_func<print_latex>(conjugate_print_latex).
                                       conjugate_func(conjugate_conjugate).
@@ -159,8 +174,21 @@ static ex real_part_imag_part(const ex & arg)
        return 0;
 }
 
+// If x is real then Re(e).diff(x) is equal to Re(e.diff(x)) 
+static ex real_part_expl_derivative(const ex & arg, const symbol & s)
+{
+       if (s.info(info_flags::real))
+               return real_part_function(arg.diff(s));
+       else {
+               exvector vec_arg;
+               vec_arg.push_back(arg);
+               return fderivative(ex_to<function>(real_part(arg)).get_serial(),0,vec_arg).hold()*arg.diff(s);
+       }
+}
+
 REGISTER_FUNCTION(real_part_function, eval_func(real_part_eval).
                                       evalf_func(real_part_evalf).
+                                      expl_derivative_func(real_part_expl_derivative).
                                       print_func<print_latex>(real_part_print_latex).
                                       conjugate_func(real_part_conjugate).
                                       real_part_func(real_part_real_part).
@@ -204,8 +232,21 @@ static ex imag_part_imag_part(const ex & arg)
        return 0;
 }
 
+// If x is real then Im(e).diff(x) is equal to Im(e.diff(x)) 
+static ex imag_part_expl_derivative(const ex & arg, const symbol & s)
+{
+       if (s.info(info_flags::real))
+               return imag_part_function(arg.diff(s));
+       else {
+               exvector vec_arg;
+               vec_arg.push_back(arg);
+               return fderivative(ex_to<function>(imag_part(arg)).get_serial(),0,vec_arg).hold()*arg.diff(s);
+       }
+}
+
 REGISTER_FUNCTION(imag_part_function, eval_func(imag_part_eval).
                                       evalf_func(imag_part_evalf).
+                                      expl_derivative_func(imag_part_expl_derivative).
                                       print_func<print_latex>(imag_part_print_latex).
                                       conjugate_func(imag_part_conjugate).
                                       real_part_func(imag_part_real_part).
@@ -232,6 +273,9 @@ static ex abs_eval(const ex & arg)
        if (arg.info(info_flags::nonnegative))
                return arg;
 
+       if (arg.info(info_flags::negative) || (-arg).info(info_flags::nonnegative))
+               return -arg;
+
        if (is_ex_the_function(arg, abs))
                return arg;
 
@@ -275,6 +319,12 @@ static ex abs_expand(const ex & arg, unsigned options)
                return abs(arg).hold();
 }
 
+static ex abs_expl_derivative(const ex & arg, const symbol & s)
+{
+       ex diff_arg = arg.diff(s);
+       return (diff_arg*arg.conjugate()+arg*diff_arg.conjugate())/2/abs(arg);
+}
+
 static void abs_print_latex(const ex & arg, const print_context & c)
 {
        c.s << "{|"; arg.print(c); c.s << "|}";
@@ -341,6 +391,7 @@ bool abs_info(const ex & arg, unsigned inf)
 REGISTER_FUNCTION(abs, eval_func(abs_eval).
                        evalf_func(abs_evalf).
                        expand_func(abs_expand).
+                       expl_derivative_func(abs_expl_derivative).
                        info_func(abs_info).
                        print_func<print_latex>(abs_print_latex).
                        print_func<print_csrc_float>(abs_print_csrc_float).
@@ -977,11 +1028,15 @@ static ex Order_imag_part(const ex & x)
        return Order(x).hold();
 }
 
-// Differentiation is handled in function::derivative because of its special requirements
+static ex Order_expl_derivative(const ex & arg, const symbol & s)
+{
+       return Order(arg.diff(s));
+}
 
 REGISTER_FUNCTION(Order, eval_func(Order_eval).
                          series_func(Order_series).
                          latex_name("\\mathcal{O}").
+                         expl_derivative_func(Order_expl_derivative).
                          conjugate_func(Order_conjugate).
                          real_part_func(Order_real_part).
                          imag_part_func(Order_imag_part));