Happy New Year!
[ginac.git] / ginac / operators.cpp
1 /** @file operators.cpp
2  *
3  *  Implementation of GiNaC's overloaded operators. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2004 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <iomanip>
25
26 #include "operators.h"
27 #include "numeric.h"
28 #include "add.h"
29 #include "mul.h"
30 #include "power.h"
31 #include "ncmul.h"
32 #include "relational.h"
33 #include "print.h"
34 #include "utils.h"
35
36 namespace GiNaC {
37
38 /** Used internally by operator+() to add two ex objects together. */
39 static inline const ex exadd(const ex & lh, const ex & rh)
40 {
41         return (new add(lh,rh))->setflag(status_flags::dynallocated);
42 }
43
44 /** Used internally by operator*() to multiply two ex objects together. */
45 static inline const ex exmul(const ex & lh, const ex & rh)
46 {
47         // Check if we are constructing a mul object or a ncmul object.  Due to
48         // ncmul::eval()'s rule to pull out commutative elements we need to check
49         // only one of the elements.
50         if (rh.return_type()==return_types::commutative ||
51             lh.return_type()==return_types::commutative)
52                 return (new mul(lh,rh))->setflag(status_flags::dynallocated);
53         else
54                 return (new ncmul(lh,rh))->setflag(status_flags::dynallocated);
55 }
56
57 /** Used internally by operator-() and friends to change the sign of an argument. */
58 static inline const ex exminus(const ex & lh)
59 {
60         return (new mul(lh,_ex_1))->setflag(status_flags::dynallocated);
61 }
62
63 // binary arithmetic operators ex with ex
64
65 const ex operator+(const ex & lh, const ex & rh)
66 {
67         return exadd(lh, rh);
68 }
69
70 const ex operator-(const ex & lh, const ex & rh)
71 {
72         return exadd(lh, exminus(rh));
73 }
74
75 const ex operator*(const ex & lh, const ex & rh)
76 {
77         return exmul(lh, rh);
78 }
79
80 const ex operator/(const ex & lh, const ex & rh)
81 {
82         return exmul(lh, power(rh,_ex_1));
83 }
84
85
86 // binary arithmetic operators numeric with numeric
87
88 const numeric operator+(const numeric & lh, const numeric & rh)
89 {
90         return lh.add(rh);
91 }
92
93 const numeric operator-(const numeric & lh, const numeric & rh)
94 {
95         return lh.sub(rh);
96 }
97
98 const numeric operator*(const numeric & lh, const numeric & rh)
99 {
100         return lh.mul(rh);
101 }
102
103 const numeric operator/(const numeric & lh, const numeric & rh)
104 {
105         return lh.div(rh);
106 }
107
108
109 // binary arithmetic assignment operators with ex
110
111 ex & operator+=(ex & lh, const ex & rh)
112 {
113         return lh = exadd(lh, rh);
114 }
115
116 ex & operator-=(ex & lh, const ex & rh)
117 {
118         return lh = exadd(lh, exminus(rh));
119 }
120
121 ex & operator*=(ex & lh, const ex & rh)
122 {
123         return lh = exmul(lh, rh);
124 }
125
126 ex & operator/=(ex & lh, const ex & rh)
127 {
128         return lh = exmul(lh, power(rh,_ex_1));
129 }
130
131
132 // binary arithmetic assignment operators with numeric
133
134 numeric & operator+=(numeric & lh, const numeric & rh)
135 {
136         lh = lh.add(rh);
137         return lh;
138 }
139
140 numeric & operator-=(numeric & lh, const numeric & rh)
141 {
142         lh = lh.sub(rh);
143         return lh;
144 }
145
146 numeric & operator*=(numeric & lh, const numeric & rh)
147 {
148         lh = lh.mul(rh);
149         return lh;
150 }
151
152 numeric & operator/=(numeric & lh, const numeric & rh)
153 {
154         lh = lh.div(rh);
155         return lh;
156 }
157
158
159 // unary operators
160
161 const ex operator+(const ex & lh)
162 {
163         return lh;
164 }
165
166 const ex operator-(const ex & lh)
167 {
168         return exminus(lh);
169 }
170
171 const numeric operator+(const numeric & lh)
172 {
173         return lh;
174 }
175
176 const numeric operator-(const numeric & lh)
177 {
178         return _num_1.mul(lh);
179 }
180
181
182 // increment / decrement operators
183
184 /** Expression prefix increment.  Adds 1 and returns incremented ex. */
185 ex & operator++(ex & rh)
186 {
187         return rh = exadd(rh, _ex1);
188 }
189
190 /** Expression prefix decrement.  Subtracts 1 and returns decremented ex. */
191 ex & operator--(ex & rh)
192 {
193         return rh = exadd(rh, _ex_1);
194 }
195
196 /** Expression postfix increment.  Returns the ex and leaves the original
197  *  incremented by 1. */
198 const ex operator++(ex & lh, int)
199 {
200         ex tmp(lh);
201         lh = exadd(lh, _ex1);
202         return tmp;
203 }
204
205 /** Expression postfix decrement.  Returns the ex and leaves the original
206  *  decremented by 1. */
207 const ex operator--(ex & lh, int)
208 {
209         ex tmp(lh);
210         lh = exadd(lh, _ex_1);
211         return tmp;
212 }
213
214 /** Numeric prefix increment.  Adds 1 and returns incremented number. */
215 numeric& operator++(numeric & rh)
216 {
217         rh = rh.add(_num1);
218         return rh;
219 }
220
221 /** Numeric prefix decrement.  Subtracts 1 and returns decremented number. */
222 numeric& operator--(numeric & rh)
223 {
224         rh = rh.add(_num_1);
225         return rh;
226 }
227
228 /** Numeric postfix increment.  Returns the number and leaves the original
229  *  incremented by 1. */
230 const numeric operator++(numeric & lh, int)
231 {
232         numeric tmp(lh);
233         lh = lh.add(_num1);
234         return tmp;
235 }
236
237 /** Numeric postfix decrement.  Returns the number and leaves the original
238  *  decremented by 1. */
239 const numeric operator--(numeric & lh, int)
240 {
241         numeric tmp(lh);
242         lh = lh.add(_num_1);
243         return tmp;
244 }
245
246 // binary relational operators ex with ex
247
248 const relational operator==(const ex & lh, const ex & rh)
249 {
250         return relational(lh,rh,relational::equal);
251 }
252
253 const relational operator!=(const ex & lh, const ex & rh)
254 {
255         return relational(lh,rh,relational::not_equal);
256 }
257
258 const relational operator<(const ex & lh, const ex & rh)
259 {
260         return relational(lh,rh,relational::less);
261 }
262
263 const relational operator<=(const ex & lh, const ex & rh)
264 {
265         return relational(lh,rh,relational::less_or_equal);
266 }
267
268 const relational operator>(const ex & lh, const ex & rh)
269 {
270         return relational(lh,rh,relational::greater);
271 }
272
273 const relational operator>=(const ex & lh, const ex & rh)
274 {
275         return relational(lh,rh,relational::greater_or_equal);
276 }
277
278 // input/output stream operators and manipulators
279
280 static int my_ios_index()
281 {
282         static int i = std::ios_base::xalloc();
283         return i;
284 }
285
286 // Stream format gets copied or destroyed
287 static void my_ios_callback(std::ios_base::event ev, std::ios_base & s, int i)
288 {
289         print_context *p = static_cast<print_context *>(s.pword(i));
290         if (ev == std::ios_base::erase_event) {
291                 delete p;
292                 s.pword(i) = 0;
293         } else if (ev == std::ios_base::copyfmt_event && p != 0)
294                 s.pword(i) = p->duplicate();
295 }
296
297 enum {
298         callback_registered = 1
299 };
300
301 // Get print_context associated with stream, may return 0 if no context has
302 // been associated yet
303 static inline print_context *get_print_context(std::ios_base & s)
304 {
305         return static_cast<print_context *>(s.pword(my_ios_index()));
306 }
307
308 // Set print_context associated with stream, retain options
309 static void set_print_context(std::ios_base & s, const print_context & c)
310 {
311         int i = my_ios_index();
312         long flags = s.iword(i);
313         if (!(flags & callback_registered)) {
314                 s.register_callback(my_ios_callback, i);
315                 s.iword(i) = flags | callback_registered;
316         }
317         print_context *p = static_cast<print_context *>(s.pword(i));
318         unsigned options = p ? p->options : c.options;
319         delete p;
320         p = c.duplicate();
321         p->options = options;
322         s.pword(i) = p;
323 }
324
325 // Get options for print_context associated with stream
326 static inline unsigned get_print_options(std::ios_base & s)
327 {
328         print_context *p = get_print_context(s);
329         return p ? p->options : 0;
330 }
331
332 // Set options for print_context associated with stream
333 static void set_print_options(std::ostream & s, unsigned options)
334 {
335         print_context *p = get_print_context(s);
336         if (p == 0)
337                 set_print_context(s, print_dflt(s, options));
338         else
339                 p->options = options;
340 }
341
342 std::ostream & operator<<(std::ostream & os, const ex & e)
343 {
344         print_context *p = get_print_context(os);
345         if (p == 0)
346                 e.print(print_dflt(os));
347         else
348                 e.print(*p);
349         return os;
350 }
351
352 std::istream & operator>>(std::istream & is, ex & e)
353 {
354         throw (std::logic_error("expression input from streams not implemented"));
355 }
356
357 std::ostream & dflt(std::ostream & os)
358 {
359         set_print_context(os, print_dflt(os));
360         set_print_options(os, 0);
361         return os;
362 }
363
364 std::ostream & latex(std::ostream & os)
365 {
366         set_print_context(os, print_latex(os));
367         return os;
368 }
369
370 std::ostream & python(std::ostream & os)
371 {
372         set_print_context(os, print_python(os));
373         return os;
374 }
375
376 std::ostream & python_repr(std::ostream & os)
377 {
378         set_print_context(os, print_python_repr(os));
379         return os;
380 }
381
382 std::ostream & tree(std::ostream & os)
383 {
384         set_print_context(os, print_tree(os));
385         return os;
386 }
387
388 std::ostream & csrc(std::ostream & os)
389 {
390         set_print_context(os, print_csrc_double(os));
391         return os;
392 }
393
394 std::ostream & csrc_float(std::ostream & os)
395 {
396         set_print_context(os, print_csrc_float(os));
397         return os;
398 }
399
400 std::ostream & csrc_double(std::ostream & os)
401 {
402         set_print_context(os, print_csrc_double(os));
403         return os;
404 }
405
406 std::ostream & csrc_cl_N(std::ostream & os)
407 {
408         set_print_context(os, print_csrc_cl_N(os));
409         return os;
410 }
411
412 std::ostream & index_dimensions(std::ostream & os)
413 {
414         set_print_options(os, get_print_options(os) | print_options::print_index_dimensions);
415         return os;
416 }
417
418 std::ostream & no_index_dimensions(std::ostream & os)
419 {
420         set_print_options(os, get_print_options(os) & ~print_options::print_index_dimensions);
421         return os;
422 }
423
424 } // namespace GiNaC