Finalize 1.7.6 release.
[ginac.git] / ginac / operators.cpp
1 /** @file operators.cpp
2  *
3  *  Implementation of GiNaC's overloaded operators. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2019 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include "operators.h"
24 #include "numeric.h"
25 #include "add.h"
26 #include "mul.h"
27 #include "power.h"
28 #include "ncmul.h"
29 #include "relational.h"
30 #include "print.h"
31 #include "utils.h"
32
33 #include <iostream>
34
35 namespace GiNaC {
36
37 /** Used internally by operator+() to add two ex objects. */
38 static inline const ex exadd(const ex & lh, const ex & rh)
39 {
40         return dynallocate<add>(lh, rh);
41 }
42
43 /** Used internally by operator*() to multiply two ex objects. */
44 static inline const ex exmul(const ex & lh, const ex & rh)
45 {
46         // Check if we are constructing a mul object or a ncmul object.  Due to
47         // ncmul::eval()'s rule to pull out commutative elements we need to check
48         // only one of the elements.
49         if (rh.return_type()==return_types::commutative ||
50             lh.return_type()==return_types::commutative) {
51                 return dynallocate<mul>(lh, rh);
52         } else {
53                 return dynallocate<ncmul>(lh, rh);
54         }
55 }
56
57 /** Used internally by operator-() and friends to change the sign of an argument. */
58 static inline const ex exminus(const ex & lh)
59 {
60         return dynallocate<mul>(lh, _ex_1);
61 }
62
63 // binary arithmetic operators ex with ex
64
65 const ex operator+(const ex & lh, const ex & rh)
66 {
67         return exadd(lh, rh);
68 }
69
70 const ex operator-(const ex & lh, const ex & rh)
71 {
72         return exadd(lh, exminus(rh));
73 }
74
75 const ex operator*(const ex & lh, const ex & rh)
76 {
77         return exmul(lh, rh);
78 }
79
80 const ex operator/(const ex & lh, const ex & rh)
81 {
82         return exmul(lh, power(rh,_ex_1));
83 }
84
85
86 // binary arithmetic operators numeric with numeric
87
88 const numeric operator+(const numeric & lh, const numeric & rh)
89 {
90         return lh.add(rh);
91 }
92
93 const numeric operator-(const numeric & lh, const numeric & rh)
94 {
95         return lh.sub(rh);
96 }
97
98 const numeric operator*(const numeric & lh, const numeric & rh)
99 {
100         return lh.mul(rh);
101 }
102
103 const numeric operator/(const numeric & lh, const numeric & rh)
104 {
105         return lh.div(rh);
106 }
107
108
109 // binary arithmetic assignment operators with ex
110
111 ex & operator+=(ex & lh, const ex & rh)
112 {
113         return lh = exadd(lh, rh);
114 }
115
116 ex & operator-=(ex & lh, const ex & rh)
117 {
118         return lh = exadd(lh, exminus(rh));
119 }
120
121 ex & operator*=(ex & lh, const ex & rh)
122 {
123         return lh = exmul(lh, rh);
124 }
125
126 ex & operator/=(ex & lh, const ex & rh)
127 {
128         return lh = exmul(lh, power(rh,_ex_1));
129 }
130
131
132 // binary arithmetic assignment operators with numeric
133
134 numeric & operator+=(numeric & lh, const numeric & rh)
135 {
136         lh = lh.add(rh);
137         return lh;
138 }
139
140 numeric & operator-=(numeric & lh, const numeric & rh)
141 {
142         lh = lh.sub(rh);
143         return lh;
144 }
145
146 numeric & operator*=(numeric & lh, const numeric & rh)
147 {
148         lh = lh.mul(rh);
149         return lh;
150 }
151
152 numeric & operator/=(numeric & lh, const numeric & rh)
153 {
154         lh = lh.div(rh);
155         return lh;
156 }
157
158
159 // unary operators
160
161 const ex operator+(const ex & lh)
162 {
163         return lh;
164 }
165
166 const ex operator-(const ex & lh)
167 {
168         return exminus(lh);
169 }
170
171 const numeric operator+(const numeric & lh)
172 {
173         return lh;
174 }
175
176 const numeric operator-(const numeric & lh)
177 {
178         return _num_1_p->mul(lh);
179 }
180
181
182 // increment / decrement operators
183
184 /** Expression prefix increment.  Adds 1 and returns incremented ex. */
185 ex & operator++(ex & rh)
186 {
187         return rh = exadd(rh, _ex1);
188 }
189
190 /** Expression prefix decrement.  Subtracts 1 and returns decremented ex. */
191 ex & operator--(ex & rh)
192 {
193         return rh = exadd(rh, _ex_1);
194 }
195
196 /** Expression postfix increment.  Returns the ex and leaves the original
197  *  incremented by 1. */
198 const ex operator++(ex & lh, int)
199 {
200         ex tmp(lh);
201         lh = exadd(lh, _ex1);
202         return tmp;
203 }
204
205 /** Expression postfix decrement.  Returns the ex and leaves the original
206  *  decremented by 1. */
207 const ex operator--(ex & lh, int)
208 {
209         ex tmp(lh);
210         lh = exadd(lh, _ex_1);
211         return tmp;
212 }
213
214 /** Numeric prefix increment.  Adds 1 and returns incremented number. */
215 numeric& operator++(numeric & rh)
216 {
217         rh = rh.add(*_num1_p);
218         return rh;
219 }
220
221 /** Numeric prefix decrement.  Subtracts 1 and returns decremented number. */
222 numeric& operator--(numeric & rh)
223 {
224         rh = rh.add(*_num_1_p);
225         return rh;
226 }
227
228 /** Numeric postfix increment.  Returns the number and leaves the original
229  *  incremented by 1. */
230 const numeric operator++(numeric & lh, int)
231 {
232         numeric tmp(lh);
233         lh = lh.add(*_num1_p);
234         return tmp;
235 }
236
237 /** Numeric postfix decrement.  Returns the number and leaves the original
238  *  decremented by 1. */
239 const numeric operator--(numeric & lh, int)
240 {
241         numeric tmp(lh);
242         lh = lh.add(*_num_1_p);
243         return tmp;
244 }
245
246 // binary relational operators ex with ex
247
248 const relational operator==(const ex & lh, const ex & rh)
249 {
250         return relational(lh, rh, relational::equal);
251 }
252
253 const relational operator!=(const ex & lh, const ex & rh)
254 {
255         return relational(lh, rh, relational::not_equal);
256 }
257
258 const relational operator<(const ex & lh, const ex & rh)
259 {
260         return relational(lh, rh, relational::less);
261 }
262
263 const relational operator<=(const ex & lh, const ex & rh)
264 {
265         return relational(lh, rh, relational::less_or_equal);
266 }
267
268 const relational operator>(const ex & lh, const ex & rh)
269 {
270         return relational(lh, rh, relational::greater);
271 }
272
273 const relational operator>=(const ex & lh, const ex & rh)
274 {
275         return relational(lh, rh, relational::greater_or_equal);
276 }
277
278 // input/output stream operators and manipulators
279
280 static int my_ios_index()
281 {
282         static int i = std::ios_base::xalloc();
283         return i;
284 }
285
286 // Stream format gets copied or destroyed
287 static void my_ios_callback(std::ios_base::event ev, std::ios_base & s, int i)
288 {
289         print_context *p = static_cast<print_context *>(s.pword(i));
290         if (ev == std::ios_base::erase_event) {
291                 delete p;
292                 s.pword(i) = nullptr;
293         } else if (ev == std::ios_base::copyfmt_event && p != nullptr)
294                 s.pword(i) = p->duplicate();
295 }
296
297 enum {
298         callback_registered = 1
299 };
300
301 // Get print_context associated with stream, may return 0 if no context has
302 // been associated yet
303 static inline print_context *get_print_context(std::ios_base & s)
304 {
305         return static_cast<print_context *>(s.pword(my_ios_index()));
306 }
307
308 // Set print_context associated with stream, retain options
309 static void set_print_context(std::ios_base & s, const print_context & c)
310 {
311         int i = my_ios_index();
312         long flags = s.iword(i);
313         if (!(flags & callback_registered)) {
314                 s.register_callback(my_ios_callback, i);
315                 s.iword(i) = flags | callback_registered;
316         }
317         print_context *p = static_cast<print_context *>(s.pword(i));
318         unsigned options = p ? p->options : c.options;
319         delete p;
320         p = c.duplicate();
321         p->options = options;
322         s.pword(i) = p;
323 }
324
325 // Get options for print_context associated with stream
326 static inline unsigned get_print_options(std::ios_base & s)
327 {
328         print_context *p = get_print_context(s);
329         return p ? p->options : 0;
330 }
331
332 // Set options for print_context associated with stream
333 static void set_print_options(std::ostream & s, unsigned options)
334 {
335         print_context *p = get_print_context(s);
336         if (p == nullptr)
337                 set_print_context(s, print_dflt(s, options));
338         else
339                 p->options = options;
340 }
341
342 std::ostream & operator<<(std::ostream & os, const ex & e)
343 {
344         print_context *p = get_print_context(os);
345         if (p == nullptr)
346                 e.print(print_dflt(os));
347         else
348                 e.print(*p);
349         return os;
350 }
351
352 std::ostream & operator<<(std::ostream & os, const exvector & e)
353 {
354         print_context *p = get_print_context(os);
355         auto i = e.begin();
356         auto vend = e.end();
357
358         if (i==vend) {
359                 os << "[]";
360                 return os;
361         }
362
363         os << "[";
364         while (true) {
365                 if (p == nullptr)
366                         i -> print(print_dflt(os));
367                 else
368                         i -> print(*p);
369                 ++i;
370                 if (i==vend)
371                         break;
372                 os << ",";
373         }
374         os << "]";
375
376         return os;
377 }
378
379 std::ostream & operator<<(std::ostream & os, const exset & e)
380 {
381         print_context *p = get_print_context(os);
382         auto i = e.begin();
383         auto send = e.end();
384
385         if (i==send) {
386                 os << "<>";
387                 return os;
388         }
389
390         os << "<";
391         while (true) {
392                 if (p == nullptr)
393                         i->print(print_dflt(os));
394                 else
395                         i->print(*p);
396                 ++i;
397                 if (i == send)
398                         break;
399                 os << ",";
400         }
401         os << ">";
402
403         return os;
404 }
405
406 std::ostream & operator<<(std::ostream & os, const exmap & e)
407 {
408         print_context *p = get_print_context(os);
409         auto i = e.begin();
410         auto mend = e.end();
411
412         if (i==mend) {
413                 os << "{}";
414                 return os;
415         }
416
417         os << "{";
418         while (true) {
419                 if (p == nullptr)
420                         i->first.print(print_dflt(os));
421                 else
422                         i->first.print(*p);
423                 os << "==";
424                 if (p == nullptr)
425                         i->second.print(print_dflt(os));
426                 else
427                         i->second.print(*p);
428                 ++i;
429                 if( i==mend )
430                         break;
431                 os << ",";
432         }
433         os << "}";
434
435         return os;
436 }
437
438 std::istream & operator>>(std::istream & is, ex & e)
439 {
440         throw (std::logic_error("expression input from streams not implemented"));
441 }
442
443 std::ostream & dflt(std::ostream & os)
444 {
445         set_print_context(os, print_dflt(os));
446         set_print_options(os, 0);
447         return os;
448 }
449
450 std::ostream & latex(std::ostream & os)
451 {
452         set_print_context(os, print_latex(os));
453         return os;
454 }
455
456 std::ostream & python(std::ostream & os)
457 {
458         set_print_context(os, print_python(os));
459         return os;
460 }
461
462 std::ostream & python_repr(std::ostream & os)
463 {
464         set_print_context(os, print_python_repr(os));
465         return os;
466 }
467
468 std::ostream & tree(std::ostream & os)
469 {
470         set_print_context(os, print_tree(os));
471         return os;
472 }
473
474 std::ostream & csrc(std::ostream & os)
475 {
476         set_print_context(os, print_csrc_double(os));
477         return os;
478 }
479
480 std::ostream & csrc_float(std::ostream & os)
481 {
482         set_print_context(os, print_csrc_float(os));
483         return os;
484 }
485
486 std::ostream & csrc_double(std::ostream & os)
487 {
488         set_print_context(os, print_csrc_double(os));
489         return os;
490 }
491
492 std::ostream & csrc_cl_N(std::ostream & os)
493 {
494         set_print_context(os, print_csrc_cl_N(os));
495         return os;
496 }
497
498 std::ostream & index_dimensions(std::ostream & os)
499 {
500         set_print_options(os, get_print_options(os) | print_options::print_index_dimensions);
501         return os;
502 }
503
504 std::ostream & no_index_dimensions(std::ostream & os)
505 {
506         set_print_options(os, get_print_options(os) & ~print_options::print_index_dimensions);
507         return os;
508 }
509
510 } // namespace GiNaC