]> www.ginac.de Git - ginac.git/blobdiff - ginac/matrix.cpp
- matrix::pow(): omit last big multiplication if it's not needed.
[ginac.git] / ginac / matrix.cpp
index 5767d4cfb924f91e4ec5adc7c43a59cecc409f05..872b8e006012c3a4ecf6279cf315ed8c34e5f793 100644 (file)
@@ -534,7 +534,7 @@ bool matrix::contract_with(exvector::iterator self, exvector::iterator other, ex
 matrix matrix::add(const matrix & other) const
 {
        if (col != other.col || row != other.row)
-               throw (std::logic_error("matrix::add(): incompatible matrices"));
+               throw std::logic_error("matrix::add(): incompatible matrices");
        
        exvector sum(this->m);
        exvector::iterator i;
@@ -552,7 +552,7 @@ matrix matrix::add(const matrix & other) const
 matrix matrix::sub(const matrix & other) const
 {
        if (col != other.col || row != other.row)
-               throw (std::logic_error("matrix::sub(): incompatible matrices"));
+               throw std::logic_error("matrix::sub(): incompatible matrices");
        
        exvector dif(this->m);
        exvector::iterator i;
@@ -570,7 +570,7 @@ matrix matrix::sub(const matrix & other) const
 matrix matrix::mul(const matrix & other) const
 {
        if (this->cols() != other.rows())
-               throw (std::logic_error("matrix::mul(): incompatible matrices"));
+               throw std::logic_error("matrix::mul(): incompatible matrices");
        
        exvector prod(this->rows()*other.cols());
        
@@ -602,6 +602,9 @@ matrix matrix::mul(const numeric & other) const
 /** Product of matrix and scalar expression. */
 matrix matrix::mul_scalar(const ex & other) const
 {
+       if (other.return_type() != return_types::commutative)
+               throw std::runtime_error("matrix::mul_scalar(): non-commutative scalar");
+
        exvector prod(row * col);
 
        for (unsigned r=0; r<row; ++r)
@@ -612,7 +615,49 @@ matrix matrix::mul_scalar(const ex & other) const
 }
 
 
-/** operator() to access elements.
+/** Power of a matrix.  Currently handles integer exponents only. */
+matrix matrix::pow(const ex & expn) const
+{
+       if (col!=row)
+               throw (std::logic_error("matrix::pow(): matrix not square"));
+       
+       if (is_ex_exactly_of_type(expn, numeric)) {
+               // Integer cases are computed by successive multiplication, using the
+               // obvious shortcut of storing temporaries, like A^4 == (A*A)*(A*A).
+               if (expn.info(info_flags::integer)) {
+                       numeric k;
+                       matrix prod(row,col);
+                       if (expn.info(info_flags::negative)) {
+                               k = -ex_to_numeric(expn);
+                               prod = this->inverse();
+                       } else {
+                               k = ex_to_numeric(expn);
+                               prod = *this;
+                       }
+                       matrix result(row,col);
+                       for (unsigned r=0; r<row; ++r)
+                               result(r,r) = _ex1();
+                       numeric b(1);
+                       // this loop computes the representation of k in base 2 and
+                       // multiplies the factors whenever needed:
+                       while (b.compare(k)<=0) {
+                               b *= numeric(2);
+                               numeric r(mod(k,b));
+                               if (!r.is_zero()) {
+                                       k -= r;
+                                       result = result.mul(prod);
+                               }
+                               if (b.compare(k)<=0)
+                                       prod = prod.mul(prod);
+                       }
+                       return result;
+               }
+       }
+       throw (std::runtime_error("matrix::pow(): don't know how to handle exponent"));
+}
+
+
+/** operator() to access elements for reading.
  *
  *  @param ro row of element
  *  @param co column of element
@@ -626,17 +671,19 @@ const ex & matrix::operator() (unsigned ro, unsigned co) const
 }
 
 
-/** Set individual elements manually.
+/** operator() to access elements for writing.
  *
+ *  @param ro row of element
+ *  @param co column of element
  *  @exception range_error (index out of range) */
-matrix & matrix::set(unsigned ro, unsigned co, ex value)
+ex & matrix::operator() (unsigned ro, unsigned co)
 {
        if (ro>=row || co>=col)
-               throw (std::range_error("matrix::set(): index out of range"));
-    
+               throw (std::range_error("matrix::operator(): index out of range"));
+
        ensure_if_modifiable();
-       m[ro*col+co] = value;
-       return *this;
+       clearflag(status_flags::hash_calculated);
+       return m[ro*col+co];
 }
 
 
@@ -878,7 +925,7 @@ matrix matrix::inverse(void) const
        // First populate the identity matrix supposed to become the right hand side.
        matrix identity(row,col);
        for (unsigned i=0; i<row; ++i)
-               identity.set(i,i,_ex1());
+               identity(i,i) = _ex1();
        
        // Populate a dummy matrix of variables, just because of compatibility with
        // matrix::solve() which wants this (for compatibility with under-determined
@@ -886,7 +933,7 @@ matrix matrix::inverse(void) const
        matrix vars(row,col);
        for (unsigned r=0; r<row; ++r)
                for (unsigned c=0; c<col; ++c)
-                       vars.set(r,c,symbol());
+                       vars(r,c) = symbol();
        
        matrix sol(row,col);
        try {
@@ -986,19 +1033,18 @@ matrix matrix::solve(const matrix & vars,
                                // assign solutions for vars between fnz+1 and
                                // last_assigned_sol-1: free parameters
                                for (unsigned c=fnz; c<last_assigned_sol-1; ++c)
-                                       sol.set(c,co,vars.m[c*p+co]);
+                                       sol(c,co) = vars.m[c*p+co];
                                ex e = aug.m[r*(n+p)+n+co];
                                for (unsigned c=fnz; c<n; ++c)
                                        e -= aug.m[r*(n+p)+c]*sol.m[c*p+co];
-                               sol.set(fnz-1,co,
-                                               (e/(aug.m[r*(n+p)+(fnz-1)])).normal());
+                               sol(fnz-1,co) = (e/(aug.m[r*(n+p)+(fnz-1)])).normal();
                                last_assigned_sol = fnz;
                        }
                }
                // assign solutions for vars between 1 and
                // last_assigned_sol-1: free parameters
                for (unsigned ro=0; ro<last_assigned_sol-1; ++ro)
-                       sol.set(ro,co,vars(ro,co));
+                       sol(ro,co) = vars(ro,co);
        }
        
        return sol;
@@ -1042,9 +1088,9 @@ ex matrix::determinant_minor(void) const
        //     for (unsigned r=0; r<minorM.rows(); ++r) {
        //         for (unsigned c=0; c<minorM.cols(); ++c) {
        //             if (r<r1)
-       //                 minorM.set(r,c,m[r*col+c+1]);
+       //                 minorM(r,c) = m[r*col+c+1];
        //             else
-       //                 minorM.set(r,c,m[(r+1)*col+c+1]);
+       //                 minorM(r,c) = m[(r+1)*col+c+1];
        //         }
        //     }
        //     // recurse down and care for sign:
@@ -1421,9 +1467,9 @@ ex lst_to_matrix(const lst & l)
        for (i=0; i<rows; i++)
                for (j=0; j<cols; j++)
                        if (l.op(i).nops() > j)
-                               m.set(i, j, l.op(i).op(j));
+                               m(i, j) = l.op(i).op(j);
                        else
-                               m.set(i, j, ex(0));
+                               m(i, j) = _ex0();
        return m;
 }
 
@@ -1434,7 +1480,7 @@ ex diag_matrix(const lst & l)
        matrix &m = *new matrix(dim, dim);
        m.setflag(status_flags::dynallocated);
        for (unsigned i=0; i<dim; i++)
-               m.set(i, i, l.op(i));
+               m(i, i) = l.op(i);
 
        return m;
 }