Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lambda/src/environment/service_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def wrapper(event, context, service_provider=None):


class ServiceProvider:
@cached_property
def user_agent(self) -> str:
return "layertwo-ffsync/1.0"

@cached_property
def aws_region(self): # pragma: nocover
return os.environ.get("AWS_REGION")
Expand Down Expand Up @@ -246,6 +250,7 @@ def oidc_validator(self) -> OIDCValidator:
client_id=self.oidc_client_id,
clock_skew_tolerance=self.clock_skew_tolerance,
cache_ttl_seconds=self.oidc_cache_ttl_seconds,
user_agent=self.user_agent,
)

@cached_property
Expand Down Expand Up @@ -362,6 +367,7 @@ def auth_api_router(self):
OIDCCodeExchangeRoute(
oidc_validator=self.oidc_validator,
account_manager=self.auth_account_manager,
user_agent=self.user_agent,
),
# Device management routes
AccountDeviceRoute(
Expand Down
23 changes: 16 additions & 7 deletions lambda/src/routes/auth/oidc_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json

import requests as http_requests
import requests
from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response

Expand Down Expand Up @@ -50,9 +50,15 @@ def __init__(
self,
oidc_validator: OIDCValidator,
account_manager: AuthAccountManager,
user_agent: str,
):
self._oidc_validator = oidc_validator
self._account_manager = account_manager
self._user_agent = user_agent

@property
def _default_headers(self) -> dict[str, str]:
return {"User-Agent": self._user_agent}

def bind(self, app: APIGatewayRestResolver):
@app.post("/v1/oidc/exchange")
Expand Down Expand Up @@ -96,7 +102,7 @@ def handle(self, event) -> Response:

# 2. Exchange code for tokens at the provider's token endpoint
try:
token_response = http_requests.post(
token_response = requests.post(
provider_config.token_endpoint,
data={
"grant_type": "authorization_code",
Expand All @@ -105,10 +111,13 @@ def handle(self, event) -> Response:
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
headers={
"Content-Type": "application/x-www-form-urlencoded",
**self._default_headers,
},
timeout=10,
)
except http_requests.exceptions.RequestException:
except requests.exceptions.RequestException:
logger.exception("Token exchange request failed")
return Response(
status_code=502,
Expand Down Expand Up @@ -150,12 +159,12 @@ def handle(self, event) -> Response:

# 4. Fetch userinfo to get email
try:
userinfo_response = http_requests.get(
userinfo_response = requests.get(
provider_config.userinfo_endpoint,
headers={"Authorization": f"Bearer {access_token}"},
headers={"Authorization": f"Bearer {access_token}", **self._default_headers},
timeout=10,
)
except http_requests.exceptions.RequestException:
except requests.exceptions.RequestException:
logger.exception("Userinfo request failed")
return Response(
status_code=502,
Expand Down
19 changes: 12 additions & 7 deletions lambda/src/services/oidc_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from typing import Optional

import jwt
import requests # type: ignore[import-untyped]
from jwt import PyJWKClient, PyJWKClientError
import requests

from src.shared.exceptions import (
InvalidCredentialsError,
Expand Down Expand Up @@ -35,6 +34,7 @@ def __init__(
self,
provider_url: str,
client_id: str,
user_agent: str,
clock_skew_tolerance: int = 300,
cache_ttl_seconds: int = CACHE_TTL_SECONDS,
):
Expand All @@ -53,7 +53,12 @@ def __init__(
self.cache_ttl_seconds = cache_ttl_seconds
self._provider_config: Optional[OIDCProviderConfig] = None
self._provider_config_timestamp: float = 0
self._jwk_client: Optional[PyJWKClient] = None
self._jwk_client: Optional[jwt.PyJWKClient] = None
self._user_agent = user_agent

@property
def _default_headers(self) -> dict[str, str]:
return {"User-Agent": self._user_agent}

def _is_cache_valid(self) -> bool:
"""Check if the cached provider configuration is still valid."""
Expand Down Expand Up @@ -81,7 +86,7 @@ def discover_provider_config(self) -> OIDCProviderConfig:
well_known_url = f"{self.provider_url}/.well-known/openid-configuration"

try:
response = requests.get(well_known_url, timeout=10)
response = requests.get(well_known_url, timeout=10, headers=self._default_headers)
response.raise_for_status()
config_data = response.json()

Expand All @@ -108,7 +113,7 @@ def discover_provider_config(self) -> OIDCProviderConfig:
except (KeyError, ValueError) as e:
raise ServiceUnavailableError(f"Invalid OIDC provider configuration: {e}")

def _get_jwk_client(self) -> PyJWKClient:
def _get_jwk_client(self) -> jwt.PyJWKClient:
"""
Get or create PyJWKClient for JWKS fetching.

Expand All @@ -122,7 +127,7 @@ def _get_jwk_client(self) -> PyJWKClient:
"""
if self._jwk_client is None:
config = self.discover_provider_config()
self._jwk_client = PyJWKClient(
self._jwk_client = jwt.PyJWKClient(
config.jwks_uri,
cache_keys=True,
lifespan=self.cache_ttl_seconds,
Expand Down Expand Up @@ -162,7 +167,7 @@ def validate_token(self, token: str) -> OIDCTokenClaims:
jwk_client = self._get_jwk_client()
try:
signing_key = jwk_client.get_signing_key_from_jwt(token)
except PyJWKClientError as e:
except jwt.PyJWKClientError as e:
raise InvalidTokenError(f"Failed to get signing key: {e}")

# Decode and validate token
Expand Down
19 changes: 10 additions & 9 deletions lambda/tests/routes/auth/test_oidc_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def exchange_route(mock_oidc_validator, mock_account_manager):
return OIDCCodeExchangeRoute(
oidc_validator=mock_oidc_validator,
account_manager=mock_account_manager,
user_agent="unit-test",
)


Expand Down Expand Up @@ -107,7 +108,7 @@ def _exchange_event(self, body=None):
}
return _make_event("POST", "/v1/oidc/exchange", body=body or default_body)

@patch("src.routes.auth.oidc_exchange.http_requests")
@patch("src.routes.auth.oidc_exchange.requests")
def test_success_account_exists(
self, mock_requests, exchange_route, mock_oidc_validator, mock_account_manager
):
Expand Down Expand Up @@ -135,7 +136,7 @@ def test_success_account_exists(
assert body["access_token"] == "at-789"
assert body["account_exists"] is True

@patch("src.routes.auth.oidc_exchange.http_requests")
@patch("src.routes.auth.oidc_exchange.requests")
def test_success_account_does_not_exist(
self, mock_requests, exchange_route, mock_oidc_validator, mock_account_manager
):
Expand Down Expand Up @@ -179,7 +180,7 @@ def test_provider_discovery_failure(self, exchange_route, mock_oidc_validator):
response = exchange_route.handle(self._exchange_event())
assert response.status_code == 503

@patch("src.routes.auth.oidc_exchange.http_requests")
@patch("src.routes.auth.oidc_exchange.requests")
def test_token_exchange_network_error(self, mock_requests, exchange_route):
import requests

Expand All @@ -191,7 +192,7 @@ def test_token_exchange_network_error(self, mock_requests, exchange_route):
body = json.loads(response.body)
assert "exchange" in body["message"].lower()

@patch("src.routes.auth.oidc_exchange.http_requests")
@patch("src.routes.auth.oidc_exchange.requests")
def test_token_exchange_provider_rejects(self, mock_requests, exchange_route):
mock_token_resp = MagicMock()
mock_token_resp.ok = False
Expand All @@ -203,7 +204,7 @@ def test_token_exchange_provider_rejects(self, mock_requests, exchange_route):
response = exchange_route.handle(self._exchange_event())
assert response.status_code == 401

@patch("src.routes.auth.oidc_exchange.http_requests")
@patch("src.routes.auth.oidc_exchange.requests")
def test_no_access_token_in_response(self, mock_requests, exchange_route):
mock_token_resp = MagicMock()
mock_token_resp.ok = True
Expand All @@ -216,7 +217,7 @@ def test_no_access_token_in_response(self, mock_requests, exchange_route):
body = json.loads(response.body)
assert "access token" in body["message"].lower()

@patch("src.routes.auth.oidc_exchange.http_requests")
@patch("src.routes.auth.oidc_exchange.requests")
def test_token_validation_failure(self, mock_requests, exchange_route, mock_oidc_validator):
mock_token_resp = MagicMock()
mock_token_resp.ok = True
Expand All @@ -229,7 +230,7 @@ def test_token_validation_failure(self, mock_requests, exchange_route, mock_oidc
response = exchange_route.handle(self._exchange_event())
assert response.status_code == 401

@patch("src.routes.auth.oidc_exchange.http_requests")
@patch("src.routes.auth.oidc_exchange.requests")
def test_userinfo_network_error(self, mock_requests, exchange_route, mock_oidc_validator):
import requests

Expand All @@ -245,7 +246,7 @@ def test_userinfo_network_error(self, mock_requests, exchange_route, mock_oidc_v
response = exchange_route.handle(self._exchange_event())
assert response.status_code == 502

@patch("src.routes.auth.oidc_exchange.http_requests")
@patch("src.routes.auth.oidc_exchange.requests")
def test_userinfo_provider_error(self, mock_requests, exchange_route, mock_oidc_validator):
mock_token_resp = MagicMock()
mock_token_resp.ok = True
Expand All @@ -265,7 +266,7 @@ def test_userinfo_provider_error(self, mock_requests, exchange_route, mock_oidc_
response = exchange_route.handle(self._exchange_event())
assert response.status_code == 502

@patch("src.routes.auth.oidc_exchange.http_requests")
@patch("src.routes.auth.oidc_exchange.requests")
def test_userinfo_missing_email(self, mock_requests, exchange_route, mock_oidc_validator):
mock_token_resp = MagicMock()
mock_token_resp.ok = True
Expand Down
4 changes: 3 additions & 1 deletion lambda/tests/routes/auth/test_route_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def test_oidc_provider_config_dispatches(self):
assert result["statusCode"] == 200

def test_oidc_code_exchange_dispatches(self):
route = OIDCCodeExchangeRoute(oidc_validator=MagicMock(), account_manager=MagicMock())
route = OIDCCodeExchangeRoute(
oidc_validator=MagicMock(), account_manager=MagicMock(), user_agent="unit-test"
)
result = _router(route).handler(
_make_event("POST", "/v1/oidc/exchange", body="{}"), _make_context()
)
Expand Down
37 changes: 24 additions & 13 deletions lambda/tests/services/test_oidc_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def client_id():

@pytest.fixture
def validator(provider_url, client_id):
return OIDCValidator(provider_url, client_id)
return OIDCValidator(provider_url, client_id, user_agent="foobar")


@pytest.fixture
Expand All @@ -51,32 +51,36 @@ class TestOIDCValidatorInit:

def test_init_strips_trailing_slash(self, client_id):
"""Test that trailing slash is stripped from provider URL"""
validator = OIDCValidator("https://auth.example.com/", client_id)
validator = OIDCValidator("https://auth.example.com/", client_id, user_agent="foobar")
assert validator.provider_url == "https://auth.example.com"

def test_init_stores_client_id(self, provider_url, client_id):
"""Test that client_id is stored correctly"""
validator = OIDCValidator(provider_url, client_id)
validator = OIDCValidator(provider_url, client_id, user_agent="foobar")
assert validator.client_id == client_id

def test_init_default_clock_skew_tolerance(self, provider_url, client_id):
"""Test that default clock_skew_tolerance is 300 seconds"""
validator = OIDCValidator(provider_url, client_id)
validator = OIDCValidator(provider_url, client_id, user_agent="foobar")
assert validator.clock_skew_tolerance == 300

def test_init_custom_clock_skew_tolerance(self, provider_url, client_id):
"""Test that custom clock_skew_tolerance is stored correctly"""
validator = OIDCValidator(provider_url, client_id, clock_skew_tolerance=600)
validator = OIDCValidator(
provider_url, client_id, clock_skew_tolerance=600, user_agent="foobar"
)
assert validator.clock_skew_tolerance == 600

def test_init_default_cache_ttl_seconds(self, provider_url, client_id):
"""Test that default cache_ttl_seconds is 3600 seconds"""
validator = OIDCValidator(provider_url, client_id)
validator = OIDCValidator(provider_url, client_id, user_agent="foobar")
assert validator.cache_ttl_seconds == 3600

def test_init_custom_cache_ttl_seconds(self, provider_url, client_id):
"""Test that custom cache_ttl_seconds is stored correctly"""
validator = OIDCValidator(provider_url, client_id, cache_ttl_seconds=7200)
validator = OIDCValidator(
provider_url, client_id, cache_ttl_seconds=7200, user_agent="foobar"
)
assert validator.cache_ttl_seconds == 7200


Expand All @@ -98,6 +102,7 @@ def test_discover_provider_config_success(self, validator, mock_provider_config)
mock_get.assert_called_once_with(
"https://auth.example.com/.well-known/openid-configuration",
timeout=10,
headers={"User-Agent": "foobar"},
)

def test_discover_provider_config_caching(self, validator, mock_provider_config):
Expand Down Expand Up @@ -142,7 +147,9 @@ def test_discover_provider_config_custom_cache_ttl(
self, provider_url, client_id, mock_provider_config
):
"""Test that custom cache TTL is respected"""
validator = OIDCValidator(provider_url, client_id, cache_ttl_seconds=1800)
validator = OIDCValidator(
provider_url, client_id, cache_ttl_seconds=1800, user_agent="foobar"
)

with patch("src.services.oidc_validator.requests.get") as mock_get:
mock_response = MagicMock()
Expand Down Expand Up @@ -599,7 +606,9 @@ def test_validate_token_timestamp_future_exceeds_tolerance(

def test_validate_token_custom_tolerance(self, provider_url, client_id, mock_provider_config):
"""Test timestamp validation with custom tolerance"""
validator = OIDCValidator(provider_url, client_id, clock_skew_tolerance=600)
validator = OIDCValidator(
provider_url, client_id, clock_skew_tolerance=600, user_agent="foobar"
)
current_time = int(datetime.now(timezone.utc).timestamp())
mock_claims = {
"sub": "user123",
Expand Down Expand Up @@ -697,7 +706,7 @@ def test_get_jwk_client_creates_client(self, validator, mock_provider_config):
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response

with patch("src.services.oidc_validator.PyJWKClient") as mock_jwk_class:
with patch("src.services.oidc_validator.jwt.PyJWKClient") as mock_jwk_class:
mock_jwk_instance = MagicMock()
mock_jwk_class.return_value = mock_jwk_instance

Expand All @@ -712,15 +721,17 @@ def test_get_jwk_client_creates_client(self, validator, mock_provider_config):

def test_get_jwk_client_custom_cache_ttl(self, provider_url, client_id, mock_provider_config):
"""Test that _get_jwk_client uses custom cache TTL"""
validator = OIDCValidator(provider_url, client_id, cache_ttl_seconds=7200)
validator = OIDCValidator(
provider_url, client_id, cache_ttl_seconds=7200, user_agent="foobar"
)

with patch("src.services.oidc_validator.requests.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_provider_config
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response

with patch("src.services.oidc_validator.PyJWKClient") as mock_jwk_class:
with patch("src.services.oidc_validator.jwt.PyJWKClient") as mock_jwk_class:
mock_jwk_instance = MagicMock()
mock_jwk_class.return_value = mock_jwk_instance

Expand All @@ -741,7 +752,7 @@ def test_get_jwk_client_caches_client(self, validator, mock_provider_config):
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response

with patch("src.services.oidc_validator.PyJWKClient") as mock_jwk_class:
with patch("src.services.oidc_validator.jwt.PyJWKClient") as mock_jwk_class:
mock_jwk_instance = MagicMock()
mock_jwk_class.return_value = mock_jwk_instance

Expand Down
Loading