import matplotlib.pyplot as plt import matplotlib.cm as cm import matplotlib.colors as clrs import requests import json import pandas as pd import torch import spaces # Function to get tokens given text @spaces.GPU def get_tokens(tokenizer, text): token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda" if torch.cuda.is_available() else "cpu") tokens = tokenizer.convert_ids_to_tokens(token_ids[0]) return tokens, token_ids # Function to apply chat template to prompt @spaces.GPU def decorate_prompt(tokenizer, prompt): chat = [ {"role": "user", "content": prompt}, {"role": "assistant", "content": ""}, ] text = tokenizer.apply_chat_template(chat, tokenize=False) token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda" if torch.cuda.is_available() else "cpu") return token_ids # Function to get response to prompt def get_response(model_pipe, prompt): response = model_pipe(prompt)[0]['generated_text'] return response # Function to highlight tokens based on given values def plot_tokens_with_highlights(tokens, values, concept, cmap_name='Oranges', vmin=None, vmax=None): if len(tokens) != len(values): raise ValueError("The number of tokens and values must be the same.") # Set color map cmap = cm.get_cmap(cmap_name) norm = clrs.Normalize(vmin=vmin if vmin is not None else values.detach().min(), vmax=vmax if vmax is not None else values.detach().max()) html_output = f"

How much information about the concept '{concept}' is carried in each token:

" for token, value in zip(tokens, values.detach().numpy()): rgba_color = cmap(norm(value)) hex_color = '#%02x%02x%02x' % (int(rgba_color[0]*255), int(rgba_color[1]*255), int(rgba_color[2]*255)) html_output += f'{token} ' return html_output # Function to get concepts dictionary def get_concepts_dictionary(dictionary_url): response = requests.get(dictionary_url, stream=True) response.raise_for_status() data_dict = {} for line in response.iter_lines(decode_unicode=True): if line: obj = json.loads(line) concept_id = obj.get("concept_id") concept = obj.get("concept") if concept_id and concept: data_dict[concept_id] = concept.capitalize() return data_dict # Function to get matching concepts def select_concepts(all_concepts, desired_concept): concept_ids = [] for k, v in all_concepts.items(): if desired_concept.lower() in v.lower(): concept_ids.append(k) concept_data = [] for concept_id in concept_ids: concept_name = all_concepts.get(concept_id, "Unknown Concept") concept_data.append({"Concept ID": concept_id, "Concept Name": concept_name}) concept_df = pd.DataFrame(concept_data) return torch.tensor(concept_ids), concept_df