diff --git a/src/lib.rs b/src/lib.rs index 61aca55..26c470e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,8 @@ pub fn init() -> Extension { "tcp_socket", Group::new() .command("start", tcp::start) + .command("start_mtls", tcp::start_mtls) + .command("start_enroll_mtls", tcp::start_enroll_mtls) .command("stop", tcp::stop) .command("send_payload", tcp::send_payload) .group( diff --git a/src/tcp/client.rs b/src/tcp/client.rs new file mode 100644 index 0000000..18ba1e5 --- /dev/null +++ b/src/tcp/client.rs @@ -0,0 +1,96 @@ +use arma_rs::Context; +use log::info; +use std::sync::mpsc::{Receiver, Sender}; +use std::sync::{Arc, Mutex}; +use std::thread; + +use super::config::ConnectionConfig; +use super::transport::{connect_stream, TransportStream}; +use super::TCP_CLIENT; + +pub enum TcpCommand { + SendMessage(String, Context), + Stop, +} + +pub struct TcpClient { + pub(crate) tx: Sender, +} + +impl TcpClient { + pub fn start(&self, config: ConnectionConfig, rx: Receiver, ctx: Context) { + if let Some(ref client) = *TCP_CLIENT.lock().unwrap() { + client.stop(); + } + + let connection = Arc::new(Mutex::new(None::)); + let connection_clone = Arc::clone(&connection); + + thread::spawn(move || { + let mut running = true; + let connection_message = config.connected_message(); + + let tcp_thread = thread::spawn(move || match connect_stream(&config) { + Ok(stream) => { + let target = config.target(); + let _ = ctx.callback_data("TCP SOCKET", connection_message, target); + *connection_clone.lock().unwrap() = Some(stream); + } + Err(e) => { + let _ = ctx.callback_data( + "TCP SOCKET ERROR", + "TAK Socket connection failed", + e.to_string(), + ); + info!("Failed to connect to TCP server: {}", e); + } + }); + + while running { + match rx.recv() { + Ok(TcpCommand::SendMessage(message, context)) => { + if let Some(stream) = connection.lock().unwrap().as_mut() { + if let Err(e) = stream.write_message(message.as_bytes()) { + info!("Failed to send message: {}", e); + + let _ = context.callback_data( + "TCP SOCKET ERROR", + "TAK Socket disconnected", + e.to_string(), + ); + + running = false; + } + } else { + let _ = context + .callback_null("TCP SOCKET ERROR", "TAK Socket is not active"); + } + } + Ok(TcpCommand::Stop) => { + running = false; + info!("Stopping TCP client."); + } + Err(error) => { + info!("Error receiving command: {}", error); + } + } + } + + tcp_thread.join().unwrap(); + }); + } + + pub fn send_payload(&self, context: Context, payload: String) { + let tx = self.tx.clone(); + thread::spawn(move || { + tx.send(TcpCommand::SendMessage(payload, context)).unwrap(); + }); + } + + pub fn stop(&self) { + let tx = self.tx.clone(); + thread::spawn(move || { + tx.send(TcpCommand::Stop).unwrap(); + }); + } +} diff --git a/src/tcp/config.rs b/src/tcp/config.rs new file mode 100644 index 0000000..d8e1460 --- /dev/null +++ b/src/tcp/config.rs @@ -0,0 +1,37 @@ +pub enum ConnectionConfig { + Plain { + address: String, + }, + Mtls { + address: String, + server_name: String, + ca_cert_path: String, + client_cert_path: String, + client_key_path: String, + }, + EnrollMtls { + host: String, + server_name: String, + enroll_port: String, + username: String, + password: String, + client_uid: String, + }, +} + +impl ConnectionConfig { + pub fn connected_message(&self) -> &'static str { + match self { + Self::Plain { .. } => "Connected to TCP Server", + Self::Mtls { .. } => "Connected to TAK Server via mTLS", + Self::EnrollMtls { .. } => "Connected to TAK Server via enrolled mTLS certificate", + } + } + + pub fn target(&self) -> String { + match self { + Self::Plain { address } | Self::Mtls { address, .. } => address.clone(), + Self::EnrollMtls { host, .. } => host.clone(), + } + } +} diff --git a/src/tcp/mod.rs b/src/tcp/mod.rs index df7b490..641625f 100644 --- a/src/tcp/mod.rs +++ b/src/tcp/mod.rs @@ -1,118 +1,88 @@ use arma_rs::Context; use lazy_static::lazy_static; use log::info; -use std::io::Write; -use std::net::TcpStream; use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::{Arc, Mutex}; -use std::thread; + +mod client; +mod config; +mod tls; +mod transport; pub mod cot; pub mod draw; -pub enum TcpCommand { - SendMessage(String, Context), - Stop, -} - -pub struct TcpClient { - pub(crate) tx: Sender, -} - -impl TcpClient { - pub fn start(&self, address: String, rx: Receiver, ctx: Context) { - if let Some(ref client) = *TCP_CLIENT.lock().unwrap() { - client.stop(); - } - - let connection = Arc::new(Mutex::new(None)); - let connection_clone = Arc::clone(&connection); - - thread::spawn(move || { - let mut running = true; - - let tcp_thread = thread::spawn(move || match TcpStream::connect(&address) { - Ok(stream) => { - let _ = ctx.callback_data("TCP SOCKET", "Connected to TCP Server", address); - *connection_clone.lock().unwrap() = Some(stream); - } - Err(e) => { - let _ = ctx.callback_data( - "TCP SOCKET ERROR", - "TAK Socket connection failed", - e.to_string(), - ); - info!("Failed to connect to TCP server: {}", e); - } - }); - - while running { - match rx.recv() { - Ok(TcpCommand::SendMessage(message, context)) => { - if let Some(mut stream) = connection.lock().unwrap().as_ref() { - if let Err(e) = stream.write_all(message.as_bytes()) { - info!("Failed to send message: {}", e); - - let _ = context.callback_data( - "TCP SOCKET ERROR", - "TAK Socket disconnected", - e.to_string(), - ); - - running = false; - } - } else { - let _ = context.callback_null( - "TCP SOCKET ERROR", - "TAK Socket is not active", - ); - } - } - Ok(TcpCommand::Stop) => { - running = false; - info!("Stopping TCP client."); - } - Err(error) => { - info!("Error receiving command: {}", error.to_string()); - } - } - } - - tcp_thread.join().unwrap(); - }); - } - - pub fn send_payload(&self, context: Context, payload: String) { - let tx = self.tx.clone(); - thread::spawn(move || { - tx.send(TcpCommand::SendMessage(payload, context)).unwrap(); - }); - } - - pub fn stop(&self) { - let tx = self.tx.clone(); - thread::spawn(move || { - tx.send(TcpCommand::Stop).unwrap(); - }); - } -} +use client::{TcpClient, TcpCommand}; +use config::ConnectionConfig; +use tls::artifacts::clear_enrollment_artifacts; lazy_static! { static ref TCP_CLIENT: Arc>> = Arc::new(Mutex::new(None)); } -pub fn start(ctx: Context, address: String) -> &'static str { +fn start_with_config(ctx: Context, config: ConnectionConfig) { let (tx, rx): (Sender, Receiver) = mpsc::channel(); let client = TcpClient { tx }; - client.start(address, rx, ctx); + client.start(config, rx, ctx); let mut client_guard = TCP_CLIENT.lock().unwrap(); *client_guard = Some(client); +} + +pub fn start(ctx: Context, address: String) -> &'static str { + start_with_config(ctx, ConnectionConfig::Plain { address }); "Starting TCP Client" } +pub fn start_mtls( + ctx: Context, + address: String, + server_name: String, + ca_cert_path: String, + client_cert_path: String, + client_key_path: String, +) -> &'static str { + start_with_config( + ctx, + ConnectionConfig::Mtls { + address, + server_name, + ca_cert_path, + client_cert_path, + client_key_path, + }, + ); + + "Starting mTLS TCP Client" +} + +pub fn start_enroll_mtls( + ctx: Context, + host: String, + server_name: String, + enroll_port: String, + username: String, + password: String, + client_uid: String, +) -> &'static str { + clear_enrollment_artifacts(); + start_with_config( + ctx, + ConnectionConfig::EnrollMtls { + host, + server_name, + enroll_port, + username, + password, + client_uid, + }, + ); + + "Starting enrolled mTLS TCP Client" +} + pub fn send_payload(ctx: Context, payload: String) -> &'static str { if let Some(ref client) = *TCP_CLIENT.lock().unwrap() { client.send_payload(ctx, payload); @@ -128,6 +98,7 @@ pub fn stop(ctx: Context) -> &'static str { if let Some(ref client) = *TCP_CLIENT.lock().unwrap() { client.stop(); let _ = ctx.callback_null("TCP SOCKET", "TCP client stopped"); + clear_enrollment_artifacts(); } else { let _ = ctx.callback_null("TCP SOCKET ERROR", "TCP client is not running"); } diff --git a/src/tcp/transport.rs b/src/tcp/transport.rs new file mode 100644 index 0000000..8e47890 --- /dev/null +++ b/src/tcp/transport.rs @@ -0,0 +1,62 @@ +use rustls::{ClientConnection, StreamOwned}; +use std::io::Write; +use std::net::TcpStream; + +use super::config::ConnectionConfig; +use super::tls::{connect_mtls, enroll_and_connect}; + +pub enum TransportStream { + Plain(TcpStream), + Mtls(StreamOwned), +} + +impl TransportStream { + pub fn write_message(&mut self, message: &[u8]) -> Result<(), std::io::Error> { + match self { + Self::Plain(stream) => { + stream.write_all(message)?; + stream.flush() + } + Self::Mtls(stream) => { + stream.write_all(message)?; + stream.flush() + } + } + } +} + +pub fn connect_stream(config: &ConnectionConfig) -> Result { + match config { + ConnectionConfig::Plain { address } => TcpStream::connect(address) + .map(TransportStream::Plain) + .map_err(|e| format!("failed to connect to {}: {}", address, e)), + ConnectionConfig::Mtls { + address, + server_name, + ca_cert_path, + client_cert_path, + client_key_path, + } => connect_mtls( + address, + server_name, + ca_cert_path, + client_cert_path, + client_key_path, + ), + ConnectionConfig::EnrollMtls { + host, + server_name, + enroll_port, + username, + password, + client_uid, + } => enroll_and_connect( + host, + server_name, + enroll_port, + username, + password, + client_uid, + ), + } +}