-
Notifications
You must be signed in to change notification settings - Fork 236
/
Copy pathserver.py
705 lines (609 loc) · 22.8 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
import asyncio
import json
import logging
import time
from contextlib import asynccontextmanager
from io import BytesIO
from pathlib import Path
import cv2
import fire
import matplotlib.pyplot as plt
import numpy as np
import requests
import torch
import torch._dynamo.config
import torch._inductor.config
import uvicorn
from compile_export_utils import (
export_model,
load_exported_model,
set_fast,
set_furious,
)
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from torch._inductor import config as inductorconfig
from torchao._models.utils import (
get_arch_name,
write_json_result_local,
write_json_result_ossci,
)
inductorconfig.triton.unique_kernel_names = True
inductorconfig.coordinate_descent_tuning = True
inductorconfig.coordinate_descent_check_all_directions = True
inductorconfig.allow_buffer_reuse = False
# torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
def download_file(url, download_dir):
# Create the directory if it doesn't exist
download_dir = Path(download_dir)
download_dir.mkdir(parents=True, exist_ok=True)
# Extract the file name from the URL
file_name = url.split("/")[-1]
# Define the full path for the downloaded file
file_path = download_dir / file_name
# Download the file
response = requests.get(url, stream=True)
response.raise_for_status() # Raise an error for bad responses
# Write the file to the specified directory
print(f"Downloading '{file_name}' to '{download_dir}'")
with open(file_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Downloaded '{file_name}' to '{download_dir}'")
def example_shapes():
return [
(848, 480, 3),
(720, 1280, 3),
(848, 480, 3),
(1280, 720, 3),
(480, 848, 3),
(1080, 1920, 3),
(1280, 720, 3),
(1280, 720, 3),
(720, 1280, 3),
(848, 480, 3),
(480, 848, 3),
(864, 480, 3),
(1920, 1080, 3),
(1920, 1080, 3),
(1280, 720, 3),
(1232, 672, 3),
(848, 480, 3),
(848, 480, 3),
(1920, 1080, 3),
(1080, 1920, 3),
(480, 848, 3),
(848, 480, 3),
(480, 848, 3),
(480, 848, 3),
(720, 1280, 3),
(720, 1280, 3),
(900, 720, 3),
(848, 480, 3),
(864, 480, 3),
(360, 640, 3),
(360, 640, 3),
(864, 480, 3),
]
def example_shapes_2():
return [
(1080, 1920, 3),
(1920, 1080, 3),
(1920, 1080, 3),
(1080, 1920, 3),
(848, 480, 3),
(864, 480, 3),
(720, 1280, 3),
(864, 480, 3),
(848, 480, 3),
(848, 480, 3),
(848, 480, 3),
(848, 480, 3),
(720, 1280, 3),
(864, 480, 3),
(480, 848, 3),
(1280, 720, 3),
(720, 1280, 3),
(1080, 1920, 3),
(1080, 1920, 3),
(1280, 720, 3),
(1080, 1920, 3),
(1080, 1920, 3),
(720, 1280, 3),
(720, 1280, 3),
(1280, 720, 3),
(360, 640, 3),
(864, 480, 3),
(1920, 1080, 3),
(1080, 1920, 3),
(1920, 1080, 3),
(1920, 1080, 3),
(1080, 1920, 3),
]
# torch.set_float32_matmul_precision('high')
def iou(mask1, mask2):
assert mask1.dim() == 2
assert mask2.dim() == 2
intersection = torch.logical_and(mask1, mask2)
union = torch.logical_or(mask1, mask2)
return intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2))
def show_anns(anns, rle_to_mask, sort_by_area=True, seed=None):
if len(anns) == 0:
return
if sort_by_area:
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
else:
sorted_anns = anns
ax = plt.gca()
ax.set_autoscale_on(False)
for ann in sorted_anns:
ann["segmentation"] = rle_to_mask(ann["segmentation"])
img = np.ones(
(
sorted_anns[0]["segmentation"].shape[0],
sorted_anns[0]["segmentation"].shape[1],
4,
)
)
img[:, :, 3] = 0
np.random.seed(seed)
ms = []
for ann in sorted_anns:
m = ann["segmentation"]
ms.append(torch.as_tensor(m))
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
return torch.stack(ms)
def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result
def memory_runner(path, fn, *args, **kwargs):
print("Start memory recording")
torch.cuda.synchronize()
torch.cuda.memory._record_memory_history(
True, trace_alloc_max_entries=100000, trace_alloc_record_context=True
)
result = fn(*args, **kwargs)
torch.cuda.synchronize()
snapshot = torch.cuda.memory._snapshot()
print("Finish memory recording")
import pickle
with open(path, "wb") as f:
pickle.dump(snapshot, f)
# Use to convert pickle file into html
# python torch/cuda/_memory_viz.py trace_plot <snapshot>.pickle -o <snapshot>.html
return result
def image_tensor_to_masks(example_image, mask_generator):
masks = mask_generator.generate(example_image)
return masks
def image_tensors_to_masks(example_images, mask_generator):
return mask_generator.generate_batch(example_images)
def file_bytes_to_image_tensor(file_bytes, output_format="numpy"):
image_array = np.asarray(file_bytes, dtype=np.uint8)
example_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
if output_format == "numpy":
return example_image
if output_format not in ["torch"]:
raise ValueError(
f"Expected output_format to be numpy or torch, but got {output_format}"
)
from torchvision.transforms import ToTensor
return ToTensor()(example_image)
def masks_to_rle_dict(masks):
ret_data = {}
for mask_id in range(len(masks)):
ret_data[f"mask_{mask_id}"] = masks[mask_id]["segmentation"]
return ret_data
# Queue to hold incoming requests
request_queue = asyncio.Queue()
batch_interval = 0.01 # Time interval to wait before processing a batch
def process_batch(batch, mask_generator):
t = time.time()
image_tensors = [image_tensor for (image_tensor, _) in batch]
if len(batch) == 1:
print(f"Processing batch of len {len(batch)} using generate")
masks = [mask_generator.generate(image_tensors[0])]
else:
print(f"Processing batch of len {len(batch)} using generate_batch")
masks = mask_generator.generate_batch(image_tensors)
print(f"Took avg. {(time.time() - t) / len(batch)}s per batch entry")
max_memory_allocated()
return masks
async def batch_worker(mask_generator, batch_size, *, pad_batch=True, furious=False):
while True:
batch = []
while len(batch) < batch_size and not request_queue.empty():
batch.append(await request_queue.get())
if batch:
padded_batch = batch
if pad_batch:
padded_batch = batch + ([batch[-1]] * (batch_size - len(batch)))
results = process_batch(padded_batch, mask_generator)
for i, (_, response_future) in enumerate(batch):
response_future.set_result(results[i])
await asyncio.sleep(batch_interval)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup logic
mask_generator = app.state.mask_generator
batch_size = app.state.batch_size
furious = app.state.furious
task = asyncio.create_task(
batch_worker(mask_generator, batch_size, furious=furious)
)
yield
# Shutdown logic (if needed)
task.cancel()
def benchmark_fn(func, inp, mask_generator, warmup=3, runs=10):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
logging.info("Running {warmup} warmup iterations.")
for _ in range(warmup):
func(inp, mask_generator)
logging.info("Running {runs} benchmark iterations.")
t = time.time()
for _ in range(runs):
func(inp, mask_generator)
avg_time_per_run = (time.time() - t) / runs
print(f"Benchmark took {avg_time_per_run}s per iteration.")
max_memory_allocated_bytes, max_memory_allocated_percentage = max_memory_allocated()
return avg_time_per_run, max_memory_allocated_bytes, max_memory_allocated_percentage
def max_memory_allocated_stats():
max_memory_allocated_bytes = torch.cuda.max_memory_allocated()
_, total_memory = torch.cuda.mem_get_info()
max_memory_allocated_percentage = int(
100 * (max_memory_allocated_bytes / total_memory)
)
return {
"bytes": max_memory_allocated_bytes,
"percentage": max_memory_allocated_percentage,
}
def max_memory_allocated():
stats = max_memory_allocated_stats()
mib = stats["bytes"] >> 20
print(f"max_memory_allocated_bytes: {mib}MiB")
print(f"max_memory_allocated_percentage: {stats['percentage']}%")
return mib, stats["percentage"]
def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False):
from compare_rle_lists import compare_masks
miou, equal_count = compare_masks(
masks, ref_masks, order_by_area=order_by_area, verbose=verbose
)
if equal_count == len(masks):
print("Masks exactly match reference.")
else:
print(f"mIoU is {miou} with equal count {equal_count} out of {len(masks)}")
MODEL_TYPES_TO_CONFIG = {
"tiny": "sam2.1_hiera_t.yaml",
"small": "sam2.1_hiera_s.yaml",
"plus": "sam2.1_hiera_b+.yaml",
"large": "sam2.1_hiera_l.yaml",
}
MODEL_TYPES_TO_MODEL = {
"tiny": "sam2.1_hiera_tiny.pt",
"small": "sam2.1_hiera_small.pt",
"plus": "sam2.1_hiera_base_plus.pt",
"large": "sam2.1_hiera_large.pt",
}
MODEL_TYPES_TO_URL = {
"tiny": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"small": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"plus": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"large": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
}
def main_docstring():
return f"""
Args:
checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints
model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())}
"""
def model_type_to_paths(checkpoint_path, model_type):
if model_type not in MODEL_TYPES_TO_CONFIG.keys():
raise ValueError(
f"Expected model_type to be one of {', '.join(MODEL_TYPES_TO_MODEL.keys())} but got {model_type}"
)
sam2_checkpoint = Path(checkpoint_path) / Path(MODEL_TYPES_TO_MODEL[model_type])
if not sam2_checkpoint.exists():
print(
f"Can't find checkpoint {sam2_checkpoint} in folder {checkpoint_path}. Downloading."
)
download_file(MODEL_TYPES_TO_URL[model_type], checkpoint_path)
assert sam2_checkpoint.exists(), "Can't find downloaded file. Please open an issue."
model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}"
return sam2_checkpoint, model_cfg
def set_autoquant(mask_generator, autoquant_type, min_sqnr):
import torchao
from torchao import autoquant
# NOTE: Not baseline feature
if autoquant_type == "autoquant":
mask_generator.predictor.model.image_encoder = autoquant(
mask_generator.predictor.model.image_encoder, min_sqnr=min_sqnr
)
elif autoquant_type == "autoquant-fp":
mask_generator.predictor.model.image_encoder = autoquant(
mask_generator.predictor.model.image_encoder,
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
min_sqnr=min_sqnr,
)
elif autoquant_type == "autoquant-all":
mask_generator.predictor.model.image_encoder = autoquant(
mask_generator.predictor.model.image_encoder,
qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST,
min_sqnr=min_sqnr,
)
else:
raise ValueError(f"Unexpected autoquant type: {autoquant_type}")
mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision("high")
# NOTE: this fails when we run
# python server.py ~/checkpoints/sam2 large --port 8000 --host localhost --fast --autoquant_type autoquant --unittest
# https://gist.github.com/jerryzh168/d337cb5de0a1dec306069fe48ac8225e
# mask_generator.predictor.model.sam_mask_decoder = autoquant(mask_generator.predictor.model.sam_mask_decoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
def main(
checkpoint_path,
model_type,
baseline=False,
fast=False,
furious=False,
autoquant_type=None,
min_sqnr=None,
unittest=False,
benchmark=False,
profile=None,
memory_profile=None,
verbose=False,
points_per_batch=64,
port=5000,
host="127.0.0.1",
dry=False,
batch_size=1,
load_fast="",
save_fast="",
output_json_path=None,
output_json_local=False,
):
if verbose:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logging.info(f"Running with fast set to {fast} and furious set to {furious}")
logging.info(f"Running with port {port} and host {host}")
logging.info(f"Running with batch size {batch_size}")
if baseline:
assert batch_size == 1, "baseline only supports batch size 1."
logging.info(
"Importing sam2 from outside of torchao. If this errors, install https://github.com/facebookresearch/sam2"
)
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from sam2.utils.amg import rle_to_mask
else:
from torchao._models.sam2.automatic_mask_generator import (
SAM2AutomaticMaskGenerator,
)
from torchao._models.sam2.build_sam import build_sam2
from torchao._models.sam2.utils.amg import rle_to_mask
device = "cuda"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}")
sam2 = build_sam2(
model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False
)
logging.info(f"Using {points_per_batch} points_per_batch")
mask_generator = SAM2AutomaticMaskGenerator(
sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle"
)
if load_fast != "":
load_exported_model(
mask_generator, load_fast, "amg", furious, batch_size, points_per_batch
)
if furious:
set_furious(mask_generator)
if save_fast != "":
assert (
load_fast == ""
), "Can't save compiled models while loading them with --load-fast."
assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible."
print(f"Saving compiled models under directory {save_fast}")
export_model(
mask_generator,
save_fast,
"amg",
furious=furious,
batch_size=batch_size,
points_per_batch=points_per_batch,
)
if fast:
assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible."
set_fast(mask_generator, "amg", load_fast)
# since autoquant is replicating what furious mode is doing, don't use these two together
if autoquant_type is not None:
assert not furious, "autoquant can't be used together with furious"
set_autoquant(mask_generator, autoquant_type, min_sqnr)
with open("dog.jpg", "rb") as f:
output_format = "numpy" if baseline else "torch"
image_tensor = file_bytes_to_image_tensor(
bytearray(f.read()), output_format=output_format
)
# from torchvision import io as tio
# img_bytes_tensor = tio.read_file('dog.jpg')
# image_tensor = tio.decode_jpeg(img_bytes_tensor, device='cuda', mode=tio.ImageReadMode.RGB)
if unittest:
if batch_size == 1:
logging.info("batch size 1 unittest")
masks = image_tensor_to_masks(image_tensor, mask_generator)
masks = masks_to_rle_dict(masks)
ref_masks = json.loads(open("dog_rle.json").read())
unittest_fn(masks, ref_masks, order_by_area=True, verbose=verbose)
else:
# TODO: Transpose dog image to create diversity in input image shape
logging.info(f"batch size {batch_size} unittest")
all_masks = image_tensors_to_masks(
[image_tensor] * batch_size, mask_generator
)
all_masks = [masks_to_rle_dict(masks) for masks in all_masks]
ref_masks = json.loads(open("dog_rle.json").read())
for masks in all_masks:
unittest_fn(masks, ref_masks, order_by_area=True, verbose=verbose)
if benchmark:
print(f"batch size {batch_size} dog benchmark")
if batch_size == 1:
result = benchmark_fn(image_tensor_to_masks, image_tensor, mask_generator)
else:
result = benchmark_fn(
image_tensors_to_masks, [image_tensor] * batch_size, mask_generator
)
for i, shapes in enumerate([example_shapes(), example_shapes_2()]):
print(f"batch size {batch_size} example shapes {i} benchmark")
random_images = [
np.random.randint(0, 256, size=size, dtype=np.uint8) for size in shapes
]
if batch_size > len(random_images):
num_repeat = (len(random_images) + batch_size) // batch_size
random_images = num_repeat * random_images
if batch_size == 1:
[
benchmark_fn(image_tensor_to_masks, r, mask_generator)
for r in random_images
]
else:
random_images = random_images[:batch_size]
print("len(random_images): ", len(random_images))
benchmark_fn(image_tensors_to_masks, random_images, mask_generator)
if output_json_path:
headers = [
"name",
"dtype",
"min_sqnr",
"compile",
"device",
"arch",
"metric",
"actual",
"target",
]
name = "sam2-" + model_type
arch = get_arch_name()
dtype = autoquant_type or "noquant"
# boolean flag to indicate whether it's eager or compile
compile = fast
(
avg_time_per_run,
max_memory_allocated_bytes,
max_memory_allocated_percentage,
) = result
memory_result = [
name,
dtype,
min_sqnr,
compile,
device,
arch,
"memory(MiB)",
max_memory_allocated_bytes,
None,
]
performance_result = [
name,
dtype,
min_sqnr,
compile,
device,
arch,
"time_s(avg)",
avg_time_per_run,
None,
]
write_json_result = (
write_json_result_local
if output_json_local
else write_json_result_ossci
)
write_json_result(output_json_path, headers, memory_result)
write_json_result(output_json_path, headers, performance_result)
if profile is not None:
print(f"Saving profile under {profile}")
if batch_size == 1:
profiler_runner(
profile, image_tensor_to_masks, image_tensor, mask_generator
)
else:
profiler_runner(
profile,
image_tensors_to_masks,
[image_tensor] * batch_size,
mask_generator,
)
if memory_profile is not None:
print(f"Saving memory profile under {memory_profile}")
if batch_size == 1:
memory_runner(
memory_profile, image_tensor_to_masks, image_tensor, mask_generator
)
else:
memory_runner(
memory_profile,
image_tensors_to_masks,
[image_tensor] * batch_size,
mask_generator,
)
if dry:
return
app = FastAPI(lifespan=lifespan)
app.state.mask_generator = mask_generator
app.state.batch_size = batch_size
app.state.furious = furious
# Allow all origins (you can restrict it in production)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/upload_rle")
async def upload_rle(image: UploadFile = File(...)):
image_tensor = file_bytes_to_image_tensor(bytearray(await image.read()))
response_future = asyncio.Future()
await request_queue.put((image_tensor, response_future))
masks = await response_future
return masks_to_rle_dict(masks)
@app.post("/upload")
async def upload_image(image: UploadFile = File(...)):
image_tensor = file_bytes_to_image_tensor(bytearray(await image.read()))
response_future = asyncio.Future()
await request_queue.put((image_tensor, response_future))
masks = await response_future
# Create figure and ensure it's closed after generating response
fig = plt.figure(
figsize=(image_tensor.shape[1] / 100.0, image_tensor.shape[0] / 100.0),
dpi=100,
)
plt.imshow(image_tensor)
show_anns(masks, rle_to_mask)
plt.axis("off")
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
plt.close(fig) # Close figure after we're done with it
return StreamingResponse(buf, media_type="image/png")
# uvicorn.run(app, host=host, port=port, log_level="info")
uvicorn.run(app, host=host, port=port)
main.__doc__ = main_docstring()
if __name__ == "__main__":
fire.Fire(main)