- ASSERT macro renamed to GINAC_ASSERT
[ginac.git] / ginac / matrix.cpp
1 /** @file matrix.cpp
2  *
3  *  Implementation of symbolic matrices */
4
5 /*
6  *  GiNaC Copyright (C) 1999 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <algorithm>
24 #include <stdexcept>
25
26 #include "matrix.h"
27 #include "debugmsg.h"
28
29 namespace GiNaC {
30
31 //////////
32 // default constructor, destructor, copy constructor, assignment operator
33 // and helpers:
34 //////////
35
36 // public
37
38 /** Default ctor.  Initializes to 1 x 1-dimensional zero-matrix. */
39 matrix::matrix()
40     : basic(TINFO_matrix), row(1), col(1)
41 {
42     debugmsg("matrix default constructor",LOGLEVEL_CONSTRUCT);
43     m.push_back(exZERO());
44 }
45
46 matrix::~matrix()
47 {
48     debugmsg("matrix destructor",LOGLEVEL_DESTRUCT);
49 }
50
51 matrix::matrix(matrix const & other)
52 {
53     debugmsg("matrix copy constructor",LOGLEVEL_CONSTRUCT);
54     copy(other);
55 }
56
57 matrix const & matrix::operator=(matrix const & other)
58 {
59     debugmsg("matrix operator=",LOGLEVEL_ASSIGNMENT);
60     if (this != &other) {
61         destroy(1);
62         copy(other);
63     }
64     return *this;
65 }
66
67 // protected
68
69 void matrix::copy(matrix const & other)
70 {
71     basic::copy(other);
72     row=other.row;
73     col=other.col;
74     m=other.m;  // use STL's vector copying
75 }
76
77 void matrix::destroy(bool call_parent)
78 {
79     if (call_parent) basic::destroy(call_parent);
80 }
81
82 //////////
83 // other constructors
84 //////////
85
86 // public
87
88 /** Very common ctor.  Initializes to r x c-dimensional zero-matrix.
89  *
90  *  @param r number of rows
91  *  @param c number of cols */
92 matrix::matrix(int r, int c)
93     : basic(TINFO_matrix), row(r), col(c)
94 {
95     debugmsg("matrix constructor from int,int",LOGLEVEL_CONSTRUCT);
96     m.resize(r*c, exZERO());
97 }
98
99 // protected
100
101 /** Ctor from representation, for internal use only. */
102 matrix::matrix(int r, int c, vector<ex> const & m2)
103     : basic(TINFO_matrix), row(r), col(c), m(m2)
104 {
105     debugmsg("matrix constructor from int,int,vector<ex>",LOGLEVEL_CONSTRUCT);
106 }
107
108 //////////
109 // functions overriding virtual functions from bases classes
110 //////////
111
112 // public
113
114 basic * matrix::duplicate() const
115 {
116     debugmsg("matrix duplicate",LOGLEVEL_DUPLICATE);
117     return new matrix(*this);
118 }
119
120 /** nops is defined to be rows x columns. */
121 int matrix::nops() const
122 {
123     return row*col;
124 }
125
126 /** returns matrix entry at position (i/col, i%col). */
127 ex & matrix::let_op(int const i)
128 {
129     return m[i];
130 }
131
132 /** expands the elements of a matrix entry by entry. */
133 ex matrix::expand(unsigned options) const
134 {
135     vector<ex> tmp(row*col);
136     for (int i=0; i<row*col; ++i) {
137         tmp[i]=m[i].expand(options);
138     }
139     return matrix(row, col, tmp);
140 }
141
142 /** Search ocurrences.  A matrix 'has' an expression if it is the expression
143  *  itself or one of the elements 'has' it. */
144 bool matrix::has(ex const & other) const
145 {
146     GINAC_ASSERT(other.bp!=0);
147     
148     // tautology: it is the expression itself
149     if (is_equal(*other.bp)) return true;
150     
151     // search all the elements
152     for (vector<ex>::const_iterator r=m.begin(); r!=m.end(); ++r) {
153         if ((*r).has(other)) return true;
154     }
155     return false;
156 }
157
158 /** evaluate matrix entry by entry. */
159 ex matrix::eval(int level) const
160 {
161     debugmsg("matrix eval",LOGLEVEL_MEMBER_FUNCTION);
162     
163     // check if we have to do anything at all
164     if ((level==1)&&(flags & status_flags::evaluated)) {
165         return *this;
166     }
167     
168     // emergency break
169     if (level == -max_recursion_level) {
170         throw (std::runtime_error("matrix::eval(): recursion limit exceeded"));
171     }
172     
173     // eval() entry by entry
174     vector<ex> m2(row*col);
175     --level;    
176     for (int r=0; r<row; ++r) {
177         for (int c=0; c<col; ++c) {
178             m2[r*col+c] = m[r*col+c].eval(level);
179         }
180     }
181     
182     return (new matrix(row, col, m2))->setflag(status_flags::dynallocated |
183                                                status_flags::evaluated );
184 }
185
186 /** evaluate matrix numerically entry by entry. */
187 ex matrix::evalf(int level) const
188 {
189     debugmsg("matrix evalf",LOGLEVEL_MEMBER_FUNCTION);
190         
191     // check if we have to do anything at all
192     if (level==1) {
193         return *this;
194     }
195     
196     // emergency break
197     if (level == -max_recursion_level) {
198         throw (std::runtime_error("matrix::evalf(): recursion limit exceeded"));
199     }
200     
201     // evalf() entry by entry
202     vector<ex> m2(row*col);
203     --level;
204     for (int r=0; r<row; ++r) {
205         for (int c=0; c<col; ++c) {
206             m2[r*col+c] = m[r*col+c].evalf(level);
207         }
208     }
209     return matrix(row, col, m2);
210 }
211
212 // protected
213
214 int matrix::compare_same_type(basic const & other) const
215 {
216     GINAC_ASSERT(is_exactly_of_type(other, matrix));
217     matrix const & o=static_cast<matrix &>(const_cast<basic &>(other));
218     
219     // compare number of rows
220     if (row != o.rows()) {
221         return row < o.rows() ? -1 : 1;
222     }
223     
224     // compare number of columns
225     if (col != o.cols()) {
226         return col < o.cols() ? -1 : 1;
227     }
228     
229     // equal number of rows and columns, compare individual elements
230     int cmpval;
231     for (int r=0; r<row; ++r) {
232         for (int c=0; c<col; ++c) {
233             cmpval=((*this)(r,c)).compare(o(r,c));
234             if (cmpval!=0) return cmpval;
235         }
236     }
237     // all elements are equal => matrices are equal;
238     return 0;
239 }
240
241 //////////
242 // non-virtual functions in this class
243 //////////
244
245 // public
246
247 /** Sum of matrices.
248  *
249  *  @exception logic_error (incompatible matrices) */
250 matrix matrix::add(matrix const & other) const
251 {
252     if (col != other.col || row != other.row) {
253         throw (std::logic_error("matrix::add(): incompatible matrices"));
254     }
255     
256     vector<ex> sum(this->m);
257     vector<ex>::iterator i;
258     vector<ex>::const_iterator ci;
259     for (i=sum.begin(), ci=other.m.begin();
260          i!=sum.end();
261          ++i, ++ci) {
262         (*i) += (*ci);
263     }
264     return matrix(row,col,sum);
265 }
266
267 /** Difference of matrices.
268  *
269  *  @exception logic_error (incompatible matrices) */
270 matrix matrix::sub(matrix const & other) const
271 {
272     if (col != other.col || row != other.row) {
273         throw (std::logic_error("matrix::sub(): incompatible matrices"));
274     }
275     
276     vector<ex> dif(this->m);
277     vector<ex>::iterator i;
278     vector<ex>::const_iterator ci;
279     for (i=dif.begin(), ci=other.m.begin();
280          i!=dif.end();
281          ++i, ++ci) {
282         (*i) -= (*ci);
283     }
284     return matrix(row,col,dif);
285 }
286
287 /** Product of matrices.
288  *
289  *  @exception logic_error (incompatible matrices) */
290 matrix matrix::mul(matrix const & other) const
291 {
292     if (col != other.row) {
293         throw (std::logic_error("matrix::mul(): incompatible matrices"));
294     }
295     
296     vector<ex> prod(row*other.col);
297     for (int i=0; i<row; ++i) {
298         for (int j=0; j<other.col; ++j) {
299             for (int l=0; l<col; ++l) {
300                 prod[i*other.col+j] += m[i*col+l] * other.m[l*other.col+j];
301             }
302         }
303     }
304     return matrix(row, other.col, prod);
305 }
306
307 /** operator() to access elements.
308  *
309  *  @param ro row of element
310  *  @param co column of element 
311  *  @exception range_error (index out of range) */
312 ex const & matrix::operator() (int ro, int co) const
313 {
314     if (ro<0 || ro>=row || co<0 || co>=col) {
315         throw (std::range_error("matrix::operator(): index out of range"));
316     }
317     
318     return m[ro*col+co];
319 }
320
321 /** Set individual elements manually.
322  *
323  *  @exception range_error (index out of range) */
324 matrix & matrix::set(int ro, int co, ex value)
325 {
326     if (ro<0 || ro>=row || co<0 || co>=col) {
327         throw (std::range_error("matrix::set(): index out of range"));
328     }
329     
330     ensure_if_modifiable();
331     m[ro*col+co]=value;
332     return *this;
333 }
334
335 /** Transposed of an m x n matrix, producing a new n x m matrix object that
336  *  represents the transposed. */
337 matrix matrix::transpose(void) const
338 {
339     vector<ex> trans(col*row);
340     
341     for (int r=0; r<col; ++r) {
342         for (int c=0; c<row; ++c) {
343             trans[r*row+c] = m[c*col+r];
344         }
345     }
346     return matrix(col,row,trans);
347 }
348
349 /* Determiant of purely numeric matrix, using pivoting. This routine is only
350  * called internally by matrix::determinant(). */
351 ex determinant_numeric(const matrix & M)
352 {
353     GINAC_ASSERT(M.rows()==M.cols());  // cannot happen, just in case...
354     matrix tmp(M);
355     ex det=exONE();
356     ex piv;
357     
358     for (int r1=0; r1<M.rows(); ++r1) {
359         int indx = tmp.pivot(r1);
360         if (indx == -1) {
361             return exZERO();
362         }
363         if (indx != 0) {
364             det *= exMINUSONE();
365         }
366         det = det * tmp.m[r1*M.cols()+r1];
367         for (int r2=r1+1; r2<M.rows(); ++r2) {
368             piv = tmp.m[r2*M.cols()+r1] / tmp.m[r1*M.cols()+r1];
369             for (int c=r1+1; c<M.cols(); c++) {
370                 tmp.m[r2*M.cols()+c] -= piv * tmp.m[r1*M.cols()+c];
371             }
372         }
373     }
374     return det;
375 }
376
377 // Compute the sign of a permutation of a vector of things, used internally
378 // by determinant_symbolic_perm() where it is instantiated for int.
379 template <class T>
380 int permutation_sign(vector<T> s)
381 {
382     if (s.size() < 2)
383         return 0;
384     int sigma=1;
385     for (typename vector<T>::iterator i=s.begin(); i!=s.end()-1; ++i) {
386         for (typename vector<T>::iterator j=i+1; j!=s.end(); ++j) {
387             if (*i == *j)
388                 return 0;
389             if (*i > *j) {
390                 iter_swap(i,j);
391                 sigma = -sigma;
392             }
393         }
394     }
395     return sigma;
396 }
397
398 /** Determinant built by application of the full permutation group. This
399  *  routine is only called internally by matrix::determinant(). */
400 ex determinant_symbolic_perm(const matrix & M)
401 {
402     GINAC_ASSERT(M.rows()==M.cols());  // cannot happen, just in case...
403     
404     if (M.rows()==1) {  // speed things up
405         return M(0,0);
406     }
407     
408     ex det;
409     ex term;
410     vector<int> sigma(M.cols());
411     for (int i=0; i<M.cols(); ++i) sigma[i]=i;
412     
413     do {
414         term = M(sigma[0],0);
415         for (int i=1; i<M.cols(); ++i) term *= M(sigma[i],i);
416         det += permutation_sign(sigma)*term;
417     } while (next_permutation(sigma.begin(), sigma.end()));
418     
419     return det;
420 }
421
422 /** Recursive determiant for small matrices having at least one symbolic entry.
423  *  This algorithm is also known as Laplace-expansion. This routine is only
424  *  called internally by matrix::determinant(). */
425 ex determinant_symbolic_minor(const matrix & M)
426 {
427     GINAC_ASSERT(M.rows()==M.cols());  // cannot happen, just in case...
428     
429     if (M.rows()==1) {  // end of recursion
430         return M(0,0);
431     }
432     if (M.rows()==2) {  // speed things up
433         return (M(0,0)*M(1,1)-
434                 M(1,0)*M(0,1));
435     }
436     if (M.rows()==3) {  // speed things up even a little more
437         return ((M(2,1)*M(0,2)-M(2,2)*M(0,1))*M(1,0)+
438                 (M(1,2)*M(0,1)-M(1,1)*M(0,2))*M(2,0)+
439                 (M(2,2)*M(1,1)-M(2,1)*M(1,2))*M(0,0));
440     }
441     
442     ex det;
443     matrix minorM(M.rows()-1,M.cols()-1);
444     for (int r1=0; r1<M.rows(); ++r1) {
445         // assemble the minor matrix
446         for (int r=0; r<minorM.rows(); ++r) {
447             for (int c=0; c<minorM.cols(); ++c) {
448                 if (r<r1) {
449                     minorM.set(r,c,M(r,c+1));
450                 } else {
451                     minorM.set(r,c,M(r+1,c+1));
452                 }
453             }
454         }
455         // recurse down
456         if (r1%2) {
457             det -= M(r1,0) * determinant_symbolic_minor(minorM);
458         } else {
459             det += M(r1,0) * determinant_symbolic_minor(minorM);
460         }
461     }
462     return det;
463 }
464
465 /*  Leverrier algorithm for large matrices having at least one symbolic entry.
466  *  This routine is only called internally by matrix::determinant(). The
467  *  algorithm is deemed bad for symbolic matrices since it returns expressions
468  *  that are very hard to canonicalize. */
469 /*ex determinant_symbolic_leverrier(const matrix & M)
470  *{
471  *    GINAC_ASSERT(M.rows()==M.cols());  // cannot happen, just in case...
472  *    
473  *    matrix B(M);
474  *    matrix I(M.row, M.col);
475  *    ex c=B.trace();
476  *    for (int i=1; i<M.row; ++i) {
477  *        for (int j=0; j<M.row; ++j)
478  *            I.m[j*M.col+j] = c;
479  *        B = M.mul(B.sub(I));
480  *        c = B.trace()/ex(i+1);
481  *    }
482  *    if (M.row%2) {
483  *        return c;
484  *    } else {
485  *        return -c;
486  *    }
487  *}*/
488
489 /** Determinant of square matrix.  This routine doesn't actually calculate the
490  *  determinant, it only implements some heuristics about which algorithm to
491  *  call.  When the parameter for normalization is explicitly turned off this
492  *  method does not normalize its result at the end, which might imply that
493  *  the symbolic 2x2 matrix [[a/(a-b),1],[b/(a-b),1]] is not immediatly
494  *  recognized to be unity.  (This is Mathematica's default behaviour, it
495  *  should be used with care.)
496  *
497  *  @param     normalized may be set to false if no normalization of the
498  *             result is desired (i.e. to force Mathematica behavior, Maple
499  *             does normalize the result).
500  *  @return    the determinant as a new expression
501  *  @exception logic_error (matrix not square) */
502 ex matrix::determinant(bool normalized) const
503 {
504     if (row != col) {
505         throw (std::logic_error("matrix::determinant(): matrix not square"));
506     }
507
508     // check, if there are non-numeric entries in the matrix:
509     for (vector<ex>::const_iterator r=m.begin(); r!=m.end(); ++r) {
510         if (!(*r).info(info_flags::numeric)) {
511             if (normalized) {
512                 return determinant_symbolic_minor(*this).normal();
513             } else {
514                 return determinant_symbolic_perm(*this);
515             }
516         }
517     }
518     // if it turns out that all elements are numeric
519     return determinant_numeric(*this);
520 }
521
522 /** Trace of a matrix.
523  *
524  *  @return    the sum of diagonal elements
525  *  @exception logic_error (matrix not square) */
526 ex matrix::trace(void) const
527 {
528     if (row != col) {
529         throw (std::logic_error("matrix::trace(): matrix not square"));
530     }
531     
532     ex tr;
533     for (int r=0; r<col; ++r) {
534         tr += m[r*col+r];
535     }
536     return tr;
537 }
538
539 /** Characteristic Polynomial.  The characteristic polynomial of a matrix M is
540  *  defined as the determiant of (M - lambda * 1) where 1 stands for the unit
541  *  matrix of the same dimension as M.  This method returns the characteristic
542  *  polynomial as a new expression.
543  *
544  *  @return    characteristic polynomial as new expression
545  *  @exception logic_error (matrix not square)
546  *  @see       matrix::determinant() */
547 ex matrix::charpoly(ex const & lambda) const
548 {
549     if (row != col) {
550         throw (std::logic_error("matrix::charpoly(): matrix not square"));
551     }
552     
553     matrix M(*this);
554     for (int r=0; r<col; ++r) {
555         M.m[r*col+r] -= lambda;
556     }
557     return (M.determinant());
558 }
559
560 /** Inverse of this matrix.
561  *
562  *  @return    the inverted matrix
563  *  @exception logic_error (matrix not square)
564  *  @exception runtime_error (singular matrix) */
565 matrix matrix::inverse(void) const
566 {
567     if (row != col) {
568         throw (std::logic_error("matrix::inverse(): matrix not square"));
569     }
570     
571     matrix tmp(row,col);
572     // set tmp to the unit matrix
573     for (int i=0; i<col; ++i) {
574         tmp.m[i*col+i] = exONE();
575     }
576     // create a copy of this matrix
577     matrix cpy(*this);
578     for (int r1=0; r1<row; ++r1) {
579         int indx = cpy.pivot(r1);
580         if (indx == -1) {
581             throw (std::runtime_error("matrix::inverse(): singular matrix"));
582         }
583         if (indx != 0) {  // swap rows r and indx of matrix tmp
584             for (int i=0; i<col; ++i) {
585                 tmp.m[r1*col+i].swap(tmp.m[indx*col+i]);
586             }
587         }
588         ex a1 = cpy.m[r1*col+r1];
589         for (int c=0; c<col; ++c) {
590             cpy.m[r1*col+c] /= a1;
591             tmp.m[r1*col+c] /= a1;
592         }
593         for (int r2=0; r2<row; ++r2) {
594             if (r2 != r1) {
595                 ex a2 = cpy.m[r2*col+r1];
596                 for (int c=0; c<col; ++c) {
597                     cpy.m[r2*col+c] -= a2 * cpy.m[r1*col+c];
598                     tmp.m[r2*col+c] -= a2 * tmp.m[r1*col+c];
599                 }
600             }
601         }
602     }
603     return tmp;
604 }
605
606 void matrix::ffe_swap(int r1, int c1, int r2 ,int c2)
607 {
608     ensure_if_modifiable();
609     
610     ex tmp=ffe_get(r1,c1);
611     ffe_set(r1,c1,ffe_get(r2,c2));
612     ffe_set(r2,c2,tmp);
613 }
614
615 void matrix::ffe_set(int r, int c, ex e)
616 {
617     set(r-1,c-1,e);
618 }
619
620 ex matrix::ffe_get(int r, int c) const
621 {
622     return operator()(r-1,c-1);
623 }
624
625 /** Solve a set of equations for an m x n matrix by fraction-free Gaussian
626  *  elimination. Based on algorithm 9.1 from 'Algorithms for Computer Algebra'
627  *  by Keith O. Geddes et al.
628  *
629  *  @param vars n x p matrix
630  *  @param rhs m x p matrix
631  *  @exception logic_error (incompatible matrices)
632  *  @exception runtime_error (singular matrix) */
633 matrix matrix::fraction_free_elim(matrix const & vars,
634                                   matrix const & rhs) const
635 {
636     if ((row != rhs.row) || (col != vars.row) || (rhs.col != vars.col)) {
637         throw (std::logic_error("matrix::solve(): incompatible matrices"));
638     }
639     
640     matrix a(*this); // make a copy of the matrix
641     matrix b(rhs);     // make a copy of the rhs vector
642     
643     // given an m x n matrix a, reduce it to upper echelon form
644     int m=a.row;
645     int n=a.col;
646     int sign=1;
647     ex divisor=1;
648     int r=1;
649     
650     // eliminate below row r, with pivot in column k
651     for (int k=1; (k<=n)&&(r<=m); ++k) {
652         // find a nonzero pivot
653         int p;
654         for (p=r; (p<=m)&&(a.ffe_get(p,k).is_equal(exZERO())); ++p) {}
655         // pivot is in row p
656         if (p<=m) {
657             if (p!=r) {
658                 // switch rows p and r
659                 for (int j=k; j<=n; ++j) {
660                     a.ffe_swap(p,j,r,j);
661                 }
662                 b.ffe_swap(p,1,r,1);
663                 // keep track of sign changes due to row exchange
664                 sign=-sign;
665             }
666             for (int i=r+1; i<=m; ++i) {
667                 for (int j=k+1; j<=n; ++j) {
668                     a.ffe_set(i,j,(a.ffe_get(r,k)*a.ffe_get(i,j)
669                                   -a.ffe_get(r,j)*a.ffe_get(i,k))/divisor);
670                     a.ffe_set(i,j,a.ffe_get(i,j).normal() /*.normal() */ );
671                 }
672                 b.ffe_set(i,1,(a.ffe_get(r,k)*b.ffe_get(i,1)
673                               -b.ffe_get(r,1)*a.ffe_get(i,k))/divisor);
674                 b.ffe_set(i,1,b.ffe_get(i,1).normal() /*.normal() */ );
675                 a.ffe_set(i,k,0);
676             }
677             divisor=a.ffe_get(r,k);
678             r++;
679         }
680     }
681     // optionally compute the determinant for square or augmented matrices
682     // if (r==m+1) { det=sign*divisor; } else { det=0; }
683     
684     /*
685     for (int r=1; r<=m; ++r) {
686         for (int c=1; c<=n; ++c) {
687             cout << a.ffe_get(r,c) << "\t";
688         }
689         cout << " | " <<  b.ffe_get(r,1) << endl;
690     }
691     */
692     
693 #ifdef DO_GINAC_ASSERT
694     // test if we really have an upper echelon matrix
695     int zero_in_last_row=-1;
696     for (int r=1; r<=m; ++r) {
697         int zero_in_this_row=0;
698         for (int c=1; c<=n; ++c) {
699             if (a.ffe_get(r,c).is_equal(exZERO())) {
700                zero_in_this_row++;
701             } else {
702                 break;
703             }
704         }
705         GINAC_ASSERT((zero_in_this_row>zero_in_last_row)||(zero_in_this_row=n));
706         zero_in_last_row=zero_in_this_row;
707     }
708 #endif // def DO_GINAC_ASSERT
709     
710     // assemble solution
711     matrix sol(n,1);
712     int last_assigned_sol=n+1;
713     for (int r=m; r>0; --r) {
714         int first_non_zero=1;
715         while ((first_non_zero<=n)&&(a.ffe_get(r,first_non_zero).is_zero())) {
716             first_non_zero++;
717         }
718         if (first_non_zero>n) {
719             // row consists only of zeroes, corresponding rhs must be 0 as well
720             if (!b.ffe_get(r,1).is_zero()) {
721                 throw (std::runtime_error("matrix::fraction_free_elim(): singular matrix"));
722             }
723         } else {
724             // assign solutions for vars between first_non_zero+1 and
725             // last_assigned_sol-1: free parameters
726             for (int c=first_non_zero+1; c<=last_assigned_sol-1; ++c) {
727                 sol.ffe_set(c,1,vars.ffe_get(c,1));
728             }
729             ex e=b.ffe_get(r,1);
730             for (int c=first_non_zero+1; c<=n; ++c) {
731                 e=e-a.ffe_get(r,c)*sol.ffe_get(c,1);
732             }
733             sol.ffe_set(first_non_zero,1,
734                         (e/a.ffe_get(r,first_non_zero)).normal());
735             last_assigned_sol=first_non_zero;
736         }
737     }
738     // assign solutions for vars between 1 and
739     // last_assigned_sol-1: free parameters
740     for (int c=1; c<=last_assigned_sol-1; ++c) {
741         sol.ffe_set(c,1,vars.ffe_get(c,1));
742     }
743
744     /*
745     for (int c=1; c<=n; ++c) {
746         cout << vars.ffe_get(c,1) << "->" << sol.ffe_get(c,1) << endl;
747     }
748     */
749     
750 #ifdef DO_GINAC_ASSERT
751     // test solution with echelon matrix
752     for (int r=1; r<=m; ++r) {
753         ex e=0;
754         for (int c=1; c<=n; ++c) {
755             e=e+a.ffe_get(r,c)*sol.ffe_get(c,1);
756         }
757         if (!(e-b.ffe_get(r,1)).normal().is_zero()) {
758             cout << "e=" << e;
759             cout << "b.ffe_get(" << r<<",1)=" << b.ffe_get(r,1) << endl;
760             cout << "diff=" << (e-b.ffe_get(r,1)).normal() << endl;
761         }
762         GINAC_ASSERT((e-b.ffe_get(r,1)).normal().is_zero());
763     }
764
765     // test solution with original matrix
766     for (int r=1; r<=m; ++r) {
767         ex e=0;
768         for (int c=1; c<=n; ++c) {
769             e=e+ffe_get(r,c)*sol.ffe_get(c,1);
770         }
771         try {
772         if (!(e-rhs.ffe_get(r,1)).normal().is_zero()) {
773             cout << "e=" << e << endl;
774             e.printtree(cout);
775             ex en=e.normal();
776             cout << "e.normal()=" << en << endl;
777             en.printtree(cout);
778             cout << "rhs.ffe_get(" << r<<",1)=" << rhs.ffe_get(r,1) << endl;
779             cout << "diff=" << (e-rhs.ffe_get(r,1)).normal() << endl;
780         }
781         } catch (...) {
782             ex xxx=e-rhs.ffe_get(r,1);
783             cerr << "xxx=" << xxx << endl << endl;
784         }
785         GINAC_ASSERT((e-rhs.ffe_get(r,1)).normal().is_zero());
786     }
787 #endif // def DO_GINAC_ASSERT
788     
789     return sol;
790 }   
791     
792 /** Solve simultaneous set of equations. */
793 matrix matrix::solve(matrix const & v) const
794 {
795     if (!(row == col && col == v.row)) {
796         throw (std::logic_error("matrix::solve(): incompatible matrices"));
797     }
798     
799     // build the extended matrix of *this with v attached to the right
800     matrix tmp(row,col+v.col);
801     for (int r=0; r<row; ++r) {
802         for (int c=0; c<col; ++c) {
803             tmp.m[r*tmp.col+c] = m[r*col+c];
804         }
805         for (int c=0; c<v.col; ++c) {
806             tmp.m[r*tmp.col+c+col] = v.m[r*v.col+c];
807         }
808     }
809     for (int r1=0; r1<row; ++r1) {
810         int indx = tmp.pivot(r1);
811         if (indx == -1) {
812             throw (std::runtime_error("matrix::solve(): singular matrix"));
813         }
814         for (int c=r1; c<tmp.col; ++c) {
815             tmp.m[r1*tmp.col+c] /= tmp.m[r1*tmp.col+r1];
816         }
817         for (int r2=r1+1; r2<row; ++r2) {
818             for (int c=r1; c<tmp.col; ++c) {
819                 tmp.m[r2*tmp.col+c]
820                     -= tmp.m[r2*tmp.col+r1] * tmp.m[r1*tmp.col+c];
821             }
822         }
823     }
824     
825     // assemble the solution matrix
826     vector<ex> sol(v.row*v.col);
827     for (int c=0; c<v.col; ++c) {
828         for (int r=col-1; r>=0; --r) {
829             sol[r*v.col+c] = tmp[r*tmp.col+c];
830             for (int i=r+1; i<col; ++i) {
831                 sol[r*v.col+c]
832                     -= tmp[r*tmp.col+i] * sol[i*v.col+c];
833             }
834         }
835     }
836     return matrix(v.row, v.col, sol);
837 }
838
839 // protected
840
841 /** Partial pivoting method.
842  *  Usual pivoting returns the index to the element with the largest absolute
843  *  value and swaps the current row with the one where the element was found.
844  *  Here it does the same with the first non-zero element. (This works fine,
845  *  but may be far from optimal for numerics.) */
846 int matrix::pivot(int ro)
847 {
848     int k=ro;
849     
850     for (int r=ro; r<row; ++r) {
851         if (!m[r*col+ro].is_zero()) {
852             k = r;
853             break;
854         }
855     }
856     if (m[k*col+ro].is_zero()) {
857         return -1;
858     }
859     if (k!=ro) {  // swap rows
860         for (int c=0; c<col; ++c) {
861             m[k*col+c].swap(m[ro*col+c]);
862         }
863         return k;
864     }
865     return 0;
866 }
867
868 //////////
869 // global constants
870 //////////
871
872 const matrix some_matrix;
873 type_info const & typeid_matrix=typeid(some_matrix);
874
875 } // namespace GiNaC