Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/api/organization/project/branch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@
from ....settings import get_settings as get_api_settings
from .api_keys import api as api_key_api
from .auth import api as auth_api
from .control_tasks import (
from .tasks import (
_CONTROL_TO_POWER_STATE,
dispatch_control,
dispatch_resize,
get_control_in_progress_status,
)
from .resize_tasks import dispatch_resize
from .tasks import task_api
from .tasks import api as task_api

api = APIRouter(tags=["branch"])

Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
"""Branch task list/detail endpoints.

Exposes Celery task state (resize and control) under:
GET .../branches/{branch_id}/tasks
GET .../branches/{branch_id}/tasks/{task_id}
"""

from datetime import datetime
from typing import Any
from typing import Any, Literal
from uuid import UUID

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel

from ...._util import Forbidden, NotFound, Unauthenticated
from ....dependencies import BranchDep, OrganizationDep, ProjectDep
from .control_tasks import perform_control
from .resize_tasks import finalize_resize
from ....._util import Forbidden, NotFound, Unauthenticated
from .....dependencies import BranchDep, OrganizationDep, ProjectDep
from ._control import _CONTROL_TO_POWER_STATE as _CONTROL_TO_POWER_STATE
from ._control import dispatch_control as dispatch_control
from ._control import get_control_in_progress_status as get_control_in_progress_status
from ._control import perform_control
from ._resize import dispatch_resize as dispatch_resize
from ._resize import finalize_resize

task_api = APIRouter(tags=["branch"])
api = APIRouter(tags=["branch"])

TaskType = Literal["control", "resize"]

_CELERY_STATE_TO_STATUS: dict[str, str] = {
"PENDING": "PENDING",
Expand All @@ -38,40 +37,29 @@ class BranchTaskPublic(BaseModel):
date_done: datetime | None


def _build_resize_task_public(task_id: UUID) -> BranchTaskPublic:
result = finalize_resize.AsyncResult(str(task_id))
state = result.state
status = _CELERY_STATE_TO_STATUS.get(state, state)
kwargs: dict = result.kwargs or {}
return BranchTaskPublic(
id=task_id,
task_type="resize",
status=status,
parameters=kwargs.get("effective_parameters", {}),
result=result.result if state == "SUCCESS" else None,
error=str(result.traceback) if state == "FAILURE" and result.traceback else None,
date_done=result.date_done,
)

def _build_task_public(task_id: UUID, task_type: TaskType) -> BranchTaskPublic:
tasks = {
"control": perform_control,
"resize": finalize_resize,
}
result = tasks[task_type].AsyncResult(str(task_id))

def _build_control_task_public(task_id: UUID) -> BranchTaskPublic:
result = perform_control.AsyncResult(str(task_id))
state = result.state
status = _CELERY_STATE_TO_STATUS.get(state, state)
kwargs: dict = result.kwargs or {}
action = kwargs.get("action", "control")
task_type = task_type if task_type != "control" else kwargs["action"]
return BranchTaskPublic(
id=task_id,
task_type=action,
task_type=task_type,
status=status,
parameters={"action": action},
parameters=kwargs.get("effective_parameters", {}),
result=result.result if state == "SUCCESS" else None,
error=str(result.traceback) if state == "FAILURE" and result.traceback else None,
date_done=result.date_done,
)


@task_api.get(
@api.get(
"/",
name="organizations:projects:branch:tasks:list",
response_model=list[BranchTaskPublic],
Expand All @@ -82,15 +70,14 @@ async def list_tasks(
_project: ProjectDep,
branch: BranchDep,
) -> list[BranchTaskPublic]:
tasks = []
if branch.resize_task_id is not None:
tasks.append(_build_resize_task_public(branch.resize_task_id))
if branch.control_task_id is not None:
tasks.append(_build_control_task_public(branch.control_task_id))
return tasks
tasks: list[tuple[UUID | None, TaskType]] = [
(branch.control_task_id, "control"),
(branch.resize_task_id, "resize"),
]
return [_build_task_public(task_id, task_type) for task_id, task_type in tasks if task_id is not None]


@task_api.get(
@api.get(
"/{task_id}",
name="organizations:projects:branch:tasks:detail",
response_model=BranchTaskPublic,
Expand All @@ -103,7 +90,7 @@ async def get_task(
task_id: UUID,
) -> BranchTaskPublic:
if branch.resize_task_id == task_id:
return _build_resize_task_public(task_id)
return _build_task_public(task_id, "resize")
if branch.control_task_id == task_id:
return _build_control_task_public(task_id)
return _build_task_public(task_id, "control")
raise HTTPException(status_code=404, detail="Task not found")
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from sqlalchemy.exc import NoResultFound
from ulid import ULID

from .....database import AsyncSessionLocal
from .....deployment import get_autoscaler_vm_identity
from .....deployment.health import query_deployment_status
from .....deployment.kubernetes.neonvm import Phase, PowerState, get_neon_vm, set_virtualmachine_power_state
from .....models.branch import BranchServiceStatus
from .....models.branch import lookup as branch_lookup
from .....worker import app
from ......database import AsyncSessionLocal
from ......deployment import get_autoscaler_vm_identity
from ......deployment.health import query_deployment_status
from ......deployment.kubernetes.neonvm import Phase, PowerState, get_neon_vm, set_virtualmachine_power_state
from ......models.branch import BranchServiceStatus
from ......models.branch import lookup as branch_lookup
from ......worker import app

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from celery import chord
from ulid import ULID

from .....database import AsyncSessionLocal
from .....deployment.health import collect_branch_service_health, derive_branch_status_from_services
from .....deployment.resize import resize_cpu_memory, resize_database_pvc, resize_iops, resize_storage_pvc
from .....models.branch import Branch
from .....models.resources import ResourceLimitsPublic
from .....worker import app
from ...._util.resourcelimit import apply_branch_resource_allocation
from ......database import AsyncSessionLocal
from ......deployment.health import collect_branch_service_health, derive_branch_status_from_services
from ......deployment.resize import resize_cpu_memory, resize_database_pvc, resize_iops, resize_storage_pvc
from ......models.branch import Branch
from ......models.resources import ResourceLimitsPublic
from ......worker import app
from ....._util.resourcelimit import apply_branch_resource_allocation

logger = logging.getLogger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions src/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,5 @@ class Settings(BaseSettings):
app.conf.beat_schedule_filename = "/tmp/celerybeat-schedule"

# Register tasks — must be imported after `app` is defined.
from ..api.organization.project.branch import control_tasks as _api_control_tasks # noqa: E402, F401
from ..api.organization.project.branch import resize_tasks as _api_resize_tasks # noqa: E402, F401
from ..api.organization.project.branch import tasks as _api_branch_tasks # noqa: E402, F401
from ..deployment import resize as _deployment_resize # noqa: E402, F401
Loading