]> www.ginac.de Git - ginac.git/blobdiff - ginac/function.pl
- removed inert Diff() function; only Derivative() remains
[ginac.git] / ginac / function.pl
index cb0fe80f1d74b5e012c4b5be9acc2b094b4948e7..caa4175c3cc269b77d6619a654c5a8bd783a78f7 100755 (executable)
@@ -1,6 +1,6 @@
 #!/usr/bin/perl -w
 
-$maxargs=10;
+$maxargs=13;
 
 sub generate_seq {
     my ($seq_template,$n)=@_;
@@ -508,7 +508,11 @@ protected:
     unsigned serial;
 };
 
-// utility macros
+// utility functions/macros
+inline const function &ex_to_function(const ex &e)
+{
+    return static_cast<const function &>(*e.bp);
+}
 
 #ifndef NO_NAMESPACE_GINAC
 
@@ -568,6 +572,7 @@ $implementation=<<END_OF_IMPLEMENTATION;
 
 #include "function.h"
 #include "ex.h"
+#include "lst.h"
 #include "archive.h"
 #include "inifcns.h"
 #include "utils.h"
@@ -968,9 +973,24 @@ ex function::derivative(const symbol & s) const
 {
     ex result;
     
-    if (serial==function_index_Order) {
+    if (serial == function_index_Order) {
         // Order Term function only differentiates the argument
         return Order(seq[0].diff(s));
+    } else if (serial == function_index_Derivative) {
+        // Inert derivative performs chain rule on the first argument only, and
+        // adds differentiation parameter to list (second argument)
+        GINAC_ASSERT(is_ex_exactly_of_type(seq[0], function));
+        GINAC_ASSERT(is_ex_exactly_of_type(seq[1], function));
+        ex fcn = seq[0];
+        ex arg_diff;
+        for (unsigned i=0; i!=fcn.nops(); i++) {
+            arg_diff = fcn.op(i).diff(s);
+            if (!arg_diff.is_zero()) {
+                lst new_lst = ex_to_lst(seq[1]);
+                new_lst.append(i);
+                result += arg_diff * Derivative(fcn, new_lst);
+            }
+        }
     } else {
         // Chain rule
         ex arg_diff;
@@ -1041,7 +1061,7 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation
     GINAC_ASSERT(serial<registered_functions().size());
     
     if (registered_functions()[serial].derivative_f==0) {
-        throw(std::logic_error(string("function::pderivative(") + registered_functions()[serial].name + "): no diff function defined"));
+        return Derivative(*this, lst(diff_param));
     }
     switch (registered_functions()[serial].nparams) {
         // the following lines have been generated for max. ${maxargs} parameters