import os os.environ["GRADIO_ALLOWED_PATHS"] = os.path.abspath("./") import warnings warnings.filterwarnings("ignore") import transformers transformers.logging.set_verbosity_error() import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.models as models from transformers import AutoTokenizer, AutoModel from PIL import Image from pyvi import ViTokenizer from safetensors.torch import load_file import pandas as pd from sklearn.preprocessing import LabelEncoder # 1. LOAD DATASET AND LABEL ENCODER df = pd.read_csv('./animal_dataset_vi.csv') label_encoder = LabelEncoder() label_encoder.fit(df['answer'].astype(str)) num_classes = len(label_encoder.classes_) # 2. PREPARE 3 RANDOM SAMPLES (Chỉ lấy những ảnh thực sự có trên server) custom_questions = [ "Con vật trong hình là con gì?", "Màu sắc chủ đạo của con vật này là gì?", "Con này sống ở đâu?" ] # Trộn ngẫu nhiên toàn bộ dataset df_shuffled = df.sample(frac=1) examples_list = [] count = 0 for _, row in df_shuffled.iterrows(): if count >= 3: break # Đã nhặt đủ 3 mẫu thì dừng # Chuyển đường dẫn cho khớp với server Hugging Face img_path = row['image_path'].replace( "animal_dataset/animals/animals", "./animals/animals" ) # CHỈ THÊM VÀO NẾU ẢNH THỰC SỰ ĐƯỢC BẠN UPLOAD LÊN SERVER if os.path.exists(img_path): examples_list.append([img_path, custom_questions[count]]) count += 1 if len(examples_list) == 0: print("⚠️ LƯU Ý: Không tìm thấy ảnh nào trong thư mục ./animals/animals. Hãy chắc chắn bạn đã upload ảnh vào đúng thư mục nhé!") # 3. INITIALIZE MODEL ARCHITECTURE class VQAModel(nn.Module): def __init__(self, num_classes): super(VQAModel, self).__init__() self.image_encoder = nn.Sequential(*list(models.resnet50(weights=None).children())[:-1]) self.img_proj = nn.Linear(2048, 512) self.text_encoder = AutoModel.from_pretrained("vinai/phobert-base-v2") self.text_proj = nn.Linear(768, 512) self.classifier = nn.Sequential( nn.LayerNorm(512), nn.Dropout(0.4), nn.Linear(512, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, num_classes) ) def forward(self, images, input_ids, attention_mask): img_features = self.image_encoder(images).flatten(start_dim=1) img_features = self.img_proj(img_features) text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) text_features = self.text_proj(text_outputs.pooler_output) combined_features = img_features * text_features return self.classifier(combined_features) # Setup device and load weights device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = VQAModel(num_classes).to(device) model_path = './vqa_resnet50_phobert.safetensors' if os.path.exists(model_path): model.load_state_dict(load_file(model_path)) model.eval() # Initialize text tokenizer and image transformations tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2") transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 4. INFERENCE FUNCTION def predict_vqa(image, question): if image is None or question.strip() == "": return "Please provide both an image and a question." try: image_tensor = transform(image.convert('RGB')).unsqueeze(0).to(device) segmented_question = ViTokenizer.tokenize(question) encoding = tokenizer( segmented_question, truncation=True, padding='max_length', max_length=64, return_tensors='pt' ) with torch.no_grad(): outputs = model( image_tensor, encoding['input_ids'].to(device), encoding['attention_mask'].to(device) ) _, predicted_id = torch.max(outputs, 1) answer = label_encoder.inverse_transform([predicted_id.item()])[0] return answer.capitalize() except Exception as e: return f"Error: {str(e)}" # 5. GRADIO INTERFACE demo = gr.Interface( fn=predict_vqa, inputs=[ gr.Image(type="pil", label="Image"), gr.Textbox(lines=2, label="Question") ], outputs=gr.Textbox(label="Answer"), examples=examples_list, cache_examples=False, title="Vi-VQA Animal", theme=gr.themes.Default(primary_hue="orange") ) # Launch with explicit paths allowed if __name__ == "__main__": demo.launch(allowed_paths=[os.path.abspath("./")])