diff --git a/graphium/ipu/__init__.py b/graphium/ipu/__init__.py index e69de29bb..263b83df1 100644 --- a/graphium/ipu/__init__.py +++ b/graphium/ipu/__init__.py @@ -0,0 +1,2 @@ +# Very confusing, but does match the directory structure +from .isnan.isnan import _ipu_isnan as ipu_isnan diff --git a/graphium/ipu/ipu_losses.py b/graphium/ipu/ipu_losses.py index 5e7a93c85..a7c0eb7a1 100644 --- a/graphium/ipu/ipu_losses.py +++ b/graphium/ipu/ipu_losses.py @@ -5,6 +5,7 @@ from loguru import logger from graphium.trainer.losses import HybridCELoss +from graphium.ipu import ipu_isnan class BCEWithLogitsLossIPU(BCEWithLogitsLoss): """ @@ -29,7 +30,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: # Replace the nan-targets by 0 or 1. Take the value closest to the input. # Give a weight of 0 where there are nan-targets - nan_targets = target.isnan() + #nan_targets = target.isnan() + nan_targets = ipu_isnan(target) nan_targets_0 = (input < 0.5) & nan_targets nan_targets_1 = (input >= 0.5) & nan_targets target[nan_targets_0] = 0.0 @@ -74,7 +76,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: # Replace the nan-targets by 0 or 1. Take the value closest to the input. # Give a weight of 0 where there are nan-targets - nan_targets = target.isnan() + #nan_targets = target.isnan() + nan_targets = ipu_isnan(target) nan_targets_0 = (input < 0.5) & nan_targets nan_targets_1 = (input >= 0.5) & nan_targets target[nan_targets_0] = 0.0 @@ -109,7 +112,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: input = input.clone() # Replace the nan-targets in the input/target tensors by 0 - nan_targets = target.isnan() + #nan_targets = target.isnan() + nan_targets = ipu_isnan(target) input[nan_targets] = 0.0 target[nan_targets] = 0.0 @@ -137,7 +141,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: input = input.clone() # Replace the nan-targets in the input/target tensors by 0 - nan_targets = target.isnan() + #nan_targets = target.isnan() + nan_targets = ipu_isnan(target) input[nan_targets] = 0.0 target[nan_targets] = 0.0 @@ -175,7 +180,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: input = input.clone() # Replace the nan-targets in the input/target tensors by 0 - nan_targets = target.isnan() + #nan_targets = target.isnan() + nan_targets = ipu_isnan(target) # Compute the loss, and rescale by the number of nan elements loss = super().forward(input, target, nan_targets) diff --git a/graphium/ipu/ipu_metrics.py b/graphium/ipu/ipu_metrics.py index ba42108c5..7fae1156c 100644 --- a/graphium/ipu/ipu_metrics.py +++ b/graphium/ipu/ipu_metrics.py @@ -18,7 +18,7 @@ from graphium.utils.tensor import nan_mean from graphium.ipu.ipu_utils import import_poptorch - +from graphium.ipu import ipu_isnan def auroc_ipu( preds: Tensor, @@ -41,7 +41,7 @@ def auroc_ipu( preds = preds.clone() # Replace the nan-targets in the preds/target tensors by 0 - nan_targets = target.isnan() + nan_targets = ipu_isnan(target) preds[nan_targets] = 0.0 target[nan_targets] = 0.0 @@ -85,9 +85,7 @@ def average_precision_ipu( target = target.clone() preds = preds.clone() - # Replace the nan-targets in the preds/target tensors by 0 - # Average precision is not sensitive to true negatives - nan_targets = target.isnan() + ) preds[nan_targets] = 0.0 target[nan_targets] = 0.0 @@ -401,7 +399,7 @@ def get_confusion_matrix( #### ADDED #### # Put all the NaNs as the 0-class - nans = torch.isnan(target) + nans = ipu_isnan(target) target[nans] = 0 preds[nans] = 0 if (preds.ndim > 1) and (preds.shape[1] > 1): @@ -452,7 +450,7 @@ def get_nans(self) -> BoolTensor: In the case of a boolean tensor, this returns a Tensor filled with `False` """ if self.is_floating_point(): - return self.isnan() + return ipu_isnan(self) elif self.is_signed(): return self == torch.iinfo(self.dtype).min else: @@ -578,7 +576,7 @@ def spearman_ipu(preds, target): preds: estimated scores target: ground truth scores """ - nans = target.isnan() + nans = ipu_isnan(target) dtype = preds.dtype preds[nans] = float("inf") target[nans] = float("inf") @@ -849,7 +847,7 @@ def mean_squared_error_ipu(preds: Tensor, target: Tensor, squared: bool) -> Tens preds = preds.clone() # Replace the nan-targets in the preds/target tensors by 0 - nan_targets = target.isnan() + nan_targets = ipu_isnan(target) preds[nan_targets] = 0.0 target[nan_targets] = 0.0 @@ -882,7 +880,7 @@ def mean_absolute_error_ipu(preds: Tensor, target: Tensor) -> Tensor: preds = preds.clone() # Replace the nan-targets in the preds/target tensors by 0 - nan_targets = target.isnan() + nan_targets = ipu_isnan(target) preds[nan_targets] = 0.0 target[nan_targets] = 0.0 diff --git a/graphium/ipu/isnan/Makefile b/graphium/ipu/isnan/Makefile new file mode 100644 index 000000000..a84d9cce5 --- /dev/null +++ b/graphium/ipu/isnan/Makefile @@ -0,0 +1,21 @@ +CXX ?= g++ +CXXFLAGS = -std=c++14 -fPIC -g +LDLIBS = -shared -lpopart +ONNX_NAMESPACE = -DONNX_NAMESPACE=onnx + +BUILD_DIR = build +SOURCES = isnan_popart.cpp +TARGET = $(BUILD_DIR)/custom_ops.so + +all: create_build_dir isnan_custom_op + +.PHONY: create_build_dir +create_build_dir: + mkdir -p $(BUILD_DIR) + +isnan_custom_op: isnan_popart.cpp + $(CXX) $(SOURCES) $(LDLIBS) $(CXXFLAGS) $(ONNX_NAMESPACE) -o $(TARGET) + +.PHONY: clean +clean: + rm -rf $(BUILD_DIR) diff --git a/graphium/ipu/isnan/__init__.py b/graphium/ipu/isnan/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphium/ipu/isnan/isnan.py b/graphium/ipu/isnan/isnan.py new file mode 100644 index 000000000..1f177152e --- /dev/null +++ b/graphium/ipu/isnan/isnan.py @@ -0,0 +1,19 @@ +import torch +import poptorch +import pathlib +import ctypes + +myso = list(pathlib.Path(__file__).parent.rglob("build/cus*.so")) +assert myso, "Failed to find custom op .so file - please cd into `graphium/ipu/isnan` and run `make`" +assert len(myso) == 1, f"Too many ({len(myso)}) custom op .so files, there should only be one" +ctypes.cdll.LoadLibrary(myso[0]) + +def _ipu_isnan(self, x): + + return poptorch.custom_op( + inputs=(x,), + name="IsNanCustom", + domain="custom.ops", + domain_version=1, + example_outputs=(x.bool(),), + ) diff --git a/graphium/ipu/isnan/isnan_codelet.cpp b/graphium/ipu/isnan/isnan_codelet.cpp new file mode 100644 index 000000000..9db0c5fb7 --- /dev/null +++ b/graphium/ipu/isnan/isnan_codelet.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2018 Graphcore Ltd. All rights reserved. +#include +#include + +class IsNaNUsingF32ClassVertex : public poplar::Vertex { +public: + // Fields + poplar::Input> in; + poplar::Output> out; + + // Compute function + bool compute() { + + for (int i = 0; i < in.size(); ++i) { + + auto inClass = __builtin_ipu_f32class(in[i]); + + // TFPU_CLASS_UNC = 0 + // TFPU_CLASS_SNAN = 1 + // TFPU_CLASS_QNAN = 2 + // All others are > 3 and not NaN + out[i] = inClass < 3; + + } + return true; + } +}; diff --git a/graphium/ipu/isnan/isnan_popart.cpp b/graphium/ipu/isnan/isnan_popart.cpp new file mode 100644 index 000000000..1ac2118c7 --- /dev/null +++ b/graphium/ipu/isnan/isnan_popart.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2020 Graphcore Ltd. All rights reserved. + +#include +#include +#include + +#include +#include + +#include + +namespace CustomOperators { +const popart::OperatorIdentifier IsNanId = {"custom.ops", "IsNanCustom", 1}; +} // namespace CustomOperators + +class IsNanOp; +class IsNanOpx; +class IsNanGradOpx; + +class IsNanOp : public popart::Op { +public: + IsNanOp(const popart::OperatorIdentifier &_opid, + const popart::Op::Settings &settings_) + : popart::Op(_opid, settings_) {} + + std::unique_ptr clone() const final { + return std::make_unique(*this); + } + + void setup() final { + outInfo(0) = popart::TensorInfo(popart::DataType::BOOL, + inInfo(0).shape()); + } + + void appendAttributes(popart::OpSerialiserBase &os) const override { + Op::appendAttributes(os); + } + + void appendOutlineAttributes(popart::OpSerialiserBase &os) const override { + Op::appendOutlineAttributes(os); + } + + float getSubgraphValue() const final { return getHighSubgraphValue(); } + + bool requiresRandomSeed() const override { return false; } + +}; + +namespace { +using popart::DataType; +using popart::OpDefinition; + + static OpDefinition isNanOpDef({OpDefinition::Inputs({{"input", {popart::DataType::FLOAT}}}), + OpDefinition::Outputs({{"output", {popart::DataType::BOOL}}}), + OpDefinition::Attributes({}) + }); + +static popart::OpCreator isNanOpCreator( + popart::OpDefinitions({{CustomOperators::IsNanId, isNanOpDef}}), + [](const popart::OpCreatorInfo &info) { + return std::make_unique(info.opid, info.settings); + }, + true); +} // namespace + +namespace pe = popops::expr; + +class IsNanOpx : public popart::popx::Opx { +public: + IsNanOpx(popart::Op *op, popart::popx::Devicex *devicex) + : popart::popx::Opx(op, devicex) { + verifyOp(op, {CustomOperators::IsNanId}); + } + + void grow(poplar::program::Sequence &prog) const final { + + auto op = getOp(); + + poplar::Tensor input = getInTensor(0); + auto output = graph().addVariable(poplar::BOOL, input.shape(), debugContext("IsNanCustom")); + auto tileMapping = graph().getTileMapping(input); + graph().setTileMapping(output, tileMapping); + + graph().addCodelets("isnan.gp"); + auto computeSet = graph().addComputeSet("isNanComputeSet"); + + for (unsigned i = 0; i < tileMapping.size(); ++i) { + auto intervals = tileMapping.at(i); + //std::cerr << "i = " << i << std::endl; + //std::cerr << "intervals.size() = " << intervals.size() << std::endl; + for (auto interval : intervals) { + auto vertex = graph().addVertex(computeSet, "IsNaNUsingF32ClassVertex"); + graph().setTileMapping(vertex, i); + graph().connect(vertex["in"], input.flatten().slice(interval.begin(), interval.end())); + graph().connect(vertex["out"], output.flatten().slice(interval.begin(), interval.end())); + } + } + + prog.add(poplar::program::Execute(computeSet)); + + setOutTensor(0, output); + } +}; + +static popart::popx::OpxCreator + IsNanOpxCreator({CustomOperators::IsNanId}); diff --git a/graphium/nn/encoders/laplace_pos_encoder.py b/graphium/nn/encoders/laplace_pos_encoder.py index 7cc69919b..6b594e8cc 100644 --- a/graphium/nn/encoders/laplace_pos_encoder.py +++ b/graphium/nn/encoders/laplace_pos_encoder.py @@ -6,6 +6,7 @@ from graphium.nn.base_layers import MLP, get_norm, FCLayer, TransformerEncoderLayerMup from graphium.nn.encoders.base_encoder import BaseEncoder +from graphium.ipu import ipu_isnan class LapPENodeEncoder(BaseEncoder): def __init__( @@ -191,7 +192,8 @@ def forward( pos_enc = torch.cat( (eigvecs.unsqueeze(2), eigvals.unsqueeze(2)), dim=2 ) # (Num nodes) x (Num Eigenvectors) x 2 - empty_mask = torch.isnan(pos_enc) # (Num nodes) x (Num Eigenvectors) x 2 + #empty_mask = torch.isnan(pos_enc) # (Num nodes) x (Num Eigenvectors) x 2 + empty_mask = ipu_isnan(pos_enc) # (Num nodes) x (Num Eigenvectors) x 2 pos_enc[empty_mask] = 0 # (Num nodes) x (Num Eigenvectors) x 2 if self.first_normalization: diff --git a/graphium/nn/encoders/signnet_pos_encoder.py b/graphium/nn/encoders/signnet_pos_encoder.py index e95278cd5..7b675ef9a 100644 --- a/graphium/nn/encoders/signnet_pos_encoder.py +++ b/graphium/nn/encoders/signnet_pos_encoder.py @@ -12,7 +12,7 @@ from graphium.nn.base_layers import MLP from graphium.nn.encoders.base_encoder import BaseEncoder - +from graphium.ipu import ipu_isnan class SimpleGIN(nn.Module): def __init__( @@ -306,7 +306,8 @@ def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> Dict[str, t eigvecs, edge_index, batch_index = batch[input_keys[0]], batch["edge_index"], batch["batch_index"] pos_enc = eigvecs.unsqueeze(-1) # (Num nodes) x (Num Eigenvectors) x 1 - empty_mask = torch.isnan(pos_enc) + #empty_mask = torch.isnan(pos_enc) + empty_mask = ipu_isnan(pos_enc) pos_enc[empty_mask] = 0 # (Num nodes) x (Num Eigenvectors) x 1 # SignNet diff --git a/graphium/nn/pyg_layers/utils.py b/graphium/nn/pyg_layers/utils.py index 100ccec51..87ca2fe33 100644 --- a/graphium/nn/pyg_layers/utils.py +++ b/graphium/nn/pyg_layers/utils.py @@ -9,7 +9,7 @@ from graphium.nn.base_layers import MLP, get_norm from graphium.ipu.to_dense_batch import to_dense_batch, to_sparse_batch - +from graphium.ipu import ipu_isnan class PreprocessPositions(nn.Module): """ @@ -96,7 +96,8 @@ def forward( # for real nodes, if 3d position does not exit, it is nans. For padding nodes, 3d positions will be 0 # if the first node of a molecule has 3d position as nan, the whole molecule will be masked out. # [batch] - nan_mask = torch.isnan(pos)[:, 0, 0] + #nan_mask = torch.isnan(pos)[:, 0, 0] + nan_mask = ipu_isnan(pos)[:, 0, 0] # apply nan_mask on pos so that it does not give nan gradient # when applying gaussian kernels pos.masked_fill_(nan_mask.unsqueeze(1).unsqueeze(2), 0.0) diff --git a/graphium/trainer/metrics.py b/graphium/trainer/metrics.py index d935401ed..aac9b4227 100644 --- a/graphium/trainer/metrics.py +++ b/graphium/trainer/metrics.py @@ -10,6 +10,7 @@ import torchmetrics.functional.regression.mae from graphium.utils.tensor import nan_mean +from graphium.ipu import ipu_isnan # NOTE(hadim): the below is a fix to be able to import previously saved Graphium model that are incompatible # with the current version of torchmetrics. @@ -265,7 +266,8 @@ def compute(self, preds: Tensor, target: Tensor) -> Tensor: if self.thresholder is not None: preds, target = self.thresholder(preds, target) - target_nans = torch.isnan(target) + #target_nans = torch.isnan(target) + target_nans = ipu_isnan(target) # for the classifigression task, cast predictions from # (batch_size, n_targets * n_brackets) to (batch_size, n_targets, n_brackets) @@ -329,13 +331,15 @@ def compute(self, preds: Tensor, target: Tensor) -> Tensor: def _filter_nans(self, preds: Tensor, target: Tensor): """Handle the NaNs according to the chosen options""" - target_nans = torch.isnan(target) + #target_nans = torch.isnan(target) + target_nans = ipu_isnan(target) if self.target_nan_mask is None: pass elif isinstance(self.target_nan_mask, (int, float)): target = target.clone() - target[torch.isnan(target)] = self.target_nan_mask + #target[torch.isnan(target)] = self.target_nan_mask + target[ipu_isnan(target)] = self.target_nan_mask elif self.target_nan_mask == "ignore": target = target[~target_nans] preds = preds[~target_nans] diff --git a/graphium/utils/tensor.py b/graphium/utils/tensor.py index 1514e620c..233aa4b86 100644 --- a/graphium/utils/tensor.py +++ b/graphium/utils/tensor.py @@ -133,7 +133,8 @@ def nan_mean(input: Tensor, *args, **kwargs) -> Tensor: """ sum = torch.nansum(input, *args, **kwargs) - num = torch.sum(~torch.isnan(input), *args, **kwargs) + #num = torch.sum(~torch.isnan(input), *args, **kwargs) + num = torch.sum(~ipu_isnan(input), *args, **kwargs) mean = sum / num return mean @@ -220,7 +221,8 @@ def nan_var(input: Tensor, unbiased: bool = True, **kwargs) -> Tensor: var = nan_mean(dist2, **kwargs) if unbiased: - num = torch.sum(~torch.isnan(input), **kwargs) + #num = torch.sum(~torch.isnan(input), **kwargs) + num = torch.sum(~ipu_isnan(input), **kwargs) var = var * num / (num - 1) return var