65ee609a27cc8d04a7c76c43a585dc473e397f00
[ginac.git] / ginac / diff.cpp
1 /** @file diff.cpp
2  *
3  *  Implementation of symbolic differentiation in all of GiNaC's classes. */
4
5 /*
6  *  GiNaC Copyright (C) 1999 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <stdexcept>
24
25 #include "basic.h"
26 #include "ex.h"
27 #include "add.h"
28 #include "constant.h"
29 #include "expairseq.h"
30 #include "indexed.h"
31 #include "inifcns.h"
32 #include "mul.h"
33 #include "ncmul.h"
34 #include "numeric.h"
35 #include "power.h"
36 #include "relational.h"
37 #include "series.h"
38 #include "symbol.h"
39
40 #ifndef NO_GINAC_NAMESPACE
41 namespace GiNaC {
42 #endif // ndef NO_GINAC_NAMESPACE
43
44 /** Default implementation of ex::diff(). It prints and error message and returns a fail object.
45  *  @see ex::diff */
46 ex basic::diff(symbol const & s) const
47 {
48     throw(std::logic_error("differentiation not supported by this type"));
49 }
50
51
52 /** Implementation of ex::diff() for a numeric. It always returns 0.
53  *
54  *  @see ex::diff */
55 ex numeric::diff(symbol const & s) const
56 {
57     return exZERO();
58 }
59
60
61 /** Implementation of ex::diff() for single differentiation of a symbol.
62  *  It returns 1 or 0.
63  *
64  *  @see ex::diff */
65 ex symbol::diff(symbol const & s) const
66 {
67     if (compare_same_type(s)) {
68         return exZERO();
69     } else {
70         return exONE();
71     }
72 }
73
74 /** Implementation of ex::diff() for a constant. It always returns 0.
75  *
76  *  @see ex::diff */
77 ex constant::diff(symbol const & s) const
78 {
79     return exZERO();
80 }
81
82 /** Implementation of ex::diff() for multiple differentiation of a symbol.
83  *  It returns the symbol, 1 or 0.
84  *
85  *  @param nth order of differentiation
86  *  @see ex::diff */
87 ex symbol::diff(symbol const & s, unsigned nth) const
88 {
89     if (compare_same_type(s)) {
90         switch (nth) {
91         case 0:
92             return s;
93             break;
94         case 1:
95             return exONE();
96             break;
97         default:
98             return exZERO();
99         }
100     } else {
101         return exONE();
102     }
103 }
104
105
106 /** Implementation of ex::diff() for an indexed object. It always returns 0.
107  *  @see ex::diff */
108 ex indexed::diff(symbol const & s) const
109 {
110         return exZERO();
111 }
112
113
114 /** Implementation of ex::diff() for an expairseq. It differentiates all elements of the sequence.
115  *  @see ex::diff */
116 ex expairseq::diff(symbol const & s) const
117 {
118     return thisexpairseq(diffchildren(s),overall_coeff);
119 }
120
121
122 /** Implementation of ex::diff() for a sum. It differentiates each term.
123  *  @see ex::diff */
124 ex add::diff(symbol const & s) const
125 {
126     // D(a+b+c)=D(a)+D(b)+D(c)
127     return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
128 }
129
130
131 /** Implementation of ex::diff() for a product. It applies the product rule.
132  *  @see ex::diff */
133 ex mul::diff(symbol const & s) const
134 {
135     exvector new_seq;
136     new_seq.reserve(seq.size());
137
138     // D(a*b*c)=D(a)*b*c+a*D(b)*c+a*b*D(c)
139     for (unsigned i=0; i!=seq.size(); i++) {
140         epvector sub_seq=seq;
141         sub_seq[i] = split_ex_to_pair(sub_seq[i].coeff*
142                                       power(sub_seq[i].rest,sub_seq[i].coeff-1)*
143                                       sub_seq[i].rest.diff(s));
144         new_seq.push_back((new mul(sub_seq,overall_coeff))->setflag(status_flags::dynallocated));
145     }
146     return (new add(new_seq))->setflag(status_flags::dynallocated);
147 }
148
149
150 /** Implementation of ex::diff() for a non-commutative product. It always returns 0.
151  *  @see ex::diff */
152 ex ncmul::diff(symbol const & s) const
153 {
154     return exZERO();
155 }
156
157
158 /** Implementation of ex::diff() for a power.
159  *  @see ex::diff */
160 ex power::diff(symbol const & s) const
161 {
162     if (exponent.info(info_flags::real)) {
163         // D(b^r) = r * b^(r-1) * D(b) (faster than the formula below)
164         return mul(mul(exponent, power(basis, exponent - exONE())), basis.diff(s));
165     } else {
166         // D(b^e) = b^e * (D(e)*ln(b) + e*D(b)/b)
167         return mul(power(basis, exponent),
168                    add(mul(exponent.diff(s), log(basis)),
169                        mul(mul(exponent, basis.diff(s)), power(basis, -1))));
170     }
171 }
172
173
174 /** Implementation of ex::diff() for functions. It applies the chain rule,
175  *  except for the Order term function.
176  *  @see ex::diff */
177 ex function::diff(symbol const & s) const
178 {
179     exvector new_seq;
180     
181     if (serial==function_index_Order) {
182         // Order Term function only differentiates the argument
183         return Order(seq[0].diff(s));
184     } else {
185         // Chain rule
186         ex arg_diff;
187         for (unsigned i=0; i!=seq.size(); i++) {
188             arg_diff = seq[i].diff(s);
189             // We apply the chain rule only when it makes sense.  This is not
190             // just for performance reasons but also to allow functions to
191             // throw when differentiated with respect to one of its arguments
192             // without running into trouble with our automatic full
193             // differentiation:
194             if (!arg_diff.is_zero())
195                 new_seq.push_back(mul(pdiff(i), arg_diff));
196         }
197     }
198     return add(new_seq);
199 }
200
201
202 /** Implementation of ex::diff() for a power-series. It treats the series as a polynomial.
203  *  @see ex::diff */
204 ex series::diff(symbol const & s) const
205 {
206     if (s == var) {
207         epvector new_seq;
208         epvector::const_iterator it = seq.begin(), itend = seq.end();
209         
210         // FIXME: coeff might depend on var
211         while (it != itend) {
212             if (is_order_function(it->rest)) {
213                 new_seq.push_back(expair(it->rest, it->coeff - 1));
214             } else {
215                 ex c = it->rest * it->coeff;
216                 if (!c.is_zero())
217                     new_seq.push_back(expair(c, it->coeff - 1));
218             }
219             it++;
220         }
221         return series(var, point, new_seq);
222     } else {
223         return *this;
224     }
225 }
226
227
228 /** Compute partial derivative of an expression.
229  *
230  *  @param s  symbol by which the expression is derived
231  *  @param nth  order of derivative (default 1)
232  *  @return partial derivative as a new expression */
233
234 ex ex::diff(symbol const & s, unsigned nth) const
235 {
236     GINAC_ASSERT(bp!=0);
237
238     if (nth==0) {
239         return *this;
240     }
241
242     ex ndiff = bp->diff(s);
243     while (nth>1) {
244         ndiff = ndiff.diff(s);
245         --nth;
246     }
247     return ndiff;
248 }
249
250 #ifndef NO_GINAC_NAMESPACE
251 } // namespace GiNaC
252 #endif // ndef NO_GINAC_NAMESPACE