]> www.ginac.de Git - ginac.git/blob - ginac/fderivative.cpp
parser: add necessary checks to operator() to stop accepting nonsense.
[ginac.git] / ginac / fderivative.cpp
1 /** @file fderivative.cpp
2  *
3  *  Implementation of abstract derivatives of functions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <iostream>
24
25 #include "fderivative.h"
26 #include "operators.h"
27 #include "archive.h"
28 #include "utils.h"
29
30 namespace GiNaC {
31
32 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(fderivative, function,
33   print_func<print_context>(&fderivative::do_print).
34   print_func<print_csrc>(&fderivative::do_print_csrc).
35   print_func<print_tree>(&fderivative::do_print_tree))
36
37 //////////
38 // default constructor
39 //////////
40
41 fderivative::fderivative()
42 {
43         tinfo_key = &fderivative::tinfo_static;
44 }
45
46 //////////
47 // other constructors
48 //////////
49
50 fderivative::fderivative(unsigned ser, unsigned param, const exvector & args) : function(ser, args)
51 {
52         parameter_set.insert(param);
53         tinfo_key = &fderivative::tinfo_static;
54 }
55
56 fderivative::fderivative(unsigned ser, const paramset & params, const exvector & args) : function(ser, args), parameter_set(params)
57 {
58         tinfo_key = &fderivative::tinfo_static;
59 }
60
61 fderivative::fderivative(unsigned ser, const paramset & params, std::auto_ptr<exvector> vp) : function(ser, vp), parameter_set(params)
62 {
63         tinfo_key = &fderivative::tinfo_static;
64 }
65
66 //////////
67 // archiving
68 //////////
69
70 fderivative::fderivative(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
71 {
72         unsigned i = 0;
73         while (true) {
74                 unsigned u;
75                 if (n.find_unsigned("param", u, i))
76                         parameter_set.insert(u);
77                 else
78                         break;
79                 ++i;
80         }
81 }
82
83 void fderivative::archive(archive_node &n) const
84 {
85         inherited::archive(n);
86         paramset::const_iterator i = parameter_set.begin(), end = parameter_set.end();
87         while (i != end) {
88                 n.add_unsigned("param", *i);
89                 ++i;
90         }
91 }
92
93 DEFAULT_UNARCHIVE(fderivative)
94
95 //////////
96 // functions overriding virtual functions from base classes
97 //////////
98
99 void fderivative::print(const print_context & c, unsigned level) const
100 {
101         // class function overrides print(), but we don't want that
102         basic::print(c, level);
103 }
104
105 void fderivative::do_print(const print_context & c, unsigned level) const
106 {
107         c.s << "D[";
108         paramset::const_iterator i = parameter_set.begin(), end = parameter_set.end();
109         --end;
110         while (i != end) {
111                 c.s << *i++ << ",";
112         }
113         c.s << *i << "](" << registered_functions()[serial].name << ")";
114         printseq(c, '(', ',', ')', exprseq::precedence(), function::precedence());
115 }
116
117 void fderivative::do_print_csrc(const print_csrc & c, unsigned level) const
118 {
119         c.s << "D_";
120         paramset::const_iterator i = parameter_set.begin(), end = parameter_set.end();
121         --end;
122         while (i != end)
123                 c.s << *i++ << "_";
124         c.s << *i << "_" << registered_functions()[serial].name;
125         printseq(c, '(', ',', ')', exprseq::precedence(), function::precedence());
126 }
127
128 void fderivative::do_print_tree(const print_tree & c, unsigned level) const
129 {
130         c.s << std::string(level, ' ') << class_name() << " "
131             << registered_functions()[serial].name << " @" << this
132             << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
133             << ", nops=" << nops()
134             << ", params=";
135         paramset::const_iterator i = parameter_set.begin(), end = parameter_set.end();
136         --end;
137         while (i != end)
138                 c.s << *i++ << ",";
139         c.s << *i << std::endl;
140         for (size_t i=0; i<seq.size(); ++i)
141                 seq[i].print(c, level + c.delta_indent);
142         c.s << std::string(level + c.delta_indent, ' ') << "=====" << std::endl;
143 }
144
145 ex fderivative::eval(int level) const
146 {
147         if (level > 1) {
148                 // first evaluate children, then we will end up here again
149                 return fderivative(serial, parameter_set, evalchildren(level));
150         }
151
152         // No parameters specified? Then return the function itself
153         if (parameter_set.empty())
154                 return function(serial, seq);
155
156         // If the function in question actually has a derivative, return it
157         if (registered_functions()[serial].has_derivative() && parameter_set.size() == 1)
158                 return pderivative(*(parameter_set.begin()));
159
160         return this->hold();
161 }
162
163 /** Numeric evaluation falls back to evaluation of arguments.
164  *  @see basic::evalf */
165 ex fderivative::evalf(int level) const
166 {
167         return basic::evalf(level);
168 }
169
170 /** The series expansion of derivatives falls back to Taylor expansion.
171  *  @see basic::series */
172 ex fderivative::series(const relational & r, int order, unsigned options) const
173 {
174         return basic::series(r, order, options);
175 }
176
177 ex fderivative::thiscontainer(const exvector & v) const
178 {
179         return fderivative(serial, parameter_set, v);
180 }
181
182 ex fderivative::thiscontainer(std::auto_ptr<exvector> vp) const
183 {
184         return fderivative(serial, parameter_set, vp);
185 }
186
187 /** Implementation of ex::diff() for derivatives. It applies the chain rule.
188  *  @see ex::diff */
189 ex fderivative::derivative(const symbol & s) const
190 {
191         ex result;
192         for (size_t i=0; i<seq.size(); i++) {
193                 ex arg_diff = seq[i].diff(s);
194                 if (!arg_diff.is_zero()) {
195                         paramset ps = parameter_set;
196                         ps.insert(i);
197                         result += arg_diff * fderivative(serial, ps, seq);
198                 }
199         }
200         return result;
201 }
202
203 int fderivative::compare_same_type(const basic & other) const
204 {
205         GINAC_ASSERT(is_a<fderivative>(other));
206         const fderivative & o = static_cast<const fderivative &>(other);
207
208         if (parameter_set != o.parameter_set)
209                 return parameter_set < o.parameter_set ? -1 : 1;
210         else
211                 return inherited::compare_same_type(o);
212 }
213
214 bool fderivative::is_equal_same_type(const basic & other) const
215 {
216         GINAC_ASSERT(is_a<fderivative>(other));
217         const fderivative & o = static_cast<const fderivative &>(other);
218
219         if (parameter_set != o.parameter_set)
220                 return false;
221         else
222                 return inherited::is_equal_same_type(o);
223 }
224
225 bool fderivative::match_same_type(const basic & other) const
226 {
227         GINAC_ASSERT(is_a<fderivative>(other));
228         const fderivative & o = static_cast<const fderivative &>(other);
229
230         return parameter_set == o.parameter_set && inherited::match_same_type(other);
231 }
232
233 } // namespace GiNaC