Skip to content

Commit bf2ce16

Browse files
committed
Add support for apps created by factory functions
1 parent 63267a3 commit bf2ce16

14 files changed

+738
-39
lines changed

Diff for: src/fastapi_cli/cli.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _run(
6262
proxy_headers: bool = False,
6363
) -> None:
6464
try:
65-
use_uvicorn_app = get_import_string(path=path, app_name=app)
65+
use_uvicorn_app, is_factory = get_import_string(path=path, app_name=app)
6666
except FastAPICLIException as e:
6767
logger.error(str(e))
6868
raise typer.Exit(code=1) from None
@@ -97,6 +97,7 @@ def _run(
9797
workers=workers,
9898
root_path=root_path,
9999
proxy_headers=proxy_headers,
100+
factory=is_factory,
100101
)
101102

102103

Diff for: src/fastapi_cli/discover.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
from logging import getLogger
55
from pathlib import Path
6-
from typing import Union
6+
from typing import Any, Callable, Union, get_type_hints
77

88
from rich import print
99
from rich.padding import Padding
@@ -98,7 +98,9 @@ def get_module_data_from_path(path: Path) -> ModuleData:
9898
)
9999

100100

101-
def get_app_name(*, mod_data: ModuleData, app_name: Union[str, None] = None) -> str:
101+
def get_app_name(
102+
*, mod_data: ModuleData, app_name: Union[str, None] = None
103+
) -> tuple[str, bool]:
102104
try:
103105
mod = importlib.import_module(mod_data.module_import_str)
104106
except (ImportError, ValueError) as e:
@@ -119,26 +121,40 @@ def get_app_name(*, mod_data: ModuleData, app_name: Union[str, None] = None) ->
119121
f"Could not find app name {app_name} in {mod_data.module_import_str}"
120122
)
121123
app = getattr(mod, app_name)
124+
is_factory = False
122125
if not isinstance(app, FastAPI):
123-
raise FastAPICLIException(
124-
f"The app name {app_name} in {mod_data.module_import_str} doesn't seem to be a FastAPI app"
125-
)
126-
return app_name
127-
for preferred_name in ["app", "api"]:
126+
is_factory = check_factory(app)
127+
if not is_factory:
128+
raise FastAPICLIException(
129+
f"The app name {app_name} in {mod_data.module_import_str} doesn't seem to be a FastAPI app"
130+
)
131+
return app_name, is_factory
132+
for preferred_name in ["app", "api", "create_app", "create_api"]:
128133
if preferred_name in object_names_set:
129134
obj = getattr(mod, preferred_name)
130135
if isinstance(obj, FastAPI):
131-
return preferred_name
136+
return preferred_name, False
137+
if check_factory(obj):
138+
return preferred_name, True
132139
for name in object_names:
133140
obj = getattr(mod, name)
134141
if isinstance(obj, FastAPI):
135-
return name
142+
return name, False
136143
raise FastAPICLIException("Could not find FastAPI app in module, try using --app")
137144

138145

146+
def check_factory(fn: Callable[[], Any]) -> bool:
147+
"""Checks whether the return-type of a factory function is FastAPI"""
148+
# if not callable(fn):
149+
# return False
150+
type_hints = get_type_hints(fn)
151+
return_type = type_hints.get("return")
152+
return return_type is not None and issubclass(return_type, FastAPI)
153+
154+
139155
def get_import_string(
140156
*, path: Union[Path, None] = None, app_name: Union[str, None] = None
141-
) -> str:
157+
) -> tuple[str, bool]:
142158
if not path:
143159
path = get_default_path()
144160
logger.info(f"Using path [blue]{path}[/blue]")
@@ -147,7 +163,7 @@ def get_import_string(
147163
raise FastAPICLIException(f"Path does not exist {path}")
148164
mod_data = get_module_data_from_path(path)
149165
sys.path.insert(0, str(mod_data.extra_sys_path))
150-
use_app_name = get_app_name(mod_data=mod_data, app_name=app_name)
166+
use_app_name, is_factory = get_app_name(mod_data=mod_data, app_name=app_name)
151167
import_example = Syntax(
152168
f"from {mod_data.module_import_str} import {use_app_name}", "python"
153169
)
@@ -164,4 +180,4 @@ def get_import_string(
164180
print(import_panel)
165181
import_string = f"{mod_data.module_import_str}:{use_app_name}"
166182
logger.info(f"Using import string [b green]{import_string}[/b green]")
167-
return import_string
183+
return import_string, is_factory

Diff for: tests/assets/factory_create_api.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from fastapi import FastAPI
2+
3+
4+
def create_api() -> FastAPI:
5+
app = FastAPI()
6+
7+
@app.get("/")
8+
def app_root():
9+
return {"message": "single file factory app"}
10+
11+
return app

Diff for: tests/assets/factory_create_app.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from fastapi import FastAPI
2+
3+
4+
class App(FastAPI): ...
5+
6+
7+
def create_app_other() -> App:
8+
app = App()
9+
10+
@app.get("/")
11+
def app_root():
12+
return {"message": "single file factory app inherited"}
13+
14+
return app
15+
16+
17+
def create_app() -> FastAPI:
18+
app = FastAPI()
19+
20+
@app.get("/")
21+
def app_root():
22+
return {"message": "single file factory app"}
23+
24+
return app

Diff for: tests/assets/package/mod/factory_api.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from fastapi import FastAPI
2+
3+
4+
def create_api() -> FastAPI:
5+
app = FastAPI()
6+
7+
@app.get("/")
8+
def root():
9+
return {"message": "package create_api"}
10+
11+
return app

Diff for: tests/assets/package/mod/factory_app.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from fastapi import FastAPI
2+
3+
4+
def create_app() -> FastAPI:
5+
app = FastAPI()
6+
7+
@app.get("/")
8+
def root():
9+
return {"message": "package create_app"}
10+
11+
return app

Diff for: tests/assets/package/mod/factory_inherit.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from fastapi import FastAPI
2+
3+
4+
class App(FastAPI): ...
5+
6+
7+
def create_app() -> App:
8+
app = App()
9+
10+
@app.get("/")
11+
def root():
12+
return {"message": "package build_app"}
13+
14+
return app

Diff for: tests/assets/package/mod/factory_other.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from fastapi import FastAPI
2+
3+
4+
def build_app() -> FastAPI:
5+
app = FastAPI()
6+
7+
@app.get("/")
8+
def root():
9+
return {"message": "package build_app"}
10+
11+
return app

Diff for: tests/test_cli.py

+58
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def test_dev() -> None:
2929
"workers": None,
3030
"root_path": "",
3131
"proxy_headers": True,
32+
"factory": False,
3233
}
3334
assert "Using import string single_file_app:app" in result.output
3435
assert (
@@ -40,6 +41,33 @@ def test_dev() -> None:
4041
assert "│ fastapi run" in result.output
4142

4243

44+
def test_dev_factory() -> None:
45+
with changing_dir(assets_path):
46+
with patch.object(uvicorn, "run") as mock_run:
47+
result = runner.invoke(app, ["dev", "factory_create_app.py"])
48+
assert result.exit_code == 0, result.output
49+
assert mock_run.called
50+
assert mock_run.call_args
51+
assert mock_run.call_args.kwargs == {
52+
"app": "factory_create_app:create_app",
53+
"host": "127.0.0.1",
54+
"port": 8000,
55+
"reload": True,
56+
"workers": None,
57+
"root_path": "",
58+
"proxy_headers": True,
59+
"factory": True,
60+
}
61+
assert "Using import string factory_create_app:create_app" in result.output
62+
assert (
63+
"╭────────── FastAPI CLI - Development mode ───────────╮" in result.output
64+
)
65+
assert "│ Serving at: http://127.0.0.1:8000" in result.output
66+
assert "│ API docs: http://127.0.0.1:8000/docs" in result.output
67+
assert "│ Running in development mode, for production use:" in result.output
68+
assert "│ fastapi run" in result.output
69+
70+
4371
def test_dev_args() -> None:
4472
with changing_dir(assets_path):
4573
with patch.object(uvicorn, "run") as mock_run:
@@ -71,6 +99,7 @@ def test_dev_args() -> None:
7199
"workers": None,
72100
"root_path": "/api",
73101
"proxy_headers": False,
102+
"factory": False,
74103
}
75104
assert "Using import string single_file_app:api" in result.output
76105
assert (
@@ -97,6 +126,7 @@ def test_run() -> None:
97126
"workers": None,
98127
"root_path": "",
99128
"proxy_headers": True,
129+
"factory": False,
100130
}
101131
assert "Using import string single_file_app:app" in result.output
102132
assert (
@@ -108,6 +138,33 @@ def test_run() -> None:
108138
assert "│ fastapi dev" in result.output
109139

110140

141+
def test_run_factory() -> None:
142+
with changing_dir(assets_path):
143+
with patch.object(uvicorn, "run") as mock_run:
144+
result = runner.invoke(app, ["run", "factory_create_app.py"])
145+
assert result.exit_code == 0, result.output
146+
assert mock_run.called
147+
assert mock_run.call_args
148+
assert mock_run.call_args.kwargs == {
149+
"app": "factory_create_app:create_app",
150+
"host": "0.0.0.0",
151+
"port": 8000,
152+
"reload": False,
153+
"workers": None,
154+
"root_path": "",
155+
"proxy_headers": True,
156+
"factory": True,
157+
}
158+
assert "Using import string factory_create_app:create_app" in result.output
159+
assert (
160+
"╭─────────── FastAPI CLI - Production mode ───────────╮" in result.output
161+
)
162+
assert "│ Serving at: http://0.0.0.0:8000" in result.output
163+
assert "│ API docs: http://0.0.0.0:8000/docs" in result.output
164+
assert "│ Running in production mode, for development use:" in result.output
165+
assert "│ fastapi dev" in result.output
166+
167+
111168
def test_run_args() -> None:
112169
with changing_dir(assets_path):
113170
with patch.object(uvicorn, "run") as mock_run:
@@ -141,6 +198,7 @@ def test_run_args() -> None:
141198
"workers": 2,
142199
"root_path": "/api",
143200
"proxy_headers": False,
201+
"factory": False,
144202
}
145203
assert "Using import string single_file_app:api" in result.output
146204
assert (

Diff for: tests/test_utils_check_factory.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from fastapi import FastAPI
2+
from fastapi_cli.discover import check_factory
3+
4+
5+
def test_check_untyped_factory() -> None:
6+
def create_app(): # type: ignore[no-untyped-def]
7+
return FastAPI() # pragma: no cover
8+
9+
assert check_factory(create_app) is False
10+
11+
12+
def test_check_typed_factory() -> None:
13+
def create_app() -> FastAPI:
14+
return FastAPI() # pragma: no cover
15+
16+
assert check_factory(create_app) is True
17+
18+
19+
def test_check_typed_factory_inherited() -> None:
20+
class MyApp(FastAPI): ...
21+
22+
def create_app() -> MyApp:
23+
return MyApp() # pragma: no cover
24+
25+
assert check_factory(create_app) is True
26+
27+
28+
def test_create_app_with_different_type() -> None:
29+
def create_app() -> int:
30+
return 1 # pragma: no cover
31+
32+
assert check_factory(create_app) is False

Diff for: tests/test_utils_default_dir.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212

1313
def test_app_dir_main(capsys: CaptureFixture[str]) -> None:
1414
with changing_dir(assets_path / "default_files" / "default_app_dir_main"):
15-
import_string = get_import_string()
15+
import_string, is_factory = get_import_string()
1616
assert import_string == "app.main:app"
17+
assert is_factory is False
1718

1819
captured = capsys.readouterr()
1920
assert "Using path app/main.py" in captured.out
@@ -36,8 +37,9 @@ def test_app_dir_main(capsys: CaptureFixture[str]) -> None:
3637

3738
def test_app_dir_app(capsys: CaptureFixture[str]) -> None:
3839
with changing_dir(assets_path / "default_files" / "default_app_dir_app"):
39-
import_string = get_import_string()
40+
import_string, is_factory = get_import_string()
4041
assert import_string == "app.app:app"
42+
assert is_factory is False
4143

4244
captured = capsys.readouterr()
4345
assert "Using path app/app.py" in captured.out
@@ -58,8 +60,9 @@ def test_app_dir_app(capsys: CaptureFixture[str]) -> None:
5860

5961
def test_app_dir_api(capsys: CaptureFixture[str]) -> None:
6062
with changing_dir(assets_path / "default_files" / "default_app_dir_api"):
61-
import_string = get_import_string()
63+
import_string, is_factory = get_import_string()
6264
assert import_string == "app.api:app"
65+
assert is_factory is False
6366

6467
captured = capsys.readouterr()
6568
assert "Using path app/api.py" in captured.out

Diff for: tests/test_utils_default_file.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ def test_single_file_main(capsys: CaptureFixture[str]) -> None:
2020
mod = importlib.import_module("main")
2121

2222
importlib.reload(mod)
23-
import_string = get_import_string()
23+
import_string, is_factory = get_import_string()
2424
assert import_string == "main:app"
25+
assert is_factory is False
2526

2627
captured = capsys.readouterr()
2728
assert "Using path main.py" in captured.out
@@ -47,8 +48,9 @@ def test_single_file_app(capsys: CaptureFixture[str]) -> None:
4748
mod = importlib.import_module("app")
4849

4950
importlib.reload(mod)
50-
import_string = get_import_string()
51+
import_string, is_factory = get_import_string()
5152
assert import_string == "app:app"
53+
assert is_factory is False
5254

5355
captured = capsys.readouterr()
5456
assert "Using path app.py" in captured.out
@@ -74,8 +76,9 @@ def test_single_file_api(capsys: CaptureFixture[str]) -> None:
7476
mod = importlib.import_module("api")
7577

7678
importlib.reload(mod)
77-
import_string = get_import_string()
79+
import_string, is_factory = get_import_string()
7880
assert import_string == "api:app"
81+
assert is_factory is False
7982

8083
captured = capsys.readouterr()
8184
assert "Using path api.py" in captured.out

0 commit comments

Comments
 (0)