Skip to main content

tp_lib_core/io/
csv.rs

1//! CSV parsing and writing
2
3use crate::errors::ProjectionError;
4use crate::models::{AssociatedNetElement, GnssPosition, ProjectedPosition, TrainPath};
5use chrono::{DateTime, FixedOffset, NaiveDateTime, TimeZone, Utc};
6use polars::prelude::*;
7use std::collections::HashMap;
8
9// CSV column name constants for projected positions output
10const COL_ORIGINAL_LAT: &str = "original_lat";
11const COL_ORIGINAL_LON: &str = "original_lon";
12const COL_ORIGINAL_TIME: &str = "original_time";
13const COL_PROJECTED_LAT: &str = "projected_lat";
14const COL_PROJECTED_LON: &str = "projected_lon";
15const COL_NETELEMENT_ID: &str = "netelement_id";
16const COL_MEASURE_METERS: &str = "measure_meters";
17const COL_PROJECTION_DISTANCE_METERS: &str = "projection_distance_meters";
18const COL_CRS: &str = "crs";
19
20// CSV column names for train path output
21const COL_PROBABILITY: &str = "probability";
22const COL_START_INTRINSIC: &str = "start_intrinsic";
23const COL_END_INTRINSIC: &str = "end_intrinsic";
24const COL_GNSS_START_INDEX: &str = "gnss_start_index";
25const COL_GNSS_END_INDEX: &str = "gnss_end_index";
26
27/// Parse a timestamp string, accepting RFC3339 (with timezone) or ISO 8601 without timezone
28/// (assumed to be UTC).
29fn parse_timestamp(s: &str) -> Result<DateTime<FixedOffset>, String> {
30    // First try full RFC3339 (includes timezone)
31    if let Ok(dt) = DateTime::parse_from_rfc3339(s) {
32        return Ok(dt);
33    }
34    // Fall back: treat timezone-less ISO 8601 datetime as UTC
35    let naive = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
36        .or_else(|_| NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S"))
37        .map_err(|e| {
38            format!(
39                "{} (expected RFC3339 with timezone, e.g., 2025-12-09T14:30:00+01:00, or ISO 8601 without timezone assumed UTC)",
40                e
41            )
42        })?;
43    Ok(Utc.from_utc_datetime(&naive).fixed_offset())
44}
45
46/// Parse GNSS positions from CSV file
47pub fn parse_gnss_csv(
48    path: &str,
49    crs: &str,
50    lat_col: &str,
51    lon_col: &str,
52    time_col: &str,
53) -> Result<Vec<GnssPosition>, ProjectionError> {
54    // Read CSV file using polars
55    let df = CsvReadOptions::default()
56        .with_has_header(true)
57        .try_into_reader_with_file_path(Some(path.into()))
58        .map_err(|e| {
59            ProjectionError::IoError(std::io::Error::new(
60                std::io::ErrorKind::InvalidData,
61                format!("Failed to read CSV: {}", e),
62            ))
63        })?
64        .finish()
65        .map_err(|e| {
66            ProjectionError::IoError(std::io::Error::new(
67                std::io::ErrorKind::InvalidData,
68                format!("Failed to parse CSV: {}", e),
69            ))
70        })?;
71
72    // Handle empty CSV (only headers) - polars can't infer types from empty data
73    if df.height() == 0 {
74        return Ok(Vec::new());
75    }
76
77    // Validate required columns exist
78    let schema = df.schema();
79    if !schema.contains(lat_col) {
80        return Err(ProjectionError::InvalidCoordinate(format!(
81            "Latitude column '{}' not found in CSV",
82            lat_col
83        )));
84    }
85    if !schema.contains(lon_col) {
86        return Err(ProjectionError::InvalidCoordinate(format!(
87            "Longitude column '{}' not found in CSV",
88            lon_col
89        )));
90    }
91    if !schema.contains(time_col) {
92        return Err(ProjectionError::InvalidTimestamp(format!(
93            "Timestamp column '{}' not found in CSV",
94            time_col
95        )));
96    }
97
98    // Get all column names for metadata preservation
99    let all_columns: Vec<String> = schema.iter_names().map(|s| s.to_string()).collect();
100
101    // Check if heading and distance columns exist (optional - US4: T115-T116)
102    let has_heading = schema.contains("heading");
103    let has_distance = schema.contains("distance");
104
105    // Extract required columns
106    let lat_series = df.column(lat_col).map_err(|e| {
107        ProjectionError::InvalidCoordinate(format!("Failed to get latitude: {}", e))
108    })?;
109    let lon_series = df.column(lon_col).map_err(|e| {
110        ProjectionError::InvalidCoordinate(format!("Failed to get longitude: {}", e))
111    })?;
112    let time_series = df.column(time_col).map_err(|e| {
113        ProjectionError::InvalidTimestamp(format!("Failed to get timestamp: {}", e))
114    })?;
115
116    // Convert to f64 arrays
117    let lat_array = lat_series.f64().map_err(|e| {
118        ProjectionError::InvalidCoordinate(format!("Latitude must be numeric: {}", e))
119    })?;
120    let lon_array = lon_series.f64().map_err(|e| {
121        ProjectionError::InvalidCoordinate(format!("Longitude must be numeric: {}", e))
122    })?;
123    let time_array = time_series.str().map_err(|e| {
124        ProjectionError::InvalidTimestamp(format!("Timestamp must be string: {}", e))
125    })?;
126
127    // Get optional heading and distance series if they exist
128    let heading_series = if has_heading {
129        Some(df.column("heading").map_err(|e| {
130            ProjectionError::InvalidGeometry(format!("Failed to get heading: {}", e))
131        })?)
132    } else {
133        None
134    };
135
136    let distance_series = if has_distance {
137        Some(df.column("distance").map_err(|e| {
138            ProjectionError::InvalidGeometry(format!("Failed to get distance: {}", e))
139        })?)
140    } else {
141        None
142    };
143
144    // Convert heading and distance to typed arrays
145    let heading_array = heading_series
146        .as_ref()
147        .map(|s| s.f64())
148        .transpose()
149        .map_err(|e| ProjectionError::InvalidGeometry(format!("Heading must be numeric: {}", e)))?;
150
151    let distance_array = distance_series
152        .as_ref()
153        .map(|s| s.f64())
154        .transpose()
155        .map_err(|e| {
156            ProjectionError::InvalidGeometry(format!("Distance must be numeric: {}", e))
157        })?;
158
159    // Build GNSS positions
160    let mut positions = Vec::new();
161    let row_count = df.height();
162
163    for i in 0..row_count {
164        // Get coordinates
165        let latitude = lat_array.get(i).ok_or_else(|| {
166            ProjectionError::InvalidCoordinate(format!("Missing latitude at row {}", i))
167        })?;
168        let longitude = lon_array.get(i).ok_or_else(|| {
169            ProjectionError::InvalidCoordinate(format!("Missing longitude at row {}", i))
170        })?;
171
172        // Get and parse timestamp
173        let time_str = time_array.get(i).ok_or_else(|| {
174            ProjectionError::InvalidTimestamp(format!("Missing timestamp at row {}", i))
175        })?;
176
177        let timestamp = parse_timestamp(time_str).map_err(|e| {
178            ProjectionError::InvalidTimestamp(format!(
179                "Invalid timestamp '{}' at row {}: {}",
180                time_str, i, e
181            ))
182        })?;
183
184        // Build metadata from other columns
185        let mut metadata = HashMap::new();
186        for col_name in &all_columns {
187            if col_name != lat_col
188                && col_name != lon_col
189                && col_name != time_col
190                && col_name != "heading"
191                && col_name != "distance"
192            {
193                if let Ok(series) = df.column(col_name) {
194                    if let Ok(str_series) = series.cast(&DataType::String) {
195                        if let Ok(str_chunked) = str_series.str() {
196                            if let Some(value) = str_chunked.get(i) {
197                                metadata.insert(col_name.clone(), value.to_string());
198                            }
199                        }
200                    }
201                }
202            }
203        }
204
205        // Extract heading if present (0-360°), validate range
206        let heading = heading_array.as_ref().and_then(|arr| arr.get(i));
207        if let Some(h) = heading {
208            if !(0.0..=360.0).contains(&h) {
209                return Err(ProjectionError::InvalidGeometry(format!(
210                    "Heading must be in [0, 360], got {} at row {}",
211                    h, i
212                )));
213            }
214        }
215
216        // Extract distance if present (must be >= 0)
217        let distance = distance_array.as_ref().and_then(|arr| arr.get(i));
218        if let Some(d) = distance {
219            if d < 0.0 {
220                return Err(ProjectionError::InvalidGeometry(format!(
221                    "Distance must be >= 0, got {} at row {}",
222                    d, i
223                )));
224            }
225        }
226
227        // Create GNSS position with heading and distance if available (US4: T115-T116)
228        let mut position = GnssPosition::with_heading_distance(
229            latitude,
230            longitude,
231            timestamp,
232            crs.to_string(),
233            heading,
234            distance,
235        )?;
236        position.metadata = metadata;
237        positions.push(position);
238    }
239
240    Ok(positions)
241}
242
243/// Write projected positions to CSV
244pub fn write_csv(
245    positions: &[ProjectedPosition],
246    writer: &mut impl std::io::Write,
247) -> Result<(), ProjectionError> {
248    use csv::Writer;
249
250    let mut csv_writer = Writer::from_writer(writer);
251
252    // Write header
253    csv_writer.write_record([
254        COL_ORIGINAL_LAT,
255        COL_ORIGINAL_LON,
256        COL_ORIGINAL_TIME,
257        COL_PROJECTED_LAT,
258        COL_PROJECTED_LON,
259        COL_NETELEMENT_ID,
260        COL_MEASURE_METERS,
261        COL_PROJECTION_DISTANCE_METERS,
262        COL_CRS,
263    ])?;
264
265    // Write data rows
266    for pos in positions {
267        csv_writer.write_record(&[
268            pos.original.latitude.to_string(),
269            pos.original.longitude.to_string(),
270            pos.original.timestamp.to_rfc3339(),
271            pos.projected_coords.y().to_string(),
272            pos.projected_coords.x().to_string(),
273            pos.netelement_id.clone(),
274            pos.measure_meters.to_string(),
275            pos.projection_distance_meters.to_string(),
276            pos.crs.clone(),
277        ])?;
278    }
279
280    csv_writer.flush()?;
281    Ok(())
282}
283
284/// Write TrainPath to CSV
285///
286/// Output format: One row per segment with columns:
287/// - netelement_id: ID of the netelement
288/// - probability: Segment probability (0.0 to 1.0)
289/// - start_intrinsic: Entry point on netelement (0.0 to 1.0)
290/// - end_intrinsic: Exit point on netelement (0.0 to 1.0)
291/// - gnss_start_index: First GNSS position index
292/// - gnss_end_index: Last GNSS position index
293///
294/// The overall_probability is written as a comment in the first line.
295///
296/// # Example Output
297///
298/// ```csv
299/// # overall_probability: 0.89
300/// netelement_id,probability,start_intrinsic,end_intrinsic,gnss_start_index,gnss_end_index
301/// NE_A,0.87,0.0,1.0,0,10
302/// NE_B,0.92,0.0,1.0,11,18
303/// ```
304pub fn write_trainpath_csv(
305    train_path: &TrainPath,
306    writer: &mut impl std::io::Write,
307) -> Result<(), ProjectionError> {
308    use csv::Writer;
309
310    // Write overall probability as comment
311    writeln!(
312        writer,
313        "# overall_probability: {}",
314        train_path.overall_probability
315    )?;
316
317    if let Some(calculated_at) = &train_path.calculated_at {
318        writeln!(writer, "# calculated_at: {}", calculated_at.to_rfc3339())?;
319    }
320
321    let mut csv_writer = Writer::from_writer(writer);
322
323    // Write header
324    csv_writer.write_record([
325        COL_NETELEMENT_ID,
326        COL_PROBABILITY,
327        COL_START_INTRINSIC,
328        COL_END_INTRINSIC,
329        COL_GNSS_START_INDEX,
330        COL_GNSS_END_INDEX,
331    ])?;
332
333    // Write data rows
334    for segment in &train_path.segments {
335        csv_writer.write_record(&[
336            segment.netelement_id.clone(),
337            segment.probability.to_string(),
338            segment.start_intrinsic.to_string(),
339            segment.end_intrinsic.to_string(),
340            segment.gnss_start_index.to_string(),
341            segment.gnss_end_index.to_string(),
342        ])?;
343    }
344
345    csv_writer.flush()?;
346    Ok(())
347}
348
349/// Parse TrainPath from CSV
350///
351/// Reads a CSV file in the format produced by write_trainpath_csv.
352/// Expects columns: netelement_id, probability, start_intrinsic, end_intrinsic,
353/// gnss_start_index, gnss_end_index
354///
355/// The overall_probability can be specified in a comment line starting with
356/// `# overall_probability:` or will default to the average of segment probabilities.
357///
358/// # Arguments
359///
360/// * `path` - Path to CSV file
361///
362/// # Returns
363///
364/// A TrainPath struct reconstructed from the CSV data
365pub fn parse_trainpath_csv(path: &str) -> Result<TrainPath, ProjectionError> {
366    // Read the file to extract comment lines and filter them out
367    let file_content = std::fs::read_to_string(path)?;
368    let mut overall_probability: Option<f64> = None;
369    let mut calculated_at: Option<chrono::DateTime<chrono::Utc>> = None;
370    let mut csv_lines = Vec::new();
371
372    // Parse comment lines and collect non-comment lines
373    for line in file_content.lines() {
374        if let Some(comment) = line.strip_prefix('#') {
375            let comment = comment.trim();
376            if let Some(value) = comment.strip_prefix("overall_probability:") {
377                overall_probability = value.trim().parse().ok();
378            } else if let Some(value) = comment.strip_prefix("calculated_at:") {
379                if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(value.trim()) {
380                    calculated_at = Some(dt.with_timezone(&chrono::Utc));
381                }
382            }
383        } else {
384            csv_lines.push(line);
385        }
386    }
387
388    // Write filtered CSV to temporary string for polars
389    // Use thread ID and timestamp to avoid race conditions with parallel tests
390    let filtered_csv = csv_lines.join("\n");
391    let unique_id = format!(
392        "{}_{:?}",
393        std::process::id(),
394        std::time::SystemTime::now()
395            .duration_since(std::time::UNIX_EPOCH)
396            .map(|d| d.as_nanos())
397            .unwrap_or(0)
398    );
399    let temp_file = std::env::temp_dir().join(format!("trainpath_{}.csv", unique_id));
400    std::fs::write(&temp_file, filtered_csv)?;
401
402    // Read CSV using polars
403    let df = CsvReadOptions::default()
404        .with_has_header(true)
405        .try_into_reader_with_file_path(Some(temp_file.clone()))
406        .map_err(|e| {
407            ProjectionError::IoError(std::io::Error::new(
408                std::io::ErrorKind::InvalidData,
409                format!("Failed to read TrainPath CSV: {}", e),
410            ))
411        })?
412        .finish()
413        .map_err(|e| {
414            ProjectionError::IoError(std::io::Error::new(
415                std::io::ErrorKind::InvalidData,
416                format!("Failed to parse TrainPath CSV: {}", e),
417            ))
418        })?;
419
420    // Clean up temp file
421    let _ = std::fs::remove_file(temp_file);
422
423    // Handle empty CSV (only headers)
424    if df.height() == 0 {
425        return TrainPath::new(Vec::new(), 1.0, None, None);
426    }
427
428    // Extract columns and cast to correct types
429    let netelement_id = df
430        .column("netelement_id")
431        .map_err(|e| ProjectionError::GeoJsonError(format!("Missing netelement_id column: {}", e)))?
432        .str()
433        .map_err(|e| ProjectionError::GeoJsonError(format!("netelement_id must be string: {}", e)))?
434        .clone();
435
436    let probability_series = df
437        .column("probability")
438        .map_err(|e| ProjectionError::GeoJsonError(format!("Missing probability column: {}", e)))?
439        .cast(&DataType::Float64)
440        .map_err(|e| ProjectionError::GeoJsonError(format!("probability cast failed: {}", e)))?;
441    let probability = probability_series.f64().map_err(|e| {
442        ProjectionError::GeoJsonError(format!("probability must be numeric: {}", e))
443    })?;
444
445    let start_intrinsic_series = df
446        .column("start_intrinsic")
447        .map_err(|e| {
448            ProjectionError::GeoJsonError(format!("Missing start_intrinsic column: {}", e))
449        })?
450        .cast(&DataType::Float64)
451        .map_err(|e| {
452            ProjectionError::GeoJsonError(format!("start_intrinsic cast failed: {}", e))
453        })?;
454    let start_intrinsic = start_intrinsic_series.f64().map_err(|e| {
455        ProjectionError::GeoJsonError(format!("start_intrinsic must be numeric: {}", e))
456    })?;
457
458    let end_intrinsic_series = df
459        .column("end_intrinsic")
460        .map_err(|e| ProjectionError::GeoJsonError(format!("Missing end_intrinsic column: {}", e)))?
461        .cast(&DataType::Float64)
462        .map_err(|e| ProjectionError::GeoJsonError(format!("end_intrinsic cast failed: {}", e)))?;
463    let end_intrinsic = end_intrinsic_series.f64().map_err(|e| {
464        ProjectionError::GeoJsonError(format!("end_intrinsic must be numeric: {}", e))
465    })?;
466
467    let gnss_start_index_series = df
468        .column("gnss_start_index")
469        .map_err(|e| {
470            ProjectionError::GeoJsonError(format!("Missing gnss_start_index column: {}", e))
471        })?
472        .cast(&DataType::UInt32)
473        .map_err(|e| {
474            ProjectionError::GeoJsonError(format!("gnss_start_index cast failed: {}", e))
475        })?;
476    let gnss_start_index = gnss_start_index_series.u32().map_err(|e| {
477        ProjectionError::GeoJsonError(format!("gnss_start_index must be integer: {}", e))
478    })?;
479
480    let gnss_end_index_series = df
481        .column("gnss_end_index")
482        .map_err(|e| {
483            ProjectionError::GeoJsonError(format!("Missing gnss_end_index column: {}", e))
484        })?
485        .cast(&DataType::UInt32)
486        .map_err(|e| ProjectionError::GeoJsonError(format!("gnss_end_index cast failed: {}", e)))?;
487    let gnss_end_index = gnss_end_index_series.u32().map_err(|e| {
488        ProjectionError::GeoJsonError(format!("gnss_end_index must be integer: {}", e))
489    })?;
490
491    // Build segments
492    let mut segments = Vec::new();
493    let row_count = df.height();
494
495    for i in 0..row_count {
496        let id = netelement_id
497            .get(i)
498            .ok_or_else(|| {
499                ProjectionError::GeoJsonError(format!("Missing netelement_id at row {}", i))
500            })?
501            .to_string();
502
503        let prob = probability.get(i).ok_or_else(|| {
504            ProjectionError::GeoJsonError(format!("Missing probability at row {}", i))
505        })?;
506
507        let start_intr = start_intrinsic.get(i).ok_or_else(|| {
508            ProjectionError::GeoJsonError(format!("Missing start_intrinsic at row {}", i))
509        })?;
510
511        let end_intr = end_intrinsic.get(i).ok_or_else(|| {
512            ProjectionError::GeoJsonError(format!("Missing end_intrinsic at row {}", i))
513        })?;
514
515        let start_idx = gnss_start_index.get(i).ok_or_else(|| {
516            ProjectionError::GeoJsonError(format!("Missing gnss_start_index at row {}", i))
517        })? as usize;
518
519        let end_idx = gnss_end_index.get(i).ok_or_else(|| {
520            ProjectionError::GeoJsonError(format!("Missing gnss_end_index at row {}", i))
521        })? as usize;
522
523        let segment =
524            AssociatedNetElement::new(id, prob, start_intr, end_intr, start_idx, end_idx)?;
525
526        segments.push(segment);
527    }
528
529    // Calculate overall probability if not provided
530    let overall_prob = overall_probability.unwrap_or_else(|| {
531        let sum: f64 = segments.iter().map(|s| s.probability).sum();
532        sum / segments.len() as f64
533    });
534
535    // Create TrainPath
536    TrainPath::new(segments, overall_prob, calculated_at, None)
537}
538
539#[cfg(test)]
540mod tests;