34d5f2f93c008a0deeac8aeaf0e4aff64992ed18
[ginac.git] / ginac / matrix.cpp
1 /** @file matrix.cpp
2  *
3  *  Implementation of symbolic matrices */
4
5 /*
6  *  GiNaC Copyright (C) 1999 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <algorithm>
24 #include <stdexcept>
25
26 #include "matrix.h"
27
28 //////////
29 // default constructor, destructor, copy constructor, assignment operator
30 // and helpers:
31 //////////
32
33 // public
34
35 /** Default ctor.  Initializes to 1 x 1-dimensional zero-matrix. */
36 matrix::matrix()
37     : basic(TINFO_matrix), row(1), col(1)
38 {
39     debugmsg("matrix default constructor",LOGLEVEL_CONSTRUCT);
40     m.push_back(exZERO());
41 }
42
43 matrix::~matrix()
44 {
45     debugmsg("matrix destructor",LOGLEVEL_DESTRUCT);
46 }
47
48 matrix::matrix(matrix const & other)
49 {
50     debugmsg("matrix copy constructor",LOGLEVEL_CONSTRUCT);
51     copy(other);
52 }
53
54 matrix const & matrix::operator=(matrix const & other)
55 {
56     debugmsg("matrix operator=",LOGLEVEL_ASSIGNMENT);
57     if (this != &other) {
58         destroy(1);
59         copy(other);
60     }
61     return *this;
62 }
63
64 // protected
65
66 void matrix::copy(matrix const & other)
67 {
68     basic::copy(other);
69     row=other.row;
70     col=other.col;
71     m=other.m;  // use STL's vector copying
72 }
73
74 void matrix::destroy(bool call_parent)
75 {
76     if (call_parent) basic::destroy(call_parent);
77 }
78
79 //////////
80 // other constructors
81 //////////
82
83 // public
84
85 /** Very common ctor.  Initializes to r x c-dimensional zero-matrix.
86  *
87  *  @param r number of rows
88  *  @param c number of cols */
89 matrix::matrix(int r, int c)
90     : basic(TINFO_matrix), row(r), col(c)
91 {
92     debugmsg("matrix constructor from int,int",LOGLEVEL_CONSTRUCT);
93     m.resize(r*c, exZERO());
94 }
95
96 // protected
97
98 /** Ctor from representation, for internal use only. */
99 matrix::matrix(int r, int c, vector<ex> const & m2)
100     : basic(TINFO_matrix), row(r), col(c), m(m2)
101 {
102     debugmsg("matrix constructor from int,int,vector<ex>",LOGLEVEL_CONSTRUCT);
103 }
104
105 //////////
106 // functions overriding virtual functions from bases classes
107 //////////
108
109 // public
110
111 basic * matrix::duplicate() const
112 {
113     debugmsg("matrix duplicate",LOGLEVEL_DUPLICATE);
114     return new matrix(*this);
115 }
116
117 /** nops is defined to be rows x columns. */
118 int matrix::nops() const
119 {
120     return row*col;
121 }
122
123 /** returns matrix entry at position (i/col, i%col). */
124 ex & matrix::let_op(int const i)
125 {
126     return m[i];
127 }
128
129 /** expands the elements of a matrix entry by entry. */
130 ex matrix::expand(unsigned options) const
131 {
132     vector<ex> tmp(row*col);
133     for (int i=0; i<row*col; ++i) {
134         tmp[i]=m[i].expand(options);
135     }
136     return matrix(row, col, tmp);
137 }
138
139 /** Search ocurrences.  A matrix 'has' an expression if it is the expression
140  *  itself or one of the elements 'has' it. */
141 bool matrix::has(ex const & other) const
142 {
143     ASSERT(other.bp!=0);
144     
145     // tautology: it is the expression itself
146     if (is_equal(*other.bp)) return true;
147     
148     // search all the elements
149     for (vector<ex>::const_iterator r=m.begin(); r!=m.end(); ++r) {
150         if ((*r).has(other)) return true;
151     }
152     return false;
153 }
154
155 /** evaluate matrix entry by entry. */
156 ex matrix::eval(int level) const
157 {
158     debugmsg("matrix eval",LOGLEVEL_MEMBER_FUNCTION);
159     
160     // check if we have to do anything at all
161     if ((level==1)&&(flags & status_flags::evaluated)) {
162         return *this;
163     }
164     
165     // emergency break
166     if (level == -max_recursion_level) {
167         throw (std::runtime_error("matrix::eval(): recursion limit exceeded"));
168     }
169     
170     // eval() entry by entry
171     vector<ex> m2(row*col);
172     --level;    
173     for (int r=0; r<row; ++r) {
174         for (int c=0; c<col; ++c) {
175             m2[r*col+c] = m[r*col+c].eval(level);
176         }
177     }
178     
179     return (new matrix(row, col, m2))->setflag(status_flags::dynallocated |
180                                                status_flags::evaluated );
181 }
182
183 /** evaluate matrix numerically entry by entry. */
184 ex matrix::evalf(int level) const
185 {
186     debugmsg("matrix evalf",LOGLEVEL_MEMBER_FUNCTION);
187         
188     // check if we have to do anything at all
189     if (level==1) {
190         return *this;
191     }
192     
193     // emergency break
194     if (level == -max_recursion_level) {
195         throw (std::runtime_error("matrix::evalf(): recursion limit exceeded"));
196     }
197     
198     // evalf() entry by entry
199     vector<ex> m2(row*col);
200     --level;
201     for (int r=0; r<row; ++r) {
202         for (int c=0; c<col; ++c) {
203             m2[r*col+c] = m[r*col+c].evalf(level);
204         }
205     }
206     return matrix(row, col, m2);
207 }
208
209 // protected
210
211 int matrix::compare_same_type(basic const & other) const
212 {
213     ASSERT(is_exactly_of_type(other, matrix));
214     matrix const & o=static_cast<matrix &>(const_cast<basic &>(other));
215     
216     // compare number of rows
217     if (row != o.rows()) {
218         return row < o.rows() ? -1 : 1;
219     }
220     
221     // compare number of columns
222     if (col != o.cols()) {
223         return col < o.cols() ? -1 : 1;
224     }
225     
226     // equal number of rows and columns, compare individual elements
227     int cmpval;
228     for (int r=0; r<row; ++r) {
229         for (int c=0; c<col; ++c) {
230             cmpval=((*this)(r,c)).compare(o(r,c));
231             if (cmpval!=0) return cmpval;
232         }
233     }
234     // all elements are equal => matrices are equal;
235     return 0;
236 }
237
238 //////////
239 // non-virtual functions in this class
240 //////////
241
242 // public
243
244 /** Sum of matrices.
245  *
246  *  @exception logic_error (incompatible matrices) */
247 matrix matrix::add(matrix const & other) const
248 {
249     if (col != other.col || row != other.row) {
250         throw (std::logic_error("matrix::add(): incompatible matrices"));
251     }
252     
253     vector<ex> sum(this->m);
254     vector<ex>::iterator i;
255     vector<ex>::const_iterator ci;
256     for (i=sum.begin(), ci=other.m.begin();
257          i!=sum.end();
258          ++i, ++ci) {
259         (*i) += (*ci);
260     }
261     return matrix(row,col,sum);
262 }
263
264 /** Difference of matrices.
265  *
266  *  @exception logic_error (incompatible matrices) */
267 matrix matrix::sub(matrix const & other) const
268 {
269     if (col != other.col || row != other.row) {
270         throw (std::logic_error("matrix::sub(): incompatible matrices"));
271     }
272     
273     vector<ex> dif(this->m);
274     vector<ex>::iterator i;
275     vector<ex>::const_iterator ci;
276     for (i=dif.begin(), ci=other.m.begin();
277          i!=dif.end();
278          ++i, ++ci) {
279         (*i) -= (*ci);
280     }
281     return matrix(row,col,dif);
282 }
283
284 /** Product of matrices.
285  *
286  *  @exception logic_error (incompatible matrices) */
287 matrix matrix::mul(matrix const & other) const
288 {
289     if (col != other.row) {
290         throw (std::logic_error("matrix::mul(): incompatible matrices"));
291     }
292     
293     vector<ex> prod(row*other.col);
294     for (int i=0; i<row; ++i) {
295         for (int j=0; j<other.col; ++j) {
296             for (int l=0; l<col; ++l) {
297                 prod[i*other.col+j] += m[i*col+l] * other.m[l*other.col+j];
298             }
299         }
300     }
301     return matrix(row, other.col, prod);
302 }
303
304 /** operator() to access elements.
305  *
306  *  @param ro row of element
307  *  @param co column of element 
308  *  @exception range_error (index out of range) */
309 ex const & matrix::operator() (int ro, int co) const
310 {
311     if (ro<0 || ro>=row || co<0 || co>=col) {
312         throw (std::range_error("matrix::operator(): index out of range"));
313     }
314     
315     return m[ro*col+co];
316 }
317
318 /** Set individual elements manually.
319  *
320  *  @exception range_error (index out of range) */
321 matrix & matrix::set(int ro, int co, ex value)
322 {
323     if (ro<0 || ro>=row || co<0 || co>=col) {
324         throw (std::range_error("matrix::set(): index out of range"));
325     }
326     
327     ensure_if_modifiable();
328     m[ro*col+co]=value;
329     return *this;
330 }
331
332 /** Transposed of an m x n matrix, producing a new n x m matrix object that
333  *  represents the transposed. */
334 matrix matrix::transpose(void) const
335 {
336     vector<ex> trans(col*row);
337     
338     for (int r=0; r<col; ++r) {
339         for (int c=0; c<row; ++c) {
340             trans[r*row+c] = m[c*col+r];
341         }
342     }
343     return matrix(col,row,trans);
344 }
345
346 /* Determiant of purely numeric matrix, using pivoting. This routine is only
347  * called internally by matrix::determinant(). */
348 ex determinant_numeric(const matrix & M)
349 {
350     ASSERT(M.rows()==M.cols());  // cannot happen, just in case...
351     matrix tmp(M);
352     ex det=exONE();
353     ex piv;
354     
355     for (int r1=0; r1<M.rows(); ++r1) {
356         int indx = tmp.pivot(r1);
357         if (indx == -1) {
358             return exZERO();
359         }
360         if (indx != 0) {
361             det *= exMINUSONE();
362         }
363         det = det * tmp.m[r1*M.cols()+r1];
364         for (int r2=r1+1; r2<M.rows(); ++r2) {
365             piv = tmp.m[r2*M.cols()+r1] / tmp.m[r1*M.cols()+r1];
366             for (int c=r1+1; c<M.cols(); c++) {
367                 tmp.m[r2*M.cols()+c] -= piv * tmp.m[r1*M.cols()+c];
368             }
369         }
370     }
371     return det;
372 }
373
374 // Compute the sign of a permutation of a vector of things, used internally
375 // by determinant_symbolic_perm() where it is instantiated for int.
376 template <class T>
377 int permutation_sign(vector<T> s)
378 {
379     if (s.size() < 2)
380         return 0;
381     int sigma=1;
382     for (typename vector<T>::iterator i=s.begin(); i!=s.end()-1; ++i) {
383         for (typename vector<T>::iterator j=i+1; j!=s.end(); ++j) {
384             if (*i == *j)
385                 return 0;
386             if (*i > *j) {
387                 iter_swap(i,j);
388                 sigma = -sigma;
389             }
390         }
391     }
392     return sigma;
393 }
394
395 /** Determinant built by application of the full permutation group. This
396  *  routine is only called internally by matrix::determinant(). */
397 ex determinant_symbolic_perm(const matrix & M)
398 {
399     ASSERT(M.rows()==M.cols());  // cannot happen, just in case...
400     
401     if (M.rows()==1) {  // speed things up
402         return M(0,0);
403     }
404     
405     ex det;
406     ex term;
407     vector<int> sigma(M.cols());
408     for (int i=0; i<M.cols(); ++i) sigma[i]=i;
409     
410     do {
411         term = M(sigma[0],0);
412         for (int i=1; i<M.cols(); ++i) term *= M(sigma[i],i);
413         det += permutation_sign(sigma)*term;
414     } while (next_permutation(sigma.begin(), sigma.end()));
415     
416     return det;
417 }
418
419 /** Recursive determiant for small matrices having at least one symbolic entry.
420  *  This algorithm is also known as Laplace-expansion. This routine is only
421  *  called internally by matrix::determinant(). */
422 ex determinant_symbolic_minor(const matrix & M)
423 {
424     ASSERT(M.rows()==M.cols());  // cannot happen, just in case...
425     
426     if (M.rows()==1) {  // end of recursion
427         return M(0,0);
428     }
429     if (M.rows()==2) {  // speed things up
430         return (M(0,0)*M(1,1)-
431                 M(1,0)*M(0,1));
432     }
433     if (M.rows()==3) {  // speed things up even a little more
434         return ((M(2,1)*M(0,2)-M(2,2)*M(0,1))*M(1,0)+
435                 (M(1,2)*M(0,1)-M(1,1)*M(0,2))*M(2,0)+
436                 (M(2,2)*M(1,1)-M(2,1)*M(1,2))*M(0,0));
437     }
438     
439     ex det;
440     matrix minorM(M.rows()-1,M.cols()-1);
441     for (int r1=0; r1<M.rows(); ++r1) {
442         // assemble the minor matrix
443         for (int r=0; r<minorM.rows(); ++r) {
444             for (int c=0; c<minorM.cols(); ++c) {
445                 if (r<r1) {
446                     minorM.set(r,c,M(r,c+1));
447                 } else {
448                     minorM.set(r,c,M(r+1,c+1));
449                 }
450             }
451         }
452         // recurse down
453         if (r1%2) {
454             det -= M(r1,0) * determinant_symbolic_minor(minorM);
455         } else {
456             det += M(r1,0) * determinant_symbolic_minor(minorM);
457         }
458     }
459     return det;
460 }
461
462 /*  Leverrier algorithm for large matrices having at least one symbolic entry.
463  *  This routine is only called internally by matrix::determinant(). The
464  *  algorithm is deemed bad for symbolic matrices since it returns expressions
465  *  that are very hard to canonicalize. */
466 /*ex determinant_symbolic_leverrier(const matrix & M)
467  *{
468  *    ASSERT(M.rows()==M.cols());  // cannot happen, just in case...
469  *    
470  *    matrix B(M);
471  *    matrix I(M.row, M.col);
472  *    ex c=B.trace();
473  *    for (int i=1; i<M.row; ++i) {
474  *        for (int j=0; j<M.row; ++j)
475  *            I.m[j*M.col+j] = c;
476  *        B = M.mul(B.sub(I));
477  *        c = B.trace()/ex(i+1);
478  *    }
479  *    if (M.row%2) {
480  *        return c;
481  *    } else {
482  *        return -c;
483  *    }
484  *}*/
485
486 /** Determinant of square matrix.  This routine doesn't actually calculate the
487  *  determinant, it only implements some heuristics about which algorithm to
488  *  call.  When the parameter for normalization is explicitly turned off this
489  *  method does not normalize its result at the end, which might imply that
490  *  the symbolic 2x2 matrix [[a/(a-b),1],[b/(a-b),1]] is not immediatly
491  *  recognized to be unity.  (This is Mathematica's default behaviour, it
492  *  should be used with care.)
493  *
494  *  @param     normalized may be set to false if no normalization of the
495  *             result is desired (i.e. to force Mathematica behavior, Maple
496  *             does normalize the result).
497  *  @return    the determinant as a new expression
498  *  @exception logic_error (matrix not square) */
499 ex matrix::determinant(bool normalized) const
500 {
501     if (row != col) {
502         throw (std::logic_error("matrix::determinant(): matrix not square"));
503     }
504
505     // check, if there are non-numeric entries in the matrix:
506     for (vector<ex>::const_iterator r=m.begin(); r!=m.end(); ++r) {
507         if (!(*r).info(info_flags::numeric)) {
508             if (normalized) {
509                 return determinant_symbolic_minor(*this).normal();
510             } else {
511                 return determinant_symbolic_perm(*this);
512             }
513         }
514     }
515     // if it turns out that all elements are numeric
516     return determinant_numeric(*this);
517 }
518
519 /** Trace of a matrix.
520  *
521  *  @return    the sum of diagonal elements
522  *  @exception logic_error (matrix not square) */
523 ex matrix::trace(void) const
524 {
525     if (row != col) {
526         throw (std::logic_error("matrix::trace(): matrix not square"));
527     }
528     
529     ex tr;
530     for (int r=0; r<col; ++r) {
531         tr += m[r*col+r];
532     }
533     return tr;
534 }
535
536 /** Characteristic Polynomial.  The characteristic polynomial of a matrix M is
537  *  defined as the determiant of (M - lambda * 1) where 1 stands for the unit
538  *  matrix of the same dimension as M.  This method returns the characteristic
539  *  polynomial as a new expression.
540  *
541  *  @return    characteristic polynomial as new expression
542  *  @exception logic_error (matrix not square)
543  *  @see       matrix::determinant() */
544 ex matrix::charpoly(ex const & lambda) const
545 {
546     if (row != col) {
547         throw (std::logic_error("matrix::charpoly(): matrix not square"));
548     }
549     
550     matrix M(*this);
551     for (int r=0; r<col; ++r) {
552         M.m[r*col+r] -= lambda;
553     }
554     return (M.determinant());
555 }
556
557 /** Inverse of this matrix.
558  *
559  *  @return    the inverted matrix
560  *  @exception logic_error (matrix not square)
561  *  @exception runtime_error (singular matrix) */
562 matrix matrix::inverse(void) const
563 {
564     if (row != col) {
565         throw (std::logic_error("matrix::inverse(): matrix not square"));
566     }
567     
568     matrix tmp(row,col);
569     // set tmp to the unit matrix
570     for (int i=0; i<col; ++i) {
571         tmp.m[i*col+i] = exONE();
572     }
573     // create a copy of this matrix
574     matrix cpy(*this);
575     for (int r1=0; r1<row; ++r1) {
576         int indx = cpy.pivot(r1);
577         if (indx == -1) {
578             throw (std::runtime_error("matrix::inverse(): singular matrix"));
579         }
580         if (indx != 0) {  // swap rows r and indx of matrix tmp
581             for (int i=0; i<col; ++i) {
582                 tmp.m[r1*col+i].swap(tmp.m[indx*col+i]);
583             }
584         }
585         ex a1 = cpy.m[r1*col+r1];
586         for (int c=0; c<col; ++c) {
587             cpy.m[r1*col+c] /= a1;
588             tmp.m[r1*col+c] /= a1;
589         }
590         for (int r2=0; r2<row; ++r2) {
591             if (r2 != r1) {
592                 ex a2 = cpy.m[r2*col+r1];
593                 for (int c=0; c<col; ++c) {
594                     cpy.m[r2*col+c] -= a2 * cpy.m[r1*col+c];
595                     tmp.m[r2*col+c] -= a2 * tmp.m[r1*col+c];
596                 }
597             }
598         }
599     }
600     return tmp;
601 }
602
603 void matrix::ffe_swap(int r1, int c1, int r2 ,int c2)
604 {
605     ensure_if_modifiable();
606     
607     ex tmp=ffe_get(r1,c1);
608     ffe_set(r1,c1,ffe_get(r2,c2));
609     ffe_set(r2,c2,tmp);
610 }
611
612 void matrix::ffe_set(int r, int c, ex e)
613 {
614     set(r-1,c-1,e);
615 }
616
617 ex matrix::ffe_get(int r, int c) const
618 {
619     return operator()(r-1,c-1);
620 }
621
622 /** Solve a set of equations for an m x n matrix by fraction-free Gaussian
623  *  elimination. Based on algorithm 9.1 from 'Algorithms for Computer Algebra'
624  *  by Keith O. Geddes et al.
625  *
626  *  @param vars n x p matrix
627  *  @param rhs m x p matrix
628  *  @exception logic_error (incompatible matrices)
629  *  @exception runtime_error (singular matrix) */
630 matrix matrix::fraction_free_elim(matrix const & vars,
631                                   matrix const & rhs) const
632 {
633     if ((row != rhs.row) || (col != vars.row) || (rhs.col != vars.col)) {
634         throw (std::logic_error("matrix::solve(): incompatible matrices"));
635     }
636     
637     matrix a(*this); // make a copy of the matrix
638     matrix b(rhs);     // make a copy of the rhs vector
639     
640     // given an m x n matrix a, reduce it to upper echelon form
641     int m=a.row;
642     int n=a.col;
643     int sign=1;
644     ex divisor=1;
645     int r=1;
646     
647     // eliminate below row r, with pivot in column k
648     for (int k=1; (k<=n)&&(r<=m); ++k) {
649         // find a nonzero pivot
650         int p;
651         for (p=r; (p<=m)&&(a.ffe_get(p,k).is_equal(exZERO())); ++p) {}
652         // pivot is in row p
653         if (p<=m) {
654             if (p!=r) {
655                 // switch rows p and r
656                 for (int j=k; j<=n; ++j) {
657                     a.ffe_swap(p,j,r,j);
658                 }
659                 b.ffe_swap(p,1,r,1);
660                 // keep track of sign changes due to row exchange
661                 sign=-sign;
662             }
663             for (int i=r+1; i<=m; ++i) {
664                 for (int j=k+1; j<=n; ++j) {
665                     a.ffe_set(i,j,(a.ffe_get(r,k)*a.ffe_get(i,j)
666                                   -a.ffe_get(r,j)*a.ffe_get(i,k))/divisor);
667                     a.ffe_set(i,j,a.ffe_get(i,j).normal() /*.normal() */ );
668                 }
669                 b.ffe_set(i,1,(a.ffe_get(r,k)*b.ffe_get(i,1)
670                               -b.ffe_get(r,1)*a.ffe_get(i,k))/divisor);
671                 b.ffe_set(i,1,b.ffe_get(i,1).normal() /*.normal() */ );
672                 a.ffe_set(i,k,0);
673             }
674             divisor=a.ffe_get(r,k);
675             r++;
676         }
677     }
678     // optionally compute the determinant for square or augmented matrices
679     // if (r==m+1) { det=sign*divisor; } else { det=0; }
680     
681     /*
682     for (int r=1; r<=m; ++r) {
683         for (int c=1; c<=n; ++c) {
684             cout << a.ffe_get(r,c) << "\t";
685         }
686         cout << " | " <<  b.ffe_get(r,1) << endl;
687     }
688     */
689     
690 #ifdef DOASSERT
691     // test if we really have an upper echelon matrix
692     int zero_in_last_row=-1;
693     for (int r=1; r<=m; ++r) {
694         int zero_in_this_row=0;
695         for (int c=1; c<=n; ++c) {
696             if (a.ffe_get(r,c).is_equal(exZERO())) {
697                zero_in_this_row++;
698             } else {
699                 break;
700             }
701         }
702         ASSERT((zero_in_this_row>zero_in_last_row)||(zero_in_this_row=n));
703         zero_in_last_row=zero_in_this_row;
704     }
705 #endif // def DOASSERT
706     
707     // assemble solution
708     matrix sol(n,1);
709     int last_assigned_sol=n+1;
710     for (int r=m; r>0; --r) {
711         int first_non_zero=1;
712         while ((first_non_zero<=n)&&(a.ffe_get(r,first_non_zero).is_zero())) {
713             first_non_zero++;
714         }
715         if (first_non_zero>n) {
716             // row consists only of zeroes, corresponding rhs must be 0 as well
717             if (!b.ffe_get(r,1).is_zero()) {
718                 throw (std::runtime_error("matrix::fraction_free_elim(): singular matrix"));
719             }
720         } else {
721             // assign solutions for vars between first_non_zero+1 and
722             // last_assigned_sol-1: free parameters
723             for (int c=first_non_zero+1; c<=last_assigned_sol-1; ++c) {
724                 sol.ffe_set(c,1,vars.ffe_get(c,1));
725             }
726             ex e=b.ffe_get(r,1);
727             for (int c=first_non_zero+1; c<=n; ++c) {
728                 e=e-a.ffe_get(r,c)*sol.ffe_get(c,1);
729             }
730             sol.ffe_set(first_non_zero,1,
731                         (e/a.ffe_get(r,first_non_zero)).normal());
732             last_assigned_sol=first_non_zero;
733         }
734     }
735     // assign solutions for vars between 1 and
736     // last_assigned_sol-1: free parameters
737     for (int c=1; c<=last_assigned_sol-1; ++c) {
738         sol.ffe_set(c,1,vars.ffe_get(c,1));
739     }
740
741     /*
742     for (int c=1; c<=n; ++c) {
743         cout << vars.ffe_get(c,1) << "->" << sol.ffe_get(c,1) << endl;
744     }
745     */
746     
747 #ifdef DOASSERT
748     // test solution with echelon matrix
749     for (int r=1; r<=m; ++r) {
750         ex e=0;
751         for (int c=1; c<=n; ++c) {
752             e=e+a.ffe_get(r,c)*sol.ffe_get(c,1);
753         }
754         if (!(e-b.ffe_get(r,1)).normal().is_zero()) {
755             cout << "e=" << e;
756             cout << "b.ffe_get(" << r<<",1)=" << b.ffe_get(r,1) << endl;
757             cout << "diff=" << (e-b.ffe_get(r,1)).normal() << endl;
758         }
759         ASSERT((e-b.ffe_get(r,1)).normal().is_zero());
760     }
761
762     // test solution with original matrix
763     for (int r=1; r<=m; ++r) {
764         ex e=0;
765         for (int c=1; c<=n; ++c) {
766             e=e+ffe_get(r,c)*sol.ffe_get(c,1);
767         }
768         try {
769         if (!(e-rhs.ffe_get(r,1)).normal().is_zero()) {
770             cout << "e=" << e << endl;
771             e.printtree(cout);
772             ex en=e.normal();
773             cout << "e.normal()=" << en << endl;
774             en.printtree(cout);
775             cout << "rhs.ffe_get(" << r<<",1)=" << rhs.ffe_get(r,1) << endl;
776             cout << "diff=" << (e-rhs.ffe_get(r,1)).normal() << endl;
777         }
778         } catch (...) {
779             ex xxx=e-rhs.ffe_get(r,1);
780             cerr << "xxx=" << xxx << endl << endl;
781         }
782         ASSERT((e-rhs.ffe_get(r,1)).normal().is_zero());
783     }
784 #endif // def DOASSERT
785     
786     return sol;
787 }   
788     
789 /** Solve simultaneous set of equations. */
790 matrix matrix::solve(matrix const & v) const
791 {
792     if (!(row == col && col == v.row)) {
793         throw (std::logic_error("matrix::solve(): incompatible matrices"));
794     }
795     
796     // build the extended matrix of *this with v attached to the right
797     matrix tmp(row,col+v.col);
798     for (int r=0; r<row; ++r) {
799         for (int c=0; c<col; ++c) {
800             tmp.m[r*tmp.col+c] = m[r*col+c];
801         }
802         for (int c=0; c<v.col; ++c) {
803             tmp.m[r*tmp.col+c+col] = v.m[r*v.col+c];
804         }
805     }
806     for (int r1=0; r1<row; ++r1) {
807         int indx = tmp.pivot(r1);
808         if (indx == -1) {
809             throw (std::runtime_error("matrix::solve(): singular matrix"));
810         }
811         for (int c=r1; c<tmp.col; ++c) {
812             tmp.m[r1*tmp.col+c] /= tmp.m[r1*tmp.col+r1];
813         }
814         for (int r2=r1+1; r2<row; ++r2) {
815             for (int c=r1; c<tmp.col; ++c) {
816                 tmp.m[r2*tmp.col+c]
817                     -= tmp.m[r2*tmp.col+r1] * tmp.m[r1*tmp.col+c];
818             }
819         }
820     }
821     
822     // assemble the solution matrix
823     vector<ex> sol(v.row*v.col);
824     for (int c=0; c<v.col; ++c) {
825         for (int r=col-1; r>=0; --r) {
826             sol[r*v.col+c] = tmp[r*tmp.col+c];
827             for (int i=r+1; i<col; ++i) {
828                 sol[r*v.col+c]
829                     -= tmp[r*tmp.col+i] * sol[i*v.col+c];
830             }
831         }
832     }
833     return matrix(v.row, v.col, sol);
834 }
835
836 // protected
837
838 /** Partial pivoting method.
839  *  Usual pivoting returns the index to the element with the largest absolute
840  *  value and swaps the current row with the one where the element was found.
841  *  Here it does the same with the first non-zero element. (This works fine,
842  *  but may be far from optimal for numerics.) */
843 int matrix::pivot(int ro)
844 {
845     int k=ro;
846     
847     for (int r=ro; r<row; ++r) {
848         if (!m[r*col+ro].is_zero()) {
849             k = r;
850             break;
851         }
852     }
853     if (m[k*col+ro].is_zero()) {
854         return -1;
855     }
856     if (k!=ro) {  // swap rows
857         for (int c=0; c<col; ++c) {
858             m[k*col+c].swap(m[ro*col+c]);
859         }
860         return k;
861     }
862     return 0;
863 }
864
865 //////////
866 // global constants
867 //////////
868
869 const matrix some_matrix;
870 type_info const & typeid_matrix=typeid(some_matrix);