|
13 | 13 | import requests
|
14 | 14 |
|
15 | 15 | import test_utils
|
| 16 | +from concurrent import futures |
16 | 17 |
|
17 | 18 | REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")
|
18 | 19 | data_file_mnist = os.path.join(REPO_ROOT, "examples/image_classifier/mnist/test_data/1.png")
|
@@ -137,17 +138,22 @@ def test_batch_input(set_custom_handler, handler_name):
|
137 | 138 | """
|
138 | 139 | Tests pytorch profiler integration with batch inference
|
139 | 140 | """
|
| 141 | + |
140 | 142 | CUSTOM_PATH = "/tmp/output/resnet-152-batch"
|
| 143 | + |
141 | 144 | if os.path.exists(CUSTOM_PATH):
|
142 | 145 | shutil.rmtree(CUSTOM_PATH)
|
143 | 146 | assert os.path.exists(data_file_resnet)
|
144 | 147 |
|
145 |
| - cmd = ["bash", os.path.join(profiler_utils, "resnet_batch.sh")] |
146 |
| - |
147 |
| - proc = subprocess.run(cmd, stdout=subprocess.PIPE, check=True, timeout=1000) |
| 148 | + def invoke_batch_input(): |
| 149 | + data = open(data_file_resnet, "rb") |
| 150 | + response = requests.post("{}/predictions/resnet152".format(TF_INFERENCE_API), data) |
| 151 | + assert response.status_code == 200 |
| 152 | + assert "tiger_cat" in json.loads(response.content) |
148 | 153 |
|
149 |
| - assert "tiger_cat" in proc.stdout.decode("utf-8") |
150 |
| - assert "Labrador_retriever" in proc.stdout.decode("utf-8") |
| 154 | + with futures.ThreadPoolExecutor(2) as executor: |
| 155 | + for _ in range(2): |
| 156 | + executor.submit(invoke_batch_input) |
151 | 157 |
|
152 | 158 | assert len(glob.glob("{}/*.pt.trace.json".format(CUSTOM_PATH))) == 1
|
153 | 159 | test_utils.unregister_model("resnet152")
|
|
0 commit comments