Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ lambda/.coverage
lambda/coverage.xml
*.tsbuildinfo
docs/

.DS_Store
9 changes: 6 additions & 3 deletions lambda/src/environment/service_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def storage_api_router(self):
],
middlewares=[
RequestLoggingMiddleware(),
HawkAuthMiddleware(hawk_service=self.hawk_service),
HawkAuthMiddleware(hawk_service=self.hawk_service, metrics=self.metrics),
WeaveTimestampMiddleware(),
],
exception_handlers=self._storage_exception_handlers,
Expand Down Expand Up @@ -296,7 +296,7 @@ def device_manager(self) -> DeviceManager:

@cached_property
def fxa_token_manager(self) -> FxATokenManager:
return FxATokenManager(table=self.auth_table)
return FxATokenManager(table=self.auth_table, metrics=self.metrics)

@cached_property
def oauth_code_manager(self) -> OAuthCodeManager:
Expand All @@ -316,7 +316,7 @@ def jwt_verifier(self) -> JWTVerifier:

@cached_property
def session_hawk_middleware(self) -> HawkAuthMiddleware:
return HawkAuthMiddleware(token_manager=self.fxa_token_manager)
return HawkAuthMiddleware(token_manager=self.fxa_token_manager, metrics=self.metrics)

@cached_property
def cors_config(self) -> CORSConfig:
Expand Down Expand Up @@ -366,6 +366,7 @@ def auth_api_router(self):
jwt_service=self.jwt_service,
account_manager=self.auth_account_manager,
token_manager=self.fxa_token_manager,
metrics=self.metrics,
),
OAuthDestroyRoute(oauth_code_manager=self.oauth_code_manager),
# Discovery routes
Expand Down Expand Up @@ -412,6 +413,7 @@ def token_api_router(self):
user_manager=self.user_manager,
token_generator=self.token_generator,
retry_after_seconds=self.retry_after_seconds,
metrics=self.metrics,
),
],
middlewares=[WeaveTimestampMiddleware()],
Expand All @@ -427,6 +429,7 @@ def profile_api_router(self):
GetProfileRoute(
jwt_verifier=self.jwt_verifier,
auth_account_manager=self.auth_account_manager,
metrics=self.metrics,
),
],
middlewares=[WeaveTimestampMiddleware()],
Expand Down
19 changes: 14 additions & 5 deletions lambda/src/middlewares/hawk_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
from aws_lambda_powertools.metrics import Metrics, MetricUnit

from src.services.fxa_token_manager import FxATokenManager
from src.services.hawk_service import HawkService
Expand All @@ -30,6 +31,7 @@ class HawkAuthMiddleware(BaseMiddlewareHandler):
def __init__(
self,
*,
metrics: Metrics,
hawk_service: HawkService | None = None,
token_manager: FxATokenManager | None = None,
):
Expand All @@ -38,22 +40,29 @@ def __init__(
raise ValueError("Either hawk_service or token_manager is required")
self._hawk_service = hawk_service
self._token_manager = token_manager
self._metrics = metrics

def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
event = app.current_event

headers = event.headers or {}
auth_header = headers.get("Authorization") or headers.get("authorization")
if not auth_header:
self._metrics.add_metric("HawkAuthFailure", MetricUnit.Count, 1)
raise HawkAuthenticationError("Missing or invalid authorization")

method, path, host, port = extract_hawk_request_params(event)

if self._hawk_service:
self._validate_storage_hawk(event, auth_header, method, path, host, port)
else:
self._validate_session_hawk(event, auth_header, method, path, host, port)

try:
if self._hawk_service:
self._validate_storage_hawk(event, auth_header, method, path, host, port)
else:
self._validate_session_hawk(event, auth_header, method, path, host, port)
except HawkAuthenticationError, UidMismatchError:
self._metrics.add_metric("HawkAuthFailure", MetricUnit.Count, 1)
raise

self._metrics.add_metric("HawkAuthSuccess", MetricUnit.Count, 1)
return next_middleware(app)

def _validate_storage_hawk(self, event, auth_header, method, path, host, port):
Expand Down
6 changes: 6 additions & 0 deletions lambda/src/routes/auth/oauth_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time

from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.metrics import Metrics, MetricUnit

from src.services.auth_account_manager import AuthAccountManager
from src.services.fxa_token_manager import FxATokenManager
Expand All @@ -26,11 +27,13 @@ def __init__(
oauth_code_manager: OAuthCodeManager,
jwt_service: JWTService,
account_manager: AuthAccountManager,
metrics: Metrics,
token_manager: FxATokenManager | None = None,
):
self._oauth_code_manager = oauth_code_manager
self._jwt_service = jwt_service
self._account_manager = account_manager
self._metrics = metrics
self._token_manager = token_manager

def bind(self, app: APIGatewayRestResolver):
Expand Down Expand Up @@ -124,6 +127,7 @@ def _handle_authorization_code(self, body: dict) -> Response:
if keys_jwe:
result.keys_jwe = keys_jwe

self._metrics.add_metric("AccessTokensIssued", MetricUnit.Count, 1)
return Response(
status_code=200,
content_type="application/json",
Expand Down Expand Up @@ -180,6 +184,7 @@ def _handle_refresh_token(self, body: dict) -> Response:
refresh_token=new_refresh_token,
auth_at=int(time.time()),
)
self._metrics.add_metric("AccessTokensIssued", MetricUnit.Count, 1)
return Response(
status_code=200,
content_type="application/json",
Expand Down Expand Up @@ -225,6 +230,7 @@ def _handle_fxa_credentials(self, event, body: dict) -> Response:
scope=scope,
auth_at=int(time.time()),
)
self._metrics.add_metric("AccessTokensIssued", MetricUnit.Count, 1)
return Response(
status_code=200,
content_type="application/json",
Expand Down
7 changes: 7 additions & 0 deletions lambda/src/routes/profile/get_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json

from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.metrics import Metrics, MetricUnit

from src.services.auth_account_manager import AuthAccountManager
from src.services.jwt_verifier import JWTVerifier
Expand All @@ -18,9 +19,11 @@ def __init__(
self,
jwt_verifier: JWTVerifier,
auth_account_manager: AuthAccountManager,
metrics: Metrics,
):
self._jwt_verifier = jwt_verifier
self._auth_account_manager = auth_account_manager
self._metrics = metrics

def bind(self, app: APIGatewayRestResolver):
@app.get("/v1/profile")
Expand All @@ -32,16 +35,20 @@ def handle(self, event) -> Response:
auth_header = headers.get("authorization", "")

if not auth_header:
self._metrics.add_metric("JWTAuthFailure", MetricUnit.Count, 1)
return self._error(401, 110, "Missing or invalid authorization")

if not auth_header.startswith("Bearer "):
self._metrics.add_metric("JWTAuthFailure", MetricUnit.Count, 1)
return self._error(401, 110, "Missing or invalid authorization")

token = auth_header[len("Bearer ") :]

try:
claims = self._jwt_verifier.validate_token(token)
self._metrics.add_metric("JWTAuthSuccess", MetricUnit.Count, 1)
except InvalidTokenError:
self._metrics.add_metric("JWTAuthFailure", MetricUnit.Count, 1)
return self._error(401, 110, "Invalid or expired token")

# Look up account by fxa_uid (from JWT) or fall back to oidcSub lookup
Expand Down
4 changes: 4 additions & 0 deletions lambda/src/routes/token/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.metrics import Metrics, MetricUnit

from src.services.token_generator import TokenGenerator
from src.services.user_manager import UserManager
Expand Down Expand Up @@ -37,11 +38,13 @@ def __init__(
oidc_validator,
user_manager: UserManager,
token_generator: TokenGenerator,
metrics: Metrics,
retry_after_seconds: int = 30,
):
self.oidc_validator = oidc_validator # OIDCValidator or JWTVerifier
self.user_manager = user_manager
self.token_generator = token_generator
self._metrics = metrics
self.retry_after_seconds = retry_after_seconds

def bind(self, app: APIGatewayRestResolver):
Expand Down Expand Up @@ -170,6 +173,7 @@ def handle(self, event) -> Response:
)

result = TokenOutput.model_validate(asdict(token_response))
self._metrics.add_metric("SyncTokensIssued", MetricUnit.Count, 1)
return Response(
status_code=200,
content_type="application/json",
Expand Down
5 changes: 5 additions & 0 deletions lambda/src/services/fxa_token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import mohawk
import mohawk.exc
from aws_lambda_powertools import Logger
from aws_lambda_powertools.metrics import MetricUnit
from botocore.exceptions import ClientError

from src.services import fxa_crypto
Expand Down Expand Up @@ -37,17 +38,20 @@ def extract_token_id_from_hawk_header(authorization_header: str) -> str | None:
def __init__(
self,
table,
metrics,
session_ttl_seconds: int = 2592000,
keyfetch_ttl_seconds: int = 300,
):
"""Initialize FxATokenManager.

Args:
table: DynamoDB Table resource (same auth table used by AuthAccountManager)
metrics: PowerTools Metrics instance
session_ttl_seconds: Session token TTL (default 30 days)
keyfetch_ttl_seconds: Key-fetch token TTL (default 5 minutes)
"""
self.table = table
self._metrics = metrics
self.session_ttl_seconds = session_ttl_seconds
self.keyfetch_ttl_seconds = keyfetch_ttl_seconds

Expand Down Expand Up @@ -84,6 +88,7 @@ def create_session_token(self, uid: str) -> bytes:
"reqHMACkey": req_hmac_key.hex(),
}
)
self._metrics.add_metric("SessionsCreated", MetricUnit.Count, 1)

return token

Expand Down
2 changes: 2 additions & 0 deletions lambda/tests/routes/auth/test_oauth_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def route(mock_oauth_code_manager, mock_jwt_service, mock_account_manager, mock_
jwt_service=mock_jwt_service,
account_manager=mock_account_manager,
token_manager=mock_token_manager,
metrics=MagicMock(),
)


Expand Down Expand Up @@ -597,6 +598,7 @@ def test_returns_400_when_token_manager_not_configured(
oauth_code_manager=mock_oauth_code_manager,
jwt_service=mock_jwt_service,
account_manager=mock_account_manager,
metrics=MagicMock(),
)
event = APIGatewayProxyEvent(
{
Expand Down
5 changes: 4 additions & 1 deletion lambda/tests/routes/auth/test_route_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def test_oauth_authorization_dispatches(self):

def test_oauth_token_dispatches(self):
route = OAuthTokenRoute(
oauth_code_manager=MagicMock(), jwt_service=MagicMock(), account_manager=MagicMock()
oauth_code_manager=MagicMock(),
jwt_service=MagicMock(),
account_manager=MagicMock(),
metrics=MagicMock(),
)
result = _router(route).handler(
_make_event("POST", "/v1/oauth/token", body="{}"), _make_context()
Expand Down
1 change: 1 addition & 0 deletions lambda/tests/routes/profile/test_get_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def route(mock_jwt_verifier, mock_auth_account_manager):
return GetProfileRoute(
jwt_verifier=mock_jwt_verifier,
auth_account_manager=mock_auth_account_manager,
metrics=MagicMock(),
)


Expand Down
3 changes: 3 additions & 0 deletions lambda/tests/routes/token/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def request_token_route(mock_oidc_validator, mock_user_manager, mock_token_gener
oidc_validator=mock_oidc_validator,
user_manager=mock_user_manager,
token_generator=mock_token_generator,
metrics=MagicMock(),
)


Expand Down Expand Up @@ -109,6 +110,7 @@ def test_init_stores_dependencies(
oidc_validator=mock_oidc_validator,
user_manager=mock_user_manager,
token_generator=mock_token_generator,
metrics=MagicMock(),
)
assert route.oidc_validator is mock_oidc_validator
assert route.user_manager is mock_user_manager
Expand Down Expand Up @@ -1128,6 +1130,7 @@ def test_retry_after_header_custom_value(
user_manager=mock_user_manager,
token_generator=mock_token_generator,
retry_after_seconds=60,
metrics=MagicMock(),
)
mock_oidc_validator.validate_token.side_effect = ServiceUnavailableError(
"OIDC provider unreachable"
Expand Down
Loading
Loading