Skip to content

Commit 8820aed

Browse files
committed
Update dataset filename partitioning logic to match latest codebase changes
1 parent a03d6a2 commit 8820aed

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

project/datasets/builder/partition_dataset_filenames.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import logging
22
import os
3+
import random
34
from pathlib import Path
45

56
import atom3.pair as pa
67
import click
78
import pandas as pd
89
from atom3 import database as db
9-
from sklearn.model_selection import train_test_split
1010
from tqdm import tqdm
1111

1212
from project.utils.constants import DB5_TEST_PDB_CODES, ATOM_COUNT_LIMIT
@@ -64,15 +64,25 @@ def main(output_dir: str, source_type: str, filter_by_atom_count: bool, max_atom
6464
if not os.path.exists(pairs_postprocessed_val_txt): # Create val data list if not already existent
6565
open(pairs_postprocessed_val_txt, 'w').close()
6666
# Write out training-validation partitions for DIPS
67+
output_dirs = [filename
68+
for filename in os.listdir(output_dir)
69+
if os.path.isdir(os.path.join(output_dir, filename))]
70+
# Get training and validation directories separately
71+
num_train_dirs = int(0.8 * len(output_dirs))
72+
train_dirs = random.sample(output_dirs, num_train_dirs)
73+
val_dirs = list(set(output_dirs) - set(train_dirs))
74+
# Ascertain training and validation filename separately
6775
filenames_frame = pd.read_csv(pairs_postprocessed_txt, header=None)
68-
train_filenames_frame, val_filenames_frame, _, _ = train_test_split(
69-
filenames_frame,
70-
# Ignore labels for now - will create feature vectors in dataset class
71-
[None for _ in range(len(filenames_frame))],
72-
train_size=(8 / 10),
73-
test_size=(2 / 10)
74-
)
76+
train_filenames = [os.path.join(train_dir, filename)
77+
for train_dir in train_dirs
78+
for filename in os.listdir(os.path.join(output_dir, train_dir))
79+
if os.path.join(train_dir, filename) in filenames_frame.values]
80+
val_filenames = [os.path.join(val_dir, filename)
81+
for val_dir in val_dirs
82+
for filename in os.listdir(os.path.join(output_dir, val_dir))
83+
if os.path.join(val_dir, filename) in filenames_frame.values]
7584
# Create separate .txt files to describe the training list and validation list, respectively
85+
train_filenames_frame, val_filenames_frame = pd.DataFrame(train_filenames), pd.DataFrame(val_filenames)
7686
train_filenames_frame.to_csv(pairs_postprocessed_train_txt, header=None, index=None, sep=' ', mode='a')
7787
val_filenames_frame.to_csv(pairs_postprocessed_val_txt, header=None, index=None, sep=' ', mode='a')
7888

0 commit comments

Comments
 (0)