- changed old-style power() to new-style pow()
[ginac.git] / ginac / diff.cpp
1 /** @file diff.cpp
2  *
3  *  Implementation of symbolic differentiation in all of GiNaC's classes. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2000 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <stdexcept>
24
25 #include "basic.h"
26 #include "ex.h"
27 #include "add.h"
28 #include "constant.h"
29 #include "expairseq.h"
30 #include "indexed.h"
31 #include "inifcns.h"
32 #include "mul.h"
33 #include "ncmul.h"
34 #include "numeric.h"
35 #include "power.h"
36 #include "relational.h"
37 #include "pseries.h"
38 #include "symbol.h"
39 #include "utils.h"
40
41 #ifndef NO_GINAC_NAMESPACE
42 namespace GiNaC {
43 #endif // ndef NO_GINAC_NAMESPACE
44
45 /** Default implementation of ex::diff(). It prints and error message and returns a fail object.
46  *  @see ex::diff */
47 ex basic::diff(symbol const & s) const
48 {
49     throw(std::logic_error("differentiation not supported by this type"));
50 }
51
52
53 /** Implementation of ex::diff() for a numeric. It always returns 0.
54  *
55  *  @see ex::diff */
56 ex numeric::diff(symbol const & s) const
57 {
58     return _ex0();
59 }
60
61
62 /** Implementation of ex::diff() for single differentiation of a symbol.
63  *  It returns 1 or 0.
64  *
65  *  @see ex::diff */
66 ex symbol::diff(symbol const & s) const
67 {
68     if (compare_same_type(s)) {
69         return _ex0();
70     } else {
71         return _ex1();
72     }
73 }
74
75 /** Implementation of ex::diff() for a constant. It always returns 0.
76  *
77  *  @see ex::diff */
78 ex constant::diff(symbol const & s) const
79 {
80     return _ex0();
81 }
82
83 /** Implementation of ex::diff() for multiple differentiation of a symbol.
84  *  It returns the symbol, 1 or 0.
85  *
86  *  @param nth order of differentiation
87  *  @see ex::diff */
88 ex symbol::diff(symbol const & s, unsigned nth) const
89 {
90     if (compare_same_type(s)) {
91         switch (nth) {
92         case 0:
93             return s;
94             break;
95         case 1:
96             return _ex1();
97             break;
98         default:
99             return _ex0();
100         }
101     } else {
102         return _ex1();
103     }
104 }
105
106
107 /** Implementation of ex::diff() for an indexed object. It always returns 0.
108  *  @see ex::diff */
109 ex indexed::diff(symbol const & s) const
110 {
111         return _ex0();
112 }
113
114
115 /** Implementation of ex::diff() for an expairseq. It differentiates all elements of the sequence.
116  *  @see ex::diff */
117 ex expairseq::diff(symbol const & s) const
118 {
119     return thisexpairseq(diffchildren(s),overall_coeff);
120 }
121
122
123 /** Implementation of ex::diff() for a sum. It differentiates each term.
124  *  @see ex::diff */
125 ex add::diff(symbol const & s) const
126 {
127     // D(a+b+c)=D(a)+D(b)+D(c)
128     return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
129 }
130
131
132 /** Implementation of ex::diff() for a product. It applies the product rule.
133  *  @see ex::diff */
134 ex mul::diff(symbol const & s) const
135 {
136     exvector new_seq;
137     new_seq.reserve(seq.size());
138
139     // D(a*b*c)=D(a)*b*c+a*D(b)*c+a*b*D(c)
140     for (unsigned i=0; i!=seq.size(); i++) {
141         epvector sub_seq=seq;
142         sub_seq[i] = split_ex_to_pair(sub_seq[i].coeff*
143                                       power(sub_seq[i].rest,sub_seq[i].coeff-1)*
144                                       sub_seq[i].rest.diff(s));
145         new_seq.push_back((new mul(sub_seq,overall_coeff))->setflag(status_flags::dynallocated));
146     }
147     return (new add(new_seq))->setflag(status_flags::dynallocated);
148 }
149
150
151 /** Implementation of ex::diff() for a non-commutative product. It always returns 0.
152  *  @see ex::diff */
153 ex ncmul::diff(symbol const & s) const
154 {
155     return _ex0();
156 }
157
158
159 /** Implementation of ex::diff() for a power.
160  *  @see ex::diff */
161 ex power::diff(symbol const & s) const
162 {
163     if (exponent.info(info_flags::real)) {
164         // D(b^r) = r * b^(r-1) * D(b) (faster than the formula below)
165         return mul(mul(exponent, power(basis, exponent - _ex1())), basis.diff(s));
166     } else {
167         // D(b^e) = b^e * (D(e)*ln(b) + e*D(b)/b)
168         return mul(power(basis, exponent),
169                    add(mul(exponent.diff(s), log(basis)),
170                        mul(mul(exponent, basis.diff(s)), power(basis, -1))));
171     }
172 }
173
174
175 /** Implementation of ex::diff() for functions. It applies the chain rule,
176  *  except for the Order term function.
177  *  @see ex::diff */
178 ex function::diff(symbol const & s) const
179 {
180     exvector new_seq;
181     
182     if (serial==function_index_Order) {
183         // Order Term function only differentiates the argument
184         return Order(seq[0].diff(s));
185     } else {
186         // Chain rule
187         ex arg_diff;
188         for (unsigned i=0; i!=seq.size(); i++) {
189             arg_diff = seq[i].diff(s);
190             // We apply the chain rule only when it makes sense.  This is not
191             // just for performance reasons but also to allow functions to
192             // throw when differentiated with respect to one of its arguments
193             // without running into trouble with our automatic full
194             // differentiation:
195             if (!arg_diff.is_zero())
196                 new_seq.push_back(mul(pdiff(i), arg_diff));
197         }
198     }
199     return add(new_seq);
200 }
201
202
203 /** Implementation of ex::diff() for a power series. It treats the series as a polynomial.
204  *  @see ex::diff */
205 ex pseries::diff(symbol const & s) const
206 {
207     if (s == var) {
208         epvector new_seq;
209         epvector::const_iterator it = seq.begin(), itend = seq.end();
210         
211         // FIXME: coeff might depend on var
212         while (it != itend) {
213             if (is_order_function(it->rest)) {
214                 new_seq.push_back(expair(it->rest, it->coeff - 1));
215             } else {
216                 ex c = it->rest * it->coeff;
217                 if (!c.is_zero())
218                     new_seq.push_back(expair(c, it->coeff - 1));
219             }
220             it++;
221         }
222         return pseries(var, point, new_seq);
223     } else {
224         return *this;
225     }
226 }
227
228
229 /** Compute partial derivative of an expression.
230  *
231  *  @param s  symbol by which the expression is derived
232  *  @param nth  order of derivative (default 1)
233  *  @return partial derivative as a new expression */
234
235 ex ex::diff(symbol const & s, unsigned nth) const
236 {
237     GINAC_ASSERT(bp!=0);
238
239     if (nth==0) {
240         return *this;
241     }
242
243     ex ndiff = bp->diff(s);
244     while (nth>1) {
245         ndiff = ndiff.diff(s);
246         --nth;
247     }
248     return ndiff;
249 }
250
251 #ifndef NO_GINAC_NAMESPACE
252 } // namespace GiNaC
253 #endif // ndef NO_GINAC_NAMESPACE