8
8
9
9
from .models import Project , MapToken
10
10
from .exceptions import RefreshTokenException
11
- from .settings import (
12
- RV_CPU_JOB_DEF , RV_CPU_QUEUE , DEVELOP_BRANCH , RV_CONFIG_URI_ROOT )
13
11
from .utils import upload_raster_vision_config
12
+ from .settings import RV_PROJ_CONFIG_DIR_URI
14
13
15
14
SPEC_PATH = os .path .join (os .path .dirname (os .path .realpath (__file__ )),
16
15
'spec.yml' )
@@ -156,12 +155,12 @@ def get_project_configs(self, project_ids, annotation_uris):
156
155
157
156
return project_configs
158
157
159
- output_zip_uri ,
160
- config_uri_root = RV_CONFIG_URI_ROOT ,
161
- job_queue = RV_CPU_QUEUE ,
162
- job_definition = RV_CPU_JOB_DEF ,
163
- branch_name = DEVELOP_BRANCH , attempts = 1 ):
164
158
def start_prep_train_data_job (self , rv_batch_client , project_ids ,
159
+ annotation_uris ,
160
+ output_zip_uri , label_map_uri ,
161
+ proj_config_dir_uri = RV_PROJ_CONFIG_DIR_URI ,
162
+ min_area = None , single_label = None ,
163
+ no_partial = True , channel_order = None ):
165
164
"""Start a Batch job to prepare object detection training data.
166
165
167
166
Args:
@@ -170,24 +169,41 @@ def start_prep_train_data_job(self, rv_batch_client, project_ids,
170
169
project_ids (list of str): ids of projects to make train data for
171
170
annotation_uris (list of str): annotation URIs for projects
172
171
output_zip_uri (str): URI of output zip file
173
- config_uri_root (str): The root of generated URIs for config files
174
- job_queue (str): name of the Batch job queue to run the job in
175
- job_definition (str): name of the Batch job definition
176
- branch_name (str): branch of the raster-vision repo to use
177
- attempts (int): number of attempts for the Batch job
172
+ label_map_uri (str): URI of output label map
173
+ proj_config_dir_uri (str): The root of generated URIs for config
174
+ files
175
+ min_area (float): minimum area of bounding boxes to include
176
+ single_label (str): Convert all labels to this label
177
+ no_partial (bool): Black out partially visible objects
178
+ channel_order: list of length 3 with GeoTIFF channel indices to
179
+ map to RGB.
178
180
179
181
Returns:
180
182
job_id (str): job_id of job started on Batch
181
183
"""
182
184
project_configs = self .get_project_configs (
183
185
project_ids , annotation_uris )
184
186
config_uri = upload_raster_vision_config (
185
- project_configs , config_uri_root )
186
-
187
- command = ('python -m rv.run prep_train_data --debug ' +
188
- '--chip-size 300 --num-neg-chips 100 ' +
189
- '--max-attempts 500 {} {}' ).format (
190
- config_uri , output_zip_uri )
187
+ project_configs , proj_config_dir_uri )
188
+
189
+ base_command = \
190
+ 'python -m rv.run prep_train_data --debug --chip-size 300 '
191
+ min_area_opt = ('--min-area {} ' .format (min_area )
192
+ if min_area is not None else '' )
193
+ single_label_opt = ('--single-label {} ' .format (single_label )
194
+ if single_label is not None else '' )
195
+ no_partial_opt = '--no-partial ' if no_partial else ''
196
+
197
+ channel_order_opt = ''
198
+ if channel_order is not None :
199
+ channel_order_str = ' ' .join ([
200
+ str (channel_ind ) for channel_ind in channel_order ])
201
+ channel_order_opt = ('--channel-order {} '
202
+ .format (channel_order_str ))
203
+
204
+ command = (base_command + min_area_opt + single_label_opt +
205
+ no_partial_opt + channel_order_opt + '{} {} {}' )
206
+ command = command .format (config_uri , output_zip_uri , label_map_uri )
191
207
192
208
job_name = 'prep_train_data_{}' .format (uuid .uuid1 ())
193
209
job_id = rv_batch_client .start_raster_vision_job (job_name , command )
0 commit comments