medpmc-clip-l-14_jun24_v1 / inference_example.py
Hyunjae Kim
Initial release of MedPMC-CLIP
7d8fbac
Raw
History Blame Contribute Delete
1.08 kB
import torch
import open_clip
from safetensors.torch import load_file
from PIL import Image
model_name = "ViT-L-14"
checkpoint_path = "open_clip_pytorch_model.safetensors"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms(
model_name,
pretrained=None,
)
state_dict = load_file(checkpoint_path)
model.load_state_dict(state_dict, strict=True)
model = model.to(device)
model.eval()
tokenizer = open_clip.get_tokenizer(model_name)
image = preprocess(Image.open("example.jpg").convert("RGB")).unsqueeze(0).to(device)
texts = tokenizer([
"chest radiograph",
"fundus photograph",
"histopathology image",
]).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(texts)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
similarity = image_features @ text_features.T
probs = similarity.softmax(dim=-1)
print(probs)