- changed function::diff() to be more tolerant by checking first if the
[ginac.git] / ginac / diff.cpp
1 /** @file diff.cpp
2  *
3  *  Implementation of symbolic differentiation in all of GiNaC's classes. */
4
5 /*
6  *  GiNaC Copyright (C) 1999 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <stdexcept>
24
25 #include "basic.h"
26 #include "ex.h"
27 #include "add.h"
28 #include "constant.h"
29 #include "expairseq.h"
30 #include "indexed.h"
31 #include "inifcns.h"
32 #include "mul.h"
33 #include "ncmul.h"
34 #include "numeric.h"
35 #include "power.h"
36 #include "relational.h"
37 #include "series.h"
38 #include "symbol.h"
39
40 namespace GiNaC {
41
42 /** Default implementation of ex::diff(). It prints and error message and returns a fail object.
43  *  @see ex::diff */
44 ex basic::diff(symbol const & s) const
45 {
46     throw(std::logic_error("differentiation not supported by this type"));
47 }
48
49
50 /** Implementation of ex::diff() for a numeric. It always returns 0.
51  *
52  *  @see ex::diff */
53 ex numeric::diff(symbol const & s) const
54 {
55     return exZERO();
56 }
57
58
59 /** Implementation of ex::diff() for single differentiation of a symbol.
60  *  It returns 1 or 0.
61  *
62  *  @see ex::diff */
63 ex symbol::diff(symbol const & s) const
64 {
65     if (compare_same_type(s)) {
66         return exZERO();
67     } else {
68         return exONE();
69     }
70 }
71
72 /** Implementation of ex::diff() for a constant. It always returns 0.
73  *
74  *  @see ex::diff */
75 ex constant::diff(symbol const & s) const
76 {
77     return exZERO();
78 }
79
80 /** Implementation of ex::diff() for multiple differentiation of a symbol.
81  *  It returns the symbol, 1 or 0.
82  *
83  *  @param nth order of differentiation
84  *  @see ex::diff */
85 ex symbol::diff(symbol const & s, unsigned nth) const
86 {
87     if (compare_same_type(s)) {
88         switch (nth) {
89         case 0:
90             return s;
91             break;
92         case 1:
93             return exONE();
94             break;
95         default:
96             return exZERO();
97         }
98     } else {
99         return exONE();
100     }
101 }
102
103
104 /** Implementation of ex::diff() for an indexed object. It always returns 0.
105  *  @see ex::diff */
106 ex indexed::diff(symbol const & s) const
107 {
108         return exZERO();
109 }
110
111
112 /** Implementation of ex::diff() for an expairseq. It differentiates all elements of the sequence.
113  *  @see ex::diff */
114 ex expairseq::diff(symbol const & s) const
115 {
116     return thisexpairseq(diffchildren(s),overall_coeff);
117 }
118
119
120 /** Implementation of ex::diff() for a sum. It differentiates each term.
121  *  @see ex::diff */
122 ex add::diff(symbol const & s) const
123 {
124     // D(a+b+c)=D(a)+D(b)+D(c)
125     return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
126 }
127
128
129 /** Implementation of ex::diff() for a product. It applies the product rule.
130  *  @see ex::diff */
131 ex mul::diff(symbol const & s) const
132 {
133     exvector new_seq;
134     new_seq.reserve(seq.size());
135
136     // D(a*b*c)=D(a)*b*c+a*D(b)*c+a*b*D(c)
137     for (unsigned i=0; i!=seq.size(); i++) {
138         epvector sub_seq=seq;
139         sub_seq[i] = split_ex_to_pair(sub_seq[i].coeff*
140                                       power(sub_seq[i].rest,sub_seq[i].coeff-1)*
141                                       sub_seq[i].rest.diff(s));
142         new_seq.push_back((new mul(sub_seq,overall_coeff))->setflag(status_flags::dynallocated));
143     }
144     return (new add(new_seq))->setflag(status_flags::dynallocated);
145 }
146
147
148 /** Implementation of ex::diff() for a non-commutative product. It always returns 0.
149  *  @see ex::diff */
150 ex ncmul::diff(symbol const & s) const
151 {
152     return exZERO();
153 }
154
155
156 /** Implementation of ex::diff() for a power.
157  *  @see ex::diff */
158 ex power::diff(symbol const & s) const
159 {
160     if (exponent.info(info_flags::real)) {
161         // D(b^r) = r * b^(r-1) * D(b) (faster than the formula below)
162         return mul(mul(exponent, power(basis, exponent - exONE())), basis.diff(s));
163     } else {
164         // D(b^e) = b^e * (D(e)*ln(b) + e*D(b)/b)
165         return mul(power(basis, exponent),
166                    add(mul(exponent.diff(s), log(basis)),
167                        mul(mul(exponent, basis.diff(s)), power(basis, -1))));
168     }
169 }
170
171
172 /** Implementation of ex::diff() for functions. It applies the chain rule,
173  *  except for the Order term function.
174  *  @see ex::diff */
175 ex function::diff(symbol const & s) const
176 {
177     exvector new_seq;
178     
179     if (serial==function_index_Order) {
180         // Order Term function only differentiates the argument
181         return Order(seq[0].diff(s));
182     } else {
183         // Chain rule
184         ex arg_diff;
185         for (unsigned i=0; i!=seq.size(); i++) {
186             arg_diff = seq[i].diff(s);
187             // We apply the chain rule only when it makes sense.  This is not
188             // just for performance reasons but also to allow functions to
189             // throw when differentiated with respect to one of its arguments
190             // without running into trouble with our automatic full
191             // differentiation:
192             if (!arg_diff.is_zero())
193                 new_seq.push_back(mul(pdiff(i), arg_diff));
194         }
195     }
196     return add(new_seq);
197 }
198
199
200 /** Implementation of ex::diff() for a power-series. It treats the series as a polynomial.
201  *  @see ex::diff */
202 ex series::diff(symbol const & s) const
203 {
204     if (s == var) {
205         epvector new_seq;
206         epvector::const_iterator it = seq.begin(), itend = seq.end();
207         
208         // FIXME: coeff might depend on var
209         while (it != itend) {
210             if (is_order_function(it->rest)) {
211                 new_seq.push_back(expair(it->rest, it->coeff - 1));
212             } else {
213                 ex c = it->rest * it->coeff;
214                 if (!c.is_zero())
215                     new_seq.push_back(expair(c, it->coeff - 1));
216             }
217             it++;
218         }
219         return series(var, point, new_seq);
220     } else {
221         return *this;
222     }
223 }
224
225
226 /** Compute partial derivative of an expression.
227  *
228  *  @param s  symbol by which the expression is derived
229  *  @param nth  order of derivative (default 1)
230  *  @return partial derivative as a new expression */
231
232 ex ex::diff(symbol const & s, unsigned nth) const
233 {
234     GINAC_ASSERT(bp!=0);
235
236     if ( nth==0 ) {
237         return *this;
238     }
239
240     ex ndiff = bp->diff(s);
241     while ( nth>1 ) {
242         ndiff = ndiff.diff(s);
243         --nth;
244     }
245     return ndiff;
246 }
247
248 } // namespace GiNaC