fix for transformers >= 4.35.2
Browse files- README.md +2 -2
- modeling_lsg_bart.py +3 -3
README.md
CHANGED
|
@@ -18,7 +18,7 @@ model-index:
|
|
| 18 |
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
| 19 |
should probably proofread and complete it, then remove this comment. -->
|
| 20 |
|
| 21 |
-
**Transformers >= 4.
|
| 22 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
| 23 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
| 24 |
|
|
@@ -105,7 +105,7 @@ The following hyperparameters were used during generation:
|
|
| 105 |
|
| 106 |
### Framework versions
|
| 107 |
|
| 108 |
-
- Transformers 4.
|
| 109 |
- Pytorch 1.12.1
|
| 110 |
- Datasets 2.3.2
|
| 111 |
- Tokenizers 0.11.6
|
|
|
|
| 18 |
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
| 19 |
should probably proofread and complete it, then remove this comment. -->
|
| 20 |
|
| 21 |
+
**Transformers >= 4.35.2**\
|
| 22 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
| 23 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
| 24 |
|
|
|
|
| 105 |
|
| 106 |
### Framework versions
|
| 107 |
|
| 108 |
+
- Transformers 4.35.2
|
| 109 |
- Pytorch 1.12.1
|
| 110 |
- Datasets 2.3.2
|
| 111 |
- Tokenizers 0.11.6
|
modeling_lsg_bart.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from logging import warn
|
| 2 |
import torch
|
| 3 |
from transformers.models.bart.modeling_bart import *
|
| 4 |
-
from transformers.
|
| 5 |
import torch.nn as nn
|
| 6 |
import sys
|
| 7 |
|
|
@@ -852,7 +852,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
| 852 |
# expand attention_mask
|
| 853 |
if attention_mask is not None:
|
| 854 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 855 |
-
attention_mask =
|
| 856 |
|
| 857 |
encoder_states = () if output_hidden_states else None
|
| 858 |
all_attentions = () if output_attentions else None
|
|
@@ -1093,4 +1093,4 @@ try:
|
|
| 1093 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
| 1094 |
except:
|
| 1095 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
| 1096 |
-
warn("Update to transformers >= 4.
|
|
|
|
| 1 |
from logging import warn
|
| 2 |
import torch
|
| 3 |
from transformers.models.bart.modeling_bart import *
|
| 4 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
| 5 |
import torch.nn as nn
|
| 6 |
import sys
|
| 7 |
|
|
|
|
| 852 |
# expand attention_mask
|
| 853 |
if attention_mask is not None:
|
| 854 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 855 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
| 856 |
|
| 857 |
encoder_states = () if output_hidden_states else None
|
| 858 |
all_attentions = () if output_attentions else None
|
|
|
|
| 1093 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
| 1094 |
except:
|
| 1095 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
| 1096 |
+
warn("Update to transformers >= 4.35.2 to fix.")
|