|
| 1 | +# Copyright 2021 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the 'License'); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an 'AS IS' BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | + |
| 13 | +from dataclasses import dataclass |
| 14 | +import itertools |
| 15 | +import json |
| 16 | +import multiprocessing as mp |
| 17 | +import os |
| 18 | +import subprocess |
| 19 | +import sys |
| 20 | +import time |
| 21 | +from typing import Any, Callable, Dict, Iterable, Optional |
| 22 | +import uuid |
| 23 | + |
| 24 | +import pytest |
| 25 | + |
| 26 | +# Default options. |
| 27 | +UUID = uuid.uuid4().hex[0:6] |
| 28 | +PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] |
| 29 | +REGION = "us-west1" |
| 30 | +ZONE = "us-west1-b" |
| 31 | + |
| 32 | +RETRY_MAX_TIME = 5 * 60 # 5 minutes in seconds |
| 33 | + |
| 34 | + |
| 35 | +@dataclass |
| 36 | +class Utils: |
| 37 | + uuid: str = UUID |
| 38 | + project: str = PROJECT |
| 39 | + region: str = REGION |
| 40 | + zone: str = ZONE |
| 41 | + |
| 42 | + @staticmethod |
| 43 | + def storage_bucket(bucket_name: str) -> str: |
| 44 | + from google.cloud import storage |
| 45 | + |
| 46 | + storage_client = storage.Client() |
| 47 | + bucket_unique_name = f"{bucket_name}-{UUID}" |
| 48 | + bucket = storage_client.create_bucket(bucket_unique_name) |
| 49 | + |
| 50 | + print(f"storage_bucket: {bucket_unique_name}") |
| 51 | + yield bucket_unique_name |
| 52 | + |
| 53 | + bucket.delete(force=True) |
| 54 | + |
| 55 | + @staticmethod |
| 56 | + def bigquery_dataset(dataset_name: str, project: str = PROJECT) -> str: |
| 57 | + from google.cloud import bigquery |
| 58 | + |
| 59 | + bigquery_client = bigquery.Client() |
| 60 | + dataset = bigquery_client.create_dataset( |
| 61 | + bigquery.Dataset(f"{project}.{dataset_name.replace('-', '_')}_{UUID}") |
| 62 | + ) |
| 63 | + |
| 64 | + print(f"bigquery_dataset: {dataset.full_dataset_id}") |
| 65 | + yield dataset.full_dataset_id |
| 66 | + |
| 67 | + bigquery_client.delete_dataset( |
| 68 | + dataset.full_dataset_id.replace(":", "."), delete_contents=True |
| 69 | + ) |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def bigquery_query(query: str) -> Iterable[Dict[str, Any]]: |
| 73 | + from google.cloud import bigquery |
| 74 | + |
| 75 | + bigquery_client = bigquery.Client() |
| 76 | + for row in bigquery_client.query(query): |
| 77 | + yield dict(row) |
| 78 | + |
| 79 | + @staticmethod |
| 80 | + def pubsub_topic(topic_name: str, project: str = PROJECT) -> str: |
| 81 | + from google.cloud import pubsub |
| 82 | + |
| 83 | + publisher_client = pubsub.PublisherClient() |
| 84 | + topic_path = publisher_client.topic_path(project, f"{topic_name}-{UUID}") |
| 85 | + topic = publisher_client.create_topic(topic_path) |
| 86 | + |
| 87 | + print(f"pubsub_topic: {topic.name}") |
| 88 | + yield topic.name |
| 89 | + |
| 90 | + # Due to the pinned library dependencies in apache-beam, client |
| 91 | + # library throws an error upon deletion. |
| 92 | + # We use gcloud for a workaround. See also: |
| 93 | + # https://github.com/GoogleCloudPlatform/python-docs-samples/issues/4492 |
| 94 | + cmd = ["gcloud", "pubsub", "--project", project, "topics", "delete", topic.name] |
| 95 | + print(cmd) |
| 96 | + subprocess.run(cmd, check=True) |
| 97 | + |
| 98 | + @staticmethod |
| 99 | + def pubsub_subscription( |
| 100 | + topic_path: str, |
| 101 | + subscription_name: str, |
| 102 | + project: str = PROJECT, |
| 103 | + ) -> str: |
| 104 | + from google.cloud import pubsub |
| 105 | + |
| 106 | + subscriber = pubsub.SubscriberClient() |
| 107 | + subscription_path = subscriber.subscription_path( |
| 108 | + project, f"{subscription_name}-{UUID}" |
| 109 | + ) |
| 110 | + subscription = subscriber.create_subscription(subscription_path, topic_path) |
| 111 | + |
| 112 | + print(f"pubsub_subscription: {subscription.name}") |
| 113 | + yield subscription.name |
| 114 | + |
| 115 | + # Due to the pinned library dependencies in apache-beam, client |
| 116 | + # library throws an error upon deletion. |
| 117 | + # We use gcloud for a workaround. See also: |
| 118 | + # https://github.com/GoogleCloudPlatform/python-docs-samples/issues/4492 |
| 119 | + cmd = [ |
| 120 | + "gcloud", |
| 121 | + "pubsub", |
| 122 | + "--project", |
| 123 | + project, |
| 124 | + "subscriptions", |
| 125 | + "delete", |
| 126 | + subscription.name, |
| 127 | + ] |
| 128 | + print(cmd) |
| 129 | + subprocess.run(cmd, check=True) |
| 130 | + |
| 131 | + @staticmethod |
| 132 | + def pubsub_publisher( |
| 133 | + topic_path: str, |
| 134 | + new_msg: Callable[[int], str] = lambda i: json.dumps( |
| 135 | + {"id": i, "content": f"message {i}"} |
| 136 | + ), |
| 137 | + sleep_sec: int = 1, |
| 138 | + ) -> bool: |
| 139 | + from google.cloud import pubsub |
| 140 | + |
| 141 | + def _infinite_publish_job() -> None: |
| 142 | + publisher_client = pubsub.PublisherClient() |
| 143 | + for i in itertools.count(): |
| 144 | + msg = new_msg(i) |
| 145 | + publisher_client.publish(topic_path, msg.encode("utf-8")).result() |
| 146 | + time.sleep(sleep_sec) |
| 147 | + |
| 148 | + # Start a subprocess in the background to do the publishing. |
| 149 | + print(f"Starting publisher on {topic_path}") |
| 150 | + p = mp.Process(target=_infinite_publish_job) |
| 151 | + p.start() |
| 152 | + |
| 153 | + yield p.is_alive() |
| 154 | + |
| 155 | + # For cleanup, terminate the background process. |
| 156 | + print("Stopping publisher") |
| 157 | + p.join(timeout=0) |
| 158 | + p.terminate() |
| 159 | + |
| 160 | + @staticmethod |
| 161 | + def container_image( |
| 162 | + image_path: str, |
| 163 | + project: str = PROJECT, |
| 164 | + tag: str = "latest", |
| 165 | + ) -> str: |
| 166 | + image_name = f"gcr.io/{project}/{image_path}-{UUID}:{tag}" |
| 167 | + cmd = ["gcloud", "auth", "configure-docker"] |
| 168 | + print(cmd) |
| 169 | + subprocess.run(cmd, check=True) |
| 170 | + cmd = [ |
| 171 | + "gcloud", |
| 172 | + "builds", |
| 173 | + "submit", |
| 174 | + f"--project={project}", |
| 175 | + f"--tag={image_name}", |
| 176 | + ".", |
| 177 | + ] |
| 178 | + print(cmd) |
| 179 | + subprocess.run(cmd, check=True) |
| 180 | + |
| 181 | + print(f"container_image: {image_name}") |
| 182 | + yield image_name |
| 183 | + |
| 184 | + cmd = [ |
| 185 | + "gcloud", |
| 186 | + "container", |
| 187 | + "images", |
| 188 | + "delete", |
| 189 | + image_name, |
| 190 | + f"--project={project}", |
| 191 | + "--quiet", |
| 192 | + ] |
| 193 | + print(cmd) |
| 194 | + subprocess.run(cmd, check=True) |
| 195 | + |
| 196 | + @staticmethod |
| 197 | + def dataflow_job_id_from_job_name( |
| 198 | + job_name: str, |
| 199 | + project: str = PROJECT, |
| 200 | + ) -> Optional[str]: |
| 201 | + from googleapiclient.discovery import build |
| 202 | + |
| 203 | + dataflow = build("dataflow", "v1b3") |
| 204 | + |
| 205 | + # Only return the 50 most recent results - our job is likely to be in here. |
| 206 | + # If the job is not found, first try increasing this number.[]''job_id |
| 207 | + # For more info see: |
| 208 | + # https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.jobs/list |
| 209 | + jobs_request = ( |
| 210 | + dataflow.projects() |
| 211 | + .jobs() |
| 212 | + .list( |
| 213 | + projectId=project, |
| 214 | + filter="ACTIVE", |
| 215 | + pageSize=50, |
| 216 | + ) |
| 217 | + ) |
| 218 | + response = jobs_request.execute() |
| 219 | + |
| 220 | + # Search for the job in the list that has our name (names are unique) |
| 221 | + for job in response["jobs"]: |
| 222 | + if job["name"] == job_name: |
| 223 | + return job["id"] |
| 224 | + return None |
| 225 | + |
| 226 | + @staticmethod |
| 227 | + def dataflow_jobs_wait( |
| 228 | + job_id: str, |
| 229 | + project: str = PROJECT, |
| 230 | + status: str = "JOB_STATE_RUNNING", |
| 231 | + ) -> bool: |
| 232 | + from googleapiclient.discovery import build |
| 233 | + |
| 234 | + dataflow = build("dataflow", "v1b3") |
| 235 | + |
| 236 | + sleep_time_seconds = 30 |
| 237 | + max_sleep_time = 10 * 60 |
| 238 | + |
| 239 | + print(f"Waiting for Dataflow job ID: {job_id} (until status {status})") |
| 240 | + for _ in range(0, max_sleep_time, sleep_time_seconds): |
| 241 | + try: |
| 242 | + # For more info see: |
| 243 | + # https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.jobs/get |
| 244 | + jobs_request = ( |
| 245 | + dataflow.projects() |
| 246 | + .jobs() |
| 247 | + .get( |
| 248 | + projectId=project, |
| 249 | + jobId=job_id, |
| 250 | + view="JOB_VIEW_SUMMARY", |
| 251 | + ) |
| 252 | + ) |
| 253 | + response = jobs_request.execute() |
| 254 | + print(response) |
| 255 | + if response["currentState"] == status: |
| 256 | + return True |
| 257 | + except: |
| 258 | + pass |
| 259 | + time.sleep(sleep_time_seconds) |
| 260 | + return False |
| 261 | + |
| 262 | + @staticmethod |
| 263 | + def dataflow_jobs_cancel_by_job_id( |
| 264 | + job_id: str, project: str = PROJECT, region: str = REGION |
| 265 | + ) -> None: |
| 266 | + print(f"Canceling Dataflow job ID: {job_id}") |
| 267 | + # We get an error using the googleapiclient.discovery APIs, probably |
| 268 | + # due to incompatible dependencies with apache-beam. |
| 269 | + # We use gcloud instead to cancel the job. |
| 270 | + # https://cloud.google.com/sdk/gcloud/reference/dataflow/jobs/cancel |
| 271 | + cmd = [ |
| 272 | + "gcloud", |
| 273 | + f"--project={project}", |
| 274 | + "dataflow", |
| 275 | + "jobs", |
| 276 | + "cancel", |
| 277 | + job_id, |
| 278 | + f"--region={region}", |
| 279 | + ] |
| 280 | + subprocess.run(cmd, check=True) |
| 281 | + |
| 282 | + @staticmethod |
| 283 | + def dataflow_jobs_cancel_by_job_name( |
| 284 | + job_name: str, project: str = PROJECT, region: str = REGION |
| 285 | + ) -> None: |
| 286 | + # To cancel a dataflow job, we need its ID, not its name. |
| 287 | + # If it doesn't, job_id will be equal to None. |
| 288 | + job_id = Utils.dataflow_job_id_from_job_name(project, job_name) |
| 289 | + if job_id is not None: |
| 290 | + Utils.dataflow_jobs_cancel_by_job_id(job_id, project, region) |
| 291 | + |
| 292 | + @staticmethod |
| 293 | + def dataflow_flex_template_build( |
| 294 | + bucket_name: str, |
| 295 | + template_image: str, |
| 296 | + metadata_file: str, |
| 297 | + project: str = PROJECT, |
| 298 | + template_file: str = "template.json", |
| 299 | + ) -> str: |
| 300 | + # https://cloud.google.com/sdk/gcloud/reference/dataflow/flex-template/build |
| 301 | + template_gcs_path = f"gs://{bucket_name}/{template_file}" |
| 302 | + cmd = [ |
| 303 | + "gcloud", |
| 304 | + "dataflow", |
| 305 | + "flex-template", |
| 306 | + "build", |
| 307 | + template_gcs_path, |
| 308 | + f"--project={project}", |
| 309 | + f"--image={template_image}", |
| 310 | + "--sdk-language=PYTHON", |
| 311 | + f"--metadata-file={metadata_file}", |
| 312 | + ] |
| 313 | + print(cmd) |
| 314 | + subprocess.run(cmd, check=True) |
| 315 | + |
| 316 | + print(f"dataflow_flex_template_build: {template_gcs_path}") |
| 317 | + yield template_gcs_path |
| 318 | + # The template file gets deleted when we delete the bucket. |
| 319 | + |
| 320 | + @staticmethod |
| 321 | + def dataflow_flex_template_run( |
| 322 | + job_name: str, |
| 323 | + template_path: str, |
| 324 | + bucket_name: str, |
| 325 | + parameters: Dict[str, str] = {}, |
| 326 | + project: str = PROJECT, |
| 327 | + region: str = REGION, |
| 328 | + ) -> str: |
| 329 | + import yaml |
| 330 | + |
| 331 | + # https://cloud.google.com/sdk/gcloud/reference/dataflow/flex-template/run |
| 332 | + unique_job_name = f"{job_name}-{UUID}" |
| 333 | + print(f"dataflow_job_name: {unique_job_name}") |
| 334 | + cmd = [ |
| 335 | + "gcloud", |
| 336 | + "dataflow", |
| 337 | + "flex-template", |
| 338 | + "run", |
| 339 | + unique_job_name, |
| 340 | + f"--template-file-gcs-location={template_path}", |
| 341 | + f"--project={project}", |
| 342 | + f"--region={region}", |
| 343 | + ] + [ |
| 344 | + f"--parameters={name}={value}" |
| 345 | + for name, value in { |
| 346 | + **parameters, |
| 347 | + "temp_location": f"gs://{bucket_name}/temp", |
| 348 | + }.items() |
| 349 | + ] |
| 350 | + print(cmd) |
| 351 | + try: |
| 352 | + # The `capture_output` option was added in Python 3.7, so we must |
| 353 | + # pass the `stdout` and `stderr` options explicitly to support 3.6. |
| 354 | + # https://docs.python.org/3/library/subprocess.html#subprocess.run |
| 355 | + p = subprocess.run( |
| 356 | + cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE |
| 357 | + ) |
| 358 | + stdout = p.stdout.decode("utf-8") |
| 359 | + stderr = p.stderr.decode("utf-8") |
| 360 | + print(f"Launched Dataflow Flex Template job: {unique_job_name}") |
| 361 | + except subprocess.CalledProcessError as e: |
| 362 | + print(e, file=sys.stderr) |
| 363 | + stdout = stdout.decode("utf-8") |
| 364 | + stderr = stderr.decode("utf-8") |
| 365 | + finally: |
| 366 | + print("--- stderr ---") |
| 367 | + print(stderr) |
| 368 | + print("--- stdout ---") |
| 369 | + print(stdout) |
| 370 | + print("--- end ---") |
| 371 | + return yaml.safe_load(stdout)["job"]["id"] |
| 372 | + |
| 373 | + |
| 374 | +@pytest.fixture(scope="session") |
| 375 | +def utils() -> Utils: |
| 376 | + print(f"Test unique identifier: {UUID}") |
| 377 | + subprocess.run(["gcloud", "version"]) |
| 378 | + return Utils() |
0 commit comments