X-Git-Url: https://www.ginac.de/ginac.git//ginac.git?p=ginac.git;a=blobdiff_plain;f=ginac%2Ffunction.pl;h=565b0f91c4b9f98ac58874a19a3b659393689dfd;hp=b36df872337b225da8534ed867741009e6d92347;hb=b11c30cf00d90113c924e4a96e8fed0341c246c6;hpb=72c6a645189bbeacc565968753c15885f6c6829f diff --git a/ginac/function.pl b/ginac/function.pl index b36df872..565b0f91 100755 --- a/ginac/function.pl +++ b/ginac/function.pl @@ -34,25 +34,11 @@ sub generate { return generate_from_to($template,$seq_template1,$seq_template2,1,$maxargs); } -$declare_function_macro = <<'END_OF_DECLARE_FUNCTION_1_AND_2P_MACRO'; -#define DECLARE_FUNCTION_1P(NAME) \ -extern const unsigned function_index_##NAME; \ -inline GiNaC::function NAME(const GiNaC::ex & p1) { \ - return GiNaC::function(function_index_##NAME, p1); \ -} -#define DECLARE_FUNCTION_2P(NAME) \ -extern const unsigned function_index_##NAME; \ -inline GiNaC::function NAME(const GiNaC::ex & p1, const GiNaC::ex & p2) { \ - return GiNaC::function(function_index_##NAME, p1, p2); \ -} - -END_OF_DECLARE_FUNCTION_1_AND_2P_MACRO - -$declare_function_macro .= generate_from_to( - <<'END_OF_DECLARE_FUNCTION_MACRO','const GiNaC::ex & p${N}','p${N}',3,$maxargs); +$declare_function_macro = generate_from_to( + <<'END_OF_DECLARE_FUNCTION_MACRO','const GiNaC::ex & p${N}','p${N}',1,$maxargs); #define DECLARE_FUNCTION_${N}P(NAME) \\ extern const unsigned function_index_##NAME; \\ -inline GiNaC::function NAME(${SEQ1}) { \\ +inline const GiNaC::function NAME(${SEQ1}) { \\ return GiNaC::function(function_index_##NAME, ${SEQ2}); \\ } @@ -99,7 +85,7 @@ END_OF_CONSTRUCTORS_IMPLEMENTATION $eval_switch_statement=generate( <<'END_OF_EVAL_SWITCH_STATEMENT','seq[${N}-1]',''); case ${N}: - eval_result=((eval_funcp_${N})(registered_functions()[serial].eval_f))(${SEQ1}); + eval_result = ((eval_funcp_${N})(registered_functions()[serial].eval_f))(${SEQ1}); break; END_OF_EVAL_SWITCH_STATEMENT @@ -131,9 +117,9 @@ $eval_func_implementation=generate( function_options & function_options::eval_func(eval_funcp_${N} e) { test_and_set_nparams(${N}); - eval_f=eval_funcp(e); + eval_f = eval_funcp(e); return *this; -} +} END_OF_EVAL_FUNC_IMPLEMENTATION $evalf_func_implementation=generate( @@ -141,9 +127,9 @@ $evalf_func_implementation=generate( function_options & function_options::evalf_func(evalf_funcp_${N} ef) { test_and_set_nparams(${N}); - evalf_f=evalf_funcp(ef); + evalf_f = evalf_funcp(ef); return *this; -} +} END_OF_EVALF_FUNC_IMPLEMENTATION $derivative_func_implementation=generate( @@ -151,9 +137,9 @@ $derivative_func_implementation=generate( function_options & function_options::derivative_func(derivative_funcp_${N} d) { test_and_set_nparams(${N}); - derivative_f=derivative_funcp(d); + derivative_f = derivative_funcp(d); return *this; -} +} END_OF_DERIVATIVE_FUNC_IMPLEMENTATION $series_func_implementation=generate( @@ -161,9 +147,9 @@ $series_func_implementation=generate( function_options & function_options::series_func(series_funcp_${N} s) { test_and_set_nparams(${N}); - series_f=series_funcp(s); + series_f = series_funcp(s); return *this; -} +} END_OF_SERIES_FUNC_IMPLEMENTATION $interface=< #include -// CINT needs to work properly with +// CINT needs to work properly with #include #include "exprseq.h" @@ -212,19 +198,11 @@ $declare_function_macro const unsigned function_index_##NAME= \\ GiNaC::function::register_new(GiNaC::function_options(#NAME).OPT); -#define REGISTER_FUNCTION_OLD(NAME,E,EF,D,S) \\ -const unsigned function_index_##NAME= \\ - GiNaC::function::register_new(GiNaC::function_options(#NAME). \\ - eval_func(E). \\ - evalf_func(EF). \\ - derivative_func(D). \\ - series_func(S)); - #define BEGIN_TYPECHECK \\ bool automatic_typecheck=true; #define TYPECHECK(VAR,TYPE) \\ -if (!is_ex_exactly_of_type(VAR,TYPE)) { \\ +if (!is_exactly_a(VAR)) { \\ automatic_typecheck=false; \\ } else @@ -242,6 +220,7 @@ if (!automatic_typecheck) { \\ namespace GiNaC { class function; +class symmetry; typedef ex (* eval_funcp)(); typedef ex (* evalf_funcp)(); @@ -264,6 +243,7 @@ public: ~function_options(); void initialize(void); function_options & set_name(std::string const & n, std::string const & tn=std::string()); + function_options & latex_name(std::string const & tn); // the following lines have been generated for max. ${maxargs} parameters $eval_func_interface $evalf_func_interface @@ -275,9 +255,11 @@ $series_func_interface function_options & remember(unsigned size, unsigned assoc_size=0, unsigned strategy=remember_strategies::delete_never); function_options & overloaded(unsigned o); + function_options & set_symmetry(const symmetry & s); void test_and_set_nparams(unsigned n); std::string get_name(void) const { return name; } unsigned get_nparams(void) const { return nparams; } + bool has_derivative(void) const { return derivative_f != NULL; } protected: std::string name; @@ -302,6 +284,8 @@ protected: unsigned remember_strategy; unsigned functions_with_same_name; + + ex symtree; }; /** The class function is used to implement builtin functions like sin, cos... @@ -318,6 +302,7 @@ class function : public exprseq friend class remember_table_entry; // friend class remember_table_list; // friend class remember_table; + friend ex Derivative_eval(const ex &, const ex &); // member functions @@ -328,15 +313,13 @@ public: $constructors_interface // end of generated lines function(unsigned ser, const exprseq & es); - function(unsigned ser, const exvector & v, bool discardable=0); + function(unsigned ser, const exvector & v, bool discardable = false); function(unsigned ser, exvector * vp); // vp will be deleted // functions overriding virtual functions from bases classes public: - void printraw(std::ostream & os) const; - void print(std::ostream & os, unsigned upper_precedence=0) const; - void printtree(std::ostream & os, unsigned indent) const; - void printcsrc(std::ostream & os, unsigned type, unsigned upper_precedence=0) const; + void print(const print_context & c, unsigned level = 0) const; + unsigned precedence(void) const {return 70;} int degree(const ex & s) const; int ldegree(const ex & s) const; ex coeff(const ex & s, int n = 1) const; @@ -350,6 +333,7 @@ public: protected: ex derivative(const symbol & s) const; bool is_equal_same_type(const basic & other) const; + bool match_same_type(const basic & other) const; unsigned return_type(void) const; unsigned return_type_tinfo(void) const; @@ -365,7 +349,8 @@ protected: public: static unsigned register_new(function_options const & opt); static unsigned find_function(const std::string &name, unsigned nparams); - unsigned getserial(void) const {return serial;} + unsigned get_serial(void) const {return serial;} + std::string get_name(void) const; // member variables @@ -374,18 +359,21 @@ protected: }; // utility functions/macros +/** Return the object of type function handled by an ex. + * This is unsafe: you need to check the type first. */ inline const function &ex_to_function(const ex &e) { return static_cast(*e.bp); } -#define is_ex_the_function(OBJ, FUNCNAME) \\ - (is_ex_exactly_of_type(OBJ, function) && static_cast(OBJ.bp)->getserial() == function_index_##FUNCNAME) - -// global constants +/** Specialization of is_exactly_a(obj) for objects of type function. */ +template<> inline bool is_exactly_a(const basic & obj) +{ + return obj.tinfo()==TINFO_function; +} -extern const function some_function; -extern const std::type_info & typeid_function; +#define is_ex_the_function(OBJ, FUNCNAME) \\ + (is_ex_exactly_of_type(OBJ, function) && static_cast(OBJ.bp)->get_serial() == function_index_##FUNCNAME) } // namespace GiNaC @@ -427,6 +415,8 @@ $implementation=<print(os); - } - os << ")"; -} - -void function::print(std::ostream & os, unsigned upper_precedence) const -{ - debugmsg("function print",LOGLEVEL_PRINT); - - GINAC_ASSERT(serialprint(c); + ++it; + if (it != itend) + c.s << ","; + } + c.s << ")"; - // Print arguments, separated by commas - exvector::const_iterator it = seq.begin(); - exvector::const_iterator itend = seq.end(); - while (it != itend) { - it->bp->printcsrc(os, type, 0); - it++; - if (it != itend) - os << ","; + } else if (is_of_type(c, print_latex)) { + c.s << registered_functions()[serial].TeX_name; + printseq(c, '(', ',', ')', exprseq::precedence(), function::precedence()); + } else { + c.s << registered_functions()[serial].name; + printseq(c, '(', ',', ')', exprseq::precedence(), function::precedence()); } - os << ")"; } ex function::expand(unsigned options) const { - return this->setflag(status_flags::expanded); + // Only expand arguments when asked to do so + if (options & expand_options::expand_function_args) + return inherited::expand(options); + else + return (options == 0) ? setflag(status_flags::expanded) : *this; } int function::degree(const ex & s) const @@ -738,17 +724,32 @@ ex function::eval(int level) const return function(serial,evalchildren(level)); } - if (registered_functions()[serial].eval_f==0) { + const function_options &opt = registered_functions()[serial]; + + // Canonicalize argument order according to the symmetry properties + if (seq.size() > 1 && !(opt.symtree.is_zero())) { + exvector v = seq; + GINAC_ASSERT(is_ex_exactly_of_type(opt.symtree, symmetry)); + int sig = canonicalize(v.begin(), ex_to(opt.symtree)); + if (sig != INT_MAX) { + // Something has changed while sorting arguments, more evaluations later + if (sig == 0) + return _ex0(); + return ex(sig) * thisexprseq(v); + } + } + + if (opt.eval_f==0) { return this->hold(); } - bool use_remember=registered_functions()[serial].use_remember; + bool use_remember = opt.use_remember; ex eval_result; if (use_remember && lookup_remember_table(eval_result)) { return eval_result; } - switch (registered_functions()[serial].nparams) { + switch (opt.nparams) { // the following lines have been generated for max. ${maxargs} parameters ${eval_switch_statement} // end of generated lines @@ -765,7 +766,20 @@ ex function::evalf(int level) const { GINAC_ASSERT(serialevalf(level)); + ++it; + } if (registered_functions()[serial].evalf_f==0) { return function(serial,eseq).hold(); @@ -823,7 +837,6 @@ ${series_switch_statement} // protected - /** Implementation of ex::diff() for functions. It applies the chain rule, * except for the Order term function. * \@see ex::diff */ @@ -844,7 +857,7 @@ ex function::derivative(const symbol & s) const for (unsigned i=0; i!=fcn.nops(); i++) { arg_diff = fcn.op(i).diff(s); if (!arg_diff.is_zero()) { - lst new_lst = ex_to_lst(seq[1]); + lst new_lst = ex_to(seq[1]); new_lst.append(i); result += arg_diff * Derivative(fcn, new_lst); } @@ -852,7 +865,8 @@ ex function::derivative(const symbol & s) const } else { // Chain rule ex arg_diff; - for (unsigned i=0; i!=seq.size(); i++) { + unsigned num = seq.size(); + for (unsigned i=0; i(const_cast(other)); + const function & o = static_cast(other); - if (serial!=o.serial) { + if (serial != o.serial) return serial < o.serial ? -1 : 1; - } - return exprseq::compare_same_type(o); + else + return exprseq::compare_same_type(o); } bool function::is_equal_same_type(const basic & other) const { GINAC_ASSERT(is_of_type(other, function)); - const function & o=static_cast(const_cast(other)); + const function & o = static_cast(other); + + if (serial != o.serial) + return false; + else + return exprseq::is_equal_same_type(o); +} + +bool function::match_same_type(const basic & other) const +{ + GINAC_ASSERT(is_of_type(other, function)); + const function & o = static_cast(other); - if (serial!=o.serial) return false; - return exprseq::is_equal_same_type(o); + return serial == o.serial; } unsigned function::return_type(void) const { - if (seq.size()==0) { + if (seq.empty()) return return_types::commutative; - } - return (*seq.begin()).return_type(); + else + return seq.begin()->return_type(); } - + unsigned function::return_type_tinfo(void) const { - if (seq.size()==0) { + if (seq.empty()) return tinfo_key; - } - return (*seq.begin()).return_type_tinfo(); + else + return seq.begin()->return_type_tinfo(); } ////////// @@ -925,7 +949,7 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation // the following lines have been generated for max. ${maxargs} parameters ${diff_switch_statement} // end of generated lines - } + } throw(std::logic_error("function::pderivative(): no diff function defined")); } @@ -949,10 +973,10 @@ void function::store_remember_table(ex const & result) const unsigned function::register_new(function_options const & opt) { - unsigned same_name=0; + unsigned same_name = 0; for (unsigned i=0; i=opt.functions_with_same_name) { @@ -983,24 +1007,18 @@ unsigned function::find_function(const std::string &name, unsigned nparams) while (i != end) { if (i->get_name() == name && i->get_nparams() == nparams) return serial; - i++; - serial++; + ++i; + ++serial; } throw (std::runtime_error("no function '" + name + "' with " + ToString(nparams) + " parameters defined")); } -////////// -// static member variables -////////// - -// none - -////////// -// global constants -////////// - -const function some_function; -const std::type_info & typeid_function=typeid(some_function); +/** Return the print name of the function. */ +std::string function::get_name(void) const +{ + GINAC_ASSERT(serial