@@ -1201,6 +1201,128 @@ async fn it_can_copy_out() -> anyhow::Result<()> {
1201
1201
Ok ( ( ) )
1202
1202
}
1203
1203
1204
+ #[ sqlx_macros:: test]
1205
+ async fn it_encodes_custom_array_issue_1504 ( ) -> anyhow:: Result < ( ) > {
1206
+ use sqlx:: encode:: IsNull ;
1207
+ use sqlx:: postgres:: { PgArgumentBuffer , PgTypeInfo } ;
1208
+ use sqlx:: { Decode , Encode , Type , ValueRef } ;
1209
+
1210
+ #[ derive( Debug , PartialEq ) ]
1211
+ enum Value {
1212
+ String ( String ) ,
1213
+ Number ( i32 ) ,
1214
+ Array ( Vec < Value > ) ,
1215
+ }
1216
+
1217
+ impl < ' r > Decode < ' r , Postgres > for Value {
1218
+ fn decode (
1219
+ value : sqlx:: postgres:: PgValueRef < ' r > ,
1220
+ ) -> std:: result:: Result < Self , Box < dyn std:: error:: Error + ' static + Send + Sync > > {
1221
+ let typ = value. type_info ( ) . into_owned ( ) ;
1222
+
1223
+ if typ == PgTypeInfo :: with_name ( "text" ) {
1224
+ let s = <String as Decode < ' _ , Postgres > >:: decode ( value) ?;
1225
+
1226
+ Ok ( Self :: String ( s) )
1227
+ } else if typ == PgTypeInfo :: with_name ( "int4" ) {
1228
+ let n = <i32 as Decode < ' _ , Postgres > >:: decode ( value) ?;
1229
+
1230
+ Ok ( Self :: Number ( n) )
1231
+ } else if typ == PgTypeInfo :: with_name ( "_text" ) {
1232
+ let arr = Vec :: < String > :: decode ( value) ?;
1233
+ let v = arr. into_iter ( ) . map ( |s| Value :: String ( s) ) . collect ( ) ;
1234
+
1235
+ Ok ( Self :: Array ( v) )
1236
+ } else if typ == PgTypeInfo :: with_name ( "_int4" ) {
1237
+ let arr = Vec :: < i32 > :: decode ( value) ?;
1238
+ let v = arr. into_iter ( ) . map ( |n| Value :: Number ( n) ) . collect ( ) ;
1239
+
1240
+ Ok ( Self :: Array ( v) )
1241
+ } else {
1242
+ Err ( "unknown type" . into ( ) )
1243
+ }
1244
+ }
1245
+ }
1246
+
1247
+ impl Encode < ' _ , Postgres > for Value {
1248
+ fn produces ( & self ) -> Option < PgTypeInfo > {
1249
+ match self {
1250
+ Self :: Array ( a) => {
1251
+ if a. len ( ) < 1 {
1252
+ return Some ( PgTypeInfo :: with_name ( "_text" ) ) ;
1253
+ }
1254
+
1255
+ match a[ 0 ] {
1256
+ Self :: String ( _) => Some ( PgTypeInfo :: with_name ( "_text" ) ) ,
1257
+ Self :: Number ( _) => Some ( PgTypeInfo :: with_name ( "_int4" ) ) ,
1258
+ Self :: Array ( _) => None ,
1259
+ }
1260
+ }
1261
+ Self :: String ( _) => Some ( PgTypeInfo :: with_name ( "text" ) ) ,
1262
+ Self :: Number ( _) => Some ( PgTypeInfo :: with_name ( "int4" ) ) ,
1263
+ }
1264
+ }
1265
+
1266
+ fn encode_by_ref ( & self , buf : & mut PgArgumentBuffer ) -> IsNull {
1267
+ match self {
1268
+ Value :: String ( s) => <String as Encode < ' _ , Postgres > >:: encode_by_ref ( s, buf) ,
1269
+ Value :: Number ( n) => <i32 as Encode < ' _ , Postgres > >:: encode_by_ref ( n, buf) ,
1270
+ Value :: Array ( arr) => arr. encode ( buf) ,
1271
+ }
1272
+ }
1273
+ }
1274
+
1275
+ impl Type < Postgres > for Value {
1276
+ fn type_info ( ) -> PgTypeInfo {
1277
+ PgTypeInfo :: with_name ( "unknown" )
1278
+ }
1279
+
1280
+ fn compatible ( ty : & PgTypeInfo ) -> bool {
1281
+ [
1282
+ PgTypeInfo :: with_name ( "text" ) ,
1283
+ PgTypeInfo :: with_name ( "_text" ) ,
1284
+ PgTypeInfo :: with_name ( "int4" ) ,
1285
+ PgTypeInfo :: with_name ( "_int4" ) ,
1286
+ ]
1287
+ . contains ( ty)
1288
+ }
1289
+ }
1290
+
1291
+ let mut conn = new :: < Postgres > ( ) . await ?;
1292
+
1293
+ let ( row, ) : ( Value , ) = sqlx:: query_as ( "SELECT $1::text[] as Dummy" )
1294
+ . bind ( Value :: Array ( vec ! [
1295
+ Value :: String ( "Test 0" . to_string( ) ) ,
1296
+ Value :: String ( "Test 1" . to_string( ) ) ,
1297
+ ] ) )
1298
+ . fetch_one ( & mut conn)
1299
+ . await ?;
1300
+
1301
+ assert_eq ! (
1302
+ row,
1303
+ Value :: Array ( vec![
1304
+ Value :: String ( "Test 0" . to_string( ) ) ,
1305
+ Value :: String ( "Test 1" . to_string( ) ) ,
1306
+ ] )
1307
+ ) ;
1308
+
1309
+ let ( row, ) : ( Value , ) = sqlx:: query_as ( "SELECT $1::int4[] as Dummy" )
1310
+ . bind ( Value :: Array ( vec ! [
1311
+ Value :: Number ( 3 ) ,
1312
+ Value :: Number ( 2 ) ,
1313
+ Value :: Number ( 1 ) ,
1314
+ ] ) )
1315
+ . fetch_one ( & mut conn)
1316
+ . await ?;
1317
+
1318
+ assert_eq ! (
1319
+ row,
1320
+ Value :: Array ( vec![ Value :: Number ( 3 ) , Value :: Number ( 2 ) , Value :: Number ( 1 ) ] )
1321
+ ) ;
1322
+
1323
+ Ok ( ( ) )
1324
+ }
1325
+
1204
1326
#[ sqlx_macros:: test]
1205
1327
async fn test_issue_1254 ( ) -> anyhow:: Result < ( ) > {
1206
1328
#[ derive( sqlx:: Type ) ]
0 commit comments