]> www.ginac.de Git - ginac.git/blob - ginac/polynomial/collect_vargs.cpp
Use C++11 range-based foor loops and auto, where possible.
[ginac.git] / ginac / polynomial / collect_vargs.cpp
1 /** @file collect_vargs.cpp
2  *
3  *  Utility functions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2015 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "add.h"
24 #include "mul.h"
25 #include "operators.h"
26 #include "power.h"
27 #include "collect_vargs.h"
28 #include "smod_helpers.h"
29 #include "debug.h"
30
31 #include <algorithm>
32 #include <cln/integer.h>
33 #include <iterator>
34 #include <map>
35 #include <stdexcept>
36 #include <string>
37
38 namespace GiNaC {
39
40 typedef std::map<exp_vector_t, ex> ex_collect_priv_t;
41
42 static void 
43 collect_vargs(ex_collect_priv_t& ec, ex e, const exvector& vars);
44 static void
45 collect_term(ex_collect_priv_t& ec, const ex& e, const exvector& vars);
46 static void wipe_out_zeros(ex_collect_priv_t& ec);
47
48 template<typename T, typename CoeffCMP>
49 struct compare_terms
50 {
51         const CoeffCMP& coeff_cmp;
52         explicit compare_terms(const CoeffCMP& coeff_cmp_) : coeff_cmp(coeff_cmp_)
53         { }
54         inline bool operator()(const T& t1, const T& t2) const
55         {
56                 bool exponent_is_less =
57                         std::lexicographical_compare(t1.first.rbegin(),
58                                                      t1.first.rend(),
59                                                      t2.first.rbegin(),
60                                                      t2.first.rend());
61                 if (exponent_is_less)
62                         return true;
63
64                 if ((t1.first == t2.first) &&
65                                 coeff_cmp(t2.second, t2.second))
66                         return true;
67                 return false;
68         }
69 };
70
71 template<typename T, typename CoeffCMP>
72 static compare_terms<T, CoeffCMP>
73 make_compare_terms(const T& dummy, const CoeffCMP& coeff_cmp)
74 {
75         return compare_terms<T, CoeffCMP>(coeff_cmp);
76 }
77
78 void collect_vargs(ex_collect_t& ec, const ex& e, const exvector& vars)
79 {
80         ex_collect_priv_t ecp;
81         collect_vargs(ecp, e, vars);
82         ec.reserve(ecp.size());
83         std::copy(ecp.begin(), ecp.end(), std::back_inserter(ec));
84         std::sort(ec.begin(), ec.end(),
85                   make_compare_terms(*ec.begin(), ex_is_less()));
86 }
87
88 static void 
89 collect_vargs(ex_collect_priv_t& ec, ex e, const exvector& vars)
90 {
91         e = e.expand();
92         if (e.is_zero()) {
93                 ec.clear();
94                 return;
95         }
96
97         if (!is_a<add>(e)) {
98                 collect_term(ec, e, vars);
99                 return;
100         }
101
102         for (const_iterator i = e.begin(); i != e.end(); ++i)
103                 collect_term(ec, *i, vars);
104
105         wipe_out_zeros(ec);
106 }
107
108 static void
109 collect_term(ex_collect_priv_t& ec, const ex& e, const exvector& vars)
110 {
111         if (e.is_zero())
112                 return;
113         static const ex ex1(1);
114         exp_vector_t key(vars.size());
115         ex pre_coeff = e;
116         for (std::size_t i = 0; i < vars.size(); ++i) {
117                 const int var_i_pow = pre_coeff.degree(vars[i]);
118                 key[i] = var_i_pow;
119                 pre_coeff = pre_coeff.coeff(vars[i], var_i_pow);
120         }
121         ex_collect_priv_t::iterator i = ec.find(key);
122         if (i != ec.end())
123                 i->second += pre_coeff;
124         else
125                 ec.insert(ex_collect_priv_t::value_type(key, pre_coeff));
126 }
127
128 static void wipe_out_zeros(ex_collect_priv_t& m)
129 {
130         ex_collect_priv_t::iterator i = m.begin();
131         while (i != m.end()) {
132                 // be careful to not invalide iterator, use post-increment
133                 // for that, see e.g.
134                 // http://coding.derkeiler.com/Archive/C_CPP/comp.lang.cpp/2004-02/0502.html
135                 if (i->second.is_zero())
136                         m.erase(i++);
137                 else
138                         ++i;
139         }
140 }
141
142 ex
143 ex_collect_to_ex(const ex_collect_t& ec, const exvector& vars)
144 {
145         exvector ev;
146         ev.reserve(ec.size());
147         for (std::size_t i = 0; i < ec.size(); ++i) {
148                 exvector tv;
149                 tv.reserve(vars.size() + 1);
150                 for (std::size_t j = 0; j < vars.size(); ++j) {
151                         const exp_vector_t& exp_vector(ec[i].first);
152
153                         bug_on(exp_vector.size() != vars.size(),
154                                 "expected " << vars.size() << " variables, "
155                                 "expression has " << exp_vector.size() << " instead");
156
157                         if (exp_vector[j] != 0)
158                                 tv.push_back(power(vars[j], exp_vector[j]));
159                 }
160                 tv.push_back(ec[i].second);
161                 ex tmp = (new mul(tv))->setflag(status_flags::dynallocated);
162                 ev.push_back(tmp);
163         }
164         ex ret = (new add(ev))->setflag(status_flags::dynallocated);
165         return ret;
166 }
167
168 ex lcoeff_wrt(ex e, const exvector& x)
169 {
170         static const ex ex0(0);
171         e = e.expand();
172         if (e.is_zero())
173                 return ex0;
174
175         ex_collect_t ec;
176         collect_vargs(ec, e, x);
177         return ec.rbegin()->second;
178 }
179
180 exp_vector_t degree_vector(ex e, const exvector& vars)
181 {
182         e = e.expand();
183         exp_vector_t dvec(vars.size());
184         for (std::size_t i = vars.size(); i-- != 0; ) {
185                 const int deg_i = e.degree(vars[i]);
186                 e = e.coeff(vars[i], deg_i);
187                 dvec[i] = deg_i;
188         }
189         return dvec;
190 }
191
192 cln::cl_I integer_lcoeff(const ex& e, const exvector& vars)
193 {
194         ex_collect_t ec;
195         collect_vargs(ec, e, vars);
196         if (ec.size() == 0)
197                 return cln::cl_I(0);
198         ex lc = ec.rbegin()->second;
199         bug_on(!is_a<numeric>(lc), "leading coefficient is not an integer");
200         bug_on(!lc.info(info_flags::integer),
201                 "leading coefficient is not an integer");
202
203         return to_cl_I(lc);
204 }
205
206 } // namespace GiNaC