|
1 | 1 | import logging
|
2 | 2 | import os
|
| 3 | +import random |
3 | 4 | from pathlib import Path
|
4 | 5 |
|
5 | 6 | import atom3.pair as pa
|
6 | 7 | import click
|
7 | 8 | import pandas as pd
|
8 | 9 | from atom3 import database as db
|
9 |
| -from sklearn.model_selection import train_test_split |
10 | 10 | from tqdm import tqdm
|
11 | 11 |
|
12 | 12 | from project.utils.constants import DB5_TEST_PDB_CODES, ATOM_COUNT_LIMIT
|
@@ -64,15 +64,25 @@ def main(output_dir: str, source_type: str, filter_by_atom_count: bool, max_atom
|
64 | 64 | if not os.path.exists(pairs_postprocessed_val_txt): # Create val data list if not already existent
|
65 | 65 | open(pairs_postprocessed_val_txt, 'w').close()
|
66 | 66 | # Write out training-validation partitions for DIPS
|
| 67 | + output_dirs = [filename |
| 68 | + for filename in os.listdir(output_dir) |
| 69 | + if os.path.isdir(os.path.join(output_dir, filename))] |
| 70 | + # Get training and validation directories separately |
| 71 | + num_train_dirs = int(0.8 * len(output_dirs)) |
| 72 | + train_dirs = random.sample(output_dirs, num_train_dirs) |
| 73 | + val_dirs = list(set(output_dirs) - set(train_dirs)) |
| 74 | + # Ascertain training and validation filename separately |
67 | 75 | filenames_frame = pd.read_csv(pairs_postprocessed_txt, header=None)
|
68 |
| - train_filenames_frame, val_filenames_frame, _, _ = train_test_split( |
69 |
| - filenames_frame, |
70 |
| - # Ignore labels for now - will create feature vectors in dataset class |
71 |
| - [None for _ in range(len(filenames_frame))], |
72 |
| - train_size=(8 / 10), |
73 |
| - test_size=(2 / 10) |
74 |
| - ) |
| 76 | + train_filenames = [os.path.join(train_dir, filename) |
| 77 | + for train_dir in train_dirs |
| 78 | + for filename in os.listdir(os.path.join(output_dir, train_dir)) |
| 79 | + if os.path.join(train_dir, filename) in filenames_frame.values] |
| 80 | + val_filenames = [os.path.join(val_dir, filename) |
| 81 | + for val_dir in val_dirs |
| 82 | + for filename in os.listdir(os.path.join(output_dir, val_dir)) |
| 83 | + if os.path.join(val_dir, filename) in filenames_frame.values] |
75 | 84 | # Create separate .txt files to describe the training list and validation list, respectively
|
| 85 | + train_filenames_frame, val_filenames_frame = pd.DataFrame(train_filenames), pd.DataFrame(val_filenames) |
76 | 86 | train_filenames_frame.to_csv(pairs_postprocessed_train_txt, header=None, index=None, sep=' ', mode='a')
|
77 | 87 | val_filenames_frame.to_csv(pairs_postprocessed_val_txt, header=None, index=None, sep=' ', mode='a')
|
78 | 88 |
|
|
0 commit comments