Update app.py
Browse files
app.py
CHANGED
|
@@ -16,26 +16,45 @@ from safetensors.torch import load_file
|
|
| 16 |
import pandas as pd
|
| 17 |
from sklearn.preprocessing import LabelEncoder
|
| 18 |
|
|
|
|
| 19 |
df = pd.read_csv('./animal_dataset_vi.csv')
|
| 20 |
label_encoder = LabelEncoder()
|
| 21 |
label_encoder.fit(df['answer'].astype(str))
|
| 22 |
num_classes = len(label_encoder.classes_)
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
["./animals/animals/bat/bat_52.jpg", "Con vật này có màu sẫm hay nhạt?"]
|
| 30 |
]
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
class VQAModel(nn.Module):
|
| 33 |
def __init__(self, num_classes):
|
| 34 |
super(VQAModel, self).__init__()
|
|
|
|
| 35 |
self.image_encoder = nn.Sequential(*list(models.resnet50(weights=None).children())[:-1])
|
| 36 |
self.img_proj = nn.Linear(2048, 512)
|
|
|
|
|
|
|
| 37 |
self.text_encoder = AutoModel.from_pretrained("vinai/phobert-base-v2")
|
| 38 |
self.text_proj = nn.Linear(768, 512)
|
|
|
|
|
|
|
| 39 |
self.classifier = nn.Sequential(
|
| 40 |
nn.LayerNorm(512),
|
| 41 |
nn.Dropout(0.4),
|
|
@@ -48,11 +67,14 @@ class VQAModel(nn.Module):
|
|
| 48 |
def forward(self, images, input_ids, attention_mask):
|
| 49 |
img_features = self.image_encoder(images).flatten(start_dim=1)
|
| 50 |
img_features = self.img_proj(img_features)
|
|
|
|
| 51 |
text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 52 |
text_features = self.text_proj(text_outputs.pooler_output)
|
|
|
|
| 53 |
combined_features = img_features * text_features
|
| 54 |
return self.classifier(combined_features)
|
| 55 |
|
|
|
|
| 56 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
model = VQAModel(num_classes).to(device)
|
| 58 |
|
|
@@ -61,6 +83,7 @@ if os.path.exists(model_path):
|
|
| 61 |
model.load_state_dict(load_file(model_path))
|
| 62 |
model.eval()
|
| 63 |
|
|
|
|
| 64 |
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
|
| 65 |
transform = transforms.Compose([
|
| 66 |
transforms.Resize((224, 224)),
|
|
@@ -68,6 +91,7 @@ transform = transforms.Compose([
|
|
| 68 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 69 |
])
|
| 70 |
|
|
|
|
| 71 |
def predict_vqa(image, question):
|
| 72 |
if image is None or question.strip() == "":
|
| 73 |
return "Please provide both an image and a question."
|
|
@@ -78,6 +102,7 @@ def predict_vqa(image, question):
|
|
| 78 |
segmented_question, truncation=True, padding='max_length',
|
| 79 |
max_length=64, return_tensors='pt'
|
| 80 |
)
|
|
|
|
| 81 |
with torch.no_grad():
|
| 82 |
outputs = model(
|
| 83 |
image_tensor,
|
|
@@ -85,20 +110,25 @@ def predict_vqa(image, question):
|
|
| 85 |
encoding['attention_mask'].to(device)
|
| 86 |
)
|
| 87 |
_, predicted_id = torch.max(outputs, 1)
|
|
|
|
| 88 |
answer = label_encoder.inverse_transform([predicted_id.item()])[0]
|
| 89 |
return answer.capitalize()
|
| 90 |
except Exception as e:
|
| 91 |
return f"Error: {str(e)}"
|
| 92 |
|
|
|
|
| 93 |
demo = gr.Interface(
|
| 94 |
fn=predict_vqa,
|
| 95 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
| 96 |
outputs=gr.Textbox(label="Answer"),
|
| 97 |
examples=examples_list,
|
| 98 |
title="Vi-VQA Animal",
|
| 99 |
-
theme=gr.themes.Default(primary_hue="orange")
|
| 100 |
-
allow_flagging="never"
|
| 101 |
)
|
| 102 |
|
|
|
|
| 103 |
if __name__ == "__main__":
|
| 104 |
-
demo.launch()
|
|
|
|
| 16 |
import pandas as pd
|
| 17 |
from sklearn.preprocessing import LabelEncoder
|
| 18 |
|
| 19 |
+
# 1. LOAD DATASET AND LABEL ENCODER
|
| 20 |
df = pd.read_csv('./animal_dataset_vi.csv')
|
| 21 |
label_encoder = LabelEncoder()
|
| 22 |
label_encoder.fit(df['answer'].astype(str))
|
| 23 |
num_classes = len(label_encoder.classes_)
|
| 24 |
|
| 25 |
+
# 2. PREPARE 3 RANDOM SAMPLES
|
| 26 |
+
custom_questions = [
|
| 27 |
+
"Con vật trong hình là con gì?",
|
| 28 |
+
"Màu sắc chủ đạo của con vật này là gì?",
|
| 29 |
+
"Con này sống ở đâu?"
|
|
|
|
| 30 |
]
|
| 31 |
|
| 32 |
+
# Randomly select 3 images from the dataset
|
| 33 |
+
df_samples = df.sample(n=3)
|
| 34 |
+
|
| 35 |
+
examples_list = []
|
| 36 |
+
for i, (_, row) in enumerate(df_samples.iterrows()):
|
| 37 |
+
# Convert absolute paths from your dataset to relative paths for Hugging Face
|
| 38 |
+
img_path = row['image_path'].replace(
|
| 39 |
+
"animal_dataset/animals/animals",
|
| 40 |
+
"./animals/animals"
|
| 41 |
+
)
|
| 42 |
+
# Pair the random image with a fixed question
|
| 43 |
+
examples_list.append([img_path, custom_questions[i]])
|
| 44 |
+
|
| 45 |
+
# 3. INITIALIZE MODEL ARCHITECTURE
|
| 46 |
class VQAModel(nn.Module):
|
| 47 |
def __init__(self, num_classes):
|
| 48 |
super(VQAModel, self).__init__()
|
| 49 |
+
# Image Feature Extractor (ResNet50)
|
| 50 |
self.image_encoder = nn.Sequential(*list(models.resnet50(weights=None).children())[:-1])
|
| 51 |
self.img_proj = nn.Linear(2048, 512)
|
| 52 |
+
|
| 53 |
+
# Text Feature Extractor (PhoBERT)
|
| 54 |
self.text_encoder = AutoModel.from_pretrained("vinai/phobert-base-v2")
|
| 55 |
self.text_proj = nn.Linear(768, 512)
|
| 56 |
+
|
| 57 |
+
# Classification Head
|
| 58 |
self.classifier = nn.Sequential(
|
| 59 |
nn.LayerNorm(512),
|
| 60 |
nn.Dropout(0.4),
|
|
|
|
| 67 |
def forward(self, images, input_ids, attention_mask):
|
| 68 |
img_features = self.image_encoder(images).flatten(start_dim=1)
|
| 69 |
img_features = self.img_proj(img_features)
|
| 70 |
+
|
| 71 |
text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 72 |
text_features = self.text_proj(text_outputs.pooler_output)
|
| 73 |
+
|
| 74 |
combined_features = img_features * text_features
|
| 75 |
return self.classifier(combined_features)
|
| 76 |
|
| 77 |
+
# Setup device and load weights
|
| 78 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 79 |
model = VQAModel(num_classes).to(device)
|
| 80 |
|
|
|
|
| 83 |
model.load_state_dict(load_file(model_path))
|
| 84 |
model.eval()
|
| 85 |
|
| 86 |
+
# Initialize text tokenizer and image transformations
|
| 87 |
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
|
| 88 |
transform = transforms.Compose([
|
| 89 |
transforms.Resize((224, 224)),
|
|
|
|
| 91 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 92 |
])
|
| 93 |
|
| 94 |
+
# 4. INFERENCE FUNCTION
|
| 95 |
def predict_vqa(image, question):
|
| 96 |
if image is None or question.strip() == "":
|
| 97 |
return "Please provide both an image and a question."
|
|
|
|
| 102 |
segmented_question, truncation=True, padding='max_length',
|
| 103 |
max_length=64, return_tensors='pt'
|
| 104 |
)
|
| 105 |
+
|
| 106 |
with torch.no_grad():
|
| 107 |
outputs = model(
|
| 108 |
image_tensor,
|
|
|
|
| 110 |
encoding['attention_mask'].to(device)
|
| 111 |
)
|
| 112 |
_, predicted_id = torch.max(outputs, 1)
|
| 113 |
+
|
| 114 |
answer = label_encoder.inverse_transform([predicted_id.item()])[0]
|
| 115 |
return answer.capitalize()
|
| 116 |
except Exception as e:
|
| 117 |
return f"Error: {str(e)}"
|
| 118 |
|
| 119 |
+
# 5. GRADIO INTERFACE
|
| 120 |
demo = gr.Interface(
|
| 121 |
fn=predict_vqa,
|
| 122 |
+
inputs=[
|
| 123 |
+
gr.Image(type="pil", label="Image"),
|
| 124 |
+
gr.Textbox(lines=2, label="Question")
|
| 125 |
+
],
|
| 126 |
outputs=gr.Textbox(label="Answer"),
|
| 127 |
examples=examples_list,
|
| 128 |
title="Vi-VQA Animal",
|
| 129 |
+
theme=gr.themes.Default(primary_hue="orange")
|
|
|
|
| 130 |
)
|
| 131 |
|
| 132 |
+
# Launch the web app
|
| 133 |
if __name__ == "__main__":
|
| 134 |
+
demo.launch()
|