A utility file for running inference with various LLM providers.
run_batch_inference_anthropic(model_name, messages, **kwargs) async
Run batch inference using an Anthropic model asynchronously.
Source code in docprompt/utils/inference.py
| async def run_batch_inference_anthropic( model_name: str, messages: List[List[OpenAIMessage]], **kwargs ) -> List[str]: """Run batch inference using an Anthropic model asynchronously.""" retry_decorator = get_anthropic_retry_decorator() @retry_decorator async def process_message_set(msg_set): return await run_inference_anthropic(model_name, msg_set, **kwargs) tasks = [process_message_set(msg_set) for msg_set in messages] responses: List[str] = [] for f in tqdm(asyncio.as_completed(tasks), desc="Processing messages"): response = await f responses.append(response) return responses
|
run_inference_anthropic(model_name, messages, **kwargs) async
Run inference using an Anthropic model asynchronously.
Source code in docprompt/utils/inference.py
| async def run_inference_anthropic( model_name: str, messages: List[OpenAIMessage], **kwargs ) -> str: """Run inference using an Anthropic model asynchronously.""" from anthropic import AsyncAnthropic api_key = kwargs.pop("api_key", os.environ.get("ANTHROPIC_API_KEY")) client = AsyncAnthropic(api_key=api_key) system = None if messages and messages[0].role == "system": system = messages[0].content messages = messages[1:] processed_messages = [] for msg in messages: if isinstance(msg.content, list): processed_content = [] for content in msg.content: if isinstance(content, OpenAIComplexContent): content = content.to_anthropic_message() processed_content.append(content) else: pass # raise ValueError(f"Invalid content type: {type(content)} Expected OpenAIComplexContent") dumped = msg.model_dump() dumped["content"] = processed_content processed_messages.append(dumped) else: processed_messages.append(msg) client_kwargs = { "model": model_name, "max_tokens": 2048, "messages": processed_messages, **kwargs, } if system: client_kwargs["system"] = system response = await client.messages.create(**client_kwargs) content = response.content[0].text return content
|