Skip to content

Commit 72ee7d6

Browse files
committed
add method for data splitting regression
1 parent 32b805f commit 72ee7d6

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

src/sparcscore/ml/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,45 @@ def combine_datasets_balanced(list_of_datasets, class_labels, train_per_class, v
103103
val_dataset = torch.utils.data.ConcatDataset(val_dataset)
104104

105105
return train_dataset, val_dataset, test_dataset
106+
107+
108+
def split_dataset_regression(dataset, train_size, test_size, val_size, seed=None):
109+
"""
110+
Split a dataset into train, test, and validation set.
111+
112+
Parameters
113+
----------
114+
dataset : torch.utils.data.Dataset
115+
Dataset to be split.
116+
train_size : int
117+
Number of samples in the train set.
118+
test_size : int
119+
Number of samples in the test set.
120+
val_size : int
121+
Number of samples in the validation set.
122+
123+
Returns
124+
-------
125+
train : torch.utils.data.Dataset
126+
Train dataset.
127+
val : torch.utils.data.Dataset
128+
Validation dataset.
129+
test : torch.utils.data.Dataset
130+
Test dataset.
131+
"""
132+
residual_size = len(dataset) - train_size - test_size - val_size
133+
134+
if residual_size < 0:
135+
raise ValueError(
136+
f"Dataset with length {len(dataset)} is too small to be split into test set of size {test_size}, "
137+
f"train set of size {train_size}, and validation set of size {val_size}. "
138+
)
139+
140+
if seed is not None:
141+
gen = torch.Generator()
142+
gen.manual_seed(seed)
143+
train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size], generator=gen)
144+
else:
145+
train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size])
146+
147+
return train, val, test

0 commit comments

Comments
 (0)