11from fastapi import FastAPI
22from argparse import ArgumentParser
33import uvicorn
4- from .state import State , ClientReport
5- from fastapi . responses import JSONResponse
6- from fastapi import Request
4+ from .state import State , ClientReport , AccountId
5+ from fastapi import Request , Response , HTTPException
6+ import asyncio
77
88app = FastAPI ()
99state = None
1010
1111
12+ def auth (request : Request ):
13+ try :
14+ client_id = request .cookies .get ("id" )
15+ except :
16+ raise HTTPException (status_code = 401 )
17+ host = request .client .host
18+ _id = create_id (host , client_id )
19+
20+ if not state .is_auth (_id ):
21+ raise HTTPException (status_code = 401 )
22+ return _id
23+
24+
25+ def create_id (host : str , id : str ):
26+ return f"{ host } #{ id } "
27+
28+
1229@app .get ("/client" )
13- async def client (request : Request ):
14- client_data = state .get_client (request .client .host )
30+ async def client (request : Request , response : Response ):
31+ client_id = await state .next_client_id ()
32+ response .set_cookie (key = "id" , value = client_id )
33+ client_data = await state .get_client (create_id (request .client .host , client_id ))
1534 if client_data is None :
16- return JSONResponse (
17- status_code = 403 , content = {"error" : "Clients have been depleted" }
18- )
35+ raise HTTPException (status_code = 403 )
1936 return client_data
2037
2138
22- @app .post ("/ready" )
23- async def ready (request : Request ):
24- state .ready (request .client .host )
39+ @app .post ("/id" )
40+ async def upload_id (request : Request , account_id : AccountId ):
41+ _id = auth (request )
42+ await state .set_account_id (_id , account_id .account_id )
2543
2644
27- @app .get ("/start " )
28- async def start ( ):
29- return { "start" : state .clients_ready , "epoch" : state . start_time }
45+ @app .get ("/sync " )
46+ async def sync ( request : Request ):
47+ return await state .start ( auth ( request ))
3048
3149
3250@app .post ("/upload" )
3351async def upload (request : Request , report : ClientReport ):
34- state .report (request .client .host , report )
52+ _id = auth (request )
53+ await state .report (_id , report )
54+
55+ if not state .all_clients_have_uploaded :
56+ return
57+ state .save_report ()
58+
59+
60+ @app .get ("/health" )
61+ async def health ():
62+ return "OK"
63+
64+
65+ async def setup_state (path : str ):
66+ global state
67+ state = State (path )
68+ await state .init_state ()
69+ return state .scenario .address .split (":" )
3570
3671
3772def main ():
@@ -43,8 +78,5 @@ def main():
4378 args = parser .parse_args ()
4479 config_path : str = args .config
4580
46- state = State (config_path )
47-
48- ip , port = state .scenario .address .split (":" )
49-
81+ ip , port = asyncio .run (setup_state (config_path ))
5082 uvicorn .run ("sam_dispatcher.server:app" , host = ip , port = int (port ))
0 commit comments