1515# limitations under the License.
1616#
1717
18- from typing import Optional , Sequence , Tuple , Union
18+ import csv
19+ import logging
20+
21+ from typing import List , Optional , Sequence , Tuple , Union
1922
2023from google .auth import credentials as auth_credentials
2124
25+ from google .cloud import bigquery
26+ from google .cloud import storage
27+
2228from google .cloud .aiplatform import datasets
2329from google .cloud .aiplatform .datasets import _datasources
2430from google .cloud .aiplatform import initializer
@@ -33,6 +39,157 @@ class TabularDataset(datasets._Dataset):
3339 schema .dataset .metadata .tabular ,
3440 )
3541
42+ @property
43+ def column_names (self ) -> List [str ]:
44+ """Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
45+ Google BigQuery source.
46+
47+ Returns:
48+ List[str]
49+ A list of columns names
50+
51+ Raises:
52+ RuntimeError: When no valid source is found.
53+ """
54+
55+ metadata = self ._gca_resource .metadata
56+
57+ if metadata is None :
58+ raise RuntimeError ("No metadata found for dataset" )
59+
60+ input_config = metadata .get ("inputConfig" )
61+
62+ if input_config is None :
63+ raise RuntimeError ("No inputConfig found for dataset" )
64+
65+ gcs_source = input_config .get ("gcsSource" )
66+ bq_source = input_config .get ("bigquerySource" )
67+
68+ if gcs_source :
69+ gcs_source_uris = gcs_source .get ("uri" )
70+
71+ if gcs_source_uris and len (gcs_source_uris ) > 0 :
72+ # Lexicographically sort the files
73+ gcs_source_uris .sort ()
74+
75+ # Get the first file in sorted list
76+ return TabularDataset ._retrieve_gcs_source_columns (
77+ self .project , gcs_source_uris [0 ]
78+ )
79+ elif bq_source :
80+ bq_table_uri = bq_source .get ("uri" )
81+ if bq_table_uri :
82+ return TabularDataset ._retrieve_bq_source_columns (
83+ self .project , bq_table_uri
84+ )
85+
86+ raise RuntimeError ("No valid CSV or BigQuery datasource found." )
87+
88+ @staticmethod
89+ def _retrieve_gcs_source_columns (project : str , gcs_csv_file_path : str ) -> List [str ]:
90+ """Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
91+
92+ Example Usage:
93+
94+ column_names = _retrieve_gcs_source_columns(
95+ "project_id",
96+ "gs://example-bucket/path/to/csv_file"
97+ )
98+
99+ # column_names = ["column_1", "column_2"]
100+
101+ Args:
102+ project (str):
103+ Required. Project to initiate the Google Cloud Storage client with.
104+ gcs_csv_file_path (str):
105+ Required. A full path to a CSV files stored on Google Cloud Storage.
106+ Must include "gs://" prefix.
107+
108+ Returns:
109+ List[str]
110+ A list of columns names in the CSV file.
111+
112+ Raises:
113+ RuntimeError: When the retrieved CSV file is invalid.
114+ """
115+
116+ gcs_bucket , gcs_blob = utils .extract_bucket_and_prefix_from_gcs_path (
117+ gcs_csv_file_path
118+ )
119+ client = storage .Client (project = project )
120+ bucket = client .bucket (gcs_bucket )
121+ blob = bucket .blob (gcs_blob )
122+
123+ # Incrementally download the CSV file until the header is retrieved
124+ first_new_line_index = - 1
125+ start_index = 0
126+ increment = 1000
127+ line = ""
128+
129+ try :
130+ logger = logging .getLogger ("google.resumable_media._helpers" )
131+ logging_warning_filter = utils .LoggingFilter (logging .INFO )
132+ logger .addFilter (logging_warning_filter )
133+
134+ while first_new_line_index == - 1 :
135+ line += blob .download_as_bytes (
136+ start = start_index , end = start_index + increment
137+ ).decode ("utf-8" )
138+ first_new_line_index = line .find ("\n " )
139+ start_index += increment
140+
141+ header_line = line [:first_new_line_index ]
142+
143+ # Split to make it an iterable
144+ header_line = header_line .split ("\n " )[:1 ]
145+
146+ csv_reader = csv .reader (header_line , delimiter = "," )
147+ except (ValueError , RuntimeError ) as err :
148+ raise RuntimeError (
149+ "There was a problem extracting the headers from the CSV file at '{}': {}" .format (
150+ gcs_csv_file_path , err
151+ )
152+ )
153+ finally :
154+ logger .removeFilter (logging_warning_filter )
155+
156+ return next (csv_reader )
157+
158+ @staticmethod
159+ def _retrieve_bq_source_columns (project : str , bq_table_uri : str ) -> List [str ]:
160+ """Retrieve the columns from a table on Google BigQuery
161+
162+ Example Usage:
163+
164+ column_names = _retrieve_bq_source_columns(
165+ "project_id",
166+ "bq://project_id.dataset.table"
167+ )
168+
169+ # column_names = ["column_1", "column_2"]
170+
171+ Args:
172+ project (str):
173+ Required. Project to initiate the BigQuery client with.
174+ bq_table_uri (str):
175+ Required. A URI to a BigQuery table.
176+ Can include "bq://" prefix but not required.
177+
178+ Returns:
179+ List[str]
180+ A list of columns names in the BigQuery table.
181+ """
182+
183+ # Remove bq:// prefix
184+ prefix = "bq://"
185+ if bq_table_uri .startswith (prefix ):
186+ bq_table_uri = bq_table_uri [len (prefix ) :]
187+
188+ client = bigquery .Client (project = project )
189+ table = client .get_table (bq_table_uri )
190+ schema = table .schema
191+ return [schema .name for schema in schema ]
192+
36193 @classmethod
37194 def create (
38195 cls ,
0 commit comments