]> www.ginac.de Git - ginac.git/blobdiff - ginac/matrix.cpp
sums of indexed matrices are now possible
[ginac.git] / ginac / matrix.cpp
index 81d31bdb9e8739a5d073a2e20ddbb95e58f25c3f..f34534afdb1ab3ee39b567bd11c7ed7408010855 100644 (file)
 #include "archive.h"
 #include "numeric.h"
 #include "lst.h"
+#include "idx.h"
+#include "indexed.h"
 #include "utils.h"
 #include "debugmsg.h"
 #include "power.h"
 #include "symbol.h"
 #include "normal.h"
 
-#ifndef NO_NAMESPACE_GINAC
 namespace GiNaC {
-#endif // ndef NO_NAMESPACE_GINAC
 
 GINAC_IMPLEMENT_REGISTERED_CLASS(matrix, basic)
 
 //////////
-// default constructor, destructor, copy constructor, assignment operator
-// and helpers:
+// default ctor, dtor, copy ctor, assignment operator and helpers:
 //////////
 
 // public
@@ -50,12 +49,13 @@ GINAC_IMPLEMENT_REGISTERED_CLASS(matrix, basic)
 /** Default ctor.  Initializes to 1 x 1-dimensional zero-matrix. */
 matrix::matrix() : inherited(TINFO_matrix), row(1), col(1)
 {
-       debugmsg("matrix default constructor",LOGLEVEL_CONSTRUCT);
+       debugmsg("matrix default ctor",LOGLEVEL_CONSTRUCT);
        m.push_back(_ex0());
 }
 
 // protected
 
+/** For use by copy ctor and assignment operator. */
 void matrix::copy(const matrix & other)
 {
        inherited::copy(other);
@@ -70,7 +70,7 @@ void matrix::destroy(bool call_parent)
 }
 
 //////////
-// other constructors
+// other ctors
 //////////
 
 // public
@@ -82,7 +82,7 @@ void matrix::destroy(bool call_parent)
 matrix::matrix(unsigned r, unsigned c)
   : inherited(TINFO_matrix), row(r), col(c)
 {
-       debugmsg("matrix constructor from unsigned,unsigned",LOGLEVEL_CONSTRUCT);
+       debugmsg("matrix ctor from unsigned,unsigned",LOGLEVEL_CONSTRUCT);
        m.resize(r*c, _ex0());
 }
 
@@ -92,7 +92,7 @@ matrix::matrix(unsigned r, unsigned c)
 matrix::matrix(unsigned r, unsigned c, const exvector & m2)
   : inherited(TINFO_matrix), row(r), col(c), m(m2)
 {
-       debugmsg("matrix constructor from unsigned,unsigned,exvector",LOGLEVEL_CONSTRUCT);
+       debugmsg("matrix ctor from unsigned,unsigned,exvector",LOGLEVEL_CONSTRUCT);
 }
 
 //////////
@@ -102,7 +102,7 @@ matrix::matrix(unsigned r, unsigned c, const exvector & m2)
 /** Construct object from archive_node. */
 matrix::matrix(const archive_node &n, const lst &sym_lst) : inherited(n, sym_lst)
 {
-       debugmsg("matrix constructor from archive_node", LOGLEVEL_CONSTRUCT);
+       debugmsg("matrix ctor from archive_node", LOGLEVEL_CONSTRUCT);
        if (!(n.find_unsigned("row", row)) || !(n.find_unsigned("col", col)))
                throw (std::runtime_error("unknown matrix dimensions in archive"));
        m.reserve(row * col);
@@ -159,7 +159,7 @@ void matrix::print(std::ostream & os, unsigned upper_precedence) const
 void matrix::printraw(std::ostream & os) const
 {
        debugmsg("matrix printraw",LOGLEVEL_PRINT);
-       os << "matrix(" << row << "," << col <<",";
+       os << class_name() << "(" << row << "," << col <<",";
        for (unsigned r=0; r<row-1; ++r) {
                os << "(";
                for (unsigned c=0; c<col-1; ++c)
@@ -219,7 +219,7 @@ bool matrix::has(const ex & other) const
        return false;
 }
 
-/** evaluate matrix entry by entry. */
+/** Evaluate matrix entry by entry. */
 ex matrix::eval(int level) const
 {
        debugmsg("matrix eval",LOGLEVEL_MEMBER_FUNCTION);
@@ -243,7 +243,7 @@ ex matrix::eval(int level) const
                                                                                           status_flags::evaluated );
 }
 
-/** evaluate matrix numerically entry by entry. */
+/** Evaluate matrix numerically entry by entry. */
 ex matrix::evalf(int level) const
 {
        debugmsg("matrix evalf",LOGLEVEL_MEMBER_FUNCTION);
@@ -267,6 +267,16 @@ ex matrix::evalf(int level) const
        return matrix(row, col, m2);
 }
 
+ex matrix::subs(const lst & ls, const lst & lr) const
+{
+       exvector m2(row * col);
+       for (unsigned r=0; r<row; ++r)
+               for (unsigned c=0; c<col; ++c)
+                       m2[r*col+c] = m[r*col+c].subs(ls, lr);
+
+       return matrix(row, col, m2);
+}
+
 // protected
 
 int matrix::compare_same_type(const basic & other) const
@@ -294,6 +304,219 @@ int matrix::compare_same_type(const basic & other) const
        return 0;
 }
 
+/** Automatic symbolic evaluation of an indexed matrix. */
+ex matrix::eval_indexed(const basic & i) const
+{
+       GINAC_ASSERT(is_of_type(i, indexed));
+       GINAC_ASSERT(is_ex_of_type(i.op(0), matrix));
+
+       bool all_indices_unsigned = static_cast<const indexed &>(i).all_index_values_are(info_flags::nonnegint);
+
+       // Check indices
+       if (i.nops() == 2) {
+
+               // One index, must be one-dimensional vector
+               if (row != 1 && col != 1)
+                       throw (std::runtime_error("matrix::eval_indexed(): vector must have exactly 1 index"));
+
+               const idx & i1 = ex_to_idx(i.op(1));
+
+               if (col == 1) {
+
+                       // Column vector
+                       if (!i1.get_dim().is_equal(row))
+                               throw (std::runtime_error("matrix::eval_indexed(): dimension of index must match number of vector elements"));
+
+                       // Index numeric -> return vector element
+                       if (all_indices_unsigned) {
+                               unsigned n1 = ex_to_numeric(i1.get_value()).to_int();
+                               if (n1 >= row)
+                                       throw (std::runtime_error("matrix::eval_indexed(): value of index exceeds number of vector elements"));
+                               return (*this)(n1, 0);
+                       }
+
+               } else {
+
+                       // Row vector
+                       if (!i1.get_dim().is_equal(col))
+                               throw (std::runtime_error("matrix::eval_indexed(): dimension of index must match number of vector elements"));
+
+                       // Index numeric -> return vector element
+                       if (all_indices_unsigned) {
+                               unsigned n1 = ex_to_numeric(i1.get_value()).to_int();
+                               if (n1 >= col)
+                                       throw (std::runtime_error("matrix::eval_indexed(): value of index exceeds number of vector elements"));
+                               return (*this)(0, n1);
+                       }
+               }
+
+       } else if (i.nops() == 3) {
+
+               // Two indices
+               const idx & i1 = ex_to_idx(i.op(1));
+               const idx & i2 = ex_to_idx(i.op(2));
+
+               if (!i1.get_dim().is_equal(row))
+                       throw (std::runtime_error("matrix::eval_indexed(): dimension of first index must match number of rows"));
+               if (!i2.get_dim().is_equal(col))
+                       throw (std::runtime_error("matrix::eval_indexed(): dimension of second index must match number of columns"));
+
+               // Pair of dummy indices -> compute trace
+               if (is_dummy_pair(i1, i2))
+                       return trace();
+
+               // Both indices numeric -> return matrix element
+               if (all_indices_unsigned) {
+                       unsigned n1 = ex_to_numeric(i1.get_value()).to_int(), n2 = ex_to_numeric(i2.get_value()).to_int();
+                       if (n1 >= row)
+                               throw (std::runtime_error("matrix::eval_indexed(): value of first index exceeds number of rows"));
+                       if (n2 >= col)
+                               throw (std::runtime_error("matrix::eval_indexed(): value of second index exceeds number of columns"));
+                       return (*this)(n1, n2);
+               }
+
+       } else
+               throw (std::runtime_error("matrix::eval_indexed(): matrix must have exactly 2 indices"));
+
+       return i.hold();
+}
+
+/** Sum of two indexed matrices. */
+ex matrix::add_indexed(const ex & self, const ex & other) const
+{
+       GINAC_ASSERT(is_ex_of_type(self, indexed));
+       GINAC_ASSERT(is_ex_of_type(other, indexed));
+       GINAC_ASSERT(self.nops() == 2 || self.nops() == 3);
+
+       // Only add two matrices
+       if (is_ex_of_type(other.op(0), matrix)) {
+               GINAC_ASSERT(other.nops() == 2 || other.nops() == 3);
+
+               const matrix &self_matrix = ex_to_matrix(self.op(0));
+               const matrix &other_matrix = ex_to_matrix(other.op(0));
+
+               if (self.nops() == 2 && other.nops() == 2) { // vector + vector
+
+                       if (self_matrix.row == other_matrix.row)
+                               return indexed(self_matrix.add(other_matrix), self.op(1));
+                       else if (self_matrix.row == other_matrix.col)
+                               return indexed(self_matrix.add(other_matrix.transpose()), self.op(1));
+
+               } else if (self.nops() == 3 && other.nops() == 3) { // matrix + matrix
+
+                       if (self.op(1).is_equal(other.op(1)) && self.op(2).is_equal(other.op(2)))
+                               return indexed(self_matrix.add(other_matrix), self.op(1), self.op(2));
+                       else if (self.op(1).is_equal(other.op(2)) && self.op(2).is_equal(other.op(1)))
+                               return indexed(self_matrix.add(other_matrix.transpose()), self.op(1), self.op(2));
+
+               }
+       }
+
+       // Don't know what to do, return unevaluated sum
+       return self + other;
+}
+
+/** Contraction of an indexed matrix with something else. */
+bool matrix::contract_with(exvector::iterator self, exvector::iterator other, exvector & v) const
+{
+       GINAC_ASSERT(is_ex_of_type(*self, indexed));
+       GINAC_ASSERT(is_ex_of_type(*other, indexed));
+       GINAC_ASSERT(self->nops() == 2 || self->nops() == 3);
+       GINAC_ASSERT(is_ex_of_type(self->op(0), matrix));
+
+       // Only contract with other matrices
+       if (!is_ex_of_type(other->op(0), matrix))
+               return false;
+
+       GINAC_ASSERT(other->nops() == 2 || other->nops() == 3);
+
+       const matrix &self_matrix = ex_to_matrix(self->op(0));
+       const matrix &other_matrix = ex_to_matrix(other->op(0));
+
+       if (self->nops() == 2) {
+               unsigned self_dim = (self_matrix.col == 1) ? self_matrix.row : self_matrix.col;
+
+               if (other->nops() == 2) { // vector * vector (scalar product)
+                       unsigned other_dim = (other_matrix.col == 1) ? other_matrix.row : other_matrix.col;
+
+                       if (self_matrix.col == 1) {
+                               if (other_matrix.col == 1) {
+                                       // Column vector * column vector, transpose first vector
+                                       *self = self_matrix.transpose().mul(other_matrix)(0, 0);
+                               } else {
+                                       // Column vector * row vector, swap factors
+                                       *self = other_matrix.mul(self_matrix)(0, 0);
+                               }
+                       } else {
+                               if (other_matrix.col == 1) {
+                                       // Row vector * column vector, perfect
+                                       *self = self_matrix.mul(other_matrix)(0, 0);
+                               } else {
+                                       // Row vector * row vector, transpose second vector
+                                       *self = self_matrix.mul(other_matrix.transpose())(0, 0);
+                               }
+                       }
+                       *other = _ex1();
+                       return true;
+
+               } else { // vector * matrix
+
+                       // B_i * A_ij = (B*A)_j (B is row vector)
+                       if (is_dummy_pair(self->op(1), other->op(1))) {
+                               if (self_matrix.row == 1)
+                                       *self = indexed(self_matrix.mul(other_matrix), other->op(2));
+                               else
+                                       *self = indexed(self_matrix.transpose().mul(other_matrix), other->op(2));
+                               *other = _ex1();
+                               return true;
+                       }
+
+                       // B_j * A_ij = (A*B)_i (B is column vector)
+                       if (is_dummy_pair(self->op(1), other->op(2))) {
+                               if (self_matrix.col == 1)
+                                       *self = indexed(other_matrix.mul(self_matrix), other->op(1));
+                               else
+                                       *self = indexed(other_matrix.mul(self_matrix.transpose()), other->op(1));
+                               *other = _ex1();
+                               return true;
+                       }
+               }
+
+       } else if (other->nops() == 3) { // matrix * matrix
+
+               // A_ij * B_jk = (A*B)_ik
+               if (is_dummy_pair(self->op(2), other->op(1))) {
+                       *self = indexed(self_matrix.mul(other_matrix), self->op(1), other->op(2));
+                       *other = _ex1();
+                       return true;
+               }
+
+               // A_ij * B_kj = (A*Btrans)_ik
+               if (is_dummy_pair(self->op(2), other->op(2))) {
+                       *self = indexed(self_matrix.mul(other_matrix.transpose()), self->op(1), other->op(1));
+                       *other = _ex1();
+                       return true;
+               }
+
+               // A_ji * B_jk = (Atrans*B)_ik
+               if (is_dummy_pair(self->op(1), other->op(1))) {
+                       *self = indexed(self_matrix.transpose().mul(other_matrix), self->op(2), other->op(2));
+                       *other = _ex1();
+                       return true;
+               }
+
+               // A_ji * B_kj = (B*A)_ki
+               if (is_dummy_pair(self->op(1), other->op(2))) {
+                       *self = indexed(other_matrix.mul(self_matrix), other->op(1), self->op(2));
+                       *other = _ex1();
+                       return true;
+               }
+       }
+
+       return false;
+}
+
+
 //////////
 // non-virtual functions in this class
 //////////
@@ -1192,6 +1415,4 @@ ex lst_to_matrix(const ex &l)
        return m;
 }
 
-#ifndef NO_NAMESPACE_GINAC
 } // namespace GiNaC
-#endif // ndef NO_NAMESPACE_GINAC