Skip to content

Commit e449479

Browse files
committed
Fix multiplication and division of complex numbers
Multiplication and division of complex numbers are not just pointwise applications of those operations. Fixes: #8375
1 parent 629dbcd commit e449479

File tree

4 files changed

+136
-44
lines changed

4 files changed

+136
-44
lines changed

regression/cbmc/complex2/main.c

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#include <complex.h>
2+
3+
int main()
4+
{
5+
char choice;
6+
float re = choice ? 1.3f : 2.1f; // a non-constant well-behaved float
7+
_Complex float z1 = I + re;
8+
_Complex float z2 = z1 * z1;
9+
_Complex float expected = 2 * I * re + re * re - 1; // (a+i)^2 = 2ai + a^2 - 1
10+
_Complex float actual =
11+
re * re + I; // (a1 + b1*i)*(a2 + b2*i) = (a1*a2 + b1*b2*i)
12+
__CPROVER_assert(z2 == expected, "right");
13+
__CPROVER_assert(z2 == actual, "wrong");
14+
return 0;
15+
}

regression/cbmc/complex2/test.desc

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
CORE no-new-smt gcc-only
2+
main.c
3+
4+
^\[main.assertion.1\] line 12 right: SUCCESS$
5+
^\[main.assertion.2\] line 13 wrong: FAILURE$
6+
^VERIFICATION FAILED$
7+
^EXIT=10$
8+
^SIGNAL=0$
9+
--
10+
^warning: ignoring
11+
--
12+
Visual Studio does not directly support `complex` or `_Complex`, or using
13+
standard arithmetic operators over complex numbers.

src/goto-programs/remove_complex.cpp

+50-6
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,10 @@ static void remove_complex(exprt &expr)
127127

128128
if(expr.type().id()==ID_complex)
129129
{
130-
if(expr.id()==ID_plus || expr.id()==ID_minus ||
131-
expr.id()==ID_mult || expr.id()==ID_div)
130+
if(expr.id() == ID_plus || expr.id() == ID_minus)
132131
{
133-
// FIXME plus and mult are defined as n-ary operations
134-
// rather than binary. This code assumes that they
135-
// can only have exactly 2 operands, and it is not clear
136-
// that it is safe to do so in this context
132+
// plus and mult are n-ary expressions, but front-ends currently ensure
133+
// that we only see them as binary ones
137134
PRECONDITION(expr.operands().size() == 2);
138135
// do component-wise:
139136
// x+y -> complex(x.r+y.r,x.i+y.i)
@@ -153,6 +150,53 @@ static void remove_complex(exprt &expr)
153150

154151
expr=struct_expr;
155152
}
153+
else if(expr.id() == ID_mult)
154+
{
155+
// plus and mult are n-ary expressions, but front-ends currently ensure
156+
// that we only see them as binary ones
157+
PRECONDITION(expr.operands().size() == 2);
158+
exprt lhs_real = complex_member(to_binary_expr(expr).op0(), ID_real);
159+
exprt lhs_imag = complex_member(to_binary_expr(expr).op0(), ID_imag);
160+
exprt rhs_real = complex_member(to_binary_expr(expr).op1(), ID_real);
161+
exprt rhs_imag = complex_member(to_binary_expr(expr).op1(), ID_imag);
162+
163+
struct_exprt struct_expr{
164+
{minus_exprt{
165+
mult_exprt{lhs_real, rhs_real}, mult_exprt{lhs_imag, rhs_imag}},
166+
plus_exprt{
167+
mult_exprt{lhs_imag, rhs_real}, mult_exprt{lhs_real, rhs_imag}}},
168+
expr.type()};
169+
170+
struct_expr.op0().add_source_location() = expr.source_location();
171+
struct_expr.op1().add_source_location() = expr.source_location();
172+
173+
expr = struct_expr;
174+
}
175+
else if(expr.id() == ID_div)
176+
{
177+
exprt lhs_real = complex_member(to_binary_expr(expr).op0(), ID_real);
178+
exprt lhs_imag = complex_member(to_binary_expr(expr).op0(), ID_imag);
179+
exprt rhs_real = complex_member(to_binary_expr(expr).op1(), ID_real);
180+
exprt rhs_imag = complex_member(to_binary_expr(expr).op1(), ID_imag);
181+
182+
plus_exprt numerator_real{
183+
mult_exprt{lhs_real, rhs_real}, mult_exprt{lhs_imag, rhs_imag}};
184+
minus_exprt numerator_imag{
185+
mult_exprt{lhs_imag, rhs_real}, mult_exprt{lhs_real, rhs_imag}};
186+
187+
plus_exprt denominator{
188+
mult_exprt{rhs_real, rhs_real}, mult_exprt{rhs_imag, rhs_imag}};
189+
190+
struct_exprt struct_expr{
191+
{div_exprt{numerator_real, denominator},
192+
div_exprt{numerator_imag, denominator}},
193+
expr.type()};
194+
195+
struct_expr.op0().add_source_location() = expr.source_location();
196+
struct_expr.op1().add_source_location() = expr.source_location();
197+
198+
expr = struct_expr;
199+
}
156200
else if(expr.id()==ID_unary_minus)
157201
{
158202
auto const &unary_minus_expr = to_unary_minus_expr(expr);

src/solvers/flattening/boolbv_floatbv_op.cpp

+58-38
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,14 @@ Author: Daniel Kroening, [email protected]
66
77
\*******************************************************************/
88

9-
#include "boolbv.h"
10-
11-
#include <algorithm>
12-
139
#include <util/bitvector_types.h>
1410
#include <util/c_types.h>
1511
#include <util/floatbv_expr.h>
1612

1713
#include <solvers/floatbv/float_utils.h>
1814

15+
#include "boolbv.h"
16+
1917
bvt boolbvt::convert_floatbv_typecast(const floatbv_typecast_exprt &expr)
2018
{
2119
const exprt &op0=expr.op(); // number to convert
@@ -131,44 +129,66 @@ bvt boolbvt::convert_floatbv_op(const ieee_float_op_exprt &expr)
131129
sub_width > 0 && width % sub_width == 0,
132130
"width of a complex subtype must be positive and evenly divide the "
133131
"width of the complex expression");
132+
DATA_INVARIANT(
133+
sub_width * 2 == width, "a complex type consists of exactly two parts");
134134

135-
std::size_t size=width/sub_width;
136-
bvt result_bv;
137-
result_bv.resize(width);
135+
bvt lhs_real{lhs_as_bv.begin(), lhs_as_bv.begin() + sub_width};
136+
bvt rhs_real{rhs_as_bv.begin(), rhs_as_bv.begin() + sub_width};
138137

139-
for(std::size_t i=0; i<size; i++)
138+
bvt lhs_imag{lhs_as_bv.begin() + sub_width, lhs_as_bv.end()};
139+
bvt rhs_imag{rhs_as_bv.begin() + sub_width, rhs_as_bv.end()};
140+
141+
bvt result_real, result_imag;
142+
143+
if(expr.id() == ID_floatbv_plus || expr.id() == ID_floatbv_minus)
144+
{
145+
result_real = float_utils.add_sub(
146+
lhs_real, rhs_real, expr.id() == ID_floatbv_minus);
147+
result_imag = float_utils.add_sub(
148+
lhs_imag, rhs_imag, expr.id() == ID_floatbv_minus);
149+
}
150+
else if(expr.id() == ID_floatbv_mult)
151+
{
152+
// Could be optimised to just three multiplications with more additions
153+
// instead, but then we'd have to worry about the impact of possible
154+
// overflows. So we use the naive approach for now:
155+
result_real = float_utils.add_sub(
156+
float_utils.mul(lhs_real, rhs_real),
157+
float_utils.mul(lhs_imag, rhs_imag),
158+
true);
159+
result_imag = float_utils.add_sub(
160+
float_utils.mul(lhs_real, rhs_imag),
161+
float_utils.mul(lhs_imag, rhs_real),
162+
false);
163+
}
164+
else if(expr.id() == ID_floatbv_div)
140165
{
141-
bvt lhs_sub_bv, rhs_sub_bv, sub_result_bv;
142-
143-
lhs_sub_bv.assign(
144-
lhs_as_bv.begin() + i * sub_width,
145-
lhs_as_bv.begin() + (i + 1) * sub_width);
146-
rhs_sub_bv.assign(
147-
rhs_as_bv.begin() + i * sub_width,
148-
rhs_as_bv.begin() + (i + 1) * sub_width);
149-
150-
if(expr.id()==ID_floatbv_plus)
151-
sub_result_bv = float_utils.add_sub(lhs_sub_bv, rhs_sub_bv, false);
152-
else if(expr.id()==ID_floatbv_minus)
153-
sub_result_bv = float_utils.add_sub(lhs_sub_bv, rhs_sub_bv, true);
154-
else if(expr.id()==ID_floatbv_mult)
155-
sub_result_bv = float_utils.mul(lhs_sub_bv, rhs_sub_bv);
156-
else if(expr.id()==ID_floatbv_div)
157-
sub_result_bv = float_utils.div(lhs_sub_bv, rhs_sub_bv);
158-
else
159-
UNREACHABLE;
160-
161-
INVARIANT(
162-
sub_result_bv.size() == sub_width,
163-
"we constructed a new complex of the right size");
164-
INVARIANT(
165-
i * sub_width + sub_width - 1 < result_bv.size(),
166-
"the sub-bitvector fits into the result bitvector");
167-
std::copy(
168-
sub_result_bv.begin(),
169-
sub_result_bv.end(),
170-
result_bv.begin() + i * sub_width);
166+
bvt numerator_real = float_utils.add_sub(
167+
float_utils.mul(lhs_real, rhs_real),
168+
float_utils.mul(lhs_imag, rhs_imag),
169+
false);
170+
bvt numerator_imag = float_utils.add_sub(
171+
float_utils.mul(lhs_imag, rhs_real),
172+
float_utils.mul(lhs_real, rhs_imag),
173+
true);
174+
175+
bvt denominator = float_utils.add_sub(
176+
float_utils.mul(rhs_real, rhs_real),
177+
float_utils.mul(rhs_imag, rhs_imag),
178+
false);
179+
180+
result_real = float_utils.div(numerator_real, denominator);
181+
result_imag = float_utils.div(numerator_imag, denominator);
171182
}
183+
else
184+
UNREACHABLE;
185+
186+
bvt result_bv = std::move(result_real);
187+
result_bv.reserve(width);
188+
result_bv.insert(
189+
result_bv.end(),
190+
std::make_move_iterator(result_imag.begin()),
191+
std::make_move_iterator(result_imag.end()));
172192

173193
return result_bv;
174194
}

0 commit comments

Comments
 (0)