diff --git a/Cargo.lock b/Cargo.lock index 6badbc9..603ac2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1053,8 +1053,8 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" dependencies = [ + "aws-lc-rs", "pem", - "ring", "rustls-pki-types", "time", "yasna", diff --git a/Cargo.toml b/Cargo.toml index 9a62ecd..a6d34d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ lazy_static = "1.5.0" log = "0.4.22" log4rs = "1.3.0" reqwest = { version = "0.12.15", default-features = false, features = ["blocking", "json", "rustls-tls"] } -rcgen = "0.13.2" +rcgen = { version = "0.13.2", default-features = false, features = ["crypto", "pem", "aws_lc_rs"] } rustls = "0.23.23" rustls-pemfile = "2.2.0" serde = { version = "1.0.210", features = ["derive"] } diff --git a/src/lib.rs b/src/lib.rs index 26c470e..f3cfb53 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ use arma_rs::{arma, Extension, Group}; +use rustls::crypto::aws_lc_rs; mod structs; mod tcp; mod tests; @@ -31,6 +32,9 @@ pub fn init() -> Extension { log4rs::init_config(config).unwrap(); + let _ = aws_lc_rs::default_provider().install_default(); + log::info!("Initialized rustls aws-lc crypto provider."); + Extension::build() .command("local_ip", utils::address::get_local_address) .command("uuid", utils::uuid::get_uuid) diff --git a/src/tcp/client.rs b/src/tcp/client.rs index 18ba1e5..d3cbff4 100644 --- a/src/tcp/client.rs +++ b/src/tcp/client.rs @@ -1,13 +1,18 @@ use arma_rs::Context; -use log::info; -use std::sync::mpsc::{Receiver, Sender}; -use std::sync::{Arc, Mutex}; +use log::{info, warn}; +use std::collections::VecDeque; +use std::panic::{self, AssertUnwindSafe}; +use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender, TryRecvError}; use std::thread; +use std::time::Duration; use super::config::ConnectionConfig; use super::transport::{connect_stream, TransportStream}; use super::TCP_CLIENT; +const CONNECT_POLL_INTERVAL: Duration = Duration::from_millis(200); +const MAX_PENDING_MESSAGES: usize = 128; + pub enum TcpCommand { SendMessage(String, Context), Stop, @@ -17,80 +22,259 @@ pub struct TcpClient { pub(crate) tx: Sender, } +enum ConnectionState { + Connecting, + Connected, + Failed(String), +} + +enum ConnectEvent { + Connected(TransportStream), + Failed(String), +} + +fn describe_panic_payload(payload: Box) -> String { + if let Some(message) = payload.downcast_ref::<&str>() { + (*message).to_string() + } else if let Some(message) = payload.downcast_ref::() { + message.clone() + } else { + "unknown panic payload".to_string() + } +} + +fn log_message_preview(message: &str) -> String { + message.chars().take(96).collect::() +} + +fn send_over_stream( + stream: &mut TransportStream, + context: &Context, + message: String, +) -> Result<(), String> { + let message_len = message.len(); + info!("Sending TCP payload ({} bytes)", message_len); + stream + .write_message(message.as_bytes()) + .map_err(|e| { + let message = e.to_string(); + let _ = context.callback_data( + "TCP SOCKET ERROR", + "TAK Socket disconnected", + message.clone(), + ); + message + }) +} + +fn flush_pending_messages( + connection: &mut Option, + pending_messages: &mut VecDeque<(String, Context)>, + state: &mut ConnectionState, +) { + if pending_messages.is_empty() { + return; + } + + let Some(stream) = connection.as_mut() else { + return; + }; + + info!( + "Flushing {} queued TCP payload(s) after connection became active", + pending_messages.len() + ); + + while let Some((message, context)) = pending_messages.pop_front() { + if let Err(error) = send_over_stream(stream, &context, message) { + info!("Failed to send queued message: {}", error); + *state = ConnectionState::Failed(error); + *connection = None; + return; + } + } +} + +fn poll_connect_event( + connect_rx: &Receiver, + connection: &mut Option, + state: &mut ConnectionState, + pending_messages: &mut VecDeque<(String, Context)>, + ctx: &Context, + connection_message: &str, + target: &str, +) { + loop { + match connect_rx.try_recv() { + Ok(ConnectEvent::Connected(stream)) => { + info!("TCP connection established successfully: {}", target); + let _ = ctx.callback_data("TCP SOCKET", connection_message, target.to_string()); + *connection = Some(stream); + *state = ConnectionState::Connected; + flush_pending_messages(connection, pending_messages, state); + } + Ok(ConnectEvent::Failed(error)) => { + info!("Failed to connect to TCP server: {}", error); + let _ = ctx.callback_data( + "TCP SOCKET ERROR", + "TAK Socket connection failed", + error.clone(), + ); + *state = ConnectionState::Failed(error); + } + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => break, + } + } +} + impl TcpClient { pub fn start(&self, config: ConnectionConfig, rx: Receiver, ctx: Context) { if let Some(ref client) = *TCP_CLIENT.lock().unwrap() { + info!("Existing TCP client detected; stopping previous instance before restart."); client.stop(); } - let connection = Arc::new(Mutex::new(None::)); - let connection_clone = Arc::clone(&connection); - thread::spawn(move || { let mut running = true; let connection_message = config.connected_message(); + let config_description = config.describe(); + let target = config.target(); + let mut state = ConnectionState::Connecting; + let mut connection: Option = None; + let mut pending_messages: VecDeque<(String, Context)> = VecDeque::new(); + let (connect_tx, connect_rx) = mpsc::channel(); - let tcp_thread = thread::spawn(move || match connect_stream(&config) { - Ok(stream) => { - let target = config.target(); - let _ = ctx.callback_data("TCP SOCKET", connection_message, target); - *connection_clone.lock().unwrap() = Some(stream); - } - Err(e) => { - let _ = ctx.callback_data( - "TCP SOCKET ERROR", - "TAK Socket connection failed", - e.to_string(), - ); - info!("Failed to connect to TCP server: {}", e); + info!("TCP worker thread started with config: {}", config_description); + + let tcp_thread = thread::spawn(move || { + let connect_result = panic::catch_unwind(AssertUnwindSafe(|| connect_stream(&config))); + + match connect_result { + Ok(Ok(stream)) => { + let _ = connect_tx.send(ConnectEvent::Connected(stream)); + } + Ok(Err(error)) => { + let _ = connect_tx.send(ConnectEvent::Failed(error)); + } + Err(payload) => { + let message = format!( + "TCP connection worker panicked: {}", + describe_panic_payload(payload) + ); + let _ = connect_tx.send(ConnectEvent::Failed(message)); + } } }); while running { - match rx.recv() { - Ok(TcpCommand::SendMessage(message, context)) => { - if let Some(stream) = connection.lock().unwrap().as_mut() { - if let Err(e) = stream.write_message(message.as_bytes()) { - info!("Failed to send message: {}", e); + poll_connect_event( + &connect_rx, + &mut connection, + &mut state, + &mut pending_messages, + &ctx, + connection_message, + &target, + ); + match rx.recv_timeout(CONNECT_POLL_INTERVAL) { + Ok(TcpCommand::SendMessage(message, context)) => { + let message_len = message.len(); + match &mut state { + ConnectionState::Connected => { + if let Some(stream) = connection.as_mut() { + if let Err(error) = send_over_stream(stream, &context, message) { + info!("Failed to send message: {}", error); + state = ConnectionState::Failed(error); + connection = None; + } + } else { + warn!( + "Connection state said connected, but no socket was present; queuing payload." + ); + pending_messages.push_back((message, context)); + } + } + ConnectionState::Connecting => { + if pending_messages.len() >= MAX_PENDING_MESSAGES { + let preview = log_message_preview(&message); + warn!( + "Dropping TCP payload because connection is still pending and queue is full ({} bytes, preview={:?})", + message_len, preview + ); + let _ = context.callback_data( + "TCP SOCKET ERROR", + "TAK Socket is still connecting", + format!( + "queue full while connecting; dropped payload ({} bytes, preview={:?})", + message_len, preview + ), + ); + } else { + info!( + "Queueing TCP payload while connection is pending ({} bytes, queued={})", + message_len, + pending_messages.len() + 1 + ); + pending_messages.push_back((message, context)); + } + } + ConnectionState::Failed(error) => { + let preview = log_message_preview(&message); + warn!( + "Dropping TCP payload because connection is in failed state ({} bytes, preview={:?}, error={})", + message_len, preview, error + ); let _ = context.callback_data( "TCP SOCKET ERROR", - "TAK Socket disconnected", - e.to_string(), + "TAK Socket is not connected", + error.clone(), ); - - running = false; } - } else { - let _ = context - .callback_null("TCP SOCKET ERROR", "TAK Socket is not active"); } } Ok(TcpCommand::Stop) => { running = false; info!("Stopping TCP client."); } - Err(error) => { - info!("Error receiving command: {}", error); + Err(RecvTimeoutError::Timeout) => {} + Err(RecvTimeoutError::Disconnected) => { + warn!("TCP command channel disconnected."); + running = false; } } } - tcp_thread.join().unwrap(); + info!("Waiting for TCP connection thread to finish."); + match tcp_thread.join() { + Ok(()) => info!("TCP connection thread joined successfully."), + Err(payload) => warn!( + "TCP connection thread join reported a panic: {}", + describe_panic_payload(payload) + ), + } + info!("TCP worker thread finished."); }); } pub fn send_payload(&self, context: Context, payload: String) { let tx = self.tx.clone(); thread::spawn(move || { - tx.send(TcpCommand::SendMessage(payload, context)).unwrap(); + info!("Dispatching queued TCP payload command."); + if let Err(error) = tx.send(TcpCommand::SendMessage(payload, context)) { + warn!("Failed to dispatch TCP payload command: {}", error); + } }); } pub fn stop(&self) { let tx = self.tx.clone(); thread::spawn(move || { - tx.send(TcpCommand::Stop).unwrap(); + info!("Dispatching TCP stop command."); + if let Err(error) = tx.send(TcpCommand::Stop) { + warn!("Failed to dispatch TCP stop command: {}", error); + } }); } } diff --git a/src/tcp/config.rs b/src/tcp/config.rs index d8e1460..fb05a19 100644 --- a/src/tcp/config.rs +++ b/src/tcp/config.rs @@ -34,4 +34,31 @@ impl ConnectionConfig { Self::EnrollMtls { host, .. } => host.clone(), } } + + pub fn describe(&self) -> String { + match self { + Self::Plain { address } => format!("plain tcp -> {}", address), + Self::Mtls { + address, + server_name, + ca_cert_path, + client_cert_path, + client_key_path, + } => format!( + "manual mtls -> {} (server_name={}, ca={}, cert={}, key={})", + address, server_name, ca_cert_path, client_cert_path, client_key_path + ), + Self::EnrollMtls { + host, + server_name, + enroll_port, + username, + client_uid, + .. + } => format!( + "enroll mtls -> host={} enroll_port={} server_name={} username={} client_uid={}", + host, enroll_port, server_name, username, client_uid + ), + } + } } diff --git a/src/tcp/mod.rs b/src/tcp/mod.rs index 641625f..5212fd5 100644 --- a/src/tcp/mod.rs +++ b/src/tcp/mod.rs @@ -14,13 +14,13 @@ pub mod draw; use client::{TcpClient, TcpCommand}; use config::ConnectionConfig; -use tls::artifacts::clear_enrollment_artifacts; lazy_static! { static ref TCP_CLIENT: Arc>> = Arc::new(Mutex::new(None)); } fn start_with_config(ctx: Context, config: ConnectionConfig) { + info!("Starting TCP client with config: {}", config.describe()); let (tx, rx): (Sender, Receiver) = mpsc::channel(); let client = TcpClient { tx }; @@ -67,7 +67,6 @@ pub fn start_enroll_mtls( password: String, client_uid: String, ) -> &'static str { - clear_enrollment_artifacts(); start_with_config( ctx, ConnectionConfig::EnrollMtls { @@ -85,6 +84,7 @@ pub fn start_enroll_mtls( pub fn send_payload(ctx: Context, payload: String) -> &'static str { if let Some(ref client) = *TCP_CLIENT.lock().unwrap() { + info!("Queueing TCP payload ({} bytes)", payload.len()); client.send_payload(ctx, payload); } else { let _ = ctx.callback_null("TCP SOCKET ERROR", "TCP Client is not running"); @@ -96,9 +96,9 @@ pub fn send_payload(ctx: Context, payload: String) -> &'static str { pub fn stop(ctx: Context) -> &'static str { if let Some(ref client) = *TCP_CLIENT.lock().unwrap() { + info!("Stopping TCP client via extension command."); client.stop(); let _ = ctx.callback_null("TCP SOCKET", "TCP client stopped"); - clear_enrollment_artifacts(); } else { let _ = ctx.callback_null("TCP SOCKET ERROR", "TCP client is not running"); } diff --git a/src/tcp/tls/artifacts.rs b/src/tcp/tls/artifacts.rs deleted file mode 100644 index fc5267c..0000000 --- a/src/tcp/tls/artifacts.rs +++ /dev/null @@ -1,101 +0,0 @@ -use lazy_static::lazy_static; -use std::env; -use std::fs::{self, create_dir_all}; -use std::path::PathBuf; -use std::sync::Mutex; - -#[derive(Clone)] -pub struct EnrollmentArtifacts { - pub ca_cert_path: String, - pub client_cert_path: String, - pub client_key_path: String, -} - -lazy_static! { - static ref ENROLLMENT_ARTIFACTS: Mutex> = Mutex::new(None); -} - -fn current_artifacts_dir() -> Result { - let mut path = env::current_dir().map_err(|e| format!("failed to resolve cwd: {}", e))?; - path.push(".armatak"); - path.push("session-certs"); - create_dir_all(&path) - .map_err(|e| format!("failed to create cert dir {}: {}", path.display(), e))?; - Ok(path) -} - -pub fn persist_enrollment_artifacts( - client_uid: &str, - ca_pem: &str, - cert_pem: &str, - key_pem: &str, -) -> Result { - let mut base_dir = current_artifacts_dir()?; - let safe_uid = client_uid - .chars() - .map(|ch| { - if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' { - ch - } else { - '_' - } - }) - .collect::(); - - base_dir.push(safe_uid); - create_dir_all(&base_dir).map_err(|e| { - format!( - "failed to create session cert dir {}: {}", - base_dir.display(), - e - ) - })?; - - let ca_cert_path = base_dir.join("ca.pem"); - let client_cert_path = base_dir.join("client.pem"); - let client_key_path = base_dir.join("client.key"); - - fs::write(&ca_cert_path, ca_pem).map_err(|e| { - format!( - "failed to persist CA cert {}: {}", - ca_cert_path.display(), - e - ) - })?; - fs::write(&client_cert_path, cert_pem).map_err(|e| { - format!( - "failed to persist client cert {}: {}", - client_cert_path.display(), - e - ) - })?; - fs::write(&client_key_path, key_pem).map_err(|e| { - format!( - "failed to persist client key {}: {}", - client_key_path.display(), - e - ) - })?; - - Ok(EnrollmentArtifacts { - ca_cert_path: ca_cert_path.to_string_lossy().to_string(), - client_cert_path: client_cert_path.to_string_lossy().to_string(), - client_key_path: client_key_path.to_string_lossy().to_string(), - }) -} - -pub fn store_enrollment_artifacts(artifacts: EnrollmentArtifacts) { - *ENROLLMENT_ARTIFACTS.lock().unwrap() = Some(artifacts); -} - -pub fn clear_enrollment_artifacts() { - if let Some(artifacts) = ENROLLMENT_ARTIFACTS.lock().unwrap().take() { - for path in [ - artifacts.ca_cert_path, - artifacts.client_cert_path, - artifacts.client_key_path, - ] { - let _ = fs::remove_file(path); - } - } -} diff --git a/src/tcp/tls/connector.rs b/src/tcp/tls/connector.rs index a1b4ec2..f770be2 100644 --- a/src/tcp/tls/connector.rs +++ b/src/tcp/tls/connector.rs @@ -1,13 +1,19 @@ +use log::info; use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned}; use rustls_pemfile::{certs, private_key}; use std::fs::File; use std::io::BufReader; -use std::net::TcpStream; +use std::io::Cursor; +use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; +use std::time::Duration; use std::sync::Arc; use crate::tcp::transport::TransportStream; +const TCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); +const SOCKET_IO_TIMEOUT: Duration = Duration::from_secs(10); + fn load_certificates(path: &str) -> Result>, String> { let file = File::open(path).map_err(|e| format!("failed to open cert file {}: {}", path, e))?; let mut reader = BufReader::new(file); @@ -26,6 +32,22 @@ fn load_private_key(path: &str) -> Result, String> { .ok_or_else(|| format!("no supported private key found in {}", path)) } +fn load_certificates_from_pem(pem: &str) -> Result>, String> { + let mut reader = Cursor::new(pem.as_bytes()); + + certs(&mut reader) + .collect::, _>>() + .map_err(|e| format!("failed to read certs from PEM payload: {}", e)) +} + +fn load_private_key_from_pem(pem: &str) -> Result, String> { + let mut reader = Cursor::new(pem.as_bytes()); + + private_key(&mut reader) + .map_err(|e| format!("failed to read private key from PEM payload: {}", e))? + .ok_or_else(|| "no supported private key found in PEM payload".to_string()) +} + fn infer_server_name(address: &str) -> &str { address .trim() @@ -38,6 +60,34 @@ fn infer_server_name(address: &str) -> &str { .unwrap_or(address) } +fn resolve_address(address: &str) -> Result { + address + .to_socket_addrs() + .map_err(|e| format!("failed to resolve {}: {}", address, e))? + .next() + .ok_or_else(|| format!("failed to resolve {}: no socket addresses returned", address)) +} + +fn connect_tcp(address: &str) -> Result { + let socket_addr = resolve_address(address)?; + info!( + "Opening TCP connection to {} (resolved={}) with timeout {:?}", + address, socket_addr, TCP_CONNECT_TIMEOUT + ); + + let tcp_stream = TcpStream::connect_timeout(&socket_addr, TCP_CONNECT_TIMEOUT) + .map_err(|e| format!("failed to connect to {}: {}", address, e))?; + + tcp_stream + .set_read_timeout(Some(SOCKET_IO_TIMEOUT)) + .map_err(|e| format!("failed to set read timeout on {}: {}", address, e))?; + tcp_stream + .set_write_timeout(Some(SOCKET_IO_TIMEOUT)) + .map_err(|e| format!("failed to set write timeout on {}: {}", address, e))?; + + Ok(tcp_stream) +} + pub fn connect_mtls( address: &str, server_name: &str, @@ -45,23 +95,39 @@ pub fn connect_mtls( client_cert_path: &str, client_key_path: &str, ) -> Result { + info!( + "Connecting mTLS from file paths to {} using server_name={}", + address, server_name + ); let mut root_store = RootCertStore::empty(); - for certificate in load_certificates(ca_cert_path)? { + let ca_certificates = load_certificates(ca_cert_path)?; + info!( + "Loaded {} CA certificate(s) from {}", + ca_certificates.len(), + ca_cert_path + ); + for certificate in ca_certificates { root_store .add(certificate) .map_err(|e| format!("failed to add CA certificate from {}: {}", ca_cert_path, e))?; } let client_certificates = load_certificates(client_cert_path)?; + info!( + "Loaded {} client certificate(s) from {}", + client_certificates.len(), + client_cert_path + ); let client_key = load_private_key(client_key_path)?; + info!("Loaded client private key from {}", client_key_path); let tls_config = ClientConfig::builder() .with_root_certificates(root_store) .with_client_auth_cert(client_certificates, client_key) .map_err(|e| format!("failed to configure mTLS client: {}", e))?; + info!("Constructed rustls client config for {}", address); - let tcp_stream = TcpStream::connect(address) - .map_err(|e| format!("failed to connect to {}: {}", address, e))?; + let tcp_stream = connect_tcp(address)?; let resolved_server_name = if server_name.trim().is_empty() { infer_server_name(address).to_string() } else { @@ -76,6 +142,7 @@ pub fn connect_mtls( tcp_stream, ); + info!("Starting mTLS handshake for {}", address); while tls_stream.conn.is_handshaking() { tls_stream .conn @@ -83,5 +150,72 @@ pub fn connect_mtls( .map_err(|e| format!("TLS handshake failed: {}", e))?; } + info!("mTLS handshake completed successfully for {}", address); + + Ok(TransportStream::Mtls(tls_stream)) +} + +pub fn connect_mtls_from_pem( + address: &str, + server_name: &str, + ca_cert_pem: &str, + client_cert_pem: &str, + client_key_pem: &str, +) -> Result { + info!( + "Connecting mTLS from in-memory PEM payloads to {} using server_name={}", + address, server_name + ); + let mut root_store = RootCertStore::empty(); + let ca_certificates = load_certificates_from_pem(ca_cert_pem)?; + info!( + "Loaded {} CA certificate(s) from enrollment payload", + ca_certificates.len() + ); + for certificate in ca_certificates { + root_store + .add(certificate) + .map_err(|e| format!("failed to add CA certificate from PEM payload: {}", e))?; + } + + let client_certificates = load_certificates_from_pem(client_cert_pem)?; + info!( + "Loaded {} client certificate(s) from enrollment payload", + client_certificates.len() + ); + let client_key = load_private_key_from_pem(client_key_pem)?; + info!("Loaded client private key from enrollment payload"); + + let tls_config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_client_auth_cert(client_certificates, client_key) + .map_err(|e| format!("failed to configure mTLS client: {}", e))?; + info!("Constructed rustls client config for {}", address); + + let tcp_stream = connect_tcp(address)?; + let resolved_server_name = if server_name.trim().is_empty() { + infer_server_name(address).to_string() + } else { + server_name.trim().to_string() + }; + + let server_name = ServerName::try_from(resolved_server_name.clone()) + .map_err(|_| format!("invalid TLS server name: {}", resolved_server_name))?; + let mut tls_stream = StreamOwned::new( + ClientConnection::new(Arc::new(tls_config), server_name) + .map_err(|e| format!("failed to create TLS client: {}", e))?, + tcp_stream, + ); + + info!("Starting mTLS handshake for {}", address); + while tls_stream.conn.is_handshaking() { + tls_stream + .conn + .complete_io(&mut tls_stream.sock) + .map_err(|e| format!("TLS handshake failed: {}", e))?; + } + + info!("mTLS handshake completed successfully for {}", address); + Ok(TransportStream::Mtls(tls_stream)) } diff --git a/src/tcp/tls/enrollment.rs b/src/tcp/tls/enrollment.rs index 46395fb..ad36d19 100644 --- a/src/tcp/tls/enrollment.rs +++ b/src/tcp/tls/enrollment.rs @@ -1,12 +1,10 @@ -use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair, PKCS_ECDSA_P256_SHA256}; +use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair, PKCS_RSA_SHA256}; +use log::info; use reqwest::blocking::Client; use serde::Deserialize; use uuid::Uuid; -use super::artifacts::{ - persist_enrollment_artifacts, store_enrollment_artifacts, EnrollmentArtifacts, -}; -use super::connector::connect_mtls; +use super::connector::connect_mtls_from_pem; use crate::tcp::transport::TransportStream; #[derive(Deserialize)] @@ -72,6 +70,7 @@ fn fetch_enrollment_config(host: &str, enroll_port: &str) -> Result Result Result { - let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256) +) -> Result<(String, String, String), String> { + info!( + "Generating RSA client keypair and CSR for enrolled TAK client {}", + client_uid + ); + let key_pair = KeyPair::generate_for(&PKCS_RSA_SHA256) .map_err(|e| format!("failed to generate client keypair: {}", e))?; let mut distinguished_name = DistinguishedName::new(); @@ -133,6 +141,10 @@ fn enroll_client_certificate( enroll_path.trim(), client_uid.trim() ); + info!( + "Submitting client certificate enrollment request for {} to {}", + client_uid, url + ); let response = enrollment_http_client()? .post(&url) @@ -154,6 +166,12 @@ fn enroll_client_certificate( let enrollment: EnrollmentResponse = response .json() .map_err(|e| format!("failed to parse enrollment response: {}", e))?; + info!( + "Enrollment response parsed successfully for {} (signed_cert_len={}, ca_len={})", + client_uid, + enrollment.signed_cert.len(), + enrollment.ca0.len() + ); let cert_pem = wrap_pem_body( &enrollment.signed_cert, @@ -162,7 +180,7 @@ fn enroll_client_certificate( ); let key_pem = key_pair.serialize_pem(); - persist_enrollment_artifacts(client_uid, &enrollment.ca0, &cert_pem, &key_pem) + Ok((enrollment.ca0, cert_pem, key_pem)) } pub fn enroll_and_connect( @@ -178,9 +196,16 @@ pub fn enroll_and_connect( } else { client_uid.trim().to_string() }; + info!( + "Starting enroll_and_connect for host={} enroll_port={} server_name={} client_uid={}", + host, + enroll_port, + server_name, + normalized_client_uid + ); let enrollment_config = fetch_enrollment_config(host, enroll_port)?; - let artifacts = enroll_client_certificate( + let (ca_cert_pem, client_cert_pem, client_key_pem) = enroll_client_certificate( host, enroll_port, &enrollment_config.enroll_path, @@ -189,17 +214,15 @@ pub fn enroll_and_connect( &normalized_client_uid, )?; - store_enrollment_artifacts(artifacts.clone()); - - connect_mtls( + connect_mtls_from_pem( &format!("{}:{}", host.trim(), enrollment_config.server_port.trim()), if server_name.trim().is_empty() { host.trim() } else { server_name.trim() }, - &artifacts.ca_cert_path, - &artifacts.client_cert_path, - &artifacts.client_key_path, + &ca_cert_pem, + &client_cert_pem, + &client_key_pem, ) } diff --git a/src/tcp/tls/mod.rs b/src/tcp/tls/mod.rs index 1b7e579..52b21a6 100644 --- a/src/tcp/tls/mod.rs +++ b/src/tcp/tls/mod.rs @@ -1,4 +1,3 @@ -pub mod artifacts; mod connector; mod enrollment; diff --git a/src/tcp/transport.rs b/src/tcp/transport.rs index 8e47890..f11d883 100644 --- a/src/tcp/transport.rs +++ b/src/tcp/transport.rs @@ -1,10 +1,15 @@ +use log::info; use rustls::{ClientConnection, StreamOwned}; use std::io::Write; -use std::net::TcpStream; +use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; +use std::time::Duration; use super::config::ConnectionConfig; use super::tls::{connect_mtls, enroll_and_connect}; +const TCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); +const SOCKET_IO_TIMEOUT: Duration = Duration::from_secs(10); + pub enum TransportStream { Plain(TcpStream), Mtls(StreamOwned), @@ -25,11 +30,34 @@ impl TransportStream { } } +fn connect_plain(address: &str) -> Result { + let socket_addr: SocketAddr = address + .to_socket_addrs() + .map_err(|e| format!("failed to resolve {}: {}", address, e))? + .next() + .ok_or_else(|| format!("failed to resolve {}: no socket addresses returned", address))?; + + info!( + "Opening plain TCP connection to {} (resolved={}) with timeout {:?}", + address, socket_addr, TCP_CONNECT_TIMEOUT + ); + + let stream = TcpStream::connect_timeout(&socket_addr, TCP_CONNECT_TIMEOUT) + .map_err(|e| format!("failed to connect to {}: {}", address, e))?; + stream + .set_read_timeout(Some(SOCKET_IO_TIMEOUT)) + .map_err(|e| format!("failed to set read timeout on {}: {}", address, e))?; + stream + .set_write_timeout(Some(SOCKET_IO_TIMEOUT)) + .map_err(|e| format!("failed to set write timeout on {}: {}", address, e))?; + + Ok(TransportStream::Plain(stream)) +} + pub fn connect_stream(config: &ConnectionConfig) -> Result { + info!("connect_stream invoked for {}", config.describe()); match config { - ConnectionConfig::Plain { address } => TcpStream::connect(address) - .map(TransportStream::Plain) - .map_err(|e| format!("failed to connect to {}: {}", address, e)), + ConnectionConfig::Plain { address } => connect_plain(address), ConnectionConfig::Mtls { address, server_name,