226b6c2d3a0d8144d126f0f6f83fb584c43767ec
[ginac.git] / check / check_lsolve.cpp
1 /** @file check_lsolve.cpp
2  *
3  *  These test routines do some simple checks on solving linear systems of
4  *  symbolic equations.  They are a well-tried resource for cross-checking
5  *  the underlying symbolic manipulations. */
6
7 /*
8  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
9  *
10  *  This program is free software; you can redistribute it and/or modify
11  *  it under the terms of the GNU General Public License as published by
12  *  the Free Software Foundation; either version 2 of the License, or
13  *  (at your option) any later version.
14  *
15  *  This program is distributed in the hope that it will be useful,
16  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18  *  GNU General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with this program; if not, write to the Free Software
22  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
23  */
24
25 #include <iostream>
26 #include <sstream>
27 #include <cstdlib> // rand()
28 #include "ginac.h"
29 using namespace std;
30 using namespace GiNaC;
31
32 extern const ex 
33 dense_univariate_poly(const symbol & x, unsigned degree);
34
35 static unsigned check_matrix_solve(unsigned m, unsigned n, unsigned p,
36                                                                    unsigned degree)
37 {
38         const symbol a("a");
39         matrix A(m,n);
40         matrix B(m,p);
41         // set the first min(m,n) rows of A and B
42         for (unsigned ro=0; (ro<m)&&(ro<n); ++ro) {
43                 for (unsigned co=0; co<n; ++co)
44                         A.set(ro,co,dense_univariate_poly(a,degree));
45                 for (unsigned co=0; co<p; ++co)
46                         B.set(ro,co,dense_univariate_poly(a,degree));
47         }
48         // repeat excessive rows of A and B to avoid excessive construction of
49         // overdetermined linear systems
50         for (unsigned ro=n; ro<m; ++ro) {
51                 for (unsigned co=0; co<n; ++co)
52                         A.set(ro,co,A(ro-1,co));
53                 for (unsigned co=0; co<p; ++co)
54                         B.set(ro,co,B(ro-1,co));
55         }
56         // create a vector of n*p symbols all named "xrc" where r and c are ints
57         vector<symbol> x;
58         matrix X(n,p);
59         for (unsigned i=0; i<n; ++i) {
60                 for (unsigned j=0; j<p; ++j) {
61                         ostringstream buf;
62                         buf << "x" << i << j << ends;
63                         x.push_back(symbol(buf.str()));
64                         X.set(i,j,x[p*i+j]);
65                 }
66         }
67         matrix sol(n,p);
68         // Solve the system A*X==B:
69         try {
70                 sol = A.solve(X, B);
71         } catch (const exception & err) {  // catch runtime_error
72                 // Presumably, the coefficient matrix A was degenerate
73                 string errwhat = err.what();
74                 if (errwhat == "matrix::solve(): inconsistent linear system")
75                         return 0;
76                 else
77                         clog << "caught exception: " << errwhat << endl;
78                 throw;
79         }
80         
81         // check the result with our original matrix:
82         bool errorflag = false;
83         for (unsigned ro=0; ro<m; ++ro) {
84                 for (unsigned pco=0; pco<p; ++pco) {
85                         ex e = 0;
86                         for (unsigned co=0; co<n; ++co)
87                         e += A(ro,co)*sol(co,pco);
88                         if (!(e-B(ro,pco)).normal().is_zero())
89                                 errorflag = true;
90                 }
91         }
92         if (errorflag) {
93                 clog << "Our solve method claims that A*X==B, with matrices" << endl
94                      << "A == " << A << endl
95                      << "X == " << sol << endl
96                      << "B == " << B << endl;
97                 return 1;
98         }
99         
100         return 0;
101 }
102
103 static unsigned check_inifcns_lsolve(unsigned n)
104 {
105         unsigned result = 0;
106         
107         for (int repetition=0; repetition<200; ++repetition) {
108                 // create two size n vectors of symbols, one for the coefficients
109                 // a[0],..,a[n], one for indeterminates x[0]..x[n]:
110                 vector<symbol> a;
111                 vector<symbol> x;
112                 for (unsigned i=0; i<n; ++i) {
113                         ostringstream buf;
114                         buf << i << ends;
115                         a.push_back(symbol(string("a")+buf.str()));
116                         x.push_back(symbol(string("x")+buf.str()));
117                 }
118                 lst eqns;  // equation list
119                 lst vars;  // variable list
120                 ex sol; // solution
121                 // Create a random linear system...
122                 for (unsigned i=0; i<n; ++i) {
123                         ex lhs = rand()%201-100;
124                         ex rhs = rand()%201-100;
125                         for (unsigned j=0; j<n; ++j) {
126                                 // ...with small coefficients to give degeneracy a chance...
127                                 lhs += a[j]*(rand()%21-10);
128                                 rhs += x[j]*(rand()%21-10);
129                         }
130                         eqns.append(lhs==rhs);
131                         vars.append(x[i]);
132                 }
133                 // ...solve it...
134                 sol = lsolve(eqns, vars);
135                 
136                 // ...and check the solution:
137                 if (sol.nops() == 0) {
138                         // no solution was found
139                         // is the coefficient matrix really, really, really degenerate?
140                         matrix coeffmat(n,n);
141                         for (unsigned ro=0; ro<n; ++ro)
142                                 for (unsigned co=0; co<n; ++co)
143                                         coeffmat.set(ro,co,eqns.op(co).rhs().coeff(a[co],1));
144                         if (!coeffmat.determinant().is_zero()) {
145                                 ++result;
146                                 clog << "solution of the system " << eqns << " for " << vars
147                                          << " was not found" << endl;
148                         }
149                 } else {
150                         // insert the solution into rhs of out equations
151                         bool errorflag = false;
152                         for (unsigned i=0; i<n; ++i)
153                                 if (eqns.op(i).rhs().subs(sol) != eqns.op(i).lhs())
154                                         errorflag = true;
155                         if (errorflag) {
156                                 ++result;
157                                 clog << "solution of the system " << eqns << " for " << vars
158                                      << " erroneously returned " << sol << endl;
159                         }
160                 }
161         }
162         
163         return result;
164 }
165
166 unsigned check_lsolve()
167 {
168         unsigned result = 0;
169         
170         cout << "checking linear solve" << flush;
171         
172         // solve some numeric linear systems
173         for (unsigned n=1; n<14; ++n)
174                 result += check_matrix_solve(n, n, 1, 0);
175         cout << '.' << flush;
176         // solve some underdetermined numeric systems
177         for (unsigned n=1; n<14; ++n)
178                 result += check_matrix_solve(n+1, n, 1, 0);
179         cout << '.' << flush;
180         // solve some overdetermined numeric systems
181         for (unsigned n=1; n<14; ++n)
182                 result += check_matrix_solve(n, n+1, 1, 0);
183         cout << '.' << flush;
184         // solve some multiple numeric systems
185         for (unsigned n=1; n<14; ++n)
186                 result += check_matrix_solve(n, n, n/3+1, 0);
187         cout << '.' << flush;
188         // solve some symbolic linear systems
189         for (unsigned n=1; n<8; ++n)
190                 result += check_matrix_solve(n, n, 1, 2);
191         cout << '.' << flush;
192         
193         // check lsolve, the wrapper function around matrix::solve()
194         result += check_inifcns_lsolve(2);  cout << '.' << flush;
195         result += check_inifcns_lsolve(3);  cout << '.' << flush;
196         result += check_inifcns_lsolve(4);  cout << '.' << flush;
197         result += check_inifcns_lsolve(5);  cout << '.' << flush;
198         result += check_inifcns_lsolve(6);  cout << '.' << flush;
199                 
200         return result;
201 }
202
203 int main(int argc, char** argv)
204 {
205         return check_lsolve();
206 }