-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_transforms.py
30 lines (24 loc) · 1.25 KB
/
custom_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
from PIL import Image
from torchvision.transforms import functional as F
class SegmentationIdToTrainId:
def __init__(self, labels):
# Create a mapping from id to trainId
self.id_to_trainid = {label.id: label.trainId for label in labels}
def __call__(self, seg_image):
# Convert segmentation image to numpy array if it's a PIL Image
if isinstance(seg_image, Image.Image):
seg_image = np.array(seg_image)
# Apply the mapping using numpy vectorization for performance
# Explicitly ignore labels with a value of -1 by mapping them to 255
vectorized_mapping = np.vectorize(
lambda x: 255 if x == -1 else self.id_to_trainid.get(x, 255)
) # Default to 255 (ignored class)
seg_image_transformed = vectorized_mapping(seg_image)
# Debugging: Print unique values before and after transformation
# print(f"Unique labels before transformation: {np.unique(seg_image)}")
# print(f"Unique labels after transformation: {np.unique(seg_image_transformed)}")
# Convert back to PIL Image if necessary
return Image.fromarray(seg_image_transformed.astype(np.uint8))
def __repr__(self):
return f'{self.__class__.__name__}()'