Skip to content

Commit

Permalink
Perfomed some clean-up and undid all changes which aren't ready yet a…
Browse files Browse the repository at this point in the history
…nd depend on issue xdslproject#3654 to be fixed
  • Loading branch information
watermelonwolverine committed Jan 2, 2025
1 parent 0ff1569 commit b331962
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 78 deletions.
79 changes: 3 additions & 76 deletions xdsl/dialects/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,9 @@ def verify_permutation_map(
permutation_map: AffineMap,
):
"""
TODO test
This mirrors VectorOps.cpp -> verifyPermutationMap
"""

# This mirrors VectorOps.cpp -> verifyPermutationMap
seen: list[bool] = [False for _ in range(permutation_map.num_dims)]

for expr in permutation_map.results:
Expand All @@ -430,9 +429,6 @@ def verify_transfer_op(
op: TransferReadOp | TransferWriteOp,
shaped_type: MemRefType[Attribute] | TensorType[Attribute],
vector_type: VectorType[Attribute],
mask_type: VectorType[I1] | None,
# WJOG9GVF: TODO fix: remove None type from inferred_mask_type once 7S4F0FZA has been fixed
inferred_mask_type: VectorType[I1] | None,
permutation_map: AffineMap,
in_bounds: ArrayAttr[BoolAttr],
):
Expand Down Expand Up @@ -481,13 +477,6 @@ def verify_transfer_op(
f'"{op.name}" requires a permutation_map with input dims of the same rank as the source type'
)

# WJOG9GVF: TODO fix: uncomment this when 7S4F0FZA has been fixed
# if mask_type:
# if mask_type != inferred_mask_type:
# raise VerifyException(
# f'"{op.name}" inferred mask type ({inferred_mask_type}) and mask operand type ({mask_type}) don\'t match'
# )

if len(in_bounds) != len(permutation_map.results):
raise VerifyException(
f'"{op.name}" expects the optional in_bounds attr of same rank as permutation_map results: {str(permutation_map)} vs in_bounds of of size {len(in_bounds)}'
Expand All @@ -503,44 +492,8 @@ def verify_transfer_op(
)


def infer_transfer_op_mask_type(
vector_type: VectorType[Attribute],
affine_map: AffineMap,
) -> VectorType[I1] | None:
"""
TODO test
"""

# 7S4F0FZA
# TODO uncomment and test this once VectorType has been fixed, see issue #3654
# When you do this also fix all WJOG9GVF

# inverse_permutation_map = affine_map.compress_dims(
# affine_map.unused_dims_bit_vector()
# ).inverse_permutation()

# assert inverse_permutation_map

# mask_shape = inverse_permutation_map.compose_with_values(vector_type.get_shape())

# scalable_dims = inverse_permutation_map.eval(
# [1 if dim_scalable else 0 for dim_scalable in vector_type.get_scalable_dims()],
# [],
# )

# return VectorType(
# i1,
# mask_shape,
# [dim_scalable == 1 for dim_scalable in scalable_dims],
# )

return None


class VectorTransferOp(ABC):
"""
TODO document
TODO test
Mirrors VectorTransferOpInterface from VectorInterfaces.h.inc
"""

Expand Down Expand Up @@ -581,31 +534,17 @@ class TransferReadOp(IRDLOperation, VectorTransferOp):

irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]

# assembly_format = "$source `[` $indices `]` `,` $padding ( `,` $mask^ )? attr-dict `:` type($source) `,` type($result)"

def verify_(self):
assert isa(self.source.type, MemRefType[Attribute] | TensorType[Attribute])
assert isa(self.result.type, VectorType[Attribute])
if self.mask:
assert isa(self.mask.type, VectorType[I1])
mask_type = self.mask.type
else:
mask_type = None

if len(self.indices) != self.source.type.get_num_dims():
raise VerifyException("Expected an index for each memref/tensor dimension.")

inferred_mask_type = infer_transfer_op_mask_type(
self.result.type,
self.permutation_map.data,
)

verify_transfer_op(
self,
self.source.type,
self.result.type,
mask_type,
inferred_mask_type,
self.permutation_map.data,
self.in_bounds,
)
Expand Down Expand Up @@ -642,7 +581,7 @@ def __init__(
properties={"permutation_map": permutation_map, "in_bounds": in_bounds},
)

# override
# override VectorTransferOp.get_permutation_map
def get_permutation_map(self):
return self.permutation_map.data

Expand All @@ -666,11 +605,6 @@ class TransferWriteOp(IRDLOperation, VectorTransferOp):
def verify_(self):
assert isa(self.source.type, MemRefType[Attribute] | TensorType[Attribute])
assert isa(self.vector.type, VectorType[Attribute])
if self.mask:
assert isa(self.mask.type, VectorType[I1])
mask_type = self.mask.type
else:
mask_type = None

if len(self.indices) != self.source.type.get_num_dims():
raise VerifyException("Expected an index for each memref/tensor dimension.")
Expand All @@ -680,17 +614,10 @@ def verify_(self):
f'"{self.name}" should not have broadcast dimensions.'
)

inferred_mask_type = infer_transfer_op_mask_type(
self.vector.type,
self.permutation_map.data,
)

verify_transfer_op(
self,
self.source.type,
self.vector.type,
mask_type,
inferred_mask_type,
self.permutation_map.data,
self.in_bounds,
)
Expand All @@ -716,7 +643,7 @@ def __init__(
result_types=[result_type],
)

# override
# override VectorTransferOp.get_permutation_map
def get_permutation_map(self):
return self.permutation_map.data

Expand Down
2 changes: 0 additions & 2 deletions xdsl/ir/affine/affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,6 @@ def compose(self, other: AffineMap) -> AffineMap:

def compose_with_values(self, values: Sequence[int]) -> tuple[int, ...]:
"""
TODO document
TODO test
Same as SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const from AffineMap.cpp
"""
assert self.num_symbols == 0
Expand Down

0 comments on commit b331962

Please sign in to comment.