diff options
author | 2024-03-31 21:39:46 +0200 | |
---|---|---|
committer | 2024-03-31 21:39:46 +0200 | |
commit | facada6f8d5aecb5ce0bc2a13042622f93f15807 (patch) | |
tree | f872c48279d636df7bbea6edb56de4d5130c8fe5 | |
parent | 7af7a7626a8e83fe3f9c3b0d2ad7d2b32da41d45 (diff) |
♻️ (error.rs): remove unnecessary error conversions for IllegalArgumentException and IllegalStateException
♻️ (ls_client.rs): refactor connect method to accept shutdown signal and return generic error
✨ (ls_client.rs): add support for graceful shutdown using Notify
✨ (ls_client.rs): implement session creation and subscription logic in connect method
♻️ (main.rs): replace SharedState with Notify for handling shutdown signal
✨ (main.rs): add retry logic with a maximum of 5 retries for the client connection in main function
✨ (main.rs): ensure graceful client disconnect and orderly shutdown of the application
-rw-r--r-- | src/error.rs | 24 | ||||
-rw-r--r-- | src/ls_client.rs | 218 | ||||
-rw-r--r-- | src/main.rs | 51 |
3 files changed, 183 insertions, 110 deletions
diff --git a/src/error.rs b/src/error.rs index 9c3a916..cd1d7c2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,12 +22,6 @@ impl Error for IllegalArgumentException { } } -impl From<Box<dyn Error>> for IllegalArgumentException { - fn from(error: Box<dyn Error>) -> Self { - IllegalArgumentException::new(&error.to_string()) - } -} - #[derive(Debug)] pub struct IllegalStateException { details: String @@ -49,22 +43,4 @@ impl Error for IllegalStateException { fn description(&self) -> &str { &self.details } -} - -impl From<Box<dyn Error>> for IllegalStateException { - fn from(error: Box<dyn Error>) -> Self { - IllegalStateException::new(&error.to_string()) - } -} - -impl From<serde_urlencoded::ser::Error> for IllegalStateException { - fn from(err: serde_urlencoded::ser::Error) -> Self { - IllegalStateException::new(&format!("Serialization error: {}", err)) - } -} - -impl From<tokio_tungstenite::tungstenite::Error> for IllegalStateException { - fn from(err: tokio_tungstenite::tungstenite::Error) -> Self { - IllegalStateException::new(&format!("WebSocket error: {}", err)) - } }
\ No newline at end of file diff --git a/src/ls_client.rs b/src/ls_client.rs index 26328f7..51e4786 100644 --- a/src/ls_client.rs +++ b/src/ls_client.rs @@ -2,14 +2,17 @@ use crate::client_listener::ClientListener; use crate::client_message_listener::ClientMessageListener; use crate::connection_details::ConnectionDetails; use crate::connection_options::ConnectionOptions; -use crate::subscription::Subscription; use crate::error::IllegalStateException; +use crate::subscription::Subscription; use crate::util::*; use cookie::Cookie; use futures_util::{SinkExt, StreamExt}; use std::collections::HashMap; +use std::error::Error; use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; +use tokio::sync::Notify; use tokio_tungstenite::{ connect_async, tungstenite::{ @@ -187,29 +190,36 @@ impl LightstreamerClient { /// See also `ClientListener.onStatusChange()` /// /// See also `ConnectionDetails.setServerAddress()` - pub async fn connect(&mut self) -> Result<(), IllegalStateException> { + pub async fn connect(&mut self, shutdown_signal: Arc<Notify>) -> Result<(), Box<dyn Error>> { if self.server_address.is_none() { - return Err(IllegalStateException::new( + return Err(Box::new(IllegalStateException::new( "No server address was configured.", - )); + ))); } let forced_transport = self.connection_options.get_forced_transport(); - if forced_transport.is_none() || *forced_transport.unwrap() != Transport::WsStreaming { - // unwrap() is safe here. - return Err(IllegalStateException::new( + if forced_transport.is_none() + || *forced_transport.unwrap() /* unwrap() is safe here */ != Transport::WsStreaming + { + return Err(Box::new(IllegalStateException::new( "Only WebSocket streaming transport is currently supported.", - )); + ))); } - let mut params = HashMap::new(); + let mut base_params = HashMap::new(); // - // Build the mandatory request parameters. + // Build the request base parameters. // - params.insert("LS_protocol", "TLCP-2.5.0"); - params.insert("LS_cid", "mgQkwtwdysogQz2BJ4Ji%20kOj2Bg"); + base_params.extend([ + ("LS_protocol", "TLCP-2.5.0"), + ("LS_cid", "mgQkwtwdysogQz2BJ4Ji%20kOj2Bg"), + ]); + + if let Some(adapter_set) = self.connection_details.get_adapter_set() { + base_params.insert("LS_adapter_set", adapter_set); + } // // Add optional parameters @@ -300,10 +310,10 @@ impl LightstreamerClient { .set_scheme("wss") .expect("Failed to set scheme to wss for WebSocket URL."), invalid_scheme => { - return Err(IllegalStateException::new(&format!( + return Err(Box::new(IllegalStateException::new(&format!( "Unsupported scheme '{}' found when converting HTTP URL to WebSocket URL.", invalid_scheme - ))); + )))); } } let ws_url = url.as_str(); @@ -318,7 +328,10 @@ impl LightstreamerClient { .header( HeaderName::from_static("host"), HeaderValue::from_str(url.host_str().unwrap_or("localhost")).map_err(|err| { - IllegalStateException::new(&format!("Invalid header value for header with name 'host': {}", err)) + IllegalStateException::new(&format!( + "Invalid header value for header with name 'host': {}", + err + )) })?, ) .header( @@ -337,34 +350,41 @@ impl LightstreamerClient { HeaderName::from_static("upgrade"), HeaderValue::from_static("websocket"), ) - .body(()) - .unwrap(); + .body(())?; // Connect to the Lightstreamer server using WebSocket. let ws_stream = match connect_async(request).await { Ok((ws_stream, response)) => { if let Some(server_header) = response.headers().get("server") { - println!("Connected to Lightstreamer server: {}", server_header.to_str().unwrap_or("")); + println!( + "Connected to Lightstreamer server: {}", + server_header.to_str().unwrap_or("") + ); } else { println!("Connected to Lightstreamer server"); } ws_stream - }, + } Err(err) => { - return Err(IllegalStateException::new(&format!( - "Failed to connect to Lightstreamer server with WebSocket: {}", - err + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::ConnectionRefused, + format!( + "Failed to connect to Lightstreamer server with WebSocket: {}", + err + ), ))); } }; - // Split the WebSocket stream into a write and read stream. + // Split the WebSocket stream into a write and a read stream. let (mut write_stream, mut read_stream) = ws_stream.split(); // // Confirm the connection by sending a 'wsok' message to the server. // - write_stream.send(Message::Text("wsok".into())).await.expect("Failed to send message"); + write_stream + .send(Message::Text("wsok".into())) + .await?; if let Some(result) = read_stream.next().await { match result? { Message::Text(text) => { @@ -372,53 +392,137 @@ impl LightstreamerClient { if clean_text == "wsok" { println!("Connection confirmed by server"); } else { - return Err(IllegalStateException::new(&format!( - "Unexpected message received from server: {}", - clean_text + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Unexpected message received from server: {}", clean_text), ))); } - }, + } non_text_message => { - println!("Unexpected non-text message from server: {:?}", non_text_message); - }, + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "Unexpected non-text message from server: {:?}", + non_text_message + ), + ))); + } } } - /* - // Session creation parameters - let params = [ - ("LS_op2", "create_session"), - ("LS_cid", "mgQkwtwdysogQz2BJ4Ji kOj2Bg"), - ("LS_adapter_set", "DEMO"), - ]; - - let encoded_params = serde_urlencoded::to_string(¶ms)?; - - // Send the create session message + // + // Session creation. + // + let mut session_id: Option<String> = None; + let encoded_params = serde_urlencoded::to_string(&base_params)?; write_stream - .send(Message::Text(format!("{}\n", encoded_params))) + .send(Message::Text(format!("create_session\r\n{}\n", encoded_params))) .await?; - */ - - // Listen for messages from the server - while let Some(message) = read_stream.next().await { - match message? { + if let Some(result) = read_stream.next().await { + match result? { Message::Text(text) => { - if text.starts_with("CONOK") { - let session_info: Vec<&str> = text.split(",").collect(); - let session_id = session_info.get(1).unwrap_or(&"").to_string(); - println!("Session established with ID: {}", session_id); - //subscribe_to_channel_ws(session_id, write_stream).await?; - break; // Exit after successful subscription + let clean_text = clean_message(&text); + if clean_text.starts_with("conok") { + let session_info: Vec<&str> = clean_text.split(",").collect(); + session_id = session_info.get(1).map(|s| s.to_string()); } else { - println!("Received unexpected message from server: {}", text); + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Unexpected message received from server: {}", clean_text), + ))); } } - msg => { println!("Received non-text message from server: {:?}", msg); } + non_text_message => { + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "Unexpected non-text message from server: {:?}", + non_text_message + ), + ))); + } + } + } + + // + // Perform subscription. + // + if let Some(session_id) = session_id { + let mut params = base_params.clone(); + params.extend([ + ("LS_session", session_id.as_str()), + ("LS_op", "add"), + ("LS_table", "1"), + ("LS_id", "1"), + ("LS_mode", "MERGE"), + ("LS_schema", "stock_name,last_price"), + ("LS_data_adapter", "QUOTE_ADAPTER"), + ("LS_snapshot", "true"), + ]); + let encoded_params = serde_urlencoded::to_string(&base_params)?; + write_stream + .send(Message::Text(format!("control\r\n{}\n", encoded_params))) + .await?; + if let Some(result) = read_stream.next().await { + match result? { + Message::Text(text) => { + let clean_text = clean_message(&text); + if clean_text.starts_with("subok") { + println!("Subscription confirmed by server"); + } else { + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Unexpected message received from server: {}", clean_text), + ))); + } + } + non_text_message => { + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "Unexpected non-text message from server: {:?}", + non_text_message + ), + ))); + } + } + } + + + // Listen for messages from the server + loop { + tokio::select! { + message = read_stream.next() => { + match message { + Some(Ok(Message::Text(text))) => { + println!("Received message from server: {}", text); + }, + Some(Ok(non_text_message)) => { + println!("Received non-text message from server: {:?}", non_text_message); + }, + Some(Err(err)) => { + return Err(Box::new( + std::io::Error::new(std::io::ErrorKind::InvalidData, format!( + "Error reading message from server: {}", + err + )), + )); + }, + None => { + println!("No more messages from server"); + break; + }, + } + }, + _ = shutdown_signal.notified() => { + println!("Received shutdown signal"); + break; + }, + } } } - println!("No more messages from server"); + println!("Ending function connect() to Lightstreamer server"); Ok(()) } @@ -439,7 +543,7 @@ impl LightstreamerClient { /// "DISCONNECTED", then nothing will be done. /// /// See also `connect()` - pub fn disconnect(&mut self) { + pub async fn disconnect(&mut self) { // Implementation for disconnect println!("Disconnecting from Lightstreamer server"); } diff --git a/src/main.rs b/src/main.rs index 87f9da4..8431333 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,14 +13,9 @@ use signal_hook::{consts::SIGINT, consts::SIGTERM, iterator::Signals}; use std::error::Error; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use tokio::sync::Mutex; +use tokio::sync::{Notify, Mutex}; use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; -struct SharedState { - client: Arc<Mutex<LightstreamerClient>>, - should_disconnect: Arc<AtomicBool>, -} - /// Sets up a signal hook for SIGINT and SIGTERM. /// /// Creates a signal hook for the specified signals and spawns a thread to handle them. @@ -35,7 +30,7 @@ struct SharedState { /// /// The function panics if it fails to create the signal iterator. /// -async fn setup_signal_hook(shared_state: Arc<Mutex<SharedState>>) { +async fn setup_signal_hook(shutdown_signal: Arc<Notify>) { // Create a signal set of signals to be handled and a signal iterator to monitor them. let signals = &[SIGINT, SIGTERM]; let mut signals_iterator = Signals::new(signals).expect("Failed to create signal iterator"); @@ -44,18 +39,8 @@ async fn setup_signal_hook(shared_state: Arc<Mutex<SharedState>>) { tokio::spawn(async move { for signal in signals_iterator.forever() { println!("Received signal: {}", signal_name(signal).unwrap()); - // - // Clean up and prepare to exit... - // ... - { - let shared_state = shared_state.lock().await; - shared_state.should_disconnect.store(true, Ordering::Relaxed); - let mut client = shared_state.client.lock().await; - client.disconnect(); - } - - // Exit with 0 code to indicate orderly shutdown. - std::process::exit(0); + let _ = shutdown_signal.notify_one(); + break; } }); } @@ -106,14 +91,10 @@ async fn main() -> Result<(), Box<dyn Error>> { client.connection_options.set_forced_transport(Some(Transport::WsStreaming)); } - let should_disconnect = Arc::new(AtomicBool::new(false)); - let shared_state = Arc::new(Mutex::new(SharedState { - client: client.clone(), - should_disconnect: should_disconnect.clone(), - })); - + // Create a new Notify instance to send a shutdown signal to the signal handler thread. + let shutdown_signal = Arc::new(tokio::sync::Notify::new()); // Spawn a new thread to handle SIGINT and SIGTERM process signals. - setup_signal_hook(shared_state).await; + setup_signal_hook(Arc::clone(&shutdown_signal)).await; // // Infinite loop that will indefinitely retry failed connections unless @@ -121,10 +102,13 @@ async fn main() -> Result<(), Box<dyn Error>> { // let mut retry_interval_milis: u64 = 0; let mut retry_counter: u64 = 0; - loop { + while retry_counter < 5 { let mut client = client.lock().await; - match client.connect().await { - Ok(_) => {} + match client.connect(Arc::clone(&shutdown_signal)).await { + Ok(_) => { + client.disconnect().await; + break; + } Err(e) => { println!("Failed to connect: {:?}", e); tokio::time::sleep(std::time::Duration::from_millis(retry_interval_milis)).await; @@ -137,4 +121,13 @@ async fn main() -> Result<(), Box<dyn Error>> { } } } + + if retry_counter == 5 { + println!("Failed to connect after {} retries. Exiting...", retry_counter); + } else { + println!("Exiting orderly from Lightstreamer client..."); + } + + // Exit using std::process::exit() to avoid waiting for existing tokio tasks to complete. + std::process::exit(0); } |