diff --git a/configs/model_config_test/llmmultiroundrouter.yaml b/configs/model_config_test/llmmultiroundrouter.yaml index 92a7a0a..f388bee 100644 --- a/configs/model_config_test/llmmultiroundrouter.yaml +++ b/configs/model_config_test/llmmultiroundrouter.yaml @@ -24,6 +24,7 @@ data_path: base_model: 'meta/llama-3.1-8b-instruct' # Model for decomposition+routing and aggregation use_local_llm: false # Set to true to use vLLM for local inference (requires vLLM installed) api_endpoint: 'https://integrate.api.nvidia.com/v1' # API endpoint for execution +decomposition_max_tokens: 2048 # Max tokens for the decomposition+routing step metric: weights: @@ -32,4 +33,3 @@ metric: llm_judge: 0 # No hparam section needed (no KNN parameters) - diff --git a/llmrouter/models/llmmultiroundrouter/README.md b/llmrouter/models/llmmultiroundrouter/README.md index bc1fcf8..4fc7b00 100644 --- a/llmrouter/models/llmmultiroundrouter/README.md +++ b/llmrouter/models/llmmultiroundrouter/README.md @@ -56,6 +56,7 @@ Query → LLM Decomposition+Routing → [(Sub-Query 1, Model A), (Sub-Query 2, M | `base_model` | str | `"Qwen/Qwen2.5-3B-Instruct"` | Base model for decomposition/aggregation/routing | | `use_local_llm` | bool | `false` | Use local vLLM (true) or API (false) | | `api_endpoint` | str | - | API endpoint for execution | +| `decomposition_max_tokens` | int | `2048` | Max tokens used by the decomposition + routing step | ### LLM Data diff --git a/llmrouter/models/llmmultiroundrouter/router.py b/llmrouter/models/llmmultiroundrouter/router.py index 9c18839..6701a9f 100644 --- a/llmrouter/models/llmmultiroundrouter/router.py +++ b/llmrouter/models/llmmultiroundrouter/router.py @@ -405,6 +405,7 @@ def _decompose_and_route(self, query: str) -> List[tuple]: List of (sub_query, model_name) tuples """ decomp_route_prompt = self.DECOMP_ROUTE_PROMPT.format(query=query) + decomp_max_tokens = self.cfg.get("decomposition_max_tokens", 2048) if self.use_local_llm: self._initialize_local_llm() @@ -415,7 +416,7 @@ def _decompose_and_route(self, query: str) -> List[tuple]: tokenize=False, add_generation_prompt=True ) - sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=512) + sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=decomp_max_tokens) outputs = self.local_llm.generate([prompt_text], sampling_params) decomp_output = outputs[0].outputs[0].text.strip() else: @@ -459,7 +460,7 @@ def _decompose_and_route(self, query: str) -> List[tuple]: if service: request["service"] = service try: - result = call_api(request, max_tokens=512, temperature=0.0) + result = call_api(request, max_tokens=decomp_max_tokens, temperature=0.0) decomp_output = result.get("response", "") except Exception as e: print(f"Error in decomposition+routing: {e}") diff --git a/tests/test_llmmultiroundrouter_decomposition.py b/tests/test_llmmultiroundrouter_decomposition.py new file mode 100644 index 0000000..8ee7780 --- /dev/null +++ b/tests/test_llmmultiroundrouter_decomposition.py @@ -0,0 +1,66 @@ +import torch.nn as nn +from unittest.mock import patch + +from llmrouter.models.llmmultiroundrouter.router import LLMMultiRoundRouter + + +def _build_router(cfg=None): + router = LLMMultiRoundRouter.__new__(LLMMultiRoundRouter) + nn.Module.__init__(router) + router.cfg = cfg or {} + router.use_local_llm = False + router.base_model = "deepseek-v4-pro" + router.api_endpoint = "https://api.example.com" + router.llm_data = { + "deepseek-flash": { + "model": "deepseek-v4-flash", + "api_endpoint": "https://api.example.com", + "service": "DeepSeek", + }, + "deepseek-pro": { + "model": "deepseek-v4-pro", + "api_endpoint": "https://api.example.com", + "service": "DeepSeek", + }, + } + router.DECOMP_ROUTE_PROMPT = "{query}" + return router + + +def test_decompose_and_route_uses_large_default_max_tokens(): + router = _build_router() + captured = {} + + def fake_call_api(request, max_tokens, temperature): + captured["request"] = request + captured["max_tokens"] = max_tokens + captured["temperature"] = temperature + return {"response": "simple task: deepseek-flash"} + + with patch( + "llmrouter.models.llmmultiroundrouter.router.call_api", + side_effect=fake_call_api, + ): + result = router._decompose_and_route("test query") + + assert captured["max_tokens"] == 2048 + assert captured["temperature"] == 0.0 + assert result == [("simple task", "deepseek-v4-flash")] + + +def test_decompose_and_route_allows_token_override_from_config(): + router = _build_router({"decomposition_max_tokens": 1024}) + captured = {} + + def fake_call_api(request, max_tokens, temperature): + captured["max_tokens"] = max_tokens + return {"response": "complex task: deepseek-pro"} + + with patch( + "llmrouter.models.llmmultiroundrouter.router.call_api", + side_effect=fake_call_api, + ): + result = router._decompose_and_route("test query") + + assert captured["max_tokens"] == 1024 + assert result == [("complex task", "deepseek-v4-pro")]