]> www.ginac.de Git - ginac.git/blob - ginac/matrix.cpp
- pseries::print(): did not insert parenthesis when needed for precedence.
[ginac.git] / ginac / matrix.cpp
1 /** @file matrix.cpp
2  *
3  *  Implementation of symbolic matrices */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2000 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <algorithm>
24 #include <map>
25 #include <stdexcept>
26
27 #include "matrix.h"
28 #include "archive.h"
29 #include "numeric.h"
30 #include "lst.h"
31 #include "utils.h"
32 #include "debugmsg.h"
33 #include "power.h"
34 #include "symbol.h"
35 #include "normal.h"
36
37 #ifndef NO_NAMESPACE_GINAC
38 namespace GiNaC {
39 #endif // ndef NO_NAMESPACE_GINAC
40
41 GINAC_IMPLEMENT_REGISTERED_CLASS(matrix, basic)
42
43 //////////
44 // default constructor, destructor, copy constructor, assignment operator
45 // and helpers:
46 //////////
47
48 // public
49
50 /** Default ctor.  Initializes to 1 x 1-dimensional zero-matrix. */
51 matrix::matrix() : inherited(TINFO_matrix), row(1), col(1)
52 {
53         debugmsg("matrix default constructor",LOGLEVEL_CONSTRUCT);
54         m.push_back(_ex0());
55 }
56
57 matrix::~matrix()
58 {
59         debugmsg("matrix destructor",LOGLEVEL_DESTRUCT);
60         destroy(false);
61 }
62
63 matrix::matrix(const matrix & other)
64 {
65         debugmsg("matrix copy constructor",LOGLEVEL_CONSTRUCT);
66         copy(other);
67 }
68
69 const matrix & matrix::operator=(const matrix & other)
70 {
71         debugmsg("matrix operator=",LOGLEVEL_ASSIGNMENT);
72         if (this != &other) {
73                 destroy(true);
74                 copy(other);
75         }
76         return *this;
77 }
78
79 // protected
80
81 void matrix::copy(const matrix & other)
82 {
83         inherited::copy(other);
84         row = other.row;
85         col = other.col;
86         m = other.m;  // STL's vector copying invoked here
87 }
88
89 void matrix::destroy(bool call_parent)
90 {
91         if (call_parent) inherited::destroy(call_parent);
92 }
93
94 //////////
95 // other constructors
96 //////////
97
98 // public
99
100 /** Very common ctor.  Initializes to r x c-dimensional zero-matrix.
101  *
102  *  @param r number of rows
103  *  @param c number of cols */
104 matrix::matrix(unsigned r, unsigned c)
105   : inherited(TINFO_matrix), row(r), col(c)
106 {
107         debugmsg("matrix constructor from unsigned,unsigned",LOGLEVEL_CONSTRUCT);
108         m.resize(r*c, _ex0());
109 }
110
111 // protected
112
113 /** Ctor from representation, for internal use only. */
114 matrix::matrix(unsigned r, unsigned c, const exvector & m2)
115   : inherited(TINFO_matrix), row(r), col(c), m(m2)
116 {
117         debugmsg("matrix constructor from unsigned,unsigned,exvector",LOGLEVEL_CONSTRUCT);
118 }
119
120 //////////
121 // archiving
122 //////////
123
124 /** Construct object from archive_node. */
125 matrix::matrix(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
126 {
127         debugmsg("matrix constructor from archive_node", LOGLEVEL_CONSTRUCT);
128         if (!(n.find_unsigned("row", row)) || !(n.find_unsigned("col", col)))
129                 throw (std::runtime_error("unknown matrix dimensions in archive"));
130         m.reserve(row * col);
131         for (unsigned int i=0; true; i++) {
132                 ex e;
133                 if (n.find_ex("m", e, sym_lst, i))
134                         m.push_back(e);
135                 else
136                         break;
137         }
138 }
139
140 /** Unarchive the object. */
141 ex matrix::unarchive(const archive_node &n, const lst &sym_lst)
142 {
143         return (new matrix(n, sym_lst))->setflag(status_flags::dynallocated);
144 }
145
146 /** Archive the object. */
147 void matrix::archive(archive_node &n) const
148 {
149         inherited::archive(n);
150         n.add_unsigned("row", row);
151         n.add_unsigned("col", col);
152         exvector::const_iterator i = m.begin(), iend = m.end();
153         while (i != iend) {
154                 n.add_ex("m", *i);
155                 ++i;
156         }
157 }
158
159 //////////
160 // functions overriding virtual functions from bases classes
161 //////////
162
163 // public
164
165 basic * matrix::duplicate() const
166 {
167         debugmsg("matrix duplicate",LOGLEVEL_DUPLICATE);
168         return new matrix(*this);
169 }
170
171 void matrix::print(std::ostream & os, unsigned upper_precedence) const
172 {
173         debugmsg("matrix print",LOGLEVEL_PRINT);
174         os << "[[ ";
175         for (unsigned r=0; r<row-1; ++r) {
176                 os << "[[";
177                 for (unsigned c=0; c<col-1; ++c)
178                         os << m[r*col+c] << ",";
179                 os << m[col*(r+1)-1] << "]], ";
180         }
181         os << "[[";
182         for (unsigned c=0; c<col-1; ++c)
183                 os << m[(row-1)*col+c] << ",";
184         os << m[row*col-1] << "]] ]]";
185 }
186
187 void matrix::printraw(std::ostream & os) const
188 {
189         debugmsg("matrix printraw",LOGLEVEL_PRINT);
190         os << "matrix(" << row << "," << col <<",";
191         for (unsigned r=0; r<row-1; ++r) {
192                 os << "(";
193                 for (unsigned c=0; c<col-1; ++c)
194                         os << m[r*col+c] << ",";
195                 os << m[col*(r-1)-1] << "),";
196         }
197         os << "(";
198         for (unsigned c=0; c<col-1; ++c)
199                 os << m[(row-1)*col+c] << ",";
200         os << m[row*col-1] << "))";
201 }
202
203 /** nops is defined to be rows x columns. */
204 unsigned matrix::nops() const
205 {
206         return row*col;
207 }
208
209 /** returns matrix entry at position (i/col, i%col). */
210 ex matrix::op(int i) const
211 {
212         return m[i];
213 }
214
215 /** returns matrix entry at position (i/col, i%col). */
216 ex & matrix::let_op(int i)
217 {
218         GINAC_ASSERT(i>=0);
219         GINAC_ASSERT(i<nops());
220         
221         return m[i];
222 }
223
224 /** expands the elements of a matrix entry by entry. */
225 ex matrix::expand(unsigned options) const
226 {
227         exvector tmp(row*col);
228         for (unsigned i=0; i<row*col; ++i)
229                 tmp[i] = m[i].expand(options);
230         
231         return matrix(row, col, tmp);
232 }
233
234 /** Search ocurrences.  A matrix 'has' an expression if it is the expression
235  *  itself or one of the elements 'has' it. */
236 bool matrix::has(const ex & other) const
237 {
238         GINAC_ASSERT(other.bp!=0);
239         
240         // tautology: it is the expression itself
241         if (is_equal(*other.bp)) return true;
242         
243         // search all the elements
244         for (exvector::const_iterator r=m.begin(); r!=m.end(); ++r)
245                 if ((*r).has(other)) return true;
246         
247         return false;
248 }
249
250 /** evaluate matrix entry by entry. */
251 ex matrix::eval(int level) const
252 {
253         debugmsg("matrix eval",LOGLEVEL_MEMBER_FUNCTION);
254         
255         // check if we have to do anything at all
256         if ((level==1)&&(flags & status_flags::evaluated))
257                 return *this;
258         
259         // emergency break
260         if (level == -max_recursion_level)
261                 throw (std::runtime_error("matrix::eval(): recursion limit exceeded"));
262         
263         // eval() entry by entry
264         exvector m2(row*col);
265         --level;
266         for (unsigned r=0; r<row; ++r)
267                 for (unsigned c=0; c<col; ++c)
268                         m2[r*col+c] = m[r*col+c].eval(level);
269         
270         return (new matrix(row, col, m2))->setflag(status_flags::dynallocated |
271                                                                                            status_flags::evaluated );
272 }
273
274 /** evaluate matrix numerically entry by entry. */
275 ex matrix::evalf(int level) const
276 {
277         debugmsg("matrix evalf",LOGLEVEL_MEMBER_FUNCTION);
278                 
279         // check if we have to do anything at all
280         if (level==1)
281                 return *this;
282         
283         // emergency break
284         if (level == -max_recursion_level) {
285                 throw (std::runtime_error("matrix::evalf(): recursion limit exceeded"));
286         }
287         
288         // evalf() entry by entry
289         exvector m2(row*col);
290         --level;
291         for (unsigned r=0; r<row; ++r)
292                 for (unsigned c=0; c<col; ++c)
293                         m2[r*col+c] = m[r*col+c].evalf(level);
294         
295         return matrix(row, col, m2);
296 }
297
298 // protected
299
300 int matrix::compare_same_type(const basic & other) const
301 {
302         GINAC_ASSERT(is_exactly_of_type(other, matrix));
303         const matrix & o = static_cast<matrix &>(const_cast<basic &>(other));
304         
305         // compare number of rows
306         if (row != o.rows())
307                 return row < o.rows() ? -1 : 1;
308         
309         // compare number of columns
310         if (col != o.cols())
311                 return col < o.cols() ? -1 : 1;
312         
313         // equal number of rows and columns, compare individual elements
314         int cmpval;
315         for (unsigned r=0; r<row; ++r) {
316                 for (unsigned c=0; c<col; ++c) {
317                         cmpval = ((*this)(r,c)).compare(o(r,c));
318                         if (cmpval!=0) return cmpval;
319                 }
320         }
321         // all elements are equal => matrices are equal;
322         return 0;
323 }
324
325 //////////
326 // non-virtual functions in this class
327 //////////
328
329 // public
330
331 /** Sum of matrices.
332  *
333  *  @exception logic_error (incompatible matrices) */
334 matrix matrix::add(const matrix & other) const
335 {
336         if (col != other.col || row != other.row)
337                 throw (std::logic_error("matrix::add(): incompatible matrices"));
338         
339         exvector sum(this->m);
340         exvector::iterator i;
341         exvector::const_iterator ci;
342         for (i=sum.begin(), ci=other.m.begin(); i!=sum.end(); ++i, ++ci)
343                 (*i) += (*ci);
344         
345         return matrix(row,col,sum);
346 }
347
348
349 /** Difference of matrices.
350  *
351  *  @exception logic_error (incompatible matrices) */
352 matrix matrix::sub(const matrix & other) const
353 {
354         if (col != other.col || row != other.row)
355                 throw (std::logic_error("matrix::sub(): incompatible matrices"));
356         
357         exvector dif(this->m);
358         exvector::iterator i;
359         exvector::const_iterator ci;
360         for (i=dif.begin(), ci=other.m.begin(); i!=dif.end(); ++i, ++ci)
361                 (*i) -= (*ci);
362         
363         return matrix(row,col,dif);
364 }
365
366
367 /** Product of matrices.
368  *
369  *  @exception logic_error (incompatible matrices) */
370 matrix matrix::mul(const matrix & other) const
371 {
372         if (this->cols() != other.rows())
373                 throw (std::logic_error("matrix::mul(): incompatible matrices"));
374         
375         exvector prod(this->rows()*other.cols());
376         
377         for (unsigned r1=0; r1<this->rows(); ++r1) {
378                 for (unsigned c=0; c<this->cols(); ++c) {
379                         if (m[r1*col+c].is_zero())
380                                 continue;
381                         for (unsigned r2=0; r2<other.cols(); ++r2)
382                                 prod[r1*other.col+r2] += (m[r1*col+c] * other.m[c*other.col+r2]).expand();
383                 }
384         }
385         return matrix(row, other.col, prod);
386 }
387
388
389 /** operator() to access elements.
390  *
391  *  @param ro row of element
392  *  @param co column of element
393  *  @exception range_error (index out of range) */
394 const ex & matrix::operator() (unsigned ro, unsigned co) const
395 {
396         if (ro>=row || co>=col)
397                 throw (std::range_error("matrix::operator(): index out of range"));
398
399         return m[ro*col+co];
400 }
401
402
403 /** Set individual elements manually.
404  *
405  *  @exception range_error (index out of range) */
406 matrix & matrix::set(unsigned ro, unsigned co, ex value)
407 {
408         if (ro>=row || co>=col)
409                 throw (std::range_error("matrix::set(): index out of range"));
410     
411         ensure_if_modifiable();
412         m[ro*col+co] = value;
413         return *this;
414 }
415
416
417 /** Transposed of an m x n matrix, producing a new n x m matrix object that
418  *  represents the transposed. */
419 matrix matrix::transpose(void) const
420 {
421         exvector trans(this->cols()*this->rows());
422         
423         for (unsigned r=0; r<this->cols(); ++r)
424                 for (unsigned c=0; c<this->rows(); ++c)
425                         trans[r*this->rows()+c] = m[c*this->cols()+r];
426         
427         return matrix(this->cols(),this->rows(),trans);
428 }
429
430
431 /** Determinant of square matrix.  This routine doesn't actually calculate the
432  *  determinant, it only implements some heuristics about which algorithm to
433  *  run.  If all the elements of the matrix are elements of an integral domain
434  *  the determinant is also in that integral domain and the result is expanded
435  *  only.  If one or more elements are from a quotient field the determinant is
436  *  usually also in that quotient field and the result is normalized before it
437  *  is returned.  This implies that the determinant of the symbolic 2x2 matrix
438  *  [[a/(a-b),1],[b/(a-b),1]] is returned as unity.  (In this respect, it
439  *  behaves like MapleV and unlike Mathematica.)
440  *
441  *  @param     algo allows to chose an algorithm
442  *  @return    the determinant as a new expression
443  *  @exception logic_error (matrix not square)
444  *  @see       determinant_algo */
445 ex matrix::determinant(unsigned algo) const
446 {
447         if (row!=col)
448                 throw (std::logic_error("matrix::determinant(): matrix not square"));
449         GINAC_ASSERT(row*col==m.capacity());
450         
451         // Gather some statistical information about this matrix:
452         bool numeric_flag = true;
453         bool normal_flag = false;
454         unsigned sparse_count = 0;  // counts non-zero elements
455         for (exvector::const_iterator r=m.begin(); r!=m.end(); ++r) {
456                 lst srl;  // symbol replacement list
457                 ex rtest = (*r).to_rational(srl);
458                 if (!rtest.is_zero())
459                         ++sparse_count;
460                 if (!rtest.info(info_flags::numeric))
461                         numeric_flag = false;
462                 if (!rtest.info(info_flags::crational_polynomial) &&
463                          rtest.info(info_flags::rational_function))
464                         normal_flag = true;
465         }
466         
467         // Here is the heuristics in case this routine has to decide:
468         if (algo == determinant_algo::automatic) {
469                 // Minor expansion is generally a good guess:
470                 algo = determinant_algo::laplace;
471                 // Does anybody know when a matrix is really sparse?
472                 // Maybe <~row/2.236 nonzero elements average in a row?
473                 if (row>3 && 5*sparse_count<=row*col)
474                         algo = determinant_algo::bareiss;
475                 // Purely numeric matrix can be handled by Gauss elimination.
476                 // This overrides any prior decisions.
477                 if (numeric_flag)
478                         algo = determinant_algo::gauss;
479         }
480         
481         // Trap the trivial case here, since some algorithms don't like it
482         if (this->row==1) {
483                 // for consistency with non-trivial determinants...
484                 if (normal_flag)
485                         return m[0].normal();
486                 else
487                         return m[0].expand();
488         }
489         
490         // Compute the determinant
491         switch(algo) {
492                 case determinant_algo::gauss: {
493                         ex det = 1;
494                         matrix tmp(*this);
495                         int sign = tmp.gauss_elimination(true);
496                         for (unsigned d=0; d<row; ++d)
497                                 det *= tmp.m[d*col+d];
498                         if (normal_flag)
499                                 return (sign*det).normal();
500                         else
501                                 return (sign*det).normal().expand();
502                 }
503                 case determinant_algo::bareiss: {
504                         matrix tmp(*this);
505                         int sign;
506                         sign = tmp.fraction_free_elimination(true);
507                         if (normal_flag)
508                                 return (sign*tmp.m[row*col-1]).normal();
509                         else
510                                 return (sign*tmp.m[row*col-1]).expand();
511                 }
512                 case determinant_algo::divfree: {
513                         matrix tmp(*this);
514                         int sign;
515                         sign = tmp.division_free_elimination(true);
516                         if (sign==0)
517                                 return _ex0();
518                         ex det = tmp.m[row*col-1];
519                         // factor out accumulated bogus slag
520                         for (unsigned d=0; d<row-2; ++d)
521                                 for (unsigned j=0; j<row-d-2; ++j)
522                                         det = (det/tmp.m[d*col+d]).normal();
523                         return (sign*det);
524                 }
525                 case determinant_algo::laplace:
526                 default: {
527                         // This is the minor expansion scheme.  We always develop such
528                         // that the smallest minors (i.e, the trivial 1x1 ones) are on the
529                         // rightmost column.  For this to be efficient it turns out that
530                         // the emptiest columns (i.e. the ones with most zeros) should be
531                         // the ones on the right hand side.  Therefore we presort the
532                         // columns of the matrix:
533                         typedef std::pair<unsigned,unsigned> uintpair;
534                         std::vector<uintpair> c_zeros;  // number of zeros in column
535                         for (unsigned c=0; c<col; ++c) {
536                                 unsigned acc = 0;
537                                 for (unsigned r=0; r<row; ++r)
538                                         if (m[r*col+c].is_zero())
539                                                 ++acc;
540                                 c_zeros.push_back(uintpair(acc,c));
541                         }
542                         sort(c_zeros.begin(),c_zeros.end());
543                         std::vector<unsigned> pre_sort;
544                         for (std::vector<uintpair>::iterator i=c_zeros.begin(); i!=c_zeros.end(); ++i)
545                                 pre_sort.push_back(i->second);
546                         int sign = permutation_sign(pre_sort);
547                         exvector result(row*col);  // represents sorted matrix
548                         unsigned c = 0;
549                         for (std::vector<unsigned>::iterator i=pre_sort.begin();
550                                  i!=pre_sort.end();
551                                  ++i,++c) {
552                                 for (unsigned r=0; r<row; ++r)
553                                         result[r*col+c] = m[r*col+(*i)];
554                         }
555                         
556                         if (normal_flag)
557                                 return (sign*matrix(row,col,result).determinant_minor()).normal();
558                         else
559                                 return sign*matrix(row,col,result).determinant_minor();
560                 }
561         }
562 }
563
564
565 /** Trace of a matrix.  The result is normalized if it is in some quotient
566  *  field and expanded only otherwise.  This implies that the trace of the
567  *  symbolic 2x2 matrix [[a/(a-b),x],[y,b/(b-a)]] is recognized to be unity.
568  *
569  *  @return    the sum of diagonal elements
570  *  @exception logic_error (matrix not square) */
571 ex matrix::trace(void) const
572 {
573         if (row != col)
574                 throw (std::logic_error("matrix::trace(): matrix not square"));
575         
576         ex tr;
577         for (unsigned r=0; r<col; ++r)
578                 tr += m[r*col+r];
579         
580         if (tr.info(info_flags::rational_function) &&
581                 !tr.info(info_flags::crational_polynomial))
582                 return tr.normal();
583         else
584                 return tr.expand();
585 }
586
587
588 /** Characteristic Polynomial.  Following mathematica notation the
589  *  characteristic polynomial of a matrix M is defined as the determiant of
590  *  (M - lambda * 1) where 1 stands for the unit matrix of the same dimension
591  *  as M.  Note that some CASs define it with a sign inside the determinant
592  *  which gives rise to an overall sign if the dimension is odd.  This method
593  *  returns the characteristic polynomial collected in powers of lambda as a
594  *  new expression.
595  *
596  *  @return    characteristic polynomial as new expression
597  *  @exception logic_error (matrix not square)
598  *  @see       matrix::determinant() */
599 ex matrix::charpoly(const symbol & lambda) const
600 {
601         if (row != col)
602                 throw (std::logic_error("matrix::charpoly(): matrix not square"));
603         
604         bool numeric_flag = true;
605         for (exvector::const_iterator r=m.begin(); r!=m.end(); ++r) {
606                 if (!(*r).info(info_flags::numeric)) {
607                         numeric_flag = false;
608                 }
609         }
610         
611         // The pure numeric case is traditionally rather common.  Hence, it is
612         // trapped and we use Leverrier's algorithm which goes as row^3 for
613         // every coefficient.  The expensive part is the matrix multiplication.
614         if (numeric_flag) {
615                 matrix B(*this);
616                 ex c = B.trace();
617                 ex poly = power(lambda,row)-c*power(lambda,row-1);
618                 for (unsigned i=1; i<row; ++i) {
619                         for (unsigned j=0; j<row; ++j)
620                                 B.m[j*col+j] -= c;
621                         B = this->mul(B);
622                         c = B.trace()/ex(i+1);
623                         poly -= c*power(lambda,row-i-1);
624                 }
625                 if (row%2)
626                         return -poly;
627                 else
628                         return poly;
629         }
630         
631         matrix M(*this);
632         for (unsigned r=0; r<col; ++r)
633                 M.m[r*col+r] -= lambda;
634         
635         return M.determinant().collect(lambda);
636 }
637
638
639 /** Inverse of this matrix.
640  *
641  *  @return    the inverted matrix
642  *  @exception logic_error (matrix not square)
643  *  @exception runtime_error (singular matrix) */
644 matrix matrix::inverse(void) const
645 {
646         if (row != col)
647                 throw (std::logic_error("matrix::inverse(): matrix not square"));
648         
649         // NOTE: the Gauss-Jordan elimination used here can in principle be
650         // replaced by two clever calls to gauss_elimination() and some to
651         // transpose().  Wouldn't be more efficient (maybe less?), just more
652         // orthogonal.
653         matrix tmp(row,col);
654         // set tmp to the unit matrix
655         for (unsigned i=0; i<col; ++i)
656                 tmp.m[i*col+i] = _ex1();
657         
658         // create a copy of this matrix
659         matrix cpy(*this);
660         for (unsigned r1=0; r1<row; ++r1) {
661                 int indx = cpy.pivot(r1, r1);
662                 if (indx == -1) {
663                         throw (std::runtime_error("matrix::inverse(): singular matrix"));
664                 }
665                 if (indx != 0) {  // swap rows r and indx of matrix tmp
666                         for (unsigned i=0; i<col; ++i)
667                                 tmp.m[r1*col+i].swap(tmp.m[indx*col+i]);
668                 }
669                 ex a1 = cpy.m[r1*col+r1];
670                 for (unsigned c=0; c<col; ++c) {
671                         cpy.m[r1*col+c] /= a1;
672                         tmp.m[r1*col+c] /= a1;
673                 }
674                 for (unsigned r2=0; r2<row; ++r2) {
675                         if (r2 != r1) {
676                                 if (!cpy.m[r2*col+r1].is_zero()) {
677                                         ex a2 = cpy.m[r2*col+r1];
678                                         // yes, there is something to do in this column
679                                         for (unsigned c=0; c<col; ++c) {
680                                                 cpy.m[r2*col+c] -= a2 * cpy.m[r1*col+c];
681                                                 if (!cpy.m[r2*col+c].info(info_flags::numeric))
682                                                         cpy.m[r2*col+c] = cpy.m[r2*col+c].normal();
683                                                 tmp.m[r2*col+c] -= a2 * tmp.m[r1*col+c];
684                                                 if (!tmp.m[r2*col+c].info(info_flags::numeric))
685                                                         tmp.m[r2*col+c] = tmp.m[r2*col+c].normal();
686                                         }
687                                 }
688                         }
689                 }
690         }
691         
692         return tmp;
693 }
694
695
696 /** Solve a linear system consisting of a m x n matrix and a m x p right hand
697  *  side by applying an elimination scheme to the augmented matrix.
698  *
699  *  @param vars n x p matrix, all elements must be symbols 
700  *  @param rhs m x p matrix
701  *  @return n x p solution matrix
702  *  @exception logic_error (incompatible matrices)
703  *  @exception invalid_argument (1st argument must be matrix of symbols)
704  *  @exception runtime_error (inconsistent linear system)
705  *  @see       solve_algo */
706 matrix matrix::solve(const matrix & vars,
707                                          const matrix & rhs,
708                                          unsigned algo) const
709 {
710         const unsigned m = this->rows();
711         const unsigned n = this->cols();
712         const unsigned p = rhs.cols();
713         
714         // syntax checks    
715         if ((rhs.rows() != m) || (vars.rows() != n) || (vars.col != p))
716                 throw (std::logic_error("matrix::solve(): incompatible matrices"));
717         for (unsigned ro=0; ro<n; ++ro)
718                 for (unsigned co=0; co<p; ++co)
719                         if (!vars(ro,co).info(info_flags::symbol))
720                                 throw (std::invalid_argument("matrix::solve(): 1st argument must be matrix of symbols"));
721         
722         // build the augmented matrix of *this with rhs attached to the right
723         matrix aug(m,n+p);
724         for (unsigned r=0; r<m; ++r) {
725                 for (unsigned c=0; c<n; ++c)
726                         aug.m[r*(n+p)+c] = this->m[r*n+c];
727                 for (unsigned c=0; c<p; ++c)
728                         aug.m[r*(n+p)+c+n] = rhs.m[r*p+c];
729         }
730         
731         // Gather some statistical information about the augmented matrix:
732         bool numeric_flag = true;
733         for (exvector::const_iterator r=aug.m.begin(); r!=aug.m.end(); ++r) {
734                 if (!(*r).info(info_flags::numeric))
735                         numeric_flag = false;
736         }
737         
738         // Here is the heuristics in case this routine has to decide:
739         if (algo == solve_algo::automatic) {
740                 // Bareiss (fraction-free) elimination is generally a good guess:
741                 algo = solve_algo::bareiss;
742                 // For m<3, Bareiss elimination is equivalent to division free
743                 // elimination but has more logistic overhead
744                 if (m<3)
745                         algo = solve_algo::divfree;
746                 // This overrides any prior decisions.
747                 if (numeric_flag)
748                         algo = solve_algo::gauss;
749         }
750         
751         // Eliminate the augmented matrix:
752         switch(algo) {
753                 case solve_algo::gauss:
754                         aug.gauss_elimination();
755                 case solve_algo::divfree:
756                         aug.division_free_elimination();
757                 case solve_algo::bareiss:
758                 default:
759                         aug.fraction_free_elimination();
760         }
761         
762         // assemble the solution matrix:
763         matrix sol(n,p);
764         for (unsigned co=0; co<p; ++co) {
765                 unsigned last_assigned_sol = n+1;
766                 for (int r=m-1; r>=0; --r) {
767                         unsigned fnz = 1;    // first non-zero in row
768                         while ((fnz<=n) && (aug.m[r*(n+p)+(fnz-1)].is_zero()))
769                                 ++fnz;
770                         if (fnz>n) {
771                                 // row consists only of zeros, corresponding rhs must be 0, too
772                                 if (!aug.m[r*(n+p)+n+co].is_zero()) {
773                                         throw (std::runtime_error("matrix::solve(): inconsistent linear system"));
774                                 }
775                         } else {
776                                 // assign solutions for vars between fnz+1 and
777                                 // last_assigned_sol-1: free parameters
778                                 for (unsigned c=fnz; c<last_assigned_sol-1; ++c)
779                                         sol.set(c,co,vars.m[c*p+co]);
780                                 ex e = aug.m[r*(n+p)+n+co];
781                                 for (unsigned c=fnz; c<n; ++c)
782                                         e -= aug.m[r*(n+p)+c]*sol.m[c*p+co];
783                                 sol.set(fnz-1,co,
784                                                 (e/(aug.m[r*(n+p)+(fnz-1)])).normal());
785                                 last_assigned_sol = fnz;
786                         }
787                 }
788                 // assign solutions for vars between 1 and
789                 // last_assigned_sol-1: free parameters
790                 for (unsigned ro=0; ro<last_assigned_sol-1; ++ro)
791                         sol.set(ro,co,vars(ro,co));
792         }
793         
794         return sol;
795 }
796
797
798 // protected
799
800 /** Recursive determinant for small matrices having at least one symbolic
801  *  entry.  The basic algorithm, known as Laplace-expansion, is enhanced by
802  *  some bookkeeping to avoid calculation of the same submatrices ("minors")
803  *  more than once.  According to W.M.Gentleman and S.C.Johnson this algorithm
804  *  is better than elimination schemes for matrices of sparse multivariate
805  *  polynomials and also for matrices of dense univariate polynomials if the
806  *  matrix' dimesion is larger than 7.
807  *
808  *  @return the determinant as a new expression (in expanded form)
809  *  @see matrix::determinant() */
810 ex matrix::determinant_minor(void) const
811 {
812         // for small matrices the algorithm does not make any sense:
813         const unsigned n = this->cols();
814         if (n==1)
815                 return m[0].expand();
816         if (n==2)
817                 return (m[0]*m[3]-m[2]*m[1]).expand();
818         if (n==3)
819                 return (m[0]*m[4]*m[8]-m[0]*m[5]*m[7]-
820                         m[1]*m[3]*m[8]+m[2]*m[3]*m[7]+
821                         m[1]*m[5]*m[6]-m[2]*m[4]*m[6]).expand();
822         
823         // This algorithm can best be understood by looking at a naive
824         // implementation of Laplace-expansion, like this one:
825         // ex det;
826         // matrix minorM(this->rows()-1,this->cols()-1);
827         // for (unsigned r1=0; r1<this->rows(); ++r1) {
828         //     // shortcut if element(r1,0) vanishes
829         //     if (m[r1*col].is_zero())
830         //         continue;
831         //     // assemble the minor matrix
832         //     for (unsigned r=0; r<minorM.rows(); ++r) {
833         //         for (unsigned c=0; c<minorM.cols(); ++c) {
834         //             if (r<r1)
835         //                 minorM.set(r,c,m[r*col+c+1]);
836         //             else
837         //                 minorM.set(r,c,m[(r+1)*col+c+1]);
838         //         }
839         //     }
840         //     // recurse down and care for sign:
841         //     if (r1%2)
842         //         det -= m[r1*col] * minorM.determinant_minor();
843         //     else
844         //         det += m[r1*col] * minorM.determinant_minor();
845         // }
846         // return det.expand();
847         // What happens is that while proceeding down many of the minors are
848         // computed more than once.  In particular, there are binomial(n,k)
849         // kxk minors and each one is computed factorial(n-k) times.  Therefore
850         // it is reasonable to store the results of the minors.  We proceed from
851         // right to left.  At each column c we only need to retrieve the minors
852         // calculated in step c-1.  We therefore only have to store at most 
853         // 2*binomial(n,n/2) minors.
854         
855         // Unique flipper counter for partitioning into minors
856         std::vector<unsigned> Pkey;
857         Pkey.reserve(n);
858         // key for minor determinant (a subpartition of Pkey)
859         std::vector<unsigned> Mkey;
860         Mkey.reserve(n-1);
861         // we store our subminors in maps, keys being the rows they arise from
862         typedef std::map<std::vector<unsigned>,class ex> Rmap;
863         typedef std::map<std::vector<unsigned>,class ex>::value_type Rmap_value;
864         Rmap A;
865         Rmap B;
866         ex det;
867         // initialize A with last column:
868         for (unsigned r=0; r<n; ++r) {
869                 Pkey.erase(Pkey.begin(),Pkey.end());
870                 Pkey.push_back(r);
871                 A.insert(Rmap_value(Pkey,m[n*(r+1)-1]));
872         }
873         // proceed from right to left through matrix
874         for (int c=n-2; c>=0; --c) {
875                 Pkey.erase(Pkey.begin(),Pkey.end());  // don't change capacity
876                 Mkey.erase(Mkey.begin(),Mkey.end());
877                 for (unsigned i=0; i<n-c; ++i)
878                         Pkey.push_back(i);
879                 unsigned fc = 0;  // controls logic for our strange flipper counter
880                 do {
881                         det = _ex0();
882                         for (unsigned r=0; r<n-c; ++r) {
883                                 // maybe there is nothing to do?
884                                 if (m[Pkey[r]*n+c].is_zero())
885                                         continue;
886                                 // create the sorted key for all possible minors
887                                 Mkey.erase(Mkey.begin(),Mkey.end());
888                                 for (unsigned i=0; i<n-c; ++i)
889                                         if (i!=r)
890                                                 Mkey.push_back(Pkey[i]);
891                                 // Fetch the minors and compute the new determinant
892                                 if (r%2)
893                                         det -= m[Pkey[r]*n+c]*A[Mkey];
894                                 else
895                                         det += m[Pkey[r]*n+c]*A[Mkey];
896                         }
897                         // prevent build-up of deep nesting of expressions saves time:
898                         det = det.expand();
899                         // store the new determinant at its place in B:
900                         if (!det.is_zero())
901                                 B.insert(Rmap_value(Pkey,det));
902                         // increment our strange flipper counter
903                         for (fc=n-c; fc>0; --fc) {
904                                 ++Pkey[fc-1];
905                                 if (Pkey[fc-1]<fc+c)
906                                         break;
907                         }
908                         if (fc<n-c && fc>0)
909                                 for (unsigned j=fc; j<n-c; ++j)
910                                         Pkey[j] = Pkey[j-1]+1;
911                 } while(fc);
912                 // next column, so change the role of A and B:
913                 A = B;
914                 B.clear();
915         }
916         
917         return det;
918 }
919
920
921 /** Perform the steps of an ordinary Gaussian elimination to bring the m x n
922  *  matrix into an upper echelon form.  The algorithm is ok for matrices
923  *  with numeric coefficients but quite unsuited for symbolic matrices.
924  *
925  *  @param det may be set to true to save a lot of space if one is only
926  *  interested in the diagonal elements (i.e. for calculating determinants).
927  *  The others are set to zero in this case.
928  *  @return sign is 1 if an even number of rows was swapped, -1 if an odd
929  *  number of rows was swapped and 0 if the matrix is singular. */
930 int matrix::gauss_elimination(const bool det)
931 {
932         ensure_if_modifiable();
933         const unsigned m = this->rows();
934         const unsigned n = this->cols();
935         GINAC_ASSERT(!det || n==m);
936         int sign = 1;
937         
938         unsigned r0 = 0;
939         for (unsigned r1=0; (r1<n-1)&&(r0<m-1); ++r1) {
940                 int indx = pivot(r0, r1, true);
941                 if (indx == -1) {
942                         sign = 0;
943                         if (det)
944                                 return 0;  // leaves *this in a messy state
945                 }
946                 if (indx>=0) {
947                         if (indx > 0)
948                                 sign = -sign;
949                         for (unsigned r2=r0+1; r2<m; ++r2) {
950                                 if (!this->m[r2*n+r1].is_zero()) {
951                                         // yes, there is something to do in this row
952                                         ex piv = this->m[r2*n+r1] / this->m[r0*n+r1];
953                                         for (unsigned c=r1+1; c<n; ++c) {
954                                                 this->m[r2*n+c] -= piv * this->m[r0*n+c];
955                                                 if (!this->m[r2*n+c].info(info_flags::numeric))
956                                                         this->m[r2*n+c] = this->m[r2*n+c].normal();
957                                         }
958                                 }
959                                 // fill up left hand side with zeros
960                                 for (unsigned c=0; c<=r1; ++c)
961                                         this->m[r2*n+c] = _ex0();
962                         }
963                         if (det) {
964                                 // save space by deleting no longer needed elements
965                                 for (unsigned c=r0+1; c<n; ++c)
966                                         this->m[r0*n+c] = _ex0();
967                         }
968                         ++r0;
969                 }
970         }
971         
972         return sign;
973 }
974
975
976 /** Perform the steps of division free elimination to bring the m x n matrix
977  *  into an upper echelon form.
978  *
979  *  @param det may be set to true to save a lot of space if one is only
980  *  interested in the diagonal elements (i.e. for calculating determinants).
981  *  The others are set to zero in this case.
982  *  @return sign is 1 if an even number of rows was swapped, -1 if an odd
983  *  number of rows was swapped and 0 if the matrix is singular. */
984 int matrix::division_free_elimination(const bool det)
985 {
986         ensure_if_modifiable();
987         const unsigned m = this->rows();
988         const unsigned n = this->cols();
989         GINAC_ASSERT(!det || n==m);
990         int sign = 1;
991         
992         unsigned r0 = 0;
993         for (unsigned r1=0; (r1<n-1)&&(r0<m-1); ++r1) {
994                 int indx = pivot(r0, r1, true);
995                 if (indx==-1) {
996                         sign = 0;
997                         if (det)
998                                 return 0;  // leaves *this in a messy state
999                 }
1000                 if (indx>=0) {
1001                         if (indx>0)
1002                                 sign = -sign;
1003                         for (unsigned r2=r0+1; r2<m; ++r2) {
1004                                 for (unsigned c=r1+1; c<n; ++c)
1005                                         this->m[r2*n+c] = (this->m[r0*n+r1]*this->m[r2*n+c] - this->m[r2*n+r1]*this->m[r0*n+c]).expand();
1006                                 // fill up left hand side with zeros
1007                                 for (unsigned c=0; c<=r1; ++c)
1008                                         this->m[r2*n+c] = _ex0();
1009                         }
1010                         if (det) {
1011                                 // save space by deleting no longer needed elements
1012                                 for (unsigned c=r0+1; c<n; ++c)
1013                                         this->m[r0*n+c] = _ex0();
1014                         }
1015                         ++r0;
1016                 }
1017         }
1018         
1019         return sign;
1020 }
1021
1022
1023 /** Perform the steps of Bareiss' one-step fraction free elimination to bring
1024  *  the matrix into an upper echelon form.  Fraction free elimination means
1025  *  that divide is used straightforwardly, without computing GCDs first.  This
1026  *  is possible, since we know the divisor at each step.
1027  *  
1028  *  @param det may be set to true to save a lot of space if one is only
1029  *  interested in the last element (i.e. for calculating determinants). The
1030  *  others are set to zero in this case.
1031  *  @return sign is 1 if an even number of rows was swapped, -1 if an odd
1032  *  number of rows was swapped and 0 if the matrix is singular. */
1033 int matrix::fraction_free_elimination(const bool det)
1034 {
1035         // Method:
1036         // (single-step fraction free elimination scheme, already known to Jordan)
1037         //
1038         // Usual division-free elimination sets m[0](r,c) = m(r,c) and then sets
1039         //     m[k+1](r,c) = m[k](k,k) * m[k](r,c) - m[k](r,k) * m[k](k,c).
1040         //
1041         // Bareiss (fraction-free) elimination in addition divides that element
1042         // by m[k-1](k-1,k-1) for k>1, where it can be shown by means of the
1043         // Sylvester determinant that this really divides m[k+1](r,c).
1044         //
1045         // We also allow rational functions where the original prove still holds.
1046         // However, we must care for numerator and denominator separately and
1047         // "manually" work in the integral domains because of subtle cancellations
1048         // (see below).  This blows up the bookkeeping a bit and the formula has
1049         // to be modified to expand like this (N{x} stands for numerator of x,
1050         // D{x} for denominator of x):
1051         //     N{m[k+1](r,c)} = N{m[k](k,k)}*N{m[k](r,c)}*D{m[k](r,k)}*D{m[k](k,c)}
1052         //                     -N{m[k](r,k)}*N{m[k](k,c)}*D{m[k](k,k)}*D{m[k](r,c)}
1053         //     D{m[k+1](r,c)} = D{m[k](k,k)}*D{m[k](r,c)}*D{m[k](r,k)}*D{m[k](k,c)}
1054         // where for k>1 we now divide N{m[k+1](r,c)} by
1055         //     N{m[k-1](k-1,k-1)}
1056         // and D{m[k+1](r,c)} by
1057         //     D{m[k-1](k-1,k-1)}.
1058         
1059         ensure_if_modifiable();
1060         const unsigned m = this->rows();
1061         const unsigned n = this->cols();
1062         GINAC_ASSERT(!det || n==m);
1063         int sign = 1;
1064         if (m==1)
1065                 return 1;
1066         ex divisor_n = 1;
1067         ex divisor_d = 1;
1068         ex dividend_n;
1069         ex dividend_d;
1070         
1071         // We populate temporary matrices to subsequently operate on.  There is
1072         // one holding numerators and another holding denominators of entries.
1073         // This is a must since the evaluator (or even earlier mul's constructor)
1074         // might cancel some trivial element which causes divide() to fail.  The
1075         // elements are normalized first (yes, even though this algorithm doesn't
1076         // need GCDs) since the elements of *this might be unnormalized, which
1077         // makes things more complicated than they need to be.
1078         matrix tmp_n(*this);
1079         matrix tmp_d(m,n);  // for denominators, if needed
1080         lst srl;  // symbol replacement list
1081         exvector::iterator it = this->m.begin();
1082         exvector::iterator tmp_n_it = tmp_n.m.begin();
1083         exvector::iterator tmp_d_it = tmp_d.m.begin();
1084         for (; it!= this->m.end(); ++it, ++tmp_n_it, ++tmp_d_it) {
1085                 (*tmp_n_it) = (*it).normal().to_rational(srl);
1086                 (*tmp_d_it) = (*tmp_n_it).denom();
1087                 (*tmp_n_it) = (*tmp_n_it).numer();
1088         }
1089         
1090         unsigned r0 = 0;
1091         for (unsigned r1=0; (r1<n-1)&&(r0<m-1); ++r1) {
1092                 int indx = tmp_n.pivot(r0, r1, true);
1093                 if (indx==-1) {
1094                         sign = 0;
1095                         if (det)
1096                                 return 0;
1097                 }
1098                 if (indx>=0) {
1099                         if (indx>0) {
1100                                 sign = -sign;
1101                                 // tmp_n's rows r0 and indx were swapped, do the same in tmp_d:
1102                                 for (unsigned c=r1; c<n; ++c)
1103                                         tmp_d.m[n*indx+c].swap(tmp_d.m[n*r0+c]);
1104                         }
1105                         for (unsigned r2=r0+1; r2<m; ++r2) {
1106                                 for (unsigned c=r1+1; c<n; ++c) {
1107                                         dividend_n = (tmp_n.m[r0*n+r1]*tmp_n.m[r2*n+c]*
1108                                                       tmp_d.m[r2*n+r1]*tmp_d.m[r0*n+c]
1109                                                      -tmp_n.m[r2*n+r1]*tmp_n.m[r0*n+c]*
1110                                                       tmp_d.m[r0*n+r1]*tmp_d.m[r2*n+c]).expand();
1111                                         dividend_d = (tmp_d.m[r2*n+r1]*tmp_d.m[r0*n+c]*
1112                                                       tmp_d.m[r0*n+r1]*tmp_d.m[r2*n+c]).expand();
1113                                         bool check = divide(dividend_n, divisor_n,
1114                                                             tmp_n.m[r2*n+c], true);
1115                                         check &= divide(dividend_d, divisor_d,
1116                                                         tmp_d.m[r2*n+c], true);
1117                                         GINAC_ASSERT(check);
1118                                 }
1119                                 // fill up left hand side with zeros
1120                                 for (unsigned c=0; c<=r1; ++c)
1121                                         tmp_n.m[r2*n+c] = _ex0();
1122                         }
1123                         if ((r1<n-1)&&(r0<m-1)) {
1124                                 // compute next iteration's divisor
1125                                 divisor_n = tmp_n.m[r0*n+r1].expand();
1126                                 divisor_d = tmp_d.m[r0*n+r1].expand();
1127                                 if (det) {
1128                                         // save space by deleting no longer needed elements
1129                                         for (unsigned c=0; c<n; ++c) {
1130                                                 tmp_n.m[r0*n+c] = _ex0();
1131                                                 tmp_d.m[r0*n+c] = _ex1();
1132                                         }
1133                                 }
1134                         }
1135                         ++r0;
1136                 }
1137         }
1138         // repopulate *this matrix:
1139         it = this->m.begin();
1140         tmp_n_it = tmp_n.m.begin();
1141         tmp_d_it = tmp_d.m.begin();
1142         for (; it!= this->m.end(); ++it, ++tmp_n_it, ++tmp_d_it)
1143                 (*it) = ((*tmp_n_it)/(*tmp_d_it)).subs(srl);
1144         
1145         return sign;
1146 }
1147
1148
1149 /** Partial pivoting method for matrix elimination schemes.
1150  *  Usual pivoting (symbolic==false) returns the index to the element with the
1151  *  largest absolute value in column ro and swaps the current row with the one
1152  *  where the element was found.  With (symbolic==true) it does the same thing
1153  *  with the first non-zero element.
1154  *
1155  *  @param ro is the row from where to begin
1156  *  @param co is the column to be inspected
1157  *  @param symbolic signal if we want the first non-zero element to be pivoted
1158  *  (true) or the one with the largest absolute value (false).
1159  *  @return 0 if no interchange occured, -1 if all are zero (usually signaling
1160  *  a degeneracy) and positive integer k means that rows ro and k were swapped.
1161  */
1162 int matrix::pivot(unsigned ro, unsigned co, bool symbolic)
1163 {
1164         unsigned k = ro;
1165         if (symbolic) {
1166                 // search first non-zero element in column co beginning at row ro
1167                 while ((k<row) && (this->m[k*col+co].expand().is_zero()))
1168                         ++k;
1169         } else {
1170                 // search largest element in column co beginning at row ro
1171                 GINAC_ASSERT(is_ex_of_type(this->m[k*col+co],numeric));
1172                 unsigned kmax = k+1;
1173                 numeric mmax = abs(ex_to_numeric(m[kmax*col+co]));
1174                 while (kmax<row) {
1175                         GINAC_ASSERT(is_ex_of_type(this->m[kmax*col+co],numeric));
1176                         numeric tmp = ex_to_numeric(this->m[kmax*col+co]);
1177                         if (abs(tmp) > mmax) {
1178                                 mmax = tmp;
1179                                 k = kmax;
1180                         }
1181                         ++kmax;
1182                 }
1183                 if (!mmax.is_zero())
1184                         k = kmax;
1185         }
1186         if (k==row)
1187                 // all elements in column co below row ro vanish
1188                 return -1;
1189         if (k==ro)
1190                 // matrix needs no pivoting
1191                 return 0;
1192         // matrix needs pivoting, so swap rows k and ro
1193         ensure_if_modifiable();
1194         for (unsigned c=0; c<col; ++c)
1195                 this->m[k*col+c].swap(this->m[ro*col+c]);
1196         
1197         return k;
1198 }
1199
1200 /** Convert list of lists to matrix. */
1201 ex lst_to_matrix(const ex &l)
1202 {
1203         if (!is_ex_of_type(l, lst))
1204                 throw(std::invalid_argument("argument to lst_to_matrix() must be a lst"));
1205         
1206         // Find number of rows and columns
1207         unsigned rows = l.nops(), cols = 0, i, j;
1208         for (i=0; i<rows; i++)
1209                 if (l.op(i).nops() > cols)
1210                         cols = l.op(i).nops();
1211         
1212         // Allocate and fill matrix
1213         matrix &m = *new matrix(rows, cols);
1214         for (i=0; i<rows; i++)
1215                 for (j=0; j<cols; j++)
1216                         if (l.op(i).nops() > j)
1217                                 m.set(i, j, l.op(i).op(j));
1218                         else
1219                                 m.set(i, j, ex(0));
1220         return m;
1221 }
1222
1223 //////////
1224 // global constants
1225 //////////
1226
1227 const matrix some_matrix;
1228 const std::type_info & typeid_matrix = typeid(some_matrix);
1229
1230 #ifndef NO_NAMESPACE_GINAC
1231 } // namespace GiNaC
1232 #endif // ndef NO_NAMESPACE_GINAC