- Added inclusion of <sstream> or <strstream> since <ginac/ginac.h> now
[ginac.git] / check / check_lsolve.cpp
1 /** @file check_lsolve.cpp
2  *
3  *  These test routines do some simple checks on solving linear systems of
4  *  symbolic equations. */
5
6 /*
7  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
8  *
9  *  This program is free software; you can redistribute it and/or modify
10  *  it under the terms of the GNU General Public License as published by
11  *  the Free Software Foundation; either version 2 of the License, or
12  *  (at your option) any later version.
13  *
14  *  This program is distributed in the hope that it will be useful,
15  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17  *  GNU General Public License for more details.
18  *
19  *  You should have received a copy of the GNU General Public License
20  *  along with this program; if not, write to the Free Software
21  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
22  */
23
24 #include "checks.h"
25
26 #if defined(HAVE_SSTREAM)
27 #  include <sstream>
28 #else
29 #  include <strstream>
30 #endif
31
32 static unsigned check_matrix_solve(unsigned m, unsigned n, unsigned p,
33                                                                    unsigned degree)
34 {
35         const symbol a("a");
36         matrix A(m,n);
37         matrix B(m,p);
38         // set the first min(m,n) rows of A and B
39         for (unsigned ro=0; (ro<m)&&(ro<n); ++ro) {
40                 for (unsigned co=0; co<n; ++co)
41                         A.set(ro,co,dense_univariate_poly(a,degree));
42                 for (unsigned co=0; co<p; ++co)
43                         B.set(ro,co,dense_univariate_poly(a,degree));
44         }
45         // repeat excessive rows of A and B to avoid excessive construction of
46         // overdetermined linear systems
47         for (unsigned ro=n; ro<m; ++ro) {
48                 for (unsigned co=0; co<n; ++co)
49                         A.set(ro,co,A(ro-1,co));
50                 for (unsigned co=0; co<p; ++co)
51                         B.set(ro,co,B(ro-1,co));
52         }
53         // create a vector of n*p symbols all named "xrc" where r and c are ints
54         vector<symbol> x;
55         matrix X(n,p);
56         for (unsigned i=0; i<n; ++i) {
57                 for (unsigned j=0; j<p; ++j) {
58 #if defined(HAVE_SSTREAM)
59                         ostringstream buf;
60                         buf << "x" << i << j << ends;
61                         x.push_back(symbol(buf.str()));
62 #else
63                         char buf[4];
64                         ostrstream(buf,sizeof(buf)) << i << j << ends;
65                         x.push_back(symbol(string("x")+buf));
66 #endif
67                         X.set(i,j,x[p*i+j]);
68                 }
69         }
70         matrix sol(n,p);
71         // Solve the system A*X==B:
72         try {
73                 sol = A.solve(X, B);
74         } catch (const exception & err) {  // catch runtime_error
75                 // Presumably, the coefficient matrix A was degenerate
76                 string errwhat = err.what();
77                 if (errwhat == "matrix::solve(): inconsistent linear system")
78                         return 0;
79                 else
80                         clog << "caught exception: " << errwhat << endl;
81                 throw;
82         }
83         
84         // check the result with our original matrix:
85         bool errorflag = false;
86         for (unsigned ro=0; ro<m; ++ro) {
87                 for (unsigned pco=0; pco<p; ++pco) {
88                         ex e = 0;
89                         for (unsigned co=0; co<n; ++co)
90                         e += A(ro,co)*sol(co,pco);
91                         if (!(e-B(ro,pco)).normal().is_zero())
92                                 errorflag = true;
93                 }
94         }
95         if (errorflag) {
96                 clog << "Our solve method claims that A*X==B, with matrices" << endl
97                      << "A == " << A << endl
98                      << "X == " << sol << endl
99                      << "B == " << B << endl;
100                 return 1;
101         }
102         
103         return 0;
104 }
105
106 static unsigned check_inifcns_lsolve(unsigned n)
107 {
108         unsigned result = 0;
109         
110         for (int repetition=0; repetition<100; ++repetition) {
111                 // create two size n vectors of symbols, one for the coefficients
112                 // a[0],..,a[n], one for indeterminates x[0]..x[n]:
113                 vector<symbol> a;
114                 vector<symbol> x;
115                 for (unsigned i=0; i<n; ++i) {
116 #if defined(HAVE_SSTREAM)
117                         ostringstream buf;
118                         buf << i << ends;
119                         a.push_back(symbol(string("a")+buf.str()));
120                         x.push_back(symbol(string("x")+buf.str()));
121 #else
122                         char buf[3];
123                         ostrstream(buf,sizeof(buf)) << i << ends;
124                         a.push_back(symbol(string("a")+buf));
125                         x.push_back(symbol(string("x")+buf));
126 #endif
127                 }
128                 lst eqns;  // equation list
129                 lst vars;  // variable list
130                 ex sol; // solution
131                 // Create a random linear system...
132                 for (unsigned i=0; i<n; ++i) {
133                         ex lhs = rand()%201-100;
134                         ex rhs = rand()%201-100;
135                         for (unsigned j=0; j<n; ++j) {
136                                 // ...with small coefficients to give degeneracy a chance...
137                                 lhs += a[j]*(rand()%21-10);
138                                 rhs += x[j]*(rand()%21-10);
139                         }
140                         eqns.append(lhs==rhs);
141                         vars.append(x[i]);
142                 }
143                 // ...solve it...
144                 sol = lsolve(eqns, vars);
145                 
146                 // ...and check the solution:
147                 if (sol.nops() == 0) {
148                         // no solution was found
149                         // is the coefficient matrix really, really, really degenerate?
150                         matrix coeffmat(n,n);
151                         for (unsigned ro=0; ro<n; ++ro)
152                                 for (unsigned co=0; co<n; ++co)
153                                         coeffmat.set(ro,co,eqns.op(co).rhs().coeff(a[co],1));
154                         if (!coeffmat.determinant().is_zero()) {
155                                 ++result;
156                                 clog << "solution of the system " << eqns << " for " << vars
157                                          << " was not found" << endl;
158                         }
159                 } else {
160                         // insert the solution into rhs of out equations
161                         bool errorflag = false;
162                         for (unsigned i=0; i<n; ++i)
163                                 if (eqns.op(i).rhs().subs(sol) != eqns.op(i).lhs())
164                                         errorflag = true;
165                         if (errorflag) {
166                                 ++result;
167                                 clog << "solution of the system " << eqns << " for " << vars
168                                      << " erroneously returned " << sol << endl;
169                         }
170                 }
171         }
172         
173         return result;
174 }
175
176 unsigned check_lsolve(void)
177 {
178         unsigned result = 0;
179         
180         cout << "checking linear solve" << flush;
181         clog << "---------linear solve:" << endl;
182         
183         // solve some numeric linear systems
184         for (unsigned n=1; n<12; ++n)
185                 result += check_matrix_solve(n, n, 1, 0);
186         cout << '.' << flush;
187         // solve some underdetermined numeric systems
188         for (unsigned n=1; n<12; ++n)
189                 result += check_matrix_solve(n+1, n, 1, 0);
190         cout << '.' << flush;
191         // solve some overdetermined numeric systems
192         for (unsigned n=1; n<12; ++n)
193                 result += check_matrix_solve(n, n+1, 1, 0);
194         cout << '.' << flush;
195         // solve some multiple numeric systems
196         for (unsigned n=1; n<12; ++n)
197                 result += check_matrix_solve(n, n, n/3+1, 0);
198         cout << '.' << flush;
199         // solve some symbolic linear systems
200         for (unsigned n=1; n<7; ++n)
201                 result += check_matrix_solve(n, n, 1, 2);
202         cout << '.' << flush;
203         
204         // check lsolve, the wrapper function around matrix::solve()
205         result += check_inifcns_lsolve(2);  cout << '.' << flush;
206         result += check_inifcns_lsolve(3);  cout << '.' << flush;
207         result += check_inifcns_lsolve(4);  cout << '.' << flush;
208         result += check_inifcns_lsolve(5);  cout << '.' << flush;
209         result += check_inifcns_lsolve(6);  cout << '.' << flush;
210                 
211         if (!result) {
212                 cout << " passed " << endl;
213                 clog << "(no output)" << endl;
214         } else {
215                 cout << " failed " << endl;
216         }
217         
218         return result;
219 }