synced to 1.2 (typos, better ncmul::degree() and ::coeff())
[ginac.git] / check / exam_differentiation.cpp
1 /** @file exam_differentiation.cpp
2  *
3  *  Tests for symbolic differentiation, including various functions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2004 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include "exams.h"
24
25 static unsigned check_diff(const ex &e, const symbol &x,
26                                                    const ex &d, unsigned nth=1)
27 {
28         ex ed = e.diff(x, nth);
29         if (!(ed - d).is_zero()) {
30                 switch (nth) {
31                 case 0:
32                         clog << "zeroth ";
33                         break;
34                 case 1:
35                         break;
36                 case 2:
37                         clog << "second ";
38                         break;
39                 case 3:
40                         clog << "third ";
41                         break;
42                 default:
43                         clog << nth << "th ";
44                 }
45                 clog << "derivative of " << e << " by " << x << " returned "
46                      << ed << " instead of " << d << endl;
47                 clog << "returned:" << endl;
48                 clog << tree << ed << "instead of\n" << d << dflt;
49
50                 return 1;
51         }
52         return 0;
53 }
54
55 // Simple (expanded) polynomials
56 static unsigned exam_differentiation1()
57 {
58         unsigned result = 0;
59         symbol x("x"), y("y");
60         ex e1, e2, e, d;
61         
62         // construct bivariate polynomial e to be diff'ed:
63         e1 = pow(x, -2) * 3 + pow(x, -1) * 5 + 7 + x * 11 + pow(x, 2) * 13;
64         e2 = pow(y, -2) * 5 + pow(y, -1) * 7 + 11 + y * 13 + pow(y, 2) * 17;
65         e = (e1 * e2).expand();
66         
67         // d e / dx:
68         d = ex("121-55/x^2-66/x^3-30/x^3/y^2-42/x^3/y-78/x^3*y-102/x^3*y^2-25/x^2/y^2-35/x^2/y-65/x^2*y-85/x^2*y^2+77/y+143*y+187*y^2+130*x/y^2+182/y*x+338*x*y+442*x*y^2+55/y^2+286*x",lst(x,y));
69         result += check_diff(e, x, d);
70         
71         // d e / dy:
72         d = ex("91-30/x^2/y^3-21/x^2/y^2+39/x^2+102/x^2*y-50/x/y^3-35/x/y^2+65/x+170/x*y-77*x/y^2+143*x+374*x*y-130/y^3*x^2-91/y^2*x^2+169*x^2+442*x^2*y-110/y^3*x-70/y^3+238*y-49/y^2",lst(x,y));
73         result += check_diff(e, y, d);
74         
75         // d^2 e / dx^2:
76         d = ex("286+90/x^4/y^2+126/x^4/y+234/x^4*y+306/x^4*y^2+50/x^3/y^2+70/x^3/y+130/x^3*y+170/x^3*y^2+130/y^2+182/y+338*y+442*y^2+198/x^4+110/x^3",lst(x,y));
77         result += check_diff(e, x, d, 2);
78         
79         // d^2 e / dy^2:
80         d = ex("238+90/x^2/y^4+42/x^2/y^3+102/x^2+150/x/y^4+70/x/y^3+170/x+330*x/y^4+154*x/y^3+374*x+390*x^2/y^4+182*x^2/y^3+442*x^2+210/y^4+98/y^3",lst(x,y));
81         result += check_diff(e, y, d, 2);
82         
83         return result;
84 }
85
86 // Trigonometric functions
87 static unsigned exam_differentiation2()
88 {
89         unsigned result = 0;
90         symbol x("x"), y("y"), a("a"), b("b");
91         ex e1, e2, e, d;
92         
93         // construct expression e to be diff'ed:
94         e1 = y*pow(x, 2) + a*x + b;
95         e2 = sin(e1);
96         e = b*pow(e2, 2) + y*e2 + a;
97         
98         d = 2*b*e2*cos(e1)*(2*x*y + a) + y*cos(e1)*(2*x*y + a);
99         result += check_diff(e, x, d);
100         
101         d = 2*b*pow(cos(e1),2)*pow(2*x*y + a, 2) + 4*b*y*e2*cos(e1)
102             - 2*b*pow(e2,2)*pow(2*x*y + a, 2) - y*e2*pow(2*x*y + a, 2)
103             + 2*pow(y,2)*cos(e1);
104         result += check_diff(e, x, d, 2);
105         
106         d = 2*b*e2*cos(e1)*pow(x, 2) + e2 + y*cos(e1)*pow(x, 2);
107         result += check_diff(e, y, d);
108
109         d = 2*b*pow(cos(e1),2)*pow(x,4) - 2*b*pow(e2,2)*pow(x,4)
110             + 2*cos(e1)*pow(x,2) - y*e2*pow(x,4);
111         result += check_diff(e, y, d, 2);
112         
113         // construct expression e to be diff'ed:
114         e2 = cos(e1);
115         e = b*pow(e2, 2) + y*e2 + a;
116         
117         d = -2*b*e2*sin(e1)*(2*x*y + a) - y*sin(e1)*(2*x*y + a);
118         result += check_diff(e, x, d);
119         
120         d = 2*b*pow(sin(e1),2)*pow(2*y*x + a,2) - 4*b*e2*sin(e1)*y 
121             - 2*b*pow(e2,2)*pow(2*y*x + a,2) - y*e2*pow(2*y*x + a,2)
122             - 2*pow(y,2)*sin(e1);
123         result += check_diff(e, x, d, 2);
124         
125         d = -2*b*e2*sin(e1)*pow(x,2) + e2 - y*sin(e1)*pow(x, 2);
126         result += check_diff(e, y, d);
127         
128         d = -2*b*pow(e2,2)*pow(x,4) + 2*b*pow(sin(e1),2)*pow(x,4)
129             - 2*sin(e1)*pow(x,2) - y*e2*pow(x,4);
130         result += check_diff(e, y, d, 2);
131
132         return result;
133 }
134         
135 // exp function
136 static unsigned exam_differentiation3()
137 {
138         unsigned result = 0;
139         symbol x("x"), y("y"), a("a"), b("b");
140         ex e1, e2, e, d;
141
142         // construct expression e to be diff'ed:
143         e1 = y*pow(x, 2) + a*x + b;
144         e2 = exp(e1);
145         e = b*pow(e2, 2) + y*e2 + a;
146         
147         d = 2*b*pow(e2, 2)*(2*x*y + a) + y*e2*(2*x*y + a);
148         result += check_diff(e, x, d);
149         
150         d = 4*b*pow(e2,2)*pow(2*y*x + a,2) + 4*b*pow(e2,2)*y
151             + 2*pow(y,2)*e2 + y*e2*pow(2*y*x + a,2);
152         result += check_diff(e, x, d, 2);
153         
154         d = 2*b*pow(e2,2)*pow(x,2) + e2 + y*e2*pow(x,2);
155         result += check_diff(e, y, d);
156         
157         d = 4*b*pow(e2,2)*pow(x,4) + 2*e2*pow(x,2) + y*e2*pow(x,4);
158         result += check_diff(e, y, d, 2);
159
160         return result;
161 }
162
163 // log functions
164 static unsigned exam_differentiation4()
165 {
166         unsigned result = 0;
167         symbol x("x"), y("y"), a("a"), b("b");
168         ex e1, e2, e, d;
169         
170         // construct expression e to be diff'ed:
171         e1 = y*pow(x, 2) + a*x + b;
172         e2 = log(e1);
173         e = b*pow(e2, 2) + y*e2 + a;
174         
175         d = 2*b*e2*(2*x*y + a)/e1 + y*(2*x*y + a)/e1;
176         result += check_diff(e, x, d);
177         
178         d = 2*b*pow((2*x*y + a),2)*pow(e1,-2) + 4*b*y*e2/e1
179             - 2*b*e2*pow(2*x*y + a,2)*pow(e1,-2) + 2*pow(y,2)/e1
180             - y*pow(2*x*y + a,2)*pow(e1,-2);
181         result += check_diff(e, x, d, 2);
182         
183         d = 2*b*e2*pow(x,2)/e1 + e2 + y*pow(x,2)/e1;
184         result += check_diff(e, y, d);
185         
186         d = 2*b*pow(x,4)*pow(e1,-2) - 2*b*e2*pow(e1,-2)*pow(x,4)
187             + 2*pow(x,2)/e1 - y*pow(x,4)*pow(e1,-2);
188         result += check_diff(e, y, d, 2);
189
190         return result;
191 }
192
193 // Functions with two variables
194 static unsigned exam_differentiation5()
195 {
196         unsigned result = 0;
197         symbol x("x"), y("y"), a("a"), b("b");
198         ex e1, e2, e, d;
199         
200         // test atan2
201         e1 = y*pow(x, 2) + a*x + b;
202         e2 = x*pow(y, 2) + b*y + a;
203         e = atan2(e1,e2);
204         
205         d = pow(y,2)*pow(pow(b+y*pow(x,2)+x*a,2)+pow(y*b+pow(y,2)*x+a,2),-1)*
206             (-b-y*pow(x,2)-x*a)
207            +pow(pow(b+y*pow(x,2)+x*a,2)+pow(y*b+pow(y,2)*x+a,2),-1)*
208             (y*b+pow(y,2)*x+a)*(2*y*x+a);
209         result += check_diff(e, x, d);
210         
211         return result;
212 }
213
214 // Series
215 static unsigned exam_differentiation6()
216 {
217         symbol x("x");
218         ex e, d, ed;
219         
220         e = sin(x).series(x==0, 8);
221         d = cos(x).series(x==0, 7);
222         ed = e.diff(x);
223         ed = series_to_poly(ed);
224         d = series_to_poly(d);
225         
226         if (!(ed - d).is_zero()) {
227                 clog << "derivative of " << e << " by " << x << " returned "
228                      << ed << " instead of " << d << ")" << endl;
229                 return 1;
230         }
231         return 0;
232 }
233
234 // Hashing can help a lot, if differentiation is done cleverly
235 static unsigned exam_differentiation7()
236 {
237         symbol x("x");
238         ex P = x + pow(x,3);
239         ex e = (P.diff(x) / P).diff(x, 2);
240         ex d = 6/P - 18*x/pow(P,2) - 54*pow(x,3)/pow(P,2) + 2/pow(P,3)
241             +18*pow(x,2)/pow(P,3) + 54*pow(x,4)/pow(P,3) + 54*pow(x,6)/pow(P,3);
242         
243         if (!(e-d).expand().is_zero()) {
244                 clog << "expanded second derivative of " << (P.diff(x) / P) << " by " << x
245                      << " returned " << e.expand() << " instead of " << d << endl;
246                 return 1;
247         }
248         if (e.nops() > 3) {
249                 clog << "second derivative of " << (P.diff(x) / P) << " by " << x
250                      << " has " << e.nops() << " operands.  "
251                      << "The result is still correct but not optimal: 3 are enough!  "
252                      << "(Hint: maybe the product rule for objects of class mul should be more careful about assembling the result?)" << endl;
253                 return 1;
254         }
255         return 0;
256 }
257
258 unsigned exam_differentiation()
259 {
260         unsigned result = 0;
261         
262         cout << "examining symbolic differentiation" << flush;
263         clog << "----------symbolic differentiation:" << endl;
264         
265         result += exam_differentiation1();  cout << '.' << flush;
266         result += exam_differentiation2();  cout << '.' << flush;
267         result += exam_differentiation3();  cout << '.' << flush;
268         result += exam_differentiation4();  cout << '.' << flush;
269         result += exam_differentiation5();  cout << '.' << flush;
270         result += exam_differentiation6();  cout << '.' << flush;
271         result += exam_differentiation7();  cout << '.' << flush;
272         
273         if (!result) {
274                 cout << " passed " << endl;
275                 clog << "(no output)" << endl;
276         } else {
277                 cout << " failed " << endl;
278         }
279         return result;
280 }