diff --git a/src/tcp/tls/enrollment.rs b/src/tcp/tls/enrollment.rs index 922441a..d03fc56 100644 --- a/src/tcp/tls/enrollment.rs +++ b/src/tcp/tls/enrollment.rs @@ -7,6 +7,9 @@ use uuid::Uuid; use super::connector::connect_mtls_from_pem; use crate::tcp::transport::TransportStream; +const DEFAULT_MTLS_SERVER_PORT: &str = "8089"; +const DEFAULT_ENROLL_PATH: &str = "/Marti/api/tls/signClient/v2"; + #[derive(Deserialize)] struct EnrollmentResponse { #[serde(rename = "signedCert")] @@ -27,6 +30,23 @@ fn extract_tag_value(xml: &str, tag_name: &str) -> Option { Some(xml[start..end].trim().to_string()) } +fn normalize_certificate_pem(certificate: &str) -> String { + let trimmed = certificate.trim(); + if trimmed.contains("-----BEGIN CERTIFICATE-----") { + if trimmed.ends_with('\n') { + trimmed.to_string() + } else { + format!("{}\n", trimmed) + } + } else { + wrap_pem_body( + trimmed, + "-----BEGIN CERTIFICATE-----", + "-----END CERTIFICATE-----", + ) + } +} + fn wrap_pem_body(base64_body: &str, begin: &str, end: &str) -> String { let mut wrapped = String::new(); let normalized = base64_body.trim().replace(['\r', '\n'], ""); @@ -89,10 +109,20 @@ fn fetch_enrollment_config(host: &str, enroll_port: &str) -> Result"; + + assert!(extract_tag_value(xml, "serverPort").is_none()); + assert!(extract_tag_value(xml, "enrollPath").is_none()); + } + + #[test] + fn normalizes_base64_certificate_body_to_pem() { + let pem = normalize_certificate_pem("QUJDREVGRw=="); + + assert_eq!( + pem, + wrap_pem_body( + "QUJDREVGRw==", + "-----BEGIN CERTIFICATE-----", + "-----END CERTIFICATE-----" + ) + ); + } + + #[test] + fn preserves_existing_certificate_pem() { + let pem = "-----BEGIN CERTIFICATE-----\nQUJDREVGRw==\n-----END CERTIFICATE-----\n"; + + assert_eq!(normalize_certificate_pem(pem), pem); + } +}