diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index c6bdcd244ca..d3a5c853a4a 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -318,22 +318,19 @@ def quantized_add_per_tensor( f"X and Y dtypes need to be in {supported_dtypes}. Got {dtype}" ) - if dtype == torch.uint8: - X = X.to(torch.int8) - Y = Y.to(torch.int8) + qmin = torch.iinfo(dtype).min + qmax = torch.iinfo(dtype).max - # TODO(agrebenisan): This should be done in fixed point arithmetic, but to match the quantized_add_out.cpp - # reference implementation, we'll do it in floating point. - dequant_X = X_scale * (X - X_zero_point) - dequant_Y = Y_scale * (Y - Y_zero_point) + dequant_X = dequantize_per_tensor(X, X_scale, X_zero_point, qmin, qmax, dtype) + dequant_Y = dequantize_per_tensor(Y, Y_scale, Y_zero_point, qmin, qmax, dtype) # q_min/q_max are unused args return quantize_per_tensor( dequant_X + dequant_Y, out_scale, out_zero_point, - torch.iinfo(dtype).min, - torch.iinfo(dtype).max, + qmin, + qmax, dtype, ) @@ -394,9 +391,9 @@ def quantized_add_asym8uxasym8u_asym8u_per_tensor( out_zero_point: int, ) -> torch.Tensor: if X.dtype != torch.uint8: - raise ValueError("X dtype must be torch.int8") + raise ValueError("X dtype must be torch.uint8") if Y.dtype != torch.uint8: - raise ValueError("Y dtype must be torch.int8") + raise ValueError("Y dtype must be torch.uint8") return quantized_add_per_tensor( X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point @@ -447,19 +444,18 @@ def quantized_mul_per_tensor( f"X and Y dtypes need to be in {supported_dtypes}. Got {dtype}" ) - if dtype == torch.uint8: - X = X.to(torch.int8) - Y = Y.to(torch.int8) + qmin = torch.iinfo(dtype).min + qmax = torch.iinfo(dtype).max - dequant_X = X_scale * (X - X_zero_point) - dequant_Y = Y_scale * (Y - Y_zero_point) + dequant_X = dequantize_per_tensor(X, X_scale, X_zero_point, qmin, qmax, dtype) + dequant_Y = dequantize_per_tensor(Y, Y_scale, Y_zero_point, qmin, qmax, dtype) return quantize_per_tensor( dequant_X * dequant_Y, out_scale, out_zero_point, - torch.iinfo(dtype).min, - torch.iinfo(dtype).max, + qmin, + qmax, dtype, ) @@ -503,8 +499,8 @@ def quantized_linear_common( ) out = torch.nn.functional.linear( - (src - in_zero_point).float(), - (weight - weight_zero_point).float(), + src.float() - in_zero_point, + weight.float() - weight_zero_point, bias.float(), ) return quantize_per_tensor( @@ -673,8 +669,8 @@ def quantized_matmul( out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift)) out = torch.matmul( - (X - X_zero_point).float(), - (Y - Y_zero_point).float(), + X.float() - X_zero_point, + Y.float() - Y_zero_point, ) return quantize_per_tensor( out, @@ -857,10 +853,10 @@ def quantized_conv_per_tensor( - out_shift (int): Unused """ if len(input_tensor.shape) == 3: - float_out = torch.nn.functional.conv1d( - (input_tensor - in_zero_point).float(), - (weight - weight_zero_point).float(), - (bias * bias_scale).float(), + acc = torch.nn.functional.conv1d( + input_tensor.float() - in_zero_point, + weight.float() - weight_zero_point, + bias.float(), stride[-1], padding[-1], dilation[-1], @@ -868,10 +864,10 @@ def quantized_conv_per_tensor( ) elif len(input_tensor.shape) == 4: - float_out = torch.nn.functional.conv2d( - (input_tensor - in_zero_point).float(), - (weight - weight_zero_point).float(), - (bias * bias_scale).float(), + acc = torch.nn.functional.conv2d( + input_tensor.float() - in_zero_point, + weight.float() - weight_zero_point, + bias.float(), stride, padding, dilation, @@ -880,6 +876,11 @@ def quantized_conv_per_tensor( else: raise ValueError("Input tensor must be 3D or 4D") + # conv accumulates in the integer domain (scale = in_scale * weight_scale = + # bias_scale) with the integer bias added pre-scale; dequantize the whole + # accumulation by bias_scale to get the floating-point result. + float_out = acc * bias_scale + return quantize_per_tensor( float_out, output_scale, @@ -1944,8 +1945,8 @@ def quantized_relu_common( out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift)) dequantized_X = torch.where( - X > X_zero_point, X - X_zero_point, torch.zeros_like(X) - ).to(torch.float32) + X > X_zero_point, X.float() - X_zero_point, torch.zeros_like(X) + ) out = quantize_per_tensor( dequantized_X, out_scale, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index f089e36d4d5..005bf9d85bd 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -670,10 +670,11 @@ def test_quantized_layer_norm_per_tensor( 0, # unused out_shift torch.uint8, # dtype torch.tensor( - [[[[238]]]], dtype=torch.uint8 - ), # (130 - 128) + (134 - 128) = 10 - # + bias -> 10 + 1 = 11 - # round(11 / 0.1 + 128) = 238, + [[[[148]]]], dtype=torch.uint8 + ), # conv_acc = sum((input - 128) * (weight - 128)) + # = 2*1 + 4*0 + 6*0 + 8*1 = 10 + # float_out = bias_scale * (conv_acc + bias) = 0.1 * (10 + 10) = 2.0 + # round(2.0 / 0.1 + 128) = 148 memory_format, ) for memory_format in [torch.contiguous_format, torch.channels_last] @@ -918,6 +919,34 @@ def test_quantized_layer_norm_per_tensor( ) for memory_format in [torch.contiguous_format, torch.channels_last] ], + # Zero-point overflow: int8 input minus a negative zero point exceeds + # the int8 range and wraps unless the subtraction is upcast first. + *[ + ( + torch.tensor( + [[[[120, 120], [120, 120]]]], dtype=torch.int8 + ), # input + torch.tensor([[[[1, 0], [0, 1]]]], dtype=torch.int8), # weight + torch.tensor([0], dtype=torch.int32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + -20, # in_zero_point (120 - (-20) = 140 wraps to -116 in int8) + 0, # weight_zero_point + 1.0, # bias_scale + 4.0, # output_scale + 0, # output_zero_point + 0, # unused out_multiplier + 0, # unused out_shift + torch.int8, # dtype + torch.tensor( + [[[[70]]]], dtype=torch.int8 + ), # (120 + 20) * 2 / 4 = 70 + memory_format, + ) + for memory_format in [torch.contiguous_format, torch.channels_last] + ], ] ) def test_quantized_conv_per_tensor( @@ -1407,6 +1436,23 @@ def test_quantized_w8a32_linear( ) for dtype in [torch.uint8] ], + # Zero-point overflow: int8 X minus a negative zero point exceeds the + # int8 range and wraps unless the subtraction is upcast first. + *[ + ( + "int8_negative_zp_overflow", + torch.tensor([120], dtype=dtype), # input + -20, # X_zero_point (120 - (-20) = 140 wraps to -116 in int8) + 0, # out_zero_point + 1073741824, # out_multiplier (0.5 * 2^31) + 0, # out_shift + dtype, # dtype + torch.tensor( + [-70], dtype=dtype + ), # shifted = 140; 140 * (-0.5) = -70 + ) + for dtype in [torch.int8] + ], ] ) def test_quantized_relu_per_tensor(