diff --git a/test/test_entitylist.py b/test/test_entitylist.py index b22bd54f..f8f3160f 100644 --- a/test/test_entitylist.py +++ b/test/test_entitylist.py @@ -10,9 +10,9 @@ from unittest import TestCase, main import torch -from torch_extensions.tensorlist.tensorlist import TensorList from torchbiggraph.entitylist import EntityList +from torchbiggraph.tensorlist import TensorList def tensor_list_from_lists(lists: Sequence[Sequence[int]]) -> TensorList: diff --git a/test/test_graph_storages.py b/test/test_graph_storages.py index 3461e3e0..ac8f9b69 100644 --- a/test/test_graph_storages.py +++ b/test/test_graph_storages.py @@ -12,9 +12,9 @@ import h5py import numpy as np import torch -from torch_extensions.tensorlist.tensorlist import TensorList from torchbiggraph.graph_storages import FileEdgeAppender +from torchbiggraph.tensorlist import TensorList class TestFileEdgeAppender(TestCase): diff --git a/test/test_model.py b/test/test_model.py index fdca7823..63121549 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -14,7 +14,6 @@ from unittest import TestCase, main import torch -from torch_extensions.tensorlist.tensorlist import TensorList from torchbiggraph.entitylist import EntityList from torchbiggraph.model import ( @@ -39,6 +38,7 @@ TranslationOperator, match_shape, ) +from torchbiggraph.tensorlist import TensorList class TensorTestCase(TestCase): diff --git a/torch_extensions/README.md b/torch_extensions/README.md deleted file mode 100644 index 7fd3893e..00000000 --- a/torch_extensions/README.md +++ /dev/null @@ -1 +0,0 @@ -These packages are extensions to PyTorch needed by torchbiggraph. diff --git a/torch_extensions/__init__.py b/torch_extensions/__init__.py deleted file mode 100644 index 3f26b2db..00000000 --- a/torch_extensions/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE.txt file in the root directory of this source tree. diff --git a/torch_extensions/rpc/__init__.py b/torch_extensions/rpc/__init__.py deleted file mode 100644 index 46de451f..00000000 --- a/torch_extensions/rpc/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torch_extensions/tensorlist/__init__.py b/torch_extensions/tensorlist/__init__.py deleted file mode 100644 index 46de451f..00000000 --- a/torch_extensions/tensorlist/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torchbiggraph/bucket_scheduling.py b/torchbiggraph/bucket_scheduling.py index 22fd4819..2b547df6 100644 --- a/torchbiggraph/bucket_scheduling.py +++ b/torchbiggraph/bucket_scheduling.py @@ -11,10 +11,9 @@ from abc import ABC, abstractmethod from typing import Dict, List, NamedTuple, Optional, Set, Tuple -from torch_extensions.rpc.rpc import Client, Server - from torchbiggraph.config import BucketOrder from torchbiggraph.distributed import Startable +from torchbiggraph.rpc import Client, Server from torchbiggraph.stats import Stats from torchbiggraph.types import Bucket, EntityName, Partition, Rank, Side diff --git a/torchbiggraph/entitylist.py b/torchbiggraph/entitylist.py index 865af388..ae1941a4 100644 --- a/torchbiggraph/entitylist.py +++ b/torchbiggraph/entitylist.py @@ -9,8 +9,8 @@ from typing import Any, Sequence, Union import torch -from torch_extensions.tensorlist.tensorlist import TensorList +from torchbiggraph.tensorlist import TensorList from torchbiggraph.types import LongTensorType diff --git a/torchbiggraph/graph_storages.py b/torchbiggraph/graph_storages.py index cd73f193..f216a668 100644 --- a/torchbiggraph/graph_storages.py +++ b/torchbiggraph/graph_storages.py @@ -18,11 +18,11 @@ import h5py import numpy as np import torch -from torch_extensions.tensorlist.tensorlist import TensorList from torchbiggraph.edgelist import EdgeList from torchbiggraph.entitylist import EntityList from torchbiggraph.plugin import URLPluginRegistry +from torchbiggraph.tensorlist import TensorList from torchbiggraph.util import CouldNotLoadData diff --git a/torchbiggraph/model.py b/torchbiggraph/model.py index 87e0c980..ec4e6396 100644 --- a/torchbiggraph/model.py +++ b/torchbiggraph/model.py @@ -23,7 +23,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch_extensions.tensorlist.tensorlist import TensorList from torchbiggraph.config import ( ConfigSchema, @@ -34,6 +33,7 @@ from torchbiggraph.entitylist import EntityList from torchbiggraph.graph_storages import RELATION_TYPE_STORAGES from torchbiggraph.plugin import PluginRegistry +from torchbiggraph.tensorlist import TensorList from torchbiggraph.types import FloatTensorType, LongTensorType, Side from torchbiggraph.util import CouldNotLoadData diff --git a/torchbiggraph/parameter_sharing.py b/torchbiggraph/parameter_sharing.py index ea389819..7fca5d2d 100644 --- a/torchbiggraph/parameter_sharing.py +++ b/torchbiggraph/parameter_sharing.py @@ -63,7 +63,7 @@ class ParameterServer(Startable): get tensors by string key. Operations on the parameter server are globally synchronous. - FIXME: torch_extensions.rpc should be fixed to not require torch.serialization, + FIXME: torchbiggraph.rpc should be fixed to not require torch.serialization, then most of this code can be removed. FIXME: torch.distributed.recv should not require you to provide the tensor to write to; the type and size should be sent in the header. diff --git a/torch_extensions/rpc/rpc.py b/torchbiggraph/rpc.py similarity index 100% rename from torch_extensions/rpc/rpc.py rename to torchbiggraph/rpc.py diff --git a/torch_extensions/tensorlist/tensorlist.py b/torchbiggraph/tensorlist.py similarity index 100% rename from torch_extensions/tensorlist/tensorlist.py rename to torchbiggraph/tensorlist.py