Make it possible to print exmaps and exvectors.
[ginac.git] / ginac / operators.cpp
1 /** @file operators.cpp
2  *
3  *  Implementation of GiNaC's overloaded operators. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2005 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <iostream>
24 #include <iomanip>
25
26 #include "operators.h"
27 #include "numeric.h"
28 #include "add.h"
29 #include "mul.h"
30 #include "power.h"
31 #include "ncmul.h"
32 #include "relational.h"
33 #include "print.h"
34 #include "utils.h"
35
36 namespace GiNaC {
37
38 /** Used internally by operator+() to add two ex objects together. */
39 static inline const ex exadd(const ex & lh, const ex & rh)
40 {
41         return (new add(lh,rh))->setflag(status_flags::dynallocated);
42 }
43
44 /** Used internally by operator*() to multiply two ex objects together. */
45 static inline const ex exmul(const ex & lh, const ex & rh)
46 {
47         // Check if we are constructing a mul object or a ncmul object.  Due to
48         // ncmul::eval()'s rule to pull out commutative elements we need to check
49         // only one of the elements.
50         if (rh.return_type()==return_types::commutative ||
51             lh.return_type()==return_types::commutative) {
52                 return (new mul(lh,rh))->setflag(status_flags::dynallocated);
53         } else {
54                 return (new ncmul(lh,rh))->setflag(status_flags::dynallocated);
55         }
56 }
57
58 /** Used internally by operator-() and friends to change the sign of an argument. */
59 static inline const ex exminus(const ex & lh)
60 {
61         return (new mul(lh,_ex_1))->setflag(status_flags::dynallocated);
62 }
63
64 // binary arithmetic operators ex with ex
65
66 const ex operator+(const ex & lh, const ex & rh)
67 {
68         return exadd(lh, rh);
69 }
70
71 const ex operator-(const ex & lh, const ex & rh)
72 {
73         return exadd(lh, exminus(rh));
74 }
75
76 const ex operator*(const ex & lh, const ex & rh)
77 {
78         return exmul(lh, rh);
79 }
80
81 const ex operator/(const ex & lh, const ex & rh)
82 {
83         return exmul(lh, power(rh,_ex_1));
84 }
85
86
87 // binary arithmetic operators numeric with numeric
88
89 const numeric operator+(const numeric & lh, const numeric & rh)
90 {
91         return lh.add(rh);
92 }
93
94 const numeric operator-(const numeric & lh, const numeric & rh)
95 {
96         return lh.sub(rh);
97 }
98
99 const numeric operator*(const numeric & lh, const numeric & rh)
100 {
101         return lh.mul(rh);
102 }
103
104 const numeric operator/(const numeric & lh, const numeric & rh)
105 {
106         return lh.div(rh);
107 }
108
109
110 // binary arithmetic assignment operators with ex
111
112 ex & operator+=(ex & lh, const ex & rh)
113 {
114         return lh = exadd(lh, rh);
115 }
116
117 ex & operator-=(ex & lh, const ex & rh)
118 {
119         return lh = exadd(lh, exminus(rh));
120 }
121
122 ex & operator*=(ex & lh, const ex & rh)
123 {
124         return lh = exmul(lh, rh);
125 }
126
127 ex & operator/=(ex & lh, const ex & rh)
128 {
129         return lh = exmul(lh, power(rh,_ex_1));
130 }
131
132
133 // binary arithmetic assignment operators with numeric
134
135 numeric & operator+=(numeric & lh, const numeric & rh)
136 {
137         lh = lh.add(rh);
138         return lh;
139 }
140
141 numeric & operator-=(numeric & lh, const numeric & rh)
142 {
143         lh = lh.sub(rh);
144         return lh;
145 }
146
147 numeric & operator*=(numeric & lh, const numeric & rh)
148 {
149         lh = lh.mul(rh);
150         return lh;
151 }
152
153 numeric & operator/=(numeric & lh, const numeric & rh)
154 {
155         lh = lh.div(rh);
156         return lh;
157 }
158
159
160 // unary operators
161
162 const ex operator+(const ex & lh)
163 {
164         return lh;
165 }
166
167 const ex operator-(const ex & lh)
168 {
169         return exminus(lh);
170 }
171
172 const numeric operator+(const numeric & lh)
173 {
174         return lh;
175 }
176
177 const numeric operator-(const numeric & lh)
178 {
179         return _num_1_p->mul(lh);
180 }
181
182
183 // increment / decrement operators
184
185 /** Expression prefix increment.  Adds 1 and returns incremented ex. */
186 ex & operator++(ex & rh)
187 {
188         return rh = exadd(rh, _ex1);
189 }
190
191 /** Expression prefix decrement.  Subtracts 1 and returns decremented ex. */
192 ex & operator--(ex & rh)
193 {
194         return rh = exadd(rh, _ex_1);
195 }
196
197 /** Expression postfix increment.  Returns the ex and leaves the original
198  *  incremented by 1. */
199 const ex operator++(ex & lh, int)
200 {
201         ex tmp(lh);
202         lh = exadd(lh, _ex1);
203         return tmp;
204 }
205
206 /** Expression postfix decrement.  Returns the ex and leaves the original
207  *  decremented by 1. */
208 const ex operator--(ex & lh, int)
209 {
210         ex tmp(lh);
211         lh = exadd(lh, _ex_1);
212         return tmp;
213 }
214
215 /** Numeric prefix increment.  Adds 1 and returns incremented number. */
216 numeric& operator++(numeric & rh)
217 {
218         rh = rh.add(*_num1_p);
219         return rh;
220 }
221
222 /** Numeric prefix decrement.  Subtracts 1 and returns decremented number. */
223 numeric& operator--(numeric & rh)
224 {
225         rh = rh.add(*_num_1_p);
226         return rh;
227 }
228
229 /** Numeric postfix increment.  Returns the number and leaves the original
230  *  incremented by 1. */
231 const numeric operator++(numeric & lh, int)
232 {
233         numeric tmp(lh);
234         lh = lh.add(*_num1_p);
235         return tmp;
236 }
237
238 /** Numeric postfix decrement.  Returns the number and leaves the original
239  *  decremented by 1. */
240 const numeric operator--(numeric & lh, int)
241 {
242         numeric tmp(lh);
243         lh = lh.add(*_num_1_p);
244         return tmp;
245 }
246
247 // binary relational operators ex with ex
248
249 const relational operator==(const ex & lh, const ex & rh)
250 {
251         return relational(lh,rh,relational::equal);
252 }
253
254 const relational operator!=(const ex & lh, const ex & rh)
255 {
256         return relational(lh,rh,relational::not_equal);
257 }
258
259 const relational operator<(const ex & lh, const ex & rh)
260 {
261         return relational(lh,rh,relational::less);
262 }
263
264 const relational operator<=(const ex & lh, const ex & rh)
265 {
266         return relational(lh,rh,relational::less_or_equal);
267 }
268
269 const relational operator>(const ex & lh, const ex & rh)
270 {
271         return relational(lh,rh,relational::greater);
272 }
273
274 const relational operator>=(const ex & lh, const ex & rh)
275 {
276         return relational(lh,rh,relational::greater_or_equal);
277 }
278
279 // input/output stream operators and manipulators
280
281 static int my_ios_index()
282 {
283         static int i = std::ios_base::xalloc();
284         return i;
285 }
286
287 // Stream format gets copied or destroyed
288 static void my_ios_callback(std::ios_base::event ev, std::ios_base & s, int i)
289 {
290         print_context *p = static_cast<print_context *>(s.pword(i));
291         if (ev == std::ios_base::erase_event) {
292                 delete p;
293                 s.pword(i) = 0;
294         } else if (ev == std::ios_base::copyfmt_event && p != 0)
295                 s.pword(i) = p->duplicate();
296 }
297
298 enum {
299         callback_registered = 1
300 };
301
302 // Get print_context associated with stream, may return 0 if no context has
303 // been associated yet
304 static inline print_context *get_print_context(std::ios_base & s)
305 {
306         return static_cast<print_context *>(s.pword(my_ios_index()));
307 }
308
309 // Set print_context associated with stream, retain options
310 static void set_print_context(std::ios_base & s, const print_context & c)
311 {
312         int i = my_ios_index();
313         long flags = s.iword(i);
314         if (!(flags & callback_registered)) {
315                 s.register_callback(my_ios_callback, i);
316                 s.iword(i) = flags | callback_registered;
317         }
318         print_context *p = static_cast<print_context *>(s.pword(i));
319         unsigned options = p ? p->options : c.options;
320         delete p;
321         p = c.duplicate();
322         p->options = options;
323         s.pword(i) = p;
324 }
325
326 // Get options for print_context associated with stream
327 static inline unsigned get_print_options(std::ios_base & s)
328 {
329         print_context *p = get_print_context(s);
330         return p ? p->options : 0;
331 }
332
333 // Set options for print_context associated with stream
334 static void set_print_options(std::ostream & s, unsigned options)
335 {
336         print_context *p = get_print_context(s);
337         if (p == 0)
338                 set_print_context(s, print_dflt(s, options));
339         else
340                 p->options = options;
341 }
342
343 std::ostream & operator<<(std::ostream & os, const ex & e)
344 {
345         print_context *p = get_print_context(os);
346         if (p == 0)
347                 e.print(print_dflt(os));
348         else
349                 e.print(*p);
350         return os;
351 }
352
353 std::ostream & operator<<(std::ostream & os, const exvector & e)
354 {
355         print_context *p = get_print_context(os);
356         exvector::const_iterator i = e.begin();
357         exvector::const_iterator vend = e.end();
358
359         if (i==vend) {
360                 os << "[]";
361                 return os;
362         }
363
364         os << "[";
365         while (true) {
366                 if (p == 0)
367                         i -> print(print_dflt(os));
368                 else
369                         i -> print(*p);
370                 ++i;
371                 if (i==vend)
372                         break;
373                 os << ",";
374         }
375         os << "]";
376
377         return os;
378 }
379
380 std::ostream & operator<<(std::ostream & os, const exmap & e)
381 {
382         print_context *p = get_print_context(os);
383         exmap::const_iterator i = e.begin();
384         exmap::const_iterator mend = e.end();
385
386         if (i==mend) {
387                 os << "{}";
388                 return os;
389         }
390
391         os << "{";
392         while (true) {
393                 if (p == 0)
394                         i->first.print(print_dflt(os));
395                 else
396                         i->first.print(*p);
397                 os << "==";
398                 if (p == 0)
399                         i->second.print(print_dflt(os));
400                 else
401                         i->second.print(*p);
402                 ++i;
403                 if( i==mend )
404                         break;
405                 os << ",";
406         }
407         os << "}";
408
409         return os;
410 }
411
412 std::istream & operator>>(std::istream & is, ex & e)
413 {
414         throw (std::logic_error("expression input from streams not implemented"));
415 }
416
417 std::ostream & dflt(std::ostream & os)
418 {
419         set_print_context(os, print_dflt(os));
420         set_print_options(os, 0);
421         return os;
422 }
423
424 std::ostream & latex(std::ostream & os)
425 {
426         set_print_context(os, print_latex(os));
427         return os;
428 }
429
430 std::ostream & python(std::ostream & os)
431 {
432         set_print_context(os, print_python(os));
433         return os;
434 }
435
436 std::ostream & python_repr(std::ostream & os)
437 {
438         set_print_context(os, print_python_repr(os));
439         return os;
440 }
441
442 std::ostream & tree(std::ostream & os)
443 {
444         set_print_context(os, print_tree(os));
445         return os;
446 }
447
448 std::ostream & csrc(std::ostream & os)
449 {
450         set_print_context(os, print_csrc_double(os));
451         return os;
452 }
453
454 std::ostream & csrc_float(std::ostream & os)
455 {
456         set_print_context(os, print_csrc_float(os));
457         return os;
458 }
459
460 std::ostream & csrc_double(std::ostream & os)
461 {
462         set_print_context(os, print_csrc_double(os));
463         return os;
464 }
465
466 std::ostream & csrc_cl_N(std::ostream & os)
467 {
468         set_print_context(os, print_csrc_cl_N(os));
469         return os;
470 }
471
472 std::ostream & index_dimensions(std::ostream & os)
473 {
474         set_print_options(os, get_print_options(os) | print_options::print_index_dimensions);
475         return os;
476 }
477
478 std::ostream & no_index_dimensions(std::ostream & os)
479 {
480         set_print_options(os, get_print_options(os) & ~print_options::print_index_dimensions);
481         return os;
482 }
483
484 } // namespace GiNaC