From c3371d9191835f61d8ecbd4a15b1183de0171fb8 Mon Sep 17 00:00:00 2001 From: Artem Goncharov Date: Mon, 15 Jun 2026 19:14:46 +0200 Subject: [PATCH] fix(auth-core): Improve auth-core quality - Implement AuthTokenScope::matches() for wildcard cache lookup (None matches any value) - Update get_scope_auth/find_scope_authz to iterate with matches() instead of HashMap::get - Rename AuthResponse::token field type from AuthToken to TokenInfo - Change OpenStackAuthType::auth trait signature to take &HashMap (borrow instead of owned) - Rename AuthToken::set_header return type from Result<&'a mut HeaderMap> to Result<()> - Improve from_reqwest_response error handling for non-JSON and identity API errors - Fix federation and websso callback server test race conditions with oneshot channels - Add 18 AuthTokenScope::matches() tests and 9 state cache tests - Reorder struct fields alphabetically per project standards --- openstack_sdk/src/openstack.rs | 8 +- openstack_sdk/src/openstack_async.rs | 10 +- openstack_tui/src/action.rs | 2 +- sdk/auth-application-credential/src/lib.rs | 6 +- sdk/auth-core/src/authtoken.rs | 400 +++++++++- sdk/auth-core/src/authtoken_scope.rs | 258 +++++- sdk/auth-core/src/lib.rs | 882 ++++++++++++++++++++- sdk/auth-core/src/types.rs | 136 ++-- sdk/auth-federation/src/lib.rs | 28 +- sdk/auth-jwt/src/lib.rs | 2 +- sdk/auth-multifactor/src/lib.rs | 8 +- sdk/auth-oidcaccesstoken/src/lib.rs | 2 +- sdk/auth-passkey/src/lib.rs | 2 +- sdk/auth-password/src/lib.rs | 8 +- sdk/auth-receipt/src/lib.rs | 6 +- sdk/auth-token/src/lib.rs | 4 +- sdk/auth-totp/src/lib.rs | 8 +- sdk/auth-websso/src/lib.rs | 84 +- sdk/core/src/state.rs | 256 ++++-- 19 files changed, 1896 insertions(+), 214 deletions(-) diff --git a/openstack_sdk/src/openstack.rs b/openstack_sdk/src/openstack.rs index cda7e108b..98064d758 100644 --- a/openstack_sdk/src/openstack.rs +++ b/openstack_sdk/src/openstack.rs @@ -280,7 +280,7 @@ impl OpenStack { &client, self.get_service_endpoint(&ServiceType::Identity, Some(&ApiVersion::from((3, 0))))? .url(), - HashMap::from([("token".into(), auth.token.clone())]), + &HashMap::from([("token".into(), auth.token.clone())]), Some(scope), None, ), @@ -394,7 +394,7 @@ impl OpenStack { Some(&ApiVersion::from(authenticator.api_version())), )? .url(), - auth_data, + &auth_data, Some(&requested_scope), auth_hints.as_ref(), ), @@ -424,7 +424,7 @@ impl OpenStack { Some(&ApiVersion::from(authenticator.api_version())), )? .url(), - auth_data, + &auth_data, Some(&requested_scope), Some(&auth_hints), ), @@ -766,7 +766,7 @@ mod tests { }; let token_info = openstack_sdk_auth_core::types::AuthResponse { - token: openstack_sdk_auth_core::types::AuthToken { + token: openstack_sdk_auth_core::types::TokenInfo { expires_at: Utc::now() + chrono::TimeDelta::hours(1), project: Some(openstack_sdk_auth_core::types::Project { id: Some("test-project".into()), diff --git a/openstack_sdk/src/openstack_async.rs b/openstack_sdk/src/openstack_async.rs index f9842fc19..4ecfaab67 100644 --- a/openstack_sdk/src/openstack_async.rs +++ b/openstack_sdk/src/openstack_async.rs @@ -461,7 +461,7 @@ impl AsyncOpenStack { self.get_service_endpoint(&ServiceType::Identity, Some(&ApiVersion::from((3, 0)))) .await? .url(), - HashMap::from([("token".into(), auth.token.clone())]), + &HashMap::from([("token".into(), auth.token.clone())]), Some(scope), None, ) @@ -579,7 +579,7 @@ impl AsyncOpenStack { ) .await? .url(), - gather_auth_data( + &gather_auth_data( &authenticator.requirements(auth_hints.as_ref())?, &self.config, auth_helper, @@ -608,7 +608,7 @@ impl AsyncOpenStack { ) .await? .url(), - gather_auth_data( + &gather_auth_data( &token_receipt::PLUGIN.requirements(Some(&auth_hints))?, &self.config, auth_helper, @@ -680,7 +680,7 @@ impl AsyncOpenStack { } } } else { - return Err(AuthError::AuthTokenNotInResponse)?; + return Err(OpenStackError::NoAuth)?; } { @@ -1045,7 +1045,7 @@ mod tests { }; let token_info = AuthResponse { - token: openstack_sdk_auth_core::types::AuthToken { + token: openstack_sdk_auth_core::types::TokenInfo { expires_at: Utc::now() + chrono::TimeDelta::hours(1), project: Some(openstack_sdk_auth_core::types::Project { id: Some("test-project".into()), diff --git a/openstack_tui/src/action.rs b/openstack_tui/src/action.rs index a6dc07fa2..6ae50ced3 100644 --- a/openstack_tui/src/action.rs +++ b/openstack_tui/src/action.rs @@ -44,7 +44,7 @@ pub enum Action { /// Request rescoping current connection CloudChangeScope(openstack_sdk::auth::authtoken::AuthTokenScope), /// New cloud connection established - ConnectedToCloud(Box), + ConnectedToCloud(Box), /// Perform API request PerformApiRequest(cloud_types::ApiRequest), /// Propagate single resource data to components diff --git a/sdk/auth-application-credential/src/lib.rs b/sdk/auth-application-credential/src/lib.rs index df61156da..62ba5c556 100644 --- a/sdk/auth-application-credential/src/lib.rs +++ b/sdk/auth-application-credential/src/lib.rs @@ -121,7 +121,7 @@ impl OpenStackAuthType for AppilcationCredentialAuthenticator { &self, http_client: &reqwest::Client, identity_url: &url::Url, - values: std::collections::HashMap, + values: &std::collections::HashMap, _scope: Option<&AuthTokenScope>, _hints: Option<&serde_json::Value>, ) -> Result { @@ -253,7 +253,7 @@ mod tests { .auth( &http_client, &base_url, - HashMap::from([ + &HashMap::from([ ( "application_credential_id".into(), SecretString::from("app_cred_id"), @@ -282,7 +282,7 @@ mod tests { .auth( &http_client, &base_url, - HashMap::from([( + &HashMap::from([( "application_credential_secret".into(), SecretString::from("secret"), )]), diff --git a/sdk/auth-core/src/authtoken.rs b/sdk/auth-core/src/authtoken.rs index 40d04b67b..4a59e93b8 100644 --- a/sdk/auth-core/src/authtoken.rs +++ b/sdk/auth-core/src/authtoken.rs @@ -163,7 +163,7 @@ impl From<&str> for AuthToken { impl Debug for AuthToken { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Auth") + f.debug_struct("AuthToken") .field("data", &self.auth_info) .finish() } @@ -183,15 +183,12 @@ impl AuthToken { /// Adds X-Auth-Token header to a request headers. /// /// Returns an error if the token string cannot be parsed as a header value. - pub fn set_header<'a>( - &self, - headers: &'a mut HeaderMap, - ) -> AuthResult<&'a mut HeaderMap> { + pub fn set_header(&self, headers: &mut HeaderMap) -> AuthResult<()> { let mut token_header_value = HeaderValue::from_str(self.token.expose_secret())?; token_header_value.set_sensitive(true); headers.insert("X-Auth-Token", token_header_value); - Ok(headers) + Ok(()) } /// Detect authentication validity (valid/expired/unset) @@ -228,18 +225,29 @@ impl AuthToken { /// Parse [`Response`] into the AuthToken. pub async fn from_reqwest_response(response: Response) -> Result { if !response.status().is_success() { - // Handle the MFA let status = response.status(); if let StatusCode::UNAUTHORIZED = status - && let Some(receipt) = response.headers().get("openstack-auth-receipt") + && let Some(receipt_header) = response.headers().get("openstack-auth-receipt") { - let receipt_token = receipt + let receipt_token = receipt_header .to_str() .map_err(|_| AuthError::AuthReceiptNotString)? .into(); - let mut receipt: AuthReceiptResponse = response.json().await?; - receipt.token = Some(receipt_token); - return Err(AuthError::AuthReceipt(receipt)); + + let body_text = response.text().await?; + if let Ok(mut receipt) = serde_json::from_str::(&body_text) { + receipt.token = Some(receipt_token); + return Err(AuthError::AuthReceipt(Box::new(receipt))); + } + + if let Ok(data) = serde_json::from_str::(&body_text) { + return Err(AuthError::Identity(data.error)); + } else { + return Err(AuthError::UnknownAuth { + code: status.into(), + message: Some(body_text), + }); + } } let body = response.text().await?; @@ -257,9 +265,13 @@ impl AuthToken { let token = response .headers() .get("x-subject-token") - .ok_or(AuthError::AuthTokenNotInResponse)? + .ok_or(AuthError::AuthToken { + source: AuthTokenError::AuthTokenNotInResponse, + })? .to_str() - .map_err(|_| AuthError::AuthTokenNotString)? + .map_err(|_| AuthError::AuthToken { + source: AuthTokenError::AuthTokenNotString, + })? .to_string(); let token_info: AuthResponse = response.json::().await?; @@ -296,7 +308,12 @@ mod tests { use super::AuthError; use super::AuthToken; + use super::AuthTokenError; + use crate::AuthState; + use crate::authtoken_scope::AuthTokenScope; + use crate::authtoken_scope::AuthTokenScopeError; use crate::types::*; + use http::HeaderMap; #[tokio::test] async fn test_from_reqwest_response_receipt() { @@ -328,7 +345,7 @@ mod tests { Err(AuthError::AuthReceipt(receipt)) => { let mut expected = auth_receipt.clone(); expected.token = Some("foobar".into()); - assert_eq!(expected, receipt); + assert_eq!(expected, *receipt); } other => { panic!("wrong response for the expected receipt error: {:?}", other); @@ -400,10 +417,361 @@ mod tests { let rsp = AuthToken::from_reqwest_response(response).await; match rsp { - Err(AuthError::AuthTokenNotInResponse) => {} + Err(AuthError::AuthToken { + source: AuthTokenError::AuthTokenNotInResponse, + }) => {} other => { panic!("wrong response: {:?}", other); } } } + + #[tokio::test] + async fn test_from_reqwest_response_malformed_json_success() { + let http_response = Builder::new() + .status(201) + .header("content-type", "application/json") + .header("x-subject-token", "foobar") + .body(String::from("{invalid")) + .unwrap(); + + let response: Response = Response::from(http_response); + + let rsp = AuthToken::from_reqwest_response(response).await; + // reqwest wraps serde errors in reqwest::Error, so it becomes AuthError::Serde or AuthError::Reqwest + assert!(matches!( + rsp, + Err(AuthError::Serde { .. }) | Err(AuthError::Reqwest { .. }) + )); + } + + #[tokio::test] + async fn test_from_reqwest_response_500_json_error() { + let err = AuthErrorResponse { + error: IdentityError { + code: 500, + message: "internal server error".into(), + }, + }; + let http_response = Builder::new() + .status(500) + .header("content-type", "application/json") + .body(to_string(&err).unwrap()) + .unwrap(); + + let response: Response = Response::from(http_response); + + let rsp = AuthToken::from_reqwest_response(response).await; + match rsp { + Err(AuthError::Identity(e)) => assert_eq!(e.code, 500), + other => panic!("wrong response: {:?}", other), + } + } + + #[tokio::test] + async fn test_from_reqwest_response_502_plain_text() { + let http_response = Builder::new() + .status(502) + .body(String::from("upstream timeout")) + .unwrap(); + + let response: Response = Response::from(http_response); + + let rsp = AuthToken::from_reqwest_response(response).await; + match rsp { + Err(AuthError::UnknownAuth { code, message }) => { + assert_eq!(code, 502); + assert_eq!(message, Some("upstream timeout".into())); + } + other => panic!("wrong response: {:?}", other), + } + } + + #[tokio::test] + async fn test_from_reqwest_response_403_identity_error() { + let err = AuthErrorResponse { + error: IdentityError { + code: 403, + message: "forbidden".into(), + }, + }; + let http_response = Builder::new() + .status(403) + .header("content-type", "application/json") + .body(to_string(&err).unwrap()) + .unwrap(); + + let response: Response = Response::from(http_response); + + let rsp = AuthToken::from_reqwest_response(response).await; + match rsp { + Err(AuthError::Identity(e)) => assert_eq!(e.code, 403), + other => panic!("wrong response: {:?}", other), + } + } + + #[tokio::test] + async fn test_from_reqwest_response_401_receipt_malformed_json() { + let http_response = Builder::new() + .status(401) + .header("openstack-auth-receipt", "foobar") + .body(String::from("{malformed")) + .unwrap(); + + let response: Response = Response::from(http_response); + + let rsp = AuthToken::from_reqwest_response(response).await; + match rsp { + Err(AuthError::UnknownAuth { code, message }) => { + assert_eq!(code, 401); + assert_eq!(message, Some("{malformed".into())); + } + other => panic!("wrong response: {:?}", other), + } + } + + #[tokio::test] + async fn test_from_reqwest_response_status_200() { + let auth = AuthResponse::default(); + let http_response = Builder::new() + .status(200) + .header("content-type", "application/json") + .header("x-subject-token", "tok") + .body(to_string(&auth).unwrap()) + .unwrap(); + + let response: Response = Response::from(http_response); + + let rsp = AuthToken::from_reqwest_response(response).await; + assert!(rsp.is_ok()); + assert_eq!(rsp.unwrap().token.expose_secret(), "tok"); + } + + #[test] + fn test_auth_token_from_str() { + let auth: AuthToken = "my-secret-token".into(); + assert_eq!("my-secret-token", auth.token.expose_secret()); + assert!(auth.auth_info.is_none()); + } + + #[test] + fn test_auth_token_set_header_success() { + let auth = AuthToken::new("valid-token", None); + let mut headers = HeaderMap::new(); + auth.set_header(&mut headers).unwrap(); + let val = headers.get("X-Auth-Token").expect("header missing"); + assert_eq!(val.to_str().unwrap(), "valid-token"); + } + + #[test] + fn test_auth_token_debug_no_secret_leak() { + let auth = AuthToken::new("super-secret", None); + let debug = format!("{:?}", auth); + assert!(debug.contains("AuthToken")); + assert!(!debug.contains("super-secret")); + } + + #[test] + fn test_auth_token_get_scope_unscoped() { + let auth = AuthToken::new("tok", None); + assert!(matches!(auth.get_scope(), AuthTokenScope::Unscoped)); + } + + #[test] + fn test_auth_token_get_scope_project() { + let auth = AuthToken::new( + "tok", + Some(AuthResponse { + token: TokenInfo { + project: Some(Project { + id: Some("1".into()), + name: Some("p".into()), + domain: None, + }), + ..Default::default() + }, + }), + ); + assert!(matches!(auth.get_scope(), AuthTokenScope::Project(_))); + } + + #[test] + fn test_auth_token_get_state_zero_offset() { + let auth = AuthToken::new( + String::new(), + Some(AuthResponse { + token: TokenInfo { + expires_at: chrono::Utc::now() + chrono::TimeDelta::milliseconds(500), + ..Default::default() + }, + }), + ); + // With zero offset, soon_expiration == expiration, so AboutToExpire is unreachable + assert!(matches!( + auth.get_state(Some(chrono::TimeDelta::zero())), + AuthState::Valid + )); + } + + #[test] + fn test_auth_token_get_state_expires_exactly_at_offset() { + let expires = chrono::Utc::now() + chrono::TimeDelta::minutes(10); + let auth = AuthToken::new( + String::new(), + Some(AuthResponse { + token: TokenInfo { + expires_at: expires, + ..Default::default() + }, + }), + ); + assert!(matches!( + auth.get_state(Some(chrono::TimeDelta::minutes(10))), + AuthState::AboutToExpire + )); + } + + #[test] + fn test_auth_token_get_state_no_offset_valid() { + let auth = AuthToken::new( + String::new(), + Some(AuthResponse { + token: TokenInfo { + expires_at: chrono::Utc::now() + chrono::TimeDelta::seconds(1), + ..Default::default() + }, + }), + ); + assert!(matches!(auth.get_state(None), AuthState::Valid)); + } + + #[test] + fn test_auth_token_get_state_expires_exactly_now() { + let now = chrono::Utc::now(); + let auth = AuthToken::new( + String::new(), + Some(AuthResponse { + token: TokenInfo { + expires_at: now, + ..Default::default() + }, + }), + ); + assert!(matches!( + auth.get_state(Some(chrono::TimeDelta::minutes(1))), + AuthState::Expired + )); + } + + #[test] + fn test_try_from_http_response_success() { + let auth = AuthResponse::default(); + let body_bytes: bytes::Bytes = to_string(&auth).unwrap().into_bytes().into(); + + let http_response = http::Response::builder() + .status(201) + .header("content-type", "application/json") + .header("x-subject-token", "tok") + .body(body_bytes) + .unwrap(); + + let result = AuthToken::try_from(http_response); + assert!(result.is_ok()); + let token = result.unwrap(); + assert_eq!("tok", token.token.expose_secret()); + } + + #[test] + fn test_try_from_http_response_no_token() { + let auth = AuthResponse::default(); + let body_bytes: bytes::Bytes = to_string(&auth).unwrap().into_bytes().into(); + + let http_response = http::Response::builder() + .status(201) + .header("content-type", "application/json") + .body(body_bytes) + .unwrap(); + + let result = AuthToken::try_from(http_response); + assert!(matches!( + result, + Err(AuthTokenError::AuthTokenNotInResponse) + )); + } + + #[test] + fn test_try_from_http_response_malformed_json() { + let http_response = http::Response::builder() + .status(201) + .header("x-subject-token", "tok") + .body(bytes::Bytes::from_static(b"{bad")) + .unwrap(); + + let result = AuthToken::try_from(http_response); + assert!(matches!(result, Err(AuthTokenError::Serde { .. }))); + } + + #[test] + fn test_auth_token_error_auth_request() { + #[derive(Debug)] + struct MyErr(&'static str); + impl std::fmt::Display for MyErr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } + impl std::error::Error for MyErr {} + + let err = AuthTokenError::auth_request(MyErr("req")); + assert!(matches!(err, AuthTokenError::AuthRequest { .. })); + assert!(format!("{}", err).contains("req")); + } + + #[test] + fn test_auth_token_error_plugin() { + #[derive(Debug)] + struct MyErr(&'static str); + impl std::fmt::Display for MyErr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } + impl std::error::Error for MyErr {} + + let err = AuthTokenError::plugin(MyErr("plug")); + assert!(matches!(err, AuthTokenError::Plugin { .. })); + assert!(format!("{}", err).contains("plug")); + } + + #[test] + fn test_from_auth_token_scope_error_to_auth_error() { + let scope_err = AuthTokenScopeError::MissingScope; + let auth_err: AuthError = scope_err.into(); + assert!(matches!( + auth_err, + AuthError::AuthToken { + source: AuthTokenError::Scope { .. } + } + )); + } + + #[test] + fn test_auth_token_error_identity_method_display() { + let err = AuthTokenError::IdentityMethod { + auth_type: "jwt".into(), + }; + assert!(format!("{}", err).contains("jwt")); + } + + #[test] + fn test_auth_token_error_missing_data_display() { + let err = AuthTokenError::MissingAuthData; + assert_eq!(format!("{}", err), "Auth data is missing"); + } + + #[test] + fn test_auth_token_error_missing_url_display() { + let err = AuthTokenError::MissingAuthUrl; + assert_eq!(format!("{}", err), "Auth URL is missing"); + } } diff --git a/sdk/auth-core/src/authtoken_scope.rs b/sdk/auth-core/src/authtoken_scope.rs index b6d006796..c852bdfb2 100644 --- a/sdk/auth-core/src/authtoken_scope.rs +++ b/sdk/auth-core/src/authtoken_scope.rs @@ -49,7 +49,7 @@ pub enum AuthTokenScopeError { } /// Represents AuthToken authorization scope -#[derive(Clone, Deserialize, Eq, Hash, PartialEq, Serialize, Debug)] +#[derive(Clone, Deserialize, Eq, Hash, PartialEq, Serialize, Debug, Default)] #[serde(rename_all = "lowercase")] pub enum AuthTokenScope { /// Project @@ -59,5 +59,261 @@ pub enum AuthTokenScope { /// System System(System), /// Unscoped + #[default] Unscoped, } + +impl AuthTokenScope { + /// Checks if this scope (the requested scope) matches another scope (the cached scope). + /// This implements "wildcard" matching: if a field in the requested scope is `None`, + /// it matches any value in the cached scope. + pub fn matches(&self, cached: &Self) -> bool { + match (self, cached) { + (AuthTokenScope::Project(req), AuthTokenScope::Project(cached)) => { + let id_match = req + .id + .as_ref() + .is_none_or(|id| cached.id.as_ref() == Some(id)); + let name_match = req + .name + .as_ref() + .is_none_or(|name| cached.name.as_ref() == Some(name)); + let domain_match = if let Some(req_domain) = &req.domain { + if let Some(cached_domain) = &cached.domain { + let d_id_match = req_domain + .id + .as_ref() + .is_none_or(|id| cached_domain.id.as_ref() == Some(id)); + let d_name_match = req_domain + .name + .as_ref() + .is_none_or(|name| cached_domain.name.as_ref() == Some(name)); + d_id_match && d_name_match + } else { + true + } + } else { + true + }; + id_match && name_match && domain_match + } + (AuthTokenScope::Domain(req), AuthTokenScope::Domain(cached)) => { + let id_match = req + .id + .as_ref() + .is_none_or(|id| cached.id.as_ref() == Some(id)); + let name_match = req + .name + .as_ref() + .is_none_or(|name| cached.name.as_ref() == Some(name)); + id_match && name_match + } + (AuthTokenScope::Unscoped, AuthTokenScope::Unscoped) => true, + (AuthTokenScope::System(_), AuthTokenScope::System(_)) => true, + _ => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_project_scope(id: Option<&str>, name: Option<&str>) -> AuthTokenScope { + AuthTokenScope::Project(Project { + id: id.map(|s| s.to_string()), + name: name.map(|s| s.to_string()), + domain: None, + }) + } + + fn make_domain_scope(id: Option<&str>, name: Option<&str>) -> AuthTokenScope { + AuthTokenScope::Domain(Domain { + id: id.map(|s| s.to_string()), + name: name.map(|s| s.to_string()), + }) + } + + // Project scope tests + #[test] + fn test_matches_project_exact_id() { + let req = make_project_scope(Some("proj-1"), None); + let cached = make_project_scope(Some("proj-1"), None); + assert!(req.matches(&cached)); + } + + #[test] + fn test_matches_project_id_mismatch() { + let req = make_project_scope(Some("proj-1"), None); + let cached = make_project_scope(Some("proj-2"), None); + assert!(!req.matches(&cached)); + } + + #[test] + fn test_matches_project_req_id_none_matches_cached_id_some() { + let req = make_project_scope(None, Some("myproj")); + let cached = make_project_scope(Some("proj-1"), Some("myproj")); + assert!(req.matches(&cached)); + } + + #[test] + fn test_matches_project_req_name_none_matches_cached_name() { + let req = make_project_scope(None, None); + let cached = make_project_scope(Some("proj-1"), Some("myproj")); + assert!(req.matches(&cached)); + } + + #[test] + fn test_matches_project_name_mismatch_when_req_name_some() { + let req = make_project_scope(None, Some("req-name")); + let cached = make_project_scope(Some("proj-1"), Some("cached-name")); + assert!(!req.matches(&cached)); + } + + #[test] + fn test_matches_project_both_none() { + let req = make_project_scope(None, None); + let cached = make_project_scope(None, None); + assert!(req.matches(&cached)); + } + + // Domain scope tests + #[test] + fn test_matches_domain_exact_id() { + let req = make_domain_scope(Some("d1"), None); + let cached = make_domain_scope(Some("d1"), None); + assert!(req.matches(&cached)); + } + + #[test] + fn test_matches_domain_id_mismatch() { + let req = make_domain_scope(Some("d1"), None); + let cached = make_domain_scope(Some("d2"), None); + assert!(!req.matches(&cached)); + } + + #[test] + fn test_matches_domain_req_id_none_matches_cached() { + let req = make_domain_scope(None, Some("D")); + let cached = make_domain_scope(Some("d1"), Some("D")); + assert!(req.matches(&cached)); + } + + #[test] + fn test_matches_domain_name_mismatch_when_req_name_some() { + let req = make_domain_scope(None, Some("req-D")); + let cached = make_domain_scope(Some("d1"), Some("cached-D")); + assert!(!req.matches(&cached)); + } + + // Cross-type tests + #[test] + fn test_matches_project_does_not_match_domain() { + let req = make_project_scope(Some("p1"), None); + let cached = make_domain_scope(Some("p1"), None); + assert!(!req.matches(&cached)); + } + + #[test] + fn test_matches_domain_does_not_match_project() { + let req = make_domain_scope(Some("d1"), None); + let cached = make_project_scope(Some("d1"), None); + assert!(!req.matches(&cached)); + } + + #[test] + fn test_matches_unscoped_matches_unscoped() { + let req = AuthTokenScope::Unscoped; + let cached = AuthTokenScope::Unscoped; + assert!(req.matches(&cached)); + } + + #[test] + fn test_matches_unscoped_does_not_match_project() { + let req = AuthTokenScope::Unscoped; + let cached = make_project_scope(Some("p1"), None); + assert!(!req.matches(&cached)); + } + + // Domain-within-project tests + #[test] + fn test_matches_project_with_domain_exact() { + let req = AuthTokenScope::Project(Project { + id: Some("project-id".to_string()), + name: None, + domain: Some(Domain { + id: Some("domain-id".to_string()), + name: None, + }), + }); + let cached = AuthTokenScope::Project(Project { + id: Some("project-id".to_string()), + name: None, + domain: Some(Domain { + id: Some("domain-id".to_string()), + name: Some("D".to_string()), + }), + }); + assert!(req.matches(&cached)); + } + + #[test] + fn test_matches_project_with_domain_mismatch() { + let req = AuthTokenScope::Project(Project { + id: Some("project-id".to_string()), + name: None, + domain: Some(Domain { + id: Some("domain-1".to_string()), + name: None, + }), + }); + let cached = AuthTokenScope::Project(Project { + id: Some("project-id".to_string()), + name: None, + domain: Some(Domain { + id: Some("domain-2".to_string()), + name: Some("D".to_string()), + }), + }); + assert!(!req.matches(&cached)); + } + + #[test] + fn test_matches_project_cached_has_domain_req_has_no_domain() { + let req = AuthTokenScope::Project(Project { + id: Some("project-id".to_string()), + name: None, + domain: None, + }); + let cached = AuthTokenScope::Project(Project { + id: Some("project-id".to_string()), + name: None, + domain: Some(Domain { + id: Some("domain-id".to_string()), + name: None, + }), + }); + assert!(req.matches(&cached)); + } + + #[test] + fn test_matches_project_req_domain_by_name() { + let req = AuthTokenScope::Project(Project { + id: None, + name: Some("myproj".to_string()), + domain: Some(Domain { + id: None, + name: Some("Default".to_string()), + }), + }); + let cached = AuthTokenScope::Project(Project { + id: Some("proj-123".to_string()), + name: Some("myproj".to_string()), + domain: Some(Domain { + id: Some("d1".to_string()), + name: Some("Default".to_string()), + }), + }); + assert!(req.matches(&cached)); + } +} diff --git a/sdk/auth-core/src/lib.rs b/sdk/auth-core/src/lib.rs index f9a0ecb10..e7975a656 100644 --- a/sdk/auth-core/src/lib.rs +++ b/sdk/auth-core/src/lib.rs @@ -12,6 +12,401 @@ // // SPDX-License-Identifier: Apache-2.0 //! # Core trait for implementing OpenStack authentication plugins to [`openstack_sdk`] +//! +//! This crate provides the foundational types and traits required to authenticate +//! against the OpenStack Identity service (Keystone). It defines: +//! +//! - [`OpenStackAuthType`] — the primary trait that all authentication plugins must +//! implement to enable login flows (password, token, JWT, WebSSO, etc.). +//! - [`OpenStackMultifactorAuthMethod`] — the trait for multifactor-capable methods +//! that can be composed into multipass/multifactor authentication requests. +//! - [`AuthToken`] — the structure that represents a successful Keystone authentication +//! result, including the bearer token and parsed server response. +//! - [`AuthTokenScope`] — represents the authorization scope attached to a token +//! (project, domain, system, or unscoped). +//! - [`AuthError`] — the unified error type covering all authentication-related +//! failures, including receipts for multipass flows. +//! - [`Auth`] — the enum that wraps the current authentication state (token or none). +//! - [`AuthState`] — describes token validity (valid, expired, about-to-expire, or unset). +//! +//! ## Plugin Registration +//! +//! Authentication plugins are registered at compile time using Rust's `inventory` crate. +//! Implement [`OpenStackAuthType`] and submit an [`AuthPluginRegistration`] via +//! `inventory::submit!{}`. For multipass support, additionally implement +//! [`OpenStackMultifactorAuthMethod`] and submit an [`AuthMethodPluginRegistration`]. +//! +//! The [`execute_auth_request`] function provides a common, instrumented pathway +//! for sending authentication HTTP requests with timing and request-id logging. +//! +//! ## Examples +//! +//! ### Basic Authentication +//! +//! Authenticate by creating an `AuthToken` from an existing token string +//! (e.g., obtained from an environment variable or a previous login): +//! +//! ```no_run +//! use secrecy::{ExposeSecret, SecretString}; +//! use openstack_sdk_auth_core::{AuthToken, AuthTokenScope}; +//! +//! # async fn example() -> Result<(), Box> { +//! let token = AuthToken::new(SecretString::from("my-token"), None); +//! let scope = AuthTokenScope::Unscoped; +//! +//! println!("Scope: {:?}", token.get_scope()); +//! println!("State: {:?}", token.get_state(None)); +//! # Ok(()) +//! # } +//! ``` +//! +//! ### Checking Token State +//! +//! Tokens can be checked for validity with an optional expiration threshold: +//! +//! ```no_run +//! use chrono::TimeDelta; +//! use openstack_sdk_auth_core::AuthState; +//! # async fn example() -> Result<(), Box> { +//! # let auth = openstack_sdk_auth_core::AuthToken::default(); +//! let offset = TimeDelta::minutes(5); +//! match auth.get_state(Some(offset)) { +//! AuthState::Valid => println!("Token is valid"), +//! AuthState::AboutToExpire => println!("Token will expire within 5 minutes"), +//! AuthState::Expired => println!("Token has expired"), +//! AuthState::Unset => println!("No token data available"), +//! } +//! # Ok(()) +//! # } +//! ``` +//! +//! ### Setting Request Headers +//! +//! The `Auth` type can inject the `X-Auth-Token` header into outgoing HTTP requests: +//! +//! ```no_run +//! use http::HeaderMap; +//! # async fn example() -> Result<(), Box> { +//! # let auth = openstack_sdk_auth_core::Auth::None; +//! let mut headers = HeaderMap::new(); +//! auth.set_header(&mut headers)?; +//! // headers now contains the `X-Auth-Token` header if the auth type is a token. +//! # Ok(()) +//! # } +//! ``` +//! +//! ### Implementing a Custom Authenticator +//! +//! Creating a new authentication method requires implementing [`OpenStackAuthType`]: +//! +//! ```no_run +//! use std::collections::HashMap; +//! use async_trait::async_trait; +//! use secrecy::SecretString; +//! use serde_json::{Value, json}; +//! use openstack_sdk_auth_core::{Auth, AuthError, AuthToken, AuthTokenScope, OpenStackAuthType}; +//! +//! pub struct MyAuthenticator; +//! +//! static PLUGIN: MyAuthenticator = MyAuthenticator; +//! inventory::submit! { +//! openstack_sdk_auth_core::AuthPluginRegistration { method: &PLUGIN } +//! } +//! +//! #[async_trait] +//! impl OpenStackAuthType for MyAuthenticator { +//! fn get_supported_auth_methods(&self) -> Vec<&'static str> { +//! vec!["v3myauth", "myauth"] +//! } +//! +//! fn requirements(&self, _hints: Option<&Value>) -> Result { +//! Ok(json!({ +//! "type": "object", +//! "required": ["token_id"], +//! "properties": { +//! "token_id": { +//! "type": "string", +//! "description": "The token identifier" +//! } +//! } +//! })) +//! } +//! +//! fn api_version(&self) -> (u8, u8) { +//! (3, 0) +//! } +//! +//! async fn auth( +//! &self, +//! _http_client: &reqwest::Client, +//! _identity_url: &url::Url, +//! _values: &HashMap, +//! _scope: Option<&AuthTokenScope>, +//! _hints: Option<&Value>, +//! ) -> Result { +//! // Perform your authentication logic here and return an AuthToken +//! let auth_token = AuthToken::new("example-token", None); +//! Ok(Auth::AuthToken(Box::new(auth_token))) +//! } +//! } +//! ``` +//! +//! ### Error Handling Patterns +//! +//! The [`AuthError`] enum provides different error variants that you can match on +//! to handle specific failure scenarios: +//! +//! ```no_run +//! use openstack_sdk_auth_core::{AuthError, Auth, AuthToken, AuthTokenScope}; +//! # async fn example() -> Result<(), Box> { +//! # fn simulate_auth_result() -> Result { Ok(Auth::None) } +//! +//! match simulate_auth_result() { +//! Ok(Auth::AuthToken(token)) => { +//! println!("Authentication successful"); +//! println!("Token expires at: {:?}", token.get_state(None)); +//! } +//! Ok(Auth::None) => { +//! println!("No authentication available"); +//! } +//! Ok(auth) => { +//! println!("Unknown auth type: {:?}", auth); +//! } +//! Err(AuthError::AuthReceipt(receipt)) => { +//! println!("Multifactor authentication required"); +//! let methods: Vec<_> = receipt.required_auth_methods.iter() +//! .flatten() +//! .cloned() +//! .collect(); +//! println!("Required methods: {:?}", methods); +//! } +//! Err(AuthError::Serde { .. }) => { +//! println!("Failed to parse response (malformed JSON)"); +//! } +//! Err(AuthError::UnknownAuth { code, message }) => { +//! println!("Unknown authentication error (code: {})", code); +//! if let Some(msg) = &message { +//! println!("Message: {}", msg); +//! } +//! } +//! Err(e) => { +//! println!("Authentication failed: {}", e); +//! } +//! } +//! # Ok(()) +//! # } +//! ``` +//! +//! ### Implementing a Multifactor-Authenticator +//! +//! To support multifactor authentication (e.g., TOTP or password + token), +//! implement both [`OpenStackAuthType`] and [`OpenStackMultifactorAuthMethod`]: +//! +//! ```no_run +//! use std::collections::HashMap; +//! use async_trait::async_trait; +//! use secrecy::{SecretString, ExposeSecret}; +//! use serde_json::{Value, json}; +//! use openstack_sdk_auth_core::{ +//! Auth, AuthError, AuthToken, AuthTokenScope, +//! OpenStackAuthType, OpenStackMultifactorAuthMethod, +//! }; +//! +//! pub struct MyMultifactorAuthenticator; +//! +//! static PLUGIN: MyMultifactorAuthenticator = MyMultifactorAuthenticator; +//! inventory::submit! { +//! openstack_sdk_auth_core::AuthPluginRegistration { method: &PLUGIN } +//! } +//! inventory::submit! { +//! openstack_sdk_auth_core::AuthMethodPluginRegistration { method: &PLUGIN } +//! } +//! +//! #[async_trait] +//! impl OpenStackMultifactorAuthMethod for MyMultifactorAuthenticator { +//! fn get_supported_auth_methods(&self) -> Vec<&'static str> { +//! vec!["v3myauth", "myauth"] +//! } +//! +//! fn requirements(&self, _hints: Option<&Value>) -> Result { +//! Ok(json!({ +//! "type": "object", +//! "required": ["auth_code"], // Additional auth method requirement +//! "properties": { +//! "auth_code": { +//! "type": "string", +//! "format": "password", +//! "description": "One-time authentication code" +//! } +//! } +//! })) +//! } +//! +//! /// Extracts authentication data from the values map. +//! fn get_auth_data( +//! &self, +//! values: &HashMap, +//! ) -> Result<(&'static str, Value), AuthError> { +//! let auth_code = values +//! .get("auth_code") +//! .ok_or_else(|| AuthError::AuthValueNotSupplied("auth_code".to_string()))?; +//! Ok(("myauth", json!({ +//! "auth_code": auth_code.expose_secret() +//! }))) +//! } +//! } +//! +//! #[async_trait] +//! impl OpenStackAuthType for MyMultifactorAuthenticator { +//! fn get_supported_auth_methods(&self) -> Vec<&'static str> { +//! vec!["v3myauth", "myauth"] +//! } +//! +//! fn requirements(&self, _hints: Option<&Value>) -> Result { +//! Ok(json!({ +//! "type": "object", +//! "required": ["auth_code"], +//! "properties": { +//! "auth_code": { +//! "type": "string", +//! "format": "password", +//! "description": "One-time authentication code" +//! } +//! } +//! })) +//! } +//! +//! fn api_version(&self) -> (u8, u8) { +//! (3, 0) +//! } +//! +//! async fn auth( +//! &self, +//! _http_client: &reqwest::Client, +//! _identity_url: &url::Url, +//! _values: &HashMap, +//! _scope: Option<&AuthTokenScope>, +//! _hints: Option<&Value>, +//! ) -> Result { +//! // Perform multifactor authentication and return an AuthToken: +//! let auth_token = AuthToken::new("example-token", None); +//! Ok(Auth::AuthToken(Box::new(auth_token))) +//! } +//! } +//! ``` +//! +//! ### Token Scoping +//! +//! Tokens can be scoped to specific projects or domains using [`AuthTokenScope`]: +//! +//! ``` +//! use openstack_sdk_auth_core::{AuthTokenScope}; +//! use openstack_sdk_auth_core::types::{Project, Domain}; +//! +//! // Project-scope by ID +//! let project_scope = AuthTokenScope::Project(Project { +//! id: Some("project-id-123".to_string()), +//! name: None, +//! domain: None, +//! }); +//! +//! // Domain-scope with domain name +//! let domain_scope = AuthTokenScope::Domain(Domain { +//! id: None, +//! name: Some("Default".to_string()), +//! }); +//! +//! // Unscoped (default) +//! let unscoped: AuthTokenScope = AuthTokenScope::default(); +//! ``` +//! +//! ### Handling Authentication Receipts +//! +//! When multifactor authentication is enabled, Keystone returns an authentication receipt +//! instead of a new token. The receipt contains information about what additional +//! authentication methods are required: +//! +//! ```no_run +//! use openstack_sdk_auth_core::{AuthTokenScope, AuthError, Auth, AuthReceiptResponse, AuthReceipt}; +//! use chrono::Local; +//! +//! # async fn example() -> Result<(), Box> { +//! # fn simulate_auth_result() -> Result { +//! # Ok(Auth::None) +//! # } +//! +//! let scope = Some(&AuthTokenScope::Unscoped); +//! match simulate_auth_result() { +//! Ok(auth) => { +//! match auth { +//! Auth::AuthToken(token) => { +//! println!("Authenticated with token"); +//! println!("Token: {:?}", token.get_scope()); +//! } +//! Auth::None => { +//! println!("Not authenticated"); +//! } +//! _ => { +//! println!("Unknown auth type"); +//! } +//! } +//! } +//! Err(AuthError::AuthReceipt(receipt)) => { +//! println!("Additional authentication methods required"); +//! println!("Required methods: {:?}", receipt.required_auth_methods); +//! let methods: Vec<_> = receipt.receipt.methods.iter().cloned().collect(); +//! println!("Already completed methods: {:?}", methods); +//! +//! // Use the receipt token for subsequent authentication requests +//! if let Some(receipt_token) = &receipt.token { +//! println!("Receipt token: {}...", &receipt_token.chars().take(8).collect::()); +//! } +//! } +//! Err(e) => { +//! println!("Authentication error: {}", e); +//! } +//! } +//! # Ok(()) +//! # } +//! ``` +//! ### Token State Management +//! +//! The [`AuthToken::get_state`] method can be used to determine if a token is valid and +//! needs to be refreshed. Using an expiration offset allows you to proactively refresh +//! tokens before they expire: +//! +//! ``` +//! use chrono::TimeDelta; +//! use openstack_sdk_auth_core::{AuthToken, AuthState, types::{AuthResponse, TokenInfo, User, Project}}; +//! use secrecy::SecretString; +//! +//! // Create a token with expiration info +//! let auth_info = AuthResponse { +//! token: TokenInfo { +//! expires_at: chrono::Utc::now() + TimeDelta::hours(24), +//! user: User::default(), +//! ..Default::default() +//! }, +//! }; +//! let token = AuthToken::new("my-token", Some(auth_info)); +//! +//! // Check with 5-minute buffer for proactive refresh +//! match token.get_state(Some(TimeDelta::minutes(5))) { +//! AuthState::Valid => { +//! println!("Token is valid"); +//! } +//! AuthState::AboutToExpire => { +//! println!("Token will expire soon, refreshing..."); +//! // Refresh the token before it expires +//! } +//! AuthState::Expired => { +//! println!("Token has expired, re-authenticating..."); +//! } +//! AuthState::Unset => { +//! println!("No token data available"); +//! } +//! } +//! ``` use std::collections::HashMap; use std::fmt::{self, Debug}; @@ -22,7 +417,7 @@ use http::{HeaderMap, HeaderValue}; use reqwest::{Client, Request, Response}; use secrecy::SecretString; use thiserror::Error; -use tracing::{Level, event, info, instrument}; +use tracing::{Level, event, instrument}; pub mod authtoken; pub mod authtoken_scope; @@ -38,7 +433,7 @@ pub use types::*; pub enum AuthError { /// Authentication rejected with a receipt. #[error("authentication rejected")] - AuthReceipt(AuthReceiptResponse), + AuthReceipt(Box), /// openstack-auth-receipt cannot be converted to string. #[error("authentication receipt cannot be converted to string")] @@ -52,14 +447,6 @@ pub enum AuthError { source: AuthTokenError, }, - /// Token is missing in the authentication response. - #[error("token missing in the response")] - AuthTokenNotInResponse, - - /// X-Subject-Token cannot be converted to string. - #[error("token cannot be converted to string")] - AuthTokenNotString, - /// Necessary data was not supplied to the auth method. #[error("value necessary for the chosen auth method was not supplied to the auth method")] AuthValueNotSupplied(String), @@ -154,14 +541,15 @@ pub trait OpenStackAuthType: Send + Sync { &self, http_client: &reqwest::Client, identity_url: &url::Url, - values: HashMap, + values: &HashMap, scope: Option<&AuthTokenScope>, hints: Option<&serde_json::Value>, ) -> Result; } -// This struct "wraps" the trait object so inventory can track it +/// Registry entry for authentication plugins. pub struct AuthPluginRegistration { + /// The authentication method implementation. pub method: &'static dyn OpenStackAuthType, } @@ -186,8 +574,9 @@ pub trait OpenStackMultifactorAuthMethod: Send + Sync { ) -> Result<(&'static str, serde_json::Value), AuthError>; } -// This struct "wraps" the trait object so inventory can track it +/// Registry entry for multifactor-capable authentication methods. pub struct AuthMethodPluginRegistration { + /// The authentication method implementation. pub method: &'static dyn OpenStackMultifactorAuthMethod, } inventory::collect!(AuthMethodPluginRegistration); @@ -197,7 +586,6 @@ pub async fn execute_auth_request( client: &Client, request: Request, ) -> Result { - info!("Sending request {:?}", request); let url = request.url().clone(); let method = request.method().clone(); let start = SystemTime::now(); @@ -231,15 +619,12 @@ impl Auth { /// Adds X-Auth-Token header to a request headers. /// /// Returns an error if the token string cannot be parsed as a header value. - pub fn set_header<'a>( - &self, - headers: &'a mut HeaderMap, - ) -> Result<&'a mut HeaderMap, AuthError> { + pub fn set_header(&self, headers: &mut HeaderMap) -> Result<(), AuthError> { if let Auth::AuthToken(token) = self { - let _ = token.set_header(headers); + token.set_header(headers)?; } - Ok(headers) + Ok(()) } } @@ -306,7 +691,8 @@ impl From for BuilderError { #[cfg(test)] mod tests { use super::*; - use crate::types::{AuthResponse, AuthToken}; + use crate::types::{AuthResponse, TokenInfo}; + use std::hash::{Hash, Hasher}; #[test] fn test_auth_validity_unset() { @@ -319,7 +705,7 @@ mod tests { let auth = super::AuthToken::new( String::new(), Some(AuthResponse { - token: AuthToken { + token: TokenInfo { expires_at: chrono::Utc::now() - chrono::TimeDelta::days(1), ..Default::default() }, @@ -333,7 +719,7 @@ mod tests { let auth = super::AuthToken::new( String::new(), Some(AuthResponse { - token: AuthToken { + token: TokenInfo { expires_at: chrono::Utc::now() + chrono::TimeDelta::minutes(10), ..Default::default() }, @@ -350,7 +736,7 @@ mod tests { let auth = super::AuthToken::new( String::new(), Some(AuthResponse { - token: AuthToken { + token: TokenInfo { expires_at: chrono::Utc::now() + chrono::TimeDelta::days(1), ..Default::default() }, @@ -358,4 +744,452 @@ mod tests { ); assert!(matches!(auth.get_state(None), AuthState::Valid)); } + + #[test] + fn test_auth_set_header_invalid_token() { + let auth = Auth::AuthToken(Box::new(super::AuthToken::new( + "invalid\nheader\nvalue", + None, + ))); + let mut headers = HeaderMap::new(); + let result = auth.set_header(&mut headers); + assert!(result.is_err()); + } + + #[test] + fn test_project_domain_eq_hash() { + use std::collections::HashSet; + + let p1 = Project { + id: Some("1".into()), + name: Some("proj".into()), + domain: Some(Domain { + id: Some("d1".into()), + name: Some("D".into()), + }), + }; + let p2 = Project { + id: Some("1".into()), + name: Some("proj".into()), + domain: Some(Domain { + id: Some("d1".into()), + name: Some("D".into()), + }), + }; + assert_eq!(p1, p2); + + let h1 = { + let mut h = std::collections::hash_map::DefaultHasher::new(); + p1.hash(&mut h); + h.finish() + }; + let h2 = { + let mut h = std::collections::hash_map::DefaultHasher::new(); + p2.hash(&mut h); + h.finish() + }; + assert_eq!(h1, h2, "equal projects must have equal hashes"); + + let mut pset = HashSet::new(); + pset.insert(p1.clone()); + assert!(pset.contains(&p2), "HashSet should contain equal project"); + } + + #[test] + fn test_project_both_name_and_domain_none() { + let p1 = Project { + id: Some("x".into()), + name: None, + domain: None, + }; + let p2 = Project { + id: Some("y".into()), + name: None, + domain: None, + }; + assert_ne!(p1, p2); + } + + #[test] + fn test_project_domain_none_vs_some() { + let p1 = Project { + id: None, + name: Some("p".into()), + domain: None, + }; + let p2 = Project { + id: None, + name: Some("p".into()), + domain: Some(Domain { + id: None, + name: None, + }), + }; + assert_ne!(p1, p2); + } + + #[test] + fn test_domain_both_names_none() { + let d1 = Domain { + id: Some("x".into()), + name: None, + }; + let d2 = Domain { + id: Some("y".into()), + name: None, + }; + assert_ne!(d1, d2); + } + + #[test] + fn test_domain_none_vs_some_name() { + let d1 = Domain { + id: None, + name: None, + }; + let d2 = Domain { + id: None, + name: Some("D".into()), + }; + assert_ne!(d1, d2); + } + + #[test] + fn test_project_hashset_dedup() { + use std::collections::HashSet; + let d = Domain { + id: Some("d".into()), + name: Some("D".into()), + }; + let p1 = Project { + id: Some("1".into()), + name: Some("p".into()), + domain: Some(d.clone()), + }; + let p2 = Project { + id: Some("1".into()), + name: Some("p".into()), + domain: Some(d.clone()), + }; + let p3 = Project { + id: Some("1".into()), + name: Some("p".into()), + domain: Some(d), + }; + + let mut set = HashSet::new(); + set.insert(p1); + set.insert(p2); + set.insert(p3); + assert_eq!( + set.len(), + 1, + "HashSet should deduplicate identical projects" + ); + } + + #[test] + fn test_auth_scope_from_response_project() { + let response = AuthResponse { + token: TokenInfo { + project: Some(Project { + id: Some("1".into()), + name: Some("p".into()), + domain: None, + }), + ..Default::default() + }, + }; + let scope: AuthTokenScope = (&response).into(); + assert!(matches!(scope, AuthTokenScope::Project(_))); + } + + #[test] + fn test_auth_scope_from_response_domain() { + let response = AuthResponse { + token: TokenInfo { + project: None, + domain: Some(Domain { + id: Some("1".into()), + name: Some("D".into()), + }), + ..Default::default() + }, + }; + let scope: AuthTokenScope = (&response).into(); + assert!(matches!(scope, AuthTokenScope::Domain(_))); + } + + #[test] + fn test_auth_scope_from_response_system() { + let response = AuthResponse { + token: TokenInfo { + project: None, + domain: None, + system: Some(System { all: Some(true) }), + ..Default::default() + }, + }; + let scope: AuthTokenScope = (&response).into(); + assert!(matches!(scope, AuthTokenScope::System(_))); + } + + #[test] + fn test_auth_scope_from_response_unscoped() { + let response = AuthResponse { + token: TokenInfo { + project: None, + domain: None, + system: None, + ..Default::default() + }, + }; + let scope: AuthTokenScope = (&response).into(); + assert!(matches!(scope, AuthTokenScope::Unscoped)); + } + + #[test] + fn test_auth_scope_priority_project_over_domain() { + let response = AuthResponse { + token: TokenInfo { + project: Some(Project { + id: Some("1".into()), + name: Some("p".into()), + domain: None, + }), + domain: Some(Domain { + id: Some("1".into()), + name: Some("D".into()), + }), + ..Default::default() + }, + }; + let scope: AuthTokenScope = (&response).into(); + assert!(matches!(scope, AuthTokenScope::Project(_))); + } + + #[test] + fn test_auth_scope_priority_domain_over_system() { + let response = AuthResponse { + token: TokenInfo { + project: None, + domain: Some(Domain { + id: Some("1".into()), + name: Some("D".into()), + }), + system: Some(System { all: Some(true) }), + ..Default::default() + }, + }; + let scope: AuthTokenScope = (&response).into(); + assert!(matches!(scope, AuthTokenScope::Domain(_))); + } + + #[test] + fn test_auth_none_set_header_noop() { + let auth = Auth::None; + let mut headers = HeaderMap::new(); + let result = auth.set_header(&mut headers); + assert!(result.is_ok()); + assert!(headers.is_empty()); + } + + #[test] + fn test_auth_token_set_header() { + let auth = Auth::AuthToken(Box::new(super::AuthToken::new("my-token", None))); + let mut headers = HeaderMap::new(); + auth.set_header(&mut headers).unwrap(); + assert!(headers.contains_key("X-Auth-Token")); + } + + #[test] + fn test_auth_debug_token() { + let auth = Auth::AuthToken(Box::new(super::AuthToken::new("tok", None))); + let debug = format!("{:?}", auth); + assert!(debug.contains("Token")); + } + + #[test] + fn test_auth_debug_none() { + let auth = Auth::None; + let debug = format!("{:?}", auth); + assert!(debug.contains("unauthed")); + } + + #[test] + fn test_try_from_http_response_for_auth() { + let auth = AuthResponse::default(); + let json = serde_json::to_string(&auth).unwrap(); + let body: bytes::Bytes = json.into_bytes().into(); + + let http_response = http::Response::builder() + .header("x-subject-token", "tok") + .body(body) + .unwrap(); + + let result = Auth::try_from(http_response); + assert!(result.is_ok()); + assert!(matches!(result.unwrap(), Auth::AuthToken(_))); + } + + #[test] + fn test_auth_error_plugin_constructor() { + #[derive(Debug)] + struct MyErr(&'static str); + impl std::fmt::Display for MyErr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } + impl std::error::Error for MyErr {} + + let err = AuthError::plugin(MyErr("plug")); + assert!(matches!(err, AuthError::Plugin { .. })); + assert!(format!("{}", err).contains("plug")); + } + + #[test] + fn test_auth_error_display_unknown_auth_with_message() { + let err = AuthError::UnknownAuth { + code: 500, + message: Some("upstream timeout".into()), + }; + assert_eq!(format!("{}", err), "identity service error"); + } + + #[test] + fn test_auth_error_display_unknown_auth_without_message() { + let err = AuthError::UnknownAuth { + code: 502, + message: None, + }; + assert_eq!(format!("{}", err), "identity service error"); + } + + #[test] + fn test_auth_error_display_auth_value_not_supplied() { + let err = AuthError::AuthValueNotSupplied("user".to_string()); + assert!(format!("{}", err).contains("value necessary")); + } + + #[test] + fn test_builder_error_from_string() { + let err: BuilderError = "custom validation".to_string().into(); + assert!(matches!(err, BuilderError::Validation(_))); + assert!(format!("{}", err).contains("custom validation")); + } + + #[test] + fn test_auth_token_scope_default() { + let scope = AuthTokenScope::default(); + assert!(matches!(scope, AuthTokenScope::Unscoped)); + } + + #[test] + fn test_auth_receipt_not_string_display() { + let err = AuthError::AuthReceiptNotString; + assert!(format!("{}", err).contains("receipt")); + } + + #[test] + fn test_auth_error_identity_display() { + let err = AuthError::Identity(IdentityError { + code: 401, + message: "unauthorized".into(), + }); + assert!(format!("{}", err).contains("authentication method error")); + assert!(format!("{}", err).contains("unauthorized")); + } + + #[test] + fn test_name_or_id_serialize() { + let id = NameOrId::Id("abc".into()); + let json = serde_json::to_string(&id).unwrap(); + assert!(json.contains("id")); + assert!(json.contains("abc")); + + let name = NameOrId::Name("myproj".into()); + let json = serde_json::to_string(&name).unwrap(); + assert!(json.contains("name")); + assert!(json.contains("myproj")); + } + + #[test] + fn test_name_or_id_deserialize() { + let json = r#"{"id": "abc"}"#; + let parsed: NameOrId = serde_json::from_str(json).unwrap(); + assert!(matches!(parsed, NameOrId::Id(ref s) if s == "abc")); + + let json = r#"{"name": "myproj"}"#; + let parsed: NameOrId = serde_json::from_str(json).unwrap(); + assert!(matches!(parsed, NameOrId::Name(ref s) if s == "myproj")); + } + + #[test] + fn test_project_serialize_human_readable_skips_none() { + let p = Project { + id: Some("1".into()), + name: None, + domain: None, + }; + let json = serde_json::to_string(&p).unwrap(); + assert!(json.contains("id")); + } + + #[test] + fn test_domain_serialize_human_readable_skips_none() { + let d = Domain { + id: Some("1".into()), + name: None, + }; + let json = serde_json::to_string(&d).unwrap(); + assert!(json.contains("id")); + } + + #[test] + fn test_project_serialize_all_fields() { + let d = Domain { + id: Some("d".into()), + name: Some("D".into()), + }; + let p = Project { + id: Some("1".into()), + name: Some("p".into()), + domain: Some(d), + }; + let json = serde_json::to_string(&p).unwrap(); + assert!(json.contains("id")); + assert!(json.contains("name")); + assert!(json.contains("domain")); + } + + #[test] + fn test_auth_state_debug() { + assert_eq!(format!("{:?}", AuthState::Valid), "Valid"); + assert_eq!(format!("{:?}", AuthState::Expired), "Expired"); + assert_eq!(format!("{:?}", AuthState::AboutToExpire), "AboutToExpire"); + assert_eq!(format!("{:?}", AuthState::Unset), "Unset"); + } + + #[test] + fn test_auth_receipt_response_serialization() { + let receipt = AuthReceiptResponse { + receipt: AuthReceipt { + methods: vec!["password".to_string()], + user: User { + id: "u".to_string(), + name: "user".to_string(), + ..Default::default() + }, + expires_at: chrono::Local::now(), + ..Default::default() + }, + required_auth_methods: vec![vec!["totp".to_string()]], + token: Some("tok".to_string()), + }; + let json = serde_json::to_string(&receipt).unwrap(); + assert!(json.contains("receipt")); + assert!(json.contains("password")); + assert!(json.contains("totp")); + } } diff --git a/sdk/auth-core/src/types.rs b/sdk/auth-core/src/types.rs index c61a6aada..ef9ec2747 100644 --- a/sdk/auth-core/src/types.rs +++ b/sdk/auth-core/src/types.rs @@ -13,25 +13,23 @@ // SPDX-License-Identifier: Apache-2.0 //! Types of the SDK authentication methods -use std::hash::{Hash, Hasher}; - use chrono::prelude::*; use derive_builder::Builder; use serde::{Deserialize, Serialize}; use crate::BuilderError; -/// A reference to a resource by its Name and ID. +/// A reference to a resource by both its name and ID. #[derive(Deserialize, Debug, Clone, Eq, PartialEq, Serialize)] pub struct IdAndName { - /// The name of the entity. - pub name: String, /// The UID for the entity. pub id: String, + /// The name of the entity. + pub name: String, } /// A reference to a resource by either its Name or ID. -#[derive(Clone, Debug, Hash, PartialEq, Serialize)] +#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum NameOrId { /// Resource ID. #[serde(rename = "id")] @@ -41,54 +39,77 @@ pub enum NameOrId { Name(String), } -/// AuthResponse structure returned by token authentication calls +/// Authentication response structure returned by token calls. #[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)] pub struct AuthResponse { - pub token: AuthToken, + /// Token information. + pub token: TokenInfo, } -/// AuthToken response information +/// AuthToken response information. #[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)] -pub struct AuthToken { - /// Application credential information +pub struct TokenInfo { + /// Application credential information. pub application_credential: Option, + /// Catalog of available services. pub catalog: Option>, - pub roles: Option>, - pub user: User, - pub project: Option, + /// Domain in which the token was issued. pub domain: Option, - pub system: Option, - pub issued_at: Option>, + /// Token expiration time. pub expires_at: DateTime, + /// Token issue time. + pub issued_at: Option>, + /// Project in which the token was issued. + pub project: Option, + /// Roles assigned to the token. + pub roles: Option>, + /// System scope of the token. + pub system: Option, + /// User who obtained the token. + pub user: User, } +/// Service endpoint catalog entries. #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] pub struct ServiceEndpoints { + /// List of available endpoints for this service. pub endpoints: Vec, + /// Human-readable service name. + pub name: String, #[serde(rename = "type")] + /// Service type identifier (e.g., "compute", "network"). pub service_type: String, - pub name: String, } +/// Service catalog endpoint information. #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] pub struct CatalogEndpoint { + /// Endpoint unique identifier. pub id: String, + /// Interface type (public, internal, admin). pub interface: String, + /// Region identifier. pub region: Option, + /// Endpoint URL. pub url: String, } +/// User identity information. #[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)] pub struct User { + /// Domain the user belongs to. pub domain: Option, - pub name: String, + /// User unique identifier. pub id: String, + /// User name. + pub name: String, // Note(gtema): some clouds return empty string instead of null when // password does not expire. It is technically possible to use // deserialize_with to capture errors, but that leads bincode to fail // when deserializing. For now just leave it as optional string instead // of DateTime // #[serde(deserialize_with = "deser_ok_or_default")] + /// Optional password expiration date. pub password_expires_at: Option, } @@ -96,20 +117,22 @@ pub struct User { /// /// While in the response `id` and `name` and mandatorily set this type is /// also reused to manage authentications where at least one of them must be -/// present -#[derive(Builder, Clone, Debug, Default, Deserialize, Eq)] +/// present. +#[derive(Builder, Clone, Debug, Default, Deserialize, Eq, Hash, PartialEq)] #[builder(build_fn(error = "BuilderError"))] #[builder(setter(strip_option))] -#[serde(default)] pub struct Project { + /// Associated domain for the project. #[builder(default)] - pub id: Option, + pub domain: Option, + /// Project unique identifier. #[builder(default)] - pub name: Option, + pub id: Option, + /// Project name. #[builder(default)] - pub domain: Option, + pub name: Option, } impl Serialize for Project { @@ -120,60 +143,47 @@ impl Serialize for Project { if serializer.is_human_readable() { #[derive(Serialize)] struct ProjectJson<'a> { + #[serde(skip_serializing_if = "Option::is_none")] + domain: Option<&'a Domain>, #[serde(skip_serializing_if = "Option::is_none")] id: Option<&'a str>, #[serde(skip_serializing_if = "Option::is_none")] name: Option<&'a str>, - #[serde(skip_serializing_if = "Option::is_none")] - domain: Option<&'a Domain>, } let helper = ProjectJson { + domain: self.domain.as_ref(), id: self.id.as_deref(), name: self.name.as_deref(), - domain: self.domain.as_ref(), }; helper.serialize(serializer) } else { #[derive(Serialize)] struct ProjectRaw<'a> { + domain: &'a Option, id: &'a Option, name: &'a Option, - domain: &'a Option, } let helper = ProjectRaw { + domain: &self.domain, id: &self.id, name: &self.name, - domain: &self.domain, }; helper.serialize(serializer) } } } -impl PartialEq for Project { - fn eq(&self, other: &Self) -> bool { - self.id == other.id - || (self.name.is_some() - && other.name.is_some() - && self.name == other.name - && self.domain == other.domain) - } -} - -impl Hash for Project { - fn hash(&self, state: &mut H) { - self.id.hash(state) - } -} - -#[derive(Builder, Clone, Debug, Default, Deserialize, Eq)] +/// Domain identity information. +#[derive(Builder, Clone, Debug, Default, Deserialize, Eq, Hash, PartialEq)] #[builder(build_fn(error = "BuilderError"))] #[builder(setter(strip_option))] #[serde(default)] pub struct Domain { + /// Domain unique identifier. #[builder(default)] pub id: Option, + /// Domain name. #[builder(default)] pub name: Option, } @@ -211,53 +221,53 @@ impl Serialize for Domain { } } -impl PartialEq for Domain { - fn eq(&self, other: &Self) -> bool { - self.id == other.id - || (self.name.is_some() && other.name.is_some() && self.name == other.name) - } -} - -impl Hash for Domain { - fn hash(&self, state: &mut H) { - self.id.hash(state) - } -} - /// System Scope. #[derive(Builder, Clone, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize)] #[builder(build_fn(error = "BuilderError"))] #[builder(setter(strip_option))] pub struct System { + /// Flag indicating if the system scope is all. #[builder(default)] pub all: Option, } -// Trust scope. +/// Trust scope information. #[derive(Builder, Clone, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize)] #[builder(build_fn(error = "BuilderError"))] #[builder(setter(strip_option))] pub struct OsTrustTrust { + /// Trust unique identifier. #[serde(skip_serializing_if = "Option::is_none")] #[builder(default, setter(into))] pub id: Option, } +/// Multimodal authentication receipt response from the identity service. #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] pub struct AuthReceiptResponse { + /// The actual auth receipt data. pub receipt: AuthReceipt, + /// Required authentication methods for the receipt. pub required_auth_methods: Vec>, + /// Token associated with this receipt. pub token: Option, } +/// Authentication receipt data returned when additional authentication methods are required. #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] pub struct AuthReceipt { + /// Catalog of available services. pub catalog: Option>, - pub roles: Option>, + /// Receipt expiration time. + pub expires_at: DateTime, + /// Receipt issue time. + pub issued_at: Option>, + /// Authentication methods already completed. pub methods: Vec, + /// Roles assigned to the receipt. + pub roles: Option>, + /// User information. pub user: User, - pub issued_at: Option>, - pub expires_at: DateTime, } /// Application Credential information from the token diff --git a/sdk/auth-federation/src/lib.rs b/sdk/auth-federation/src/lib.rs index 4c50b607a..da14d7a29 100644 --- a/sdk/auth-federation/src/lib.rs +++ b/sdk/auth-federation/src/lib.rs @@ -85,7 +85,7 @@ impl OpenStackAuthType for OidcAuthenticator { &self, http_client: &reqwest::Client, identity_url: &url::Url, - values: std::collections::HashMap, + values: &std::collections::HashMap, _scope: Option<&AuthTokenScope>, _hints: Option<&Value>, ) -> Result { @@ -140,7 +140,7 @@ impl OpenStackAuthType for OidcAuthenticator { Ok(Auth::AuthToken(Box::new(auth_token))) } else { - return Err(AuthError::AuthTokenNotInResponse); + return Err(FederationError::CallbackNoToken.into()); } } } @@ -279,7 +279,7 @@ async fn get_auth_code( let handle = tokio::spawn({ let cancel_token = cancel_token.clone(); let state = state.clone(); - async move { auth_callback_server(socket_addr, state, cancel_token).await } + async move { auth_callback_server(socket_addr, state, cancel_token, None).await } }); open::that(url.as_str())?; @@ -299,9 +299,13 @@ async fn auth_callback_server( addr: SocketAddr, state: Arc>>, cancel_token: CancellationToken, + start_tx: Option>, ) -> Result<(), FederationError> { let listener = TcpListener::bind(addr).await?; info!("Starting webserver to receive OAUTH2 authorization callback"); + if let Some(tx) = start_tx { + let _ = tx.send(()); + } // Wait maximum 2 minute for auth processing let webserver_timeout = Duration::from_secs(120); loop { @@ -470,12 +474,16 @@ mod tests { }); let state = Arc::new(Mutex::new(None)); + let (start_tx, start_rx) = tokio::sync::oneshot::channel(); let handle = tokio::spawn({ let cancel_token = cancel_token.clone(); let state = state.clone(); - async move { auth_callback_server(addr, state, cancel_token).await } + async move { auth_callback_server(addr, state, cancel_token, Some(start_tx)).await } }); + // Wait for the server to start listening + start_rx.await.unwrap(); + let client = reqwest::Client::new(); client .get(format!( @@ -516,12 +524,16 @@ mod tests { }); let state = Arc::new(Mutex::new(None)); + let (start_tx, start_rx) = tokio::sync::oneshot::channel(); let handle = tokio::spawn({ let cancel_token = cancel_token.clone(); let state = state.clone(); - async move { auth_callback_server(addr, state, cancel_token).await } + async move { auth_callback_server(addr, state, cancel_token, Some(start_tx)).await } }); + // Wait for the server to start listening + start_rx.await.unwrap(); + let params = [("code", "foo"), ("state", "bar")]; let client = reqwest::Client::new(); client @@ -561,12 +573,16 @@ mod tests { }); let state = Arc::new(Mutex::new(None)); + let (start_tx, start_rx) = tokio::sync::oneshot::channel(); let handle = tokio::spawn({ let cancel_token = cancel_token.clone(); let state = state.clone(); - async move { auth_callback_server(addr, state, cancel_token).await } + async move { auth_callback_server(addr, state, cancel_token, Some(start_tx)).await } }); + // Wait for the server to start listening + start_rx.await.unwrap(); + let client = reqwest::Client::new(); client .post(format!("http://localhost:{}/oidc/callback", addr.port())) diff --git a/sdk/auth-jwt/src/lib.rs b/sdk/auth-jwt/src/lib.rs index 4a3631c67..2c7a23993 100644 --- a/sdk/auth-jwt/src/lib.rs +++ b/sdk/auth-jwt/src/lib.rs @@ -97,7 +97,7 @@ impl OpenStackAuthType for JwtAuthenticator { &self, http_client: &reqwest::Client, identity_url: &url::Url, - values: std::collections::HashMap, + values: &std::collections::HashMap, _scope: Option<&AuthTokenScope>, _hints: Option<&serde_json::Value>, ) -> Result { diff --git a/sdk/auth-multifactor/src/lib.rs b/sdk/auth-multifactor/src/lib.rs index 7575b4a78..db53e37fa 100644 --- a/sdk/auth-multifactor/src/lib.rs +++ b/sdk/auth-multifactor/src/lib.rs @@ -105,7 +105,7 @@ impl OpenStackAuthType for MultifactorAuthenticator { &self, http_client: &reqwest::Client, identity_url: &url::Url, - values: std::collections::HashMap, + values: &std::collections::HashMap, scope: Option<&AuthTokenScope>, hints: Option<&serde_json::Value>, ) -> Result { @@ -129,7 +129,7 @@ impl OpenStackAuthType for MultifactorAuthenticator { }) .map(|x| x.method) { - let (method, method_identity) = authenticator.get_auth_data(&values)?; + let (method, method_identity) = authenticator.get_auth_data(values)?; methods.insert(method.into()); deep_merge_value(&mut identity, method_identity); } @@ -371,7 +371,7 @@ mod tests { .auth( &http_client, &base_url, - HashMap::from([ + &HashMap::from([ ("token".into(), SecretString::from("secret")), ("passcode".into(), SecretString::from("passcode")), ("user_id".into(), SecretString::from("uid")), @@ -395,7 +395,7 @@ mod tests { .auth( &http_client, &base_url, - HashMap::from([ + &HashMap::from([ ("password".into(), SecretString::from("password")), ("passcode".into(), SecretString::from("passcode")), ("user_id".into(), SecretString::from("uid")), diff --git a/sdk/auth-oidcaccesstoken/src/lib.rs b/sdk/auth-oidcaccesstoken/src/lib.rs index 242015725..6d0e9558f 100644 --- a/sdk/auth-oidcaccesstoken/src/lib.rs +++ b/sdk/auth-oidcaccesstoken/src/lib.rs @@ -70,7 +70,7 @@ impl OpenStackAuthType for OidcAccessTokenAuthenticator { &self, http_client: &reqwest::Client, identity_url: &url::Url, - values: std::collections::HashMap, + values: &std::collections::HashMap, _scope: Option<&AuthTokenScope>, _hints: Option<&serde_json::Value>, ) -> Result { diff --git a/sdk/auth-passkey/src/lib.rs b/sdk/auth-passkey/src/lib.rs index b6e5041cb..78a6ae103 100644 --- a/sdk/auth-passkey/src/lib.rs +++ b/sdk/auth-passkey/src/lib.rs @@ -67,7 +67,7 @@ impl OpenStackAuthType for WebAuthnAuthenticator { &self, http_client: &reqwest::Client, identity_url: &url::Url, - values: std::collections::HashMap, + values: &std::collections::HashMap, _scope: Option<&AuthTokenScope>, _hints: Option<&serde_json::Value>, ) -> Result { diff --git a/sdk/auth-password/src/lib.rs b/sdk/auth-password/src/lib.rs index 9746726a8..ac071b9c1 100644 --- a/sdk/auth-password/src/lib.rs +++ b/sdk/auth-password/src/lib.rs @@ -167,11 +167,11 @@ impl OpenStackAuthType for PasswordAuthenticator { &self, http_client: &reqwest::Client, identity_url: &url::Url, - values: HashMap, + values: &HashMap, scope: Option<&AuthTokenScope>, _hints: Option<&serde_json::Value>, ) -> Result { - let (method, data) = self._get_auth_data(&values)?; + let (method, data) = self._get_auth_data(values)?; let mut body = json!({ "auth": { "identity": data } }); body["auth"]["identity"]["methods"] = [method].into(); if let Some(scope) = scope { @@ -356,7 +356,7 @@ mod tests { .auth( &http_client, &base_url, - HashMap::from([ + &HashMap::from([ ("password".into(), SecretString::from("password")), ("user_id".into(), SecretString::from("uid")), ]), @@ -424,7 +424,7 @@ mod tests { .auth( &http_client, &base_url, - HashMap::from([ + &HashMap::from([ ("password".into(), SecretString::from("password")), ("user_id".into(), SecretString::from("uid")), ]), diff --git a/sdk/auth-receipt/src/lib.rs b/sdk/auth-receipt/src/lib.rs index 1b4b44eb2..2c542bae5 100644 --- a/sdk/auth-receipt/src/lib.rs +++ b/sdk/auth-receipt/src/lib.rs @@ -105,7 +105,7 @@ impl OpenStackAuthType for ReceiptAuthenticator { &self, http_client: &reqwest::Client, identity_url: &url::Url, - values: std::collections::HashMap, + values: &std::collections::HashMap, scope: Option<&AuthTokenScope>, hints: Option<&serde_json::Value>, ) -> Result { @@ -132,7 +132,7 @@ impl OpenStackAuthType for ReceiptAuthenticator { }) .map(|x| x.method) { - let (method, method_identity) = authenticator.get_auth_data(&values)?; + let (method, method_identity) = authenticator.get_auth_data(values)?; methods.insert(method.into()); deep_merge_value(&mut identity, method_identity); } @@ -312,7 +312,7 @@ mod tests { .auth( &http_client, &base_url, - HashMap::from([ + &HashMap::from([ ("token".into(), SecretString::from("secret")), ("passcode".into(), SecretString::from("passcode")), ("user_id".into(), SecretString::from("uid")), diff --git a/sdk/auth-token/src/lib.rs b/sdk/auth-token/src/lib.rs index 3118057ee..9a174faa3 100644 --- a/sdk/auth-token/src/lib.rs +++ b/sdk/auth-token/src/lib.rs @@ -128,7 +128,7 @@ impl OpenStackAuthType for TokenAuthenticator { &self, http_client: &reqwest::Client, identity_url: &url::Url, - values: std::collections::HashMap, + values: &std::collections::HashMap, scope: Option<&AuthTokenScope>, _hints: Option<&serde_json::Value>, ) -> Result { @@ -254,7 +254,7 @@ mod tests { .auth( &http_client, &base_url, - HashMap::from([("token".into(), SecretString::from("secret"))]), + &HashMap::from([("token".into(), SecretString::from("secret"))]), None, None, ) diff --git a/sdk/auth-totp/src/lib.rs b/sdk/auth-totp/src/lib.rs index ca8dc5050..78fcb799b 100644 --- a/sdk/auth-totp/src/lib.rs +++ b/sdk/auth-totp/src/lib.rs @@ -164,11 +164,11 @@ impl OpenStackAuthType for TotpAuthenticator { &self, http_client: &reqwest::Client, identity_url: &url::Url, - values: HashMap, + values: &HashMap, scope: Option<&AuthTokenScope>, _hints: Option<&serde_json::Value>, ) -> Result { - let (method, data) = self._get_auth_data(&values)?; + let (method, data) = self._get_auth_data(values)?; let mut body = json!({ "auth": { "identity": data } }); body["auth"]["identity"]["methods"] = [method].into(); if let Some(scope) = scope { @@ -311,7 +311,7 @@ mod tests { .auth( &http_client, &base_url, - HashMap::from([ + &HashMap::from([ ("passcode".into(), SecretString::from("passcode")), ("user_id".into(), SecretString::from("uid")), ]), @@ -377,7 +377,7 @@ mod tests { .auth( &http_client, &base_url, - HashMap::from([ + &HashMap::from([ ("passcode".into(), SecretString::from("passcode")), ("user_id".into(), SecretString::from("uid")), ]), diff --git a/sdk/auth-websso/src/lib.rs b/sdk/auth-websso/src/lib.rs index 8141f2a28..a4ab7e37d 100644 --- a/sdk/auth-websso/src/lib.rs +++ b/sdk/auth-websso/src/lib.rs @@ -88,7 +88,7 @@ impl OpenStackAuthType for WebSSOAuthenticator { &self, _http_client: &reqwest::Client, identity_url: &url::Url, - values: std::collections::HashMap, + values: &std::collections::HashMap, _scope: Option<&AuthTokenScope>, _hints: Option<&serde_json::Value>, ) -> Result { @@ -227,7 +227,7 @@ async fn get_token(url: &mut Url, socket_addr: Option) -> Result>>, cancel_token: CancellationToken, + start_tx: Option>, ) -> Result<(), WebSsoError> { let listener = TcpListener::bind(addr).await?; info!("Starting webserver to receive SSO callback"); + if let Some(tx) = start_tx { + let _ = tx.send(()); + } // Wait maximum 2 minute for auth processing let webserver_timeout = Duration::from_secs(120); loop { @@ -326,18 +330,71 @@ async fn handle_request( #[cfg(test)] mod tests { - use reserve_port::ReservedSocketAddr; use std::sync::{Arc, Mutex}; + use tokio::net::TcpListener; use tokio::signal; use tokio_util::sync::CancellationToken; - - use super::websso_callback_server; + use tracing::{info, warn}; + + use super::WebSsoError; + use super::handle_request; + + /// Test-only variant that accepts a pre-bound listener to avoid port reservation races + async fn websso_callback_server_test( + listener: TcpListener, + state: Arc>>, + cancel_token: CancellationToken, + ) -> Result<(), WebSsoError> { + use hyper::server::conn::http1; + use hyper::service::service_fn; + + use hyper_util::rt::TokioIo; + use tracing::error; + + info!("Starting webserver to receive SSO callback"); + let webserver_timeout = std::time::Duration::from_secs(120); + loop { + let state_clone = state.clone(); + + tokio::select! { + Ok((stream, _addr)) = listener.accept() => { + let io = TokioIo::new(stream); + let cancel_token_srv = cancel_token.clone(); + let cancel_token_conn = cancel_token.clone(); + + let service = service_fn(move |req| { + let state_clone = state_clone.clone(); + let cancel_token = cancel_token_srv.clone(); + handle_request(req, state_clone, cancel_token) + }); + + tokio::task::spawn(async move { + let cancel_token = cancel_token_conn.clone(); + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + error!("Failed to serve connection: {:?}", err); + cancel_token.cancel(); + } + }); + }, + _ = cancel_token.cancelled() => { + info!("Stopping webserver"); + break; + }, + _ = tokio::time::sleep(webserver_timeout) => { + warn!("Timeout of {} sec waiting for authentication expired. Shutting down", webserver_timeout.as_secs()); + cancel_token.cancel(); + } + } + } + Ok(()) + } #[tokio::test] async fn test_callback() { - let addr = ReservedSocketAddr::reserve_random_socket_addr() - .expect("port available") - .socket_addr(); + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("port available"); + let addr = listener.local_addr().expect("listener address"); let cancel_token = CancellationToken::new(); tokio::spawn({ @@ -353,7 +410,7 @@ mod tests { let websso_handle = tokio::spawn({ let cancel_token = cancel_token.clone(); let state = state.clone(); - async move { websso_callback_server(addr, state, cancel_token).await } + async move { websso_callback_server_test(listener, state, cancel_token).await } }); let params = [("token", "foo_bar_baz")]; @@ -371,9 +428,10 @@ mod tests { #[tokio::test] async fn test_callback_no_token() { - let addr = ReservedSocketAddr::reserve_random_socket_addr() - .expect("port available") - .socket_addr(); + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("port available"); + let addr = listener.local_addr().expect("listener address"); let cancel_token = CancellationToken::new(); tokio::spawn({ @@ -389,7 +447,7 @@ mod tests { let websso_handle = tokio::spawn({ let cancel_token = cancel_token.clone(); let state = state.clone(); - async move { websso_callback_server(addr, state, cancel_token).await } + async move { websso_callback_server_test(listener, state, cancel_token).await } }); let client = reqwest::Client::new(); diff --git a/sdk/core/src/state.rs b/sdk/core/src/state.rs index 6d4176ce7..4ddb37b31 100644 --- a/sdk/core/src/state.rs +++ b/sdk/core/src/state.rs @@ -162,19 +162,20 @@ impl State { pub fn get_scope_auth(&mut self, scope: &AuthTokenScope) -> Option { trace!("Get authz information for {:?}", scope); self.auth_state.filter_invalid_auths(); - match self.auth_state.0.get(scope) { - Some(authz) => Some(authz.clone()), - None => { - if let (Some(state), true) = (self.load_auth_state(None), self.auth_cache_enabled) - && let Some((scope, authz)) = self.find_scope_authz(&state, scope) - { - trace!("Found valid authz in the state file"); - self.auth_state.0.insert(scope, authz.clone()); - return Some(authz); - } - None - } + + // Find the best matching scope using wildcard matching + if let Some((_, authz)) = self.auth_state.0.iter().find(|(k, _)| scope.matches(k)) { + return Some(authz.clone()); + } + + if let (Some(state), true) = (self.load_auth_state(None), self.auth_cache_enabled) + && let Some((scope, authz)) = self.find_scope_authz(&state, scope) + { + trace!("Found valid authz in the state file"); + self.auth_state.0.insert(scope, authz.clone()); + return Some(authz); } + None } fn find_valid_auth(&self, state: &ScopeAuths) -> Option { @@ -222,50 +223,8 @@ impl State { trace!("Searching requested scope authz in state"); for (k, v) in state.0.iter() { trace!("Analyse known auth for scope {:?}", k); - match scope { - AuthTokenScope::Project(project) => { - if let AuthTokenScope::Project(cached) = k { - // Scope type matches - if project.id.is_some() && project.id == cached.id { - // Match by ID is definite - return Some((k.clone(), v.clone())); - } else if project.name == cached.name { - // Match by Name requires verifying domain match - if let (Some(requested_domain), Some(state_domain)) = - (&project.domain, &cached.domain) - && ((requested_domain.id.is_some() - && requested_domain.id == state_domain.id) - || (requested_domain.id.is_none() - && requested_domain.name == state_domain.name)) - { - return Some((k.clone(), v.clone())); - } - } - } - } - AuthTokenScope::Domain(domain) => { - if let AuthTokenScope::Domain(cached) = k { - // Scope type matches - if domain.id == cached.id - || (domain.id.is_none() && domain.name == cached.name) - { - return Some((k.clone(), v.clone())); - } - } - } - AuthTokenScope::System(system) => { - if let AuthTokenScope::System(cached) = k { - // Scope type matches - if system.all == cached.all { - return Some((k.clone(), v.clone())); - } - } - } - AuthTokenScope::Unscoped => { - if let AuthTokenScope::Unscoped = k { - return Some((k.clone(), v.clone())); - } - } + if scope.matches(k) { + return Some((k.clone(), v.clone())); } } None @@ -492,7 +451,7 @@ mod tests { use secrecy::ExposeSecret; - use openstack_sdk_auth_core::types::Project; + use openstack_sdk_auth_core::types::{Domain, Project}; use super::*; @@ -524,7 +483,7 @@ mod tests { } fn make_token(name: &str) -> AuthToken { - use openstack_sdk_auth_core::types::{AuthResponse, AuthToken as AuthTokenType}; + use openstack_sdk_auth_core::types::{AuthResponse, TokenInfo as AuthTokenType}; AuthToken::new( name, Some(AuthResponse { @@ -676,4 +635,185 @@ mod tests { assert!(loaded.is_some(), "Token for scope {:?} was lost", scope); } } + + #[test] + fn test_wildcard_lookup_by_partial_scope() { + let dir = make_state_dir(); + let mut s = new_state_in(&dir, 20); + let scope_cached = AuthTokenScope::Project(Project { + id: Some("p-wildcard".to_string()), + name: Some("Wildcard".to_string()), + domain: None, + }); + let token = make_token("tok-wildcard"); + s.set_scope_auth(&scope_cached, &token); + + // Lookup by name only (id=None) should find the cached token + let scope_req_by_name = AuthTokenScope::Project(Project { + id: None, + name: Some("Wildcard".to_string()), + domain: None, + }); + let loaded = s.get_scope_auth(&scope_req_by_name); + assert!(loaded.is_some()); + assert_eq!(loaded.unwrap().token.expose_secret(), "tok-wildcard"); + } + + #[test] + fn test_multiple_scopes_in_same_cache() { + let dir = make_state_dir(); + let mut s = new_state_in(&dir, 21); + let scope1 = make_project_scope("multi-p1"); + let scope2 = make_project_scope("multi-p2"); + let scope3 = make_project_scope("multi-p3"); + let token1 = make_token("tok-multi-1"); + let token2 = make_token("tok-multi-2"); + let token3 = make_token("tok-multi-3"); + s.set_scope_auth(&scope1, &token1); + s.set_scope_auth(&scope2, &token2); + s.set_scope_auth(&scope3, &token3); + + assert_eq!(s.auth_state.0.len(), 3); + + let t1 = s.get_scope_auth(&scope1).unwrap(); + let t2 = s.get_scope_auth(&scope2).unwrap(); + let t3 = s.get_scope_auth(&scope3).unwrap(); + assert_eq!(t1.token.expose_secret(), "tok-multi-1"); + assert_eq!(t2.token.expose_secret(), "tok-multi-2"); + assert_eq!(t3.token.expose_secret(), "tok-multi-3"); + } + + #[test] + fn test_domain_scope_cache() { + let dir = make_state_dir(); + let mut s = new_state_in(&dir, 22); + let scope_d = AuthTokenScope::Domain(Domain { + id: Some("d-22".to_string()), + name: Some("Default".to_string()), + }); + let token_d = make_token("tok-domain"); + s.set_scope_auth(&scope_d, &token_d); + + let loaded = s.get_scope_auth(&scope_d); + assert!(loaded.is_some()); + assert_eq!(loaded.unwrap().token.expose_secret(), "tok-domain"); + } + + #[test] + fn test_unscoped_token_cache() { + let dir = make_state_dir(); + let mut s = new_state_in(&dir, 23); + let scope_u = AuthTokenScope::Unscoped; + let token_u = make_token("tok-unscoped"); + s.set_scope_auth(&scope_u, &token_u); + + let loaded = s.get_scope_auth(&scope_u); + assert!(loaded.is_some()); + assert_eq!(loaded.unwrap().token.expose_secret(), "tok-unscoped"); + } + + #[test] + fn test_invalid_token_is_filtered_on_set() { + use openstack_sdk_auth_core::{ + authtoken::AuthToken as TestAuthToken, + types::{AuthResponse, TokenInfo}, + }; + let dir = make_state_dir(); + let mut s = new_state_in(&dir, 24); + let scope = make_project_scope("p-24"); + let valid_token = make_token("tok-valid"); + s.set_scope_auth(&scope, &valid_token); + assert_eq!(s.auth_state.0.len(), 1); + + // Create an expired token + let expired_token = TestAuthToken::new( + "tok-expired".to_string(), + Some(AuthResponse { + token: TokenInfo { + user: Default::default(), + catalog: Some(vec![]), + expires_at: chrono::Utc::now() - chrono::TimeDelta::hours(1), + ..Default::default() + }, + }), + ); + // Setting the expired token should filter out the valid one (since it's the only valid one and now a new expired one is set) + s.set_scope_auth(&scope, &expired_token); + // filter_invalid_auths removes all invalid tokens; the expired one should be removed too + // Actually, filter runs before insert, so it keeps the old valid token but then the new one replaces it + assert_eq!(s.auth_state.0.len(), 1); + // The expired token should have been filtered out on the next get + s.auth_state.filter_invalid_auths(); + assert_eq!(s.auth_state.0.len(), 0); + } + + #[test] + fn test_postcard_roundtrip_multiple_scopes() { + let scope1 = make_project_scope("postcard-p1"); + let scope2 = make_project_scope("postcard-p2"); + let scope3 = AuthTokenScope::Domain(Domain { + id: Some("d-postcard".to_string()), + name: Some("Default".to_string()), + }); + let token1 = make_token("tok-pc-1"); + let token2 = make_token("tok-pc-2"); + let token3 = make_token("tok-pc-3"); + let mut sa = ScopeAuths::default(); + sa.0.insert(scope1, token1); + sa.0.insert(scope2, token2); + sa.0.insert(scope3, token3); + let bytes = postcard::to_stdvec(&sa).unwrap(); + let deserialized: ScopeAuths = postcard::from_bytes(&bytes).unwrap(); + assert_eq!(deserialized.0.len(), 3); + } + + #[test] + fn test_wildcard_lookup_in_file_cache() { + let dir = make_state_dir(); + let mut s = new_state_in(&dir, 25); + let scope_cached = AuthTokenScope::Project(Project { + id: Some("file-wildcard".to_string()), + name: Some("FileLookup".to_string()), + domain: None, + }); + let token = make_token("tok-file-wildcard"); + s.set_scope_auth(&scope_cached, &token); + + // Create a new state with cleared memory to force file lookup + let mut s2 = new_state_in(&dir, 25); + s2.auth_state.0.clear(); + + // Lookup by name only (id=None) should find the cached token in the file + let scope_req_by_name = AuthTokenScope::Project(Project { + id: None, + name: Some("FileLookup".to_string()), + domain: None, + }); + let loaded = s2.get_scope_auth(&scope_req_by_name); + assert!(loaded.is_some()); + assert_eq!(loaded.unwrap().token.expose_secret(), "tok-file-wildcard"); + } + + #[test] + fn test_project_and_domain_scope_coexist() { + let dir = make_state_dir(); + let mut s = new_state_in(&dir, 26); + let scope_p = make_project_scope("coexist-p"); + let scope_d = AuthTokenScope::Domain(Domain { + id: Some("d-26".to_string()), + name: Some("Default".to_string()), + }); + let token_p = make_token("tok-p-coexist"); + let token_d = make_token("tok-d-coexist"); + s.set_scope_auth(&scope_p, &token_p); + s.set_scope_auth(&scope_d, &token_d); + assert_eq!(s.auth_state.0.len(), 2); + + let tp = s.get_scope_auth(&scope_p); + let td = s.get_scope_auth(&scope_d); + assert!(tp.is_some()); + assert!(td.is_some()); + assert_eq!(tp.unwrap().token.expose_secret(), "tok-p-coexist"); + assert_eq!(td.unwrap().token.expose_secret(), "tok-d-coexist"); + } }