]> www.ginac.de Git - ginac.git/blobdiff - ginac/matrix.cpp
- matrix::pow(): omit last big multiplication if it's not needed.
[ginac.git] / ginac / matrix.cpp
index cc80f7cae43b66d1223f0364acdac57ba533cf9f..872b8e006012c3a4ecf6279cf315ed8c34e5f793 100644 (file)
 #include <stdexcept>
 
 #include "matrix.h"
-#include "archive.h"
 #include "numeric.h"
 #include "lst.h"
 #include "idx.h"
 #include "indexed.h"
-#include "utils.h"
-#include "debugmsg.h"
 #include "power.h"
 #include "symbol.h"
 #include "normal.h"
+#include "print.h"
+#include "archive.h"
+#include "utils.h"
+#include "debugmsg.h"
 
 namespace GiNaC {
 
@@ -44,8 +45,6 @@ GINAC_IMPLEMENT_REGISTERED_CLASS(matrix, basic)
 // default ctor, dtor, copy ctor, assignment operator and helpers:
 //////////
 
-// public
-
 /** Default ctor.  Initializes to 1 x 1-dimensional zero-matrix. */
 matrix::matrix() : inherited(TINFO_matrix), row(1), col(1)
 {
@@ -53,9 +52,6 @@ matrix::matrix() : inherited(TINFO_matrix), row(1), col(1)
        m.push_back(_ex0());
 }
 
-// protected
-
-/** For use by copy ctor and assignment operator. */
 void matrix::copy(const matrix & other)
 {
        inherited::copy(other);
@@ -64,10 +60,7 @@ void matrix::copy(const matrix & other)
        m = other.m;  // STL's vector copying invoked here
 }
 
-void matrix::destroy(bool call_parent)
-{
-       if (call_parent) inherited::destroy(call_parent);
-}
+DEFAULT_DESTROY(matrix)
 
 //////////
 // other ctors
@@ -118,7 +111,6 @@ matrix::matrix(unsigned r, unsigned c, const lst & l)
 // archiving
 //////////
 
-/** Construct object from archive_node. */
 matrix::matrix(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
 {
        debugmsg("matrix ctor from archive_node", LOGLEVEL_CONSTRUCT);
@@ -134,13 +126,6 @@ matrix::matrix(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst
        }
 }
 
-/** Unarchive the object. */
-ex matrix::unarchive(const archive_node &n, const lst &sym_lst)
-{
-       return (new matrix(n, sym_lst))->setflag(status_flags::dynallocated);
-}
-
-/** Archive the object. */
 void matrix::archive(archive_node &n) const
 {
        inherited::archive(n);
@@ -153,42 +138,43 @@ void matrix::archive(archive_node &n) const
        }
 }
 
+DEFAULT_UNARCHIVE(matrix)
+
 //////////
 // functions overriding virtual functions from bases classes
 //////////
 
 // public
 
-void matrix::print(std::ostream & os, unsigned upper_precedence) const
+void matrix::print(const print_context & c, unsigned level) const
 {
-       debugmsg("matrix print",LOGLEVEL_PRINT);
-       os << "[[ ";
-       for (unsigned r=0; r<row-1; ++r) {
-               os << "[[";
-               for (unsigned c=0; c<col-1; ++c)
-                       os << m[r*col+c] << ",";
-               os << m[col*(r+1)-1] << "]], ";
-       }
-       os << "[[";
-       for (unsigned c=0; c<col-1; ++c)
-               os << m[(row-1)*col+c] << ",";
-       os << m[row*col-1] << "]] ]]";
-}
+       debugmsg("matrix print", LOGLEVEL_PRINT);
+
+       if (is_of_type(c, print_tree)) {
+
+               inherited::print(c, level);
+
+       } else {
+
+               c.s << "[";
+               for (unsigned y=0; y<row-1; ++y) {
+                       c.s << "[";
+                       for (unsigned x=0; x<col-1; ++x) {
+                               m[y*col+x].print(c);
+                               c.s << ",";
+                       }
+                       m[col*(y+1)-1].print(c);
+                       c.s << "],";
+               }
+               c.s << "[";
+               for (unsigned x=0; x<col-1; ++x) {
+                       m[(row-1)*col+x].print(c);
+                       c.s << ",";
+               }
+               m[row*col-1].print(c);
+               c.s << "]]";
 
-void matrix::printraw(std::ostream & os) const
-{
-       debugmsg("matrix printraw",LOGLEVEL_PRINT);
-       os << class_name() << "(" << row << "," << col <<",";
-       for (unsigned r=0; r<row-1; ++r) {
-               os << "(";
-               for (unsigned c=0; c<col-1; ++c)
-                       os << m[r*col+c] << ",";
-               os << m[col*(r-1)-1] << "),";
        }
-       os << "(";
-       for (unsigned c=0; c<col-1; ++c)
-               os << m[(row-1)*col+c] << ",";
-       os << m[row*col-1] << "))";
 }
 
 /** nops is defined to be rows x columns. */
@@ -222,22 +208,6 @@ ex matrix::expand(unsigned options) const
        return matrix(row, col, tmp);
 }
 
-/** Search ocurrences.  A matrix 'has' an expression if it is the expression
- *  itself or one of the elements 'has' it. */
-bool matrix::has(const ex & other) const
-{
-       GINAC_ASSERT(other.bp!=0);
-       
-       // tautology: it is the expression itself
-       if (is_equal(*other.bp)) return true;
-       
-       // search all the elements
-       for (exvector::const_iterator r=m.begin(); r!=m.end(); ++r)
-               if ((*r).has(other)) return true;
-       
-       return false;
-}
-
 /** Evaluate matrix entry by entry. */
 ex matrix::eval(int level) const
 {
@@ -286,14 +256,14 @@ ex matrix::evalf(int level) const
        return matrix(row, col, m2);
 }
 
-ex matrix::subs(const lst & ls, const lst & lr) const
+ex matrix::subs(const lst & ls, const lst & lr, bool no_pattern) const
 {
        exvector m2(row * col);
        for (unsigned r=0; r<row; ++r)
                for (unsigned c=0; c<col; ++c)
-                       m2[r*col+c] = m[r*col+c].subs(ls, lr);
+                       m2[r*col+c] = m[r*col+c].subs(ls, lr, no_pattern);
 
-       return matrix(row, col, m2);
+       return ex(matrix(row, col, m2)).bp->basic::subs(ls, lr, no_pattern);
 }
 
 // protected
@@ -564,7 +534,7 @@ bool matrix::contract_with(exvector::iterator self, exvector::iterator other, ex
 matrix matrix::add(const matrix & other) const
 {
        if (col != other.col || row != other.row)
-               throw (std::logic_error("matrix::add(): incompatible matrices"));
+               throw std::logic_error("matrix::add(): incompatible matrices");
        
        exvector sum(this->m);
        exvector::iterator i;
@@ -582,7 +552,7 @@ matrix matrix::add(const matrix & other) const
 matrix matrix::sub(const matrix & other) const
 {
        if (col != other.col || row != other.row)
-               throw (std::logic_error("matrix::sub(): incompatible matrices"));
+               throw std::logic_error("matrix::sub(): incompatible matrices");
        
        exvector dif(this->m);
        exvector::iterator i;
@@ -600,7 +570,7 @@ matrix matrix::sub(const matrix & other) const
 matrix matrix::mul(const matrix & other) const
 {
        if (this->cols() != other.rows())
-               throw (std::logic_error("matrix::mul(): incompatible matrices"));
+               throw std::logic_error("matrix::mul(): incompatible matrices");
        
        exvector prod(this->rows()*other.cols());
        
@@ -629,7 +599,65 @@ matrix matrix::mul(const numeric & other) const
 }
 
 
-/** operator() to access elements.
+/** Product of matrix and scalar expression. */
+matrix matrix::mul_scalar(const ex & other) const
+{
+       if (other.return_type() != return_types::commutative)
+               throw std::runtime_error("matrix::mul_scalar(): non-commutative scalar");
+
+       exvector prod(row * col);
+
+       for (unsigned r=0; r<row; ++r)
+               for (unsigned c=0; c<col; ++c)
+                       prod[r*col+c] = m[r*col+c] * other;
+
+       return matrix(row, col, prod);
+}
+
+
+/** Power of a matrix.  Currently handles integer exponents only. */
+matrix matrix::pow(const ex & expn) const
+{
+       if (col!=row)
+               throw (std::logic_error("matrix::pow(): matrix not square"));
+       
+       if (is_ex_exactly_of_type(expn, numeric)) {
+               // Integer cases are computed by successive multiplication, using the
+               // obvious shortcut of storing temporaries, like A^4 == (A*A)*(A*A).
+               if (expn.info(info_flags::integer)) {
+                       numeric k;
+                       matrix prod(row,col);
+                       if (expn.info(info_flags::negative)) {
+                               k = -ex_to_numeric(expn);
+                               prod = this->inverse();
+                       } else {
+                               k = ex_to_numeric(expn);
+                               prod = *this;
+                       }
+                       matrix result(row,col);
+                       for (unsigned r=0; r<row; ++r)
+                               result(r,r) = _ex1();
+                       numeric b(1);
+                       // this loop computes the representation of k in base 2 and
+                       // multiplies the factors whenever needed:
+                       while (b.compare(k)<=0) {
+                               b *= numeric(2);
+                               numeric r(mod(k,b));
+                               if (!r.is_zero()) {
+                                       k -= r;
+                                       result = result.mul(prod);
+                               }
+                               if (b.compare(k)<=0)
+                                       prod = prod.mul(prod);
+                       }
+                       return result;
+               }
+       }
+       throw (std::runtime_error("matrix::pow(): don't know how to handle exponent"));
+}
+
+
+/** operator() to access elements for reading.
  *
  *  @param ro row of element
  *  @param co column of element
@@ -643,17 +671,19 @@ const ex & matrix::operator() (unsigned ro, unsigned co) const
 }
 
 
-/** Set individual elements manually.
+/** operator() to access elements for writing.
  *
+ *  @param ro row of element
+ *  @param co column of element
  *  @exception range_error (index out of range) */
-matrix & matrix::set(unsigned ro, unsigned co, ex value)
+ex & matrix::operator() (unsigned ro, unsigned co)
 {
        if (ro>=row || co>=col)
-               throw (std::range_error("matrix::set(): index out of range"));
-    
+               throw (std::range_error("matrix::operator(): index out of range"));
+
        ensure_if_modifiable();
-       m[ro*col+co] = value;
-       return *this;
+       clearflag(status_flags::hash_calculated);
+       return m[ro*col+co];
 }
 
 
@@ -670,7 +700,6 @@ matrix matrix::transpose(void) const
        return matrix(this->cols(),this->rows(),trans);
 }
 
-
 /** Determinant of square matrix.  This routine doesn't actually calculate the
  *  determinant, it only implements some heuristics about which algorithm to
  *  run.  If all the elements of the matrix are elements of an integral domain
@@ -786,7 +815,8 @@ ex matrix::determinant(unsigned algo) const
                        std::vector<unsigned> pre_sort;
                        for (std::vector<uintpair>::iterator i=c_zeros.begin(); i!=c_zeros.end(); ++i)
                                pre_sort.push_back(i->second);
-                       int sign = permutation_sign(pre_sort);
+                       std::vector<unsigned> pre_sort_test(pre_sort); // permutation_sign() modifies the vector so we make a copy here
+                       int sign = permutation_sign(pre_sort_test.begin(), pre_sort_test.end());
                        exvector result(row*col);  // represents sorted matrix
                        unsigned c = 0;
                        for (std::vector<unsigned>::iterator i=pre_sort.begin();
@@ -889,50 +919,32 @@ matrix matrix::inverse(void) const
        if (row != col)
                throw (std::logic_error("matrix::inverse(): matrix not square"));
        
-       // NOTE: the Gauss-Jordan elimination used here can in principle be
-       // replaced by two clever calls to gauss_elimination() and some to
-       // transpose().  Wouldn't be more efficient (maybe less?), just more
-       // orthogonal.
-       matrix tmp(row,col);
-       // set tmp to the unit matrix
-       for (unsigned i=0; i<col; ++i)
-               tmp.m[i*col+i] = _ex1();
+       // This routine actually doesn't do anything fancy at all.  We compute the
+       // inverse of the matrix A by solving the system A * A^{-1} == Id.
        
-       // create a copy of this matrix
-       matrix cpy(*this);
-       for (unsigned r1=0; r1<row; ++r1) {
-               int indx = cpy.pivot(r1, r1);
-               if (indx == -1) {
+       // First populate the identity matrix supposed to become the right hand side.
+       matrix identity(row,col);
+       for (unsigned i=0; i<row; ++i)
+               identity(i,i) = _ex1();
+       
+       // Populate a dummy matrix of variables, just because of compatibility with
+       // matrix::solve() which wants this (for compatibility with under-determined
+       // systems of equations).
+       matrix vars(row,col);
+       for (unsigned r=0; r<row; ++r)
+               for (unsigned c=0; c<col; ++c)
+                       vars(r,c) = symbol();
+       
+       matrix sol(row,col);
+       try {
+               sol = this->solve(vars,identity);
+       } catch (const std::runtime_error & e) {
+           if (e.what()==std::string("matrix::solve(): inconsistent linear system"))
                        throw (std::runtime_error("matrix::inverse(): singular matrix"));
-               }
-               if (indx != 0) {  // swap rows r and indx of matrix tmp
-                       for (unsigned i=0; i<col; ++i)
-                               tmp.m[r1*col+i].swap(tmp.m[indx*col+i]);
-               }
-               ex a1 = cpy.m[r1*col+r1];
-               for (unsigned c=0; c<col; ++c) {
-                       cpy.m[r1*col+c] /= a1;
-                       tmp.m[r1*col+c] /= a1;
-               }
-               for (unsigned r2=0; r2<row; ++r2) {
-                       if (r2 != r1) {
-                               if (!cpy.m[r2*col+r1].is_zero()) {
-                                       ex a2 = cpy.m[r2*col+r1];
-                                       // yes, there is something to do in this column
-                                       for (unsigned c=0; c<col; ++c) {
-                                               cpy.m[r2*col+c] -= a2 * cpy.m[r1*col+c];
-                                               if (!cpy.m[r2*col+c].info(info_flags::numeric))
-                                                       cpy.m[r2*col+c] = cpy.m[r2*col+c].normal();
-                                               tmp.m[r2*col+c] -= a2 * tmp.m[r1*col+c];
-                                               if (!tmp.m[r2*col+c].info(info_flags::numeric))
-                                                       tmp.m[r2*col+c] = tmp.m[r2*col+c].normal();
-                                       }
-                               }
-                       }
-               }
+               else
+                       throw;
        }
-       
-       return tmp;
+       return sol;
 }
 
 
@@ -995,8 +1007,10 @@ matrix matrix::solve(const matrix & vars,
        switch(algo) {
                case solve_algo::gauss:
                        aug.gauss_elimination();
+                       break;
                case solve_algo::divfree:
                        aug.division_free_elimination();
+                       break;
                case solve_algo::bareiss:
                default:
                        aug.fraction_free_elimination();
@@ -1019,19 +1033,18 @@ matrix matrix::solve(const matrix & vars,
                                // assign solutions for vars between fnz+1 and
                                // last_assigned_sol-1: free parameters
                                for (unsigned c=fnz; c<last_assigned_sol-1; ++c)
-                                       sol.set(c,co,vars.m[c*p+co]);
+                                       sol(c,co) = vars.m[c*p+co];
                                ex e = aug.m[r*(n+p)+n+co];
                                for (unsigned c=fnz; c<n; ++c)
                                        e -= aug.m[r*(n+p)+c]*sol.m[c*p+co];
-                               sol.set(fnz-1,co,
-                                               (e/(aug.m[r*(n+p)+(fnz-1)])).normal());
+                               sol(fnz-1,co) = (e/(aug.m[r*(n+p)+(fnz-1)])).normal();
                                last_assigned_sol = fnz;
                        }
                }
                // assign solutions for vars between 1 and
                // last_assigned_sol-1: free parameters
                for (unsigned ro=0; ro<last_assigned_sol-1; ++ro)
-                       sol.set(ro,co,vars(ro,co));
+                       sol(ro,co) = vars(ro,co);
        }
        
        return sol;
@@ -1075,9 +1088,9 @@ ex matrix::determinant_minor(void) const
        //     for (unsigned r=0; r<minorM.rows(); ++r) {
        //         for (unsigned c=0; c<minorM.cols(); ++c) {
        //             if (r<r1)
-       //                 minorM.set(r,c,m[r*col+c+1]);
+       //                 minorM(r,c) = m[r*col+c+1];
        //             else
-       //                 minorM.set(r,c,m[(r+1)*col+c+1]);
+       //                 minorM(r,c) = m[(r+1)*col+c+1];
        //         }
        //     }
        //     // recurse down and care for sign:
@@ -1454,9 +1467,9 @@ ex lst_to_matrix(const lst & l)
        for (i=0; i<rows; i++)
                for (j=0; j<cols; j++)
                        if (l.op(i).nops() > j)
-                               m.set(i, j, l.op(i).op(j));
+                               m(i, j) = l.op(i).op(j);
                        else
-                               m.set(i, j, ex(0));
+                               m(i, j) = _ex0();
        return m;
 }
 
@@ -1467,7 +1480,7 @@ ex diag_matrix(const lst & l)
        matrix &m = *new matrix(dim, dim);
        m.setflag(status_flags::dynallocated);
        for (unsigned i=0; i<dim; i++)
-               m.set(i, i, l.op(i));
+               m(i, i) = l.op(i);
 
        return m;
 }