3e1c25e5de5ae6af233f23272763860432758c2b
[ginac.git] / check / exam_matrices.cpp
1 /** @file exam_matrices.cpp
2  *
3  *  Here we examine manipulations on GiNaC's symbolic matrices. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "ginac.h"
24 using namespace GiNaC;
25
26 #include <iostream>
27 #include <stdexcept>
28 using namespace std;
29
30 static unsigned matrix_determinants()
31 {
32         unsigned result = 0;
33         ex det;
34         matrix m1(1,1), m2(2,2), m3(3,3), m4(4,4);
35         symbol a("a"), b("b"), c("c");
36         symbol d("d"), e("e"), f("f");
37         symbol g("g"), h("h"), i("i");
38         
39         // check symbolic trivial matrix determinant
40         m1.set(0,0,a);
41         det = m1.determinant();
42         if (det != a) {
43                 clog << "determinant of 1x1 matrix " << m1
44                      << " erroneously returned " << det << endl;
45                 ++result;
46         }
47         
48         // check generic dense symbolic 2x2 matrix determinant
49         m2.set(0,0,a).set(0,1,b);
50         m2.set(1,0,c).set(1,1,d);
51         det = m2.determinant();
52         if (det != (a*d-b*c)) {
53                 clog << "determinant of 2x2 matrix " << m2
54                      << " erroneously returned " << det << endl;
55                 ++result;
56         }
57         
58         // check generic dense symbolic 3x3 matrix determinant
59         m3.set(0,0,a).set(0,1,b).set(0,2,c);
60         m3.set(1,0,d).set(1,1,e).set(1,2,f);
61         m3.set(2,0,g).set(2,1,h).set(2,2,i);
62         det = m3.determinant();
63         if (det != (a*e*i - a*f*h - d*b*i + d*c*h + g*b*f - g*c*e)) {
64                 clog << "determinant of 3x3 matrix " << m3
65                      << " erroneously returned " << det << endl;
66                 ++result;
67         }
68         
69         // check dense numeric 3x3 matrix determinant
70         m3.set(0,0,numeric(0)).set(0,1,numeric(-1)).set(0,2,numeric(3));
71         m3.set(1,0,numeric(3)).set(1,1,numeric(-2)).set(1,2,numeric(2));
72         m3.set(2,0,numeric(3)).set(2,1,numeric(4)).set(2,2,numeric(-2));
73         det = m3.determinant();
74         if (det != 42) {
75                 clog << "determinant of 3x3 matrix " << m3
76                      << " erroneously returned " << det << endl;
77                 ++result;
78         }
79         
80         // check dense symbolic 2x2 matrix determinant
81         m2.set(0,0,a/(a-b)).set(0,1,1);
82         m2.set(1,0,b/(a-b)).set(1,1,1);
83         det = m2.determinant();
84         if (det != 1) {
85                 if (det.normal() == 1)  // only half wrong
86                         clog << "determinant of 2x2 matrix " << m2
87                              << " was returned unnormalized as " << det << endl;
88                 else  // totally wrong
89                         clog << "determinant of 2x2 matrix " << m2
90                              << " erroneously returned " << det << endl;
91                 ++result;
92         }
93         
94         // check sparse symbolic 4x4 matrix determinant
95         m4.set(0,1,a).set(1,0,b).set(3,2,c).set(2,3,d);
96         det = m4.determinant();
97         if (det != a*b*c*d) {
98                 clog << "determinant of 4x4 matrix " << m4
99                      << " erroneously returned " << det << endl;
100                 ++result;
101         }
102         
103         // check characteristic polynomial
104         m3.set(0,0,a).set(0,1,-2).set(0,2,2);
105         m3.set(1,0,3).set(1,1,a-1).set(1,2,2);
106         m3.set(2,0,3).set(2,1,4).set(2,2,a-3);
107         ex p = m3.charpoly(a);
108         if (p != 0) {
109                 clog << "charpoly of 3x3 matrix " << m3
110                      << " erroneously returned " << p << endl;
111                 ++result;
112         }
113         
114         return result;
115 }
116
117 static unsigned matrix_invert1()
118 {
119         unsigned result = 0;
120         matrix m(1,1);
121         symbol a("a");
122         
123         m.set(0,0,a);
124         matrix m_i = m.inverse();
125         
126         if (m_i(0,0) != pow(a,-1)) {
127                 clog << "inversion of 1x1 matrix " << m
128                      << " erroneously returned " << m_i << endl;
129                 ++result;
130         }
131         
132         return result;
133 }
134
135 static unsigned matrix_invert2()
136 {
137         unsigned result = 0;
138         matrix m(2,2);
139         symbol a("a"), b("b"), c("c"), d("d");
140         m.set(0,0,a).set(0,1,b);
141         m.set(1,0,c).set(1,1,d);
142         matrix m_i = m.inverse();
143         ex det = m.determinant();
144         
145         if ((normal(m_i(0,0)*det) != d) ||
146             (normal(m_i(0,1)*det) != -b) ||
147             (normal(m_i(1,0)*det) != -c) ||
148             (normal(m_i(1,1)*det) != a)) {
149                 clog << "inversion of 2x2 matrix " << m
150                      << " erroneously returned " << m_i << endl;
151                 ++result;
152         }
153         
154         return result;
155 }
156
157 static unsigned matrix_invert3()
158 {
159         unsigned result = 0;
160         matrix m(3,3);
161         symbol a("a"), b("b"), c("c");
162         symbol d("d"), e("e"), f("f");
163         symbol g("g"), h("h"), i("i");
164         m.set(0,0,a).set(0,1,b).set(0,2,c);
165         m.set(1,0,d).set(1,1,e).set(1,2,f);
166         m.set(2,0,g).set(2,1,h).set(2,2,i);
167         matrix m_i = m.inverse();
168         ex det = m.determinant();
169         
170         if ((normal(m_i(0,0)*det) != (e*i-f*h)) ||
171             (normal(m_i(0,1)*det) != (c*h-b*i)) ||
172             (normal(m_i(0,2)*det) != (b*f-c*e)) ||
173             (normal(m_i(1,0)*det) != (f*g-d*i)) ||
174             (normal(m_i(1,1)*det) != (a*i-c*g)) ||
175             (normal(m_i(1,2)*det) != (c*d-a*f)) ||
176             (normal(m_i(2,0)*det) != (d*h-e*g)) ||
177             (normal(m_i(2,1)*det) != (b*g-a*h)) ||
178             (normal(m_i(2,2)*det) != (a*e-b*d))) {
179                 clog << "inversion of 3x3 matrix " << m
180                      << " erroneously returned " << m_i << endl;
181                 ++result;
182         }
183         
184         return result;
185 }
186
187 static unsigned matrix_solve2()
188 {
189         // check the solution of the multiple system A*X = B:
190         //       [ 1  2 -1 ] [ x0 y0 ]   [ 4 0 ]
191         //       [ 1  4 -2 ]*[ x1 y1 ] = [ 7 0 ]
192         //       [ a -2  2 ] [ x2 y2 ]   [ a 4 ]
193         unsigned result = 0;
194         symbol a("a");
195         symbol x0("x0"), x1("x1"), x2("x2");
196         symbol y0("y0"), y1("y1"), y2("y2");
197         matrix A(3,3);
198         A.set(0,0,1).set(0,1,2).set(0,2,-1);
199         A.set(1,0,1).set(1,1,4).set(1,2,-2);
200         A.set(2,0,a).set(2,1,-2).set(2,2,2);
201         matrix B(3,2);
202         B.set(0,0,4).set(1,0,7).set(2,0,a);
203         B.set(0,1,0).set(1,1,0).set(2,1,4);
204         matrix X(3,2);
205         X.set(0,0,x0).set(1,0,x1).set(2,0,x2);
206         X.set(0,1,y0).set(1,1,y1).set(2,1,y2);
207         matrix cmp(3,2);
208         cmp.set(0,0,1).set(1,0,3).set(2,0,3);
209         cmp.set(0,1,0).set(1,1,2).set(2,1,4);
210         matrix sol(A.solve(X, B));
211         for (unsigned ro=0; ro<3; ++ro)
212                 for (unsigned co=0; co<2; ++co)
213                         if (cmp(ro,co) != sol(ro,co))
214                                 result = 1;
215         if (result) {
216                 clog << "Solving " << A << " * " << X << " == " << B << endl
217                      << "erroneously returned " << sol << endl;
218         }
219         
220         return result;
221 }
222
223 static unsigned matrix_evalm()
224 {
225         unsigned result = 0;
226
227         matrix S(2, 2, lst(
228                 1, 2,
229                 3, 4
230         )), T(2, 2, lst(
231                 1, 1,
232                 2, -1
233         )), R(2, 2, lst(
234                 27, 14,
235                 36, 26
236         ));
237
238         ex e = ((S + T) * (S + 2*T));
239         ex f = e.evalm();
240         if (!f.is_equal(R)) {
241                 clog << "Evaluating " << e << " erroneously returned " << f << " instead of " << R << endl;
242                 result++;
243         }
244
245         return result;
246 }
247
248 static unsigned matrix_rank()
249 {
250         unsigned result = 0;
251         symbol x("x"), y("y");
252         matrix m(3,3);
253
254         // the zero matrix always has rank 0
255         if (m.rank() != 0) {
256                 clog << "The rank of " << m << " was not computed correctly." << endl;
257                 ++result;
258         }
259
260         // a trivial rank one example
261         m = 1, 0, 0,
262             2, 0, 0,
263             3, 0, 0;
264         if (m.rank() != 1) {
265                 clog << "The rank of " << m << " was not computed correctly." << endl;
266                 ++result;
267         }
268
269         // an example from Maple's help with rank two
270         m = x,  1,  0,
271             0,  0,  1,
272            x*y, y,  1;
273         if (m.rank() != 2) {
274                 clog << "The rank of " << m << " was not computed correctly." << endl;
275                 ++result;
276         }
277
278         // the 3x3 unit matrix has rank 3
279         m = ex_to<matrix>(unit_matrix(3,3));
280         if (m.rank() != 3) {
281                 clog << "The rank of " << m << " was not computed correctly." << endl;
282                 ++result;
283         }
284
285         return result;  
286 }
287
288 static unsigned matrix_misc()
289 {
290         unsigned result = 0;
291         matrix m1(2,2);
292         symbol a("a"), b("b"), c("c"), d("d"), e("e"), f("f");
293         m1.set(0,0,a).set(0,1,b);
294         m1.set(1,0,c).set(1,1,d);
295         ex tr = trace(m1);
296         
297         // check a simple trace
298         if (tr.compare(a+d)) {
299                 clog << "trace of 2x2 matrix " << m1
300                      << " erroneously returned " << tr << endl;
301                 ++result;
302         }
303         
304         // and two simple transpositions
305         matrix m2 = transpose(m1);
306         if (m2(0,0) != a || m2(0,1) != c || m2(1,0) != b || m2(1,1) != d) {
307                 clog << "transpose of 2x2 matrix " << m1
308                          << " erroneously returned " << m2 << endl;
309                 ++result;
310         }
311         matrix m3(3,2);
312         m3.set(0,0,a).set(0,1,b);
313         m3.set(1,0,c).set(1,1,d);
314         m3.set(2,0,e).set(2,1,f);
315         if (transpose(transpose(m3)) != m3) {
316                 clog << "transposing 3x2 matrix " << m3 << " twice"
317                      << " erroneously returned " << transpose(transpose(m3)) << endl;
318                 ++result;
319         }
320         
321         // produce a runtime-error by inverting a singular matrix and catch it
322         matrix m4(2,2);
323         matrix m5;
324         bool caught = false;
325         try {
326                 m5 = inverse(m4);
327         } catch (std::runtime_error err) {
328                 caught = true;
329         }
330         if (!caught) {
331                 cerr << "singular 2x2 matrix " << m4
332                      << " erroneously inverted to " << m5 << endl;
333                 ++result;
334         }
335         
336         return result;
337 }
338
339 unsigned exam_matrices()
340 {
341         unsigned result = 0;
342         
343         cout << "examining symbolic matrix manipulations" << flush;
344         
345         result += matrix_determinants();  cout << '.' << flush;
346         result += matrix_invert1();  cout << '.' << flush;
347         result += matrix_invert2();  cout << '.' << flush;
348         result += matrix_invert3();  cout << '.' << flush;
349         result += matrix_solve2();  cout << '.' << flush;
350         result += matrix_evalm();  cout << "." << flush;
351         result += matrix_rank();  cout << "." << flush;
352         result += matrix_misc();  cout << '.' << flush;
353         
354         return result;
355 }
356
357 int main(int argc, char** argv)
358 {
359         return exam_matrices();
360 }