matrix matrix::add(const matrix & other) const
{
if (col != other.col || row != other.row)
- throw (std::logic_error("matrix::add(): incompatible matrices"));
+ throw std::logic_error("matrix::add(): incompatible matrices");
exvector sum(this->m);
exvector::iterator i;
matrix matrix::sub(const matrix & other) const
{
if (col != other.col || row != other.row)
- throw (std::logic_error("matrix::sub(): incompatible matrices"));
+ throw std::logic_error("matrix::sub(): incompatible matrices");
exvector dif(this->m);
exvector::iterator i;
matrix matrix::mul(const matrix & other) const
{
if (this->cols() != other.rows())
- throw (std::logic_error("matrix::mul(): incompatible matrices"));
+ throw std::logic_error("matrix::mul(): incompatible matrices");
exvector prod(this->rows()*other.cols());
/** Product of matrix and scalar expression. */
matrix matrix::mul_scalar(const ex & other) const
{
+ if (other.return_type() != return_types::commutative)
+ throw std::runtime_error("matrix::mul_scalar(): non-commutative scalar");
+
exvector prod(row * col);
for (unsigned r=0; r<row; ++r)
}
+/** Power of a matrix. Currently handles integer exponents only. */
+matrix matrix::pow(const ex & expn) const
+{
+ if (col!=row)
+ throw (std::logic_error("matrix::pow(): matrix not square"));
+
+ if (is_ex_exactly_of_type(expn, numeric)) {
+ // Integer cases are computed by successive multiplication, using the
+ // obvious shortcut of storing temporaries, like A^4 == (A*A)*(A*A).
+ if (expn.info(info_flags::integer)) {
+ numeric k;
+ matrix prod(row,col);
+ if (expn.info(info_flags::negative)) {
+ k = -ex_to_numeric(expn);
+ prod = this->inverse();
+ } else {
+ k = ex_to_numeric(expn);
+ prod = *this;
+ }
+ matrix result(row,col);
+ for (unsigned r=0; r<row; ++r)
+ result.set(r,r,_ex1());
+ numeric b(1);
+ // this loop computes the representation of k in base 2 and multiplies
+ // the factors whenever needed:
+ while (b.compare(k)<=0) {
+ b *= numeric(2);
+ numeric r(mod(k,b));
+ if (!r.is_zero()) {
+ k -= r;
+ result = result.mul(prod);
+ }
+ prod = prod.mul(prod);
+ }
+ return result;
+ }
+ }
+ throw (std::runtime_error("matrix::pow(): don't know how to handle exponent"));
+}
+
+
/** operator() to access elements.
*
* @param ro row of element
{
if (ro>=row || co>=col)
throw (std::range_error("matrix::set(): index out of range"));
+ if (value.return_type() != return_types::commutative)
+ throw std::runtime_error("matrix::set(): non-commutative argument");
ensure_if_modifiable();
m[ro*col+co] = value;