]> www.ginac.de Git - ginac.git/blob - ginac/printcsrc.cpp
418d3845f8eb9edfc30f54a6fe97e548bbbf83fc
[ginac.git] / ginac / printcsrc.cpp
1 /** @file printcsrc.cpp
2  *
3  *  The methods .printcsrc() are responsible for C-source output of
4  *  objects.  All related helper-functions go in here as well. */
5
6 #include <iostream>
7
8 #include "ginac.h"
9
10 /** Print expression as a C++ statement. The output looks like
11  *  "<type> <var_name> = <expression>;". The "type" parameter has an effect
12  *  on how number literals are printed.
13  *
14  *  @param os output stream
15  *  @param type variable type (one of the csrc_types)
16  *  @param var_name variable name to be printed */
17 void ex::printcsrc(ostream & os, unsigned type, const char *var_name) const
18 {
19     debugmsg("ex print csrc", LOGLEVEL_PRINT);
20     ASSERT(bp!=0);
21         switch (type) {
22                 case csrc_types::ctype_float:
23                         os << "float ";
24                         break;
25                 case csrc_types::ctype_double:
26                         os << "double ";
27                         break;
28                 case csrc_types::ctype_cl_N:
29                         os << "cl_N ";
30                         break;
31         }
32     os << var_name << " = ";
33     bp->printcsrc(os, type, 0);
34     os << ";\n";
35 }
36
37 void basic::printcsrc(ostream & os, unsigned type, unsigned upper_precedence) const
38 {
39     debugmsg("basic print csrc", LOGLEVEL_PRINT);
40 }
41
42 void numeric::printcsrc(ostream & os, unsigned type, unsigned upper_precedence) const
43 {
44     debugmsg("numeric print csrc", LOGLEVEL_PRINT);
45         ios::fmtflags oldflags = os.flags();
46         os.setf(ios::scientific);
47     if (is_rational() && !is_integer()) {
48         if (compare(numZERO()) > 0) {
49             os << "(";
50                         if (type == csrc_types::ctype_cl_N)
51                                 os << "cl_F(\"" << numer().evalf() << "\")";
52                         else
53                     os << numer().to_double();
54         } else {
55             os << "-(";
56                         if (type == csrc_types::ctype_cl_N)
57                                 os << "cl_F(\"" << -numer().evalf() << "\")";
58                         else
59                     os << -numer().to_double();
60         }
61         os << "/";
62                 if (type == csrc_types::ctype_cl_N)
63                         os << "cl_F(\"" << denom().evalf() << "\")";
64                 else
65                 os << denom().to_double();
66         os << ")";
67     } else {
68                 if (type == csrc_types::ctype_cl_N)
69                         os << "cl_F(\"" << evalf() << "\")";
70                 else
71                 os << to_double();
72         }
73         os.flags(oldflags);
74 }
75
76 void symbol::printcsrc(ostream & os, unsigned type, unsigned upper_precedence) const
77 {
78     debugmsg("symbol print csrc", LOGLEVEL_PRINT);
79     os << name;
80 }
81
82 void constant::printcsrc(ostream & os, unsigned type, unsigned upper_precedence) const
83 {
84     debugmsg("constant print csrc",LOGLEVEL_PRINT);
85     os << name;
86 }
87
88 static void print_sym_pow(ostream & os, unsigned type, const symbol &x, int exp)
89 {
90     // Optimal output of integer powers of symbols to aid compiler CSE
91     if (exp == 1) {
92         x.printcsrc(os, type, 0);
93     } else if (exp == 2) {
94         x.printcsrc(os, type, 0);
95         os << "*";
96         x.printcsrc(os, type, 0);
97     } else if (exp & 1) {
98         x.printcsrc(os, 0);
99         os << "*";
100         print_sym_pow(os, type, x, exp-1);
101     } else {
102         os << "(";
103         print_sym_pow(os, type, x, exp >> 1);
104         os << ")*(";
105         print_sym_pow(os, type, x, exp >> 1);
106         os << ")";
107     }
108 }
109
110 void power::printcsrc(ostream & os, unsigned type, unsigned upper_precedence) const
111 {
112     debugmsg("power print csrc", LOGLEVEL_PRINT);
113
114         // Integer powers of symbols are printed in a special, optimized way
115     if (exponent.info(info_flags::integer) &&
116         (is_ex_exactly_of_type(basis, symbol) ||
117          is_ex_exactly_of_type(basis, constant))) {
118         int exp = ex_to_numeric(exponent).to_int();
119         if (exp > 0)
120             os << "(";
121         else {
122             exp = -exp;
123                         if (type == csrc_types::ctype_cl_N)
124                                 os << "recip(";
125                         else
126                     os << "1.0/(";
127         }
128         print_sym_pow(os, type, static_cast<const symbol &>(*basis.bp), exp);
129         os << ")";
130
131         // <expr>^-1 is printed as "1.0/<expr>" or with the recip() function of CLN
132     } else if (exponent.compare(numMINUSONE()) == 0) {
133                 if (type == csrc_types::ctype_cl_N)
134                         os << "recip(";
135                 else
136                 os << "1.0/(";
137         basis.bp->printcsrc(os, type, 0);
138                 os << ")";
139
140         // Otherwise, use the pow() or expt() (CLN) functions
141     } else {
142                 if (type == csrc_types::ctype_cl_N)
143                         os << "expt(";
144                 else
145                 os << "pow(";
146         basis.bp->printcsrc(os, type, 0);
147         os << ",";
148         exponent.bp->printcsrc(os, type, 0);
149         os << ")";
150     }
151 }
152
153 void add::printcsrc(ostream & os, unsigned type, unsigned upper_precedence) const
154 {
155     debugmsg("add print csrc", LOGLEVEL_PRINT);
156     if (precedence <= upper_precedence)
157         os << "(";
158
159         // Print arguments, separated by "+"
160     epvector::const_iterator it = seq.begin();
161     epvector::const_iterator itend = seq.end();
162     while (it != itend) {
163
164                 // If the coefficient is -1, it is replaced by a single minus sign
165         if (it->coeff.compare(numONE()) == 0) {
166                 it->rest.bp->printcsrc(os, type, precedence);
167         } else if (it->coeff.compare(numMINUSONE()) == 0) {
168             os << "-";
169                 it->rest.bp->printcsrc(os, type, precedence);
170                 } else if (ex_to_numeric(it->coeff).numer().compare(numONE()) == 0) {
171                 it->rest.bp->printcsrc(os, type, precedence);
172                         os << "/";
173             ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
174                 } else if (ex_to_numeric(it->coeff).numer().compare(numMINUSONE()) == 0) {
175                         os << "-";
176                 it->rest.bp->printcsrc(os, type, precedence);
177                         os << "/";
178             ex_to_numeric(it->coeff).denom().printcsrc(os, type, precedence);
179                 } else {
180             it->coeff.bp->printcsrc(os, type, precedence);
181             os << "*";
182                 it->rest.bp->printcsrc(os, type, precedence);
183         }
184
185                 // Separator is "+", except it the following expression would have a leading minus sign
186         it++;
187         if (it != itend && !(it->coeff.compare(numZERO()) < 0 || (it->coeff.compare(numONE()) == 0 && is_ex_exactly_of_type(it->rest, numeric) && it->rest.compare(numZERO()) < 0)))
188             os << "+";
189     }
190     
191     if (!overall_coeff.is_equal(exZERO())) {
192         if (overall_coeff > 0) os << '+';
193         overall_coeff.bp->printcsrc(os,type,precedence);
194     }
195     
196     if (precedence <= upper_precedence)
197         os << ")";
198 }
199
200 void mul::printcsrc(ostream & os, unsigned type, unsigned upper_precedence) const
201 {
202     debugmsg("mul print csrc", LOGLEVEL_PRINT);
203     if (precedence <= upper_precedence)
204         os << "(";
205
206     if (!overall_coeff.is_equal(exONE())) {
207         overall_coeff.bp->printcsrc(os,type,precedence);
208         os << "*";
209     }
210     
211         // Print arguments, separated by "*" or "/"
212     epvector::const_iterator it = seq.begin();
213     epvector::const_iterator itend = seq.end();
214     while (it != itend) {
215
216                 // If the first argument is a negative integer power, it gets printed as "1.0/<expr>"
217         if (it == seq.begin() && ex_to_numeric(it->coeff).is_integer() && it->coeff.compare(numZERO()) < 0) {
218                         if (type == csrc_types::ctype_cl_N)
219                                 os << "recip(";
220                         else
221                     os << "1.0/";
222                 }
223
224                 // If the exponent is 1 or -1, it is left out
225         if (it->coeff.compare(exONE()) == 0 || it->coeff.compare(numMINUSONE()) == 0)
226             it->rest.bp->printcsrc(os, type, precedence);
227         else
228             // outer parens around ex needed for broken gcc-2.95 parser:
229             (ex(power(it->rest, abs(ex_to_numeric(it->coeff))))).bp->printcsrc(os, type, upper_precedence);
230
231                 // Separator is "/" for negative integer powers, "*" otherwise
232         it++;
233         if (it != itend) {
234             if (ex_to_numeric(it->coeff).is_integer() && it->coeff.compare(numZERO()) < 0)
235                 os << "/";
236             else
237                 os << "*";
238         }
239     }
240     if (precedence <= upper_precedence)
241         os << ")";
242 }
243
244 void ncmul::printcsrc(ostream & os, unsigned upper_precedence) const
245 {
246     debugmsg("ncmul print csrc",LOGLEVEL_PRINT);
247     exvector::const_iterator it;
248     exvector::const_iterator itend = seq.end()-1;
249     os << "ncmul(";
250     for (it=seq.begin(); it!=itend; ++it) {
251         (*it).bp->printcsrc(os,precedence);
252         os << ",";
253     }
254     (*it).bp->printcsrc(os,precedence);
255     os << ")";
256 }
257
258 void relational::printcsrc(ostream & os, unsigned type, unsigned upper_precedence) const
259 {
260     debugmsg("relational print csrc", LOGLEVEL_PRINT);
261     if (precedence<=upper_precedence)
262                 os << "(";
263
264         // Print left-hand expression
265     lh.bp->printcsrc(os, type, precedence);
266
267         // Print relational operator
268     switch (o) {
269             case equal:
270                 os << "==";
271                 break;
272             case not_equal:
273                 os << "!=";
274                 break;
275             case less:
276                 os << "<";
277                 break;
278             case less_or_equal:
279                 os << "<=";
280                 break;
281             case greater:
282                 os << ">";
283                 break;
284             case greater_or_equal:
285                 os << ">=";
286                 break;
287             default:
288                 os << "(INVALID RELATIONAL OPERATOR)";
289                         break;
290     }
291
292         // Print right-hand operator
293     rh.bp->printcsrc(os, type, precedence);
294
295     if (precedence <= upper_precedence)
296                 os << ")";
297 }