mirror of
https://github.com/valmojr/armatak.git
synced 2026-06-13 09:53:30 +00:00
Fixed CoT queue during armatak connection to the TAK Server, running soft as butter
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -1053,8 +1053,8 @@ version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"pem",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"time",
|
||||
"yasna",
|
||||
|
||||
@@ -10,7 +10,7 @@ lazy_static = "1.5.0"
|
||||
log = "0.4.22"
|
||||
log4rs = "1.3.0"
|
||||
reqwest = { version = "0.12.15", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||
rcgen = "0.13.2"
|
||||
rcgen = { version = "0.13.2", default-features = false, features = ["crypto", "pem", "aws_lc_rs"] }
|
||||
rustls = "0.23.23"
|
||||
rustls-pemfile = "2.2.0"
|
||||
serde = { version = "1.0.210", features = ["derive"] }
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use arma_rs::{arma, Extension, Group};
|
||||
use rustls::crypto::aws_lc_rs;
|
||||
mod structs;
|
||||
mod tcp;
|
||||
mod tests;
|
||||
@@ -31,6 +32,9 @@ pub fn init() -> Extension {
|
||||
|
||||
log4rs::init_config(config).unwrap();
|
||||
|
||||
let _ = aws_lc_rs::default_provider().install_default();
|
||||
log::info!("Initialized rustls aws-lc crypto provider.");
|
||||
|
||||
Extension::build()
|
||||
.command("local_ip", utils::address::get_local_address)
|
||||
.command("uuid", utils::uuid::get_uuid)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
use arma_rs::Context;
|
||||
use log::info;
|
||||
use std::sync::mpsc::{Receiver, Sender};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use log::{info, warn};
|
||||
use std::collections::VecDeque;
|
||||
use std::panic::{self, AssertUnwindSafe};
|
||||
use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender, TryRecvError};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use super::config::ConnectionConfig;
|
||||
use super::transport::{connect_stream, TransportStream};
|
||||
use super::TCP_CLIENT;
|
||||
|
||||
const CONNECT_POLL_INTERVAL: Duration = Duration::from_millis(200);
|
||||
const MAX_PENDING_MESSAGES: usize = 128;
|
||||
|
||||
pub enum TcpCommand {
|
||||
SendMessage(String, Context),
|
||||
Stop,
|
||||
@@ -17,80 +22,259 @@ pub struct TcpClient {
|
||||
pub(crate) tx: Sender<TcpCommand>,
|
||||
}
|
||||
|
||||
enum ConnectionState {
|
||||
Connecting,
|
||||
Connected,
|
||||
Failed(String),
|
||||
}
|
||||
|
||||
enum ConnectEvent {
|
||||
Connected(TransportStream),
|
||||
Failed(String),
|
||||
}
|
||||
|
||||
fn describe_panic_payload(payload: Box<dyn std::any::Any + Send>) -> String {
|
||||
if let Some(message) = payload.downcast_ref::<&str>() {
|
||||
(*message).to_string()
|
||||
} else if let Some(message) = payload.downcast_ref::<String>() {
|
||||
message.clone()
|
||||
} else {
|
||||
"unknown panic payload".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn log_message_preview(message: &str) -> String {
|
||||
message.chars().take(96).collect::<String>()
|
||||
}
|
||||
|
||||
fn send_over_stream(
|
||||
stream: &mut TransportStream,
|
||||
context: &Context,
|
||||
message: String,
|
||||
) -> Result<(), String> {
|
||||
let message_len = message.len();
|
||||
info!("Sending TCP payload ({} bytes)", message_len);
|
||||
stream
|
||||
.write_message(message.as_bytes())
|
||||
.map_err(|e| {
|
||||
let message = e.to_string();
|
||||
let _ = context.callback_data(
|
||||
"TCP SOCKET ERROR",
|
||||
"TAK Socket disconnected",
|
||||
message.clone(),
|
||||
);
|
||||
message
|
||||
})
|
||||
}
|
||||
|
||||
fn flush_pending_messages(
|
||||
connection: &mut Option<TransportStream>,
|
||||
pending_messages: &mut VecDeque<(String, Context)>,
|
||||
state: &mut ConnectionState,
|
||||
) {
|
||||
if pending_messages.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(stream) = connection.as_mut() else {
|
||||
return;
|
||||
};
|
||||
|
||||
info!(
|
||||
"Flushing {} queued TCP payload(s) after connection became active",
|
||||
pending_messages.len()
|
||||
);
|
||||
|
||||
while let Some((message, context)) = pending_messages.pop_front() {
|
||||
if let Err(error) = send_over_stream(stream, &context, message) {
|
||||
info!("Failed to send queued message: {}", error);
|
||||
*state = ConnectionState::Failed(error);
|
||||
*connection = None;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_connect_event(
|
||||
connect_rx: &Receiver<ConnectEvent>,
|
||||
connection: &mut Option<TransportStream>,
|
||||
state: &mut ConnectionState,
|
||||
pending_messages: &mut VecDeque<(String, Context)>,
|
||||
ctx: &Context,
|
||||
connection_message: &str,
|
||||
target: &str,
|
||||
) {
|
||||
loop {
|
||||
match connect_rx.try_recv() {
|
||||
Ok(ConnectEvent::Connected(stream)) => {
|
||||
info!("TCP connection established successfully: {}", target);
|
||||
let _ = ctx.callback_data("TCP SOCKET", connection_message, target.to_string());
|
||||
*connection = Some(stream);
|
||||
*state = ConnectionState::Connected;
|
||||
flush_pending_messages(connection, pending_messages, state);
|
||||
}
|
||||
Ok(ConnectEvent::Failed(error)) => {
|
||||
info!("Failed to connect to TCP server: {}", error);
|
||||
let _ = ctx.callback_data(
|
||||
"TCP SOCKET ERROR",
|
||||
"TAK Socket connection failed",
|
||||
error.clone(),
|
||||
);
|
||||
*state = ConnectionState::Failed(error);
|
||||
}
|
||||
Err(TryRecvError::Empty) => break,
|
||||
Err(TryRecvError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpClient {
|
||||
pub fn start(&self, config: ConnectionConfig, rx: Receiver<TcpCommand>, ctx: Context) {
|
||||
if let Some(ref client) = *TCP_CLIENT.lock().unwrap() {
|
||||
info!("Existing TCP client detected; stopping previous instance before restart.");
|
||||
client.stop();
|
||||
}
|
||||
|
||||
let connection = Arc::new(Mutex::new(None::<TransportStream>));
|
||||
let connection_clone = Arc::clone(&connection);
|
||||
|
||||
thread::spawn(move || {
|
||||
let mut running = true;
|
||||
let connection_message = config.connected_message();
|
||||
let config_description = config.describe();
|
||||
let target = config.target();
|
||||
let mut state = ConnectionState::Connecting;
|
||||
let mut connection: Option<TransportStream> = None;
|
||||
let mut pending_messages: VecDeque<(String, Context)> = VecDeque::new();
|
||||
let (connect_tx, connect_rx) = mpsc::channel();
|
||||
|
||||
let tcp_thread = thread::spawn(move || match connect_stream(&config) {
|
||||
Ok(stream) => {
|
||||
let target = config.target();
|
||||
let _ = ctx.callback_data("TCP SOCKET", connection_message, target);
|
||||
*connection_clone.lock().unwrap() = Some(stream);
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = ctx.callback_data(
|
||||
"TCP SOCKET ERROR",
|
||||
"TAK Socket connection failed",
|
||||
e.to_string(),
|
||||
);
|
||||
info!("Failed to connect to TCP server: {}", e);
|
||||
info!("TCP worker thread started with config: {}", config_description);
|
||||
|
||||
let tcp_thread = thread::spawn(move || {
|
||||
let connect_result = panic::catch_unwind(AssertUnwindSafe(|| connect_stream(&config)));
|
||||
|
||||
match connect_result {
|
||||
Ok(Ok(stream)) => {
|
||||
let _ = connect_tx.send(ConnectEvent::Connected(stream));
|
||||
}
|
||||
Ok(Err(error)) => {
|
||||
let _ = connect_tx.send(ConnectEvent::Failed(error));
|
||||
}
|
||||
Err(payload) => {
|
||||
let message = format!(
|
||||
"TCP connection worker panicked: {}",
|
||||
describe_panic_payload(payload)
|
||||
);
|
||||
let _ = connect_tx.send(ConnectEvent::Failed(message));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
while running {
|
||||
match rx.recv() {
|
||||
Ok(TcpCommand::SendMessage(message, context)) => {
|
||||
if let Some(stream) = connection.lock().unwrap().as_mut() {
|
||||
if let Err(e) = stream.write_message(message.as_bytes()) {
|
||||
info!("Failed to send message: {}", e);
|
||||
poll_connect_event(
|
||||
&connect_rx,
|
||||
&mut connection,
|
||||
&mut state,
|
||||
&mut pending_messages,
|
||||
&ctx,
|
||||
connection_message,
|
||||
&target,
|
||||
);
|
||||
|
||||
match rx.recv_timeout(CONNECT_POLL_INTERVAL) {
|
||||
Ok(TcpCommand::SendMessage(message, context)) => {
|
||||
let message_len = message.len();
|
||||
match &mut state {
|
||||
ConnectionState::Connected => {
|
||||
if let Some(stream) = connection.as_mut() {
|
||||
if let Err(error) = send_over_stream(stream, &context, message) {
|
||||
info!("Failed to send message: {}", error);
|
||||
state = ConnectionState::Failed(error);
|
||||
connection = None;
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
"Connection state said connected, but no socket was present; queuing payload."
|
||||
);
|
||||
pending_messages.push_back((message, context));
|
||||
}
|
||||
}
|
||||
ConnectionState::Connecting => {
|
||||
if pending_messages.len() >= MAX_PENDING_MESSAGES {
|
||||
let preview = log_message_preview(&message);
|
||||
warn!(
|
||||
"Dropping TCP payload because connection is still pending and queue is full ({} bytes, preview={:?})",
|
||||
message_len, preview
|
||||
);
|
||||
let _ = context.callback_data(
|
||||
"TCP SOCKET ERROR",
|
||||
"TAK Socket is still connecting",
|
||||
format!(
|
||||
"queue full while connecting; dropped payload ({} bytes, preview={:?})",
|
||||
message_len, preview
|
||||
),
|
||||
);
|
||||
} else {
|
||||
info!(
|
||||
"Queueing TCP payload while connection is pending ({} bytes, queued={})",
|
||||
message_len,
|
||||
pending_messages.len() + 1
|
||||
);
|
||||
pending_messages.push_back((message, context));
|
||||
}
|
||||
}
|
||||
ConnectionState::Failed(error) => {
|
||||
let preview = log_message_preview(&message);
|
||||
warn!(
|
||||
"Dropping TCP payload because connection is in failed state ({} bytes, preview={:?}, error={})",
|
||||
message_len, preview, error
|
||||
);
|
||||
let _ = context.callback_data(
|
||||
"TCP SOCKET ERROR",
|
||||
"TAK Socket disconnected",
|
||||
e.to_string(),
|
||||
"TAK Socket is not connected",
|
||||
error.clone(),
|
||||
);
|
||||
|
||||
running = false;
|
||||
}
|
||||
} else {
|
||||
let _ = context
|
||||
.callback_null("TCP SOCKET ERROR", "TAK Socket is not active");
|
||||
}
|
||||
}
|
||||
Ok(TcpCommand::Stop) => {
|
||||
running = false;
|
||||
info!("Stopping TCP client.");
|
||||
}
|
||||
Err(error) => {
|
||||
info!("Error receiving command: {}", error);
|
||||
Err(RecvTimeoutError::Timeout) => {}
|
||||
Err(RecvTimeoutError::Disconnected) => {
|
||||
warn!("TCP command channel disconnected.");
|
||||
running = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tcp_thread.join().unwrap();
|
||||
info!("Waiting for TCP connection thread to finish.");
|
||||
match tcp_thread.join() {
|
||||
Ok(()) => info!("TCP connection thread joined successfully."),
|
||||
Err(payload) => warn!(
|
||||
"TCP connection thread join reported a panic: {}",
|
||||
describe_panic_payload(payload)
|
||||
),
|
||||
}
|
||||
info!("TCP worker thread finished.");
|
||||
});
|
||||
}
|
||||
|
||||
pub fn send_payload(&self, context: Context, payload: String) {
|
||||
let tx = self.tx.clone();
|
||||
thread::spawn(move || {
|
||||
tx.send(TcpCommand::SendMessage(payload, context)).unwrap();
|
||||
info!("Dispatching queued TCP payload command.");
|
||||
if let Err(error) = tx.send(TcpCommand::SendMessage(payload, context)) {
|
||||
warn!("Failed to dispatch TCP payload command: {}", error);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
let tx = self.tx.clone();
|
||||
thread::spawn(move || {
|
||||
tx.send(TcpCommand::Stop).unwrap();
|
||||
info!("Dispatching TCP stop command.");
|
||||
if let Err(error) = tx.send(TcpCommand::Stop) {
|
||||
warn!("Failed to dispatch TCP stop command: {}", error);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,4 +34,31 @@ impl ConnectionConfig {
|
||||
Self::EnrollMtls { host, .. } => host.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn describe(&self) -> String {
|
||||
match self {
|
||||
Self::Plain { address } => format!("plain tcp -> {}", address),
|
||||
Self::Mtls {
|
||||
address,
|
||||
server_name,
|
||||
ca_cert_path,
|
||||
client_cert_path,
|
||||
client_key_path,
|
||||
} => format!(
|
||||
"manual mtls -> {} (server_name={}, ca={}, cert={}, key={})",
|
||||
address, server_name, ca_cert_path, client_cert_path, client_key_path
|
||||
),
|
||||
Self::EnrollMtls {
|
||||
host,
|
||||
server_name,
|
||||
enroll_port,
|
||||
username,
|
||||
client_uid,
|
||||
..
|
||||
} => format!(
|
||||
"enroll mtls -> host={} enroll_port={} server_name={} username={} client_uid={}",
|
||||
host, enroll_port, server_name, username, client_uid
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,13 +14,13 @@ pub mod draw;
|
||||
|
||||
use client::{TcpClient, TcpCommand};
|
||||
use config::ConnectionConfig;
|
||||
use tls::artifacts::clear_enrollment_artifacts;
|
||||
|
||||
lazy_static! {
|
||||
static ref TCP_CLIENT: Arc<Mutex<Option<TcpClient>>> = Arc::new(Mutex::new(None));
|
||||
}
|
||||
|
||||
fn start_with_config(ctx: Context, config: ConnectionConfig) {
|
||||
info!("Starting TCP client with config: {}", config.describe());
|
||||
let (tx, rx): (Sender<TcpCommand>, Receiver<TcpCommand>) = mpsc::channel();
|
||||
|
||||
let client = TcpClient { tx };
|
||||
@@ -67,7 +67,6 @@ pub fn start_enroll_mtls(
|
||||
password: String,
|
||||
client_uid: String,
|
||||
) -> &'static str {
|
||||
clear_enrollment_artifacts();
|
||||
start_with_config(
|
||||
ctx,
|
||||
ConnectionConfig::EnrollMtls {
|
||||
@@ -85,6 +84,7 @@ pub fn start_enroll_mtls(
|
||||
|
||||
pub fn send_payload(ctx: Context, payload: String) -> &'static str {
|
||||
if let Some(ref client) = *TCP_CLIENT.lock().unwrap() {
|
||||
info!("Queueing TCP payload ({} bytes)", payload.len());
|
||||
client.send_payload(ctx, payload);
|
||||
} else {
|
||||
let _ = ctx.callback_null("TCP SOCKET ERROR", "TCP Client is not running");
|
||||
@@ -96,9 +96,9 @@ pub fn send_payload(ctx: Context, payload: String) -> &'static str {
|
||||
|
||||
pub fn stop(ctx: Context) -> &'static str {
|
||||
if let Some(ref client) = *TCP_CLIENT.lock().unwrap() {
|
||||
info!("Stopping TCP client via extension command.");
|
||||
client.stop();
|
||||
let _ = ctx.callback_null("TCP SOCKET", "TCP client stopped");
|
||||
clear_enrollment_artifacts();
|
||||
} else {
|
||||
let _ = ctx.callback_null("TCP SOCKET ERROR", "TCP client is not running");
|
||||
}
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
use lazy_static::lazy_static;
|
||||
use std::env;
|
||||
use std::fs::{self, create_dir_all};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Mutex;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EnrollmentArtifacts {
|
||||
pub ca_cert_path: String,
|
||||
pub client_cert_path: String,
|
||||
pub client_key_path: String,
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref ENROLLMENT_ARTIFACTS: Mutex<Option<EnrollmentArtifacts>> = Mutex::new(None);
|
||||
}
|
||||
|
||||
fn current_artifacts_dir() -> Result<PathBuf, String> {
|
||||
let mut path = env::current_dir().map_err(|e| format!("failed to resolve cwd: {}", e))?;
|
||||
path.push(".armatak");
|
||||
path.push("session-certs");
|
||||
create_dir_all(&path)
|
||||
.map_err(|e| format!("failed to create cert dir {}: {}", path.display(), e))?;
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
pub fn persist_enrollment_artifacts(
|
||||
client_uid: &str,
|
||||
ca_pem: &str,
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> Result<EnrollmentArtifacts, String> {
|
||||
let mut base_dir = current_artifacts_dir()?;
|
||||
let safe_uid = client_uid
|
||||
.chars()
|
||||
.map(|ch| {
|
||||
if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
|
||||
ch
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
base_dir.push(safe_uid);
|
||||
create_dir_all(&base_dir).map_err(|e| {
|
||||
format!(
|
||||
"failed to create session cert dir {}: {}",
|
||||
base_dir.display(),
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
let ca_cert_path = base_dir.join("ca.pem");
|
||||
let client_cert_path = base_dir.join("client.pem");
|
||||
let client_key_path = base_dir.join("client.key");
|
||||
|
||||
fs::write(&ca_cert_path, ca_pem).map_err(|e| {
|
||||
format!(
|
||||
"failed to persist CA cert {}: {}",
|
||||
ca_cert_path.display(),
|
||||
e
|
||||
)
|
||||
})?;
|
||||
fs::write(&client_cert_path, cert_pem).map_err(|e| {
|
||||
format!(
|
||||
"failed to persist client cert {}: {}",
|
||||
client_cert_path.display(),
|
||||
e
|
||||
)
|
||||
})?;
|
||||
fs::write(&client_key_path, key_pem).map_err(|e| {
|
||||
format!(
|
||||
"failed to persist client key {}: {}",
|
||||
client_key_path.display(),
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(EnrollmentArtifacts {
|
||||
ca_cert_path: ca_cert_path.to_string_lossy().to_string(),
|
||||
client_cert_path: client_cert_path.to_string_lossy().to_string(),
|
||||
client_key_path: client_key_path.to_string_lossy().to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn store_enrollment_artifacts(artifacts: EnrollmentArtifacts) {
|
||||
*ENROLLMENT_ARTIFACTS.lock().unwrap() = Some(artifacts);
|
||||
}
|
||||
|
||||
pub fn clear_enrollment_artifacts() {
|
||||
if let Some(artifacts) = ENROLLMENT_ARTIFACTS.lock().unwrap().take() {
|
||||
for path in [
|
||||
artifacts.ca_cert_path,
|
||||
artifacts.client_cert_path,
|
||||
artifacts.client_key_path,
|
||||
] {
|
||||
let _ = fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,19 @@
|
||||
use log::info;
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
|
||||
use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
|
||||
use rustls_pemfile::{certs, private_key};
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::net::TcpStream;
|
||||
use std::io::Cursor;
|
||||
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
|
||||
use std::time::Duration;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::tcp::transport::TransportStream;
|
||||
|
||||
const TCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const SOCKET_IO_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
fn load_certificates(path: &str) -> Result<Vec<CertificateDer<'static>>, String> {
|
||||
let file = File::open(path).map_err(|e| format!("failed to open cert file {}: {}", path, e))?;
|
||||
let mut reader = BufReader::new(file);
|
||||
@@ -26,6 +32,22 @@ fn load_private_key(path: &str) -> Result<PrivateKeyDer<'static>, String> {
|
||||
.ok_or_else(|| format!("no supported private key found in {}", path))
|
||||
}
|
||||
|
||||
fn load_certificates_from_pem(pem: &str) -> Result<Vec<CertificateDer<'static>>, String> {
|
||||
let mut reader = Cursor::new(pem.as_bytes());
|
||||
|
||||
certs(&mut reader)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| format!("failed to read certs from PEM payload: {}", e))
|
||||
}
|
||||
|
||||
fn load_private_key_from_pem(pem: &str) -> Result<PrivateKeyDer<'static>, String> {
|
||||
let mut reader = Cursor::new(pem.as_bytes());
|
||||
|
||||
private_key(&mut reader)
|
||||
.map_err(|e| format!("failed to read private key from PEM payload: {}", e))?
|
||||
.ok_or_else(|| "no supported private key found in PEM payload".to_string())
|
||||
}
|
||||
|
||||
fn infer_server_name(address: &str) -> &str {
|
||||
address
|
||||
.trim()
|
||||
@@ -38,6 +60,34 @@ fn infer_server_name(address: &str) -> &str {
|
||||
.unwrap_or(address)
|
||||
}
|
||||
|
||||
fn resolve_address(address: &str) -> Result<SocketAddr, String> {
|
||||
address
|
||||
.to_socket_addrs()
|
||||
.map_err(|e| format!("failed to resolve {}: {}", address, e))?
|
||||
.next()
|
||||
.ok_or_else(|| format!("failed to resolve {}: no socket addresses returned", address))
|
||||
}
|
||||
|
||||
fn connect_tcp(address: &str) -> Result<TcpStream, String> {
|
||||
let socket_addr = resolve_address(address)?;
|
||||
info!(
|
||||
"Opening TCP connection to {} (resolved={}) with timeout {:?}",
|
||||
address, socket_addr, TCP_CONNECT_TIMEOUT
|
||||
);
|
||||
|
||||
let tcp_stream = TcpStream::connect_timeout(&socket_addr, TCP_CONNECT_TIMEOUT)
|
||||
.map_err(|e| format!("failed to connect to {}: {}", address, e))?;
|
||||
|
||||
tcp_stream
|
||||
.set_read_timeout(Some(SOCKET_IO_TIMEOUT))
|
||||
.map_err(|e| format!("failed to set read timeout on {}: {}", address, e))?;
|
||||
tcp_stream
|
||||
.set_write_timeout(Some(SOCKET_IO_TIMEOUT))
|
||||
.map_err(|e| format!("failed to set write timeout on {}: {}", address, e))?;
|
||||
|
||||
Ok(tcp_stream)
|
||||
}
|
||||
|
||||
pub fn connect_mtls(
|
||||
address: &str,
|
||||
server_name: &str,
|
||||
@@ -45,23 +95,39 @@ pub fn connect_mtls(
|
||||
client_cert_path: &str,
|
||||
client_key_path: &str,
|
||||
) -> Result<TransportStream, String> {
|
||||
info!(
|
||||
"Connecting mTLS from file paths to {} using server_name={}",
|
||||
address, server_name
|
||||
);
|
||||
let mut root_store = RootCertStore::empty();
|
||||
for certificate in load_certificates(ca_cert_path)? {
|
||||
let ca_certificates = load_certificates(ca_cert_path)?;
|
||||
info!(
|
||||
"Loaded {} CA certificate(s) from {}",
|
||||
ca_certificates.len(),
|
||||
ca_cert_path
|
||||
);
|
||||
for certificate in ca_certificates {
|
||||
root_store
|
||||
.add(certificate)
|
||||
.map_err(|e| format!("failed to add CA certificate from {}: {}", ca_cert_path, e))?;
|
||||
}
|
||||
|
||||
let client_certificates = load_certificates(client_cert_path)?;
|
||||
info!(
|
||||
"Loaded {} client certificate(s) from {}",
|
||||
client_certificates.len(),
|
||||
client_cert_path
|
||||
);
|
||||
let client_key = load_private_key(client_key_path)?;
|
||||
info!("Loaded client private key from {}", client_key_path);
|
||||
|
||||
let tls_config = ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_client_auth_cert(client_certificates, client_key)
|
||||
.map_err(|e| format!("failed to configure mTLS client: {}", e))?;
|
||||
info!("Constructed rustls client config for {}", address);
|
||||
|
||||
let tcp_stream = TcpStream::connect(address)
|
||||
.map_err(|e| format!("failed to connect to {}: {}", address, e))?;
|
||||
let tcp_stream = connect_tcp(address)?;
|
||||
let resolved_server_name = if server_name.trim().is_empty() {
|
||||
infer_server_name(address).to_string()
|
||||
} else {
|
||||
@@ -76,6 +142,7 @@ pub fn connect_mtls(
|
||||
tcp_stream,
|
||||
);
|
||||
|
||||
info!("Starting mTLS handshake for {}", address);
|
||||
while tls_stream.conn.is_handshaking() {
|
||||
tls_stream
|
||||
.conn
|
||||
@@ -83,5 +150,72 @@ pub fn connect_mtls(
|
||||
.map_err(|e| format!("TLS handshake failed: {}", e))?;
|
||||
}
|
||||
|
||||
info!("mTLS handshake completed successfully for {}", address);
|
||||
|
||||
Ok(TransportStream::Mtls(tls_stream))
|
||||
}
|
||||
|
||||
pub fn connect_mtls_from_pem(
|
||||
address: &str,
|
||||
server_name: &str,
|
||||
ca_cert_pem: &str,
|
||||
client_cert_pem: &str,
|
||||
client_key_pem: &str,
|
||||
) -> Result<TransportStream, String> {
|
||||
info!(
|
||||
"Connecting mTLS from in-memory PEM payloads to {} using server_name={}",
|
||||
address, server_name
|
||||
);
|
||||
let mut root_store = RootCertStore::empty();
|
||||
let ca_certificates = load_certificates_from_pem(ca_cert_pem)?;
|
||||
info!(
|
||||
"Loaded {} CA certificate(s) from enrollment payload",
|
||||
ca_certificates.len()
|
||||
);
|
||||
for certificate in ca_certificates {
|
||||
root_store
|
||||
.add(certificate)
|
||||
.map_err(|e| format!("failed to add CA certificate from PEM payload: {}", e))?;
|
||||
}
|
||||
|
||||
let client_certificates = load_certificates_from_pem(client_cert_pem)?;
|
||||
info!(
|
||||
"Loaded {} client certificate(s) from enrollment payload",
|
||||
client_certificates.len()
|
||||
);
|
||||
let client_key = load_private_key_from_pem(client_key_pem)?;
|
||||
info!("Loaded client private key from enrollment payload");
|
||||
|
||||
let tls_config = ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_client_auth_cert(client_certificates, client_key)
|
||||
.map_err(|e| format!("failed to configure mTLS client: {}", e))?;
|
||||
info!("Constructed rustls client config for {}", address);
|
||||
|
||||
let tcp_stream = connect_tcp(address)?;
|
||||
let resolved_server_name = if server_name.trim().is_empty() {
|
||||
infer_server_name(address).to_string()
|
||||
} else {
|
||||
server_name.trim().to_string()
|
||||
};
|
||||
|
||||
let server_name = ServerName::try_from(resolved_server_name.clone())
|
||||
.map_err(|_| format!("invalid TLS server name: {}", resolved_server_name))?;
|
||||
let mut tls_stream = StreamOwned::new(
|
||||
ClientConnection::new(Arc::new(tls_config), server_name)
|
||||
.map_err(|e| format!("failed to create TLS client: {}", e))?,
|
||||
tcp_stream,
|
||||
);
|
||||
|
||||
info!("Starting mTLS handshake for {}", address);
|
||||
while tls_stream.conn.is_handshaking() {
|
||||
tls_stream
|
||||
.conn
|
||||
.complete_io(&mut tls_stream.sock)
|
||||
.map_err(|e| format!("TLS handshake failed: {}", e))?;
|
||||
}
|
||||
|
||||
info!("mTLS handshake completed successfully for {}", address);
|
||||
|
||||
Ok(TransportStream::Mtls(tls_stream))
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair, PKCS_ECDSA_P256_SHA256};
|
||||
use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair, PKCS_RSA_SHA256};
|
||||
use log::info;
|
||||
use reqwest::blocking::Client;
|
||||
use serde::Deserialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::artifacts::{
|
||||
persist_enrollment_artifacts, store_enrollment_artifacts, EnrollmentArtifacts,
|
||||
};
|
||||
use super::connector::connect_mtls;
|
||||
use super::connector::connect_mtls_from_pem;
|
||||
use crate::tcp::transport::TransportStream;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -72,6 +70,7 @@ fn fetch_enrollment_config(host: &str, enroll_port: &str) -> Result<EnrollmentCo
|
||||
host.trim(),
|
||||
enroll_port.trim()
|
||||
);
|
||||
info!("Fetching TAK enrollment config from {}", url);
|
||||
|
||||
let response = enrollment_http_client()?
|
||||
.get(&url)
|
||||
@@ -95,6 +94,11 @@ fn fetch_enrollment_config(host: &str, enroll_port: &str) -> Result<EnrollmentCo
|
||||
let enroll_path = extract_tag_value(&response_text, "enrollPath")
|
||||
.ok_or_else(|| "missing enrollPath in /Marti/api/tls/config response".to_string())?;
|
||||
|
||||
info!(
|
||||
"Enrollment config received: server_port={} enroll_path={}",
|
||||
server_port, enroll_path
|
||||
);
|
||||
|
||||
Ok(EnrollmentConfig {
|
||||
server_port,
|
||||
enroll_path,
|
||||
@@ -108,8 +112,12 @@ fn enroll_client_certificate(
|
||||
username: &str,
|
||||
password: &str,
|
||||
client_uid: &str,
|
||||
) -> Result<EnrollmentArtifacts, String> {
|
||||
let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)
|
||||
) -> Result<(String, String, String), String> {
|
||||
info!(
|
||||
"Generating RSA client keypair and CSR for enrolled TAK client {}",
|
||||
client_uid
|
||||
);
|
||||
let key_pair = KeyPair::generate_for(&PKCS_RSA_SHA256)
|
||||
.map_err(|e| format!("failed to generate client keypair: {}", e))?;
|
||||
|
||||
let mut distinguished_name = DistinguishedName::new();
|
||||
@@ -133,6 +141,10 @@ fn enroll_client_certificate(
|
||||
enroll_path.trim(),
|
||||
client_uid.trim()
|
||||
);
|
||||
info!(
|
||||
"Submitting client certificate enrollment request for {} to {}",
|
||||
client_uid, url
|
||||
);
|
||||
|
||||
let response = enrollment_http_client()?
|
||||
.post(&url)
|
||||
@@ -154,6 +166,12 @@ fn enroll_client_certificate(
|
||||
let enrollment: EnrollmentResponse = response
|
||||
.json()
|
||||
.map_err(|e| format!("failed to parse enrollment response: {}", e))?;
|
||||
info!(
|
||||
"Enrollment response parsed successfully for {} (signed_cert_len={}, ca_len={})",
|
||||
client_uid,
|
||||
enrollment.signed_cert.len(),
|
||||
enrollment.ca0.len()
|
||||
);
|
||||
|
||||
let cert_pem = wrap_pem_body(
|
||||
&enrollment.signed_cert,
|
||||
@@ -162,7 +180,7 @@ fn enroll_client_certificate(
|
||||
);
|
||||
let key_pem = key_pair.serialize_pem();
|
||||
|
||||
persist_enrollment_artifacts(client_uid, &enrollment.ca0, &cert_pem, &key_pem)
|
||||
Ok((enrollment.ca0, cert_pem, key_pem))
|
||||
}
|
||||
|
||||
pub fn enroll_and_connect(
|
||||
@@ -178,9 +196,16 @@ pub fn enroll_and_connect(
|
||||
} else {
|
||||
client_uid.trim().to_string()
|
||||
};
|
||||
info!(
|
||||
"Starting enroll_and_connect for host={} enroll_port={} server_name={} client_uid={}",
|
||||
host,
|
||||
enroll_port,
|
||||
server_name,
|
||||
normalized_client_uid
|
||||
);
|
||||
|
||||
let enrollment_config = fetch_enrollment_config(host, enroll_port)?;
|
||||
let artifacts = enroll_client_certificate(
|
||||
let (ca_cert_pem, client_cert_pem, client_key_pem) = enroll_client_certificate(
|
||||
host,
|
||||
enroll_port,
|
||||
&enrollment_config.enroll_path,
|
||||
@@ -189,17 +214,15 @@ pub fn enroll_and_connect(
|
||||
&normalized_client_uid,
|
||||
)?;
|
||||
|
||||
store_enrollment_artifacts(artifacts.clone());
|
||||
|
||||
connect_mtls(
|
||||
connect_mtls_from_pem(
|
||||
&format!("{}:{}", host.trim(), enrollment_config.server_port.trim()),
|
||||
if server_name.trim().is_empty() {
|
||||
host.trim()
|
||||
} else {
|
||||
server_name.trim()
|
||||
},
|
||||
&artifacts.ca_cert_path,
|
||||
&artifacts.client_cert_path,
|
||||
&artifacts.client_key_path,
|
||||
&ca_cert_pem,
|
||||
&client_cert_pem,
|
||||
&client_key_pem,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
pub mod artifacts;
|
||||
mod connector;
|
||||
mod enrollment;
|
||||
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
use log::info;
|
||||
use rustls::{ClientConnection, StreamOwned};
|
||||
use std::io::Write;
|
||||
use std::net::TcpStream;
|
||||
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::config::ConnectionConfig;
|
||||
use super::tls::{connect_mtls, enroll_and_connect};
|
||||
|
||||
const TCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const SOCKET_IO_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
pub enum TransportStream {
|
||||
Plain(TcpStream),
|
||||
Mtls(StreamOwned<ClientConnection, TcpStream>),
|
||||
@@ -25,11 +30,34 @@ impl TransportStream {
|
||||
}
|
||||
}
|
||||
|
||||
fn connect_plain(address: &str) -> Result<TransportStream, String> {
|
||||
let socket_addr: SocketAddr = address
|
||||
.to_socket_addrs()
|
||||
.map_err(|e| format!("failed to resolve {}: {}", address, e))?
|
||||
.next()
|
||||
.ok_or_else(|| format!("failed to resolve {}: no socket addresses returned", address))?;
|
||||
|
||||
info!(
|
||||
"Opening plain TCP connection to {} (resolved={}) with timeout {:?}",
|
||||
address, socket_addr, TCP_CONNECT_TIMEOUT
|
||||
);
|
||||
|
||||
let stream = TcpStream::connect_timeout(&socket_addr, TCP_CONNECT_TIMEOUT)
|
||||
.map_err(|e| format!("failed to connect to {}: {}", address, e))?;
|
||||
stream
|
||||
.set_read_timeout(Some(SOCKET_IO_TIMEOUT))
|
||||
.map_err(|e| format!("failed to set read timeout on {}: {}", address, e))?;
|
||||
stream
|
||||
.set_write_timeout(Some(SOCKET_IO_TIMEOUT))
|
||||
.map_err(|e| format!("failed to set write timeout on {}: {}", address, e))?;
|
||||
|
||||
Ok(TransportStream::Plain(stream))
|
||||
}
|
||||
|
||||
pub fn connect_stream(config: &ConnectionConfig) -> Result<TransportStream, String> {
|
||||
info!("connect_stream invoked for {}", config.describe());
|
||||
match config {
|
||||
ConnectionConfig::Plain { address } => TcpStream::connect(address)
|
||||
.map(TransportStream::Plain)
|
||||
.map_err(|e| format!("failed to connect to {}: {}", address, e)),
|
||||
ConnectionConfig::Plain { address } => connect_plain(address),
|
||||
ConnectionConfig::Mtls {
|
||||
address,
|
||||
server_name,
|
||||
|
||||
Reference in New Issue
Block a user