File size: 3,991 Bytes
27e4f2b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | ---
language: en
license: mit
tags:
- image-classification
- coffee
- hybrid-model
- cnn
- transformer
- resnet
- vision-transformer
datasets:
- USK-Coffee
metrics:
- accuracy
- precision
- recall
- f1
library_name: pytorch
---
# Green Arabica Coffee Bean Classification (R50-V Hybrid Model)
## Model Description
This model is a hybrid deep learning architecture combining **ResNet-50** (CNN-based) and **Vision Transformer** for classifying green arabica coffee beans into four quality classes. The model achieved **91.88% accuracy** on the USK-Coffee dataset, outperforming previous single-model approaches.
**Model Type:** Hybrid CNN-Transformer
**Architecture:** ResNet-50 + Vision Transformer (R50-V)
**Task:** Multi-class Image Classification
**Classes:** 4 (Peaberry, Longberry, Premium, Defect)
## Model Architecture
The hybrid model processes input images through two parallel branches:
1. **CNN Branch (ResNet-50)**: Extracts local visual features such as texture, edges, and patterns
2. **Transformer Branch (Vision Transformer)**: Captures global context and long-range dependencies
Features from both branches are concatenated and passed through a classification head.
## How to Use
### Installation
```bash
pip install torch torchvision pillow huggingface_hub
```
### Code
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import (
resnet50, vit_b_16,
ResNet50_Weights, ViT_B_16_Weights
)
from huggingface_hub import hf_hub_download
class HybridModel(nn.Module):
def __init__(self, num_classes=4):
super(HybridModel, self).__init__()
self.resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet_in_features = self.resnet.fc.in_features
self.resnet.fc = nn.Identity()
self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
vit_in_features = self.vit.heads.head.in_features
self.vit.heads = nn.Identity()
self.fc = nn.Sequential(
nn.Linear(resnet_in_features + vit_in_features, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
x_resnet, x_vit = x
local_features = self.resnet(x_resnet)
global_features = self.vit(x_vit)
combined_features = torch.cat((local_features, global_features), dim=1)
return self.fc(combined_features)
def load_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = hf_hub_download(
repo_id="kelvinandreas/Green_Arabica_Coffee_Bean_Classification_R50-V",
filename="best_model.pth"
)
model = HybridModel(num_classes=4)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
return model, device
def preprocess_image(image_path, device):
image = Image.open(image_path).convert("RGB")
preprocess_resnet = ResNet50_Weights.IMAGENET1K_V1.transforms()
preprocess_vit = ViT_B_16_Weights.IMAGENET1K_V1.transforms()
x_resnet = preprocess_resnet(image).unsqueeze(0).to(device)
x_vit = preprocess_vit(image).unsqueeze(0).to(device)
return x_resnet, x_vit
def predict(image_path):
CLASS_LABELS = ['Defect', 'Longberry', 'Peaberry', 'Premium']
model, device = load_model()
x_resnet, x_vit = preprocess_image(image_path, device)
with torch.no_grad():
logits = model((x_resnet, x_vit))
probs = F.softmax(logits, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred_idx].item()
return {
"prediction": CLASS_LABELS[pred_idx],
"confidence_percent": round(confidence * 100, 2)
}
image_path = "your_coffee_bean_image.jpg"
result = predict(image_path)
print(result)
```
## Evaluation Results
| Metric | Score |
|--------|-------|
| **Accuracy** | **91.88%** |
| **Precision** | 92.12% |
| **Recall** | 91.88% |
| **F1-Score** | 91.85% | |