@@ -103,3 +103,45 @@ def combine_datasets_balanced(list_of_datasets, class_labels, train_per_class, v
103
103
val_dataset = torch .utils .data .ConcatDataset (val_dataset )
104
104
105
105
return train_dataset , val_dataset , test_dataset
106
+
107
+
108
+ def split_dataset_regression (dataset , train_size , test_size , val_size , seed = None ):
109
+ """
110
+ Split a dataset into train, test, and validation set.
111
+
112
+ Parameters
113
+ ----------
114
+ dataset : torch.utils.data.Dataset
115
+ Dataset to be split.
116
+ train_size : int
117
+ Number of samples in the train set.
118
+ test_size : int
119
+ Number of samples in the test set.
120
+ val_size : int
121
+ Number of samples in the validation set.
122
+
123
+ Returns
124
+ -------
125
+ train : torch.utils.data.Dataset
126
+ Train dataset.
127
+ val : torch.utils.data.Dataset
128
+ Validation dataset.
129
+ test : torch.utils.data.Dataset
130
+ Test dataset.
131
+ """
132
+ residual_size = len (dataset ) - train_size - test_size - val_size
133
+
134
+ if residual_size < 0 :
135
+ raise ValueError (
136
+ f"Dataset with length { len (dataset )} is too small to be split into test set of size { test_size } , "
137
+ f"train set of size { train_size } , and validation set of size { val_size } . "
138
+ )
139
+
140
+ if seed is not None :
141
+ gen = torch .Generator ()
142
+ gen .manual_seed (seed )
143
+ train , test , val , _ = torch .utils .data .random_split (dataset , [train_size , test_size , val_size , residual_size ], generator = gen )
144
+ else :
145
+ train , test , val , _ = torch .utils .data .random_split (dataset , [train_size , test_size , val_size , residual_size ])
146
+
147
+ return train , val , test
0 commit comments