]> www.ginac.de Git - ginac.git/blobdiff - ginac/matrix.cpp
- ++version_major.
[ginac.git] / ginac / matrix.cpp
index 65b1254936e0a378eb5ba1760ee9709a3a52c409..28c2f1d25a853faf7750220d1a82867d9f425775 100644 (file)
@@ -156,23 +156,23 @@ void matrix::print(const print_context & c, unsigned level) const
 
        } else {
 
-               c.s << "[";
+               c.s << "[";
                for (unsigned y=0; y<row-1; ++y) {
-                       c.s << "[[";
+                       c.s << "[";
                        for (unsigned x=0; x<col-1; ++x) {
                                m[y*col+x].print(c);
                                c.s << ",";
                        }
                        m[col*(y+1)-1].print(c);
-                       c.s << "]], ";
+                       c.s << "],";
                }
-               c.s << "[[";
+               c.s << "[";
                for (unsigned x=0; x<col-1; ++x) {
                        m[(row-1)*col+x].print(c);
                        c.s << ",";
                }
                m[row*col-1].print(c);
-               c.s << "]] ]]";
+               c.s << "]]";
 
        }
 }
@@ -534,7 +534,7 @@ bool matrix::contract_with(exvector::iterator self, exvector::iterator other, ex
 matrix matrix::add(const matrix & other) const
 {
        if (col != other.col || row != other.row)
-               throw (std::logic_error("matrix::add(): incompatible matrices"));
+               throw std::logic_error("matrix::add(): incompatible matrices");
        
        exvector sum(this->m);
        exvector::iterator i;
@@ -552,7 +552,7 @@ matrix matrix::add(const matrix & other) const
 matrix matrix::sub(const matrix & other) const
 {
        if (col != other.col || row != other.row)
-               throw (std::logic_error("matrix::sub(): incompatible matrices"));
+               throw std::logic_error("matrix::sub(): incompatible matrices");
        
        exvector dif(this->m);
        exvector::iterator i;
@@ -570,7 +570,7 @@ matrix matrix::sub(const matrix & other) const
 matrix matrix::mul(const matrix & other) const
 {
        if (this->cols() != other.rows())
-               throw (std::logic_error("matrix::mul(): incompatible matrices"));
+               throw std::logic_error("matrix::mul(): incompatible matrices");
        
        exvector prod(this->rows()*other.cols());
        
@@ -599,6 +599,63 @@ matrix matrix::mul(const numeric & other) const
 }
 
 
+/** Product of matrix and scalar expression. */
+matrix matrix::mul_scalar(const ex & other) const
+{
+       if (other.return_type() != return_types::commutative)
+               throw std::runtime_error("matrix::mul_scalar(): non-commutative scalar");
+
+       exvector prod(row * col);
+
+       for (unsigned r=0; r<row; ++r)
+               for (unsigned c=0; c<col; ++c)
+                       prod[r*col+c] = m[r*col+c] * other;
+
+       return matrix(row, col, prod);
+}
+
+
+/** Power of a matrix.  Currently handles integer exponents only. */
+matrix matrix::pow(const ex & expn) const
+{
+       if (col!=row)
+               throw (std::logic_error("matrix::pow(): matrix not square"));
+       
+       if (is_ex_exactly_of_type(expn, numeric)) {
+               // Integer cases are computed by successive multiplication, using the
+               // obvious shortcut of storing temporaries, like A^4 == (A*A)*(A*A).
+               if (expn.info(info_flags::integer)) {
+                       numeric k;
+                       matrix prod(row,col);
+                       if (expn.info(info_flags::negative)) {
+                               k = -ex_to_numeric(expn);
+                               prod = this->inverse();
+                       } else {
+                               k = ex_to_numeric(expn);
+                               prod = *this;
+                       }
+                       matrix result(row,col);
+                       for (unsigned r=0; r<row; ++r)
+                               result.set(r,r,_ex1());
+                       numeric b(1);
+                       // this loop computes the representation of k in base 2 and multiplies
+                       // the factors whenever needed:
+                       while (b.compare(k)<=0) {
+                               b *= numeric(2);
+                               numeric r(mod(k,b));
+                               if (!r.is_zero()) {
+                                       k -= r;
+                                       result = result.mul(prod);
+                               }
+                               prod = prod.mul(prod);
+                       }
+                       return result;
+               }
+       }
+       throw (std::runtime_error("matrix::pow(): don't know how to handle exponent"));
+}
+
+
 /** operator() to access elements.
  *
  *  @param ro row of element
@@ -620,6 +677,8 @@ matrix & matrix::set(unsigned ro, unsigned co, ex value)
 {
        if (ro>=row || co>=col)
                throw (std::range_error("matrix::set(): index out of range"));
+       if (value.return_type() != return_types::commutative)
+               throw std::runtime_error("matrix::set(): non-commutative argument");
     
        ensure_if_modifiable();
        m[ro*col+co] = value;
@@ -640,7 +699,6 @@ matrix matrix::transpose(void) const
        return matrix(this->cols(),this->rows(),trans);
 }
 
-
 /** Determinant of square matrix.  This routine doesn't actually calculate the
  *  determinant, it only implements some heuristics about which algorithm to
  *  run.  If all the elements of the matrix are elements of an integral domain
@@ -756,7 +814,8 @@ ex matrix::determinant(unsigned algo) const
                        std::vector<unsigned> pre_sort;
                        for (std::vector<uintpair>::iterator i=c_zeros.begin(); i!=c_zeros.end(); ++i)
                                pre_sort.push_back(i->second);
-                       int sign = permutation_sign(pre_sort);
+                       std::vector<unsigned> pre_sort_test(pre_sort); // permutation_sign() modifies the vector so we make a copy here
+                       int sign = permutation_sign(pre_sort_test.begin(), pre_sort_test.end());
                        exvector result(row*col);  // represents sorted matrix
                        unsigned c = 0;
                        for (std::vector<unsigned>::iterator i=pre_sort.begin();