|
19 | 19 | import importlib
|
20 | 20 | import json
|
21 | 21 | import logging
|
| 22 | +import os |
22 | 23 | import re
|
23 | 24 | import typing
|
24 | 25 | from typing import (
|
@@ -5803,6 +5804,90 @@ def validate_judge_model_sampling_count(cls, value: Optional[int]) -> Optional[i
|
5803 | 5804 | raise ValueError("judge_model_sampling_count must be between 1 and 32.")
|
5804 | 5805 | return value
|
5805 | 5806 |
|
| 5807 | + @classmethod |
| 5808 | + def load(cls, config_path: str, client: Optional[Any] = None) -> "LLMMetric": |
| 5809 | + """Loads a metric configuration from a YAML or JSON file. |
| 5810 | +
|
| 5811 | + This method allows for the creation of an LLMMetric instance from a |
| 5812 | + local file path or a Google Cloud Storage (GCS) URI. It will |
| 5813 | + automatically |
| 5814 | + detect the file type (.yaml, .yml, or .json) and parse it accordingly. |
| 5815 | +
|
| 5816 | + Args: |
| 5817 | + config_path: The local path or GCS URI (e.g., |
| 5818 | + 'gs://bucket/metric.yaml') to the metric configuration file. |
| 5819 | + client: Optional. The Vertex AI client instance to use for |
| 5820 | + authentication. If not provided, Application Default Credentials |
| 5821 | + (ADC) will be used. |
| 5822 | +
|
| 5823 | + Returns: |
| 5824 | + An instance of LLMMetric configured with the loaded data. |
| 5825 | +
|
| 5826 | + Raises: |
| 5827 | + ValueError: If the file path is invalid or the file content cannot |
| 5828 | + be parsed. |
| 5829 | + ImportError: If a required library like 'PyYAML' or |
| 5830 | + 'google-cloud-storage' is not installed. |
| 5831 | + IOError: If the file cannot be read from the specified path. |
| 5832 | + """ |
| 5833 | + file_extension = os.path.splitext(config_path)[1].lower() |
| 5834 | + if file_extension not in [".yaml", ".yml", ".json"]: |
| 5835 | + raise ValueError( |
| 5836 | + "Unsupported file extension for metric config. Must be .yaml," |
| 5837 | + " .yml, or .json" |
| 5838 | + ) |
| 5839 | + |
| 5840 | + content_str: str |
| 5841 | + if config_path.startswith("gs://"): |
| 5842 | + try: |
| 5843 | + from google.cloud import storage |
| 5844 | + |
| 5845 | + storage_client = storage.Client( |
| 5846 | + credentials=client._api_client._credentials if client else None |
| 5847 | + ) |
| 5848 | + path_without_prefix = config_path[len("gs://") :] |
| 5849 | + bucket_name, blob_path = path_without_prefix.split("/", 1) |
| 5850 | + |
| 5851 | + bucket = storage_client.bucket(bucket_name) |
| 5852 | + blob = bucket.blob(blob_path) |
| 5853 | + content_str = blob.download_as_bytes().decode("utf-8") |
| 5854 | + except ImportError as e: |
| 5855 | + raise ImportError( |
| 5856 | + "Reading from GCS requires the 'google-cloud-storage'" |
| 5857 | + " library. Please install it with 'pip install" |
| 5858 | + " google-cloud-aiplatform[evaluation]'." |
| 5859 | + ) from e |
| 5860 | + except Exception as e: |
| 5861 | + raise IOError(f"Failed to read from GCS path {config_path}: {e}") from e |
| 5862 | + else: |
| 5863 | + try: |
| 5864 | + with open(config_path, "r", encoding="utf-8") as f: |
| 5865 | + content_str = f.read() |
| 5866 | + except FileNotFoundError: |
| 5867 | + raise FileNotFoundError( |
| 5868 | + f"Local configuration file not found at: {config_path}" |
| 5869 | + ) |
| 5870 | + except Exception as e: |
| 5871 | + raise IOError(f"Failed to read local file {config_path}: {e}") from e |
| 5872 | + |
| 5873 | + data: Dict[str, Any] |
| 5874 | + |
| 5875 | + if file_extension in [".yaml", ".yml"]: |
| 5876 | + if yaml is None: |
| 5877 | + raise ImportError( |
| 5878 | + "YAML parsing requires the pyyaml library. Please install" |
| 5879 | + " it with 'pip install" |
| 5880 | + " google-cloud-aiplatform[evaluation]'." |
| 5881 | + ) |
| 5882 | + data = yaml.safe_load(content_str) |
| 5883 | + elif file_extension == ".json": |
| 5884 | + data = json.loads(content_str) |
| 5885 | + |
| 5886 | + if not isinstance(data, dict): |
| 5887 | + raise ValueError("Metric config content did not parse into a dictionary.") |
| 5888 | + |
| 5889 | + return cls.model_validate(data) |
| 5890 | + |
5806 | 5891 |
|
5807 | 5892 | class MetricDict(TypedDict, total=False):
|
5808 | 5893 | """The metric used for evaluation."""
|
|
0 commit comments