]> www.ginac.de Git - ginac.git/blob - ginac/polynomial/collect_vargs.cpp
Implement modular multivariate GCD (based on chinese remaindering algorithm).
[ginac.git] / ginac / polynomial / collect_vargs.cpp
1 #include <iterator>
2 #include <map>
3 #include <algorithm>
4 #include <stdexcept>
5 #include <string>
6 #include <ginac/ginac.h>
7 #include "collect_vargs.h"
8 #include <cln/integer.h>
9 #include "smod_helpers.h"
10 #include "debug.hpp"
11
12 namespace GiNaC
13 {
14
15 typedef std::map<exp_vector_t, ex> ex_collect_priv_t;
16
17 static void 
18 collect_vargs(ex_collect_priv_t& ec, ex e, const exvector& vars);
19 static void
20 collect_term(ex_collect_priv_t& ec, const ex& e, const exvector& vars);
21 static void wipe_out_zeros(ex_collect_priv_t& ec);
22
23 template<typename T, typename CoeffCMP>
24 struct compare_terms
25 {
26         const CoeffCMP& coeff_cmp;
27         explicit compare_terms(const CoeffCMP& coeff_cmp_) : coeff_cmp(coeff_cmp_)
28         { }
29         inline bool operator()(const T& t1, const T& t2) const
30         {
31                 bool exponent_is_less =
32                         std::lexicographical_compare(t1.first.rbegin(),
33                                                      t1.first.rend(),
34                                                      t2.first.rbegin(),
35                                                      t2.first.rend());
36                 if (exponent_is_less)
37                         return true;
38
39                 if ((t1.first == t2.first) &&
40                                 coeff_cmp(t2.second, t2.second))
41                         return true;
42                 return false;
43         }
44 };
45
46 template<typename T, typename CoeffCMP>
47 static struct compare_terms<T, CoeffCMP>
48 make_compare_terms(const T& dummy, const CoeffCMP& coeff_cmp)
49 {
50         return compare_terms<T, CoeffCMP>(coeff_cmp);
51 }
52
53 void collect_vargs(ex_collect_t& ec, const ex& e, const exvector& vars)
54 {
55         ex_collect_priv_t ecp;
56         collect_vargs(ecp, e, vars);
57         ec.reserve(ecp.size());
58         std::copy(ecp.begin(), ecp.end(), std::back_inserter(ec));
59         std::sort(ec.begin(), ec.end(),
60                   make_compare_terms(*ec.begin(), ex_is_less()));
61 }
62
63 static void 
64 collect_vargs(ex_collect_priv_t& ec, ex e, const exvector& vars)
65 {
66         e = e.expand();
67         if (e.is_zero()) {
68                 ec.clear();
69                 return;
70         }
71
72         if (!is_a<add>(e)) {
73                 collect_term(ec, e, vars);
74                 return;
75         }
76
77         for (const_iterator i = e.begin(); i != e.end(); ++i)
78                 collect_term(ec, *i, vars);
79
80         wipe_out_zeros(ec);
81 }
82
83 static void
84 collect_term(ex_collect_priv_t& ec, const ex& e, const exvector& vars)
85 {
86         if (e.is_zero())
87                 return;
88         static const ex ex1(1);
89         exp_vector_t key(vars.size());
90         ex pre_coeff = e;
91         for (std::size_t i = 0; i < vars.size(); ++i) {
92                 const int var_i_pow = pre_coeff.degree(vars[i]);
93                 key[i] = var_i_pow;
94                 pre_coeff = pre_coeff.coeff(vars[i], var_i_pow);
95         }
96         ex_collect_priv_t::iterator i = ec.find(key);
97         if (i != ec.end())
98                 i->second += pre_coeff;
99         else
100                 ec.insert(ex_collect_priv_t::value_type(key, pre_coeff));
101 }
102
103 static void wipe_out_zeros(ex_collect_priv_t& m)
104 {
105         ex_collect_priv_t::iterator i = m.begin();
106         while (i != m.end()) {
107                 // be careful to not invalide iterator, use post-increment
108                 // for that, see e.g.
109                 // http://coding.derkeiler.com/Archive/C_CPP/comp.lang.cpp/2004-02/0502.html
110                 if (i->second.is_zero())
111                         m.erase(i++);
112                 else
113                         ++i;
114         }
115 }
116
117 GiNaC::ex
118 ex_collect_to_ex(const ex_collect_t& ec, const GiNaC::exvector& vars)
119 {
120         exvector ev;
121         ev.reserve(ec.size());
122         for (std::size_t i = 0; i < ec.size(); ++i) {
123                 exvector tv;
124                 tv.reserve(vars.size() + 1);
125                 for (std::size_t j = 0; j < vars.size(); ++j) {
126                         if (ec[i].first[j] != 0)
127                                 tv.push_back(power(vars[j], ec[i].first[j]));
128                 }
129                 tv.push_back(ec[i].second);
130                 ex tmp = (new mul(tv))->setflag(status_flags::dynallocated);
131                 ev.push_back(tmp);
132         }
133         ex ret = (new add(ev))->setflag(status_flags::dynallocated);
134         return ret;
135 }
136
137 ex lcoeff_wrt(ex e, const exvector& x)
138 {
139         static const ex ex0(0);
140         e = e.expand();
141         if (e.is_zero())
142                 return ex0;
143
144         ex_collect_t ec;
145         collect_vargs(ec, e, x);
146         return ec.rbegin()->second;
147 }
148
149 cln::cl_I integer_lcoeff(const ex& e, const exvector& vars)
150 {
151         ex_collect_t ec;
152         collect_vargs(ec, e, vars);
153         if (ec.size() == 0)
154                 return cln::cl_I(0);
155         ex lc = ec.rbegin()->second;
156         bug_on(!is_a<numeric>(lc), "leading coefficient is not an integer");
157         bug_on(!lc.info(info_flags::integer),
158                 "leading coefficient is not an integer");
159
160         return to_cl_I(lc);
161 }
162
163 } // namespace GiNaC
164