| from fastapi import FastAPI, File, UploadFile
|
| from fastapi.responses import JSONResponse
|
| from fastapi.staticfiles import StaticFiles
|
| from fastapi.responses import FileResponse
|
| import numpy as np
|
| from tensorflow.keras.models import load_model
|
| from tensorflow.keras.preprocessing.image import load_img, img_to_array
|
| from PIL import Image
|
| import io
|
| import tensorflow as tf
|
| from tensorflow.keras.layers import Dense
|
|
|
| app = FastAPI(title="Clasificador de Vehículos")
|
|
|
| class CompatDense(Dense):
|
| def __init__(self, *Args, quantization_config=None, **kwargs):
|
|
|
| kwargs.pop('quantization_config', None)
|
| super().__init__(*Args, **kwargs)
|
|
|
|
|
| model = load_model(
|
| 'modelo/modelo.h5',
|
| custom_objects={'Dense': CompatDense}
|
| )
|
|
|
| CLASS_NAMES = ['airplane', 'car', 'ship']
|
|
|
| def preprocess_image(image: Image.Image):
|
| """
|
| Preprocesa la imagen para que coincida con el entrenamiento.
|
| """
|
| image = image.resize((150, 150))
|
| img_array = img_to_array(image)
|
| img_array = img_array / 255.0
|
| img_array = np.expand_dims(img_array, axis=0)
|
| return img_array
|
|
|
| @app.post("/predict")
|
| async def predict(file: UploadFile = File(...)):
|
| """
|
| Recibe una imagen y devuelve la clase predicha con su confianza.
|
| """
|
|
|
| contents = await file.read()
|
| try:
|
|
|
| image = Image.open(io.BytesIO(contents)).convert('RGB')
|
| except Exception:
|
| return JSONResponse(
|
| content={"error": "No se pudo leer la imagen. Asegúrate de enviar un archivo válido."},
|
| status_code=400
|
| )
|
|
|
|
|
| processed = preprocess_image(image)
|
| predictions = model.predict(processed)[0]
|
|
|
| predicted_idx = np.argmax(predictions)
|
| label = CLASS_NAMES[predicted_idx]
|
| confidence = float(predictions[predicted_idx])
|
|
|
|
|
| label_es = {"airplane": "Aéreo (Avión)", "car": "Terrestre (Coche)", "ship": "Marítimo (Barco)"}
|
|
|
| return {
|
| "prediccion": label_es[label],
|
| "confianza": round(confidence, 4),
|
| "probabilidades": {
|
| "airplane": round(float(predictions[0]), 4),
|
| "car": round(float(predictions[1]), 4),
|
| "ship": round(float(predictions[2]), 4)
|
| }
|
| }
|
|
|
| app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
| @app.get("/")
|
| def read_index():
|
| return FileResponse("static/index.html")
|
|
|
| if __name__ == "__main__":
|
| import uvicorn
|
| uvicorn.run(app, host="0.0.0.0", port=8000) |