manycast/worker/
mod.rs

1use clap::ArgMatches;
2use gethostname::gethostname;
3use local_ip_address::{local_ip, local_ipv6};
4use std::error::Error;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::{Arc, Mutex};
7use std::thread;
8use std::time::Duration;
9use tonic::transport::{Certificate, Channel, ClientTlsConfig};
10use tonic::Request;
11
12use pnet::datalink::{self, Channel as SocketChannel};
13
14use custom_module::manycastr::{
15    controller_client::ControllerClient, task::Data, Address, End, Finished, Origin, Task,
16    TaskResult,
17};
18
19use crate::net::packet::is_in_prefix;
20use crate::worker::inbound::{inbound, InboundConfig};
21use crate::worker::outbound::{outbound, OutboundConfig};
22use crate::{custom_module, ALL_ID, A_ID, CHAOS_ID, ICMP_ID, TCP_ID};
23
24mod inbound;
25mod outbound;
26
27/// The worker that is run at the anycast sites and performs measurements as instructed by the orchestrator.
28///
29/// The worker is responsible for establishing a connection with the orchestrator, receiving tasks, and performing measurements.
30///
31/// # Fields
32///
33/// * 'grpc_client' - the worker gRPC connection with the orchestrator
34/// * 'hostname' - the hostname of the worker
35/// * 'is_active' - boolean value that is set to true when the worker is currently doing a measurement
36/// * 'current_m_id' - contains the ID of the current measurement
37/// * 'outbound_tx' - contains the sender of a channel to the outbound prober that tasks are send to
38/// * 'inbound_f' - an atomic boolean that is used to signal the inbound thread to stop listening for packets
39#[derive(Clone)]
40pub struct Worker {
41    grpc_client: ControllerClient<Channel>,
42    hostname: String,
43    is_active: Arc<Mutex<bool>>,
44    current_m_id: Arc<Mutex<u32>>,
45    outbound_tx: Option<tokio::sync::mpsc::Sender<Data>>,
46    abort_s: Arc<AtomicBool>,
47}
48
49impl Worker {
50    /// Create a worker instance, which includes establishing a connection with the orchestrator.
51    ///
52    /// Extracts the parameters of the command-line arguments.
53    ///
54    /// # Arguments
55    ///
56    /// * 'args' - contains the parsed command-line arguments
57    pub async fn new(args: &ArgMatches) -> Result<Worker, Box<dyn Error>> {
58        // Get hostname from command line arguments or use the system hostname
59        let hostname = args
60            .get_one::<String>("hostname")
61            .map(|h| h.parse::<String>().expect("Unable to parse hostname"))
62            .unwrap_or_else(|| gethostname().into_string().expect("Unable to get hostname"));
63
64        let orc_addr = args.get_one::<String>("orchestrator").unwrap();
65        let fqdn = args.get_one::<String>("tls");
66        let client = Worker::connect(orc_addr.parse().unwrap(), fqdn)
67            .await
68            .expect("Unable to connect to orchestrator");
69
70        // Initialize a worker instance
71        let mut worker = Worker {
72            grpc_client: client,
73            hostname,
74            is_active: Arc::new(Mutex::new(false)),
75            current_m_id: Arc::new(Mutex::new(0)),
76            outbound_tx: None,
77            abort_s: Arc::new(AtomicBool::new(false)),
78        };
79
80        worker.connect_to_server().await?;
81
82        Ok(worker)
83    }
84
85    /// Connect to the orchestrator.
86    ///
87    /// # Arguments
88    ///
89    /// * 'address' - the address of the orchestrator in string format, containing both the IPv4 address and port number
90    ///
91    /// * 'fqdn' - an optional string that contains the FQDN of the orchestrator certificate (if TLS is enabled)
92    ///
93    /// # Example
94    ///
95    /// ```
96    /// let client = connect("127.0.0.0:50001", true);
97    /// ```
98    async fn connect(
99        address: String,
100        fqdn: Option<&String>,
101    ) -> Result<ControllerClient<Channel>, Box<dyn Error>> {
102        let channel = if let Some(fqdn) = fqdn {
103            // Secure connection
104            let addr = format!("https://{address}");
105
106            // Load the CA certificate used to authenticate the orchestrator
107            let pem = std::fs::read_to_string("tls/orchestrator.crt")
108                .expect("Unable to read CA certificate at ./tls/orchestrator.crt");
109            let ca = Certificate::from_pem(pem);
110            // Create a TLS configuration
111            let tls = ClientTlsConfig::new().ca_certificate(ca).domain_name(fqdn);
112
113            let builder = Channel::from_shared(addr.to_owned()).expect("Unable to set address"); // Use the address provided
114            builder
115                .keep_alive_timeout(Duration::from_secs(30))
116                .http2_keep_alive_interval(Duration::from_secs(15))
117                .tcp_keepalive(Some(Duration::from_secs(60)))
118                .tls_config(tls)
119                .expect("Unable to set TLS configuration")
120                .connect()
121                .await
122                .expect("Unable to connect to orchestrator")
123        } else {
124            // Unsecure connection
125            let addr = format!("http://{address}");
126
127            Channel::from_shared(addr.to_owned())
128                .expect("Unable to set address")
129                .keep_alive_timeout(Duration::from_secs(30))
130                .http2_keep_alive_interval(Duration::from_secs(15))
131                .tcp_keepalive(Some(Duration::from_secs(60)))
132                .connect()
133                .await
134                .expect("Unable to connect to orchestrator")
135        };
136        // Create worker with secret token that is used to authenticate worker commands.
137        let client = ControllerClient::new(channel);
138
139        Ok(client)
140    }
141
142    /// Initialize a new measurement by creating outbound and inbound threads, and ensures task results are sent back to the orchestrator.
143    ///
144    /// Extracts the protocol type from the measurement definition, and determines which source address to use.
145    /// Creates a socket to send out probes and receive replies with, calls the appropriate inbound & outbound functions.
146    /// Creates an additional thread that forwards task results to the orchestrator.
147    ///
148    /// # Arguments
149    ///
150    /// * 'task' - the first 'Task' message sent by the orchestrator, that contains the measurement definition
151    ///
152    /// * 'worker_id' - the unique ID of this worker
153    ///
154    /// * 'abort_s' - an optional Arc<AtomicBool> that is used to signal the outbound thread to stop sending probes
155    fn init(&mut self, task: Task, worker_id: u16, abort_s: Option<Arc<AtomicBool>>) {
156        let start_measurement = if let Data::Start(start) = task.data.unwrap() {
157            start
158        } else {
159            panic!("Received non-start packet for init")
160        };
161        let m_id = start_measurement.m_id;
162        let is_ipv6 = start_measurement.is_ipv6;
163        let mut rx_origins: Vec<Origin> = start_measurement.rx_origins;
164        let is_unicast = start_measurement.is_unicast;
165        let is_probing = !start_measurement.tx_origins.is_empty();
166        let qname = start_measurement.record;
167        let info_url = start_measurement.url;
168        let probing_rate = start_measurement.rate;
169        let is_latency = start_measurement.is_latency;
170
171        // Channel for forwarding tasks to outbound
172        let outbound_rx = if is_probing {
173            let (tx, rx) = tokio::sync::mpsc::channel(1000);
174            self.outbound_tx = Some(tx);
175            Some(rx)
176        } else {
177            None
178        };
179
180        let tx_origins: Vec<Origin> = if !is_probing {
181            vec![]
182        } else if is_unicast {
183            // Use the local unicast address and CLI defined ports
184            let sport = start_measurement.tx_origins[0].sport;
185            let dport = start_measurement.tx_origins[0].dport;
186
187            // Get the local unicast address
188            let unicast_ip = Address::from(if is_ipv6 {
189                local_ipv6().expect("Unable to get local unicast IPv6 address")
190            } else {
191                local_ip().expect("Unable to get local unicast IPv4 address")
192            });
193
194            let unicast_origin = Origin {
195                src: Some(unicast_ip), // Unicast IP
196                sport,                 // CLI defined source port
197                dport,                 // CLI defined destination port
198                origin_id: u32::MAX,   // ID for unicast address
199            };
200
201            // We only listen to our own unicast address (each worker has its own unicast address)
202            rx_origins = vec![unicast_origin];
203
204            println!("[Worker] Using local unicast IP address: {unicast_ip}");
205            // Use the local unicast address
206            vec![unicast_origin]
207        } else {
208            // Use the sender origins set by the orchestrator
209            start_measurement.tx_origins
210        };
211
212        // Channel for sending from inbound to the orchestrator forwarder thread
213        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
214
215        // Get the network interface to use
216        let interfaces = datalink::interfaces();
217
218        // Look for the interface that uses the listening IP address
219        let addr = rx_origins[0].src.unwrap().to_string();
220        let interface = if let Some(interface) = interfaces
221            .iter()
222            .find(|iface| iface.ips.iter().any(|ip| is_in_prefix(&addr, ip)))
223        {
224            println!(
225                "[Worker] Found interface: {}, for address {}",
226                interface.name, addr
227            );
228            interface.clone() // Return the found interface
229        } else {
230            // Use the default interface (first non-loopback interface)
231            let interface = interfaces
232                .into_iter()
233                .find(|iface| !iface.is_loopback())
234                .expect("Failed to find default interface");
235            println!(
236                "[Worker] No interface found for address: {}, using default interface {}",
237                addr, interface.name
238            );
239            interface
240        };
241
242        // Create a socket to send out probes and receive replies with
243        let config = datalink::Config {
244            write_buffer_size: 10 * 1024 * 1024, // 10 MB
245            read_buffer_size: 10 * 1024 * 1024,  // 10 MB
246            ..Default::default()
247        };
248        let (socket_tx, socket_rx) = match datalink::channel(&interface, config) {
249            Ok(SocketChannel::Ethernet(socket_tx, socket_rx)) => (socket_tx, socket_rx),
250            Ok(_) => panic!("Unsupported channel type"),
251            Err(e) => panic!("Failed to create datalink channel: {e}"),
252        };
253
254        // Start listening thread (except if it is a unicast measurement and we are not probing)
255        if !is_unicast || is_probing {
256            let config = InboundConfig {
257                m_id,
258                worker_id,
259                is_ipv6,
260                m_type: start_measurement.m_type as u8,
261                origin_map: rx_origins,
262                abort_s: self.abort_s.clone(),
263            };
264
265            inbound(config, tx, socket_rx);
266        }
267
268        if is_probing {
269            match start_measurement.m_type as u8 {
270                ICMP_ID => {
271                    // Print all probe origin addresses
272                    for origin in tx_origins.iter() {
273                        println!(
274                            "[Worker] Sending on address: {} using ICMP identifier {}",
275                            origin.src.unwrap(),
276                            origin.dport
277                        );
278                    }
279                }
280                A_ID | TCP_ID | CHAOS_ID | ALL_ID => {
281                    // Print all probe origin addresses
282                    for origin in tx_origins.iter() {
283                        println!(
284                            "[Worker] Sending on address: {}, from src port {}, to dst port {}",
285                            origin.src.unwrap(),
286                            origin.sport,
287                            origin.dport
288                        );
289                    }
290                }
291                _ => (),
292            }
293
294            let config = OutboundConfig {
295                worker_id,
296                tx_origins,
297                abort_s: abort_s.unwrap(),
298                is_ipv6,
299                is_symmetric: is_latency || is_unicast,
300                m_id,
301                m_type: start_measurement.m_type as u8,
302                qname,
303                info_url,
304                if_name: interface.name,
305                probing_rate,
306            };
307
308            // Start sending thread
309            outbound(config, outbound_rx.unwrap(), socket_tx);
310        } else {
311            println!("[Worker] Not sending probes");
312        }
313
314        let mut self_clone = self.clone();
315        // Thread that listens for task results from inbound and forwards them to the orchestrator
316        thread::Builder::new()
317            .name("forwarder_thread".to_string())
318            .spawn(move || {
319                let rt = tokio::runtime::Runtime::new().unwrap();
320                let _enter = rt.enter();
321
322                rt.block_on(async {
323                    // Obtain TaskResults from the unbounded channel and send them to the orchestrator
324                    while let Some(packet) = rx.recv().await {
325                        // A default TaskResult notifies this sender that there will be no more results
326                        if packet == TaskResult::default() {
327                            self_clone
328                                .measurement_finish_to_server(Finished {
329                                    m_id,
330                                    worker_id: worker_id.into(),
331                                })
332                                .await
333                                .unwrap();
334
335                            break;
336                        }
337
338                        self_clone
339                            .send_result_to_server(packet)
340                            .await
341                            .expect("Unable to send task result to orchestrator");
342                    }
343                    rx.close();
344                });
345            })
346            .expect("Unable to start forwarder thread");
347    }
348
349    /// Establish a formal connection with the orchestrator.
350    ///
351    /// Obtains a unique worker ID from the orchestrator, establishes a stream for receiving tasks, and handles tasks as they come in.
352    async fn connect_to_server(&mut self) -> Result<(), Box<dyn Error>> {
353        println!("[Worker] Connecting to orchestrator");
354        let mut abort_s: Option<Arc<AtomicBool>> = None;
355
356        // Get the local unicast addresses
357        let unicast_v6 = local_ipv6().ok().map(Address::from);
358        let unicast_v4 = local_ip().ok().map(Address::from);
359
360        let worker = custom_module::manycastr::Worker {
361            hostname: self.hostname.clone(),
362            worker_id: 0, // This will be set after the connection
363            status: "".to_string(),
364            unicast_v6,
365            unicast_v4,
366        };
367
368        // Connect to the orchestrator
369        let response = self
370            .grpc_client
371            .worker_connect(Request::new(worker))
372            .await
373            .expect("Unable to connect to orchestrator");
374
375        let mut stream = response.into_inner();
376        // Read the assigned unique worker ID
377        let id_message = stream
378            .message()
379            .await
380            .expect("Unable to await stream")
381            .expect("Unable to receive worker ID");
382        let worker_id = id_message.worker_id.expect("No initial worker ID set") as u16;
383        println!(
384            "[Worker] Successfully connected with the orchestrator with worker_id: {worker_id}"
385        );
386
387        // Await tasks
388        while let Some(task) = stream.message().await.expect("Unable to receive task") {
389            // If we already have an active measurement
390            if *self.is_active.lock().unwrap() {
391                // If the CLI disconnected we will receive this message
392                match task.data {
393                    None => {
394                        println!("[Worker] Received empty task, skipping");
395                        continue;
396                    }
397                    Some(Data::Start(_)) => {
398                        println!("[Worker] Received new measurement during an active measurement, skipping");
399                        continue;
400                    }
401                    Some(Data::End(data)) => {
402                        // Received finish signal
403                        if data.code == 0 {
404                            println!(
405                                "[Worker] Received measurement finished signal from orchestrator"
406                            );
407                            // Close inbound threads
408                            self.abort_s.store(true, Ordering::SeqCst);
409                            // Close outbound threads gracefully
410                            if let Some(tx) = self.outbound_tx.take() {
411                                tx.send(Data::End(End { code: 0 })).await.expect(
412                                    "Unable to send measurement_finished to outbound thread",
413                                );
414                            }
415                        } else if data.code == 1 {
416                            println!("[Worker] CLI disconnected, aborting measurement");
417
418                            // Close the inbound threads
419                            self.abort_s.store(true, Ordering::SeqCst);
420                            // finish will be None if this worker is not probing
421                            if let Some(abort_s) = &abort_s {
422                                // Close outbound threads
423                                abort_s.store(true, Ordering::SeqCst);
424                            }
425                        } else {
426                            println!("[Worker] Received invalid code from orchestrator");
427                            continue;
428                        }
429                    }
430                    Some(task) => {
431                        // outbound_tx will be None if this worker is not probing
432                        if let Some(outbound_tx) = &self.outbound_tx {
433                            // Send the task to the prober
434                            outbound_tx
435                                .send(task)
436                                .await
437                                .expect("Unable to send task to outbound thread");
438                        }
439                    }
440                };
441
442                // If we don't have an active measurement
443            } else {
444                let (is_unicast, is_probing, m_id) =
445                    match task.clone().data.expect("None start measurement task") {
446                        Data::Start(start) => {
447                            (start.is_unicast, !start.tx_origins.is_empty(), start.m_id)
448                        }
449                        _ => {
450                            // First task is not a start measurement task
451                            continue;
452                        }
453                    };
454
455                // If we are not probing for a unicast measurement, we do nothing
456                if is_unicast && !is_probing {
457                    println!("[Worker] Not probing for unicast measurement, skipping");
458                    continue;
459                }
460
461                println!("[Worker] Starting new measurement");
462
463                *self.is_active.lock().unwrap() = true;
464                *self.current_m_id.lock().unwrap() = m_id;
465                self.abort_s.store(false, Ordering::SeqCst);
466
467                if is_probing {
468                    // This worker is probing
469                    // Initialize signal finish atomic boolean
470                    abort_s = Some(Arc::new(AtomicBool::new(false)));
471
472                    self.init(task, worker_id, abort_s.clone());
473                } else {
474                    // This worker is not probing
475                    abort_s = None;
476                    self.outbound_tx = None;
477                    self.init(task, worker_id, None);
478                }
479            }
480        }
481        println!("[Worker] Stopped awaiting tasks");
482
483        Ok(())
484    }
485
486    /// Send a TaskResult to the orchestrator
487    async fn send_result_to_server(
488        &mut self,
489        task_result: TaskResult,
490    ) -> Result<(), Box<dyn Error>> {
491        self.grpc_client
492            .send_result(Request::new(task_result))
493            .await?;
494
495        Ok(())
496    }
497
498    /// Let the orchestrator know the current measurement is finished.
499    ///
500    /// When a measurement is finished the orchestrator knows not to expect any more results from this worker.
501    ///
502    /// # Arguments
503    ///
504    /// * 'finished' - the 'Finished' message to send to the orchestrator
505    async fn measurement_finish_to_server(
506        &mut self,
507        finished: Finished,
508    ) -> Result<(), Box<dyn Error>> {
509        println!(
510            "[Worker] Letting the orchestrator know that this worker finished the measurement"
511        );
512        *self.is_active.lock().unwrap() = false;
513        self.grpc_client
514            .measurement_finished(Request::new(finished))
515            .await?;
516
517        Ok(())
518    }
519}