X-Git-Url: https://www.ginac.de/ginac.git//ginac.git?a=blobdiff_plain;f=ginac%2Fmatrix.cpp;h=60869f09fa709bbd17b051206fa4b32c29db0e73;hb=ff7eeaa487dca86b36ddbb19b1dcefb2134327be;hp=65b1254936e0a378eb5ba1760ee9709a3a52c409;hpb=b9cd4b49ffbfbf3e1c36a2b594ec3148a5baca64;p=ginac.git diff --git a/ginac/matrix.cpp b/ginac/matrix.cpp index 65b12549..60869f09 100644 --- a/ginac/matrix.cpp +++ b/ginac/matrix.cpp @@ -141,7 +141,7 @@ void matrix::archive(archive_node &n) const DEFAULT_UNARCHIVE(matrix) ////////// -// functions overriding virtual functions from bases classes +// functions overriding virtual functions from base classes ////////// // public @@ -156,23 +156,23 @@ void matrix::print(const print_context & c, unsigned level) const } else { - c.s << "[[ "; + c.s << "["; for (unsigned y=0; y(const_cast(other)); + const matrix & o = static_cast(other); // compare number of rows if (row != o.rows()) @@ -293,6 +259,16 @@ int matrix::compare_same_type(const basic & other) const return 0; } +bool matrix::match_same_type(const basic & other) const +{ + GINAC_ASSERT(is_exactly_of_type(other, matrix)); + const matrix & o = static_cast(other); + + // The number of rows and columns must be the same. This is necessary to + // prevent a 2x3 matrix from matching a 3x2 one. + return row == o.rows() && col == o.cols(); +} + /** Automatic symbolic evaluation of an indexed matrix. */ ex matrix::eval_indexed(const basic & i) const { @@ -308,7 +284,7 @@ ex matrix::eval_indexed(const basic & i) const if (row != 1 && col != 1) throw (std::runtime_error("matrix::eval_indexed(): vector must have exactly 1 index")); - const idx & i1 = ex_to_idx(i.op(1)); + const idx & i1 = ex_to(i.op(1)); if (col == 1) { @@ -318,7 +294,7 @@ ex matrix::eval_indexed(const basic & i) const // Index numeric -> return vector element if (all_indices_unsigned) { - unsigned n1 = ex_to_numeric(i1.get_value()).to_int(); + unsigned n1 = ex_to(i1.get_value()).to_int(); if (n1 >= row) throw (std::runtime_error("matrix::eval_indexed(): value of index exceeds number of vector elements")); return (*this)(n1, 0); @@ -332,7 +308,7 @@ ex matrix::eval_indexed(const basic & i) const // Index numeric -> return vector element if (all_indices_unsigned) { - unsigned n1 = ex_to_numeric(i1.get_value()).to_int(); + unsigned n1 = ex_to(i1.get_value()).to_int(); if (n1 >= col) throw (std::runtime_error("matrix::eval_indexed(): value of index exceeds number of vector elements")); return (*this)(0, n1); @@ -342,8 +318,8 @@ ex matrix::eval_indexed(const basic & i) const } else if (i.nops() == 3) { // Two indices - const idx & i1 = ex_to_idx(i.op(1)); - const idx & i2 = ex_to_idx(i.op(2)); + const idx & i1 = ex_to(i.op(1)); + const idx & i2 = ex_to(i.op(2)); if (!i1.get_dim().is_equal(row)) throw (std::runtime_error("matrix::eval_indexed(): dimension of first index must match number of rows")); @@ -356,7 +332,7 @@ ex matrix::eval_indexed(const basic & i) const // Both indices numeric -> return matrix element if (all_indices_unsigned) { - unsigned n1 = ex_to_numeric(i1.get_value()).to_int(), n2 = ex_to_numeric(i2.get_value()).to_int(); + unsigned n1 = ex_to(i1.get_value()).to_int(), n2 = ex_to(i2.get_value()).to_int(); if (n1 >= row) throw (std::runtime_error("matrix::eval_indexed(): value of first index exceeds number of rows")); if (n2 >= col) @@ -382,8 +358,8 @@ ex matrix::add_indexed(const ex & self, const ex & other) const if (is_ex_of_type(other.op(0), matrix)) { GINAC_ASSERT(other.nops() == 2 || other.nops() == 3); - const matrix &self_matrix = ex_to_matrix(self.op(0)); - const matrix &other_matrix = ex_to_matrix(other.op(0)); + const matrix &self_matrix = ex_to(self.op(0)); + const matrix &other_matrix = ex_to(other.op(0)); if (self.nops() == 2 && other.nops() == 2) { // vector + vector @@ -413,7 +389,7 @@ ex matrix::scalar_mul_indexed(const ex & self, const numeric & other) const GINAC_ASSERT(is_ex_of_type(self.op(0), matrix)); GINAC_ASSERT(self.nops() == 2 || self.nops() == 3); - const matrix &self_matrix = ex_to_matrix(self.op(0)); + const matrix &self_matrix = ex_to(self.op(0)); if (self.nops() == 2) return indexed(self_matrix.mul(other), self.op(1)); @@ -435,14 +411,12 @@ bool matrix::contract_with(exvector::iterator self, exvector::iterator other, ex GINAC_ASSERT(other->nops() == 2 || other->nops() == 3); - const matrix &self_matrix = ex_to_matrix(self->op(0)); - const matrix &other_matrix = ex_to_matrix(other->op(0)); + const matrix &self_matrix = ex_to(self->op(0)); + const matrix &other_matrix = ex_to(other->op(0)); if (self->nops() == 2) { - unsigned self_dim = (self_matrix.col == 1) ? self_matrix.row : self_matrix.col; if (other->nops() == 2) { // vector * vector (scalar product) - unsigned other_dim = (other_matrix.col == 1) ? other_matrix.row : other_matrix.col; if (self_matrix.col == 1) { if (other_matrix.col == 1) { @@ -534,13 +508,13 @@ bool matrix::contract_with(exvector::iterator self, exvector::iterator other, ex matrix matrix::add(const matrix & other) const { if (col != other.col || row != other.row) - throw (std::logic_error("matrix::add(): incompatible matrices")); + throw std::logic_error("matrix::add(): incompatible matrices"); exvector sum(this->m); - exvector::iterator i; - exvector::const_iterator ci; - for (i=sum.begin(), ci=other.m.begin(); i!=sum.end(); ++i, ++ci) - (*i) += (*ci); + exvector::iterator i = sum.begin(), end = sum.end(); + exvector::const_iterator ci = other.m.begin(); + while (i != end) + *i++ += *ci++; return matrix(row,col,sum); } @@ -552,13 +526,13 @@ matrix matrix::add(const matrix & other) const matrix matrix::sub(const matrix & other) const { if (col != other.col || row != other.row) - throw (std::logic_error("matrix::sub(): incompatible matrices")); + throw std::logic_error("matrix::sub(): incompatible matrices"); exvector dif(this->m); - exvector::iterator i; - exvector::const_iterator ci; - for (i=dif.begin(), ci=other.m.begin(); i!=dif.end(); ++i, ++ci) - (*i) -= (*ci); + exvector::iterator i = dif.begin(), end = dif.end(); + exvector::const_iterator ci = other.m.begin(); + while (i != end) + *i++ -= *ci++; return matrix(row,col,dif); } @@ -570,7 +544,7 @@ matrix matrix::sub(const matrix & other) const matrix matrix::mul(const matrix & other) const { if (this->cols() != other.rows()) - throw (std::logic_error("matrix::mul(): incompatible matrices")); + throw std::logic_error("matrix::mul(): incompatible matrices"); exvector prod(this->rows()*other.cols()); @@ -599,7 +573,64 @@ matrix matrix::mul(const numeric & other) const } -/** operator() to access elements. +/** Product of matrix and scalar expression. */ +matrix matrix::mul_scalar(const ex & other) const +{ + if (other.return_type() != return_types::commutative) + throw std::runtime_error("matrix::mul_scalar(): non-commutative scalar"); + + exvector prod(row * col); + + for (unsigned r=0; r(expn); + matrix A(row,col); + if (expn.info(info_flags::negative)) { + b *= -1; + A = this->inverse(); + } else { + A = *this; + } + matrix C(row,col); + for (unsigned r=0; r=row || co>=col) - throw (std::range_error("matrix::set(): index out of range")); - + throw (std::range_error("matrix::operator(): index out of range")); + ensure_if_modifiable(); - m[ro*col+co] = value; - return *this; + return m[ro*col+co]; } @@ -640,7 +672,6 @@ matrix matrix::transpose(void) const return matrix(this->cols(),this->rows(),trans); } - /** Determinant of square matrix. This routine doesn't actually calculate the * determinant, it only implements some heuristics about which algorithm to * run. If all the elements of the matrix are elements of an integral domain @@ -665,9 +696,10 @@ ex matrix::determinant(unsigned algo) const bool numeric_flag = true; bool normal_flag = false; unsigned sparse_count = 0; // counts non-zero elements - for (exvector::const_iterator r=m.begin(); r!=m.end(); ++r) { + exvector::const_iterator r = m.begin(), rend = m.end(); + while (r != rend) { lst srl; // symbol replacement list - ex rtest = (*r).to_rational(srl); + ex rtest = r->to_rational(srl); if (!rtest.is_zero()) ++sparse_count; if (!rtest.info(info_flags::numeric)) @@ -675,6 +707,7 @@ ex matrix::determinant(unsigned algo) const if (!rtest.info(info_flags::crational_polynomial) && rtest.info(info_flags::rational_function)) normal_flag = true; + ++r; } // Here is the heuristics in case this routine has to decide: @@ -754,12 +787,13 @@ ex matrix::determinant(unsigned algo) const } sort(c_zeros.begin(),c_zeros.end()); std::vector pre_sort; - for (std::vector::iterator i=c_zeros.begin(); i!=c_zeros.end(); ++i) + for (std::vector::const_iterator i=c_zeros.begin(); i!=c_zeros.end(); ++i) pre_sort.push_back(i->second); - int sign = permutation_sign(pre_sort); + std::vector pre_sort_test(pre_sort); // permutation_sign() modifies the vector so we make a copy here + int sign = permutation_sign(pre_sort_test.begin(), pre_sort_test.end()); exvector result(row*col); // represents sorted matrix unsigned c = 0; - for (std::vector::iterator i=pre_sort.begin(); + for (std::vector::const_iterator i=pre_sort.begin(); i!=pre_sort.end(); ++i,++c) { for (unsigned r=0; rinfo(info_flags::numeric)) numeric_flag = false; - } + ++r; } // The pure numeric case is traditionally rather common. Hence, it is @@ -865,7 +900,7 @@ matrix matrix::inverse(void) const // First populate the identity matrix supposed to become the right hand side. matrix identity(row,col); for (unsigned i=0; iinfo(info_flags::numeric)) numeric_flag = false; + ++r; } // Here is the heuristics in case this routine has to decide: @@ -973,19 +1010,18 @@ matrix matrix::solve(const matrix & vars, // assign solutions for vars between fnz+1 and // last_assigned_sol-1: free parameters for (unsigned c=fnz; cm.begin(); - exvector::iterator tmp_n_it = tmp_n.m.begin(); - exvector::iterator tmp_d_it = tmp_d.m.begin(); - for (; it!= this->m.end(); ++it, ++tmp_n_it, ++tmp_d_it) { - (*tmp_n_it) = (*it).normal().to_rational(srl); - (*tmp_d_it) = (*tmp_n_it).denom(); - (*tmp_n_it) = (*tmp_n_it).numer(); + exvector::const_iterator cit = this->m.begin(), citend = this->m.end(); + exvector::iterator tmp_n_it = tmp_n.m.begin(), tmp_d_it = tmp_d.m.begin(); + while (cit != citend) { + ex nd = cit->normal().to_rational(srl).numer_denom(); + ++cit; + *tmp_n_it++ = nd.op(0); + *tmp_d_it++ = nd.op(1); } unsigned r0 = 0; @@ -1333,11 +1369,11 @@ int matrix::fraction_free_elimination(const bool det) } } // repopulate *this matrix: - it = this->m.begin(); + exvector::iterator it = this->m.begin(), itend = this->m.end(); tmp_n_it = tmp_n.m.begin(); tmp_d_it = tmp_d.m.begin(); - for (; it!= this->m.end(); ++it, ++tmp_n_it, ++tmp_d_it) - (*it) = ((*tmp_n_it)/(*tmp_d_it)).subs(srl); + while (it != itend) + *it++ = ((*tmp_n_it++)/(*tmp_d_it++)).subs(srl); return sign; } @@ -1367,10 +1403,10 @@ int matrix::pivot(unsigned ro, unsigned co, bool symbolic) // search largest element in column co beginning at row ro GINAC_ASSERT(is_ex_of_type(this->m[k*col+co],numeric)); unsigned kmax = k+1; - numeric mmax = abs(ex_to_numeric(m[kmax*col+co])); + numeric mmax = abs(ex_to(m[kmax*col+co])); while (kmaxm[kmax*col+co],numeric)); - numeric tmp = ex_to_numeric(this->m[kmax*col+co]); + numeric tmp = ex_to(this->m[kmax*col+co]); if (abs(tmp) > mmax) { mmax = tmp; k = kmax; @@ -1408,9 +1444,9 @@ ex lst_to_matrix(const lst & l) for (i=0; i j) - m.set(i, j, l.op(i).op(j)); + m(i, j) = l.op(i).op(j); else - m.set(i, j, ex(0)); + m(i, j) = _ex0(); return m; } @@ -1421,7 +1457,7 @@ ex diag_matrix(const lst & l) matrix &m = *new matrix(dim, dim); m.setflag(status_flags::dynallocated); for (unsigned i=0; i