From 61ba9f6d632223b22e15f7ed05575a8eef56092b Mon Sep 17 00:00:00 2001 From: Valmo Date: Tue, 24 Mar 2026 16:55:05 -0300 Subject: [PATCH] Added connector and enrollment for mTLS client certificate auto enrollment on game sessions, will MOCK a official tak client behavior when authenticating --- src/tcp/tls/artifacts.rs | 101 ++++++++++++++++++++++ src/tcp/tls/connector.rs | 87 +++++++++++++++++++ src/tcp/tls/enrollment.rs | 175 ++++++++++++++++++++++++++++++++++++++ src/tcp/tls/mod.rs | 6 ++ 4 files changed, 369 insertions(+) create mode 100644 src/tcp/tls/artifacts.rs create mode 100644 src/tcp/tls/connector.rs create mode 100644 src/tcp/tls/enrollment.rs create mode 100644 src/tcp/tls/mod.rs diff --git a/src/tcp/tls/artifacts.rs b/src/tcp/tls/artifacts.rs new file mode 100644 index 0000000..fc5267c --- /dev/null +++ b/src/tcp/tls/artifacts.rs @@ -0,0 +1,101 @@ +use lazy_static::lazy_static; +use std::env; +use std::fs::{self, create_dir_all}; +use std::path::PathBuf; +use std::sync::Mutex; + +#[derive(Clone)] +pub struct EnrollmentArtifacts { + pub ca_cert_path: String, + pub client_cert_path: String, + pub client_key_path: String, +} + +lazy_static! { + static ref ENROLLMENT_ARTIFACTS: Mutex> = Mutex::new(None); +} + +fn current_artifacts_dir() -> Result { + let mut path = env::current_dir().map_err(|e| format!("failed to resolve cwd: {}", e))?; + path.push(".armatak"); + path.push("session-certs"); + create_dir_all(&path) + .map_err(|e| format!("failed to create cert dir {}: {}", path.display(), e))?; + Ok(path) +} + +pub fn persist_enrollment_artifacts( + client_uid: &str, + ca_pem: &str, + cert_pem: &str, + key_pem: &str, +) -> Result { + let mut base_dir = current_artifacts_dir()?; + let safe_uid = client_uid + .chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' { + ch + } else { + '_' + } + }) + .collect::(); + + base_dir.push(safe_uid); + create_dir_all(&base_dir).map_err(|e| { + format!( + "failed to create session cert dir {}: {}", + base_dir.display(), + e + ) + })?; + + let ca_cert_path = base_dir.join("ca.pem"); + let client_cert_path = base_dir.join("client.pem"); + let client_key_path = base_dir.join("client.key"); + + fs::write(&ca_cert_path, ca_pem).map_err(|e| { + format!( + "failed to persist CA cert {}: {}", + ca_cert_path.display(), + e + ) + })?; + fs::write(&client_cert_path, cert_pem).map_err(|e| { + format!( + "failed to persist client cert {}: {}", + client_cert_path.display(), + e + ) + })?; + fs::write(&client_key_path, key_pem).map_err(|e| { + format!( + "failed to persist client key {}: {}", + client_key_path.display(), + e + ) + })?; + + Ok(EnrollmentArtifacts { + ca_cert_path: ca_cert_path.to_string_lossy().to_string(), + client_cert_path: client_cert_path.to_string_lossy().to_string(), + client_key_path: client_key_path.to_string_lossy().to_string(), + }) +} + +pub fn store_enrollment_artifacts(artifacts: EnrollmentArtifacts) { + *ENROLLMENT_ARTIFACTS.lock().unwrap() = Some(artifacts); +} + +pub fn clear_enrollment_artifacts() { + if let Some(artifacts) = ENROLLMENT_ARTIFACTS.lock().unwrap().take() { + for path in [ + artifacts.ca_cert_path, + artifacts.client_cert_path, + artifacts.client_key_path, + ] { + let _ = fs::remove_file(path); + } + } +} diff --git a/src/tcp/tls/connector.rs b/src/tcp/tls/connector.rs new file mode 100644 index 0000000..a1b4ec2 --- /dev/null +++ b/src/tcp/tls/connector.rs @@ -0,0 +1,87 @@ +use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; +use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned}; +use rustls_pemfile::{certs, private_key}; +use std::fs::File; +use std::io::BufReader; +use std::net::TcpStream; +use std::sync::Arc; + +use crate::tcp::transport::TransportStream; + +fn load_certificates(path: &str) -> Result>, String> { + let file = File::open(path).map_err(|e| format!("failed to open cert file {}: {}", path, e))?; + let mut reader = BufReader::new(file); + + certs(&mut reader) + .collect::, _>>() + .map_err(|e| format!("failed to read certs from {}: {}", path, e)) +} + +fn load_private_key(path: &str) -> Result, String> { + let file = File::open(path).map_err(|e| format!("failed to open key file {}: {}", path, e))?; + let mut reader = BufReader::new(file); + + private_key(&mut reader) + .map_err(|e| format!("failed to read private key from {}: {}", path, e))? + .ok_or_else(|| format!("no supported private key found in {}", path)) +} + +fn infer_server_name(address: &str) -> &str { + address + .trim() + .trim_start_matches('[') + .split(']') + .next() + .unwrap_or(address) + .split(':') + .next() + .unwrap_or(address) +} + +pub fn connect_mtls( + address: &str, + server_name: &str, + ca_cert_path: &str, + client_cert_path: &str, + client_key_path: &str, +) -> Result { + let mut root_store = RootCertStore::empty(); + for certificate in load_certificates(ca_cert_path)? { + root_store + .add(certificate) + .map_err(|e| format!("failed to add CA certificate from {}: {}", ca_cert_path, e))?; + } + + let client_certificates = load_certificates(client_cert_path)?; + let client_key = load_private_key(client_key_path)?; + + let tls_config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_client_auth_cert(client_certificates, client_key) + .map_err(|e| format!("failed to configure mTLS client: {}", e))?; + + let tcp_stream = TcpStream::connect(address) + .map_err(|e| format!("failed to connect to {}: {}", address, e))?; + let resolved_server_name = if server_name.trim().is_empty() { + infer_server_name(address).to_string() + } else { + server_name.trim().to_string() + }; + + let server_name = ServerName::try_from(resolved_server_name.clone()) + .map_err(|_| format!("invalid TLS server name: {}", resolved_server_name))?; + let mut tls_stream = StreamOwned::new( + ClientConnection::new(Arc::new(tls_config), server_name) + .map_err(|e| format!("failed to create TLS client: {}", e))?, + tcp_stream, + ); + + while tls_stream.conn.is_handshaking() { + tls_stream + .conn + .complete_io(&mut tls_stream.sock) + .map_err(|e| format!("TLS handshake failed: {}", e))?; + } + + Ok(TransportStream::Mtls(tls_stream)) +} diff --git a/src/tcp/tls/enrollment.rs b/src/tcp/tls/enrollment.rs new file mode 100644 index 0000000..b08e9e2 --- /dev/null +++ b/src/tcp/tls/enrollment.rs @@ -0,0 +1,175 @@ +use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair, PKCS_ECDSA_P256_SHA256}; +use reqwest::blocking::Client; +use serde::Deserialize; +use uuid::Uuid; + +use super::artifacts::{ + persist_enrollment_artifacts, store_enrollment_artifacts, EnrollmentArtifacts, +}; +use super::connector::connect_mtls; +use crate::tcp::transport::TransportStream; + +#[derive(Deserialize)] +struct EnrollmentResponse { + #[serde(rename = "signedCert")] + signed_cert: String, + ca0: String, +} + +struct EnrollmentConfig { + server_port: String, + enroll_path: String, +} + +fn extract_tag_value(xml: &str, tag_name: &str) -> Option { + let open_tag = format!("<{}>", tag_name); + let close_tag = format!("", tag_name); + let start = xml.find(&open_tag)? + open_tag.len(); + let end = xml[start..].find(&close_tag)? + start; + Some(xml[start..end].trim().to_string()) +} + +fn wrap_pem_body(base64_body: &str, begin: &str, end: &str) -> String { + let mut wrapped = String::new(); + let normalized = base64_body.trim().replace(['\r', '\n'], ""); + + wrapped.push_str(begin); + wrapped.push('\n'); + for chunk in normalized.as_bytes().chunks(64) { + wrapped.push_str(std::str::from_utf8(chunk).unwrap_or_default()); + wrapped.push('\n'); + } + wrapped.push_str(end); + wrapped.push('\n'); + wrapped +} + +fn enrollment_http_client() -> Result { + Client::builder() + .danger_accept_invalid_certs(true) + .build() + .map_err(|e| format!("failed to build enrollment HTTP client: {}", e)) +} + +fn fetch_enrollment_config(host: &str, enroll_port: &str) -> Result { + let url = format!( + "https://{}:{}/Marti/api/tls/config", + host.trim(), + enroll_port.trim() + ); + + let response_text = enrollment_http_client()? + .get(&url) + .send() + .and_then(|response| response.error_for_status()) + .map_err(|e| format!("failed to fetch {}: {}", url, e))? + .text() + .map_err(|e| format!("failed to read config response from {}: {}", url, e))?; + + let server_port = extract_tag_value(&response_text, "serverPort") + .ok_or_else(|| "missing serverPort in /Marti/api/tls/config response".to_string())?; + let enroll_path = extract_tag_value(&response_text, "enrollPath") + .ok_or_else(|| "missing enrollPath in /Marti/api/tls/config response".to_string())?; + + Ok(EnrollmentConfig { + server_port, + enroll_path, + }) +} + +fn enroll_client_certificate( + host: &str, + enroll_port: &str, + enroll_path: &str, + username: &str, + password: &str, + client_uid: &str, +) -> Result { + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256) + .map_err(|e| format!("failed to generate client keypair: {}", e))?; + + let mut distinguished_name = DistinguishedName::new(); + distinguished_name.push(DnType::CommonName, client_uid); + distinguished_name.push(DnType::OrganizationName, "ArmaTAK"); + distinguished_name.push(DnType::OrganizationalUnitName, "ArmaTAK Session"); + + let mut params = CertificateParams::new(vec![]) + .map_err(|e| format!("failed to create CSR params: {}", e))?; + params.distinguished_name = distinguished_name; + + let csr = params + .serialize_request(&key_pair) + .map_err(|e| format!("failed to generate CSR: {}", e))? + .pem() + .map_err(|e| format!("failed to serialize CSR to PEM: {}", e))?; + + let url = format!( + "https://{}:{}{}?clientUid={}", + host.trim(), + enroll_port.trim(), + enroll_path.trim(), + client_uid.trim() + ); + + let response = enrollment_http_client()? + .post(&url) + .basic_auth(username.trim(), Some(password.to_string())) + .header("Accept", "application/json") + .header("Content-Type", "application/x-pem-file") + .body(csr) + .send() + .and_then(|response| response.error_for_status()) + .map_err(|e| format!("failed to enroll client certificate at {}: {}", url, e))?; + + let enrollment: EnrollmentResponse = response + .json() + .map_err(|e| format!("failed to parse enrollment response: {}", e))?; + + let cert_pem = wrap_pem_body( + &enrollment.signed_cert, + "-----BEGIN CERTIFICATE-----", + "-----END CERTIFICATE-----", + ); + let key_pem = key_pair.serialize_pem(); + + persist_enrollment_artifacts(client_uid, &enrollment.ca0, &cert_pem, &key_pem) +} + +pub fn enroll_and_connect( + host: &str, + server_name: &str, + enroll_port: &str, + username: &str, + password: &str, + client_uid: &str, +) -> Result { + let normalized_client_uid = if client_uid.trim().is_empty() { + format!("armatak-{}", Uuid::new_v4()) + } else { + client_uid.trim().to_string() + }; + + let enrollment_config = fetch_enrollment_config(host, enroll_port)?; + let artifacts = enroll_client_certificate( + host, + enroll_port, + &enrollment_config.enroll_path, + username, + password, + &normalized_client_uid, + )?; + + store_enrollment_artifacts(artifacts.clone()); + + connect_mtls( + &format!("{}:{}", host.trim(), enrollment_config.server_port.trim()), + if server_name.trim().is_empty() { + host.trim() + } else { + server_name.trim() + }, + &artifacts.ca_cert_path, + &artifacts.client_cert_path, + &artifacts.client_key_path, + ) +} diff --git a/src/tcp/tls/mod.rs b/src/tcp/tls/mod.rs new file mode 100644 index 0000000..1b7e579 --- /dev/null +++ b/src/tcp/tls/mod.rs @@ -0,0 +1,6 @@ +pub mod artifacts; +mod connector; +mod enrollment; + +pub use connector::connect_mtls; +pub use enrollment::enroll_and_connect;