Skip to content

Commit

Permalink
split api_version_negotiation
Browse files Browse the repository at this point in the history
  • Loading branch information
boxdot committed Feb 19, 2025
1 parent 54be631 commit b2c7752
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 73 deletions.
16 changes: 8 additions & 8 deletions apiclient/src/qs_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ use phnxtypes::{
use thiserror::Error;
use tls_codec::{DeserializeBytes, Serialize};

use crate::{version::api_version_negotiation, ApiClient, Protocol};
use crate::{
version::{extract_api_version_negotiation, negotiate_api_version},
ApiClient, Protocol,
};

pub mod ws;

Expand Down Expand Up @@ -84,16 +87,13 @@ impl ApiClient {
let response = send_qs_message(&self.client, &endpoint, &message).await?;

// check if we need to negotiate a new API version
let Some(accepted_version) = api_version_negotiation(
&response,
api_version,
ClientToQsMessageTbs::SUPPORTED_API_VERSIONS,
)
.transpose()?
else {
let Some(accepted_versions) = extract_api_version_negotiation(&response) else {
return process_response(response).await;
};

let supported_versions = ClientToQsMessageTbs::SUPPORTED_API_VERSIONS;
let accepted_version = negotiate_api_version(accepted_versions, supported_versions)
.ok_or_else(|| VersionError::new(api_version, supported_versions.to_vec()))?;
self.negotiated_versions()
.set_qs_api_version(accepted_version);

Expand Down
110 changes: 45 additions & 65 deletions apiclient/src/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ use std::{

use http::{HeaderMap, StatusCode};
use phnxtypes::{
messages::{
client_qs::{ClientToQsMessageTbs, VersionError},
ApiVersion,
},
messages::{client_qs::ClientToQsMessageTbs, ApiVersion},
ACCEPTED_API_VERSIONS_HEADER,
};
use tracing::error;
Expand Down Expand Up @@ -42,32 +39,16 @@ impl NegotiatedApiVersions {
}
}

/// Returns `Some` if the server supports a different API version, otherwise None.
///
/// If there is no API version supported by this client which is accepted by the server, the
/// returned result is an error.
pub(crate) fn api_version_negotiation(
pub(crate) fn extract_api_version_negotiation(
response: &reqwest::Response,
current_version: ApiVersion,
supported_versions: &[ApiVersion],
) -> Option<Result<ApiVersion, VersionError>> {
) -> Option<HashSet<ApiVersion>> {
if response.status() != StatusCode::NOT_ACCEPTABLE {
return None;
}

let accepted_versions = parse_accepted_versions_header(response.headers())?;

let accepted_version = negotiate_version(
accepted_versions,
supported_versions.iter().copied().collect(),
);
let accepted_version = accepted_version
.ok_or_else(|| VersionError::new(current_version, supported_versions.to_vec()));

Some(accepted_version)
parse_accepted_versions_header(response.headers())
}

fn parse_accepted_versions_header(headers: &HeaderMap) -> Option<HashSet<ApiVersion>> {
pub(crate) fn parse_accepted_versions_header(headers: &HeaderMap) -> Option<HashSet<ApiVersion>> {
let value = headers.get(ACCEPTED_API_VERSIONS_HEADER)?;
let Ok(value) = value.to_str() else {
error!(
Expand All @@ -91,13 +72,14 @@ fn parse_accepted_versions_header(headers: &HeaderMap) -> Option<HashSet<ApiVers
}

/// Returns the highest API version that is supported by both the client and the server.
fn negotiate_version(
pub(crate) fn negotiate_api_version(
accepted_versions: HashSet<ApiVersion>,
supported_versions: HashSet<ApiVersion>,
supported_versions: &[ApiVersion],
) -> Option<ApiVersion> {
accepted_versions
.intersection(&supported_versions)
.cloned()
supported_versions
.iter()
.copied()
.filter(|version| accepted_versions.contains(version))
.max()
}

Expand All @@ -106,66 +88,64 @@ mod tests {
use super::*;

#[test]
fn api_version_negotiation_needed() {
fn extract_api_version_negotiation_some() {
let response = http::response::Builder::new()
.status(StatusCode::NOT_ACCEPTABLE)
.header(ACCEPTED_API_VERSIONS_HEADER, "1,something,3")
.body(Vec::new())
.unwrap()
.into();

let v1 = ApiVersion::new(1).unwrap();
let v2 = ApiVersion::new(2).unwrap();
let v3 = ApiVersion::new(3).unwrap();
let v4 = ApiVersion::new(4).unwrap();

let current_version = v1;
assert_eq!(
api_version_negotiation(&response, current_version, &[v1])
.transpose()
.unwrap(),
Some(v1)
);
assert_eq!(
api_version_negotiation(&response, current_version, &[v1, v3])
.transpose()
.unwrap(),
Some(v3)
);
assert_eq!(
api_version_negotiation(&response, current_version, &[v1, v2, v3, v4])
.transpose()
.unwrap(),
Some(v3)
);
assert!(
api_version_negotiation(&response, current_version, &[v2, v4])
.transpose()
.is_err()
extract_api_version_negotiation(&response),
Some(
[ApiVersion::new(1).unwrap(), ApiVersion::new(3).unwrap()]
.into_iter()
.collect()
),
);
}

#[test]
fn api_version_negotiation_not_needed() {
fn extract_api_version_negotiation_status_ok() {
let response = http::response::Builder::new()
.status(StatusCode::OK)
.body(Vec::new())
.unwrap()
.into();

let v1 = ApiVersion::new(1).unwrap();
assert!(api_version_negotiation(&response, v1, &[v1]).is_none());
assert!(extract_api_version_negotiation(&response).is_none());
}

#[test]
fn api_version_negotiation_header_missing() {
fn extract_api_version_negotiation_header_missing() {
let response = http::response::Builder::new()
.status(StatusCode::NOT_ACCEPTABLE)
.body(Vec::new())
.unwrap()
.into();
assert!(extract_api_version_negotiation(&response).is_none());
}

let v1 = ApiVersion::new(1).unwrap();
assert!(api_version_negotiation(&response, v1, &[v1]).is_none());
#[test]
fn negotiate_api_version_success() {
let accepted_versions = [ApiVersion::new(1).unwrap(), ApiVersion::new(3).unwrap()]
.into_iter()
.collect();
let supported_versions = &[ApiVersion::new(1).unwrap(), ApiVersion::new(2).unwrap()];
assert_eq!(
negotiate_api_version(accepted_versions, supported_versions),
Some(ApiVersion::new(1).unwrap())
);
}

#[test]
fn negotiate_api_version_failure() {
let accepted_versions = [ApiVersion::new(2).unwrap(), ApiVersion::new(4).unwrap()]
.into_iter()
.collect();
let supported_versions = &[ApiVersion::new(1).unwrap(), ApiVersion::new(0).unwrap()];
assert_eq!(
negotiate_api_version(accepted_versions, supported_versions),
None
);
}
}

0 comments on commit b2c7752

Please sign in to comment.