Lora
commited on
Commit
·
5758582
1
Parent(s):
0730e98
add negative sense words and note
Browse files
app.py
CHANGED
|
@@ -121,9 +121,10 @@ Args:
|
|
| 121 |
length: length of the input sentence, used to get the contextualization weights for the last token
|
| 122 |
token: the selected token
|
| 123 |
token_index: the index of the selected token in the input sentence
|
| 124 |
-
|
|
|
|
| 125 |
"""
|
| 126 |
-
def get_token_contextual_weights (contextualization_weights, length, token, token_index,
|
| 127 |
print(">>>>>in get_token_contextual_weights")
|
| 128 |
print(f"Selected {token_index}th token: {token}")
|
| 129 |
|
|
@@ -139,47 +140,54 @@ def get_token_contextual_weights (contextualization_weights, length, token, toke
|
|
| 139 |
senses = torch.squeeze(senses) # (nv, s=1, d)
|
| 140 |
|
| 141 |
# build dataframe
|
| 142 |
-
neg_word_lists = []
|
| 143 |
pos_dfs, neg_dfs = [], []
|
| 144 |
|
| 145 |
for i in range(num_senses):
|
| 146 |
logits = lm_head(senses[i,:]) # (vocab,) [768, 50257] -> [50257]
|
| 147 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 148 |
|
| 149 |
-
pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(
|
| 150 |
-
pos_df = pd.DataFrame(pos_sorted_words)
|
| 151 |
pos_dfs.append(pos_df)
|
| 152 |
|
| 153 |
-
neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(
|
| 154 |
-
neg_df = pd.DataFrame(neg_sorted_words)
|
| 155 |
neg_dfs.append(neg_df)
|
| 156 |
|
| 157 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, \
|
| 158 |
sense6words, sense7words, sense8words, sense9words, sense10words, sense11words, \
|
| 159 |
sense12words, sense13words, sense14words, sense15words = pos_dfs
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, \
|
| 162 |
sense6slider, sense7slider, sense8slider, sense9slider, sense10slider, sense11slider, \
|
| 163 |
sense12slider, sense13slider, sense14slider, sense15slider = token_contextualization_weights_list
|
| 164 |
|
| 165 |
-
return token, token_index,
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
"""
|
| 171 |
Wrapper for when the user selects a new token in the tokens dataframe.
|
| 172 |
Converts `evt` (the selected token) to `token` and `token_index` which are used by get_token_contextual_weights.
|
| 173 |
"""
|
| 174 |
-
def new_token_contextual_weights (contextualization_weights, length, evt: gr.SelectData,
|
| 175 |
print(">>>>>in new_token_contextual_weights")
|
| 176 |
token_index = evt.index[1] # selected token is the token_index-th token in the sentence
|
| 177 |
token = evt.value
|
| 178 |
if not token:
|
| 179 |
-
return None, None,
|
| 180 |
-
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
|
| 181 |
-
|
| 182 |
-
|
|
|
|
| 183 |
|
| 184 |
def change_sense0_weight(contextualization_weights, length, token_index, new_weight):
|
| 185 |
contextualization_weights[0, 0, length-1, token_index] = new_weight
|
|
@@ -273,7 +281,7 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
| 273 |
with gr.Column(scale=1):
|
| 274 |
selected_token = gr.Textbox(label="Current Selected Token", interactive=False)
|
| 275 |
with gr.Column(scale=8):
|
| 276 |
-
gr.Markdown("""
|
| 277 |
Once a token is chosen, you can **use the sliders below to change the weights of any senses** for that token, \
|
| 278 |
and then click "Predict next word" to see updated next-word predictions. \
|
| 279 |
You can change the weights of *multiple senses of multiple tokens;* \
|
|
@@ -314,6 +322,23 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
| 314 |
sense6words = gr.DataFrame(headers = ["Sense 6"])
|
| 315 |
with gr.Column(scale=0, min_width=120):
|
| 316 |
sense7words = gr.DataFrame(headers = ["Sense 7"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
with gr.Row():
|
| 318 |
with gr.Column(scale=0, min_width=120):
|
| 319 |
sense8slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 8", elem_id="sense8slider", interactive=True)
|
|
@@ -348,7 +373,26 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
| 348 |
sense14words = gr.DataFrame(headers = ["Sense 14"])
|
| 349 |
with gr.Column(scale=0, min_width=120):
|
| 350 |
sense15words = gr.DataFrame(headers = ["Sense 15"])
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
# gr.Examples(
|
| 353 |
# examples=[["Messi plays for", top_k, None]],
|
| 354 |
# inputs=[input_sentence, top_k, contextualization_weights],
|
|
@@ -405,6 +449,7 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
| 405 |
inputs=[contextualization_weights, length, token_index, sense15slider],
|
| 406 |
outputs=[contextualization_weights])
|
| 407 |
|
|
|
|
| 408 |
predict.click(
|
| 409 |
fn=predict_next_word,
|
| 410 |
inputs = [input_sentence, top_k, contextualization_weights],
|
|
@@ -418,6 +463,9 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
| 418 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
|
| 419 |
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
|
| 420 |
|
|
|
|
|
|
|
|
|
|
| 421 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
|
| 422 |
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
|
| 423 |
)
|
|
@@ -438,6 +486,9 @@ with gr.Blocks( theme = gr.themes.Base(),
|
|
| 438 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
|
| 439 |
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
|
| 440 |
|
|
|
|
|
|
|
|
|
|
| 441 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
|
| 442 |
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
|
| 443 |
)
|
|
|
|
| 121 |
length: length of the input sentence, used to get the contextualization weights for the last token
|
| 122 |
token: the selected token
|
| 123 |
token_index: the index of the selected token in the input sentence
|
| 124 |
+
pos_count: how many top positive words to display for each sense
|
| 125 |
+
neg_count: how many top negative words to display for each sense
|
| 126 |
"""
|
| 127 |
+
def get_token_contextual_weights (contextualization_weights, length, token, token_index, pos_count = 5, neg_count = 3):
|
| 128 |
print(">>>>>in get_token_contextual_weights")
|
| 129 |
print(f"Selected {token_index}th token: {token}")
|
| 130 |
|
|
|
|
| 140 |
senses = torch.squeeze(senses) # (nv, s=1, d)
|
| 141 |
|
| 142 |
# build dataframe
|
|
|
|
| 143 |
pos_dfs, neg_dfs = [], []
|
| 144 |
|
| 145 |
for i in range(num_senses):
|
| 146 |
logits = lm_head(senses[i,:]) # (vocab,) [768, 50257] -> [50257]
|
| 147 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 148 |
|
| 149 |
+
pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(pos_count)]
|
| 150 |
+
pos_df = pd.DataFrame(pos_sorted_words, columns=["Sense {}".format(i)])
|
| 151 |
pos_dfs.append(pos_df)
|
| 152 |
|
| 153 |
+
neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(neg_count)]
|
| 154 |
+
neg_df = pd.DataFrame(neg_sorted_words, columns=["Top Negative"])
|
| 155 |
neg_dfs.append(neg_df)
|
| 156 |
|
| 157 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, \
|
| 158 |
sense6words, sense7words, sense8words, sense9words, sense10words, sense11words, \
|
| 159 |
sense12words, sense13words, sense14words, sense15words = pos_dfs
|
| 160 |
|
| 161 |
+
sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, \
|
| 162 |
+
sense6negwords, sense7negwords, sense8negwords, sense9negwords, sense10negwords, sense11negwords, \
|
| 163 |
+
sense12negwords, sense13negwords, sense14negwords, sense15negwords = neg_dfs
|
| 164 |
+
|
| 165 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, \
|
| 166 |
sense6slider, sense7slider, sense8slider, sense9slider, sense10slider, sense11slider, \
|
| 167 |
sense12slider, sense13slider, sense14slider, sense15slider = token_contextualization_weights_list
|
| 168 |
|
| 169 |
+
return token, token_index, \
|
| 170 |
+
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words, \
|
| 171 |
+
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words, \
|
| 172 |
+
sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords, \
|
| 173 |
+
sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords, \
|
| 174 |
+
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider, \
|
| 175 |
+
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider
|
| 176 |
|
| 177 |
"""
|
| 178 |
Wrapper for when the user selects a new token in the tokens dataframe.
|
| 179 |
Converts `evt` (the selected token) to `token` and `token_index` which are used by get_token_contextual_weights.
|
| 180 |
"""
|
| 181 |
+
def new_token_contextual_weights (contextualization_weights, length, evt: gr.SelectData, pos_count = 5, neg_count = 3):
|
| 182 |
print(">>>>>in new_token_contextual_weights")
|
| 183 |
token_index = evt.index[1] # selected token is the token_index-th token in the sentence
|
| 184 |
token = evt.value
|
| 185 |
if not token:
|
| 186 |
+
return None, None, \
|
| 187 |
+
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
|
| 188 |
+
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
|
| 189 |
+
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
| 190 |
+
return get_token_contextual_weights (contextualization_weights, length, token, token_index, pos_count, neg_count)
|
| 191 |
|
| 192 |
def change_sense0_weight(contextualization_weights, length, token_index, new_weight):
|
| 193 |
contextualization_weights[0, 0, length-1, token_index] = new_weight
|
|
|
|
| 281 |
with gr.Column(scale=1):
|
| 282 |
selected_token = gr.Textbox(label="Current Selected Token", interactive=False)
|
| 283 |
with gr.Column(scale=8):
|
| 284 |
+
gr.Markdown("""####
|
| 285 |
Once a token is chosen, you can **use the sliders below to change the weights of any senses** for that token, \
|
| 286 |
and then click "Predict next word" to see updated next-word predictions. \
|
| 287 |
You can change the weights of *multiple senses of multiple tokens;* \
|
|
|
|
| 322 |
sense6words = gr.DataFrame(headers = ["Sense 6"])
|
| 323 |
with gr.Column(scale=0, min_width=120):
|
| 324 |
sense7words = gr.DataFrame(headers = ["Sense 7"])
|
| 325 |
+
with gr.Row():
|
| 326 |
+
with gr.Column(scale=0, min_width=120):
|
| 327 |
+
sense0negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 328 |
+
with gr.Column(scale=0, min_width=120):
|
| 329 |
+
sense1negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 330 |
+
with gr.Column(scale=0, min_width=120):
|
| 331 |
+
sense2negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 332 |
+
with gr.Column(scale=0, min_width=120):
|
| 333 |
+
sense3negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 334 |
+
with gr.Column(scale=0, min_width=120):
|
| 335 |
+
sense4negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 336 |
+
with gr.Column(scale=0, min_width=120):
|
| 337 |
+
sense5negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 338 |
+
with gr.Column(scale=0, min_width=120):
|
| 339 |
+
sense6negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 340 |
+
with gr.Column(scale=0, min_width=120):
|
| 341 |
+
sense7negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 342 |
with gr.Row():
|
| 343 |
with gr.Column(scale=0, min_width=120):
|
| 344 |
sense8slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 8", elem_id="sense8slider", interactive=True)
|
|
|
|
| 373 |
sense14words = gr.DataFrame(headers = ["Sense 14"])
|
| 374 |
with gr.Column(scale=0, min_width=120):
|
| 375 |
sense15words = gr.DataFrame(headers = ["Sense 15"])
|
| 376 |
+
with gr.Row():
|
| 377 |
+
with gr.Column(scale=0, min_width=120):
|
| 378 |
+
sense8negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 379 |
+
with gr.Column(scale=0, min_width=120):
|
| 380 |
+
sense9negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 381 |
+
with gr.Column(scale=0, min_width=120):
|
| 382 |
+
sense10negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 383 |
+
with gr.Column(scale=0, min_width=120):
|
| 384 |
+
sense11negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 385 |
+
with gr.Column(scale=0, min_width=120):
|
| 386 |
+
sense12negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 387 |
+
with gr.Column(scale=0, min_width=120):
|
| 388 |
+
sense13negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 389 |
+
with gr.Column(scale=0, min_width=120):
|
| 390 |
+
sense14negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 391 |
+
with gr.Column(scale=0, min_width=120):
|
| 392 |
+
sense15negwords = gr.DataFrame(headers = ["Top Negative"])
|
| 393 |
+
gr.Markdown("""Note: **"Top Negative"** shows words that have the most negative dot products with the sense vector, which can
|
| 394 |
+
exhibit more coherent meaning than those with the most positive dot products.
|
| 395 |
+
To see more representative words of each sense, scroll to the top and use the **"Individual Word Sense Look Up"** tab.""")
|
| 396 |
# gr.Examples(
|
| 397 |
# examples=[["Messi plays for", top_k, None]],
|
| 398 |
# inputs=[input_sentence, top_k, contextualization_weights],
|
|
|
|
| 449 |
inputs=[contextualization_weights, length, token_index, sense15slider],
|
| 450 |
outputs=[contextualization_weights])
|
| 451 |
|
| 452 |
+
|
| 453 |
predict.click(
|
| 454 |
fn=predict_next_word,
|
| 455 |
inputs = [input_sentence, top_k, contextualization_weights],
|
|
|
|
| 463 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
|
| 464 |
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
|
| 465 |
|
| 466 |
+
sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords,
|
| 467 |
+
sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords,
|
| 468 |
+
|
| 469 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
|
| 470 |
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
|
| 471 |
)
|
|
|
|
| 486 |
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
|
| 487 |
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
|
| 488 |
|
| 489 |
+
sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords,
|
| 490 |
+
sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords,
|
| 491 |
+
|
| 492 |
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
|
| 493 |
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
|
| 494 |
)
|