manycast/cli/
writer.rs

1use std::fs::File;
2use std::io;
3use std::io::Write;
4
5use bimap::BiHashMap;
6use csv::Writer;
7use tokio::sync::mpsc::UnboundedReceiver;
8
9use custom_module::manycastr::{Configuration, Reply, TaskResult};
10use custom_module::Separated;
11use flate2::write::GzEncoder;
12use flate2::Compression;
13use std::io::BufWriter;
14use std::sync::Arc;
15
16use parquet::basic::{Compression as ParquetCompression, LogicalType, Repetition};
17use parquet::data_type::{ByteArray, DoubleType, Int32Type, Int64Type};
18use parquet::file::properties::WriterProperties;
19use parquet::file::writer::SerializedFileWriter;
20use parquet::schema::types::{Type as SchemaType, TypePtr};
21
22use crate::{custom_module, CHAOS_ID, TCP_ID};
23
24/// Configuration for the results writing process.
25///
26/// This struct bundles all the necessary parameters for `write_results`
27/// to determine where and how to output measurement results,
28/// including formatting options and contextual metadata.
29pub struct WriteConfig<'a> {
30    /// Determines whether the results should also be printed to the command-line interface.
31    pub print_to_cli: bool,
32    /// The file handle to which the measurement results should be written.
33    pub output_file: File,
34    /// Metadata for the measurement, to be written at the beginning of the output file.
35    pub metadata_args: MetadataArgs<'a>,
36    /// The type of measurement being performed, influencing how results are processed or formatted.
37    /// (e.g., 1 for ICMP, 2 for DNS/A, 3 for TCP, 4 for DNS/CHAOS, etc.)
38    pub m_type: u32,
39    /// Indicates whether the measurement involves multiple origins, which affects
40    /// how results are written.
41    pub is_multi_origin: bool,
42    /// Indicates whether the measurement is symmetric (e.g., sender == receiver is always true),
43    /// to simplify certain result interpretations.
44    pub is_symmetric: bool,
45    /// A bidirectional map used to convert worker IDs (u32) to their corresponding hostnames (String).
46    pub worker_map: BiHashMap<u32, String>,
47}
48
49/// Holds all the arguments required to metadata for the output file.
50pub struct MetadataArgs<'a> {
51    /// Divide-and-conquer measurement flag.
52    pub is_divide: bool,
53    /// The origins used in the measurement.
54    pub origin_str: String,
55    /// Path to the hitlist used.
56    pub hitlist: &'a str,
57    /// Whether the hitlist was shuffled.
58    pub is_shuffle: bool,
59    /// A string representation of the measurement type (e.g., "ICMP", "DNS").
60    pub m_type_str: String,
61    /// The probing rate used.
62    pub probing_rate: u32,
63    /// The interval between subsequent workers.
64    pub interval: u32,
65    /// Hostnames of the workers selected to probe.
66    pub active_workers: Vec<String>,
67    /// A bidirectional map of all possible worker IDs to their hostnames.
68    pub all_workers: &'a BiHashMap<u32, String>,
69    /// Optional configuration file used.
70    pub configurations: &'a Vec<Configuration>,
71    /// Whether this is a configuration-based measurement.
72    pub is_config: bool,
73    /// Whether this is a latency-based measurement.
74    pub is_latency: bool,
75    /// Whether this is a responsiveness-based measurement.
76    pub is_responsive: bool,
77}
78
79/// Writes the results to a file (and optionally to the command-line)
80///
81/// # Arguments
82///
83/// * 'rx' - The receiver channel that receives the results
84///
85/// * 'config' - The configuration for writing results, including file handle, metadata, and measurement type
86pub fn write_results(mut rx: UnboundedReceiver<TaskResult>, config: WriteConfig) {
87    // CSV writer to command-line interface
88    let mut wtr_cli = if config.print_to_cli {
89        Some(Writer::from_writer(io::stdout()))
90    } else {
91        None
92    };
93
94    let buffered_file_writer = BufWriter::new(config.output_file);
95    let mut gz_encoder = GzEncoder::new(buffered_file_writer, Compression::default());
96
97    // Write metadata to file
98    let md_lines = get_csv_metadata(config.metadata_args, &config.worker_map);
99    for line in md_lines {
100        if let Err(e) = writeln!(gz_encoder, "{line}") {
101            eprintln!("Failed to write metadata line to Gzip stream: {e}");
102        }
103    }
104
105    // .gz writer
106    let mut wtr_file = Writer::from_writer(gz_encoder);
107
108    // Write header
109    let header = get_header(config.m_type, config.is_multi_origin, config.is_symmetric);
110    if let Some(wtr) = wtr_cli.as_mut() {
111        wtr.write_record(&header)
112            .expect("Failed to write header to stdout")
113    };
114    wtr_file
115        .write_record(header)
116        .expect("Failed to write header to file");
117
118    tokio::spawn(async move {
119        // Receive tasks from the outbound channel
120        while let Some(task_result) = rx.recv().await {
121            if task_result == TaskResult::default() {
122                break;
123            }
124            let results: Vec<Reply> = task_result.result_list;
125            for result in results {
126                let result = get_row(
127                    result,
128                    task_result.worker_id,
129                    config.m_type as u8,
130                    config.is_symmetric,
131                    config.worker_map.clone(),
132                );
133
134                // Write to command-line
135                if let Some(ref mut wtr) = wtr_cli {
136                    wtr.write_record(&result)
137                        .expect("Failed to write payload to CLI");
138                    wtr.flush().expect("Failed to flush stdout");
139                }
140
141                // Write to file
142                wtr_file
143                    .write_record(result)
144                    .expect("Failed to write payload to file");
145            }
146            wtr_file.flush().expect("Failed to flush file");
147        }
148        rx.close();
149        wtr_file.flush().expect("Failed to flush file");
150    });
151}
152
153/// Creates the appropriate CSV header for the results file (based on the measurement type)
154///
155/// # Arguments
156///
157/// * 'measurement_type' - The type of measurement being performed
158///
159/// * 'is_multi_origin' - A boolean that determines whether multiple origins are used
160///
161/// * 'is_symmetric' - A boolean that determines whether the measurement is symmetric (i.e., sender == receiver is always true)
162pub fn get_header(m_type: u32, is_multi_origin: bool, is_symmetric: bool) -> Vec<&'static str> {
163    let mut header = if is_symmetric {
164        vec!["rx", "addr", "ttl", "rtt"]
165    } else {
166        // TCP anycast does not have tx_time
167        if m_type == TCP_ID as u32 {
168            vec!["rx", "rx_time", "addr", "ttl", "tx"]
169        } else {
170            vec!["rx", "rx_time", "addr", "ttl", "tx_time", "tx"]
171        }
172    };
173
174    if m_type == CHAOS_ID as u32 {
175        header.push("chaos_data");
176    }
177
178    if is_multi_origin {
179        header.push("origin_id");
180    }
181
182    header
183}
184
185/// Get the result (csv row) from a Reply message
186///
187/// # Arguments
188///
189/// * `result` - The Reply that is being written to this row
190///
191/// * `rx_worker_id` - The worker ID of the receiver
192///
193/// * `m_type` - The type of measurement being performed
194///
195/// * `is_symmetric` - A boolean that determines whether the measurement is symmetric (i.e., sender == receiver is always true)
196///
197/// * `worker_map` - A map of worker IDs to hostnames, used to convert worker IDs to hostnames in the results
198///
199/// # Returns
200///
201/// A vector of strings representing the row in the CSV file
202fn get_row(
203    result: Reply,
204    rx_worker_id: u32,
205    m_type: u8,
206    is_symmetric: bool,
207    worker_map: BiHashMap<u32, String>,
208) -> Vec<String> {
209    let origin_id = result.origin_id.to_string();
210    let is_multi_origin = result.origin_id != 0 && result.origin_id != u32::MAX;
211    let rx_worker_id = rx_worker_id.to_string();
212    // convert the worker ID to hostname
213    let rx_hostname = worker_map
214        .get_by_left(&rx_worker_id.parse::<u32>().unwrap())
215        .unwrap_or(&String::from("Unknown"))
216        .to_string();
217    let rx_time = result.rx_time.to_string();
218    let tx_time = result.tx_time.to_string();
219    let tx_id = result.tx_id;
220    let ttl = result.ttl.to_string();
221    let reply_src = result.src.unwrap().to_string();
222
223    let mut row = if is_symmetric {
224        let rtt = format!(
225            "{:.2}",
226            calculate_rtt(result.rx_time, result.tx_time, m_type == TCP_ID)
227        );
228        vec![rx_hostname, reply_src, ttl, rtt]
229    } else {
230        let tx_hostname = worker_map
231            .get_by_left(&tx_id)
232            .unwrap_or(&String::from("Unknown"))
233            .to_string();
234
235        // TCP anycast does not have tx_time
236        if m_type == TCP_ID {
237            vec![rx_hostname, rx_time, reply_src, ttl, tx_hostname]
238        } else {
239            vec![rx_hostname, rx_time, reply_src, ttl, tx_time, tx_hostname]
240        }
241    };
242
243    // Optional fields
244    if let Some(chaos) = result.chaos {
245        row.push(chaos);
246    }
247    if is_multi_origin {
248        row.push(origin_id);
249    }
250
251    row
252}
253
254pub fn calculate_rtt(rx_time: u64, tx_time: u64, is_tcp: bool) -> f64 {
255    if is_tcp {
256        let rx_time_ms = rx_time / 1_000;
257        let rx_time_adj = rx_time_ms as u32;
258
259        (rx_time_adj - tx_time as u32) as f64
260    } else {
261        (rx_time - tx_time) as f64 / 1_000.0
262    }
263}
264
265/// Returns a vector of lines containing the metadata of the measurement
266///
267/// # Arguments
268///
269/// Variables describing the measurement
270pub fn get_csv_metadata(
271    args: MetadataArgs<'_>,
272    worker_map: &BiHashMap<u32, String>,
273) -> Vec<String> {
274    let mut md_file = Vec::new();
275    if args.is_divide {
276        md_file.push("# Measurement style: Divide-and-conquer".to_string());
277    } else if args.is_latency {
278        md_file.push("# Measurement style: Anycast latency".to_string());
279    } else if args.is_responsive {
280        md_file.push("# Measurement style: Responsive-mode".to_string());
281    }
282    md_file.push(format!("# Origin used: {}", args.origin_str));
283    md_file.push(format!(
284        "# Hitlist{}: {}",
285        if args.is_shuffle { " (shuffled)" } else { "" },
286        args.hitlist
287    ));
288    md_file.push(format!("# Measurement type: {}", args.m_type_str));
289    md_file.push(format!(
290        "# Probing rate: {}",
291        args.probing_rate.with_separator()
292    ));
293    md_file.push(format!("# Worker interval: {}", args.interval));
294    if !args.active_workers.is_empty() {
295        md_file.push(format!(
296            "# Selective probing using the following workers: {:?}",
297            args.active_workers
298        ));
299    }
300    md_file.push(format!("# {} connected workers:", args.all_workers.len()));
301    for (_, hostname) in args.all_workers {
302        md_file.push(format!("# * {hostname}"))
303    }
304
305    // Write configurations used for the measurement
306    if args.is_config {
307        md_file.push("# Configurations:".to_string());
308        for configuration in args.configurations {
309            let origin = configuration.origin.unwrap();
310            let src = origin.src.expect("Invalid source address");
311            let hostname = if configuration.worker_id == u32::MAX {
312                "ALL".to_string()
313            } else {
314                worker_map
315                    .get_by_left(&configuration.worker_id)
316                    .unwrap_or(&String::from("Unknown"))
317                    .to_string()
318            };
319            md_file.push(format!(
320                "# * {:<2}, source IP: {}, source port: {}, destination port: {}",
321                hostname, src, origin.sport, origin.dport
322            ));
323        }
324    }
325
326    md_file
327}
328
329/// Returns a vector of key-value pairs containing the metadata of the measurement.
330pub fn get_parquet_metadata(
331    args: MetadataArgs<'_>,
332    worker_map: &BiHashMap<u32, String>,
333) -> Vec<(String, String)> {
334    let mut md = Vec::new();
335
336    if args.is_divide {
337        md.push((
338            "measurement_style".to_string(),
339            "Divide-and-conquer".to_string(),
340        ));
341    }
342    if args.is_latency {
343        md.push((
344            "measurement_style".to_string(),
345            "Anycast-latency".to_string(),
346        ));
347    }
348    if args.is_responsive {
349        md.push((
350            "measurement_style".to_string(),
351            "Responsive-mode".to_string(),
352        ));
353    }
354
355    md.push(("origin_used".to_string(), args.origin_str));
356    md.push(("hitlist_path".to_string(), args.hitlist.to_string()));
357    md.push(("hitlist_shuffled".to_string(), args.is_shuffle.to_string()));
358    md.push(("measurement_type".to_string(), args.m_type_str));
359    // Store numbers without separators for easier parsing later
360    md.push((
361        "probing_rate_pps".to_string(),
362        args.probing_rate.to_string(),
363    ));
364    md.push(("worker_interval_ms".to_string(), args.interval.to_string()));
365
366    // Store active workers as a JSON string
367    if !args.active_workers.is_empty() {
368        md.push((
369            "selective_probing_workers".to_string(),
370            serde_json::to_string(&args.active_workers).unwrap_or_default(),
371        ));
372    }
373
374    let worker_hostnames: Vec<&String> = args.all_workers.right_values().collect();
375    md.push((
376        "connected_workers".to_string(),
377        serde_json::to_string(&worker_hostnames).unwrap_or_default(),
378    ));
379    md.push((
380        "connected_workers_count".to_string(),
381        args.all_workers.len().to_string(),
382    ));
383
384    if args.is_config && !args.configurations.is_empty() {
385        let config_str = args
386            .configurations
387            .iter()
388            .map(|c| {
389                format!(
390                    "Worker: {}, SrcIP: {}, SrcPort: {}, DstPort: {}",
391                    if c.worker_id == u32::MAX {
392                        "ALL".to_string()
393                    } else {
394                        worker_map
395                            .get_by_left(&c.worker_id)
396                            .unwrap_or(&String::from("Unknown"))
397                            .to_string()
398                    },
399                    c.origin
400                        .as_ref()
401                        .and_then(|o| o.src)
402                        .map_or("N/A".to_string(), |s| s.to_string()),
403                    c.origin.as_ref().map_or(0, |o| o.sport),
404                    c.origin.as_ref().map_or(0, |o| o.dport)
405                )
406            })
407            .collect::<Vec<_>>();
408
409        md.push((
410            "configurations".to_string(),
411            serde_json::to_string(&config_str).unwrap_or_default(),
412        ));
413    }
414
415    md
416}
417
418const BATCH_SIZE: usize = 1024;
419
420/// Represents a row of data in the Parquet file format.
421/// Fields used depend on the measurement type and configuration.
422struct ParquetDataRow {
423    /// Hostname of the probe receiver.
424    rx: Option<String>,
425    /// UNIX timestamp in nanoseconds when the reply was received.
426    rx_time: Option<u64>,
427    /// Source address of the reply (as a string).
428    addr: Option<String>,
429    /// Time-to-live (TTL) value of the reply.
430    ttl: Option<u8>,
431    /// UNIX timestamp in nanoseconds when the request was sent.
432    tx_time: Option<u64>,
433    /// Hostname of the probe sender.
434    tx: Option<String>,
435    /// Round-trip time (RTT) in milliseconds.
436    rtt: Option<f64>,
437    /// DNS TXT CHAOS record value.
438    chaos_data: Option<String>,
439    /// Origin ID for multi-origin measurements (source address, ports).
440    origin_id: Option<u8>,
441}
442
443/// Write results to a Parquet file as they are received from the channel.
444/// This function processes the results in batches to optimize writing performance.
445///
446/// # Arguments
447///
448/// * `rx` - The receiver channel that receives the results.
449///
450/// * `config` - The configuration for writing results, including file handle, metadata, and measurement type.
451///
452/// * `metadata_args` - Arguments for generating metadata, including measurement type, origins, and configurations.
453pub fn write_results_parquet(mut rx: UnboundedReceiver<TaskResult>, config: WriteConfig) {
454    let schema = build_parquet_schema(config.m_type, config.is_multi_origin, config.is_symmetric);
455
456    // Get metadata key-value pairs for the Parquet file
457    let key_value_tuples = get_parquet_metadata(config.metadata_args, &config.worker_map);
458
459    // Configure writer properties, including compression and metadata
460    let key_value_metadata: Vec<parquet::file::metadata::KeyValue> = key_value_tuples
461        .into_iter()
462        .map(|(key, value)| parquet::file::metadata::KeyValue::new(key, value))
463        .collect();
464
465    let props = Arc::new(
466        WriterProperties::builder()
467            .set_compression(ParquetCompression::SNAPPY)
468            .set_key_value_metadata(Some(key_value_metadata)) // Use the clean metadata
469            .build(),
470    );
471
472    let mut writer = SerializedFileWriter::new(config.output_file, schema.clone(), props)
473        .expect("Failed to create parquet writer");
474
475    // Get the appropriate header for the Parquet file based on the measurement type and configuration
476    let headers = get_header(config.m_type, config.is_multi_origin, config.is_symmetric);
477
478    tokio::spawn(async move {
479        let mut row_buffer: Vec<ParquetDataRow> = Vec::with_capacity(BATCH_SIZE);
480
481        while let Some(task_result) = rx.recv().await {
482            if task_result == TaskResult::default() {
483                break; // End of stream
484            }
485
486            let worker_id = task_result.worker_id;
487            for reply in task_result.result_list {
488                let parquet_row = reply_to_parquet_row(
489                    reply,
490                    worker_id,
491                    config.m_type as u8,
492                    config.is_symmetric,
493                    &config.worker_map,
494                );
495                row_buffer.push(parquet_row);
496            }
497
498            // If the buffer is full, write the batch to the file
499            if row_buffer.len() >= BATCH_SIZE {
500                write_batch_to_parquet(&mut writer, &row_buffer, &headers)
501                    .expect("Failed to write batch to Parquet file");
502                row_buffer.clear();
503            }
504        }
505
506        // Write any remaining rows in the buffer
507        if !row_buffer.is_empty() {
508            write_batch_to_parquet(&mut writer, &row_buffer, &headers)
509                .expect("Failed to write final batch to Parquet file");
510        }
511
512        writer.close().expect("Failed to close Parquet writer");
513        rx.close();
514    });
515}
516
517/// Creates a parquet data schema from the headers based on the measurement type and configuration.
518fn build_parquet_schema(m_type: u32, is_multi_origin: bool, is_symmetric: bool) -> TypePtr {
519    let headers = get_header(m_type, is_multi_origin, is_symmetric);
520    let mut fields = Vec::new();
521
522    for &header in &headers {
523        let field = match header {
524            "rx" | "addr" | "tx" | "chaos_data" => {
525                SchemaType::primitive_type_builder(header, parquet::basic::Type::BYTE_ARRAY)
526                    .with_repetition(Repetition::OPTIONAL)
527                    .with_logical_type(Some(parquet::basic::LogicalType::String))
528                    .build()
529                    .unwrap()
530            }
531            "rx_time" | "tx_time" => {
532                SchemaType::primitive_type_builder(header, parquet::basic::Type::INT64)
533                    .with_repetition(Repetition::OPTIONAL)
534                    .with_logical_type(Some(LogicalType::Integer {
535                        bit_width: 64,
536                        is_signed: false,
537                    })) // u64
538                    .build()
539                    .unwrap()
540            }
541            "ttl" | "origin_id" => {
542                SchemaType::primitive_type_builder(header, parquet::basic::Type::INT32)
543                    .with_repetition(Repetition::OPTIONAL)
544                    .with_logical_type(Some(LogicalType::Integer {
545                        bit_width: 8,
546                        is_signed: false,
547                    })) // u8
548                    .build()
549                    .unwrap()
550            }
551            "rtt" => SchemaType::primitive_type_builder(header, parquet::basic::Type::DOUBLE)
552                .with_repetition(Repetition::OPTIONAL)
553                .build()
554                .unwrap(),
555            _ => panic!("Unknown header column: {header}"),
556        };
557        fields.push(Arc::new(field));
558    }
559
560    Arc::new(
561        SchemaType::group_type_builder("schema")
562            .with_fields(fields)
563            .build()
564            .unwrap(),
565    )
566}
567
568/// Converts a Reply message into a ParquetDataRow for writing to a Parquet file.
569fn reply_to_parquet_row(
570    result: Reply,
571    rx_worker_id: u32,
572    m_type: u8,
573    is_symmetric: bool,
574    worker_map: &BiHashMap<u32, String>,
575) -> ParquetDataRow {
576    let mut row = ParquetDataRow {
577        rx: worker_map.get_by_left(&rx_worker_id).cloned(),
578        rx_time: Some(result.rx_time),
579        addr: result.src.map(|s| s.to_string()),
580        ttl: Some(result.ttl as u8),
581        tx_time: Some(result.tx_time),
582        tx: None,
583        rtt: None,
584        chaos_data: result.chaos,
585        origin_id: if result.origin_id != 0 && result.origin_id != u32::MAX {
586            Some(result.origin_id as u8)
587        } else {
588            None
589        },
590    };
591
592    if is_symmetric {
593        row.rtt = Some(calculate_rtt(
594            result.rx_time,
595            result.tx_time,
596            m_type == TCP_ID,
597        ));
598        row.rx_time = None;
599        row.tx_time = None;
600    } else {
601        row.tx = worker_map.get_by_left(&result.tx_id).cloned();
602        if m_type == TCP_ID {
603            row.tx_time = None;
604        }
605    }
606
607    row
608}
609
610/// Writes a batch of ParquetDataRow to the Parquet file using the provided writer.
611fn write_batch_to_parquet(
612    writer: &mut SerializedFileWriter<File>,
613    batch: &[ParquetDataRow],
614    headers: &[&str],
615) -> Result<(), parquet::errors::ParquetError> {
616    let mut row_group_writer = writer.next_row_group()?;
617
618    for &header in headers {
619        if let Some(mut col_writer) = row_group_writer.next_column()? {
620            match header {
621                "rx" | "addr" | "tx" | "chaos_data" => {
622                    let mut values = Vec::new();
623                    let def_levels: Vec<i16> = batch
624                        .iter()
625                        .map(|row| {
626                            let opt_val = match header {
627                                "rx" => row.rx.as_ref(),
628                                "addr" => row.addr.as_ref(),
629                                "tx" => row.tx.as_ref(),
630                                "chaos_data" => row.chaos_data.as_ref(),
631                                _ => None,
632                            };
633                            if let Some(val) = opt_val {
634                                values.push(ByteArray::from(val.as_str()));
635                                1 // 1 means the value is defined (not NULL)
636                            } else {
637                                0 // 0 means the value is NULL
638                            }
639                        })
640                        .collect();
641                    col_writer
642                        .typed::<parquet::data_type::ByteArrayType>()
643                        .write_batch(&values, Some(&def_levels), None)?;
644                }
645                "rx_time" | "tx_time" => {
646                    let mut values = Vec::new();
647                    let def_levels: Vec<i16> = batch
648                        .iter()
649                        .map(|row| {
650                            let opt_val = match header {
651                                "rx_time" => row.rx_time,
652                                "tx_time" => row.tx_time,
653                                _ => None,
654                            };
655                            if let Some(val) = opt_val {
656                                values.push(val as i64);
657                                1
658                            } else {
659                                0
660                            }
661                        })
662                        .collect();
663                    col_writer.typed::<Int64Type>().write_batch(
664                        &values,
665                        Some(&def_levels),
666                        None,
667                    )?;
668                }
669                "ttl" | "origin_id" => {
670                    let mut values = Vec::new();
671                    let def_levels: Vec<i16> = batch
672                        .iter()
673                        .map(|row| {
674                            let opt_val = match header {
675                                "ttl" => row.ttl,
676                                "origin_id" => row.origin_id,
677                                _ => None,
678                            };
679                            if let Some(val) = opt_val {
680                                values.push(val as i32);
681                                1
682                            } else {
683                                0
684                            }
685                        })
686                        .collect();
687                    col_writer.typed::<Int32Type>().write_batch(
688                        &values,
689                        Some(&def_levels),
690                        None,
691                    )?;
692                }
693                "rtt" => {
694                    let mut values = Vec::new();
695                    let def_levels: Vec<i16> = batch
696                        .iter()
697                        .map(|row| {
698                            if let Some(val) = row.rtt {
699                                values.push(val);
700                                1 // 1 means the value is defined (not NULL)
701                            } else {
702                                0 // 0 means the value is NULL
703                            }
704                        })
705                        .collect();
706                    col_writer.typed::<DoubleType>().write_batch(
707                        &values,
708                        Some(&def_levels),
709                        None,
710                    )?;
711                }
712                _ => {}
713            }
714            col_writer.close()?;
715        }
716    }
717    row_group_writer.close()?;
718    Ok(())
719}