Skip to content

Commit dee5147

Browse files
authored
Postgres: Array enum encoding (launchbadge#1511)
* Postgres: Add test for array enum * Allow produces() to override type_info() as per doc * run cargo fmt
1 parent 04109d9 commit dee5147

File tree

3 files changed

+130
-2
lines changed

3 files changed

+130
-2
lines changed

sqlx-core/src/postgres/types/array.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,17 @@ where
6969
T: Encode<'q, Postgres> + Type<Postgres>,
7070
{
7171
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
72+
let type_info = if self.len() < 1 {
73+
T::type_info()
74+
} else {
75+
self[0].produces().unwrap_or_else(T::type_info)
76+
};
77+
7278
buf.extend(&1_i32.to_be_bytes()); // number of dimensions
7379
buf.extend(&0_i32.to_be_bytes()); // flags
7480

7581
// element type
76-
match T::type_info().0 {
82+
match type_info.0 {
7783
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
7884

7985
ty => {

sqlx-core/src/postgres/types/record.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ impl<'a> PgRecordEncoder<'a> {
3838
'a: 'q,
3939
T: Encode<'q, Postgres> + Type<Postgres>,
4040
{
41-
let ty = T::type_info();
41+
let ty = value.produces().unwrap_or_else(T::type_info);
4242

4343
if let PgType::DeclareWithName(name) = ty.0 {
4444
// push a hole for this type ID

tests/postgres/postgres.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,128 @@ async fn it_can_copy_out() -> anyhow::Result<()> {
12011201
Ok(())
12021202
}
12031203

1204+
#[sqlx_macros::test]
1205+
async fn it_encodes_custom_array_issue_1504() -> anyhow::Result<()> {
1206+
use sqlx::encode::IsNull;
1207+
use sqlx::postgres::{PgArgumentBuffer, PgTypeInfo};
1208+
use sqlx::{Decode, Encode, Type, ValueRef};
1209+
1210+
#[derive(Debug, PartialEq)]
1211+
enum Value {
1212+
String(String),
1213+
Number(i32),
1214+
Array(Vec<Value>),
1215+
}
1216+
1217+
impl<'r> Decode<'r, Postgres> for Value {
1218+
fn decode(
1219+
value: sqlx::postgres::PgValueRef<'r>,
1220+
) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
1221+
let typ = value.type_info().into_owned();
1222+
1223+
if typ == PgTypeInfo::with_name("text") {
1224+
let s = <String as Decode<'_, Postgres>>::decode(value)?;
1225+
1226+
Ok(Self::String(s))
1227+
} else if typ == PgTypeInfo::with_name("int4") {
1228+
let n = <i32 as Decode<'_, Postgres>>::decode(value)?;
1229+
1230+
Ok(Self::Number(n))
1231+
} else if typ == PgTypeInfo::with_name("_text") {
1232+
let arr = Vec::<String>::decode(value)?;
1233+
let v = arr.into_iter().map(|s| Value::String(s)).collect();
1234+
1235+
Ok(Self::Array(v))
1236+
} else if typ == PgTypeInfo::with_name("_int4") {
1237+
let arr = Vec::<i32>::decode(value)?;
1238+
let v = arr.into_iter().map(|n| Value::Number(n)).collect();
1239+
1240+
Ok(Self::Array(v))
1241+
} else {
1242+
Err("unknown type".into())
1243+
}
1244+
}
1245+
}
1246+
1247+
impl Encode<'_, Postgres> for Value {
1248+
fn produces(&self) -> Option<PgTypeInfo> {
1249+
match self {
1250+
Self::Array(a) => {
1251+
if a.len() < 1 {
1252+
return Some(PgTypeInfo::with_name("_text"));
1253+
}
1254+
1255+
match a[0] {
1256+
Self::String(_) => Some(PgTypeInfo::with_name("_text")),
1257+
Self::Number(_) => Some(PgTypeInfo::with_name("_int4")),
1258+
Self::Array(_) => None,
1259+
}
1260+
}
1261+
Self::String(_) => Some(PgTypeInfo::with_name("text")),
1262+
Self::Number(_) => Some(PgTypeInfo::with_name("int4")),
1263+
}
1264+
}
1265+
1266+
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
1267+
match self {
1268+
Value::String(s) => <String as Encode<'_, Postgres>>::encode_by_ref(s, buf),
1269+
Value::Number(n) => <i32 as Encode<'_, Postgres>>::encode_by_ref(n, buf),
1270+
Value::Array(arr) => arr.encode(buf),
1271+
}
1272+
}
1273+
}
1274+
1275+
impl Type<Postgres> for Value {
1276+
fn type_info() -> PgTypeInfo {
1277+
PgTypeInfo::with_name("unknown")
1278+
}
1279+
1280+
fn compatible(ty: &PgTypeInfo) -> bool {
1281+
[
1282+
PgTypeInfo::with_name("text"),
1283+
PgTypeInfo::with_name("_text"),
1284+
PgTypeInfo::with_name("int4"),
1285+
PgTypeInfo::with_name("_int4"),
1286+
]
1287+
.contains(ty)
1288+
}
1289+
}
1290+
1291+
let mut conn = new::<Postgres>().await?;
1292+
1293+
let (row,): (Value,) = sqlx::query_as("SELECT $1::text[] as Dummy")
1294+
.bind(Value::Array(vec![
1295+
Value::String("Test 0".to_string()),
1296+
Value::String("Test 1".to_string()),
1297+
]))
1298+
.fetch_one(&mut conn)
1299+
.await?;
1300+
1301+
assert_eq!(
1302+
row,
1303+
Value::Array(vec![
1304+
Value::String("Test 0".to_string()),
1305+
Value::String("Test 1".to_string()),
1306+
])
1307+
);
1308+
1309+
let (row,): (Value,) = sqlx::query_as("SELECT $1::int4[] as Dummy")
1310+
.bind(Value::Array(vec![
1311+
Value::Number(3),
1312+
Value::Number(2),
1313+
Value::Number(1),
1314+
]))
1315+
.fetch_one(&mut conn)
1316+
.await?;
1317+
1318+
assert_eq!(
1319+
row,
1320+
Value::Array(vec![Value::Number(3), Value::Number(2), Value::Number(1)])
1321+
);
1322+
1323+
Ok(())
1324+
}
1325+
12041326
#[sqlx_macros::test]
12051327
async fn test_issue_1254() -> anyhow::Result<()> {
12061328
#[derive(sqlx::Type)]

0 commit comments

Comments
 (0)