small fix with torch.finfo
Browse files- modeling_lsg_bart.py +4 -2
modeling_lsg_bart.py
CHANGED
|
@@ -435,7 +435,8 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 435 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
| 436 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
| 437 |
|
| 438 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
|
| 439 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
| 440 |
|
| 441 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
@@ -500,7 +501,8 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 500 |
keys /= mask + 1e-8
|
| 501 |
values /= mask + 1e-8
|
| 502 |
|
| 503 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
|
| 504 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
| 505 |
|
| 506 |
def lsh_round(self, keys, values, mask, output_size):
|
|
|
|
| 435 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
| 436 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
| 437 |
|
| 438 |
+
mask = (1. - mask.clamp(0, 1))
|
| 439 |
+
mask *= torch.finfo(mask.dtype).min
|
| 440 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
| 441 |
|
| 442 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
|
| 501 |
keys /= mask + 1e-8
|
| 502 |
values /= mask + 1e-8
|
| 503 |
|
| 504 |
+
mask = (1. - mask.clamp(0, 1))
|
| 505 |
+
mask *= torch.finfo(mask.dtype).min
|
| 506 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
| 507 |
|
| 508 |
def lsh_round(self, keys, values, mask, output_size):
|