diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 5b00bfe2d..3a52cda73 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -2598,6 +2598,10 @@ def upsample_nearest2d(context, node): def tupleunpack(context, node): inputs = _get_inputs(context, node, expected=1) values = inputs[0] + + if len(node.outputs) == 1: + values = [values] + # Node input could have been turned into constant array in @tupleconstruct if not isinstance(values, tuple) and not isinstance(values, list): values = values.val @@ -3097,8 +3101,11 @@ def index(context, node): # For multiple index axes case, we now assume that all the index have equal shape for index in valid_indices: if not is_compatible_symbolic_vector(index.shape, valid_indices[0].shape): - raise NotImplementedError("Broadcasable tensor index not supported.") - + broadcast_inputs = _broadcast_tensors([valid_indices[0], index]) + index = broadcast_inputs[1] + valid_indices[0] = broadcast_inputs[0] + valid_indices.append(index) + # First stack the index together indices_rank = valid_indices[0].rank indices = mb.stack(values=valid_indices, axis=indices_rank) @@ -3398,6 +3405,18 @@ def _slice(context, node): context.add(res) +def _num_splits_and_sizes(split_sizes): + if split_sizes.sym_val is not None: + return len(split_sizes.sym_val), split_sizes.sym_val + + if any_symbolic(split_sizes.shape): + raise ValueError("Unable to determine number of splits") + + num_splits = len(split_sizes.shape) + sizes = [get_new_symbol() for _ in range(num_splits)] + return num_splits, sizes + + @register_torch_op(torch_alias=["split_with_sizes"]) def split(context, node): inputs = _get_inputs(context, node, expected=3) @@ -3425,6 +3444,14 @@ def split(context, node): else: partial_size = mb.mul(x=tmp, y=remainder) split_sizes = mb.concat(values=[whole_sizes, partial_size], axis=0) + + + num_splits, sizes = _num_splits_and_sizes(split_sizes=split_sizes) + if num_splits == 1: + out = mb.identity(x=x, name=node.name) + context.add(out, node.name) + return + res = mb.split(x=x, split_sizes=split_sizes, axis=dim, name=node.name) context.add(res, torch_name=node.name) @@ -3482,6 +3509,13 @@ def to(context, node): "Received invalid arguments for PyTorch conversion of op {}".format(node) ) + # We have to handle the case where the dtype is not set, this should be inferred from the Tensor dtype + # see, https://pytorch.org/docs/stable/generated/torch.Tensor.to.html?highlight=#torch.Tensor.to + if dtype is None: + out = mb.identity(x=_input, name=node.name) + context.add(out, node.name) + return + torch_dtype = NUM_TO_TORCH_DTYPE[dtype] if isinstance(_input, Var) and _input.val is not None: _input = _input.val @@ -3924,8 +3958,20 @@ def ceil(context, node): @register_torch_op def clamp(context, node): inputs = _get_inputs(context, node, expected=3) - min_val = inputs[1] if inputs[1] else _np.finfo(_np.float32).min - max_val = inputs[2] if inputs[2] else _np.finfo(_np.float32).max + if not inputs[1]: + min_val = _np.finfo(_np.float32).min + else: + min_val = inputs[1] + if types.builtin_to_string(min_val.dtype).startswith('int'): + min_val = mb.cast(x=min_val, dtype='fp32') + + if not inputs[2]: + max_val = _np.finfo(_np.float32).max + else: + max_val = inputs[2] + if types.builtin_to_string(max_val.dtype).startswith('int'): + max_val = mb.cast(x=max_val, dtype='fp32') + context.add(mb.clip(x=inputs[0], alpha=min_val, beta=max_val, name=node.name)) @register_torch_op @@ -4074,7 +4120,7 @@ def is_floating_point(context, node): is_float = types.is_float(inputs[0].dtype) context.add(mb.const(val=is_float, name=node.name)) -@register_torch_op() +@register_torch_op(torch_alias=["__and_", "__and__"]) def logical_and(context, node): inputs = _get_inputs(context, node, expected=2) x, y = inputs @@ -4253,6 +4299,11 @@ def _make_tensor(list_of_tensor, name, rank): context.add(mb.identity(x=val, name=node.name)) return + if inputs[2] is None: + res = mb.const(val=[val.val], name=node.name) + context.add(res, torch_name=node.name) + return + # Case 2: Create a tensor filled with a single value val = val.val # element val to fill msg_prefix = 'torch::tensor {} '.format(node.name) @@ -4483,7 +4534,6 @@ def _scatter(context, inputs, mode, name): axis=axis, mode=mode, name=name) context.add(result) - @register_torch_op def scatter(context, node): inputs = _get_inputs(context, node) @@ -4501,8 +4551,106 @@ def scatter(context, node): _scatter(context, inputs, mode, node.name) - @register_torch_op def scatter_add(context, node): inputs = _get_inputs(context, node) _scatter(context, inputs, 'add', node.name) + +@register_torch_op +def roi_align(context, node): + inputs = _get_inputs(context, node) + + x = context[node.inputs[0]] + input_shape = x.shape # (B, h_in, w_in, C) + if len(input_shape) != 4: + raise ValueError( + '"CropResize" op: expected input rank 4, got {}'.format(x.rank) + ) + + const_box_info = True + if context[node.inputs[1]].val is None or context[node.inputs[2]].val is None: + const_box_info = False + + extrapolation_value = context[node.inputs[2]].val + + # CoreML index information along with boxes + if const_box_info: + boxes = context[node.inputs[1]].val + # CoreML expects boxes/ROI in + # [N, 1, 5, 1, 1] format + boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1) + else: + boxes = inputs[1] + boxes = mb.reshape(x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1]) + # Get Height and Width of crop + h_out = inputs[3] + w_out = inputs[4] + + # Torch input format: [B, C, h_in, w_in] + # CoreML input format: [B, C, h_in, w_in] + + # Crop Resize + x = mb.crop_resize( + x=x, + roi=boxes, + target_height=h_out.val, + target_width=w_out.val, + normalized_coordinates=True, + spatial_scale=extrapolation_value, + box_coordinate_mode="CORNERS_HEIGHT_FIRST", + sampling_mode='OFFSET_CORNERS', + ) + + # CoreML output format: [N, 1, C, h_out, w_out] + # Torch output format: [N, C, h_out, w_out] + x = mb.squeeze(x=x, axes=[1]) + + context.add(x, torch_name=node.outputs[0]) + +@register_torch_op +def numel(context, node): + inputs = _get_inputs(context, node, expected=1) + context.add(mb.reduce_prod(x=inputs[0], name=node.name), torch_name=node.outputs[0]) + +@register_torch_op +def nms(context, node): + inputs = _get_inputs(context, node) + boxes = inputs[0] + + num_boxes = boxes.shape[0] + max_boxes = num_boxes # we set the max_boxes just to be # input boxes + + scores = inputs[1] + iou_threshold = inputs[2] + boxes = mb.expand_dims(x=boxes, axes=[0]) + scores = mb.expand_dims(x=scores, axes=[0, -1]) + + # Follow tensorflow op example: TensorFlow's default value for score_threshold, Core ML does not + # have float('-inf') support, converted to minimum float32 instead + score_threshold = -3.4e38 + + _, _, x, _ = mb.non_maximum_suppression( + boxes=boxes, + scores=scores, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + max_boxes=max_boxes + ) + + if not is_symbolic(num_boxes): + x = mb.squeeze(x=x, axes=[0]) + x = mb.slice_by_index(x=x, begin=[0], end=[max_boxes], name=node.name) + else: + x = mb.squeeze(x=x, axes=[0], name=node.name) + context.add(x, torch_name=node.name) + +@register_torch_op +def narrow(context, node): + data, dim, start, length = _get_inputs(context, node, expected=4) + data_shape = mb.shape(x=data).val + begin = [0]*len(data_shape) + end = [x for x in data_shape] + begin[dim.val] = start.val + end[dim.val] = start.val+length.val + out = mb.slice_by_index(x=data, begin=begin, end=end) + context.add(out, torch_name=node.name) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 95e6690b6..8ebb78982 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -10,6 +10,8 @@ import pytest import torch.nn as nn +import torchvision + from .testing_utils import ( contains_op, generate_input_data, @@ -4564,3 +4566,76 @@ def forward(self, x): backend=backend, converter_input_type=converter_input_type, ) + +class TestNumel(TorchBaseTest): + @pytest.mark.parametrize( + "shapes, backend", + itertools.product( + [ + [(2, 1)], + [(5, 1, 4, 1)], + [(1,)], + ], + backends + ), + ) + def test_numel(self, shapes, backend): + class Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + v = torch.numel(x) + return torch.tensor(v) + + model = Model() + self.run_compare_torch(shapes, model, backend=backend) + + +class TestNarrow(TorchBaseTest): + @pytest.mark.parametrize( + "shapes, dim_start_length, backend", + itertools.product( + [ + [(3, 3)], + ], + [ + (0, 0, 2) + ] + , + backends + ), + ) + def test_narrow(self, shapes, dim_start_length, backend): + dim, start, length = dim_start_length + class Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.narrow(x, dim, start, length) + + model = Model() + self.run_compare_torch(shapes, model, backend=backend) + + +class TestNonMaximalSuppression(TorchBaseTest): + @pytest.mark.parametrize( + "shapes, scores, backend", + itertools.product( + [[(2, 4)]], + [(2,)], + backends + ), + ) + def test_non_maximal_supression(self, shapes, scores, backend): + scores = torch.rand(scores) + class Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torchvision.ops.nms(x, scores, iou_threshold=0.7) + + model = Model() + self.run_compare_torch(shapes, model, backend=backend)