Skip to content

Deeplabv3 modified #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/data/dataloader/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class CitySegmentation(SegmentationDataset):
BASE_DIR = 'cityscapes'
NUM_CLASS = 19

def __init__(self, root='../datasets/citys', split='train', mode=None, transform=None, **kwargs):
def __init__(self, root='/content/data', split='train', mode=None, transform=None, **kwargs):
super(CitySegmentation, self).__init__(root, split, mode, transform, **kwargs)
# self.root = os.path.join(root, self.BASE_DIR)
assert os.path.exists(self.root), "Please setup the dataset using ../datasets/cityscapes.py"
Expand Down
2 changes: 1 addition & 1 deletion core/models/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def forward(self, x):
class _DeepLabHead(nn.Module):
def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
super(_DeepLabHead, self).__init__()
self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, norm_kwargs=norm_kwargs, **kwargs)
self.aspp = _ASPP(2048, [6, 12, 18], norm_layer=norm_layer, norm_kwargs=norm_kwargs, **kwargs)
self.block = nn.Sequential(
nn.Conv2d(256, 256, 3, padding=1, bias=False),
norm_layer(256, **({} if norm_kwargs is None else norm_kwargs)),
Expand Down
2 changes: 1 addition & 1 deletion core/models/deeplabv3_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def forward(self, x):
size = x.size()[2:]
c1, c3, c4 = self.base_forward(x)
outputs = list()
x = self.head(c4, c1)
x = self.head(c4, c1) #跳级融合
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
outputs.append(x)
if self.aux:
Expand Down
56 changes: 46 additions & 10 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
from core.utils.score import SegmentationMetric


from torch.utils.tensorboard import SummaryWriter
# 定义该次实验名称

writer = SummaryWriter("/content/log")

def parse_args():
parser = argparse.ArgumentParser(description='Semantic Segmentation Training With Pytorch')
# model and dataset
Expand Down Expand Up @@ -97,6 +102,11 @@ def parse_args():
help='run validation every val-epoch')
parser.add_argument('--skip-val', action='store_true', default=False,
help='skip validation during training')

###AAAAA网盘保存路径
parser.add_argument('--dirtang', type=str, default=None,
help='input the dir of train info ')

args = parser.parse_args()

# default settings for epochs, batch_size and lr
Expand Down Expand Up @@ -129,7 +139,12 @@ class Trainer(object):
def __init__(self, args):
self.args = args
self.device = torch.device(args.device)



####AAAA我定义的命令
self.cmd_tang1='cp -f ~/.torch/models/* ' +args.dirtang+'/pth' #临时保存命令
self.cmd_tang2='cp -f /content/log/* ' +args.dirtang+'/log'

# image transform
input_transform = transforms.Compose([
transforms.ToTensor(),
Expand All @@ -147,6 +162,7 @@ def __init__(self, args):
val_sampler = make_data_sampler(val_dataset, False, args.distributed)
val_batch_sampler = make_batch_data_sampler(val_sampler, args.batch_size)

#数据增强修改点
self.train_loader = data.DataLoader(dataset=train_dataset,
batch_sampler=train_batch_sampler,
num_workers=args.workers,
Expand All @@ -155,12 +171,12 @@ def __init__(self, args):
batch_sampler=val_batch_sampler,
num_workers=args.workers,
pin_memory=True)

# create network
# create network 初始化网络
BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone,
aux=args.aux, jpu=args.jpu, norm_layer=BatchNorm2d).to(self.device)

# resume checkpoint if needed
if args.resume:
if os.path.isfile(args.resume):
Expand All @@ -180,6 +196,8 @@ def __init__(self, args):
if hasattr(self.model, 'exclusive'):
for module in self.model.exclusive:
params_list.append({'params': getattr(self.model, module).parameters(), 'lr': args.lr * 10})

#optimizer修改点
self.optimizer = torch.optim.SGD(params_list,
lr=args.lr,
momentum=args.momentum,
Expand Down Expand Up @@ -209,7 +227,10 @@ def train(self):
save_per_iters = self.args.save_epoch * self.args.iters_per_epoch
start_time = time.time()
logger.info('Start training, Total Epochs: {:d} = Total Iterations {:d}'.format(epochs, max_iters))


###tang3


self.model.train()
for iteration, (images, targets, _) in enumerate(self.train_loader):
iteration = iteration + 1
Expand All @@ -220,6 +241,9 @@ def train(self):

outputs = self.model(images)
loss_dict = self.criterion(outputs, targets)

###AAA
del outputs #减少内存消耗

losses = sum(loss for loss in loss_dict.values())

Expand All @@ -239,22 +263,31 @@ def train(self):
"Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".format(
iteration, max_iters, self.optimizer.param_groups[0]['lr'], losses_reduced.item(),
str(datetime.timedelta(seconds=int(time.time() - start_time))), eta_string))

###AAA1
writer.add_scalar('loss', losses_reduced.item(), iteration)
writer.add_scalar('Learn rate', self.optimizer.param_groups[0]['lr'], iteration)

if iteration % save_per_iters == 0 and save_to_disk:
save_checkpoint(self.model, self.args, is_best=False)

if not self.args.skip_val and iteration % val_per_iters == 0:
self.validation()
self.validation(iteration)

#AAAAA
os.system(str(self.cmd_tang1))
os.system(str(self.cmd_tang2)) #val之后保存到云盘

self.model.train()



save_checkpoint(self.model, self.args, is_best=False)
total_training_time = time.time() - start_time
total_training_str = str(datetime.timedelta(seconds=total_training_time))
logger.info(
"Total training time: {} ({:.4f}s / it)".format(
total_training_str, total_training_time / max_iters))

def validation(self):
def validation(self,iteration):
# total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
is_best = False
self.metric.reset()
Expand All @@ -273,7 +306,10 @@ def validation(self):
self.metric.update(outputs[0], target)
pixAcc, mIoU = self.metric.get()
logger.info("Sample: {:d}, Validation pixAcc: {:.3f}, mIoU: {:.3f}".format(i + 1, pixAcc, mIoU))

###AAA2
writer.add_scalar('mIOU', mIoU, iteration)
writer.add_scalar('pixAcc', pixAcc, iteration)
writer.flush()
new_pred = (pixAcc + mIoU) / 2
if new_pred > self.best_pred:
is_best = True
Expand Down