Skip to content

Commit bd2d761

Browse files
authored
Test clients can now use dispatcher (#2)
1 parent 229197b commit bd2d761

File tree

5 files changed

+228
-99
lines changed

5 files changed

+228
-99
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ options:
2121
{
2222
"name": "Example Scenario", // name of scenario
2323
"address": "127.0.0.1:8080", // port to run dispatcher on
24+
"type": "sam", // whether to use sam or denim infrastructure (valid: sam, denim)
2425
"clients": 1, // how many clients to register
2526
/* How many groups of clients that should communicate.
2627
Each group has at least one denim client that communicates with a denim client from another group.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ name = "sam_dispatcher"
33
version = "0.1.0"
44
authors = [{ name = "SAM Research" }]
55

6-
dependencies = ["fastapi", "uvicorn", "pydantic"]
6+
dependencies = ["fastapi", "uvicorn", "pydantic", "asyncio"]
77

88
[project.scripts]
99
sam-dispatch = "sam_dispatcher.server:main"
10+
11+
[project.optional-dependencies]
12+
test = ["pytest", "pytest-asyncio"]

src/sam_dispatcher/server.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,72 @@
11
from fastapi import FastAPI
22
from argparse import ArgumentParser
33
import uvicorn
4-
from .state import State, ClientReport
5-
from fastapi.responses import JSONResponse
6-
from fastapi import Request
4+
from .state import State, ClientReport, AccountId
5+
from fastapi import Request, Response, HTTPException
6+
import asyncio
77

88
app = FastAPI()
99
state = None
1010

1111

12+
def auth(request: Request):
13+
try:
14+
client_id = request.cookies.get("id")
15+
except:
16+
raise HTTPException(status_code=401)
17+
host = request.client.host
18+
_id = create_id(host, client_id)
19+
20+
if not state.is_auth(_id):
21+
raise HTTPException(status_code=401)
22+
return _id
23+
24+
25+
def create_id(host: str, id: str):
26+
return f"{host}#{id}"
27+
28+
1229
@app.get("/client")
13-
async def client(request: Request):
14-
client_data = state.get_client(request.client.host)
30+
async def client(request: Request, response: Response):
31+
client_id = await state.next_client_id()
32+
response.set_cookie(key="id", value=client_id)
33+
client_data = await state.get_client(create_id(request.client.host, client_id))
1534
if client_data is None:
16-
return JSONResponse(
17-
status_code=403, content={"error": "Clients have been depleted"}
18-
)
35+
raise HTTPException(status_code=403)
1936
return client_data
2037

2138

22-
@app.post("/ready")
23-
async def ready(request: Request):
24-
state.ready(request.client.host)
39+
@app.post("/id")
40+
async def upload_id(request: Request, account_id: AccountId):
41+
_id = auth(request)
42+
await state.set_account_id(_id, account_id.account_id)
2543

2644

27-
@app.get("/start")
28-
async def start():
29-
return {"start": state.clients_ready, "epoch": state.start_time}
45+
@app.get("/sync")
46+
async def sync(request: Request):
47+
return await state.start(auth(request))
3048

3149

3250
@app.post("/upload")
3351
async def upload(request: Request, report: ClientReport):
34-
state.report(request.client.host, report)
52+
_id = auth(request)
53+
await state.report(_id, report)
54+
55+
if not state.all_clients_have_uploaded:
56+
return
57+
state.save_report()
58+
59+
60+
@app.get("/health")
61+
async def health():
62+
return "OK"
63+
64+
65+
async def setup_state(path: str):
66+
global state
67+
state = State(path)
68+
await state.init_state()
69+
return state.scenario.address.split(":")
3570

3671

3772
def main():
@@ -43,8 +78,5 @@ def main():
4378
args = parser.parse_args()
4479
config_path: str = args.config
4580

46-
state = State(config_path)
47-
48-
ip, port = state.scenario.address.split(":")
49-
81+
ip, port = asyncio.run(setup_state(config_path))
5082
uvicorn.run("sam_dispatcher.server:app", host=ip, port=int(port))

0 commit comments

Comments
 (0)