- Same shit as Christian did yesterday in ginac/.
[ginac.git] / check / exam_differentiation.cpp
1 /** @file exam_differentiation.cpp
2  *
3  *  Tests for symbolic differentiation, including various functions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2000 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include "exams.h"
24
25 static unsigned check_diff(const ex &e, const symbol &x,
26                                                    const ex &d, unsigned nth=1)
27 {
28         ex ed = e.diff(x, nth);
29         if ((ed - d).compare(ex(0)) != 0) {
30                 switch (nth) {
31                 case 0:
32                         clog << "zeroth ";
33                         break;
34                 case 1:
35                         break;
36                 case 2:
37                         clog << "second ";
38                         break;
39                 case 3:
40                         clog << "third ";
41                         break;
42                 default:
43                         clog << nth << "th ";
44                 }
45                 clog << "derivative of " << e << " by " << x << " returned "
46                          << ed << " instead of " << d << endl;
47                 clog << "returned:" << endl;
48                 ed.printtree(clog);
49                 clog << endl << "instead of" << endl;
50                 d.printtree(clog);
51
52                 return 1;
53         }
54         return 0;
55 }
56
57 // Simple (expanded) polynomials
58 static unsigned exam_differentiation1(void)
59 {
60         unsigned result = 0;
61         symbol x("x"), y("y");
62         ex e1, e2, e, d;
63         
64         // construct bivariate polynomial e to be diff'ed:
65         e1 = pow(x, -2) * 3 + pow(x, -1) * 5 + 7 + x * 11 + pow(x, 2) * 13;
66         e2 = pow(y, -2) * 5 + pow(y, -1) * 7 + 11 + y * 13 + pow(y, 2) * 17;
67         e = (e1 * e2).expand();
68         
69         // d e / dx:
70         d = ex("121-55/x^2-66/x^3-30/x^3/y^2-42/x^3/y-78/x^3*y-102/x^3*y^2-25/x^2/y^2-35/x^2/y-65/x^2*y-85/x^2*y^2+77/y+143*y+187*y^2+130*x/y^2+182/y*x+338*x*y+442*x*y^2+55/y^2+286*x",lst(x,y));
71         result += check_diff(e, x, d);
72         
73         // d e / dy:
74         d = ex("91-30/x^2/y^3-21/x^2/y^2+39/x^2+102/x^2*y-50/x/y^3-35/x/y^2+65/x+170/x*y-77*x/y^2+143*x+374*x*y-130/y^3*x^2-91/y^2*x^2+169*x^2+442*x^2*y-110/y^3*x-70/y^3+238*y-49/y^2",lst(x,y));
75         result += check_diff(e, y, d);
76         
77         // d^2 e / dx^2:
78         d = ex("286+90/x^4/y^2+126/x^4/y+234/x^4*y+306/x^4*y^2+50/x^3/y^2+70/x^3/y+130/x^3*y+170/x^3*y^2+130/y^2+182/y+338*y+442*y^2+198/x^4+110/x^3",lst(x,y));
79         result += check_diff(e, x, d, 2);
80         
81         // d^2 e / dy^2:
82         d = ex("238+90/x^2/y^4+42/x^2/y^3+102/x^2+150/x/y^4+70/x/y^3+170/x+330*x/y^4+154*x/y^3+374*x+390*x^2/y^4+182*x^2/y^3+442*x^2+210/y^4+98/y^3",lst(x,y));
83         result += check_diff(e, y, d, 2);
84         
85         return result;
86 }
87
88 // Trigonometric functions
89 static unsigned exam_differentiation2(void)
90 {
91         unsigned result = 0;
92         symbol x("x"), y("y"), a("a"), b("b");
93         ex e1, e2, e, d;
94         
95         // construct expression e to be diff'ed:
96         e1 = y*pow(x, 2) + a*x + b;
97         e2 = sin(e1);
98         e = b*pow(e2, 2) + y*e2 + a;
99         
100         d = 2*b*e2*cos(e1)*(2*x*y + a) + y*cos(e1)*(2*x*y + a);
101         result += check_diff(e, x, d);
102         
103         d = 2*b*pow(cos(e1),2)*pow(2*x*y + a, 2) + 4*b*y*e2*cos(e1)
104                 - 2*b*pow(e2,2)*pow(2*x*y + a, 2) - y*e2*pow(2*x*y + a, 2)
105                 + 2*pow(y,2)*cos(e1);
106         result += check_diff(e, x, d, 2);
107         
108         d = 2*b*e2*cos(e1)*pow(x, 2) + e2 + y*cos(e1)*pow(x, 2);
109         result += check_diff(e, y, d);
110
111         d = 2*b*pow(cos(e1),2)*pow(x,4) - 2*b*pow(e2,2)*pow(x,4)
112                 + 2*cos(e1)*pow(x,2) - y*e2*pow(x,4);
113         result += check_diff(e, y, d, 2);
114         
115         // construct expression e to be diff'ed:
116         e2 = cos(e1);
117         e = b*pow(e2, 2) + y*e2 + a;
118         
119         d = -2*b*e2*sin(e1)*(2*x*y + a) - y*sin(e1)*(2*x*y + a);
120         result += check_diff(e, x, d);
121         
122         d = 2*b*pow(sin(e1),2)*pow(2*y*x + a,2) - 4*b*e2*sin(e1)*y 
123                 - 2*b*pow(e2,2)*pow(2*y*x + a,2) - y*e2*pow(2*y*x + a,2)
124                 - 2*pow(y,2)*sin(e1);
125         result += check_diff(e, x, d, 2);
126         
127         d = -2*b*e2*sin(e1)*pow(x,2) + e2 - y*sin(e1)*pow(x, 2);
128         result += check_diff(e, y, d);
129         
130         d = -2*b*pow(e2,2)*pow(x,4) + 2*b*pow(sin(e1),2)*pow(x,4)
131                 - 2*sin(e1)*pow(x,2) - y*e2*pow(x,4);
132         result += check_diff(e, y, d, 2);
133
134         return result;
135 }
136         
137 // exp function
138 static unsigned exam_differentiation3(void)
139 {
140         unsigned result = 0;
141         symbol x("x"), y("y"), a("a"), b("b");
142         ex e1, e2, e, d;
143
144         // construct expression e to be diff'ed:
145         e1 = y*pow(x, 2) + a*x + b;
146         e2 = exp(e1);
147         e = b*pow(e2, 2) + y*e2 + a;
148         
149         d = 2*b*pow(e2, 2)*(2*x*y + a) + y*e2*(2*x*y + a);
150         result += check_diff(e, x, d);
151         
152         d = 4*b*pow(e2,2)*pow(2*y*x + a,2) + 4*b*pow(e2,2)*y
153                 + 2*pow(y,2)*e2 + y*e2*pow(2*y*x + a,2);
154         result += check_diff(e, x, d, 2);
155         
156         d = 2*b*pow(e2,2)*pow(x,2) + e2 + y*e2*pow(x,2);
157         result += check_diff(e, y, d);
158         
159         d = 4*b*pow(e2,2)*pow(x,4) + 2*e2*pow(x,2) + y*e2*pow(x,4);
160         result += check_diff(e, y, d, 2);
161
162         return result;
163 }
164
165 // log functions
166 static unsigned exam_differentiation4(void)
167 {
168         unsigned result = 0;
169         symbol x("x"), y("y"), a("a"), b("b");
170         ex e1, e2, e, d;
171         
172         // construct expression e to be diff'ed:
173         e1 = y*pow(x, 2) + a*x + b;
174         e2 = log(e1);
175         e = b*pow(e2, 2) + y*e2 + a;
176         
177         d = 2*b*e2*(2*x*y + a)/e1 + y*(2*x*y + a)/e1;
178         result += check_diff(e, x, d);
179         
180         d = 2*b*pow((2*x*y + a),2)*pow(e1,-2) + 4*b*y*e2/e1
181                 - 2*b*e2*pow(2*x*y + a,2)*pow(e1,-2) + 2*pow(y,2)/e1
182                 - y*pow(2*x*y + a,2)*pow(e1,-2);
183         result += check_diff(e, x, d, 2);
184         
185         d = 2*b*e2*pow(x,2)/e1 + e2 + y*pow(x,2)/e1;
186         result += check_diff(e, y, d);
187         
188         d = 2*b*pow(x,4)*pow(e1,-2) - 2*b*e2*pow(e1,-2)*pow(x,4)
189                 + 2*pow(x,2)/e1 - y*pow(x,4)*pow(e1,-2);
190         result += check_diff(e, y, d, 2);
191
192         return result;
193 }
194
195 // Functions with two variables
196 static unsigned exam_differentiation5(void)
197 {
198         unsigned result = 0;
199         symbol x("x"), y("y"), a("a"), b("b");
200         ex e1, e2, e, d;
201         
202         // test atan2
203         e1 = y*pow(x, 2) + a*x + b;
204         e2 = x*pow(y, 2) + b*y + a;
205         e = atan2(e1,e2);
206         /*
207         d = pow(y,2)*(-b-y*pow(x,2)-a*x)/(pow(b+y*pow(x,2)+a*x,2)+pow(x*pow(y,2)+b*y+a,2))
208                 +(2*y*x+a)/((x*pow(y,2)+b*y+a)*(1+pow(b*y*pow(x,2)+a*x,2)/pow(x*pow(y,2)+b*y+a,2)));
209         */
210         /*
211         d = ((a+2*y*x)*pow(y*b+pow(y,2)*x+a,-1)-(a*x+b+y*pow(x,2))*
212                  pow(y*b+pow(y,2)*x+a,-2)*pow(y,2))*
213                 pow(1+pow(a*x+b+y*pow(x,2),2)*pow(y*b+pow(y,2)*x+a,-2),-1);
214         */
215         /*
216         d = pow(1+pow(a*x+b+y*pow(x,2),2)*pow(y*b+pow(y,2)*x+a,-2),-1)
217                 *pow(y*b+pow(y,2)*x+a,-1)*(a+2*y*x)
218                 +pow(y,2)*(-a*x-b-y*pow(x,2))*
219                 pow(pow(y*b+pow(y,2)*x+a,2)+pow(a*x+b+y*pow(x,2),2),-1);
220         */
221         d = pow(y,2)*pow(pow(b+y*pow(x,2)+x*a,2)+pow(y*b+pow(y,2)*x+a,2),-1)*
222                 (-b-y*pow(x,2)-x*a)+
223                 pow(pow(b+y*pow(x,2)+x*a,2)+pow(y*b+pow(y,2)*x+a,2),-1)*
224                 (y*b+pow(y,2)*x+a)*(2*y*x+a);
225         result += check_diff(e, x, d);
226         
227         return result;
228 }
229
230 // Series
231 static unsigned exam_differentiation6(void)
232 {
233         symbol x("x");
234         ex e, d, ed;
235         
236         e = sin(x).series(x==0, 8);
237         d = cos(x).series(x==0, 7);
238         ed = e.diff(x);
239         ed = series_to_poly(ed);
240         d = series_to_poly(d);
241         
242         if ((ed - d).compare(ex(0)) != 0) {
243                 clog << "derivative of " << e << " by " << x << " returned "
244                          << ed << " instead of " << d << ")" << endl;
245                 return 1;
246         }
247         return 0;
248 }
249
250 // Hashing can help a lot, if differentiation is done cleverly
251 static unsigned exam_differentiation7(void)
252 {
253         symbol x("x");
254         ex P = x + pow(x,3);
255         ex e = (P.diff(x) / P).diff(x, 2);
256         ex d = 6/P - 18*x/pow(P,2) - 54*pow(x,3)/pow(P,2) + 2/pow(P,3)
257                 +18*pow(x,2)/pow(P,3) + 54*pow(x,4)/pow(P,3) + 54*pow(x,6)/pow(P,3);
258         
259         if (!(e-d).expand().is_zero()) {
260                 clog << "expanded second derivative of " << (P.diff(x) / P) << " by " << x
261                          << " returned " << e.expand() << " instead of " << d << endl;
262                 return 1;
263         }
264         if (e.nops() > 3) {
265                 clog << "second derivative of " << (P.diff(x) / P) << " by " << x
266                          << " has " << e.nops() << " operands.  "
267                          << "The result is still correct but not optimal: 3 are enough!  "
268                          << "(Hint: maybe the product rule for objects of class mul should be more careful about assembling the result?)" << endl;
269                 return 1;
270         }
271         return 0;
272 }
273
274 unsigned exam_differentiation(void)
275 {
276         unsigned result = 0;
277         
278         cout << "examining symbolic differentiation" << flush;
279         clog << "----------symbolic differentiation:" << endl;
280         
281         result += exam_differentiation1();  cout << '.' << flush;
282         result += exam_differentiation2();  cout << '.' << flush;
283         result += exam_differentiation3();  cout << '.' << flush;
284         result += exam_differentiation4();  cout << '.' << flush;
285         result += exam_differentiation5();  cout << '.' << flush;
286         result += exam_differentiation6();  cout << '.' << flush;
287         result += exam_differentiation7();  cout << '.' << flush;
288         
289         if (!result) {
290                 cout << " passed " << endl;
291                 clog << "(no output)" << endl;
292         } else {
293                 cout << " failed " << endl;
294         }
295         return result;
296 }