]> www.ginac.de Git - ginac.git/blobdiff - ginac/matrix.cpp
Added a document about the coding conventions used in GiNaC. Corrections,
[ginac.git] / ginac / matrix.cpp
index cdb875d3aa2c483d6ea00b8acbbd13161a98a4db..2c191ff4f69c04d993b2724b16f13e13d445156e 100644 (file)
 #include "symbol.h"
 #include "operators.h"
 #include "normal.h"
-#include "print.h"
 #include "archive.h"
 #include "utils.h"
 
 namespace GiNaC {
 
-GINAC_IMPLEMENT_REGISTERED_CLASS(matrix, basic)
+GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(matrix, basic,
+  print_func<print_context>(&matrix::do_print).
+  print_func<print_latex>(&matrix::do_print_latex).
+  print_func<print_tree>(&basic::do_print_tree).
+  print_func<print_python_repr>(&matrix::do_print_python_repr))
 
 //////////
 // default constructor
 //////////
 
 /** Default ctor.  Initializes to 1 x 1-dimensional zero-matrix. */
-matrix::matrix() : inherited(TINFO_matrix), row(1), col(1)
+matrix::matrix() : inherited(TINFO_matrix), row(1), col(1), m(1, _ex0)
 {
-       m.push_back(_ex0);
+       setflag(status_flags::not_shareable);
 }
 
 //////////
@@ -66,25 +69,28 @@ matrix::matrix() : inherited(TINFO_matrix), row(1), col(1)
  *  @param r number of rows
  *  @param c number of cols */
 matrix::matrix(unsigned r, unsigned c)
-  : inherited(TINFO_matrix), row(r), col(c)
+  : inherited(TINFO_matrix), row(r), col(c), m(r*c, _ex0)
 {
-       m.resize(r*c, _ex0);
+       setflag(status_flags::not_shareable);
 }
 
 // protected
 
 /** Ctor from representation, for internal use only. */
 matrix::matrix(unsigned r, unsigned c, const exvector & m2)
-  : inherited(TINFO_matrix), row(r), col(c), m(m2) {}
+  : inherited(TINFO_matrix), row(r), col(c), m(m2)
+{
+       setflag(status_flags::not_shareable);
+}
 
 /** Construct matrix from (flat) list of elements. If the list has fewer
  *  elements than the matrix, the remaining matrix elements are set to zero.
  *  If the list has more elements than the matrix, the excessive elements are
  *  thrown away. */
 matrix::matrix(unsigned r, unsigned c, const lst & l)
-  : inherited(TINFO_matrix), row(r), col(c)
+  : inherited(TINFO_matrix), row(r), col(c), m(r*c, _ex0)
 {
-       m.resize(r*c, _ex0);
+       setflag(status_flags::not_shareable);
 
        size_t i = 0;
        for (lst::const_iterator it = l.begin(); it != l.end(); ++it, ++i) {
@@ -102,6 +108,8 @@ matrix::matrix(unsigned r, unsigned c, const lst & l)
 
 matrix::matrix(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
 {
+       setflag(status_flags::not_shareable);
+
        if (!(n.find_unsigned("row", row)) || !(n.find_unsigned("col", col)))
                throw (std::runtime_error("unknown matrix dimensions in archive"));
        m.reserve(row * col);
@@ -134,54 +142,41 @@ DEFAULT_UNARCHIVE(matrix)
 
 // public
 
-void matrix::print(const print_context & c, unsigned level) const
+void matrix::print_elements(const print_context & c, const char *row_start, const char *row_end, const char *row_sep, const char *col_sep) const
 {
-       if (is_a<print_tree>(c)) {
-
-               inherited::print(c, level);
-
-       } else {
-
-               if (is_a<print_python_repr>(c))
-                       c.s << class_name() << '(';
-
-               if (is_a<print_latex>(c))
-                       c.s << "\\left(\\begin{array}{" << std::string(col,'c') << "}";
-               else
-                       c.s << "[";
-
-               for (unsigned ro=0; ro<row; ++ro) {
-                       if (!is_a<print_latex>(c))
-                               c.s << "[";
-                       for (unsigned co=0; co<col; ++co) {
-                               m[ro*col+co].print(c);
-                               if (co<col-1) {
-                                       if (is_a<print_latex>(c))
-                                               c.s << "&";
-                                       else
-                                               c.s << ",";
-                               } else {
-                                       if (!is_a<print_latex>(c))
-                                               c.s << "]";
-                               }
-                       }
-                       if (ro<row-1) {
-                               if (is_a<print_latex>(c))
-                                       c.s << "\\\\";
-                               else
-                                       c.s << ",";
-                       }
+       for (unsigned ro=0; ro<row; ++ro) {
+               c.s << row_start;
+               for (unsigned co=0; co<col; ++co) {
+                       m[ro*col+co].print(c);
+                       if (co < col-1)
+                               c.s << col_sep;
+                       else
+                               c.s << row_end;
                }
+               if (ro < row-1)
+                       c.s << row_sep;
+       }
+}
 
-               if (is_a<print_latex>(c))
-                       c.s << "\\end{array}\\right)";
-               else
-                       c.s << "]";
+void matrix::do_print(const print_context & c, unsigned level) const
+{
+       c.s << "[";
+       print_elements(c, "[", "]", ",", ",");
+       c.s << "]";
+}
 
-               if (is_a<print_python_repr>(c))
-                       c.s << ')';
+void matrix::do_print_latex(const print_latex & c, unsigned level) const
+{
+       c.s << "\\left(\\begin{array}{" << std::string(col,'c') << "}";
+       print_elements(c, "", "", "\\\\", "&");
+       c.s << "\\end{array}\\right)";
+}
 
-       }
+void matrix::do_print_python_repr(const print_python_repr & c, unsigned level) const
+{
+       c.s << class_name() << '(';
+       print_elements(c, "[", "]", ",", ",");
+       c.s << ')';
 }
 
 /** nops is defined to be rows x columns. */
@@ -707,7 +702,7 @@ ex matrix::determinant(unsigned algo) const
        unsigned sparse_count = 0;  // counts non-zero elements
        exvector::const_iterator r = m.begin(), rend = m.end();
        while (r != rend) {
-               lst srl;  // symbol replacement list
+               exmap srl;  // symbol replacement list
                ex rtest = r->to_rational(srl);
                if (!rtest.is_zero())
                        ++sparse_count;
@@ -855,7 +850,7 @@ ex matrix::trace() const
  *  @return    characteristic polynomial as new expression
  *  @exception logic_error (matrix not square)
  *  @see       matrix::determinant() */
-ex matrix::charpoly(const symbol & lambda) const
+ex matrix::charpoly(const ex & lambda) const
 {
        if (row != col)
                throw (std::logic_error("matrix::charpoly(): matrix not square"));
@@ -875,13 +870,13 @@ ex matrix::charpoly(const symbol & lambda) const
 
                matrix B(*this);
                ex c = B.trace();
-               ex poly = power(lambda,row)-c*power(lambda,row-1);
+               ex poly = power(lambda, row) - c*power(lambda, row-1);
                for (unsigned i=1; i<row; ++i) {
                        for (unsigned j=0; j<row; ++j)
                                B.m[j*col+j] -= c;
                        B = this->mul(B);
                        c = B.trace() / ex(i+1);
-                       poly -= c*power(lambda,row-i-1);
+                       poly -= c*power(lambda, row-i-1);
                }
                if (row%2)
                        return -poly;
@@ -943,6 +938,7 @@ matrix matrix::inverse() const
  *
  *  @param vars n x p matrix, all elements must be symbols 
  *  @param rhs m x p matrix
+ *  @param algo selects the solving algorithm
  *  @return n x p solution matrix
  *  @exception logic_error (incompatible matrices)
  *  @exception invalid_argument (1st argument must be matrix of symbols)
@@ -1325,7 +1321,7 @@ int matrix::fraction_free_elimination(const bool det)
        // makes things more complicated than they need to be.
        matrix tmp_n(*this);
        matrix tmp_d(m,n);  // for denominators, if needed
-       lst srl;  // symbol replacement list
+       exmap srl;  // symbol replacement list
        exvector::const_iterator cit = this->m.begin(), citend = this->m.end();
        exvector::iterator tmp_n_it = tmp_n.m.begin(), tmp_d_it = tmp_d.m.begin();
        while (cit != citend) {
@@ -1388,7 +1384,7 @@ int matrix::fraction_free_elimination(const bool det)
        tmp_n_it = tmp_n.m.begin();
        tmp_d_it = tmp_d.m.begin();
        while (it != itend)
-               *it++ = ((*tmp_n_it++)/(*tmp_d_it++)).subs(srl);
+               *it++ = ((*tmp_n_it++)/(*tmp_d_it++)).subs(srl, subs_options::no_pattern);
        
        return sign;
 }