File size: 3,237 Bytes
b1406c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
from torchvision import transforms
import logging

logger = logging.getLogger(__name__)

class MultimodalFusion:
    """Computes image-text alignment using BiomedVLP."""
    IMAGE_SIZE = 224
    
    @staticmethod
    def get_image_transform():
        return transforms.Compose([
            transforms.Resize((MultimodalFusion.IMAGE_SIZE, MultimodalFusion.IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    @staticmethod
    def get_image_embedding(image_path: Path, model, device: str) -> torch.Tensor:
        image = Image.open(image_path).convert("RGB")
        transform = MultimodalFusion.get_image_transform()
        image_tensor = transform(image).unsqueeze(0).to(device)
        
        with torch.no_grad():
            embedding = model.get_image_embeddings(image_tensor)
        return F.normalize(embedding, p=2, dim=-1)
    
    @staticmethod
    def get_text_embedding(text: str, model, tokenizer, device: str) -> torch.Tensor:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256, padding="max_length").to(device)
        with torch.no_grad():
            embedding = model.get_text_embeddings(**inputs)
        return F.normalize(embedding, p=2, dim=-1)
    
    @staticmethod
    def compute_similarity(image_path: Path, text: str, model, tokenizer, device: str) -> tuple[float, str]:
        try:
            img_emb = MultimodalFusion.get_image_embedding(image_path, model, device)
            txt_emb = MultimodalFusion.get_text_embedding(text, model, tokenizer, device)
            
            similarity = float(torch.cosine_similarity(img_emb, txt_emb).item())
            similarity = (similarity + 1) / 2 # Shift to [0,1]
            
            if similarity >= 0.7:
                alignment = "HIGH"
            elif similarity >= 0.4:
                alignment = "MEDIUM"
            else:
                alignment = "LOW"
            return round(similarity, 3), alignment
        except Exception as e:
            logger.warning(f"Fusion similarity computation failed: {e}")
            return 0.5, "UNKNOWN"
            
    @staticmethod
    def get_fused_embedding(image_path: Path, text: str, model, tokenizer, device: str) -> torch.Tensor:
        img_emb = MultimodalFusion.get_image_embedding(image_path, model, device)
        txt_emb = MultimodalFusion.get_text_embedding(text, model, tokenizer, device)
        return torch.cat([img_emb, txt_emb], dim=-1)

class FallbackFusion:
    @staticmethod
    def compute_similarity(image_path: Path, text: str) -> tuple[float, str]:
        """Simple keyword-based fallback when BiomedVLP unavailable due to RAM constraints."""
        CHEST_KEYWORDS = ["chest", "lung", "cardiac", "pleural", "pneumo", "infiltrate", "opacity", "nodule", "effusion"]
        text_lower = text.lower()
        matches = sum(1 for kw in CHEST_KEYWORDS if kw in text_lower)
        score = min(0.9, 0.3 + matches * 0.1)
        alignment = "HIGH" if score > 0.6 else "MEDIUM" if score > 0.4 else "LOW"
        return score, alignment