From bee23fc9ec0a4d41fe5ac8a303333e92b5076760 Mon Sep 17 00:00:00 2001 From: David Lawrence Date: Wed, 3 Jun 2026 13:59:10 -0400 Subject: [PATCH 1/7] Bug 2044654 - Update the Github ETL process to run repository exports in parallel rather one one after another --- main.py | 261 ++++++++++++++++++++++++++++++--------------- tests/test_main.py | 86 +++++++++++++++ 2 files changed, 262 insertions(+), 85 deletions(-) diff --git a/main.py b/main.py index 2645f69..d7487c4 100755 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ and loads it into a BigQuery dataset using chunked processing. """ +import concurrent.futures import logging import os import re @@ -836,6 +837,147 @@ def load_data( ) +_DEFAULT_MAX_WORKERS: int = 8 + + +def _resolve_max_workers(repo_count: int) -> int: + """ + Decide how many repos to process concurrently. + + Defaults to one worker per repo, capped at ``_DEFAULT_MAX_WORKERS`` to avoid + spawning an unbounded number of threads (and GitHub API connections) for large + repo lists. The cap is overridable via the ``GITHUB_ETL_MAX_WORKERS`` env var. + + Args: + repo_count: Number of repositories to process + + Returns: + Worker count, always at least 1. + """ + cap = _DEFAULT_MAX_WORKERS + override = os.environ.get("GITHUB_ETL_MAX_WORKERS") + if override: + try: + parsed = int(override) + if parsed > 0: + cap = parsed + else: + logger.warning( + f"Ignoring non-positive GITHUB_ETL_MAX_WORKERS={override!r}" + ) + except ValueError: + logger.warning(f"Ignoring invalid GITHUB_ETL_MAX_WORKERS={override!r}") + return max(1, min(repo_count, cap)) + + +def process_repo( + repo: str, + github_app_id: str | None, + github_private_key: str | None, + github_api_url: str, + bigquery_client: bigquery.Client, + bigquery_dataset: str, + snapshot_date: str, + use_streaming_insert: bool, +) -> int: + """ + Run the full extract/transform/load pipeline for a single repository. + + This is the unit of work executed per worker thread. It creates its own + ``requests.Session`` so that repos processed concurrently never share or + clobber each other's ``Authorization`` header (installation tokens are + per-repo and the header is rewritten on every refresh). + + Args: + repo: Repository in "owner/repo" format + github_app_id: GitHub App ID, or None to run unauthenticated + github_private_key: RSA private key (PEM), or None to run unauthenticated + github_api_url: GitHub API base URL + bigquery_client: Shared BigQuery client (thread-safe for queries/loads) + bigquery_dataset: BigQuery dataset ID + snapshot_date: Snapshot date string in YYYY-MM-DD format + use_streaming_insert: Whether to use streaming inserts (emulator only) + + Returns: + Number of PRs processed for this repo. + """ + # Each thread gets its own session; requests.Session is not safe to share + # across threads and we rewrite the Authorization header per repo. + session = requests.Session() + session.headers.update( + { + "Accept": "application/vnd.github+json", + "User-Agent": "gh-pr-scraper/1.0 (+https://api.github.com)", + } + ) + + # Build a per-repo token refresh callable. It is called by the generator + # before each page fetch, so every API request (PRs + commits + reviewers + + # comments) uses a valid token. The access_token_cache means this only hits + # the GitHub API when the cached token has <60 seconds remaining. + refresh_auth: Callable[[], None] | None = None + if github_app_id and github_private_key: + + def _refresh() -> None: + try: + app_jwt = generate_github_jwt(github_app_id, github_private_key) + access_token = get_installation_access_token( + app_jwt, repo, github_api_url + ) + except Exception as e: + raise RuntimeError( + f"Failed to obtain GitHub App access token for {repo}: {e}. " + "Check that GITHUB_APP_ID is correct and GITHUB_PRIVATE_KEY " + "is a valid PEM-encoded RSA private key." + ) from e + session.headers["Authorization"] = f"Bearer {access_token}" + + refresh_auth = _refresh + # Set the token immediately so the first generator page is authenticated. + refresh_auth() + + # Delete any existing rows for this (repo, snapshot_date) before loading. + # This makes every run idempotent: if a previous run crashed mid-way and left + # partial data, a rerun will clean up the partial write and reload cleanly. + if snapshot_exists(bigquery_client, bigquery_dataset, repo, snapshot_date): + logger.info( + f"Deleting partial/existing snapshot for {repo} on {snapshot_date} before reload" + ) + delete_existing_snapshot(bigquery_client, bigquery_dataset, repo, snapshot_date) + + processed = 0 + for chunk_count, chunk in enumerate( + extract_pull_requests( + session, + repo, + chunk_size=100, + github_api_url=github_api_url, + refresh_auth=refresh_auth, + ), + start=1, + ): + logger.info(f"[{repo}] Processing chunk {chunk_count} with {len(chunk)} PRs") + + # Transform + transformed_data = transform_data(chunk, repo) + + # Load + load_data( + bigquery_client, + bigquery_dataset, + transformed_data, + snapshot_date, + use_streaming_insert=use_streaming_insert, + ) + + processed += len(chunk) + logger.info( + f"[{repo}] Completed chunk {chunk_count}. PRs processed for repo: {processed}" + ) + + return processed + + def main() -> int: """ Main ETL process with chunked processing. @@ -879,16 +1021,6 @@ def _main() -> int: if not bigquery_dataset: raise SystemExit("Environment variable BIGQUERY_DATASET is required") - # Setup GitHub session; the Authorization header is updated before each repo using - # an installation access token (which may be cached) - session = requests.Session() - session.headers.update( - { - "Accept": "application/vnd.github+json", - "User-Agent": "gh-pr-scraper/1.0 (+https://api.github.com)", - } - ) - github_api_url = os.environ.get("GITHUB_API_URL", "https://api.github.com") if os.environ.get("GITHUB_API_URL"): logger.info(f"Using custom GitHub API URL: {github_api_url}") @@ -921,83 +1053,42 @@ def _main() -> int: failed_repos: list[str] = [] - for repo in github_repos: - try: - # Delete any existing rows for this (repo, snapshot_date) before loading. - # This makes every run idempotent: if a previous run crashed mid-way and left - # partial data, a rerun will clean up the partial write and reload cleanly. - if snapshot_exists(bigquery_client, bigquery_dataset, repo, snapshot_date): - logger.info( - f"Deleting partial/existing snapshot for {repo} on {snapshot_date} before reload" - ) - delete_existing_snapshot( - bigquery_client, bigquery_dataset, repo, snapshot_date - ) - - # Build a per-repo token refresh callable. It is called by the generator - # before each page fetch, so every API request (PRs + commits + reviewers + - # comments) uses a valid token. The access_token_cache means this only hits - # the GitHub API when the cached token has <60 seconds remaining. - refresh_auth: Callable[[], None] | None = None - if github_app_id and github_private_key: - - def _make_refresh( - _repo: str = repo, - ) -> Callable[[], None]: - def _refresh() -> None: - try: - app_jwt = generate_github_jwt( - github_app_id, github_private_key - ) - access_token = get_installation_access_token( - app_jwt, _repo, github_api_url - ) - except Exception as e: - raise RuntimeError( - f"Failed to obtain GitHub App access token for {_repo}: {e}. " - "Check that GITHUB_APP_ID is correct and GITHUB_PRIVATE_KEY " - "is a valid PEM-encoded RSA private key." - ) from e - session.headers["Authorization"] = f"Bearer {access_token}" - - return _refresh - - refresh_auth = _make_refresh() - # Set the token immediately so the first generator page is authenticated. - refresh_auth() + # Each repo is independent (its own session, token, and BigQuery rows keyed by + # target_repository), so they are processed concurrently. The work is I/O-bound + # (GitHub API + BigQuery), so threads — not processes — are the right fit. + max_workers = _resolve_max_workers(len(github_repos)) + logger.info( + f"Processing {len(github_repos)} repo(s) with up to {max_workers} worker(s)" + ) - for chunk_count, chunk in enumerate( - extract_pull_requests( - session, - repo, - chunk_size=100, - github_api_url=github_api_url, - refresh_auth=refresh_auth, - ), - start=1, - ): - logger.info(f"Processing chunk {chunk_count} with {len(chunk)} PRs") - - # Transform - transformed_data = transform_data(chunk, repo) - - # Load - load_data( - bigquery_client, - bigquery_dataset, - transformed_data, - snapshot_date, - use_streaming_insert=bool(emulator_host), - ) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_repo = { + executor.submit( + process_repo, + repo, + github_app_id, + github_private_key, + github_api_url, + bigquery_client, + bigquery_dataset, + snapshot_date, + bool(emulator_host), + ): repo + for repo in github_repos + } - total_processed += len(chunk) - logger.info( - f"Completed chunk {chunk_count}. Total PRs processed: {total_processed}" - ) - except (TooManyRetriesError, RuntimeError) as exc: - logger.error(f"Failed to process repo {repo}: {exc}") - failed_repos.append(repo) - continue + for future in concurrent.futures.as_completed(future_to_repo): + repo = future_to_repo[future] + try: + processed = future.result() + except (TooManyRetriesError, RuntimeError) as exc: + logger.error(f"Failed to process repo {repo}: {exc}") + failed_repos.append(repo) + continue + total_processed += processed + logger.info( + f"Finished repo {repo}: {processed} PRs. Total so far: {total_processed}" + ) if failed_repos: logger.error( diff --git a/tests/test_main.py b/tests/test_main.py index 250905e..9c7271e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,5 @@ import os +import threading from unittest.mock import MagicMock, Mock, patch import pytest @@ -537,3 +538,88 @@ def extract_side_effect(*args, **kwargs): assert result == 1 # partial failure assert mock_extract.call_count == 2 # both repos were attempted mock_load.assert_called_once() # only the successful repo loaded data + + +class TestResolveMaxWorkers: + """Tests for _resolve_max_workers worker-count resolution.""" + + def test_defaults_to_repo_count_when_below_cap(self): + with patch.dict(os.environ, {}, clear=True): + assert main._resolve_max_workers(3) == 3 + + def test_caps_at_default_for_large_repo_lists(self): + with patch.dict(os.environ, {}, clear=True): + assert main._resolve_max_workers(100) == main._DEFAULT_MAX_WORKERS + + def test_never_returns_less_than_one(self): + with patch.dict(os.environ, {}, clear=True): + assert main._resolve_max_workers(0) == 1 + + def test_env_override_raises_cap(self): + with patch.dict(os.environ, {"GITHUB_ETL_MAX_WORKERS": "20"}, clear=True): + assert main._resolve_max_workers(15) == 15 + + def test_env_override_still_bounded_by_repo_count(self): + with patch.dict(os.environ, {"GITHUB_ETL_MAX_WORKERS": "20"}, clear=True): + assert main._resolve_max_workers(2) == 2 + + def test_invalid_env_override_falls_back_to_default(self): + with patch.dict(os.environ, {"GITHUB_ETL_MAX_WORKERS": "abc"}, clear=True): + assert main._resolve_max_workers(100) == main._DEFAULT_MAX_WORKERS + + def test_non_positive_env_override_falls_back_to_default(self): + with patch.dict(os.environ, {"GITHUB_ETL_MAX_WORKERS": "0"}, clear=True): + assert main._resolve_max_workers(100) == main._DEFAULT_MAX_WORKERS + + +@patch("main.setup_logging") +@patch("main.bigquery.Client") +@patch("requests.Session") +@patch("main.transform_data") +@patch("main.load_data") +def test_repos_are_processed_concurrently( + mock_load, + mock_transform, + mock_session_class, + mock_bq_client, + mock_setup_logging, +): + """Repos run in parallel: a barrier that all repos must reach proves overlap. + + If processing were sequential, the first repo would block forever at the + barrier (the others never start), so barrier.wait() would time out and the + test would fail with BrokenBarrierError. + """ + repos = "mozilla/firefox,mozilla/gecko-dev,mozilla/addons" + num_repos = len(repos.split(",")) + barrier = threading.Barrier(num_repos, timeout=5) + + def extract_side_effect(*args, **kwargs): + # Every repo's worker must reach the barrier before any may proceed. + barrier.wait() + return iter([[{"number": 1}]]) + + mock_transform.return_value = { + "pull_requests": [{"pull_request_id": 1}], + "commits": [], + "reviewers": [], + "comments": [], + } + + with ( + patch.dict( + os.environ, + { + "GITHUB_REPOS": repos, + "BIGQUERY_PROJECT": "test", + "BIGQUERY_DATASET": "test", + }, + clear=True, + ), + patch("main.extract_pull_requests", side_effect=extract_side_effect), + ): + result = main.main() + + assert result == 0 + assert not barrier.broken # all repos reached the barrier => true concurrency + assert mock_load.call_count == num_repos From 047775fbbcda6cd1fd047514da4dfbe65ff12a5c Mon Sep 17 00:00:00 2001 From: David Lawrence Date: Wed, 3 Jun 2026 16:15:44 -0400 Subject: [PATCH 2/7] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index d7487c4..e93f988 100755 --- a/main.py +++ b/main.py @@ -885,8 +885,8 @@ def process_repo( This is the unit of work executed per worker thread. It creates its own ``requests.Session`` so that repos processed concurrently never share or - clobber each other's ``Authorization`` header (installation tokens are - per-repo and the header is rewritten on every refresh). + clobber each other's ``Authorization`` header (installation access tokens are + cached per installation, and the header is rewritten on refresh). Args: repo: Repository in "owner/repo" format From 7d1f107ed0215f1a0e6e55763990d2619aa90239 Mon Sep 17 00:00:00 2001 From: David Lawrence Date: Wed, 3 Jun 2026 16:44:51 -0400 Subject: [PATCH 3/7] Copilot review fixes --- main.py | 119 ++++++++++++++++++++++++++++----------------- tests/test_main.py | 57 ++++++++++++++++++++++ 2 files changed, 132 insertions(+), 44 deletions(-) diff --git a/main.py b/main.py index e93f988..61df340 100755 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ import os import re import sys +import threading import time from dataclasses import dataclass from datetime import datetime, timedelta, timezone @@ -35,6 +36,9 @@ _MAX_AUTH_RETRIES: int = 2 _REQUEST_TIMEOUT: float = 30.0 +# Default ceiling on concurrent repo workers (overridable via GITHUB_ETL_MAX_WORKERS). +_DEFAULT_MAX_WORKERS: int = 8 + class TooManyRetriesError(Exception): """Raised when all retry attempts for a GitHub API request are exhausted.""" @@ -55,6 +59,12 @@ class AccessToken: access_token_cache: dict[int, AccessToken] = {} repo_installation_cache: dict[str, int] = {} +# Serializes installation-token creation across repo worker threads. Without it, +# all workers miss the cache on startup and each POST /access_tokens for the same +# installation, producing redundant token creations and a burst against GitHub's +# per-installation rate limit. The lock is only held on the slow (cache-miss) path. +_token_lock = threading.Lock() + def generate_github_jwt(app_id: str, private_key_pem: str) -> str: """ @@ -134,52 +144,72 @@ def get_installation_access_token( ) repo_installation_cache[repo] = installation_id - now = datetime.now(timezone.utc) - cached = access_token_cache.get(installation_id) - if cached is not None and cached.expires_at > now + timedelta(seconds=60): + def _cached_token() -> str | None: + cached = access_token_cache.get(installation_id) + if cached is not None and cached.expires_at > datetime.now( + timezone.utc + ) + timedelta(seconds=60): + logger.info( + f"Reusing cached access token for installation {installation_id}, " + f"expires at {cached.expires_at}" + ) + return cached.token + return None + + # Fast path: serve a still-valid cached token without taking the lock. + token = _cached_token() + if token is not None: + return token + + # Slow path: serialize creation so concurrent workers sharing an installation + # don't each POST /access_tokens. Re-check the cache once the lock is held in + # case another thread populated it while we waited. + with _token_lock: + token = _cached_token() + if token is not None: + return token + logger.info( - f"Reusing cached access token for installation {installation_id}, " - f"expires at {cached.expires_at}" + f"Fetching new GitHub App installation access token for installation {installation_id}" ) - return cached.token - - logger.info( - f"Fetching new GitHub App installation access token for installation {installation_id}" - ) - resp = session.post( - f"{github_api_url}/app/installations/{installation_id}/access_tokens", - ) - if ( - resp.status_code == 403 - and int(resp.headers.get("X-RateLimit-Remaining", "1")) == 0 - ): - sleep_for_rate_limit(resp) resp = session.post( f"{github_api_url}/app/installations/{installation_id}/access_tokens", ) - if resp.status_code != 201: - raise RuntimeError( - f"Failed to get installation access token: {resp.status_code}: {resp.text}" - ) + if ( + resp.status_code == 403 + and int(resp.headers.get("X-RateLimit-Remaining", "1")) == 0 + ): + sleep_for_rate_limit(resp) + resp = session.post( + f"{github_api_url}/app/installations/{installation_id}/access_tokens", + ) + if resp.status_code != 201: + raise RuntimeError( + f"Failed to get installation access token: {resp.status_code}: {resp.text}" + ) - try: - data = resp.json() - except requests.exceptions.JSONDecodeError as e: - raise RuntimeError(f"Failed to parse access token response: {e}: {resp.text}") - try: - access_token = AccessToken( - token=data["token"], - expires_at=datetime.fromisoformat(data["expires_at"]), - ) - except KeyError as e: - raise RuntimeError( - f"Unexpected access token response structure, missing key {e}: {resp.text}" - ) - except ValueError as e: - raise RuntimeError(f"Invalid expires_at format in access token response: {e}") - access_token_cache[installation_id] = access_token - logger.info(f"Obtained new access token, expires at {access_token.expires_at}") - return access_token.token + try: + data = resp.json() + except requests.exceptions.JSONDecodeError as e: + raise RuntimeError( + f"Failed to parse access token response: {e}: {resp.text}" + ) + try: + access_token = AccessToken( + token=data["token"], + expires_at=datetime.fromisoformat(data["expires_at"]), + ) + except KeyError as e: + raise RuntimeError( + f"Unexpected access token response structure, missing key {e}: {resp.text}" + ) + except ValueError as e: + raise RuntimeError( + f"Invalid expires_at format in access token response: {e}" + ) + access_token_cache[installation_id] = access_token + logger.info(f"Obtained new access token, expires at {access_token.expires_at}") + return access_token.token def setup_logging() -> None: @@ -837,9 +867,6 @@ def load_data( ) -_DEFAULT_MAX_WORKERS: int = 8 - - def _resolve_max_workers(repo_count: int) -> int: """ Decide how many repos to process concurrently. @@ -1081,7 +1108,11 @@ def _main() -> int: repo = future_to_repo[future] try: processed = future.result() - except (TooManyRetriesError, RuntimeError) as exc: + except Exception as exc: + # Catch broadly so one repo's failure (a TooManyRetriesError, a + # RuntimeError, or a bare Exception from load_data) is recorded as a + # failed repo rather than propagating out of the executor and + # discarding the results of other in-flight repos. logger.error(f"Failed to process repo {repo}: {exc}") failed_repos.append(repo) continue diff --git a/tests/test_main.py b/tests/test_main.py index 9c7271e..e5ca30c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -540,6 +540,63 @@ def extract_side_effect(*args, **kwargs): mock_load.assert_called_once() # only the successful repo loaded data +@patch("main.setup_logging") +@patch("main.bigquery.Client") +@patch("requests.Session") +@patch("main.extract_pull_requests") +@patch("main.transform_data") +@patch("main.load_data") +def test_bare_exception_on_one_repo_is_isolated( + mock_load, + mock_transform, + mock_extract, + mock_session_class, + mock_bq_client, + mock_setup_logging, +): + """A bare Exception (e.g. from load_data) on one repo must not abort the others. + + The executor catches broadly, so the failing repo is recorded in failed_repos + (overall exit code 1) while the healthy repo still completes its load. + """ + # Fresh iterator per repo (a shared return_value iterator would be exhausted + # by whichever repo consumes it first). + mock_extract.side_effect = lambda *a, **k: iter([[{"number": 1}]]) + + def load_side_effect(client, dataset, transformed, *args, **kwargs): + # Fail only for firefox; gecko-dev should still load successfully. + if transformed["pull_requests"][0].get("repo_marker") == "fail": + raise Exception("BigQuery insert errors for table pull_requests") + + # Tag the transform output per repo so load_side_effect can decide which fails. + def transform_side_effect(chunk, repo): + marker = "fail" if repo == "mozilla/firefox" else "ok" + return { + "pull_requests": [{"pull_request_id": 1, "repo_marker": marker}], + "commits": [], + "reviewers": [], + "comments": [], + } + + mock_transform.side_effect = transform_side_effect + mock_load.side_effect = load_side_effect + + with patch.dict( + os.environ, + { + "GITHUB_REPOS": "mozilla/firefox,mozilla/gecko-dev", + "BIGQUERY_PROJECT": "test", + "BIGQUERY_DATASET": "test", + }, + clear=True, + ): + result = main.main() + + assert result == 1 # partial failure recorded, run did not abort + assert mock_extract.call_count == 2 # both repos were attempted + assert mock_load.call_count == 2 # both repos reached the load step + + class TestResolveMaxWorkers: """Tests for _resolve_max_workers worker-count resolution.""" From 96b9cda35a9e07c7892a0620585ab9a621f8d181 Mon Sep 17 00:00:00 2001 From: David Lawrence Date: Wed, 3 Jun 2026 17:29:17 -0400 Subject: [PATCH 4/7] More Copilot review fixes --- main.py | 43 ++++++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index 61df340..86461f3 100755 --- a/main.py +++ b/main.py @@ -59,11 +59,24 @@ class AccessToken: access_token_cache: dict[int, AccessToken] = {} repo_installation_cache: dict[str, int] = {} -# Serializes installation-token creation across repo worker threads. Without it, -# all workers miss the cache on startup and each POST /access_tokens for the same -# installation, producing redundant token creations and a burst against GitHub's -# per-installation rate limit. The lock is only held on the slow (cache-miss) path. -_token_lock = threading.Lock() +# Serializes installation-token creation per installation. Without it, all workers +# sharing an installation miss the cache on startup and each POST /access_tokens, +# producing redundant token creations and a burst against GitHub's per-installation +# rate limit. Locks are keyed by installation ID so a cache miss (or rate-limit +# sleep) for one installation never blocks token creation for another. A lock is +# only held on the slow (cache-miss) path. The guard protects the map itself. +_token_locks: dict[int, threading.Lock] = {} +_token_locks_guard = threading.Lock() + + +def _lock_for_installation(installation_id: int) -> threading.Lock: + """Return the (lazily created) token-creation lock for an installation.""" + with _token_locks_guard: + lock = _token_locks.get(installation_id) + if lock is None: + lock = threading.Lock() + _token_locks[installation_id] = lock + return lock def generate_github_jwt(app_id: str, private_key_pem: str) -> str: @@ -162,9 +175,11 @@ def _cached_token() -> str | None: return token # Slow path: serialize creation so concurrent workers sharing an installation - # don't each POST /access_tokens. Re-check the cache once the lock is held in - # case another thread populated it while we waited. - with _token_lock: + # don't each POST /access_tokens. The lock is specific to this installation, + # so a rate-limit sleep here never blocks workers on other installations. + # Re-check the cache once the lock is held in case another thread populated it + # while we waited. + with _lock_for_installation(installation_id): token = _cached_token() if token is not None: return token @@ -1069,7 +1084,12 @@ def _main() -> int: github_repos = [] github_repos_str = os.getenv("GITHUB_REPOS") if github_repos_str: - github_repos = [r.strip() for r in github_repos_str.split(",") if r.strip()] + # Deduplicate while preserving order: with concurrent processing, a repo + # listed twice would otherwise have its delete_existing_snapshot() and + # load_data() interleave with its duplicate, corrupting the snapshot. + github_repos = list( + dict.fromkeys(r.strip() for r in github_repos_str.split(",") if r.strip()) + ) else: raise SystemExit( "Environment variable GITHUB_REPOS is required (format: 'owner/repo,owner/repo')" @@ -1112,8 +1132,9 @@ def _main() -> int: # Catch broadly so one repo's failure (a TooManyRetriesError, a # RuntimeError, or a bare Exception from load_data) is recorded as a # failed repo rather than propagating out of the executor and - # discarding the results of other in-flight repos. - logger.error(f"Failed to process repo {repo}: {exc}") + # discarding the results of other in-flight repos. logger.exception + # records the worker thread's traceback for debugging in CI/prod. + logger.exception(f"Failed to process repo {repo}: {exc}") failed_repos.append(repo) continue total_processed += processed From 7a75aea7a5087f63f09a694b931e95c5c984bba3 Mon Sep 17 00:00:00 2001 From: David Lawrence Date: Mon, 15 Jun 2026 09:53:15 -0400 Subject: [PATCH 5/7] Review fixes --- main.py | 143 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 88 insertions(+), 55 deletions(-) diff --git a/main.py b/main.py index 86461f3..c38579a 100755 --- a/main.py +++ b/main.py @@ -56,27 +56,70 @@ class AccessToken: expires_at: datetime -access_token_cache: dict[int, AccessToken] = {} repo_installation_cache: dict[str, int] = {} -# Serializes installation-token creation per installation. Without it, all workers -# sharing an installation miss the cache on startup and each POST /access_tokens, -# producing redundant token creations and a burst against GitHub's per-installation -# rate limit. Locks are keyed by installation ID so a cache miss (or rate-limit -# sleep) for one installation never blocks token creation for another. A lock is -# only held on the slow (cache-miss) path. The guard protects the map itself. -_token_locks: dict[int, threading.Lock] = {} -_token_locks_guard = threading.Lock() +class TokenStore: + """ + Thread-safe cache of GitHub App installation access tokens. + + Tokens are cached per installation ID and reused until they are within 60 + seconds of expiry. Token creation is serialized per installation via a lock so + that concurrent workers sharing an installation don't each POST + /access_tokens, which would produce redundant token creations and a burst + against GitHub's per-installation rate limit. Locks are keyed by installation + ID so a cache miss (or rate-limit sleep) for one installation never blocks + token creation for another. The lock is only held on the slow (cache-miss) + path; the guard protects the lock map itself. + """ + + def __init__(self) -> None: + self._tokens: dict[int, AccessToken] = {} + self._locks: dict[int, threading.Lock] = {} + self._locks_guard = threading.Lock() + + def cached_token(self, installation_id: int) -> str | None: + """Return a still-valid cached token for the installation, or None.""" + cached = self._tokens.get(installation_id) + if cached is not None and cached.expires_at > datetime.now( + timezone.utc + ) + timedelta(seconds=60): + logger.info( + f"Reusing cached access token for installation {installation_id}, " + f"expires at {cached.expires_at}" + ) + return cached.token + return None -def _lock_for_installation(installation_id: int) -> threading.Lock: - """Return the (lazily created) token-creation lock for an installation.""" - with _token_locks_guard: - lock = _token_locks.get(installation_id) - if lock is None: - lock = threading.Lock() - _token_locks[installation_id] = lock - return lock + def lock_for(self, installation_id: int) -> threading.Lock: + """Return the (lazily created) token-creation lock for an installation.""" + with self._locks_guard: + lock = self._locks.get(installation_id) + if lock is None: + lock = threading.Lock() + self._locks[installation_id] = lock + return lock + + def store(self, installation_id: int, token: AccessToken) -> None: + """Cache *token* for *installation_id*.""" + self._tokens[installation_id] = token + + +# Module-level token store shared across all worker threads. +token_store = TokenStore() + + +def _is_rate_limited(resp: requests.Response) -> bool: + """ + Return True when *resp* indicates an exhausted primary rate limit. + + GitHub signals a hit primary rate limit with either 403 or 429 plus + ``X-RateLimit-Remaining: 0``. Both status codes are treated identically. + """ + return ( + resp.status_code in (403, 429) + and int(resp.headers.get("X-RateLimit-Remaining", "1")) == 0 + ) def generate_github_jwt(app_id: str, private_key_pem: str) -> str: @@ -138,10 +181,7 @@ def get_installation_access_token( installation_id = repo_installation_cache.get(repo) if installation_id is None: resp = session.get(f"{github_api_url}/repos/{repo}/installation") - if ( - resp.status_code == 403 - and int(resp.headers.get("X-RateLimit-Remaining", "1")) == 0 - ): + if _is_rate_limited(resp): sleep_for_rate_limit(resp) resp = session.get(f"{github_api_url}/repos/{repo}/installation") if resp.status_code != 200: @@ -157,20 +197,8 @@ def get_installation_access_token( ) repo_installation_cache[repo] = installation_id - def _cached_token() -> str | None: - cached = access_token_cache.get(installation_id) - if cached is not None and cached.expires_at > datetime.now( - timezone.utc - ) + timedelta(seconds=60): - logger.info( - f"Reusing cached access token for installation {installation_id}, " - f"expires at {cached.expires_at}" - ) - return cached.token - return None - # Fast path: serve a still-valid cached token without taking the lock. - token = _cached_token() + token = token_store.cached_token(installation_id) if token is not None: return token @@ -179,8 +207,8 @@ def _cached_token() -> str | None: # so a rate-limit sleep here never blocks workers on other installations. # Re-check the cache once the lock is held in case another thread populated it # while we waited. - with _lock_for_installation(installation_id): - token = _cached_token() + with token_store.lock_for(installation_id): + token = token_store.cached_token(installation_id) if token is not None: return token @@ -190,10 +218,7 @@ def _cached_token() -> str | None: resp = session.post( f"{github_api_url}/app/installations/{installation_id}/access_tokens", ) - if ( - resp.status_code == 403 - and int(resp.headers.get("X-RateLimit-Remaining", "1")) == 0 - ): + if _is_rate_limited(resp): sleep_for_rate_limit(resp) resp = session.post( f"{github_api_url}/app/installations/{installation_id}/access_tokens", @@ -222,7 +247,7 @@ def _cached_token() -> str | None: raise RuntimeError( f"Invalid expires_at format in access token response: {e}" ) - access_token_cache[installation_id] = access_token + token_store.store(installation_id, access_token) logger.info(f"Obtained new access token, expires at {access_token.expires_at}") return access_token.token @@ -519,10 +544,7 @@ def github_get( if resp.status_code == 200: return resp - if ( - resp.status_code in (403, 429) - and int(resp.headers.get("X-RateLimit-Remaining", "1")) == 0 - ): + if _is_rate_limited(resp): sleep_for_rate_limit(resp) continue @@ -882,6 +904,24 @@ def load_data( ) +def _build_session() -> requests.Session: + """ + Create a ``requests.Session`` with the default GitHub API headers. + + Each worker thread gets its own session because ``requests.Session`` is not + safe to share across threads and the per-repo ``Authorization`` header is + rewritten on token refresh. + """ + session = requests.Session() + session.headers.update( + { + "Accept": "application/vnd.github+json", + "User-Agent": "gh-pr-scraper/1.0 (+https://api.github.com)", + } + ) + return session + + def _resolve_max_workers(repo_count: int) -> int: """ Decide how many repos to process concurrently. @@ -945,22 +985,16 @@ def process_repo( """ # Each thread gets its own session; requests.Session is not safe to share # across threads and we rewrite the Authorization header per repo. - session = requests.Session() - session.headers.update( - { - "Accept": "application/vnd.github+json", - "User-Agent": "gh-pr-scraper/1.0 (+https://api.github.com)", - } - ) + session = _build_session() # Build a per-repo token refresh callable. It is called by the generator # before each page fetch, so every API request (PRs + commits + reviewers + - # comments) uses a valid token. The access_token_cache means this only hits + # comments) uses a valid token. The token_store cache means this only hits # the GitHub API when the cached token has <60 seconds remaining. refresh_auth: Callable[[], None] | None = None if github_app_id and github_private_key: - def _refresh() -> None: + def refresh_auth() -> None: try: app_jwt = generate_github_jwt(github_app_id, github_private_key) access_token = get_installation_access_token( @@ -974,7 +1008,6 @@ def _refresh() -> None: ) from e session.headers["Authorization"] = f"Bearer {access_token}" - refresh_auth = _refresh # Set the token immediately so the first generator page is authenticated. refresh_auth() From 7a870a5b1a0bab36e18cc3c3633e765dede07457 Mon Sep 17 00:00:00 2001 From: David Lawrence Date: Mon, 15 Jun 2026 10:20:28 -0400 Subject: [PATCH 6/7] Copilot review fixes --- main.py | 217 ++++++++++++++++++++++++++++++++------------------------ 1 file changed, 123 insertions(+), 94 deletions(-) diff --git a/main.py b/main.py index c38579a..ca2fd2d 100755 --- a/main.py +++ b/main.py @@ -70,17 +70,20 @@ class TokenStore: against GitHub's per-installation rate limit. Locks are keyed by installation ID so a cache miss (or rate-limit sleep) for one installation never blocks token creation for another. The lock is only held on the slow (cache-miss) - path; the guard protects the lock map itself. + path. Two short-lived guards protect the shared dicts themselves: one for the + token cache and one for the lock map. """ def __init__(self) -> None: self._tokens: dict[int, AccessToken] = {} + self._tokens_guard = threading.Lock() self._locks: dict[int, threading.Lock] = {} self._locks_guard = threading.Lock() def cached_token(self, installation_id: int) -> str | None: """Return a still-valid cached token for the installation, or None.""" - cached = self._tokens.get(installation_id) + with self._tokens_guard: + cached = self._tokens.get(installation_id) if cached is not None and cached.expires_at > datetime.now( timezone.utc ) + timedelta(seconds=60): @@ -102,7 +105,8 @@ def lock_for(self, installation_id: int) -> threading.Lock: def store(self, installation_id: int, token: AccessToken) -> None: """Cache *token* for *installation_id*.""" - self._tokens[installation_id] = token + with self._tokens_guard: + self._tokens[installation_id] = token # Module-level token store shared across all worker threads. @@ -169,87 +173,92 @@ def get_installation_access_token( Installation access token string """ - session = requests.Session() - session.headers.update( - { - "Authorization": f"Bearer {app_jwt}", - "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28", - } - ) + # Use a context manager so the temporary session's connection pool is closed + # on every return path (cache hits included), rather than leaking sockets as + # refresh_auth calls this repeatedly across concurrent repo workers. + with requests.Session() as session: + session.headers.update( + { + "Authorization": f"Bearer {app_jwt}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + ) - installation_id = repo_installation_cache.get(repo) - if installation_id is None: - resp = session.get(f"{github_api_url}/repos/{repo}/installation") - if _is_rate_limited(resp): - sleep_for_rate_limit(resp) + installation_id = repo_installation_cache.get(repo) + if installation_id is None: resp = session.get(f"{github_api_url}/repos/{repo}/installation") - if resp.status_code != 200: - raise RuntimeError( - f"Failed to get GitHub App installation for {repo}: " - f"{resp.status_code}: {resp.text}" - ) - try: - installation_id = resp.json()["id"] - except (requests.exceptions.JSONDecodeError, KeyError) as e: - raise RuntimeError( - f"Failed to parse installation response for {repo}: {e}: {resp.text}" - ) - repo_installation_cache[repo] = installation_id - - # Fast path: serve a still-valid cached token without taking the lock. - token = token_store.cached_token(installation_id) - if token is not None: - return token - - # Slow path: serialize creation so concurrent workers sharing an installation - # don't each POST /access_tokens. The lock is specific to this installation, - # so a rate-limit sleep here never blocks workers on other installations. - # Re-check the cache once the lock is held in case another thread populated it - # while we waited. - with token_store.lock_for(installation_id): + if _is_rate_limited(resp): + sleep_for_rate_limit(resp) + resp = session.get(f"{github_api_url}/repos/{repo}/installation") + if resp.status_code != 200: + raise RuntimeError( + f"Failed to get GitHub App installation for {repo}: " + f"{resp.status_code}: {resp.text}" + ) + try: + installation_id = resp.json()["id"] + except (requests.exceptions.JSONDecodeError, KeyError) as e: + raise RuntimeError( + f"Failed to parse installation response for {repo}: {e}: {resp.text}" + ) + repo_installation_cache[repo] = installation_id + + # Fast path: serve a still-valid cached token without taking the lock. token = token_store.cached_token(installation_id) if token is not None: return token - logger.info( - f"Fetching new GitHub App installation access token for installation {installation_id}" - ) - resp = session.post( - f"{github_api_url}/app/installations/{installation_id}/access_tokens", - ) - if _is_rate_limited(resp): - sleep_for_rate_limit(resp) + # Slow path: serialize creation so concurrent workers sharing an installation + # don't each POST /access_tokens. The lock is specific to this installation, + # so a rate-limit sleep here never blocks workers on other installations. + # Re-check the cache once the lock is held in case another thread populated it + # while we waited. + with token_store.lock_for(installation_id): + token = token_store.cached_token(installation_id) + if token is not None: + return token + + logger.info( + f"Fetching new GitHub App installation access token for installation {installation_id}" + ) resp = session.post( f"{github_api_url}/app/installations/{installation_id}/access_tokens", ) - if resp.status_code != 201: - raise RuntimeError( - f"Failed to get installation access token: {resp.status_code}: {resp.text}" - ) + if _is_rate_limited(resp): + sleep_for_rate_limit(resp) + resp = session.post( + f"{github_api_url}/app/installations/{installation_id}/access_tokens", + ) + if resp.status_code != 201: + raise RuntimeError( + f"Failed to get installation access token: {resp.status_code}: {resp.text}" + ) - try: - data = resp.json() - except requests.exceptions.JSONDecodeError as e: - raise RuntimeError( - f"Failed to parse access token response: {e}: {resp.text}" - ) - try: - access_token = AccessToken( - token=data["token"], - expires_at=datetime.fromisoformat(data["expires_at"]), - ) - except KeyError as e: - raise RuntimeError( - f"Unexpected access token response structure, missing key {e}: {resp.text}" - ) - except ValueError as e: - raise RuntimeError( - f"Invalid expires_at format in access token response: {e}" + try: + data = resp.json() + except requests.exceptions.JSONDecodeError as e: + raise RuntimeError( + f"Failed to parse access token response: {e}: {resp.text}" + ) + try: + access_token = AccessToken( + token=data["token"], + expires_at=datetime.fromisoformat(data["expires_at"]), + ) + except KeyError as e: + raise RuntimeError( + f"Unexpected access token response structure, missing key {e}: {resp.text}" + ) + except ValueError as e: + raise RuntimeError( + f"Invalid expires_at format in access token response: {e}" + ) + token_store.store(installation_id, access_token) + logger.info( + f"Obtained new access token, expires at {access_token.expires_at}" ) - token_store.store(installation_id, access_token) - logger.info(f"Obtained new access token, expires at {access_token.expires_at}") - return access_token.token + return access_token.token def setup_logging() -> None: @@ -952,6 +961,37 @@ def _resolve_max_workers(repo_count: int) -> int: return max(1, min(repo_count, cap)) +def _make_refresh_auth( + session: requests.Session, + repo: str, + github_app_id: str, + github_private_key: str, + github_api_url: str, +) -> Callable[[], None]: + """ + Build a callable that refreshes *session*'s Authorization header for *repo*. + + The returned callable is invoked by the extraction generator before each page + fetch, so every API request (PRs + commits + reviewers + comments) uses a + valid token. The token_store cache means it only hits the GitHub API when the + cached token has <60 seconds remaining. + """ + + def refresh_auth() -> None: + try: + app_jwt = generate_github_jwt(github_app_id, github_private_key) + access_token = get_installation_access_token(app_jwt, repo, github_api_url) + except Exception as e: + raise RuntimeError( + f"Failed to obtain GitHub App access token for {repo}: {e}. " + "Check that GITHUB_APP_ID is correct and GITHUB_PRIVATE_KEY " + "is a valid PEM-encoded RSA private key." + ) from e + session.headers["Authorization"] = f"Bearer {access_token}" + + return refresh_auth + + def process_repo( repo: str, github_app_id: str | None, @@ -987,27 +1027,16 @@ def process_repo( # across threads and we rewrite the Authorization header per repo. session = _build_session() - # Build a per-repo token refresh callable. It is called by the generator - # before each page fetch, so every API request (PRs + commits + reviewers + - # comments) uses a valid token. The token_store cache means this only hits - # the GitHub API when the cached token has <60 seconds remaining. - refresh_auth: Callable[[], None] | None = None - if github_app_id and github_private_key: - - def refresh_auth() -> None: - try: - app_jwt = generate_github_jwt(github_app_id, github_private_key) - access_token = get_installation_access_token( - app_jwt, repo, github_api_url - ) - except Exception as e: - raise RuntimeError( - f"Failed to obtain GitHub App access token for {repo}: {e}. " - "Check that GITHUB_APP_ID is correct and GITHUB_PRIVATE_KEY " - "is a valid PEM-encoded RSA private key." - ) from e - session.headers["Authorization"] = f"Bearer {access_token}" - + # Build a per-repo token refresh callable when running authenticated. Bound + # once here (no None-then-redeclare) so the name has a single, clear type. + refresh_auth = ( + _make_refresh_auth( + session, repo, github_app_id, github_private_key, github_api_url + ) + if github_app_id and github_private_key + else None + ) + if refresh_auth is not None: # Set the token immediately so the first generator page is authenticated. refresh_auth() From aa18cd3a0bb6835fdb2e233b382a6bfaf337d945 Mon Sep 17 00:00:00 2001 From: David Lawrence Date: Wed, 17 Jun 2026 14:35:13 -0400 Subject: [PATCH 7/7] Review fixes from shtrom --- main.py | 110 ++++++++++++++++++++++++-------------- tests/test_rate_limit.py | 86 ++++++++++++++++++++++++++++- tests/test_token_store.py | 84 +++++++++++++++++++++++++++++ 3 files changed, 237 insertions(+), 43 deletions(-) create mode 100644 tests/test_token_store.py diff --git a/main.py b/main.py index ca2fd2d..bc81a5a 100755 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ import sys import threading import time +from collections import defaultdict from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Callable, Iterator @@ -77,7 +78,7 @@ class TokenStore: def __init__(self) -> None: self._tokens: dict[int, AccessToken] = {} self._tokens_guard = threading.Lock() - self._locks: dict[int, threading.Lock] = {} + self._locks: defaultdict[int, threading.Lock] = defaultdict(threading.Lock) self._locks_guard = threading.Lock() def cached_token(self, installation_id: int) -> str | None: @@ -97,11 +98,7 @@ def cached_token(self, installation_id: int) -> str | None: def lock_for(self, installation_id: int) -> threading.Lock: """Return the (lazily created) token-creation lock for an installation.""" with self._locks_guard: - lock = self._locks.get(installation_id) - if lock is None: - lock = threading.Lock() - self._locks[installation_id] = lock - return lock + return self._locks[installation_id] def store(self, installation_id: int, token: AccessToken) -> None: """Cache *token* for *installation_id*.""" @@ -115,15 +112,23 @@ def store(self, installation_id: int, token: AccessToken) -> None: def _is_rate_limited(resp: requests.Response) -> bool: """ - Return True when *resp* indicates an exhausted primary rate limit. + Return True when *resp* indicates an exhausted rate limit. - GitHub signals a hit primary rate limit with either 403 or 429 plus - ``X-RateLimit-Remaining: 0``. Both status codes are treated identically. + GitHub uses 403 or 429 for two distinct rate limits, both handled here: + + - **Primary**: signaled by ``X-RateLimit-Remaining: 0`` (with a + ``X-RateLimit-Reset`` epoch telling us when it replenishes). + - **Secondary** (abuse detection): signaled by a ``Retry-After`` header, + and frequently *without* ``X-RateLimit-Remaining: 0``. + + A 403/429 carrying neither signal (e.g. a genuine permission error) is + deliberately not treated as rate-limited. """ - return ( - resp.status_code in (403, 429) - and int(resp.headers.get("X-RateLimit-Remaining", "1")) == 0 - ) + if resp.status_code not in (403, 429): + return False + if int(resp.headers.get("X-RateLimit-Remaining", "1")) == 0: + return True + return "Retry-After" in resp.headers def generate_github_jwt(app_id: str, private_key_pem: str) -> str: @@ -176,26 +181,21 @@ def get_installation_access_token( # Use a context manager so the temporary session's connection pool is closed # on every return path (cache hits included), rather than leaking sockets as # refresh_auth calls this repeatedly across concurrent repo workers. - with requests.Session() as session: + with _build_session() as session: session.headers.update( { "Authorization": f"Bearer {app_jwt}", - "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28", } ) installation_id = repo_installation_cache.get(repo) if installation_id is None: - resp = session.get(f"{github_api_url}/repos/{repo}/installation") - if _is_rate_limited(resp): - sleep_for_rate_limit(resp) - resp = session.get(f"{github_api_url}/repos/{repo}/installation") - if resp.status_code != 200: - raise RuntimeError( - f"Failed to get GitHub App installation for {repo}: " - f"{resp.status_code}: {resp.text}" - ) + # No refresh_auth: this request authenticates with the app JWT, so a 401 + # means the JWT itself is bad and refreshing an installation token would + # not help (and would recurse back into this function). + resp = github_request( + session, "GET", f"{github_api_url}/repos/{repo}/installation" + ) try: installation_id = resp.json()["id"] except (requests.exceptions.JSONDecodeError, KeyError) as e: @@ -222,18 +222,12 @@ def get_installation_access_token( logger.info( f"Fetching new GitHub App installation access token for installation {installation_id}" ) - resp = session.post( + resp = github_request( + session, + "POST", f"{github_api_url}/app/installations/{installation_id}/access_tokens", + expected_status=201, ) - if _is_rate_limited(resp): - sleep_for_rate_limit(resp) - resp = session.post( - f"{github_api_url}/app/installations/{installation_id}/access_tokens", - ) - if resp.status_code != 201: - raise RuntimeError( - f"Failed to get installation access token: {resp.status_code}: {resp.text}" - ) try: data = resp.json() @@ -479,11 +473,29 @@ def extract_comments( def sleep_for_rate_limit(resp: requests.Response) -> None: - """Sleep until rate limit resets.""" + """Sleep until the rate limit resets. + + Handles both the primary limit (``X-RateLimit-Remaining: 0`` plus an + ``X-RateLimit-Reset`` epoch) and the secondary/abuse limit (a + ``Retry-After`` header giving a delay in seconds). When both are present + the longer wait wins. + """ + sleep_time = 0 remaining = int(resp.headers.get("X-RateLimit-Remaining", 1)) reset = int(resp.headers.get("X-RateLimit-Reset", 0)) if remaining == 0: - sleep_time = max(0, reset - int(time.time())) + sleep_time = max(sleep_time, reset - int(time.time())) + + retry_after = resp.headers.get("Retry-After") + if retry_after is not None: + try: + sleep_time = max(sleep_time, int(retry_after)) + except ValueError: + # Retry-After may be an HTTP-date; ignore and fall back to reset. + pass + + sleep_time = max(0, sleep_time) + if sleep_time > 0: print( f"Rate limit exceeded. Sleeping for {sleep_time} seconds.", file=sys.stderr ) @@ -496,14 +508,17 @@ def _is_html_error_page(resp: requests.Response) -> bool: return "application/json" not in content_type and resp.status_code >= 400 -def github_get( +def github_request( session: requests.Session, + method: str, url: str, + *, params: dict | None = None, refresh_auth: Callable[[], None] | None = None, + expected_status: int = 200, ) -> requests.Response: """ - Make a GitHub API GET request, retrying on transient errors and expired tokens. + Make a GitHub API request, retrying on transient errors and expired tokens. Retry behaviour: - 403 rate-limit: sleeps until reset, then retries (unbounded, existing behaviour). @@ -533,7 +548,9 @@ def github_get( while True: try: - resp = session.get(url, params=params, timeout=_REQUEST_TIMEOUT) + resp = getattr(session, method.lower())( + url, params=params, timeout=_REQUEST_TIMEOUT + ) except ( requests.exceptions.Timeout, requests.exceptions.ConnectionError, @@ -550,7 +567,7 @@ def github_get( f"GitHub API request failed after retries for {url}: {exc}" ) - if resp.status_code == 200: + if resp.status_code == expected_status: return resp if _is_rate_limited(resp): @@ -593,6 +610,16 @@ def github_get( ) +def github_get( + session: requests.Session, + url: str, + params: dict | None = None, + refresh_auth: Callable[[], None] | None = None, +) -> requests.Response: + """Convenience wrapper around :func:`github_request` for GET requests.""" + return github_request(session, "GET", url, params=params, refresh_auth=refresh_auth) + + def transform_data(raw_data: list[dict], repo: str) -> dict: """ Transform GitHub pull request data into BigQuery-compatible format. @@ -926,6 +953,7 @@ def _build_session() -> requests.Session: { "Accept": "application/vnd.github+json", "User-Agent": "gh-pr-scraper/1.0 (+https://api.github.com)", + "X-GitHub-Api-Version": "2022-11-28", } ) return session diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py index 606b1bd..6bc5203 100644 --- a/tests/test_rate_limit.py +++ b/tests/test_rate_limit.py @@ -34,8 +34,8 @@ def test_sleep_for_rate_limit_when_reset_already_passed(mock_sleep, mock_time): main.sleep_for_rate_limit(mock_response) - # Should sleep for 0 seconds (max of 0 and negative value) - mock_sleep.assert_called_once_with(0) + # Reset already passed -> no positive wait, so we skip sleeping entirely. + mock_sleep.assert_not_called() @patch("time.sleep") @@ -63,3 +63,85 @@ def test_sleep_for_rate_limit_with_missing_headers(mock_sleep): # Should not sleep when headers are missing (defaults to remaining=1) mock_sleep.assert_not_called() + + +@patch("time.sleep") +def test_sleep_for_rate_limit_honors_retry_after(mock_sleep): + """Secondary rate limit: sleep for the Retry-After delay.""" + mock_response = Mock() + mock_response.headers = { + "X-RateLimit-Remaining": "5", # primary limit not exhausted + "Retry-After": "30", + } + + main.sleep_for_rate_limit(mock_response) + + mock_sleep.assert_called_once_with(30) + + +@patch("time.time") +@patch("time.sleep") +def test_sleep_for_rate_limit_takes_longer_of_reset_and_retry_after( + mock_sleep, mock_time +): + """When both signals are present, the longer wait wins.""" + mock_time.return_value = 1000 + + mock_response = Mock() + mock_response.headers = { + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": "1060", # 60 seconds from now + "Retry-After": "90", + } + + main.sleep_for_rate_limit(mock_response) + + mock_sleep.assert_called_once_with(90) + + +@patch("time.sleep") +def test_sleep_for_rate_limit_ignores_http_date_retry_after(mock_sleep): + """A Retry-After HTTP-date (non-integer) is ignored without raising.""" + mock_response = Mock() + mock_response.headers = { + "X-RateLimit-Remaining": "5", + "Retry-After": "Wed, 21 Oct 2026 07:28:00 GMT", + } + + main.sleep_for_rate_limit(mock_response) + + mock_sleep.assert_not_called() + + +def test_is_rate_limited_primary(): + """403/429 with X-RateLimit-Remaining: 0 is a primary rate limit.""" + for status in (403, 429): + resp = Mock() + resp.status_code = status + resp.headers = {"X-RateLimit-Remaining": "0"} + assert main._is_rate_limited(resp) is True + + +def test_is_rate_limited_secondary(): + """403/429 with a Retry-After header is a secondary rate limit.""" + for status in (403, 429): + resp = Mock() + resp.status_code = status + resp.headers = {"X-RateLimit-Remaining": "5", "Retry-After": "30"} + assert main._is_rate_limited(resp) is True + + +def test_is_rate_limited_plain_forbidden(): + """A 403 with no rate-limit signals is not treated as rate-limited.""" + resp = Mock() + resp.status_code = 403 + resp.headers = {} + assert main._is_rate_limited(resp) is False + + +def test_is_rate_limited_non_rate_limit_status(): + """Non-403/429 statuses are never rate limits.""" + resp = Mock() + resp.status_code = 500 + resp.headers = {"X-RateLimit-Remaining": "0", "Retry-After": "30"} + assert main._is_rate_limited(resp) is False diff --git a/tests/test_token_store.py b/tests/test_token_store.py new file mode 100644 index 0000000..24c59fe --- /dev/null +++ b/tests/test_token_store.py @@ -0,0 +1,84 @@ +""" +Tests for installation-token rate-limit detection and caching. + +Covers: + - _is_rate_limited: treats both 403 and 429 as rate-limit signals, but only + when X-RateLimit-Remaining is 0. + - TokenStore: caches valid tokens, expires tokens within the 60s skew window, + and hands out a stable per-installation lock. +""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock + +import main +from main import AccessToken, TokenStore, _is_rate_limited + + +def _resp(status_code, remaining=None): + resp = Mock() + resp.status_code = status_code + resp.headers = {} if remaining is None else {"X-RateLimit-Remaining": remaining} + return resp + + +def test_is_rate_limited_403_remaining_zero(): + assert _is_rate_limited(_resp(403, "0")) is True + + +def test_is_rate_limited_429_remaining_zero(): + """429 (secondary/abuse rate limit) is handled the same as 403.""" + assert _is_rate_limited(_resp(429, "0")) is True + + +def test_is_rate_limited_403_with_remaining(): + assert _is_rate_limited(_resp(403, "5")) is False + + +def test_is_rate_limited_429_with_remaining(): + assert _is_rate_limited(_resp(429, "12")) is False + + +def test_is_rate_limited_other_status(): + assert _is_rate_limited(_resp(500, "0")) is False + + +def test_is_rate_limited_missing_header_defaults_not_limited(): + """A missing X-RateLimit-Remaining header defaults to 1 (not limited).""" + assert _is_rate_limited(_resp(403)) is False + + +def test_token_store_caches_valid_token(): + store = TokenStore() + expires = datetime.now(timezone.utc) + timedelta(minutes=30) + store.store(123, AccessToken(token="abc", expires_at=expires)) + + assert store.cached_token(123) == "abc" + + +def test_token_store_misses_for_unknown_installation(): + store = TokenStore() + assert store.cached_token(999) is None + + +def test_token_store_expires_token_within_skew_window(): + """Tokens within 60s of expiry are treated as expired.""" + store = TokenStore() + expires = datetime.now(timezone.utc) + timedelta(seconds=30) + store.store(123, AccessToken(token="abc", expires_at=expires)) + + assert store.cached_token(123) is None + + +def test_token_store_lock_is_stable_per_installation(): + store = TokenStore() + lock_a = store.lock_for(1) + lock_b = store.lock_for(1) + lock_c = store.lock_for(2) + + assert lock_a is lock_b + assert lock_a is not lock_c + + +def test_module_level_token_store_exists(): + assert isinstance(main.token_store, TokenStore)