Skip to content

Commit cfaa491

Browse files
authored
Move csv_utils to public repo (#387)
* Move csv_utils to public repo * test_utils can't be cfg(test)
1 parent 796e48e commit cfaa491

File tree

5 files changed

+374
-0
lines changed

5 files changed

+374
-0
lines changed

ciphercore-base/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ tonic = { version = "0.9.2", default-features = false, features = [
3737
], optional = true }
3838
aes = "0.8.2"
3939
cipher = { version = "0.4.4", features = ["block-padding"] }
40+
csv = "1.1"
4041

4142
[target.'cfg(target_arch = "wasm32")'.dependencies]
4243
wasm-bindgen = { version = "0.2.86", features = ["serde-serialize"] }

ciphercore-base/src/csv/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub mod test_utils;
2+
pub mod utils;
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
use crate::errors::Result;
2+
3+
fn assert_rows_eq(rows: Vec<csv::StringRecord>, expected_rows: Vec<Vec<&str>>) {
4+
assert_eq!(rows.len(), expected_rows.len());
5+
for (i, (csv_row, expected_row)) in rows.into_iter().zip(expected_rows.into_iter()).enumerate()
6+
{
7+
assert_eq!(csv_row, expected_row, "row {i}");
8+
}
9+
}
10+
11+
pub fn assert_table_eq(
12+
csv_bytes: Vec<u8>,
13+
expected_headers: Vec<&str>,
14+
expected_records: Vec<Vec<&str>>,
15+
) -> Result<()> {
16+
let mut csv = csv::Reader::from_reader(csv_bytes.as_slice());
17+
assert_eq!(csv.headers()?, expected_headers);
18+
let csv_rows = csv
19+
.records()
20+
.collect::<std::result::Result<Vec<csv::StringRecord>, csv::Error>>()?;
21+
assert_rows_eq(csv_rows, expected_records);
22+
Ok(())
23+
}
24+
25+
pub fn assert_sorted_table_eq(
26+
csv_bytes: Vec<u8>,
27+
expected_headers: Vec<&str>,
28+
expected_records: Vec<Vec<&str>>,
29+
) -> Result<()> {
30+
let mut csv = csv::Reader::from_reader(csv_bytes.as_slice());
31+
assert_eq!(csv.headers()?, expected_headers);
32+
let mut csv_rows = csv
33+
.records()
34+
.collect::<std::result::Result<Vec<csv::StringRecord>, csv::Error>>()?;
35+
csv_rows.sort_by(|r1, r2: &csv::StringRecord| r1[0].cmp(&r2[0]));
36+
assert_rows_eq(csv_rows, expected_records);
37+
Ok(())
38+
}
39+
40+
pub fn assert_table_unordered_eq(
41+
csv_bytes: Vec<u8>,
42+
expected_headers: Vec<&str>,
43+
expected_records: Vec<Vec<&str>>,
44+
) -> Result<()> {
45+
let mut csv = csv::Reader::from_reader(csv_bytes.as_slice());
46+
assert_eq!(csv.headers()?, expected_headers);
47+
let mut csv_rows = csv
48+
.records()
49+
.collect::<std::result::Result<Vec<csv::StringRecord>, csv::Error>>()?;
50+
csv_rows.sort_by_key(|row| {
51+
row.clone()
52+
.into_iter()
53+
.map(|s| s.to_string())
54+
.collect::<Vec<String>>()
55+
});
56+
let mut expected_records = expected_records;
57+
expected_records.sort();
58+
assert_rows_eq(csv_rows, expected_records);
59+
Ok(())
60+
}

ciphercore-base/src/csv/utils.rs

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
use crate::data_types::Type;
2+
use crate::errors::Result;
3+
use crate::runtime_error;
4+
use crate::typed_value::TypedValue;
5+
6+
pub struct Column {
7+
pub name: String,
8+
pub data: Vec<Option<String>>,
9+
}
10+
11+
pub fn output_string_column(
12+
name: String,
13+
body: TypedValue,
14+
len: TypedValue,
15+
mask: &[u8],
16+
) -> Result<Column> {
17+
let n = match len.t.get_shape().as_slice() {
18+
&[n] => n as usize,
19+
_ => return Err(runtime_error!("len type {:?}", len.t)),
20+
};
21+
let max_len = match body.t.get_shape().as_slice() {
22+
&[rows, max_len] => {
23+
if rows as usize != n {
24+
return Err(runtime_error!("len {:?} vs body {:?}", body.t, len.t));
25+
}
26+
max_len as usize
27+
}
28+
_ => return Err(runtime_error!("body type {:?}", body.t)),
29+
};
30+
if mask.len() != n {
31+
return Err(runtime_error!("mask.len() != n: {} vs {}", mask.len(), n));
32+
}
33+
let len = len.value.to_flattened_array_u16(len.t)?;
34+
let mut data = vec![];
35+
body.value.access_bytes(|bytes| {
36+
for i in 0..n {
37+
data.push(if mask[i] == 1 {
38+
let offset = i * max_len;
39+
let b = bytes[offset..offset + len[i] as usize].to_vec();
40+
Some(String::from_utf8(b).map_err(|e| runtime_error!("UTF8 error: {e}"))?)
41+
} else {
42+
None
43+
});
44+
}
45+
Ok(())
46+
})?;
47+
Ok(Column { name, data })
48+
}
49+
50+
fn map_or_mask<T>(
51+
values: Vec<T>,
52+
mask: &[u8],
53+
f: impl Fn(T) -> String,
54+
) -> Result<Vec<Option<String>>> {
55+
let mut data = vec![];
56+
for (value, &mask) in values.into_iter().zip(mask.iter()) {
57+
data.push(if mask == 0 { None } else { Some(f(value)) });
58+
}
59+
Ok(data)
60+
}
61+
62+
pub fn output_int_column(name: String, value: TypedValue, mask: &[u8]) -> Result<Column> {
63+
validate_shapes(&value, mask)?;
64+
let data = if value.t.get_scalar_type().is_signed() {
65+
let values = value.value.to_flattened_array_i64(value.t)?;
66+
map_or_mask(values, mask, |v| v.to_string())?
67+
} else {
68+
let values = value.value.to_flattened_array_u64(value.t)?;
69+
map_or_mask(values, mask, |v| v.to_string())?
70+
};
71+
Ok(Column { name, data })
72+
}
73+
74+
pub fn output_float_column(
75+
name: String,
76+
value: TypedValue,
77+
fractional_bits: usize,
78+
float_decimal_places: usize,
79+
mask: &[u8],
80+
) -> Result<Column> {
81+
validate_shapes(&value, mask)?;
82+
let data = if value.t.get_scalar_type().is_signed() {
83+
let values = value.value.to_flattened_array_i64(value.t)?;
84+
map_or_mask(values, mask, |v| {
85+
format!(
86+
"{:.prec$}",
87+
v as f64 / (1 << fractional_bits) as f64,
88+
prec = float_decimal_places
89+
)
90+
})?
91+
} else {
92+
let values = value.value.to_flattened_array_u64(value.t)?;
93+
map_or_mask(values, mask, |v| {
94+
format!(
95+
"{:.prec$}",
96+
v as f64 / (1 << fractional_bits) as f64,
97+
prec = float_decimal_places
98+
)
99+
})?
100+
};
101+
Ok(Column { name, data })
102+
}
103+
104+
pub fn output_bool_column(name: String, value: TypedValue, mask: &[u8]) -> Result<Column> {
105+
validate_shapes(&value, mask)?;
106+
let values = value.value.to_flattened_array_u8(value.t)?;
107+
let data = map_or_mask(values, mask, |v| {
108+
if v == 0 {
109+
"false".to_string()
110+
} else {
111+
"true".to_string()
112+
}
113+
})?;
114+
Ok(Column { name, data })
115+
}
116+
117+
fn validate_shapes(value: &TypedValue, mask: &[u8]) -> Result<()> {
118+
let n = match value.t.get_shape().as_slice() {
119+
&[n] => n as usize,
120+
_ => return Err(runtime_error!("value type {:?}", value.t)),
121+
};
122+
if mask.len() != n {
123+
return Err(runtime_error!("mask.len() != n: {} vs {}", mask.len(), n));
124+
}
125+
Ok(())
126+
}
127+
128+
pub fn output_table(columns: &Vec<Column>) -> Result<Vec<Vec<String>>> {
129+
if columns.is_empty() {
130+
return Err(runtime_error!("write_rows: empty columns list"));
131+
}
132+
let n = columns[0].data.len();
133+
let mut table = vec![];
134+
for i in 0..n {
135+
let row_iter = columns.iter().map(|column| &column.data[i]);
136+
if row_iter.clone().all(|cell| cell.is_none()) {
137+
// Skip empty rows.
138+
continue;
139+
}
140+
table.push(
141+
row_iter
142+
.map(|cell| match cell {
143+
None => "".to_owned(),
144+
Some(val) => val.clone(),
145+
})
146+
.collect::<Vec<String>>(),
147+
);
148+
}
149+
Ok(table)
150+
}
151+
152+
pub fn write_table(
153+
mut columns: Vec<Column>,
154+
sort_columns: bool,
155+
sort_rows: bool,
156+
) -> Result<Vec<u8>> {
157+
if sort_columns {
158+
columns.sort_by(|c1, c2| c1.name.cmp(&c2.name));
159+
}
160+
let mut table = output_table(&columns)?;
161+
if sort_rows {
162+
table.sort();
163+
}
164+
let header = columns.into_iter().map(|column| column.name).collect();
165+
write_to_csv(header, table)
166+
}
167+
168+
fn write_to_csv(header: Vec<String>, table: Vec<Vec<String>>) -> Result<Vec<u8>> {
169+
let mut wtr = csv::Writer::from_writer(vec![]);
170+
wtr.write_record(header)?;
171+
for row in table {
172+
wtr.write_record(row)?;
173+
}
174+
wtr.into_inner()
175+
.map_err(|err| runtime_error!("Error: {}", err))
176+
}
177+
178+
pub fn unpack_named_tuple(value: TypedValue) -> Result<Vec<(String, TypedValue)>> {
179+
let name_and_types = match value.t.clone() {
180+
Type::NamedTuple(elements) => elements,
181+
t => return Err(runtime_error!("Expected NamedTuple, got {:?}", t)),
182+
};
183+
let values = value.value.to_vector()?;
184+
if name_and_types.len() != values.len() {
185+
return Err(runtime_error!("Inconsistent data"));
186+
}
187+
let mut result = vec![];
188+
for ((name, t), value) in name_and_types.into_iter().zip(values.into_iter()) {
189+
result.push((
190+
name,
191+
TypedValue {
192+
value,
193+
t: t.as_ref().clone(),
194+
name: None,
195+
},
196+
));
197+
}
198+
Ok(result)
199+
}
200+
201+
pub fn unpack_tuple(value: TypedValue) -> Result<Vec<TypedValue>> {
202+
let types = match value.t.clone() {
203+
Type::Tuple(elements) => elements,
204+
t => return Err(runtime_error!("Expected Tuple, got {:?}", t)),
205+
};
206+
let values = value.value.to_vector()?;
207+
if types.len() != values.len() {
208+
return Err(runtime_error!("Inconsistent data"));
209+
}
210+
let mut result = vec![];
211+
for (t, value) in types.into_iter().zip(values.into_iter()) {
212+
result.push(TypedValue {
213+
value,
214+
t: t.as_ref().clone(),
215+
name: None,
216+
});
217+
}
218+
Ok(result)
219+
}
220+
221+
pub fn unpack_pair(value: TypedValue) -> Result<(TypedValue, TypedValue)> {
222+
let values = unpack_tuple(value)?;
223+
match values.as_slice() {
224+
[first, second] => Ok((first.clone(), second.clone())),
225+
_ => Err(runtime_error!("Expected tuple of size 2")),
226+
}
227+
}
228+
229+
pub fn extract_data_mask_pair(value: TypedValue) -> Result<(TypedValue, Vec<u8>)> {
230+
let (data, mask) = unpack_pair(value)?;
231+
let mask = mask.value.to_flattened_array_u8(mask.t)?;
232+
Ok((data, mask))
233+
}
234+
235+
#[cfg(test)]
236+
mod tests {
237+
use ndarray::array;
238+
239+
use super::*;
240+
use crate::{
241+
csv::test_utils::assert_table_eq,
242+
data_types::{BIT, INT64, UINT8},
243+
typed_value_operations::TypedValueArrayOperations,
244+
};
245+
246+
#[test]
247+
fn test_output_csv() -> Result<()> {
248+
let c1 = output_int_column(
249+
"d".into(),
250+
TypedValue::from_ndarray(array![1, 2, 3, 4].into_dyn(), INT64)?,
251+
&[1, 1, 1, 0],
252+
)?;
253+
let c2 = output_float_column(
254+
"c".into(),
255+
TypedValue::from_ndarray(array![128, 256, 512, 1024].into_dyn(), INT64)?,
256+
10,
257+
3,
258+
&[0, 1, 1, 1],
259+
)?;
260+
let c3 = output_bool_column(
261+
"b".into(),
262+
TypedValue::from_ndarray(array![1, 0, 0, 1].into_dyn(), BIT)?,
263+
&[1, 0, 1, 1],
264+
)?;
265+
let c4 = output_string_column(
266+
"a".into(),
267+
TypedValue::from_ndarray(
268+
array![[65, 66, 0], [70, 0, 0], [75, 76, 77], [80, 81, 0]].into_dyn(),
269+
UINT8,
270+
)?,
271+
TypedValue::from_ndarray(array![2, 1, 3, 2].into_dyn(), INT64)?,
272+
&[1, 1, 1, 1],
273+
)?;
274+
assert_table_eq(
275+
write_table(vec![c1, c2, c3, c4], false, false)?,
276+
vec!["d", "c", "b", "a"],
277+
vec![
278+
vec!["1", "", "true", "AB"],
279+
vec!["2", "0.250", "", "F"],
280+
vec!["3", "0.500", "false", "KLM"],
281+
vec!["", "1.000", "true", "PQ"],
282+
],
283+
)
284+
}
285+
286+
#[tokio::test]
287+
async fn test_output_sorted_csv() -> Result<()> {
288+
let c1 = output_int_column(
289+
"salary".into(),
290+
TypedValue::from_ndarray(array![1000, 2000, 3000, 4000].into_dyn(), INT64)?,
291+
&[1, 1, 1, 1],
292+
)?;
293+
let c2 = output_int_column(
294+
"age".into(),
295+
TypedValue::from_ndarray(array![10, 20, 100, 30].into_dyn(), INT64)?,
296+
&[1, 1, 1, 0],
297+
)?;
298+
assert_table_eq(
299+
write_table(vec![c1, c2], true, true)?,
300+
vec!["age", "salary"],
301+
vec![
302+
vec!["", "4000"],
303+
vec!["10", "1000"],
304+
vec!["100", "3000"],
305+
vec!["20", "2000"],
306+
],
307+
)
308+
}
309+
}

ciphercore-base/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,8 @@ pub mod broadcast;
896896
#[doc(hidden)]
897897
pub mod bytes;
898898
mod constants;
899+
#[doc(hidden)]
900+
pub mod csv;
899901
pub mod custom_ops;
900902
pub mod data_types;
901903
pub mod data_values;

0 commit comments

Comments
 (0)