|
2 | 2 | """Classes for interacting with jobs."""
|
3 | 3 |
|
4 | 4 | import logging
|
| 5 | +from multiprocessing.sharedctypes import Value |
5 | 6 | import time
|
| 7 | +import boto3 |
6 | 8 | from datetime import datetime
|
7 | 9 | from types import SimpleNamespace
|
8 | 10 | from urllib.parse import urlencode
|
9 | 11 | from ._api_object import ApiObject
|
10 | 12 | from ._size import human_read_to_bytes
|
11 | 13 | from ._util import encode_data_uri, depth, file_to_chunks, bytes_to_chunks
|
12 | 14 | from .error import Timeout
|
13 |
| -from .models import Model |
| 15 | +from .models import Model, Models |
14 | 16 | from deprecation import deprecated
|
15 | 17 |
|
16 | 18 |
|
@@ -201,6 +203,34 @@ def __fix_single_source_job(self, sources):
|
201 | 203 | else:
|
202 | 204 | return sources
|
203 | 205 |
|
| 206 | + def check_storagegrid_endpoint(self, endpoint, bucket, access_key_id, secret_access_key): |
| 207 | + |
| 208 | + # establish session with aws sdk |
| 209 | + session = boto3.session.Session() |
| 210 | + # try to connect to storagegrid endpoint |
| 211 | + req_check=False |
| 212 | + try: |
| 213 | + s3 = session.resource(service_name='s3', endpoint_url=endpoint, verify=False, aws_access_key_id=access_key_id, aws_secret_access_key=secret_access_key) |
| 214 | + bucket_exists = s3.Bucket(bucket) in s3.buckets.all() |
| 215 | + if not bucket_exists: |
| 216 | + raise ValueError |
| 217 | + req_check=True |
| 218 | + new_endpoint=endpoint |
| 219 | + except Exception as e: |
| 220 | + if not str(endpoint).startswith("https://") and not str(endpoint).startswith("http://"): |
| 221 | + new_endpoint = "https://" + str(endpoint) |
| 222 | + self.check_storagegrid_endpoint(new_endpoint, bucket, access_key_id, secret_access_key) |
| 223 | + req_check = True |
| 224 | + elif str(endpoint).startswith("http://"): |
| 225 | + new_endpoint = "https://" + endpoint.split("http://")[-1] |
| 226 | + self.check_storagegrid_endpoint(new_endpoint, bucket, access_key_id, secret_access_key) |
| 227 | + req_check = True |
| 228 | + |
| 229 | + if not req_check: |
| 230 | + raise ValueError("Invalid endpoint or bucket name. The endpoint param should point to a valid endpoint associated with your StorageGRID account. Confirm both your endpoint and the bucket name in your input sources dictionary are correct and try again.") |
| 231 | + |
| 232 | + return new_endpoint |
| 233 | + |
204 | 234 | def submit_text(self, model, version, sources, explain=False):
|
205 | 235 | """Submits text data for a multiple source `Job`.
|
206 | 236 |
|
@@ -613,14 +643,19 @@ def submit_storagegrid(self, model, version, sources, access_key_id, secret_acce
|
613 | 643 | identifier = Model._coerce_identifier(model)
|
614 | 644 | version = str(version)
|
615 | 645 | access_key_id = str(access_key_id)
|
616 |
| - # storageGRID endpoint must begin with "https://", so this conducts a quick test |
617 |
| - if not str(endpoint).startswith("https://") and not str(endpoint).startswith("http://"): |
618 |
| - endpoint = "https://" + str(endpoint) |
619 |
| - elif str(endpoint).startswith("http://"): |
620 |
| - endpoint = "https://" + endpoint.split("http://")[-1] |
621 |
| - else: |
622 |
| - endpoint = str(endpoint) |
623 |
| - |
| 646 | + models = Models(self._api_client) |
| 647 | + sample_input_key = list(models.get_version_input_sample(identifier, version)["input"]["sources"].keys())[0] |
| 648 | + input_filename = list(models.get_version_input_sample(identifier, version)["input"]["sources"][sample_input_key].keys())[0] |
| 649 | + # validate storageGRID endpoint |
| 650 | + try: |
| 651 | + first_key = list(sources.keys())[0] |
| 652 | + if first_key == input_filename: |
| 653 | + bucket = sources[first_key]["bucket"] |
| 654 | + else: |
| 655 | + bucket = sources[first_key][input_filename]["bucket"] |
| 656 | + except Exception: |
| 657 | + raise ValueError("Invalid input sources file. Confirm your sources dictionary meets the required format (https://docs.modzy.com/docs/clientjobssubmit_netapp_storage_grid) and try again.") |
| 658 | + endpoint = self.check_storagegrid_endpoint(endpoint, bucket, access_key_id, secret_access_key) |
624 | 659 | body = {
|
625 | 660 | "model": {
|
626 | 661 | "identifier": identifier,
|
|
0 commit comments