]> www.ginac.de Git - ginac.git/blob - ginac/fderivative.cpp
series expansion behaviour fixed.
[ginac.git] / ginac / fderivative.cpp
1 /** @file fderivative.cpp
2  *
3  *  Implementation of abstract derivatives of functions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2003 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24
25 #include "fderivative.h"
26 #include "operators.h"
27 #include "print.h"
28 #include "archive.h"
29 #include "utils.h"
30
31 namespace GiNaC {
32
33 GINAC_IMPLEMENT_REGISTERED_CLASS(fderivative, function)
34
35 //////////
36 // default ctor, dtor, copy ctor, assignment operator and helpers
37 //////////
38
39 fderivative::fderivative()
40 {
41         tinfo_key = TINFO_fderivative;
42 }
43
44 void fderivative::copy(const fderivative & other)
45 {
46         inherited::copy(other);
47         parameter_set = other.parameter_set;
48 }
49
50 DEFAULT_DESTROY(fderivative)
51
52 //////////
53 // other constructors
54 //////////
55
56 fderivative::fderivative(unsigned ser, unsigned param, const exvector & args) : function(ser, args)
57 {
58         parameter_set.insert(param);
59         tinfo_key = TINFO_fderivative;
60 }
61
62 fderivative::fderivative(unsigned ser, const paramset & params, const exvector & args) : function(ser, args), parameter_set(params)
63 {
64         tinfo_key = TINFO_fderivative;
65 }
66
67 fderivative::fderivative(unsigned ser, const paramset & params, exvector * vp) : function(ser, vp), parameter_set(params)
68 {
69         tinfo_key = TINFO_fderivative;
70 }
71
72 //////////
73 // archiving
74 //////////
75
76 fderivative::fderivative(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
77 {
78         unsigned i = 0;
79         while (true) {
80                 unsigned u;
81                 if (n.find_unsigned("param", u, i))
82                         parameter_set.insert(u);
83                 else
84                         break;
85                 ++i;
86         }
87 }
88
89 void fderivative::archive(archive_node &n) const
90 {
91         inherited::archive(n);
92         paramset::const_iterator i = parameter_set.begin(), end = parameter_set.end();
93         while (i != end) {
94                 n.add_unsigned("param", *i);
95                 ++i;
96         }
97 }
98
99 DEFAULT_UNARCHIVE(fderivative)
100
101 //////////
102 // functions overriding virtual functions from base classes
103 //////////
104
105 void fderivative::print(const print_context & c, unsigned level) const
106 {
107         if (is_a<print_tree>(c)) {
108
109                 c.s << std::string(level, ' ') << class_name() << " "
110                     << registered_functions()[serial].name
111                     << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
112                     << ", nops=" << nops()
113                     << ", params=";
114                 paramset::const_iterator i = parameter_set.begin(), end = parameter_set.end();
115                 --end;
116                 while (i != end)
117                         c.s << *i++ << ",";
118                 c.s << *i << std::endl;
119                 unsigned delta_indent = static_cast<const print_tree &>(c).delta_indent;
120                 for (size_t i=0; i<seq.size(); ++i)
121                         seq[i].print(c, level + delta_indent);
122                 c.s << std::string(level + delta_indent, ' ') << "=====" << std::endl;
123
124         } else {
125
126                 c.s << "D[";
127                 paramset::const_iterator i = parameter_set.begin(), end = parameter_set.end();
128                 --end;
129                 while (i != end)
130                         c.s << *i++ << ",";
131                 c.s << *i << "](" << registered_functions()[serial].name << ")";
132                 printseq(c, '(', ',', ')', exprseq::precedence(), function::precedence());
133         }
134 }
135
136 ex fderivative::eval(int level) const
137 {
138         if (level > 1) {
139                 // first evaluate children, then we will end up here again
140                 return fderivative(serial, parameter_set, evalchildren(level));
141         }
142
143         // No parameters specified? Then return the function itself
144         if (parameter_set.empty())
145                 return function(serial, seq);
146
147         // If the function in question actually has a derivative, return it
148         if (registered_functions()[serial].has_derivative() && parameter_set.size() == 1)
149                 return pderivative(*(parameter_set.begin()));
150
151         return this->hold();
152 }
153
154 /** Numeric evaluation falls back to evaluation of arguments.
155  *  @see basic::evalf */
156 ex fderivative::evalf(int level) const
157 {
158         return basic::evalf(level);
159 }
160
161 /** The series expansion of derivatives falls back to Taylor expansion.
162  *  @see basic::series */
163 ex fderivative::series(const relational & r, int order, unsigned options) const
164 {
165         return basic::series(r, order, options);
166 }
167
168 ex fderivative::thisexprseq(const exvector & v) const
169 {
170         return fderivative(serial, parameter_set, v);
171 }
172
173 ex fderivative::thisexprseq(exvector * vp) const
174 {
175         return fderivative(serial, parameter_set, vp);
176 }
177
178 /** Implementation of ex::diff() for derivatives. It applies the chain rule.
179  *  @see ex::diff */
180 ex fderivative::derivative(const symbol & s) const
181 {
182         ex result;
183         for (size_t i=0; i<seq.size(); i++) {
184                 ex arg_diff = seq[i].diff(s);
185                 if (!arg_diff.is_zero()) {
186                         paramset ps = parameter_set;
187                         ps.insert(i);
188                         result += arg_diff * fderivative(serial, ps, seq);
189                 }
190         }
191         return result;
192 }
193
194 int fderivative::compare_same_type(const basic & other) const
195 {
196         GINAC_ASSERT(is_a<fderivative>(other));
197         const fderivative & o = static_cast<const fderivative &>(other);
198
199         if (parameter_set != o.parameter_set)
200                 return parameter_set < o.parameter_set ? -1 : 1;
201         else
202                 return inherited::compare_same_type(o);
203 }
204
205 bool fderivative::is_equal_same_type(const basic & other) const
206 {
207         GINAC_ASSERT(is_a<fderivative>(other));
208         const fderivative & o = static_cast<const fderivative &>(other);
209
210         if (parameter_set != o.parameter_set)
211                 return false;
212         else
213                 return inherited::is_equal_same_type(o);
214 }
215
216 bool fderivative::match_same_type(const basic & other) const
217 {
218         GINAC_ASSERT(is_a<fderivative>(other));
219         const fderivative & o = static_cast<const fderivative &>(other);
220
221         return parameter_set == o.parameter_set;
222 }
223
224 } // namespace GiNaC