From 6fada2dab369b822bf26a3cff6de81d6bdc6370c Mon Sep 17 00:00:00 2001 From: Abby Seseri Date: Fri, 13 Jun 2025 17:29:01 -0700 Subject: [PATCH 1/2] feat(auth): Implement JWT authentication and protected endpoints --- Makefile | 6 +- backend/auth/__init__.py | 0 backend/auth/router.py | 56 ++++++++++++++++ backend/auth/security.py | 37 +++++++++++ backend/create_admin.py | 21 ++++++ backend/database.py | 16 +++-- backend/main.py | 96 ++++++++++++++++------------ backend/models.py | 8 ++- backend/pytest.ini | 4 +- backend/requirements.txt | 8 +++ backend/tests/conftest.py | 84 +++++++++++++++++++++--- backend/tests/test_auth.py | 65 +++++++++++++++++++ backend/tests/test_main.py | 15 +++-- backend/tests/test_players_api.py | 44 +++++-------- backend/tests/test_similarity_api.py | 40 ++++++------ 15 files changed, 384 insertions(+), 116 deletions(-) create mode 100644 backend/auth/__init__.py create mode 100644 backend/auth/router.py create mode 100644 backend/auth/security.py create mode 100644 backend/create_admin.py create mode 100644 backend/tests/test_auth.py diff --git a/Makefile b/Makefile index a0fe00a..c4d6a16 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,7 @@ logs: ## Follow logs for all services # ============================================================================== # DEVELOPMENT & TESTING # ============================================================================== -setup: up seed-db build-model ## Run this once to setup a new LOCAL environment from scratch +setup: up seed-db build-model create-admin ## Run this once to setup a new LOCAL environment from scratch @echo "$(GREEN)✅ Initial local setup complete! Database is seeded and model is built.$(RESET)" seed-db: ## Run the database seed script on the LOCAL docker DB @@ -48,6 +48,10 @@ build-model: ## Run the ML model training script @echo "$(YELLOW)--> Building similarity model...$(RESET)" @docker compose exec backend python build_similarity_model.py +create-admin: + @echo "$(YELLOW)--> Creating admin user...$(RESET)" + @docker compose exec backend python create_admin.py + test: test-backend test-frontend ## Run all backend and frontend tests test-backend: ## Run backend python tests diff --git a/backend/auth/__init__.py b/backend/auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/auth/router.py b/backend/auth/router.py new file mode 100644 index 0000000..c364112 --- /dev/null +++ b/backend/auth/router.py @@ -0,0 +1,56 @@ +# backend/auth/router.py +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from sqlalchemy.orm import Session +from typing import Annotated +from jose import JWTError, jwt +from typing import List + +from . import security +import models +from database import get_db + +# This creates a new "router" that we can include in our main app +router = APIRouter() + +# Define the token endpoint URL and the security scheme +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token") + +@router.post("/token") +def login_for_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Session = Depends(get_db)): + """ + Handles the login request. Takes username and password from a form body. + """ + user = db.query(models.User).filter(models.User.username == form_data.username).first() + if not user or not security.verify_password(form_data.password, user.hashed_password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token = security.create_access_token(data={"sub": user.username}) + return {"access_token": access_token, "token_type": "bearer"} + +async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_db)): + """ + This is the dependency that will protect our routes. + It decodes the token, validates it, and fetches the user from the database. + """ + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, security.SECRET_KEY, algorithms=[security.ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + except JWTError: + raise credentials_exception + + user = db.query(models.User).filter(models.User.username == username).first() + if user is None: + raise credentials_exception + return user + diff --git a/backend/auth/security.py b/backend/auth/security.py new file mode 100644 index 0000000..645fcfa --- /dev/null +++ b/backend/auth/security.py @@ -0,0 +1,37 @@ +# backend/auth/security.py +from datetime import datetime, timedelta, timezone +from typing import Optional +from passlib.context import CryptContext +from jose import JWTError, jwt +import os + +# --- Configuration --- +# This is a hardcoded key for development. In a real production environment, +# you would load this from a secure environment variable or a secret manager. +SECRET_KEY = os.getenv("SECRET_KEY", "a_very_secret_key_for_a_portfolio_project") +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 + +# Setup for password hashing. bcrypt is the standard, secure algorithm. +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verifies a plain password against a hashed one.""" + return pwd_context.verify(plain_password, hashed_password) + +def get_password_hash(password: str) -> str: + """Hashes a plain password.""" + return pwd_context.hash(password) + +# --- JWT Token Logic --- +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): + """Creates a new JWT access token.""" + to_encode = data.copy() + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt diff --git a/backend/create_admin.py b/backend/create_admin.py new file mode 100644 index 0000000..bdf40bc --- /dev/null +++ b/backend/create_admin.py @@ -0,0 +1,21 @@ +# backend/create_admin.py +from database import SessionLocal, engine +from models import User, Base +from auth.security import get_password_hash + +db = SessionLocal() +# Ensure the 'users' table exists +Base.metadata.create_all(bind=engine) + +# Check if user already exists +if db.query(User).filter(User.username == "admin").first(): + print("Admin user already exists.") +else: + # Create a new user with a hashed password + hashed_password = get_password_hash("your_secret_password") # Choose a strong password + admin_user = User(username="admin", hashed_password=hashed_password) + db.add(admin_user) + db.commit() + print("Admin user created successfully.") + +db.close() diff --git a/backend/database.py b/backend/database.py index 16a1334..84147a4 100644 --- a/backend/database.py +++ b/backend/database.py @@ -32,9 +32,15 @@ def get_database_url(): return "postgresql://admin:password123@db:5432/wnba_db" DATABASE_URL = get_database_url() - -# --- Use the DATABASE_URL variable --- -engine = create_engine(DATABASE_URL) - +connect_args = {"check_same_thread": False} if "sqlite" in DATABASE_URL else {} +engine = create_engine(DATABASE_URL, connect_args=connect_args) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base = declarative_base() \ No newline at end of file +Base = declarative_base() + +def get_db(): + """Dependency to get a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/backend/main.py b/backend/main.py index 4cc734e..47f3727 100644 --- a/backend/main.py +++ b/backend/main.py @@ -8,7 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware from sqlalchemy.orm import Session from pydantic import BaseModel, ConfigDict -from typing import List +from typing import List, Annotated import joblib import pandas as pd @@ -17,6 +17,9 @@ # Import your SQLAlchemy models and session management import models import database +from database import get_db +from auth.router import router as auth_router # Import our new auth router +from auth.router import get_current_user # Import our new dependency # --- The Lifespan function now loads the model --- @asynccontextmanager @@ -41,9 +44,12 @@ async def lifespan(app: FastAPI): logger.info("Application shutdown.") app = FastAPI(lifespan=lifespan) +app.include_router(auth_router, prefix="/auth", tags=["Authentication"]) -origins = ["http://localhost:3000", - "https://wnba-frontend-service-776933261932.us-west1.run.app"] +origins = [ + "http://localhost:3000", + "https://wnba-frontend-service-776933261932.us-west1.run.app" +] app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -52,14 +58,6 @@ async def lifespan(app: FastAPI): allow_headers=["*"], ) -# Dependency to get a database session for each request -def get_db(): - db = database.SessionLocal() - try: - yield db - finally: - db.close() - # --- Pydantic Schemas --- class PlayerStatBase(BaseModel): season: str @@ -95,44 +93,24 @@ class Player(PlayerBase): stats: List[PlayerStat] = [] model_config = ConfigDict(from_attributes=True) -# ---- API ENDPOINTS ---- -@app.get("/api") -def read_root(): - return {"message": "WNBA Analytics API is running!"} +class UserOut(BaseModel): + id: int + username: str + model_config = ConfigDict(from_attributes=True) +# ---- PROTECTED API ENDPOINTS ---- # Endpoint to CREATE a new player @app.post("/api/players", response_model=Player) -def create_player(player: PlayerCreate, db: Session = Depends(get_db)): +def create_player(player: PlayerCreate, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)): new_player = models.Player(**player.model_dump()) db.add(new_player) db.commit() db.refresh(new_player) return new_player -# Endpoint to READ all players -@app.get("/api/players", response_model=List[Player]) -def get_players(db: Session = Depends(get_db)): - players = db.query(models.Player).all() - return players - -@app.get("/api/players/{player_id}", response_model=Player) -def get_player(player_id: int, db: Session = Depends(get_db)): - player = db.query(models.Player).filter(models.Player.id == player_id).first() - if player is None: raise HTTPException(status_code=404, detail="Player not found") - return player - -# Endpoint to DELETE a player -@app.delete("/api/players/{player_id}") -def delete_player(player_id: int, db: Session = Depends(get_db)): - player_to_delete = db.query(models.Player).filter(models.Player.id == player_id).first() - if player_to_delete is None: raise HTTPException(status_code=404, detail="Player not found") - db.delete(player_to_delete) - db.commit() - return {"message": "Player deleted successfully"} - -# Endpoint to UPDATE a player +# Endpoint to UPDATE a player @app.put("/api/players/{player_id}", response_model=Player) -def update_player(player_id: int, player_update: PlayerCreate, db: Session = Depends(get_db)): +def update_player(player_id: int, player_update: PlayerCreate, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)): player_to_update = db.query(models.Player).filter(models.Player.id == player_id).first() if player_to_update is None: raise HTTPException(status_code=404, detail="Player not found") @@ -144,9 +122,17 @@ def update_player(player_id: int, player_update: PlayerCreate, db: Session = Dep db.refresh(player_to_update) return player_to_update +# Endpoint to DELETE a player +@app.delete("/api/players/{player_id}") +def delete_player(player_id: int, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)): + player_to_delete = db.query(models.Player).filter(models.Player.id == player_id).first() + if player_to_delete is None: raise HTTPException(status_code=404, detail="Player not found") + db.delete(player_to_delete) + db.commit() + return {"message": "Player deleted successfully"} + @app.post("/api/players/{player_id}/stats", response_model=PlayerStat) -def create_stats_for_player(player_id: int, stat: PlayerStatCreate, - db: Session = Depends(get_db)): +def create_stats_for_player(player_id: int, stat: PlayerStatCreate, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)): db_player = db.query(models.Player).filter(models.Player.id == player_id).first() if db_player is None: raise HTTPException(status_code=404, detail="Player not found") db_stat = models.PlayerStat(**stat.model_dump(), player_id=player_id) @@ -155,6 +141,34 @@ def create_stats_for_player(player_id: int, stat: PlayerStatCreate, db.refresh(db_stat) return db_stat +@app.get("/users", response_model=List[UserOut]) +def read_users(current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)): + """ + Retrieves a list of all users. + This is a protected endpoint that requires authentication. + The `response_model` ensures that only the fields from `UserOut` (id, username) + are returned, protecting the hashed password. + """ + users = db.query(models.User).all() + return users + +# ---- PUBLIC API ENDPOINTS ---- +@app.get("/api") +def read_root(): + return {"message": "WNBA Analytics API is running!"} + +# Endpoint to READ all players +@app.get("/api/players", response_model=List[Player]) +def get_players(db: Session = Depends(get_db)): + players = db.query(models.Player).all() + return players + +@app.get("/api/players/{player_id}", response_model=Player) +def get_player(player_id: int, db: Session = Depends(get_db)): + player = db.query(models.Player).filter(models.Player.id == player_id).first() + if player is None: raise HTTPException(status_code=404, detail="Player not found") + return player + # A Pydantic schema for the similarity response class SimilarPlayer(BaseModel): player_season_id: str diff --git a/backend/models.py b/backend/models.py index c77bc15..41239a0 100644 --- a/backend/models.py +++ b/backend/models.py @@ -32,4 +32,10 @@ class PlayerStat(Base): player_efficiency_rating = Column(Float) player_id = Column(Integer, ForeignKey("players.id")) - player = relationship("Player", back_populates="stats") \ No newline at end of file + player = relationship("Player", back_populates="stats") + +class User(Base): + __tablename__ = "users" + id = Column(Integer, primary_key=True, index=True) + username = Column(String, unique=True, index=True) + hashed_password = Column(String) diff --git a/backend/pytest.ini b/backend/pytest.ini index 9ccee47..cf012a4 100644 --- a/backend/pytest.ini +++ b/backend/pytest.ini @@ -1,3 +1,5 @@ # backend/pytest.ini [pytest] -pythonpath = . \ No newline at end of file +pythonpath = . +# This tells pytest to ignore any DeprecationWarning originating from the passlib library +filterwarnings = ignore::DeprecationWarning:passlib.* \ No newline at end of file diff --git a/backend/requirements.txt b/backend/requirements.txt index 55a39b3..ef585d9 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,11 +1,14 @@ annotated-types==0.7.0 anyio==4.9.0 +bcrypt==4.0.1 certifi==2025.4.26 click==8.2.1 dnspython==2.7.0 +ecdsa==0.19.1 email_validator==2.2.0 fastapi==0.115.12 fastapi-cli==0.0.7 +greenlet==3.2.3 h11==0.16.0 httpcore==1.0.9 httptools==0.6.4 @@ -15,6 +18,7 @@ iniconfig==2.1.0 itsdangerous==2.2.0 Jinja2==3.1.6 joblib==1.5.1 +jose==1.0.0 markdown-it-py==3.0.0 MarkupSafe==3.0.2 mdurl==0.1.2 @@ -22,8 +26,10 @@ numpy==2.3.0 orjson==3.10.18 packaging==25.0 pandas==2.3.0 +passlib==1.7.4 pluggy==1.6.0 psycopg2-binary==2.9.10 +pyasn1==0.6.1 pydantic==2.11.5 pydantic-extra-types==2.10.5 pydantic-settings==2.9.1 @@ -32,11 +38,13 @@ Pygments==2.19.1 pytest==8.4.0 python-dateutil==2.9.0.post0 python-dotenv==1.1.0 +python-jose==3.5.0 python-multipart==0.0.20 pytz==2025.2 PyYAML==6.0.2 rich==14.0.0 rich-toolkit==0.14.7 +rsa==4.9.1 scikit-learn==1.7.0 scipy==1.15.3 shellingham==1.5.4 diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index eceb2b7..531ae54 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,39 +1,103 @@ # backend/tests/conftest.py import pytest import os -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -# Set the environment variable BEFORE other imports. +# This MUST be the first thing to run. It configures the app for testing. os.environ['DATABASE_URL'] = "sqlite:///./test.db" +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + from main import app, get_db from database import Base, engine +from auth.security import get_password_hash +import models TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @pytest.fixture(scope="function") def db_session(): """ - This fixture creates a clean database with all tables for a single test function, - and then drops all tables after the test is done. + The single fixture to provide a clean database session for each test. + It creates all tables before the test and drops them all after. + This guarantees perfect test isolation. """ + # --- SETUP --- Base.metadata.create_all(bind=engine) - # This is the key: we override the app's dependency to use our - # clean, temporary database for the duration of the test. def override_get_db(): + db = TestingSessionLocal() try: - db = TestingSessionLocal() yield db finally: db.close() + # Override the dependency for the duration of the test app.dependency_overrides[get_db] = override_get_db - # Yield nothing, just perform setup and teardown yield - # Teardown: clean up the override and drop tables + # --- TEARDOWN --- + del app.dependency_overrides[get_db] + Base.metadata.drop_all(bind=engine) + +@pytest.fixture(scope="function") +def test_client(): + """ + Provides a TestClient with a clean database for each test. + Handles startup/shutdown events and DB table creation/destruction. + """ + # --- SETUP --- + Base.metadata.create_all(bind=engine) + + def override_get_db(): + """This function overrides the production database dependency.""" + try: + db = TestingSessionLocal() + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override_get_db + + with TestClient(app) as client: + yield client + + # --- TEARDOWN --- + # Clean up after the test is done del app.dependency_overrides[get_db] Base.metadata.drop_all(bind=engine) + + +@pytest.fixture(scope="function") +def authenticated_client(test_client): + """ + Provides an authenticated TestClient. It uses the regular test_client + and performs a login to get an auth token. + """ + # We need a direct session to create the user in the database first + db = TestingSessionLocal() + password = "testpassword" + + # Create user if it doesn't exist to avoid IntegrityError + user = db.query(models.User).filter(models.User.username == "testuser").first() + if not user: + user = models.User(username="testuser", hashed_password=get_password_hash(password)) + db.add(user) + db.commit() + + db.close() + + # Log in as the test user to get a token + login_response = test_client.post( + "/auth/token", + data={"username": "testuser", "password": password} + ) + assert login_response.status_code == 200, "Failed to log in during test setup" + token = login_response.json()["access_token"] + + # Set the auth header for all future requests with this client + test_client.headers["Authorization"] = f"Bearer {token}" + + return test_client \ No newline at end of file diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py new file mode 100644 index 0000000..906978a --- /dev/null +++ b/backend/tests/test_auth.py @@ -0,0 +1,65 @@ +# backend/tests/test_auth.py +from auth.security import get_password_hash +import models +from database import SessionLocal + +def test_login_for_access_token(test_client): + """Tests if a user can successfully log in with correct credentials.""" + # Setup: Create a user directly in the test database + db = SessionLocal() + password = "correct_password" + user = models.User(username="logintestuser", hashed_password=get_password_hash(password)) + db.add(user) + db.commit() + db.close() + + # Execute & Assert + response = test_client.post( + "/auth/token", + data={"username": "logintestuser", "password": "correct_password"} + ) + assert response.status_code == 200, response.text + assert "access_token" in response.json() + assert response.json()["token_type"] == "bearer" + +def test_login_with_wrong_password(test_client): + """Tests that login fails with an incorrect password.""" + db = SessionLocal() + user = models.User(username="wrongpassuser", hashed_password=get_password_hash("correct_password")) + db.add(user) + db.commit() + db.close() + + response = test_client.post( + "/auth/token", + data={"username": "wrongpassuser", "password": "incorrect_password"} + ) + assert response.status_code == 401 # Unauthorized + +def test_read_users_as_authenticated_user(authenticated_client): + """ + Tests that a logged-in user can successfully fetch the list of users. + """ + # The authenticated_client fixture has already created and logged in a "testuser" + response = authenticated_client.get("/users") + + # Assert success + assert response.status_code == 200 + data = response.json() + + # Assert the response format is correct + assert isinstance(data, list) + assert len(data) > 0 + assert data[0]["username"] == "testuser" + + # CRITICAL: Assert that the hashed password is NOT present + assert "hashed_password" not in data[0] + +def test_read_users_unauthenticated(test_client): + """ + Tests that a non-logged-in user receives a 401 error when trying + to access the protected user list. + """ + # We use the basic, unauthenticated client here + response = test_client.get("/users") + assert response.status_code == 401 # Unauthorized \ No newline at end of file diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index 7a10005..bdf90b2 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -1,11 +1,12 @@ # backend/tests/test_main.py -from fastapi.testclient import TestClient -from main import app -# This test doesn't need a database, so it doesn't need a fixture. -def test_read_root(): - # It creates its own client. - client = TestClient(app) - response = client.get("/api") +def test_read_root(test_client): + """Tests the public root API endpoint.""" + response = test_client.get("/api") assert response.status_code == 200 assert response.json() == {"message": "WNBA Analytics API is running!"} + +def test_get_players_publicly(test_client): + """Tests that anyone can view the list of players.""" + response = test_client.get("/api/players") + assert response.status_code == 200 \ No newline at end of file diff --git a/backend/tests/test_players_api.py b/backend/tests/test_players_api.py index 3574fbe..b63af6f 100644 --- a/backend/tests/test_players_api.py +++ b/backend/tests/test_players_api.py @@ -1,32 +1,18 @@ # backend/tests/test_players_api.py -from fastapi.testclient import TestClient -from main import app -# Each test asks for the `db_session` fixture to ensure the database is clean. -def test_create_and_get_player(db_session): - client = TestClient(app) # Create a client for this test - create_response = client.post("/api/players", json={"first_name": "Caitlin", "last_name": "Clark", "team": "Indiana Fever"}) - assert create_response.status_code == 200 - player_id = create_response.json()["id"] +def test_create_player_as_authenticated_user(authenticated_client): + """Tests that a logged-in user can create a player.""" + response = authenticated_client.post( + "/api/players", + json={"first_name": "Caitlin", "last_name": "Clark", "team": "Indiana Fever"} + ) + assert response.status_code == 200 + assert response.json()["first_name"] == "Caitlin" - get_response = client.get(f"/api/players/{player_id}") - assert get_response.status_code == 200 - assert get_response.json()["first_name"] == "Caitlin" - -def test_get_nonexistent_player_returns_404(db_session): - client = TestClient(app) - response = client.get("/api/players/99999") - assert response.status_code == 404 - -def test_add_and_get_stats_for_player(db_session): - client = TestClient(app) - player_res = client.post("/api/players", json={"first_name": "Sabrina", "last_name": "Ionescu", "team": "New York Liberty"}) - player_id = player_res.json()["id"] - stats_payload = { - "season": "2024", "points_per_game": 19, "rebounds_per_game": 4, "assists_per_game": 7, - "games_played": 30, "games_started": 30, "field_goal_percentage": 0.45, - "three_point_percentage": 0.35, "steals_per_game": 1.5, "blocks_per_game": 0.8, - "player_efficiency_rating": 22.5 - } - stats_res = client.post(f"/api/players/{player_id}/stats", json=stats_payload) - assert stats_res.status_code == 200 +def test_create_player_unauthenticated(test_client): + """Tests that a non-logged-in user cannot create a player.""" + response = test_client.post( + "/api/players", + json={"first_name": "Caitlin", "last_name": "Clark", "team": "Indiana Fever"} + ) + assert response.status_code == 401 # Unauthorized \ No newline at end of file diff --git a/backend/tests/test_similarity_api.py b/backend/tests/test_similarity_api.py index 0f20afd..88dec94 100644 --- a/backend/tests/test_similarity_api.py +++ b/backend/tests/test_similarity_api.py @@ -3,37 +3,35 @@ from unittest.mock import patch import pandas as pd import numpy as np -import pytest -from main import app +from main import app, get_db +import models +from auth.security import get_password_hash -# This test also asks for `db_session` to ensure the database is clean. def test_get_similar_players(db_session): - mock_player_vectors = pd.DataFrame( - {'pts': [0.9, -0.5, 0.85], 'reb': [0.8, -0.6, 0.9]}, - index=['Player A (2024)', 'Player B (2024)', 'Player C (2024)'] - ) - mock_matrix = np.array([[1.0, 0.1, 0.95], [0.1, 1.0, 0.2], [0.95, 0.2, 1.0]]) + mock_player_vectors = pd.DataFrame(index=['Player A (2024)', 'Player C (2024)']) + mock_matrix = np.array([[1.0, 0.95], [0.95, 1.0]]) - # 1. We set up the mock using 'with patch' with patch('main.joblib.load') as mock_joblib_load: - def side_effect(filename): - if "data" in filename: return mock_player_vectors - if "matrix" in filename: return mock_matrix - mock_joblib_load.side_effect = side_effect + mock_joblib_load.side_effect = [mock_player_vectors, mock_matrix] - # 2. We create the TestClient INSIDE the patch block. - # This is the critical part that forces the app's startup - # event to run while our mock is active. with TestClient(app) as client: - # 3. Setup the DB with a player - player_res = client.post("/api/players", json={"first_name": "Player", "last_name": "A", "team": "Team A"}) + # Create user, log in, get headers + db = client.app.dependency_overrides[get_db]().__next__() + user = models.User(username="testuser", hashed_password=get_password_hash("password")) + db.add(user) + db.commit() + login_res = client.post("/auth/token", data={"username": "testuser", "password": "password"}) + token = login_res.json()["access_token"] + headers = {"Authorization": f"Bearer {token}"} + db.close() + + player_res = client.post("/api/players", headers=headers, json={"first_name": "Player", "last_name": "A", "team": "Team A"}) assert player_res.status_code == 200 player_id = player_res.json()["id"] - # 4. Call the API endpoint + # Call the similarity endpoint response = client.get(f"/api/players/{player_id}/seasons/2024/similar") - # 5. Assert the response assert response.status_code == 200 data = response.json() - assert data[0]["player_season_id"] == "Player C (2024)" + assert data[0]["player_season_id"] == "Player C (2024)" \ No newline at end of file From b43bb18a1784a3ab3fcfffa854bfe7a1f21444bc Mon Sep 17 00:00:00 2001 From: Abby Seseri Date: Tue, 17 Jun 2025 13:43:40 -0700 Subject: [PATCH 2/2] feat(ui): Implement persistent theme and conditional admin controls --- .gitignore | 2 +- Makefile | 5 ++- backend/create_admin.py | 43 ++++++++++++------ frontend/src/AuthContext.js | 45 +++++++++++++++++++ frontend/src/ThemeContext.js | 9 +++- frontend/src/components/App.js | 2 + frontend/src/components/App.test.js | 15 ++++--- frontend/src/components/Layout.js | 17 ++++++- .../src/components/LoginPage/LoginPage.js | 39 ++++++++++++++++ .../components/LoginPage/LoginPage.test.js | 45 +++++++++++++++++++ .../src/components/RosterPage/RosterPage.js | 22 +++++---- frontend/src/index.js | 9 ++-- 12 files changed, 217 insertions(+), 36 deletions(-) create mode 100644 frontend/src/AuthContext.js create mode 100644 frontend/src/components/LoginPage/LoginPage.js create mode 100644 frontend/src/components/LoginPage/LoginPage.test.js diff --git a/.gitignore b/.gitignore index f4f04a2..6581526 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,4 @@ backend/tests.http *.joblib # Environment Variables -.env +**/.env diff --git a/Makefile b/Makefile index c4d6a16..9132c03 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,8 @@ # Makefile for WNBA Analytics Project +include .env +export + # Color codes for prettier output GREEN := $(shell tput -T xterm setaf 2) YELLOW := $(shell tput -T xterm setaf 3) @@ -50,7 +53,7 @@ build-model: ## Run the ML model training script create-admin: @echo "$(YELLOW)--> Creating admin user...$(RESET)" - @docker compose exec backend python create_admin.py + @docker compose exec -e ADMIN_USER=$(ADMIN_USER) -e ADMIN_PASSWORD=$(ADMIN_PASSWORD) backend python create_admin.py test: test-backend test-frontend ## Run all backend and frontend tests diff --git a/backend/create_admin.py b/backend/create_admin.py index bdf40bc..1c25ed7 100644 --- a/backend/create_admin.py +++ b/backend/create_admin.py @@ -1,21 +1,36 @@ # backend/create_admin.py +import os from database import SessionLocal, engine from models import User, Base from auth.security import get_password_hash +import logging -db = SessionLocal() -# Ensure the 'users' table exists -Base.metadata.create_all(bind=engine) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -# Check if user already exists -if db.query(User).filter(User.username == "admin").first(): - print("Admin user already exists.") -else: - # Create a new user with a hashed password - hashed_password = get_password_hash("your_secret_password") # Choose a strong password - admin_user = User(username="admin", hashed_password=hashed_password) - db.add(admin_user) - db.commit() - print("Admin user created successfully.") +ADMIN_USERNAME = os.getenv("ADMIN_USER", "admin") +ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD") -db.close() +def create_admin_user(): + if not ADMIN_PASSWORD: + logger.error("ADMIN_PASSWORD environment variable not set. Cannot create admin user.") + return + + db = SessionLocal() + Base.metadata.create_all(bind=engine) # Ensure the 'users' table exists + + # Check if user already exists + if db.query(User).filter(User.username == ADMIN_USERNAME).first(): + logger.warning(f"Admin user '{ADMIN_USERNAME}' already exists.") + else: + # Create a new user with a hashed password + hashed_password = get_password_hash(ADMIN_PASSWORD) + admin_user = User(username=ADMIN_USERNAME, hashed_password=hashed_password) + db.add(admin_user) + db.commit() + logger.info(f"Admin user '{ADMIN_USERNAME}' created successfully.") + + db.close() + +if __name__ == "__main__": + create_admin_user() diff --git a/frontend/src/AuthContext.js b/frontend/src/AuthContext.js new file mode 100644 index 0000000..99befda --- /dev/null +++ b/frontend/src/AuthContext.js @@ -0,0 +1,45 @@ +// frontend/src/AuthContext.js +import React, { createContext, useState, useContext, useMemo } from 'react'; + +export const AuthContext = createContext(null); + +export const AuthProvider = ({ children }) => { + const [token, setToken] = useState(localStorage.getItem('authToken')); + + const login = async (username, password) => { + const response = await fetch(`${process.env.REACT_APP_API_BASE_URL}/auth/token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: new URLSearchParams({ username, password }) + }); + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.detail || 'Failed to login'); + } + const data = await response.json(); + setToken(data.access_token); + localStorage.setItem('authToken', data.access_token); + }; + + const logout = () => { + setToken(null); + localStorage.removeItem('authToken'); + }; + + // useMemo ensures this object is not recreated on every render + const authContextValue = useMemo(() => ({ + token, + login, + logout + }), [token]); + + return ( + + {children} + + ); +}; + +export const useAuth = () => { + return useContext(AuthContext); +}; \ No newline at end of file diff --git a/frontend/src/ThemeContext.js b/frontend/src/ThemeContext.js index b23b78b..3bbca66 100644 --- a/frontend/src/ThemeContext.js +++ b/frontend/src/ThemeContext.js @@ -8,12 +8,17 @@ export const ThemeContext = createContext({ }); export const AppThemeProvider = ({ children }) => { - const [mode, setMode] = useState('dark'); // Default to dark mode + const storedMode = localStorage.getItem('themeMode') || 'dark'; // Default to dark mode + const [mode, setMode] = useState(storedMode); const theme = useMemo(() => (mode === 'light' ? lightTheme : darkTheme), [mode]); const toggleTheme = () => { - setMode((prevMode) => (prevMode === 'light' ? 'dark' : 'light')); + setMode((prevMode) => { + const newMode = prevMode === 'light' ? 'dark' : 'light'; + localStorage.setItem('themeMode', newMode); + return newMode; + }); }; return ( diff --git a/frontend/src/components/App.js b/frontend/src/components/App.js index 2865882..187964e 100644 --- a/frontend/src/components/App.js +++ b/frontend/src/components/App.js @@ -4,6 +4,7 @@ import { Routes, Route } from 'react-router-dom'; import Layout from './Layout'; import RosterPage from './RosterPage/RosterPage'; import PlayerDetailPage from './PlayerDetailPage/PlayerDetailPage'; +import LoginPage from './LoginPage/LoginPage'; import './App.css'; function App() { @@ -12,6 +13,7 @@ function App() { } /> } /> + } /> ); diff --git a/frontend/src/components/App.test.js b/frontend/src/components/App.test.js index 454f8b7..b2f28f2 100644 --- a/frontend/src/components/App.test.js +++ b/frontend/src/components/App.test.js @@ -2,13 +2,18 @@ import { render, screen } from '@testing-library/react'; import { MemoryRouter } from 'react-router-dom'; import App from './App'; +import { AuthProvider } from '../AuthContext'; +import { AppThemeProvider } from '../ThemeContext'; test('renders the RosterPage for the home route', () => { render( - - - + + + + + + + ); - // We expect the main heading from the RosterPage to be present expect(screen.getByRole('heading', { name: /WNBA Player Roster/i })).toBeInTheDocument(); -}); +}); \ No newline at end of file diff --git a/frontend/src/components/Layout.js b/frontend/src/components/Layout.js index 29e98ee..d1f3564 100644 --- a/frontend/src/components/Layout.js +++ b/frontend/src/components/Layout.js @@ -1,14 +1,22 @@ // frontend/src/components/Layout.js import React, { useContext } from 'react'; -import { Box, AppBar, Toolbar, Typography, IconButton, Container } from '@mui/material'; +import { Box, AppBar, Toolbar, Typography, IconButton, Container, Button } from '@mui/material'; import { Brightness4, Brightness7 } from '@mui/icons-material'; import { useTheme } from '@mui/material/styles'; import { ThemeContext } from '../ThemeContext'; -import { Link } from 'react-router-dom'; +import { Link, useNavigate } from 'react-router-dom'; +import { useAuth } from '../AuthContext'; const Layout = ({ children }) => { const theme = useTheme(); const colorMode = useContext(ThemeContext); + const auth = useAuth(); + const navigate = useNavigate(); + + const handleLogout = () => { + auth.logout(); + navigate('/'); + }; return ( @@ -22,6 +30,11 @@ const Layout = ({ children }) => { {theme.palette.mode === 'dark' ? : } + {auth.token ? ( + + ) : ( + + )} diff --git a/frontend/src/components/LoginPage/LoginPage.js b/frontend/src/components/LoginPage/LoginPage.js new file mode 100644 index 0000000..2183938 --- /dev/null +++ b/frontend/src/components/LoginPage/LoginPage.js @@ -0,0 +1,39 @@ +// frontend/src/components/LoginPage/LoginPage.js +import React, { useState } from 'react'; +import { useNavigate } from 'react-router-dom'; +import { useAuth } from '../../AuthContext'; +import { Box, Button, TextField, Typography, Paper } from '@mui/material'; + +function LoginPage() { + const [username, setUsername] = useState(''); + const [password, setPassword] = useState(''); + const [error, setError] = useState(''); + const auth = useAuth(); + const navigate = useNavigate(); + + const handleSubmit = async (event) => { + event.preventDefault(); + setError(''); + try { + await auth.login(username, password); + navigate('/'); // Redirect to homepage on successful login + } catch (err) { + setError('Invalid username or password.'); + } + }; + + return ( + + + Admin Login + + setUsername(e.target.value)} required /> + setPassword(e.target.value)} required /> + {error && {error}} + + + + + ); +} +export default LoginPage; \ No newline at end of file diff --git a/frontend/src/components/LoginPage/LoginPage.test.js b/frontend/src/components/LoginPage/LoginPage.test.js new file mode 100644 index 0000000..61c0c41 --- /dev/null +++ b/frontend/src/components/LoginPage/LoginPage.test.js @@ -0,0 +1,45 @@ +// frontend/src/components/LoginPage/LoginPage.test.js +import React from 'react'; +import { render, screen, fireEvent, waitFor } from '@testing-library/react'; +import { MemoryRouter } from 'react-router-dom'; +import { AuthContext } from '../../AuthContext'; +import LoginPage from './LoginPage'; + +// We create a mock login function using Jest's built-in mocking +const mockLogin = jest.fn(); +const mockNavigate = jest.fn(); + +// Mock the useNavigate hook from react-router-dom +jest.mock('react-router-dom', () => ({ + ...jest.requireActual('react-router-dom'), + useNavigate: () => mockNavigate, +})); + +const renderWithAuthProvider = (component) => { + return render( + + + {component} + + + ); +}; + +test('calls login function and navigates on successful submission', async () => { + // Configure our mock login function to simulate a successful login + mockLogin.mockResolvedValue(true); + + renderWithAuthProvider(); + + fireEvent.change(screen.getByLabelText(/Username/i), { target: { value: 'admin' } }); + fireEvent.change(screen.getByLabelText(/Password/i), { target: { value: 'password123' } }); + fireEvent.click(screen.getByRole('button', { name: /Login/i })); + + // Wait for the login function to be called and check arguments + await waitFor(() => { + expect(mockLogin).toHaveBeenCalledWith('admin', 'password123'); + }); + + // Check that it redirected to the homepage + expect(mockNavigate).toHaveBeenCalledWith('/'); +}); \ No newline at end of file diff --git a/frontend/src/components/RosterPage/RosterPage.js b/frontend/src/components/RosterPage/RosterPage.js index 02d29bb..2be4eae 100644 --- a/frontend/src/components/RosterPage/RosterPage.js +++ b/frontend/src/components/RosterPage/RosterPage.js @@ -2,6 +2,7 @@ import React, { useState, useEffect } from 'react'; import { Link } from 'react-router-dom'; // Import Link import { Box, Button, TextField, Typography, Grid, Card, CardContent, CardActions } from '@mui/material'; +import { useAuth } from '../../AuthContext'; function RosterPage() { const [players, setPlayers] = useState([]); @@ -9,7 +10,7 @@ function RosterPage() { const [lastName, setLastName] = useState(''); const [team, setTeam] = useState(''); const [editingPlayer, setEditingPlayer] = useState(null); // State to track which player is being edited - + const auth = useAuth(); const fetchPlayers = () => { fetch(`${process.env.REACT_APP_API_BASE_URL}/api/players`) @@ -24,6 +25,10 @@ function RosterPage() { // This function handles both creating AND updating const handleFormSubmit = (event) => { event.preventDefault(); + const headers = { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${auth.token}` + }; // If we are editing, call the update logic if (editingPlayer) { @@ -35,7 +40,7 @@ function RosterPage() { fetch(`${process.env.REACT_APP_API_BASE_URL}/api/players/${editingPlayer.id}`, { method: 'PUT', - headers: { 'Content-Type': 'application/json' }, + headers: headers, body: JSON.stringify(updatedPlayer), }) .then(response => response.json()) @@ -54,7 +59,7 @@ function RosterPage() { fetch(`${process.env.REACT_APP_API_BASE_URL}/api/players`, { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: headers, body: JSON.stringify(newPlayer), }) .then(response => response.json()) @@ -83,7 +88,8 @@ function RosterPage() { }; const handleDelete = (playerId) => { - fetch(`${process.env.REACT_APP_API_BASE_URL}/api/players/${playerId}`, { method: 'DELETE' }) + const headers = { 'Authorization': `Bearer ${auth.token}` }; + fetch(`${process.env.REACT_APP_API_BASE_URL}/api/players/${playerId}`, { method: 'DELETE', headers: headers }) .then(response => { if (response.ok) fetchPlayers(); }) .catch(error => console.error('Error deleting player:', error)); }; @@ -94,7 +100,7 @@ function RosterPage() { {editingPlayer ? 'Edit Player' : 'WNBA Player Roster'} - + {auth.token && setFirstName(e.target.value)} required /> @@ -110,7 +116,7 @@ function RosterPage() { {editingPlayer && } - + } {players.map(player => ( @@ -124,10 +130,10 @@ function RosterPage() { {player.team} - + {auth.token && ( - + )} ))} diff --git a/frontend/src/index.js b/frontend/src/index.js index 3808073..1a7defd 100644 --- a/frontend/src/index.js +++ b/frontend/src/index.js @@ -5,14 +5,17 @@ import App from './components/App'; import reportWebVitals from './reportWebVitals'; import {BrowserRouter} from 'react-router-dom'; import { AppThemeProvider } from './ThemeContext'; +import { AuthProvider } from './AuthContext'; const root = ReactDOM.createRoot(document.getElementById('root')); root.render( - - - + + + + + );