]> www.ginac.de Git - ginac.git/blob - ginac/polynomial/collect_vargs.cpp
Copyright goes 2010.
[ginac.git] / ginac / polynomial / collect_vargs.cpp
1 /** @file collect_vargs.cpp
2  *
3  *  Utility functions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2010 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "add.h"
24 #include "mul.h"
25 #include "operators.h"
26 #include "power.h"
27 #include "collect_vargs.h"
28 #include "smod_helpers.h"
29 #include "debug.h"
30
31 #include <algorithm>
32 #include <cln/integer.h>
33 #include <iterator>
34 #include <map>
35 #include <stdexcept>
36 #include <string>
37
38 namespace GiNaC {
39
40 typedef std::map<exp_vector_t, ex> ex_collect_priv_t;
41
42 static void 
43 collect_vargs(ex_collect_priv_t& ec, ex e, const exvector& vars);
44 static void
45 collect_term(ex_collect_priv_t& ec, const ex& e, const exvector& vars);
46 static void wipe_out_zeros(ex_collect_priv_t& ec);
47
48 template<typename T, typename CoeffCMP>
49 struct compare_terms
50 {
51         const CoeffCMP& coeff_cmp;
52         explicit compare_terms(const CoeffCMP& coeff_cmp_) : coeff_cmp(coeff_cmp_)
53         { }
54         inline bool operator()(const T& t1, const T& t2) const
55         {
56                 bool exponent_is_less =
57                         std::lexicographical_compare(t1.first.rbegin(),
58                                                      t1.first.rend(),
59                                                      t2.first.rbegin(),
60                                                      t2.first.rend());
61                 if (exponent_is_less)
62                         return true;
63
64                 if ((t1.first == t2.first) &&
65                                 coeff_cmp(t2.second, t2.second))
66                         return true;
67                 return false;
68         }
69 };
70
71 template<typename T, typename CoeffCMP>
72 static struct compare_terms<T, CoeffCMP>
73 make_compare_terms(const T& dummy, const CoeffCMP& coeff_cmp)
74 {
75         return compare_terms<T, CoeffCMP>(coeff_cmp);
76 }
77
78 void collect_vargs(ex_collect_t& ec, const ex& e, const exvector& vars)
79 {
80         ex_collect_priv_t ecp;
81         collect_vargs(ecp, e, vars);
82         ec.reserve(ecp.size());
83         std::copy(ecp.begin(), ecp.end(), std::back_inserter(ec));
84         std::sort(ec.begin(), ec.end(),
85                   make_compare_terms(*ec.begin(), ex_is_less()));
86 }
87
88 static void 
89 collect_vargs(ex_collect_priv_t& ec, ex e, const exvector& vars)
90 {
91         e = e.expand();
92         if (e.is_zero()) {
93                 ec.clear();
94                 return;
95         }
96
97         if (!is_a<add>(e)) {
98                 collect_term(ec, e, vars);
99                 return;
100         }
101
102         for (const_iterator i = e.begin(); i != e.end(); ++i)
103                 collect_term(ec, *i, vars);
104
105         wipe_out_zeros(ec);
106 }
107
108 static void
109 collect_term(ex_collect_priv_t& ec, const ex& e, const exvector& vars)
110 {
111         if (e.is_zero())
112                 return;
113         static const ex ex1(1);
114         exp_vector_t key(vars.size());
115         ex pre_coeff = e;
116         for (std::size_t i = 0; i < vars.size(); ++i) {
117                 const int var_i_pow = pre_coeff.degree(vars[i]);
118                 key[i] = var_i_pow;
119                 pre_coeff = pre_coeff.coeff(vars[i], var_i_pow);
120         }
121         ex_collect_priv_t::iterator i = ec.find(key);
122         if (i != ec.end())
123                 i->second += pre_coeff;
124         else
125                 ec.insert(ex_collect_priv_t::value_type(key, pre_coeff));
126 }
127
128 static void wipe_out_zeros(ex_collect_priv_t& m)
129 {
130         ex_collect_priv_t::iterator i = m.begin();
131         while (i != m.end()) {
132                 // be careful to not invalide iterator, use post-increment
133                 // for that, see e.g.
134                 // http://coding.derkeiler.com/Archive/C_CPP/comp.lang.cpp/2004-02/0502.html
135                 if (i->second.is_zero())
136                         m.erase(i++);
137                 else
138                         ++i;
139         }
140 }
141
142 GiNaC::ex
143 ex_collect_to_ex(const ex_collect_t& ec, const GiNaC::exvector& vars)
144 {
145         exvector ev;
146         ev.reserve(ec.size());
147         for (std::size_t i = 0; i < ec.size(); ++i) {
148                 exvector tv;
149                 tv.reserve(vars.size() + 1);
150                 for (std::size_t j = 0; j < vars.size(); ++j) {
151                         if (ec[i].first[j] != 0)
152                                 tv.push_back(power(vars[j], ec[i].first[j]));
153                 }
154                 tv.push_back(ec[i].second);
155                 ex tmp = (new mul(tv))->setflag(status_flags::dynallocated);
156                 ev.push_back(tmp);
157         }
158         ex ret = (new add(ev))->setflag(status_flags::dynallocated);
159         return ret;
160 }
161
162 ex lcoeff_wrt(ex e, const exvector& x)
163 {
164         static const ex ex0(0);
165         e = e.expand();
166         if (e.is_zero())
167                 return ex0;
168
169         ex_collect_t ec;
170         collect_vargs(ec, e, x);
171         return ec.rbegin()->second;
172 }
173
174 cln::cl_I integer_lcoeff(const ex& e, const exvector& vars)
175 {
176         ex_collect_t ec;
177         collect_vargs(ec, e, vars);
178         if (ec.size() == 0)
179                 return cln::cl_I(0);
180         ex lc = ec.rbegin()->second;
181         bug_on(!is_a<numeric>(lc), "leading coefficient is not an integer");
182         bug_on(!lc.info(info_flags::integer),
183                 "leading coefficient is not an integer");
184
185         return to_cl_I(lc);
186 }
187
188 } // namespace GiNaC