]> www.ginac.de Git - ginac.git/blobdiff - ginac/matrix.cpp
cleaned up some is_a<> vs. is_exactly_a<> stuff
[ginac.git] / ginac / matrix.cpp
index 3d735fe46689aa57fc0d973a0b8588e07f2049f8..46d7a64cd9faa2330cca7a035e9782c2cb334da6 100644 (file)
@@ -3,7 +3,7 @@
  *  Implementation of symbolic matrices */
 
 /*
- *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2002 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -20,7 +20,9 @@
  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
  */
 
+#include <string>
 #include <iostream>
+#include <sstream>
 #include <algorithm>
 #include <map>
 #include <stdexcept>
@@ -147,23 +149,44 @@ void matrix::print(const print_context & c, unsigned level) const
 
        } else {
 
-               c.s << "[";
-               for (unsigned y=0; y<row-1; ++y) {
+               if (is_a<print_python_repr>(c))
+                       c.s << class_name() << '(';
+
+               if (is_a<print_latex>(c))
+                       c.s << "\\left(\\begin{array}{" << std::string(col,'c') << "}";
+               else
                        c.s << "[";
-                       for (unsigned x=0; x<col-1; ++x) {
-                               m[y*col+x].print(c);
-                               c.s << ",";
+
+               for (unsigned ro=0; ro<row; ++ro) {
+                       if (!is_a<print_latex>(c))
+                               c.s << "[";
+                       for (unsigned co=0; co<col; ++co) {
+                               m[ro*col+co].print(c);
+                               if (co<col-1) {
+                                       if (is_a<print_latex>(c))
+                                               c.s << "&";
+                                       else
+                                               c.s << ",";
+                               } else {
+                                       if (!is_a<print_latex>(c))
+                                               c.s << "]";
+                               }
+                       }
+                       if (ro<row-1) {
+                               if (is_a<print_latex>(c))
+                                       c.s << "\\\\";
+                               else
+                                       c.s << ",";
                        }
-                       m[col*(y+1)-1].print(c);
-                       c.s << "],";
-               }
-               c.s << "[";
-               for (unsigned x=0; x<col-1; ++x) {
-                       m[(row-1)*col+x].print(c);
-                       c.s << ",";
                }
-               m[row*col-1].print(c);
-               c.s << "]]";
+
+               if (is_a<print_latex>(c))
+                       c.s << "\\end{array}\\right)";
+               else
+                       c.s << "]";
+
+               if (is_a<print_python_repr>(c))
+                       c.s << ')';
 
        }
 }
@@ -208,7 +231,7 @@ ex matrix::eval(int level) const
                        m2[r*col+c] = m[r*col+c].eval(level);
        
        return (new matrix(row, col, m2))->setflag(status_flags::dynallocated |
-                                                                                          status_flags::evaluated );
+                                                                                          status_flags::evaluated);
 }
 
 ex matrix::subs(const lst & ls, const lst & lr, bool no_pattern) const
@@ -599,17 +622,19 @@ matrix matrix::pow(const ex & expn) const
                        matrix C(row,col);
                        for (unsigned r=0; r<row; ++r)
                                C(r,r) = _ex1;
+                       if (b.is_zero())
+                               return C;
                        // This loop computes the representation of b in base 2 from right
                        // to left and multiplies the factors whenever needed.  Note
                        // that this is not entirely optimal but close to optimal and
                        // "better" algorithms are much harder to implement.  (See Knuth,
                        // TAoCP2, section "Evaluation of Powers" for a good discussion.)
-                       while (b!=1) {
+                       while (b!=_num1) {
                                if (b.is_odd()) {
                                        C = C.mul(A);
-                                       b -= 1;
+                                       --b;
                                }
-                               b *= _num1_2;  // b /= 2, still integer.
+                               b /= _num2;  // still integer.
                                A = A.mul(A);
                        }
                        return A.mul(C);
@@ -1393,11 +1418,11 @@ int matrix::pivot(unsigned ro, unsigned co, bool symbolic)
                        ++k;
        } else {
                // search largest element in column co beginning at row ro
-               GINAC_ASSERT(is_a<numeric>(this->m[k*col+co]));
+               GINAC_ASSERT(is_exactly_a<numeric>(this->m[k*col+co]));
                unsigned kmax = k+1;
                numeric mmax = abs(ex_to<numeric>(m[kmax*col+co]));
                while (kmax<row) {
-                       GINAC_ASSERT(is_a<numeric>(this->m[kmax*col+co]));
+                       GINAC_ASSERT(is_exactly_a<numeric>(this->m[kmax*col+co]));
                        numeric tmp = ex_to<numeric>(this->m[kmax*col+co]);
                        if (abs(tmp) > mmax) {
                                mmax = tmp;
@@ -1431,15 +1456,15 @@ ex lst_to_matrix(const lst & l)
                        cols = l.op(i).nops();
 
        // Allocate and fill matrix
-       matrix &m = *new matrix(rows, cols);
-       m.setflag(status_flags::dynallocated);
+       matrix &M = *new matrix(rows, cols);
+       M.setflag(status_flags::dynallocated);
        for (i=0; i<rows; i++)
                for (j=0; j<cols; j++)
                        if (l.op(i).nops() > j)
-                               m(i, j) = l.op(i).op(j);
+                               M(i, j) = l.op(i).op(j);
                        else
-                               m(i, j) = _ex0;
-       return m;
+                               M(i, j) = _ex0;
+       return M;
 }
 
 ex diag_matrix(const lst & l)
@@ -1454,4 +1479,49 @@ ex diag_matrix(const lst & l)
        return m;
 }
 
+ex unit_matrix(unsigned r, unsigned c)
+{
+       matrix Id(r,c);
+       for (unsigned i=0; i<r && i<c; ++i)
+               Id(i,i) = _ex1;
+       return Id;
+}
+
+ex symbolic_matrix(unsigned r, unsigned c, const std::string & base_name, const std::string & tex_base_name)
+{
+       matrix &M = *new matrix(r, c);
+       M.setflag(status_flags::dynallocated | status_flags::evaluated);
+
+       bool long_format = (r > 10 || c > 10);
+       bool single_row = (r == 1 || c == 1);
+
+       for (unsigned i=0; i<r; i++) {
+               for (unsigned j=0; j<c; j++) {
+                       std::ostringstream s1, s2;
+                       s1 << base_name;
+                       s2 << tex_base_name << "_{";
+                       if (single_row) {
+                               if (c == 1) {
+                                       s1 << i;
+                                       s2 << i << '}';
+                               } else {
+                                       s1 << j;
+                                       s2 << j << '}';
+                               }
+                       } else {
+                               if (long_format) {
+                                       s1 << '_' << i << '_' << j;
+                                       s2 << i << ';' << j << "}";
+                               } else {
+                                       s1 << i << j;
+                                       s2 << i << j << '}';
+                               }
+                       }
+                       M(i, j) = symbol(s1.str(), s2.str());
+               }
+       }
+
+       return M;
+}
+
 } // namespace GiNaC