import numpy as np import pydicom import torch import torch.nn as nn from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, PreTrainedModel, ) from .configuration import MRIBrainSequenceBERTConfig class MRIBrainSequenceBERT(PreTrainedModel): config_class = MRIBrainSequenceBERTConfig def __init__(self, config): super().__init__(config) self.llm = AutoModelForSequenceClassification.from_pretrained( "google/mobilebert-uncased" ) self.dim_feats = self.llm.classifier.in_features self.dropout = nn.Dropout(p=config.dropout) self.classifier = nn.Linear(self.dim_feats, config.num_classes) self.llm.dropout = nn.Identity() self.llm.classifier = nn.Identity() self.tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased") self.max_len = config.max_len self.metadata_elements = [ "SeriesDescription", "ImageType", "Manufacturer", "ManufacturerModelName", "ContrastBolusAgent", "ScanningSequence", "SequenceVariant", "ScanOptions", "MRAcquisitionType", "SequenceName", "AngioFlag", "SliceThickness", "RepetitionTime", "EchoTime", "InversionTime", "NumberOfAverages", "ImagingFrequency", "ImagedNucleus", "EchoNumbers", "SpacingBetweenSlices", "NumberOfPhaseEncodingSteps", "EchoTrainLength", "PercentSampling", "PercentPhaseFieldOfView", "PixelBandwidth", "ContrastBolusVolume", "ContrastBolusTotalDose", "AcquisitionMatrix", "InPlanePhaseEncodingDirection", "FlipAngle", "VariableFlipAngleFlag", "SAR", "dBdt", "SeriesNumber", "AcquisitionNumber", "PhotometricInterpretation", "PixelSpacing", "ImagesInAcquisition", "SmallestImagePixelValue", "LargestImagePixelValue", ] self.label2index = { "t1": 0, # T1 precontrast "t1c": 1, # T1 postcontrast "t2": 2, # T2 "flair": 3, # T2-FLAIR "dwi": 4, # DWI trace "adc": 5, # ADC map "dti": 6, # DTI "swi": 7, # SWI "swi_mip": 8, # SWI MinIP "phase": 9, # SWI phase images "mag": 10, # SWI mag images "gre": 11, # T2* GRE "perf": 12, # Perfusion-related images "pd": 13, # Proton density "loc": 14, # Localizers "other": 15, # Other, NOS } self.index2label = {v: k for k, v in self.label2index.items()} def forward( self, x: str, device: str | torch.device = "cpu", apply_softmax: bool = True ): x = self.tokenizer( x, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_len, ) for k, v in x.items(): x[k] = v.to(device) features = self.llm(**x)["logits"] logits = self.classifier(self.dropout(features)) if apply_softmax: logits = torch.softmax(logits, dim=1) return logits def create_string_from_dicom( self, ds: pydicom.Dataset | dict, exclude_elements: list[str] = [] ): # Sometimes we may want to exclude specific elements from being used for prediction x = [] for each_element in self.metadata_elements: # Only include elements which are present if each_element in ds and each_element not in exclude_elements: if ds[each_element] is not None and str(ds[each_element]) != "nan": x.append(f"{each_element} {ds[each_element]}") x = " | ".join(x) x = x.replace("[", "").replace("]", "").replace(",", "").replace("'", "") return x @staticmethod def determine_plane_from_dicom(ds: pydicom.Dataset): iop = np.asarray(ds.ImageOrientationPatient) # Calculate the direction cosine for the normal vector of the plane normal_vector = np.cross(iop[:3], iop[3:]) # Determine the plane based on the largest component of the normal vector abs_normal = np.abs(normal_vector) if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]: return "SAG" elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]: return "COR" else: return "AX"