diff --git a/.claude/skills/incident-triage-lame/SKILL.md b/.claude/skills/incident-triage-lame/SKILL.md new file mode 100644 index 0000000000..452be0dd5e --- /dev/null +++ b/.claude/skills/incident-triage-lame/SKILL.md @@ -0,0 +1,45 @@ +--- +name: incident-triage-lame +description: Triage active PagerDuty incidents by gathering context from all available services one tool call at a time. +allowed-tools: Bash +--- + +# Incident Triage (No Scripting) + +We have active PagerDuty incidents. Build an incident triage report by gathering data from every available service. + +Do NOT use `execute_tool_script`. Call each tool individually, one at a time. + +1. Check PagerDuty for service health and active incidents +2. For each degraded service, gather context from Datadog (metrics + logs), GitHub (recent PRs), Slack (#incidents messages), Jira (related issues), and Confluence (runbooks) +3. Cross-reference the results to identify probable root causes, who's engaged, and what runbooks apply + +Format the final report as markdown matching this structure exactly: + +``` +# Incident Triage Report + +## Service Health + + +## Active Incidents + + +## Degraded Service: +### Metrics + +### Error Logs + +### Recent PRs (Potential Root Causes) + +### Slack #incidents Context + +### Related Jira Issues + +### Runbooks + + +(repeat for each degraded service) +``` + +Include the raw tool output under each heading — do not summarize or rewrite it. diff --git a/.claude/skills/incident-triage/SKILL.md b/.claude/skills/incident-triage/SKILL.md new file mode 100644 index 0000000000..11f97440bc --- /dev/null +++ b/.claude/skills/incident-triage/SKILL.md @@ -0,0 +1,79 @@ +--- +name: incident-triage +description: Triage active PagerDuty incidents by gathering context from all available services using execute_tool_script. +allowed-tools: Bash +--- + +# Incident Triage + +We have active PagerDuty incidents. Use `execute_tool_script` to build an incident triage report by gathering data from every available service in a single scripted call. + +Write a Starlark script that: + +1. Gets the service health list and active incidents from PagerDuty +2. For each service that is NOT "Operational", gathers context in parallel: + - Datadog metrics and error logs for that service + - Recent GitHub PRs (look for potential root cause deploys) + - Slack #incidents messages for team context + - Related Jira issues + - Confluence runbooks +3. Parses the text results to extract key details — incident IDs, error messages, who's involved, what was recently deployed +4. Formats the result as a **markdown report** and returns it as a string + +The script should return a ready-to-display markdown string — NOT a dict. Build the markdown inside the script so no post-processing is needed. Structure it like: + +``` +# Incident Triage Report + +## Service Health + + +## Active Incidents + + +## Degraded Service: +### Metrics + +### Error Logs + +### Recent PRs (Potential Root Causes) + +### Slack #incidents Context + +### Related Jira Issues + +### Runbooks + + +(repeat for each degraded service) +``` + +Use loops over the degraded services and string parsing to cross-reference results. + +Use `parallel()` to fan out tool calls concurrently. `parallel()` takes a list of zero-arg callables (use `lambda`) and returns results in order. Fan out all services at once: + +```python +def gather_context(svc): + results = parallel([ + lambda s=svc: datadog_datadog_query_metrics(query=s), + lambda s=svc: datadog_datadog_search_logs(query=s), + lambda s=svc: github_github_search_prs(query=s), + lambda s=svc: slack_slack_read_messages(channel="incidents"), + lambda s=svc: jira_jira_search_issues(query=s), + lambda s=svc: confluence_confluence_search_pages(query=s), + ]) + return results + +# Fan out ALL services concurrently (nested parallel) +contexts = parallel([lambda s=svc: gather_context(s) for svc in degraded_services]) +``` + +NOTE: Starlark lambdas capture variables by reference. When using `lambda` inside a loop, bind the loop variable via a default argument to avoid the classic closure bug: +```python +# WRONG — all lambdas see the final value of svc +[lambda: query(svc) for svc in services] +# RIGHT — bind svc at definition time +[lambda s=svc: query(s) for svc in services] +``` + +IMPORTANT: The script returns a fully formatted markdown report. After calling execute_tool_script, display the result text verbatim. Do NOT summarize, reformat, or add your own analysis — the script output IS the final answer. diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 53238c7780..22000b0e11 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -27,6 +27,7 @@ import ( "github.com/stacklok/toolhive/pkg/authserver/server/keys" "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/groups" + "github.com/stacklok/toolhive/pkg/script" "github.com/stacklok/toolhive/pkg/telemetry" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" @@ -624,6 +625,7 @@ func runServe(cmd *cobra.Command, _ []string) error { Port: port, AuthMiddleware: authMiddleware, AuthzMiddleware: authzMiddleware, + ScriptMiddleware: script.NewMiddleware(), AuthInfoHandler: authInfoHandler, AuthServer: embeddedAuthServer, TelemetryProvider: telemetryProvider, diff --git a/demo/script-middleware/README.md b/demo/script-middleware/README.md new file mode 100644 index 0000000000..5f46b1a689 --- /dev/null +++ b/demo/script-middleware/README.md @@ -0,0 +1,139 @@ +# Script Middleware Demo + +Demonstrates `execute_tool_script` — a Starlark scripting layer that lets agents +orchestrate multiple MCP tool calls in a single atomic operation. + +## What this shows + +An agent connected to a VirtualMCPServer with 8 enterprise tool backends +(Slack, Jira, GitHub, PagerDuty, Datadog, Confluence, Google Drive, Linear) +uses `execute_tool_script` to gather and cross-reference data across services +in one call instead of 8+ sequential round-trips. + +## Setup (local Kind cluster) + +### Prerequisites +- `kind`, `kubectl`, `docker` installed +- ToolHive operator image built locally: `task build-all-images` + +### Deploy + +```bash +# From repo root +./demo/script-middleware/deploy.sh +``` + +This creates a Kind cluster, installs the operator, deploys 8 dummy MCP servers +and a VirtualMCPServer, and sets up port-forwarding on localhost:4483. + +### Connect with Claude Code + +```bash +# In Claude Code settings, add as an MCP server: +# URL: http://localhost:4483/mcp +# Transport: streamable-http + +# Then give Claude this prompt (see below) +``` + +### Teardown + +```bash +kind delete cluster --name script-demo +``` + +## The Prompt + +Give this to Claude (or any MCP-capable agent) after connecting: + +> We have active PagerDuty incidents. Use execute_tool_script to build an +> incident triage report by gathering data from every available service. +> +> Write a script that: +> 1. Gets the service health list and active incidents from PagerDuty +> 2. For each service that is NOT "Operational", gathers context: +> - Datadog metrics and error logs for that service +> - Recent GitHub PRs (look for potential root cause deploys) +> - Slack #incidents messages for team context +> - Related Jira issues +> - Confluence runbooks +> 3. Parses the text results to extract key details (incident IDs, error +> messages, who's involved, what was recently deployed) +> 4. Returns a structured dict mapping each degraded service to its +> full triage context +> +> The script should use loops and string parsing — don't just call each +> tool once, cross-reference the results. + +### What the agent should produce + +A Starlark script that loops over degraded services, calls 5-6 tools per +service, parses the text output to extract names/IDs/timestamps, and returns +a structured dict. Something like: + +```python +services = pagerduty_list_services() +incidents = pagerduty_list_incidents() +report = {} + +for line in services.split("\n"): + if "Degraded" in line or "Critical" in line: + svc = line.split(" — ")[0].strip() + + metrics = datadog_query_metrics(query=svc, timeframe="last_1h") + logs = datadog_search_logs(query="ERROR", service=svc) + prs = github_search_prs(query="merged", repo=svc) + slack = slack_read_messages(channel="#incidents") + jira = jira_search_issues(query=svc, project="ENG") + runbook = confluence_search_pages(query=svc + " runbook") + + # Extract people involved from Slack messages + people = [] + for msg in slack.split("\n"): + if "]" in msg: + who = msg.split("]")[1].split(":")[0].strip() + if who and who not in people: + people.append(who) + + # Find incident IDs for this service + svc_incidents = [] + for inc in incidents.split("\n"): + if svc in inc: + svc_incidents.append(inc.strip()) + + report[svc] = { + "incidents": svc_incidents, + "metrics_summary": metrics, + "recent_errors": logs, + "recent_prs": prs, + "team_engaged": people, + "related_jira": jira, + "runbook": runbook, + } + +return report +``` + +## Why this is interesting + +1. **Loops + conditionals** — the script iterates over degraded services, + not a static list. The agent writes real control flow. + +2. **Cross-referencing** — incident IDs from PagerDuty are matched against + service names. Slack messages are parsed to extract who's engaged. + +3. **8+ tool calls in one round-trip** — without `execute_tool_script`, + the agent needs sequential calls with model inference between each. + The script runs server-side and returns one aggregated result. + +4. **Text parsing** — the script does string splitting and filtering that + would otherwise require the model to process raw text from each tool. + +## Coherent demo story + +The dummy data tells a story: Alice deployed `v2.4.1` which caused the +checkout service to timeout, spiking web-app latency. PagerDuty fired two +incidents (SEV1 checkout, SEV2 web-app). The Slack #incidents channel shows +Alice, Bob, and Carol coordinating. Datadog logs show the exact error chain. +GitHub shows the merged PR that caused it. The script stitches all of this +together into a single triage report. diff --git a/demo/script-middleware/deploy.sh b/demo/script-middleware/deploy.sh new file mode 100755 index 0000000000..62f3df3df3 --- /dev/null +++ b/demo/script-middleware/deploy.sh @@ -0,0 +1,95 @@ +#!/usr/bin/env bash +# Deploy the script middleware demo to a local Kind cluster. +# Prerequisites: kind, kubectl, docker, task (Taskfile) +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +CLUSTER_NAME="script-demo" + +echo "=== Script Middleware Demo ===" +echo "" + +# 1. Create Kind cluster (if not exists) +if kind get clusters 2>/dev/null | grep -q "^${CLUSTER_NAME}$"; then + echo "Kind cluster '$CLUSTER_NAME' already exists, reusing." +else + echo "Creating Kind cluster '$CLUSTER_NAME'..." + cat </dev/null || echo "$HOME/.kube/config")" +kubectl config use-context "kind-${CLUSTER_NAME}" + +# 2. Build and load images +echo "" +echo "Building operator and vmcp images..." +cd "$REPO_ROOT" +task build-all-images 2>&1 | tail -5 + +echo "Loading images into Kind cluster..." +kind load docker-image ghcr.io/stacklok/toolhive/operator:latest --name "$CLUSTER_NAME" +kind load docker-image ghcr.io/stacklok/toolhive/vmcp:latest --name "$CLUSTER_NAME" +kind load docker-image ghcr.io/stacklok/toolhive/proxyrunner:latest --name "$CLUSTER_NAME" + +# 3. Install CRDs and operator +echo "" +echo "Installing CRDs..." +kubectl apply -f deploy/charts/operator-crds/files/crds/ 2>&1 | head -5 + +echo "Deploying operator..." +helm upgrade --install thv-operator deploy/charts/operator \ + --namespace toolhive-system --create-namespace \ + --set image.tag=latest \ + --set vmcpImage.tag=latest \ + --set proxyRunnerImage.tag=latest \ + --wait --timeout 120s 2>&1 | tail -3 + +# 4. Deploy demo manifests +echo "" +echo "Deploying demo MCP servers..." +kubectl apply -f "$SCRIPT_DIR/manifests.yaml" + +# 5. Wait for VirtualMCPServer +echo "" +echo "Waiting for VirtualMCPServer to be ready..." +kubectl wait --for=condition=Ready virtualmcpserver/demo-vmcp \ + -n script-demo --timeout=180s 2>&1 || true + +# 6. Patch the NodePort to use 30080 (mapped to host 4483) +echo "" +echo "Configuring NodePort..." +VMCP_SVC=$(kubectl get svc -n script-demo -l app.kubernetes.io/instance=demo-vmcp -o name | head -1) +if [ -n "$VMCP_SVC" ]; then + kubectl patch "$VMCP_SVC" -n script-demo --type='json' \ + -p='[{"op":"replace","path":"/spec/ports/0/nodePort","value":30080}]' 2>/dev/null || true +fi + +echo "" +echo "=== Demo Ready ===" +echo "" +echo "VirtualMCPServer: http://localhost:4483/mcp" +echo "" +echo "Tools available: slack (4), jira (4), confluence (2), github (4)," +echo " pagerduty (3), datadog (3), google-drive (2), linear (2)" +echo " + execute_tool_script" +echo "" +echo "Connect with an MCP client or add to Claude Code settings:" +echo ' { "mcpServers": { "demo": { "url": "http://localhost:4483/mcp" } } }' +echo "" +echo "Teardown: kind delete cluster --name $CLUSTER_NAME" diff --git a/demo/script-middleware/manifests.yaml b/demo/script-middleware/manifests.yaml new file mode 100644 index 0000000000..ea29826627 --- /dev/null +++ b/demo/script-middleware/manifests.yaml @@ -0,0 +1,235 @@ +# Script middleware demo — 8 dummy MCP servers + VirtualMCPServer +# Adapted from infra/flux/apps/anthropic-dogfood/mcp-resources.yaml +--- +apiVersion: v1 +kind: Namespace +metadata: + name: script-demo +--- +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPGroup +metadata: + name: demo-group + namespace: script-demo +spec: + description: Demo MCP servers for script middleware prototype +--- +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPServer +metadata: + name: slack + namespace: script-demo +spec: + image: ghcr.io/stackloklabs/dummy-mcp:0.1.3 + transport: streamable-http + proxyPort: 8080 + mcpPort: 8100 + groupRef: demo-group + podTemplateSpec: + spec: + containers: + - name: mcp + env: + - name: TOOLS_CONFIG + value: '{"tools":[{"name":"slack_list_channels","description":"List Slack channels in the workspace","return_value":"#general — Company-wide announcements (1,247 members)\n#engineering — Engineering discussion (342 members)\n#incidents — Active incident coordination (198 members)\n#random — Water cooler (1,102 members)"},{"name":"slack_read_messages","description":"Read recent messages from a channel","input_schema":{"channel":{"type":"string","description":"Channel name or ID"}},"return_value":"[10:32 AM] alice: Deployed v2.4.1 to staging, running smoke tests now\n[10:45 AM] bob: Smoke tests passed, promoting to prod\n[11:02 AM] carol: Prod deploy complete, monitoring dashboards look clean\n[11:15 AM] alice: Seeing elevated p99 on web-app after deploy, investigating\n[11:18 AM] bob: Confirmed — /checkout endpoint latency spiked, rolling back"},{"name":"slack_find_people","description":"Search for people in the workspace","input_schema":{"name":{"type":"string","description":"Person name or username"}},"return_value":"Alice Chen — Senior Engineer, Platform (alice@company.com)\nAlice Wong — Product Manager (awong@company.com)"},{"name":"slack_list_teams","description":"List Slack user groups and teams","return_value":"@engineering — Engineering team (45 members)\n@platform — Platform team (12 members)\n@oncall-primary — Current primary oncall (1 member)\n@product — Product team (18 members)"}]}' + resources: + requests: + cpu: 50m + memory: 64Mi + limits: + cpu: 100m + memory: 128Mi +--- +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPServer +metadata: + name: jira + namespace: script-demo +spec: + image: ghcr.io/stackloklabs/dummy-mcp:0.1.3 + transport: streamable-http + proxyPort: 8080 + mcpPort: 8100 + groupRef: demo-group + podTemplateSpec: + spec: + containers: + - name: mcp + env: + - name: TOOLS_CONFIG + value: '{"tools":[{"name":"jira_list_projects","description":"List Jira projects and boards","return_value":"ID: ENG | Name: Engineering\nID: PROD | Name: Product\nID: INFRA | Name: Infrastructure\nID: SEC | Name: Security"},{"name":"jira_search_issues","description":"Search issues by project, status, or assignee","input_schema":{"query":{"type":"string","description":"Search query or JQL expression"},"project":{"type":"string","description":"Project key to filter by"}},"return_value":"ENG-142: Upgrade auth middleware [In Progress] @alice — priority: High\nENG-157: Fix rate limiter edge case [Open] @bob — priority: Critical\nENG-163: Add retry logic to webhook handler [In Review] @carol — priority: Medium\nINFRA-89: web-app p99 latency regression [Open] @alice — priority: Critical"},{"name":"jira_create_issue","description":"Create a new Jira issue with summary and description","input_schema":{"project":{"type":"string","description":"Project key (e.g. ENG)"},"summary":{"type":"string","description":"Issue summary"},"description":{"type":"string","description":"Detailed description"}},"return_value":"Created ENG-201: Issue created successfully"},{"name":"jira_find_assignee","description":"Find a team member to assign an issue to","input_schema":{"name":{"type":"string","description":"Person name or username"}},"return_value":"Alice Chen (alice@company.com) — Engineering, 3 active issues\nBob Martinez (bob@company.com) — Engineering, 5 active issues\nCarol Park (carol@company.com) — Platform, 2 active issues"}]}' + resources: + requests: + cpu: 50m + memory: 64Mi + limits: + cpu: 100m + memory: 128Mi +--- +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPServer +metadata: + name: confluence + namespace: script-demo +spec: + image: ghcr.io/stackloklabs/dummy-mcp:0.1.3 + transport: streamable-http + proxyPort: 8080 + mcpPort: 8100 + groupRef: demo-group + podTemplateSpec: + spec: + containers: + - name: mcp + env: + - name: TOOLS_CONFIG + value: '{"tools":[{"name":"confluence_list_spaces","description":"List Confluence spaces and knowledge bases","return_value":"ENG — Engineering Wiki\nPROD — Product Specs\nOPS — Operations Runbooks\nHR — People & Culture"},{"name":"confluence_search_pages","description":"Search pages and documentation by keyword","input_schema":{"query":{"type":"string","description":"Search keyword or phrase"}},"return_value":"[OPS] Incident Response Playbook — Updated 3 days ago\n[OPS] web-app Runbook: Latency Escalation — Updated 1 week ago\n[ENG] Deployment Checklist — Updated 1 week ago\n[ENG] API Authentication Guide — Updated 2 days ago"}]}' + resources: + requests: + cpu: 50m + memory: 64Mi + limits: + cpu: 100m + memory: 128Mi +--- +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPServer +metadata: + name: github + namespace: script-demo +spec: + image: ghcr.io/stackloklabs/dummy-mcp:0.1.3 + transport: streamable-http + proxyPort: 8080 + mcpPort: 8100 + groupRef: demo-group + podTemplateSpec: + spec: + containers: + - name: mcp + env: + - name: TOOLS_CONFIG + value: '{"tools":[{"name":"github_list_repos","description":"List repositories in the organization","return_value":"acme/api-gateway — API gateway service (Go) ★ 142\nacme/web-app — Customer-facing web application (TypeScript) ★ 89\nacme/auth-service — Authentication and authorization (Rust) ★ 203\nacme/infra — Infrastructure as code (Terraform) ★ 56"},{"name":"github_search_prs","description":"Search pull requests by status, author, or repo","input_schema":{"query":{"type":"string","description":"Search query"},"repo":{"type":"string","description":"Repository name"}},"return_value":"#847 [open] Add rate limiting to /api/v2 endpoints — @alice (2 approvals, CI passing)\n#842 [open] Fix memory leak in connection pool — @bob (needs review)\n#839 [merged 2h ago] Upgrade TLS certificates — @carol\n#835 [merged 6h ago] Bump checkout-service dependency to v2.4.1 — @alice"},{"name":"github_create_issue","description":"Create a GitHub issue with labels and assignees","input_schema":{"repo":{"type":"string","description":"Repository (org/repo)"},"title":{"type":"string","description":"Issue title"},"body":{"type":"string","description":"Issue body in markdown"},"labels":{"type":"string","description":"Comma-separated labels"}},"return_value":"Created issue #312 in acme/api-gateway"},{"name":"github_search_code","description":"Search code across repositories","input_schema":{"query":{"type":"string","description":"Code search query"},"repo":{"type":"string","description":"Optional repository filter"}},"return_value":"acme/web-app/src/checkout/handler.ts:42 — async function processCheckout(cart: Cart)\nacme/web-app/src/middleware/timeout.ts:15 — const CHECKOUT_TIMEOUT_MS = 5000\nacme/api-gateway/src/routes/checkout.go:28 — func ProxyCheckout(w http.ResponseWriter, r *http.Request)"}]}' + resources: + requests: + cpu: 50m + memory: 64Mi + limits: + cpu: 100m + memory: 128Mi +--- +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPServer +metadata: + name: pagerduty + namespace: script-demo +spec: + image: ghcr.io/stackloklabs/dummy-mcp:0.1.3 + transport: streamable-http + proxyPort: 8080 + mcpPort: 8100 + groupRef: demo-group + podTemplateSpec: + spec: + containers: + - name: mcp + env: + - name: TOOLS_CONFIG + value: '{"tools":[{"name":"pagerduty_list_services","description":"List monitored services and their status","return_value":"api-gateway — Operational ✓\nauth-service — Operational ✓\nweb-app — Degraded Performance ⚠\ndatabase-primary — Operational ✓\ncheckout-service — Degraded Performance ⚠\ncache-redis — Operational ✓"},{"name":"pagerduty_list_incidents","description":"List active incidents by service and severity","input_schema":{"service":{"type":"string","description":"Service name to filter"},"status":{"type":"string","description":"Incident status (triggered, acknowledged, resolved)"}},"return_value":"INC-4521 [triggered] web-app: Elevated p99 latency (>2s) — SEV2 — 23 min ago\nINC-4522 [triggered] checkout-service: Timeout errors on payment flow — SEV1 — 18 min ago\nINC-4519 [acknowledged] web-app: Increased error rate on /checkout — SEV2 — 45 min ago\nINC-4515 [resolved] api-gateway: Certificate expiry warning — SEV3 — 2h ago"},{"name":"pagerduty_ack_incident","description":"Acknowledge or resolve an incident","input_schema":{"incident_id":{"type":"string","description":"Incident ID (e.g. INC-4521)"},"action":{"type":"string","description":"Action: acknowledge or resolve"}},"return_value":"Acknowledged INC-4521 — assigned to current oncall (alice@company.com)"}]}' + resources: + requests: + cpu: 50m + memory: 64Mi + limits: + cpu: 100m + memory: 128Mi +--- +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPServer +metadata: + name: datadog + namespace: script-demo +spec: + image: ghcr.io/stackloklabs/dummy-mcp:0.1.3 + transport: streamable-http + proxyPort: 8080 + mcpPort: 8100 + groupRef: demo-group + podTemplateSpec: + spec: + containers: + - name: mcp + env: + - name: TOOLS_CONFIG + value: '{"tools":[{"name":"datadog_list_dashboards","description":"List monitoring dashboards","return_value":"API Gateway Overview — Last viewed 2h ago\nService Health Matrix — Last viewed 30m ago\nDatabase Performance — Last viewed 1h ago\nDeployment Tracker — Last viewed 4h ago"},{"name":"datadog_query_metrics","description":"Query metrics, traces, and service health","input_schema":{"query":{"type":"string","description":"Metric query (e.g. avg:system.cpu.user)"},"timeframe":{"type":"string","description":"Time range (e.g. last_1h, last_24h)"}},"return_value":"avg:http.request.duration{service:web-app} [last_1h]\n p50: 145ms | p95: 1230ms | p99: 2890ms ⚠ (SLO breach)\n Request rate: 8,421 req/min (↓ 15% from baseline)\n Error rate: 4.7% ⚠ (baseline: 0.3%)\n Top error: 504 Gateway Timeout on /checkout — 312 occurrences"},{"name":"datadog_search_logs","description":"Search application logs by service or keyword","input_schema":{"query":{"type":"string","description":"Log search query"},"service":{"type":"string","description":"Service name filter"}},"return_value":"[11:18:42] ERROR web-app: Timeout connecting to checkout-service timeout=5s path=/checkout\n[11:18:38] ERROR checkout-service: upstream dependency timeout service=payment-gateway\n[11:17:55] WARN web-app: Retry exhausted for /checkout attempts=3 last_error=504\n[11:15:02] INFO web-app: Deployment v2.4.1 rollout complete replicas=5/5\n[11:02:33] INFO web-app: Starting deployment v2.4.1 initiated_by=alice"}]}' + resources: + requests: + cpu: 50m + memory: 64Mi + limits: + cpu: 100m + memory: 128Mi +--- +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPServer +metadata: + name: google-drive + namespace: script-demo +spec: + image: ghcr.io/stackloklabs/dummy-mcp:0.1.3 + transport: streamable-http + proxyPort: 8080 + mcpPort: 8100 + groupRef: demo-group + podTemplateSpec: + spec: + containers: + - name: mcp + env: + - name: TOOLS_CONFIG + value: '{"tools":[{"name":"gdrive_list_files","description":"List files and folders in shared drives","return_value":"Engineering Shared Drive\n Q1 Architecture Review.docx — Modified 3 days ago\n Service Dependencies.xlsx — Modified 1 week ago\n Incident Postmortem Template.docx — Modified 2 weeks ago"},{"name":"gdrive_search","description":"Search for documents, spreadsheets, and presentations","input_schema":{"query":{"type":"string","description":"Search query"}},"return_value":"Q4 Planning Doc.docx — Product Shared Drive — Modified Dec 15\nAPI Migration Tracker.xlsx — Engineering Shared Drive — Modified Jan 10\nCheckout Service Postmortem (Dec).docx — Engineering Shared Drive — Modified Dec 22"}]}' + resources: + requests: + cpu: 50m + memory: 64Mi + limits: + cpu: 100m + memory: 128Mi +--- +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPServer +metadata: + name: linear + namespace: script-demo +spec: + image: ghcr.io/stackloklabs/dummy-mcp:0.1.3 + transport: streamable-http + proxyPort: 8080 + mcpPort: 8100 + groupRef: demo-group + podTemplateSpec: + spec: + containers: + - name: mcp + env: + - name: TOOLS_CONFIG + value: '{"tools":[{"name":"linear_list_projects","description":"List Linear projects and cycles","return_value":"Q1 Platform Hardening — 67% complete (ends Mar 31)\nAuth Service Rewrite — 45% complete (ends Feb 28)\nAPI v2 Migration — Complete\nQ2 Planning — Not started"},{"name":"linear_search_issues","description":"Search issues by project, assignee, or priority","input_schema":{"query":{"type":"string","description":"Search query"},"project":{"type":"string","description":"Project name filter"}},"return_value":"PLAT-234 [In Progress] Implement connection pooling — P1 @alice\nPLAT-241 [Todo] Add circuit breaker to external calls — P2 @bob\nPLAT-245 [In Review] Update health check endpoints — P1 @carol"}]}' + resources: + requests: + cpu: 50m + memory: 64Mi + limits: + cpu: 100m + memory: 128Mi +--- +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: VirtualMCPServer +metadata: + name: demo-vmcp + namespace: script-demo +spec: + config: + groupRef: demo-group + incomingAuth: + type: anonymous + serviceType: NodePort diff --git a/go.mod b/go.mod index 7d25faaf41..21e21a09bf 100644 --- a/go.mod +++ b/go.mod @@ -79,6 +79,8 @@ require ( require github.com/getsentry/sentry-go/otel v0.44.1 +require go.starlark.net v0.0.0-20260326113308-fadfc96def35 + require ( cel.dev/expr v0.25.1 // indirect github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect @@ -298,7 +300,7 @@ require ( gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect - k8s.io/apiextensions-apiserver v0.35.0 // indirect + k8s.io/apiextensions-apiserver v0.35.0 k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect modernc.org/libc v1.70.0 // indirect diff --git a/go.sum b/go.sum index 1684d252d0..b0635e8aeb 100644 --- a/go.sum +++ b/go.sum @@ -938,6 +938,8 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09 go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= +go.starlark.net v0.0.0-20260326113308-fadfc96def35 h1:VYAqieSOJNxBDX8KJneTAwvdf4J4zRDE2u+UFXtt9h4= +go.starlark.net v0.0.0-20260326113308-fadfc96def35/go.mod h1:Iue6g6iirlfLoVi/DYCi5/x0h/bAOuWF3dULTKpt2Vo= go.step.sm/crypto v0.74.0 h1:/APBEv45yYR4qQFg47HA8w1nesIGcxh44pGyQNw6JRA= go.step.sm/crypto v0.74.0/go.mod h1:UoXqCAJjjRgzPte0Llaqen7O9P7XjPmgjgTHQGkKCDk= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= diff --git a/pkg/script/acceptance_test.go b/pkg/script/acceptance_test.go new file mode 100644 index 0000000000..4778e21347 --- /dev/null +++ b/pkg/script/acceptance_test.go @@ -0,0 +1,274 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package script + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestAcceptance_ScriptOrchestratesMultipleTools is the motivating example from the RFC: +// fetch PRs by an author, then filter to only those where a specific reviewer commented. +func TestAcceptance_ScriptOrchestratesMultipleTools(t *testing.T) { + t.Parallel() + + // Mock data: two PRs by jerm-dro, only PR 1 has a comment from yrobla + prsData := []map[string]interface{}{ + {"id": 1, "title": "Add script middleware", "author": "jerm-dro"}, + {"id": 2, "title": "Fix linting", "author": "jerm-dro"}, + } + commentsData := map[float64][]map[string]interface{}{ + 1: { + {"author": "yrobla", "body": "lgtm"}, + {"author": "someone-else", "body": "nice"}, + }, + 2: { + {"author": "someone-else", "body": "needs work"}, + }, + } + + backend := mockMCPBackend(map[string]func(map[string]interface{}) string{ + "fetch_prs": func(args map[string]interface{}) string { + // Filter by author if provided + author, _ := args["author"].(string) + var filtered []map[string]interface{} + for _, pr := range prsData { + if author == "" || pr["author"] == author { + filtered = append(filtered, pr) + } + } + b, _ := json.Marshal(filtered) + return string(b) + }, + "fetch_comments": func(args map[string]interface{}) string { + prID, _ := args["pr_id"].(float64) + comments := commentsData[prID] + if comments == nil { + comments = []map[string]interface{}{} + } + b, _ := json.Marshal(comments) + return string(b) + }, + }) + + middleware := NewMiddleware()(backend) + + // The motivating script: find PRs where a specific reviewer commented + script := ` +prs = fetch_prs(author=author_name) +output = [] +for pr in prs: + comments = fetch_comments(pr_id=pr["id"]) + for c in comments: + if c["author"] == reviewer: + output.append(pr) + break +return output +` + + t.Run("script filters PRs by reviewer comments", func(t *testing.T) { + t.Parallel() + + rec := sendJSONRPC(t, middleware, "tools/call", map[string]interface{}{ + "name": ExecuteToolScriptName, + "arguments": map[string]interface{}{ + "script": script, + "data": map[string]interface{}{ + "author_name": "jerm-dro", + "reviewer": "yrobla", + }, + }, + }) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp jsonRPCResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotNil(t, resp.Result, "expected a result") + require.Nil(t, resp.Error, "expected no error") + + var resultMap map[string]interface{} + require.NoError(t, json.Unmarshal(*resp.Result, &resultMap)) + + content := resultMap["content"].([]interface{}) + require.NotEmpty(t, content) + + textItem := content[0].(map[string]interface{}) + require.Equal(t, "text", textItem["type"]) + + var prs []map[string]interface{} + require.NoError(t, json.Unmarshal([]byte(textItem["text"].(string)), &prs)) + + // Only PR 1 should be in the result (the one where yrobla commented) + require.Len(t, prs, 1) + assert.Equal(t, float64(1), prs[0]["id"]) + assert.Equal(t, "Add script middleware", prs[0]["title"]) + }) + + t.Run("tools/list includes execute_tool_script with dynamic description", func(t *testing.T) { + t.Parallel() + + rec := sendJSONRPC(t, middleware, "tools/list", map[string]interface{}{}) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp jsonRPCResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + + var resultMap map[string]interface{} + require.NoError(t, json.Unmarshal(*resp.Result, &resultMap)) + + toolsRaw := resultMap["tools"].([]interface{}) + + names := make([]string, 0, len(toolsRaw)) + var scriptToolDesc string + for _, item := range toolsRaw { + tm := item.(map[string]interface{}) + name := tm["name"].(string) + names = append(names, name) + if name == ExecuteToolScriptName { + scriptToolDesc = tm["description"].(string) + } + } + + assert.Contains(t, names, "fetch_prs") + assert.Contains(t, names, "fetch_comments") + assert.Contains(t, names, ExecuteToolScriptName) + + // Dynamic description should mention the available tools + assert.Contains(t, scriptToolDesc, "fetch_prs") + assert.Contains(t, scriptToolDesc, "fetch_comments") + }) + + t.Run("empty script returns null", func(t *testing.T) { + t.Parallel() + + rec := sendJSONRPC(t, middleware, "tools/call", map[string]interface{}{ + "name": ExecuteToolScriptName, + "arguments": map[string]interface{}{ + "script": "x = 1", + }, + }) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp jsonRPCResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotNil(t, resp.Result) + + var resultMap map[string]interface{} + require.NoError(t, json.Unmarshal(*resp.Result, &resultMap)) + + content := resultMap["content"].([]interface{}) + textItem := content[0].(map[string]interface{}) + assert.Equal(t, "null", textItem["text"]) + }) + + t.Run("step limit exceeded returns error", func(t *testing.T) { + t.Parallel() + + rec := sendJSONRPC(t, middleware, "tools/call", map[string]interface{}{ + "name": ExecuteToolScriptName, + "arguments": map[string]interface{}{ + "script": "while True:\n pass", + }, + }) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp jsonRPCResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotNil(t, resp.Error, "should return error for infinite loop") + }) + + t.Run("data arguments accessible as script globals", func(t *testing.T) { + t.Parallel() + + rec := sendJSONRPC(t, middleware, "tools/call", map[string]interface{}{ + "name": ExecuteToolScriptName, + "arguments": map[string]interface{}{ + "script": "return {\"name\": user_name, \"count\": item_count}", + "data": map[string]interface{}{ + "user_name": "test-user", + "item_count": 42, + }, + }, + }) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp jsonRPCResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotNil(t, resp.Result) + + var resultMap map[string]interface{} + require.NoError(t, json.Unmarshal(*resp.Result, &resultMap)) + + content := resultMap["content"].([]interface{}) + textItem := content[0].(map[string]interface{}) + + var result map[string]interface{} + require.NoError(t, json.Unmarshal([]byte(textItem["text"].(string)), &result)) + assert.Equal(t, "test-user", result["name"]) + assert.Equal(t, float64(42), result["count"]) + }) +} + +// TestAcceptance_RawHTTPFlow verifies the complete HTTP flow with a real httptest.Server. +func TestAcceptance_RawHTTPFlow(t *testing.T) { + t.Parallel() + + backend := mockMCPBackend(map[string]func(map[string]interface{}) string{ + "add": func(args map[string]interface{}) string { + a, _ := args["a"].(float64) + b, _ := args["b"].(float64) + result, _ := json.Marshal(a + b) + return string(result) + }, + }) + + handler := NewMiddleware()(backend) + server := httptest.NewServer(handler) + defer server.Close() + + // Send a real HTTP request + body := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]interface{}{ + "name": ExecuteToolScriptName, + "arguments": map[string]interface{}{ + "script": "return add(a=x, b=y)", + "data": map[string]interface{}{"x": 10, "y": 32}, + }, + }, + } + bodyBytes, _ := json.Marshal(body) + + resp, err := http.Post(server.URL, "application/json", bytes.NewReader(bodyBytes)) + require.NoError(t, err) + defer func() { + _ = resp.Body.Close() + }() + + require.Equal(t, http.StatusOK, resp.StatusCode) + + var rpcResp jsonRPCResponse + require.NoError(t, json.NewDecoder(resp.Body).Decode(&rpcResp)) + require.NotNil(t, rpcResp.Result) + + var resultMap map[string]interface{} + require.NoError(t, json.Unmarshal(*rpcResp.Result, &resultMap)) + + content := resultMap["content"].([]interface{}) + textItem := content[0].(map[string]interface{}) + assert.Equal(t, "42", textItem["text"]) +} diff --git a/pkg/script/bridge.go b/pkg/script/bridge.go new file mode 100644 index 0000000000..6be602c9c9 --- /dev/null +++ b/pkg/script/bridge.go @@ -0,0 +1,393 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package script + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "math" + "regexp" + "strings" + "sync" + "unicode" + + "go.starlark.net/starlark" +) + +// CallToolResult holds the result of an MCP tool call. +type CallToolResult struct { + Content []ContentItem `json:"content"` + StructuredContent map[string]interface{} `json:"structuredContent,omitempty"` + IsError bool `json:"isError,omitempty"` +} + +// ContentItem represents a single content item in an MCP tool result. +type ContentItem struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +// ToolCaller is the interface for making MCP tool calls. +type ToolCaller interface { + CallTool(ctx context.Context, toolName string, arguments map[string]interface{}) (*CallToolResult, error) +} + +// ToolInfo describes an MCP tool available to scripts. +type ToolInfo struct { + Name string + Description string +} + +// BuildGlobals creates Starlark globals from MCP tools, a caller, data arguments, and a context. +// Each tool becomes a callable Starlark function. A generic call_tool() builtin is also provided. +func BuildGlobals(ctx context.Context, tools []ToolInfo, caller ToolCaller, data map[string]interface{}) starlark.StringDict { + globals := make(starlark.StringDict) + + // Track sanitized names for collision detection + seen := make(map[string]string) // sanitized → original + + for _, tool := range tools { + sanitized := sanitizeToolName(tool.Name) + if existing, ok := seen[sanitized]; ok { + slog.Warn("tool name collision after sanitization", + "tool1", existing, "tool2", tool.Name, "sanitized", sanitized) + continue + } + seen[sanitized] = tool.Name + globals[sanitized] = makeToolBuiltin(ctx, tool.Name, sanitized, caller) + } + + // Generic call_tool builtin for tools with awkward names + globals["call_tool"] = starlark.NewBuiltin("call_tool", func( + _ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple, + ) (starlark.Value, error) { + if len(args) < 1 { + return nil, fmt.Errorf("call_tool: requires at least 1 positional argument (tool name)") + } + nameVal, ok := args[0].(starlark.String) + if !ok { + return nil, fmt.Errorf("call_tool: first argument must be a string, got %s", args[0].Type()) + } + toolName := string(nameVal) + arguments := kwargsToGoMap(kwargs) + return callToolAndConvert(ctx, caller, toolName, arguments) + }) + + // parallel() builtin — execute a list of callables concurrently + globals["parallel"] = starlark.NewBuiltin("parallel", parallelBuiltin) + + // Inject data arguments as top-level globals + for k, v := range data { + sv, err := goToStarlark(v) + if err != nil { + slog.Warn("failed to convert data argument to Starlark", "key", k, "error", err) + continue + } + globals[k] = sv + } + + return globals +} + +func makeToolBuiltin(ctx context.Context, realName, displayName string, caller ToolCaller) *starlark.Builtin { + return starlark.NewBuiltin(displayName, func( + _ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple, + ) (starlark.Value, error) { + if len(args) > 0 { + return nil, fmt.Errorf("%s: use keyword arguments only (e.g., %s(key=value))", displayName, displayName) + } + arguments := kwargsToGoMap(kwargs) + return callToolAndConvert(ctx, caller, realName, arguments) + }) +} + +func callToolAndConvert( + ctx context.Context, caller ToolCaller, toolName string, arguments map[string]interface{}, +) (starlark.Value, error) { + result, err := caller.CallTool(ctx, toolName, arguments) + if err != nil { + return nil, fmt.Errorf("tool %q call failed: %w", toolName, err) + } + + goVal, err := parseToolResult(result) + if err != nil { + return nil, fmt.Errorf("tool %q returned error: %w", toolName, err) + } + + sv, err := goToStarlark(goVal) + if err != nil { + return nil, fmt.Errorf("tool %q result conversion failed: %w", toolName, err) + } + return sv, nil +} + +func kwargsToGoMap(kwargs []starlark.Tuple) map[string]interface{} { + m := make(map[string]interface{}, len(kwargs)) + for _, kv := range kwargs { + key := string(kv[0].(starlark.String)) + m[key] = starlarkToGo(kv[1]) + } + return m +} + +// parseToolResult converts a CallToolResult into a Go value. +func parseToolResult(result *CallToolResult) (interface{}, error) { + if result.IsError { + msg := "tool execution error" + if len(result.Content) > 0 && result.Content[0].Text != "" { + msg = result.Content[0].Text + } + return nil, fmt.Errorf("%s", msg) + } + + // Prefer structured content, but unwrap the common SDK wrapper + // pattern where the result is {"result": }. + if result.StructuredContent != nil { + if len(result.StructuredContent) == 1 { + if v, ok := result.StructuredContent["result"]; ok { + return v, nil + } + } + return result.StructuredContent, nil + } + + // Fall back to parsing first text content as JSON + if len(result.Content) == 0 { + return nil, nil + } + + if len(result.Content) > 1 { + slog.Debug("tool returned multiple content items, using first text item only", + "count", len(result.Content)) + } + + text := result.Content[0].Text + + var parsed interface{} + if err := json.Unmarshal([]byte(text), &parsed); err != nil { + // Not valid JSON — return as plain string + return text, nil + } + return parsed, nil +} + +// parallelBuiltin executes a list of zero-arg callables concurrently and +// returns a list of results where result[i] corresponds to callable[i]. +func parallelBuiltin( + thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple, +) (starlark.Value, error) { + var fns *starlark.List + if err := starlark.UnpackPositionalArgs("parallel", args, kwargs, 1, &fns); err != nil { + return nil, err + } + + n := fns.Len() + if n == 0 { + return starlark.NewList(nil), nil + } + + results := make([]starlark.Value, n) + errs := make([]error, n) + + var wg sync.WaitGroup + wg.Add(n) + + for i := range n { + go func(idx int) { + defer wg.Done() + + callable, ok := fns.Index(idx).(starlark.Callable) + if !ok { + errs[idx] = fmt.Errorf("parallel: element %d is not callable (got %s)", + idx, fns.Index(idx).Type()) + return + } + + childThread := &starlark.Thread{ + Name: fmt.Sprintf("%s/parallel-%d", thread.Name, idx), + Print: thread.Print, + } + + result, err := starlark.Call(childThread, callable, nil, nil) + if err != nil { + errs[idx] = err + return + } + results[idx] = result + }(i) + } + + wg.Wait() + + for i, err := range errs { + if err != nil { + return nil, fmt.Errorf("parallel: task %d failed: %w", i, err) + } + } + + return starlark.NewList(results), nil +} + +var nonIdentChar = regexp.MustCompile(`[^a-zA-Z0-9_]`) + +// sanitizeToolName converts an MCP tool name into a valid Starlark identifier. +func sanitizeToolName(name string) string { + s := nonIdentChar.ReplaceAllString(name, "_") + if len(s) > 0 && unicode.IsDigit(rune(s[0])) { + s = "_" + s + } + if s == "" { + s = "_" + } + return s +} + +// GenerateToolDescription creates a dynamic description for the execute_tool_script tool. +func GenerateToolDescription(tools []ToolInfo) string { + var b strings.Builder + b.WriteString("Execute a Starlark script that orchestrates multiple tool calls ") + b.WriteString("and returns an aggregated result. Use 'return' to produce output.\n\n") + b.WriteString("Available tools (callable as functions with keyword arguments):\n") + for _, t := range tools { + sanitized := sanitizeToolName(t.Name) + desc := t.Description + if len(desc) > 80 { + desc = desc[:77] + "..." + } + fmt.Fprintf(&b, " - %s: %s\n", sanitized, desc) + } + b.WriteString("\nTool names with special characters are available with underscores ") + b.WriteString("(e.g., my-tool becomes my_tool). Use call_tool(\"name\", ...) for any tool by its original name.\n\n") + b.WriteString("Built-in: parallel([fn1, fn2, ...]) executes zero-arg callables concurrently ") + b.WriteString("and returns results in order. Use with lambda to fan out tool calls.\n\n") + b.WriteString("Named data arguments passed in the 'data' parameter are available as top-level variables in the script.") + return b.String() +} + +// starlarkToGo converts a Starlark value to a Go value. +func starlarkToGo(v starlark.Value) interface{} { + switch v := v.(type) { + case starlark.NoneType: + return nil + case starlark.Bool: + return bool(v) + case starlark.Int: + if i, ok := v.Int64(); ok { + return i + } + return v.String() + case starlark.Float: + return float64(v) + case starlark.String: + return string(v) + case *starlark.List: + result := make([]interface{}, v.Len()) + for i := 0; i < v.Len(); i++ { + result[i] = starlarkToGo(v.Index(i)) + } + return result + case *starlark.Dict: + result := make(map[string]interface{}) + for _, item := range v.Items() { + key := starlarkToGo(item[0]) + keyStr, ok := key.(string) + if !ok { + keyStr = fmt.Sprintf("%v", key) + } + result[keyStr] = starlarkToGo(item[1]) + } + return result + case starlark.Tuple: + result := make([]interface{}, len(v)) + for i, elem := range v { + result[i] = starlarkToGo(elem) + } + return result + default: + return v.String() + } +} + +// goToStarlark converts a Go value to a Starlark value. +// +//nolint:gocyclo // type switch over Go types is inherently branchy +func goToStarlark(v interface{}) (starlark.Value, error) { + switch v := v.(type) { + case nil: + return starlark.None, nil + case bool: + return starlark.Bool(v), nil + case int: + return starlark.MakeInt(v), nil + case int64: + return starlark.MakeInt64(v), nil + case float64: + return goFloat64ToStarlark(v), nil + case string: + return starlark.String(v), nil + case []interface{}: + return goSliceToStarlark(v) + case map[string]interface{}: + return goMapToStarlark(v) + case json.Number: + return goJSONNumberToStarlark(v) + default: + return nil, fmt.Errorf("unsupported Go type %T for Starlark conversion", v) + } +} + +func goFloat64ToStarlark(v float64) starlark.Value { + if v == math.Trunc(v) && !math.IsInf(v, 0) && !math.IsNaN(v) && math.Abs(v) < (1<<53) { + return starlark.MakeInt64(int64(v)) + } + return starlark.Float(v) +} + +func goSliceToStarlark(v []interface{}) (starlark.Value, error) { + elems := make([]starlark.Value, len(v)) + for i, e := range v { + sv, err := goToStarlark(e) + if err != nil { + return nil, err + } + elems[i] = sv + } + return starlark.NewList(elems), nil +} + +func goMapToStarlark(v map[string]interface{}) (starlark.Value, error) { + d := starlark.NewDict(len(v)) + for k, val := range v { + sv, err := goToStarlark(val) + if err != nil { + return nil, err + } + if err := d.SetKey(starlark.String(k), sv); err != nil { + return nil, err + } + } + return d, nil +} + +func goJSONNumberToStarlark(v json.Number) (starlark.Value, error) { + if i, err := v.Int64(); err == nil { + return starlark.MakeInt64(i), nil + } + if f, err := v.Float64(); err == nil { + return starlark.Float(f), nil + } + return starlark.String(v.String()), nil +} + +// ResultToJSON converts a Starlark value to a JSON string. +func ResultToJSON(v starlark.Value) (string, error) { + goVal := starlarkToGo(v) + b, err := json.Marshal(goVal) + if err != nil { + return "", fmt.Errorf("failed to marshal result to JSON: %w", err) + } + return string(b), nil +} diff --git a/pkg/script/bridge_test.go b/pkg/script/bridge_test.go new file mode 100644 index 0000000000..4968ad0fa3 --- /dev/null +++ b/pkg/script/bridge_test.go @@ -0,0 +1,425 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package script + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.starlark.net/starlark" +) + +type mockToolCaller struct { + calls []mockToolCall + results map[string]*CallToolResult + err error +} + +type mockToolCall struct { + Name string + Arguments map[string]interface{} +} + +func (m *mockToolCaller) CallTool(_ context.Context, toolName string, arguments map[string]interface{}) (*CallToolResult, error) { + m.calls = append(m.calls, mockToolCall{Name: toolName, Arguments: arguments}) + if m.err != nil { + return nil, m.err + } + if r, ok := m.results[toolName]; ok { + return r, nil + } + return nil, fmt.Errorf("tool %q not found", toolName) +} + +func TestSanitizeToolName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + }{ + {"simple", "fetch_prs", "fetch_prs"}, + {"hyphens", "github-fetch-prs", "github_fetch_prs"}, + {"dots", "my.tool.name", "my_tool_name"}, + {"leading digit", "3d-render", "_3d_render"}, + {"special chars", "tool@v2!", "tool_v2_"}, + {"empty string", "", "_"}, + {"already clean", "measure_length", "measure_length"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, sanitizeToolName(tt.input)) + }) + } +} + +func TestParseToolResult(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + result *CallToolResult + want interface{} + wantErr string + }{ + { + name: "JSON object in text content", + result: &CallToolResult{ + Content: []ContentItem{{Type: "text", Text: `{"id": 1, "name": "test"}`}}, + }, + want: map[string]interface{}{"id": float64(1), "name": "test"}, + }, + { + name: "JSON array in text content", + result: &CallToolResult{ + Content: []ContentItem{{Type: "text", Text: `[1, 2, 3]`}}, + }, + want: []interface{}{float64(1), float64(2), float64(3)}, + }, + { + name: "JSON number in text content", + result: &CallToolResult{ + Content: []ContentItem{{Type: "text", Text: `42`}}, + }, + want: float64(42), + }, + { + name: "plain string in text content", + result: &CallToolResult{ + Content: []ContentItem{{Type: "text", Text: "hello world"}}, + }, + want: "hello world", + }, + { + name: "structured content preferred over text", + result: &CallToolResult{ + Content: []ContentItem{{Type: "text", Text: `"ignored"`}}, + StructuredContent: map[string]interface{}{"key": "from_structured"}, + }, + want: map[string]interface{}{"key": "from_structured"}, + }, + { + name: "structured content result wrapper unwrapped", + result: &CallToolResult{ + Content: []ContentItem{{Type: "text", Text: "the text"}}, + StructuredContent: map[string]interface{}{"result": "the text"}, + }, + want: "the text", + }, + { + name: "isError returns error", + result: &CallToolResult{ + Content: []ContentItem{{Type: "text", Text: "something went wrong"}}, + IsError: true, + }, + wantErr: "something went wrong", + }, + { + name: "isError with empty content", + result: &CallToolResult{ + IsError: true, + }, + wantErr: "tool execution error", + }, + { + name: "empty content returns nil", + result: &CallToolResult{}, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := parseToolResult(tt.result) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestBuildGlobals_ToolCallFlow(t *testing.T) { + t.Parallel() + + caller := &mockToolCaller{ + results: map[string]*CallToolResult{ + "fetch_prs": { + Content: []ContentItem{{Type: "text", Text: `[{"id": 1, "title": "PR1"}]`}}, + }, + }, + } + + tools := []ToolInfo{{Name: "fetch_prs", Description: "Fetch pull requests"}} + globals := BuildGlobals(context.Background(), tools, caller, nil) + + // Tool should be available as a global + _, ok := globals["fetch_prs"] + assert.True(t, ok, "fetch_prs should be in globals") + + // call_tool should always be present + _, ok = globals["call_tool"] + assert.True(t, ok, "call_tool should be in globals") +} + +func TestBuildGlobals_DataArguments(t *testing.T) { + t.Parallel() + + caller := &mockToolCaller{results: map[string]*CallToolResult{}} + data := map[string]interface{}{ + "author": "jerm-dro", + "count": float64(5), + "tags": []interface{}{"go", "starlark"}, + "metadata": map[string]interface{}{"key": "value"}, + } + + globals := BuildGlobals(context.Background(), nil, caller, data) + + // Verify data arguments are injected + v, ok := globals["author"] + require.True(t, ok) + assert.Equal(t, starlark.String("jerm-dro"), v) + + v, ok = globals["count"] + require.True(t, ok) + assert.Equal(t, starlark.MakeInt(5), v) +} + +func TestBuildGlobals_HyphenatedToolName(t *testing.T) { + t.Parallel() + + caller := &mockToolCaller{ + results: map[string]*CallToolResult{ + "github-fetch-prs": { + Content: []ContentItem{{Type: "text", Text: `"ok"`}}, + }, + }, + } + + tools := []ToolInfo{{Name: "github-fetch-prs", Description: "Fetch PRs"}} + globals := BuildGlobals(context.Background(), tools, caller, nil) + + // Should be available as sanitized name + _, ok := globals["github_fetch_prs"] + assert.True(t, ok, "github_fetch_prs should be in globals") +} + +func TestCallTool_ViaScript(t *testing.T) { + t.Parallel() + + caller := &mockToolCaller{ + results: map[string]*CallToolResult{ + "fetch_data": { + Content: []ContentItem{{Type: "text", Text: `{"value": 42}`}}, + }, + }, + } + + tools := []ToolInfo{{Name: "fetch_data", Description: "Fetch data"}} + globals := BuildGlobals(context.Background(), tools, caller, nil) + + result, err := Execute(`return fetch_data(key="test")`, globals, 0) + require.NoError(t, err) + + got := starlarkToGo(result.Value) + assert.Equal(t, map[string]interface{}{"value": int64(42)}, got) + + // Verify the caller received correct arguments + require.Len(t, caller.calls, 1) + assert.Equal(t, "fetch_data", caller.calls[0].Name) + assert.Equal(t, map[string]interface{}{"key": "test"}, caller.calls[0].Arguments) +} + +func TestCallTool_GenericCallTool(t *testing.T) { + t.Parallel() + + caller := &mockToolCaller{ + results: map[string]*CallToolResult{ + "github-fetch-prs": { + Content: []ContentItem{{Type: "text", Text: `"result"`}}, + }, + }, + } + + tools := []ToolInfo{{Name: "github-fetch-prs", Description: "Fetch PRs"}} + globals := BuildGlobals(context.Background(), tools, caller, nil) + + result, err := Execute(`return call_tool("github-fetch-prs", author="jerm")`, globals, 0) + require.NoError(t, err) + + got := starlarkToGo(result.Value) + assert.Equal(t, "result", got) + + require.Len(t, caller.calls, 1) + assert.Equal(t, "github-fetch-prs", caller.calls[0].Name) +} + +func TestCallTool_ErrorPropagation(t *testing.T) { + t.Parallel() + + caller := &mockToolCaller{ + results: map[string]*CallToolResult{ + "failing_tool": { + Content: []ContentItem{{Type: "text", Text: "access denied"}}, + IsError: true, + }, + }, + } + + tools := []ToolInfo{{Name: "failing_tool", Description: "A tool that fails"}} + globals := BuildGlobals(context.Background(), tools, caller, nil) + + _, err := Execute(`return failing_tool()`, globals, 0) + require.Error(t, err) + assert.Contains(t, err.Error(), "access denied") +} + +func TestStarlarkToGoRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input interface{} + }{ + {"nil", nil}, + {"bool true", true}, + {"bool false", false}, + {"int", int64(42)}, + {"float", 3.14}, + {"string", "hello"}, + {"empty list", []interface{}{}}, + {"int list", []interface{}{int64(1), int64(2), int64(3)}}, + {"map", map[string]interface{}{"a": int64(1), "b": "two"}}, + {"nested", map[string]interface{}{ + "items": []interface{}{ + map[string]interface{}{"id": int64(1)}, + }, + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + sv, err := goToStarlark(tt.input) + require.NoError(t, err) + got := starlarkToGo(sv) + assert.Equal(t, tt.input, got) + }) + } +} + +func TestGoToStarlark_JSONFloat64AsInt(t *testing.T) { + t.Parallel() + + // JSON numbers come as float64 — integers should be converted to Starlark int + sv, err := goToStarlark(float64(42)) + require.NoError(t, err) + _, ok := sv.(starlark.Int) + assert.True(t, ok, "float64(42) should become starlark.Int, got %T", sv) +} + +func TestParallel_ViaScript(t *testing.T) { + t.Parallel() + + caller := &mockToolCaller{ + results: map[string]*CallToolResult{ + "tool_a": {Content: []ContentItem{{Type: "text", Text: `"result_a"`}}}, + "tool_b": {Content: []ContentItem{{Type: "text", Text: `"result_b"`}}}, + "tool_c": {Content: []ContentItem{{Type: "text", Text: `"result_c"`}}}, + }, + } + + tools := []ToolInfo{ + {Name: "tool_a", Description: "A"}, + {Name: "tool_b", Description: "B"}, + {Name: "tool_c", Description: "C"}, + } + globals := BuildGlobals(context.Background(), tools, caller, nil) + + script := ` +results = parallel([ + lambda: tool_a(), + lambda: tool_b(), + lambda: tool_c(), +]) +return results +` + result, err := Execute(script, globals, 0) + require.NoError(t, err) + + got := starlarkToGo(result.Value) + gotList, ok := got.([]interface{}) + require.True(t, ok, "expected list, got %T", got) + require.Len(t, gotList, 3) + assert.Equal(t, "result_a", gotList[0]) + assert.Equal(t, "result_b", gotList[1]) + assert.Equal(t, "result_c", gotList[2]) + + // All three tools should have been called + require.Len(t, caller.calls, 3) +} + +func TestParallel_ErrorPropagation(t *testing.T) { + t.Parallel() + + caller := &mockToolCaller{ + results: map[string]*CallToolResult{ + "good_tool": {Content: []ContentItem{{Type: "text", Text: `"ok"`}}}, + }, + } + + tools := []ToolInfo{ + {Name: "good_tool", Description: "works"}, + {Name: "bad_tool", Description: "missing"}, + } + globals := BuildGlobals(context.Background(), tools, caller, nil) + + script := ` +results = parallel([ + lambda: good_tool(), + lambda: bad_tool(), +]) +return results +` + _, err := Execute(script, globals, 0) + require.Error(t, err) + assert.Contains(t, err.Error(), "parallel: task 1 failed") +} + +func TestParallel_EmptyList(t *testing.T) { + t.Parallel() + + globals := BuildGlobals(context.Background(), nil, &mockToolCaller{results: map[string]*CallToolResult{}}, nil) + + result, err := Execute(`return parallel([])`, globals, 0) + require.NoError(t, err) + + got := starlarkToGo(result.Value) + assert.Equal(t, []interface{}{}, got) +} + +func TestGenerateToolDescription(t *testing.T) { + t.Parallel() + + tools := []ToolInfo{ + {Name: "measure_length", Description: "Measure the length of text"}, + {Name: "random_number", Description: "Generate a random number"}, + } + + desc := GenerateToolDescription(tools) + assert.Contains(t, desc, "measure_length") + assert.Contains(t, desc, "random_number") + assert.Contains(t, desc, "Measure the length of text") + assert.Contains(t, desc, "call_tool") +} diff --git a/pkg/script/engine.go b/pkg/script/engine.go new file mode 100644 index 0000000000..916e5baa54 --- /dev/null +++ b/pkg/script/engine.go @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package script provides a Starlark-based script execution engine for orchestrating +// MCP tool calls. It allows agents to send scripts that call multiple tools and +// return aggregated results. +package script + +import ( + "fmt" + "strings" + + "go.starlark.net/starlark" + "go.starlark.net/syntax" +) + +// DefaultStepLimit is the default maximum number of Starlark execution steps. +const DefaultStepLimit uint64 = 100_000 + +// ExecuteResult holds the result of a Starlark script execution. +type ExecuteResult struct { + Value starlark.Value + Logs []string +} + +// Execute runs a Starlark script with the given globals and step limit. +// The script is wrapped in a function so that top-level `return` statements work. +// A stepLimit of 0 uses DefaultStepLimit. +func Execute(script string, globals starlark.StringDict, stepLimit uint64) (*ExecuteResult, error) { + if stepLimit == 0 { + stepLimit = DefaultStepLimit + } + + wrapped := wrapScript(script) + + var logs []string + thread := &starlark.Thread{ + Name: "script-exec", + Print: func(_ *starlark.Thread, msg string) { + logs = append(logs, msg) + }, + } + thread.SetMaxExecutionSteps(stepLimit) + + // Merge globals into the predeclared set so they're available at top level + predeclared := make(starlark.StringDict, len(globals)) + for k, v := range globals { + predeclared[k] = v + } + + resultGlobals, err := starlark.ExecFileOptions( + &syntax.FileOptions{}, + thread, + "script.star", + wrapped, + predeclared, + ) + if err != nil { + return nil, fmt.Errorf("script execution failed: %w", err) + } + + result, ok := resultGlobals["__result__"] + if !ok { + result = starlark.None + } + + return &ExecuteResult{ + Value: result, + Logs: logs, + }, nil +} + +// wrapScript wraps a user script in a function body so top-level return works. +func wrapScript(script string) string { + var b strings.Builder + b.WriteString("def __main__():\n") + for _, line := range strings.Split(script, "\n") { + b.WriteString(" ") + b.WriteString(line) + b.WriteString("\n") + } + b.WriteString("__result__ = __main__()\n") + return b.String() +} diff --git a/pkg/script/engine_test.go b/pkg/script/engine_test.go new file mode 100644 index 0000000000..2057cf4ad2 --- /dev/null +++ b/pkg/script/engine_test.go @@ -0,0 +1,135 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package script + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.starlark.net/starlark" +) + +func TestExecute(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + script string + globals starlark.StringDict + stepLimit uint64 + wantValue interface{} + wantLogs []string + wantErr string + }{ + { + name: "return integer", + script: "return 42", + wantValue: int64(42), + }, + { + name: "return string", + script: `return "hello"`, + wantValue: "hello", + }, + { + name: "return dict", + script: `return {"a": 1, "b": 2}`, + wantValue: map[string]interface{}{"a": int64(1), "b": int64(2)}, + }, + { + name: "return list", + script: `return [1, 2, 3]`, + wantValue: []interface{}{int64(1), int64(2), int64(3)}, + }, + { + name: "call provided global function", + script: "return double(21)", + globals: starlark.StringDict{ + "double": starlark.NewBuiltin("double", func(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, _ []starlark.Tuple) (starlark.Value, error) { + var x int + if err := starlark.UnpackPositionalArgs("double", args, nil, 1, &x); err != nil { + return nil, err + } + return starlark.MakeInt(x * 2), nil + }), + }, + wantValue: int64(42), + }, + { + name: "syntax error", + script: "return !!!", + wantErr: "script execution failed", + }, + { + name: "runtime error division by zero", + script: "return 1 // 0", + wantErr: "script execution failed", + }, + { + name: "step limit exceeded", + script: "x = 0\nwhile True:\n x += 1", + stepLimit: 1000, + wantErr: "script execution failed", + }, + { + name: "no return yields None", + script: "x = 1", + wantValue: nil, + }, + { + name: "multi-line for loop building list", + script: `result = [] +for i in range(5): + result.append(i * 2) +return result`, + wantValue: []interface{}{int64(0), int64(2), int64(4), int64(6), int64(8)}, + }, + { + name: "print captured in logs", + script: "print(\"hello\")\nprint(\"world\")\nreturn 1", + wantValue: int64(1), + wantLogs: []string{"hello", "world"}, + }, + { + name: "return boolean true", + script: "return True", + wantValue: true, + }, + { + name: "return None explicitly", + script: "return None", + wantValue: nil, + }, + { + name: "return float", + script: "return 3.14", + wantValue: 3.14, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result, err := Execute(tt.script, tt.globals, tt.stepLimit) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + return + } + + require.NoError(t, err) + require.NotNil(t, result) + + got := starlarkToGo(result.Value) + assert.Equal(t, tt.wantValue, got) + + if tt.wantLogs != nil { + assert.Equal(t, tt.wantLogs, result.Logs) + } + }) + } +} diff --git a/pkg/script/middleware.go b/pkg/script/middleware.go new file mode 100644 index 0000000000..100085dad7 --- /dev/null +++ b/pkg/script/middleware.go @@ -0,0 +1,477 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package script + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + + "github.com/stacklok/toolhive/pkg/mcp" + "github.com/stacklok/toolhive/pkg/transport/types" +) + +const ( + // ExecuteToolScriptName is the name of the virtual tool exposed by this middleware. + ExecuteToolScriptName = "execute_tool_script" + // MiddlewareType is the middleware type identifier for registration. + MiddlewareType = "script" +) + +// Middleware implements the types.Middleware interface. +type Middleware struct { + middleware types.MiddlewareFunction +} + +// Handler returns the middleware function. +func (s *Middleware) Handler() types.MiddlewareFunction { + return s.middleware +} + +// Close is a no-op for the script middleware. +func (*Middleware) Close() error { + return nil +} + +// CreateMiddleware is the factory function for registering the script middleware. +func CreateMiddleware(_ *types.MiddlewareConfig, runner types.MiddlewareRunner) error { + mw := &Middleware{middleware: NewMiddleware()} + runner.AddMiddleware(MiddlewareType, mw) + return nil +} + +// NewMiddleware returns a middleware function that intercepts execute_tool_script +// calls and injects the virtual tool into tools/list responses. +func NewMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + next.ServeHTTP(w, r) + return + } + + contentType := r.Header.Get("Content-Type") + if !strings.HasPrefix(contentType, "application/json") { + next.ServeHTTP(w, r) + return + } + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + next.ServeHTTP(w, r) + return + } + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + var req jsonRPCRequest + if err := json.Unmarshal(bodyBytes, &req); err != nil { + next.ServeHTTP(w, r) + return + } + + switch { + case req.Method == "tools/call" && isScriptToolCall(&req): + handleScriptExecution(w, r, next, &req) + case req.Method == "tools/list": + handleToolsListInjection(w, r, next) + default: + next.ServeHTTP(w, r) + } + }) + } +} + +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Method string `json:"method"` + Params *json.RawMessage `json:"params,omitempty"` +} + +type toolCallParams struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +func isScriptToolCall(req *jsonRPCRequest) bool { + if req.Params == nil { + return false + } + var params toolCallParams + if err := json.Unmarshal(*req.Params, ¶ms); err != nil { + return false + } + return params.Name == ExecuteToolScriptName +} + +func handleScriptExecution(w http.ResponseWriter, r *http.Request, next http.Handler, req *jsonRPCRequest) { + var params toolCallParams + if err := json.Unmarshal(*req.Params, ¶ms); err != nil { + writeJSONRPCError(w, req.ID, -32602, "invalid params") + return + } + + scriptRaw, ok := params.Arguments["script"] + if !ok { + writeJSONRPCError(w, req.ID, -32602, "missing required argument: script") + return + } + script, ok := scriptRaw.(string) + if !ok { + writeJSONRPCError(w, req.ID, -32602, "script argument must be a string") + return + } + + var data map[string]interface{} + if dataRaw, ok := params.Arguments["data"]; ok { + data, _ = dataRaw.(map[string]interface{}) + } + + // Fetch authorized tool list + tools, err := fetchToolList(r, next) + if err != nil { + slog.Error("failed to fetch tool list for script execution", "error", err) + writeJSONRPCError(w, req.ID, -32000, "failed to fetch available tools") + return + } + + // Build caller and globals + caller := &innerToolCaller{next: next, origReq: r} + globals := BuildGlobals(r.Context(), tools, caller, data) + + // Execute script + result, err := Execute(script, globals, 0) + if err != nil { + writeJSONRPCError(w, req.ID, -32000, err.Error()) + return + } + + // Convert result to JSON + resultJSON, err := ResultToJSON(result.Value) + if err != nil { + writeJSONRPCError(w, req.ID, -32000, fmt.Sprintf("failed to serialize result: %v", err)) + return + } + + writeJSONRPCResult(w, req.ID, resultJSON, result.Logs) +} + +func handleToolsListInjection(w http.ResponseWriter, r *http.Request, next http.Handler) { + rec := httptest.NewRecorder() + next.ServeHTTP(rec, r) + + // Copy status and headers + for k, v := range rec.Header() { + w.Header()[k] = v + } + + body := rec.Body.Bytes() + + var resp jsonRPCResponse + if err := json.Unmarshal(body, &resp); err != nil { + w.WriteHeader(rec.Code) + //nolint:errcheck,gosec // best-effort write + w.Write(body) + return + } + + if resp.Result == nil { + w.WriteHeader(rec.Code) + //nolint:errcheck,gosec // best-effort write + w.Write(body) + return + } + + var resultMap map[string]interface{} + if err := json.Unmarshal(*resp.Result, &resultMap); err != nil { + w.WriteHeader(rec.Code) + //nolint:errcheck,gosec // best-effort write + w.Write(body) + return + } + + toolsRaw, ok := resultMap["tools"] + if !ok { + w.WriteHeader(rec.Code) + //nolint:errcheck,gosec // best-effort write + w.Write(body) + return + } + + toolsSlice, ok := toolsRaw.([]interface{}) + if !ok { + w.WriteHeader(rec.Code) + //nolint:errcheck,gosec // best-effort write + w.Write(body) + return + } + + // Extract tool info for dynamic description + var toolInfos []ToolInfo + for _, t := range toolsSlice { + tm, ok := t.(map[string]interface{}) + if !ok { + continue + } + name, _ := tm["name"].(string) + desc, _ := tm["description"].(string) + if name != "" { + toolInfos = append(toolInfos, ToolInfo{Name: name, Description: desc}) + } + } + + // Append the virtual tool + scriptTool := buildScriptToolDefinition(toolInfos) + toolsSlice = append(toolsSlice, scriptTool) + resultMap["tools"] = toolsSlice + + resultBytes, err := json.Marshal(resultMap) + if err != nil { + w.WriteHeader(rec.Code) + //nolint:errcheck,gosec // best-effort write + w.Write(body) + return + } + + resp.Result = (*json.RawMessage)(&resultBytes) + modified, err := json.Marshal(resp) + if err != nil { + w.WriteHeader(rec.Code) + //nolint:errcheck,gosec // best-effort write + w.Write(body) + return + } + + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(modified))) + w.WriteHeader(rec.Code) + //nolint:errcheck,gosec // best-effort write + w.Write(modified) +} + +func buildScriptToolDefinition(tools []ToolInfo) map[string]interface{} { + return map[string]interface{}{ + "name": ExecuteToolScriptName, + "description": GenerateToolDescription(tools), + "inputSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "script": map[string]interface{}{ + "type": "string", + "description": "Starlark script body. Use 'return' to produce output.", + }, + "data": map[string]interface{}{ + "type": "object", + "description": "Named data arguments injected as top-level Starlark variables", + "additionalProperties": true, + }, + }, + "required": []string{"script"}, + }, + } +} + +// fetchToolList sends a synthetic tools/list request through the middleware chain. +func fetchToolList(origReq *http.Request, next http.Handler) ([]ToolInfo, error) { + listBody := `{"jsonrpc":"2.0","id":0,"method":"tools/list","params":{}}` + + ctx := context.WithValue(origReq.Context(), mcp.MCPRequestContextKey, nil) + innerReq, err := http.NewRequestWithContext(ctx, http.MethodPost, origReq.URL.String(), strings.NewReader(listBody)) + if err != nil { + return nil, err + } + innerReq.Header = origReq.Header.Clone() + innerReq.Header.Set("Content-Type", "application/json") + innerReq.ContentLength = int64(len(listBody)) + + rec := httptest.NewRecorder() + next.ServeHTTP(rec, innerReq) + + if rec.Code != http.StatusOK { + return nil, fmt.Errorf("tools/list returned status %d", rec.Code) + } + + var resp jsonRPCResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + return nil, fmt.Errorf("failed to parse tools/list response: %w", err) + } + + if resp.Result == nil { + return nil, fmt.Errorf("tools/list response has no result") + } + + var resultMap map[string]interface{} + if err := json.Unmarshal(*resp.Result, &resultMap); err != nil { + return nil, fmt.Errorf("failed to parse tools/list result: %w", err) + } + + toolsRaw, ok := resultMap["tools"] + if !ok { + return nil, nil + } + + toolsSlice, ok := toolsRaw.([]interface{}) + if !ok { + return nil, nil + } + + var tools []ToolInfo + for _, t := range toolsSlice { + tm, ok := t.(map[string]interface{}) + if !ok { + continue + } + name, _ := tm["name"].(string) + desc, _ := tm["description"].(string) + if name != "" { + tools = append(tools, ToolInfo{Name: name, Description: desc}) + } + } + + return tools, nil +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Result *json.RawMessage `json:"result,omitempty"` + Error *json.RawMessage `json:"error,omitempty"` +} + +func writeJSONRPCError(w http.ResponseWriter, id json.RawMessage, code int, message string) { + resp := map[string]interface{}{ + "jsonrpc": "2.0", + "id": json.RawMessage(id), + "error": map[string]interface{}{ + "code": code, + "message": message, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + //nolint:errcheck,gosec // best-effort write + json.NewEncoder(w).Encode(resp) +} + +func writeJSONRPCResult(w http.ResponseWriter, id json.RawMessage, resultJSON string, logs []string) { + content := []map[string]interface{}{ + {"type": "text", "text": resultJSON}, + } + if len(logs) > 0 { + content = append(content, map[string]interface{}{ + "type": "text", + "text": "Script logs:\n" + strings.Join(logs, "\n"), + }) + } + + result := map[string]interface{}{ + "content": content, + } + resultBytes, _ := json.Marshal(result) + raw := json.RawMessage(resultBytes) + + resp := jsonRPCResponse{ + JSONRPC: "2.0", + ID: id, + Result: &raw, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + //nolint:errcheck,gosec // best-effort write + json.NewEncoder(w).Encode(resp) +} + +// innerToolCaller calls tools through the inner middleware chain. +type innerToolCaller struct { + next http.Handler + origReq *http.Request +} + +func (c *innerToolCaller) CallTool( + ctx context.Context, toolName string, arguments map[string]interface{}, +) (*CallToolResult, error) { + params := map[string]interface{}{ + "name": toolName, + "arguments": arguments, + } + body := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": params, + } + bodyBytes, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool call: %w", err) + } + + // Clear parsed MCP request so the parser re-parses + innerCtx := context.WithValue(ctx, mcp.MCPRequestContextKey, nil) + innerReq, err := http.NewRequestWithContext(innerCtx, http.MethodPost, c.origReq.URL.String(), bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + innerReq.Header = c.origReq.Header.Clone() + innerReq.Header.Set("Content-Type", "application/json") + innerReq.ContentLength = int64(len(bodyBytes)) + + rec := httptest.NewRecorder() + c.next.ServeHTTP(rec, innerReq) + + if rec.Code != http.StatusOK { + return nil, fmt.Errorf("tool %q returned HTTP %d", toolName, rec.Code) + } + + var resp jsonRPCResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + return nil, fmt.Errorf("failed to parse tool response: %w", err) + } + + if resp.Error != nil { + return nil, fmt.Errorf("tool %q returned JSON-RPC error: %s", toolName, string(*resp.Error)) + } + + if resp.Result == nil { + return &CallToolResult{}, nil + } + + var resultMap map[string]interface{} + if err := json.Unmarshal(*resp.Result, &resultMap); err != nil { + return nil, fmt.Errorf("failed to parse tool result: %w", err) + } + + result := &CallToolResult{} + + // Extract isError + if isErr, ok := resultMap["isError"].(bool); ok { + result.IsError = isErr + } + + // Extract structured content + if sc, ok := resultMap["structuredContent"].(map[string]interface{}); ok { + result.StructuredContent = sc + } + + // Extract content array + if contentRaw, ok := resultMap["content"].([]interface{}); ok { + for _, item := range contentRaw { + itemMap, ok := item.(map[string]interface{}) + if !ok { + continue + } + ci := ContentItem{} + ci.Type, _ = itemMap["type"].(string) + ci.Text, _ = itemMap["text"].(string) + result.Content = append(result.Content, ci) + } + } + + return result, nil +} diff --git a/pkg/script/middleware_test.go b/pkg/script/middleware_test.go new file mode 100644 index 0000000000..2c7ae5373a --- /dev/null +++ b/pkg/script/middleware_test.go @@ -0,0 +1,280 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package script + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockMCPBackend simulates an MCP server that handles tools/list and tools/call. +func mockMCPBackend(tools map[string]func(map[string]interface{}) string) http.Handler { + toolsList := make([]map[string]interface{}, 0, len(tools)) + for name := range tools { + toolsList = append(toolsList, map[string]interface{}{ + "name": name, + "description": "Test tool: " + name, + }) + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + var req jsonRPCRequest + if err := json.Unmarshal(bodyBytes, &req); err != nil { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + + switch req.Method { + case "tools/list": + result := map[string]interface{}{"tools": toolsList} + resultBytes, _ := json.Marshal(result) + raw := json.RawMessage(resultBytes) + resp := jsonRPCResponse{JSONRPC: "2.0", ID: req.ID, Result: &raw} + //nolint:errcheck + json.NewEncoder(w).Encode(resp) + + case "tools/call": + var params toolCallParams + if req.Params != nil { + //nolint:errcheck + json.Unmarshal(*req.Params, ¶ms) + } + handler, ok := tools[params.Name] + if !ok { + writeJSONRPCError(w, req.ID, -32601, "tool not found: "+params.Name) + return + } + text := handler(params.Arguments) + result := map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "text", "text": text}, + }, + } + resultBytes, _ := json.Marshal(result) + raw := json.RawMessage(resultBytes) + resp := jsonRPCResponse{JSONRPC: "2.0", ID: req.ID, Result: &raw} + //nolint:errcheck + json.NewEncoder(w).Encode(resp) + + default: + writeJSONRPCError(w, req.ID, -32601, "unknown method") + } + }) +} + +func sendJSONRPC(t *testing.T, handler http.Handler, method string, params interface{}) *httptest.ResponseRecorder { + t.Helper() + body := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": method, + "params": params, + } + bodyBytes, err := json.Marshal(body) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + return rec +} + +func TestMiddleware_NonScriptPassthrough(t *testing.T) { + t.Parallel() + + var backendCalled bool + backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + }) + + middleware := NewMiddleware()(backend) + rec := sendJSONRPC(t, middleware, "tools/call", map[string]interface{}{ + "name": "some_other_tool", + }) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.True(t, backendCalled, "backend should be called for non-script tools") +} + +func TestMiddleware_GETPassthrough(t *testing.T) { + t.Parallel() + + var backendCalled bool + backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + }) + + middleware := NewMiddleware()(backend) + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + rec := httptest.NewRecorder() + middleware.ServeHTTP(rec, req) + + assert.True(t, backendCalled, "GET should pass through") +} + +func TestMiddleware_ToolsListInjection(t *testing.T) { + t.Parallel() + + backend := mockMCPBackend(map[string]func(map[string]interface{}) string{ + "measure_length": func(_ map[string]interface{}) string { return `5` }, + }) + + middleware := NewMiddleware()(backend) + rec := sendJSONRPC(t, middleware, "tools/list", map[string]interface{}{}) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp jsonRPCResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotNil(t, resp.Result) + + var resultMap map[string]interface{} + require.NoError(t, json.Unmarshal(*resp.Result, &resultMap)) + + toolsRaw, ok := resultMap["tools"].([]interface{}) + require.True(t, ok) + + // Should have original tool + execute_tool_script + require.Len(t, toolsRaw, 2) + + names := make([]string, 0, len(toolsRaw)) + for _, item := range toolsRaw { + tm := item.(map[string]interface{}) + names = append(names, tm["name"].(string)) + } + assert.Contains(t, names, "measure_length") + assert.Contains(t, names, ExecuteToolScriptName) + + // Check dynamic description mentions measure_length + for _, item := range toolsRaw { + tm := item.(map[string]interface{}) + if tm["name"] == ExecuteToolScriptName { + desc := tm["description"].(string) + assert.Contains(t, desc, "measure_length") + } + } +} + +func TestMiddleware_ScriptExecution(t *testing.T) { + t.Parallel() + + backend := mockMCPBackend(map[string]func(map[string]interface{}) string{ + "get_value": func(_ map[string]interface{}) string { return `42` }, + }) + + middleware := NewMiddleware()(backend) + rec := sendJSONRPC(t, middleware, "tools/call", map[string]interface{}{ + "name": ExecuteToolScriptName, + "arguments": map[string]interface{}{ + "script": `return get_value()`, + }, + }) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp jsonRPCResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotNil(t, resp.Result) + require.Nil(t, resp.Error) + + var resultMap map[string]interface{} + require.NoError(t, json.Unmarshal(*resp.Result, &resultMap)) + + content, ok := resultMap["content"].([]interface{}) + require.True(t, ok) + require.NotEmpty(t, content) + + firstItem := content[0].(map[string]interface{}) + assert.Equal(t, "text", firstItem["type"]) + assert.Equal(t, "42", firstItem["text"]) +} + +func TestMiddleware_ScriptWithDataArgs(t *testing.T) { + t.Parallel() + + backend := mockMCPBackend(map[string]func(map[string]interface{}) string{ + "echo": func(args map[string]interface{}) string { + msg, _ := args["msg"].(string) + return `"` + msg + `"` + }, + }) + + middleware := NewMiddleware()(backend) + rec := sendJSONRPC(t, middleware, "tools/call", map[string]interface{}{ + "name": ExecuteToolScriptName, + "arguments": map[string]interface{}{ + "script": `return echo(msg=greeting)`, + "data": map[string]interface{}{"greeting": "hello"}, + }, + }) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp jsonRPCResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotNil(t, resp.Result) + + var resultMap map[string]interface{} + require.NoError(t, json.Unmarshal(*resp.Result, &resultMap)) + + content := resultMap["content"].([]interface{}) + firstItem := content[0].(map[string]interface{}) + assert.Equal(t, `"hello"`, firstItem["text"]) +} + +func TestMiddleware_ScriptError(t *testing.T) { + t.Parallel() + + backend := mockMCPBackend(map[string]func(map[string]interface{}) string{}) + + middleware := NewMiddleware()(backend) + rec := sendJSONRPC(t, middleware, "tools/call", map[string]interface{}{ + "name": ExecuteToolScriptName, + "arguments": map[string]interface{}{ + "script": `return !!!`, + }, + }) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp jsonRPCResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotNil(t, resp.Error, "should return JSON-RPC error for bad script") +} + +func TestMiddleware_MissingScript(t *testing.T) { + t.Parallel() + + backend := mockMCPBackend(map[string]func(map[string]interface{}) string{}) + + middleware := NewMiddleware()(backend) + rec := sendJSONRPC(t, middleware, "tools/call", map[string]interface{}{ + "name": ExecuteToolScriptName, + "arguments": map[string]interface{}{}, + }) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp jsonRPCResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotNil(t, resp.Error, "should error when script argument is missing") +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 48dac4089e..92078e425e 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -122,6 +122,11 @@ type Config struct { // If nil, no authorization is performed. AuthzMiddleware func(http.Handler) http.Handler + // ScriptMiddleware is the optional Starlark script execution middleware. + // Sits above (outer to) authz so scripts only see/call authorized tools. + // If nil, the execute_tool_script virtual tool is not available. + ScriptMiddleware func(http.Handler) http.Handler + // AuthInfoHandler is the optional handler for /.well-known/oauth-protected-resource endpoint. // Exposes OIDC discovery information about the protected resource. AuthInfoHandler http.Handler @@ -554,10 +559,10 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { // MCP endpoint - apply middleware chain (wrapping order, execution happens in reverse): // Code wraps: auth+parser → audit → discovery → annotation-enrichment → - // authz → backend-enrichment → MCP-parsing → telemetry + // authz → script → backend-enrichment → MCP-parsing → telemetry // Execution order: recovery → header-val → auth+parser → audit → - // discovery → annotation-enrichment → authz → backend-enrichment → - // MCP-parsing → telemetry → handler + // discovery → annotation-enrichment → authz → script → + // backend-enrichment → MCP-parsing → telemetry → handler var mcpHandler http.Handler = streamableServer @@ -589,6 +594,14 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { slog.Info("authorization middleware enabled for MCP endpoints (post-discovery)") } + // Apply script middleware if configured (runs AFTER authz in execution). + // Wrapping after authz makes it outer: intercepts execute_tool_script calls + // before they reach authz, but inner tool calls from scripts flow through authz. + if s.config.ScriptMiddleware != nil { + mcpHandler = s.config.ScriptMiddleware(mcpHandler) + slog.Info("script middleware enabled for MCP endpoints") + } + // Apply annotation enrichment middleware (runs after discovery, before authz in execution). // Reads tool annotations from discovered capabilities and injects them into the // request context so the authz middleware can make annotation-aware decisions. diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_script_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_script_test.go new file mode 100644 index 0000000000..ee66bb62f1 --- /dev/null +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_script_test.go @@ -0,0 +1,184 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package virtualmcp + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/mark3labs/mcp-go/mcp" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + "github.com/stacklok/toolhive/pkg/script" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/test/e2e/images" +) + +var _ = Describe("VirtualMCPServer Script Middleware", Ordered, func() { + var ( + testNamespace = "default" + mcpGroupName = "test-script-group" + vmcpServerName = "test-vmcp-script" + backendName = "yardstick-script" + timeout = 3 * time.Minute + pollingInterval = 1 * time.Second + vmcpNodePort int32 + ) + + BeforeAll(func() { + By("Creating MCPGroup for script middleware test") + CreateMCPGroupAndWait(ctx, k8sClient, mcpGroupName, testNamespace, + "Test MCP Group for script middleware", timeout, pollingInterval) + + By("Creating yardstick backend MCPServer") + backend := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: backendName, + Namespace: testNamespace, + }, + Spec: mcpv1alpha1.MCPServerSpec{ + GroupRef: mcpGroupName, + Image: images.YardstickServerImage, + Transport: "streamable-http", + ProxyPort: 8080, + McpPort: 8080, + Env: []mcpv1alpha1.EnvVar{ + {Name: "TRANSPORT", Value: "streamable-http"}, + }, + }, + } + Expect(k8sClient.Create(ctx, backend)).To(Succeed()) + + By("Waiting for backend MCPServer to be running") + Eventually(func() error { + server := &mcpv1alpha1.MCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: backendName, + Namespace: testNamespace, + }, server); err != nil { + return fmt.Errorf("failed to get server: %w", err) + } + if server.Status.Phase != mcpv1alpha1.MCPServerPhaseReady { + return fmt.Errorf("not ready yet, phase: %s", server.Status.Phase) + } + return nil + }, timeout, pollingInterval).Should(Succeed()) + + By("Creating VirtualMCPServer") + vmcpServer := &mcpv1alpha1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: vmcpServerName, + Namespace: testNamespace, + }, + Spec: mcpv1alpha1.VirtualMCPServerSpec{ + Config: vmcpconfig.Config{ + Group: mcpGroupName, + }, + IncomingAuth: &mcpv1alpha1.IncomingAuthConfig{ + Type: "anonymous", + }, + ServiceType: "NodePort", + }, + } + Expect(k8sClient.Create(ctx, vmcpServer)).To(Succeed()) + + By("Waiting for VirtualMCPServer to be ready") + WaitForVirtualMCPServerReady(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval) + + By("Getting NodePort for VirtualMCPServer") + vmcpNodePort = GetVMCPNodePort(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval) + GinkgoWriter.Printf("VirtualMCPServer accessible at http://localhost:%d\n", vmcpNodePort) + }) + + AfterAll(func() { + By("Cleaning up VirtualMCPServer") + _ = k8sClient.Delete(ctx, &mcpv1alpha1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: vmcpServerName, Namespace: testNamespace}, + }) + + By("Cleaning up backend MCPServer") + _ = k8sClient.Delete(ctx, &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: backendName, Namespace: testNamespace}, + }) + + By("Cleaning up MCPGroup") + _ = k8sClient.Delete(ctx, &mcpv1alpha1.MCPGroup{ + ObjectMeta: metav1.ObjectMeta{Name: mcpGroupName, Namespace: testNamespace}, + }) + }) + + It("should include execute_tool_script in tools/list with dynamic description", func() { + tools := WaitForExpectedTools(vmcpNodePort, "script-test-client", + func(toolsList []mcp.Tool) error { + return ToolsContainAll(toolsList, script.ExecuteToolScriptName) + }, timeout) + + // Find the script tool and verify its description mentions backend tools + var scriptTool *mcp.Tool + for i := range tools.Tools { + if tools.Tools[i].Name == script.ExecuteToolScriptName { + scriptTool = &tools.Tools[i] + break + } + } + Expect(scriptTool).ToNot(BeNil(), "execute_tool_script should be in tool list") + Expect(scriptTool.Description).To(ContainSubstring("echo"), + "dynamic description should mention yardstick's echo tool") + + GinkgoWriter.Printf("Script tool description:\n%s\n", scriptTool.Description) + }) + + It("should execute a script that calls a backend tool", func() { + mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "script-exec-client", 30*time.Second) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + callRequest := mcp.CallToolRequest{} + callRequest.Params.Name = script.ExecuteToolScriptName + callRequest.Params.Arguments = map[string]any{ + "script": `result = echo(input=message) +return {"echoed": result}`, + "data": map[string]any{ + "message": "hello from script", + }, + } + + result, err := mcpClient.Client.CallTool(mcpClient.Ctx, callRequest) + Expect(err).ToNot(HaveOccurred(), "Script tool call should succeed") + Expect(result).ToNot(BeNil()) + Expect(result.Content).ToNot(BeEmpty(), "Should have content in response") + + // Parse the result text + textContent, ok := result.Content[0].(mcp.TextContent) + Expect(ok).To(BeTrue(), "First content should be text") + GinkgoWriter.Printf("Script result: %s\n", textContent.Text) + + // The result should be valid JSON containing the echoed value + var resultMap map[string]any + Expect(json.Unmarshal([]byte(textContent.Text), &resultMap)).To(Succeed()) + Expect(resultMap).To(HaveKey("echoed")) + }) + + It("should return error for invalid script", func() { + mcpClient, err := CreateInitializedMCPClient(vmcpNodePort, "script-error-client", 30*time.Second) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + callRequest := mcp.CallToolRequest{} + callRequest.Params.Name = script.ExecuteToolScriptName + callRequest.Params.Arguments = map[string]any{ + "script": "return !!!", + } + + // The call may return an error or an isError result depending on how the + // JSON-RPC error is surfaced through the mcp-go client + _, err = mcpClient.Client.CallTool(mcpClient.Ctx, callRequest) + Expect(err).To(HaveOccurred(), "Invalid script should produce an error") + }) +})