ProbMED / preprocessor.py
mcintoc's picture
add documentation and model weights (#1)
b5350db verified
import torch
from transformers import AutoTokenizer, BatchEncoding
from mixinhelpers import CXR_Mixin, ECG_Mixin, ECHO_Mixin, Text_Mixin
"""
Preprocessor classes for different modalities and their combinations.
You can combine different mixins to create preprocessors for multi-modal inputs.
Examples below are provided for ECHO+Text, ECG+Text, and CXR+Text.
"""
class BasePreprocessor:
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# duo modality preprocessors
class ECHOText_Preprocessor(BasePreprocessor, ECHO_Mixin, Text_Mixin):
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
super().__init__(model_name=model_name)
def preprocess_echo_text(self, echo_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
"""this can be used in dataloader to correctly collate batches, use the string keys to
identify the modalities
echo_path: path to echo npy file
text: string of text report
returns: (echo tensor, tokenized text dict)"""
echo = self.preprocess_single_echo(echo_path) # (C, H, W)
text_inputs = self.construct_caption(
caption=text, tokenizer=self.tokenizer, modality=self.ECHO_KEY
)
return echo, text_inputs
class ECGText_Preprocessor(BasePreprocessor, ECG_Mixin, Text_Mixin):
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
super().__init__(model_name=model_name)
def preprocess_ecg_text(self, ecg_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
"""this can be used in dataloader to correctly collate batches, use the string keys
to identify the modalities
ecg_path: path to ecg npy file
text: string of text report
returns: (ecg tensor, tokenized text dict)"""
ecg = self.preprocess_single_ecg(ecg_path) # (C, L)
text_inputs = self.construct_caption(
caption=text, tokenizer=self.tokenizer, modality=self.ECG_KEY
)
return ecg, text_inputs
class CXRText_Preprocessor(BasePreprocessor, CXR_Mixin, Text_Mixin):
def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
super().__init__(model_name=model_name)
def preprocess_cxr_text(self, cxr_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
"""this can be used in dataloader to correctly collate batches, use the string keys to
identify the modalities
cxr_path: path to cxr image file
text: string of text report
returns: (cxr tensor, tokenized text dict)"""
cxr = self.preprocess_single_cxr(cxr_path) # (C, H, W)
text_inputs = self.construct_caption(
caption=text, tokenizer=self.tokenizer, modality=self.VISION_KEY
)
return cxr, text_inputs