1use std::cmp::PartialEq;
2use std::collections::{HashMap, HashSet, VecDeque};
3use std::fmt;
4use std::fs;
5use std::net::SocketAddr;
6use std::ops::AddAssign;
7use std::path::Path;
8use std::pin::Pin;
9use std::sync::atomic::AtomicBool;
10use std::sync::{Arc, Mutex};
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14use crate::custom_module;
15use crate::custom_module::manycastr::Address;
16use crate::orchestrator::mpsc::Sender;
17use crate::orchestrator::WorkerStatus::{Disconnected, Idle, Listening, Probing};
18use clap::ArgMatches;
19use custom_module::manycastr::{
20 controller_server::Controller, controller_server::ControllerServer, task::Data::End as TaskEnd,
21 task::Data::Start as TaskStart, Ack, Empty, End, Finished, ScheduleMeasurement, Start,
22 Status as ServerStatus, Targets, Task, TaskResult, Worker,
23};
24use futures_core::Stream;
25use rand::Rng;
26use tokio::spawn;
27use tokio::sync::mpsc;
28use tokio::time::Instant;
29use tonic::codec::CompressionEncoding;
30use tonic::transport::{Identity, ServerTlsConfig};
31use tonic::{transport::Server, Request, Response, Status};
32
33type ResultMessage = Result<TaskResult, Status>;
34type CliSender = Sender<ResultMessage>;
35type CliHandle = Arc<Mutex<Option<CliSender>>>;
36
37type TaskMessage = Result<Task, Status>;
38
39#[derive(Debug)]
54pub struct ControllerService {
55 workers: Arc<Mutex<Vec<WorkerSender<TaskMessage>>>>,
56 cli_sender: CliHandle,
57 open_measurements: Arc<Mutex<HashMap<u32, u32>>>,
58 m_id: Arc<Mutex<u32>>,
59 unique_id: Arc<Mutex<u32>>,
60 is_active: Arc<Mutex<bool>>,
61 is_responsive: Arc<AtomicBool>,
62 is_latency: Arc<AtomicBool>,
63 worker_config: Option<HashMap<String, u32>>,
64 worker_stacks: Arc<Mutex<HashMap<u32, VecDeque<Address>>>>,
65}
66
67const BREAK_SIGNAL: u32 = u32::MAX - 1;
68const ALL_WORKERS_DIRECT: u32 = u32::MAX;
69const ALL_WORKERS_INTERVAL: u32 = u32::MAX - 2;
70
71#[derive(Debug, Clone, Copy)]
72pub enum WorkerStatus {
73 Idle, Probing, Listening, Disconnected, }
78
79impl fmt::Display for WorkerStatus {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 let s = match self {
82 Idle => "IDLE",
83 Probing => "PROBING",
84 Listening => "LISTENING",
85 Disconnected => "DISCONNECTED",
86 };
87 write!(f, "{s}")
88 }
89}
90
91pub struct WorkerReceiver<T> {
106 inner: mpsc::Receiver<T>,
107 open_measurements: Arc<Mutex<HashMap<u32, u32>>>,
108 cli_sender: CliHandle,
109 hostname: String,
110 status: Arc<Mutex<WorkerStatus>>,
111}
112
113impl<T> Stream for WorkerReceiver<T> {
114 type Item = T;
115
116 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
117 self.inner.poll_recv(cx)
118 }
119}
120
121impl<T> Drop for WorkerReceiver<T> {
122 fn drop(&mut self) {
123 println!("[Orchestrator] Worker {} lost connection", self.hostname);
124
125 let mut open_measurements = self.open_measurements.lock().unwrap();
127 if !open_measurements.is_empty() {
128 for (m_id, remaining) in open_measurements.clone().iter() {
129 if remaining == &0 {
131 continue;
132 }
133 if remaining == &1 {
135 open_measurements.remove(m_id);
137
138 println!("[Orchestrator] The last worker for a measurement dropped, sending measurement finished signal to CLI");
139 match self
140 .cli_sender
141 .lock()
142 .unwrap()
143 .clone()
144 .unwrap()
145 .try_send(Ok(TaskResult::default()))
146 {
147 Ok(_) => (),
148 Err(_) => println!(
149 "[Orchestrator] Failed to send measurement finished signal to CLI"
150 ),
151 }
152 } else {
153 *open_measurements.get_mut(m_id).unwrap() -= 1;
155 }
156 }
157 }
158
159 let mut status = self.status.lock().unwrap();
160 *status = Disconnected;
161 }
162}
163
164#[derive(Clone)]
175pub struct WorkerSender<T> {
176 inner: Sender<T>,
177 worker_id: u32,
178 hostname: String,
179 status: Arc<Mutex<WorkerStatus>>,
180 unicast_v4: Option<Address>,
181 unicast_v6: Option<Address>,
182}
183impl<T> WorkerSender<T> {
184 pub fn is_closed(&self) -> bool {
186 self.inner.is_closed()
187 }
188
189 pub async fn send(&self, task: T) -> Result<(), mpsc::error::SendError<T>> {
191 match self.inner.send(task).await {
192 Ok(_) => Ok(()),
193 Err(e) => {
194 self.cleanup();
195 Err(e)
196 }
197 }
198 }
199
200 fn cleanup(&self) {
201 let mut status = self.status.lock().unwrap();
203 *status = Disconnected;
204
205 println!(
206 "[Orchestrator] Worker {} (ID: {}) dropped",
207 self.hostname, self.worker_id
208 );
209 }
210
211 pub fn is_probing(&self) -> bool {
212 *self.status.lock().unwrap() == Probing
213 }
214
215 pub fn get_status(&self) -> String {
216 self.status.lock().unwrap().clone().to_string()
217 }
218
219 pub fn finished(&self) {
221 let mut status = self.status.lock().unwrap();
222 if *status != Disconnected {
224 *status = Idle;
225 }
226 }
227}
228impl<T> fmt::Debug for WorkerSender<T> {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 write!(
231 f,
232 "WorkerSender {{ worker_id: {}, hostname: {}, status: {} }}",
233 self.worker_id,
234 self.hostname,
235 self.get_status()
236 )
237 }
238}
239
240pub struct CLIReceiver<T> {
253 inner: mpsc::Receiver<T>,
254 m_active: Arc<Mutex<bool>>,
255 m_id: u32,
256 open_measurements: Arc<Mutex<HashMap<u32, u32>>>,
257}
258
259impl<T> Stream for CLIReceiver<T> {
260 type Item = T;
261
262 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
263 self.inner.poll_recv(cx)
264 }
265}
266
267impl<T> Drop for CLIReceiver<T> {
268 fn drop(&mut self) {
269 let mut is_active = self.m_active.lock().unwrap();
270
271 if *is_active {
273 println!(
274 "[Orchestrator] CLI dropped during an active measurement, terminating measurement"
275 );
276 }
277 *is_active = false; let mut open_measurements = self.open_measurements.lock().unwrap();
281 open_measurements.remove(&self.m_id);
282 }
283}
284
285impl PartialEq for WorkerStatus {
286 fn eq(&self, other: &Self) -> bool {
287 matches!(
288 (self, other),
289 (Idle, Idle)
290 | (Probing, Probing)
291 | (Listening, Listening)
292 | (Disconnected, Disconnected)
293 )
294 }
295}
296
297#[tonic::async_trait]
300impl Controller for ControllerService {
301 async fn measurement_finished(
313 &self,
314 request: Request<Finished>,
315 ) -> Result<Response<Ack>, Status> {
316 let finished_measurement = request.into_inner();
317 let m_id: u32 = finished_measurement.m_id;
318 let tx = self.cli_sender.lock().unwrap().clone().unwrap();
319
320 let is_finished = {
322 let mut open_measurements = self.open_measurements.lock().unwrap();
323
324 let remaining = if let Some(remaining) = open_measurements.get(&m_id) {
326 remaining
327 } else {
328 println!("[Orchestrator] Received measurement finished signal for non-existent measurement {}", &m_id);
329 return Ok(Response::new(Ack {
330 is_success: false,
331 error_message: "Measurement unknown".to_string(),
332 }));
333 };
334
335 if remaining == &(1u32) {
336 println!("[Orchestrator] All workers finished");
338
339 open_measurements.remove(&m_id);
340 true } else {
342 *open_measurements.get_mut(&m_id).unwrap() -= 1;
344
345 false }
347 };
348 if is_finished {
349 println!("[Orchestrator] Notifying CLI that the measurement is finished");
350 *self.is_active.lock().unwrap() = false;
352
353 return match tx.send(Ok(TaskResult::default())).await {
355 Ok(_) => Ok(Response::new(Ack {
356 is_success: true,
357 error_message: "".to_string(),
358 })),
359 Err(_) => Ok(Response::new(Ack {
360 is_success: false,
361 error_message: "CLI disconnected".to_string(),
362 })),
363 };
364 } else {
365 Ok(Response::new(Ack {
367 is_success: true,
368 error_message: "".to_string(),
369 }))
370 }
371 }
372
373 type WorkerConnectStream = WorkerReceiver<Result<Task, Status>>;
374
375 async fn worker_connect(
385 &self,
386 request: Request<Worker>,
387 ) -> Result<Response<Self::WorkerConnectStream>, Status> {
388 let worker = request.into_inner();
389 let hostname = worker.hostname;
390 let unicast_v4 = worker.unicast_v4;
391 let unicast_v6 = worker.unicast_v6;
392 let (tx, rx) = mpsc::channel::<Result<Task, Status>>(1000);
393 let (worker_id, is_reconnect) = self
395 .get_worker_id(&hostname)
396 .map_err(|boxed_status| *boxed_status)?;
397
398 if is_reconnect {
399 println!("[Orchestrator] Reconnecting worker: {hostname}");
400 } else {
402 println!("[Orchestrator] New worker connected: {hostname}");
403 }
404
405 tx.send(Ok(Task {
407 worker_id: Some(worker_id),
408 data: None,
409 }))
410 .await
411 .expect("Failed to send worker ID");
412
413 let worker_status = Arc::new(Mutex::new(Idle));
414
415 let worker_tx = WorkerSender {
416 inner: tx,
417 worker_id,
418 hostname: hostname.clone(),
419 status: worker_status.clone(),
420 unicast_v4,
421 unicast_v6,
422 };
423
424 if is_reconnect {
426 let mut senders = self.workers.lock().unwrap();
427 senders.retain(|sender| sender.worker_id != worker_id);
428 }
429
430 self.workers.lock().unwrap().push(worker_tx);
432
433 let worker_rx = WorkerReceiver {
435 inner: rx,
436 open_measurements: self.open_measurements.clone(),
437 cli_sender: self.cli_sender.clone(),
438 hostname,
439 status: worker_status,
440 };
441
442 Ok(Response::new(worker_rx))
444 }
445 type DoMeasurementStream = CLIReceiver<Result<TaskResult, Status>>;
446
447 async fn do_measurement(
469 &self,
470 request: Request<ScheduleMeasurement>,
471 ) -> Result<Response<Self::DoMeasurementStream>, Status> {
472 println!("[Orchestrator] Received CLI measurement request for measurement");
473
474 {
476 let mut active = self.is_active.lock().unwrap();
478 if *active {
479 println!("[Orchestrator] There is already an active measurement, returning");
480 return Err(Status::new(
481 tonic::Code::Cancelled,
482 "There is already an active measurement",
483 ));
484 }
485
486 for (_, open) in self
488 .open_measurements
489 .lock()
490 .expect("No open measurements map")
491 .iter()
492 {
493 if open > &0 {
495 println!("[Orchestrator] There is already an active measurement, returning");
496 return Err(Status::new(
497 tonic::Code::Cancelled,
498 "There are still workers working on an active measurement",
499 ));
500 }
501 }
502
503 *active = true;
504 }
505
506 let m_definition = request.into_inner();
508 let is_responsive = m_definition.is_responsive;
509 let is_latency = m_definition.is_latency;
510 let is_divide = m_definition.is_divide;
511 let worker_interval = m_definition.worker_interval as u64;
512 let probe_interval = m_definition.probe_interval as u64;
513 let number_of_probes = m_definition.number_of_probes as u8;
514
515 let workers: Vec<WorkerSender<Result<Task, Status>>> = {
517 let mut workers = self.workers.lock().unwrap().clone();
518 workers.retain(|worker| {
520 if *worker.status.lock().unwrap() == Disconnected {
521 println!("[Orchestrator] Worker {} unavailable.", worker.hostname);
522 false
523 } else {
524 true
525 }
526 });
527
528 if m_definition.configurations.iter().any(|conf| {
530 !workers
531 .iter()
532 .any(|sender| sender.worker_id == conf.worker_id)
533 && conf.worker_id != u32::MAX
534 }) {
535 println!(
536 "[Orchestrator] Unknown worker in configuration list, terminating measurement."
537 );
538 *self.is_active.lock().unwrap() = false;
539 return Err(Status::new(
540 tonic::Code::Cancelled,
541 "Unknown worker in configuration",
542 ));
543 }
544
545 for worker in workers.iter_mut() {
547 let is_probing = m_definition.configurations.iter().any(|config| {
548 config.worker_id == worker.worker_id || config.worker_id == u32::MAX
549 });
550
551 if is_probing {
552 *worker.status.lock().unwrap() = Probing;
553 } else {
554 *worker.status.lock().unwrap() = Listening;
555 }
556 }
557
558 workers
559 };
560
561 if workers.is_empty() {
563 println!("[Orchestrator] No connected workers, terminating measurement.");
564 *self.is_active.lock().unwrap() = false;
565 return Err(Status::new(tonic::Code::Cancelled, "No connected workers"));
566 }
567
568 let m_id = {
570 let mut uniq_m_id = self.m_id.lock().unwrap();
571 let id = *uniq_m_id;
572 *uniq_m_id = uniq_m_id.wrapping_add(1);
573 id
574 };
575
576 let is_unicast = m_definition.is_unicast;
578
579 let number_of_workers = workers.len() as u32;
580 let number_of_probing_workers = workers.iter().filter(|sender| sender.is_probing()).count();
581
582 let number_of_listeners = if is_unicast {
583 number_of_probing_workers as u32
585 } else {
586 number_of_workers
588 };
589
590 self.open_measurements
592 .lock()
593 .unwrap()
594 .insert(m_id, number_of_listeners);
595
596 let probing_rate = m_definition.probing_rate;
597 let m_type = m_definition.m_type;
598 let is_ipv6 = m_definition.is_ipv6;
599 let dst_addresses = m_definition
600 .targets
601 .expect("Received measurement with no targets")
602 .dst_list;
603 let dns_record = m_definition.record;
604 let info_url = m_definition.url;
605
606 println!("[Orchestrator] {number_of_probing_workers} workers will probe, {number_of_workers} will listen ({worker_interval} seconds between probing workers)");
607
608 let (tx, rx) = mpsc::channel::<Result<TaskResult, Status>>(1000);
610 let _ = self.cli_sender.lock().unwrap().insert(tx);
612
613 let mut rx_origins = vec![];
615 for configuration in m_definition.configurations.iter() {
617 if let Some(origin) = &configuration.origin {
618 if !rx_origins.contains(origin) {
620 rx_origins.push(*origin);
621 }
622 }
623 }
624
625 let (tx_t, rx_t) = mpsc::channel::<(u32, Task, bool)>(1000);
627
628 let probing_worker_ids = workers
630 .iter()
631 .filter(|sender| sender.is_probing())
632 .map(|sender| sender.worker_id)
633 .collect::<Vec<u32>>();
634
635 for worker in workers.iter() {
637 let worker_id = worker.worker_id;
638 let mut tx_origins = vec![];
639
640 for configuration in &m_definition.configurations {
642 if (configuration.worker_id == worker_id) | (configuration.worker_id == u32::MAX) {
644 if let Some(origin) = &configuration.origin {
645 tx_origins.push(*origin);
646 }
647 }
648 }
649
650 let start_task = Task {
651 worker_id: None,
652 data: Some(TaskStart(Start {
653 rate: probing_rate,
654 m_id,
655 m_type,
656 is_unicast,
657 is_ipv6,
658 tx_origins,
659 rx_origins: rx_origins.clone(),
660 record: dns_record.clone(),
661 url: info_url.clone(),
662 is_latency,
663 })),
664 };
665
666 tx_t.send((worker_id, start_task, false))
667 .await
668 .expect("Failed to send task to TaskDistributor");
669 }
670
671 spawn(async move {
672 task_sender(
673 rx_t,
674 workers,
675 worker_interval,
676 probe_interval,
677 number_of_probes,
678 )
679 .await;
680 });
681
682 tokio::time::sleep(Duration::from_secs(1)).await;
684
685 self.is_responsive
686 .store(is_responsive, std::sync::atomic::Ordering::SeqCst);
687 println!("[Orchestrator] Responsive probing mode: {}", is_responsive);
688 self.is_latency
689 .store(is_latency, std::sync::atomic::Ordering::SeqCst);
690
691 let mut probing_rate_interval = if is_latency || is_divide {
692 tokio::time::interval(Duration::from_secs(1) / number_of_probing_workers as u32)
694 } else {
695 tokio::time::interval(Duration::from_secs(1))
697 };
698
699 let is_active = self.is_active.clone();
700
701 let is_discovery = if is_responsive || is_latency {
702 Some(true)
704 } else {
705 None
706 };
707
708 if is_divide || is_responsive || is_latency {
709 println!("[Orchestrator] Starting Round-Robin Task Distributor.");
710
711 let is_latency_signal = self.is_latency.clone();
712 let is_responsive_signal = self.is_responsive.clone();
713 let mut cooldown_timer: Option<Instant> = None;
714
715 let worker_stacks = self.worker_stacks.clone();
717
718 spawn(async move {
719 let mut sender_cycler = probing_worker_ids.into_iter().cycle();
721 let mut hitlist_iter = dst_addresses.into_iter();
726 let mut hitlist_is_empty = false;
727
728 loop {
729 if !(*is_active.lock().unwrap()) {
730 println!("[Orchestrator] CLI disconnected; ending measurement");
731 break;
732 }
733
734 let worker_id = sender_cycler.next().expect("No probing workers available");
736
737 let f_worker_id = if is_responsive {
738 ALL_WORKERS_INTERVAL
740 } else {
741 worker_id
742 };
743
744 let mut follow_ups: Vec<Address> = Vec::new();
746
747 {
749 let mut stacks = worker_stacks.lock().unwrap();
750 if let Some(queue) = stacks.get_mut(&f_worker_id) {
751 let num_to_take = std::cmp::min(probing_rate as usize, queue.len());
752
753 follow_ups.extend(queue.drain(..num_to_take));
754 }
755 }
756
757 let follow_ups_len = follow_ups.len();
758
759 if follow_ups_len > 0 {
761 let task = Task {
762 worker_id: None,
763 data: Some(custom_module::manycastr::task::Data::Targets(Targets {
764 dst_list: follow_ups,
765 is_discovery: None,
766 })),
767 };
768
769 tx_t.send((f_worker_id, task, true))
770 .await
771 .expect("Failed to send task to TaskDistributor");
772 }
773
774 let remainder_needed = if is_responsive {
776 probing_rate as usize
778 } else {
779 (probing_rate as usize).saturating_sub(follow_ups_len)
780 };
781
782 let hitlist_targets = if remainder_needed > 0 && !hitlist_is_empty {
783 let addresses_from_hitlist: Vec<Address> =
785 hitlist_iter.by_ref().take(remainder_needed).collect();
786
787 if (addresses_from_hitlist.len() < remainder_needed) && !hitlist_is_empty {
789 println!("[Orchestrator] All discovery probes sent, awaiting follow-up probes.");
790 hitlist_is_empty = true;
791 }
792
793 addresses_from_hitlist
794 } else {
795 Vec::new()
796 };
797
798 if !hitlist_targets.is_empty() {
800 let task = Task {
802 worker_id: None,
803 data: Some(custom_module::manycastr::task::Data::Targets(Targets {
804 dst_list: hitlist_targets,
805 is_discovery,
806 })),
807 };
808
809 tx_t.send((worker_id, task, is_discovery.is_none()))
810 .await
811 .expect("Failed to send task to TaskDistributor");
812 }
813
814 if hitlist_is_empty {
815 if let Some(start_time) = cooldown_timer {
816 if start_time.elapsed()
817 >= Duration::from_secs(
818 number_of_probing_workers as u64 * worker_interval + 5,
819 )
820 {
821 println!("[Orchestrator] Task distribution finished.");
822 break;
823 }
824 } else {
825 let all_stacks_empty = {
827 let stacks_guard = worker_stacks.lock().unwrap();
828 stacks_guard.values().all(|queue| queue.is_empty())
829 };
830 if all_stacks_empty {
831 println!(
832 "[Orchestrator] No more tasks. Waiting {} seconds for cooldown.",
833 number_of_probing_workers as u64 * worker_interval + 5
834 );
835 cooldown_timer = Some(Instant::now());
836 }
837 }
838 }
839
840 probing_rate_interval.tick().await;
841 }
842
843 tx_t.send((
845 ALL_WORKERS_DIRECT,
846 Task {
847 worker_id: None,
848 data: Some(TaskEnd(End { code: 0 })),
849 },
850 false,
851 ))
852 .await
853 .expect("Failed to send end task to TaskDistributor");
854
855 while *is_active.lock().unwrap() {
857 tokio::time::sleep(Duration::from_secs(1)).await;
858 }
859
860 is_latency_signal.store(false, std::sync::atomic::Ordering::SeqCst);
862 is_responsive_signal.store(false, std::sync::atomic::Ordering::SeqCst);
863
864 {
866 let mut stacks_guard = worker_stacks.lock().unwrap();
867 *stacks_guard = HashMap::new();
868 }
869
870 tx_t.send((
872 BREAK_SIGNAL,
873 Task {
874 worker_id: None,
875 data: None,
876 },
877 false,
878 ))
879 .await
880 .expect("Failed to send end task to TaskDistributor");
881 });
882 } else {
883 println!("[Orchestrator] Starting Broadcast Task Distributor.");
884 spawn(async move {
885 for chunk in dst_addresses.chunks(probing_rate as usize) {
887 if !(*is_active.lock().unwrap()) {
888 println!("[Orchestrator] Measurement no longer active");
889 break;
890 }
891
892 tx_t.send((
893 ALL_WORKERS_INTERVAL,
894 Task {
895 worker_id: None,
896 data: Some(custom_module::manycastr::task::Data::Targets(Targets {
897 dst_list: chunk.to_vec(),
898 is_discovery,
899 })),
900 },
901 is_discovery.is_none(),
902 ))
903 .await
904 .expect("Failed to send task to TaskDistributor");
905
906 probing_rate_interval.tick().await;
907 }
908
909 tokio::time::sleep(Duration::from_secs(
911 (number_of_probing_workers as u64 * worker_interval) + 1,
912 ))
913 .await;
914
915 println!("[Orchestrator] Task distribution finished");
916
917 tx_t.send((
919 ALL_WORKERS_DIRECT,
920 Task {
921 worker_id: None,
922 data: Some(TaskEnd(End { code: 0 })),
923 },
924 false,
925 ))
926 .await
927 .expect("Failed to send end task to TaskDistributor");
928
929 while *is_active.lock().unwrap() {
931 tokio::time::sleep(Duration::from_secs(1)).await;
932 }
933
934 tx_t.send((
936 BREAK_SIGNAL,
937 Task {
938 worker_id: None,
939 data: None,
940 },
941 false,
942 ))
943 .await
944 .expect("Failed to send end task to TaskDistributor");
945 });
946 }
947
948 let rx = CLIReceiver {
949 inner: rx,
950 m_active: self.is_active.clone(),
951 m_id,
952 open_measurements: self.open_measurements.clone(),
953 };
954
955 Ok(Response::new(rx))
956 }
957 async fn list_workers(
961 &self,
962 _request: Request<Empty>,
963 ) -> Result<Response<ServerStatus>, Status> {
964 let workers_list = self.workers.lock().unwrap();
966 let mut workers = Vec::new();
967 for worker in workers_list.iter() {
968 workers.push(Worker {
969 worker_id: worker.worker_id,
970 hostname: worker.hostname.clone(),
971 status: worker.get_status().clone(),
972 unicast_v4: worker.unicast_v4,
973 unicast_v6: worker.unicast_v6,
974 });
975 }
976
977 let status = ServerStatus { workers };
978 Ok(Response::new(status))
979 }
980
981 async fn send_result(&self, request: Request<TaskResult>) -> Result<Response<Ack>, Status> {
991 let task_result = request.into_inner();
993
994 let is_discovery = task_result.is_discovery;
995
996 if is_discovery
998 && (self.is_latency.load(std::sync::atomic::Ordering::SeqCst)
999 || self.is_responsive.load(std::sync::atomic::Ordering::SeqCst))
1000 {
1001 let rx_id = if self.is_responsive.load(std::sync::atomic::Ordering::SeqCst) {
1002 ALL_WORKERS_INTERVAL } else {
1004 task_result.worker_id };
1006
1007 let responsive_targets: Vec<Address> = task_result
1009 .result_list
1010 .iter()
1011 .map(|result| result.src.unwrap())
1012 .collect();
1013
1014 tokio::time::sleep(Duration::from_secs(1)).await;
1016
1017 {
1019 let mut worker_stacks = self.worker_stacks.lock().unwrap();
1020 worker_stacks
1021 .entry(rx_id)
1022 .or_default()
1023 .extend(responsive_targets);
1024 }
1025
1026 return Ok(Response::new(Ack {
1027 is_success: true,
1028 error_message: "".to_string(),
1029 }));
1030 }
1031
1032 let tx = {
1034 let sender = self.cli_sender.lock().unwrap();
1035 sender.clone().unwrap()
1036 };
1037
1038 match tx.send(Ok(task_result)).await {
1039 Ok(_) => Ok(Response::new(Ack {
1040 is_success: true,
1041 error_message: "".to_string(),
1042 })),
1043 Err(_) => Ok(Response::new(Ack {
1044 is_success: false,
1045 error_message: "CLI disconnected".to_string(),
1046 })),
1047 }
1048 }
1049}
1050
1051impl ControllerService {
1052 fn get_unique_id(&self) -> u32 {
1055 let mut unique_id = self.unique_id.lock().unwrap();
1056 let worker_id = *unique_id;
1057 unique_id.add_assign(1);
1058
1059 worker_id
1060 }
1061
1062 fn get_worker_id(&self, hostname: &str) -> Result<(u32, bool), Box<Status>> {
1075 {
1076 let workers = self.workers.lock().unwrap();
1077 if let Some(existing_worker) = workers.iter().find(|w| w.hostname == hostname) {
1079 return if !existing_worker.is_closed() {
1080 println!(
1081 "[Orchestrator] Refusing worker as the hostname already exists: {hostname}"
1082 );
1083 Err(Box::new(Status::already_exists(
1084 "This hostname already exists",
1085 )))
1086 } else {
1087 let id = existing_worker.worker_id;
1089 Ok((id, true))
1090 };
1091 }
1092 }
1093
1094 if let Some(worker_config) = &self.worker_config {
1096 if let Some(worker_id) = worker_config.get(hostname) {
1097 return Ok((*worker_id, false));
1098 }
1099 }
1100
1101 let new_id = self.get_unique_id();
1103 Ok((new_id, false))
1104 }
1105}
1106
1107impl PartialEq<WorkerStatus> for Mutex<WorkerStatus> {
1108 fn eq(&self, other: &WorkerStatus) -> bool {
1109 let status = self.lock().unwrap();
1110 *status == *other
1111 }
1112}
1113
1114async fn task_sender(
1131 mut rx: mpsc::Receiver<(u32, Task, bool)>,
1132 workers: Vec<WorkerSender<Result<Task, Status>>>,
1133 inter_client_interval: u64,
1134 inter_probe_interval: u64,
1135 number_of_probes: u8,
1136) {
1137 while let Some((worker_id, task, multiple)) = rx.recv().await {
1139 let nprobes = if multiple { number_of_probes } else { 1 };
1140
1141 if worker_id == BREAK_SIGNAL {
1142 break;
1143 } else if worker_id == ALL_WORKERS_DIRECT {
1144 for sender in &workers {
1146 sender.send(Ok(task.clone())).await.unwrap_or_else(|e| {
1147 sender.cleanup();
1148 eprintln!(
1149 "[Orchestrator] Failed to send broadcast task to worker {}: {:?}",
1150 sender.hostname, e
1151 );
1152 });
1153 sender.finished();
1154 }
1155 } else if worker_id == ALL_WORKERS_INTERVAL {
1156 let mut probing_index = 0;
1158
1159 for sender in &workers {
1160 if *sender.status == Probing {
1161 let sender_c = sender.clone();
1162 let task_c = task.clone();
1163 spawn(async move {
1164 tokio::time::sleep(Duration::from_secs(
1166 probing_index * inter_client_interval,
1167 ))
1168 .await;
1169
1170 spawn(async move {
1171 for _ in 0..nprobes {
1172 sender_c.send(Ok(task_c.clone())).await.unwrap_or_else(|e| {
1173 sender_c.cleanup();
1174 eprintln!(
1175 "[Orchestrator] Failed to send broadcast task to probing worker {}: {:?}",
1176 sender_c.hostname, e
1177 );
1178 });
1179 tokio::time::sleep(Duration::from_secs(inter_probe_interval)).await;
1181 }
1182 });
1183 });
1184 probing_index += 1;
1185 }
1186 }
1187 } else {
1188 if let Some(sender) = workers.iter().find(|s| s.worker_id == worker_id) {
1190 if nprobes < 2 {
1191 sender.send(Ok(task)).await.unwrap_or_else(|e| {
1192 sender.cleanup();
1193 eprintln!(
1194 "[Orchestrator] Failed to send task to worker {}: {:?}",
1195 sender.hostname, e
1196 );
1197 });
1198 } else {
1199 let sender_clone = sender.clone();
1201 spawn(async move {
1202 for _ in 0..number_of_probes {
1203 sender_clone
1204 .send(Ok(task.clone()))
1205 .await
1206 .unwrap_or_else(|e| {
1207 sender_clone.cleanup();
1208 eprintln!(
1209 "[Orchestrator] Failed to send task to worker {}: {:?}",
1210 sender_clone.hostname, e
1211 );
1212 });
1213 tokio::time::sleep(Duration::from_secs(inter_probe_interval)).await;
1215 }
1216 });
1217 }
1218 } else {
1219 eprintln!("[Orchestrator] No sender found for worker ID {worker_id}");
1220 }
1221 }
1222 }
1223
1224 println!("[Orchestrator] Task distributor finished");
1225}
1226
1227pub async fn start(args: &ArgMatches) -> Result<(), Box<dyn std::error::Error>> {
1235 let port = *args.get_one::<u16>("port").unwrap();
1236 let addr: SocketAddr = format!("[::]:{port}").parse().unwrap();
1237
1238 let (current_worker_id, worker_config) = args
1240 .get_one::<String>("config")
1241 .map(load_worker_config)
1242 .unwrap_or_else(|| (Arc::new(Mutex::new(1)), None));
1243
1244 let m_id = rand::rng().random_range(0..u32::MAX);
1246
1247 let controller = ControllerService {
1248 workers: Arc::new(Mutex::new(Vec::new())),
1249 cli_sender: Arc::new(Mutex::new(None)),
1250 open_measurements: Arc::new(Mutex::new(HashMap::new())),
1251 m_id: Arc::new(Mutex::new(m_id)),
1252 unique_id: current_worker_id,
1253 is_active: Arc::new(Mutex::new(false)),
1254 is_responsive: Arc::new(AtomicBool::new(false)),
1255 is_latency: Arc::new(AtomicBool::new(false)),
1256 worker_config,
1257 worker_stacks: Arc::new(Mutex::new(HashMap::new())),
1258 };
1259
1260 let svc = ControllerServer::new(controller)
1261 .accept_compressed(CompressionEncoding::Zstd)
1262 .max_decoding_message_size(10 * 1024 * 1024 * 1024) .max_encoding_message_size(10 * 1024 * 1024 * 1024);
1264
1265 if args.get_flag("tls") {
1267 println!("[Orchestrator] Starting orchestrator with TLS enabled");
1268 Server::builder()
1269 .tls_config(ServerTlsConfig::new().identity(load_tls()))
1270 .expect("Failed to load TLS certificate")
1271 .http2_keepalive_interval(Some(Duration::from_secs(10)))
1272 .http2_keepalive_timeout(Some(Duration::from_secs(20)))
1273 .tcp_keepalive(Some(Duration::from_secs(30)))
1274 .add_service(svc)
1275 .serve(addr)
1276 .await
1277 .expect("Failed to start orchestrator with TLS");
1278 } else {
1279 Server::builder()
1280 .http2_keepalive_interval(Some(Duration::from_secs(10)))
1281 .http2_keepalive_timeout(Some(Duration::from_secs(20)))
1282 .tcp_keepalive(Some(Duration::from_secs(30)))
1283 .add_service(svc)
1284 .serve(addr)
1285 .await
1286 .expect("Failed to start orchestrator");
1287 }
1288
1289 Ok(())
1290}
1291
1292fn load_worker_config(config_path: &String) -> (Arc<Mutex<u32>>, Option<HashMap<String, u32>>) {
1308 if !Path::new(config_path).exists() {
1309 panic!("[Orchestrator] Configuration file {config_path} not found!");
1310 }
1311
1312 let config_content = fs::read_to_string(config_path)
1313 .expect("[Orchestrator] Could not read the configuration file.");
1314
1315 let mut hosts = HashMap::new();
1316 let mut used_ids = HashSet::new();
1317
1318 for (i, line) in config_content.lines().enumerate() {
1319 let line_number = i + 1;
1320
1321 let trimmed_line = line.trim();
1322
1323 if trimmed_line.is_empty() || trimmed_line.starts_with('#') {
1325 continue;
1326 }
1327
1328 let parts: Vec<&str> = trimmed_line.split(',').collect();
1330 if parts.len() != 2 {
1331 panic!(
1332 "[Orchestrator] Error on line {line_number}: Malformed entry. Expected 'hostname,id', found '{line}'"
1333 );
1334 }
1335
1336 let hostname = parts[0].trim().to_string();
1337 let id = match parts[1].trim().parse::<u32>() {
1338 Ok(val) => val,
1339 Err(_) => {
1340 panic!(
1341 "[Orchestrator] Error on line {}: Invalid ID '{}'. ID must be a non-negative integer.",
1342 line_number, parts[1].trim()
1343 );
1344 }
1345 };
1346
1347 if hosts.contains_key(&hostname) {
1349 panic!(
1350 "[Orchestrator] Error on line {line_number}: Duplicate hostname '{hostname}' found. Hostnames must be unique."
1351 );
1352 }
1353
1354 if !used_ids.insert(id) {
1356 panic!(
1357 "[Orchestrator] Error on line {line_number}: Duplicate ID '{id}' found. IDs must be unique."
1358 );
1359 }
1360
1361 if id == ALL_WORKERS_INTERVAL || id == ALL_WORKERS_DIRECT || id == BREAK_SIGNAL {
1363 panic!(
1364 "[Orchestrator] Error on line {line_number}: ID '{id}' is reserved for special purposes. Please use a different ID."
1365 );
1366 }
1367
1368 hosts.insert(hostname, id);
1369 }
1370
1371 println!("[Orchestrator] {} hosts loaded.", hosts.len());
1372
1373 let current_worker_id = hosts.values().max().map_or(1, |&max_id| max_id + 1);
1375
1376 (Arc::new(Mutex::new(current_worker_id)), Some(hosts))
1377}
1378
1379fn load_tls() -> Identity {
1387 let cert = fs::read("tls/orchestrator.crt")
1389 .expect("Unable to read certificate file at ./tls/orchestrator.crt");
1390 let key = fs::read("tls/orchestrator.key")
1392 .expect("Unable to read key file at ./tls/orchestrator.key");
1393
1394 Identity::from_pem(cert, key)
1396}