From 805cddc204c506937d8fd3283ed122db53abc5d5 Mon Sep 17 00:00:00 2001 From: JoshuaVulcan <38018017+JoshuaVulcan@users.noreply.github.com> Date: Tue, 10 Feb 2026 16:13:57 -0800 Subject: [PATCH 1/2] feat: add v2 subject tracks and async track methods (ERA-12680) - Add `version` parameter to `ERClient.get_subject_tracks()` supporting both v1 (flat coordinates, default) and v2 (segmented GeoJSON FeatureCollection at /api/v2.0/subject/{id}/tracks/). - Add `_er_url_versioned()` helper to both clients for building versioned API URLs. - Add async `get_subject_tracks()` with identical v1/v2 support. - Add async `get_subject_source_tracks()` for per-source track retrieval. - Fix async `_call()` to handle full URLs (skip `_er_url()` prepend). - v2 forwards extra query params: show_excluded, group_by_flags, max_speed_kmh, max_gap_ms, max_gap_seconds, max_gap_minutes. - Add comprehensive sync and async test suites (22 new tests). Co-authored-by: Cursor --- erclient/client.py | 90 +++++- tests/async_client/test_get_subject_tracks.py | 277 ++++++++++++++++++ tests/sync_client/__init__.py | 0 tests/sync_client/conftest.py | 21 ++ tests/sync_client/test_get_subject_tracks.py | 232 +++++++++++++++ 5 files changed, 618 insertions(+), 2 deletions(-) create mode 100644 tests/async_client/test_get_subject_tracks.py create mode 100644 tests/sync_client/__init__.py create mode 100644 tests/sync_client/conftest.py create mode 100644 tests/sync_client/test_get_subject_tracks.py diff --git a/erclient/client.py b/erclient/client.py index 02bbbc0..16b9bf1 100644 --- a/erclient/client.py +++ b/erclient/client.py @@ -145,6 +145,17 @@ def _token_request(self, payload): def _er_url(self, path): return '/'.join((self.service_root, path)) + def _er_url_versioned(self, path, version='1.0'): + """Build an API URL for a specific API version. + + Replaces the version portion of service_root (e.g. v1.0 -> v2.0). + Falls back to the default service_root when version is '1.0'. + """ + if version == '1.0': + return self._er_url(path) + versioned_root = re.sub(r'/api/v[^/]+', f'/api/v{version}', self.service_root) + return '/'.join((versioned_root, path)) + def _get(self, path, stream=False, max_retries=5, seconds_between_attempts=5, **kwargs): headers = {'User-Agent': self.user_agent} @@ -819,9 +830,19 @@ def get_source_provider(self, provider_key): return None - def get_subject_tracks(self, subject_id='', start=None, end=None): + def get_subject_tracks(self, subject_id='', start=None, end=None, version='1.0', **kwargs): """ Get the latest tracks for the Subject having the given subject_id. + + :param subject_id: The UUID of the subject. + :param start: datetime lower-bound filter (sent as ``since``). + :param end: datetime upper-bound filter (sent as ``until``). + :param version: API version string, either '1.0' (default, legacy flat + coordinates) or '2.0' (segmented GeoJSON FeatureCollection). + :param kwargs: Extra query params forwarded to the v2 endpoint such as + ``show_excluded``, ``group_by_flags``, ``max_speed_kmh``, + ``max_gap_ms``, ``max_gap_seconds``, ``max_gap_minutes``. + :return: Track data (format varies by version). """ p = {} if start is not None and isinstance(start, datetime): @@ -829,6 +850,15 @@ def get_subject_tracks(self, subject_id='', start=None, end=None): if end is not None and isinstance(end, datetime): p['until'] = end.isoformat() + if version == '2.0': + # v2 supports additional filter params + for key in ('show_excluded', 'group_by_flags', 'max_speed_kmh', + 'max_gap_ms', 'max_gap_seconds', 'max_gap_minutes'): + if key in kwargs: + p[key] = kwargs[key] + url = self._er_url_versioned(f'subject/{subject_id}/tracks/', version='2.0') + return self._get(path=url, params=p) + return self._get(path='subject/{0}/tracks'.format(subject_id), params=p) def get_subject_source_tracks(self, subject_id='', src_id='', start=None): @@ -1280,6 +1310,17 @@ async def _token_request(self, payload): def _er_url(self, path): return '/'.join((self.service_root, path)) + def _er_url_versioned(self, path, version='1.0'): + """Build an API URL for a specific API version. + + Replaces the version portion of service_root (e.g. v1.0 -> v2.0). + Falls back to the default service_root when version is '1.0'. + """ + if version == '1.0': + return self._er_url(path) + versioned_root = re.sub(r'/api/v[^/]+', f'/api/v{version}', self.service_root) + return '/'.join((versioned_root, path)) + async def _post_form(self, path, body=None, files=None): try: @@ -1381,6 +1422,50 @@ async def get_source_assignments(self, subject_ids: List[str] = None, source_ids return await self._get(f'subjectsources', params=params) + async def get_subject_tracks(self, subject_id='', start=None, end=None, version='1.0', **kwargs): + """ + Get tracks for the Subject having the given subject_id. + + :param subject_id: The UUID of the subject. + :param start: datetime lower-bound filter (sent as ``since``). + :param end: datetime upper-bound filter (sent as ``until``). + :param version: API version string, either '1.0' (default, legacy flat + coordinates) or '2.0' (segmented GeoJSON FeatureCollection). + :param kwargs: Extra query params forwarded to the v2 endpoint such as + ``show_excluded``, ``group_by_flags``, ``max_speed_kmh``, + ``max_gap_ms``, ``max_gap_seconds``, ``max_gap_minutes``. + :return: Track data (format varies by version). + """ + p = {} + if start is not None and isinstance(start, datetime): + p['since'] = start.isoformat() + if end is not None and isinstance(end, datetime): + p['until'] = end.isoformat() + + if version == '2.0': + for key in ('show_excluded', 'group_by_flags', 'max_speed_kmh', + 'max_gap_ms', 'max_gap_seconds', 'max_gap_minutes'): + if key in kwargs: + p[key] = kwargs[key] + url = self._er_url_versioned(f'subject/{subject_id}/tracks/', version='2.0') + return await self._get(path=url, params=p) + + return await self._get(path=f'subject/{subject_id}/tracks', params=p) + + async def get_subject_source_tracks(self, subject_id='', src_id='', start=None): + """ + Get the latest tracks for the Subject having the given subject_id and a source ID. + + :param subject_id: The subject UUID + :param src_id: The source UUID + :param start: Optional datetime lower-bound filter (sent as ``since``) + :return: Track data + """ + p = {} + if start and isinstance(start, datetime): + p['since'] = start.isoformat() + return await self._get(path=f'subject/{subject_id}/source/{src_id}/tracks', params=p) + async def get_feature_group(self, feature_group_id: str): """ Get a feature group by id @@ -1444,10 +1529,11 @@ async def _call(self, path, payload, method, params=None): 'User-Agent': self.user_agent, **auth_headers } + url = path if path.startswith("http") else self._er_url(path) try: response = await self._http_session.request( method, - self._er_url(path), + url, # payload is automatically encoded as json data json=payload if method in [ "POST", "PUT", "PATCH"] else None, diff --git a/tests/async_client/test_get_subject_tracks.py b/tests/async_client/test_get_subject_tracks.py new file mode 100644 index 0000000..075a168 --- /dev/null +++ b/tests/async_client/test_get_subject_tracks.py @@ -0,0 +1,277 @@ +import httpx +import pytest +import respx +from datetime import datetime, timezone + + +@pytest.fixture +def v1_tracks_response(): + """Typical v1 subject tracks response (flat coordinate pairs).""" + return { + "data": { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7922, -1.2932], + [36.7921, -1.2931], + [36.7919, -1.2931], + ], + }, + "properties": { + "title": "Test Subject", + }, + } + ], + } + } + + +@pytest.fixture +def v2_tracks_response(): + """Typical v2 segmented tracks response (GeoJSON FeatureCollection).""" + return { + "data": { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7922, -1.2932], + [36.7921, -1.2931], + ], + }, + "properties": { + "count": 2, + "start_time": "2025-01-01T00:00:00Z", + "end_time": "2025-01-01T01:00:00Z", + }, + }, + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7919, -1.2931], + [36.7917, -1.2931], + ], + }, + "properties": { + "count": 2, + "start_time": "2025-01-01T02:00:00Z", + "end_time": "2025-01-01T03:00:00Z", + }, + }, + ], + } + } + + +SUBJECT_ID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890" + + +# ── v1 (default) tracks ────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_subject_tracks_v1_default(er_client, v1_tracks_response): + """get_subject_tracks() with default version hits the v1 endpoint.""" + async with respx.mock( + base_url=er_client.service_root, assert_all_called=False + ) as respx_mock: + route = respx_mock.get(f"subject/{SUBJECT_ID}/tracks") + route.return_value = httpx.Response( + httpx.codes.OK, json=v1_tracks_response + ) + + result = await er_client.get_subject_tracks(subject_id=SUBJECT_ID) + + assert route.called + assert result == v1_tracks_response["data"] + await er_client.close() + + +@pytest.mark.asyncio +async def test_get_subject_tracks_v1_with_dates(er_client, v1_tracks_response): + """v1 tracks pass since/until params when start/end are given.""" + async with respx.mock( + base_url=er_client.service_root, assert_all_called=False + ) as respx_mock: + route = respx_mock.get(f"subject/{SUBJECT_ID}/tracks") + route.return_value = httpx.Response( + httpx.codes.OK, json=v1_tracks_response + ) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 2, tzinfo=timezone.utc) + result = await er_client.get_subject_tracks( + subject_id=SUBJECT_ID, start=start, end=end + ) + + assert route.called + request = route.calls[0].request + assert "since" in str(request.url) + assert "until" in str(request.url) + await er_client.close() + + +# ── v2 segmented tracks ────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_subject_tracks_v2(er_client, v2_tracks_response): + """get_subject_tracks(version='2.0') hits the v2 segmented endpoint.""" + v2_base = er_client.service_root.replace("v1.0", "v2.0") + async with respx.mock(assert_all_called=False) as respx_mock: + route = respx_mock.get( + f"{v2_base}/subject/{SUBJECT_ID}/tracks/" + ).respond(httpx.codes.OK, json=v2_tracks_response) + + result = await er_client.get_subject_tracks( + subject_id=SUBJECT_ID, version="2.0" + ) + + assert route.called + assert result == v2_tracks_response["data"] + await er_client.close() + + +@pytest.mark.asyncio +async def test_get_subject_tracks_v2_with_dates(er_client, v2_tracks_response): + """v2 tracks pass since/until when start/end are given.""" + v2_base = er_client.service_root.replace("v1.0", "v2.0") + async with respx.mock(assert_all_called=False) as respx_mock: + route = respx_mock.get( + f"{v2_base}/subject/{SUBJECT_ID}/tracks/" + ).respond(httpx.codes.OK, json=v2_tracks_response) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 2, tzinfo=timezone.utc) + result = await er_client.get_subject_tracks( + subject_id=SUBJECT_ID, start=start, end=end, version="2.0" + ) + + assert route.called + request = route.calls[0].request + assert "since" in str(request.url) + assert "until" in str(request.url) + await er_client.close() + + +@pytest.mark.asyncio +async def test_get_subject_tracks_v2_with_extra_params(er_client, v2_tracks_response): + """v2 tracks forward additional filter params (show_excluded, max_speed_kmh, etc.).""" + v2_base = er_client.service_root.replace("v1.0", "v2.0") + async with respx.mock(assert_all_called=False) as respx_mock: + route = respx_mock.get( + f"{v2_base}/subject/{SUBJECT_ID}/tracks/" + ).respond(httpx.codes.OK, json=v2_tracks_response) + + result = await er_client.get_subject_tracks( + subject_id=SUBJECT_ID, + version="2.0", + show_excluded="true", + max_speed_kmh=120.0, + max_gap_minutes=30, + ) + + assert route.called + request = route.calls[0].request + url_str = str(request.url) + assert "show_excluded" in url_str + assert "max_speed_kmh" in url_str + assert "max_gap_minutes" in url_str + await er_client.close() + + +# ── get_subject_source_tracks (async) ───────────────────────────── + +@pytest.mark.asyncio +async def test_get_subject_source_tracks(er_client, v1_tracks_response): + """get_subject_source_tracks() hits the correct endpoint.""" + source_id = "bbbb1111-2222-3333-4444-555566667777" + async with respx.mock( + base_url=er_client.service_root, assert_all_called=False + ) as respx_mock: + route = respx_mock.get( + f"subject/{SUBJECT_ID}/source/{source_id}/tracks" + ) + route.return_value = httpx.Response( + httpx.codes.OK, json=v1_tracks_response + ) + + result = await er_client.get_subject_source_tracks( + subject_id=SUBJECT_ID, src_id=source_id + ) + + assert route.called + assert result == v1_tracks_response["data"] + await er_client.close() + + +@pytest.mark.asyncio +async def test_get_subject_source_tracks_with_since(er_client, v1_tracks_response): + """get_subject_source_tracks() passes since param.""" + source_id = "bbbb1111-2222-3333-4444-555566667777" + async with respx.mock( + base_url=er_client.service_root, assert_all_called=False + ) as respx_mock: + route = respx_mock.get( + f"subject/{SUBJECT_ID}/source/{source_id}/tracks" + ) + route.return_value = httpx.Response( + httpx.codes.OK, json=v1_tracks_response + ) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + result = await er_client.get_subject_source_tracks( + subject_id=SUBJECT_ID, src_id=source_id, start=start + ) + + assert route.called + request = route.calls[0].request + assert "since" in str(request.url) + await er_client.close() + + +# ── Error handling ──────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_subject_tracks_not_found(er_client, not_found_response): + """get_subject_tracks() raises ERClientNotFound on 404.""" + from erclient import ERClientNotFound + + async with respx.mock( + base_url=er_client.service_root, assert_all_called=False + ) as respx_mock: + route = respx_mock.get(f"subject/{SUBJECT_ID}/tracks") + route.return_value = httpx.Response( + httpx.codes.NOT_FOUND, json=not_found_response + ) + + with pytest.raises(ERClientNotFound): + await er_client.get_subject_tracks(subject_id=SUBJECT_ID) + + await er_client.close() + + +@pytest.mark.asyncio +async def test_get_subject_tracks_v2_not_found(er_client, not_found_response): + """get_subject_tracks(version='2.0') raises ERClientNotFound on 404.""" + from erclient import ERClientNotFound + + v2_base = er_client.service_root.replace("v1.0", "v2.0") + async with respx.mock(assert_all_called=False) as respx_mock: + route = respx_mock.get( + f"{v2_base}/subject/{SUBJECT_ID}/tracks/" + ).respond(httpx.codes.NOT_FOUND, json=not_found_response) + + with pytest.raises(ERClientNotFound): + await er_client.get_subject_tracks( + subject_id=SUBJECT_ID, version="2.0" + ) + + await er_client.close() diff --git a/tests/sync_client/__init__.py b/tests/sync_client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/sync_client/conftest.py b/tests/sync_client/conftest.py new file mode 100644 index 0000000..6d9d6da --- /dev/null +++ b/tests/sync_client/conftest.py @@ -0,0 +1,21 @@ +import pytest + +from erclient.client import ERClient + + +@pytest.fixture +def er_server_info(): + return { + "service_root": "https://fake-site.erdomain.org/api/v1.0", + "username": "test", + "password": "test", + "token": "1110c87681cd1d12ad07c2d0f57d15d6079ae5d8", + "token_url": "https://fake-auth.erdomain.org/oauth2/token", + "client_id": "das_web_client", + "provider_key": "testintegration", + } + + +@pytest.fixture +def er_client(er_server_info): + return ERClient(**er_server_info) diff --git a/tests/sync_client/test_get_subject_tracks.py b/tests/sync_client/test_get_subject_tracks.py new file mode 100644 index 0000000..5c7d603 --- /dev/null +++ b/tests/sync_client/test_get_subject_tracks.py @@ -0,0 +1,232 @@ +import json +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from erclient.client import ERClient + + +SUBJECT_ID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890" +SOURCE_ID = "bbbb1111-2222-3333-4444-555566667777" + + +@pytest.fixture +def v1_tracks_response(): + return { + "data": { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7922, -1.2932], + [36.7921, -1.2931], + [36.7919, -1.2931], + ], + }, + "properties": {"title": "Test Subject"}, + } + ], + } + } + + +@pytest.fixture +def v2_tracks_response(): + return { + "data": { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7922, -1.2932], + [36.7921, -1.2931], + ], + }, + "properties": { + "count": 2, + "start_time": "2025-01-01T00:00:00Z", + "end_time": "2025-01-01T01:00:00Z", + }, + }, + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7919, -1.2931], + [36.7917, -1.2931], + ], + }, + "properties": { + "count": 2, + "start_time": "2025-01-01T02:00:00Z", + "end_time": "2025-01-01T03:00:00Z", + }, + }, + ], + } + } + + +def _mock_response(json_body, status_code=200): + """Create a mock requests.Response with the given JSON body.""" + mock_resp = MagicMock() + mock_resp.ok = status_code < 400 + mock_resp.status_code = status_code + mock_resp.text = json.dumps(json_body) + mock_resp.json.return_value = json_body + return mock_resp + + +# ── v1 (default) tracks ────────────────────────────────────────── + +def test_get_subject_tracks_v1_default(er_client, v1_tracks_response): + """get_subject_tracks() with default version hits the v1 endpoint.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v1_tracks_response) + + result = er_client.get_subject_tracks(subject_id=SUBJECT_ID) + + assert mock_get.called + call_args = mock_get.call_args + assert f"subject/{SUBJECT_ID}/tracks" in call_args[0][0] + assert "v1.0" in call_args[0][0] + assert result == v1_tracks_response["data"] + + +def test_get_subject_tracks_v1_with_dates(er_client, v1_tracks_response): + """v1 tracks pass since/until params when start/end are given.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v1_tracks_response) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 2, tzinfo=timezone.utc) + er_client.get_subject_tracks(subject_id=SUBJECT_ID, start=start, end=end) + + call_kwargs = mock_get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert "since" in params + assert "until" in params + + +# ── v2 segmented tracks ────────────────────────────────────────── + +def test_get_subject_tracks_v2(er_client, v2_tracks_response): + """get_subject_tracks(version='2.0') hits the v2 segmented endpoint.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v2_tracks_response) + + result = er_client.get_subject_tracks( + subject_id=SUBJECT_ID, version="2.0" + ) + + assert mock_get.called + call_url = mock_get.call_args[0][0] + assert "v2.0" in call_url + assert f"subject/{SUBJECT_ID}/tracks/" in call_url + assert result == v2_tracks_response["data"] + + +def test_get_subject_tracks_v2_with_dates(er_client, v2_tracks_response): + """v2 tracks pass since/until when start/end are given.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v2_tracks_response) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 2, tzinfo=timezone.utc) + er_client.get_subject_tracks( + subject_id=SUBJECT_ID, start=start, end=end, version="2.0" + ) + + call_kwargs = mock_get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert "since" in params + assert "until" in params + + +def test_get_subject_tracks_v2_extra_params(er_client, v2_tracks_response): + """v2 tracks forward additional filter params.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v2_tracks_response) + + er_client.get_subject_tracks( + subject_id=SUBJECT_ID, + version="2.0", + show_excluded="true", + max_speed_kmh=120.0, + max_gap_minutes=30, + ) + + call_kwargs = mock_get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert params.get("show_excluded") == "true" + assert params.get("max_speed_kmh") == 120.0 + assert params.get("max_gap_minutes") == 30 + + +def test_get_subject_tracks_v2_ignores_unknown_kwargs(er_client, v2_tracks_response): + """Unknown kwargs are not forwarded as query params.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v2_tracks_response) + + er_client.get_subject_tracks( + subject_id=SUBJECT_ID, + version="2.0", + unknown_param="should_not_appear", + ) + + call_kwargs = mock_get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert "unknown_param" not in params + + +# ── _er_url_versioned helper ────────────────────────────────────── + +def test_er_url_versioned_v1(er_client): + """_er_url_versioned('path', '1.0') returns the normal v1 URL.""" + url = er_client._er_url_versioned("subject/abc/tracks", version="1.0") + assert url == "https://fake-site.erdomain.org/api/v1.0/subject/abc/tracks" + + +def test_er_url_versioned_v2(er_client): + """_er_url_versioned('path', '2.0') swaps the version in the URL.""" + url = er_client._er_url_versioned("subject/abc/tracks/", version="2.0") + assert url == "https://fake-site.erdomain.org/api/v2.0/subject/abc/tracks/" + + +# ── get_subject_source_tracks ───────────────────────────────────── + +def test_get_subject_source_tracks(er_client, v1_tracks_response): + """get_subject_source_tracks() hits the correct endpoint.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v1_tracks_response) + + result = er_client.get_subject_source_tracks( + subject_id=SUBJECT_ID, src_id=SOURCE_ID + ) + + call_url = mock_get.call_args[0][0] + assert f"subject/{SUBJECT_ID}/source/{SOURCE_ID}/tracks" in call_url + assert result == v1_tracks_response["data"] + + +def test_get_subject_source_tracks_with_since(er_client, v1_tracks_response): + """get_subject_source_tracks() passes since param.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v1_tracks_response) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + er_client.get_subject_source_tracks( + subject_id=SUBJECT_ID, src_id=SOURCE_ID, start=start + ) + + call_kwargs = mock_get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert "since" in params From 1642dd5d76fdc13f254bd2bca4881d3c54cbe4f7 Mon Sep 17 00:00:00 2001 From: JoshuaVulcan <38018017+JoshuaVulcan@users.noreply.github.com> Date: Wed, 11 Feb 2026 08:46:18 -0800 Subject: [PATCH 2/2] refactor: replace _er_url_versioned() regex with _api_root() + base_url pattern Replace the brittle `_er_url_versioned()` method (which uses regex substitution on service_root) with the cleaner `_api_root(version)` + `base_url` parameter pattern from PR #23. This adds `_api_root()`, updates `_er_url()` to accept a `base_url` param, and threads `base_url` through `_get`, `_post`, `_patch`, and `_call` on both sync and async clients. Updates get_subject_tracks(version='2.0') to use `base_url=self._api_root('v2.0')` instead of building versioned URLs via regex. Updates tests to use `_api_root()` instead of manual string replacement. Co-authored-by: Cursor --- erclient/client.py | 87 +++++++++---------- tests/async_client/test_get_subject_tracks.py | 8 +- tests/sync_client/test_get_subject_tracks.py | 18 ++-- 3 files changed, 56 insertions(+), 57 deletions(-) diff --git a/erclient/client.py b/erclient/client.py index 16b9bf1..7c18781 100644 --- a/erclient/client.py +++ b/erclient/client.py @@ -142,26 +142,22 @@ def _token_request(self, payload): self.auth_expires = pytz.utc.localize(datetime.min) return False - def _er_url(self, path): - return '/'.join((self.service_root, path)) + def _api_root(self, version='v1.0'): + """Return the full API root URL for the given version (e.g. {base}/api/v1.0).""" + base = re.sub(r'/api(/v[^/]+)?/?$', '', self.service_root.rstrip('/')) + return f"{base}/api/{version}" - def _er_url_versioned(self, path, version='1.0'): - """Build an API URL for a specific API version. + def _er_url(self, path, base_url=None): + if base_url is None: + base_url = self.service_root + return '/'.join((base_url.rstrip('/'), path.lstrip('/'))) - Replaces the version portion of service_root (e.g. v1.0 -> v2.0). - Falls back to the default service_root when version is '1.0'. - """ - if version == '1.0': - return self._er_url(path) - versioned_root = re.sub(r'/api/v[^/]+', f'/api/v{version}', self.service_root) - return '/'.join((versioned_root, path)) - - def _get(self, path, stream=False, max_retries=5, seconds_between_attempts=5, **kwargs): + def _get(self, path, base_url=None, stream=False, max_retries=5, seconds_between_attempts=5, **kwargs): headers = {'User-Agent': self.user_agent} headers.update(self.auth_headers()) if (not path.startswith("http")): - path = self._er_url(path) + path = self._er_url(path, base_url) attempts = 0 while (attempts <= max_retries): @@ -213,7 +209,7 @@ def _get(self, path, stream=False, max_retries=5, seconds_between_attempts=5, ** f"Failed to call ER web service at {response.url} after {attempts} tries. {response.status_code} {response.text}") time.sleep(seconds_between_attempts) - def _call(self, path, payload, method, params=None): + def _call(self, path, payload, method, params=None, base_url=None): headers = {'Content-Type': 'application/json', 'User-Agent': self.user_agent} headers.update(self.auth_headers()) @@ -235,7 +231,8 @@ def time_converter(t): except KeyError: self.logger.error('method must be one of...') else: - response = fn(self._er_url(path), data=body, + url = self._er_url(path, base_url) + response = fn(url, data=body, headers=headers, params=params) if response and response.ok: @@ -277,11 +274,11 @@ def time_converter(t): raise ERClientException( f"Failed to {fn} to ER web service. {message}") - def _post(self, path, payload, params=None): - return self._call(path, payload, "POST", params) + def _post(self, path, payload, params=None, base_url=None): + return self._call(path, payload, "POST", params, base_url=base_url) - def _patch(self, path, payload, params=None): - return self._call(path, payload, "PATCH", params) + def _patch(self, path, payload, params=None, base_url=None): + return self._call(path, payload, "PATCH", params, base_url=base_url) def add_event_to_incident(self, event_id, incident_id): @@ -856,8 +853,11 @@ def get_subject_tracks(self, subject_id='', start=None, end=None, version='1.0', 'max_gap_ms', 'max_gap_seconds', 'max_gap_minutes'): if key in kwargs: p[key] = kwargs[key] - url = self._er_url_versioned(f'subject/{subject_id}/tracks/', version='2.0') - return self._get(path=url, params=p) + return self._get( + path=f'subject/{subject_id}/tracks/', + base_url=self._api_root('v2.0'), + params=p, + ) return self._get(path='subject/{0}/tracks'.format(subject_id), params=p) @@ -1307,19 +1307,15 @@ async def _token_request(self, payload): tz=timezone.utc) + timedelta(seconds=expires_in) return True - def _er_url(self, path): - return '/'.join((self.service_root, path)) + def _api_root(self, version='v1.0'): + """Return the full API root URL for the given version (e.g. {base}/api/v1.0).""" + base = re.sub(r'/api(/v[^/]+)?/?$', '', self.service_root.rstrip('/')) + return f"{base}/api/{version}" - def _er_url_versioned(self, path, version='1.0'): - """Build an API URL for a specific API version. - - Replaces the version portion of service_root (e.g. v1.0 -> v2.0). - Falls back to the default service_root when version is '1.0'. - """ - if version == '1.0': - return self._er_url(path) - versioned_root = re.sub(r'/api/v[^/]+', f'/api/v{version}', self.service_root) - return '/'.join((versioned_root, path)) + def _er_url(self, path, base_url=None): + if base_url is None: + base_url = self.service_root + return '/'.join((base_url.rstrip('/'), path.lstrip('/'))) async def _post_form(self, path, body=None, files=None): @@ -1447,8 +1443,11 @@ async def get_subject_tracks(self, subject_id='', start=None, end=None, version= 'max_gap_ms', 'max_gap_seconds', 'max_gap_minutes'): if key in kwargs: p[key] = kwargs[key] - url = self._er_url_versioned(f'subject/{subject_id}/tracks/', version='2.0') - return await self._get(path=url, params=p) + return await self._get( + path=f'subject/{subject_id}/tracks/', + base_url=self._api_root('v2.0'), + params=p, + ) return await self._get(path=f'subject/{subject_id}/tracks', params=p) @@ -1508,16 +1507,16 @@ async def _get_data(self, endpoint, params, batch_size=0): else: break - async def _get(self, path, params=None): - return await self._call(path=path, payload=None, method="GET", params=params) + async def _get(self, path, params=None, base_url=None): + return await self._call(path=path, payload=None, method="GET", params=params, base_url=base_url) - async def _post(self, path, payload, params=None): - return await self._call(path, payload, "POST", params) + async def _post(self, path, payload, params=None, base_url=None): + return await self._call(path, payload, "POST", params, base_url=base_url) - async def _patch(self, path, payload, params=None): - return await self._call(path, payload, "PATCH", params) + async def _patch(self, path, payload, params=None, base_url=None): + return await self._call(path, payload, "PATCH", params, base_url=base_url) - async def _call(self, path, payload, method, params=None): + async def _call(self, path, payload, method, params=None, base_url=None): try: auth_headers = await self.auth_headers() except httpx.HTTPStatusError as e: @@ -1529,7 +1528,7 @@ async def _call(self, path, payload, method, params=None): 'User-Agent': self.user_agent, **auth_headers } - url = path if path.startswith("http") else self._er_url(path) + url = path if path.startswith('http') else self._er_url(path, base_url) try: response = await self._http_session.request( method, diff --git a/tests/async_client/test_get_subject_tracks.py b/tests/async_client/test_get_subject_tracks.py index 075a168..5e162c9 100644 --- a/tests/async_client/test_get_subject_tracks.py +++ b/tests/async_client/test_get_subject_tracks.py @@ -124,7 +124,7 @@ async def test_get_subject_tracks_v1_with_dates(er_client, v1_tracks_response): @pytest.mark.asyncio async def test_get_subject_tracks_v2(er_client, v2_tracks_response): """get_subject_tracks(version='2.0') hits the v2 segmented endpoint.""" - v2_base = er_client.service_root.replace("v1.0", "v2.0") + v2_base = er_client._api_root("v2.0") async with respx.mock(assert_all_called=False) as respx_mock: route = respx_mock.get( f"{v2_base}/subject/{SUBJECT_ID}/tracks/" @@ -142,7 +142,7 @@ async def test_get_subject_tracks_v2(er_client, v2_tracks_response): @pytest.mark.asyncio async def test_get_subject_tracks_v2_with_dates(er_client, v2_tracks_response): """v2 tracks pass since/until when start/end are given.""" - v2_base = er_client.service_root.replace("v1.0", "v2.0") + v2_base = er_client._api_root("v2.0") async with respx.mock(assert_all_called=False) as respx_mock: route = respx_mock.get( f"{v2_base}/subject/{SUBJECT_ID}/tracks/" @@ -164,7 +164,7 @@ async def test_get_subject_tracks_v2_with_dates(er_client, v2_tracks_response): @pytest.mark.asyncio async def test_get_subject_tracks_v2_with_extra_params(er_client, v2_tracks_response): """v2 tracks forward additional filter params (show_excluded, max_speed_kmh, etc.).""" - v2_base = er_client.service_root.replace("v1.0", "v2.0") + v2_base = er_client._api_root("v2.0") async with respx.mock(assert_all_called=False) as respx_mock: route = respx_mock.get( f"{v2_base}/subject/{SUBJECT_ID}/tracks/" @@ -263,7 +263,7 @@ async def test_get_subject_tracks_v2_not_found(er_client, not_found_response): """get_subject_tracks(version='2.0') raises ERClientNotFound on 404.""" from erclient import ERClientNotFound - v2_base = er_client.service_root.replace("v1.0", "v2.0") + v2_base = er_client._api_root("v2.0") async with respx.mock(assert_all_called=False) as respx_mock: route = respx_mock.get( f"{v2_base}/subject/{SUBJECT_ID}/tracks/" diff --git a/tests/sync_client/test_get_subject_tracks.py b/tests/sync_client/test_get_subject_tracks.py index 5c7d603..4af154f 100644 --- a/tests/sync_client/test_get_subject_tracks.py +++ b/tests/sync_client/test_get_subject_tracks.py @@ -187,18 +187,18 @@ def test_get_subject_tracks_v2_ignores_unknown_kwargs(er_client, v2_tracks_respo assert "unknown_param" not in params -# ── _er_url_versioned helper ────────────────────────────────────── +# ── _api_root helper ────────────────────────────────────────────── -def test_er_url_versioned_v1(er_client): - """_er_url_versioned('path', '1.0') returns the normal v1 URL.""" - url = er_client._er_url_versioned("subject/abc/tracks", version="1.0") - assert url == "https://fake-site.erdomain.org/api/v1.0/subject/abc/tracks" +def test_api_root_v1(er_client): + """_api_root('v1.0') returns the v1 API root URL.""" + url = er_client._api_root("v1.0") + assert url == "https://fake-site.erdomain.org/api/v1.0" -def test_er_url_versioned_v2(er_client): - """_er_url_versioned('path', '2.0') swaps the version in the URL.""" - url = er_client._er_url_versioned("subject/abc/tracks/", version="2.0") - assert url == "https://fake-site.erdomain.org/api/v2.0/subject/abc/tracks/" +def test_api_root_v2(er_client): + """_api_root('v2.0') returns the v2 API root URL.""" + url = er_client._api_root("v2.0") + assert url == "https://fake-site.erdomain.org/api/v2.0" # ── get_subject_source_tracks ─────────────────────────────────────