-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathdemo.py
40 lines (28 loc) · 1.13 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import os
import sys
import cv2
from models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
from models.mbv2_mlsd_large import MobileV2_MLSD_Large
from utils import pred_lines
def main():
current_dir = os.path.dirname(__file__)
if current_dir == "":
current_dir = "./"
# model_path = current_dir+'/models/mlsd_tiny_512_fp32.pth'
# model = MobileV2_MLSD_Tiny().cuda().eval()
model_path = current_dir + '/models/mlsd_large_512_fp32.pth'
model = MobileV2_MLSD_Large().cuda().eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(model_path, map_location=device), strict=True)
img_fn = current_dir+'/data/frame_1.jpg'
img = cv2.imread(img_fn)
img = cv2.resize(img, (512, 512))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
lines = pred_lines(img, model, [512, 512], 0.1, 20)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
for l in lines:
cv2.line(img, (int(l[0]), int(l[1])), (int(l[2]), int(l[3])), (0,200,200), 1,16)
cv2.imwrite(current_dir+'/data/frame_1_out.jpg', img)
if __name__ == '__main__':
main()