Skip to main content

tp_lib_core/io/
csv.rs

1//! CSV parsing and writing
2
3use crate::errors::ProjectionError;
4use crate::models::{AssociatedNetElement, GnssPosition, ProjectedPosition, TrainPath};
5use crate::temporal::parse_timestamp_flexible_str;
6use chrono::{DateTime, FixedOffset};
7use polars::prelude::*;
8use std::collections::HashMap;
9
10// CSV column name constants for projected positions output
11const COL_ORIGINAL_LAT: &str = "original_lat";
12const COL_ORIGINAL_LON: &str = "original_lon";
13const COL_ORIGINAL_TIME: &str = "original_time";
14const COL_PROJECTED_LAT: &str = "projected_lat";
15const COL_PROJECTED_LON: &str = "projected_lon";
16const COL_NETELEMENT_ID: &str = "netelement_id";
17const COL_MEASURE_METERS: &str = "measure_meters";
18const COL_PROJECTION_DISTANCE_METERS: &str = "projection_distance_meters";
19const COL_CRS: &str = "crs";
20
21// CSV column names for train path output
22const COL_PROBABILITY: &str = "probability";
23const COL_START_INTRINSIC: &str = "start_intrinsic";
24const COL_END_INTRINSIC: &str = "end_intrinsic";
25const COL_GNSS_START_INDEX: &str = "gnss_start_index";
26const COL_GNSS_END_INDEX: &str = "gnss_end_index";
27
28/// Parse a timestamp string, accepting RFC3339 (with timezone) or a naive
29/// ISO 8601 datetime without timezone (interpreted as local time).
30fn parse_timestamp(s: &str) -> Result<DateTime<FixedOffset>, String> {
31    parse_timestamp_flexible_str(s)
32}
33
34/// Parse GNSS positions from CSV file
35pub fn parse_gnss_csv(
36    path: &str,
37    crs: &str,
38    lat_col: &str,
39    lon_col: &str,
40    time_col: &str,
41) -> Result<Vec<GnssPosition>, ProjectionError> {
42    // Read CSV file using polars
43    let df = CsvReadOptions::default()
44        .with_has_header(true)
45        .try_into_reader_with_file_path(Some(path.into()))
46        .map_err(|e| {
47            ProjectionError::IoError(std::io::Error::new(
48                std::io::ErrorKind::InvalidData,
49                format!("Failed to read CSV: {}", e),
50            ))
51        })?
52        .finish()
53        .map_err(|e| {
54            ProjectionError::IoError(std::io::Error::new(
55                std::io::ErrorKind::InvalidData,
56                format!("Failed to parse CSV: {}", e),
57            ))
58        })?;
59
60    gnss_positions_from_df(df, crs, lat_col, lon_col, time_col)
61}
62
63/// In-memory variant of [`parse_gnss_csv`] that accepts the full CSV text
64/// directly. No disk I/O is performed; required by the .NET bindings for
65/// database-backed callers (FR-012).
66pub fn parse_gnss_csv_str(
67    csv_text: &str,
68    crs: &str,
69    lat_col: &str,
70    lon_col: &str,
71    time_col: &str,
72) -> Result<Vec<GnssPosition>, ProjectionError> {
73    let cursor = std::io::Cursor::new(csv_text.as_bytes().to_vec());
74    let df = CsvReadOptions::default()
75        .with_has_header(true)
76        .into_reader_with_file_handle(cursor)
77        .finish()
78        .map_err(|e| {
79            ProjectionError::IoError(std::io::Error::new(
80                std::io::ErrorKind::InvalidData,
81                format!("Failed to parse CSV: {}", e),
82            ))
83        })?;
84    gnss_positions_from_df(df, crs, lat_col, lon_col, time_col)
85}
86
87/// Shared body: convert a polars DataFrame to a sequence of GnssPosition rows.
88fn gnss_positions_from_df(
89    df: DataFrame,
90    crs: &str,
91    lat_col: &str,
92    lon_col: &str,
93    time_col: &str,
94) -> Result<Vec<GnssPosition>, ProjectionError> {
95    // Handle empty CSV (only headers) - polars can't infer types from empty data
96    if df.height() == 0 {
97        return Ok(Vec::new());
98    }
99
100    // Validate required columns exist
101    let schema = df.schema();
102    if !schema.contains(lat_col) {
103        return Err(ProjectionError::InvalidCoordinate(format!(
104            "Latitude column '{}' not found in CSV",
105            lat_col
106        )));
107    }
108    if !schema.contains(lon_col) {
109        return Err(ProjectionError::InvalidCoordinate(format!(
110            "Longitude column '{}' not found in CSV",
111            lon_col
112        )));
113    }
114    if !schema.contains(time_col) {
115        return Err(ProjectionError::InvalidTimestamp(format!(
116            "Timestamp column '{}' not found in CSV",
117            time_col
118        )));
119    }
120
121    // Get all column names for metadata preservation
122    let all_columns: Vec<String> = schema.iter_names().map(|s| s.to_string()).collect();
123
124    // Check if heading and distance columns exist (optional - US4: T115-T116)
125    let has_heading = schema.contains("heading");
126    let has_distance = schema.contains("distance");
127
128    // Extract required columns
129    let lat_series = df.column(lat_col).map_err(|e| {
130        ProjectionError::InvalidCoordinate(format!("Failed to get latitude: {}", e))
131    })?;
132    let lon_series = df.column(lon_col).map_err(|e| {
133        ProjectionError::InvalidCoordinate(format!("Failed to get longitude: {}", e))
134    })?;
135    let time_series = df.column(time_col).map_err(|e| {
136        ProjectionError::InvalidTimestamp(format!("Failed to get timestamp: {}", e))
137    })?;
138
139    // Convert to f64 arrays
140    let lat_array = lat_series.f64().map_err(|e| {
141        ProjectionError::InvalidCoordinate(format!("Latitude must be numeric: {}", e))
142    })?;
143    let lon_array = lon_series.f64().map_err(|e| {
144        ProjectionError::InvalidCoordinate(format!("Longitude must be numeric: {}", e))
145    })?;
146    let time_array = time_series.str().map_err(|e| {
147        ProjectionError::InvalidTimestamp(format!("Timestamp must be string: {}", e))
148    })?;
149
150    // Get optional heading and distance series if they exist
151    let heading_series = if has_heading {
152        Some(df.column("heading").map_err(|e| {
153            ProjectionError::InvalidGeometry(format!("Failed to get heading: {}", e))
154        })?)
155    } else {
156        None
157    };
158
159    let distance_series = if has_distance {
160        Some(df.column("distance").map_err(|e| {
161            ProjectionError::InvalidGeometry(format!("Failed to get distance: {}", e))
162        })?)
163    } else {
164        None
165    };
166
167    // Convert heading and distance to typed arrays
168    let heading_array = heading_series
169        .as_ref()
170        .map(|s| s.f64())
171        .transpose()
172        .map_err(|e| ProjectionError::InvalidGeometry(format!("Heading must be numeric: {}", e)))?;
173
174    let distance_array = distance_series
175        .as_ref()
176        .map(|s| s.f64())
177        .transpose()
178        .map_err(|e| {
179            ProjectionError::InvalidGeometry(format!("Distance must be numeric: {}", e))
180        })?;
181
182    // Build GNSS positions
183    let mut positions = Vec::new();
184    let row_count = df.height();
185
186    for i in 0..row_count {
187        // Get coordinates
188        let latitude = lat_array.get(i).ok_or_else(|| {
189            ProjectionError::InvalidCoordinate(format!("Missing latitude at row {}", i))
190        })?;
191        let longitude = lon_array.get(i).ok_or_else(|| {
192            ProjectionError::InvalidCoordinate(format!("Missing longitude at row {}", i))
193        })?;
194
195        // Get and parse timestamp
196        let time_str = time_array.get(i).ok_or_else(|| {
197            ProjectionError::InvalidTimestamp(format!("Missing timestamp at row {}", i))
198        })?;
199
200        let timestamp = parse_timestamp(time_str).map_err(|e| {
201            ProjectionError::InvalidTimestamp(format!(
202                "Invalid timestamp '{}' at row {}: {}",
203                time_str, i, e
204            ))
205        })?;
206
207        // Build metadata from other columns
208        let mut metadata = HashMap::new();
209        for col_name in &all_columns {
210            if col_name != lat_col
211                && col_name != lon_col
212                && col_name != time_col
213                && col_name != "heading"
214                && col_name != "distance"
215            {
216                if let Ok(series) = df.column(col_name) {
217                    if let Ok(str_series) = series.cast(&DataType::String) {
218                        if let Ok(str_chunked) = str_series.str() {
219                            if let Some(value) = str_chunked.get(i) {
220                                metadata.insert(col_name.clone(), value.to_string());
221                            }
222                        }
223                    }
224                }
225            }
226        }
227
228        // Extract heading if present (0-360°), validate range
229        let heading = heading_array.as_ref().and_then(|arr| arr.get(i));
230        if let Some(h) = heading {
231            if !(0.0..=360.0).contains(&h) {
232                return Err(ProjectionError::InvalidGeometry(format!(
233                    "Heading must be in [0, 360], got {} at row {}",
234                    h, i
235                )));
236            }
237        }
238
239        // Extract distance if present (must be >= 0)
240        let distance = distance_array.as_ref().and_then(|arr| arr.get(i));
241        if let Some(d) = distance {
242            if d < 0.0 {
243                return Err(ProjectionError::InvalidGeometry(format!(
244                    "Distance must be >= 0, got {} at row {}",
245                    d, i
246                )));
247            }
248        }
249
250        // Create GNSS position with heading and distance if available (US4: T115-T116)
251        let mut position = GnssPosition::with_heading_distance(
252            latitude,
253            longitude,
254            timestamp,
255            crs.to_string(),
256            heading,
257            distance,
258        )?;
259        position.metadata = metadata;
260        positions.push(position);
261    }
262
263    Ok(positions)
264}
265
266/// Write projected positions to CSV
267pub fn write_csv(
268    positions: &[ProjectedPosition],
269    writer: &mut impl std::io::Write,
270) -> Result<(), ProjectionError> {
271    use csv::Writer;
272
273    let mut csv_writer = Writer::from_writer(writer);
274
275    // Write header
276    csv_writer.write_record([
277        COL_ORIGINAL_LAT,
278        COL_ORIGINAL_LON,
279        COL_ORIGINAL_TIME,
280        COL_PROJECTED_LAT,
281        COL_PROJECTED_LON,
282        COL_NETELEMENT_ID,
283        COL_MEASURE_METERS,
284        COL_PROJECTION_DISTANCE_METERS,
285        COL_CRS,
286    ])?;
287
288    // Write data rows
289    for pos in positions {
290        csv_writer.write_record(&[
291            pos.original.latitude.to_string(),
292            pos.original.longitude.to_string(),
293            pos.original.timestamp.to_rfc3339(),
294            pos.projected_coords.y().to_string(),
295            pos.projected_coords.x().to_string(),
296            pos.netelement_id.clone(),
297            pos.measure_meters.to_string(),
298            pos.projection_distance_meters.to_string(),
299            pos.crs.clone(),
300        ])?;
301    }
302
303    csv_writer.flush()?;
304    Ok(())
305}
306
307/// Write TrainPath to CSV
308///
309/// Output format: One row per segment with columns:
310/// - netelement_id: ID of the netelement
311/// - probability: Segment probability (0.0 to 1.0)
312/// - start_intrinsic: Entry point on netelement (0.0 to 1.0)
313/// - end_intrinsic: Exit point on netelement (0.0 to 1.0)
314/// - gnss_start_index: First GNSS position index
315/// - gnss_end_index: Last GNSS position index
316///
317/// The overall_probability is written as a comment in the first line.
318///
319/// # Example Output
320///
321/// ```csv
322/// # overall_probability: 0.89
323/// netelement_id,probability,start_intrinsic,end_intrinsic,gnss_start_index,gnss_end_index
324/// NE_A,0.87,0.0,1.0,0,10
325/// NE_B,0.92,0.0,1.0,11,18
326/// ```
327pub fn write_trainpath_csv(
328    train_path: &TrainPath,
329    writer: &mut impl std::io::Write,
330) -> Result<(), ProjectionError> {
331    use csv::Writer;
332
333    // Write overall probability as comment
334    writeln!(
335        writer,
336        "# overall_probability: {}",
337        train_path.overall_probability
338    )?;
339
340    if let Some(calculated_at) = &train_path.calculated_at {
341        writeln!(writer, "# calculated_at: {}", calculated_at.to_rfc3339())?;
342    }
343
344    let mut csv_writer = Writer::from_writer(writer);
345
346    // Write header
347    csv_writer.write_record([
348        COL_NETELEMENT_ID,
349        COL_PROBABILITY,
350        COL_START_INTRINSIC,
351        COL_END_INTRINSIC,
352        COL_GNSS_START_INDEX,
353        COL_GNSS_END_INDEX,
354    ])?;
355
356    // Write data rows
357    for segment in &train_path.segments {
358        csv_writer.write_record(&[
359            segment.netelement_id.clone(),
360            segment.probability.to_string(),
361            segment.start_intrinsic.to_string(),
362            segment.end_intrinsic.to_string(),
363            segment.gnss_start_index.to_string(),
364            segment.gnss_end_index.to_string(),
365        ])?;
366    }
367
368    csv_writer.flush()?;
369    Ok(())
370}
371
372/// Parse TrainPath from CSV
373///
374/// Reads a CSV file in the format produced by write_trainpath_csv.
375/// Expects columns: netelement_id, probability, start_intrinsic, end_intrinsic,
376/// gnss_start_index, gnss_end_index
377///
378/// The overall_probability can be specified in a comment line starting with
379/// `# overall_probability:` or will default to the average of segment probabilities.
380///
381/// # Arguments
382///
383/// * `path` - Path to CSV file
384///
385/// # Returns
386///
387/// A TrainPath struct reconstructed from the CSV data
388pub fn parse_trainpath_csv(path: &str) -> Result<TrainPath, ProjectionError> {
389    // Read the file to extract comment lines and filter them out
390    let file_content = std::fs::read_to_string(path)?;
391    let mut overall_probability: Option<f64> = None;
392    let mut calculated_at: Option<chrono::DateTime<chrono::Utc>> = None;
393    let mut csv_lines = Vec::new();
394
395    // Parse comment lines and collect non-comment lines
396    for line in file_content.lines() {
397        if let Some(comment) = line.strip_prefix('#') {
398            let comment = comment.trim();
399            if let Some(value) = comment.strip_prefix("overall_probability:") {
400                overall_probability = value.trim().parse().ok();
401            } else if let Some(value) = comment.strip_prefix("calculated_at:") {
402                if let Ok(dt) = parse_timestamp_flexible_str(value.trim()) {
403                    calculated_at = Some(dt.with_timezone(&chrono::Utc));
404                }
405            }
406        } else {
407            csv_lines.push(line);
408        }
409    }
410
411    // Write filtered CSV to temporary string for polars
412    // Use thread ID and timestamp to avoid race conditions with parallel tests
413    let filtered_csv = csv_lines.join("\n");
414    let unique_id = format!(
415        "{}_{:?}",
416        std::process::id(),
417        std::time::SystemTime::now()
418            .duration_since(std::time::UNIX_EPOCH)
419            .map(|d| d.as_nanos())
420            .unwrap_or(0)
421    );
422    let temp_file = std::env::temp_dir().join(format!("trainpath_{}.csv", unique_id));
423    std::fs::write(&temp_file, filtered_csv)?;
424
425    // Read CSV using polars
426    let df = CsvReadOptions::default()
427        .with_has_header(true)
428        .try_into_reader_with_file_path(Some(temp_file.clone()))
429        .map_err(|e| {
430            ProjectionError::IoError(std::io::Error::new(
431                std::io::ErrorKind::InvalidData,
432                format!("Failed to read TrainPath CSV: {}", e),
433            ))
434        })?
435        .finish()
436        .map_err(|e| {
437            ProjectionError::IoError(std::io::Error::new(
438                std::io::ErrorKind::InvalidData,
439                format!("Failed to parse TrainPath CSV: {}", e),
440            ))
441        })?;
442
443    // Clean up temp file
444    let _ = std::fs::remove_file(temp_file);
445
446    // Handle empty CSV (only headers)
447    if df.height() == 0 {
448        return TrainPath::new(Vec::new(), 1.0, None, None);
449    }
450
451    // Extract columns and cast to correct types
452    let netelement_id = df
453        .column("netelement_id")
454        .map_err(|e| ProjectionError::GeoJsonError(format!("Missing netelement_id column: {}", e)))?
455        .str()
456        .map_err(|e| ProjectionError::GeoJsonError(format!("netelement_id must be string: {}", e)))?
457        .clone();
458
459    let probability_series = df
460        .column("probability")
461        .map_err(|e| ProjectionError::GeoJsonError(format!("Missing probability column: {}", e)))?
462        .cast(&DataType::Float64)
463        .map_err(|e| ProjectionError::GeoJsonError(format!("probability cast failed: {}", e)))?;
464    let probability = probability_series.f64().map_err(|e| {
465        ProjectionError::GeoJsonError(format!("probability must be numeric: {}", e))
466    })?;
467
468    let start_intrinsic_series = df
469        .column("start_intrinsic")
470        .map_err(|e| {
471            ProjectionError::GeoJsonError(format!("Missing start_intrinsic column: {}", e))
472        })?
473        .cast(&DataType::Float64)
474        .map_err(|e| {
475            ProjectionError::GeoJsonError(format!("start_intrinsic cast failed: {}", e))
476        })?;
477    let start_intrinsic = start_intrinsic_series.f64().map_err(|e| {
478        ProjectionError::GeoJsonError(format!("start_intrinsic must be numeric: {}", e))
479    })?;
480
481    let end_intrinsic_series = df
482        .column("end_intrinsic")
483        .map_err(|e| ProjectionError::GeoJsonError(format!("Missing end_intrinsic column: {}", e)))?
484        .cast(&DataType::Float64)
485        .map_err(|e| ProjectionError::GeoJsonError(format!("end_intrinsic cast failed: {}", e)))?;
486    let end_intrinsic = end_intrinsic_series.f64().map_err(|e| {
487        ProjectionError::GeoJsonError(format!("end_intrinsic must be numeric: {}", e))
488    })?;
489
490    let gnss_start_index_series = df
491        .column("gnss_start_index")
492        .map_err(|e| {
493            ProjectionError::GeoJsonError(format!("Missing gnss_start_index column: {}", e))
494        })?
495        .cast(&DataType::UInt32)
496        .map_err(|e| {
497            ProjectionError::GeoJsonError(format!("gnss_start_index cast failed: {}", e))
498        })?;
499    let gnss_start_index = gnss_start_index_series.u32().map_err(|e| {
500        ProjectionError::GeoJsonError(format!("gnss_start_index must be integer: {}", e))
501    })?;
502
503    let gnss_end_index_series = df
504        .column("gnss_end_index")
505        .map_err(|e| {
506            ProjectionError::GeoJsonError(format!("Missing gnss_end_index column: {}", e))
507        })?
508        .cast(&DataType::UInt32)
509        .map_err(|e| ProjectionError::GeoJsonError(format!("gnss_end_index cast failed: {}", e)))?;
510    let gnss_end_index = gnss_end_index_series.u32().map_err(|e| {
511        ProjectionError::GeoJsonError(format!("gnss_end_index must be integer: {}", e))
512    })?;
513
514    // Build segments
515    let mut segments = Vec::new();
516    let row_count = df.height();
517
518    for i in 0..row_count {
519        let id = netelement_id
520            .get(i)
521            .ok_or_else(|| {
522                ProjectionError::GeoJsonError(format!("Missing netelement_id at row {}", i))
523            })?
524            .to_string();
525
526        let prob = probability.get(i).ok_or_else(|| {
527            ProjectionError::GeoJsonError(format!("Missing probability at row {}", i))
528        })?;
529
530        let start_intr = start_intrinsic.get(i).ok_or_else(|| {
531            ProjectionError::GeoJsonError(format!("Missing start_intrinsic at row {}", i))
532        })?;
533
534        let end_intr = end_intrinsic.get(i).ok_or_else(|| {
535            ProjectionError::GeoJsonError(format!("Missing end_intrinsic at row {}", i))
536        })?;
537
538        let start_idx = gnss_start_index.get(i).ok_or_else(|| {
539            ProjectionError::GeoJsonError(format!("Missing gnss_start_index at row {}", i))
540        })? as usize;
541
542        let end_idx = gnss_end_index.get(i).ok_or_else(|| {
543            ProjectionError::GeoJsonError(format!("Missing gnss_end_index at row {}", i))
544        })? as usize;
545
546        let segment =
547            AssociatedNetElement::new(id, prob, start_intr, end_intr, start_idx, end_idx)?;
548
549        segments.push(segment);
550    }
551
552    // Calculate overall probability if not provided
553    let overall_prob = overall_probability.unwrap_or_else(|| {
554        let sum: f64 = segments.iter().map(|s| s.probability).sum();
555        sum / segments.len() as f64
556    });
557
558    // Create TrainPath
559    TrainPath::new(segments, overall_prob, calculated_at, None)
560}
561
562#[cfg(test)]
563mod tests;
564
565pub mod detections;