log(-<realnumber>) now returns a real number
[ginac.git] / check / check_lsolve.cpp
1 /** @file check_lsolve.cpp
2  *
3  *  These test routines do some simple checks on solving linear systems of
4  *  symbolic equations.  They are a well-tried resource for cross-checking
5  *  the underlying symbolic manipulations. */
6
7 /*
8  *  GiNaC Copyright (C) 1999-2004 Johannes Gutenberg University Mainz, Germany
9  *
10  *  This program is free software; you can redistribute it and/or modify
11  *  it under the terms of the GNU General Public License as published by
12  *  the Free Software Foundation; either version 2 of the License, or
13  *  (at your option) any later version.
14  *
15  *  This program is distributed in the hope that it will be useful,
16  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18  *  GNU General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with this program; if not, write to the Free Software
22  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
23  */
24
25 #include "checks.h"
26 #include <sstream>
27
28 static unsigned check_matrix_solve(unsigned m, unsigned n, unsigned p,
29                                                                    unsigned degree)
30 {
31         const symbol a("a");
32         matrix A(m,n);
33         matrix B(m,p);
34         // set the first min(m,n) rows of A and B
35         for (unsigned ro=0; (ro<m)&&(ro<n); ++ro) {
36                 for (unsigned co=0; co<n; ++co)
37                         A.set(ro,co,dense_univariate_poly(a,degree));
38                 for (unsigned co=0; co<p; ++co)
39                         B.set(ro,co,dense_univariate_poly(a,degree));
40         }
41         // repeat excessive rows of A and B to avoid excessive construction of
42         // overdetermined linear systems
43         for (unsigned ro=n; ro<m; ++ro) {
44                 for (unsigned co=0; co<n; ++co)
45                         A.set(ro,co,A(ro-1,co));
46                 for (unsigned co=0; co<p; ++co)
47                         B.set(ro,co,B(ro-1,co));
48         }
49         // create a vector of n*p symbols all named "xrc" where r and c are ints
50         vector<symbol> x;
51         matrix X(n,p);
52         for (unsigned i=0; i<n; ++i) {
53                 for (unsigned j=0; j<p; ++j) {
54                         ostringstream buf;
55                         buf << "x" << i << j << ends;
56                         x.push_back(symbol(buf.str()));
57                         X.set(i,j,x[p*i+j]);
58                 }
59         }
60         matrix sol(n,p);
61         // Solve the system A*X==B:
62         try {
63                 sol = A.solve(X, B);
64         } catch (const exception & err) {  // catch runtime_error
65                 // Presumably, the coefficient matrix A was degenerate
66                 string errwhat = err.what();
67                 if (errwhat == "matrix::solve(): inconsistent linear system")
68                         return 0;
69                 else
70                         clog << "caught exception: " << errwhat << endl;
71                 throw;
72         }
73         
74         // check the result with our original matrix:
75         bool errorflag = false;
76         for (unsigned ro=0; ro<m; ++ro) {
77                 for (unsigned pco=0; pco<p; ++pco) {
78                         ex e = 0;
79                         for (unsigned co=0; co<n; ++co)
80                         e += A(ro,co)*sol(co,pco);
81                         if (!(e-B(ro,pco)).normal().is_zero())
82                                 errorflag = true;
83                 }
84         }
85         if (errorflag) {
86                 clog << "Our solve method claims that A*X==B, with matrices" << endl
87                      << "A == " << A << endl
88                      << "X == " << sol << endl
89                      << "B == " << B << endl;
90                 return 1;
91         }
92         
93         return 0;
94 }
95
96 static unsigned check_inifcns_lsolve(unsigned n)
97 {
98         unsigned result = 0;
99         
100         for (int repetition=0; repetition<200; ++repetition) {
101                 // create two size n vectors of symbols, one for the coefficients
102                 // a[0],..,a[n], one for indeterminates x[0]..x[n]:
103                 vector<symbol> a;
104                 vector<symbol> x;
105                 for (unsigned i=0; i<n; ++i) {
106                         ostringstream buf;
107                         buf << i << ends;
108                         a.push_back(symbol(string("a")+buf.str()));
109                         x.push_back(symbol(string("x")+buf.str()));
110                 }
111                 lst eqns;  // equation list
112                 lst vars;  // variable list
113                 ex sol; // solution
114                 // Create a random linear system...
115                 for (unsigned i=0; i<n; ++i) {
116                         ex lhs = rand()%201-100;
117                         ex rhs = rand()%201-100;
118                         for (unsigned j=0; j<n; ++j) {
119                                 // ...with small coefficients to give degeneracy a chance...
120                                 lhs += a[j]*(rand()%21-10);
121                                 rhs += x[j]*(rand()%21-10);
122                         }
123                         eqns.append(lhs==rhs);
124                         vars.append(x[i]);
125                 }
126                 // ...solve it...
127                 sol = lsolve(eqns, vars);
128                 
129                 // ...and check the solution:
130                 if (sol.nops() == 0) {
131                         // no solution was found
132                         // is the coefficient matrix really, really, really degenerate?
133                         matrix coeffmat(n,n);
134                         for (unsigned ro=0; ro<n; ++ro)
135                                 for (unsigned co=0; co<n; ++co)
136                                         coeffmat.set(ro,co,eqns.op(co).rhs().coeff(a[co],1));
137                         if (!coeffmat.determinant().is_zero()) {
138                                 ++result;
139                                 clog << "solution of the system " << eqns << " for " << vars
140                                          << " was not found" << endl;
141                         }
142                 } else {
143                         // insert the solution into rhs of out equations
144                         bool errorflag = false;
145                         for (unsigned i=0; i<n; ++i)
146                                 if (eqns.op(i).rhs().subs(sol) != eqns.op(i).lhs())
147                                         errorflag = true;
148                         if (errorflag) {
149                                 ++result;
150                                 clog << "solution of the system " << eqns << " for " << vars
151                                      << " erroneously returned " << sol << endl;
152                         }
153                 }
154         }
155         
156         return result;
157 }
158
159 unsigned check_lsolve()
160 {
161         unsigned result = 0;
162         
163         cout << "checking linear solve" << flush;
164         clog << "---------linear solve:" << endl;
165         
166         // solve some numeric linear systems
167         for (unsigned n=1; n<14; ++n)
168                 result += check_matrix_solve(n, n, 1, 0);
169         cout << '.' << flush;
170         // solve some underdetermined numeric systems
171         for (unsigned n=1; n<14; ++n)
172                 result += check_matrix_solve(n+1, n, 1, 0);
173         cout << '.' << flush;
174         // solve some overdetermined numeric systems
175         for (unsigned n=1; n<14; ++n)
176                 result += check_matrix_solve(n, n+1, 1, 0);
177         cout << '.' << flush;
178         // solve some multiple numeric systems
179         for (unsigned n=1; n<14; ++n)
180                 result += check_matrix_solve(n, n, n/3+1, 0);
181         cout << '.' << flush;
182         // solve some symbolic linear systems
183         for (unsigned n=1; n<8; ++n)
184                 result += check_matrix_solve(n, n, 1, 2);
185         cout << '.' << flush;
186         
187         // check lsolve, the wrapper function around matrix::solve()
188         result += check_inifcns_lsolve(2);  cout << '.' << flush;
189         result += check_inifcns_lsolve(3);  cout << '.' << flush;
190         result += check_inifcns_lsolve(4);  cout << '.' << flush;
191         result += check_inifcns_lsolve(5);  cout << '.' << flush;
192         result += check_inifcns_lsolve(6);  cout << '.' << flush;
193                 
194         if (!result) {
195                 cout << " passed " << endl;
196                 clog << "(no output)" << endl;
197         } else {
198                 cout << " failed " << endl;
199         }
200         
201         return result;
202 }