]> www.ginac.de Git - ginac.git/blobdiff - ginac/function.cppy
[PATCH] Check number of parameters when reading function from archive.
[ginac.git] / ginac / function.cppy
index 0a521e11a19e333d3ca0b3769248815018de1a05..6739eb9ebafdf6ebf8108e4b3a30125f0493ef78 100644 (file)
@@ -7,7 +7,7 @@
  *  Please do not modify it directly, edit function.cppy instead!
  *  function.py options: maxargs=@maxargs@
  *
- *  GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2020 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -34,7 +34,6 @@
 #include "power.h"
 #include "archive.h"
 #include "inifcns.h"
-#include "tostring.h"
 #include "utils.h"
 #include "hash_seed.h"
 #include "remember.h"
@@ -76,11 +75,11 @@ function_options::~function_options()
 
 void function_options::initialize()
 {
-       set_name("unnamed_function", "\\\\mbox{unnamed}");
+       set_name("unnamed_function", "\\mbox{unnamed}");
        nparams = 0;
        eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = expand_f
-               = derivative_f = expl_derivative_f = power_f = series_f = 0;
-       info_f = 0;
+               = derivative_f = expl_derivative_f = power_f = series_f = nullptr;
+       info_f = nullptr;
        evalf_params_first = true;
        use_return_type = false;
        eval_use_exvector_args = false;
@@ -105,7 +104,7 @@ function_options & function_options::set_name(std::string const & n,
 {
        name = n;
        if (tn==std::string())
-               TeX_name = "\\\\mbox{"+name+"}";
+               TeX_name = "\\mbox{"+name+"}";
        else
                TeX_name = tn;
        return *this;
@@ -143,7 +142,7 @@ function_options & function_options::set_return_type(unsigned rt, const return_t
 {
        use_return_type = true;
        return_type = rt;
-       if (rtt != 0)
+       if (rtt != nullptr)
                return_type_tinfo = *rtt;
        else
                return_type_tinfo = make_return_type_t<function>();
@@ -229,7 +228,7 @@ function::function(unsigned ser) : serial(ser)
 // the following lines have been generated for max. @maxargs@ parameters
 +++ for N in range(1, maxargs + 1):
 function::function(unsigned ser, @seq('const ex & param%(n)d', N)@)
-       : exprseq(@seq('param%(n)d', N)@), serial(ser)
+       : exprseq{@seq('param%(n)d', N)@}, serial(ser)
 {
 }
 ---
@@ -242,8 +241,8 @@ function::function(unsigned ser, const exprseq & es) : exprseq(es), serial(ser)
        clearflag(status_flags::evaluated);
 }
 
-function::function(unsigned ser, const exvector & v, bool discardable) 
-  : exprseq(v,discardable), serial(ser)
+function::function(unsigned ser, const exvector & v)
+  : exprseq(v), serial(ser)
 {
 }
 
@@ -260,18 +259,20 @@ function::function(unsigned ser, exvector && v)
 void function::read_archive(const archive_node& n, lst& sym_lst)
 {
        inherited::read_archive(n, sym_lst);
-       // Find serial number by function name
+       // Find serial number by function name and number of parameters
+       unsigned np = seq.size();
        std::string s;
        if (n.find_string("name", s)) {
                unsigned int ser = 0;
                for (auto & it : registered_functions()) {
-                       if (s == it.name) {
+                       if (s == it.name && np == registered_functions()[ser].nparams) {
                                serial = ser;
                                return;
                        }
                        ++ser;
                }
-               throw (std::runtime_error("unknown function '" + s + "' in archive"));
+               throw (std::runtime_error("unknown function '" + s +
+                                         "' with " + std::to_string(np) + " parameters in archive"));
        } else
                throw (std::runtime_error("unnamed function in archive"));
 }
@@ -363,11 +364,10 @@ next_context:
        }
 }
 
-ex function::eval(int level) const
+ex function::eval() const
 {
-       if (level>1) {
-               // first evaluate children, then we will end up here again
-               return function(serial,evalchildren(level));
+       if (flags & status_flags::evaluated) {
+               return *this;
        }
 
        GINAC_ASSERT(serial<registered_functions().size());
@@ -382,11 +382,11 @@ ex function::eval(int level) const
                        // Something has changed while sorting arguments, more evaluations later
                        if (sig == 0)
                                return _ex0;
-                       return ex(sig) * thiscontainer(v);
+                       return ex(sig) * thiscontainer(std::move(v));
                }
        }
 
-       if (opt.eval_f==0) {
+       if (opt.eval_f==nullptr) {
                return this->hold();
        }
 
@@ -416,26 +416,23 @@ ex function::eval(int level) const
        return eval_result;
 }
 
-ex function::evalf(int level) const
+ex function::evalf() const
 {
        GINAC_ASSERT(serial<registered_functions().size());
        const function_options &opt = registered_functions()[serial];
 
        // Evaluate children first
        exvector eseq;
-       if (level == 1 || !(opt.evalf_params_first))
+       if (!opt.evalf_params_first)
                eseq = seq;
-       else if (level == -max_recursion_level)
-               throw(std::runtime_error("max recursion level reached"));
        else {
                eseq.reserve(seq.size());
-               --level;
                for (auto & it : seq) {
-                       eseq.push_back(it.evalf(level));
+                       eseq.push_back(it.evalf());
                }
        }
 
-       if (opt.evalf_f==0) {
+       if (opt.evalf_f==nullptr) {
                return function(serial,eseq).hold();
        }
        current_serial = serial;
@@ -453,7 +450,7 @@ ex function::evalf(int level) const
 }
 
 /**
- *  This method is defined to be in line with behaviour of function::return_type()
+ *  This method is defined to be in line with behavior of function::return_type()
  */
 ex function::eval_ncmul(const exvector & v) const
 {
@@ -494,7 +491,7 @@ ex function::series(const relational & r, int order, unsigned options) const
        GINAC_ASSERT(serial<registered_functions().size());
        const function_options &opt = registered_functions()[serial];
 
-       if (opt.series_f==0) {
+       if (opt.series_f==nullptr) {
                return basic::series(r, order);
        }
        ex res;
@@ -529,7 +526,7 @@ ex function::conjugate() const
        GINAC_ASSERT(serial<registered_functions().size());
        const function_options & opt = registered_functions()[serial];
 
-       if (opt.conjugate_f==0) {
+       if (opt.conjugate_f==nullptr) {
                return conjugate_function(*this).hold();
        }
 
@@ -554,7 +551,7 @@ ex function::real_part() const
        GINAC_ASSERT(serial<registered_functions().size());
        const function_options & opt = registered_functions()[serial];
 
-       if (opt.real_part_f==0)
+       if (opt.real_part_f==nullptr)
                return basic::real_part();
 
        if (opt.real_part_use_exvector_args)
@@ -577,7 +574,7 @@ ex function::imag_part() const
        GINAC_ASSERT(serial<registered_functions().size());
        const function_options & opt = registered_functions()[serial];
 
-       if (opt.imag_part_f==0)
+       if (opt.imag_part_f==nullptr)
                return basic::imag_part();
 
        if (opt.imag_part_use_exvector_args)
@@ -600,7 +597,7 @@ bool function::info(unsigned inf) const
        GINAC_ASSERT(serial<registered_functions().size());
        const function_options & opt = registered_functions()[serial];
 
-       if (opt.info_f==0) {
+       if (opt.info_f==nullptr) {
                return basic::info(inf);
        }
 
@@ -769,7 +766,7 @@ ex function::expl_derivative(const symbol & s) const // explicit differentiation
                        // end of generated lines
                }
        }
-       // There is no fallback for explicit deriviative.
+       // There is no fallback for explicit derivative.
        throw(std::logic_error("function::expl_derivative(): explicit derivation is called, but no such function defined"));
 }
 
@@ -793,8 +790,7 @@ ex function::power(const ex & power_param) const // power of function
                }
        }
        // No power function defined? Fall back to returning a power object.
-       return (new GiNaC::power(*this, power_param))->setflag(status_flags::dynallocated |
-                                                              status_flags::evaluated);
+       return dynallocate<GiNaC::power>(*this, power_param).setflag(status_flags::evaluated);
 }
 
 ex function::expand(unsigned options) const
@@ -844,8 +840,8 @@ void function::store_remember_table(ex const & result) const
 unsigned function::register_new(function_options const & opt)
 {
        size_t same_name = 0;
-       for (size_t i=0; i<registered_functions().size(); ++i) {
-               if (registered_functions()[i].name==opt.name) {
+       for (auto & i : registered_functions()) {
+               if (i.name==opt.name) {
                        ++same_name;
                }
        }
@@ -878,7 +874,7 @@ unsigned function::find_function(const std::string &name, unsigned nparams)
                        return serial;
                ++serial;
        }
-       throw (std::runtime_error("no function '" + name + "' with " + ToString(nparams) + " parameters defined"));
+       throw (std::runtime_error("no function '" + name + "' with " + std::to_string(nparams) + " parameters defined"));
 }
 
 /** Return the print name of the function. */