Skip to content
This repository was archived by the owner on Dec 10, 2019. It is now read-only.

Commit a565bfa

Browse files
committed
Add RasterVisionBatchClient abstraction
This class allows us to avoid passing the job_def, job_queue, branch_name, and attempts to every method that starts a Raster Vision job.
1 parent beba57c commit a565bfa

File tree

3 files changed

+51
-41
lines changed

3 files changed

+51
-41
lines changed

rasterfoundry/api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from .models import Project, MapToken
1010
from .exceptions import RefreshTokenException
11-
from .utils import start_raster_vision_job, upload_raster_vision_config
1211
from .settings import (
1312
RV_CPU_JOB_DEF, RV_CPU_QUEUE, DEVELOP_BRANCH, RV_CONFIG_URI_ROOT)
13+
from .utils import upload_raster_vision_config
1414

1515
SPEC_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)),
1616
'spec.yml')
@@ -156,15 +156,17 @@ def get_project_configs(self, project_ids, annotation_uris):
156156

157157
return project_configs
158158

159-
def start_prep_train_data_job(self, project_ids, annotation_uris,
160159
output_zip_uri,
161160
config_uri_root=RV_CONFIG_URI_ROOT,
162161
job_queue=RV_CPU_QUEUE,
163162
job_definition=RV_CPU_JOB_DEF,
164163
branch_name=DEVELOP_BRANCH, attempts=1):
164+
def start_prep_train_data_job(self, rv_batch_client, project_ids,
165165
"""Start a Batch job to prepare object detection training data.
166166
167167
Args:
168+
rv_batch_client: a RasterVisionBatchClient object used to start
169+
Batch jobs
168170
project_ids (list of str): ids of projects to make train data for
169171
annotation_uris (list of str): annotation URIs for projects
170172
output_zip_uri (str): URI of output zip file
@@ -188,7 +190,5 @@ def start_prep_train_data_job(self, project_ids, annotation_uris,
188190
config_uri, output_zip_uri)
189191

190192
job_name = 'prep_train_data_{}'.format(uuid.uuid1())
191-
job_id = start_raster_vision_job(
192-
job_name, command, job_queue, job_definition, branch_name,
193-
attempts=attempts)
193+
job_id = rv_batch_client.start_raster_vision_job(job_name, command)
194194
return job_id

rasterfoundry/models/project.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from .. import NOTEBOOK_SUPPORT
66
from ..decorators import check_notebook
77
from ..exceptions import GatewayTimeoutException
8-
from ..utils import start_raster_vision_job
9-
from ..settings import RV_CPU_JOB_DEF, RV_CPU_QUEUE, DEVELOP_BRANCH
108
from .map_token import MapToken
119

1210

@@ -176,13 +174,13 @@ def get_image_source_uris(self):
176174

177175
return source_uris
178176

179-
def start_predict_job(self, inference_graph_uri, label_map_uri,
180-
predictions_uri, job_queue=RV_CPU_QUEUE,
181-
job_definition=RV_CPU_JOB_DEF,
182-
branch_name=DEVELOP_BRANCH, attempts=1):
177+
def start_predict_job(self, rv_batch_client, inference_graph_uri,
178+
label_map_uri, predictions_uri):
183179
"""Start a Batch job to perform object detection on this project.
184180
185181
Args:
182+
rv_batch_client: instance of RasterVisionBatchClient used to start
183+
Batch jobs
186184
inference_graph_uri (str): file with exported object detection
187185
model file
188186
label_map_uri (str): file with mapping from class id to display name
@@ -203,10 +201,7 @@ def start_predict_job(self, inference_graph_uri, label_map_uri,
203201
command = 'python -m rv.run predict {} {} {} {}'.format(
204202
inference_graph_uri, label_map_uri, source_uris_str,
205203
predictions_uri)
206-
job_id = start_raster_vision_job(
207-
job_name, command, job_queue=job_queue,
208-
job_definition=job_definition, branch_name=branch_name,
209-
attempts=attempts)
204+
job_id = rv_batch_client.start_raster_vision_job(job_name, command)
210205

211206
return job_id
212207

rasterfoundry/utils.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,51 @@
88

99
import boto3
1010

11+
from .settings import RV_CPU_JOB_DEF, RV_CPU_QUEUE, DEVELOP_BRANCH
1112

12-
def start_raster_vision_job(job_name, command, job_queue, job_definition,
13-
branch_name, attempts=1):
14-
"""Start a raster-vision Batch job.
1513

16-
Args:
17-
job_name (str): name of the Batch job
18-
command (str): command to run inside the Docker container
19-
job_queue (str): name of the Batch job queue to run the job in
20-
job_definition (str): name of the Batch job definition
21-
branch_name (str): branch of the raster-vision repo to use
22-
attempts (int): number of attempts for the Batch job
14+
class RasterVisionBatchClient():
15+
def __init__(self, job_queue=RV_CPU_QUEUE, job_definition=RV_CPU_JOB_DEF,
16+
branch_name=DEVELOP_BRANCH, attempts=1):
17+
"""Create a Raster Vision Batch Client
2318
24-
Returns:
25-
job_id (str): job_id of job started on Batch
26-
"""
27-
batch_client = boto3.client('batch')
28-
# `run_script.sh $branch_name $command` downloads a branch of the
29-
# raster-vision repo and then runs the command.
30-
job_command = ['run_script.sh', branch_name, command]
31-
job_id = batch_client.submit_job(
32-
jobName=job_name, jobQueue=job_queue, jobDefinition=job_definition,
33-
containerOverrides={
34-
'command': job_command
35-
},
36-
retryStrategy={
37-
'attempts': attempts
38-
})['jobId']
19+
Args:
20+
job_queue (str): name of the Batch job queue to run the job in
21+
job_definition (str): name of the Batch job definition
22+
branch_name (str): branch of the raster-vision repo to use
23+
attempts (int): number of attempts for each job
24+
"""
25+
26+
self.job_queue = job_queue
27+
self.job_definition = job_definition
28+
self.branch_name = branch_name
29+
self.attempts = attempts
30+
self.batch_client = boto3.client('batch')
31+
32+
def start_raster_vision_job(self, job_name, command):
33+
"""Start a raster-vision Batch job.
34+
35+
Args:
36+
job_name (str): name of the Batch job
37+
command (str): command to run inside the Docker container
38+
39+
Returns:
40+
job_id (str): job_id of job started on Batch
41+
"""
42+
# `run_script.sh $branch_name $command` downloads a branch of the
43+
# raster-vision repo and then runs the command.
44+
job_command = ['run_script.sh', self.branch_name, command]
45+
job_id = self.batch_client.submit_job(
46+
jobName=job_name, jobQueue=self.job_queue,
47+
jobDefinition=self.job_definition,
48+
containerOverrides={
49+
'command': job_command
50+
},
51+
retryStrategy={
52+
'attempts': self.attempts
53+
})['jobId']
3954

40-
return job_id
55+
return job_id
4156

4257

4358
def upload_raster_vision_config(config_dict, config_uri_root):

0 commit comments

Comments
 (0)