]> www.ginac.de Git - ginac.git/blob - ginac/diff.cpp
- switched to automake build environment
[ginac.git] / ginac / diff.cpp
1 /** @file diff.cpp
2  *
3  *  Implementation of symbolic differentiation in all of GiNaC's classes. */
4
5 #include <stdexcept>
6
7 #include "ginac.h"
8
9 /** Default implementation of ex::diff(). It prints and error message and returns a fail object.
10  *  @see ex::diff */
11 ex basic::diff(symbol const & s) const
12 {
13     throw(std::logic_error("differentiation not supported by this type"));
14 }
15
16
17 /** Implementation of ex::diff() for a numeric. It always returns 0.
18  *
19  *  @see ex::diff */
20 ex numeric::diff(symbol const & s) const
21 {
22     return exZERO();
23 }
24
25
26 /** Implementation of ex::diff() for single differentiation of a symbol.
27  *  It returns 1 or 0.
28  *
29  *  @see ex::diff */
30 ex symbol::diff(symbol const & s) const
31 {
32     if (compare_same_type(s)) {
33         return exZERO();
34     } else {
35         return exONE();
36     }
37 }
38
39 /** Implementation of ex::diff() for a constant. It always returns 0.
40  *
41  *  @see ex::diff */
42 ex constant::diff(symbol const & s) const
43 {
44     return exZERO();
45 }
46
47 /** Implementation of ex::diff() for multiple differentiation of a symbol.
48  *  It returns the symbol, 1 or 0.
49  *
50  *  @param nth order of differentiation
51  *  @see ex::diff */
52 ex symbol::diff(symbol const & s, unsigned nth) const
53 {
54     if (compare_same_type(s)) {
55         switch (nth) {
56         case 0:
57             return s;
58             break;
59         case 1:
60             return exONE();
61             break;
62         default:
63             return exZERO();
64         }
65     } else {
66         return exONE();
67     }
68 }
69
70
71 /** Implementation of ex::diff() for an indexed object. It always returns 0.
72  *  @see ex::diff */
73 ex indexed::diff(symbol const & s) const
74 {
75         return exZERO();
76 }
77
78
79 /** Implementation of ex::diff() for an expairseq. It differentiates all elements of the sequence.
80  *  @see ex::diff */
81 ex expairseq::diff(symbol const & s) const
82 {
83     return thisexpairseq(diffchildren(s),overall_coeff);
84 }
85
86
87 /** Implementation of ex::diff() for a sum. It differentiates each term.
88  *  @see ex::diff */
89 ex add::diff(symbol const & s) const
90 {
91     // D(a+b+c)=D(a)+D(b)+D(c)
92     return (new add(diffchildren(s)))->setflag(status_flags::dynallocated);
93 }
94
95
96 /** Implementation of ex::diff() for a product. It applies the product rule.
97  *  @see ex::diff */
98 ex mul::diff(symbol const & s) const
99 {
100     exvector new_seq;
101     new_seq.reserve(seq.size());
102
103     // D(a*b*c)=D(a)*b*c+a*D(b)*c+a*b*D(c)
104     for (unsigned i=0; i!=seq.size(); i++) {
105         epvector sub_seq=seq;
106         sub_seq[i] = split_ex_to_pair(sub_seq[i].coeff*
107                                       power(sub_seq[i].rest,sub_seq[i].coeff-1)*
108                                       sub_seq[i].rest.diff(s));
109         new_seq.push_back((new mul(sub_seq,overall_coeff))->setflag(status_flags::dynallocated));
110     }
111     return (new add(new_seq))->setflag(status_flags::dynallocated);
112 }
113
114
115 /** Implementation of ex::diff() for a non-commutative product. It always returns 0.
116  *  @see ex::diff */
117 ex ncmul::diff(symbol const & s) const
118 {
119     return exZERO();
120 }
121
122
123 /** Implementation of ex::diff() for a power.
124  *  @see ex::diff */
125 ex power::diff(symbol const & s) const
126 {
127     if (exponent.info(info_flags::real)) {
128         // D(b^r) = r * b^(r-1) * D(b) (faster than the formula below)
129         return mul(mul(exponent, power(basis, exponent - exONE())), basis.diff(s));
130     } else {
131         // D(b^e) = b^e * (D(e)*ln(b) + e*D(b)/b)
132         return mul(power(basis, exponent),
133                    add(mul(exponent.diff(s), log(basis)),
134                        mul(mul(exponent, basis.diff(s)), power(basis, -1))));
135     }
136 }
137
138
139 /** Implementation of ex::diff() for functions. It applies the chain rule,
140  *  except for the Order term function.
141  *  @see ex::diff */
142 ex function::diff(symbol const & s) const
143 {
144     exvector new_seq;
145
146     if (serial == function_index_Order) {
147
148         // Order Term function only differentiates the argument
149         return Order(seq[0].diff(s));
150
151     } else {
152
153         // Chain rule
154         for (unsigned i=0; i!=seq.size(); i++) {
155             new_seq.push_back(mul(pdiff(i), seq[i].diff(s)));
156         }
157     }
158     return add(new_seq);
159 }
160
161
162 /** Implementation of ex::diff() for a power-series. It treats the series as a polynomial.
163  *  @see ex::diff */
164 ex series::diff(symbol const & s) const
165 {
166     if (s == var) {
167         epvector new_seq;
168         epvector::const_iterator it = seq.begin(), itend = seq.end();
169         
170         //!! coeff might depend on var
171         while (it != itend) {
172             if (is_order_function(it->rest)) {
173                 new_seq.push_back(expair(it->rest, it->coeff - 1));
174             } else {
175                 ex c = it->rest * it->coeff;
176                 if (!c.is_zero())
177                     new_seq.push_back(expair(c, it->coeff - 1));
178             }
179             it++;
180         }
181         return series(var, point, new_seq);
182     } else {
183         return *this;
184     }
185 }
186
187
188 /** Compute partial derivative of an expression.
189  *
190  *  @param s  symbol by which the expression is derived
191  *  @param nth  order of derivative (default 1)
192  *  @return partial derivative as a new expression */
193
194 ex ex::diff(symbol const & s, unsigned nth) const
195 {
196     ASSERT(bp!=0);
197
198     if ( nth==0 ) {
199         return *this;
200     }
201
202     ex ndiff = bp->diff(s);
203     while ( nth>1 ) {
204         ndiff = ndiff.diff(s);
205         --nth;
206     }
207     return ndiff;
208 }