File size: 2,907 Bytes
b5350db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from transformers import AutoTokenizer, BatchEncoding

from mixinhelpers import CXR_Mixin, ECG_Mixin, ECHO_Mixin, Text_Mixin

"""
Preprocessor classes for different modalities and their combinations.
You can combine different mixins to create preprocessors for multi-modal inputs.
Examples below are provided for ECHO+Text, ECG+Text, and CXR+Text.
"""


class BasePreprocessor:
    def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)


# duo modality preprocessors
class ECHOText_Preprocessor(BasePreprocessor, ECHO_Mixin, Text_Mixin):
    def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
        super().__init__(model_name=model_name)

    def preprocess_echo_text(self, echo_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
        """this can be used in dataloader to correctly collate batches, use the string keys to
        identify the modalities
        echo_path: path to echo npy file
        text: string of text report
        returns: (echo tensor, tokenized text dict)"""
        echo = self.preprocess_single_echo(echo_path)  # (C, H, W)
        text_inputs = self.construct_caption(
            caption=text, tokenizer=self.tokenizer, modality=self.ECHO_KEY
        )
        return echo, text_inputs


class ECGText_Preprocessor(BasePreprocessor, ECG_Mixin, Text_Mixin):
    def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
        super().__init__(model_name=model_name)

    def preprocess_ecg_text(self, ecg_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
        """this can be used in dataloader to correctly collate batches, use the string keys
        to identify the modalities
        ecg_path: path to ecg npy file
        text: string of text report
        returns: (ecg tensor,  tokenized text dict)"""
        ecg = self.preprocess_single_ecg(ecg_path)  # (C, L)
        text_inputs = self.construct_caption(
            caption=text, tokenizer=self.tokenizer, modality=self.ECG_KEY
        )

        return ecg, text_inputs


class CXRText_Preprocessor(BasePreprocessor, CXR_Mixin, Text_Mixin):
    def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
        super().__init__(model_name=model_name)

    def preprocess_cxr_text(self, cxr_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
        """this can be used in dataloader to correctly collate batches, use the string keys to
        identify the modalities
        cxr_path: path to cxr image file
        text: string of text report
        returns: (cxr tensor, tokenized text dict)"""
        cxr = self.preprocess_single_cxr(cxr_path)  # (C, H, W)
        text_inputs = self.construct_caption(
            caption=text, tokenizer=self.tokenizer, modality=self.VISION_KEY
        )

        return cxr, text_inputs