Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ backend/tests.http
*.joblib

# Environment Variables
.env
**/.env
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Makefile for WNBA Analytics Project

include .env
export

# Color codes for prettier output
GREEN := $(shell tput -T xterm setaf 2)
YELLOW := $(shell tput -T xterm setaf 3)
Expand Down Expand Up @@ -37,7 +40,7 @@ logs: ## Follow logs for all services
# ==============================================================================
# DEVELOPMENT & TESTING
# ==============================================================================
setup: up seed-db build-model ## Run this once to setup a new LOCAL environment from scratch
setup: up seed-db build-model create-admin ## Run this once to setup a new LOCAL environment from scratch
@echo "$(GREEN)✅ Initial local setup complete! Database is seeded and model is built.$(RESET)"

seed-db: ## Run the database seed script on the LOCAL docker DB
Expand All @@ -48,6 +51,10 @@ build-model: ## Run the ML model training script
@echo "$(YELLOW)--> Building similarity model...$(RESET)"
@docker compose exec backend python build_similarity_model.py

create-admin:
@echo "$(YELLOW)--> Creating admin user...$(RESET)"
@docker compose exec -e ADMIN_USER=$(ADMIN_USER) -e ADMIN_PASSWORD=$(ADMIN_PASSWORD) backend python create_admin.py

test: test-backend test-frontend ## Run all backend and frontend tests

test-backend: ## Run backend python tests
Expand Down
Empty file added backend/auth/__init__.py
Empty file.
56 changes: 56 additions & 0 deletions backend/auth/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# backend/auth/router.py
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from typing import Annotated
from jose import JWTError, jwt
from typing import List

from . import security
import models
from database import get_db

# This creates a new "router" that we can include in our main app
router = APIRouter()

# Define the token endpoint URL and the security scheme
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token")

@router.post("/token")
def login_for_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Session = Depends(get_db)):
"""
Handles the login request. Takes username and password from a form body.
"""
user = db.query(models.User).filter(models.User.username == form_data.username).first()
if not user or not security.verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = security.create_access_token(data={"sub": user.username})
return {"access_token": access_token, "token_type": "bearer"}

async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_db)):
"""
This is the dependency that will protect our routes.
It decodes the token, validates it, and fetches the user from the database.
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, security.SECRET_KEY, algorithms=[security.ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception

user = db.query(models.User).filter(models.User.username == username).first()
if user is None:
raise credentials_exception
return user

37 changes: 37 additions & 0 deletions backend/auth/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# backend/auth/security.py
from datetime import datetime, timedelta, timezone
from typing import Optional
from passlib.context import CryptContext
from jose import JWTError, jwt
import os

# --- Configuration ---
# This is a hardcoded key for development. In a real production environment,
# you would load this from a secure environment variable or a secret manager.
SECRET_KEY = os.getenv("SECRET_KEY", "a_very_secret_key_for_a_portfolio_project")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

# Setup for password hashing. bcrypt is the standard, secure algorithm.
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verifies a plain password against a hashed one."""
return pwd_context.verify(plain_password, hashed_password)

def get_password_hash(password: str) -> str:
"""Hashes a plain password."""
return pwd_context.hash(password)

# --- JWT Token Logic ---
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
"""Creates a new JWT access token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)

to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
36 changes: 36 additions & 0 deletions backend/create_admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# backend/create_admin.py
import os
from database import SessionLocal, engine
from models import User, Base
from auth.security import get_password_hash
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

ADMIN_USERNAME = os.getenv("ADMIN_USER", "admin")
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")

def create_admin_user():
if not ADMIN_PASSWORD:
logger.error("ADMIN_PASSWORD environment variable not set. Cannot create admin user.")
return

db = SessionLocal()
Base.metadata.create_all(bind=engine) # Ensure the 'users' table exists

# Check if user already exists
if db.query(User).filter(User.username == ADMIN_USERNAME).first():
logger.warning(f"Admin user '{ADMIN_USERNAME}' already exists.")
else:
# Create a new user with a hashed password
hashed_password = get_password_hash(ADMIN_PASSWORD)
admin_user = User(username=ADMIN_USERNAME, hashed_password=hashed_password)
db.add(admin_user)
db.commit()
logger.info(f"Admin user '{ADMIN_USERNAME}' created successfully.")

db.close()

if __name__ == "__main__":
create_admin_user()
16 changes: 11 additions & 5 deletions backend/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@ def get_database_url():
return "postgresql://admin:password123@db:5432/wnba_db"

DATABASE_URL = get_database_url()

# --- Use the DATABASE_URL variable ---
engine = create_engine(DATABASE_URL)

connect_args = {"check_same_thread": False} if "sqlite" in DATABASE_URL else {}
engine = create_engine(DATABASE_URL, connect_args=connect_args)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
Base = declarative_base()

def get_db():
"""Dependency to get a database session."""
db = SessionLocal()
try:
yield db
finally:
db.close()
96 changes: 55 additions & 41 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session
from pydantic import BaseModel, ConfigDict
from typing import List
from typing import List, Annotated

import joblib
import pandas as pd
Expand All @@ -17,6 +17,9 @@
# Import your SQLAlchemy models and session management
import models
import database
from database import get_db
from auth.router import router as auth_router # Import our new auth router
from auth.router import get_current_user # Import our new dependency

# --- The Lifespan function now loads the model ---
@asynccontextmanager
Expand All @@ -41,9 +44,12 @@ async def lifespan(app: FastAPI):
logger.info("Application shutdown.")

app = FastAPI(lifespan=lifespan)
app.include_router(auth_router, prefix="/auth", tags=["Authentication"])

origins = ["http://localhost:3000",
"https://wnba-frontend-service-776933261932.us-west1.run.app"]
origins = [
"http://localhost:3000",
"https://wnba-frontend-service-776933261932.us-west1.run.app"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
Expand All @@ -52,14 +58,6 @@ async def lifespan(app: FastAPI):
allow_headers=["*"],
)

# Dependency to get a database session for each request
def get_db():
db = database.SessionLocal()
try:
yield db
finally:
db.close()

# --- Pydantic Schemas ---
class PlayerStatBase(BaseModel):
season: str
Expand Down Expand Up @@ -95,44 +93,24 @@ class Player(PlayerBase):
stats: List[PlayerStat] = []
model_config = ConfigDict(from_attributes=True)

# ---- API ENDPOINTS ----
@app.get("/api")
def read_root():
return {"message": "WNBA Analytics API is running!"}
class UserOut(BaseModel):
id: int
username: str
model_config = ConfigDict(from_attributes=True)

# ---- PROTECTED API ENDPOINTS ----
# Endpoint to CREATE a new player
@app.post("/api/players", response_model=Player)
def create_player(player: PlayerCreate, db: Session = Depends(get_db)):
def create_player(player: PlayerCreate, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)):
new_player = models.Player(**player.model_dump())
db.add(new_player)
db.commit()
db.refresh(new_player)
return new_player

# Endpoint to READ all players
@app.get("/api/players", response_model=List[Player])
def get_players(db: Session = Depends(get_db)):
players = db.query(models.Player).all()
return players

@app.get("/api/players/{player_id}", response_model=Player)
def get_player(player_id: int, db: Session = Depends(get_db)):
player = db.query(models.Player).filter(models.Player.id == player_id).first()
if player is None: raise HTTPException(status_code=404, detail="Player not found")
return player

# Endpoint to DELETE a player
@app.delete("/api/players/{player_id}")
def delete_player(player_id: int, db: Session = Depends(get_db)):
player_to_delete = db.query(models.Player).filter(models.Player.id == player_id).first()
if player_to_delete is None: raise HTTPException(status_code=404, detail="Player not found")
db.delete(player_to_delete)
db.commit()
return {"message": "Player deleted successfully"}

# Endpoint to UPDATE a player
# Endpoint to UPDATE a player
@app.put("/api/players/{player_id}", response_model=Player)
def update_player(player_id: int, player_update: PlayerCreate, db: Session = Depends(get_db)):
def update_player(player_id: int, player_update: PlayerCreate, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)):
player_to_update = db.query(models.Player).filter(models.Player.id == player_id).first()
if player_to_update is None: raise HTTPException(status_code=404, detail="Player not found")

Expand All @@ -144,9 +122,17 @@ def update_player(player_id: int, player_update: PlayerCreate, db: Session = Dep
db.refresh(player_to_update)
return player_to_update

# Endpoint to DELETE a player
@app.delete("/api/players/{player_id}")
def delete_player(player_id: int, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)):
player_to_delete = db.query(models.Player).filter(models.Player.id == player_id).first()
if player_to_delete is None: raise HTTPException(status_code=404, detail="Player not found")
db.delete(player_to_delete)
db.commit()
return {"message": "Player deleted successfully"}

@app.post("/api/players/{player_id}/stats", response_model=PlayerStat)
def create_stats_for_player(player_id: int, stat: PlayerStatCreate,
db: Session = Depends(get_db)):
def create_stats_for_player(player_id: int, stat: PlayerStatCreate, current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)):
db_player = db.query(models.Player).filter(models.Player.id == player_id).first()
if db_player is None: raise HTTPException(status_code=404, detail="Player not found")
db_stat = models.PlayerStat(**stat.model_dump(), player_id=player_id)
Expand All @@ -155,6 +141,34 @@ def create_stats_for_player(player_id: int, stat: PlayerStatCreate,
db.refresh(db_stat)
return db_stat

@app.get("/users", response_model=List[UserOut])
def read_users(current_user: Annotated[models.User, Depends(get_current_user)], db: Session = Depends(get_db)):
"""
Retrieves a list of all users.
This is a protected endpoint that requires authentication.
The `response_model` ensures that only the fields from `UserOut` (id, username)
are returned, protecting the hashed password.
"""
users = db.query(models.User).all()
return users

# ---- PUBLIC API ENDPOINTS ----
@app.get("/api")
def read_root():
return {"message": "WNBA Analytics API is running!"}

# Endpoint to READ all players
@app.get("/api/players", response_model=List[Player])
def get_players(db: Session = Depends(get_db)):
players = db.query(models.Player).all()
return players

@app.get("/api/players/{player_id}", response_model=Player)
def get_player(player_id: int, db: Session = Depends(get_db)):
player = db.query(models.Player).filter(models.Player.id == player_id).first()
if player is None: raise HTTPException(status_code=404, detail="Player not found")
return player

# A Pydantic schema for the similarity response
class SimilarPlayer(BaseModel):
player_season_id: str
Expand Down
8 changes: 7 additions & 1 deletion backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,10 @@ class PlayerStat(Base):
player_efficiency_rating = Column(Float)

player_id = Column(Integer, ForeignKey("players.id"))
player = relationship("Player", back_populates="stats")
player = relationship("Player", back_populates="stats")

class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True)
hashed_password = Column(String)
4 changes: 3 additions & 1 deletion backend/pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# backend/pytest.ini
[pytest]
pythonpath = .
pythonpath = .
# This tells pytest to ignore any DeprecationWarning originating from the passlib library
filterwarnings = ignore::DeprecationWarning:passlib.*
Loading