]> www.ginac.de Git - ginac.git/blob - ginac/polynomial/collect_vargs.cpp
aa9ae34aeaba090ec2c2d02689bba8c2697b0e5a
[ginac.git] / ginac / polynomial / collect_vargs.cpp
1 #include <iterator>
2 #include <map>
3 #include <algorithm>
4 #include <stdexcept>
5 #include <string>
6 #include "add.h"
7 #include "mul.h"
8 #include "operators.h"
9 #include "power.h"
10 #include "collect_vargs.h"
11 #include <cln/integer.h>
12 #include "smod_helpers.h"
13 #include "debug.hpp"
14
15 namespace GiNaC
16 {
17
18 typedef std::map<exp_vector_t, ex> ex_collect_priv_t;
19
20 static void 
21 collect_vargs(ex_collect_priv_t& ec, ex e, const exvector& vars);
22 static void
23 collect_term(ex_collect_priv_t& ec, const ex& e, const exvector& vars);
24 static void wipe_out_zeros(ex_collect_priv_t& ec);
25
26 template<typename T, typename CoeffCMP>
27 struct compare_terms
28 {
29         const CoeffCMP& coeff_cmp;
30         explicit compare_terms(const CoeffCMP& coeff_cmp_) : coeff_cmp(coeff_cmp_)
31         { }
32         inline bool operator()(const T& t1, const T& t2) const
33         {
34                 bool exponent_is_less =
35                         std::lexicographical_compare(t1.first.rbegin(),
36                                                      t1.first.rend(),
37                                                      t2.first.rbegin(),
38                                                      t2.first.rend());
39                 if (exponent_is_less)
40                         return true;
41
42                 if ((t1.first == t2.first) &&
43                                 coeff_cmp(t2.second, t2.second))
44                         return true;
45                 return false;
46         }
47 };
48
49 template<typename T, typename CoeffCMP>
50 static struct compare_terms<T, CoeffCMP>
51 make_compare_terms(const T& dummy, const CoeffCMP& coeff_cmp)
52 {
53         return compare_terms<T, CoeffCMP>(coeff_cmp);
54 }
55
56 void collect_vargs(ex_collect_t& ec, const ex& e, const exvector& vars)
57 {
58         ex_collect_priv_t ecp;
59         collect_vargs(ecp, e, vars);
60         ec.reserve(ecp.size());
61         std::copy(ecp.begin(), ecp.end(), std::back_inserter(ec));
62         std::sort(ec.begin(), ec.end(),
63                   make_compare_terms(*ec.begin(), ex_is_less()));
64 }
65
66 static void 
67 collect_vargs(ex_collect_priv_t& ec, ex e, const exvector& vars)
68 {
69         e = e.expand();
70         if (e.is_zero()) {
71                 ec.clear();
72                 return;
73         }
74
75         if (!is_a<add>(e)) {
76                 collect_term(ec, e, vars);
77                 return;
78         }
79
80         for (const_iterator i = e.begin(); i != e.end(); ++i)
81                 collect_term(ec, *i, vars);
82
83         wipe_out_zeros(ec);
84 }
85
86 static void
87 collect_term(ex_collect_priv_t& ec, const ex& e, const exvector& vars)
88 {
89         if (e.is_zero())
90                 return;
91         static const ex ex1(1);
92         exp_vector_t key(vars.size());
93         ex pre_coeff = e;
94         for (std::size_t i = 0; i < vars.size(); ++i) {
95                 const int var_i_pow = pre_coeff.degree(vars[i]);
96                 key[i] = var_i_pow;
97                 pre_coeff = pre_coeff.coeff(vars[i], var_i_pow);
98         }
99         ex_collect_priv_t::iterator i = ec.find(key);
100         if (i != ec.end())
101                 i->second += pre_coeff;
102         else
103                 ec.insert(ex_collect_priv_t::value_type(key, pre_coeff));
104 }
105
106 static void wipe_out_zeros(ex_collect_priv_t& m)
107 {
108         ex_collect_priv_t::iterator i = m.begin();
109         while (i != m.end()) {
110                 // be careful to not invalide iterator, use post-increment
111                 // for that, see e.g.
112                 // http://coding.derkeiler.com/Archive/C_CPP/comp.lang.cpp/2004-02/0502.html
113                 if (i->second.is_zero())
114                         m.erase(i++);
115                 else
116                         ++i;
117         }
118 }
119
120 GiNaC::ex
121 ex_collect_to_ex(const ex_collect_t& ec, const GiNaC::exvector& vars)
122 {
123         exvector ev;
124         ev.reserve(ec.size());
125         for (std::size_t i = 0; i < ec.size(); ++i) {
126                 exvector tv;
127                 tv.reserve(vars.size() + 1);
128                 for (std::size_t j = 0; j < vars.size(); ++j) {
129                         if (ec[i].first[j] != 0)
130                                 tv.push_back(power(vars[j], ec[i].first[j]));
131                 }
132                 tv.push_back(ec[i].second);
133                 ex tmp = (new mul(tv))->setflag(status_flags::dynallocated);
134                 ev.push_back(tmp);
135         }
136         ex ret = (new add(ev))->setflag(status_flags::dynallocated);
137         return ret;
138 }
139
140 ex lcoeff_wrt(ex e, const exvector& x)
141 {
142         static const ex ex0(0);
143         e = e.expand();
144         if (e.is_zero())
145                 return ex0;
146
147         ex_collect_t ec;
148         collect_vargs(ec, e, x);
149         return ec.rbegin()->second;
150 }
151
152 cln::cl_I integer_lcoeff(const ex& e, const exvector& vars)
153 {
154         ex_collect_t ec;
155         collect_vargs(ec, e, vars);
156         if (ec.size() == 0)
157                 return cln::cl_I(0);
158         ex lc = ec.rbegin()->second;
159         bug_on(!is_a<numeric>(lc), "leading coefficient is not an integer");
160         bug_on(!lc.info(info_flags::integer),
161                 "leading coefficient is not an integer");
162
163         return to_cl_I(lc);
164 }
165
166 } // namespace GiNaC
167