Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use custom isnan for IPU #490

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions graphium/ipu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Very confusing, but does match the directory structure
from .isnan.isnan import _ipu_isnan as ipu_isnan
16 changes: 11 additions & 5 deletions graphium/ipu/ipu_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from loguru import logger
from graphium.trainer.losses import HybridCELoss

from graphium.ipu import ipu_isnan

class BCEWithLogitsLossIPU(BCEWithLogitsLoss):
"""
Expand All @@ -29,7 +30,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:

# Replace the nan-targets by 0 or 1. Take the value closest to the input.
# Give a weight of 0 where there are nan-targets
nan_targets = target.isnan()
#nan_targets = target.isnan()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the name of the class ends with IPU, this class might be used on CPU or GPU, for reproducibility, debugging, or laziness. So I would recommend creating a new function

from graphium.ipu import is_running_on_ipu
def device_is_nan(tensor):
    if is_running_on_ipu():
        return ipu_isnan(tensor)
    else:
        return torch.isnan(tensor)

nan_targets = ipu_isnan(target)
nan_targets_0 = (input < 0.5) & nan_targets
nan_targets_1 = (input >= 0.5) & nan_targets
target[nan_targets_0] = 0.0
Expand Down Expand Up @@ -74,7 +76,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:

# Replace the nan-targets by 0 or 1. Take the value closest to the input.
# Give a weight of 0 where there are nan-targets
nan_targets = target.isnan()
#nan_targets = target.isnan()
nan_targets = ipu_isnan(target)
nan_targets_0 = (input < 0.5) & nan_targets
nan_targets_1 = (input >= 0.5) & nan_targets
target[nan_targets_0] = 0.0
Expand Down Expand Up @@ -109,7 +112,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
input = input.clone()

# Replace the nan-targets in the input/target tensors by 0
nan_targets = target.isnan()
#nan_targets = target.isnan()
nan_targets = ipu_isnan(target)
input[nan_targets] = 0.0
target[nan_targets] = 0.0

Expand Down Expand Up @@ -137,7 +141,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
input = input.clone()

# Replace the nan-targets in the input/target tensors by 0
nan_targets = target.isnan()
#nan_targets = target.isnan()
nan_targets = ipu_isnan(target)
input[nan_targets] = 0.0
target[nan_targets] = 0.0

Expand Down Expand Up @@ -175,7 +180,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
input = input.clone()

# Replace the nan-targets in the input/target tensors by 0
nan_targets = target.isnan()
#nan_targets = target.isnan()
nan_targets = ipu_isnan(target)

# Compute the loss, and rescale by the number of nan elements
loss = super().forward(input, target, nan_targets)
Expand Down
18 changes: 8 additions & 10 deletions graphium/ipu/ipu_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from graphium.utils.tensor import nan_mean
from graphium.ipu.ipu_utils import import_poptorch

from graphium.ipu import ipu_isnan

def auroc_ipu(
preds: Tensor,
Expand All @@ -41,7 +41,7 @@ def auroc_ipu(
preds = preds.clone()

# Replace the nan-targets in the preds/target tensors by 0
nan_targets = target.isnan()
nan_targets = ipu_isnan(target)
preds[nan_targets] = 0.0
target[nan_targets] = 0.0

Expand Down Expand Up @@ -85,9 +85,7 @@ def average_precision_ipu(
target = target.clone()
preds = preds.clone()

# Replace the nan-targets in the preds/target tensors by 0
# Average precision is not sensitive to true negatives
nan_targets = target.isnan()
)
preds[nan_targets] = 0.0
target[nan_targets] = 0.0

Expand Down Expand Up @@ -401,7 +399,7 @@ def get_confusion_matrix(

#### ADDED ####
# Put all the NaNs as the 0-class
nans = torch.isnan(target)
nans = ipu_isnan(target)
target[nans] = 0
preds[nans] = 0
if (preds.ndim > 1) and (preds.shape[1] > 1):
Expand Down Expand Up @@ -452,7 +450,7 @@ def get_nans(self) -> BoolTensor:
In the case of a boolean tensor, this returns a Tensor filled with `False`
"""
if self.is_floating_point():
return self.isnan()
return ipu_isnan(self)
elif self.is_signed():
return self == torch.iinfo(self.dtype).min
else:
Expand Down Expand Up @@ -578,7 +576,7 @@ def spearman_ipu(preds, target):
preds: estimated scores
target: ground truth scores
"""
nans = target.isnan()
nans = ipu_isnan(target)
dtype = preds.dtype
preds[nans] = float("inf")
target[nans] = float("inf")
Expand Down Expand Up @@ -849,7 +847,7 @@ def mean_squared_error_ipu(preds: Tensor, target: Tensor, squared: bool) -> Tens
preds = preds.clone()

# Replace the nan-targets in the preds/target tensors by 0
nan_targets = target.isnan()
nan_targets = ipu_isnan(target)
preds[nan_targets] = 0.0
target[nan_targets] = 0.0

Expand Down Expand Up @@ -882,7 +880,7 @@ def mean_absolute_error_ipu(preds: Tensor, target: Tensor) -> Tensor:
preds = preds.clone()

# Replace the nan-targets in the preds/target tensors by 0
nan_targets = target.isnan()
nan_targets = ipu_isnan(target)
preds[nan_targets] = 0.0
target[nan_targets] = 0.0

Expand Down
21 changes: 21 additions & 0 deletions graphium/ipu/isnan/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
CXX ?= g++
CXXFLAGS = -std=c++14 -fPIC -g
LDLIBS = -shared -lpopart
ONNX_NAMESPACE = -DONNX_NAMESPACE=onnx

BUILD_DIR = build
SOURCES = isnan_popart.cpp
TARGET = $(BUILD_DIR)/custom_ops.so

all: create_build_dir isnan_custom_op

.PHONY: create_build_dir
create_build_dir:
mkdir -p $(BUILD_DIR)

isnan_custom_op: isnan_popart.cpp
$(CXX) $(SOURCES) $(LDLIBS) $(CXXFLAGS) $(ONNX_NAMESPACE) -o $(TARGET)

.PHONY: clean
clean:
rm -rf $(BUILD_DIR)
Empty file.
19 changes: 19 additions & 0 deletions graphium/ipu/isnan/isnan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
import poptorch
import pathlib
import ctypes

myso = list(pathlib.Path(__file__).parent.rglob("build/cus*.so"))
assert myso, "Failed to find custom op .so file - please cd into `graphium/ipu/isnan` and run `make`"
assert len(myso) == 1, f"Too many ({len(myso)}) custom op .so files, there should only be one"
ctypes.cdll.LoadLibrary(myso[0])

def _ipu_isnan(self, x):

return poptorch.custom_op(
inputs=(x,),
name="IsNanCustom",
domain="custom.ops",
domain_version=1,
example_outputs=(x.bool(),),
)
27 changes: 27 additions & 0 deletions graphium/ipu/isnan/isnan_codelet.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2018 Graphcore Ltd. All rights reserved.
#include <poplar/Vertex.hpp>
#include <ipu_builtins.h>

class IsNaNUsingF32ClassVertex : public poplar::Vertex {
public:
// Fields
poplar::Input<poplar::Vector<float>> in;
poplar::Output<poplar::Vector<bool>> out;

// Compute function
bool compute() {

for (int i = 0; i < in.size(); ++i) {

auto inClass = __builtin_ipu_f32class(in[i]);

// TFPU_CLASS_UNC = 0
// TFPU_CLASS_SNAN = 1
// TFPU_CLASS_QNAN = 2
// All others are > 3 and not NaN
out[i] = inClass < 3;

}
return true;
}
};
106 changes: 106 additions & 0 deletions graphium/ipu/isnan/isnan_popart.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (c) 2020 Graphcore Ltd. All rights reserved.

#include <popart/opmanager.hpp>
#include <popart/opserialiser.hpp>
#include <popart/popx/opxmanager.hpp>

#include <popart/popx/opx.hpp>
#include <popops/ElementWise.hpp>

#include <iostream>

namespace CustomOperators {
const popart::OperatorIdentifier IsNanId = {"custom.ops", "IsNanCustom", 1};
} // namespace CustomOperators

class IsNanOp;
class IsNanOpx;
class IsNanGradOpx;

class IsNanOp : public popart::Op {
public:
IsNanOp(const popart::OperatorIdentifier &_opid,
const popart::Op::Settings &settings_)
: popart::Op(_opid, settings_) {}

std::unique_ptr<Op> clone() const final {
return std::make_unique<IsNanOp>(*this);
}

void setup() final {
outInfo(0) = popart::TensorInfo(popart::DataType::BOOL,
inInfo(0).shape());
}

void appendAttributes(popart::OpSerialiserBase &os) const override {
Op::appendAttributes(os);
}

void appendOutlineAttributes(popart::OpSerialiserBase &os) const override {
Op::appendOutlineAttributes(os);
}

float getSubgraphValue() const final { return getHighSubgraphValue(); }

bool requiresRandomSeed() const override { return false; }

};

namespace {
using popart::DataType;
using popart::OpDefinition;

static OpDefinition isNanOpDef({OpDefinition::Inputs({{"input", {popart::DataType::FLOAT}}}),
OpDefinition::Outputs({{"output", {popart::DataType::BOOL}}}),
OpDefinition::Attributes({})
});

static popart::OpCreator<IsNanOp> isNanOpCreator(
popart::OpDefinitions({{CustomOperators::IsNanId, isNanOpDef}}),
[](const popart::OpCreatorInfo &info) {
return std::make_unique<IsNanOp>(info.opid, info.settings);
},
true);
} // namespace

namespace pe = popops::expr;

class IsNanOpx : public popart::popx::Opx {
public:
IsNanOpx(popart::Op *op, popart::popx::Devicex *devicex)
: popart::popx::Opx(op, devicex) {
verifyOp<IsNanOp>(op, {CustomOperators::IsNanId});
}

void grow(poplar::program::Sequence &prog) const final {

auto op = getOp<IsNanOp>();

poplar::Tensor input = getInTensor(0);
auto output = graph().addVariable(poplar::BOOL, input.shape(), debugContext("IsNanCustom"));
auto tileMapping = graph().getTileMapping(input);
graph().setTileMapping(output, tileMapping);

graph().addCodelets("isnan.gp");
auto computeSet = graph().addComputeSet("isNanComputeSet");

for (unsigned i = 0; i < tileMapping.size(); ++i) {
auto intervals = tileMapping.at(i);
//std::cerr << "i = " << i << std::endl;
//std::cerr << "intervals.size() = " << intervals.size() << std::endl;
for (auto interval : intervals) {
auto vertex = graph().addVertex(computeSet, "IsNaNUsingF32ClassVertex");
graph().setTileMapping(vertex, i);
graph().connect(vertex["in"], input.flatten().slice(interval.begin(), interval.end()));
graph().connect(vertex["out"], output.flatten().slice(interval.begin(), interval.end()));
}
}

prog.add(poplar::program::Execute(computeSet));

setOutTensor(0, output);
}
};

static popart::popx::OpxCreator<IsNanOpx>
IsNanOpxCreator({CustomOperators::IsNanId});
4 changes: 3 additions & 1 deletion graphium/nn/encoders/laplace_pos_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from graphium.nn.base_layers import MLP, get_norm, FCLayer, TransformerEncoderLayerMup
from graphium.nn.encoders.base_encoder import BaseEncoder

from graphium.ipu import ipu_isnan

class LapPENodeEncoder(BaseEncoder):
def __init__(
Expand Down Expand Up @@ -191,7 +192,8 @@ def forward(
pos_enc = torch.cat(
(eigvecs.unsqueeze(2), eigvals.unsqueeze(2)), dim=2
) # (Num nodes) x (Num Eigenvectors) x 2
empty_mask = torch.isnan(pos_enc) # (Num nodes) x (Num Eigenvectors) x 2
#empty_mask = torch.isnan(pos_enc) # (Num nodes) x (Num Eigenvectors) x 2
empty_mask = ipu_isnan(pos_enc) # (Num nodes) x (Num Eigenvectors) x 2

pos_enc[empty_mask] = 0 # (Num nodes) x (Num Eigenvectors) x 2
if self.first_normalization:
Expand Down
5 changes: 3 additions & 2 deletions graphium/nn/encoders/signnet_pos_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from graphium.nn.base_layers import MLP
from graphium.nn.encoders.base_encoder import BaseEncoder

from graphium.ipu import ipu_isnan

class SimpleGIN(nn.Module):
def __init__(
Expand Down Expand Up @@ -306,7 +306,8 @@ def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> Dict[str, t
eigvecs, edge_index, batch_index = batch[input_keys[0]], batch["edge_index"], batch["batch_index"]
pos_enc = eigvecs.unsqueeze(-1) # (Num nodes) x (Num Eigenvectors) x 1

empty_mask = torch.isnan(pos_enc)
#empty_mask = torch.isnan(pos_enc)
empty_mask = ipu_isnan(pos_enc)
pos_enc[empty_mask] = 0 # (Num nodes) x (Num Eigenvectors) x 1

# SignNet
Expand Down
5 changes: 3 additions & 2 deletions graphium/nn/pyg_layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from graphium.nn.base_layers import MLP, get_norm
from graphium.ipu.to_dense_batch import to_dense_batch, to_sparse_batch

from graphium.ipu import ipu_isnan

class PreprocessPositions(nn.Module):
"""
Expand Down Expand Up @@ -96,7 +96,8 @@ def forward(
# for real nodes, if 3d position does not exit, it is nans. For padding nodes, 3d positions will be 0
# if the first node of a molecule has 3d position as nan, the whole molecule will be masked out.
# [batch]
nan_mask = torch.isnan(pos)[:, 0, 0]
#nan_mask = torch.isnan(pos)[:, 0, 0]
nan_mask = ipu_isnan(pos)[:, 0, 0]
# apply nan_mask on pos so that it does not give nan gradient
# when applying gaussian kernels
pos.masked_fill_(nan_mask.unsqueeze(1).unsqueeze(2), 0.0)
Expand Down
Loading
Loading