#include "expairseq.h"
#include "lst.h"
+#include "relational.h"
#include "print.h"
#include "archive.h"
#include "debugmsg.h"
return n.bp->basic::normal(sym_lst,repl_lst,level);
}
-ex expairseq::subs(const lst &ls, const lst &lr) const
+bool expairseq::match(const ex & pattern, lst & repl_lst) const
{
- epvector *vp = subschildren(ls,lr);
- if (vp==0)
- return inherited::subs(ls, lr);
-
- return thisexpairseq(vp,overall_coeff);
+//clog << "match " << *this << " with " << pattern << ", repl_lst = " << repl_lst << endl;
+ // This differs from basic::match() because we want "a+b+c+d" to
+ // match "d+*+b" with "*" being "a+c", and we want to honor commutativity
+
+ if (tinfo() == pattern.bp->tinfo()) {
+
+ // Check whether global wildcard (one that matches the "rest of the
+ // expression", like "*" above) is present
+ bool has_global_wildcard = false;
+ ex global_wildcard;
+ for (unsigned int i=0; i<pattern.nops(); i++) {
+ if (is_ex_exactly_of_type(pattern.op(i), wildcard)) {
+ has_global_wildcard = true;
+ global_wildcard = pattern.op(i);
+ break;
+ }
+ }
+
+ // Unfortunately, this is an O(N^2) operation because we can't
+ // sort the pattern in a useful way...
+
+ // Chop into terms
+ exvector ops;
+ ops.reserve(nops());
+ for (unsigned i=0; i<nops(); i++)
+ ops.push_back(op(i));
+
+ // Now, for every term of the pattern, look for a matching term in
+ // the expression and remove the match
+ for (unsigned i=0; i<pattern.nops(); i++) {
+ ex p = pattern.op(i);
+ if (has_global_wildcard && p.is_equal(global_wildcard))
+ continue;
+ exvector::iterator it = ops.begin(), itend = ops.end();
+ while (it != itend) {
+ if (it->match(p, repl_lst)) {
+ ops.erase(it);
+ goto found;
+ }
+ it++;
+ }
+ return false; // no match found
+found: ;
+ }
+
+ if (has_global_wildcard) {
+
+ // Assign all the remaining terms to the global wildcard (unless
+ // it has already been matched before, in which case the matches
+ // must be equal)
+ epvector *vp = new epvector();
+ vp->reserve(ops.size());
+ for (unsigned i=0; i<ops.size(); i++)
+ vp->push_back(split_ex_to_pair(ops[i]));
+ ex rest = thisexpairseq(vp, default_overall_coeff());
+ for (unsigned i=0; i<repl_lst.nops(); i++) {
+ if (repl_lst.op(i).op(0).is_equal(global_wildcard))
+ return rest.is_equal(*repl_lst.op(i).op(1).bp);
+ }
+ repl_lst.append(global_wildcard == rest);
+ return true;
+
+ } else {
+
+ // No global wildcard, then the match fails if there are any
+ // unmatched terms left
+ return ops.empty();
+ }
+ }
+ return inherited::match(pattern, repl_lst);
+}
+
+ex expairseq::subs(const lst &ls, const lst &lr, bool no_pattern) const
+{
+ epvector *vp = subschildren(ls, lr, no_pattern);
+ if (vp)
+ return thisexpairseq(vp, overall_coeff).bp->basic::subs(ls, lr, no_pattern);
+ else
+ return basic::subs(ls, lr, no_pattern);
}
// protected
* @see expairseq::subs()
* @return pointer to epvector containing pairs after application of subs or zero
* pointer, if no members were changed. */
-epvector * expairseq::subschildren(const lst &ls, const lst &lr) const
+epvector * expairseq::subschildren(const lst &ls, const lst &lr, bool no_pattern) const
{
// returns a NULL pointer if nothing had to be substituted
// returns a pointer to a newly created epvector otherwise
epvector::const_iterator last = seq.end();
epvector::const_iterator cit = seq.begin();
while (cit!=last) {
- const ex &subsed_ex=(*cit).rest.subs(ls,lr);
+ const ex &subsed_ex=(*cit).rest.subs(ls,lr,no_pattern);
if (!are_ex_trivially_equal((*cit).rest,subsed_ex)) {
// something changed, copy seq, subs and return it
++cit2;
// copy rest
while (cit2!=last) {
- s->push_back(combine_ex_with_coeff_to_pair((*cit2).rest.subs(ls,lr),
+ s->push_back(combine_ex_with_coeff_to_pair((*cit2).rest.subs(ls,lr,no_pattern),
(*cit2).coeff));
++cit2;
}