Skip to content
Open
24 changes: 14 additions & 10 deletions os_computer_use/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@
grounding_model = providers.OSAtlasProvider()
# grounding_model = providers.ShowUIProvider()

# vision_model = providers.FireworksProvider("llama3.2")
# vision_model = providers.OpenAIProvider("gpt-4o")
# vision_model = providers.AnthropicProvider("claude-3.5-sonnet")
vision_model = providers.GroqProvider("llama3.2")
# vision_model = providers.MistralProvider("pixtral") # pixtral-large-latest has vision capabilities

# Vision models using LiteLLM:
vision_model = providers.LiteLLMProvider("pixtral") # Mistral
# vision_model = providers.LiteLLMProvider("llama3.2", provider="fireworks") # Fireworks
# vision_model = providers.LiteLLMProvider("gpt-4-vision") # OpenAI
# vision_model = providers.LiteLLMProvider("llama3.2", provider="groq") # Groq
# vision_model = providers.LiteLLMProvider("claude-3-5-sonnet") # Anthropic
# vision_model = providers.LiteLLMProvider("gemini-2.0-flash", provider="gemini") # Gemini

# action_model = providers.FireworksProvider("llama3.3")
# action_model = providers.OpenAIProvider("gpt-4o")
# action_model = providers.AnthropicProvider("claude-3.5-sonnet")
action_model = providers.GroqProvider("llama3.3")
# action_model = providers.MistralProvider("large") # mistral-large-latest for non-vision tasks
# Action models using LiteLLM:
action_model = providers.LiteLLMProvider("large") # Mistral
# action_model = providers.LiteLLMProvider("llama3.3", provider="fireworks") # Fireworks
# action_model = providers.LiteLLMProvider("llama3.3", provider="groq") # Groq
# action_model = providers.LiteLLMProvider("gpt-4") # OpenAI
# action_model = providers.LiteLLMProvider("claude-3-5-sonnet") # Anthropic
# action_model = providers.LiteLLMProvider("gemini-2.0-flash", provider="gemini") # Gemini
147 changes: 80 additions & 67 deletions os_computer_use/llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import re
import base64
import imghdr


def Message(content, role="assistant"):
Expand All @@ -22,6 +23,29 @@ def parse_json(s):
return None


def extract_json_objects(s):
"""Extract all balanced JSON objects from a string."""
objects = []
brace_level = 0
start_index = None
for i, char in enumerate(s):
if char == "{":
if brace_level == 0:
start_index = i
brace_level += 1
elif char == "}":
brace_level -= 1
if brace_level == 0 and start_index is not None:
candidate = s[start_index : i + 1]
try:
obj = json.loads(candidate)
objects.append(obj)
except json.JSONDecodeError:
pass
start_index = None
return objects


class LLMProvider:
"""
The LLM provider is used to make calls to an LLM given a provider and model name, with optional tool use support
Expand Down Expand Up @@ -52,6 +76,13 @@ def create_function_schema(self, definitions):
properties[param_name] = {"type": "string", "description": param_desc}
required.append(param_name)

# Add a dummy property if no parameters are provided, because providers like Gemini require a non-empty "properties" object.
if not properties:
properties["noop"] = {
"type": "string",
"description": "Dummy parameter for function with no parameters.",
}

function_def = self.create_function_def(name, details, properties, required)
functions.append(function_def)

Expand All @@ -68,8 +99,7 @@ def create_tool_call(self, name, parameters):
# Wrap a content block in a text or an image object
def wrap_block(self, block):
if isinstance(block, bytes):
encoded_image = base64.b64encode(block).decode("utf-8")
return self.create_image_block(encoded_image)
return self.create_image_block(block)
else:
return Text(block)

Expand Down Expand Up @@ -117,10 +147,17 @@ def create_function_def(self, name, details, properties, required):
},
}

def create_image_block(self, base64_image):
def create_image_block(self, image_data):
# Detect the image type using imghdr.
image_type = imghdr.what(None, image_data)
if image_type is None:
image_type = "png" # fallback if type cannot be detected

# Base64-encode the raw image bytes.
encoded = base64.b64encode(image_data).decode("utf-8")
return {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
"image_url": {"url": f"data:image/{image_type};base64,{encoded}"},
}

def call(self, messages, functions=None):
Expand All @@ -140,18 +177,17 @@ def call(self, messages, functions=None):
if parse_json(tool_call.function.arguments) is not None
]

# Sometimes, function calls are returned unparsed by the inference provider. This code parses them manually.
# Sometimes, function calls are returned unparsed by the inference provider.
if message.content and not tool_calls:
tool_call_matches = re.search(r"\{.*\}", message.content)
if tool_call_matches:
tool_call = parse_json(tool_call_matches.group(0))
# Some models use "arguments" as the key instead of "parameters"
parameters = tool_call.get("parameters", tool_call.get("arguments"))
if tool_call.get("name") and parameters:
json_objs = extract_json_objects(message.content)
for obj in json_objs:
parameters = obj.get("parameters", obj.get("arguments"))
if obj.get("name") and parameters is not None:
combined_tool_calls.append(
self.create_tool_call(tool_call.get("name"), parameters)
self.create_tool_call(obj.get("name"), parameters)
)
return None, combined_tool_calls
if combined_tool_calls:
return None, combined_tool_calls

return message.content, combined_tool_calls

Expand All @@ -160,75 +196,52 @@ def call(self, messages, functions=None):
return message.content


class AnthropicBaseProvider(LLMProvider):
class LiteLLMBaseProvider(OpenAIBaseProvider):
"""Base provider using LiteLLM"""

def create_client(self):
return Anthropic(api_key=self.api_key).messages
from litellm import completion

def create_function_def(self, name, details, properties, required):
return {
"name": name,
"description": details["description"],
"input_schema": {
"type": "object",
"properties": properties,
"required": required,
},
}
import litellm

def create_image_block(self, base64_image):
return {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": base64_image,
},
}

def call(self, messages, functions=None):
tools = self.create_function_schema(functions) if functions else None

# Move all messages with the system role to a system parameter
system = "\n".join(
msg.get("content") for msg in messages if msg.get("role") == "system"
)
messages = [msg for msg in messages if msg.get("role") != "system"]

# Call the Anthropic API
completion = self.completion(
messages, system=system, tools=tools, max_tokens=4096
)
text = "".join(getattr(block, "text", "") for block in completion.content)
# Enable dropping unsupported params globally
litellm.drop_params = True
litellm.modify_params = True
# Enable debug mode for better error messages
# litellm._turn_on_debug()
return completion

# Return response text and tool calls separately
if functions:
tool_calls = [
self.create_tool_call(block.name, block.input)
for block in completion.content
if block.type == "tool_use"
]
return text, tool_calls
def completion(self, messages, **kwargs):
# Skip the tools parameter if it's None
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}

# Only return response text
else:
return text
# No need to remove tools; pass tools so that function calling works with Claude.

# Wrap content blocks in image or text objects if necessary
new_messages = [self.transform_message(message) for message in messages]

class MistralBaseProvider(OpenAIBaseProvider):
def create_function_def(self, name, details, properties, required):
# If description is wrapped in a dict, extract the inner string
if isinstance(details.get("description"), dict):
details["description"] = details["description"].get("description", "")
return super().create_function_def(name, details, properties, required)
# Call LiteLLM completion
completion_response = self.client(
model=self.model,
messages=new_messages,
api_key=self.api_key,
**filtered_kwargs,
)
return completion_response

# Added method to adjust the final message role for Mistral-based models only
def call(self, messages, functions=None):
if messages and messages[-1].get("role") == "assistant":
if (
"mistral" in self.model.lower()
and messages
and messages[-1].get("role") == "assistant"
):
prefix = messages.pop()["content"]
if messages and messages[-1].get("role") == "user":
messages[-1]["content"] = (
prefix + "\n" + messages[-1].get("content", "")
)
else:
messages.append({"role": "user", "content": prefix})

return super().call(messages, functions)
Loading