Fixed CoT queue during armatak connection to the TAK Server, running soft as butter

This commit is contained in:
2026-03-26 03:45:05 -03:00
parent e32aadda4e
commit 708fe5e670
11 changed files with 464 additions and 166 deletions

2
Cargo.lock generated
View File

@@ -1053,8 +1053,8 @@ version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2"
dependencies = [ dependencies = [
"aws-lc-rs",
"pem", "pem",
"ring",
"rustls-pki-types", "rustls-pki-types",
"time", "time",
"yasna", "yasna",

View File

@@ -10,7 +10,7 @@ lazy_static = "1.5.0"
log = "0.4.22" log = "0.4.22"
log4rs = "1.3.0" log4rs = "1.3.0"
reqwest = { version = "0.12.15", default-features = false, features = ["blocking", "json", "rustls-tls"] } reqwest = { version = "0.12.15", default-features = false, features = ["blocking", "json", "rustls-tls"] }
rcgen = "0.13.2" rcgen = { version = "0.13.2", default-features = false, features = ["crypto", "pem", "aws_lc_rs"] }
rustls = "0.23.23" rustls = "0.23.23"
rustls-pemfile = "2.2.0" rustls-pemfile = "2.2.0"
serde = { version = "1.0.210", features = ["derive"] } serde = { version = "1.0.210", features = ["derive"] }

View File

@@ -1,4 +1,5 @@
use arma_rs::{arma, Extension, Group}; use arma_rs::{arma, Extension, Group};
use rustls::crypto::aws_lc_rs;
mod structs; mod structs;
mod tcp; mod tcp;
mod tests; mod tests;
@@ -31,6 +32,9 @@ pub fn init() -> Extension {
log4rs::init_config(config).unwrap(); log4rs::init_config(config).unwrap();
let _ = aws_lc_rs::default_provider().install_default();
log::info!("Initialized rustls aws-lc crypto provider.");
Extension::build() Extension::build()
.command("local_ip", utils::address::get_local_address) .command("local_ip", utils::address::get_local_address)
.command("uuid", utils::uuid::get_uuid) .command("uuid", utils::uuid::get_uuid)

View File

@@ -1,13 +1,18 @@
use arma_rs::Context; use arma_rs::Context;
use log::info; use log::{info, warn};
use std::sync::mpsc::{Receiver, Sender}; use std::collections::VecDeque;
use std::sync::{Arc, Mutex}; use std::panic::{self, AssertUnwindSafe};
use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender, TryRecvError};
use std::thread; use std::thread;
use std::time::Duration;
use super::config::ConnectionConfig; use super::config::ConnectionConfig;
use super::transport::{connect_stream, TransportStream}; use super::transport::{connect_stream, TransportStream};
use super::TCP_CLIENT; use super::TCP_CLIENT;
const CONNECT_POLL_INTERVAL: Duration = Duration::from_millis(200);
const MAX_PENDING_MESSAGES: usize = 128;
pub enum TcpCommand { pub enum TcpCommand {
SendMessage(String, Context), SendMessage(String, Context),
Stop, Stop,
@@ -17,80 +22,259 @@ pub struct TcpClient {
pub(crate) tx: Sender<TcpCommand>, pub(crate) tx: Sender<TcpCommand>,
} }
enum ConnectionState {
Connecting,
Connected,
Failed(String),
}
enum ConnectEvent {
Connected(TransportStream),
Failed(String),
}
fn describe_panic_payload(payload: Box<dyn std::any::Any + Send>) -> String {
if let Some(message) = payload.downcast_ref::<&str>() {
(*message).to_string()
} else if let Some(message) = payload.downcast_ref::<String>() {
message.clone()
} else {
"unknown panic payload".to_string()
}
}
fn log_message_preview(message: &str) -> String {
message.chars().take(96).collect::<String>()
}
fn send_over_stream(
stream: &mut TransportStream,
context: &Context,
message: String,
) -> Result<(), String> {
let message_len = message.len();
info!("Sending TCP payload ({} bytes)", message_len);
stream
.write_message(message.as_bytes())
.map_err(|e| {
let message = e.to_string();
let _ = context.callback_data(
"TCP SOCKET ERROR",
"TAK Socket disconnected",
message.clone(),
);
message
})
}
fn flush_pending_messages(
connection: &mut Option<TransportStream>,
pending_messages: &mut VecDeque<(String, Context)>,
state: &mut ConnectionState,
) {
if pending_messages.is_empty() {
return;
}
let Some(stream) = connection.as_mut() else {
return;
};
info!(
"Flushing {} queued TCP payload(s) after connection became active",
pending_messages.len()
);
while let Some((message, context)) = pending_messages.pop_front() {
if let Err(error) = send_over_stream(stream, &context, message) {
info!("Failed to send queued message: {}", error);
*state = ConnectionState::Failed(error);
*connection = None;
return;
}
}
}
fn poll_connect_event(
connect_rx: &Receiver<ConnectEvent>,
connection: &mut Option<TransportStream>,
state: &mut ConnectionState,
pending_messages: &mut VecDeque<(String, Context)>,
ctx: &Context,
connection_message: &str,
target: &str,
) {
loop {
match connect_rx.try_recv() {
Ok(ConnectEvent::Connected(stream)) => {
info!("TCP connection established successfully: {}", target);
let _ = ctx.callback_data("TCP SOCKET", connection_message, target.to_string());
*connection = Some(stream);
*state = ConnectionState::Connected;
flush_pending_messages(connection, pending_messages, state);
}
Ok(ConnectEvent::Failed(error)) => {
info!("Failed to connect to TCP server: {}", error);
let _ = ctx.callback_data(
"TCP SOCKET ERROR",
"TAK Socket connection failed",
error.clone(),
);
*state = ConnectionState::Failed(error);
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => break,
}
}
}
impl TcpClient { impl TcpClient {
pub fn start(&self, config: ConnectionConfig, rx: Receiver<TcpCommand>, ctx: Context) { pub fn start(&self, config: ConnectionConfig, rx: Receiver<TcpCommand>, ctx: Context) {
if let Some(ref client) = *TCP_CLIENT.lock().unwrap() { if let Some(ref client) = *TCP_CLIENT.lock().unwrap() {
info!("Existing TCP client detected; stopping previous instance before restart.");
client.stop(); client.stop();
} }
let connection = Arc::new(Mutex::new(None::<TransportStream>));
let connection_clone = Arc::clone(&connection);
thread::spawn(move || { thread::spawn(move || {
let mut running = true; let mut running = true;
let connection_message = config.connected_message(); let connection_message = config.connected_message();
let config_description = config.describe();
let target = config.target();
let mut state = ConnectionState::Connecting;
let mut connection: Option<TransportStream> = None;
let mut pending_messages: VecDeque<(String, Context)> = VecDeque::new();
let (connect_tx, connect_rx) = mpsc::channel();
let tcp_thread = thread::spawn(move || match connect_stream(&config) { info!("TCP worker thread started with config: {}", config_description);
Ok(stream) => {
let target = config.target(); let tcp_thread = thread::spawn(move || {
let _ = ctx.callback_data("TCP SOCKET", connection_message, target); let connect_result = panic::catch_unwind(AssertUnwindSafe(|| connect_stream(&config)));
*connection_clone.lock().unwrap() = Some(stream);
} match connect_result {
Err(e) => { Ok(Ok(stream)) => {
let _ = ctx.callback_data( let _ = connect_tx.send(ConnectEvent::Connected(stream));
"TCP SOCKET ERROR", }
"TAK Socket connection failed", Ok(Err(error)) => {
e.to_string(), let _ = connect_tx.send(ConnectEvent::Failed(error));
); }
info!("Failed to connect to TCP server: {}", e); Err(payload) => {
let message = format!(
"TCP connection worker panicked: {}",
describe_panic_payload(payload)
);
let _ = connect_tx.send(ConnectEvent::Failed(message));
}
} }
}); });
while running { while running {
match rx.recv() { poll_connect_event(
Ok(TcpCommand::SendMessage(message, context)) => { &connect_rx,
if let Some(stream) = connection.lock().unwrap().as_mut() { &mut connection,
if let Err(e) = stream.write_message(message.as_bytes()) { &mut state,
info!("Failed to send message: {}", e); &mut pending_messages,
&ctx,
connection_message,
&target,
);
match rx.recv_timeout(CONNECT_POLL_INTERVAL) {
Ok(TcpCommand::SendMessage(message, context)) => {
let message_len = message.len();
match &mut state {
ConnectionState::Connected => {
if let Some(stream) = connection.as_mut() {
if let Err(error) = send_over_stream(stream, &context, message) {
info!("Failed to send message: {}", error);
state = ConnectionState::Failed(error);
connection = None;
}
} else {
warn!(
"Connection state said connected, but no socket was present; queuing payload."
);
pending_messages.push_back((message, context));
}
}
ConnectionState::Connecting => {
if pending_messages.len() >= MAX_PENDING_MESSAGES {
let preview = log_message_preview(&message);
warn!(
"Dropping TCP payload because connection is still pending and queue is full ({} bytes, preview={:?})",
message_len, preview
);
let _ = context.callback_data(
"TCP SOCKET ERROR",
"TAK Socket is still connecting",
format!(
"queue full while connecting; dropped payload ({} bytes, preview={:?})",
message_len, preview
),
);
} else {
info!(
"Queueing TCP payload while connection is pending ({} bytes, queued={})",
message_len,
pending_messages.len() + 1
);
pending_messages.push_back((message, context));
}
}
ConnectionState::Failed(error) => {
let preview = log_message_preview(&message);
warn!(
"Dropping TCP payload because connection is in failed state ({} bytes, preview={:?}, error={})",
message_len, preview, error
);
let _ = context.callback_data( let _ = context.callback_data(
"TCP SOCKET ERROR", "TCP SOCKET ERROR",
"TAK Socket disconnected", "TAK Socket is not connected",
e.to_string(), error.clone(),
); );
running = false;
} }
} else {
let _ = context
.callback_null("TCP SOCKET ERROR", "TAK Socket is not active");
} }
} }
Ok(TcpCommand::Stop) => { Ok(TcpCommand::Stop) => {
running = false; running = false;
info!("Stopping TCP client."); info!("Stopping TCP client.");
} }
Err(error) => { Err(RecvTimeoutError::Timeout) => {}
info!("Error receiving command: {}", error); Err(RecvTimeoutError::Disconnected) => {
warn!("TCP command channel disconnected.");
running = false;
} }
} }
} }
tcp_thread.join().unwrap(); info!("Waiting for TCP connection thread to finish.");
match tcp_thread.join() {
Ok(()) => info!("TCP connection thread joined successfully."),
Err(payload) => warn!(
"TCP connection thread join reported a panic: {}",
describe_panic_payload(payload)
),
}
info!("TCP worker thread finished.");
}); });
} }
pub fn send_payload(&self, context: Context, payload: String) { pub fn send_payload(&self, context: Context, payload: String) {
let tx = self.tx.clone(); let tx = self.tx.clone();
thread::spawn(move || { thread::spawn(move || {
tx.send(TcpCommand::SendMessage(payload, context)).unwrap(); info!("Dispatching queued TCP payload command.");
if let Err(error) = tx.send(TcpCommand::SendMessage(payload, context)) {
warn!("Failed to dispatch TCP payload command: {}", error);
}
}); });
} }
pub fn stop(&self) { pub fn stop(&self) {
let tx = self.tx.clone(); let tx = self.tx.clone();
thread::spawn(move || { thread::spawn(move || {
tx.send(TcpCommand::Stop).unwrap(); info!("Dispatching TCP stop command.");
if let Err(error) = tx.send(TcpCommand::Stop) {
warn!("Failed to dispatch TCP stop command: {}", error);
}
}); });
} }
} }

View File

@@ -34,4 +34,31 @@ impl ConnectionConfig {
Self::EnrollMtls { host, .. } => host.clone(), Self::EnrollMtls { host, .. } => host.clone(),
} }
} }
pub fn describe(&self) -> String {
match self {
Self::Plain { address } => format!("plain tcp -> {}", address),
Self::Mtls {
address,
server_name,
ca_cert_path,
client_cert_path,
client_key_path,
} => format!(
"manual mtls -> {} (server_name={}, ca={}, cert={}, key={})",
address, server_name, ca_cert_path, client_cert_path, client_key_path
),
Self::EnrollMtls {
host,
server_name,
enroll_port,
username,
client_uid,
..
} => format!(
"enroll mtls -> host={} enroll_port={} server_name={} username={} client_uid={}",
host, enroll_port, server_name, username, client_uid
),
}
}
} }

View File

@@ -14,13 +14,13 @@ pub mod draw;
use client::{TcpClient, TcpCommand}; use client::{TcpClient, TcpCommand};
use config::ConnectionConfig; use config::ConnectionConfig;
use tls::artifacts::clear_enrollment_artifacts;
lazy_static! { lazy_static! {
static ref TCP_CLIENT: Arc<Mutex<Option<TcpClient>>> = Arc::new(Mutex::new(None)); static ref TCP_CLIENT: Arc<Mutex<Option<TcpClient>>> = Arc::new(Mutex::new(None));
} }
fn start_with_config(ctx: Context, config: ConnectionConfig) { fn start_with_config(ctx: Context, config: ConnectionConfig) {
info!("Starting TCP client with config: {}", config.describe());
let (tx, rx): (Sender<TcpCommand>, Receiver<TcpCommand>) = mpsc::channel(); let (tx, rx): (Sender<TcpCommand>, Receiver<TcpCommand>) = mpsc::channel();
let client = TcpClient { tx }; let client = TcpClient { tx };
@@ -67,7 +67,6 @@ pub fn start_enroll_mtls(
password: String, password: String,
client_uid: String, client_uid: String,
) -> &'static str { ) -> &'static str {
clear_enrollment_artifacts();
start_with_config( start_with_config(
ctx, ctx,
ConnectionConfig::EnrollMtls { ConnectionConfig::EnrollMtls {
@@ -85,6 +84,7 @@ pub fn start_enroll_mtls(
pub fn send_payload(ctx: Context, payload: String) -> &'static str { pub fn send_payload(ctx: Context, payload: String) -> &'static str {
if let Some(ref client) = *TCP_CLIENT.lock().unwrap() { if let Some(ref client) = *TCP_CLIENT.lock().unwrap() {
info!("Queueing TCP payload ({} bytes)", payload.len());
client.send_payload(ctx, payload); client.send_payload(ctx, payload);
} else { } else {
let _ = ctx.callback_null("TCP SOCKET ERROR", "TCP Client is not running"); let _ = ctx.callback_null("TCP SOCKET ERROR", "TCP Client is not running");
@@ -96,9 +96,9 @@ pub fn send_payload(ctx: Context, payload: String) -> &'static str {
pub fn stop(ctx: Context) -> &'static str { pub fn stop(ctx: Context) -> &'static str {
if let Some(ref client) = *TCP_CLIENT.lock().unwrap() { if let Some(ref client) = *TCP_CLIENT.lock().unwrap() {
info!("Stopping TCP client via extension command.");
client.stop(); client.stop();
let _ = ctx.callback_null("TCP SOCKET", "TCP client stopped"); let _ = ctx.callback_null("TCP SOCKET", "TCP client stopped");
clear_enrollment_artifacts();
} else { } else {
let _ = ctx.callback_null("TCP SOCKET ERROR", "TCP client is not running"); let _ = ctx.callback_null("TCP SOCKET ERROR", "TCP client is not running");
} }

View File

@@ -1,101 +0,0 @@
use lazy_static::lazy_static;
use std::env;
use std::fs::{self, create_dir_all};
use std::path::PathBuf;
use std::sync::Mutex;
#[derive(Clone)]
pub struct EnrollmentArtifacts {
pub ca_cert_path: String,
pub client_cert_path: String,
pub client_key_path: String,
}
lazy_static! {
static ref ENROLLMENT_ARTIFACTS: Mutex<Option<EnrollmentArtifacts>> = Mutex::new(None);
}
fn current_artifacts_dir() -> Result<PathBuf, String> {
let mut path = env::current_dir().map_err(|e| format!("failed to resolve cwd: {}", e))?;
path.push(".armatak");
path.push("session-certs");
create_dir_all(&path)
.map_err(|e| format!("failed to create cert dir {}: {}", path.display(), e))?;
Ok(path)
}
pub fn persist_enrollment_artifacts(
client_uid: &str,
ca_pem: &str,
cert_pem: &str,
key_pem: &str,
) -> Result<EnrollmentArtifacts, String> {
let mut base_dir = current_artifacts_dir()?;
let safe_uid = client_uid
.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
ch
} else {
'_'
}
})
.collect::<String>();
base_dir.push(safe_uid);
create_dir_all(&base_dir).map_err(|e| {
format!(
"failed to create session cert dir {}: {}",
base_dir.display(),
e
)
})?;
let ca_cert_path = base_dir.join("ca.pem");
let client_cert_path = base_dir.join("client.pem");
let client_key_path = base_dir.join("client.key");
fs::write(&ca_cert_path, ca_pem).map_err(|e| {
format!(
"failed to persist CA cert {}: {}",
ca_cert_path.display(),
e
)
})?;
fs::write(&client_cert_path, cert_pem).map_err(|e| {
format!(
"failed to persist client cert {}: {}",
client_cert_path.display(),
e
)
})?;
fs::write(&client_key_path, key_pem).map_err(|e| {
format!(
"failed to persist client key {}: {}",
client_key_path.display(),
e
)
})?;
Ok(EnrollmentArtifacts {
ca_cert_path: ca_cert_path.to_string_lossy().to_string(),
client_cert_path: client_cert_path.to_string_lossy().to_string(),
client_key_path: client_key_path.to_string_lossy().to_string(),
})
}
pub fn store_enrollment_artifacts(artifacts: EnrollmentArtifacts) {
*ENROLLMENT_ARTIFACTS.lock().unwrap() = Some(artifacts);
}
pub fn clear_enrollment_artifacts() {
if let Some(artifacts) = ENROLLMENT_ARTIFACTS.lock().unwrap().take() {
for path in [
artifacts.ca_cert_path,
artifacts.client_cert_path,
artifacts.client_key_path,
] {
let _ = fs::remove_file(path);
}
}
}

View File

@@ -1,13 +1,19 @@
use log::info;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned}; use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
use rustls_pemfile::{certs, private_key}; use rustls_pemfile::{certs, private_key};
use std::fs::File; use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::net::TcpStream; use std::io::Cursor;
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::time::Duration;
use std::sync::Arc; use std::sync::Arc;
use crate::tcp::transport::TransportStream; use crate::tcp::transport::TransportStream;
const TCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const SOCKET_IO_TIMEOUT: Duration = Duration::from_secs(10);
fn load_certificates(path: &str) -> Result<Vec<CertificateDer<'static>>, String> { fn load_certificates(path: &str) -> Result<Vec<CertificateDer<'static>>, String> {
let file = File::open(path).map_err(|e| format!("failed to open cert file {}: {}", path, e))?; let file = File::open(path).map_err(|e| format!("failed to open cert file {}: {}", path, e))?;
let mut reader = BufReader::new(file); let mut reader = BufReader::new(file);
@@ -26,6 +32,22 @@ fn load_private_key(path: &str) -> Result<PrivateKeyDer<'static>, String> {
.ok_or_else(|| format!("no supported private key found in {}", path)) .ok_or_else(|| format!("no supported private key found in {}", path))
} }
fn load_certificates_from_pem(pem: &str) -> Result<Vec<CertificateDer<'static>>, String> {
let mut reader = Cursor::new(pem.as_bytes());
certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| format!("failed to read certs from PEM payload: {}", e))
}
fn load_private_key_from_pem(pem: &str) -> Result<PrivateKeyDer<'static>, String> {
let mut reader = Cursor::new(pem.as_bytes());
private_key(&mut reader)
.map_err(|e| format!("failed to read private key from PEM payload: {}", e))?
.ok_or_else(|| "no supported private key found in PEM payload".to_string())
}
fn infer_server_name(address: &str) -> &str { fn infer_server_name(address: &str) -> &str {
address address
.trim() .trim()
@@ -38,6 +60,34 @@ fn infer_server_name(address: &str) -> &str {
.unwrap_or(address) .unwrap_or(address)
} }
fn resolve_address(address: &str) -> Result<SocketAddr, String> {
address
.to_socket_addrs()
.map_err(|e| format!("failed to resolve {}: {}", address, e))?
.next()
.ok_or_else(|| format!("failed to resolve {}: no socket addresses returned", address))
}
fn connect_tcp(address: &str) -> Result<TcpStream, String> {
let socket_addr = resolve_address(address)?;
info!(
"Opening TCP connection to {} (resolved={}) with timeout {:?}",
address, socket_addr, TCP_CONNECT_TIMEOUT
);
let tcp_stream = TcpStream::connect_timeout(&socket_addr, TCP_CONNECT_TIMEOUT)
.map_err(|e| format!("failed to connect to {}: {}", address, e))?;
tcp_stream
.set_read_timeout(Some(SOCKET_IO_TIMEOUT))
.map_err(|e| format!("failed to set read timeout on {}: {}", address, e))?;
tcp_stream
.set_write_timeout(Some(SOCKET_IO_TIMEOUT))
.map_err(|e| format!("failed to set write timeout on {}: {}", address, e))?;
Ok(tcp_stream)
}
pub fn connect_mtls( pub fn connect_mtls(
address: &str, address: &str,
server_name: &str, server_name: &str,
@@ -45,23 +95,39 @@ pub fn connect_mtls(
client_cert_path: &str, client_cert_path: &str,
client_key_path: &str, client_key_path: &str,
) -> Result<TransportStream, String> { ) -> Result<TransportStream, String> {
info!(
"Connecting mTLS from file paths to {} using server_name={}",
address, server_name
);
let mut root_store = RootCertStore::empty(); let mut root_store = RootCertStore::empty();
for certificate in load_certificates(ca_cert_path)? { let ca_certificates = load_certificates(ca_cert_path)?;
info!(
"Loaded {} CA certificate(s) from {}",
ca_certificates.len(),
ca_cert_path
);
for certificate in ca_certificates {
root_store root_store
.add(certificate) .add(certificate)
.map_err(|e| format!("failed to add CA certificate from {}: {}", ca_cert_path, e))?; .map_err(|e| format!("failed to add CA certificate from {}: {}", ca_cert_path, e))?;
} }
let client_certificates = load_certificates(client_cert_path)?; let client_certificates = load_certificates(client_cert_path)?;
info!(
"Loaded {} client certificate(s) from {}",
client_certificates.len(),
client_cert_path
);
let client_key = load_private_key(client_key_path)?; let client_key = load_private_key(client_key_path)?;
info!("Loaded client private key from {}", client_key_path);
let tls_config = ClientConfig::builder() let tls_config = ClientConfig::builder()
.with_root_certificates(root_store) .with_root_certificates(root_store)
.with_client_auth_cert(client_certificates, client_key) .with_client_auth_cert(client_certificates, client_key)
.map_err(|e| format!("failed to configure mTLS client: {}", e))?; .map_err(|e| format!("failed to configure mTLS client: {}", e))?;
info!("Constructed rustls client config for {}", address);
let tcp_stream = TcpStream::connect(address) let tcp_stream = connect_tcp(address)?;
.map_err(|e| format!("failed to connect to {}: {}", address, e))?;
let resolved_server_name = if server_name.trim().is_empty() { let resolved_server_name = if server_name.trim().is_empty() {
infer_server_name(address).to_string() infer_server_name(address).to_string()
} else { } else {
@@ -76,6 +142,7 @@ pub fn connect_mtls(
tcp_stream, tcp_stream,
); );
info!("Starting mTLS handshake for {}", address);
while tls_stream.conn.is_handshaking() { while tls_stream.conn.is_handshaking() {
tls_stream tls_stream
.conn .conn
@@ -83,5 +150,72 @@ pub fn connect_mtls(
.map_err(|e| format!("TLS handshake failed: {}", e))?; .map_err(|e| format!("TLS handshake failed: {}", e))?;
} }
info!("mTLS handshake completed successfully for {}", address);
Ok(TransportStream::Mtls(tls_stream))
}
pub fn connect_mtls_from_pem(
address: &str,
server_name: &str,
ca_cert_pem: &str,
client_cert_pem: &str,
client_key_pem: &str,
) -> Result<TransportStream, String> {
info!(
"Connecting mTLS from in-memory PEM payloads to {} using server_name={}",
address, server_name
);
let mut root_store = RootCertStore::empty();
let ca_certificates = load_certificates_from_pem(ca_cert_pem)?;
info!(
"Loaded {} CA certificate(s) from enrollment payload",
ca_certificates.len()
);
for certificate in ca_certificates {
root_store
.add(certificate)
.map_err(|e| format!("failed to add CA certificate from PEM payload: {}", e))?;
}
let client_certificates = load_certificates_from_pem(client_cert_pem)?;
info!(
"Loaded {} client certificate(s) from enrollment payload",
client_certificates.len()
);
let client_key = load_private_key_from_pem(client_key_pem)?;
info!("Loaded client private key from enrollment payload");
let tls_config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_client_auth_cert(client_certificates, client_key)
.map_err(|e| format!("failed to configure mTLS client: {}", e))?;
info!("Constructed rustls client config for {}", address);
let tcp_stream = connect_tcp(address)?;
let resolved_server_name = if server_name.trim().is_empty() {
infer_server_name(address).to_string()
} else {
server_name.trim().to_string()
};
let server_name = ServerName::try_from(resolved_server_name.clone())
.map_err(|_| format!("invalid TLS server name: {}", resolved_server_name))?;
let mut tls_stream = StreamOwned::new(
ClientConnection::new(Arc::new(tls_config), server_name)
.map_err(|e| format!("failed to create TLS client: {}", e))?,
tcp_stream,
);
info!("Starting mTLS handshake for {}", address);
while tls_stream.conn.is_handshaking() {
tls_stream
.conn
.complete_io(&mut tls_stream.sock)
.map_err(|e| format!("TLS handshake failed: {}", e))?;
}
info!("mTLS handshake completed successfully for {}", address);
Ok(TransportStream::Mtls(tls_stream)) Ok(TransportStream::Mtls(tls_stream))
} }

View File

@@ -1,12 +1,10 @@
use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair, PKCS_ECDSA_P256_SHA256}; use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair, PKCS_RSA_SHA256};
use log::info;
use reqwest::blocking::Client; use reqwest::blocking::Client;
use serde::Deserialize; use serde::Deserialize;
use uuid::Uuid; use uuid::Uuid;
use super::artifacts::{ use super::connector::connect_mtls_from_pem;
persist_enrollment_artifacts, store_enrollment_artifacts, EnrollmentArtifacts,
};
use super::connector::connect_mtls;
use crate::tcp::transport::TransportStream; use crate::tcp::transport::TransportStream;
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -72,6 +70,7 @@ fn fetch_enrollment_config(host: &str, enroll_port: &str) -> Result<EnrollmentCo
host.trim(), host.trim(),
enroll_port.trim() enroll_port.trim()
); );
info!("Fetching TAK enrollment config from {}", url);
let response = enrollment_http_client()? let response = enrollment_http_client()?
.get(&url) .get(&url)
@@ -95,6 +94,11 @@ fn fetch_enrollment_config(host: &str, enroll_port: &str) -> Result<EnrollmentCo
let enroll_path = extract_tag_value(&response_text, "enrollPath") let enroll_path = extract_tag_value(&response_text, "enrollPath")
.ok_or_else(|| "missing enrollPath in /Marti/api/tls/config response".to_string())?; .ok_or_else(|| "missing enrollPath in /Marti/api/tls/config response".to_string())?;
info!(
"Enrollment config received: server_port={} enroll_path={}",
server_port, enroll_path
);
Ok(EnrollmentConfig { Ok(EnrollmentConfig {
server_port, server_port,
enroll_path, enroll_path,
@@ -108,8 +112,12 @@ fn enroll_client_certificate(
username: &str, username: &str,
password: &str, password: &str,
client_uid: &str, client_uid: &str,
) -> Result<EnrollmentArtifacts, String> { ) -> Result<(String, String, String), String> {
let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256) info!(
"Generating RSA client keypair and CSR for enrolled TAK client {}",
client_uid
);
let key_pair = KeyPair::generate_for(&PKCS_RSA_SHA256)
.map_err(|e| format!("failed to generate client keypair: {}", e))?; .map_err(|e| format!("failed to generate client keypair: {}", e))?;
let mut distinguished_name = DistinguishedName::new(); let mut distinguished_name = DistinguishedName::new();
@@ -133,6 +141,10 @@ fn enroll_client_certificate(
enroll_path.trim(), enroll_path.trim(),
client_uid.trim() client_uid.trim()
); );
info!(
"Submitting client certificate enrollment request for {} to {}",
client_uid, url
);
let response = enrollment_http_client()? let response = enrollment_http_client()?
.post(&url) .post(&url)
@@ -154,6 +166,12 @@ fn enroll_client_certificate(
let enrollment: EnrollmentResponse = response let enrollment: EnrollmentResponse = response
.json() .json()
.map_err(|e| format!("failed to parse enrollment response: {}", e))?; .map_err(|e| format!("failed to parse enrollment response: {}", e))?;
info!(
"Enrollment response parsed successfully for {} (signed_cert_len={}, ca_len={})",
client_uid,
enrollment.signed_cert.len(),
enrollment.ca0.len()
);
let cert_pem = wrap_pem_body( let cert_pem = wrap_pem_body(
&enrollment.signed_cert, &enrollment.signed_cert,
@@ -162,7 +180,7 @@ fn enroll_client_certificate(
); );
let key_pem = key_pair.serialize_pem(); let key_pem = key_pair.serialize_pem();
persist_enrollment_artifacts(client_uid, &enrollment.ca0, &cert_pem, &key_pem) Ok((enrollment.ca0, cert_pem, key_pem))
} }
pub fn enroll_and_connect( pub fn enroll_and_connect(
@@ -178,9 +196,16 @@ pub fn enroll_and_connect(
} else { } else {
client_uid.trim().to_string() client_uid.trim().to_string()
}; };
info!(
"Starting enroll_and_connect for host={} enroll_port={} server_name={} client_uid={}",
host,
enroll_port,
server_name,
normalized_client_uid
);
let enrollment_config = fetch_enrollment_config(host, enroll_port)?; let enrollment_config = fetch_enrollment_config(host, enroll_port)?;
let artifacts = enroll_client_certificate( let (ca_cert_pem, client_cert_pem, client_key_pem) = enroll_client_certificate(
host, host,
enroll_port, enroll_port,
&enrollment_config.enroll_path, &enrollment_config.enroll_path,
@@ -189,17 +214,15 @@ pub fn enroll_and_connect(
&normalized_client_uid, &normalized_client_uid,
)?; )?;
store_enrollment_artifacts(artifacts.clone()); connect_mtls_from_pem(
connect_mtls(
&format!("{}:{}", host.trim(), enrollment_config.server_port.trim()), &format!("{}:{}", host.trim(), enrollment_config.server_port.trim()),
if server_name.trim().is_empty() { if server_name.trim().is_empty() {
host.trim() host.trim()
} else { } else {
server_name.trim() server_name.trim()
}, },
&artifacts.ca_cert_path, &ca_cert_pem,
&artifacts.client_cert_path, &client_cert_pem,
&artifacts.client_key_path, &client_key_pem,
) )
} }

View File

@@ -1,4 +1,3 @@
pub mod artifacts;
mod connector; mod connector;
mod enrollment; mod enrollment;

View File

@@ -1,10 +1,15 @@
use log::info;
use rustls::{ClientConnection, StreamOwned}; use rustls::{ClientConnection, StreamOwned};
use std::io::Write; use std::io::Write;
use std::net::TcpStream; use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::time::Duration;
use super::config::ConnectionConfig; use super::config::ConnectionConfig;
use super::tls::{connect_mtls, enroll_and_connect}; use super::tls::{connect_mtls, enroll_and_connect};
const TCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const SOCKET_IO_TIMEOUT: Duration = Duration::from_secs(10);
pub enum TransportStream { pub enum TransportStream {
Plain(TcpStream), Plain(TcpStream),
Mtls(StreamOwned<ClientConnection, TcpStream>), Mtls(StreamOwned<ClientConnection, TcpStream>),
@@ -25,11 +30,34 @@ impl TransportStream {
} }
} }
fn connect_plain(address: &str) -> Result<TransportStream, String> {
let socket_addr: SocketAddr = address
.to_socket_addrs()
.map_err(|e| format!("failed to resolve {}: {}", address, e))?
.next()
.ok_or_else(|| format!("failed to resolve {}: no socket addresses returned", address))?;
info!(
"Opening plain TCP connection to {} (resolved={}) with timeout {:?}",
address, socket_addr, TCP_CONNECT_TIMEOUT
);
let stream = TcpStream::connect_timeout(&socket_addr, TCP_CONNECT_TIMEOUT)
.map_err(|e| format!("failed to connect to {}: {}", address, e))?;
stream
.set_read_timeout(Some(SOCKET_IO_TIMEOUT))
.map_err(|e| format!("failed to set read timeout on {}: {}", address, e))?;
stream
.set_write_timeout(Some(SOCKET_IO_TIMEOUT))
.map_err(|e| format!("failed to set write timeout on {}: {}", address, e))?;
Ok(TransportStream::Plain(stream))
}
pub fn connect_stream(config: &ConnectionConfig) -> Result<TransportStream, String> { pub fn connect_stream(config: &ConnectionConfig) -> Result<TransportStream, String> {
info!("connect_stream invoked for {}", config.describe());
match config { match config {
ConnectionConfig::Plain { address } => TcpStream::connect(address) ConnectionConfig::Plain { address } => connect_plain(address),
.map(TransportStream::Plain)
.map_err(|e| format!("failed to connect to {}: {}", address, e)),
ConnectionConfig::Mtls { ConnectionConfig::Mtls {
address, address,
server_name, server_name,