iamomtiwari's picture
Update app.py
f81e803 verified
Raw
History Blame Contribute Delete
1.77 kB
import gradio as gr
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
from PIL import Image
from huggingface_hub import hf_hub_download
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet50(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 14) # Adjust for 14 classes
model_path = hf_hub_download(repo_id="iamomtiwari/resnet50-crop-disease", filename="resnet50_model_hf.pt")
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Class labels
class_labels = [
"Corn___Common_Rust", "Corn___Gray_Leaf_Spot", "Corn___Healthy", "Corn___Northern_Leaf_Blight",
"Rice___Brown_Spot", "Rice___Healthy", "Rice___Leaf_Blast", "Rice___Neck_Blast",
"Wheat___Brown_Rust", "Wheat___Healthy", "Wheat___Yellow_Rust",
"Sugarcane__Red_Rot", "Sugarcane__Healthy", "Sugarcane__Bacterial Blight"
]
# Prediction function
def predict(image):
try:
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image)
_, predicted_class = torch.max(outputs, 1)
return class_labels[predicted_class.item()]
except Exception as e:
return f"Error: {str(e)}"
# Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="Crop Disease Classification",
description="Upload an image to classify crop diseases using ResNet-50."
)
interface.launch()