]> www.ginac.de Git - ginac.git/blob - ginac/diff.cpp
dec8dc7d590fce307de4fe0136e4bab5ed33b836
[ginac.git] / ginac / diff.cpp
1 /** @file diff.cpp
2  *
3  *  Implementation of symbolic differentiation in all of GiNaC's classes. */
4
5 /*
6  *  GiNaC Copyright (C) 1999 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <stdexcept>
24
25 #include "basic.h"
26 #include "ex.h"
27 #include "add.h"
28 #include "constant.h"
29 #include "expairseq.h"
30 #include "indexed.h"
31 #include "inifcns.h"
32 #include "mul.h"
33 #include "ncmul.h"
34 #include "numeric.h"
35 #include "power.h"
36 #include "relational.h"
37 #include "series.h"
38 #include "symbol.h"
39
40 /** Default implementation of ex::diff(). It prints and error message and returns a fail object.
41  *  @see ex::diff */
42 ex basic::diff(symbol const & s) const
43 {
44     throw(std::logic_error("differentiation not supported by this type"));
45 }
46
47
48 /** Implementation of ex::diff() for a numeric. It always returns 0.
49  *
50  *  @see ex::diff */
51 ex numeric::diff(symbol const & s) const
52 {
53     return exZERO();
54 }
55
56
57 /** Implementation of ex::diff() for single differentiation of a symbol.
58  *  It returns 1 or 0.
59  *
60  *  @see ex::diff */
61 ex symbol::diff(symbol const & s) const
62 {
63     if (compare_same_type(s)) {
64         return exZERO();
65     } else {
66         return exONE();
67     }
68 }
69
70 /** Implementation of ex::diff() for a constant. It always returns 0.
71  *
72  *  @see ex::diff */
73 ex constant::diff(symbol const & s) const
74 {
75     return exZERO();
76 }
77
78 /** Implementation of ex::diff() for multiple differentiation of a symbol.
79  *  It returns the symbol, 1 or 0.
80  *
81  *  @param nth order of differentiation
82  *  @see ex::diff */
83 ex symbol::diff(symbol const & s, unsigned nth) const
84 {
85     if (compare_same_type(s)) {
86         switch (nth) {
87         case 0:
88             return s;
89             break;
90         case 1:
91             return exONE();
92             break;
93         default:
94             return exZERO();
95         }
96     } else {
97         return exONE();
98     }
99 }
100
101
102 /** Implementation of ex::diff() for an indexed object. It always returns 0.
103  *  @see ex::diff */
104 ex indexed::diff(symbol const & s) const
105 {
106         return exZERO();
107 }
108
109
110 /** Implementation of ex::diff() for an expairseq. It differentiates all elements of the sequence.
111  *  @see ex::diff */
112 ex expairseq::diff(symbol const & s) const
113 {
114     return thisexpairseq(diffchildren(s),overall_coeff);
115 }
116
117
118 /** Implementation of ex::diff() for a sum. It differentiates each term.
119  *  @see ex::diff */
120 ex add::diff(symbol const & s) const
121 {
122     // D(a+b+c)=D(a)+D(b)+D(c)
123     return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
124 }
125
126
127 /** Implementation of ex::diff() for a product. It applies the product rule.
128  *  @see ex::diff */
129 ex mul::diff(symbol const & s) const
130 {
131     exvector new_seq;
132     new_seq.reserve(seq.size());
133
134     // D(a*b*c)=D(a)*b*c+a*D(b)*c+a*b*D(c)
135     for (unsigned i=0; i!=seq.size(); i++) {
136         epvector sub_seq=seq;
137         sub_seq[i] = split_ex_to_pair(sub_seq[i].coeff*
138                                       power(sub_seq[i].rest,sub_seq[i].coeff-1)*
139                                       sub_seq[i].rest.diff(s));
140         new_seq.push_back((new mul(sub_seq,overall_coeff))->setflag(status_flags::dynallocated));
141     }
142     return (new add(new_seq))->setflag(status_flags::dynallocated);
143 }
144
145
146 /** Implementation of ex::diff() for a non-commutative product. It always returns 0.
147  *  @see ex::diff */
148 ex ncmul::diff(symbol const & s) const
149 {
150     return exZERO();
151 }
152
153
154 /** Implementation of ex::diff() for a power.
155  *  @see ex::diff */
156 ex power::diff(symbol const & s) const
157 {
158     if (exponent.info(info_flags::real)) {
159         // D(b^r) = r * b^(r-1) * D(b) (faster than the formula below)
160         return mul(mul(exponent, power(basis, exponent - exONE())), basis.diff(s));
161     } else {
162         // D(b^e) = b^e * (D(e)*ln(b) + e*D(b)/b)
163         return mul(power(basis, exponent),
164                    add(mul(exponent.diff(s), log(basis)),
165                        mul(mul(exponent, basis.diff(s)), power(basis, -1))));
166     }
167 }
168
169
170 /** Implementation of ex::diff() for functions. It applies the chain rule,
171  *  except for the Order term function.
172  *  @see ex::diff */
173 ex function::diff(symbol const & s) const
174 {
175     exvector new_seq;
176
177     if (serial == function_index_Order) {
178
179         // Order Term function only differentiates the argument
180         return Order(seq[0].diff(s));
181
182     } else {
183
184         // Chain rule
185         for (unsigned i=0; i!=seq.size(); i++) {
186             new_seq.push_back(mul(pdiff(i), seq[i].diff(s)));
187         }
188     }
189     return add(new_seq);
190 }
191
192
193 /** Implementation of ex::diff() for a power-series. It treats the series as a polynomial.
194  *  @see ex::diff */
195 ex series::diff(symbol const & s) const
196 {
197     if (s == var) {
198         epvector new_seq;
199         epvector::const_iterator it = seq.begin(), itend = seq.end();
200         
201         //!! coeff might depend on var
202         while (it != itend) {
203             if (is_order_function(it->rest)) {
204                 new_seq.push_back(expair(it->rest, it->coeff - 1));
205             } else {
206                 ex c = it->rest * it->coeff;
207                 if (!c.is_zero())
208                     new_seq.push_back(expair(c, it->coeff - 1));
209             }
210             it++;
211         }
212         return series(var, point, new_seq);
213     } else {
214         return *this;
215     }
216 }
217
218
219 /** Compute partial derivative of an expression.
220  *
221  *  @param s  symbol by which the expression is derived
222  *  @param nth  order of derivative (default 1)
223  *  @return partial derivative as a new expression */
224
225 ex ex::diff(symbol const & s, unsigned nth) const
226 {
227     ASSERT(bp!=0);
228
229     if ( nth==0 ) {
230         return *this;
231     }
232
233     ex ndiff = bp->diff(s);
234     while ( nth>1 ) {
235         ndiff = ndiff.diff(s);
236         --nth;
237     }
238     return ndiff;
239 }