Skip to content

inference

A utility file for running inference with various LLM providers.

run_batch_inference_anthropic(model_name, messages, **kwargs) async

Run batch inference using an Anthropic model asynchronously.

Source code in docprompt/utils/inference.py
async def run_batch_inference_anthropic(  model_name: str, messages: List[List[OpenAIMessage]], **kwargs ) -> List[str]:  """Run batch inference using an Anthropic model asynchronously."""  retry_decorator = get_anthropic_retry_decorator()   @retry_decorator  async def process_message_set(msg_set):  return await run_inference_anthropic(model_name, msg_set, **kwargs)   tasks = [process_message_set(msg_set) for msg_set in messages]   responses: List[str] = []  for f in tqdm(asyncio.as_completed(tasks), desc="Processing messages"):  response = await f  responses.append(response)   return responses 

run_inference_anthropic(model_name, messages, **kwargs) async

Run inference using an Anthropic model asynchronously.

Source code in docprompt/utils/inference.py
async def run_inference_anthropic(  model_name: str, messages: List[OpenAIMessage], **kwargs ) -> str:  """Run inference using an Anthropic model asynchronously."""  from anthropic import AsyncAnthropic   api_key = kwargs.pop("api_key", os.environ.get("ANTHROPIC_API_KEY"))  client = AsyncAnthropic(api_key=api_key)   system = None  if messages and messages[0].role == "system":  system = messages[0].content  messages = messages[1:]   processed_messages = []  for msg in messages:  if isinstance(msg.content, list):  processed_content = []  for content in msg.content:  if isinstance(content, OpenAIComplexContent):  content = content.to_anthropic_message()  processed_content.append(content)  else:  pass  # raise ValueError(f"Invalid content type: {type(content)} Expected OpenAIComplexContent")   dumped = msg.model_dump()  dumped["content"] = processed_content  processed_messages.append(dumped)  else:  processed_messages.append(msg)   client_kwargs = {  "model": model_name,  "max_tokens": 2048,  "messages": processed_messages,  **kwargs,  }   if system:  client_kwargs["system"] = system   response = await client.messages.create(**client_kwargs)   content = response.content[0].text   return content