Skip to content

Commit c5c4bf4

Browse files
authored
verify that exported files are models files
1 parent 16019be commit c5c4bf4

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

Diff for: tabs/train/train.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ def refresh_embedders_folders():
214214

215215

216216
# Export
217-
## Get Pth and Index Files
218217
def get_pth_list():
219218
return [
220219
os.path.relpath(os.path.join(dirpath, filename), now_dir)
@@ -240,20 +239,32 @@ def refresh_pth_and_index_list():
240239
)
241240

242241

243-
## Export Pth and Index Files
242+
# Export Pth and Index Files
244243
def export_pth(pth_path):
245-
if pth_path and os.path.exists(pth_path):
244+
allowed_paths = get_pth_list()
245+
normalized_allowed_paths = [os.path.abspath(os.path.join(now_dir, p)) for p in allowed_paths]
246+
normalized_pth_path = os.path.abspath(os.path.join(now_dir, pth_path))
247+
248+
if normalized_pth_path in normalized_allowed_paths:
246249
return pth_path
247-
return None
250+
else:
251+
print(f"Attempted to export invalid pth path: {pth_path}")
252+
return None
248253

249254

250255
def export_index(index_path):
251-
if index_path and os.path.exists(index_path):
256+
allowed_paths = get_index_list()
257+
normalized_allowed_paths = [os.path.abspath(os.path.join(now_dir, p)) for p in allowed_paths]
258+
normalized_index_path = os.path.abspath(os.path.join(now_dir, index_path))
259+
260+
if normalized_index_path in normalized_allowed_paths:
252261
return index_path
253-
return None
262+
else:
263+
print(f"Attempted to export invalid index path: {index_path}")
264+
return None
254265

255266

256-
## Upload to Google Drive
267+
# Upload to Google Drive
257268
def upload_to_google_drive(pth_path, index_path):
258269
def upload_file(file_path):
259270
if file_path:

0 commit comments

Comments
 (0)