-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFastAPI_garbage.py
More file actions
62 lines (52 loc) · 1.76 KB
/
Copy pathFastAPI_garbage.py
File metadata and controls
62 lines (52 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import uvicorn
import numpy as np
import tensorflow as tf
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from tensorflow.keras.preprocessing import image
import io
from PIL import Image
# -------------------------------
# Loading trained model
# -------------------------------
MODEL_PATH = "garbage_mobilenet_V3_finetuned.h5"
model = tf.keras.models.load_model(MODEL_PATH)
# class names
class_names = ["Bio", "metal", "paper", "plastic"]
# -------------------------------
# FastAPI app setup
# -------------------------------
app = FastAPI()
# Enabling CORS (so frontend can calls API)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], #restricting this to ["http://127.0.0.1:5500"] --->ip address
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -------------------------------
# Prediction endpoint
# -------------------------------
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
# Read image file
contents = await file.read()
img = Image.open(io.BytesIO(contents)).convert("RGB")
img = img.resize((224, 224)) # MobileNet input size
# Preprocess image
img_array = image.img_to_array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
# Prediction
preds = model.predict(img_array, verbose=0)
pred_class = class_names[np.argmax(preds)]
confidence = float(np.max(preds) * 100)
return {"class": pred_class, "confidence": confidence}
except Exception as e:
return {"error": str(e)}
# -------------------------------
# Running server
# -------------------------------
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8000)