-
Couldn't load subscription status.
- Fork 4
Closed
Description
This is really batch:
# Copyright 2023 RisingWave Labs # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import inspect from concurrent.futures import ThreadPoolExecutor from typing import Iterator, Callable, Optional, Union, List, Dict import pyarrow as pa from pyarrow.flight import FlightServerBase, FlightInfo # comes from Databend MAX_DECIMAL128_PRECISION = 38 MAX_DECIMAL256_PRECISION = 76 EXTENSION_KEY = b"Extension" ARROW_EXT_TYPE_VARIANT = b"Variant" TIMESTAMP_UINT = "us" logger = logging.getLogger(__name__) class UserDefinedFunction: """ Base interface for user-defined function. """ _name: str _input_schema: pa.Schema _result_schema: pa.Schema def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: """ Apply the function on a batch of inputs. """ return iter([]) class ScalarFunction(UserDefinedFunction): """ Base interface for user-defined scalar function. A user-defined scalar functions maps zero, one, or multiple scalar values to a new scalar value. """ _func: Callable _executor: Optional[ThreadPoolExecutor] def __init__( self, func, input_types, result_type, name=None, io_threads=None, skip_null=None ): self._func = func self._input_schema = pa.schema( field.with_name(arg_name) for arg_name, field in zip( inspect.getfullargspec(func)[0], [_to_arrow_field(t) for t in _to_list(input_types)], ) ) self._result_schema = pa.schema( [_to_arrow_field(result_type).with_name("output")] ) self._name = name or ( func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ ) self._io_threads = io_threads self._executor = ( ThreadPoolExecutor(max_workers=self._io_threads) if self._io_threads is not None else None ) if skip_null and not self._result_schema.field(0).nullable: raise ValueError( f"Return type of function {self._name} must be nullable when skip_null is True" ) super().__init__() def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: # Convert the RecordBatch to a list of arrays with Python objects inputs = [[v.as_py() for v in array] for array in batch] # Apply the input processing function to each array inputs = [ _input_process_func(_list_field(field))(array) for array, field in zip(inputs, self._input_schema) ] # Evaluate the function for the entire batch args = inputs column = self._func(*args) # Apply the output processing function to the result column = _output_process_func(_list_field(self._result_schema.field(0)))(column) # Convert the result to a PyArrow array and yield a RecordBatch array = pa.array(column, type=self._result_schema.types[0]) output_batch = pa.RecordBatch.from_arrays([array], schema=self._result_schema) yield output_batch def __call__(self, *args): return self._func(*args) def udf( input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], result_type: Union[str, pa.DataType], name: Optional[str] = None, ) -> Callable: return lambda f: ScalarFunction(f, input_types, result_type, name) class UDFServer(FlightServerBase): """ A server that provides user-defined functions to clients. Example: ``` server = UdfServer(location="0.0.0.0:8815") server.add_function(my_udf) server.serve() ``` """ _location: str _functions: Dict[str, UserDefinedFunction] def __init__(self, location="0.0.0.0:8815", **kwargs): super(UDFServer, self).__init__("grpc://" + location, **kwargs) self._location = location self._functions = {} def get_flight_info(self, context, descriptor): """Return the result schema of a function.""" func_name = descriptor.path[0].decode("utf-8") if func_name not in self._functions: raise ValueError(f"Function {func_name} does not exists") udf = self._functions[func_name] # return the concatenation of input and output schema full_schema = pa.schema(list(udf._input_schema) + list(udf._result_schema)) return FlightInfo( schema=full_schema, descriptor=descriptor, endpoints=[], total_records=len(full_schema), total_bytes=0, ) def do_exchange(self, context, descriptor, reader, writer): """Call a function from the client.""" func_name = descriptor.path[0].decode("utf-8") if func_name not in self._functions: raise ValueError(f"Function {func_name} does not exists") udf = self._functions[func_name] writer.begin(udf._result_schema) try: for batch in reader: for output_batch in udf.eval_batch(batch.data): writer.write_batch(output_batch) except Exception as e: logger.exception(e) raise e def add_function(self, udf: UserDefinedFunction): """Add a function to the server.""" name = udf._name if name in self._functions: raise ValueError("Function already exists: " + name) self._functions[name] = udf input_types = ", ".join( _arrow_field_to_string(field) for field in udf._input_schema ) output_type = _arrow_field_to_string(udf._result_schema[0]) sql = ( f"CREATE FUNCTION {name} ({input_types}) " f"RETURNS {output_type} LANGUAGE python " f"HANDLER = '{name}' ADDRESS = 'http://{self._location}';" ) logger.info(f"added function: {name}, SQL:\n{sql}\n") def serve(self): """Start the server.""" logger.info(f"listening on {self._location}") super(UDFServer, self).serve() def _input_process_func(field: pa.Field) -> Callable: """ Return a function to process input value. - Tuple=pa.struct(): dict -> tuple - Json=pa.large_binary(): bytes -> Any - Map=pa.map_(): list[tuple(k,v)] -> dict """ if pa.types.is_list(field.type): func = _input_process_func(field.type.value_field) return ( lambda array: [func(v) if v is not None else None for v in array] if array is not None else None ) if pa.types.is_struct(field.type): funcs = [_input_process_func(f) for f in field.type] # the input value of struct type is a dict # we convert it into tuple here return ( lambda map: tuple( func(v) if v is not None else None for v, func in zip(map.values(), funcs) ) if map is not None else None ) if pa.types.is_map(field.type): funcs = [ _input_process_func(field.type.key_field), _input_process_func(field.type.item_field), ] # list[tuple[k,v]] -> dict return ( lambda array: dict( tuple(func(v) for v, func in zip(item, funcs)) for item in array ) if array is not None else None ) if pa.types.is_large_binary(field.type): if _field_is_variant(field): return lambda v: json.loads(v) if v is not None else None return lambda v: v def _output_process_func(field: pa.Field) -> Callable: """ Return a function to process output value. - Json=pa.large_binary(): Any -> str - Map=pa.map_(): dict -> list[tuple(k,v)] """ if pa.types.is_list(field.type): func = _output_process_func(field.type.value_field) return ( lambda array: [func(v) if v is not None else None for v in array] if array is not None else None ) if pa.types.is_struct(field.type): funcs = [_output_process_func(f) for f in field.type] return ( lambda tup: tuple( func(v) if v is not None else None for v, func in zip(tup, funcs) ) if tup is not None else None ) if pa.types.is_map(field.type): funcs = [ _output_process_func(field.type.key_field), _output_process_func(field.type.item_field), ] # dict -> list[tuple[k,v]] return ( lambda map: [ tuple(func(v) for v, func in zip(item, funcs)) for item in map.items() ] if map is not None else None ) if pa.types.is_large_binary(field.type): if _field_is_variant(field): return lambda v: json.dumps(_ensure_str(v)) if v is not None else None return lambda v: v def _null_func(*args): return None def _list_field(field: pa.Field) -> pa.Field: return pa.field("", pa.list_(field)) def _to_list(x): if isinstance(x, list): return x else: return [x] def _ensure_str(x): if isinstance(x, bytes): return x.decode("utf-8") elif isinstance(x, list): return [_ensure_str(v) for v in x] elif isinstance(x, dict): return {_ensure_str(k): _ensure_str(v) for k, v in x.items()} else: return x def _field_is_variant(field: pa.Field) -> bool: if field.metadata is None: return False if field.metadata.get(EXTENSION_KEY) == ARROW_EXT_TYPE_VARIANT: return True return False def _to_arrow_field(t: Union[str, pa.DataType]) -> pa.Field: """ Convert a string or pyarrow.DataType to pyarrow.Field. """ if isinstance(t, str): return _type_str_to_arrow_field(t) else: return pa.field("", t, False) def _type_str_to_arrow_field(type_str: str) -> pa.Field: """ Convert a SQL data type to `pyarrow.Field`. """ type_str = type_str.strip().upper() nullable = True if type_str.endswith("NULL"): type_str = type_str[:-4].strip() if type_str.endswith("NOT"): type_str = type_str[:-3].strip() nullable = False return _type_str_to_arrow_field_inner(type_str).with_nullable(nullable) def _type_str_to_arrow_field_inner(type_str: str) -> pa.Field: type_str = type_str.strip().upper() if type_str in ("BOOLEAN", "BOOL"): return pa.field("", pa.bool_(), False) elif type_str in ("TINYINT", "INT8"): return pa.field("", pa.int8(), False) elif type_str in ("SMALLINT", "INT16"): return pa.field("", pa.int16(), False) elif type_str in ("INT", "INTEGER", "INT32"): return pa.field("", pa.int32(), False) elif type_str in ("BIGINT", "INT64"): return pa.field("", pa.int64(), False) elif type_str in ("TINYINT UNSIGNED", "UINT8"): return pa.field("", pa.uint8(), False) elif type_str in ("SMALLINT UNSIGNED", "UINT16"): return pa.field("", pa.uint16(), False) elif type_str in ("INT UNSIGNED", "INTEGER UNSIGNED", "UINT32"): return pa.field("", pa.uint32(), False) elif type_str in ("BIGINT UNSIGNED", "UINT64"): return pa.field("", pa.uint64(), False) elif type_str in ("FLOAT", "FLOAT32"): return pa.field("", pa.float32(), False) elif type_str in ("FLOAT64", "DOUBLE"): return pa.field("", pa.float64(), False) elif type_str == "DATE": return pa.field("", pa.date32(), False) elif type_str in ("DATETIME", "TIMESTAMP"): return pa.field("", pa.timestamp(TIMESTAMP_UINT), False) elif type_str in ("STRING", "VARCHAR", "CHAR", "CHARACTER", "TEXT"): return pa.field("", pa.large_utf8(), False) elif type_str in ("BINARY"): return pa.field("", pa.large_binary(), False) elif type_str in ("VARIANT", "JSON"): # In Databend, JSON type is identified by the "EXTENSION" key in the metadata. return pa.field( "", pa.large_binary(), nullable=False, metadata={EXTENSION_KEY: ARROW_EXT_TYPE_VARIANT}, ) elif type_str.startswith("NULLABLE"): type_str = type_str[8:].strip("()").strip() return _type_str_to_arrow_field_inner(type_str).with_nullable(True) elif type_str.endswith("NULL"): type_str = type_str[:-4].strip() return _type_str_to_arrow_field_inner(type_str).with_nullable(True) elif type_str.startswith("DECIMAL"): # DECIMAL(precision, scale) str_list = type_str[7:].strip("()").split(",") precision = int(str_list[0].strip()) scale = int(str_list[1].strip()) if precision < 1 or precision > MAX_DECIMAL256_PRECISION: raise ValueError( f"Decimal precision must be between 1 and {MAX_DECIMAL256_PRECISION}" ) elif scale > precision: raise ValueError( f"Decimal scale must be between 0 and precision {precision}" ) if precision < MAX_DECIMAL128_PRECISION: return pa.field("", pa.decimal128(precision, scale), False) else: return pa.field("", pa.decimal256(precision, scale), False) elif type_str.startswith("ARRAY"): # ARRAY(INT) type_str = type_str[5:].strip("()").strip() return pa.field("", pa.list_(_type_str_to_arrow_field_inner(type_str)), False) elif type_str.startswith("MAP"): # MAP(STRING, INT) str_list = type_str[3:].strip("()").split(",") key_field = _type_str_to_arrow_field_inner(str_list[0].strip()) val_field = _type_str_to_arrow_field_inner(str_list[1].strip()) return pa.field("", pa.map_(key_field, val_field), False) elif type_str.startswith("TUPLE"): # TUPLE(STRING, INT, INT) str_list = type_str[5:].strip("()").split(",") fields = [] for type_str in str_list: type_str = type_str.strip() fields.append(_type_str_to_arrow_field_inner(type_str)) return pa.field("", pa.struct(fields), False) else: raise ValueError(f"Unsupported type: {type_str}") def _arrow_field_to_string(field: pa.Field) -> str: """ Convert a `pyarrow.Field` to a SQL data type string. """ type_str = _field_type_to_string(field) return f"{type_str} NOT NULL" if not field.nullable else type_str def _inner_field_to_string(field: pa.Field) -> str: # inner field default is NOT NULL in databend type_str = _field_type_to_string(field) return f"{type_str} NULL" if field.nullable else type_str def _field_type_to_string(field: pa.Field) -> str: """ Convert a `pyarrow.DataType` to a SQL data type string. """ t = field.type if pa.types.is_boolean(t): return "BOOLEAN" elif pa.types.is_int8(t): return "TINYINT" elif pa.types.is_int16(t): return "SMALLINT" elif pa.types.is_int32(t): return "INT" elif pa.types.is_int64(t): return "BIGINT" elif pa.types.is_uint8(t): return "TINYINT UNSIGNED" elif pa.types.is_uint16(t): return "SMALLINT UNSIGNED" elif pa.types.is_uint32(t): return "INT UNSIGNED" elif pa.types.is_uint64(t): return "BIGINT UNSIGNED" elif pa.types.is_float32(t): return "FLOAT" elif pa.types.is_float64(t): return "DOUBLE" elif pa.types.is_decimal(t): return f"DECIMAL({t.precision}, {t.scale})" elif pa.types.is_date32(t): return "DATE" elif pa.types.is_timestamp(t): return "TIMESTAMP" elif pa.types.is_large_unicode(t) or pa.types.is_unicode(t): return "VARCHAR" elif pa.types.is_large_binary(t) or pa.types.is_binary(t): if _field_is_variant(field): return "VARIANT" else: return "BINARY" elif pa.types.is_list(t): return f"ARRAY({_inner_field_to_string(t.value_field)})" elif pa.types.is_map(t): return f"MAP({_inner_field_to_string(t.key_field)}, {_inner_field_to_string(t.item_field)})" elif pa.types.is_struct(t): args_str = ", ".join(_inner_field_to_string(field) for field in t) return f"TUPLE({args_str})" else: raise ValueError(f"Unsupported type: {t}") Like we define a function as:
CREATE OR REPLACE FUNCTION ping (STRING) RETURNS STRING LANGUAGE python HANDLER = 'ping' ADDRESS = 'https://your-address';The udf server code is:
import logging from databend_udf import udf logger = logging.getLogger(__name__) @udf( input_types=["STRING"], result_type="STRING", ) def ping(inputs: list[str]) -> list[str]: logger.info(f"ping function called with {len(inputs)} inputs") results = [] try: for input_str in inputs: logger.info(f"Processing input: {input_str}") result = "pong" logger.info(f"Returning result: {result}") results.append(result) except Exception as e: logger.error(f"Error in ping function: {e}") results = [ "error" for _ in inputs ] # Return 'error' for all inputs in case of error return results # Debug: Print confirmation of UDF definition logger.info("Defined UDF: ping") Query:
select ping(data) from texts; Metadata
Metadata
Assignees
Labels
No labels