1use clap::ArgMatches;
2use gethostname::gethostname;
3use local_ip_address::{local_ip, local_ipv6};
4use std::error::Error;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::{Arc, Mutex};
7use std::thread;
8use std::time::Duration;
9use tonic::transport::{Certificate, Channel, ClientTlsConfig};
10use tonic::Request;
11
12use pnet::datalink::{self, Channel as SocketChannel};
13
14use custom_module::manycastr::{
15 controller_client::ControllerClient, task::Data, Address, End, Finished, Origin, Task,
16 TaskResult,
17};
18
19use crate::net::packet::is_in_prefix;
20use crate::worker::inbound::{inbound, InboundConfig};
21use crate::worker::outbound::{outbound, OutboundConfig};
22use crate::{custom_module, ALL_ID, A_ID, CHAOS_ID, ICMP_ID, TCP_ID};
23
24mod inbound;
25mod outbound;
26
27#[derive(Clone)]
40pub struct Worker {
41 grpc_client: ControllerClient<Channel>,
42 hostname: String,
43 is_active: Arc<Mutex<bool>>,
44 current_m_id: Arc<Mutex<u32>>,
45 outbound_tx: Option<tokio::sync::mpsc::Sender<Data>>,
46 abort_s: Arc<AtomicBool>,
47}
48
49impl Worker {
50 pub async fn new(args: &ArgMatches) -> Result<Worker, Box<dyn Error>> {
58 let hostname = args
60 .get_one::<String>("hostname")
61 .map(|h| h.parse::<String>().expect("Unable to parse hostname"))
62 .unwrap_or_else(|| gethostname().into_string().expect("Unable to get hostname"));
63
64 let orc_addr = args.get_one::<String>("orchestrator").unwrap();
65 let fqdn = args.get_one::<String>("tls");
66 let client = Worker::connect(orc_addr.parse().unwrap(), fqdn)
67 .await
68 .expect("Unable to connect to orchestrator");
69
70 let mut worker = Worker {
72 grpc_client: client,
73 hostname,
74 is_active: Arc::new(Mutex::new(false)),
75 current_m_id: Arc::new(Mutex::new(0)),
76 outbound_tx: None,
77 abort_s: Arc::new(AtomicBool::new(false)),
78 };
79
80 worker.connect_to_server().await?;
81
82 Ok(worker)
83 }
84
85 async fn connect(
99 address: String,
100 fqdn: Option<&String>,
101 ) -> Result<ControllerClient<Channel>, Box<dyn Error>> {
102 let channel = if let Some(fqdn) = fqdn {
103 let addr = format!("https://{address}");
105
106 let pem = std::fs::read_to_string("tls/orchestrator.crt")
108 .expect("Unable to read CA certificate at ./tls/orchestrator.crt");
109 let ca = Certificate::from_pem(pem);
110 let tls = ClientTlsConfig::new().ca_certificate(ca).domain_name(fqdn);
112
113 let builder = Channel::from_shared(addr.to_owned()).expect("Unable to set address"); builder
115 .keep_alive_timeout(Duration::from_secs(30))
116 .http2_keep_alive_interval(Duration::from_secs(15))
117 .tcp_keepalive(Some(Duration::from_secs(60)))
118 .tls_config(tls)
119 .expect("Unable to set TLS configuration")
120 .connect()
121 .await
122 .expect("Unable to connect to orchestrator")
123 } else {
124 let addr = format!("http://{address}");
126
127 Channel::from_shared(addr.to_owned())
128 .expect("Unable to set address")
129 .keep_alive_timeout(Duration::from_secs(30))
130 .http2_keep_alive_interval(Duration::from_secs(15))
131 .tcp_keepalive(Some(Duration::from_secs(60)))
132 .connect()
133 .await
134 .expect("Unable to connect to orchestrator")
135 };
136 let client = ControllerClient::new(channel);
138
139 Ok(client)
140 }
141
142 fn init(&mut self, task: Task, worker_id: u16, abort_s: Option<Arc<AtomicBool>>) {
156 let start_measurement = if let Data::Start(start) = task.data.unwrap() {
157 start
158 } else {
159 panic!("Received non-start packet for init")
160 };
161 let m_id = start_measurement.m_id;
162 let is_ipv6 = start_measurement.is_ipv6;
163 let mut rx_origins: Vec<Origin> = start_measurement.rx_origins;
164 let is_unicast = start_measurement.is_unicast;
165 let is_probing = !start_measurement.tx_origins.is_empty();
166 let qname = start_measurement.record;
167 let info_url = start_measurement.url;
168 let probing_rate = start_measurement.rate;
169 let is_latency = start_measurement.is_latency;
170
171 let outbound_rx = if is_probing {
173 let (tx, rx) = tokio::sync::mpsc::channel(1000);
174 self.outbound_tx = Some(tx);
175 Some(rx)
176 } else {
177 None
178 };
179
180 let tx_origins: Vec<Origin> = if !is_probing {
181 vec![]
182 } else if is_unicast {
183 let sport = start_measurement.tx_origins[0].sport;
185 let dport = start_measurement.tx_origins[0].dport;
186
187 let unicast_ip = Address::from(if is_ipv6 {
189 local_ipv6().expect("Unable to get local unicast IPv6 address")
190 } else {
191 local_ip().expect("Unable to get local unicast IPv4 address")
192 });
193
194 let unicast_origin = Origin {
195 src: Some(unicast_ip), sport, dport, origin_id: u32::MAX, };
200
201 rx_origins = vec![unicast_origin];
203
204 println!("[Worker] Using local unicast IP address: {unicast_ip}");
205 vec![unicast_origin]
207 } else {
208 start_measurement.tx_origins
210 };
211
212 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
214
215 let interfaces = datalink::interfaces();
217
218 let addr = rx_origins[0].src.unwrap().to_string();
220 let interface = if let Some(interface) = interfaces
221 .iter()
222 .find(|iface| iface.ips.iter().any(|ip| is_in_prefix(&addr, ip)))
223 {
224 println!(
225 "[Worker] Found interface: {}, for address {}",
226 interface.name, addr
227 );
228 interface.clone() } else {
230 let interface = interfaces
232 .into_iter()
233 .find(|iface| !iface.is_loopback())
234 .expect("Failed to find default interface");
235 println!(
236 "[Worker] No interface found for address: {}, using default interface {}",
237 addr, interface.name
238 );
239 interface
240 };
241
242 let config = datalink::Config {
244 write_buffer_size: 10 * 1024 * 1024, read_buffer_size: 10 * 1024 * 1024, ..Default::default()
247 };
248 let (socket_tx, socket_rx) = match datalink::channel(&interface, config) {
249 Ok(SocketChannel::Ethernet(socket_tx, socket_rx)) => (socket_tx, socket_rx),
250 Ok(_) => panic!("Unsupported channel type"),
251 Err(e) => panic!("Failed to create datalink channel: {e}"),
252 };
253
254 if !is_unicast || is_probing {
256 let config = InboundConfig {
257 m_id,
258 worker_id,
259 is_ipv6,
260 m_type: start_measurement.m_type as u8,
261 origin_map: rx_origins,
262 abort_s: self.abort_s.clone(),
263 };
264
265 inbound(config, tx, socket_rx);
266 }
267
268 if is_probing {
269 match start_measurement.m_type as u8 {
270 ICMP_ID => {
271 for origin in tx_origins.iter() {
273 println!(
274 "[Worker] Sending on address: {} using ICMP identifier {}",
275 origin.src.unwrap(),
276 origin.dport
277 );
278 }
279 }
280 A_ID | TCP_ID | CHAOS_ID | ALL_ID => {
281 for origin in tx_origins.iter() {
283 println!(
284 "[Worker] Sending on address: {}, from src port {}, to dst port {}",
285 origin.src.unwrap(),
286 origin.sport,
287 origin.dport
288 );
289 }
290 }
291 _ => (),
292 }
293
294 let config = OutboundConfig {
295 worker_id,
296 tx_origins,
297 abort_s: abort_s.unwrap(),
298 is_ipv6,
299 is_symmetric: is_latency || is_unicast,
300 m_id,
301 m_type: start_measurement.m_type as u8,
302 qname,
303 info_url,
304 if_name: interface.name,
305 probing_rate,
306 };
307
308 outbound(config, outbound_rx.unwrap(), socket_tx);
310 } else {
311 println!("[Worker] Not sending probes");
312 }
313
314 let mut self_clone = self.clone();
315 thread::Builder::new()
317 .name("forwarder_thread".to_string())
318 .spawn(move || {
319 let rt = tokio::runtime::Runtime::new().unwrap();
320 let _enter = rt.enter();
321
322 rt.block_on(async {
323 while let Some(packet) = rx.recv().await {
325 if packet == TaskResult::default() {
327 self_clone
328 .measurement_finish_to_server(Finished {
329 m_id,
330 worker_id: worker_id.into(),
331 })
332 .await
333 .unwrap();
334
335 break;
336 }
337
338 self_clone
339 .send_result_to_server(packet)
340 .await
341 .expect("Unable to send task result to orchestrator");
342 }
343 rx.close();
344 });
345 })
346 .expect("Unable to start forwarder thread");
347 }
348
349 async fn connect_to_server(&mut self) -> Result<(), Box<dyn Error>> {
353 println!("[Worker] Connecting to orchestrator");
354 let mut abort_s: Option<Arc<AtomicBool>> = None;
355
356 let unicast_v6 = local_ipv6().ok().map(Address::from);
358 let unicast_v4 = local_ip().ok().map(Address::from);
359
360 let worker = custom_module::manycastr::Worker {
361 hostname: self.hostname.clone(),
362 worker_id: 0, status: "".to_string(),
364 unicast_v6,
365 unicast_v4,
366 };
367
368 let response = self
370 .grpc_client
371 .worker_connect(Request::new(worker))
372 .await
373 .expect("Unable to connect to orchestrator");
374
375 let mut stream = response.into_inner();
376 let id_message = stream
378 .message()
379 .await
380 .expect("Unable to await stream")
381 .expect("Unable to receive worker ID");
382 let worker_id = id_message.worker_id.expect("No initial worker ID set") as u16;
383 println!(
384 "[Worker] Successfully connected with the orchestrator with worker_id: {worker_id}"
385 );
386
387 while let Some(task) = stream.message().await.expect("Unable to receive task") {
389 if *self.is_active.lock().unwrap() {
391 match task.data {
393 None => {
394 println!("[Worker] Received empty task, skipping");
395 continue;
396 }
397 Some(Data::Start(_)) => {
398 println!("[Worker] Received new measurement during an active measurement, skipping");
399 continue;
400 }
401 Some(Data::End(data)) => {
402 if data.code == 0 {
404 println!(
405 "[Worker] Received measurement finished signal from orchestrator"
406 );
407 self.abort_s.store(true, Ordering::SeqCst);
409 if let Some(tx) = self.outbound_tx.take() {
411 tx.send(Data::End(End { code: 0 })).await.expect(
412 "Unable to send measurement_finished to outbound thread",
413 );
414 }
415 } else if data.code == 1 {
416 println!("[Worker] CLI disconnected, aborting measurement");
417
418 self.abort_s.store(true, Ordering::SeqCst);
420 if let Some(abort_s) = &abort_s {
422 abort_s.store(true, Ordering::SeqCst);
424 }
425 } else {
426 println!("[Worker] Received invalid code from orchestrator");
427 continue;
428 }
429 }
430 Some(task) => {
431 if let Some(outbound_tx) = &self.outbound_tx {
433 outbound_tx
435 .send(task)
436 .await
437 .expect("Unable to send task to outbound thread");
438 }
439 }
440 };
441
442 } else {
444 let (is_unicast, is_probing, m_id) =
445 match task.clone().data.expect("None start measurement task") {
446 Data::Start(start) => {
447 (start.is_unicast, !start.tx_origins.is_empty(), start.m_id)
448 }
449 _ => {
450 continue;
452 }
453 };
454
455 if is_unicast && !is_probing {
457 println!("[Worker] Not probing for unicast measurement, skipping");
458 continue;
459 }
460
461 println!("[Worker] Starting new measurement");
462
463 *self.is_active.lock().unwrap() = true;
464 *self.current_m_id.lock().unwrap() = m_id;
465 self.abort_s.store(false, Ordering::SeqCst);
466
467 if is_probing {
468 abort_s = Some(Arc::new(AtomicBool::new(false)));
471
472 self.init(task, worker_id, abort_s.clone());
473 } else {
474 abort_s = None;
476 self.outbound_tx = None;
477 self.init(task, worker_id, None);
478 }
479 }
480 }
481 println!("[Worker] Stopped awaiting tasks");
482
483 Ok(())
484 }
485
486 async fn send_result_to_server(
488 &mut self,
489 task_result: TaskResult,
490 ) -> Result<(), Box<dyn Error>> {
491 self.grpc_client
492 .send_result(Request::new(task_result))
493 .await?;
494
495 Ok(())
496 }
497
498 async fn measurement_finish_to_server(
506 &mut self,
507 finished: Finished,
508 ) -> Result<(), Box<dyn Error>> {
509 println!(
510 "[Worker] Letting the orchestrator know that this worker finished the measurement"
511 );
512 *self.is_active.lock().unwrap() = false;
513 self.grpc_client
514 .measurement_finished(Request::new(finished))
515 .await?;
516
517 Ok(())
518 }
519}