Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/open-code-review/open-code-review/internal/stdout"
"github.com/open-code-review/open-code-review/internal/telemetry"
"github.com/open-code-review/open-code-review/internal/tool"
"go.opentelemetry.io/otel/attribute"
)

// AgentWarning is re-exported from llmloop for backwards compatibility with
Expand Down Expand Up @@ -500,6 +501,28 @@ func (a *Agent) executeSubtask(ctx context.Context, d model.Diff) error {
return err
}

// submitFilterTool is the tool definition that constrains the review filter
var submitFilterTool = llm.ToolDef{
Type: "function",
Function: llm.FunctionDef{
Name: "submit_filter_result",
Description: "Submit the list of review comments that are provably incorrect based on the diff",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"comment_ids": map[string]any{
"type": "array",
"description": "IDs of review comments confirmed as incorrect return an empty array when none, e.g. [\"c-0\", \"c-2\"]",
"items": map[string]any{
"type": "string",
},
},
},
"required": []any{"comment_ids"},
},
},
}

// executeReviewFilter runs the REVIEW_FILTER_TASK to remove comments that are
// provably incorrect based solely on the diff. Errors are logged and silently ignored.
func (a *Agent) executeReviewFilter(ctx context.Context, d model.Diff, newPath string) {
Expand Down Expand Up @@ -537,6 +560,7 @@ func (a *Agent) executeReviewFilter(ctx context.Context, d model.Diff, newPath s
resp, err := a.args.LLMClient.CompletionsWithCtx(ctx, llm.ChatRequest{
Model: a.args.Model,
Messages: messages,
Tools: []llm.ToolDef{submitFilterTool},
MaxTokens: a.args.Template.MaxTokens,
})
if err != nil {
Expand All @@ -547,12 +571,23 @@ func (a *Agent) executeReviewFilter(ctx context.Context, d model.Diff, newPath s
rec.SetResponse(resp, time.Since(startTime))
a.runner.RecordUsage(resp.Usage)

indices := parseFilterResponse(resp.Content(), len(comments))
indices := parseFilterToolCalls(resp.ToolCalls(), len(comments))
if indices == nil {
indices = parseFilterResponse(resp.Content(), len(comments))
}
if len(indices) == 0 {
telemetry.Event(ctx, "review_filter.completed",
attribute.String("file.path", newPath),
attribute.Int("total_comments", len(comments)),
attribute.Int("removed", 0))
return
}

a.args.CommentCollector.RemoveByPathAndIndices(newPath, indices)
telemetry.Event(ctx, "review_filter.completed",
attribute.String("file.path", newPath),
attribute.Int("total_comments", len(comments)),
attribute.Int("removed", len(indices)))
fmt.Fprintf(stdout.Writer(), "[ocr] Review filter removed %d comment(s) for %s\n", len(indices), newPath)
}

Expand All @@ -575,6 +610,32 @@ func buildFilterCommentsJSON(comments []model.LlmComment) string {
return string(data)
}

// parseFilterToolCalls extracts comment indices from the LLM's tool call
// response to submit_filter_result. Returns nil if no matching tool call
// is found, allowing fallback to text-based parsing.
func parseFilterToolCalls(calls []llm.ToolCall, total int) map[int]struct{} {
for _, call := range calls {
if call.Function.Name != "submit_filter_result" {
continue
}
var args struct {
CommentIDs []string `json:"comment_ids"`
}
if err := json.Unmarshal([]byte(call.Function.Arguments), &args); err != nil {
continue
}
indices := make(map[int]struct{})
for _, id := range args.CommentIDs {
var idx int
if _, err := fmt.Sscanf(id, "c-%d", &idx); err == nil && idx >= 0 && idx < total {
indices[idx] = struct{}{}
}
}
return indices
}
return nil
}

// parseFilterResponse extracts comment indices from the LLM filter response.
// Returns a set of 0-based indices. Invalid IDs or out-of-range indices are ignored.
func parseFilterResponse(raw string, total int) map[int]struct{} {
Expand Down
90 changes: 90 additions & 0 deletions internal/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,96 @@ func TestParseFilterResponse(t *testing.T) {
}
}

func TestParseFilterToolCalls(t *testing.T) {
tests := []struct {
name string
calls []llm.ToolCall
total int
wantSet map[int]struct{}
}{
{
name: "no tool calls",
calls: nil,
total: 5,
wantSet: nil,
},
{
name: "submit_filter_result with IDs",
calls: []llm.ToolCall{{
Function: llm.FunctionCall{
Name: "submit_filter_result",
Arguments: `{"comment_ids": ["c-0", "c-2"]}`,
},
}},
total: 5,
wantSet: map[int]struct{}{0: {}, 2: {}},
},
{
name: "submit_filter_result empty array",
calls: []llm.ToolCall{{
Function: llm.FunctionCall{
Name: "submit_filter_result",
Arguments: `{"comment_ids": []}`,
},
}},
total: 5,
wantSet: map[int]struct{}{},
},
{
name: "ignores non-matching tool names",
calls: []llm.ToolCall{{
Function: llm.FunctionCall{
Name: "other_tool",
Arguments: `{"comment_ids": ["c-0"]}`,
},
}},
total: 5,
wantSet: nil,
},
{
name: "out-of-range indices ignored",
calls: []llm.ToolCall{{
Function: llm.FunctionCall{
Name: "submit_filter_result",
Arguments: `{"comment_ids": ["c-0", "c-10"]}`,
},
}},
total: 5,
wantSet: map[int]struct{}{0: {}},
},
{
name: "invalid JSON arguments returns nil",
calls: []llm.ToolCall{{
Function: llm.FunctionCall{
Name: "submit_filter_result",
Arguments: `not json`,
},
}},
total: 5,
wantSet: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := parseFilterToolCalls(tt.calls, tt.total)
if tt.wantSet == nil {
if got != nil {
t.Errorf("expected nil, got %v", got)
}
return
}
if len(got) != len(tt.wantSet) {
t.Fatalf("len = %d, want %d; got %v", len(got), len(tt.wantSet), got)
}
for idx := range tt.wantSet {
if _, ok := got[idx]; !ok {
t.Errorf("missing index %d in result", idx)
}
}
})
}
}

func TestExtFromPath(t *testing.T) {
a := New(Args{})

Expand Down
12 changes: 1 addition & 11 deletions internal/config/template/prompts/review_filter_task_user.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,4 @@ After confirming that the facts visible in the diff are accurate, determine whet

### Output

Return all incorrect review comment IDs directly, without any explanation. Use JSON array format:

```json
["id-xxx", "id-yyy"]
```

If there are no review comments that can be confirmed as incorrect, return an empty array:

```json
[]
```
Call the `submit_filter_result` tool with the IDs of all review comments that can be confirmed as incorrect, without any explanation.
2 changes: 1 addition & 1 deletion internal/viewer/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ func LoadSession(root, encodedRepo, sessionID string) (*ViewSession, error) {
name, _ := tm["name"].(string)
args, _ := tm["arguments"].(string)
info := ToolCallInfo{Name: name, Arguments: args}
if name == "task_done" {
if name == "task_done" || name == "submit_filter_result" {
info.Ok = true
}
card.ToolCalls = append(card.ToolCalls, info)
Expand Down
Loading