diff --git a/erclient/client.py b/erclient/client.py index 49c69d2..251ed10 100644 --- a/erclient/client.py +++ b/erclient/client.py @@ -886,9 +886,19 @@ def get_source_provider(self, provider_key): return None - def get_subject_tracks(self, subject_id='', start=None, end=None): + def get_subject_tracks(self, subject_id='', start=None, end=None, version='1.0', **kwargs): """ Get the latest tracks for the Subject having the given subject_id. + + :param subject_id: The UUID of the subject. + :param start: datetime lower-bound filter (sent as ``since``). + :param end: datetime upper-bound filter (sent as ``until``). + :param version: API version string, either '1.0' (default, legacy flat + coordinates) or '2.0' (segmented GeoJSON FeatureCollection). + :param kwargs: Extra query params forwarded to the v2 endpoint such as + ``show_excluded``, ``group_by_flags``, ``max_speed_kmh``, + ``max_gap_ms``, ``max_gap_seconds``, ``max_gap_minutes``. + :return: Track data (format varies by version). """ p = {} if start is not None and isinstance(start, datetime): @@ -896,6 +906,18 @@ def get_subject_tracks(self, subject_id='', start=None, end=None): if end is not None and isinstance(end, datetime): p['until'] = end.isoformat() + if version == '2.0': + # v2 supports additional filter params + for key in ('show_excluded', 'group_by_flags', 'max_speed_kmh', + 'max_gap_ms', 'max_gap_seconds', 'max_gap_minutes'): + if key in kwargs: + p[key] = kwargs[key] + return self._get( + path=f'subject/{subject_id}/tracks/', + base_url=self._api_root('v2.0'), + params=p, + ) + return self._get(path='subject/{0}/tracks'.format(subject_id), params=p) def get_subject_source_tracks(self, subject_id='', src_id='', start=None): @@ -1601,6 +1623,53 @@ async def get_source_assignments(self, subject_ids: List[str] = None, source_ids return await self._get(f'subjectsources', params=params) + async def get_subject_tracks(self, subject_id='', start=None, end=None, version='1.0', **kwargs): + """ + Get tracks for the Subject having the given subject_id. + + :param subject_id: The UUID of the subject. + :param start: datetime lower-bound filter (sent as ``since``). + :param end: datetime upper-bound filter (sent as ``until``). + :param version: API version string, either '1.0' (default, legacy flat + coordinates) or '2.0' (segmented GeoJSON FeatureCollection). + :param kwargs: Extra query params forwarded to the v2 endpoint such as + ``show_excluded``, ``group_by_flags``, ``max_speed_kmh``, + ``max_gap_ms``, ``max_gap_seconds``, ``max_gap_minutes``. + :return: Track data (format varies by version). + """ + p = {} + if start is not None and isinstance(start, datetime): + p['since'] = start.isoformat() + if end is not None and isinstance(end, datetime): + p['until'] = end.isoformat() + + if version == '2.0': + for key in ('show_excluded', 'group_by_flags', 'max_speed_kmh', + 'max_gap_ms', 'max_gap_seconds', 'max_gap_minutes'): + if key in kwargs: + p[key] = kwargs[key] + return await self._get( + path=f'subject/{subject_id}/tracks/', + base_url=self._api_root('v2.0'), + params=p, + ) + + return await self._get(path=f'subject/{subject_id}/tracks', params=p) + + async def get_subject_source_tracks(self, subject_id='', src_id='', start=None): + """ + Get the latest tracks for the Subject having the given subject_id and a source ID. + + :param subject_id: The subject UUID + :param src_id: The source UUID + :param start: Optional datetime lower-bound filter (sent as ``since``) + :return: Track data + """ + p = {} + if start and isinstance(start, datetime): + p['since'] = start.isoformat() + return await self._get(path=f'subject/{subject_id}/source/{src_id}/tracks', params=p) + async def get_feature_group(self, feature_group_id: str): """ Get a feature group by id @@ -1646,38 +1715,6 @@ async def _get_data(self, endpoint, params, batch_size=0): async def _get(self, path, base_url=None, params=None): return await self._call(path=path, payload=None, method="GET", params=params, base_url=base_url) - async def _delete(self, path): - """Issue DELETE request. Returns True on success; raises ERClient* on error.""" - try: - auth_headers = await self.auth_headers() - except httpx.HTTPStatusError as e: - self._handle_http_status_error(path, "DELETE", e) - headers = {'User-Agent': self.user_agent, **auth_headers} - if not path.startswith('http'): - path = self._er_url(path) - try: - response = await self._http_session.delete(path, headers=headers) - except httpx.RequestError as e: - reason = str(e) - self.logger.error('Request to ER failed', extra=dict(provider_key=self.provider_key, - url=path, - reason=reason)) - raise ERClientException(f'Request to ER failed: {reason}') - if response.is_success: - return True - if response.status_code == 404: - self.logger.error("404 when calling %s", path) - raise ERClientNotFound() - if response.status_code == 403: - try: - reason = response.json().get('status', {}).get('detail', 'unknown reason') - except Exception: - reason = 'unknown reason' - raise ERClientPermissionDenied(reason) - raise ERClientException( - f'Failed to delete: {response.status_code} {response.text}' - ) - async def get_file(self, url): """ Download a file (e.g. attachment URL). Returns the httpx response; body is read into memory. diff --git a/tests/async_client/test_get_subject_tracks.py b/tests/async_client/test_get_subject_tracks.py new file mode 100644 index 0000000..c9f9f3e --- /dev/null +++ b/tests/async_client/test_get_subject_tracks.py @@ -0,0 +1,277 @@ +import httpx +import pytest +import respx +from datetime import datetime, timezone + + +@pytest.fixture +def v1_tracks_response(): + """Typical v1 subject tracks response (flat coordinate pairs).""" + return { + "data": { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7922, -1.2932], + [36.7921, -1.2931], + [36.7919, -1.2931], + ], + }, + "properties": { + "title": "Test Subject", + }, + } + ], + } + } + + +@pytest.fixture +def v2_tracks_response(): + """Typical v2 segmented tracks response (GeoJSON FeatureCollection).""" + return { + "data": { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7922, -1.2932], + [36.7921, -1.2931], + ], + }, + "properties": { + "count": 2, + "start_time": "2025-01-01T00:00:00Z", + "end_time": "2025-01-01T01:00:00Z", + }, + }, + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7919, -1.2931], + [36.7917, -1.2931], + ], + }, + "properties": { + "count": 2, + "start_time": "2025-01-01T02:00:00Z", + "end_time": "2025-01-01T03:00:00Z", + }, + }, + ], + } + } + + +SUBJECT_ID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890" + + +# ── v1 (default) tracks ────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_subject_tracks_v1_default(er_client, v1_tracks_response): + """get_subject_tracks() with default version hits the v1 endpoint.""" + async with respx.mock( + base_url=er_client._api_root("v1.0"), assert_all_called=False + ) as respx_mock: + route = respx_mock.get(f"subject/{SUBJECT_ID}/tracks") + route.return_value = httpx.Response( + httpx.codes.OK, json=v1_tracks_response + ) + + result = await er_client.get_subject_tracks(subject_id=SUBJECT_ID) + + assert route.called + assert result == v1_tracks_response["data"] + await er_client.close() + + +@pytest.mark.asyncio +async def test_get_subject_tracks_v1_with_dates(er_client, v1_tracks_response): + """v1 tracks pass since/until params when start/end are given.""" + async with respx.mock( + base_url=er_client._api_root("v1.0"), assert_all_called=False + ) as respx_mock: + route = respx_mock.get(f"subject/{SUBJECT_ID}/tracks") + route.return_value = httpx.Response( + httpx.codes.OK, json=v1_tracks_response + ) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 2, tzinfo=timezone.utc) + result = await er_client.get_subject_tracks( + subject_id=SUBJECT_ID, start=start, end=end + ) + + assert route.called + request = route.calls[0].request + assert "since" in str(request.url) + assert "until" in str(request.url) + await er_client.close() + + +# ── v2 segmented tracks ────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_subject_tracks_v2(er_client, v2_tracks_response): + """get_subject_tracks(version='2.0') hits the v2 segmented endpoint.""" + v2_base = er_client._api_root("v2.0") + async with respx.mock(assert_all_called=False) as respx_mock: + route = respx_mock.get( + f"{v2_base}/subject/{SUBJECT_ID}/tracks/" + ).respond(httpx.codes.OK, json=v2_tracks_response) + + result = await er_client.get_subject_tracks( + subject_id=SUBJECT_ID, version="2.0" + ) + + assert route.called + assert result == v2_tracks_response["data"] + await er_client.close() + + +@pytest.mark.asyncio +async def test_get_subject_tracks_v2_with_dates(er_client, v2_tracks_response): + """v2 tracks pass since/until when start/end are given.""" + v2_base = er_client._api_root("v2.0") + async with respx.mock(assert_all_called=False) as respx_mock: + route = respx_mock.get( + f"{v2_base}/subject/{SUBJECT_ID}/tracks/" + ).respond(httpx.codes.OK, json=v2_tracks_response) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 2, tzinfo=timezone.utc) + result = await er_client.get_subject_tracks( + subject_id=SUBJECT_ID, start=start, end=end, version="2.0" + ) + + assert route.called + request = route.calls[0].request + assert "since" in str(request.url) + assert "until" in str(request.url) + await er_client.close() + + +@pytest.mark.asyncio +async def test_get_subject_tracks_v2_with_extra_params(er_client, v2_tracks_response): + """v2 tracks forward additional filter params (show_excluded, max_speed_kmh, etc.).""" + v2_base = er_client._api_root("v2.0") + async with respx.mock(assert_all_called=False) as respx_mock: + route = respx_mock.get( + f"{v2_base}/subject/{SUBJECT_ID}/tracks/" + ).respond(httpx.codes.OK, json=v2_tracks_response) + + result = await er_client.get_subject_tracks( + subject_id=SUBJECT_ID, + version="2.0", + show_excluded="true", + max_speed_kmh=120.0, + max_gap_minutes=30, + ) + + assert route.called + request = route.calls[0].request + url_str = str(request.url) + assert "show_excluded" in url_str + assert "max_speed_kmh" in url_str + assert "max_gap_minutes" in url_str + await er_client.close() + + +# ── get_subject_source_tracks (async) ───────────────────────────── + +@pytest.mark.asyncio +async def test_get_subject_source_tracks(er_client, v1_tracks_response): + """get_subject_source_tracks() hits the correct endpoint.""" + source_id = "bbbb1111-2222-3333-4444-555566667777" + async with respx.mock( + base_url=er_client._api_root("v1.0"), assert_all_called=False + ) as respx_mock: + route = respx_mock.get( + f"subject/{SUBJECT_ID}/source/{source_id}/tracks" + ) + route.return_value = httpx.Response( + httpx.codes.OK, json=v1_tracks_response + ) + + result = await er_client.get_subject_source_tracks( + subject_id=SUBJECT_ID, src_id=source_id + ) + + assert route.called + assert result == v1_tracks_response["data"] + await er_client.close() + + +@pytest.mark.asyncio +async def test_get_subject_source_tracks_with_since(er_client, v1_tracks_response): + """get_subject_source_tracks() passes since param.""" + source_id = "bbbb1111-2222-3333-4444-555566667777" + async with respx.mock( + base_url=er_client._api_root("v1.0"), assert_all_called=False + ) as respx_mock: + route = respx_mock.get( + f"subject/{SUBJECT_ID}/source/{source_id}/tracks" + ) + route.return_value = httpx.Response( + httpx.codes.OK, json=v1_tracks_response + ) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + result = await er_client.get_subject_source_tracks( + subject_id=SUBJECT_ID, src_id=source_id, start=start + ) + + assert route.called + request = route.calls[0].request + assert "since" in str(request.url) + await er_client.close() + + +# ── Error handling ──────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_get_subject_tracks_not_found(er_client, not_found_response): + """get_subject_tracks() raises ERClientNotFound on 404.""" + from erclient import ERClientNotFound + + async with respx.mock( + base_url=er_client._api_root("v1.0"), assert_all_called=False + ) as respx_mock: + route = respx_mock.get(f"subject/{SUBJECT_ID}/tracks") + route.return_value = httpx.Response( + httpx.codes.NOT_FOUND, json=not_found_response + ) + + with pytest.raises(ERClientNotFound): + await er_client.get_subject_tracks(subject_id=SUBJECT_ID) + + await er_client.close() + + +@pytest.mark.asyncio +async def test_get_subject_tracks_v2_not_found(er_client, not_found_response): + """get_subject_tracks(version='2.0') raises ERClientNotFound on 404.""" + from erclient import ERClientNotFound + + v2_base = er_client._api_root("v2.0") + async with respx.mock(assert_all_called=False) as respx_mock: + route = respx_mock.get( + f"{v2_base}/subject/{SUBJECT_ID}/tracks/" + ).respond(httpx.codes.NOT_FOUND, json=not_found_response) + + with pytest.raises(ERClientNotFound): + await er_client.get_subject_tracks( + subject_id=SUBJECT_ID, version="2.0" + ) + + await er_client.close() diff --git a/tests/sync_client/test_get_subject_tracks.py b/tests/sync_client/test_get_subject_tracks.py new file mode 100644 index 0000000..4af154f --- /dev/null +++ b/tests/sync_client/test_get_subject_tracks.py @@ -0,0 +1,232 @@ +import json +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from erclient.client import ERClient + + +SUBJECT_ID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890" +SOURCE_ID = "bbbb1111-2222-3333-4444-555566667777" + + +@pytest.fixture +def v1_tracks_response(): + return { + "data": { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7922, -1.2932], + [36.7921, -1.2931], + [36.7919, -1.2931], + ], + }, + "properties": {"title": "Test Subject"}, + } + ], + } + } + + +@pytest.fixture +def v2_tracks_response(): + return { + "data": { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7922, -1.2932], + [36.7921, -1.2931], + ], + }, + "properties": { + "count": 2, + "start_time": "2025-01-01T00:00:00Z", + "end_time": "2025-01-01T01:00:00Z", + }, + }, + { + "type": "Feature", + "geometry": { + "type": "LineString", + "coordinates": [ + [36.7919, -1.2931], + [36.7917, -1.2931], + ], + }, + "properties": { + "count": 2, + "start_time": "2025-01-01T02:00:00Z", + "end_time": "2025-01-01T03:00:00Z", + }, + }, + ], + } + } + + +def _mock_response(json_body, status_code=200): + """Create a mock requests.Response with the given JSON body.""" + mock_resp = MagicMock() + mock_resp.ok = status_code < 400 + mock_resp.status_code = status_code + mock_resp.text = json.dumps(json_body) + mock_resp.json.return_value = json_body + return mock_resp + + +# ── v1 (default) tracks ────────────────────────────────────────── + +def test_get_subject_tracks_v1_default(er_client, v1_tracks_response): + """get_subject_tracks() with default version hits the v1 endpoint.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v1_tracks_response) + + result = er_client.get_subject_tracks(subject_id=SUBJECT_ID) + + assert mock_get.called + call_args = mock_get.call_args + assert f"subject/{SUBJECT_ID}/tracks" in call_args[0][0] + assert "v1.0" in call_args[0][0] + assert result == v1_tracks_response["data"] + + +def test_get_subject_tracks_v1_with_dates(er_client, v1_tracks_response): + """v1 tracks pass since/until params when start/end are given.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v1_tracks_response) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 2, tzinfo=timezone.utc) + er_client.get_subject_tracks(subject_id=SUBJECT_ID, start=start, end=end) + + call_kwargs = mock_get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert "since" in params + assert "until" in params + + +# ── v2 segmented tracks ────────────────────────────────────────── + +def test_get_subject_tracks_v2(er_client, v2_tracks_response): + """get_subject_tracks(version='2.0') hits the v2 segmented endpoint.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v2_tracks_response) + + result = er_client.get_subject_tracks( + subject_id=SUBJECT_ID, version="2.0" + ) + + assert mock_get.called + call_url = mock_get.call_args[0][0] + assert "v2.0" in call_url + assert f"subject/{SUBJECT_ID}/tracks/" in call_url + assert result == v2_tracks_response["data"] + + +def test_get_subject_tracks_v2_with_dates(er_client, v2_tracks_response): + """v2 tracks pass since/until when start/end are given.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v2_tracks_response) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 2, tzinfo=timezone.utc) + er_client.get_subject_tracks( + subject_id=SUBJECT_ID, start=start, end=end, version="2.0" + ) + + call_kwargs = mock_get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert "since" in params + assert "until" in params + + +def test_get_subject_tracks_v2_extra_params(er_client, v2_tracks_response): + """v2 tracks forward additional filter params.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v2_tracks_response) + + er_client.get_subject_tracks( + subject_id=SUBJECT_ID, + version="2.0", + show_excluded="true", + max_speed_kmh=120.0, + max_gap_minutes=30, + ) + + call_kwargs = mock_get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert params.get("show_excluded") == "true" + assert params.get("max_speed_kmh") == 120.0 + assert params.get("max_gap_minutes") == 30 + + +def test_get_subject_tracks_v2_ignores_unknown_kwargs(er_client, v2_tracks_response): + """Unknown kwargs are not forwarded as query params.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v2_tracks_response) + + er_client.get_subject_tracks( + subject_id=SUBJECT_ID, + version="2.0", + unknown_param="should_not_appear", + ) + + call_kwargs = mock_get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert "unknown_param" not in params + + +# ── _api_root helper ────────────────────────────────────────────── + +def test_api_root_v1(er_client): + """_api_root('v1.0') returns the v1 API root URL.""" + url = er_client._api_root("v1.0") + assert url == "https://fake-site.erdomain.org/api/v1.0" + + +def test_api_root_v2(er_client): + """_api_root('v2.0') returns the v2 API root URL.""" + url = er_client._api_root("v2.0") + assert url == "https://fake-site.erdomain.org/api/v2.0" + + +# ── get_subject_source_tracks ───────────────────────────────────── + +def test_get_subject_source_tracks(er_client, v1_tracks_response): + """get_subject_source_tracks() hits the correct endpoint.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v1_tracks_response) + + result = er_client.get_subject_source_tracks( + subject_id=SUBJECT_ID, src_id=SOURCE_ID + ) + + call_url = mock_get.call_args[0][0] + assert f"subject/{SUBJECT_ID}/source/{SOURCE_ID}/tracks" in call_url + assert result == v1_tracks_response["data"] + + +def test_get_subject_source_tracks_with_since(er_client, v1_tracks_response): + """get_subject_source_tracks() passes since param.""" + with patch.object(er_client._http_session, "get") as mock_get: + mock_get.return_value = _mock_response(v1_tracks_response) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + er_client.get_subject_source_tracks( + subject_id=SUBJECT_ID, src_id=SOURCE_ID, start=start + ) + + call_kwargs = mock_get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert "since" in params