X-Git-Url: https://www.ginac.de/ginac.git//ginac.git?p=ginac.git;a=blobdiff_plain;f=ginac%2Fmatrix.cpp;h=650799e55f6353db6807e1d42a8d7904cc55291b;hp=ec7ee2c8e004415bd04eb4e8df44acc9d3fa7294;hb=08d556dc3ac3fbf2b0ad3acd37016a1f925d7c02;hpb=e58227e1112f989f3b5417e497a61d53fc2971fa diff --git a/ginac/matrix.cpp b/ginac/matrix.cpp index ec7ee2c8..650799e5 100644 --- a/ginac/matrix.cpp +++ b/ginac/matrix.cpp @@ -28,21 +28,20 @@ #include "archive.h" #include "numeric.h" #include "lst.h" +#include "idx.h" +#include "indexed.h" #include "utils.h" #include "debugmsg.h" #include "power.h" #include "symbol.h" #include "normal.h" -#ifndef NO_NAMESPACE_GINAC namespace GiNaC { -#endif // ndef NO_NAMESPACE_GINAC GINAC_IMPLEMENT_REGISTERED_CLASS(matrix, basic) ////////// -// default constructor, destructor, copy constructor, assignment operator -// and helpers: +// default ctor, dtor, copy ctor, assignment operator and helpers: ////////// // public @@ -50,34 +49,13 @@ GINAC_IMPLEMENT_REGISTERED_CLASS(matrix, basic) /** Default ctor. Initializes to 1 x 1-dimensional zero-matrix. */ matrix::matrix() : inherited(TINFO_matrix), row(1), col(1) { - debugmsg("matrix default constructor",LOGLEVEL_CONSTRUCT); + debugmsg("matrix default ctor",LOGLEVEL_CONSTRUCT); m.push_back(_ex0()); } -matrix::~matrix() -{ - debugmsg("matrix destructor",LOGLEVEL_DESTRUCT); - destroy(false); -} - -matrix::matrix(const matrix & other) -{ - debugmsg("matrix copy constructor",LOGLEVEL_CONSTRUCT); - copy(other); -} - -const matrix & matrix::operator=(const matrix & other) -{ - debugmsg("matrix operator=",LOGLEVEL_ASSIGNMENT); - if (this != &other) { - destroy(true); - copy(other); - } - return *this; -} - // protected +/** For use by copy ctor and assignment operator. */ void matrix::copy(const matrix & other) { inherited::copy(other); @@ -92,7 +70,7 @@ void matrix::destroy(bool call_parent) } ////////// -// other constructors +// other ctors ////////// // public @@ -104,7 +82,7 @@ void matrix::destroy(bool call_parent) matrix::matrix(unsigned r, unsigned c) : inherited(TINFO_matrix), row(r), col(c) { - debugmsg("matrix constructor from unsigned,unsigned",LOGLEVEL_CONSTRUCT); + debugmsg("matrix ctor from unsigned,unsigned",LOGLEVEL_CONSTRUCT); m.resize(r*c, _ex0()); } @@ -114,7 +92,26 @@ matrix::matrix(unsigned r, unsigned c) matrix::matrix(unsigned r, unsigned c, const exvector & m2) : inherited(TINFO_matrix), row(r), col(c), m(m2) { - debugmsg("matrix constructor from unsigned,unsigned,exvector",LOGLEVEL_CONSTRUCT); + debugmsg("matrix ctor from unsigned,unsigned,exvector",LOGLEVEL_CONSTRUCT); +} + +/** Construct matrix from (flat) list of elements. If the list has fewer + * elements than the matrix, the remaining matrix elements are set to zero. + * If the list has more elements than the matrix, the excessive elements are + * thrown away. */ +matrix::matrix(unsigned r, unsigned c, const lst & l) + : inherited(TINFO_matrix), row(r), col(c) +{ + debugmsg("matrix ctor from unsigned,unsigned,lst",LOGLEVEL_CONSTRUCT); + m.resize(r*c, _ex0()); + + for (unsigned i=0; i= r) + break; // matrix smaller than list: throw away excessive elements + m[y*c+x] = l.op(i); + } } ////////// @@ -124,7 +121,7 @@ matrix::matrix(unsigned r, unsigned c, const exvector & m2) /** Construct object from archive_node. */ matrix::matrix(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst) { - debugmsg("matrix constructor from archive_node", LOGLEVEL_CONSTRUCT); + debugmsg("matrix ctor from archive_node", LOGLEVEL_CONSTRUCT); if (!(n.find_unsigned("row", row)) || !(n.find_unsigned("col", col))) throw (std::runtime_error("unknown matrix dimensions in archive")); m.reserve(row * col); @@ -162,12 +159,6 @@ void matrix::archive(archive_node &n) const // public -basic * matrix::duplicate() const -{ - debugmsg("matrix duplicate",LOGLEVEL_DUPLICATE); - return new matrix(*this); -} - void matrix::print(std::ostream & os, unsigned upper_precedence) const { debugmsg("matrix print",LOGLEVEL_PRINT); @@ -187,7 +178,7 @@ void matrix::print(std::ostream & os, unsigned upper_precedence) const void matrix::printraw(std::ostream & os) const { debugmsg("matrix printraw",LOGLEVEL_PRINT); - os << "matrix(" << row << "," << col <<","; + os << class_name() << "(" << row << "," << col <<","; for (unsigned r=0; r(i).all_index_values_are(info_flags::nonnegint); + + // Check indices + if (i.nops() == 2) { + + // One index, must be one-dimensional vector + if (row != 1 && col != 1) + throw (std::runtime_error("matrix::eval_indexed(): vector must have exactly 1 index")); + + const idx & i1 = ex_to_idx(i.op(1)); + + if (col == 1) { + + // Column vector + if (!i1.get_dim().is_equal(row)) + throw (std::runtime_error("matrix::eval_indexed(): dimension of index must match number of vector elements")); + + // Index numeric -> return vector element + if (all_indices_unsigned) { + unsigned n1 = ex_to_numeric(i1.get_value()).to_int(); + if (n1 >= row) + throw (std::runtime_error("matrix::eval_indexed(): value of index exceeds number of vector elements")); + return (*this)(n1, 0); + } + + } else { + + // Row vector + if (!i1.get_dim().is_equal(col)) + throw (std::runtime_error("matrix::eval_indexed(): dimension of index must match number of vector elements")); + + // Index numeric -> return vector element + if (all_indices_unsigned) { + unsigned n1 = ex_to_numeric(i1.get_value()).to_int(); + if (n1 >= col) + throw (std::runtime_error("matrix::eval_indexed(): value of index exceeds number of vector elements")); + return (*this)(0, n1); + } + } + + } else if (i.nops() == 3) { + + // Two indices + const idx & i1 = ex_to_idx(i.op(1)); + const idx & i2 = ex_to_idx(i.op(2)); + + if (!i1.get_dim().is_equal(row)) + throw (std::runtime_error("matrix::eval_indexed(): dimension of first index must match number of rows")); + if (!i2.get_dim().is_equal(col)) + throw (std::runtime_error("matrix::eval_indexed(): dimension of second index must match number of columns")); + + // Pair of dummy indices -> compute trace + if (is_dummy_pair(i1, i2)) + return trace(); + + // Both indices numeric -> return matrix element + if (all_indices_unsigned) { + unsigned n1 = ex_to_numeric(i1.get_value()).to_int(), n2 = ex_to_numeric(i2.get_value()).to_int(); + if (n1 >= row) + throw (std::runtime_error("matrix::eval_indexed(): value of first index exceeds number of rows")); + if (n2 >= col) + throw (std::runtime_error("matrix::eval_indexed(): value of second index exceeds number of columns")); + return (*this)(n1, n2); + } + + } else + throw (std::runtime_error("matrix::eval_indexed(): matrix must have exactly 2 indices")); + + return i.hold(); +} + +/** Sum of two indexed matrices. */ +ex matrix::add_indexed(const ex & self, const ex & other) const +{ + GINAC_ASSERT(is_ex_of_type(self, indexed)); + GINAC_ASSERT(is_ex_of_type(self.op(0), matrix)); + GINAC_ASSERT(is_ex_of_type(other, indexed)); + GINAC_ASSERT(self.nops() == 2 || self.nops() == 3); + + // Only add two matrices + if (is_ex_of_type(other.op(0), matrix)) { + GINAC_ASSERT(other.nops() == 2 || other.nops() == 3); + + const matrix &self_matrix = ex_to_matrix(self.op(0)); + const matrix &other_matrix = ex_to_matrix(other.op(0)); + + if (self.nops() == 2 && other.nops() == 2) { // vector + vector + + if (self_matrix.row == other_matrix.row) + return indexed(self_matrix.add(other_matrix), self.op(1)); + else if (self_matrix.row == other_matrix.col) + return indexed(self_matrix.add(other_matrix.transpose()), self.op(1)); + + } else if (self.nops() == 3 && other.nops() == 3) { // matrix + matrix + + if (self.op(1).is_equal(other.op(1)) && self.op(2).is_equal(other.op(2))) + return indexed(self_matrix.add(other_matrix), self.op(1), self.op(2)); + else if (self.op(1).is_equal(other.op(2)) && self.op(2).is_equal(other.op(1))) + return indexed(self_matrix.add(other_matrix.transpose()), self.op(1), self.op(2)); + + } + } + + // Don't know what to do, return unevaluated sum + return self + other; +} + +/** Product of an indexed matrix with a number. */ +ex matrix::scalar_mul_indexed(const ex & self, const numeric & other) const +{ + GINAC_ASSERT(is_ex_of_type(self, indexed)); + GINAC_ASSERT(is_ex_of_type(self.op(0), matrix)); + GINAC_ASSERT(self.nops() == 2 || self.nops() == 3); + + const matrix &self_matrix = ex_to_matrix(self.op(0)); + + if (self.nops() == 2) + return indexed(self_matrix.mul(other), self.op(1)); + else // self.nops() == 3 + return indexed(self_matrix.mul(other), self.op(1), self.op(2)); +} + +/** Contraction of an indexed matrix with something else. */ +bool matrix::contract_with(exvector::iterator self, exvector::iterator other, exvector & v) const +{ + GINAC_ASSERT(is_ex_of_type(*self, indexed)); + GINAC_ASSERT(is_ex_of_type(*other, indexed)); + GINAC_ASSERT(self->nops() == 2 || self->nops() == 3); + GINAC_ASSERT(is_ex_of_type(self->op(0), matrix)); + + // Only contract with other matrices + if (!is_ex_of_type(other->op(0), matrix)) + return false; + + GINAC_ASSERT(other->nops() == 2 || other->nops() == 3); + + const matrix &self_matrix = ex_to_matrix(self->op(0)); + const matrix &other_matrix = ex_to_matrix(other->op(0)); + + if (self->nops() == 2) { + unsigned self_dim = (self_matrix.col == 1) ? self_matrix.row : self_matrix.col; + + if (other->nops() == 2) { // vector * vector (scalar product) + unsigned other_dim = (other_matrix.col == 1) ? other_matrix.row : other_matrix.col; + + if (self_matrix.col == 1) { + if (other_matrix.col == 1) { + // Column vector * column vector, transpose first vector + *self = self_matrix.transpose().mul(other_matrix)(0, 0); + } else { + // Column vector * row vector, swap factors + *self = other_matrix.mul(self_matrix)(0, 0); + } + } else { + if (other_matrix.col == 1) { + // Row vector * column vector, perfect + *self = self_matrix.mul(other_matrix)(0, 0); + } else { + // Row vector * row vector, transpose second vector + *self = self_matrix.mul(other_matrix.transpose())(0, 0); + } + } + *other = _ex1(); + return true; + + } else { // vector * matrix + + // B_i * A_ij = (B*A)_j (B is row vector) + if (is_dummy_pair(self->op(1), other->op(1))) { + if (self_matrix.row == 1) + *self = indexed(self_matrix.mul(other_matrix), other->op(2)); + else + *self = indexed(self_matrix.transpose().mul(other_matrix), other->op(2)); + *other = _ex1(); + return true; + } + + // B_j * A_ij = (A*B)_i (B is column vector) + if (is_dummy_pair(self->op(1), other->op(2))) { + if (self_matrix.col == 1) + *self = indexed(other_matrix.mul(self_matrix), other->op(1)); + else + *self = indexed(other_matrix.mul(self_matrix.transpose()), other->op(1)); + *other = _ex1(); + return true; + } + } + + } else if (other->nops() == 3) { // matrix * matrix + + // A_ij * B_jk = (A*B)_ik + if (is_dummy_pair(self->op(2), other->op(1))) { + *self = indexed(self_matrix.mul(other_matrix), self->op(1), other->op(2)); + *other = _ex1(); + return true; + } + + // A_ij * B_kj = (A*Btrans)_ik + if (is_dummy_pair(self->op(2), other->op(2))) { + *self = indexed(self_matrix.mul(other_matrix.transpose()), self->op(1), other->op(1)); + *other = _ex1(); + return true; + } + + // A_ji * B_jk = (Atrans*B)_ik + if (is_dummy_pair(self->op(1), other->op(1))) { + *self = indexed(self_matrix.transpose().mul(other_matrix), self->op(2), other->op(2)); + *other = _ex1(); + return true; + } + + // A_ji * B_kj = (B*A)_ki + if (is_dummy_pair(self->op(1), other->op(2))) { + *self = indexed(other_matrix.mul(self_matrix), other->op(1), self->op(2)); + *other = _ex1(); + return true; + } + } + + return false; +} + + ////////// // non-virtual functions in this class ////////// @@ -386,6 +616,19 @@ matrix matrix::mul(const matrix & other) const } +/** Product of matrix and scalar. */ +matrix matrix::mul(const numeric & other) const +{ + exvector prod(row * col); + + for (unsigned r=0; r