|
1 | 1 | use std::{
|
| 2 | + fmt::Display, |
2 | 3 | net::{IpAddr, Ipv6Addr, SocketAddr},
|
3 | 4 | sync::Arc,
|
4 | 5 | thread,
|
@@ -26,46 +27,65 @@ use super::routes::{
|
26 | 27 | };
|
27 | 28 | use crate::compute::ComputeNode;
|
28 | 29 |
|
29 |
| -async fn handle_404() -> Response { |
30 |
| - StatusCode::NOT_FOUND.into_response() |
31 |
| -} |
32 |
| - |
33 | 30 | const X_REQUEST_ID: &str = "x-request-id";
|
34 | 31 |
|
35 |
| -/// This middleware function allows compute_ctl to generate its own request ID |
36 |
| -/// if one isn't supplied. The control plane will always send one as a UUID. The |
37 |
| -/// neon Postgres extension on the other hand does not send one. |
38 |
| -async fn maybe_add_request_id_header(mut request: Request, next: Next) -> Response { |
39 |
| - let headers = request.headers_mut(); |
| 32 | +/// `compute_ctl` has two servers: internal and external. The internal server |
| 33 | +/// binds to the loopback interface and handles communication from clients on |
| 34 | +/// the compute. The external server is what receives communication from the |
| 35 | +/// control plane, the metrics scraper, etc. We make the distinction because |
| 36 | +/// certain routes in `compute_ctl` only need to be exposed to local processes |
| 37 | +/// like Postgres via the neon extension and local_proxy. |
| 38 | +#[derive(Clone, Copy, Debug)] |
| 39 | +pub enum Server { |
| 40 | + Internal(u16), |
| 41 | + External(u16), |
| 42 | +} |
40 | 43 |
|
41 |
| - if headers.get(X_REQUEST_ID).is_none() { |
42 |
| - headers.append(X_REQUEST_ID, Uuid::new_v4().to_string().parse().unwrap()); |
| 44 | +impl Display for Server { |
| 45 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 46 | + match self { |
| 47 | + Server::Internal(_) => f.write_str("internal"), |
| 48 | + Server::External(_) => f.write_str("external"), |
| 49 | + } |
43 | 50 | }
|
44 |
| - |
45 |
| - next.run(request).await |
46 | 51 | }
|
47 | 52 |
|
48 |
| -/// Run the HTTP server and wait on it forever. |
49 |
| -#[tokio::main] |
50 |
| -async fn serve(port: u16, compute: Arc<ComputeNode>) { |
51 |
| - let mut app = Router::new() |
52 |
| - .route("/check_writability", post(check_writability::is_writable)) |
53 |
| - .route("/configure", post(configure::configure)) |
54 |
| - .route("/database_schema", get(database_schema::get_schema_dump)) |
55 |
| - .route("/dbs_and_roles", get(dbs_and_roles::get_catalog_objects)) |
56 |
| - .route( |
57 |
| - "/extension_server/{*filename}", |
58 |
| - post(extension_server::download_extension), |
59 |
| - ) |
60 |
| - .route("/extensions", post(extensions::install_extension)) |
61 |
| - .route("/grants", post(grants::add_grant)) |
62 |
| - .route("/insights", get(insights::get_insights)) |
63 |
| - .route("/metrics", get(metrics::get_metrics)) |
64 |
| - .route("/metrics.json", get(metrics_json::get_metrics)) |
65 |
| - .route("/status", get(status::get_status)) |
66 |
| - .route("/terminate", post(terminate::terminate)) |
67 |
| - .fallback(handle_404) |
68 |
| - .layer( |
| 53 | +impl From<Server> for Router<Arc<ComputeNode>> { |
| 54 | + fn from(server: Server) -> Self { |
| 55 | + let mut router = Router::<Arc<ComputeNode>>::new(); |
| 56 | + |
| 57 | + router = match server { |
| 58 | + Server::Internal(_) => { |
| 59 | + router = router |
| 60 | + .route( |
| 61 | + "/extension_server/{*filename}", |
| 62 | + post(extension_server::download_extension), |
| 63 | + ) |
| 64 | + .route("/extensions", post(extensions::install_extension)) |
| 65 | + .route("/grants", post(grants::add_grant)); |
| 66 | + |
| 67 | + // Add in any testing support |
| 68 | + if cfg!(feature = "testing") { |
| 69 | + use super::routes::failpoints; |
| 70 | + |
| 71 | + router = router.route("/failpoints", post(failpoints::configure_failpoints)); |
| 72 | + } |
| 73 | + |
| 74 | + router |
| 75 | + } |
| 76 | + Server::External(_) => router |
| 77 | + .route("/check_writability", post(check_writability::is_writable)) |
| 78 | + .route("/configure", post(configure::configure)) |
| 79 | + .route("/database_schema", get(database_schema::get_schema_dump)) |
| 80 | + .route("/dbs_and_roles", get(dbs_and_roles::get_catalog_objects)) |
| 81 | + .route("/insights", get(insights::get_insights)) |
| 82 | + .route("/metrics", get(metrics::get_metrics)) |
| 83 | + .route("/metrics.json", get(metrics_json::get_metrics)) |
| 84 | + .route("/status", get(status::get_status)) |
| 85 | + .route("/terminate", post(terminate::terminate)), |
| 86 | + }; |
| 87 | + |
| 88 | + router.fallback(Server::handle_404).method_not_allowed_fallback(Server::handle_405).layer( |
69 | 89 | ServiceBuilder::new()
|
70 | 90 | // Add this middleware since we assume the request ID exists
|
71 | 91 | .layer(middleware::from_fn(maybe_add_request_id_header))
|
@@ -105,45 +125,92 @@ async fn serve(port: u16, compute: Arc<ComputeNode>) {
|
105 | 125 | )
|
106 | 126 | .layer(PropagateRequestIdLayer::x_request_id()),
|
107 | 127 | )
|
108 |
| - .with_state(compute); |
| 128 | + } |
| 129 | +} |
| 130 | + |
| 131 | +impl Server { |
| 132 | + async fn handle_404() -> impl IntoResponse { |
| 133 | + StatusCode::NOT_FOUND |
| 134 | + } |
| 135 | + |
| 136 | + async fn handle_405() -> impl IntoResponse { |
| 137 | + StatusCode::METHOD_NOT_ALLOWED |
| 138 | + } |
| 139 | + |
| 140 | + async fn listener(&self) -> Result<TcpListener> { |
| 141 | + let addr = SocketAddr::new(self.ip(), self.port()); |
| 142 | + let listener = TcpListener::bind(&addr).await?; |
109 | 143 |
|
110 |
| - // Add in any testing support |
111 |
| - if cfg!(feature = "testing") { |
112 |
| - use super::routes::failpoints; |
| 144 | + Ok(listener) |
| 145 | + } |
| 146 | + |
| 147 | + fn ip(&self) -> IpAddr { |
| 148 | + match self { |
| 149 | + // TODO: Change this to Ipv6Addr::LOCALHOST when the GitHub runners |
| 150 | + // allow binding to localhost |
| 151 | + Server::Internal(_) => IpAddr::from(Ipv6Addr::UNSPECIFIED), |
| 152 | + Server::External(_) => IpAddr::from(Ipv6Addr::UNSPECIFIED), |
| 153 | + } |
| 154 | + } |
113 | 155 |
|
114 |
| - app = app.route("/failpoints", post(failpoints::configure_failpoints)) |
| 156 | + fn port(self) -> u16 { |
| 157 | + match self { |
| 158 | + Server::Internal(port) => port, |
| 159 | + Server::External(port) => port, |
| 160 | + } |
115 | 161 | }
|
116 | 162 |
|
117 |
| - // This usually binds to both IPv4 and IPv6 on Linux, see |
118 |
| - // https://github.com/rust-lang/rust/pull/34440 for more information |
119 |
| - let addr = SocketAddr::new(IpAddr::from(Ipv6Addr::UNSPECIFIED), port); |
120 |
| - let listener = match TcpListener::bind(&addr).await { |
121 |
| - Ok(listener) => listener, |
122 |
| - Err(e) => { |
123 |
| - error!( |
124 |
| - "failed to bind the compute_ctl HTTP server to port {}: {}", |
125 |
| - port, e |
| 163 | + #[tokio::main] |
| 164 | + async fn serve(self, compute: Arc<ComputeNode>) { |
| 165 | + let listener = self.listener().await.unwrap_or_else(|e| { |
| 166 | + // If we can't bind, the compute cannot operate correctly |
| 167 | + panic!( |
| 168 | + "failed to bind the compute_ctl {} HTTP server to {}: {}", |
| 169 | + self, |
| 170 | + SocketAddr::new(self.ip(), self.port()), |
| 171 | + e |
| 172 | + ); |
| 173 | + }); |
| 174 | + |
| 175 | + if tracing::enabled!(tracing::Level::INFO) { |
| 176 | + let local_addr = match listener.local_addr() { |
| 177 | + Ok(local_addr) => local_addr, |
| 178 | + Err(_) => SocketAddr::new(self.ip(), self.port()), |
| 179 | + }; |
| 180 | + |
| 181 | + info!( |
| 182 | + "compute_ctl {} HTTP server listening at {}", |
| 183 | + self, local_addr |
126 | 184 | );
|
127 |
| - return; |
128 | 185 | }
|
129 |
| - }; |
130 | 186 |
|
131 |
| - if let Ok(local_addr) = listener.local_addr() { |
132 |
| - info!("compute_ctl HTTP server listening on {}", local_addr); |
133 |
| - } else { |
134 |
| - info!("compute_ctl HTTP server listening on port {}", port); |
| 187 | + let router = Router::from(self).with_state(compute); |
| 188 | + |
| 189 | + if let Err(e) = axum::serve(listener, router).await { |
| 190 | + error!("compute_ctl {} HTTP server error: {}", self, e); |
| 191 | + } |
135 | 192 | }
|
136 | 193 |
|
137 |
| - if let Err(e) = axum::serve(listener, app).await { |
138 |
| - error!("compute_ctl HTTP server error: {}", e); |
| 194 | + pub fn launch(self, compute: &Arc<ComputeNode>) { |
| 195 | + let state = Arc::clone(compute); |
| 196 | + |
| 197 | + info!("Launching the {} server", self); |
| 198 | + thread::Builder::new() |
| 199 | + .name(format!("http-server-{self}")) |
| 200 | + .spawn(move || self.serve(state)) |
| 201 | + .unwrap_or_else(|_| panic!("Failed to start the {self} HTTP server")); |
139 | 202 | }
|
140 | 203 | }
|
141 | 204 |
|
142 |
| -/// Launch a separate HTTP server thread and return its `JoinHandle`. |
143 |
| -pub fn launch_http_server(port: u16, state: &Arc<ComputeNode>) -> Result<thread::JoinHandle<()>> { |
144 |
| - let state = Arc::clone(state); |
| 205 | +/// This middleware function allows compute_ctl to generate its own request ID |
| 206 | +/// if one isn't supplied. The control plane will always send one as a UUID. The |
| 207 | +/// neon Postgres extension on the other hand does not send one. |
| 208 | +async fn maybe_add_request_id_header(mut request: Request, next: Next) -> Response { |
| 209 | + let headers = request.headers_mut(); |
145 | 210 |
|
146 |
| - Ok(thread::Builder::new() |
147 |
| - .name("http-server".into()) |
148 |
| - .spawn(move || serve(port, state))?) |
| 211 | + if headers.get(X_REQUEST_ID).is_none() { |
| 212 | + headers.append(X_REQUEST_ID, Uuid::new_v4().to_string().parse().unwrap()); |
| 213 | + } |
| 214 | + |
| 215 | + next.run(request).await |
149 | 216 | }
|
0 commit comments