Added integral class.
[ginac.git] / ginac / integral.cpp
1 /** @file integral.cpp
2  *
3  *  Implementation of GiNaC's symbolic  integral. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2004 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include "integral.h"
24 #include "numeric.h"
25 #include "symbol.h"
26 #include "add.h"
27 #include "mul.h"
28 #include "power.h"
29 #include "inifcns.h"
30 #include "wildcard.h"
31 #include "archive.h"
32 #include "registrar.h"
33 #include "utils.h"
34 #include "operators.h"
35 #include "relational.h"
36
37 using namespace std;
38
39 namespace GiNaC {
40
41 GINAC_IMPLEMENT_REGISTERED_CLASS_OPT(integral, basic,
42   print_func<print_dflt>(&integral::do_print).
43   print_func<print_latex>(&integral::do_print_latex))
44
45
46 //////////
47 // default constructor
48 //////////
49
50 integral::integral()
51                 : inherited(TINFO_integral),
52                 x((new symbol())->setflag(status_flags::dynallocated))
53 {}
54
55 //////////
56 // other constructors
57 //////////
58
59 // public
60
61 integral::integral(const ex & x_, const ex & a_, const ex & b_, const ex & f_)
62                 : inherited(TINFO_integral), x(x_), a(a_), b(b_), f(f_)
63 {
64         if (!is_a<symbol>(x)) {
65                 throw(std::invalid_argument("first argument of integral must be of type symbol"));
66         }
67 }
68
69 //////////
70 // archiving
71 //////////
72
73 integral::integral(const archive_node & n, lst & sym_lst) : inherited(n, sym_lst)
74 {
75         n.find_ex("x", x, sym_lst);
76         n.find_ex("a", a, sym_lst);
77         n.find_ex("b", b, sym_lst);
78         n.find_ex("f", f, sym_lst);
79 }
80
81 void integral::archive(archive_node & n) const
82 {
83         inherited::archive(n);
84         n.add_ex("x", x);
85         n.add_ex("a", a);
86         n.add_ex("b", b);
87         n.add_ex("f", f);
88 }
89
90 DEFAULT_UNARCHIVE(integral)
91
92 //////////
93 // functions overriding virtual functions from base classes
94 //////////
95
96 void integral::do_print(const print_context & c, unsigned level) const
97 {
98         c.s << "integral(";
99         x.print(c);
100         c.s << ",";
101         a.print(c);
102         c.s << ",";
103         b.print(c);
104         c.s << ",";
105         f.print(c);
106         c.s << ")";
107 }
108
109 void integral::do_print_latex(const print_latex & c, unsigned level) const
110 {
111         string varname = ex_to<symbol>(x).get_name();
112         if (level > precedence())
113                 c.s << "\\left(";
114         c.s << "\\int_{";
115         a.print(c);
116         c.s << "}^{";
117         b.print(c);
118         c.s << "} d";
119         if (varname.size() > 1)
120                 c.s << "\\," << varname << "\\:";
121         else
122                 c.s << varname << "\\,";
123         f.print(c,precedence());
124         if (level > precedence())
125                 c.s << "\\right)";
126 }
127
128 int integral::compare_same_type(const basic & other) const
129 {
130         GINAC_ASSERT(is_exactly_a<integral>(other));
131         const integral &o = static_cast<const integral &>(other);
132
133         int cmpval = x.compare(o.x);
134         if (cmpval)
135                 return cmpval;
136         cmpval = a.compare(o.a);
137         if (cmpval)
138                 return cmpval;
139         cmpval = b.compare(o.b);
140         if (cmpval)
141                 return cmpval;
142         return f.compare(o.f);
143 }
144
145 ex integral::eval(int level) const
146 {
147         if ((level==1) && (flags & status_flags::evaluated))
148                 return *this;
149         if (level == -max_recursion_level)
150                 throw(std::runtime_error("max recursion level reached"));
151
152         ex eintvar = (level==1) ? x : x.eval(level-1);
153         ex ea      = (level==1) ? a : a.eval(level-1);
154         ex eb      = (level==1) ? b : b.eval(level-1);
155         ex ef      = (level==1) ? f : f.eval(level-1);
156
157         if (!ef.has(eintvar) && !haswild(ef))
158                 return eb*ef-ea*ef;
159
160         if (ea==eb)
161                 return _ex0;
162
163         if (are_ex_trivially_equal(eintvar,x) && are_ex_trivially_equal(ea,a)
164                         && are_ex_trivially_equal(eb,b) && are_ex_trivially_equal(ef,f))
165                 return this->hold();
166         return (new integral(eintvar, ea, eb, ef))
167                 ->setflag(status_flags::dynallocated | status_flags::evaluated);
168 }
169
170 ex integral::evalf(int level) const
171 {
172         ex ea;
173         ex eb;
174         ex ef;
175
176         if (level==1) {
177                 ea = a;
178                 eb = b;
179                 ef = f;
180         } else if (level == -max_recursion_level) {
181                 throw(runtime_error("max recursion level reached"));
182         } else {
183                 ea = a.evalf(level-1);
184                 eb = b.evalf(level-1);
185                 ef = f.evalf(level-1);
186         }
187
188         // 12.34 is just an arbitrary number used to check whether a number
189         // results after subsituting a number for the integration variable.
190         if (is_exactly_a<numeric>(ea) && is_exactly_a<numeric>(eb) 
191                         && is_exactly_a<numeric>(ef.subs(x==12.34).evalf())) {
192                 try {
193                         return adaptivesimpson(x, ea, eb, ef);
194                 } catch (runtime_error &rte) {}
195         }
196
197         if (are_ex_trivially_equal(a, ea) && are_ex_trivially_equal(b, eb)
198                                 && are_ex_trivially_equal(f, ef))
199                         return *this;
200                 else
201                         return (new integral(x, ea, eb, ef))
202                                 ->setflag(status_flags::dynallocated);
203 }
204
205 int integral::max_integration_level = 15;
206 ex integral::relative_integration_error = power(10,-8).evalf();
207
208 ex subsvalue(const ex & var, const ex & value, const ex & fun)
209 {
210         ex result = fun.subs(var==value).evalf();
211         if (is_a<numeric>(result))
212                 return result;
213         throw logic_error("integrant does not evaluate to numeric");
214 }
215
216 /** Numeric integration routine based upon the "Adaptive Quadrature" one
217   * in "Numerical Analysis" by Burden and Faires. Parameters are integration
218   * variable, left boundary, right boundary, function to be integrated and
219   * the relative integration error. The function should evalf into a number
220   * after substituting the integration variable by a number. Another thing
221   * to note is that this implementation is no good at integrating functions
222   * with discontinuities. */
223 ex adaptivesimpson(const ex & x, const ex & a, const ex & b, const ex & f, const ex & error)
224 {
225         // use lookup table to be potentially much faster.
226         static exmap lookup;
227         static symbol ivar("ivar");
228         ex lookupex = integral(ivar,a,b,f.subs(x==ivar));
229         exmap::iterator emi = lookup.find(lookupex);
230         if (emi!=lookup.end())
231                 return emi->second;
232
233         ex app = 0;
234         int i = 1;
235         exvector avec(integral::max_integration_level+1);
236         exvector hvec(integral::max_integration_level+1);
237         exvector favec(integral::max_integration_level+1);
238         exvector fbvec(integral::max_integration_level+1);
239         exvector fcvec(integral::max_integration_level+1);
240         exvector svec(integral::max_integration_level+1);
241         exvector errorvec(integral::max_integration_level+1);
242         vector<int> lvec(integral::max_integration_level+1);
243
244         avec[i] = a;
245         hvec[i] = (b-a)/2;
246         favec[i] = subsvalue(x, a, f);
247         fcvec[i] = subsvalue(x, a+hvec[i], f);
248         fbvec[i] = subsvalue(x, b, f);
249         svec[i] = hvec[i]*(favec[i]+4*fcvec[i]+fbvec[i])/3;
250         lvec[i] = 1;
251         errorvec[i] = integral::relative_integration_error*svec[i];
252
253         while (i>0) {
254                 ex fd = subsvalue(x, avec[i]+hvec[i]/2, f);
255                 ex fe = subsvalue(x, avec[i]+3*hvec[i]/2, f);
256                 ex s1 = hvec[i]*(favec[i]+4*fd+fcvec[i])/6;
257                 ex s2 = hvec[i]*(fcvec[i]+4*fe+fbvec[i])/6;
258                 ex nu1 = avec[i];
259                 ex nu2 = favec[i];
260                 ex nu3 = fcvec[i];
261                 ex nu4 = fbvec[i];
262                 ex nu5 = hvec[i];
263                 // hopefully prevents a crash if the function is zero sometimes.
264                 ex nu6 = max(errorvec[i], (s1+s2)*integral::relative_integration_error);
265                 ex nu7 = svec[i];
266                 int nu8 = lvec[i];
267                 --i;
268                 if (abs(ex_to<numeric>(s1+s2-nu7)) <= nu6)
269                         app+=(s1+s2);
270                 else {
271                         if (nu8>=integral::max_integration_level)
272                                 throw runtime_error("max integration level reached");
273                         ++i;
274                         avec[i] = nu1+nu5;
275                         favec[i] = nu3;
276                         fcvec[i] = fe;
277                         fbvec[i] = nu4;
278                         hvec[i] = nu5/2;
279                         errorvec[i]=nu6/2;
280                         svec[i] = s2;
281                         lvec[i] = nu8+1;
282                         ++i;
283                         avec[i] = nu1;
284                         favec[i] = nu2;
285                         fcvec[i] = fd;
286                         fbvec[i] = nu3;
287                         hvec[i] = hvec[i-1];
288                         errorvec[i]=errorvec[i-1];
289                         svec[i] = s1;
290                         lvec[i] = lvec[i-1];
291                 }
292         }
293
294         lookup[lookupex]=app;
295         return app;
296 }
297
298 int integral::degree(const ex & s) const
299 {
300         return ((b-a)*f).degree(s);
301 }
302
303 int integral::ldegree(const ex & s) const
304 {
305         return ((b-a)*f).ldegree(s);
306 }
307
308 ex integral::eval_ncmul(const exvector & v) const
309 {
310         return f.eval_ncmul(v);
311 }
312
313 size_t integral::nops() const
314 {
315         return 4;
316 }
317
318 ex integral::op(size_t i) const
319 {
320         GINAC_ASSERT(i<4);
321
322         switch(i) {
323                 case(0):
324                         return x;
325                 case(1):
326                         return a;
327                 case(2):
328                         return b;
329                 case(3):
330                         return f;
331         }
332 }
333
334 ex & integral::let_op(size_t i)
335 {
336         ensure_if_modifiable();
337         switch(i) {
338                 case(0):
339                         return x;
340                 case(1):
341                         return a;
342                 case(2):
343                         return b;
344                 case(3):
345                         return f;
346         }
347 }
348
349 ex integral::expand(unsigned options) const
350 {
351         if (options==0 && (flags & status_flags::expanded))
352                 return *this;
353
354         ex newa = a.expand(options);
355         ex newb = b.expand(options);
356         ex newf = f.expand(options);
357
358         if (is_a<add>(newf)) {
359                 exvector v;
360                 v.reserve(newf.nops());
361                 for (size_t i=0; i<newf.nops(); ++i)
362                         v.push_back(integral(x, newa, newb, newf.op(i)).expand(options));
363                 return ex(add(v)).expand(options);
364         }
365
366         if (is_a<mul>(newf)) {
367                 ex prefactor = 1;
368                 ex rest = 1;
369                 for (size_t i=0; i<newf.nops(); ++i)
370                         if (newf.op(i).has(x))
371                                 rest *= newf.op(i);
372                         else
373                                 prefactor *= newf.op(i);
374                 if (prefactor != 1)
375                         return (prefactor*integral(x, newa, newb, rest)).expand(options);
376         }
377
378         if (are_ex_trivially_equal(a, newa) && are_ex_trivially_equal(b, newb)
379                         && are_ex_trivially_equal(f, newf)) {
380                 if (options==0)
381                         this->setflag(status_flags::expanded);
382                 return *this;
383         }
384
385         const basic & newint = (new integral(x, newa, newb, newf))
386                 ->setflag(status_flags::dynallocated);
387         if (options == 0)
388                 newint.setflag(status_flags::expanded);
389         return newint;
390 }
391
392 ex integral::derivative(const symbol & s) const
393 {       if (s==x)
394                 throw(logic_error("differentiation with respect to dummy variable"));
395         return b.diff(s)*f.subs(x==b)-a.diff(s)*f.subs(x==a)+integral(x, a, b, f.diff(s));
396 }
397
398 unsigned integral::return_type() const
399 {
400         return f.return_type();
401 }
402
403 unsigned integral::return_type_tinfo() const
404 {
405         return f.return_type_tinfo();
406 }
407
408 ex integral::conjugate() const
409 {
410         ex conja = a.conjugate();
411         ex conjb = b.conjugate();
412         ex conjf = f.conjugate().subs(x.conjugate()==x);
413
414         if (are_ex_trivially_equal(a, conja) && are_ex_trivially_equal(b, conjb)
415                         && are_ex_trivially_equal(f, conjf))
416                 return *this;
417
418         return (new integral(x, conja, conjb, conjf))
419                 ->setflag(status_flags::dynallocated);
420 }
421
422 ex integral::eval_integ() const
423 {
424         if (!(flags & status_flags::expanded))
425                 return this->expand().eval_integ();
426         
427         if (f==x)
428                 return b*b/2-a*a/2;
429         if (is_a<power>(f) && f.op(0)==x) {
430                 if (f.op(1)==-1)
431                         return log(b/a);
432                 if (!f.op(1).has(x)) {
433                         ex primit = power(x,f.op(1)+1)/(f.op(1)+1);
434                         return primit.subs(x==b)-primit.subs(x==a);
435                 }
436         }
437
438         return *this;
439 }
440
441 } // namespace GiNaC