Initial revision
[ginac.git] / check / differentiation.cpp
1 // check/differentiation.cpp
2
3 /* Tests for symbolic differentiation, including various functions. */
4
5 #include "ginac.h"
6
7 static unsigned check_diff(const ex &e, const symbol &x,
8                            const ex &d, unsigned nth=1)
9 {
10     ex ed = e.diff(x, nth);
11     if ((ed - d).compare(exZERO()) != 0) {
12         switch (nth) {
13         case 0:
14             clog << "zeroth ";
15             break;
16         case 1:
17             break;
18         case 2:
19             clog << "second ";
20             break;
21         case 3:
22             clog << "third ";
23             break;
24         default:
25             clog << nth << "th ";
26         }
27         clog << "derivative of " << e << " by " << x << " returned "
28              << ed << " instead of " << d << endl;
29         clog << "returned:" << endl;
30         ed.printtree(clog);
31         clog << endl << "instead of" << endl;
32         d.printtree(clog);
33
34         return 1;
35     }
36     return 0;
37 }
38
39 // Simple (expanded) polynomials
40 static unsigned differentiation1(void)
41 {
42     unsigned result = 0;
43     symbol x("x"), y("y");
44     ex e1, e2, e, d;
45     
46     // construct bivariate polynomial e to be diff'ed:
47     e1 = pow(x, -2) * 3 + pow(x, -1) * 5 + 7 + x * 11 + pow(x, 2) * 13;
48     e2 = pow(y, -2) * 5 + pow(y, -1) * 7 + 11 + y * 13 + pow(y, 2) * 17;
49     e = (e1 * e2).expand();
50     
51     // d e / dx:
52     d = 121 - 55*pow(x,-2) - 66*pow(x,-3) - 30*pow(x,-3)*pow(y,-2)
53         - 42*pow(x,-3)*pow(y,-1) - 78*pow(x,-3)*y
54         - 102*pow(x,-3)*pow(y,2) - 25*pow(x,-2) * pow(y,-2)
55         - 35*pow(x,-2)*pow(y,-1) - 65*pow(x,-2)*y
56         - 85*pow(x,-2)*pow(y,2) + 77*pow(y,-1) + 143*y + 187*pow(y,2)
57         + 130*x*pow(y,-2) + 182*pow(y,-1)*x + 338*x*y + 442*x*pow(y,2)
58         + 55*pow(y,-2) + 286*x;
59     result += check_diff(e, x, d);
60     
61     // d e / dy:
62     d = 91 - 30*pow(x,-2)*pow(y,-3) - 21*pow(x,-2)*pow(y,-2)
63         + 39*pow(x,-2) + 102*pow(x,-2)*y - 50*pow(x,-1)*pow(y,-3)
64         - 35*pow(x,-1)*pow(y,-2) + 65*pow(x,-1) + 170*pow(x,-1)*y
65         - 77*pow(y,-2)*x + 143*x + 374*x*y - 130*pow(y,-3)*pow(x,2)
66         - 91*pow(y,-2)*pow(x,2) + 169*pow(x,2) + 442*pow(x,2)*y
67         - 110*pow(y,-3)*x - 70*pow(y,-3) + 238*y - 49*pow(y,-2);
68     result += check_diff(e, y, d);
69     
70     // d^2 e / dx^2:
71     d = 286 + 90*pow(x,-4)*pow(y,-2) + 126*pow(x,-4)*pow(y,-1)
72         + 234*pow(x,-4)*y + 306*pow(x,-4)*pow(y,2)
73         + 50*pow(x,-3)*pow(y,-2) + 70*pow(x,-3)*pow(y,-1)
74         + 130*pow(x,-3)*y + 170*pow(x,-3)*pow(y,2)
75         + 130*pow(y,-2) + 182*pow(y,-1) + 338*y + 442*pow(y,2)
76         + 198*pow(x,-4) + 110*pow(x,-3);
77     result += check_diff(e, x, d, 2);
78     
79     // d^2 e / dy^2:
80     d = 238 + 90*pow(x,-2)*pow(y,-4) + 42*pow(x,-2)*pow(y,-3)
81         + 102*pow(x,-2) + 150*pow(x,-1)*pow(y,-4)
82         + 70*pow(x,-1)*pow(y,-3) + 170*pow(x,-1) + 330*x*pow(y,-4)
83         + 154*x*pow(y,-3) + 374*x + 390*pow(x,2)*pow(y,-4)
84         + 182*pow(x,2)*pow(y,-3) + 442*pow(x,2) + 210*pow(y,-4)
85         + 98*pow(y,-3);
86     result += check_diff(e, y, d, 2);
87     
88     return result;
89 }
90
91 // Trigonometric and transcendental functions
92 static unsigned differentiation2(void)
93 {
94     unsigned result = 0;
95     symbol x("x"), y("y"), a("a"), b("b");
96     ex e1, e2, e, d;
97     
98     // construct expression e to be diff'ed:
99     e1 = y*pow(x, 2) + a*x + b;
100     e2 = sin(e1);
101     e = b*pow(e2, 2) + y*e2 + a;
102     
103     d = 2*b*e2*cos(e1)*(2*x*y + a) + y*cos(e1)*(2*x*y + a);
104     result += check_diff(e, x, d);
105     
106     d = 2*b*pow(cos(e1),2)*pow(2*x*y + a, 2) + 4*b*y*e2*cos(e1)
107         - 2*b*pow(e2,2)*pow(2*x*y + a, 2) - y*e2*pow(2*x*y + a, 2)
108         + 2*pow(y,2)*cos(e1);
109     result += check_diff(e, x, d, 2);
110     
111     d = 2*b*e2*cos(e1)*pow(x, 2) + e2 + y*cos(e1)*pow(x, 2);
112     result += check_diff(e, y, d);
113
114     d = 2*b*pow(cos(e1),2)*pow(x,4) - 2*b*pow(e2,2)*pow(x,4)
115         + 2*cos(e1)*pow(x,2) - y*e2*pow(x,4);
116     result += check_diff(e, y, d, 2);
117     
118     // construct expression e to be diff'ed:
119     e2 = cos(e1);
120     e = b*pow(e2, 2) + y*e2 + a;
121     
122     d = -2*b*e2*sin(e1)*(2*x*y + a) - y*sin(e1)*(2*x*y + a);
123     result += check_diff(e, x, d);
124     
125     d = 2*b*pow(sin(e1),2)*pow(2*y*x + a,2) - 4*b*e2*sin(e1)*y 
126         - 2*b*pow(e2,2)*pow(2*y*x + a,2) - y*e2*pow(2*y*x + a,2)
127         - 2*pow(y,2)*sin(e1);
128     result += check_diff(e, x, d, 2);
129     
130     d = -2*b*e2*sin(e1)*pow(x,2) + e2 - y*sin(e1)*pow(x, 2);
131     result += check_diff(e, y, d);
132     
133     d = -2*b*pow(e2,2)*pow(x,4) + 2*b*pow(sin(e1),2)*pow(x,4)
134         - 2*sin(e1)*pow(x,2) - y*e2*pow(x,4);
135     result += check_diff(e, y, d, 2);
136     
137     // construct expression e to be diff'ed:
138     e2 = exp(e1);
139     e = b*pow(e2, 2) + y*e2 + a;
140     
141     d = 2*b*pow(e2, 2)*(2*x*y + a) + y*e2*(2*x*y + a);
142     result += check_diff(e, x, d);
143     
144     d = 4*b*pow(e2,2)*pow(2*y*x + a,2) + 4*b*pow(e2,2)*y
145         + 2*pow(y,2)*e2 + y*e2*pow(2*y*x + a,2);
146     result += check_diff(e, x, d, 2);
147     
148     d = 2*b*pow(e2,2)*pow(x,2) + e2 + y*e2*pow(x,2);
149     result += check_diff(e, y, d);
150     
151     d = 4*b*pow(e2,2)*pow(x,4) + 2*e2*pow(x,2) + y*e2*pow(x,4);
152     result += check_diff(e, y, d, 2);
153     
154     // construct expression e to be diff'ed:
155     e2 = log(e1);
156     e = b*pow(e2, 2) + y*e2 + a;
157     
158     d = 2*b*e2*(2*x*y + a)/e1 + y*(2*x*y + a)/e1;
159     result += check_diff(e, x, d);
160     
161     d = 2*b*pow((2*x*y + a),2)*pow(e1,-2) + 4*b*y*e2/e1
162         - 2*b*e2*pow(2*x*y + a,2)*pow(e1,-2) + 2*pow(y,2)/e1
163         - y*pow(2*x*y + a,2)*pow(e1,-2);
164     result += check_diff(e, x, d, 2);
165     
166     d = 2*b*e2*pow(x,2)/e1 + e2 + y*pow(x,2)/e1;
167     result += check_diff(e, y, d);
168     
169     d = 2*b*pow(x,4)*pow(e1,-2) - 2*b*e2*pow(e1,-2)*pow(x,4)
170         + 2*pow(x,2)/e1 - y*pow(x,4)*pow(e1,-2);
171     result += check_diff(e, y, d, 2);
172     
173     // test for functions with two variables: atan2
174     e1 = y*pow(x, 2) + a*x + b;
175     e2 = x*pow(y, 2) + b*y + a;
176     e = atan2(e1,e2);
177     /*
178     d = pow(y,2)*(-b-y*pow(x,2)-a*x)/(pow(b+y*pow(x,2)+a*x,2)+pow(x*pow(y,2)+b*y+a,2))
179         +(2*y*x+a)/((x*pow(y,2)+b*y+a)*(1+pow(b*y*pow(x,2)+a*x,2)/pow(x*pow(y,2)+b*y+a,2)));
180     */
181     /*
182     d = ((a+2*y*x)*pow(y*b+pow(y,2)*x+a,-1)-(a*x+b+y*pow(x,2))*
183          pow(y*b+pow(y,2)*x+a,-2)*pow(y,2))*
184         pow(1+pow(a*x+b+y*pow(x,2),2)*pow(y*b+pow(y,2)*x+a,-2),-1);
185     */
186     d = pow(1+pow(a*x+b+y*pow(x,2),2)*pow(y*b+pow(y,2)*x+a,-2),-1)
187         *pow(y*b+pow(y,2)*x+a,-1)*(a+2*y*x)
188         +pow(y,2)*(-a*x-b-y*pow(x,2))*
189         pow(pow(y*b+pow(y,2)*x+a,2)+pow(a*x+b+y*pow(x,2),2),-1);
190     result += check_diff(e, x, d);
191     
192     return result;
193 }
194
195 // Series
196 static unsigned differentiation3(void)
197 {
198     symbol x("x");
199     ex e, d, ed;
200     
201     e = sin(x).series(x, exZERO(), 8);
202     d = cos(x).series(x, exZERO(), 7);
203     ed = e.diff(x);
204     ed = static_cast<series *>(ed.bp)->convert_to_poly();
205     d = static_cast<series *>(d.bp)->convert_to_poly();
206     
207     if ((ed - d).compare(exZERO()) != 0) {
208         clog << "derivative of " << e << " by " << x << " returned "
209              << ed << " instead of " << d << ")" << endl;
210         return 1;
211     }
212     return 0;
213 }
214
215 unsigned differentiation(void)
216 {
217     unsigned result = 0;
218     
219     cout << "checking symbolic differentiation..." << flush;
220     clog << "---------symbolic differentiation:" << endl;
221     
222     result += differentiation1();
223     result += differentiation2();
224     result += differentiation3();
225     
226     if (!result) {
227         cout << " passed ";
228         clog << "(no output)" << endl;
229     } else {
230         cout << " failed ";
231     }
232     return result;
233 }