Skip to content

Commit

Permalink
Add CORSMiddleware to shortfin servers (#965)
Browse files Browse the repository at this point in the history
We need to add the `CORSMiddleware` to our server in order to receive
requests from [shark-ui](https://github.com/nod-ai/shark-ui), otherwise
they get denied.

Here's a quick article on
[CORS](https://auth0.com/blog/cors-tutorial-a-guide-to-cross-origin-resource-sharing/).

@bjacobgordon and I would like to hear your thoughts on if we need
further restrictions or not. Specifically, if the origins need to be
restricted.

What the origin settings does, is disallow requests from clients, unless
those clients either belong to some specific ip address or ip subnet, or
belong to some specified domain. So, for example, if we had a top-secret
govt project, we'd say, only these 4 ip addresses are allowed to send
requests to our server.

Normally, in a commercial deployment scenario, this is a consideration.
For us, I don't think it is. Here's why:

1.) We're open source. So, our model is to have people spin up this API
anywhere and send requests to it from anywhere. So, we don't have a set
of origins we would want to set it to. Otherwise, we would cut off our
users. Maybe if we had a prod deployment, we would want to set it to
only allow requests from some load balancer, but we don't have that
right now.
2.) We're running open source software, serving open source models, so I
don't know if we have security considerations like that.
3.) It's easy to spoof the origin header, so this isn't really a strong
security feature anyways. We'd rely on infra components, like private
net or load balancers, IF this was a commercial deployment.
4.) We're not yet at the point where we would really need to worry about
this, and doing more advanced features would require time we don't have
right now.

All of this to say, I think it's fine like this, but would like to hear
other opinions. I would be open to making the settings configurable,
where we default to wildcard, but users are able to specify specific
security settings. That means that if someone wanted to deploy in
production with stronger settings, they can. If we do this, I would like
it to be done the same way on both servers, so would like to know if
we'd like to do:

a.) Env var (something like "SHORTFIN_ALLOW_ORIGINS=a,b"
b.) Discuss an idea for a json based config that both servers could
share for settings like this.
c.) Something else
  • Loading branch information
stbaione authored Feb 13, 2025
1 parent 3744873 commit bc76526
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
12 changes: 12 additions & 0 deletions shortfin/python/shortfin_apps/llm/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from .lifecycle_hooks import lifespan
from .routes import application_router, generation_router
Expand All @@ -16,7 +17,18 @@ def add_routes(app: FastAPI):
return app


def add_middleware(app: FastAPI):
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
return app


def get_app() -> FastAPI:
app = FastAPI(lifespan=lifespan)
app = add_routes(app)
app = add_middleware(app)
return app
10 changes: 10 additions & 0 deletions shortfin/python/shortfin_apps/sd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from shortfin.support.logging_setup import native_handler

from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware

from .components.generate import ClientGenerateBatchProcess
from .components.config_struct import ModelParams
Expand Down Expand Up @@ -105,6 +106,15 @@ async def generate_request(gen_req: GenerateReqInput, request: Request):
app.put("/generate")(generate_request)


# -------- MIDDLEWARE --------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)


def configure_sys(args) -> SystemManager:
# Setup system (configure devices, etc).
model_config, topology_config, flagfile, tuning_spec, args = get_configs(args)
Expand Down

0 comments on commit bc76526

Please sign in to comment.