| import pinecone | |
| import requests | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| from config import config | |
| def search(text: str, k: int = 5): | |
| """Get the k closest articles to the text.""" | |
| embeds = _get_embeddings(text) | |
| r = requests.post( | |
| f"https://{config.pinecone_index}-5b18b87.svc.{config.pinecone_env}.pinecone.io/query", | |
| headers={ | |
| "Api-Key": config.pinecone_api_key, | |
| "accept": "application/json", | |
| "content-type": "application/json", | |
| }, | |
| json={ | |
| "vector": embeds, | |
| "top_k": k, | |
| "includeMetadata": True, | |
| "includeValues": False, | |
| }, | |
| ) | |
| if r.status_code == 200: | |
| return r.json() | |
| else: | |
| raise Exception(f"Error: {r.status_code} - {r.text}") | |
| def _get_embeddings(text: str): | |
| inputs_ids = st.session_state.tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| with torch.no_grad(): | |
| last_hidden_states = st.session_state.model(**inputs_ids)[0] | |
| return last_hidden_states.mean(dim=1).squeeze().tolist() | |
| password = st.text_input("Password", type="password") | |
| if password == config.password: | |
| st.title("PubMed Embeddings") | |
| st.subheader("Search for a PubMed article and get its id.") | |
| text = st.text_input("Search for a PubMed article", "Epidemiology of COVID-19") | |
| with st.spinner("Loading Embedding Model..."): | |
| pinecone.init(api_key=config.pinecone_api_key, env=config.pinecone_env) | |
| if "index" not in st.session_state: | |
| st.session_state.index = pinecone.Index(config.pinecone_index) | |
| if "tokenizer" not in st.session_state: | |
| st.session_state.tokenizer = AutoTokenizer.from_pretrained(config.model_name) | |
| if "model" not in st.session_state: | |
| st.session_state.model = AutoModel.from_pretrained(config.model_name) | |
| if st.button("Search"): | |
| with st.spinner("Searching..."): | |
| results = search(text) | |
| for res in results["matches"]: | |
| st.write(f"{res['id']} - confidence: {res['score']:.2f}") | |
| else: | |
| st.write("Password incorrect!") | |