File size: 2,907 Bytes
b5350db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
from transformers import AutoTokenizer, BatchEncoding
from mixinhelpers import CXR_Mixin, ECG_Mixin, ECHO_Mixin, Text_Mixin
"""
Preprocessor classes for different modalities and their combinations.
You can combine different mixins to create preprocessors for multi-modal inputs.
Examples below are provided for ECHO+Text, ECG+Text, and CXR+Text.
"""
class BasePreprocessor:
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# duo modality preprocessors
class ECHOText_Preprocessor(BasePreprocessor, ECHO_Mixin, Text_Mixin):
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
super().__init__(model_name=model_name)
def preprocess_echo_text(self, echo_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
"""this can be used in dataloader to correctly collate batches, use the string keys to
identify the modalities
echo_path: path to echo npy file
text: string of text report
returns: (echo tensor, tokenized text dict)"""
echo = self.preprocess_single_echo(echo_path) # (C, H, W)
text_inputs = self.construct_caption(
caption=text, tokenizer=self.tokenizer, modality=self.ECHO_KEY
)
return echo, text_inputs
class ECGText_Preprocessor(BasePreprocessor, ECG_Mixin, Text_Mixin):
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
super().__init__(model_name=model_name)
def preprocess_ecg_text(self, ecg_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
"""this can be used in dataloader to correctly collate batches, use the string keys
to identify the modalities
ecg_path: path to ecg npy file
text: string of text report
returns: (ecg tensor, tokenized text dict)"""
ecg = self.preprocess_single_ecg(ecg_path) # (C, L)
text_inputs = self.construct_caption(
caption=text, tokenizer=self.tokenizer, modality=self.ECG_KEY
)
return ecg, text_inputs
class CXRText_Preprocessor(BasePreprocessor, CXR_Mixin, Text_Mixin):
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
super().__init__(model_name=model_name)
def preprocess_cxr_text(self, cxr_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
"""this can be used in dataloader to correctly collate batches, use the string keys to
identify the modalities
cxr_path: path to cxr image file
text: string of text report
returns: (cxr tensor, tokenized text dict)"""
cxr = self.preprocess_single_cxr(cxr_path) # (C, H, W)
text_inputs = self.construct_caption(
caption=text, tokenizer=self.tokenizer, modality=self.VISION_KEY
)
return cxr, text_inputs
|