Skip to content

Commit 9ee14bd

Browse files
committed
refactor session retrieval
1 parent 9639fcb commit 9ee14bd

File tree

3 files changed

+50
-67
lines changed

3 files changed

+50
-67
lines changed

crates/handlers/src/upstream_oauth2/cookie.rs

-4
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ impl UpstreamSessions {
6565
pub fn is_empty(&self) -> bool {
6666
self.0.is_empty()
6767
}
68-
/// Returns the session IDs in the cookie
69-
pub fn session_ids(&self) -> Vec<Ulid> {
70-
self.0.iter().map(|p| p.session).collect()
71-
}
7268

7369
/// Save the upstreams sessions to the cookie jar
7470
pub fn save<C>(self, cookie_jar: CookieJar, clock: &C) -> CookieJar

crates/handlers/src/upstream_oauth2/logout.rs

+47-61
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
// SPDX-License-Identifier: AGPL-3.0-only
44
// Please see LICENSE in the repository root for full details.
55

6-
use mas_axum_utils::cookies::CookieJar;
6+
use mas_data_model::{AuthenticationMethod, BrowserSession};
77
use mas_router::UrlBuilder;
88
use mas_storage::{RepositoryAccess, upstream_oauth2::UpstreamOAuthProviderRepository};
99
use serde::{Deserialize, Serialize};
1010
use thiserror::Error;
11-
use tracing::{error, warn};
11+
use tracing::error;
1212
use url::Url;
1313

14-
use super::UpstreamSessionsCookie;
1514
use crate::impl_from_error_for_route;
1615

1716
#[derive(Serialize, Deserialize)]
@@ -69,7 +68,7 @@ impl From<reqwest::Error> for RouteError {
6968
pub async fn get_rp_initiated_logout_endpoints<E>(
7069
url_builder: &UrlBuilder,
7170
repo: &mut impl RepositoryAccess<Error = E>,
72-
cookie_jar: &CookieJar,
71+
browser_session: &BrowserSession,
7372
) -> Result<UpstreamLogoutInfo, RouteError>
7473
where
7574
RouteError: std::convert::From<E>,
@@ -81,68 +80,55 @@ where
8180
.to_string();
8281
result.post_logout_redirect_uri = Some(post_logout_redirect_uri.clone());
8382

84-
let sessions_cookie = UpstreamSessionsCookie::load(cookie_jar);
85-
// Standard location for OIDC end session endpoint
86-
let session_ids = sessions_cookie.session_ids();
87-
if session_ids.is_empty() {
88-
return Ok(result);
89-
}
90-
// We only support the first upstream session
91-
let mut provider = None;
92-
let mut upstream_session = None;
93-
for session_id in session_ids {
94-
// Get the session and assign its value, wrapped in Some
95-
let session = repo
96-
.upstream_oauth_session()
97-
.lookup(session_id)
98-
.await?
99-
.ok_or(RouteError::SessionNotFound)?;
100-
// Get the provider and assign its value, wrapped in Some
101-
let prov = repo
102-
.upstream_oauth_provider()
103-
.lookup(session.provider_id)
104-
.await?
105-
.ok_or(RouteError::ProviderNotFound)?;
83+
let upstream_oauth2_session_id = repo
84+
.browser_session()
85+
.get_last_authentication(browser_session)
86+
.await?
87+
.ok_or(RouteError::SessionNotFound)
88+
.map(|auth| match auth.authentication_method {
89+
AuthenticationMethod::UpstreamOAuth2 {
90+
upstream_oauth2_session_id,
91+
} => Some(upstream_oauth2_session_id),
92+
_ => None,
93+
})?
94+
.ok_or(RouteError::SessionNotFound)?;
10695

107-
if prov.allow_rp_initiated_logout {
108-
upstream_session = Some(session);
109-
provider = Some(prov);
110-
break;
111-
}
112-
}
96+
// Get the session and assign its value, wrapped in Some
97+
let upstream_session = repo
98+
.upstream_oauth_session()
99+
.lookup(upstream_oauth2_session_id)
100+
.await?
101+
.ok_or(RouteError::SessionNotFound)?;
102+
// Get the provider and assign its value, wrapped in Some
103+
let provider = repo
104+
.upstream_oauth_provider()
105+
.lookup(upstream_session.provider_id)
106+
.await?
107+
.filter(|provider| provider.allow_rp_initiated_logout)
108+
.ok_or(RouteError::ProviderNotFound)?;
113109

114-
// Check if we found a provider with allow_rp_initiated_logout
115-
if let Some(provider) = provider {
116-
// Look for end session endpoint
117-
// In a real implementation, we'd have end_session_endpoint fields in the
118-
// provider For now, we'll try to construct one from the issuer if
119-
// available
120-
if let Some(issuer) = &provider.issuer {
121-
let end_session_endpoint = format!("{issuer}/protocol/openid-connect/logout");
122-
let mut logout_url = end_session_endpoint;
123-
// Add post_logout_redirect_uri
124-
if let Some(post_uri) = &result.post_logout_redirect_uri {
125-
if let Ok(mut url) = Url::parse(&logout_url) {
126-
url.query_pairs_mut()
127-
.append_pair("post_logout_redirect_uri", post_uri);
128-
url.query_pairs_mut()
129-
.append_pair("client_id", &provider.client_id);
130-
// Add id_token_hint if available
131-
if let Some(session) = &upstream_session {
132-
if let Some(id_token) = session.id_token() {
133-
url.query_pairs_mut().append_pair("id_token_hint", id_token);
134-
}
135-
}
136-
logout_url = url.to_string();
110+
// Look for end session endpoint
111+
// In a real implementation, we'd have end_session_endpoint fields in the
112+
// provider For now, we'll try to construct one from the issuer if
113+
// available
114+
if let Some(issuer) = &provider.issuer {
115+
let end_session_endpoint = format!("{issuer}/protocol/openid-connect/logout");
116+
let mut logout_url = end_session_endpoint;
117+
// Add post_logout_redirect_uri
118+
if let Some(post_uri) = &result.post_logout_redirect_uri {
119+
if let Ok(mut url) = Url::parse(&logout_url) {
120+
url.query_pairs_mut()
121+
.append_pair("post_logout_redirect_uri", post_uri);
122+
url.query_pairs_mut()
123+
.append_pair("client_id", &provider.client_id);
124+
// Add id_token_hint if available
125+
if let Some(id_token) = upstream_session.id_token() {
126+
url.query_pairs_mut().append_pair("id_token_hint", id_token);
137127
}
128+
logout_url = url.to_string();
138129
}
139-
result.logout_endpoints.clone_from(&logout_url);
140-
} else {
141-
warn!(
142-
upstream_oauth_provider.id = %provider.id,
143-
"Provider has no issuer defined, cannot construct RP-initiated logout URL"
144-
);
145130
}
131+
result.logout_endpoints.clone_from(&logout_url);
146132
}
147133
Ok(result)
148134
}

crates/handlers/src/views/logout.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ pub(crate) async fn post(
4343

4444
// First, get RP-initiated logout endpoints before actually finishing the
4545
// session
46-
match get_rp_initiated_logout_endpoints(&url_builder, &mut repo, &cookie_jar).await
47-
{
46+
// match get_rp_initiated_logout_endpoints(&url_builder, &mut repo,
47+
// &cookie_jar).await
48+
match get_rp_initiated_logout_endpoints(&url_builder, &mut repo, &session).await {
4849
Ok(logout_info) => {
4950
// If we have any RP-initiated logout endpoints, use the first one
5051
if !logout_info.logout_endpoints.is_empty() {

0 commit comments

Comments
 (0)