]> www.ginac.de Git - ginac.git/blobdiff - ginac/matrix.cpp
- matrix::pow(): omit last big multiplication if it's not needed.
[ginac.git] / ginac / matrix.cpp
index 28c2f1d25a853faf7750220d1a82867d9f425775..872b8e006012c3a4ecf6279cf315ed8c34e5f793 100644 (file)
@@ -636,10 +636,10 @@ matrix matrix::pow(const ex & expn) const
                        }
                        matrix result(row,col);
                        for (unsigned r=0; r<row; ++r)
-                               result.set(r,r,_ex1());
+                               result(r,r) = _ex1();
                        numeric b(1);
-                       // this loop computes the representation of k in base 2 and multiplies
-                       // the factors whenever needed:
+                       // this loop computes the representation of k in base 2 and
+                       // multiplies the factors whenever needed:
                        while (b.compare(k)<=0) {
                                b *= numeric(2);
                                numeric r(mod(k,b));
@@ -647,7 +647,8 @@ matrix matrix::pow(const ex & expn) const
                                        k -= r;
                                        result = result.mul(prod);
                                }
-                               prod = prod.mul(prod);
+                               if (b.compare(k)<=0)
+                                       prod = prod.mul(prod);
                        }
                        return result;
                }
@@ -656,7 +657,7 @@ matrix matrix::pow(const ex & expn) const
 }
 
 
-/** operator() to access elements.
+/** operator() to access elements for reading.
  *
  *  @param ro row of element
  *  @param co column of element
@@ -670,19 +671,19 @@ const ex & matrix::operator() (unsigned ro, unsigned co) const
 }
 
 
-/** Set individual elements manually.
+/** operator() to access elements for writing.
  *
+ *  @param ro row of element
+ *  @param co column of element
  *  @exception range_error (index out of range) */
-matrix & matrix::set(unsigned ro, unsigned co, ex value)
+ex & matrix::operator() (unsigned ro, unsigned co)
 {
        if (ro>=row || co>=col)
-               throw (std::range_error("matrix::set(): index out of range"));
-       if (value.return_type() != return_types::commutative)
-               throw std::runtime_error("matrix::set(): non-commutative argument");
-    
+               throw (std::range_error("matrix::operator(): index out of range"));
+
        ensure_if_modifiable();
-       m[ro*col+co] = value;
-       return *this;
+       clearflag(status_flags::hash_calculated);
+       return m[ro*col+co];
 }
 
 
@@ -924,7 +925,7 @@ matrix matrix::inverse(void) const
        // First populate the identity matrix supposed to become the right hand side.
        matrix identity(row,col);
        for (unsigned i=0; i<row; ++i)
-               identity.set(i,i,_ex1());
+               identity(i,i) = _ex1();
        
        // Populate a dummy matrix of variables, just because of compatibility with
        // matrix::solve() which wants this (for compatibility with under-determined
@@ -932,7 +933,7 @@ matrix matrix::inverse(void) const
        matrix vars(row,col);
        for (unsigned r=0; r<row; ++r)
                for (unsigned c=0; c<col; ++c)
-                       vars.set(r,c,symbol());
+                       vars(r,c) = symbol();
        
        matrix sol(row,col);
        try {
@@ -1032,19 +1033,18 @@ matrix matrix::solve(const matrix & vars,
                                // assign solutions for vars between fnz+1 and
                                // last_assigned_sol-1: free parameters
                                for (unsigned c=fnz; c<last_assigned_sol-1; ++c)
-                                       sol.set(c,co,vars.m[c*p+co]);
+                                       sol(c,co) = vars.m[c*p+co];
                                ex e = aug.m[r*(n+p)+n+co];
                                for (unsigned c=fnz; c<n; ++c)
                                        e -= aug.m[r*(n+p)+c]*sol.m[c*p+co];
-                               sol.set(fnz-1,co,
-                                               (e/(aug.m[r*(n+p)+(fnz-1)])).normal());
+                               sol(fnz-1,co) = (e/(aug.m[r*(n+p)+(fnz-1)])).normal();
                                last_assigned_sol = fnz;
                        }
                }
                // assign solutions for vars between 1 and
                // last_assigned_sol-1: free parameters
                for (unsigned ro=0; ro<last_assigned_sol-1; ++ro)
-                       sol.set(ro,co,vars(ro,co));
+                       sol(ro,co) = vars(ro,co);
        }
        
        return sol;
@@ -1088,9 +1088,9 @@ ex matrix::determinant_minor(void) const
        //     for (unsigned r=0; r<minorM.rows(); ++r) {
        //         for (unsigned c=0; c<minorM.cols(); ++c) {
        //             if (r<r1)
-       //                 minorM.set(r,c,m[r*col+c+1]);
+       //                 minorM(r,c) = m[r*col+c+1];
        //             else
-       //                 minorM.set(r,c,m[(r+1)*col+c+1]);
+       //                 minorM(r,c) = m[(r+1)*col+c+1];
        //         }
        //     }
        //     // recurse down and care for sign:
@@ -1467,9 +1467,9 @@ ex lst_to_matrix(const lst & l)
        for (i=0; i<rows; i++)
                for (j=0; j<cols; j++)
                        if (l.op(i).nops() > j)
-                               m.set(i, j, l.op(i).op(j));
+                               m(i, j) = l.op(i).op(j);
                        else
-                               m.set(i, j, ex(0));
+                               m(i, j) = _ex0();
        return m;
 }
 
@@ -1480,7 +1480,7 @@ ex diag_matrix(const lst & l)
        matrix &m = *new matrix(dim, dim);
        m.setflag(status_flags::dynallocated);
        for (unsigned i=0; i<dim; i++)
-               m.set(i, i, l.op(i));
+               m(i, i) = l.op(i);
 
        return m;
 }