1919
2020use std:: collections:: HashMap ;
2121use std:: sync:: Arc ;
22+ use std:: vec;
2223
24+ use datafusion:: arrow:: array:: { Array , StringArray } ;
2325use datafusion:: arrow:: datatypes:: DataType ;
2426use datafusion:: execution:: context:: SessionContext ;
2527use iceberg:: io:: FileIOBuilder ;
26- use iceberg:: spec:: { NestedField , PrimitiveType , Schema , Type } ;
28+ use iceberg:: spec:: { NestedField , PrimitiveType , Schema , StructType , Type } ;
2729use iceberg:: { Catalog , NamespaceIdent , Result , TableCreation } ;
2830use iceberg_catalog_memory:: MemoryCatalog ;
2931use iceberg_datafusion:: IcebergCatalogProvider ;
@@ -39,6 +41,13 @@ fn get_iceberg_catalog() -> MemoryCatalog {
3941 MemoryCatalog :: new ( file_io, Some ( temp_path ( ) ) )
4042}
4143
44+ fn get_struct_type ( ) -> StructType {
45+ StructType :: new ( vec ! [
46+ NestedField :: required( 4 , "s_foo1" , Type :: Primitive ( PrimitiveType :: Int ) ) . into( ) ,
47+ NestedField :: required( 5 , "s_foo2" , Type :: Primitive ( PrimitiveType :: String ) ) . into( ) ,
48+ ] )
49+ }
50+
4251async fn set_test_namespace ( catalog : & MemoryCatalog , namespace : & NamespaceIdent ) -> Result < ( ) > {
4352 let properties = HashMap :: new ( ) ;
4453
@@ -47,14 +56,21 @@ async fn set_test_namespace(catalog: &MemoryCatalog, namespace: &NamespaceIdent)
4756 Ok ( ( ) )
4857}
4958
50- fn set_table_creation ( location : impl ToString , name : impl ToString ) -> Result < TableCreation > {
51- let schema = Schema :: builder ( )
52- . with_schema_id ( 0 )
53- . with_fields ( vec ! [
54- NestedField :: required( 1 , "foo" , Type :: Primitive ( PrimitiveType :: Int ) ) . into( ) ,
55- NestedField :: required( 2 , "bar" , Type :: Primitive ( PrimitiveType :: String ) ) . into( ) ,
56- ] )
57- . build ( ) ?;
59+ fn get_table_creation (
60+ location : impl ToString ,
61+ name : impl ToString ,
62+ schema : Option < Schema > ,
63+ ) -> Result < TableCreation > {
64+ let schema = match schema {
65+ None => Schema :: builder ( )
66+ . with_schema_id ( 0 )
67+ . with_fields ( vec ! [
68+ NestedField :: required( 1 , "foo1" , Type :: Primitive ( PrimitiveType :: Int ) ) . into( ) ,
69+ NestedField :: required( 2 , "foo2" , Type :: Primitive ( PrimitiveType :: String ) ) . into( ) ,
70+ ] )
71+ . build ( ) ?,
72+ Some ( schema) => schema,
73+ } ;
5874
5975 let creation = TableCreation :: builder ( )
6076 . location ( location. to_string ( ) )
@@ -72,7 +88,7 @@ async fn test_provider_get_table_schema() -> Result<()> {
7288 let namespace = NamespaceIdent :: new ( "test_provider_get_table_schema" . to_string ( ) ) ;
7389 set_test_namespace ( & iceberg_catalog, & namespace) . await ?;
7490
75- let creation = set_table_creation ( temp_path ( ) , "my_table" ) ?;
91+ let creation = get_table_creation ( temp_path ( ) , "my_table" , None ) ?;
7692 iceberg_catalog. create_table ( & namespace, creation) . await ?;
7793
7894 let client = Arc :: new ( iceberg_catalog) ;
@@ -87,7 +103,7 @@ async fn test_provider_get_table_schema() -> Result<()> {
87103 let table = schema. table ( "my_table" ) . await . unwrap ( ) . unwrap ( ) ;
88104 let table_schema = table. schema ( ) ;
89105
90- let expected = [ ( "foo " , & DataType :: Int32 ) , ( "bar " , & DataType :: Utf8 ) ] ;
106+ let expected = [ ( "foo1 " , & DataType :: Int32 ) , ( "foo2 " , & DataType :: Utf8 ) ] ;
91107
92108 for ( field, exp) in table_schema. fields ( ) . iter ( ) . zip ( expected. iter ( ) ) {
93109 assert_eq ! ( field. name( ) , exp. 0 ) ;
@@ -104,7 +120,7 @@ async fn test_provider_list_table_names() -> Result<()> {
104120 let namespace = NamespaceIdent :: new ( "test_provider_list_table_names" . to_string ( ) ) ;
105121 set_test_namespace ( & iceberg_catalog, & namespace) . await ?;
106122
107- let creation = set_table_creation ( temp_path ( ) , "my_table" ) ?;
123+ let creation = get_table_creation ( temp_path ( ) , "my_table" , None ) ?;
108124 iceberg_catalog. create_table ( & namespace, creation) . await ?;
109125
110126 let client = Arc :: new ( iceberg_catalog) ;
@@ -130,7 +146,6 @@ async fn test_provider_list_schema_names() -> Result<()> {
130146 let namespace = NamespaceIdent :: new ( "test_provider_list_schema_names" . to_string ( ) ) ;
131147 set_test_namespace ( & iceberg_catalog, & namespace) . await ?;
132148
133- set_table_creation ( "test_provider_list_schema_names" , "my_table" ) ?;
134149 let client = Arc :: new ( iceberg_catalog) ;
135150 let catalog = Arc :: new ( IcebergCatalogProvider :: try_new ( client) . await ?) ;
136151
@@ -147,3 +162,71 @@ async fn test_provider_list_schema_names() -> Result<()> {
147162 . all( |item| result. contains( & item. to_string( ) ) ) ) ;
148163 Ok ( ( ) )
149164}
165+
166+ #[ tokio:: test]
167+ async fn test_table_projection ( ) -> Result < ( ) > {
168+ let iceberg_catalog = get_iceberg_catalog ( ) ;
169+ let namespace = NamespaceIdent :: new ( "ns" . to_string ( ) ) ;
170+ set_test_namespace ( & iceberg_catalog, & namespace) . await ?;
171+
172+ let schema = Schema :: builder ( )
173+ . with_schema_id ( 0 )
174+ . with_fields ( vec ! [
175+ NestedField :: required( 1 , "foo1" , Type :: Primitive ( PrimitiveType :: Int ) ) . into( ) ,
176+ NestedField :: required( 2 , "foo2" , Type :: Primitive ( PrimitiveType :: String ) ) . into( ) ,
177+ NestedField :: optional( 3 , "foo3" , Type :: Struct ( get_struct_type( ) ) ) . into( ) ,
178+ ] )
179+ . build ( ) ?;
180+ let creation = get_table_creation ( temp_path ( ) , "t1" , Some ( schema) ) ?;
181+ iceberg_catalog. create_table ( & namespace, creation) . await ?;
182+
183+ let client = Arc :: new ( iceberg_catalog) ;
184+ let catalog = Arc :: new ( IcebergCatalogProvider :: try_new ( client) . await ?) ;
185+
186+ let ctx = SessionContext :: new ( ) ;
187+ ctx. register_catalog ( "catalog" , catalog) ;
188+ let table_df = ctx. table ( "catalog.ns.t1" ) . await . unwrap ( ) ;
189+
190+ let records = table_df
191+ . clone ( )
192+ . explain ( false , false )
193+ . unwrap ( )
194+ . collect ( )
195+ . await
196+ . unwrap ( ) ;
197+ assert_eq ! ( 1 , records. len( ) ) ;
198+ let record = & records[ 0 ] ;
199+ // the first column is plan_type, the second column plan string.
200+ let s = record
201+ . column ( 1 )
202+ . as_any ( )
203+ . downcast_ref :: < StringArray > ( )
204+ . unwrap ( ) ;
205+ assert_eq ! ( 2 , s. len( ) ) ;
206+ // the first row is logical_plan, the second row is physical_plan
207+ assert_eq ! (
208+ "IcebergTableScan projection:[foo1,foo2,foo3]" ,
209+ s. value( 1 ) . trim( )
210+ ) ;
211+
212+ // datafusion doesn't support query foo3.s_foo1, use foo3 instead
213+ let records = table_df
214+ . select_columns ( & [ "foo1" , "foo3" ] )
215+ . unwrap ( )
216+ . explain ( false , false )
217+ . unwrap ( )
218+ . collect ( )
219+ . await
220+ . unwrap ( ) ;
221+ assert_eq ! ( 1 , records. len( ) ) ;
222+ let record = & records[ 0 ] ;
223+ let s = record
224+ . column ( 1 )
225+ . as_any ( )
226+ . downcast_ref :: < StringArray > ( )
227+ . unwrap ( ) ;
228+ assert_eq ! ( 2 , s. len( ) ) ;
229+ assert_eq ! ( "IcebergTableScan projection:[foo1,foo3]" , s. value( 1 ) . trim( ) ) ;
230+
231+ Ok ( ( ) )
232+ }
0 commit comments