Skip to content

Commit 25dea80

Browse files
author
Ruilong Li
committed
update before release
1 parent 709b8cb commit 25dea80

File tree

3 files changed

+31
-59
lines changed

3 files changed

+31
-59
lines changed

demo.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,26 @@
33
import torch
44
import numpy as np
55
import cv2
6+
import glob
67
import streamer_pytorch as streamer
78

9+
810
parser = argparse.ArgumentParser(description='.')
911
parser.add_argument(
10-
'--camera', action="store_true")
12+
'--camera', action="store_true", help="whether to use webcam.")
13+
parser.add_argument(
14+
'--images', default="", nargs="*", help="paths of image.")
1115
parser.add_argument(
12-
'--images', default="", nargs="*")
16+
'--image_folder', default="", help="path of image folder.")
1317
parser.add_argument(
14-
'--videos', default="", nargs="*")
18+
'--videos', default="", nargs="*", help="paths of video.")
1519
parser.add_argument(
16-
'--loop', action="store_true")
20+
'--loop', action="store_true", help="whether to repeat images/video.")
1721
parser.add_argument(
18-
'--vis', action="store_true")
22+
'--vis', action="store_true", help="whether to visualize.")
1923
args = parser.parse_args()
2024

25+
2126
def visulization(data):
2227
window = data[0].numpy()
2328
window = window.transpose(1, 2, 0)
@@ -29,14 +34,21 @@ def visulization(data):
2934
cv2.imshow('window', window)
3035
cv2.waitKey(1)
3136

37+
3238
if args.camera:
33-
data_stream = streamer.CaptureStreamer()
39+
data_stream = streamer.CaptureStreamer(pad=False)
3440
elif len(args.videos) > 0:
3541
data_stream = streamer.VideoListStreamer(
36-
args.videos * (100 if args.loop else 1))
42+
args.videos * (10 if args.loop else 1))
3743
elif len(args.images) > 0:
3844
data_stream = streamer.ImageListStreamer(
39-
args.images * (100 if args.loop else 1))
45+
args.images * (10000 if args.loop else 1))
46+
elif args.image_folder is not None:
47+
images = sorted(glob.glob(args.image_folder+'/*.jpg'))
48+
images += sorted(glob.glob(args.image_folder+'/*.png'))
49+
data_stream = streamer.ImageListStreamer(
50+
images * (10 if args.loop else 1))
51+
4052

4153
loader = torch.utils.data.DataLoader(
4254
data_stream,
@@ -45,6 +57,7 @@ def visulization(data):
4557
pin_memory=False,
4658
)
4759

60+
4861
try:
4962
for data in tqdm.tqdm(loader):
5063
if args.vis:

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
setuptools.setup(
66
name='streamer_pytorch',
7-
url='https://github.com/liruilong940607/streamer_pytorch',
7+
url='https://github.com/Project-Splinter/streamer_pytorch',
88
description='Pytorch based data streamer. (Capture, Video & Image).',
9-
version='0.0.1',
9+
version='0.0.2',
1010
author='Ruilong Li',
1111
author_email='[email protected]',
1212
license='MIT License',

streamer_pytorch/streamer.py

+8-49
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def get_affine_matrix(center, translate, scale):
4343

4444

4545
class BaseStreamer():
46+
"""This streamer will return images at 512x512 size.
47+
"""
4648
def __init__(self,
4749
width=512, height=512, pad=True,
4850
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
@@ -76,6 +78,8 @@ def __len__(self):
7678

7779

7880
class CaptureStreamer(BaseStreamer):
81+
"""This streamer takes webcam as input.
82+
"""
7983
def __init__(self, id=0, width=512, height=512, pad=True, **kwargs):
8084
super().__init__(width, height, pad, **kwargs)
8185
self.capture = cv2.VideoCapture(id)
@@ -94,6 +98,8 @@ def __del__(self):
9498

9599

96100
class VideoListStreamer(BaseStreamer):
101+
"""This streamer takes a list of video files as input.
102+
"""
97103
def __init__(self, files, width=512, height=512, pad=True, **kwargs):
98104
super().__init__(width, height, pad, **kwargs)
99105
self.files = files
@@ -115,6 +121,8 @@ def __del__(self):
115121

116122

117123
class ImageListStreamer(BaseStreamer):
124+
"""This streamer takes a list of image files as input.
125+
"""
118126
def __init__(self, files, width=512, height=512, pad=True, **kwargs):
119127
super().__init__(width, height, pad, **kwargs)
120128
self.files = files
@@ -129,53 +137,4 @@ def __len__(self):
129137
return len(self.files)
130138

131139

132-
if __name__ == "__main__":
133-
import tqdm
134-
import argparse
135-
136-
parser = argparse.ArgumentParser(description='.')
137-
parser.add_argument(
138-
'--camera', action="store_true")
139-
parser.add_argument(
140-
'--images', default="", nargs="*")
141-
parser.add_argument(
142-
'--videos', default="", nargs="*")
143-
parser.add_argument(
144-
'--loop', action="store_true")
145-
args = parser.parse_args()
146-
147-
def visulization(data):
148-
window = data[0].numpy()
149-
window = window.transpose(1, 2, 0)
150-
window = (window * 0.5 + 0.5) * 255.0
151-
window = np.uint8(window)
152-
window = cv2.cvtColor(window, cv2.COLOR_BGR2RGB)
153-
window = cv2.resize(window, (0, 0), fx=2, fy=2)
154-
155-
cv2.imshow('window', window)
156-
cv2.waitKey(1)
157-
158-
if args.camera:
159-
data_stream = CaptureStreamer()
160-
elif len(args.videos) > 0:
161-
data_stream = VideoListStreamer(args.videos * (100 if args.loop else 1))
162-
elif len(args.images) > 0:
163-
data_stream = ImageListStreamer(args.images * (100 if args.loop else 1))
164-
165-
loader = torch.utils.data.DataLoader(
166-
data_stream,
167-
batch_size=1,
168-
num_workers=1,
169-
pin_memory=False,
170-
)
171-
172-
try:
173-
for data in tqdm.tqdm(loader):
174-
visulization(data)
175-
pass
176-
except Exception as e:
177-
print (e)
178-
del data_stream
179-
180-
181140

0 commit comments

Comments
 (0)