Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/model_config_test/llmmultiroundrouter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ data_path:
base_model: 'meta/llama-3.1-8b-instruct' # Model for decomposition+routing and aggregation
use_local_llm: false # Set to true to use vLLM for local inference (requires vLLM installed)
api_endpoint: 'https://integrate.api.nvidia.com/v1' # API endpoint for execution
decomposition_max_tokens: 2048 # Max tokens for the decomposition+routing step

metric:
weights:
Expand All @@ -32,4 +33,3 @@ metric:
llm_judge: 0

# No hparam section needed (no KNN parameters)

1 change: 1 addition & 0 deletions llmrouter/models/llmmultiroundrouter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Query → LLM Decomposition+Routing → [(Sub-Query 1, Model A), (Sub-Query 2, M
| `base_model` | str | `"Qwen/Qwen2.5-3B-Instruct"` | Base model for decomposition/aggregation/routing |
| `use_local_llm` | bool | `false` | Use local vLLM (true) or API (false) |
| `api_endpoint` | str | - | API endpoint for execution |
| `decomposition_max_tokens` | int | `2048` | Max tokens used by the decomposition + routing step |

### LLM Data

Expand Down
5 changes: 3 additions & 2 deletions llmrouter/models/llmmultiroundrouter/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def _decompose_and_route(self, query: str) -> List[tuple]:
List of (sub_query, model_name) tuples
"""
decomp_route_prompt = self.DECOMP_ROUTE_PROMPT.format(query=query)
decomp_max_tokens = self.cfg.get("decomposition_max_tokens", 2048)

if self.use_local_llm:
self._initialize_local_llm()
Expand All @@ -415,7 +416,7 @@ def _decompose_and_route(self, query: str) -> List[tuple]:
tokenize=False,
add_generation_prompt=True
)
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=512)
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=decomp_max_tokens)
outputs = self.local_llm.generate([prompt_text], sampling_params)
decomp_output = outputs[0].outputs[0].text.strip()
else:
Expand Down Expand Up @@ -459,7 +460,7 @@ def _decompose_and_route(self, query: str) -> List[tuple]:
if service:
request["service"] = service
try:
result = call_api(request, max_tokens=512, temperature=0.0)
result = call_api(request, max_tokens=decomp_max_tokens, temperature=0.0)
decomp_output = result.get("response", "")
except Exception as e:
print(f"Error in decomposition+routing: {e}")
Expand Down
66 changes: 66 additions & 0 deletions tests/test_llmmultiroundrouter_decomposition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch.nn as nn
from unittest.mock import patch

from llmrouter.models.llmmultiroundrouter.router import LLMMultiRoundRouter


def _build_router(cfg=None):
router = LLMMultiRoundRouter.__new__(LLMMultiRoundRouter)
nn.Module.__init__(router)
router.cfg = cfg or {}
router.use_local_llm = False
router.base_model = "deepseek-v4-pro"
router.api_endpoint = "https://api.example.com"
router.llm_data = {
"deepseek-flash": {
"model": "deepseek-v4-flash",
"api_endpoint": "https://api.example.com",
"service": "DeepSeek",
},
"deepseek-pro": {
"model": "deepseek-v4-pro",
"api_endpoint": "https://api.example.com",
"service": "DeepSeek",
},
}
router.DECOMP_ROUTE_PROMPT = "{query}"
return router


def test_decompose_and_route_uses_large_default_max_tokens():
router = _build_router()
captured = {}

def fake_call_api(request, max_tokens, temperature):
captured["request"] = request
captured["max_tokens"] = max_tokens
captured["temperature"] = temperature
return {"response": "simple task: deepseek-flash"}

with patch(
"llmrouter.models.llmmultiroundrouter.router.call_api",
side_effect=fake_call_api,
):
result = router._decompose_and_route("test query")

assert captured["max_tokens"] == 2048
assert captured["temperature"] == 0.0
assert result == [("simple task", "deepseek-v4-flash")]


def test_decompose_and_route_allows_token_override_from_config():
router = _build_router({"decomposition_max_tokens": 1024})
captured = {}

def fake_call_api(request, max_tokens, temperature):
captured["max_tokens"] = max_tokens
return {"response": "complex task: deepseek-pro"}

with patch(
"llmrouter.models.llmmultiroundrouter.router.call_api",
side_effect=fake_call_api,
):
result = router._decompose_and_route("test query")

assert captured["max_tokens"] == 1024
assert result == [("complex task", "deepseek-v4-pro")]