d5fda588164ab990fa2d17d7222475e97c53b63c
[ginac.git] / check / differentiation.cpp
1 /** @file differentiation.cpp
2  *
3  *  Tests for symbolic differentiation, including various functions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2000 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <ginac/ginac.h>
24
25 #ifndef NO_GINAC_NAMESPACE
26 using namespace GiNaC;
27 #endif // ndef NO_GINAC_NAMESPACE
28
29 static unsigned check_diff(const ex &e, const symbol &x,
30                            const ex &d, unsigned nth=1)
31 {
32     ex ed = e.diff(x, nth);
33     if ((ed - d).compare(ex(0)) != 0) {
34         switch (nth) {
35         case 0:
36             clog << "zeroth ";
37             break;
38         case 1:
39             break;
40         case 2:
41             clog << "second ";
42             break;
43         case 3:
44             clog << "third ";
45             break;
46         default:
47             clog << nth << "th ";
48         }
49         clog << "derivative of " << e << " by " << x << " returned "
50              << ed << " instead of " << d << endl;
51         clog << "returned:" << endl;
52         ed.printtree(clog);
53         clog << endl << "instead of" << endl;
54         d.printtree(clog);
55
56         return 1;
57     }
58     return 0;
59 }
60
61 // Simple (expanded) polynomials
62 static unsigned differentiation1(void)
63 {
64     unsigned result = 0;
65     symbol x("x"), y("y");
66     ex e1, e2, e, d;
67     
68     // construct bivariate polynomial e to be diff'ed:
69     e1 = pow(x, -2) * 3 + pow(x, -1) * 5 + 7 + x * 11 + pow(x, 2) * 13;
70     e2 = pow(y, -2) * 5 + pow(y, -1) * 7 + 11 + y * 13 + pow(y, 2) * 17;
71     e = (e1 * e2).expand();
72     
73     // d e / dx:
74     d = 121 - 55*pow(x,-2) - 66*pow(x,-3) - 30*pow(x,-3)*pow(y,-2)
75         - 42*pow(x,-3)*pow(y,-1) - 78*pow(x,-3)*y
76         - 102*pow(x,-3)*pow(y,2) - 25*pow(x,-2) * pow(y,-2)
77         - 35*pow(x,-2)*pow(y,-1) - 65*pow(x,-2)*y
78         - 85*pow(x,-2)*pow(y,2) + 77*pow(y,-1) + 143*y + 187*pow(y,2)
79         + 130*x*pow(y,-2) + 182*pow(y,-1)*x + 338*x*y + 442*x*pow(y,2)
80         + 55*pow(y,-2) + 286*x;
81     result += check_diff(e, x, d);
82     
83     // d e / dy:
84     d = 91 - 30*pow(x,-2)*pow(y,-3) - 21*pow(x,-2)*pow(y,-2)
85         + 39*pow(x,-2) + 102*pow(x,-2)*y - 50*pow(x,-1)*pow(y,-3)
86         - 35*pow(x,-1)*pow(y,-2) + 65*pow(x,-1) + 170*pow(x,-1)*y
87         - 77*pow(y,-2)*x + 143*x + 374*x*y - 130*pow(y,-3)*pow(x,2)
88         - 91*pow(y,-2)*pow(x,2) + 169*pow(x,2) + 442*pow(x,2)*y
89         - 110*pow(y,-3)*x - 70*pow(y,-3) + 238*y - 49*pow(y,-2);
90     result += check_diff(e, y, d);
91     
92     // d^2 e / dx^2:
93     d = 286 + 90*pow(x,-4)*pow(y,-2) + 126*pow(x,-4)*pow(y,-1)
94         + 234*pow(x,-4)*y + 306*pow(x,-4)*pow(y,2)
95         + 50*pow(x,-3)*pow(y,-2) + 70*pow(x,-3)*pow(y,-1)
96         + 130*pow(x,-3)*y + 170*pow(x,-3)*pow(y,2)
97         + 130*pow(y,-2) + 182*pow(y,-1) + 338*y + 442*pow(y,2)
98         + 198*pow(x,-4) + 110*pow(x,-3);
99     result += check_diff(e, x, d, 2);
100     
101     // d^2 e / dy^2:
102     d = 238 + 90*pow(x,-2)*pow(y,-4) + 42*pow(x,-2)*pow(y,-3)
103         + 102*pow(x,-2) + 150*pow(x,-1)*pow(y,-4)
104         + 70*pow(x,-1)*pow(y,-3) + 170*pow(x,-1) + 330*x*pow(y,-4)
105         + 154*x*pow(y,-3) + 374*x + 390*pow(x,2)*pow(y,-4)
106         + 182*pow(x,2)*pow(y,-3) + 442*pow(x,2) + 210*pow(y,-4)
107         + 98*pow(y,-3);
108     result += check_diff(e, y, d, 2);
109     
110     return result;
111 }
112
113 // Trigonometric functions
114 static unsigned differentiation2(void)
115 {
116     unsigned result = 0;
117     symbol x("x"), y("y"), a("a"), b("b");
118     ex e1, e2, e, d;
119     
120     // construct expression e to be diff'ed:
121     e1 = y*pow(x, 2) + a*x + b;
122     e2 = sin(e1);
123     e = b*pow(e2, 2) + y*e2 + a;
124     
125     d = 2*b*e2*cos(e1)*(2*x*y + a) + y*cos(e1)*(2*x*y + a);
126     result += check_diff(e, x, d);
127     
128     d = 2*b*pow(cos(e1),2)*pow(2*x*y + a, 2) + 4*b*y*e2*cos(e1)
129         - 2*b*pow(e2,2)*pow(2*x*y + a, 2) - y*e2*pow(2*x*y + a, 2)
130         + 2*pow(y,2)*cos(e1);
131     result += check_diff(e, x, d, 2);
132     
133     d = 2*b*e2*cos(e1)*pow(x, 2) + e2 + y*cos(e1)*pow(x, 2);
134     result += check_diff(e, y, d);
135
136     d = 2*b*pow(cos(e1),2)*pow(x,4) - 2*b*pow(e2,2)*pow(x,4)
137         + 2*cos(e1)*pow(x,2) - y*e2*pow(x,4);
138     result += check_diff(e, y, d, 2);
139     
140     // construct expression e to be diff'ed:
141     e2 = cos(e1);
142     e = b*pow(e2, 2) + y*e2 + a;
143     
144     d = -2*b*e2*sin(e1)*(2*x*y + a) - y*sin(e1)*(2*x*y + a);
145     result += check_diff(e, x, d);
146     
147     d = 2*b*pow(sin(e1),2)*pow(2*y*x + a,2) - 4*b*e2*sin(e1)*y 
148         - 2*b*pow(e2,2)*pow(2*y*x + a,2) - y*e2*pow(2*y*x + a,2)
149         - 2*pow(y,2)*sin(e1);
150     result += check_diff(e, x, d, 2);
151     
152     d = -2*b*e2*sin(e1)*pow(x,2) + e2 - y*sin(e1)*pow(x, 2);
153     result += check_diff(e, y, d);
154     
155     d = -2*b*pow(e2,2)*pow(x,4) + 2*b*pow(sin(e1),2)*pow(x,4)
156         - 2*sin(e1)*pow(x,2) - y*e2*pow(x,4);
157     result += check_diff(e, y, d, 2);
158
159         return result;
160 }
161     
162 // exp function
163 static unsigned differentiation3(void)
164 {
165     unsigned result = 0;
166     symbol x("x"), y("y"), a("a"), b("b");
167     ex e1, e2, e, d;
168
169     // construct expression e to be diff'ed:
170     e1 = y*pow(x, 2) + a*x + b;
171     e2 = exp(e1);
172     e = b*pow(e2, 2) + y*e2 + a;
173     
174     d = 2*b*pow(e2, 2)*(2*x*y + a) + y*e2*(2*x*y + a);
175     result += check_diff(e, x, d);
176     
177     d = 4*b*pow(e2,2)*pow(2*y*x + a,2) + 4*b*pow(e2,2)*y
178         + 2*pow(y,2)*e2 + y*e2*pow(2*y*x + a,2);
179     result += check_diff(e, x, d, 2);
180     
181     d = 2*b*pow(e2,2)*pow(x,2) + e2 + y*e2*pow(x,2);
182     result += check_diff(e, y, d);
183     
184     d = 4*b*pow(e2,2)*pow(x,4) + 2*e2*pow(x,2) + y*e2*pow(x,4);
185     result += check_diff(e, y, d, 2);
186
187         return result;
188 }
189
190 // log functions
191 static unsigned differentiation4(void)
192 {
193     unsigned result = 0;
194     symbol x("x"), y("y"), a("a"), b("b");
195     ex e1, e2, e, d;
196     
197     // construct expression e to be diff'ed:
198     e1 = y*pow(x, 2) + a*x + b;
199     e2 = log(e1);
200     e = b*pow(e2, 2) + y*e2 + a;
201     
202     d = 2*b*e2*(2*x*y + a)/e1 + y*(2*x*y + a)/e1;
203     result += check_diff(e, x, d);
204     
205     d = 2*b*pow((2*x*y + a),2)*pow(e1,-2) + 4*b*y*e2/e1
206         - 2*b*e2*pow(2*x*y + a,2)*pow(e1,-2) + 2*pow(y,2)/e1
207         - y*pow(2*x*y + a,2)*pow(e1,-2);
208     result += check_diff(e, x, d, 2);
209     
210     d = 2*b*e2*pow(x,2)/e1 + e2 + y*pow(x,2)/e1;
211     result += check_diff(e, y, d);
212     
213     d = 2*b*pow(x,4)*pow(e1,-2) - 2*b*e2*pow(e1,-2)*pow(x,4)
214         + 2*pow(x,2)/e1 - y*pow(x,4)*pow(e1,-2);
215     result += check_diff(e, y, d, 2);
216
217         return result;
218 }
219
220 // Functions with two variables
221 static unsigned differentiation5(void)
222 {
223     unsigned result = 0;
224     symbol x("x"), y("y"), a("a"), b("b");
225     ex e1, e2, e, d;
226     
227     // test atan2
228     e1 = y*pow(x, 2) + a*x + b;
229     e2 = x*pow(y, 2) + b*y + a;
230     e = atan2(e1,e2);
231     /*
232     d = pow(y,2)*(-b-y*pow(x,2)-a*x)/(pow(b+y*pow(x,2)+a*x,2)+pow(x*pow(y,2)+b*y+a,2))
233         +(2*y*x+a)/((x*pow(y,2)+b*y+a)*(1+pow(b*y*pow(x,2)+a*x,2)/pow(x*pow(y,2)+b*y+a,2)));
234     */
235     /*
236     d = ((a+2*y*x)*pow(y*b+pow(y,2)*x+a,-1)-(a*x+b+y*pow(x,2))*
237          pow(y*b+pow(y,2)*x+a,-2)*pow(y,2))*
238         pow(1+pow(a*x+b+y*pow(x,2),2)*pow(y*b+pow(y,2)*x+a,-2),-1);
239     */
240     /*
241     d = pow(1+pow(a*x+b+y*pow(x,2),2)*pow(y*b+pow(y,2)*x+a,-2),-1)
242         *pow(y*b+pow(y,2)*x+a,-1)*(a+2*y*x)
243         +pow(y,2)*(-a*x-b-y*pow(x,2))*
244         pow(pow(y*b+pow(y,2)*x+a,2)+pow(a*x+b+y*pow(x,2),2),-1);
245     */
246     d = pow(y,2)*pow(pow(b+y*pow(x,2)+x*a,2)+pow(y*b+pow(y,2)*x+a,2),-1)*
247         (-b-y*pow(x,2)-x*a)+
248         pow(pow(b+y*pow(x,2)+x*a,2)+pow(y*b+pow(y,2)*x+a,2),-1)*
249         (y*b+pow(y,2)*x+a)*(2*y*x+a);
250     result += check_diff(e, x, d);
251     
252     return result;
253 }
254
255 // Series
256 static unsigned differentiation6(void)
257 {
258     symbol x("x");
259     ex e, d, ed;
260     
261     e = sin(x).series(x, 0, 8);
262     d = cos(x).series(x, 0, 7);
263     ed = e.diff(x);
264     ed = static_cast<series *>(ed.bp)->convert_to_poly();
265     d = static_cast<series *>(d.bp)->convert_to_poly();
266     
267     if ((ed - d).compare(ex(0)) != 0) {
268         clog << "derivative of " << e << " by " << x << " returned "
269              << ed << " instead of " << d << ")" << endl;
270         return 1;
271     }
272     return 0;
273 }
274
275 unsigned differentiation(void)
276 {
277     unsigned result = 0;
278     
279     cout << "checking symbolic differentiation..." << flush;
280     clog << "---------symbolic differentiation:" << endl;
281     
282     result += differentiation1();
283     result += differentiation2();
284     result += differentiation3();
285     result += differentiation4();
286     result += differentiation5();
287     result += differentiation6();
288     
289     if (!result) {
290         cout << " passed ";
291         clog << "(no output)" << endl;
292     } else {
293         cout << " failed ";
294     }
295     return result;
296 }