Skip to content

Commit a48886c

Browse files
authored
Fix Accuracy metric (qubvel-org#186)
1 parent ecc848f commit a48886c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

segmentation_models_pytorch/utils/functional.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def accuracy(pr, gt, threshold=0.5, ignore_channels=None):
7777
pr = _threshold(pr, threshold=threshold)
7878
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
7979

80-
tp = torch.sum(gt == pr)
80+
tp = torch.sum(gt == pr, dtype=pr.dtype)
8181
score = tp / gt.view(-1).shape[0]
8282
return score
8383

0 commit comments

Comments
 (0)