diff --git a/src/api/_util/resourcelimit.py b/src/api/_util/resourcelimit.py index 12306eb6a..25ac2144a 100644 --- a/src/api/_util/resourcelimit.py +++ b/src/api/_util/resourcelimit.py @@ -1,7 +1,7 @@ from collections.abc import Iterable, Sequence from datetime import UTC, datetime -from sqlalchemy import delete, func +from sqlalchemy import delete, func, not_ from sqlalchemy.dialects.mysql import insert from sqlalchemy.ext.asyncio import AsyncConnection from sqlmodel import col, select @@ -494,17 +494,25 @@ async def get_current_organization_allocations( *, exclude_branch_ids: Sequence[Identifier] | None = None, ) -> dict[ResourceType, int]: - result = await session.execute( - select(BranchProvisioning).join(Branch).join(Project).where(Project.organization_id == organization_id) + status_column = col(Branch.status) + branch_id_column = col(BranchProvisioning.branch_id) + + stmt = ( + select(BranchProvisioning) + .join(Branch) + .join(Project) + .where( + Project.organization_id == organization_id, + not_(status_column.in_([BranchServiceStatus.STOPPED, BranchServiceStatus.DELETING])), + ) ) - rows = list(result.scalars().all()) if exclude_branch_ids: - excluded = set(exclude_branch_ids) - rows = [row for row in rows if row.branch_id not in excluded] + stmt = stmt.where(not_(branch_id_column.in_(set(exclude_branch_ids)))) + result = await session.execute(stmt) + rows = list(result.scalars().all()) grouped = _group_by_resource_type(rows) - branch_statuses = await _collect_branch_statuses(session, rows) - return _aggregate_group_by_resource_type(grouped, branch_statuses) + return _aggregate_group_by_resource_type(grouped) async def get_current_project_allocations( @@ -513,49 +521,33 @@ async def get_current_project_allocations( *, exclude_branch_ids: Sequence[Identifier] | None = None, ) -> dict[ResourceType, int]: - result = await session.execute(select(BranchProvisioning).join(Branch).where(Branch.project_id == project_id)) - rows = list(result.scalars().all()) + status_column = col(Branch.status) + branch_id_column = col(BranchProvisioning.branch_id) + + stmt = ( + select(BranchProvisioning) + .join(Branch) + .where( + Branch.project_id == project_id, + not_(status_column.in_([BranchServiceStatus.STOPPED, BranchServiceStatus.DELETING])), + ) + ) if exclude_branch_ids: - excluded = set(exclude_branch_ids) - rows = [row for row in rows if row.branch_id not in excluded] + stmt = stmt.where(not_(branch_id_column.in_(set(exclude_branch_ids)))) + result = await session.execute(stmt) + rows = list(result.scalars().all()) grouped = _group_by_resource_type(rows) - branch_statuses = await _collect_branch_statuses(session, rows) - return _aggregate_group_by_resource_type(grouped, branch_statuses) + return _aggregate_group_by_resource_type(grouped) -def _aggregate_group_by_resource_type( - grouped: dict[ResourceType, list[BranchProvisioning]], branch_statuses: dict[Identifier, BranchServiceStatus] -) -> dict[ResourceType, int]: +def _aggregate_group_by_resource_type(grouped: dict[ResourceType, list[BranchProvisioning]]) -> dict[ResourceType, int]: return { - resource_type: sum( - allocation.amount - for allocation in allocations - if (allocation.branch_id is not None) - and ( - branch_statuses.get(allocation.branch_id) - not in {BranchServiceStatus.STOPPED, BranchServiceStatus.DELETING} - ) - ) + resource_type: sum(allocation.amount for allocation in allocations if allocation.branch_id is not None) for resource_type, allocations in grouped.items() } -async def _collect_branch_statuses( - _session: SessionDep, rows: list[BranchProvisioning] -) -> dict[Identifier, BranchServiceStatus]: - branch_ids = {row.branch_id for row in rows if row.branch_id is not None} - if not branch_ids: - return {} - - from ..organization.project import branch as branch_module - - statuses: dict[Identifier, BranchServiceStatus] = {} - for branch_id in branch_ids: - statuses[branch_id] = await branch_module.refresh_branch_status(branch_id) - return statuses - - def _group_by_resource_type(allocations: list[BranchProvisioning]) -> dict[ResourceType, list[BranchProvisioning]]: result: dict[ResourceType, list[BranchProvisioning]] = {} for allocation in allocations: