Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/agent/src/Toolbox/AgentProcessor.php
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ private function handleToolCallsCallback(Output $output): \Closure

$results = [];
foreach ($toolCalls as $toolCall) {
$result = $this->toolbox->execute($toolCall);
$results[] = new ToolResult($toolCall, $result);
$messages->add(Message::ofToolCall($toolCall, $this->resultConverter->convert($result)));
$results[] = $toolResult = $this->toolbox->execute($toolCall);
$messages->add(Message::ofToolCall($toolCall, $this->resultConverter->convert($toolResult)));
}

$event = new ToolCallsExecuted(...$results);
Expand Down
9 changes: 6 additions & 3 deletions src/agent/src/Toolbox/FaultTolerantToolbox.php
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,19 @@ public function getTools(): array
return $this->innerToolbox->getTools();
}

public function execute(ToolCall $toolCall): mixed
public function execute(ToolCall $toolCall): ToolResult
{
try {
return $this->innerToolbox->execute($toolCall);
} catch (ToolExecutionExceptionInterface $e) {
return $e->getToolCallResult();
return new ToolResult($toolCall, $e->getToolCallResult());
} catch (ToolNotFoundException) {
$names = array_map(fn (Tool $metadata) => $metadata->getName(), $this->getTools());

return \sprintf('Tool "%s" was not found, please use one of these: %s', $toolCall->getName(), implode(', ', $names));
return new ToolResult(
$toolCall,
\sprintf('Tool "%s" was not found, please use one of these: %s', $toolCall->getName(), implode(', ', $names))
);
}
}
}
4 changes: 2 additions & 2 deletions src/agent/src/Toolbox/Toolbox.php
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public function getTools(): array
return $this->map = $map;
}

public function execute(ToolCall $toolCall): mixed
public function execute(ToolCall $toolCall): ToolResult
{
$metadata = $this->getMetadata($toolCall);
$tool = $this->getExecutable($metadata);
Expand All @@ -93,7 +93,7 @@ public function execute(ToolCall $toolCall): mixed
throw ToolExecutionException::executionFailed($toolCall, $e);
}

return $result;
return new ToolResult($toolCall, $result);
}

private function getMetadata(ToolCall $toolCall): Tool
Expand Down
2 changes: 1 addition & 1 deletion src/agent/src/Toolbox/ToolboxInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ public function getTools(): array;
* @throws ToolExecutionExceptionInterface if the tool execution fails
* @throws ToolNotFoundException if the tool is not found
*/
public function execute(ToolCall $toolCall): mixed;
public function execute(ToolCall $toolCall): ToolResult;
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use Symfony\AI\Agent\Input;
use Symfony\AI\Agent\InputProcessor\SystemPromptInputProcessor;
use Symfony\AI\Agent\Toolbox\ToolboxInterface;
use Symfony\AI\Agent\Toolbox\ToolResult;
use Symfony\AI\Fixtures\Tool\ToolNoParams;
use Symfony\AI\Fixtures\Tool\ToolRequiredParams;
use Symfony\AI\Platform\Message\Content\File;
Expand Down Expand Up @@ -72,9 +73,9 @@ public function getTools(): array
return [];
}

public function execute(ToolCall $toolCall): mixed
public function execute(ToolCall $toolCall): ToolResult
{
return null;
return new ToolResult($toolCall, null);
}
},
);
Expand Down Expand Up @@ -110,9 +111,9 @@ public function getTools(): array
];
}

public function execute(ToolCall $toolCall): mixed
public function execute(ToolCall $toolCall): ToolResult
{
return null;
return new ToolResult($toolCall, null);
}
},
$this->getTranslator(),
Expand Down Expand Up @@ -153,9 +154,9 @@ public function getTools(): array
];
}

public function execute(ToolCall $toolCall): mixed
public function execute(ToolCall $toolCall): ToolResult
{
return null;
return new ToolResult($toolCall, null);
}
},
);
Expand Down
19 changes: 13 additions & 6 deletions src/agent/tests/Toolbox/AgentProcessorTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
use Symfony\AI\Agent\Output;
use Symfony\AI\Agent\Toolbox\AgentProcessor;
use Symfony\AI\Agent\Toolbox\ToolboxInterface;
use Symfony\AI\Agent\Toolbox\ToolResult;
use Symfony\AI\Platform\Message\AssistantMessage;
use Symfony\AI\Platform\Message\MessageBag;
use Symfony\AI\Platform\Message\ToolCallMessage;
Expand Down Expand Up @@ -72,12 +73,15 @@ public function testProcessInputWithRegisteredToolsButToolOverride()

public function testProcessOutputWithToolCallResponseKeepingMessages()
{
$toolCall = new ToolCall('id1', 'tool1', ['arg1' => 'value1']);
$toolbox = $this->createMock(ToolboxInterface::class);
$toolbox->expects($this->once())->method('execute')->willReturn('Test response');
$toolbox
->expects($this->once())
->method('execute')
->willReturn(new ToolResult($toolCall, 'Test response'));

$messageBag = new MessageBag();

$result = new ToolCallResult(new ToolCall('id1', 'tool1', ['arg1' => 'value1']));
$result = new ToolCallResult($toolCall);

$agent = $this->createStub(AgentInterface::class);

Expand All @@ -95,12 +99,15 @@ public function testProcessOutputWithToolCallResponseKeepingMessages()

public function testProcessOutputWithToolCallResponseForgettingMessages()
{
$toolCall = new ToolCall('id1', 'tool1', ['arg1' => 'value1']);
$toolbox = $this->createMock(ToolboxInterface::class);
$toolbox->expects($this->once())->method('execute')->willReturn('Test response');
$toolbox
->expects($this->once())
->method('execute')
->willReturn(new ToolResult($toolCall, 'Test response'));

$messageBag = new MessageBag();

$result = new ToolCallResult(new ToolCall('id1', 'tool1', ['arg1' => 'value1']));
$result = new ToolCallResult($toolCall);

$agent = $this->createStub(AgentInterface::class);

Expand Down
9 changes: 5 additions & 4 deletions src/agent/tests/Toolbox/FaultTolerantToolboxTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
use Symfony\AI\Agent\Toolbox\Exception\ToolNotFoundException;
use Symfony\AI\Agent\Toolbox\FaultTolerantToolbox;
use Symfony\AI\Agent\Toolbox\ToolboxInterface;
use Symfony\AI\Agent\Toolbox\ToolResult;
use Symfony\AI\Fixtures\Tool\ToolNoParams;
use Symfony\AI\Fixtures\Tool\ToolRequiredParams;
use Symfony\AI\Platform\Result\ToolCall;
Expand All @@ -37,7 +38,7 @@ public function testFaultyToolExecution()
$toolCall = new ToolCall('987654321', 'tool_foo');
$actual = $faultTolerantToolbox->execute($toolCall);

$this->assertSame($expected, $actual);
$this->assertSame($expected, $actual->getResult());
}

public function testFaultyToolCall()
Expand All @@ -52,7 +53,7 @@ public function testFaultyToolCall()
$toolCall = new ToolCall('123456789', 'tool_xyz');
$actual = $faultTolerantToolbox->execute($toolCall);

$this->assertSame($expected, $actual);
$this->assertSame($expected, $actual->getResult());
}

public function testCustomToolExecutionException()
Expand All @@ -72,7 +73,7 @@ public function getToolCallResult(): array
$toolCall = new ToolCall('123456789', 'tool_xyz');
$actual = $faultTolerantToolbox->execute($toolCall);

$this->assertSame($expected, $actual);
$this->assertSame($expected, $actual->getResult());
}

private function createFaultyToolbox(\Closure $exceptionFactory): ToolboxInterface
Expand All @@ -93,7 +94,7 @@ public function getTools(): array
];
}

public function execute(ToolCall $toolCall): mixed
public function execute(ToolCall $toolCall): ToolResult
{
throw ($this->exceptionFactory)($toolCall);
}
Expand Down
11 changes: 7 additions & 4 deletions src/agent/tests/Toolbox/ToolboxTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use Symfony\AI\Agent\Toolbox\ToolFactory\ChainFactory;
use Symfony\AI\Agent\Toolbox\ToolFactory\MemoryToolFactory;
use Symfony\AI\Agent\Toolbox\ToolFactory\ReflectionToolFactory;
use Symfony\AI\Agent\Toolbox\ToolResult;
use Symfony\AI\Fixtures\Tool\ToolCustomException;
use Symfony\AI\Fixtures\Tool\ToolDate;
use Symfony\AI\Fixtures\Tool\ToolException;
Expand Down Expand Up @@ -180,9 +181,11 @@ public function testExecuteWithCustomException()
#[DataProvider('executeProvider')]
public function testExecute(string $expected, string $toolName, array $toolPayload = [])
{
$this->assertSame(
$expected,
$this->toolbox->execute(new ToolCall('call_1234', $toolName, $toolPayload)),
$toolCall = new ToolCall('call_1234', $toolName, $toolPayload);

$this->assertEquals(
new ToolResult($toolCall, $expected),
$this->toolbox->execute($toolCall),
);
}

Expand Down Expand Up @@ -244,7 +247,7 @@ public function testToolboxExecutionWithMemoryFactory()
$toolbox = new Toolbox([new ToolNoAttribute1()], $memoryFactory);
$result = $toolbox->execute(new ToolCall('call_1234', 'happy_birthday', ['name' => 'John', 'years' => 30]));

$this->assertSame('Happy Birthday, John! You are 30 years old.', $result);
$this->assertSame('Happy Birthday, John! You are 30 years old.', $result->getResult());
}

public function testToolboxMapWithOverrideViaChain()
Expand Down
5 changes: 2 additions & 3 deletions src/ai-bundle/src/Profiler/DataCollector.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace Symfony\AI\AiBundle\Profiler;

use Symfony\AI\Agent\Toolbox\ToolboxInterface;
use Symfony\AI\Platform\Model;
use Symfony\AI\Agent\Toolbox\ToolResult;
use Symfony\AI\Platform\Tool\Tool;
use Symfony\Bundle\FrameworkBundle\DataCollector\AbstractDataCollector;
use Symfony\Component\HttpFoundation\Request;
Expand All @@ -23,7 +23,6 @@
* @author Christopher Hertel <mail@christopher-hertel.de>
*
* @phpstan-import-type PlatformCallData from TraceablePlatform
* @phpstan-import-type ToolCallData from TraceableToolbox
*/
final class DataCollector extends AbstractDataCollector implements LateDataCollectorInterface
{
Expand Down Expand Up @@ -86,7 +85,7 @@ public function getTools(): array
}

/**
* @return ToolCallData[]
* @return ToolResult[]
*/
public function getToolCalls(): array
{
Expand Down
19 changes: 4 additions & 15 deletions src/ai-bundle/src/Profiler/TraceableToolbox.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,16 @@
namespace Symfony\AI\AiBundle\Profiler;

use Symfony\AI\Agent\Toolbox\ToolboxInterface;
use Symfony\AI\Agent\Toolbox\ToolResult;
use Symfony\AI\Platform\Result\ToolCall;

/**
* @author Christopher Hertel <mail@christopher-hertel.de>
*
* @phpstan-type ToolCallData array{
* call: ToolCall,
* result: string,
* }
*/
final class TraceableToolbox implements ToolboxInterface
{
/**
* @var ToolCallData[]
* @var ToolResult[]
*/
public array $calls = [];

Expand All @@ -39,15 +35,8 @@ public function getTools(): array
return $this->toolbox->getTools();
}

public function execute(ToolCall $toolCall): mixed
public function execute(ToolCall $toolCall): ToolResult
{
$result = $this->toolbox->execute($toolCall);

$this->calls[] = [
'call' => $toolCall,
'result' => $result,
];

return $result;
return $this->calls[] = $this->toolbox->execute($toolCall);
}
}
10 changes: 5 additions & 5 deletions src/ai-bundle/templates/data_collector.html.twig
Original file line number Diff line number Diff line change
Expand Up @@ -221,25 +221,25 @@

<h3>Tool Calls</h3>
{% if collector.toolCalls|length %}
{% for call in collector.toolCalls %}
{% for toolResult in collector.toolCalls %}
<table class="table">
<thead>
<tr>
<th colspan="2">{{ call.call.name }}</th>
<th colspan="2">{{ toolResult.toolCall.name }}</th>
</tr>
</thead>
<tbody>
<tr>
<th>ID</th>
<td>{{ call.call.id }}</td>
<td>{{ toolResult.toolCall.id }}</td>
</tr>
<tr>
<th>Arguments</th>
<td>{{ dump(call.call.arguments) }}</td>
<td>{{ dump(toolResult.toolCall.arguments) }}</td>
</tr>
<tr>
<th>Result</th>
<td>{{ dump(call.result) }}</td>
<td>{{ dump(toolResult.result) }}</td>
</tr>
</tbody>
</table>
Expand Down
11 changes: 6 additions & 5 deletions src/ai-bundle/tests/Profiler/TraceableToolboxTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

use PHPUnit\Framework\TestCase;
use Symfony\AI\Agent\Toolbox\ToolboxInterface;
use Symfony\AI\Agent\Toolbox\ToolResult;
use Symfony\AI\AiBundle\Profiler\TraceableToolbox;
use Symfony\AI\Platform\Result\ToolCall;
use Symfony\AI\Platform\Tool\ExecutionReference;
Expand Down Expand Up @@ -40,10 +41,10 @@ public function testExecute()

$result = $traceableToolbox->execute($toolCall);

$this->assertSame('tool_result', $result);
$this->assertSame('tool_result', $result->getResult());
$this->assertCount(1, $traceableToolbox->calls);
$this->assertSame($toolCall, $traceableToolbox->calls[0]['call']);
$this->assertSame('tool_result', $traceableToolbox->calls[0]['result']);
$this->assertSame($toolCall, $traceableToolbox->calls[0]->getToolCall());
$this->assertSame('tool_result', $traceableToolbox->calls[0]->getResult());
}

/**
Expand All @@ -62,9 +63,9 @@ public function getTools(): array
return $this->tools;
}

public function execute(ToolCall $toolCall): string
public function execute(ToolCall $toolCall): ToolResult
{
return 'tool_result';
return new ToolResult($toolCall, 'tool_result');
}
};
}
Expand Down