Skip to content

Commit f252383

Browse files
committed
modified for relative path
1 parent 3e2d68d commit f252383

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ data
99
tests
1010
tutorial/model_test.py
1111
model_test.py
12-
data_test
12+
tutorial/data_test

tutorial/train.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from model import UNET
99
from utils import (load_checkpoint, save_checkpoint, get_loaders, check_accuracy, save_predictions_as_imgs)
1010
from utils import DiceLoss2D
11+
import os
1112

1213
# Hyperparameters
1314
LEARNING_RATE = 1e-4
@@ -19,10 +20,17 @@
1920
IMAGE_WIDTH = 701 # 1918 originally
2021
PIN_MEMORY = True
2122
LOAD_MODEL = False
22-
TRAIN_IMG_DIR = "/Users/harsha/PycharmProjects/robotic_surgery_tool_segmentation/data_test/train_images/"
23-
TRAIN_MASK_DIR = "/Users/harsha/PycharmProjects/robotic_surgery_tool_segmentation/data_test/train_masks/"
24-
VAL_IMG_DIR = "/Users/harsha/PycharmProjects/robotic_surgery_tool_segmentation/data_test/val_images/"
25-
VAL_MASK_DIR = "/Users/harsha/PycharmProjects/robotic_surgery_tool_segmentation/data_test/val_masks/"
23+
24+
current_dir = os.path.abspath(os.getcwd())
25+
TRAIN_IMG_DIR = os.path.join(current_dir, 'data_test/train_images/')
26+
TRAIN_MASK_DIR = os.path.join(current_dir, 'data_test/train_masks/')
27+
VAL_IMG_DIR = os.path.join(current_dir, 'data_test/val_images/')
28+
VAL_MASK_DIR = os.path.join(current_dir, 'data_test/val_masks/')
29+
30+
# TRAIN_IMG_DIR = "/Users/harsha/PycharmProjects/robotic_surgery_tool_segmentation/data_test/train_images/"
31+
# TRAIN_MASK_DIR = "/Users/harsha/PycharmProjects/robotic_surgery_tool_segmentation/data_test/train_masks/"
32+
# VAL_IMG_DIR = "/Users/harsha/PycharmProjects/robotic_surgery_tool_segmentation/data_test/val_images/"
33+
# VAL_MASK_DIR = "/Users/harsha/PycharmProjects/robotic_surgery_tool_segmentation/data_test/val_masks/"
2634

2735
def train_fn(loader, model, optimizer, loss_fn, scaler):
2836
loop = tqdm(loader)

0 commit comments

Comments
 (0)