| import os |
| os.environ["GRADIO_ALLOWED_PATHS"] = os.path.abspath("./") |
|
|
| import warnings |
| warnings.filterwarnings("ignore") |
| import transformers |
| transformers.logging.set_verbosity_error() |
|
|
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as transforms |
| import torchvision.models as models |
| from transformers import AutoTokenizer, AutoModel |
| from PIL import Image |
| from pyvi import ViTokenizer |
| from safetensors.torch import load_file |
| import pandas as pd |
| from sklearn.preprocessing import LabelEncoder |
|
|
| |
| df = pd.read_csv('./animal_dataset_vi.csv') |
| label_encoder = LabelEncoder() |
| label_encoder.fit(df['answer'].astype(str)) |
| num_classes = len(label_encoder.classes_) |
|
|
| |
| custom_questions = [ |
| "Con vật trong hình là con gì?", |
| "Màu sắc chủ đạo của con vật này là gì?", |
| "Con này sống ở đâu?" |
| ] |
|
|
| |
| df_samples = df.sample(n=3) |
|
|
| examples_list = [] |
| for i, (_, row) in enumerate(df_samples.iterrows()): |
| |
| img_path = row['image_path'].replace( |
| "animal_dataset/animals/animals", |
| "./animals/animals" |
| ) |
| |
| if os.path.exists(img_path): |
| examples_list.append([img_path, custom_questions[i]]) |
|
|
| |
| class VQAModel(nn.Module): |
| def __init__(self, num_classes): |
| super(VQAModel, self).__init__() |
| self.image_encoder = nn.Sequential(*list(models.resnet50(weights=None).children())[:-1]) |
| self.img_proj = nn.Linear(2048, 512) |
| |
| self.text_encoder = AutoModel.from_pretrained("vinai/phobert-base-v2") |
| self.text_proj = nn.Linear(768, 512) |
| |
| self.classifier = nn.Sequential( |
| nn.LayerNorm(512), |
| nn.Dropout(0.4), |
| nn.Linear(512, 512), |
| nn.ReLU(), |
| nn.Dropout(0.4), |
| nn.Linear(512, num_classes) |
| ) |
|
|
| def forward(self, images, input_ids, attention_mask): |
| img_features = self.image_encoder(images).flatten(start_dim=1) |
| img_features = self.img_proj(img_features) |
| |
| text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) |
| text_features = self.text_proj(text_outputs.pooler_output) |
| |
| combined_features = img_features * text_features |
| return self.classifier(combined_features) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = VQAModel(num_classes).to(device) |
|
|
| model_path = './vqa_resnet50_phobert.safetensors' |
| if os.path.exists(model_path): |
| model.load_state_dict(load_file(model_path)) |
| model.eval() |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2") |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| def predict_vqa(image, question): |
| if image is None or question.strip() == "": |
| return "Please provide both an image and a question." |
| try: |
| image_tensor = transform(image.convert('RGB')).unsqueeze(0).to(device) |
| segmented_question = ViTokenizer.tokenize(question) |
| encoding = tokenizer( |
| segmented_question, truncation=True, padding='max_length', |
| max_length=64, return_tensors='pt' |
| ) |
| |
| with torch.no_grad(): |
| outputs = model( |
| image_tensor, |
| encoding['input_ids'].to(device), |
| encoding['attention_mask'].to(device) |
| ) |
| _, predicted_id = torch.max(outputs, 1) |
| |
| answer = label_encoder.inverse_transform([predicted_id.item()])[0] |
| return answer.capitalize() |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
| |
| demo = gr.Interface( |
| fn=predict_vqa, |
| inputs=[ |
| gr.Image(type="pil", label="Image"), |
| gr.Textbox(lines=2, label="Question") |
| ], |
| outputs=gr.Textbox(label="Answer"), |
| examples=examples_list, |
| cache_examples=False, |
| title="Vi-VQA Animal", |
| theme=gr.themes.Default(primary_hue="orange") |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch(allowed_paths=[os.path.abspath("./")]) |