Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions auto_tune_vllm/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class OptimizationConfig:
objective: Union[str, List[str]] = None # Old format: "maximize", "minimize", list
sampler: str = "tpe" # "tpe", "random", "gp", "botorch", "nsga2", "grid"
n_trials: int = 100
n_repeats: int = 1 # Benchmark runs per trial config (same vLLM server)
n_startup_trials: int = 10 # Number of random startup trials
max_concurrent_trials: Optional[int] = (
None # Maximum concurrent trials (required for resource management)
Expand All @@ -138,6 +139,11 @@ def __post_init__(self):
else:
self._apply_default_config()
self._validate_log_metrics()
self._validate_n_repeats()

def _validate_n_repeats(self) -> None:
if self.n_repeats < 1:
raise ValueError(f"n_repeats must be >= 1, got {self.n_repeats}")

def _validate_log_metrics(self) -> None:
"""Normalize and validate log_metrics (independent of objective setup)."""
Expand Down
62 changes: 61 additions & 1 deletion auto_tune_vllm/core/study_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,10 @@ def _set_log_metric_user_attrs(
names = self.config.optimization.log_metrics
if not names or not result.success or not result.detailed_metrics:
return

n_repeats = self.config.optimization.n_repeats
repeat_runs = result.detailed_metrics.get("repeats", [])

for name in names:
if name not in result.detailed_metrics:
logger.warning(
Expand All @@ -962,7 +966,63 @@ def _set_log_metric_user_attrs(
result.trial_number,
)
continue
trial.set_user_attr(f"metric_{name}", value)

if n_repeats > 1:
repeat_values = self._collect_repeat_metric_values(
name, repeat_runs, result.trial_number
)
if repeat_values is None:
continue
rel_range = 0.0
if value != 0:
rel_range = (max(repeat_values) - min(repeat_values)) / abs(value)
trial.set_user_attr(f"metric_{name}", value)
trial.set_user_attr(f"metric_{name}_rel_range", round(rel_range, 6))
trial.set_user_attr(f"metric_{name}_values", repeat_values)
else:
trial.set_user_attr(f"metric_{name}", value)

if n_repeats > 1:
trial.set_user_attr("n_repeats", n_repeats)

@staticmethod
def _collect_repeat_metric_values(
name: str,
repeat_runs: list,
trial_number: int | None,
) -> list[float] | None:
if not repeat_runs:
logger.warning(
"log_metrics: repeats missing from detailed_metrics for trial %s; "
"skipping user attr for %r",
trial_number,
name,
)
return None

repeat_values: list[float] = []
for run in repeat_runs:
if name not in run:
logger.warning(
"log_metrics: metric %r not found in repeat run %s for trial %s; "
"skipping user attr",
name,
run.get("run"),
trial_number,
)
return None
try:
repeat_values.append(float(run[name]))
except (TypeError, ValueError):
logger.warning(
"log_metrics: cannot coerce repeat metric %r value %r to float "
"for trial %s; skipping user attr",
name,
run[name],
trial_number,
)
return None
return repeat_values

def get_best_baseline_result(self) -> list[float] | None:
"""Get the best baseline result for comparison."""
Expand Down
183 changes: 148 additions & 35 deletions auto_tune_vllm/execution/trial_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import time
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Optional
from typing import Any, Optional

try:
import ray
Expand Down Expand Up @@ -387,6 +387,14 @@ def should_cancel():
)
execution_info.mark_vllm_ready()

n_repeats = self._get_n_repeats(trial_config)
repeat_index = 0
repeat_runs: list[dict[str, Any]] = []
if n_repeats > 1:
controller_logger.info(
f"Running {n_repeats} benchmark repeats per trial config"
)

# Main execution loop - concise with extracted state handlers
while True:
poll_count += 1
Expand All @@ -404,7 +412,11 @@ def should_cancel():
# Handle current state
if state == TrialState.WAITING_FOR_VLLM:
result = self._handle_vllm_startup(
trial_config, server_info, vllm_start_time, controller_logger
trial_config,
server_info,
vllm_start_time,
controller_logger,
repeat_index,
)
if result: # vLLM is ready, transition to benchmark
benchmark_process, benchmark_start_time = result
Expand All @@ -430,13 +442,40 @@ def should_cancel():
f"success={result.success}, "
f"objectives={result.objective_values}"
)
execution_info.mark_benchmark_completed()
execution_info.mark_completed(status="success")
repeat_runs.append(
{
"run": repeat_index,
"objective_values": result.objective_values,
"detailed_metrics": dict(result.detailed_metrics),
}
)
repeat_index += 1
if repeat_index >= n_repeats:
execution_info.mark_benchmark_completed()
execution_info.mark_completed(status="success")
final_result = self._build_trial_result_from_repeats(
repeat_runs, trial_config, execution_info
)
controller_logger.info(
f"Returning successful trial result with "
f"{len(final_result.objective_values)} objectives "
f"({n_repeats} repeat(s))"
)
return final_result

controller_logger.info(
f"Returning successful trial result with "
f"{len(result.objective_values)} objectives"
f"Benchmark repeat {repeat_index}/{n_repeats} completed, "
"starting next repeat"
)
benchmark_process, benchmark_start_time = (
self._start_benchmark_run(
trial_config,
server_info,
controller_logger,
repeat_index,
)
)
return result
continue
# If None, benchmark still running - continue polling
controller_logger.debug(
f"Benchmark still running... "
Expand Down Expand Up @@ -708,12 +747,100 @@ def _check_cancellation(

raise KeyboardInterrupt(f"Trial cancelled while {state.name}")

@staticmethod
def _get_n_repeats(trial_config: TrialConfig) -> int:
if trial_config.optimization_config is None:
return 1
return trial_config.optimization_config.n_repeats

def _start_benchmark_run(
self,
trial_config: TrialConfig,
server_info: dict,
logger,
repeat_index: int,
) -> tuple[Any, float]:
n_repeats = self._get_n_repeats(trial_config)
if n_repeats > 1:
context_trial_id = f"{trial_config.trial_id}_repeat_{repeat_index}"
else:
context_trial_id = trial_config.trial_id

logger.info(f"Starting benchmark run {repeat_index + 1}/{n_repeats}")
benchmark_logger = self._get_trial_logger("benchmark")

if hasattr(self.benchmark_provider, "set_logger"):
self.benchmark_provider.set_logger(benchmark_logger)

if hasattr(self.benchmark_provider, "set_trial_context"):
self.benchmark_provider.set_trial_context(
trial_config.study_name, context_trial_id
)

benchmark_process = self.benchmark_provider.start_benchmark(
server_info["url"], trial_config.benchmark_config
)
return benchmark_process, time.time()

@staticmethod
def _average_detailed_metrics(metrics_runs: list[dict[str, Any]]) -> dict[str, Any]:
if len(metrics_runs) == 1:
return dict(metrics_runs[0])

averaged: dict[str, Any] = {}
for key in metrics_runs[0]:
values = [run[key] for run in metrics_runs if key in run]
if values and all(isinstance(value, (int, float)) for value in values):
averaged[key] = sum(values) / len(values)
return averaged

def _build_trial_result_from_repeats(
self,
repeat_runs: list[dict[str, Any]],
trial_config: TrialConfig,
execution_info: ExecutionInfo,
) -> TrialResult:
if len(repeat_runs) == 1:
run = repeat_runs[0]
return TrialResult(
trial_id=trial_config.trial_id,
trial_number=trial_config.trial_number,
trial_type=trial_config.trial_type,
objective_values=run["objective_values"],
detailed_metrics=run["detailed_metrics"],
execution_info=execution_info,
success=True,
)

metrics_runs = [run["detailed_metrics"] for run in repeat_runs]
averaged_metrics = self._average_detailed_metrics(metrics_runs)
averaged_metrics["repeats"] = [
{"run": run["run"], **run["detailed_metrics"]} for run in repeat_runs
]

n_objectives = len(repeat_runs[0]["objective_values"])
mean_objectives = []
for objective_index in range(n_objectives):
values = [run["objective_values"][objective_index] for run in repeat_runs]
mean_objectives.append(sum(values) / len(values))

return TrialResult(
trial_id=trial_config.trial_id,
trial_number=trial_config.trial_number,
trial_type=trial_config.trial_type,
objective_values=mean_objectives,
detailed_metrics=averaged_metrics,
execution_info=execution_info,
success=True,
)

def _handle_vllm_startup(
self,
trial_config: TrialConfig,
server_info: dict,
vllm_start_time: float,
logger,
repeat_index: int = 0,
):
"""Handle vLLM startup state.

Expand Down Expand Up @@ -753,23 +880,9 @@ def _handle_vllm_startup(
max_failures=trial_config.health_check_max_failures,
)

# Setup and start benchmark
logger.info("Starting benchmark run")
benchmark_logger = self._get_trial_logger("benchmark")

if hasattr(self.benchmark_provider, "set_logger"):
self.benchmark_provider.set_logger(benchmark_logger)

if hasattr(self.benchmark_provider, "set_trial_context"):
self.benchmark_provider.set_trial_context(
trial_config.study_name, trial_config.trial_id
)

# Start benchmark as subprocess
benchmark_process = self.benchmark_provider.start_benchmark(
server_info["url"], trial_config.benchmark_config
return self._start_benchmark_run(
trial_config, server_info, logger, repeat_index
)
return benchmark_process, time.time()

except requests.exceptions.RequestException as e:
# Health check failed, log and continue polling
Expand Down Expand Up @@ -1132,7 +1245,8 @@ def _check_health_status(self):
f"vLLM server health check failed: {self._health_check_failure_reason}"
)

def evaluate_metric_expression(self,
def evaluate_metric_expression(
self,
expression: str,
metric_values: dict[str, float],
) -> float:
Expand All @@ -1153,12 +1267,12 @@ def evaluate_metric_expression(self,

# Allowed operators
ALLOWED_OPERATORS = {
ast.Add: operator.add, # +
ast.Sub: operator.sub, # -
ast.Mult: operator.mul, # *
ast.Div: operator.truediv, # /
ast.Pow: operator.pow, # **
ast.USub: operator.neg, # -x (unary)
ast.Add: operator.add, # +
ast.Sub: operator.sub, # -
ast.Mult: operator.mul, # *
ast.Div: operator.truediv, # /
ast.Pow: operator.pow, # **
ast.USub: operator.neg, # -x (unary)
}

def _eval(node: ast.AST) -> float:
Expand All @@ -1184,7 +1298,9 @@ def _eval(node: ast.AST) -> float:
case ast.UnaryOp(op=op, operand=operand):
op_type = type(op)
if op_type not in ALLOWED_OPERATORS:
raise ValueError(f"Unary operator not allowed: {op_type.__name__}")
raise ValueError(
f"Unary operator not allowed: {op_type.__name__}"
)
return ALLOWED_OPERATORS[op_type](_eval(operand))

case _:
Expand All @@ -1195,7 +1311,6 @@ def _eval(node: ast.AST) -> float:
tree = ast.parse(expression, mode="eval")
return _eval(tree.body)


def _extract_objectives(
self, benchmark_result: dict, optimization_config=None
) -> list[float]:
Expand Down Expand Up @@ -1247,9 +1362,7 @@ def _extract_objectives(
metric_values[dict_key] = float_value

# Evaluate the objective expression against the metric values
objective_value = self.evaluate_metric_expression(
obj.metric, metric_values
)
objective_value = self.evaluate_metric_expression(obj.metric, metric_values)
objective_values.append(objective_value)

return objective_values
Expand Down
17 changes: 17 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,22 @@ Number of optimization trials to run. Each trial tests one parameter combination
- **Production**: 100-500 trials for thorough optimization
- **Multi-objective**: Typically needs 2x more trials than single-objective

#### `n_repeats` (integer, optional)
Number of GuideLLM benchmark runs executed for each trial configuration before reporting results to Optuna. Default: `1` (same behavior as before).

- **Budget**: `n_trials` still counts unique parameter configurations explored by the sampler. Total benchmark runs are approximately `n_trials × n_repeats`.
- **Execution**: One vLLM server start per trial; benchmarks run back-to-back on the same server. If any repeat fails, the whole trial is marked failed.
- **Optuna objective**: Mean of the repeat objective values.
- **Storage**: When `n_repeats > 1`, averaged metrics are stored in `detailed_metrics`, with per-run values under `detailed_metrics.repeats`.

Example:

```yaml
optimization:
n_trials: 50
n_repeats: 3
```

#### `n_startup_trials` (integer, optional)
Number of random trials to run before starting the main sampler algorithm. Only supported by some samplers (TPE, BoTorch). Helps initialize the sampler with diverse data points.

Expand All @@ -186,6 +202,7 @@ Extra benchmark scalars to copy onto each **Optuna trial** as [user attributes](
- **Semantics**: This does **not** change the optimization objective. It only stores additional numbers on the trial record after a successful benchmark.
- **Identifiers**: Each list entry must be a single metric id in the same `<metric>_<percentile>` form as in objective expressions (see **`objectives`** above), e.g. `request_latency_p95`, `output_tokens_per_second_median`. Allowed names are exactly the combined identifiers derived from the base metrics and percentiles documented for objectives.
- **Storage**: For each configured name, the runner writes `trial.set_user_attr("metric_<name>", float_value)` using the value from the trial’s `detailed_metrics`. If a name is missing from `detailed_metrics`, or the value cannot be converted to a float, a warning is logged and that attribute is skipped.
- **Repeats**: When `n_repeats > 1`, each listed metric also gets `metric_<name>_rel_range` (relative range: `(max - min) / abs(mean)`) and `metric_<name>_values` (list of per-run floats). The trial also stores `n_repeats`.
- **Trials**: Applied to **optimization** and **baseline** trials when the run succeeds and detailed metrics are present. Omitted or unset `log_metrics` is treated as an empty list.

Example:
Expand Down
Loading
Loading