Skip to content

Commit f377495

Browse files
committed
feat: support middleware
1 parent 2b21143 commit f377495

File tree

3 files changed

+180
-0
lines changed

3 files changed

+180
-0
lines changed

lightbug_http/__init__.mojo

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from lightbug_http.service import HTTPService, Welcome
33
from lightbug_http.sys.server import SysServer
44
from lightbug_http.tests.run import run_tests
55
from lightbug_http.middleware import *
6+
from lightbug_http.middleware import *
67

78
trait DefaultConstructible:
89
fn __init__(inout self) raises:

lightbug_http/middleware.mojo

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from lightbug_http.http import HTTPRequest, HTTPResponse
2+
3+
struct Context:
4+
var request: Request
5+
var params: Dict[String, AnyType]
6+
7+
fn __init__(self, request: Request):
8+
self.request = request
9+
self.params = Dict[String, AnyType]()
10+
11+
trait Middleware:
12+
var next: Middleware
13+
14+
fn call(self, context: Context) -> Response:
15+
...
16+
17+
struct ErrorMiddleware(Middleware):
18+
fn call(self, context: Context) -> Response:
19+
try:
20+
return next.call(context: context)
21+
catch e: Exception:
22+
return InternalServerError()
23+
24+
struct LoggerMiddleware(Middleware):
25+
fn call(self, context: Context) -> Response:
26+
print("Request: \(context.request)")
27+
return next.call(context: context)
28+
29+
struct StaticMiddleware(Middleware):
30+
var path: String
31+
32+
fnt __init__(self, path: String):
33+
self.path = path
34+
35+
fn call(self, context: Context) -> Response:
36+
if context.request.path == "/":
37+
var file = File(path: path + "index.html")
38+
else:
39+
var file = File(path: path + context.request.path)
40+
41+
if file.exists:
42+
var html: String
43+
with open(file, "r") as f:
44+
html = f.read()
45+
return OK(html.as_bytes(), "text/html")
46+
else:
47+
return next.call(context: context)
48+
49+
struct CorsMiddleware(Middleware):
50+
var allow_origin: String
51+
52+
fn __init__(self, allow_origin: String):
53+
self.allow_origin = allow_origin
54+
55+
fn call(self, context: Context) -> Response:
56+
if context.request.method == "OPTIONS":
57+
var response = next.call(context: context)
58+
response.headers["Access-Control-Allow-Origin"] = allow_origin
59+
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
60+
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
61+
return response
62+
63+
if context.request.origin == allow_origin:
64+
return next.call(context: context)
65+
else:
66+
return Unauthorized()
67+
68+
struct CompressionMiddleware(Middleware):
69+
fn call(self, context: Context) -> Response:
70+
var response = next.call(context: context)
71+
response.body = compress(response.body)
72+
return response
73+
74+
fn compress(self, body: Bytes) -> Bytes:
75+
#TODO: implement compression
76+
return body
77+
78+
79+
struct RouterMiddleware(Middleware):
80+
var routes: Dict[String, Middleware]
81+
82+
fn __init__(self):
83+
self.routes = Dict[String, Middleware]()
84+
85+
fn add(self, method: String, route: String, middleware: Middleware):
86+
routes[method + ":" + route] = middleware
87+
88+
fn call(self, context: Context) -> Response:
89+
# TODO: create a more advanced router
90+
var method = context.request.method
91+
var route = context.request.path
92+
if middleware = routes[method + ":" + route]:
93+
return middleware.call(context: context)
94+
else:
95+
return next.call(context: context)
96+
97+
struct BasicAuthMiddleware(Middleware):
98+
var username: String
99+
var password: String
100+
101+
fn __init__(self, username: String, password: String):
102+
self.username = username
103+
self.password = password
104+
105+
fn call(self, context: Context) -> Response:
106+
var request = context.request
107+
var auth = request.headers["Authorization"]
108+
if auth == "Basic \(username):\(password)":
109+
context.params["username"] = username
110+
return next.call(context: context)
111+
else:
112+
return Unauthorized()
113+
114+
# always add at the end of the middleware chain
115+
struct NotFoundMiddleware(Middleware):
116+
fn call(self, context: Context) -> Response:
117+
return NotFound()
118+
119+
struct MiddlewareChain(HttpService):
120+
var middlewares: Array[Middleware]
121+
122+
fn __init__(self):
123+
self.middlewares = Array[Middleware]()
124+
125+
fn add(self, middleware: Middleware):
126+
if middlewares.count == 0:
127+
middlewares.append(middleware)
128+
else:
129+
var last = middlewares[middlewares.count - 1]
130+
last.next = middleware
131+
middlewares.append(middleware)
132+
133+
fn func(self, request: Request) -> Response:
134+
self.add(NotFoundMiddleware())
135+
var context = Context(request: request, response: response)
136+
return middlewares[0].call(context: context)
137+
138+
fn OK(body: Bytes) -> HTTPResponse:
139+
return OK(body, String("text/plain"))
140+
141+
fn OK(body: Bytes, content_type: String) -> HTTPResponse:
142+
return HTTPResponse(
143+
ResponseHeader(True, 200, String("OK").as_bytes(), content_type.as_bytes()),
144+
body,
145+
)
146+
147+
fn NotFound(body: Bytes) -> HTTPResponse:
148+
return NotFoundResponse(body, String("text/plain"))
149+
150+
fn NotFound(body: Bytes, content_type: String) -> HTTPResponse:
151+
return HTTPResponse(
152+
ResponseHeader(True, 404, String("Not Found").as_bytes(), content_type.as_bytes()),
153+
body,
154+
)
155+
156+
fn InternalServerError(body: Bytes) -> HTTPResponse:
157+
return InternalServerErrorResponse(body, String("text/plain"))
158+
159+
fn InternalServerError(body: Bytes, content_type: String) -> HTTPResponse:
160+
return HTTPResponse(
161+
ResponseHeader(True, 500, String("Internal Server Error").as_bytes(), content_type.as_bytes()),
162+
body,
163+
)
164+
165+
fn Unauthorized(body: Bytes) -> HTTPResponse:
166+
return UnauthorizedResponse(body, String("text/plain"))
167+
168+
fn Unauthorized(body: Bytes, content_type: String) -> HTTPResponse:
169+
return HTTPResponse(
170+
ResponseHeader(True, 401, String("Unauthorized").as_bytes(), content_type.as_bytes()),
171+
body,
172+
)

0 commit comments

Comments
 (0)