iamomtiwari commited on
Commit
f81e803
·
verified ·
1 Parent(s): 02a0ca8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -3,15 +3,15 @@ import torch
3
  import torchvision.transforms as transforms
4
  from torchvision.models import resnet50
5
  from PIL import Image
6
-
7
- # Load the model
8
  from huggingface_hub import hf_hub_download
9
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model = resnet50(pretrained=False)
12
  model.fc = torch.nn.Linear(model.fc.in_features, 14) # Adjust for 14 classes
 
13
  model_path = hf_hub_download(repo_id="iamomtiwari/resnet50-crop-disease", filename="resnet50_model_hf.pt")
14
  model.load_state_dict(torch.load(model_path, map_location=device))
 
15
  model.eval()
16
 
17
  # Define image transformations
@@ -21,7 +21,7 @@ transform = transforms.Compose([
21
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
22
  ])
23
 
24
- # Class labels (adjust according to your dataset)
25
  class_labels = [
26
  "Corn___Common_Rust", "Corn___Gray_Leaf_Spot", "Corn___Healthy", "Corn___Northern_Leaf_Blight",
27
  "Rice___Brown_Spot", "Rice___Healthy", "Rice___Leaf_Blast", "Rice___Neck_Blast",
@@ -31,17 +31,20 @@ class_labels = [
31
 
32
  # Prediction function
33
  def predict(image):
34
- image = transform(image).unsqueeze(0).to(device) # Add batch dimension
35
- with torch.no_grad():
36
- outputs = model(image)
37
- _, predicted_class = torch.max(outputs, 1)
38
- return class_labels[predicted_class.item()]
 
 
 
39
 
40
  # Gradio interface
41
  interface = gr.Interface(
42
  fn=predict,
43
  inputs=gr.Image(type="pil"),
44
- outputs=gr.Label(),
45
  title="Crop Disease Classification",
46
  description="Upload an image to classify crop diseases using ResNet-50."
47
  )
 
3
  import torchvision.transforms as transforms
4
  from torchvision.models import resnet50
5
  from PIL import Image
 
 
6
  from huggingface_hub import hf_hub_download
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  model = resnet50(pretrained=False)
10
  model.fc = torch.nn.Linear(model.fc.in_features, 14) # Adjust for 14 classes
11
+
12
  model_path = hf_hub_download(repo_id="iamomtiwari/resnet50-crop-disease", filename="resnet50_model_hf.pt")
13
  model.load_state_dict(torch.load(model_path, map_location=device))
14
+ model.to(device)
15
  model.eval()
16
 
17
  # Define image transformations
 
21
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
22
  ])
23
 
24
+ # Class labels
25
  class_labels = [
26
  "Corn___Common_Rust", "Corn___Gray_Leaf_Spot", "Corn___Healthy", "Corn___Northern_Leaf_Blight",
27
  "Rice___Brown_Spot", "Rice___Healthy", "Rice___Leaf_Blast", "Rice___Neck_Blast",
 
31
 
32
  # Prediction function
33
  def predict(image):
34
+ try:
35
+ image = transform(image).unsqueeze(0).to(device)
36
+ with torch.no_grad():
37
+ outputs = model(image)
38
+ _, predicted_class = torch.max(outputs, 1)
39
+ return class_labels[predicted_class.item()]
40
+ except Exception as e:
41
+ return f"Error: {str(e)}"
42
 
43
  # Gradio interface
44
  interface = gr.Interface(
45
  fn=predict,
46
  inputs=gr.Image(type="pil"),
47
+ outputs=gr.Label(num_top_classes=3),
48
  title="Crop Disease Classification",
49
  description="Upload an image to classify crop diseases using ResNet-50."
50
  )