522H0134-NguyenNhatHuy commited on
Commit
9f08b12
·
verified ·
1 Parent(s): 825683b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -10
app.py CHANGED
@@ -16,26 +16,45 @@ from safetensors.torch import load_file
16
  import pandas as pd
17
  from sklearn.preprocessing import LabelEncoder
18
 
 
19
  df = pd.read_csv('./animal_dataset_vi.csv')
20
  label_encoder = LabelEncoder()
21
  label_encoder.fit(df['answer'].astype(str))
22
  num_classes = len(label_encoder.classes_)
23
 
24
- examples_list = [
25
- ["./animals/animals/zebra/zebra_18.jpg", "Bộ phận cơ thể riêng biệt của sinh vật này là gì?"],
26
- ["./animals/animals/hare/hare_20.jpg", "Con vật này có ở trong tự nhiên không?"],
27
- ["./animals/animals/hedgehog/hedgehog_56.jpg", "Con vật này màu gì?"],
28
- ["./animals/animals/lion/lion_37.jpg", "Động vật môi trường sống tự nhiên của nó không?"],
29
- ["./animals/animals/bat/bat_52.jpg", "Con vật này có màu sẫm hay nhạt?"]
30
  ]
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class VQAModel(nn.Module):
33
  def __init__(self, num_classes):
34
  super(VQAModel, self).__init__()
 
35
  self.image_encoder = nn.Sequential(*list(models.resnet50(weights=None).children())[:-1])
36
  self.img_proj = nn.Linear(2048, 512)
 
 
37
  self.text_encoder = AutoModel.from_pretrained("vinai/phobert-base-v2")
38
  self.text_proj = nn.Linear(768, 512)
 
 
39
  self.classifier = nn.Sequential(
40
  nn.LayerNorm(512),
41
  nn.Dropout(0.4),
@@ -48,11 +67,14 @@ class VQAModel(nn.Module):
48
  def forward(self, images, input_ids, attention_mask):
49
  img_features = self.image_encoder(images).flatten(start_dim=1)
50
  img_features = self.img_proj(img_features)
 
51
  text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
52
  text_features = self.text_proj(text_outputs.pooler_output)
 
53
  combined_features = img_features * text_features
54
  return self.classifier(combined_features)
55
 
 
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
  model = VQAModel(num_classes).to(device)
58
 
@@ -61,6 +83,7 @@ if os.path.exists(model_path):
61
  model.load_state_dict(load_file(model_path))
62
  model.eval()
63
 
 
64
  tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
65
  transform = transforms.Compose([
66
  transforms.Resize((224, 224)),
@@ -68,6 +91,7 @@ transform = transforms.Compose([
68
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
69
  ])
70
 
 
71
  def predict_vqa(image, question):
72
  if image is None or question.strip() == "":
73
  return "Please provide both an image and a question."
@@ -78,6 +102,7 @@ def predict_vqa(image, question):
78
  segmented_question, truncation=True, padding='max_length',
79
  max_length=64, return_tensors='pt'
80
  )
 
81
  with torch.no_grad():
82
  outputs = model(
83
  image_tensor,
@@ -85,20 +110,25 @@ def predict_vqa(image, question):
85
  encoding['attention_mask'].to(device)
86
  )
87
  _, predicted_id = torch.max(outputs, 1)
 
88
  answer = label_encoder.inverse_transform([predicted_id.item()])[0]
89
  return answer.capitalize()
90
  except Exception as e:
91
  return f"Error: {str(e)}"
92
 
 
93
  demo = gr.Interface(
94
  fn=predict_vqa,
95
- inputs=[gr.Image(type="pil", label="Image"), gr.Textbox(lines=2, label="Question")],
 
 
 
96
  outputs=gr.Textbox(label="Answer"),
97
  examples=examples_list,
98
  title="Vi-VQA Animal",
99
- theme=gr.themes.Default(primary_hue="orange"),
100
- allow_flagging="never"
101
  )
102
 
 
103
  if __name__ == "__main__":
104
- demo.launch()
 
16
  import pandas as pd
17
  from sklearn.preprocessing import LabelEncoder
18
 
19
+ # 1. LOAD DATASET AND LABEL ENCODER
20
  df = pd.read_csv('./animal_dataset_vi.csv')
21
  label_encoder = LabelEncoder()
22
  label_encoder.fit(df['answer'].astype(str))
23
  num_classes = len(label_encoder.classes_)
24
 
25
+ # 2. PREPARE 3 RANDOM SAMPLES
26
+ custom_questions = [
27
+ "Con vật trong hình con gì?",
28
+ "Màu sắc chủ đạo của con vật này gì?",
29
+ "Con này sốngđâu?"
 
30
  ]
31
 
32
+ # Randomly select 3 images from the dataset
33
+ df_samples = df.sample(n=3)
34
+
35
+ examples_list = []
36
+ for i, (_, row) in enumerate(df_samples.iterrows()):
37
+ # Convert absolute paths from your dataset to relative paths for Hugging Face
38
+ img_path = row['image_path'].replace(
39
+ "animal_dataset/animals/animals",
40
+ "./animals/animals"
41
+ )
42
+ # Pair the random image with a fixed question
43
+ examples_list.append([img_path, custom_questions[i]])
44
+
45
+ # 3. INITIALIZE MODEL ARCHITECTURE
46
  class VQAModel(nn.Module):
47
  def __init__(self, num_classes):
48
  super(VQAModel, self).__init__()
49
+ # Image Feature Extractor (ResNet50)
50
  self.image_encoder = nn.Sequential(*list(models.resnet50(weights=None).children())[:-1])
51
  self.img_proj = nn.Linear(2048, 512)
52
+
53
+ # Text Feature Extractor (PhoBERT)
54
  self.text_encoder = AutoModel.from_pretrained("vinai/phobert-base-v2")
55
  self.text_proj = nn.Linear(768, 512)
56
+
57
+ # Classification Head
58
  self.classifier = nn.Sequential(
59
  nn.LayerNorm(512),
60
  nn.Dropout(0.4),
 
67
  def forward(self, images, input_ids, attention_mask):
68
  img_features = self.image_encoder(images).flatten(start_dim=1)
69
  img_features = self.img_proj(img_features)
70
+
71
  text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
72
  text_features = self.text_proj(text_outputs.pooler_output)
73
+
74
  combined_features = img_features * text_features
75
  return self.classifier(combined_features)
76
 
77
+ # Setup device and load weights
78
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
  model = VQAModel(num_classes).to(device)
80
 
 
83
  model.load_state_dict(load_file(model_path))
84
  model.eval()
85
 
86
+ # Initialize text tokenizer and image transformations
87
  tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
88
  transform = transforms.Compose([
89
  transforms.Resize((224, 224)),
 
91
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
92
  ])
93
 
94
+ # 4. INFERENCE FUNCTION
95
  def predict_vqa(image, question):
96
  if image is None or question.strip() == "":
97
  return "Please provide both an image and a question."
 
102
  segmented_question, truncation=True, padding='max_length',
103
  max_length=64, return_tensors='pt'
104
  )
105
+
106
  with torch.no_grad():
107
  outputs = model(
108
  image_tensor,
 
110
  encoding['attention_mask'].to(device)
111
  )
112
  _, predicted_id = torch.max(outputs, 1)
113
+
114
  answer = label_encoder.inverse_transform([predicted_id.item()])[0]
115
  return answer.capitalize()
116
  except Exception as e:
117
  return f"Error: {str(e)}"
118
 
119
+ # 5. GRADIO INTERFACE
120
  demo = gr.Interface(
121
  fn=predict_vqa,
122
+ inputs=[
123
+ gr.Image(type="pil", label="Image"),
124
+ gr.Textbox(lines=2, label="Question")
125
+ ],
126
  outputs=gr.Textbox(label="Answer"),
127
  examples=examples_list,
128
  title="Vi-VQA Animal",
129
+ theme=gr.themes.Default(primary_hue="orange")
 
130
  )
131
 
132
+ # Launch the web app
133
  if __name__ == "__main__":
134
+ demo.launch()