|
|
import torch |
|
|
from transformers import AutoTokenizer, BatchEncoding |
|
|
|
|
|
from mixinhelpers import CXR_Mixin, ECG_Mixin, ECHO_Mixin, Text_Mixin |
|
|
|
|
|
""" |
|
|
Preprocessor classes for different modalities and their combinations. |
|
|
You can combine different mixins to create preprocessors for multi-modal inputs. |
|
|
Examples below are provided for ECHO+Text, ECG+Text, and CXR+Text. |
|
|
""" |
|
|
|
|
|
|
|
|
class BasePreprocessor: |
|
|
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
|
|
|
class ECHOText_Preprocessor(BasePreprocessor, ECHO_Mixin, Text_Mixin): |
|
|
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
|
|
super().__init__(model_name=model_name) |
|
|
|
|
|
def preprocess_echo_text(self, echo_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: |
|
|
"""this can be used in dataloader to correctly collate batches, use the string keys to |
|
|
identify the modalities |
|
|
echo_path: path to echo npy file |
|
|
text: string of text report |
|
|
returns: (echo tensor, tokenized text dict)""" |
|
|
echo = self.preprocess_single_echo(echo_path) |
|
|
text_inputs = self.construct_caption( |
|
|
caption=text, tokenizer=self.tokenizer, modality=self.ECHO_KEY |
|
|
) |
|
|
return echo, text_inputs |
|
|
|
|
|
|
|
|
class ECGText_Preprocessor(BasePreprocessor, ECG_Mixin, Text_Mixin): |
|
|
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
|
|
super().__init__(model_name=model_name) |
|
|
|
|
|
def preprocess_ecg_text(self, ecg_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: |
|
|
"""this can be used in dataloader to correctly collate batches, use the string keys |
|
|
to identify the modalities |
|
|
ecg_path: path to ecg npy file |
|
|
text: string of text report |
|
|
returns: (ecg tensor, tokenized text dict)""" |
|
|
ecg = self.preprocess_single_ecg(ecg_path) |
|
|
text_inputs = self.construct_caption( |
|
|
caption=text, tokenizer=self.tokenizer, modality=self.ECG_KEY |
|
|
) |
|
|
|
|
|
return ecg, text_inputs |
|
|
|
|
|
|
|
|
class CXRText_Preprocessor(BasePreprocessor, CXR_Mixin, Text_Mixin): |
|
|
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
|
|
super().__init__(model_name=model_name) |
|
|
|
|
|
def preprocess_cxr_text(self, cxr_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: |
|
|
"""this can be used in dataloader to correctly collate batches, use the string keys to |
|
|
identify the modalities |
|
|
cxr_path: path to cxr image file |
|
|
text: string of text report |
|
|
returns: (cxr tensor, tokenized text dict)""" |
|
|
cxr = self.preprocess_single_cxr(cxr_path) |
|
|
text_inputs = self.construct_caption( |
|
|
caption=text, tokenizer=self.tokenizer, modality=self.VISION_KEY |
|
|
) |
|
|
|
|
|
return cxr, text_inputs |
|
|
|