]> www.ginac.de Git - ginac.git/blobdiff - ginac/expairseq.cpp
- first implementation of pattern matching
[ginac.git] / ginac / expairseq.cpp
index b14ffb3027e7b72c61ae4642306fece67def22aa..2d9a8f6bc67d764f18d20e6c9efa563df28b5db1 100644 (file)
@@ -26,6 +26,7 @@
 
 #include "expairseq.h"
 #include "lst.h"
+#include "relational.h"
 #include "print.h"
 #include "archive.h"
 #include "debugmsg.h"
@@ -321,13 +322,87 @@ ex expairseq::normal(lst &sym_lst, lst &repl_lst, int level) const
        return n.bp->basic::normal(sym_lst,repl_lst,level);
 }
 
-ex expairseq::subs(const lst &ls, const lst &lr) const
+bool expairseq::match(const ex & pattern, lst & repl_lst) const
 {
-       epvector *vp = subschildren(ls,lr);
-       if (vp==0)
-               return inherited::subs(ls, lr);
-       
-       return thisexpairseq(vp,overall_coeff);
+//clog << "match " << *this << " with " << pattern << ", repl_lst = " << repl_lst << endl;
+       // This differs from basic::match() because we want "a+b+c+d" to
+       // match "d+*+b" with "*" being "a+c", and we want to honor commutativity
+
+       if (tinfo() == pattern.bp->tinfo()) {
+
+               // Check whether global wildcard (one that matches the "rest of the
+               // expression", like "*" above) is present
+               bool has_global_wildcard = false;
+               ex global_wildcard;
+               for (unsigned int i=0; i<pattern.nops(); i++) {
+                       if (is_ex_exactly_of_type(pattern.op(i), wildcard)) {
+                               has_global_wildcard = true;
+                               global_wildcard = pattern.op(i);
+                               break;
+                       }
+               }
+
+               // Unfortunately, this is an O(N^2) operation because we can't
+               // sort the pattern in a useful way...
+
+               // Chop into terms
+               exvector ops;
+               ops.reserve(nops());
+               for (unsigned i=0; i<nops(); i++)
+                       ops.push_back(op(i));
+
+               // Now, for every term of the pattern, look for a matching term in
+               // the expression and remove the match
+               for (unsigned i=0; i<pattern.nops(); i++) {
+                       ex p = pattern.op(i);
+                       if (has_global_wildcard && p.is_equal(global_wildcard))
+                               continue;
+                       exvector::iterator it = ops.begin(), itend = ops.end();
+                       while (it != itend) {
+                               if (it->match(p, repl_lst)) {
+                                       ops.erase(it);
+                                       goto found;
+                               }
+                               it++;
+                       }
+                       return false; // no match found
+found:         ;
+               }
+
+               if (has_global_wildcard) {
+
+                       // Assign all the remaining terms to the global wildcard (unless
+                       // it has already been matched before, in which case the matches
+                       // must be equal)
+                       epvector *vp = new epvector();
+                       vp->reserve(ops.size());
+                       for (unsigned i=0; i<ops.size(); i++)
+                               vp->push_back(split_ex_to_pair(ops[i]));
+                       ex rest = thisexpairseq(vp, default_overall_coeff());
+                       for (unsigned i=0; i<repl_lst.nops(); i++) {
+                               if (repl_lst.op(i).op(0).is_equal(global_wildcard))
+                                       return rest.is_equal(*repl_lst.op(i).op(1).bp);
+                       }
+                       repl_lst.append(global_wildcard == rest);
+                       return true;
+
+               } else {
+
+                       // No global wildcard, then the match fails if there are any
+                       // unmatched terms left
+                       return ops.empty();
+               }
+       }
+       return inherited::match(pattern, repl_lst);
+}
+
+ex expairseq::subs(const lst &ls, const lst &lr, bool no_pattern) const
+{
+       epvector *vp = subschildren(ls, lr, no_pattern);
+       if (vp)
+               return thisexpairseq(vp, overall_coeff).bp->basic::subs(ls, lr, no_pattern);
+       else
+               return basic::subs(ls, lr, no_pattern);
 }
 
 // protected
@@ -1586,7 +1661,7 @@ epvector expairseq::diffchildren(const symbol &y) const
  *  @see expairseq::subs()
  *  @return pointer to epvector containing pairs after application of subs or zero
  *  pointer, if no members were changed. */
-epvector * expairseq::subschildren(const lst &ls, const lst &lr) const
+epvector * expairseq::subschildren(const lst &ls, const lst &lr, bool no_pattern) const
 {
        // returns a NULL pointer if nothing had to be substituted
        // returns a pointer to a newly created epvector otherwise
@@ -1596,7 +1671,7 @@ epvector * expairseq::subschildren(const lst &ls, const lst &lr) const
        epvector::const_iterator last = seq.end();
        epvector::const_iterator cit = seq.begin();
        while (cit!=last) {
-               const ex &subsed_ex=(*cit).rest.subs(ls,lr);
+               const ex &subsed_ex=(*cit).rest.subs(ls,lr,no_pattern);
                if (!are_ex_trivially_equal((*cit).rest,subsed_ex)) {
                        
                        // something changed, copy seq, subs and return it
@@ -1615,7 +1690,7 @@ epvector * expairseq::subschildren(const lst &ls, const lst &lr) const
                        ++cit2;
                        // copy rest
                        while (cit2!=last) {
-                               s->push_back(combine_ex_with_coeff_to_pair((*cit2).rest.subs(ls,lr),
+                               s->push_back(combine_ex_with_coeff_to_pair((*cit2).rest.subs(ls,lr,no_pattern),
                                                                           (*cit2).coeff));
                                ++cit2;
                        }