- See if __GNUC__ < 2.97 before using std::vector<..,malloc_alloc>. Sorry,
[ginac.git] / check / check_lsolve.cpp
1 /** @file check_lsolve.cpp
2  *
3  *  These test routines do some simple checks on solving linear systems of
4  *  symbolic equations. */
5
6 /*
7  *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
8  *
9  *  This program is free software; you can redistribute it and/or modify
10  *  it under the terms of the GNU General Public License as published by
11  *  the Free Software Foundation; either version 2 of the License, or
12  *  (at your option) any later version.
13  *
14  *  This program is distributed in the hope that it will be useful,
15  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17  *  GNU General Public License for more details.
18  *
19  *  You should have received a copy of the GNU General Public License
20  *  along with this program; if not, write to the Free Software
21  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
22  */
23
24 #include "checks.h"
25
26 static unsigned check_matrix_solve(unsigned m, unsigned n, unsigned p,
27                                                                    unsigned degree)
28 {
29         const symbol a("a");
30         matrix A(m,n);
31         matrix B(m,p);
32         // set the first min(m,n) rows of A and B
33         for (unsigned ro=0; (ro<m)&&(ro<n); ++ro) {
34                 for (unsigned co=0; co<n; ++co)
35                         A.set(ro,co,dense_univariate_poly(a,degree));
36                 for (unsigned co=0; co<p; ++co)
37                         B.set(ro,co,dense_univariate_poly(a,degree));
38         }
39         // repeat excessive rows of A and B to avoid excessive construction of
40         // overdetermined linear systems
41         for (unsigned ro=n; ro<m; ++ro) {
42                 for (unsigned co=0; co<n; ++co)
43                         A.set(ro,co,A(ro-1,co));
44                 for (unsigned co=0; co<p; ++co)
45                         B.set(ro,co,B(ro-1,co));
46         }
47         // create a vector of n*p symbols all named "xrc" where r and c are ints
48         vector<symbol> x;
49         matrix X(n,p);
50         for (unsigned i=0; i<n; ++i) {
51                 for (unsigned j=0; j<p; ++j) {
52                         char buf[4];
53                         ostrstream(buf,sizeof(buf)) << i << j << ends;
54                         x.push_back(symbol(string("x")+buf));
55                         X.set(i,j,x[p*i+j]);
56                 }
57         }
58         matrix sol(n,p);
59         // Solve the system A*X==B:
60         try {
61                 sol = A.solve(X, B);
62         } catch (const exception & err) {  // catch runtime_error
63                 // Presumably, the coefficient matrix A was degenerate
64                 string errwhat = err.what();
65                 if (errwhat == "matrix::solve(): inconsistent linear system")
66                         return 0;
67                 else
68                         clog << "caught exception: " << errwhat << endl;
69                 throw;
70         }
71         
72         // check the result with our original matrix:
73         bool errorflag = false;
74         for (unsigned ro=0; ro<m; ++ro) {
75                 for (unsigned pco=0; pco<p; ++pco) {
76                         ex e = 0;
77                         for (unsigned co=0; co<n; ++co)
78                         e += A(ro,co)*sol(co,pco);
79                         if (!(e-B(ro,pco)).normal().is_zero())
80                                 errorflag = true;
81                 }
82         }
83         if (errorflag) {
84                 clog << "Our solve method claims that A*X==B, with matrices" << endl
85                      << "A == " << A << endl
86                      << "X == " << sol << endl
87                      << "B == " << B << endl;
88                 return 1;
89         }
90         
91         return 0;
92 }
93
94 static unsigned check_inifcns_lsolve(unsigned n)
95 {
96         unsigned result = 0;
97         
98         for (int repetition=0; repetition<100; ++repetition) {
99                 // create two size n vectors of symbols, one for the coefficients
100                 // a[0],..,a[n], one for indeterminates x[0]..x[n]:
101                 vector<symbol> a;
102                 vector<symbol> x;
103                 for (unsigned i=0; i<n; ++i) {
104                         char buf[3];
105                         ostrstream(buf,sizeof(buf)) << i << ends;
106                         a.push_back(symbol(string("a")+buf));
107                         x.push_back(symbol(string("x")+buf));
108                 }
109                 lst eqns;  // equation list
110                 lst vars;  // variable list
111                 ex sol; // solution
112                 // Create a random linear system...
113                 for (unsigned i=0; i<n; ++i) {
114                         ex lhs = rand()%201-100;
115                         ex rhs = rand()%201-100;
116                         for (unsigned j=0; j<n; ++j) {
117                                 // ...with small coefficients to give degeneracy a chance...
118                                 lhs += a[j]*(rand()%21-10);
119                                 rhs += x[j]*(rand()%21-10);
120                         }
121                         eqns.append(lhs==rhs);
122                         vars.append(x[i]);
123                 }
124                 // ...solve it...
125                 sol = lsolve(eqns, vars);
126                 
127                 // ...and check the solution:
128                 if (sol.nops() == 0) {
129                         // no solution was found
130                         // is the coefficient matrix really, really, really degenerate?
131                         matrix coeffmat(n,n);
132                         for (unsigned ro=0; ro<n; ++ro)
133                                 for (unsigned co=0; co<n; ++co)
134                                         coeffmat.set(ro,co,eqns.op(co).rhs().coeff(a[co],1));
135                         if (!coeffmat.determinant().is_zero()) {
136                                 ++result;
137                                 clog << "solution of the system " << eqns << " for " << vars
138                                          << " was not found" << endl;
139                         }
140                 } else {
141                         // insert the solution into rhs of out equations
142                         bool errorflag = false;
143                         for (unsigned i=0; i<n; ++i)
144                                 if (eqns.op(i).rhs().subs(sol) != eqns.op(i).lhs())
145                                         errorflag = true;
146                         if (errorflag) {
147                                 ++result;
148                                 clog << "solution of the system " << eqns << " for " << vars
149                                      << " erroneously returned " << sol << endl;
150                         }
151                 }
152         }
153         
154         return result;
155 }
156
157 unsigned check_lsolve(void)
158 {
159         unsigned result = 0;
160         
161         cout << "checking linear solve" << flush;
162         clog << "---------linear solve:" << endl;
163         
164         // solve some numeric linear systems
165         for (unsigned n=1; n<12; ++n)
166                 result += check_matrix_solve(n, n, 1, 0);
167         cout << '.' << flush;
168         // solve some underdetermined numeric systems
169         for (unsigned n=1; n<12; ++n)
170                 result += check_matrix_solve(n+1, n, 1, 0);
171         cout << '.' << flush;
172         // solve some overdetermined numeric systems
173         for (unsigned n=1; n<12; ++n)
174                 result += check_matrix_solve(n, n+1, 1, 0);
175         cout << '.' << flush;
176         // solve some multiple numeric systems
177         for (unsigned n=1; n<12; ++n)
178                 result += check_matrix_solve(n, n, n/3+1, 0);
179         cout << '.' << flush;
180         // solve some symbolic linear systems
181         for (unsigned n=1; n<7; ++n)
182                 result += check_matrix_solve(n, n, 1, 2);
183         cout << '.' << flush;
184         
185         // check lsolve, the wrapper function around matrix::solve()
186         result += check_inifcns_lsolve(2);  cout << '.' << flush;
187         result += check_inifcns_lsolve(3);  cout << '.' << flush;
188         result += check_inifcns_lsolve(4);  cout << '.' << flush;
189         result += check_inifcns_lsolve(5);  cout << '.' << flush;
190         result += check_inifcns_lsolve(6);  cout << '.' << flush;
191                 
192         if (!result) {
193                 cout << " passed " << endl;
194                 clog << "(no output)" << endl;
195         } else {
196                 cout << " failed " << endl;
197         }
198         
199         return result;
200 }