@@ -110,3 +110,62 @@ def test_list_rows_nullable_scalars_dtypes(
110110 timestamp_type = schema .field ("timestamp_col" ).type
111111 assert timestamp_type .unit == "us"
112112 assert timestamp_type .tz is not None
113+
114+
115+ @pytest .mark .parametrize ("do_insert" , [True , False ])
116+ def test_arrow_extension_types_same_for_storage_and_REST_APIs_894 (
117+ dataset_client , test_table_name , do_insert
118+ ):
119+ types = dict (
120+ astring = ("STRING" , "'x'" ),
121+ astring9 = ("STRING(9)" , "'x'" ),
122+ abytes = ("BYTES" , "b'x'" ),
123+ abytes9 = ("BYTES(9)" , "b'x'" ),
124+ anumeric = ("NUMERIC" , "42" ),
125+ anumeric9 = ("NUMERIC(9)" , "42" ),
126+ anumeric92 = ("NUMERIC(9,2)" , "42" ),
127+ abignumeric = ("BIGNUMERIC" , "42e30" ),
128+ abignumeric49 = ("BIGNUMERIC(37)" , "42e30" ),
129+ abignumeric492 = ("BIGNUMERIC(37,2)" , "42e30" ),
130+ abool = ("BOOL" , "true" ),
131+ adate = ("DATE" , "'2021-09-06'" ),
132+ adatetime = ("DATETIME" , "'2021-09-06T09:57:26'" ),
133+ ageography = ("GEOGRAPHY" , "ST_GEOGFROMTEXT('point(0 0)')" ),
134+ # Can't get arrow data for interval :(
135+ # ainterval=('INTERVAL', "make_interval(1, 2, 3, 4, 5, 6)"),
136+ aint64 = ("INT64" , "42" ),
137+ afloat64 = ("FLOAT64" , "42.0" ),
138+ astruct = ("STRUCT<v int64>" , "struct(42)" ),
139+ atime = ("TIME" , "'1:2:3'" ),
140+ atimestamp = ("TIMESTAMP" , "'2021-09-06T09:57:26'" ),
141+ )
142+ columns = ", " .join (f"{ k } { t [0 ]} " for k , t in types .items ())
143+ dataset_client .query (f"create table { test_table_name } ({ columns } )" ).result ()
144+ if do_insert :
145+ names = list (types )
146+ values = ", " .join (types [name ][1 ] for name in names )
147+ names = ", " .join (names )
148+ dataset_client .query (
149+ f"insert into { test_table_name } ({ names } ) values ({ values } )"
150+ ).result ()
151+ at = dataset_client .query (f"select * from { test_table_name } " ).result ().to_arrow ()
152+ storage_api_metadata = {
153+ at .field (i ).name : at .field (i ).metadata for i in range (at .num_columns )
154+ }
155+ at = (
156+ dataset_client .query (f"select * from { test_table_name } " )
157+ .result ()
158+ .to_arrow (create_bqstorage_client = False )
159+ )
160+ rest_api_metadata = {
161+ at .field (i ).name : at .field (i ).metadata for i in range (at .num_columns )
162+ }
163+
164+ assert rest_api_metadata == storage_api_metadata
165+ assert rest_api_metadata ["adatetime" ] == {
166+ b"ARROW:extension:name" : b"google:sqlType:datetime"
167+ }
168+ assert rest_api_metadata ["ageography" ] == {
169+ b"ARROW:extension:name" : b"google:sqlType:geography" ,
170+ b"ARROW:extension:metadata" : b'{"encoding": "WKT"}' ,
171+ }
0 commit comments