-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmath_unit.py
102 lines (81 loc) · 3.22 KB
/
math_unit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!env python
from __future__ import print_function
#
# This class simulates Tofino Math unit. Please, see BEACON-201 or similar
# course for more details
#
# Usage:
#
# # Calculating the square of the number
# math_unit = MathUnit(shift=1, invert=False, scale=-6,
# lookup=[x*x for x in range(15, -1, -1)])
#
# result = math_unit.compute(100)
# # Hopefully this is close to 10000 :) Should be 9216, so 7.84% error
#
class MathUnit():
#
# Parameters (should be taken from your P4 program):
# shift -- the same value as the attribute math_unit_exponent_shift
# invert -- the same value as the attribute math_unit_exponent_invert
# scale -- the same value as the attribute math_unit_output_scale
# lookup -- the same values as the attribute math_unit_lookup_table
# size -- register width in bits (8, 16 or 32)
def __init__(self, shift=0, invert=False, scale=-3,
lookup=range(15, -1, -1), size=32):
if shift not in [-1, 0, 1]:
raise ValueError("Shift can only be 0, 1 or -1")
self.shift = shift
if invert not in [True, False]:
raise ValueError("Invert must be True of False")
self.invert = invert
self.scale = scale
if len(lookup) != 16:
raise ValueError("lookup must have exactly 16 elements")
for x in lookup:
if not (0 <= x <= 255):
raise ValueError("All lookup values must be between 0 and 255")
self.lookup = lookup
if size not in [8, 16, 32]:
raiseValueError("Size can only be 8, 16 or 32 (bits)")
self.size = size
self.mask = (1 << size) -1
def compute(self, arg, verbose=False):
if arg == 0:
return 0
# Get the exponent and mantissa
arg1 = arg << 3 # Add 3 extra zeroes on the left
exp1 = arg1.bit_length() - 1;
# If we are going to calculate sqrt, we need an even exponent
if self.shift == -1 and exp1 % 2 != 0:
exp1 += 1
exp = exp1 - 3 # Compensate for the initial shift
mantissa = (arg1 >> exp) & 0xF # First 4 bits
if verbose:
print('{} ~= {} * 2^{}'.format(arg, mantissa, exp-3))
# Calculate the new exponent
# First let's do linear, square or square root
if self.shift == -1:
new_exp = exp >> 1
elif self.shift == 0:
new_exp = exp
elif self.shift == 1:
new_exp = exp << 1
else:
raise ValueError("Shift can only be 0, 1 or -1")
# For the reciprocals the exponent is inverted
if self.invert:
new_exp = -new_exp
# Finally, add the scale
new_exp += self.scale
new_mantissa = self.lookup[15 - mantissa]
if verbose:
print('{} = {}[{}]'.format(
new_mantissa, self.lookup, 15 - mantissa))
if new_exp < 0:
result = new_mantissa >> -new_exp
else:
result = new_mantissa << new_exp
if verbose:
print('{} = {} * 2^{}'.format(result, new_mantissa, new_exp))
return result & self.mask