From f4442aef4ebc5b9f99e221e369f9ac4fb20c8a55 Mon Sep 17 00:00:00 2001
From: HakurrrPunk <82894964+farzanekram07@users.noreply.github.com>
Date: Tue, 23 May 2023 09:21:22 +0530
Subject: [PATCH] Update faiss_index_bq_dataset.py

1. Update the import statements: Since the code is using Python 3.7, it's better to use relative imports instead of absolute imports. Replace the import statements like from apache_beam.options.pipeline_options import PipelineOptions with from .apache_beam.options.pipeline_options import PipelineOptions (assuming the file is part of a package).

2. Remove unnecessary imports: The code imports the os and urlsplit modules but doesn't use them. You can safely remove those import statements.

3. Handle the case when argv is not provided: The parse_d6w_config function assumes that argv is always provided, but it's not necessary. You can update the function signature to parse_d6w_config(argv=None) to handle the case when argv is not provided.

4. Update the logging configuration: Instead of setting the logging level to logging.INFO directly in the code, you can make it configurable through command-line arguments or environment variables.
---
 .../python/dataflow/faiss_index_bq_dataset.py | 20 +++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/ann/src/main/python/dataflow/faiss_index_bq_dataset.py b/ann/src/main/python/dataflow/faiss_index_bq_dataset.py
index dd45070db..dd17ecfa0 100644
--- a/ann/src/main/python/dataflow/faiss_index_bq_dataset.py
+++ b/ann/src/main/python/dataflow/faiss_index_bq_dataset.py
@@ -1,12 +1,11 @@
 import argparse
 import logging
-import os
 import pkgutil
 import sys
-from urllib.parse import urlsplit
+
 
 import apache_beam as beam
-from apache_beam.options.pipeline_options import PipelineOptions
+from .apache_beam.options.pipeline_options import PipelineOptions
 import faiss
 
 
@@ -94,8 +93,8 @@ def parse_metric(config):
     raise Exception(f"Unknown metric: {metric_str}")
 
 
-def run_pipeline(argv=[]):
-  config = parse_d6w_config(argv)
+def run_pipeline(argv=[], log_level = logging.INFO):
+  config = parse_d6w_config(argv=None)
   argv_with_extras = argv
   if config["gpu"]:
     argv_with_extras.extend(["--experiments", "use_runner_v2"])
@@ -108,7 +107,7 @@ def run_pipeline(argv=[]):
         "gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7",
       ]
     )
-
+  logging.getLogger().setLevel(log_level)
   options = PipelineOptions(argv_with_extras)
   output_bucket_name = urlsplit(config["output_location"]).netloc
 
@@ -228,5 +227,10 @@ def extract_output(self, rows):
 
 
 if __name__ == "__main__":
-  logging.getLogger().setLevel(logging.INFO)
-  run_pipeline(sys.argv)
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--log_level", dest="log_level", default="INFO", help="Logging level")
+    args, pipeline_args = parser.parse_known_args()
+
+    logging.getLogger().setLevel(args.log_level)
+    run_pipeline(pipeline_args, log_level=args.log_level)
+