Update copyright statements.
[ginac.git] / check / exam_differentiation.cpp
1 /** @file exam_differentiation.cpp
2  *
3  *  Tests for symbolic differentiation, including various functions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2014 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "ginac.h"
24 using namespace GiNaC;
25
26 #include <iostream>
27 using namespace std;
28
29 static unsigned check_diff(const ex &e, const symbol &x,
30                                                    const ex &d, unsigned nth=1)
31 {
32         ex ed = e.diff(x, nth);
33         if (!(ed - d).is_zero()) {
34                 switch (nth) {
35                 case 0:
36                         clog << "zeroth ";
37                         break;
38                 case 1:
39                         break;
40                 case 2:
41                         clog << "second ";
42                         break;
43                 case 3:
44                         clog << "third ";
45                         break;
46                 default:
47                         clog << nth << "th ";
48                 }
49                 clog << "derivative of " << e << " by " << x << " returned "
50                      << ed << " instead of " << d << endl;
51                 clog << "returned:" << endl;
52                 clog << tree << ed << "instead of\n" << d << dflt;
53
54                 return 1;
55         }
56         return 0;
57 }
58
59 // Simple (expanded) polynomials
60 static unsigned exam_differentiation1()
61 {
62         unsigned result = 0;
63         symbol x("x"), y("y");
64         ex e1, e2, e, d;
65         
66         // construct bivariate polynomial e to be diff'ed:
67         e1 = pow(x, -2) * 3 + pow(x, -1) * 5 + 7 + x * 11 + pow(x, 2) * 13;
68         e2 = pow(y, -2) * 5 + pow(y, -1) * 7 + 11 + y * 13 + pow(y, 2) * 17;
69         e = (e1 * e2).expand();
70         
71         // d e / dx:
72         d = ex("121-55/x^2-66/x^3-30/x^3/y^2-42/x^3/y-78/x^3*y-102/x^3*y^2-25/x^2/y^2-35/x^2/y-65/x^2*y-85/x^2*y^2+77/y+143*y+187*y^2+130*x/y^2+182/y*x+338*x*y+442*x*y^2+55/y^2+286*x",lst(x,y));
73         result += check_diff(e, x, d);
74         
75         // d e / dy:
76         d = ex("91-30/x^2/y^3-21/x^2/y^2+39/x^2+102/x^2*y-50/x/y^3-35/x/y^2+65/x+170/x*y-77*x/y^2+143*x+374*x*y-130/y^3*x^2-91/y^2*x^2+169*x^2+442*x^2*y-110/y^3*x-70/y^3+238*y-49/y^2",lst(x,y));
77         result += check_diff(e, y, d);
78         
79         // d^2 e / dx^2:
80         d = ex("286+90/x^4/y^2+126/x^4/y+234/x^4*y+306/x^4*y^2+50/x^3/y^2+70/x^3/y+130/x^3*y+170/x^3*y^2+130/y^2+182/y+338*y+442*y^2+198/x^4+110/x^3",lst(x,y));
81         result += check_diff(e, x, d, 2);
82         
83         // d^2 e / dy^2:
84         d = ex("238+90/x^2/y^4+42/x^2/y^3+102/x^2+150/x/y^4+70/x/y^3+170/x+330*x/y^4+154*x/y^3+374*x+390*x^2/y^4+182*x^2/y^3+442*x^2+210/y^4+98/y^3",lst(x,y));
85         result += check_diff(e, y, d, 2);
86         
87         return result;
88 }
89
90 // Trigonometric functions
91 static unsigned exam_differentiation2()
92 {
93         unsigned result = 0;
94         symbol x("x"), y("y"), a("a"), b("b");
95         ex e1, e2, e, d;
96         
97         // construct expression e to be diff'ed:
98         e1 = y*pow(x, 2) + a*x + b;
99         e2 = sin(e1);
100         e = b*pow(e2, 2) + y*e2 + a;
101         
102         d = 2*b*e2*cos(e1)*(2*x*y + a) + y*cos(e1)*(2*x*y + a);
103         result += check_diff(e, x, d);
104         
105         d = 2*b*pow(cos(e1),2)*pow(2*x*y + a, 2) + 4*b*y*e2*cos(e1)
106             - 2*b*pow(e2,2)*pow(2*x*y + a, 2) - y*e2*pow(2*x*y + a, 2)
107             + 2*pow(y,2)*cos(e1);
108         result += check_diff(e, x, d, 2);
109         
110         d = 2*b*e2*cos(e1)*pow(x, 2) + e2 + y*cos(e1)*pow(x, 2);
111         result += check_diff(e, y, d);
112
113         d = 2*b*pow(cos(e1),2)*pow(x,4) - 2*b*pow(e2,2)*pow(x,4)
114             + 2*cos(e1)*pow(x,2) - y*e2*pow(x,4);
115         result += check_diff(e, y, d, 2);
116         
117         // construct expression e to be diff'ed:
118         e2 = cos(e1);
119         e = b*pow(e2, 2) + y*e2 + a;
120         
121         d = -2*b*e2*sin(e1)*(2*x*y + a) - y*sin(e1)*(2*x*y + a);
122         result += check_diff(e, x, d);
123         
124         d = 2*b*pow(sin(e1),2)*pow(2*y*x + a,2) - 4*b*e2*sin(e1)*y 
125             - 2*b*pow(e2,2)*pow(2*y*x + a,2) - y*e2*pow(2*y*x + a,2)
126             - 2*pow(y,2)*sin(e1);
127         result += check_diff(e, x, d, 2);
128         
129         d = -2*b*e2*sin(e1)*pow(x,2) + e2 - y*sin(e1)*pow(x, 2);
130         result += check_diff(e, y, d);
131         
132         d = -2*b*pow(e2,2)*pow(x,4) + 2*b*pow(sin(e1),2)*pow(x,4)
133             - 2*sin(e1)*pow(x,2) - y*e2*pow(x,4);
134         result += check_diff(e, y, d, 2);
135
136         return result;
137 }
138         
139 // exp function
140 static unsigned exam_differentiation3()
141 {
142         unsigned result = 0;
143         symbol x("x"), y("y"), a("a"), b("b");
144         ex e1, e2, e, d;
145
146         // construct expression e to be diff'ed:
147         e1 = y*pow(x, 2) + a*x + b;
148         e2 = exp(e1);
149         e = b*pow(e2, 2) + y*e2 + a;
150         
151         d = 2*b*pow(e2, 2)*(2*x*y + a) + y*e2*(2*x*y + a);
152         result += check_diff(e, x, d);
153         
154         d = 4*b*pow(e2,2)*pow(2*y*x + a,2) + 4*b*pow(e2,2)*y
155             + 2*pow(y,2)*e2 + y*e2*pow(2*y*x + a,2);
156         result += check_diff(e, x, d, 2);
157         
158         d = 2*b*pow(e2,2)*pow(x,2) + e2 + y*e2*pow(x,2);
159         result += check_diff(e, y, d);
160         
161         d = 4*b*pow(e2,2)*pow(x,4) + 2*e2*pow(x,2) + y*e2*pow(x,4);
162         result += check_diff(e, y, d, 2);
163
164         return result;
165 }
166
167 // log functions
168 static unsigned exam_differentiation4()
169 {
170         unsigned result = 0;
171         symbol x("x"), y("y"), a("a"), b("b");
172         ex e1, e2, e, d;
173         
174         // construct expression e to be diff'ed:
175         e1 = y*pow(x, 2) + a*x + b;
176         e2 = log(e1);
177         e = b*pow(e2, 2) + y*e2 + a;
178         
179         d = 2*b*e2*(2*x*y + a)/e1 + y*(2*x*y + a)/e1;
180         result += check_diff(e, x, d);
181         
182         d = 2*b*pow((2*x*y + a),2)*pow(e1,-2) + 4*b*y*e2/e1
183             - 2*b*e2*pow(2*x*y + a,2)*pow(e1,-2) + 2*pow(y,2)/e1
184             - y*pow(2*x*y + a,2)*pow(e1,-2);
185         result += check_diff(e, x, d, 2);
186         
187         d = 2*b*e2*pow(x,2)/e1 + e2 + y*pow(x,2)/e1;
188         result += check_diff(e, y, d);
189         
190         d = 2*b*pow(x,4)*pow(e1,-2) - 2*b*e2*pow(e1,-2)*pow(x,4)
191             + 2*pow(x,2)/e1 - y*pow(x,4)*pow(e1,-2);
192         result += check_diff(e, y, d, 2);
193
194         return result;
195 }
196
197 // Functions with two variables
198 static unsigned exam_differentiation5()
199 {
200         unsigned result = 0;
201         symbol x("x"), y("y"), a("a"), b("b");
202         ex e1, e2, e, d;
203         
204         // test atan2
205         e1 = y*pow(x, 2) + a*x + b;
206         e2 = x*pow(y, 2) + b*y + a;
207         e = atan2(e1,e2);
208         
209         d = pow(y,2)*pow(pow(b+y*pow(x,2)+x*a,2)+pow(y*b+pow(y,2)*x+a,2),-1)*
210             (-b-y*pow(x,2)-x*a)
211            +pow(pow(b+y*pow(x,2)+x*a,2)+pow(y*b+pow(y,2)*x+a,2),-1)*
212             (y*b+pow(y,2)*x+a)*(2*y*x+a);
213         result += check_diff(e, x, d);
214         
215         return result;
216 }
217
218 // Series
219 static unsigned exam_differentiation6()
220 {
221         symbol x("x");
222         ex e, d, ed;
223         
224         e = sin(x).series(x==0, 8);
225         d = cos(x).series(x==0, 7);
226         ed = e.diff(x);
227         ed = series_to_poly(ed);
228         d = series_to_poly(d);
229         
230         if (!(ed - d).is_zero()) {
231                 clog << "derivative of " << e << " by " << x << " returned "
232                      << ed << " instead of " << d << ")" << endl;
233                 return 1;
234         }
235         return 0;
236 }
237
238 // Hashing can help a lot, if differentiation is done cleverly
239 static unsigned exam_differentiation7()
240 {
241         symbol x("x");
242         ex P = x + pow(x,3);
243         ex e = (P.diff(x) / P).diff(x, 2);
244         ex d = 6/P - 18*x/pow(P,2) - 54*pow(x,3)/pow(P,2) + 2/pow(P,3)
245             +18*pow(x,2)/pow(P,3) + 54*pow(x,4)/pow(P,3) + 54*pow(x,6)/pow(P,3);
246         
247         if (!(e-d).expand().is_zero()) {
248                 clog << "expanded second derivative of " << (P.diff(x) / P) << " by " << x
249                      << " returned " << e.expand() << " instead of " << d << endl;
250                 return 1;
251         }
252         if (e.nops() > 3) {
253                 clog << "second derivative of " << (P.diff(x) / P) << " by " << x
254                      << " has " << e.nops() << " operands.  "
255                      << "The result is still correct but not optimal: 3 are enough!  "
256                      << "(Hint: maybe the product rule for objects of class mul should be more careful about assembling the result?)" << endl;
257                 return 1;
258         }
259         return 0;
260 }
261
262 unsigned exam_differentiation()
263 {
264         unsigned result = 0;
265         
266         cout << "examining symbolic differentiation" << flush;
267         
268         result += exam_differentiation1();  cout << '.' << flush;
269         result += exam_differentiation2();  cout << '.' << flush;
270         result += exam_differentiation3();  cout << '.' << flush;
271         result += exam_differentiation4();  cout << '.' << flush;
272         result += exam_differentiation5();  cout << '.' << flush;
273         result += exam_differentiation6();  cout << '.' << flush;
274         result += exam_differentiation7();  cout << '.' << flush;
275         
276         return result;
277 }
278
279 int main(int argc, char** argv)
280 {
281         return exam_differentiation();
282 }