-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy path_penguins.py
142 lines (106 loc) · 3.57 KB
/
_penguins.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Penguins dataset."""
from typing import Optional, List, Dict, Tuple, Any
from torch.utils.data import Dataset
from torchvision.transforms import Compose
from pandas import DataFrame
from palmerpenguins import load_penguins
# pylint: disable=too-many-arguments
class PenguinDataset(Dataset):
"""Penguin dataset class.
Parameters
----------
input_keys : List[str]
The column titles to use in the input feature vectors.
target_keys : List[str]
The column titles to use in the target feature vectors.
train : bool
If ``True``, this object will serve as the training set, and if
``False``, the validation set.
x_tfms : Compose, optional
A composition of transforms to apply to the inputs.
y_tfms : Compose, optional
A composition of transfroms to apply to the targets.
Notes
-----
The validation split contains 10 male and 10 female penguins of each
species.
"""
def __init__(
self,
input_keys: List[str],
target_keys: List[str],
train: bool,
x_tfms: Optional[Compose] = None,
y_tfms: Optional[Compose] = None,
):
"""Build ``PenguinDataset``."""
self.input_keys = input_keys
self.target_keys = target_keys
self.full_df = _load_penguin_data()
self.split = _split_data(self.full_df)["train" if train is True else "valid"]
self.x_tfms = x_tfms
self.y_tfms = y_tfms
def __len__(self) -> int:
"""Return the length of requested split.
Returns
-------
int
The number of items in the dataset.
"""
return len(self.split)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
"""Return an input-target pair.
Parameters
----------
idx : int
Index of the input-target pair to return.
Returns
-------
in_feats : Any
Inputs.
target : Any
Targets.
"""
feats = tuple(self.split.iloc[idx][self.input_keys])
tgts = tuple(self.split.iloc[idx][self.target_keys])
if self.x_tfms is not None:
feats = self.x_tfms(feats)
if self.y_tfms is not None:
tgts = self.y_tfms(tgts)
return feats, tgts
def _load_penguin_data() -> DataFrame:
"""Return the cleaned penguin data.
Returns
-------
DataFrame
The penguin dataset, with rows containing ``NaN``s dropped.
"""
data = load_penguins()
data = (
data.loc[~data.isna().any(axis=1)]
.sort_values(by=sorted(data.keys()))
.reset_index(drop=True)
)
# Transform the sex field into a float, with male represented by 1.0, female by 0.0
data.sex = (data.sex == "male").astype(float)
return data
def _split_data(penguin_df: DataFrame) -> Dict[str, DataFrame]:
"""Split the ``penguin_df`` into a training and validation set.
Parameters
----------
penguin_df : DataFrame
The full penguin data set.
Returns
-------
Dict[str, DataFrame]
Dictionary holding the ``"train"`` and ``"valid"`` splits. The valid
split has 10 females and 10 males of each species, and the training
split contains the rest of the dataset.
"""
valid_df = penguin_df.groupby(by=["species", "sex"]).sample(
n=10,
random_state=123,
)
# The training items are simply the items *not* in the valid split
train_df = penguin_df.loc[~penguin_df.index.isin(valid_df.index)]
return {"train": train_df, "valid": valid_df}