51792a3ed0e29d0b52457b5d3e86c2657e18a752
[ginac.git] / check / check_lsolve.cpp
1 /** @file check_lsolve.cpp
2  *
3  *  These test routines do some simple checks on solving linear systems of
4  *  symbolic equations. */
5
6 /*
7  *  GiNaC Copyright (C) 1999-2002 Johannes Gutenberg University Mainz, Germany
8  *
9  *  This program is free software; you can redistribute it and/or modify
10  *  it under the terms of the GNU General Public License as published by
11  *  the Free Software Foundation; either version 2 of the License, or
12  *  (at your option) any later version.
13  *
14  *  This program is distributed in the hope that it will be useful,
15  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17  *  GNU General Public License for more details.
18  *
19  *  You should have received a copy of the GNU General Public License
20  *  along with this program; if not, write to the Free Software
21  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
22  */
23
24 #include "checks.h"
25 #include <sstream>
26
27 static unsigned check_matrix_solve(unsigned m, unsigned n, unsigned p,
28                                                                    unsigned degree)
29 {
30         const symbol a("a");
31         matrix A(m,n);
32         matrix B(m,p);
33         // set the first min(m,n) rows of A and B
34         for (unsigned ro=0; (ro<m)&&(ro<n); ++ro) {
35                 for (unsigned co=0; co<n; ++co)
36                         A.set(ro,co,dense_univariate_poly(a,degree));
37                 for (unsigned co=0; co<p; ++co)
38                         B.set(ro,co,dense_univariate_poly(a,degree));
39         }
40         // repeat excessive rows of A and B to avoid excessive construction of
41         // overdetermined linear systems
42         for (unsigned ro=n; ro<m; ++ro) {
43                 for (unsigned co=0; co<n; ++co)
44                         A.set(ro,co,A(ro-1,co));
45                 for (unsigned co=0; co<p; ++co)
46                         B.set(ro,co,B(ro-1,co));
47         }
48         // create a vector of n*p symbols all named "xrc" where r and c are ints
49         vector<symbol> x;
50         matrix X(n,p);
51         for (unsigned i=0; i<n; ++i) {
52                 for (unsigned j=0; j<p; ++j) {
53                         ostringstream buf;
54                         buf << "x" << i << j << ends;
55                         x.push_back(symbol(buf.str()));
56                         X.set(i,j,x[p*i+j]);
57                 }
58         }
59         matrix sol(n,p);
60         // Solve the system A*X==B:
61         try {
62                 sol = A.solve(X, B);
63         } catch (const exception & err) {  // catch runtime_error
64                 // Presumably, the coefficient matrix A was degenerate
65                 string errwhat = err.what();
66                 if (errwhat == "matrix::solve(): inconsistent linear system")
67                         return 0;
68                 else
69                         clog << "caught exception: " << errwhat << endl;
70                 throw;
71         }
72         
73         // check the result with our original matrix:
74         bool errorflag = false;
75         for (unsigned ro=0; ro<m; ++ro) {
76                 for (unsigned pco=0; pco<p; ++pco) {
77                         ex e = 0;
78                         for (unsigned co=0; co<n; ++co)
79                         e += A(ro,co)*sol(co,pco);
80                         if (!(e-B(ro,pco)).normal().is_zero())
81                                 errorflag = true;
82                 }
83         }
84         if (errorflag) {
85                 clog << "Our solve method claims that A*X==B, with matrices" << endl
86                      << "A == " << A << endl
87                      << "X == " << sol << endl
88                      << "B == " << B << endl;
89                 return 1;
90         }
91         
92         return 0;
93 }
94
95 static unsigned check_inifcns_lsolve(unsigned n)
96 {
97         unsigned result = 0;
98         
99         for (int repetition=0; repetition<100; ++repetition) {
100                 // create two size n vectors of symbols, one for the coefficients
101                 // a[0],..,a[n], one for indeterminates x[0]..x[n]:
102                 vector<symbol> a;
103                 vector<symbol> x;
104                 for (unsigned i=0; i<n; ++i) {
105                         ostringstream buf;
106                         buf << i << ends;
107                         a.push_back(symbol(string("a")+buf.str()));
108                         x.push_back(symbol(string("x")+buf.str()));
109                 }
110                 lst eqns;  // equation list
111                 lst vars;  // variable list
112                 ex sol; // solution
113                 // Create a random linear system...
114                 for (unsigned i=0; i<n; ++i) {
115                         ex lhs = rand()%201-100;
116                         ex rhs = rand()%201-100;
117                         for (unsigned j=0; j<n; ++j) {
118                                 // ...with small coefficients to give degeneracy a chance...
119                                 lhs += a[j]*(rand()%21-10);
120                                 rhs += x[j]*(rand()%21-10);
121                         }
122                         eqns.append(lhs==rhs);
123                         vars.append(x[i]);
124                 }
125                 // ...solve it...
126                 sol = lsolve(eqns, vars);
127                 
128                 // ...and check the solution:
129                 if (sol.nops() == 0) {
130                         // no solution was found
131                         // is the coefficient matrix really, really, really degenerate?
132                         matrix coeffmat(n,n);
133                         for (unsigned ro=0; ro<n; ++ro)
134                                 for (unsigned co=0; co<n; ++co)
135                                         coeffmat.set(ro,co,eqns.op(co).rhs().coeff(a[co],1));
136                         if (!coeffmat.determinant().is_zero()) {
137                                 ++result;
138                                 clog << "solution of the system " << eqns << " for " << vars
139                                          << " was not found" << endl;
140                         }
141                 } else {
142                         // insert the solution into rhs of out equations
143                         bool errorflag = false;
144                         for (unsigned i=0; i<n; ++i)
145                                 if (eqns.op(i).rhs().subs(sol) != eqns.op(i).lhs())
146                                         errorflag = true;
147                         if (errorflag) {
148                                 ++result;
149                                 clog << "solution of the system " << eqns << " for " << vars
150                                      << " erroneously returned " << sol << endl;
151                         }
152                 }
153         }
154         
155         return result;
156 }
157
158 unsigned check_lsolve(void)
159 {
160         unsigned result = 0;
161         
162         cout << "checking linear solve" << flush;
163         clog << "---------linear solve:" << endl;
164         
165         // solve some numeric linear systems
166         for (unsigned n=1; n<12; ++n)
167                 result += check_matrix_solve(n, n, 1, 0);
168         cout << '.' << flush;
169         // solve some underdetermined numeric systems
170         for (unsigned n=1; n<12; ++n)
171                 result += check_matrix_solve(n+1, n, 1, 0);
172         cout << '.' << flush;
173         // solve some overdetermined numeric systems
174         for (unsigned n=1; n<12; ++n)
175                 result += check_matrix_solve(n, n+1, 1, 0);
176         cout << '.' << flush;
177         // solve some multiple numeric systems
178         for (unsigned n=1; n<12; ++n)
179                 result += check_matrix_solve(n, n, n/3+1, 0);
180         cout << '.' << flush;
181         // solve some symbolic linear systems
182         for (unsigned n=1; n<7; ++n)
183                 result += check_matrix_solve(n, n, 1, 2);
184         cout << '.' << flush;
185         
186         // check lsolve, the wrapper function around matrix::solve()
187         result += check_inifcns_lsolve(2);  cout << '.' << flush;
188         result += check_inifcns_lsolve(3);  cout << '.' << flush;
189         result += check_inifcns_lsolve(4);  cout << '.' << flush;
190         result += check_inifcns_lsolve(5);  cout << '.' << flush;
191         result += check_inifcns_lsolve(6);  cout << '.' << flush;
192                 
193         if (!result) {
194                 cout << " passed " << endl;
195                 clog << "(no output)" << endl;
196         } else {
197                 cout << " failed " << endl;
198         }
199         
200         return result;
201 }