diff --git a/.gitignore b/.gitignore
index f4f04a2..6581526 100644
--- a/.gitignore
+++ b/.gitignore
@@ -22,4 +22,4 @@ backend/tests.http
*.joblib
# Environment Variables
-.env
+**/.env
diff --git a/Makefile b/Makefile
index a0fe00a..9132c03 100644
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,8 @@
# Makefile for WNBA Analytics Project
+include .env
+export
+
# Color codes for prettier output
GREEN := $(shell tput -T xterm setaf 2)
YELLOW := $(shell tput -T xterm setaf 3)
@@ -37,7 +40,7 @@ logs: ## Follow logs for all services
# ==============================================================================
# DEVELOPMENT & TESTING
# ==============================================================================
-setup: up seed-db build-model ## Run this once to setup a new LOCAL environment from scratch
+setup: up seed-db build-model create-admin ## Run this once to setup a new LOCAL environment from scratch
@echo "$(GREEN)✅ Initial local setup complete! Database is seeded and model is built.$(RESET)"
seed-db: ## Run the database seed script on the LOCAL docker DB
@@ -48,6 +51,10 @@ build-model: ## Run the ML model training script
@echo "$(YELLOW)--> Building similarity model...$(RESET)"
@docker compose exec backend python build_similarity_model.py
+create-admin:
+ @echo "$(YELLOW)--> Creating admin user...$(RESET)"
+ @docker compose exec -e ADMIN_USER=$(ADMIN_USER) -e ADMIN_PASSWORD=$(ADMIN_PASSWORD) backend python create_admin.py
+
test: test-backend test-frontend ## Run all backend and frontend tests
test-backend: ## Run backend python tests
diff --git a/backend/auth/__init__.py b/backend/auth/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/backend/auth/router.py b/backend/auth/router.py
new file mode 100644
index 0000000..c364112
--- /dev/null
+++ b/backend/auth/router.py
@@ -0,0 +1,56 @@
+# backend/auth/router.py
+from fastapi import APIRouter, Depends, HTTPException, status
+from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
+from sqlalchemy.orm import Session
+from typing import Annotated
+from jose import JWTError, jwt
+from typing import List
+
+from . import security
+import models
+from database import get_db
+
+# This creates a new "router" that we can include in our main app
+router = APIRouter()
+
+# Define the token endpoint URL and the security scheme
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token")
+
+@router.post("/token")
+def login_for_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Session = Depends(get_db)):
+ """
+ Handles the login request. Takes username and password from a form body.
+ """
+ user = db.query(models.User).filter(models.User.username == form_data.username).first()
+ if not user or not security.verify_password(form_data.password, user.hashed_password):
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Incorrect username or password",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ access_token = security.create_access_token(data={"sub": user.username})
+ return {"access_token": access_token, "token_type": "bearer"}
+
+async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_db)):
+ """
+ This is the dependency that will protect our routes.
+ It decodes the token, validates it, and fetches the user from the database.
+ """
+ credentials_exception = HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Could not validate credentials",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ try:
+ payload = jwt.decode(token, security.SECRET_KEY, algorithms=[security.ALGORITHM])
+ username: str = payload.get("sub")
+ if username is None:
+ raise credentials_exception
+ except JWTError:
+ raise credentials_exception
+
+ user = db.query(models.User).filter(models.User.username == username).first()
+ if user is None:
+ raise credentials_exception
+ return user
+
diff --git a/backend/auth/security.py b/backend/auth/security.py
new file mode 100644
index 0000000..645fcfa
--- /dev/null
+++ b/backend/auth/security.py
@@ -0,0 +1,37 @@
+# backend/auth/security.py
+from datetime import datetime, timedelta, timezone
+from typing import Optional
+from passlib.context import CryptContext
+from jose import JWTError, jwt
+import os
+
+# --- Configuration ---
+# This is a hardcoded key for development. In a real production environment,
+# you would load this from a secure environment variable or a secret manager.
+SECRET_KEY = os.getenv("SECRET_KEY", "a_very_secret_key_for_a_portfolio_project")
+ALGORITHM = "HS256"
+ACCESS_TOKEN_EXPIRE_MINUTES = 30
+
+# Setup for password hashing. bcrypt is the standard, secure algorithm.
+pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
+
+def verify_password(plain_password: str, hashed_password: str) -> bool:
+ """Verifies a plain password against a hashed one."""
+ return pwd_context.verify(plain_password, hashed_password)
+
+def get_password_hash(password: str) -> str:
+ """Hashes a plain password."""
+ return pwd_context.hash(password)
+
+# --- JWT Token Logic ---
+def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
+ """Creates a new JWT access token."""
+ to_encode = data.copy()
+ if expires_delta:
+ expire = datetime.now(timezone.utc) + expires_delta
+ else:
+ expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
+
+ to_encode.update({"exp": expire})
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
+ return encoded_jwt
diff --git a/backend/create_admin.py b/backend/create_admin.py
new file mode 100644
index 0000000..1c25ed7
--- /dev/null
+++ b/backend/create_admin.py
@@ -0,0 +1,36 @@
+# backend/create_admin.py
+import os
+from database import SessionLocal, engine
+from models import User, Base
+from auth.security import get_password_hash
+import logging
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+ADMIN_USERNAME = os.getenv("ADMIN_USER", "admin")
+ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
+
+def create_admin_user():
+ if not ADMIN_PASSWORD:
+ logger.error("ADMIN_PASSWORD environment variable not set. Cannot create admin user.")
+ return
+
+ db = SessionLocal()
+ Base.metadata.create_all(bind=engine) # Ensure the 'users' table exists
+
+ # Check if user already exists
+ if db.query(User).filter(User.username == ADMIN_USERNAME).first():
+ logger.warning(f"Admin user '{ADMIN_USERNAME}' already exists.")
+ else:
+ # Create a new user with a hashed password
+ hashed_password = get_password_hash(ADMIN_PASSWORD)
+ admin_user = User(username=ADMIN_USERNAME, hashed_password=hashed_password)
+ db.add(admin_user)
+ db.commit()
+ logger.info(f"Admin user '{ADMIN_USERNAME}' created successfully.")
+
+ db.close()
+
+if __name__ == "__main__":
+ create_admin_user()
diff --git a/backend/database.py b/backend/database.py
index 16a1334..84147a4 100644
--- a/backend/database.py
+++ b/backend/database.py
@@ -32,9 +32,15 @@ def get_database_url():
return "postgresql://admin:password123@db:5432/wnba_db"
DATABASE_URL = get_database_url()
-
-# --- Use the DATABASE_URL variable ---
-engine = create_engine(DATABASE_URL)
-
+connect_args = {"check_same_thread": False} if "sqlite" in DATABASE_URL else {}
+engine = create_engine(DATABASE_URL, connect_args=connect_args)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
-Base = declarative_base()
\ No newline at end of file
+Base = declarative_base()
+
+def get_db():
+ """Dependency to get a database session."""
+ db = SessionLocal()
+ try:
+ yield db
+ finally:
+ db.close()
diff --git a/backend/main.py b/backend/main.py
index 4cc734e..47f3727 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -8,7 +8,7 @@
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session
from pydantic import BaseModel, ConfigDict
-from typing import List
+from typing import List, Annotated
import joblib
import pandas as pd
@@ -17,6 +17,9 @@
# Import your SQLAlchemy models and session management
import models
import database
+from database import get_db
+from auth.router import router as auth_router # Import our new auth router
+from auth.router import get_current_user # Import our new dependency
# --- The Lifespan function now loads the model ---
@asynccontextmanager
@@ -41,9 +44,12 @@ async def lifespan(app: FastAPI):
logger.info("Application shutdown.")
app = FastAPI(lifespan=lifespan)
+app.include_router(auth_router, prefix="/auth", tags=["Authentication"])
-origins = ["http://localhost:3000",
- "https://wnba-frontend-service-776933261932.us-west1.run.app"]
+origins = [
+ "http://localhost:3000",
+ "https://wnba-frontend-service-776933261932.us-west1.run.app"
+]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
@@ -52,14 +58,6 @@ async def lifespan(app: FastAPI):
allow_headers=["*"],
)
-# Dependency to get a database session for each request
-def get_db():
- db = database.SessionLocal()
- try:
- yield db
- finally:
- db.close()
-
# --- Pydantic Schemas ---
class PlayerStatBase(BaseModel):
season: str
@@ -95,44 +93,24 @@ class Player(PlayerBase):
stats: List[PlayerStat] = []
model_config = ConfigDict(from_attributes=True)
-# ---- API ENDPOINTS ----
-@app.get("/api")
-def read_root():
- return {"message": "WNBA Analytics API is running!"}
+class UserOut(BaseModel):
+ id: int
+ username: str
+ model_config = ConfigDict(from_attributes=True)
+# ---- PROTECTED API ENDPOINTS ----
# Endpoint to CREATE a new player
@app.post("/api/players", response_model=Player)
-def create_player(player: PlayerCreate, db: Session = Depends(get_db)):
+def create_player(player: PlayerCreate, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)):
new_player = models.Player(**player.model_dump())
db.add(new_player)
db.commit()
db.refresh(new_player)
return new_player
-# Endpoint to READ all players
-@app.get("/api/players", response_model=List[Player])
-def get_players(db: Session = Depends(get_db)):
- players = db.query(models.Player).all()
- return players
-
-@app.get("/api/players/{player_id}", response_model=Player)
-def get_player(player_id: int, db: Session = Depends(get_db)):
- player = db.query(models.Player).filter(models.Player.id == player_id).first()
- if player is None: raise HTTPException(status_code=404, detail="Player not found")
- return player
-
-# Endpoint to DELETE a player
-@app.delete("/api/players/{player_id}")
-def delete_player(player_id: int, db: Session = Depends(get_db)):
- player_to_delete = db.query(models.Player).filter(models.Player.id == player_id).first()
- if player_to_delete is None: raise HTTPException(status_code=404, detail="Player not found")
- db.delete(player_to_delete)
- db.commit()
- return {"message": "Player deleted successfully"}
-
-# Endpoint to UPDATE a player
+# Endpoint to UPDATE a player
@app.put("/api/players/{player_id}", response_model=Player)
-def update_player(player_id: int, player_update: PlayerCreate, db: Session = Depends(get_db)):
+def update_player(player_id: int, player_update: PlayerCreate, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)):
player_to_update = db.query(models.Player).filter(models.Player.id == player_id).first()
if player_to_update is None: raise HTTPException(status_code=404, detail="Player not found")
@@ -144,9 +122,17 @@ def update_player(player_id: int, player_update: PlayerCreate, db: Session = Dep
db.refresh(player_to_update)
return player_to_update
+# Endpoint to DELETE a player
+@app.delete("/api/players/{player_id}")
+def delete_player(player_id: int, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)):
+ player_to_delete = db.query(models.Player).filter(models.Player.id == player_id).first()
+ if player_to_delete is None: raise HTTPException(status_code=404, detail="Player not found")
+ db.delete(player_to_delete)
+ db.commit()
+ return {"message": "Player deleted successfully"}
+
@app.post("/api/players/{player_id}/stats", response_model=PlayerStat)
-def create_stats_for_player(player_id: int, stat: PlayerStatCreate,
- db: Session = Depends(get_db)):
+def create_stats_for_player(player_id: int, stat: PlayerStatCreate, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)):
db_player = db.query(models.Player).filter(models.Player.id == player_id).first()
if db_player is None: raise HTTPException(status_code=404, detail="Player not found")
db_stat = models.PlayerStat(**stat.model_dump(), player_id=player_id)
@@ -155,6 +141,34 @@ def create_stats_for_player(player_id: int, stat: PlayerStatCreate,
db.refresh(db_stat)
return db_stat
+@app.get("/users", response_model=List[UserOut])
+def read_users(current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)):
+ """
+ Retrieves a list of all users.
+ This is a protected endpoint that requires authentication.
+ The `response_model` ensures that only the fields from `UserOut` (id, username)
+ are returned, protecting the hashed password.
+ """
+ users = db.query(models.User).all()
+ return users
+
+# ---- PUBLIC API ENDPOINTS ----
+@app.get("/api")
+def read_root():
+ return {"message": "WNBA Analytics API is running!"}
+
+# Endpoint to READ all players
+@app.get("/api/players", response_model=List[Player])
+def get_players(db: Session = Depends(get_db)):
+ players = db.query(models.Player).all()
+ return players
+
+@app.get("/api/players/{player_id}", response_model=Player)
+def get_player(player_id: int, db: Session = Depends(get_db)):
+ player = db.query(models.Player).filter(models.Player.id == player_id).first()
+ if player is None: raise HTTPException(status_code=404, detail="Player not found")
+ return player
+
# A Pydantic schema for the similarity response
class SimilarPlayer(BaseModel):
player_season_id: str
diff --git a/backend/models.py b/backend/models.py
index c77bc15..41239a0 100644
--- a/backend/models.py
+++ b/backend/models.py
@@ -32,4 +32,10 @@ class PlayerStat(Base):
player_efficiency_rating = Column(Float)
player_id = Column(Integer, ForeignKey("players.id"))
- player = relationship("Player", back_populates="stats")
\ No newline at end of file
+ player = relationship("Player", back_populates="stats")
+
+class User(Base):
+ __tablename__ = "users"
+ id = Column(Integer, primary_key=True, index=True)
+ username = Column(String, unique=True, index=True)
+ hashed_password = Column(String)
diff --git a/backend/pytest.ini b/backend/pytest.ini
index 9ccee47..cf012a4 100644
--- a/backend/pytest.ini
+++ b/backend/pytest.ini
@@ -1,3 +1,5 @@
# backend/pytest.ini
[pytest]
-pythonpath = .
\ No newline at end of file
+pythonpath = .
+# This tells pytest to ignore any DeprecationWarning originating from the passlib library
+filterwarnings = ignore::DeprecationWarning:passlib.*
\ No newline at end of file
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 55a39b3..ef585d9 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -1,11 +1,14 @@
annotated-types==0.7.0
anyio==4.9.0
+bcrypt==4.0.1
certifi==2025.4.26
click==8.2.1
dnspython==2.7.0
+ecdsa==0.19.1
email_validator==2.2.0
fastapi==0.115.12
fastapi-cli==0.0.7
+greenlet==3.2.3
h11==0.16.0
httpcore==1.0.9
httptools==0.6.4
@@ -15,6 +18,7 @@ iniconfig==2.1.0
itsdangerous==2.2.0
Jinja2==3.1.6
joblib==1.5.1
+jose==1.0.0
markdown-it-py==3.0.0
MarkupSafe==3.0.2
mdurl==0.1.2
@@ -22,8 +26,10 @@ numpy==2.3.0
orjson==3.10.18
packaging==25.0
pandas==2.3.0
+passlib==1.7.4
pluggy==1.6.0
psycopg2-binary==2.9.10
+pyasn1==0.6.1
pydantic==2.11.5
pydantic-extra-types==2.10.5
pydantic-settings==2.9.1
@@ -32,11 +38,13 @@ Pygments==2.19.1
pytest==8.4.0
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
+python-jose==3.5.0
python-multipart==0.0.20
pytz==2025.2
PyYAML==6.0.2
rich==14.0.0
rich-toolkit==0.14.7
+rsa==4.9.1
scikit-learn==1.7.0
scipy==1.15.3
shellingham==1.5.4
diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py
index eceb2b7..531ae54 100644
--- a/backend/tests/conftest.py
+++ b/backend/tests/conftest.py
@@ -1,39 +1,103 @@
# backend/tests/conftest.py
import pytest
import os
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
-# Set the environment variable BEFORE other imports.
+# This MUST be the first thing to run. It configures the app for testing.
os.environ['DATABASE_URL'] = "sqlite:///./test.db"
+from fastapi.testclient import TestClient
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+
from main import app, get_db
from database import Base, engine
+from auth.security import get_password_hash
+import models
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(scope="function")
def db_session():
"""
- This fixture creates a clean database with all tables for a single test function,
- and then drops all tables after the test is done.
+ The single fixture to provide a clean database session for each test.
+ It creates all tables before the test and drops them all after.
+ This guarantees perfect test isolation.
"""
+ # --- SETUP ---
Base.metadata.create_all(bind=engine)
- # This is the key: we override the app's dependency to use our
- # clean, temporary database for the duration of the test.
def override_get_db():
+ db = TestingSessionLocal()
try:
- db = TestingSessionLocal()
yield db
finally:
db.close()
+ # Override the dependency for the duration of the test
app.dependency_overrides[get_db] = override_get_db
- # Yield nothing, just perform setup and teardown
yield
- # Teardown: clean up the override and drop tables
+ # --- TEARDOWN ---
+ del app.dependency_overrides[get_db]
+ Base.metadata.drop_all(bind=engine)
+
+@pytest.fixture(scope="function")
+def test_client():
+ """
+ Provides a TestClient with a clean database for each test.
+ Handles startup/shutdown events and DB table creation/destruction.
+ """
+ # --- SETUP ---
+ Base.metadata.create_all(bind=engine)
+
+ def override_get_db():
+ """This function overrides the production database dependency."""
+ try:
+ db = TestingSessionLocal()
+ yield db
+ finally:
+ db.close()
+
+ app.dependency_overrides[get_db] = override_get_db
+
+ with TestClient(app) as client:
+ yield client
+
+ # --- TEARDOWN ---
+ # Clean up after the test is done
del app.dependency_overrides[get_db]
Base.metadata.drop_all(bind=engine)
+
+
+@pytest.fixture(scope="function")
+def authenticated_client(test_client):
+ """
+ Provides an authenticated TestClient. It uses the regular test_client
+ and performs a login to get an auth token.
+ """
+ # We need a direct session to create the user in the database first
+ db = TestingSessionLocal()
+ password = "testpassword"
+
+ # Create user if it doesn't exist to avoid IntegrityError
+ user = db.query(models.User).filter(models.User.username == "testuser").first()
+ if not user:
+ user = models.User(username="testuser", hashed_password=get_password_hash(password))
+ db.add(user)
+ db.commit()
+
+ db.close()
+
+ # Log in as the test user to get a token
+ login_response = test_client.post(
+ "/auth/token",
+ data={"username": "testuser", "password": password}
+ )
+ assert login_response.status_code == 200, "Failed to log in during test setup"
+ token = login_response.json()["access_token"]
+
+ # Set the auth header for all future requests with this client
+ test_client.headers["Authorization"] = f"Bearer {token}"
+
+ return test_client
\ No newline at end of file
diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py
new file mode 100644
index 0000000..906978a
--- /dev/null
+++ b/backend/tests/test_auth.py
@@ -0,0 +1,65 @@
+# backend/tests/test_auth.py
+from auth.security import get_password_hash
+import models
+from database import SessionLocal
+
+def test_login_for_access_token(test_client):
+ """Tests if a user can successfully log in with correct credentials."""
+ # Setup: Create a user directly in the test database
+ db = SessionLocal()
+ password = "correct_password"
+ user = models.User(username="logintestuser", hashed_password=get_password_hash(password))
+ db.add(user)
+ db.commit()
+ db.close()
+
+ # Execute & Assert
+ response = test_client.post(
+ "/auth/token",
+ data={"username": "logintestuser", "password": "correct_password"}
+ )
+ assert response.status_code == 200, response.text
+ assert "access_token" in response.json()
+ assert response.json()["token_type"] == "bearer"
+
+def test_login_with_wrong_password(test_client):
+ """Tests that login fails with an incorrect password."""
+ db = SessionLocal()
+ user = models.User(username="wrongpassuser", hashed_password=get_password_hash("correct_password"))
+ db.add(user)
+ db.commit()
+ db.close()
+
+ response = test_client.post(
+ "/auth/token",
+ data={"username": "wrongpassuser", "password": "incorrect_password"}
+ )
+ assert response.status_code == 401 # Unauthorized
+
+def test_read_users_as_authenticated_user(authenticated_client):
+ """
+ Tests that a logged-in user can successfully fetch the list of users.
+ """
+ # The authenticated_client fixture has already created and logged in a "testuser"
+ response = authenticated_client.get("/users")
+
+ # Assert success
+ assert response.status_code == 200
+ data = response.json()
+
+ # Assert the response format is correct
+ assert isinstance(data, list)
+ assert len(data) > 0
+ assert data[0]["username"] == "testuser"
+
+ # CRITICAL: Assert that the hashed password is NOT present
+ assert "hashed_password" not in data[0]
+
+def test_read_users_unauthenticated(test_client):
+ """
+ Tests that a non-logged-in user receives a 401 error when trying
+ to access the protected user list.
+ """
+ # We use the basic, unauthenticated client here
+ response = test_client.get("/users")
+ assert response.status_code == 401 # Unauthorized
\ No newline at end of file
diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py
index 7a10005..bdf90b2 100644
--- a/backend/tests/test_main.py
+++ b/backend/tests/test_main.py
@@ -1,11 +1,12 @@
# backend/tests/test_main.py
-from fastapi.testclient import TestClient
-from main import app
-# This test doesn't need a database, so it doesn't need a fixture.
-def test_read_root():
- # It creates its own client.
- client = TestClient(app)
- response = client.get("/api")
+def test_read_root(test_client):
+ """Tests the public root API endpoint."""
+ response = test_client.get("/api")
assert response.status_code == 200
assert response.json() == {"message": "WNBA Analytics API is running!"}
+
+def test_get_players_publicly(test_client):
+ """Tests that anyone can view the list of players."""
+ response = test_client.get("/api/players")
+ assert response.status_code == 200
\ No newline at end of file
diff --git a/backend/tests/test_players_api.py b/backend/tests/test_players_api.py
index 3574fbe..b63af6f 100644
--- a/backend/tests/test_players_api.py
+++ b/backend/tests/test_players_api.py
@@ -1,32 +1,18 @@
# backend/tests/test_players_api.py
-from fastapi.testclient import TestClient
-from main import app
-# Each test asks for the `db_session` fixture to ensure the database is clean.
-def test_create_and_get_player(db_session):
- client = TestClient(app) # Create a client for this test
- create_response = client.post("/api/players", json={"first_name": "Caitlin", "last_name": "Clark", "team": "Indiana Fever"})
- assert create_response.status_code == 200
- player_id = create_response.json()["id"]
+def test_create_player_as_authenticated_user(authenticated_client):
+ """Tests that a logged-in user can create a player."""
+ response = authenticated_client.post(
+ "/api/players",
+ json={"first_name": "Caitlin", "last_name": "Clark", "team": "Indiana Fever"}
+ )
+ assert response.status_code == 200
+ assert response.json()["first_name"] == "Caitlin"
- get_response = client.get(f"/api/players/{player_id}")
- assert get_response.status_code == 200
- assert get_response.json()["first_name"] == "Caitlin"
-
-def test_get_nonexistent_player_returns_404(db_session):
- client = TestClient(app)
- response = client.get("/api/players/99999")
- assert response.status_code == 404
-
-def test_add_and_get_stats_for_player(db_session):
- client = TestClient(app)
- player_res = client.post("/api/players", json={"first_name": "Sabrina", "last_name": "Ionescu", "team": "New York Liberty"})
- player_id = player_res.json()["id"]
- stats_payload = {
- "season": "2024", "points_per_game": 19, "rebounds_per_game": 4, "assists_per_game": 7,
- "games_played": 30, "games_started": 30, "field_goal_percentage": 0.45,
- "three_point_percentage": 0.35, "steals_per_game": 1.5, "blocks_per_game": 0.8,
- "player_efficiency_rating": 22.5
- }
- stats_res = client.post(f"/api/players/{player_id}/stats", json=stats_payload)
- assert stats_res.status_code == 200
+def test_create_player_unauthenticated(test_client):
+ """Tests that a non-logged-in user cannot create a player."""
+ response = test_client.post(
+ "/api/players",
+ json={"first_name": "Caitlin", "last_name": "Clark", "team": "Indiana Fever"}
+ )
+ assert response.status_code == 401 # Unauthorized
\ No newline at end of file
diff --git a/backend/tests/test_similarity_api.py b/backend/tests/test_similarity_api.py
index 0f20afd..88dec94 100644
--- a/backend/tests/test_similarity_api.py
+++ b/backend/tests/test_similarity_api.py
@@ -3,37 +3,35 @@
from unittest.mock import patch
import pandas as pd
import numpy as np
-import pytest
-from main import app
+from main import app, get_db
+import models
+from auth.security import get_password_hash
-# This test also asks for `db_session` to ensure the database is clean.
def test_get_similar_players(db_session):
- mock_player_vectors = pd.DataFrame(
- {'pts': [0.9, -0.5, 0.85], 'reb': [0.8, -0.6, 0.9]},
- index=['Player A (2024)', 'Player B (2024)', 'Player C (2024)']
- )
- mock_matrix = np.array([[1.0, 0.1, 0.95], [0.1, 1.0, 0.2], [0.95, 0.2, 1.0]])
+ mock_player_vectors = pd.DataFrame(index=['Player A (2024)', 'Player C (2024)'])
+ mock_matrix = np.array([[1.0, 0.95], [0.95, 1.0]])
- # 1. We set up the mock using 'with patch'
with patch('main.joblib.load') as mock_joblib_load:
- def side_effect(filename):
- if "data" in filename: return mock_player_vectors
- if "matrix" in filename: return mock_matrix
- mock_joblib_load.side_effect = side_effect
+ mock_joblib_load.side_effect = [mock_player_vectors, mock_matrix]
- # 2. We create the TestClient INSIDE the patch block.
- # This is the critical part that forces the app's startup
- # event to run while our mock is active.
with TestClient(app) as client:
- # 3. Setup the DB with a player
- player_res = client.post("/api/players", json={"first_name": "Player", "last_name": "A", "team": "Team A"})
+ # Create user, log in, get headers
+ db = client.app.dependency_overrides[get_db]().__next__()
+ user = models.User(username="testuser", hashed_password=get_password_hash("password"))
+ db.add(user)
+ db.commit()
+ login_res = client.post("/auth/token", data={"username": "testuser", "password": "password"})
+ token = login_res.json()["access_token"]
+ headers = {"Authorization": f"Bearer {token}"}
+ db.close()
+
+ player_res = client.post("/api/players", headers=headers, json={"first_name": "Player", "last_name": "A", "team": "Team A"})
assert player_res.status_code == 200
player_id = player_res.json()["id"]
- # 4. Call the API endpoint
+ # Call the similarity endpoint
response = client.get(f"/api/players/{player_id}/seasons/2024/similar")
- # 5. Assert the response
assert response.status_code == 200
data = response.json()
- assert data[0]["player_season_id"] == "Player C (2024)"
+ assert data[0]["player_season_id"] == "Player C (2024)"
\ No newline at end of file
diff --git a/frontend/src/AuthContext.js b/frontend/src/AuthContext.js
new file mode 100644
index 0000000..99befda
--- /dev/null
+++ b/frontend/src/AuthContext.js
@@ -0,0 +1,45 @@
+// frontend/src/AuthContext.js
+import React, { createContext, useState, useContext, useMemo } from 'react';
+
+export const AuthContext = createContext(null);
+
+export const AuthProvider = ({ children }) => {
+ const [token, setToken] = useState(localStorage.getItem('authToken'));
+
+ const login = async (username, password) => {
+ const response = await fetch(`${process.env.REACT_APP_API_BASE_URL}/auth/token`, {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
+ body: new URLSearchParams({ username, password })
+ });
+ if (!response.ok) {
+ const errorData = await response.json();
+ throw new Error(errorData.detail || 'Failed to login');
+ }
+ const data = await response.json();
+ setToken(data.access_token);
+ localStorage.setItem('authToken', data.access_token);
+ };
+
+ const logout = () => {
+ setToken(null);
+ localStorage.removeItem('authToken');
+ };
+
+ // useMemo ensures this object is not recreated on every render
+ const authContextValue = useMemo(() => ({
+ token,
+ login,
+ logout
+ }), [token]);
+
+ return (
+
+ {children}
+
+ );
+};
+
+export const useAuth = () => {
+ return useContext(AuthContext);
+};
\ No newline at end of file
diff --git a/frontend/src/ThemeContext.js b/frontend/src/ThemeContext.js
index b23b78b..3bbca66 100644
--- a/frontend/src/ThemeContext.js
+++ b/frontend/src/ThemeContext.js
@@ -8,12 +8,17 @@ export const ThemeContext = createContext({
});
export const AppThemeProvider = ({ children }) => {
- const [mode, setMode] = useState('dark'); // Default to dark mode
+ const storedMode = localStorage.getItem('themeMode') || 'dark'; // Default to dark mode
+ const [mode, setMode] = useState(storedMode);
const theme = useMemo(() => (mode === 'light' ? lightTheme : darkTheme), [mode]);
const toggleTheme = () => {
- setMode((prevMode) => (prevMode === 'light' ? 'dark' : 'light'));
+ setMode((prevMode) => {
+ const newMode = prevMode === 'light' ? 'dark' : 'light';
+ localStorage.setItem('themeMode', newMode);
+ return newMode;
+ });
};
return (
diff --git a/frontend/src/components/App.js b/frontend/src/components/App.js
index 2865882..187964e 100644
--- a/frontend/src/components/App.js
+++ b/frontend/src/components/App.js
@@ -4,6 +4,7 @@ import { Routes, Route } from 'react-router-dom';
import Layout from './Layout';
import RosterPage from './RosterPage/RosterPage';
import PlayerDetailPage from './PlayerDetailPage/PlayerDetailPage';
+import LoginPage from './LoginPage/LoginPage';
import './App.css';
function App() {
@@ -12,6 +13,7 @@ function App() {
} />
} />
+ } />
);
diff --git a/frontend/src/components/App.test.js b/frontend/src/components/App.test.js
index 454f8b7..b2f28f2 100644
--- a/frontend/src/components/App.test.js
+++ b/frontend/src/components/App.test.js
@@ -2,13 +2,18 @@
import { render, screen } from '@testing-library/react';
import { MemoryRouter } from 'react-router-dom';
import App from './App';
+import { AuthProvider } from '../AuthContext';
+import { AppThemeProvider } from '../ThemeContext';
test('renders the RosterPage for the home route', () => {
render(
-
-
-
+
+
+
+
+
+
+
);
- // We expect the main heading from the RosterPage to be present
expect(screen.getByRole('heading', { name: /WNBA Player Roster/i })).toBeInTheDocument();
-});
+});
\ No newline at end of file
diff --git a/frontend/src/components/Layout.js b/frontend/src/components/Layout.js
index 29e98ee..d1f3564 100644
--- a/frontend/src/components/Layout.js
+++ b/frontend/src/components/Layout.js
@@ -1,14 +1,22 @@
// frontend/src/components/Layout.js
import React, { useContext } from 'react';
-import { Box, AppBar, Toolbar, Typography, IconButton, Container } from '@mui/material';
+import { Box, AppBar, Toolbar, Typography, IconButton, Container, Button } from '@mui/material';
import { Brightness4, Brightness7 } from '@mui/icons-material';
import { useTheme } from '@mui/material/styles';
import { ThemeContext } from '../ThemeContext';
-import { Link } from 'react-router-dom';
+import { Link, useNavigate } from 'react-router-dom';
+import { useAuth } from '../AuthContext';
const Layout = ({ children }) => {
const theme = useTheme();
const colorMode = useContext(ThemeContext);
+ const auth = useAuth();
+ const navigate = useNavigate();
+
+ const handleLogout = () => {
+ auth.logout();
+ navigate('/');
+ };
return (
@@ -22,6 +30,11 @@ const Layout = ({ children }) => {
{theme.palette.mode === 'dark' ? : }
+ {auth.token ? (
+
+ ) : (
+
+ )}
diff --git a/frontend/src/components/LoginPage/LoginPage.js b/frontend/src/components/LoginPage/LoginPage.js
new file mode 100644
index 0000000..2183938
--- /dev/null
+++ b/frontend/src/components/LoginPage/LoginPage.js
@@ -0,0 +1,39 @@
+// frontend/src/components/LoginPage/LoginPage.js
+import React, { useState } from 'react';
+import { useNavigate } from 'react-router-dom';
+import { useAuth } from '../../AuthContext';
+import { Box, Button, TextField, Typography, Paper } from '@mui/material';
+
+function LoginPage() {
+ const [username, setUsername] = useState('');
+ const [password, setPassword] = useState('');
+ const [error, setError] = useState('');
+ const auth = useAuth();
+ const navigate = useNavigate();
+
+ const handleSubmit = async (event) => {
+ event.preventDefault();
+ setError('');
+ try {
+ await auth.login(username, password);
+ navigate('/'); // Redirect to homepage on successful login
+ } catch (err) {
+ setError('Invalid username or password.');
+ }
+ };
+
+ return (
+
+
+ Admin Login
+
+ setUsername(e.target.value)} required />
+ setPassword(e.target.value)} required />
+ {error && {error}}
+
+
+
+
+ );
+}
+export default LoginPage;
\ No newline at end of file
diff --git a/frontend/src/components/LoginPage/LoginPage.test.js b/frontend/src/components/LoginPage/LoginPage.test.js
new file mode 100644
index 0000000..61c0c41
--- /dev/null
+++ b/frontend/src/components/LoginPage/LoginPage.test.js
@@ -0,0 +1,45 @@
+// frontend/src/components/LoginPage/LoginPage.test.js
+import React from 'react';
+import { render, screen, fireEvent, waitFor } from '@testing-library/react';
+import { MemoryRouter } from 'react-router-dom';
+import { AuthContext } from '../../AuthContext';
+import LoginPage from './LoginPage';
+
+// We create a mock login function using Jest's built-in mocking
+const mockLogin = jest.fn();
+const mockNavigate = jest.fn();
+
+// Mock the useNavigate hook from react-router-dom
+jest.mock('react-router-dom', () => ({
+ ...jest.requireActual('react-router-dom'),
+ useNavigate: () => mockNavigate,
+}));
+
+const renderWithAuthProvider = (component) => {
+ return render(
+
+
+ {component}
+
+
+ );
+};
+
+test('calls login function and navigates on successful submission', async () => {
+ // Configure our mock login function to simulate a successful login
+ mockLogin.mockResolvedValue(true);
+
+ renderWithAuthProvider();
+
+ fireEvent.change(screen.getByLabelText(/Username/i), { target: { value: 'admin' } });
+ fireEvent.change(screen.getByLabelText(/Password/i), { target: { value: 'password123' } });
+ fireEvent.click(screen.getByRole('button', { name: /Login/i }));
+
+ // Wait for the login function to be called and check arguments
+ await waitFor(() => {
+ expect(mockLogin).toHaveBeenCalledWith('admin', 'password123');
+ });
+
+ // Check that it redirected to the homepage
+ expect(mockNavigate).toHaveBeenCalledWith('/');
+});
\ No newline at end of file
diff --git a/frontend/src/components/RosterPage/RosterPage.js b/frontend/src/components/RosterPage/RosterPage.js
index 02d29bb..2be4eae 100644
--- a/frontend/src/components/RosterPage/RosterPage.js
+++ b/frontend/src/components/RosterPage/RosterPage.js
@@ -2,6 +2,7 @@
import React, { useState, useEffect } from 'react';
import { Link } from 'react-router-dom'; // Import Link
import { Box, Button, TextField, Typography, Grid, Card, CardContent, CardActions } from '@mui/material';
+import { useAuth } from '../../AuthContext';
function RosterPage() {
const [players, setPlayers] = useState([]);
@@ -9,7 +10,7 @@ function RosterPage() {
const [lastName, setLastName] = useState('');
const [team, setTeam] = useState('');
const [editingPlayer, setEditingPlayer] = useState(null); // State to track which player is being edited
-
+ const auth = useAuth();
const fetchPlayers = () => {
fetch(`${process.env.REACT_APP_API_BASE_URL}/api/players`)
@@ -24,6 +25,10 @@ function RosterPage() {
// This function handles both creating AND updating
const handleFormSubmit = (event) => {
event.preventDefault();
+ const headers = {
+ 'Content-Type': 'application/json',
+ 'Authorization': `Bearer ${auth.token}`
+ };
// If we are editing, call the update logic
if (editingPlayer) {
@@ -35,7 +40,7 @@ function RosterPage() {
fetch(`${process.env.REACT_APP_API_BASE_URL}/api/players/${editingPlayer.id}`, {
method: 'PUT',
- headers: { 'Content-Type': 'application/json' },
+ headers: headers,
body: JSON.stringify(updatedPlayer),
})
.then(response => response.json())
@@ -54,7 +59,7 @@ function RosterPage() {
fetch(`${process.env.REACT_APP_API_BASE_URL}/api/players`, {
method: 'POST',
- headers: { 'Content-Type': 'application/json' },
+ headers: headers,
body: JSON.stringify(newPlayer),
})
.then(response => response.json())
@@ -83,7 +88,8 @@ function RosterPage() {
};
const handleDelete = (playerId) => {
- fetch(`${process.env.REACT_APP_API_BASE_URL}/api/players/${playerId}`, { method: 'DELETE' })
+ const headers = { 'Authorization': `Bearer ${auth.token}` };
+ fetch(`${process.env.REACT_APP_API_BASE_URL}/api/players/${playerId}`, { method: 'DELETE', headers: headers })
.then(response => { if (response.ok) fetchPlayers(); })
.catch(error => console.error('Error deleting player:', error));
};
@@ -94,7 +100,7 @@ function RosterPage() {
{editingPlayer ? 'Edit Player' : 'WNBA Player Roster'}
-
+ {auth.token &&
setFirstName(e.target.value)} required />
@@ -110,7 +116,7 @@ function RosterPage() {
{editingPlayer && }
-
+ }
{players.map(player => (
@@ -124,10 +130,10 @@ function RosterPage() {
{player.team}
-
+ {auth.token && (
-
+ )}
))}
diff --git a/frontend/src/index.js b/frontend/src/index.js
index 3808073..1a7defd 100644
--- a/frontend/src/index.js
+++ b/frontend/src/index.js
@@ -5,14 +5,17 @@ import App from './components/App';
import reportWebVitals from './reportWebVitals';
import {BrowserRouter} from 'react-router-dom';
import { AppThemeProvider } from './ThemeContext';
+import { AuthProvider } from './AuthContext';
const root = ReactDOM.createRoot(document.getElementById('root'));
root.render(
-
-
-
+
+
+
+
+
);