]> www.ginac.de Git - ginac.git/blobdiff - ginac/function.pl
- New figure classhierarchy.fig, which we all know, included in...
[ginac.git] / ginac / function.pl
index a680fc8685562b67630521aadb2dfa702eb2cc12..31893b2685e46991b685b0c4095bb270612c6706 100755 (executable)
@@ -1,6 +1,4 @@
-#!/usr/bin/perl -w
-
-$maxargs=10;
+$maxargs=13;
 
 sub generate_seq {
     my ($seq_template,$n)=@_;
@@ -155,7 +153,7 @@ $typedef_derivative_funcp=generate(
 'const ex &','');
 
 $typedef_series_funcp=generate(
-'typedef ex (* series_funcp_${N})(${SEQ1}, const symbol &, const ex &, int);'."\n",
+'typedef ex (* series_funcp_${N})(${SEQ1}, const relational &, int);'."\n",
 'const ex &','');
 
 $eval_func_interface=generate('    function_options & eval_func(eval_funcp_${N} e);'."\n",'','');
@@ -205,9 +203,9 @@ $series_switch_statement=generate(
     <<'END_OF_SERIES_SWITCH_STATEMENT','seq[${N}-1]','');
     case ${N}:
         try {
-            res = ((series_funcp_${N})(registered_functions()[serial].series_f))(${SEQ1},s,point,order);
+            res = ((series_funcp_${N})(registered_functions()[serial].series_f))(${SEQ1},r,order);
         } catch (do_taylor) {
-            res = basic::series(s, point, order);
+            res = basic::series(r, order);
         }
         return res;
         break;
@@ -479,7 +477,7 @@ public:
     ex expand(unsigned options=0) const;
     ex eval(int level=0) const;
     ex evalf(int level=0) const;
-    ex series(const symbol & s, const ex & point, int order) const;
+    ex series(const relational & r, int order) const;
     ex thisexprseq(const exvector & v) const;
     ex thisexprseq(exvector * vp) const;
 protected:
@@ -500,6 +498,7 @@ protected:
     void store_remember_table(ex const & result) const;
 public:
     static unsigned register_new(function_options const & opt);
+    static unsigned find_function(const string &name, unsigned nparams);
     unsigned getserial(void) const {return serial;}
     
 // member variables
@@ -508,7 +507,11 @@ protected:
     unsigned serial;
 };
 
-// utility macros
+// utility functions/macros
+inline const function &ex_to_function(const ex &e)
+{
+    return static_cast<const function &>(*e.bp);
+}
 
 #ifndef NO_NAMESPACE_GINAC
 
@@ -568,15 +571,21 @@ $implementation=<<END_OF_IMPLEMENTATION;
 
 #include "function.h"
 #include "ex.h"
+#include "lst.h"
 #include "archive.h"
 #include "inifcns.h"
 #include "utils.h"
 #include "debugmsg.h"
+#include "remember.h"
 
 #ifndef NO_NAMESPACE_GINAC
 namespace GiNaC {
 #endif // ndef NO_NAMESPACE_GINAC
 
+//////////
+// helper class function_options
+//////////
+
 function_options::function_options()
 {
     initialize();
@@ -668,113 +677,6 @@ void function_options::test_and_set_nparams(unsigned n)
     }
 }
 
-class remember_table_entry {
-public:
-    remember_table_entry(function const & f, ex const & r) :
-        hashvalue(f.gethash()), seq(f.seq), result(r)
-    {
-        last_access=0;
-        successful_hits=0;
-    }
-    bool is_equal(function const & f) const
-    {
-        GINAC_ASSERT(f.seq.size()==seq.size());
-        if (f.gethash()!=hashvalue) return false;
-        for (unsigned i=0; i<seq.size(); ++i) {
-            if (!seq[i].is_equal(f.seq[i])) return false;
-        }
-        last_access=access_counter++;
-        successful_hits++;
-        return true;
-    }
-    unsigned hashvalue;
-    exvector seq;
-    ex result;
-    mutable unsigned long last_access;
-    mutable unsigned successful_hits;
-
-    static unsigned access_counter;
-};    
-
-unsigned remember_table_entry::access_counter=0;
-
-class remember_table_list : public list<remember_table_entry> {
-public:
-    remember_table_list()
-    {
-        max_assoc_size=0;
-        delete_strategy=0;
-    }
-    remember_table_list(unsigned as, unsigned strat)
-    {
-        max_assoc_size=as;
-        delete_strategy=strat;
-    }
-    void add_entry(function const & f, ex const & result)
-    {
-        push_back(remember_table_entry(f,result));
-    }        
-    bool lookup_entry(function const & f, ex & result) const
-    {
-        for (const_iterator cit=begin(); cit!=end(); ++cit) {
-            if (cit->is_equal(f)) {
-                result=cit->result;
-                return true;
-            }
-        }
-        return false;
-    }
-protected:
-    unsigned max_assoc_size;
-    unsigned delete_strategy;
-};
-
-
-class remember_table : public vector<remember_table_list> {
-public:
-    remember_table()
-    {
-    }
-    remember_table(unsigned s, unsigned as, unsigned strat)
-    {
-        calc_size(s);
-        reserve(table_size);
-        for (unsigned i=0; i<table_size; ++i) {
-            push_back(remember_table_list(as,strat));
-        }
-    }
-    bool lookup_entry(function const & f, ex & result) const
-    {
-        unsigned entry=f.gethash() & (table_size-1);
-        if (entry>=size()) {
-            cerr << "entry=" << entry << ",size=" << size() << endl;
-        }
-        GINAC_ASSERT(entry<size());
-        return operator[](entry).lookup_entry(f,result);
-    }
-    void add_entry(function const & f, ex const & result)
-    {
-        unsigned entry=f.gethash() & (table_size-1);
-        GINAC_ASSERT(entry<size());
-        operator[](entry).add_entry(f,result);
-    }        
-    void calc_size(unsigned s)
-    {
-        // use some power of 2 next to s
-        table_size=1 << log2(s);
-    }
-protected:
-    unsigned table_size;
-};      
-
-// this is not declared as a static function in the class function
-// (like registered_function()) because of issues with cint
-static vector<remember_table> & remember_tables(void)
-{
-    static vector<remember_table> * rt=new vector<remember_table>;
-    return *rt;
-}
-
 GINAC_IMPLEMENT_REGISTERED_CLASS(function, exprseq)
 
 //////////
@@ -1044,12 +946,12 @@ ex function::thisexprseq(exvector * vp) const
 
 /** Implementation of ex::series for functions.
  *  \@see ex::series */
-ex function::series(const symbol & s, const ex & point, int order) const
+ex function::series(const relational & r, int order) const
 {
     GINAC_ASSERT(serial<registered_functions().size());
 
     if (registered_functions()[serial].series_f==0) {
-        return basic::series(s, point, order);
+        return basic::series(r, order);
     }
     ex res;
     switch (registered_functions()[serial].nparams) {
@@ -1070,9 +972,24 @@ ex function::derivative(const symbol & s) const
 {
     ex result;
     
-    if (serial==function_index_Order) {
+    if (serial == function_index_Order) {
         // Order Term function only differentiates the argument
         return Order(seq[0].diff(s));
+    } else if (serial == function_index_Derivative) {
+        // Inert derivative performs chain rule on the first argument only, and
+        // adds differentiation parameter to list (second argument)
+        GINAC_ASSERT(is_ex_exactly_of_type(seq[0], function));
+        GINAC_ASSERT(is_ex_exactly_of_type(seq[1], function));
+        ex fcn = seq[0];
+        ex arg_diff;
+        for (unsigned i=0; i!=fcn.nops(); i++) {
+            arg_diff = fcn.op(i).diff(s);
+            if (!arg_diff.is_zero()) {
+                lst new_lst = ex_to_lst(seq[1]);
+                new_lst.append(i);
+                result += arg_diff * Derivative(fcn, new_lst);
+            }
+        }
     } else {
         // Chain rule
         ex arg_diff;
@@ -1143,7 +1060,7 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation
     GINAC_ASSERT(serial<registered_functions().size());
     
     if (registered_functions()[serial].derivative_f==0) {
-        throw(std::logic_error(string("function::pderivative(") + registered_functions()[serial].name + "): no diff function defined"));
+        return Derivative(*this, lst(diff_param));
     }
     switch (registered_functions()[serial].nparams) {
         // the following lines have been generated for max. ${maxargs} parameters
@@ -1161,12 +1078,12 @@ vector<function_options> & function::registered_functions(void)
 
 bool function::lookup_remember_table(ex & result) const
 {
-    return remember_tables()[serial].lookup_entry(*this,result);
+    return remember_table::remember_tables()[serial].lookup_entry(*this,result);
 }
 
 void function::store_remember_table(ex const & result) const
 {
-    remember_tables()[serial].add_entry(*this,result);
+    remember_table::remember_tables()[serial].add_entry(*this,result);
 }
 
 // public
@@ -1188,15 +1105,31 @@ unsigned function::register_new(function_options const & opt)
     }
     registered_functions().push_back(opt);
     if (opt.use_remember) {
-        remember_tables().push_back(remember_table(opt.remember_size,
-                                                   opt.remember_assoc_size,
-                                                   opt.remember_strategy));
+        remember_table::remember_tables().
+            push_back(remember_table(opt.remember_size,
+                                     opt.remember_assoc_size,
+                                     opt.remember_strategy));
     } else {
-        remember_tables().push_back(remember_table());
+        remember_table::remember_tables().push_back(remember_table());
     }
     return registered_functions().size()-1;
 }
 
+/** Find serial number of function by name and number of parameters.
+ *  Throws exception if function was not found. */
+unsigned function::find_function(const string &name, unsigned nparams)
+{
+    vector<function_options>::const_iterator i = function::registered_functions().begin(), end = function::registered_functions().end();
+    unsigned serial = 0;
+    while (i != end) {
+        if (i->get_name() == name && i->get_nparams() == nparams)
+            return serial;
+        i++;
+        serial++;
+    }
+    throw (std::runtime_error("no function '" + name + "' with " + ToString(nparams) + " parameters defined"));
+}
+
 //////////
 // static member variables
 //////////