Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed DictType from TVM #69

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 0 additions & 54 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// SPDX-FileCopyrightText: © 2019-2023 The Apache Software Foundation © 2024 Tenstorrent AI ULC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this resetting this file to the original? If so, can remove SPDX headers

//
// SPDX-License-Identifier: Apache-2.0
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand Down Expand Up @@ -371,34 +368,6 @@ class TupleTypeNode : public TypeNode {
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};

/*!
* \brief The type of dictionary of str, value.
* \sa DictType
*/
class DictTypeNode : public TypeNode {
public:
/*! \brief the key (String) and type of each value */
Array<String> keys;
Array<Type> values;

DictTypeNode() {}

void VisitAttrs(AttrVisitor* v) {
v->Visit("keys", &keys);
v->Visit("values", &values);
v->Visit("span", &span);
}

bool SEqualReduce(const DictTypeNode* other, SEqualReducer equal) const {
return equal(keys, other->keys) && equal(values, other->values);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(keys); }

static constexpr const char* _type_key = "DictType";
TVM_DECLARE_FINAL_OBJECT_INFO(DictTypeNode, TypeNode);
};

/*!
* \brief Managed reference to TupleTypeNode.
* \sa TupleTypeNode.
Expand All @@ -421,29 +390,6 @@ class TupleType : public Type {
TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode);
};

/*!
* \brief Managed reference to DictTypeNode.
* \sa DictTypeNode.
*/
class DictType : public Type {
public:
/*!
* \brief Constructor
* \param keys keys of the dict.
* \param values values of the dict
* \param span The span of the type.
*/
TVM_DLL explicit DictType(Array<String> keys, Array<Type> values, Span span = Span());

/*!
* \brief Create an empty Dict type that constains nothing.
* \return A empty Dict type.
*/
TVM_DLL DictType static Empty();

TVM_DEFINE_OBJECT_REF_METHODS(DictType, Type, DictTypeNode);
};

/*!
* \return a type that represents void.
*/
Expand Down
1 change: 0 additions & 1 deletion include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ using TypeVarNode = tvm::TypeVarNode;
using GlobalTypeVar = tvm::GlobalTypeVar;
using GlobalTypeVarNode = tvm::GlobalTypeVarNode;
using TupleType = tvm::TupleType;
using DictType = tvm::DictType;
using TupleTypeNode = tvm::TupleTypeNode;
using TypeConstraint = tvm::TypeConstraint;
using TypeConstraintNode = tvm::TypeConstraintNode;
Expand Down
14 changes: 0 additions & 14 deletions python/tvm/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,6 @@ class TupleType(Type):
def __init__(self, fields):
self.__init_handle_by_constructor__(_ffi_api.TupleType, fields)

@tvm._ffi.register_object("DictType")
class DictType(Type):
"""The type of dict values.

Parameters
----------
keys : List[str]
The keys in the dict
values : List[Type]
The values in the dict
"""

def __init__(self, keys, values):
self.__init_handle_by_constructor__(_ffi_api.DictType, keys, values)

@tvm._ffi.register_object("TypeConstraint")
class TypeConstraint(Type):
Expand Down
25 changes: 2 additions & 23 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import numpy as np
import tvm
from tvm.ir import IRModule
from tvm.ir.type import DictType
from tvm.topi.utils import get_const_tuple

from .. import analysis as _analysis
Expand Down Expand Up @@ -2606,18 +2605,6 @@ def logical_xor(self, inputs, input_types):
return _op.logical_xor(lhs, rhs)

def getitem(self, inputs, input_types):
if input_types[0] == 'DictType':
keys = inputs[0].type_annotation.keys
values = inputs[0].type_annotation.values
name = inputs[0].name_hint

index = 0
for key in keys:
if key == inputs[1]:
break
index += 1
return _expr.var(f"{name}_{inputs[1]}", values[index])
else:
return self.prelude.nth(inputs[0], _wrap_const(inputs[1]))

def list_len(self, inputs, input_types):
Expand Down Expand Up @@ -2848,7 +2835,8 @@ def index(self, inputs, input_types):

def meshgrid(self, inputs, input_types):
data = inputs[0]
return _op.meshgrid(data, indexing="ij")
indexing = inputs[1] if len(inputs) > 1 else "ij"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intended change?

return _op.meshgrid(data,indexing)

def nms(self, inputs, input_types):
boxes = inputs[0]
Expand Down Expand Up @@ -5681,8 +5669,6 @@ def _get_pytorch_value_type(typ, default_dtype="float32"):

elif kind == "ListType":
return "ListType"
elif kind == "DictType":
return "DictType"
elif kind in ["IntType", "FloatType", "BoolType", "StringType", "OptionalType"]:
pt_dtype = str(typ).lower()
dtype = pt_dtype if kind == "OptionalType" else _convert_data_type(pt_dtype)
Expand Down Expand Up @@ -5878,13 +5864,6 @@ def get_relay_ty(ishape, itype, pt_type):
raise RuntimeError(msg)
rlist, _, _ = prelude.mod.get_type("List")
return rlist(elem_tys[0])
elif pt_type.kind() == "DictType":
pt_elemtype = pt_type.getValueType()
keys, values = [], []
for k, v in ishape.items():
keys.append(k)
values.append(get_relay_ty(v, itype, pt_elemtype))
return DictType(keys, values)
elif pt_type.kind() == "OptionalType":
# we do not support None yet, so we fill in the type
return get_relay_ty(ishape, itype, pt_type.getElementType())
Expand Down
19 changes: 0 additions & 19 deletions src/ir/type.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// SPDX-FileCopyrightText: © 2019-2023 The Apache Software Foundation © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand Down Expand Up @@ -116,22 +113,6 @@ TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array<Type> fields) {
return TupleType(fields);
});

DictType::DictType(Array<String> keys, Array<Type> values, Span span) {
ObjectPtr<DictTypeNode> n = make_object<DictTypeNode>();
n->keys = std::move(keys);
n->values = std::move(values);
n->span = std::move(span);
data_ = std::move(n);
}

DictType DictType::Empty() { return DictType(Array<String>(), Array<Type>()); }

TVM_REGISTER_NODE_TYPE(DictTypeNode);

TVM_REGISTER_GLOBAL("ir.DictType").set_body_typed([](Array<String> keys, Array<Type> values) {
return DictType(keys, values);
});

IncompleteType::IncompleteType(TypeKind kind, Span span) {
auto n = make_object<IncompleteTypeNode>();
n->kind = std::move(kind);
Expand Down