// Solve linear system
//////////
+static void insert_symbols(exset &es, const ex &e)
+{
+ if (is_a<symbol>(e)) {
+ es.insert(e);
+ } else {
+ for (const ex &sube : e) {
+ insert_symbols(es, sube);
+ }
+ }
+}
+
+static exset symbolset(const ex &e)
+{
+ exset s;
+ insert_symbols(s, e);
+ return s;
+}
+
ex lsolve(const ex &eqns, const ex &symbols, unsigned options)
{
// solve a system of linear equations
for (size_t r=0; r<eqns.nops(); r++) {
const ex eq = eqns.op(r).op(0)-eqns.op(r).op(1); // lhs-rhs==0
+ const exset syms = symbolset(eq);
ex linpart = eq;
for (size_t c=0; c<symbols.nops(); c++) {
+ if (syms.count(symbols.op(c)) == 0) continue;
const ex co = eq.coeff(ex_to<symbol>(symbols.op(c)),1);
linpart -= co*symbols.op(c);
sys(r,c) = co;
}
// test if system is linear and fill vars matrix
+ const exset sys_syms = symbolset(sys);
+ const exset rhs_syms = symbolset(rhs);
for (size_t i=0; i<symbols.nops(); i++) {
vars(i,0) = symbols.op(i);
- if (sys.has(symbols.op(i)))
+ if (sys_syms.count(symbols.op(i)) != 0)
throw(std::logic_error("lsolve: system is not linear"));
- if (rhs.has(symbols.op(i)))
+ if (rhs_syms.count(symbols.op(i)) != 0)
throw(std::logic_error("lsolve: system is not linear"));
}