]> www.ginac.de Git - ginac.git/blob - ginac/diff.cpp
- enforced GiNaC coding standards :-)
[ginac.git] / ginac / diff.cpp
1 /** @file diff.cpp
2  *
3  *  Implementation of symbolic differentiation in all of GiNaC's classes.
4  *
5  *  GiNaC Copyright (C) 1999 Johannes Gutenberg University Mainz, Germany
6  *
7  *  This program is free software; you can redistribute it and/or modify
8  *  it under the terms of the GNU General Public License as published by
9  *  the Free Software Foundation; either version 2 of the License, or
10  *  (at your option) any later version.
11  *
12  *  This program is distributed in the hope that it will be useful,
13  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
14  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  *  GNU General Public License for more details.
16  *
17  *  You should have received a copy of the GNU General Public License
18  *  along with this program; if not, write to the Free Software
19  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
20  */
21
22 #include <stdexcept>
23
24 #include "ginac.h"
25
26 /** Default implementation of ex::diff(). It prints and error message and returns a fail object.
27  *  @see ex::diff */
28 ex basic::diff(symbol const & s) const
29 {
30     throw(std::logic_error("differentiation not supported by this type"));
31 }
32
33
34 /** Implementation of ex::diff() for a numeric. It always returns 0.
35  *
36  *  @see ex::diff */
37 ex numeric::diff(symbol const & s) const
38 {
39     return exZERO();
40 }
41
42
43 /** Implementation of ex::diff() for single differentiation of a symbol.
44  *  It returns 1 or 0.
45  *
46  *  @see ex::diff */
47 ex symbol::diff(symbol const & s) const
48 {
49     if (compare_same_type(s)) {
50         return exZERO();
51     } else {
52         return exONE();
53     }
54 }
55
56 /** Implementation of ex::diff() for a constant. It always returns 0.
57  *
58  *  @see ex::diff */
59 ex constant::diff(symbol const & s) const
60 {
61     return exZERO();
62 }
63
64 /** Implementation of ex::diff() for multiple differentiation of a symbol.
65  *  It returns the symbol, 1 or 0.
66  *
67  *  @param nth order of differentiation
68  *  @see ex::diff */
69 ex symbol::diff(symbol const & s, unsigned nth) const
70 {
71     if (compare_same_type(s)) {
72         switch (nth) {
73         case 0:
74             return s;
75             break;
76         case 1:
77             return exONE();
78             break;
79         default:
80             return exZERO();
81         }
82     } else {
83         return exONE();
84     }
85 }
86
87
88 /** Implementation of ex::diff() for an indexed object. It always returns 0.
89  *  @see ex::diff */
90 ex indexed::diff(symbol const & s) const
91 {
92         return exZERO();
93 }
94
95
96 /** Implementation of ex::diff() for an expairseq. It differentiates all elements of the sequence.
97  *  @see ex::diff */
98 ex expairseq::diff(symbol const & s) const
99 {
100     return thisexpairseq(diffchildren(s),overall_coeff);
101 }
102
103
104 /** Implementation of ex::diff() for a sum. It differentiates each term.
105  *  @see ex::diff */
106 ex add::diff(symbol const & s) const
107 {
108     // D(a+b+c)=D(a)+D(b)+D(c)
109     return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
110 }
111
112
113 /** Implementation of ex::diff() for a product. It applies the product rule.
114  *  @see ex::diff */
115 ex mul::diff(symbol const & s) const
116 {
117     exvector new_seq;
118     new_seq.reserve(seq.size());
119
120     // D(a*b*c)=D(a)*b*c+a*D(b)*c+a*b*D(c)
121     for (unsigned i=0; i!=seq.size(); i++) {
122         epvector sub_seq=seq;
123         sub_seq[i] = split_ex_to_pair(sub_seq[i].coeff*
124                                       power(sub_seq[i].rest,sub_seq[i].coeff-1)*
125                                       sub_seq[i].rest.diff(s));
126         new_seq.push_back((new mul(sub_seq,overall_coeff))->setflag(status_flags::dynallocated));
127     }
128     return (new add(new_seq))->setflag(status_flags::dynallocated);
129 }
130
131
132 /** Implementation of ex::diff() for a non-commutative product. It always returns 0.
133  *  @see ex::diff */
134 ex ncmul::diff(symbol const & s) const
135 {
136     return exZERO();
137 }
138
139
140 /** Implementation of ex::diff() for a power.
141  *  @see ex::diff */
142 ex power::diff(symbol const & s) const
143 {
144     if (exponent.info(info_flags::real)) {
145         // D(b^r) = r * b^(r-1) * D(b) (faster than the formula below)
146         return mul(mul(exponent, power(basis, exponent - exONE())), basis.diff(s));
147     } else {
148         // D(b^e) = b^e * (D(e)*ln(b) + e*D(b)/b)
149         return mul(power(basis, exponent),
150                    add(mul(exponent.diff(s), log(basis)),
151                        mul(mul(exponent, basis.diff(s)), power(basis, -1))));
152     }
153 }
154
155
156 /** Implementation of ex::diff() for functions. It applies the chain rule,
157  *  except for the Order term function.
158  *  @see ex::diff */
159 ex function::diff(symbol const & s) const
160 {
161     exvector new_seq;
162
163     if (serial == function_index_Order) {
164
165         // Order Term function only differentiates the argument
166         return Order(seq[0].diff(s));
167
168     } else {
169
170         // Chain rule
171         for (unsigned i=0; i!=seq.size(); i++) {
172             new_seq.push_back(mul(pdiff(i), seq[i].diff(s)));
173         }
174     }
175     return add(new_seq);
176 }
177
178
179 /** Implementation of ex::diff() for a power-series. It treats the series as a polynomial.
180  *  @see ex::diff */
181 ex series::diff(symbol const & s) const
182 {
183     if (s == var) {
184         epvector new_seq;
185         epvector::const_iterator it = seq.begin(), itend = seq.end();
186         
187         //!! coeff might depend on var
188         while (it != itend) {
189             if (is_order_function(it->rest)) {
190                 new_seq.push_back(expair(it->rest, it->coeff - 1));
191             } else {
192                 ex c = it->rest * it->coeff;
193                 if (!c.is_zero())
194                     new_seq.push_back(expair(c, it->coeff - 1));
195             }
196             it++;
197         }
198         return series(var, point, new_seq);
199     } else {
200         return *this;
201     }
202 }
203
204
205 /** Compute partial derivative of an expression.
206  *
207  *  @param s  symbol by which the expression is derived
208  *  @param nth  order of derivative (default 1)
209  *  @return partial derivative as a new expression */
210
211 ex ex::diff(symbol const & s, unsigned nth) const
212 {
213     ASSERT(bp!=0);
214
215     if ( nth==0 ) {
216         return *this;
217     }
218
219     ex ndiff = bp->diff(s);
220     while ( nth>1 ) {
221         ndiff = ndiff.diff(s);
222         --nth;
223     }
224     return ndiff;
225 }