Skip to content

Commit 7483218

Browse files
committed
feat: tool approval flow
1 parent e2c680e commit 7483218

84 files changed

Lines changed: 3194 additions & 21 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/Concerns/CallsTools.php

Lines changed: 171 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,16 @@
1515
use Prism\Prism\Streaming\Events\ArtifactEvent;
1616
use Prism\Prism\Streaming\Events\StepFinishEvent;
1717
use Prism\Prism\Streaming\Events\StreamEndEvent;
18+
use Prism\Prism\Streaming\Events\StreamStartEvent;
19+
use Prism\Prism\Streaming\Events\ToolApprovalRequestEvent;
1820
use Prism\Prism\Streaming\Events\ToolResultEvent;
1921
use Prism\Prism\Streaming\StreamState;
22+
use Prism\Prism\Structured\Request as StructuredRequest;
23+
use Prism\Prism\Text\Request as TextRequest;
2024
use Prism\Prism\Tool;
25+
use Prism\Prism\ValueObjects\Messages\AssistantMessage;
26+
use Prism\Prism\ValueObjects\Messages\ToolResultMessage;
27+
use Prism\Prism\ValueObjects\ToolApprovalResponse;
2128
use Prism\Prism\ValueObjects\ToolCall;
2229
use Prism\Prism\ValueObjects\ToolOutput;
2330
use Prism\Prism\ValueObjects\ToolResult;
@@ -51,11 +58,12 @@ protected function callTools(array $tools, array $toolCalls, bool &$hasPendingTo
5158
* @param Tool[] $tools
5259
* @param ToolCall[] $toolCalls
5360
* @param ToolResult[] $toolResults Results are collected into this array by reference
54-
* @return Generator<ToolResultEvent|ArtifactEvent>
61+
* @return Generator<ToolResultEvent|ArtifactEvent|ToolApprovalRequestEvent>
5562
*/
5663
protected function callToolsAndYieldEvents(array $tools, array $toolCalls, string $messageId, array &$toolResults, bool &$hasPendingToolCalls): Generator
5764
{
58-
$serverToolCalls = $this->filterServerExecutedToolCalls($tools, $toolCalls, $hasPendingToolCalls);
65+
$approvalRequiredToolCalls = [];
66+
$serverToolCalls = $this->filterServerExecutedToolCalls($tools, $toolCalls, $hasPendingToolCalls, $approvalRequiredToolCalls);
5967

6068
$groupedToolCalls = $this->groupToolCallsByConcurrency($tools, $serverToolCalls);
6169

@@ -70,16 +78,26 @@ protected function callToolsAndYieldEvents(array $tools, array $toolCalls, strin
7078
yield $event;
7179
}
7280
}
81+
82+
foreach ($approvalRequiredToolCalls as $toolCall) {
83+
yield new ToolApprovalRequestEvent(
84+
id: EventID::generate(),
85+
timestamp: time(),
86+
toolCall: $toolCall,
87+
messageId: $messageId,
88+
);
89+
}
7390
}
7491

7592
/**
76-
* Filter out client-executed tool calls, setting the pending flag if any are found.
93+
* Filter out client-executed and approval-required tool calls, setting the pending flag if any are found.
7794
*
7895
* @param Tool[] $tools
7996
* @param ToolCall[] $toolCalls
97+
* @param ToolCall[] $approvalRequiredToolCalls Collected approval-required tool calls (by reference)
8098
* @return array<int, ToolCall> Server-executed tool calls with original indices preserved
8199
*/
82-
protected function filterServerExecutedToolCalls(array $tools, array $toolCalls, bool &$hasPendingToolCalls): array
100+
protected function filterServerExecutedToolCalls(array $tools, array $toolCalls, bool &$hasPendingToolCalls, array &$approvalRequiredToolCalls = []): array
83101
{
84102
$serverToolCalls = [];
85103

@@ -93,6 +111,13 @@ protected function filterServerExecutedToolCalls(array $tools, array $toolCalls,
93111
continue;
94112
}
95113

114+
if ($tool->needsApproval($toolCall->arguments())) {
115+
$hasPendingToolCalls = true;
116+
$approvalRequiredToolCalls[] = $toolCall;
117+
118+
continue;
119+
}
120+
96121
$serverToolCalls[$index] = $toolCall;
97122
} catch (PrismException) {
98123
// Unknown tool - keep it so error handling works in executeToolCall
@@ -258,6 +283,148 @@ protected function yieldToolCallsFinishEvents(StreamState $state): Generator
258283
);
259284
}
260285

286+
/**
287+
* Resolve pending tool approvals from a previous request (non-streaming).
288+
*
289+
* Scans request messages for a ToolResultMessage with toolApprovalResponses after
290+
* the last AssistantMessage. If found, executes approved tools, creates denial
291+
* results for denied tools, and replaces it with a ToolResultMessage containing
292+
* merged tool results (existing + resolved) and the consumed approval responses.
293+
*/
294+
protected function resolveToolApprovals(StructuredRequest|TextRequest $request): void
295+
{
296+
foreach ($this->resolveToolApprovalsAndYieldEvents($request, EventID::generate()) as $event) {
297+
// Events are discarded for non-streaming handlers
298+
}
299+
}
300+
301+
/**
302+
* @return Generator<StreamStartEvent|ToolResultEvent|ArtifactEvent>
303+
*/
304+
protected function resolveToolApprovalsAndYieldEvents(StructuredRequest|TextRequest $request, string $messageId, ?StreamState $state = null): Generator
305+
{
306+
$messages = $request->messages();
307+
308+
$assistantMessage = null;
309+
$assistantMessageIndex = null;
310+
311+
for ($i = count($messages) - 1; $i >= 0; $i--) {
312+
if ($messages[$i] instanceof AssistantMessage && $messages[$i]->toolCalls !== []) {
313+
$assistantMessage = $messages[$i];
314+
$assistantMessageIndex = $i;
315+
316+
break;
317+
}
318+
}
319+
320+
if (! $assistantMessage instanceof AssistantMessage || $assistantMessageIndex === null) {
321+
return;
322+
}
323+
324+
$toolsByName = collect($request->tools())->keyBy(fn (Tool $tool): string => $tool->name());
325+
$isAnyToolApprovalConfigured = collect($assistantMessage->toolCalls)->contains(
326+
fn (ToolCall $toolCall): bool => $toolsByName->get($toolCall->name)?->hasApprovalConfigured() === true,
327+
);
328+
329+
if (! $isAnyToolApprovalConfigured) {
330+
return;
331+
}
332+
333+
$toolMessage = null;
334+
$toolMessageIndex = null;
335+
$counter = count($messages);
336+
337+
for ($i = $assistantMessageIndex + 1; $i < $counter; $i++) {
338+
if ($messages[$i] instanceof ToolResultMessage) {
339+
$toolMessage = $messages[$i];
340+
$toolMessageIndex = $i;
341+
342+
break;
343+
}
344+
}
345+
346+
if (! $toolMessage instanceof ToolResultMessage) {
347+
$toolMessage = new ToolResultMessage;
348+
$toolMessageIndex = null;
349+
}
350+
351+
$approvalResolvedToolResults = [];
352+
353+
foreach ($assistantMessage->toolCalls as $toolCall) {
354+
$approval = $toolMessage->findByApprovalId($toolCall->id);
355+
356+
if (! $approval instanceof ToolApprovalResponse) {
357+
if (collect($toolMessage->toolResults)->contains(fn (ToolResult $tr): bool => $tr->toolCallId === $toolCall->id)) { // tool already executed
358+
continue;
359+
}
360+
if (! ($toolsByName->get($toolCall->name)?->hasApprovalConfigured() === true)) {
361+
continue;
362+
}
363+
364+
$approval = new ToolApprovalResponse($toolCall->id, false, 'No approval response provided');
365+
}
366+
367+
if ($state instanceof StreamState && $state->shouldEmitStreamStart()) {
368+
yield new StreamStartEvent(
369+
id: EventID::generate(),
370+
timestamp: time(),
371+
model: $request->model(),
372+
provider: $request->provider(),
373+
);
374+
375+
$state->markStreamStarted();
376+
}
377+
378+
if ($approval->approved) {
379+
$result = $this->executeToolCall($request->tools(), $toolCall, $messageId);
380+
381+
$approvalResolvedToolResults[] = $result['toolResult'];
382+
383+
foreach ($result['events'] as $event) {
384+
yield $event;
385+
}
386+
387+
continue;
388+
}
389+
390+
$reason = $approval->reason ?? 'User denied tool execution';
391+
392+
$toolResult = new ToolResult(
393+
toolCallId: $toolCall->id,
394+
toolName: $toolCall->name,
395+
args: $toolCall->arguments(),
396+
result: $reason,
397+
toolCallResultId: $toolCall->resultId,
398+
);
399+
400+
$approvalResolvedToolResults[] = $toolResult;
401+
402+
yield new ToolResultEvent(
403+
id: EventID::generate(),
404+
timestamp: time(),
405+
toolResult: $toolResult,
406+
messageId: $messageId,
407+
success: false,
408+
error: $reason,
409+
);
410+
}
411+
412+
if ($toolMessageIndex !== null) { // remove old tool result message
413+
$updatedMessages = array_values(array_filter(
414+
$messages,
415+
fn (int $index): bool => $index !== $toolMessageIndex,
416+
ARRAY_FILTER_USE_KEY,
417+
));
418+
$request->setMessages($updatedMessages);
419+
}
420+
421+
// Add new tool result message which also contains results of approval resolved tools
422+
$request->addMessage(new ToolResultMessage(
423+
array_merge($toolMessage->toolResults, $approvalResolvedToolResults),
424+
$toolMessage->toolApprovalResponses
425+
));
426+
}
427+
261428
/**
262429
* @param Tool[] $tools
263430
*

src/Enums/StreamEventType.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ enum StreamEventType: string
1717
case ToolCallDelta = 'tool_call_delta';
1818
case ProviderToolEvent = 'provider_tool_event';
1919
case ToolResult = 'tool_result';
20+
case ToolApprovalRequest = 'tool_approval_request';
2021
case Citation = 'citation';
2122
case Artifact = 'artifact';
2223
case Error = 'error';
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Prism\Prism\Events\Broadcasting;
6+
7+
class ToolApprovalRequestBroadcast extends StreamEventBroadcast {}

src/Providers/Anthropic/Handlers/Stream.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ public function __construct(protected PendingRequest $client)
5656
*/
5757
public function handle(Request $request): Generator
5858
{
59+
yield from $this->resolveToolApprovalsAndYieldEvents($request, EventID::generate(), $this->state);
60+
5961
$this->state->reset();
6062
$response = $this->sendRequest($request);
6163

src/Providers/Anthropic/Handlers/Text.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ public function __construct(protected PendingRequest $client, protected TextRequ
4848

4949
public function handle(): Response
5050
{
51+
$this->resolveToolApprovals($this->request);
52+
5153
$this->sendRequest();
5254

5355
$this->prepareTempResponse();

src/Providers/DeepSeek/Handlers/Stream.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ public function __construct(protected PendingRequest $client)
5757
*/
5858
public function handle(Request $request): Generator
5959
{
60+
yield from $this->resolveToolApprovalsAndYieldEvents($request, EventID::generate(), $this->state);
61+
6062
$response = $this->sendRequest($request);
6163

6264
yield from $this->processStream($response, $request);

src/Providers/DeepSeek/Handlers/Text.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ public function __construct(protected PendingRequest $client)
4141

4242
public function handle(Request $request): TextResponse
4343
{
44+
$this->resolveToolApprovals($request);
45+
4446
$data = $this->sendRequest($request);
4547

4648
$this->validateResponse($data);

src/Providers/Gemini/Handlers/Stream.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ public function __construct(
5858
*/
5959
public function handle(Request $request): Generator
6060
{
61+
yield from $this->resolveToolApprovalsAndYieldEvents($request, EventID::generate(), $this->state);
62+
6163
$this->state->reset();
6264
$this->currentThoughtSignature = null;
6365
$response = $this->sendRequest($request);

src/Providers/Gemini/Handlers/Text.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ public function __construct(
4343

4444
public function handle(Request $request): TextResponse
4545
{
46+
$this->resolveToolApprovals($request);
47+
4648
$response = $this->sendRequest($request);
4749

4850
$this->validateResponse($response);

src/Providers/Groq/Handlers/Stream.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ public function __construct(protected PendingRequest $client)
5656
*/
5757
public function handle(Request $request): Generator
5858
{
59+
yield from $this->resolveToolApprovalsAndYieldEvents($request, EventID::generate(), $this->state);
60+
5961
$response = $this->sendRequest($request);
6062

6163
yield from $this->processStream($response, $request);

0 commit comments

Comments
 (0)