ff7855af4f0854dd2e7868cfeb0e01454f890d16
[ginac.git] / ginac / diff.cpp
1 /** @file diff.cpp
2  *
3  *  Implementation of symbolic differentiation in all of GiNaC's classes.
4  *
5  *  GiNaC Copyright (C) 1999 Johannes Gutenberg University Mainz, Germany
6  *
7  *  This program is free software; you can redistribute it and/or modify
8  *  it under the terms of the GNU General Public License as published by
9  *  the Free Software Foundation; either version 2 of the License, or
10  *  (at your option) any later version.
11  *
12  *  This program is distributed in the hope that it will be useful,
13  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
14  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  *  GNU General Public License for more details.
16  *
17  *  You should have received a copy of the GNU General Public License
18  *  along with this program; if not, write to the Free Software
19  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
20  */
21
22 #include <stdexcept>
23
24 #include "basic.h"
25 #include "ex.h"
26 #include "add.h"
27 #include "constant.h"
28 #include "expairseq.h"
29 #include "indexed.h"
30 #include "inifcns.h"
31 #include "mul.h"
32 #include "ncmul.h"
33 #include "numeric.h"
34 #include "power.h"
35 #include "relational.h"
36 #include "series.h"
37 #include "symbol.h"
38
39 /** Default implementation of ex::diff(). It prints and error message and returns a fail object.
40  *  @see ex::diff */
41 ex basic::diff(symbol const & s) const
42 {
43     throw(std::logic_error("differentiation not supported by this type"));
44 }
45
46
47 /** Implementation of ex::diff() for a numeric. It always returns 0.
48  *
49  *  @see ex::diff */
50 ex numeric::diff(symbol const & s) const
51 {
52     return exZERO();
53 }
54
55
56 /** Implementation of ex::diff() for single differentiation of a symbol.
57  *  It returns 1 or 0.
58  *
59  *  @see ex::diff */
60 ex symbol::diff(symbol const & s) const
61 {
62     if (compare_same_type(s)) {
63         return exZERO();
64     } else {
65         return exONE();
66     }
67 }
68
69 /** Implementation of ex::diff() for a constant. It always returns 0.
70  *
71  *  @see ex::diff */
72 ex constant::diff(symbol const & s) const
73 {
74     return exZERO();
75 }
76
77 /** Implementation of ex::diff() for multiple differentiation of a symbol.
78  *  It returns the symbol, 1 or 0.
79  *
80  *  @param nth order of differentiation
81  *  @see ex::diff */
82 ex symbol::diff(symbol const & s, unsigned nth) const
83 {
84     if (compare_same_type(s)) {
85         switch (nth) {
86         case 0:
87             return s;
88             break;
89         case 1:
90             return exONE();
91             break;
92         default:
93             return exZERO();
94         }
95     } else {
96         return exONE();
97     }
98 }
99
100
101 /** Implementation of ex::diff() for an indexed object. It always returns 0.
102  *  @see ex::diff */
103 ex indexed::diff(symbol const & s) const
104 {
105         return exZERO();
106 }
107
108
109 /** Implementation of ex::diff() for an expairseq. It differentiates all elements of the sequence.
110  *  @see ex::diff */
111 ex expairseq::diff(symbol const & s) const
112 {
113     return thisexpairseq(diffchildren(s),overall_coeff);
114 }
115
116
117 /** Implementation of ex::diff() for a sum. It differentiates each term.
118  *  @see ex::diff */
119 ex add::diff(symbol const & s) const
120 {
121     // D(a+b+c)=D(a)+D(b)+D(c)
122     return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
123 }
124
125
126 /** Implementation of ex::diff() for a product. It applies the product rule.
127  *  @see ex::diff */
128 ex mul::diff(symbol const & s) const
129 {
130     exvector new_seq;
131     new_seq.reserve(seq.size());
132
133     // D(a*b*c)=D(a)*b*c+a*D(b)*c+a*b*D(c)
134     for (unsigned i=0; i!=seq.size(); i++) {
135         epvector sub_seq=seq;
136         sub_seq[i] = split_ex_to_pair(sub_seq[i].coeff*
137                                       power(sub_seq[i].rest,sub_seq[i].coeff-1)*
138                                       sub_seq[i].rest.diff(s));
139         new_seq.push_back((new mul(sub_seq,overall_coeff))->setflag(status_flags::dynallocated));
140     }
141     return (new add(new_seq))->setflag(status_flags::dynallocated);
142 }
143
144
145 /** Implementation of ex::diff() for a non-commutative product. It always returns 0.
146  *  @see ex::diff */
147 ex ncmul::diff(symbol const & s) const
148 {
149     return exZERO();
150 }
151
152
153 /** Implementation of ex::diff() for a power.
154  *  @see ex::diff */
155 ex power::diff(symbol const & s) const
156 {
157     if (exponent.info(info_flags::real)) {
158         // D(b^r) = r * b^(r-1) * D(b) (faster than the formula below)
159         return mul(mul(exponent, power(basis, exponent - exONE())), basis.diff(s));
160     } else {
161         // D(b^e) = b^e * (D(e)*ln(b) + e*D(b)/b)
162         return mul(power(basis, exponent),
163                    add(mul(exponent.diff(s), log(basis)),
164                        mul(mul(exponent, basis.diff(s)), power(basis, -1))));
165     }
166 }
167
168
169 /** Implementation of ex::diff() for functions. It applies the chain rule,
170  *  except for the Order term function.
171  *  @see ex::diff */
172 ex function::diff(symbol const & s) const
173 {
174     exvector new_seq;
175
176     if (serial == function_index_Order) {
177
178         // Order Term function only differentiates the argument
179         return Order(seq[0].diff(s));
180
181     } else {
182
183         // Chain rule
184         for (unsigned i=0; i!=seq.size(); i++) {
185             new_seq.push_back(mul(pdiff(i), seq[i].diff(s)));
186         }
187     }
188     return add(new_seq);
189 }
190
191
192 /** Implementation of ex::diff() for a power-series. It treats the series as a polynomial.
193  *  @see ex::diff */
194 ex series::diff(symbol const & s) const
195 {
196     if (s == var) {
197         epvector new_seq;
198         epvector::const_iterator it = seq.begin(), itend = seq.end();
199         
200         //!! coeff might depend on var
201         while (it != itend) {
202             if (is_order_function(it->rest)) {
203                 new_seq.push_back(expair(it->rest, it->coeff - 1));
204             } else {
205                 ex c = it->rest * it->coeff;
206                 if (!c.is_zero())
207                     new_seq.push_back(expair(c, it->coeff - 1));
208             }
209             it++;
210         }
211         return series(var, point, new_seq);
212     } else {
213         return *this;
214     }
215 }
216
217
218 /** Compute partial derivative of an expression.
219  *
220  *  @param s  symbol by which the expression is derived
221  *  @param nth  order of derivative (default 1)
222  *  @return partial derivative as a new expression */
223
224 ex ex::diff(symbol const & s, unsigned nth) const
225 {
226     ASSERT(bp!=0);
227
228     if ( nth==0 ) {
229         return *this;
230     }
231
232     ex ndiff = bp->diff(s);
233     while ( nth>1 ) {
234         ndiff = ndiff.diff(s);
235         --nth;
236     }
237     return ndiff;
238 }