]> www.ginac.de Git - ginac.git/blobdiff - ginac/function.pl
- Instead of just totally symmetric or antisymmetric, complex symmetries
[ginac.git] / ginac / function.pl
index a08d35f11f507d61e963048e9d8a329f989aa5c2..7d45d6d449dec417dd3d96d5c2f0b32de5075edd 100755 (executable)
@@ -242,6 +242,7 @@ if (!automatic_typecheck) { \\
 namespace GiNaC {
 
 class function;
+class symmetry;
 
 typedef ex (* eval_funcp)();
 typedef ex (* evalf_funcp)();
@@ -276,6 +277,7 @@ $series_func_interface
        function_options & remember(unsigned size, unsigned assoc_size=0,
                                    unsigned strategy=remember_strategies::delete_never);
        function_options & overloaded(unsigned o);
+       function_options & set_symmetry(const symmetry & s);
        void test_and_set_nparams(unsigned n);
        std::string get_name(void) const { return name; }
        unsigned get_nparams(void) const { return nparams; }
@@ -303,6 +305,8 @@ protected:
        unsigned remember_strategy;
 
        unsigned functions_with_same_name;
+
+       ex symtree;
 };
 
 /** The class function is used to implement builtin functions like sin, cos...
@@ -344,6 +348,7 @@ public:
        ex evalf(int level=0) const;
        unsigned calchash(void) const;
        ex series(const relational & r, int order, unsigned options = 0) const;
+       bool match(const ex & pattern, lst & repl_lst) const;
        ex thisexprseq(const exvector & v) const;
        ex thisexprseq(exvector * vp) const;
 protected:
@@ -422,6 +427,7 @@ $implementation=<<END_OF_IMPLEMENTATION;
 #include "function.h"
 #include "ex.h"
 #include "lst.h"
+#include "symmetry.h"
 #include "print.h"
 #include "archive.h"
 #include "inifcns.h"
@@ -460,6 +466,7 @@ void function_options::initialize(void)
        use_return_type=false;
        use_remember=false;
        functions_with_same_name=1;
+       symtree = 0;
 }
 
 function_options & function_options::set_name(std::string const & n,
@@ -517,6 +524,12 @@ function_options & function_options::overloaded(unsigned o)
        functions_with_same_name=o;
        return *this;
 }
+
+function_options & function_options::set_symmetry(const symmetry & s)
+{
+       symtree = s;
+       return *this;
+}
        
 void function_options::test_and_set_nparams(unsigned n)
 {
@@ -677,7 +690,7 @@ void function::print(const print_context & c, unsigned level) const
                }
                c.s << ")";
 
-       } else if is_of_type(c, print_latex) {
+       } else if (is_of_type(c, print_latex)) {
                c.s << registered_functions()[serial].TeX_name;
                printseq(c, '(', ',', ')', exprseq::precedence(), function::precedence());
        } else {
@@ -718,17 +731,32 @@ ex function::eval(int level) const
                return function(serial,evalchildren(level));
        }
 
-       if (registered_functions()[serial].eval_f==0) {
+       const function_options &opt = registered_functions()[serial];
+
+       // Canonicalize argument order according to the symmetry properties
+       if (seq.size() > 1 && !(opt.symtree.is_zero())) {
+               exvector v = seq;
+               GINAC_ASSERT(is_ex_exactly_of_type(opt.symtree, symmetry));
+               int sig = canonicalize(v.begin(), ex_to_symmetry(opt.symtree));
+               if (sig != INT_MAX) {
+                       // Something has changed while sorting arguments, more evaluations later
+                       if (sig == 0)
+                               return _ex0();
+                       return ex(sig) * thisexprseq(v);
+               }
+       }
+
+       if (opt.eval_f==0) {
                return this->hold();
        }
 
-       bool use_remember=registered_functions()[serial].use_remember;
+       bool use_remember=opt.use_remember;
        ex eval_result;
        if (use_remember && lookup_remember_table(eval_result)) {
                return eval_result;
        }
 
-       switch (registered_functions()[serial].nparams) {
+       switch (opt.nparams) {
                // the following lines have been generated for max. ${maxargs} parameters
 ${eval_switch_statement}
                // end of generated lines
@@ -801,6 +829,14 @@ ${series_switch_statement}
        throw(std::logic_error("function::series(): invalid nparams"));
 }
 
+bool function::match(const ex & pattern, lst & repl_lst) const
+{
+       // Serial number must match
+       if (is_ex_of_type(pattern, function) && serial != ex_to_function(pattern).serial)
+               return false;
+       return inherited::match(pattern, repl_lst);
+}
+
 // protected