]> www.ginac.de Git - ginac.git/blobdiff - ginac/function.pl
- Pearu Peterson's patch for class function giving them better support
[ginac.git] / ginac / function.pl
index 26f37afdc29db6bd3e07cd6382bb51731749d79d..68673ddc38e9fb0032eefa1fb733a17f6f623201 100755 (executable)
@@ -240,6 +240,13 @@ $typedef_derivative_funcp
 $typedef_series_funcp
 // end of generated lines
 
+// Alternatively, an exvector may be passed into the static function, instead
+// of individual ex objects.  Then, the number of arguments is not limited.
+typedef ex (* eval_funcp_exvector)(const exvector &);
+typedef ex (* evalf_funcp_exvector)(const exvector &);
+typedef ex (* derivative_funcp_exvector)(const exvector &, unsigned);
+typedef ex (* series_funcp_exvector)(const exvector &, const relational &, int, unsigned);
+
 class function_options
 {
        friend class function;
@@ -257,6 +264,11 @@ $evalf_func_interface
 $derivative_func_interface
 $series_func_interface
 // end of generated lines
+       function_options & eval_func(eval_funcp_exvector e);
+       function_options & evalf_func(evalf_funcp_exvector ef);
+       function_options & derivative_func(derivative_funcp_exvector d);
+       function_options & series_func(series_funcp_exvector s);
+
        function_options & set_return_type(unsigned rt, unsigned rtt=0);
        function_options & do_not_evalf_params(void);
        function_options & remember(unsigned size, unsigned assoc_size=0,
@@ -290,6 +302,11 @@ protected:
        unsigned remember_assoc_size;
        unsigned remember_strategy;
 
+       bool eval_use_exvector_args;
+       bool evalf_use_exvector_args;
+       bool derivative_use_exvector_args;
+       bool series_use_exvector_args;
+
        unsigned functions_with_same_name;
 
        ex symtree;
@@ -354,10 +371,11 @@ protected:
        void store_remember_table(ex const & result) const;
 public:
        static unsigned register_new(function_options const & opt);
+       static unsigned current_serial;
        static unsigned find_function(const std::string &name, unsigned nparams);
        unsigned get_serial(void) const {return serial;}
        std::string get_name(void) const;
-       
+
 // member variables
 
 protected:
@@ -365,12 +383,6 @@ protected:
 };
 
 // utility functions/macros
-/** Return the object of type function handled by an ex.
- *  This is unsafe: you need to check the type first. */
-inline const function &ex_to_function(const ex &e)
-{
-       return static_cast<const function &>(*e.bp);
-}
 
 /** Specialization of is_exactly_a<function>(obj) for objects of type function. */
 template<> inline bool is_exactly_a<function>(const basic & obj)
@@ -379,7 +391,7 @@ template<> inline bool is_exactly_a<function>(const basic & obj)
 }
 
 #define is_ex_the_function(OBJ, FUNCNAME) \\
-       (is_exactly_a<function>(OBJ) && static_cast<GiNaC::function *>(OBJ.bp)->get_serial() == function_index_##FUNCNAME)
+       (is_exactly_a<GiNaC::function>(OBJ) && ex_to<GiNaC::function>(OBJ).get_serial() == function_index_##FUNCNAME)
 
 } // namespace GiNaC
 
@@ -460,6 +472,10 @@ void function_options::initialize(void)
        eval_f = evalf_f = derivative_f = series_f = 0;
        evalf_params_first = true;
        use_return_type = false;
+       eval_use_exvector_args = false;
+       evalf_use_exvector_args = false;
+       derivative_use_exvector_args = false;
+       series_use_exvector_args = false;
        use_remember = false;
        functions_with_same_name = 1;
        symtree = 0;
@@ -468,7 +484,7 @@ void function_options::initialize(void)
 function_options & function_options::set_name(std::string const & n,
                                               std::string const & tn)
 {
-       name=n;
+       name = n;
        if (tn==std::string())
                TeX_name = "\\\\mbox{"+name+"}";
        else
@@ -478,7 +494,7 @@ function_options & function_options::set_name(std::string const & n,
 
 function_options & function_options::latex_name(std::string const & tn)
 {
-       TeX_name=tn;
+       TeX_name = tn;
        return *this;
 }
 
@@ -489,6 +505,32 @@ $derivative_func_implementation
 $series_func_implementation
 // end of generated lines
 
+function_options& function_options::eval_func(eval_funcp_exvector e)
+{
+       eval_use_exvector_args = true;
+       eval_f = eval_funcp(e);
+       return *this;
+}
+function_options& function_options::evalf_func(evalf_funcp_exvector ef)
+{
+       evalf_use_exvector_args = true;
+       evalf_f = evalf_funcp(ef);
+       return *this;
+}
+function_options& function_options::derivative_func(derivative_funcp_exvector d)
+{
+       derivative_use_exvector_args = true;
+       derivative_f = derivative_funcp(d);
+       return *this;
+}
+function_options& function_options::series_func(series_funcp_exvector s)
+{
+       series_use_exvector_args = true;
+       series_f = series_funcp(s);
+       return *this;
+}
+
+
 function_options & function_options::set_return_type(unsigned rt, unsigned rtt)
 {
        use_return_type = true;
@@ -540,6 +582,10 @@ void function_options::test_and_set_nparams(unsigned n)
        }
 }
 
+/** This can be used as a hook for external applications. */
+unsigned function::current_serial = 0;
+
+
 GINAC_IMPLEMENT_REGISTERED_CLASS(function, exprseq)
 
 //////////
@@ -707,17 +753,17 @@ ex function::expand(unsigned options) const
 
 int function::degree(const ex & s) const
 {
-       return is_equal(*s.bp) ? 1 : 0;
+       return is_equal(ex_to<basic>(s)) ? 1 : 0;
 }
 
 int function::ldegree(const ex & s) const
 {
-       return is_equal(*s.bp) ? 1 : 0;
+       return is_equal(ex_to<basic>(s)) ? 1 : 0;
 }
 
 ex function::coeff(const ex & s, int n) const
 {
-       if (is_equal(*s.bp))
+       if (is_equal(ex_to<basic>(s)))
                return n==1 ? _ex1() : _ex0();
        else
                return n==0 ? ex(*this) : _ex0();
@@ -756,7 +802,10 @@ ex function::eval(int level) const
        if (use_remember && lookup_remember_table(eval_result)) {
                return eval_result;
        }
-
+       current_serial = serial;
+       if (registered_functions()[serial].eval_use_exvector_args)
+               eval_result = ((eval_funcp_exvector)(registered_functions()[serial].eval_f))(seq);
+       else
        switch (opt.nparams) {
                // the following lines have been generated for max. ${maxargs} parameters
 ${eval_switch_statement}
@@ -792,6 +841,9 @@ ex function::evalf(int level) const
        if (registered_functions()[serial].evalf_f==0) {
                return function(serial,eseq).hold();
        }
+       current_serial = serial;
+       if (registered_functions()[serial].evalf_use_exvector_args)
+               return ((evalf_funcp_exvector)(registered_functions()[serial].evalf_f))(seq);
        switch (registered_functions()[serial].nparams) {
                // the following lines have been generated for max. ${maxargs} parameters
 ${evalf_switch_statement}
@@ -835,6 +887,15 @@ ex function::series(const relational & r, int order, unsigned options) const
                return basic::series(r, order);
        }
        ex res;
+       current_serial = serial;
+       if (registered_functions()[serial].series_use_exvector_args) {
+               try {
+                       res = ((series_funcp_exvector)(registered_functions()[serial].series_f))(seq, r, order, options);
+               } catch (do_taylor) {
+                       res = basic::series(r, order, options);
+               }
+               return res;
+       }
        switch (registered_functions()[serial].nparams) {
                // the following lines have been generated for max. ${maxargs} parameters
 ${series_switch_statement}
@@ -939,6 +1000,9 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation
        if (registered_functions()[serial].derivative_f == NULL)
                return fderivative(serial, diff_param, seq);
 
+       current_serial = serial;
+       if (registered_functions()[serial].derivative_use_exvector_args)
+               return ((derivative_funcp_exvector)(registered_functions()[serial].derivative_f))(seq, diff_param);
        switch (registered_functions()[serial].nparams) {
                // the following lines have been generated for max. ${maxargs} parameters
 ${diff_switch_statement}