Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions src/sre_agent/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,33 @@
logging.getLogger("pydantic_ai").setLevel(logging.INFO)


def _parse_time_range_minutes(raw: str) -> int:
"""Parse and validate a time range minutes value.

Args:
raw: Raw time range value from CLI args or environment.

Returns:
The validated time range in minutes.
"""
try:
minutes = int(raw)
except ValueError as exc:
print("time_range_minutes must be an integer.")
raise SystemExit(1) from exc

if minutes <= 0:
print("time_range_minutes must be greater than 0.")
raise SystemExit(1)
return minutes


def _load_request_from_args_or_env() -> tuple[str, str, int]:
"""Load diagnosis inputs from CLI args or environment."""
if len(sys.argv) >= 3:
log_group = sys.argv[1]
service_name = sys.argv[2]
time_range_minutes = int(sys.argv[3]) if len(sys.argv) > 3 else 10
time_range_minutes = _parse_time_range_minutes(sys.argv[3]) if len(sys.argv) > 3 else 10
return log_group, service_name, time_range_minutes

log_group = os.getenv("LOG_GROUP", "").strip()
Expand All @@ -35,16 +56,7 @@ def _load_request_from_args_or_env() -> tuple[str, str, int]:
)
raise SystemExit(1)

raw_time_range = os.getenv("TIME_RANGE_MINUTES", "10").strip()
try:
time_range_minutes = int(raw_time_range)
except ValueError as exc:
print("TIME_RANGE_MINUTES must be an integer.")
raise SystemExit(1) from exc

if time_range_minutes <= 0:
print("TIME_RANGE_MINUTES must be greater than 0.")
raise SystemExit(1)
time_range_minutes = _parse_time_range_minutes(os.getenv("TIME_RANGE_MINUTES", "10").strip())
return log_group, service_name, time_range_minutes


Expand Down
63 changes: 63 additions & 0 deletions tests/test_run_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests for CLI and env argument parsing in sre_agent.run."""

import pytest

from sre_agent.run import _load_request_from_args_or_env, _parse_time_range_minutes


def test_parse_time_range_minutes_valid() -> None:
"""Valid integer strings are parsed."""
assert _parse_time_range_minutes("15") == 15


def test_parse_time_range_minutes_non_integer_exits(capsys: pytest.CaptureFixture[str]) -> None:
"""Non integer input exits with a friendly message rather than crashing."""
with pytest.raises(SystemExit) as excinfo:
_parse_time_range_minutes("abc")
assert excinfo.value.code == 1
assert "must be an integer" in capsys.readouterr().out


@pytest.mark.parametrize("raw", ["0", "-5"])
def test_parse_time_range_minutes_non_positive_exits(
raw: str, capsys: pytest.CaptureFixture[str]
) -> None:
"""Zero and negative values are rejected."""
with pytest.raises(SystemExit) as excinfo:
_parse_time_range_minutes(raw)
assert excinfo.value.code == 1
assert "greater than 0" in capsys.readouterr().out


def test_cli_args_invalid_time_range_exits(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
"""Bad CLI third argument exits cleanly instead of raising ValueError."""
monkeypatch.setattr("sys.argv", ["run.py", "log-group", "service", "abc"])
with pytest.raises(SystemExit) as excinfo:
_load_request_from_args_or_env()
assert excinfo.value.code == 1
assert "must be an integer" in capsys.readouterr().out


def test_cli_args_non_positive_time_range_exits(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
"""A zero or negative CLI third argument is rejected."""
monkeypatch.setattr("sys.argv", ["run.py", "log-group", "service", "0"])
with pytest.raises(SystemExit) as excinfo:
_load_request_from_args_or_env()
assert excinfo.value.code == 1
assert "greater than 0" in capsys.readouterr().out


def test_cli_args_valid() -> None:
"""Valid CLI args parse as expected."""
import sys

original = sys.argv
try:
sys.argv = ["run.py", "log-group", "service", "42"]
assert _load_request_from_args_or_env() == ("log-group", "service", 42)
finally:
sys.argv = original