File size: 3,991 Bytes
27e4f2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
---
language: en
license: mit
tags:
- image-classification
- coffee
- hybrid-model
- cnn
- transformer
- resnet
- vision-transformer
datasets:
- USK-Coffee
metrics:
- accuracy
- precision
- recall
- f1
library_name: pytorch
---

# Green Arabica Coffee Bean Classification (R50-V Hybrid Model)

## Model Description

This model is a hybrid deep learning architecture combining **ResNet-50** (CNN-based) and **Vision Transformer** for classifying green arabica coffee beans into four quality classes. The model achieved **91.88% accuracy** on the USK-Coffee dataset, outperforming previous single-model approaches.

**Model Type:** Hybrid CNN-Transformer  
**Architecture:** ResNet-50 + Vision Transformer (R50-V)  
**Task:** Multi-class Image Classification  
**Classes:** 4 (Peaberry, Longberry, Premium, Defect)

## Model Architecture

The hybrid model processes input images through two parallel branches:

1. **CNN Branch (ResNet-50)**: Extracts local visual features such as texture, edges, and patterns
2. **Transformer Branch (Vision Transformer)**: Captures global context and long-range dependencies

Features from both branches are concatenated and passed through a classification head.

## How to Use

### Installation

```bash
pip install torch torchvision pillow huggingface_hub
```

### Code

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import (
    resnet50, vit_b_16,
    ResNet50_Weights, ViT_B_16_Weights
)
from huggingface_hub import hf_hub_download

class HybridModel(nn.Module):
    def __init__(self, num_classes=4):
        super(HybridModel, self).__init__()
        self.resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        resnet_in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Identity()

        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        vit_in_features = self.vit.heads.head.in_features
        self.vit.heads = nn.Identity()

        self.fc = nn.Sequential(
            nn.Linear(resnet_in_features + vit_in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x_resnet, x_vit = x
        local_features = self.resnet(x_resnet)
        global_features = self.vit(x_vit)
        combined_features = torch.cat((local_features, global_features), dim=1)
        return self.fc(combined_features)

def load_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model_path = hf_hub_download(
        repo_id="kelvinandreas/Green_Arabica_Coffee_Bean_Classification_R50-V",
        filename="best_model.pth"
    )

    model = HybridModel(num_classes=4)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    return model, device

def preprocess_image(image_path, device):
    image = Image.open(image_path).convert("RGB")

    preprocess_resnet = ResNet50_Weights.IMAGENET1K_V1.transforms()
    preprocess_vit = ViT_B_16_Weights.IMAGENET1K_V1.transforms()

    x_resnet = preprocess_resnet(image).unsqueeze(0).to(device)
    x_vit = preprocess_vit(image).unsqueeze(0).to(device)

    return x_resnet, x_vit

def predict(image_path):
    CLASS_LABELS = ['Defect', 'Longberry', 'Peaberry', 'Premium']

    model, device = load_model()
    x_resnet, x_vit = preprocess_image(image_path, device)

    with torch.no_grad():
        logits = model((x_resnet, x_vit))

    probs = F.softmax(logits, dim=1)
    pred_idx = torch.argmax(probs, dim=1).item()
    confidence = probs[0][pred_idx].item()

    return {
        "prediction": CLASS_LABELS[pred_idx],
        "confidence_percent": round(confidence * 100, 2)
    }

image_path = "your_coffee_bean_image.jpg"
result = predict(image_path)
print(result)
```

## Evaluation Results

| Metric | Score |
|--------|-------|
| **Accuracy** | **91.88%** |
| **Precision** | 92.12% |
| **Recall** | 91.88% |
| **F1-Score** | 91.85% |