We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ecc848f commit a48886cCopy full SHA for a48886c
segmentation_models_pytorch/utils/functional.py
@@ -77,7 +77,7 @@ def accuracy(pr, gt, threshold=0.5, ignore_channels=None):
77
pr = _threshold(pr, threshold=threshold)
78
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)
79
80
- tp = torch.sum(gt == pr)
+ tp = torch.sum(gt == pr, dtype=pr.dtype)
81
score = tp / gt.view(-1).shape[0]
82
return score
83
0 commit comments