diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 2a4e722f68b..466f9d69bde 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -294,28 +294,16 @@ def register_comparison_ops(): # ============================================================================= -@update_features(exir_ops.edge.aten.bitwise_and.Tensor) -def register_bitwise_and(): - return OpFeatures( - inputs_storage=utils.ANY_STORAGE, - inputs_dtypes=utils.BOOL_T, - supports_resize=True, - supports_highdim=True, - ) - - -@update_features(exir_ops.edge.aten.bitwise_not.default) -def register_bitwise_not(): - return OpFeatures( - inputs_storage=utils.ANY_STORAGE, - inputs_dtypes=utils.BOOL_T, - supports_resize=True, - supports_highdim=True, - ) - - -@update_features(exir_ops.edge.aten.logical_and.default) -def register_logical_and(): +@update_features( + [ + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_not.default, + exir_ops.edge.aten.logical_and.default, + exir_ops.edge.aten.logical_or.default, + ] +) +def register_bool_binary_ops(): return OpFeatures( inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.BOOL_T, diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op_buffer.yaml index 8aef89cd739..1f217acb127 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op_buffer.yaml @@ -52,3 +52,8 @@ binary_op_buffer: generate_variant_forall: DTYPE: - VALUE: uint8 + - NAME: binary_bitwise_or_buffer + OPERATOR: X | Y + generate_variant_forall: + DTYPE: + - VALUE: uint8 diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op_texture.yaml index 437803b2410..289466e7845 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op_texture.yaml @@ -54,3 +54,8 @@ binary_op_texture: generate_variant_forall: DTYPE: - VALUE: uint8 + - NAME: binary_bitwise_or_texture3d + OPERATOR: X | Y + generate_variant_forall: + DTYPE: + - VALUE: uint8 diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 6ff58e72dc3..9e696a008fe 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -143,6 +143,7 @@ DEFINE_BINARY_OP_FN(le); DEFINE_BINARY_OP_FN(gt); DEFINE_BINARY_OP_FN(ge); DEFINE_BINARY_OP_FN(bitwise_and); +DEFINE_BINARY_OP_FN(bitwise_or); REGISTER_OPERATORS { VK_REGISTER_OP(aten.add.Tensor, add); @@ -159,6 +160,8 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.ge.Tensor, ge); VK_REGISTER_OP(aten.bitwise_and.Tensor, bitwise_and); VK_REGISTER_OP(aten.logical_and.default, bitwise_and); + VK_REGISTER_OP(aten.bitwise_or.Tensor, bitwise_or); + VK_REGISTER_OP(aten.logical_or.default, bitwise_or); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 6efae3d0398..7a3b6943653 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -2141,14 +2141,18 @@ def get_where_inputs(): return test_suite -@register_test_suite("aten.bitwise_and.Tensor") -def get_bitwise_and_inputs(): +@register_test_suite( + ["aten.bitwise_and.Tensor", "aten.bitwise_or.Tensor", "aten.logical_or.default"] +) +def get_bitwise_binary_inputs(): test_suite = VkTestSuite( [ ((M1, M2), (M1, M2)), ((S, S1, S2), (S, S1, S2)), ((XS, S, S1, S2), (XS, S, S1, S2)), ((1, M1), (1, M1)), + ((1, M2), (M1, M2)), + ((XS, 1, S1, 1), (1, S, 1, S2)), ] ) test_suite.layouts = [ @@ -2160,7 +2164,6 @@ def get_bitwise_and_inputs(): "utils::kTexture3D", ] test_suite.dtypes = ["at::kBool"] - test_suite.data_gen = "make_seq_tensor" return test_suite