Alexeis patches for better handling of collect_common_factors.
authorChris Dams <Chris.Dams@mi.infn.it>
Wed, 22 Nov 2006 21:52:33 +0000 (21:52 +0000)
committerChris Dams <Chris.Dams@mi.infn.it>
Wed, 22 Nov 2006 21:52:33 +0000 (21:52 +0000)
ginac/normal.cpp

index 7b90030de6b0fb595b781b9f3c7a532ad4e6c015..b2a4e8c8d0cfdfe2204799fefe96b2710771cb24 100644 (file)
@@ -614,6 +614,72 @@ bool divide(const ex &a, const ex &b, ex &q, bool check_args)
        if (!get_first_symbol(a, x) && !get_first_symbol(b, x))
                throw(std::invalid_argument("invalid expression in divide()"));
 
+       // Try to avoid expanding partially factored expressions.
+       if (is_exactly_a<mul>(b)) {
+       // Divide sequentially by each term
+               ex rem_new, rem_old = a;
+               for (size_t i=0; i < b.nops(); i++) {
+                       if (! divide(rem_old, b.op(i), rem_new, false))
+                               return false;
+                       rem_old = rem_new;
+               }
+               q = rem_new;
+               return true;
+       } else if (is_exactly_a<power>(b)) {
+               const ex& bb(b.op(0));
+               int exp_b = ex_to<numeric>(b.op(1)).to_int();
+               ex rem_new, rem_old = a;
+               for (int i=exp_b; i>0; i--) {
+                       if (! divide(rem_old, bb, rem_new, false))
+                               return false;
+                       rem_old = rem_new;
+               }
+               q = rem_new;
+               return true;
+       } 
+       
+       if (is_exactly_a<mul>(a)) {
+               // Divide sequentially each term. If some term in a is divisible 
+               // by b we are done... and if not, we can't really say anything.
+               size_t i;
+               ex rem_i;
+               bool divisible_p = false;
+               for (i=0; i < a.nops(); ++i) {
+                       if (divide(a.op(i), b, rem_i, false)) {
+                               divisible_p = true;
+                               break;
+                       }
+               }
+               if (divisible_p) {
+                       exvector resv;
+                       resv.reserve(a.nops());
+                       for (size_t j=0; j < a.nops(); j++) {
+                               if (j==i)
+                                       resv.push_back(rem_i);
+                               else
+                                       resv.push_back(a.op(j));
+                       }
+                       q = (new mul(resv))->setflag(status_flags::dynallocated);
+                       return true;
+               }
+       } else if (is_exactly_a<power>(a)) {
+               // The base itself might be divisible by b, in that case we don't
+               // need to expand a
+               const ex& ab(a.op(0));
+               int a_exp = ex_to<numeric>(a.op(1)).to_int();
+               ex rem_i;
+               if (divide(ab, b, rem_i, false)) {
+                       q = rem_i*power(ab, a_exp - 1);
+                       return true;
+               }
+               for (int i=2; i < a_exp; i++) {
+                       if (divide(power(ab, i), b, rem_i, false)) {
+                               q = rem_i*power(ab, a_exp - i);
+                               return true;
+                       }
+               } // ... so we *really* need to expand expression.
+       }
+       
        // Polynomial long division (recursive)
        ex r = a.expand();
        if (r.is_zero()) {
@@ -2389,7 +2455,16 @@ ex power::to_polynomial(exmap & repl) const
        if (exponent.info(info_flags::posint))
                return power(basis.to_rational(repl), exponent);
        else if (exponent.info(info_flags::negint))
-               return power(replace_with_symbol(power(basis, _ex_1), repl), -exponent);
+       {
+               ex basis_pref = collect_common_factors(basis);
+               if (is_exactly_a<mul>(basis_pref) || is_exactly_a<power>(basis_pref)) {
+                       // (A*B)^n will be automagically transformed to A^n*B^n
+                       ex t = power(basis_pref, exponent);
+                       return t.to_polynomial(repl);
+               }
+               else
+                       return power(replace_with_symbol(power(basis, _ex_1), repl), -exponent);
+       } 
        else
                return replace_with_symbol(*this, repl);
 }
@@ -2447,7 +2522,7 @@ static ex find_common_factor(const ex & e, ex & factor, exmap & repl)
                for (size_t i=0; i<num; i++) {
                        ex x = e.op(i).to_polynomial(repl);
 
-                       if (is_exactly_a<add>(x) || is_exactly_a<mul>(x)) {
+                       if (is_exactly_a<add>(x) || is_exactly_a<mul>(x) || is_a<power>(x)) {
                                ex f = 1;
                                x = find_common_factor(x, f, repl);
                                x *= f;
@@ -2507,7 +2582,7 @@ term_done:        ;
 
        } else if (is_exactly_a<power>(e)) {
                const ex e_exp(e.op(1));
-               if (e_exp.info(info_flags::posint)) {
+               if (e_exp.info(info_flags::integer)) {
                        ex eb = e.op(0).to_polynomial(repl);
                        ex factor_local(_ex1);
                        ex pre_res = find_common_factor(eb, factor_local, repl);