Update copyright statements.
[ginac.git] / check / check_lsolve.cpp
1 /** @file check_lsolve.cpp
2  *
3  *  These test routines do some simple checks on solving linear systems of
4  *  symbolic equations.  They are a well-tried resource for cross-checking
5  *  the underlying symbolic manipulations. */
6
7 /*
8  *  GiNaC Copyright (C) 1999-2014 Johannes Gutenberg University Mainz, Germany
9  *
10  *  This program is free software; you can redistribute it and/or modify
11  *  it under the terms of the GNU General Public License as published by
12  *  the Free Software Foundation; either version 2 of the License, or
13  *  (at your option) any later version.
14  *
15  *  This program is distributed in the hope that it will be useful,
16  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18  *  GNU General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with this program; if not, write to the Free Software
22  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
23  */
24
25 #include "ginac.h"
26 using namespace GiNaC;
27
28 #include <cstdlib> // for rand()
29 #include <iostream>
30 #include <sstream>
31 using namespace std;
32
33 extern const ex 
34 dense_univariate_poly(const symbol & x, unsigned degree);
35
36 static unsigned check_matrix_solve(unsigned m, unsigned n, unsigned p,
37                                                                    unsigned degree)
38 {
39         const symbol a("a");
40         matrix A(m,n);
41         matrix B(m,p);
42         // set the first min(m,n) rows of A and B
43         for (unsigned ro=0; (ro<m)&&(ro<n); ++ro) {
44                 for (unsigned co=0; co<n; ++co)
45                         A.set(ro,co,dense_univariate_poly(a,degree));
46                 for (unsigned co=0; co<p; ++co)
47                         B.set(ro,co,dense_univariate_poly(a,degree));
48         }
49         // repeat excessive rows of A and B to avoid excessive construction of
50         // overdetermined linear systems
51         for (unsigned ro=n; ro<m; ++ro) {
52                 for (unsigned co=0; co<n; ++co)
53                         A.set(ro,co,A(ro-1,co));
54                 for (unsigned co=0; co<p; ++co)
55                         B.set(ro,co,B(ro-1,co));
56         }
57         // create a vector of n*p symbols all named "xrc" where r and c are ints
58         vector<symbol> x;
59         matrix X(n,p);
60         for (unsigned i=0; i<n; ++i) {
61                 for (unsigned j=0; j<p; ++j) {
62                         ostringstream buf;
63                         buf << "x" << i << j << ends;
64                         x.push_back(symbol(buf.str()));
65                         X.set(i,j,x[p*i+j]);
66                 }
67         }
68         matrix sol(n,p);
69         // Solve the system A*X==B:
70         try {
71                 sol = A.solve(X, B);
72         } catch (const exception & err) {  // catch runtime_error
73                 // Presumably, the coefficient matrix A was degenerate
74                 string errwhat = err.what();
75                 if (errwhat == "matrix::solve(): inconsistent linear system")
76                         return 0;
77                 else
78                         clog << "caught exception: " << errwhat << endl;
79                 throw;
80         }
81         
82         // check the result with our original matrix:
83         bool errorflag = false;
84         for (unsigned ro=0; ro<m; ++ro) {
85                 for (unsigned pco=0; pco<p; ++pco) {
86                         ex e = 0;
87                         for (unsigned co=0; co<n; ++co)
88                         e += A(ro,co)*sol(co,pco);
89                         if (!(e-B(ro,pco)).normal().is_zero())
90                                 errorflag = true;
91                 }
92         }
93         if (errorflag) {
94                 clog << "Our solve method claims that A*X==B, with matrices" << endl
95                      << "A == " << A << endl
96                      << "X == " << sol << endl
97                      << "B == " << B << endl;
98                 return 1;
99         }
100         
101         return 0;
102 }
103
104 static unsigned check_inifcns_lsolve(unsigned n)
105 {
106         unsigned result = 0;
107         
108         for (int repetition=0; repetition<200; ++repetition) {
109                 // create two size n vectors of symbols, one for the coefficients
110                 // a[0],..,a[n], one for indeterminates x[0]..x[n]:
111                 vector<symbol> a;
112                 vector<symbol> x;
113                 for (unsigned i=0; i<n; ++i) {
114                         ostringstream buf;
115                         buf << i << ends;
116                         a.push_back(symbol(string("a")+buf.str()));
117                         x.push_back(symbol(string("x")+buf.str()));
118                 }
119                 lst eqns;  // equation list
120                 lst vars;  // variable list
121                 ex sol; // solution
122                 // Create a random linear system...
123                 for (unsigned i=0; i<n; ++i) {
124                         ex lhs = rand()%201-100;
125                         ex rhs = rand()%201-100;
126                         for (unsigned j=0; j<n; ++j) {
127                                 // ...with small coefficients to give degeneracy a chance...
128                                 lhs += a[j]*(rand()%21-10);
129                                 rhs += x[j]*(rand()%21-10);
130                         }
131                         eqns.append(lhs==rhs);
132                         vars.append(x[i]);
133                 }
134                 // ...solve it...
135                 sol = lsolve(eqns, vars);
136                 
137                 // ...and check the solution:
138                 if (sol.nops() == 0) {
139                         // no solution was found
140                         // is the coefficient matrix really, really, really degenerate?
141                         matrix coeffmat(n,n);
142                         for (unsigned ro=0; ro<n; ++ro)
143                                 for (unsigned co=0; co<n; ++co)
144                                         coeffmat.set(ro,co,eqns.op(co).rhs().coeff(a[co],1));
145                         if (!coeffmat.determinant().is_zero()) {
146                                 ++result;
147                                 clog << "solution of the system " << eqns << " for " << vars
148                                          << " was not found" << endl;
149                         }
150                 } else {
151                         // insert the solution into rhs of out equations
152                         bool errorflag = false;
153                         for (unsigned i=0; i<n; ++i)
154                                 if (eqns.op(i).rhs().subs(sol) != eqns.op(i).lhs())
155                                         errorflag = true;
156                         if (errorflag) {
157                                 ++result;
158                                 clog << "solution of the system " << eqns << " for " << vars
159                                      << " erroneously returned " << sol << endl;
160                         }
161                 }
162         }
163         
164         return result;
165 }
166
167 unsigned check_lsolve()
168 {
169         unsigned result = 0;
170         
171         cout << "checking linear solve" << flush;
172         
173         // solve some numeric linear systems
174         for (unsigned n=1; n<14; ++n)
175                 result += check_matrix_solve(n, n, 1, 0);
176         cout << '.' << flush;
177         // solve some underdetermined numeric systems
178         for (unsigned n=1; n<14; ++n)
179                 result += check_matrix_solve(n+1, n, 1, 0);
180         cout << '.' << flush;
181         // solve some overdetermined numeric systems
182         for (unsigned n=1; n<14; ++n)
183                 result += check_matrix_solve(n, n+1, 1, 0);
184         cout << '.' << flush;
185         // solve some multiple numeric systems
186         for (unsigned n=1; n<14; ++n)
187                 result += check_matrix_solve(n, n, n/3+1, 0);
188         cout << '.' << flush;
189         // solve some symbolic linear systems
190         for (unsigned n=1; n<8; ++n)
191                 result += check_matrix_solve(n, n, 1, 2);
192         cout << '.' << flush;
193         
194         // check lsolve, the wrapper function around matrix::solve()
195         result += check_inifcns_lsolve(2);  cout << '.' << flush;
196         result += check_inifcns_lsolve(3);  cout << '.' << flush;
197         result += check_inifcns_lsolve(4);  cout << '.' << flush;
198         result += check_inifcns_lsolve(5);  cout << '.' << flush;
199         result += check_inifcns_lsolve(6);  cout << '.' << flush;
200                 
201         return result;
202 }
203
204 int main(int argc, char** argv)
205 {
206         return check_lsolve();
207 }