Update copyright statements.
[ginac.git] / ginac / operators.cpp
1 /** @file operators.cpp
2  *
3  *  Implementation of GiNaC's overloaded operators. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2014 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "operators.h"
24 #include "numeric.h"
25 #include "add.h"
26 #include "mul.h"
27 #include "power.h"
28 #include "ncmul.h"
29 #include "relational.h"
30 #include "print.h"
31 #include "utils.h"
32
33 #include <iomanip>
34 #include <iostream>
35
36 namespace GiNaC {
37
38 /** Used internally by operator+() to add two ex objects together. */
39 static inline const ex exadd(const ex & lh, const ex & rh)
40 {
41         return (new add(lh,rh))->setflag(status_flags::dynallocated);
42 }
43
44 /** Used internally by operator*() to multiply two ex objects together. */
45 static inline const ex exmul(const ex & lh, const ex & rh)
46 {
47         // Check if we are constructing a mul object or a ncmul object.  Due to
48         // ncmul::eval()'s rule to pull out commutative elements we need to check
49         // only one of the elements.
50         if (rh.return_type()==return_types::commutative ||
51             lh.return_type()==return_types::commutative) {
52                 return (new mul(lh,rh))->setflag(status_flags::dynallocated);
53         } else {
54                 return (new ncmul(lh,rh))->setflag(status_flags::dynallocated);
55         }
56 }
57
58 /** Used internally by operator-() and friends to change the sign of an argument. */
59 static inline const ex exminus(const ex & lh)
60 {
61         return (new mul(lh,_ex_1))->setflag(status_flags::dynallocated);
62 }
63
64 // binary arithmetic operators ex with ex
65
66 const ex operator+(const ex & lh, const ex & rh)
67 {
68         return exadd(lh, rh);
69 }
70
71 const ex operator-(const ex & lh, const ex & rh)
72 {
73         return exadd(lh, exminus(rh));
74 }
75
76 const ex operator*(const ex & lh, const ex & rh)
77 {
78         return exmul(lh, rh);
79 }
80
81 const ex operator/(const ex & lh, const ex & rh)
82 {
83         return exmul(lh, power(rh,_ex_1));
84 }
85
86
87 // binary arithmetic operators numeric with numeric
88
89 const numeric operator+(const numeric & lh, const numeric & rh)
90 {
91         return lh.add(rh);
92 }
93
94 const numeric operator-(const numeric & lh, const numeric & rh)
95 {
96         return lh.sub(rh);
97 }
98
99 const numeric operator*(const numeric & lh, const numeric & rh)
100 {
101         return lh.mul(rh);
102 }
103
104 const numeric operator/(const numeric & lh, const numeric & rh)
105 {
106         return lh.div(rh);
107 }
108
109
110 // binary arithmetic assignment operators with ex
111
112 ex & operator+=(ex & lh, const ex & rh)
113 {
114         return lh = exadd(lh, rh);
115 }
116
117 ex & operator-=(ex & lh, const ex & rh)
118 {
119         return lh = exadd(lh, exminus(rh));
120 }
121
122 ex & operator*=(ex & lh, const ex & rh)
123 {
124         return lh = exmul(lh, rh);
125 }
126
127 ex & operator/=(ex & lh, const ex & rh)
128 {
129         return lh = exmul(lh, power(rh,_ex_1));
130 }
131
132
133 // binary arithmetic assignment operators with numeric
134
135 numeric & operator+=(numeric & lh, const numeric & rh)
136 {
137         lh = lh.add(rh);
138         return lh;
139 }
140
141 numeric & operator-=(numeric & lh, const numeric & rh)
142 {
143         lh = lh.sub(rh);
144         return lh;
145 }
146
147 numeric & operator*=(numeric & lh, const numeric & rh)
148 {
149         lh = lh.mul(rh);
150         return lh;
151 }
152
153 numeric & operator/=(numeric & lh, const numeric & rh)
154 {
155         lh = lh.div(rh);
156         return lh;
157 }
158
159
160 // unary operators
161
162 const ex operator+(const ex & lh)
163 {
164         return lh;
165 }
166
167 const ex operator-(const ex & lh)
168 {
169         return exminus(lh);
170 }
171
172 const numeric operator+(const numeric & lh)
173 {
174         return lh;
175 }
176
177 const numeric operator-(const numeric & lh)
178 {
179         return _num_1_p->mul(lh);
180 }
181
182
183 // increment / decrement operators
184
185 /** Expression prefix increment.  Adds 1 and returns incremented ex. */
186 ex & operator++(ex & rh)
187 {
188         return rh = exadd(rh, _ex1);
189 }
190
191 /** Expression prefix decrement.  Subtracts 1 and returns decremented ex. */
192 ex & operator--(ex & rh)
193 {
194         return rh = exadd(rh, _ex_1);
195 }
196
197 /** Expression postfix increment.  Returns the ex and leaves the original
198  *  incremented by 1. */
199 const ex operator++(ex & lh, int)
200 {
201         ex tmp(lh);
202         lh = exadd(lh, _ex1);
203         return tmp;
204 }
205
206 /** Expression postfix decrement.  Returns the ex and leaves the original
207  *  decremented by 1. */
208 const ex operator--(ex & lh, int)
209 {
210         ex tmp(lh);
211         lh = exadd(lh, _ex_1);
212         return tmp;
213 }
214
215 /** Numeric prefix increment.  Adds 1 and returns incremented number. */
216 numeric& operator++(numeric & rh)
217 {
218         rh = rh.add(*_num1_p);
219         return rh;
220 }
221
222 /** Numeric prefix decrement.  Subtracts 1 and returns decremented number. */
223 numeric& operator--(numeric & rh)
224 {
225         rh = rh.add(*_num_1_p);
226         return rh;
227 }
228
229 /** Numeric postfix increment.  Returns the number and leaves the original
230  *  incremented by 1. */
231 const numeric operator++(numeric & lh, int)
232 {
233         numeric tmp(lh);
234         lh = lh.add(*_num1_p);
235         return tmp;
236 }
237
238 /** Numeric postfix decrement.  Returns the number and leaves the original
239  *  decremented by 1. */
240 const numeric operator--(numeric & lh, int)
241 {
242         numeric tmp(lh);
243         lh = lh.add(*_num_1_p);
244         return tmp;
245 }
246
247 // binary relational operators ex with ex
248
249 const relational operator==(const ex & lh, const ex & rh)
250 {
251         return relational(lh,rh,relational::equal);
252 }
253
254 const relational operator!=(const ex & lh, const ex & rh)
255 {
256         return relational(lh,rh,relational::not_equal);
257 }
258
259 const relational operator<(const ex & lh, const ex & rh)
260 {
261         return relational(lh,rh,relational::less);
262 }
263
264 const relational operator<=(const ex & lh, const ex & rh)
265 {
266         return relational(lh,rh,relational::less_or_equal);
267 }
268
269 const relational operator>(const ex & lh, const ex & rh)
270 {
271         return relational(lh,rh,relational::greater);
272 }
273
274 const relational operator>=(const ex & lh, const ex & rh)
275 {
276         return relational(lh,rh,relational::greater_or_equal);
277 }
278
279 // input/output stream operators and manipulators
280
281 static int my_ios_index()
282 {
283         static int i = std::ios_base::xalloc();
284         return i;
285 }
286
287 // Stream format gets copied or destroyed
288 static void my_ios_callback(std::ios_base::event ev, std::ios_base & s, int i)
289 {
290         print_context *p = static_cast<print_context *>(s.pword(i));
291         if (ev == std::ios_base::erase_event) {
292                 delete p;
293                 s.pword(i) = 0;
294         } else if (ev == std::ios_base::copyfmt_event && p != 0)
295                 s.pword(i) = p->duplicate();
296 }
297
298 enum {
299         callback_registered = 1
300 };
301
302 // Get print_context associated with stream, may return 0 if no context has
303 // been associated yet
304 static inline print_context *get_print_context(std::ios_base & s)
305 {
306         return static_cast<print_context *>(s.pword(my_ios_index()));
307 }
308
309 // Set print_context associated with stream, retain options
310 static void set_print_context(std::ios_base & s, const print_context & c)
311 {
312         int i = my_ios_index();
313         long flags = s.iword(i);
314         if (!(flags & callback_registered)) {
315                 s.register_callback(my_ios_callback, i);
316                 s.iword(i) = flags | callback_registered;
317         }
318         print_context *p = static_cast<print_context *>(s.pword(i));
319         unsigned options = p ? p->options : c.options;
320         delete p;
321         p = c.duplicate();
322         p->options = options;
323         s.pword(i) = p;
324 }
325
326 // Get options for print_context associated with stream
327 static inline unsigned get_print_options(std::ios_base & s)
328 {
329         print_context *p = get_print_context(s);
330         return p ? p->options : 0;
331 }
332
333 // Set options for print_context associated with stream
334 static void set_print_options(std::ostream & s, unsigned options)
335 {
336         print_context *p = get_print_context(s);
337         if (p == 0)
338                 set_print_context(s, print_dflt(s, options));
339         else
340                 p->options = options;
341 }
342
343 std::ostream & operator<<(std::ostream & os, const ex & e)
344 {
345         print_context *p = get_print_context(os);
346         if (p == 0)
347                 e.print(print_dflt(os));
348         else
349                 e.print(*p);
350         return os;
351 }
352
353 std::ostream & operator<<(std::ostream & os, const exvector & e)
354 {
355         print_context *p = get_print_context(os);
356         exvector::const_iterator i = e.begin();
357         exvector::const_iterator vend = e.end();
358
359         if (i==vend) {
360                 os << "[]";
361                 return os;
362         }
363
364         os << "[";
365         while (true) {
366                 if (p == 0)
367                         i -> print(print_dflt(os));
368                 else
369                         i -> print(*p);
370                 ++i;
371                 if (i==vend)
372                         break;
373                 os << ",";
374         }
375         os << "]";
376
377         return os;
378 }
379
380 std::ostream & operator<<(std::ostream & os, const exset & e)
381 {
382         print_context *p = get_print_context(os);
383         exset::const_iterator i = e.begin();
384         exset::const_iterator send = e.end();
385
386         if (i==send) {
387                 os << "<>";
388                 return os;
389         }
390
391         os << "<";
392         while (true) {
393                 if (p == 0)
394                         i->print(print_dflt(os));
395                 else
396                         i->print(*p);
397                 ++i;
398                 if (i == send)
399                         break;
400                 os << ",";
401         }
402         os << ">";
403
404         return os;
405 }
406
407 std::ostream & operator<<(std::ostream & os, const exmap & e)
408 {
409         print_context *p = get_print_context(os);
410         exmap::const_iterator i = e.begin();
411         exmap::const_iterator mend = e.end();
412
413         if (i==mend) {
414                 os << "{}";
415                 return os;
416         }
417
418         os << "{";
419         while (true) {
420                 if (p == 0)
421                         i->first.print(print_dflt(os));
422                 else
423                         i->first.print(*p);
424                 os << "==";
425                 if (p == 0)
426                         i->second.print(print_dflt(os));
427                 else
428                         i->second.print(*p);
429                 ++i;
430                 if( i==mend )
431                         break;
432                 os << ",";
433         }
434         os << "}";
435
436         return os;
437 }
438
439 std::istream & operator>>(std::istream & is, ex & e)
440 {
441         throw (std::logic_error("expression input from streams not implemented"));
442 }
443
444 std::ostream & dflt(std::ostream & os)
445 {
446         set_print_context(os, print_dflt(os));
447         set_print_options(os, 0);
448         return os;
449 }
450
451 std::ostream & latex(std::ostream & os)
452 {
453         set_print_context(os, print_latex(os));
454         return os;
455 }
456
457 std::ostream & python(std::ostream & os)
458 {
459         set_print_context(os, print_python(os));
460         return os;
461 }
462
463 std::ostream & python_repr(std::ostream & os)
464 {
465         set_print_context(os, print_python_repr(os));
466         return os;
467 }
468
469 std::ostream & tree(std::ostream & os)
470 {
471         set_print_context(os, print_tree(os));
472         return os;
473 }
474
475 std::ostream & csrc(std::ostream & os)
476 {
477         set_print_context(os, print_csrc_double(os));
478         return os;
479 }
480
481 std::ostream & csrc_float(std::ostream & os)
482 {
483         set_print_context(os, print_csrc_float(os));
484         return os;
485 }
486
487 std::ostream & csrc_double(std::ostream & os)
488 {
489         set_print_context(os, print_csrc_double(os));
490         return os;
491 }
492
493 std::ostream & csrc_cl_N(std::ostream & os)
494 {
495         set_print_context(os, print_csrc_cl_N(os));
496         return os;
497 }
498
499 std::ostream & index_dimensions(std::ostream & os)
500 {
501         set_print_options(os, get_print_options(os) | print_options::print_index_dimensions);
502         return os;
503 }
504
505 std::ostream & no_index_dimensions(std::ostream & os)
506 {
507         set_print_options(os, get_print_options(os) & ~print_options::print_index_dimensions);
508         return os;
509 }
510
511 } // namespace GiNaC