]> www.ginac.de Git - ginac.git/blobdiff - ginac/normal.cpp
when there are multiple variables with the same maximum degree, the one with
[ginac.git] / ginac / normal.cpp
index 0e1a0bfa4f39b2e8f5911cb4813e969c5e3ac52d..3eaf4206bbbe314d1ea2ff667db8a792c0d32853 100644 (file)
@@ -6,7 +6,7 @@
  *  computation, square-free factorization and rational function normalization. */
 
 /*
- *  GiNaC Copyright (C) 1999-2000 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -34,7 +34,6 @@
 #include "constant.h"
 #include "expairseq.h"
 #include "fail.h"
-#include "indexed.h"
 #include "inifcns.h"
 #include "lst.h"
 #include "mul.h"
@@ -140,8 +139,17 @@ struct sym_desc {
        /** Maximum of deg_a and deg_b (Used for sorting) */
        int max_deg;
 
+       /** Maximum number of terms of leading coefficient of symbol in both polynomials */
+       int max_lcnops;
+
        /** Commparison operator for sorting */
-       bool operator<(const sym_desc &x) const {return max_deg < x.max_deg;}
+       bool operator<(const sym_desc &x) const
+       {
+               if (max_deg == x.max_deg)
+                       return max_lcnops < x.max_lcnops;
+               else
+                       return max_deg < x.max_deg;
+       }
 };
 
 // Vector of sym_desc structures
@@ -196,7 +204,8 @@ static void get_symbol_stats(const ex &a, const ex &b, sym_desc_vec &v)
                int deg_b = b.degree(*(it->sym));
                it->deg_a = deg_a;
                it->deg_b = deg_b;
-               it->max_deg = max(deg_a, deg_b);
+               it->max_deg = std::max(deg_a, deg_b);
+               it->max_lcnops = std::max(a.lcoeff(*(it->sym)).nops(), b.lcoeff(*(it->sym)).nops());
                it->ldeg_a = a.ldegree(*(it->sym));
                it->ldeg_b = b.ldegree(*(it->sym));
                it++;
@@ -206,7 +215,7 @@ static void get_symbol_stats(const ex &a, const ex &b, sym_desc_vec &v)
        std::clog << "Symbols:\n";
        it = v.begin(); itend = v.end();
        while (it != itend) {
-               std::clog << " " << *it->sym << ": deg_a=" << it->deg_a << ", deg_b=" << it->deg_b << ", ldeg_a=" << it->ldeg_a << ", ldeg_b=" << it->ldeg_b << ", max_deg=" << it->max_deg << endl;
+               std::clog << " " << *it->sym << ": deg_a=" << it->deg_a << ", deg_b=" << it->deg_b << ", ldeg_a=" << it->ldeg_a << ", ldeg_b=" << it->ldeg_b << ", max_deg=" << it->max_deg << ", max_lcnops=" << it->max_lcnops << endl;
                std::clog << "  lcoeff_a=" << a.lcoeff(*(it->sym)) << ", lcoeff_b=" << b.lcoeff(*(it->sym)) << endl;
                it++;
        }
@@ -1353,10 +1362,11 @@ static ex heur_gcd(const ex &a, const ex &b, ex *ca, ex *cb, sym_desc_vec::const
        numeric rgc = gc.inverse();
        ex p = a * rgc;
        ex q = b * rgc;
-       int maxdeg = max(p.degree(x), q.degree(x));
-
+       int maxdeg =  std::max(p.degree(x),q.degree(x));
+       
        // Find evaluation point
-       numeric mp = p.max_coefficient(), mq = q.max_coefficient();
+       numeric mp = p.max_coefficient();
+       numeric mq = q.max_coefficient();
        numeric xi;
        if (mp > mq)
                xi = mq * _num2() + _num2();
@@ -1463,7 +1473,7 @@ ex gcd(const ex &a, const ex &b, ex *ca, ex *cb, bool check_args)
        }
 
        // Check arguments
-       if (check_args && !a.info(info_flags::rational_polynomial) || !b.info(info_flags::rational_polynomial)) {
+       if (check_args && (!a.info(info_flags::rational_polynomial) || !b.info(info_flags::rational_polynomial))) {
                throw(std::invalid_argument("gcd: arguments must be polynomials over the rationals"));
        }
 
@@ -1595,7 +1605,7 @@ factored_b:
        // Cancel trivial common factor
        int ldeg_a = var->ldeg_a;
        int ldeg_b = var->ldeg_b;
-       int min_ldeg = min(ldeg_a, ldeg_b);
+       int min_ldeg = std::min(ldeg_a,ldeg_b);
        if (min_ldeg > 0) {
                ex common = power(x, min_ldeg);
 //std::clog << "trivial common factor " << common << endl;
@@ -1677,7 +1687,7 @@ ex lcm(const ex &a, const ex &b, bool check_args)
 {
        if (is_ex_exactly_of_type(a, numeric) && is_ex_exactly_of_type(b, numeric))
                return lcm(ex_to_numeric(a), ex_to_numeric(b));
-       if (check_args && !a.info(info_flags::rational_polynomial) || !b.info(info_flags::rational_polynomial))
+       if (check_args && (!a.info(info_flags::rational_polynomial) || !b.info(info_flags::rational_polynomial)))
                throw(std::invalid_argument("lcm: arguments must be polynomials over the rationals"));
        
        ex ca, cb;
@@ -1865,9 +1875,13 @@ static ex frac_cancel(const ex &n, const ex &d)
 
 //std::clog << "frac_cancel num = " << num << ", den = " << den << endl;
 
+       // Handle trivial case where denominator is 1
+       if (den.is_equal(_ex1()))
+               return (new lst(num, den))->setflag(status_flags::dynallocated);
+
        // Handle special cases where numerator or denominator is 0
        if (num.is_zero())
-               return (new lst(_ex0(), _ex1()))->setflag(status_flags::dynallocated);
+               return (new lst(num, _ex1()))->setflag(status_flags::dynallocated);
        if (den.expand().is_zero())
                throw(std::overflow_error("frac_cancel: division by zero in frac_cancel"));
 
@@ -1909,78 +1923,56 @@ static ex frac_cancel(const ex &n, const ex &d)
 ex add::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
        if (level == 1)
-               return (new lst(*this, _ex1()))->setflag(status_flags::dynallocated);
+               return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
        else if (level == -max_recursion_level)
                throw(std::runtime_error("max recursion level reached"));
 
-       // Normalize and expand children, chop into summands
-       exvector o;
-       o.reserve(seq.size()+1);
+       // Normalize children and split each one into numerator and denominator
+       exvector nums, dens;
+       nums.reserve(seq.size()+1);
+       dens.reserve(seq.size()+1);
        epvector::const_iterator it = seq.begin(), itend = seq.end();
        while (it != itend) {
-
-               // Normalize and expand child
-               ex n = recombine_pair_to_ex(*it).bp->normal(sym_lst, repl_lst, level-1).expand();
-
-               // If numerator is a sum, chop into summands
-               if (is_ex_exactly_of_type(n.op(0), add)) {
-                       epvector::const_iterator bit = ex_to_add(n.op(0)).seq.begin(), bitend = ex_to_add(n.op(0)).seq.end();
-                       while (bit != bitend) {
-                               o.push_back((new lst(recombine_pair_to_ex(*bit), n.op(1)))->setflag(status_flags::dynallocated));
-                               bit++;
-                       }
-
-                       // The overall_coeff is already normalized (== rational), we just
-                       // split it into numerator and denominator
-                       GINAC_ASSERT(ex_to_numeric(ex_to_add(n.op(0)).overall_coeff).is_rational());
-                       numeric overall = ex_to_numeric(ex_to_add(n.op(0)).overall_coeff);
-                       o.push_back((new lst(overall.numer(), overall.denom() * n.op(1)))->setflag(status_flags::dynallocated));
-               } else
-                       o.push_back(n);
+               ex n = recombine_pair_to_ex(*it).bp->normal(sym_lst, repl_lst, level-1);
+               nums.push_back(n.op(0));
+               dens.push_back(n.op(1));
                it++;
        }
-       o.push_back(overall_coeff.bp->normal(sym_lst, repl_lst, level-1));
-
-       // o is now a vector of {numerator, denominator} lists
+       ex n = overall_coeff.bp->normal(sym_lst, repl_lst, level-1);
+       nums.push_back(n.op(0));
+       dens.push_back(n.op(1));
+       GINAC_ASSERT(nums.size() == dens.size());
+
+       // Now, nums is a vector of all numerators and dens is a vector of
+       // all denominators
+//std::clog << "add::normal uses " << nums.size() << " summands:\n";
+
+       // Add fractions sequentially
+       exvector::const_iterator num_it = nums.begin(), num_itend = nums.end();
+       exvector::const_iterator den_it = dens.begin(), den_itend = dens.end();
+//std::clog << " num = " << *num_it << ", den = " << *den_it << endl;
+       ex num = *num_it++, den = *den_it++;
+       while (num_it != num_itend) {
+//std::clog << " num = " << *num_it << ", den = " << *den_it << endl;
+               ex next_num = *num_it++, next_den = *den_it++;
+
+               // Trivially add sequences of fractions with identical denominators
+               while ((den_it != den_itend) && next_den.is_equal(*den_it)) {
+                       next_num += *num_it;
+                       num_it++; den_it++;
+               }
 
-       // Determine common denominator
-       ex den = _ex1();
-       exvector::const_iterator ait = o.begin(), aitend = o.end();
-//std::clog << "add::normal uses the following summands:\n";
-       while (ait != aitend) {
-//std::clog << " num = " << ait->op(0) << ", den = " << ait->op(1) << endl;
-               den = lcm(ait->op(1), den, false);
-               ait++;
+               // Additiion of two fractions, taking advantage of the fact that
+               // the heuristic GCD algorithm computes the cofactors at no extra cost
+               ex co_den1, co_den2;
+               ex g = gcd(den, next_den, &co_den1, &co_den2, false);
+               num = ((num * co_den2) + (next_num * co_den1)).expand();
+               den *= co_den2;         // this is the lcm(den, next_den)
        }
 //std::clog << " common denominator = " << den << endl;
 
-       // Add fractions
-       if (den.is_equal(_ex1())) {
-
-               // Common denominator is 1, simply add all fractions
-               exvector num_seq;
-               for (ait=o.begin(); ait!=aitend; ait++) {
-                       num_seq.push_back(ait->op(0) / ait->op(1));
-               }
-               return (new lst((new add(num_seq))->setflag(status_flags::dynallocated), den))->setflag(status_flags::dynallocated);
-
-       } else {
-
-               // Perform fractional addition
-               exvector num_seq;
-               for (ait=o.begin(); ait!=aitend; ait++) {
-                       ex q;
-                       if (!divide(den, ait->op(1), q, false)) {
-                               // should not happen
-                               throw(std::runtime_error("invalid expression in add::normal, division failed"));
-                       }
-                       num_seq.push_back((ait->op(0) * q).expand());
-               }
-               ex num = (new add(num_seq))->setflag(status_flags::dynallocated);
-
-               // Cancel common factors from num/den
-               return frac_cancel(num, den);
-       }
+       // Cancel common factors from num/den
+       return frac_cancel(num, den);
 }
 
 
@@ -1990,7 +1982,7 @@ ex add::normal(lst &sym_lst, lst &repl_lst, int level) const
 ex mul::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
        if (level == 1)
-               return (new lst(*this, _ex1()))->setflag(status_flags::dynallocated);
+               return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
        else if (level == -max_recursion_level)
                throw(std::runtime_error("max recursion level reached"));
 
@@ -2021,50 +2013,52 @@ ex mul::normal(lst &sym_lst, lst &repl_lst, int level) const
 ex power::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
        if (level == 1)
-               return (new lst(*this, _ex1()))->setflag(status_flags::dynallocated);
+               return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
        else if (level == -max_recursion_level)
                throw(std::runtime_error("max recursion level reached"));
 
-       // Normalize basis
-       ex n = basis.bp->normal(sym_lst, repl_lst, level-1);
+       // Normalize basis and exponent (exponent gets reassembled)
+       ex n_basis = basis.bp->normal(sym_lst, repl_lst, level-1);
+       ex n_exponent = exponent.bp->normal(sym_lst, repl_lst, level-1);
+       n_exponent = n_exponent.op(0) / n_exponent.op(1);
 
-       if (exponent.info(info_flags::integer)) {
+       if (n_exponent.info(info_flags::integer)) {
 
-               if (exponent.info(info_flags::positive)) {
+               if (n_exponent.info(info_flags::positive)) {
 
                        // (a/b)^n -> {a^n, b^n}
-                       return (new lst(power(n.op(0), exponent), power(n.op(1), exponent)))->setflag(status_flags::dynallocated);
+                       return (new lst(power(n_basis.op(0), n_exponent), power(n_basis.op(1), n_exponent)))->setflag(status_flags::dynallocated);
 
-               } else if (exponent.info(info_flags::negative)) {
+               } else if (n_exponent.info(info_flags::negative)) {
 
                        // (a/b)^-n -> {b^n, a^n}
-                       return (new lst(power(n.op(1), -exponent), power(n.op(0), -exponent)))->setflag(status_flags::dynallocated);
+                       return (new lst(power(n_basis.op(1), -n_exponent), power(n_basis.op(0), -n_exponent)))->setflag(status_flags::dynallocated);
                }
 
        } else {
 
-               if (exponent.info(info_flags::positive)) {
+               if (n_exponent.info(info_flags::positive)) {
 
                        // (a/b)^x -> {sym((a/b)^x), 1}
-                       return (new lst(replace_with_symbol(power(n.op(0) / n.op(1), exponent), sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
+                       return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1), n_exponent), sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
 
-               } else if (exponent.info(info_flags::negative)) {
+               } else if (n_exponent.info(info_flags::negative)) {
 
-                       if (n.op(1).is_equal(_ex1())) {
+                       if (n_basis.op(1).is_equal(_ex1())) {
 
                                // a^-x -> {1, sym(a^x)}
-                               return (new lst(_ex1(), replace_with_symbol(power(n.op(0), -exponent), sym_lst, repl_lst)))->setflag(status_flags::dynallocated);
+                               return (new lst(_ex1(), replace_with_symbol(power(n_basis.op(0), -n_exponent), sym_lst, repl_lst)))->setflag(status_flags::dynallocated);
 
                        } else {
 
                                // (a/b)^-x -> {sym((b/a)^x), 1}
-                               return (new lst(replace_with_symbol(power(n.op(1) / n.op(0), -exponent), sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
+                               return (new lst(replace_with_symbol(power(n_basis.op(1) / n_basis.op(0), -n_exponent), sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
                        }
 
-               } else {        // exponent not numeric
+               } else {        // n_exponent not numeric
 
                        // (a/b)^x -> {sym((a/b)^x, 1}
-                       return (new lst(replace_with_symbol(power(n.op(0) / n.op(1), exponent), sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
+                       return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1), n_exponent), sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
                }
        }
 }
@@ -2075,15 +2069,13 @@ ex power::normal(lst &sym_lst, lst &repl_lst, int level) const
  *  @see ex::normal */
 ex pseries::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
-       epvector new_seq;
-       new_seq.reserve(seq.size());
-
-       epvector::const_iterator it = seq.begin(), itend = seq.end();
-       while (it != itend) {
-               new_seq.push_back(expair(it->rest.normal(), it->coeff));
-               it++;
+       epvector newseq;
+       for (epvector::const_iterator i=seq.begin(); i!=seq.end(); ++i) {
+               ex restexp = i->rest.normal();
+               if (!restexp.is_zero())
+                       newseq.push_back(expair(restexp, i->coeff));
        }
-       ex n = pseries(relational(var,point), new_seq);
+       ex n = pseries(relational(var,point), newseq);
        return (new lst(replace_with_symbol(n, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
 }