Skip to content

Commit 7bd9e3a

Browse files
committed
add open source drone dataset pytorch reader
1 parent 690471c commit 7bd9e3a

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from glob import glob
2+
from os.path import join
3+
import numpy as np
4+
import pandas as pd
5+
from PIL import Image
6+
from typing import List, Tuple
7+
from src.datasets.coco import BoundingBox
8+
9+
from torch.utils.data import Dataset
10+
11+
12+
class DroneDataset(Dataset):
13+
def __init__(self, images_dir: str, mask_dir: str, class_dict_path: str):
14+
self.images_dir = images_dir
15+
self.mask_dir = mask_dir
16+
self.images_index = [
17+
filename.split(".")[0] for filename in glob("*.jpg")
18+
]
19+
20+
class_dict = pd.read_csv(class_dict_path).to_dict("index")
21+
self.class_id_to_name = {
22+
class_id: rec["name"] for class_id, rec in class_dict.items()
23+
}
24+
self.rgb_to_class = {
25+
(rec["r"], rec["g"], rec["b"]): int(class_id)
26+
for class_id, rec in class_dict.items()
27+
}
28+
29+
def _mask_rgb_to_class_label(self, rgb_mask: np.ndarray):
30+
"""The Semantic Drone Dataset formats their masks as an RGB mask
31+
To prepare the mask for use with a PyTorch model, we must encode
32+
the mask as a 2D array of class labels
33+
34+
Parameters
35+
----------
36+
rgb_mask : np.ndarray
37+
Mask array with RGB values for each class
38+
39+
Returns
40+
-------
41+
mask : np.ndarray
42+
Mask with shape `(height, width)` with class_id values where they occur
43+
"""
44+
height, width, _ = rgb_mask.shape
45+
mask = np.zeros((height, width))
46+
for i in range(height):
47+
for j in range(width):
48+
mask[i][j] = self.rgb_to_class[tuple(rgb_mask[i][j])]
49+
return mask
50+
51+
def __getitem__(
52+
self, image_id: int
53+
) -> Tuple[np.ndarray, List[BoundingBox]]:
54+
filename = self.images_index[image_id]
55+
image_filepath = join(self.images_dir, f"{filename}.jpg")
56+
image = Image.open(image_filepath).convert("RGB")
57+
image = np.array(image).astype("float32")
58+
59+
mask_filepath = join(self.images_dir, f"{filename}.png")
60+
mask = Image.open(mask_filepath).convert("RGB")
61+
mask = np.array(mask).astype("uint8")
62+
63+
mask = self._mask_rgb_to_class_label(mask)
64+
65+
return image, [], mask, []

0 commit comments

Comments
 (0)