]> www.ginac.de Git - ginac.git/blobdiff - ginac/normal.cpp
replaced fraction addition algorithm with a slightly faster one (no polynomial
[ginac.git] / ginac / normal.cpp
index 1cbb87737cf305b96ef50590a704690dd954b1bb..067670ced9f380b7aff2623755bde19f72704293 100644 (file)
@@ -6,7 +6,7 @@
  *  computation, square-free factorization and rational function normalization. */
 
 /*
- *  GiNaC Copyright (C) 1999-2000 Johannes Gutenberg University Mainz, Germany
+ *  GiNaC Copyright (C) 1999-2001 Johannes Gutenberg University Mainz, Germany
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -1910,13 +1910,15 @@ static ex frac_cancel(const ex &n, const ex &d)
 ex add::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
        if (level == 1)
-               return (new lst(*this, _ex1()))->setflag(status_flags::dynallocated);
+               return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
        else if (level == -max_recursion_level)
                throw(std::runtime_error("max recursion level reached"));
 
-       // Normalize and expand children, chop into summands
-       exvector o;
-       o.reserve(seq.size()+1);
+       // Normalize and expand children, chop into summands and split each
+       // one into numerator and denominator
+       exvector nums, dens;
+       nums.reserve(seq.size()+1);
+       dens.reserve(seq.size()+1);
        epvector::const_iterator it = seq.begin(), itend = seq.end();
        while (it != itend) {
 
@@ -1927,7 +1929,8 @@ ex add::normal(lst &sym_lst, lst &repl_lst, int level) const
                if (is_ex_exactly_of_type(n.op(0), add)) {
                        epvector::const_iterator bit = ex_to_add(n.op(0)).seq.begin(), bitend = ex_to_add(n.op(0)).seq.end();
                        while (bit != bitend) {
-                               o.push_back((new lst(recombine_pair_to_ex(*bit), n.op(1)))->setflag(status_flags::dynallocated));
+                               nums.push_back(recombine_pair_to_ex(*bit));
+                               dens.push_back(n.op(1));
                                bit++;
                        }
 
@@ -1935,53 +1938,40 @@ ex add::normal(lst &sym_lst, lst &repl_lst, int level) const
                        // split it into numerator and denominator
                        GINAC_ASSERT(ex_to_numeric(ex_to_add(n.op(0)).overall_coeff).is_rational());
                        numeric overall = ex_to_numeric(ex_to_add(n.op(0)).overall_coeff);
-                       o.push_back((new lst(overall.numer(), overall.denom() * n.op(1)))->setflag(status_flags::dynallocated));
-               } else
-                       o.push_back(n);
+                       nums.push_back(overall.numer());
+                       dens.push_back(overall.denom() * n.op(1));
+               } else {
+                       nums.push_back(n.op(0));
+                       dens.push_back(n.op(1));
+               }
                it++;
        }
-       o.push_back(overall_coeff.bp->normal(sym_lst, repl_lst, level-1));
+       ex n = overall_coeff.bp->normal(sym_lst, repl_lst, level-1);
+       nums.push_back(n.op(0));
+       dens.push_back(n.op(1));
+       GINAC_ASSERT(nums.size() == dens.size());
 
-       // o is now a vector of {numerator, denominator} lists
+       // Now, nums is a vector of all numerators and dens is a vector of
+       // all denominators
 
-       // Determine common denominator
-       ex den = _ex1();
-       exvector::const_iterator ait = o.begin(), aitend = o.end();
+       // Add fractions sequentially
+       exvector::const_iterator num_it = nums.begin(), num_itend = nums.end();
+       exvector::const_iterator den_it = dens.begin(), den_itend = dens.end();
 //std::clog << "add::normal uses the following summands:\n";
-       while (ait != aitend) {
-//std::clog << " num = " << ait->op(0) << ", den = " << ait->op(1) << endl;
-               den = lcm(ait->op(1), den, false);
-               ait++;
+//std::clog << " num = " << *num_it << ", den = " << *den_it << endl;
+       ex num = *num_it++, den = *den_it++;
+       while (num_it != num_itend) {
+//std::clog << " num = " << *num_it << ", den = " << *den_it << endl;
+               ex co_den1, co_den2;
+               ex g = gcd(den, *den_it, &co_den1, &co_den2, false);
+               num = (num * co_den2) + (*num_it * co_den1);
+               den *= co_den2;         // this is the lcm(den, *den_it)
+               num_it++; den_it++;
        }
 //std::clog << " common denominator = " << den << endl;
 
-       // Add fractions
-       if (den.is_equal(_ex1())) {
-
-               // Common denominator is 1, simply add all fractions
-               exvector num_seq;
-               for (ait=o.begin(); ait!=aitend; ait++) {
-                       num_seq.push_back(ait->op(0) / ait->op(1));
-               }
-               return (new lst((new add(num_seq))->setflag(status_flags::dynallocated), den))->setflag(status_flags::dynallocated);
-
-       } else {
-
-               // Perform fractional addition
-               exvector num_seq;
-               for (ait=o.begin(); ait!=aitend; ait++) {
-                       ex q;
-                       if (!divide(den, ait->op(1), q, false)) {
-                               // should not happen
-                               throw(std::runtime_error("invalid expression in add::normal, division failed"));
-                       }
-                       num_seq.push_back((ait->op(0) * q).expand());
-               }
-               ex num = (new add(num_seq))->setflag(status_flags::dynallocated);
-
-               // Cancel common factors from num/den
-               return frac_cancel(num, den);
-       }
+       // Cancel common factors from num/den
+       return frac_cancel(num, den);
 }
 
 
@@ -1991,7 +1981,7 @@ ex add::normal(lst &sym_lst, lst &repl_lst, int level) const
 ex mul::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
        if (level == 1)
-               return (new lst(*this, _ex1()))->setflag(status_flags::dynallocated);
+               return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
        else if (level == -max_recursion_level)
                throw(std::runtime_error("max recursion level reached"));
 
@@ -2022,7 +2012,7 @@ ex mul::normal(lst &sym_lst, lst &repl_lst, int level) const
 ex power::normal(lst &sym_lst, lst &repl_lst, int level) const
 {
        if (level == 1)
-               return (new lst(*this, _ex1()))->setflag(status_flags::dynallocated);
+               return (new lst(replace_with_symbol(*this, sym_lst, repl_lst), _ex1()))->setflag(status_flags::dynallocated);
        else if (level == -max_recursion_level)
                throw(std::runtime_error("max recursion level reached"));