| from langchain_core.retrievers import BaseRetriever | |
| from langchain_core.documents.base import Document | |
| from langchain_core.vectorstores import VectorStore | |
| from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun | |
| from typing import List | |
| class ClimateQARetriever(BaseRetriever): | |
| vectorstore: VectorStore | |
| sources: list = [] | |
| reports:list = [] | |
| threshold: float = 0.01 | |
| k_summary: int = 3 | |
| k_total: int = 7 | |
| min_size: int = 200 | |
| filter: dict = None | |
| def _get_relevant_documents( | |
| self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
| ) -> List[Document]: | |
| # Check if all elements in the list are either IPCC or IPBES | |
| assert isinstance(self.sources,list) | |
| # assert self.k_total > self.k_summary, "k_total should be greater than k_summary" | |
| # Prepare base search kwargs | |
| filters = {} | |
| filters["source"] = { "$in":self.sources} | |
| docs = self.vectorstore.similarity_search_with_score(query=query,k=self.k_total, filter=self.filter) | |
| # Add score to metadata | |
| results = [] | |
| for i, (doc, score) in enumerate(docs): | |
| # filtre les sources sous le seuil | |
| if score < self.threshold: | |
| continue | |
| doc.metadata["similarity_score"] = score | |
| doc.metadata["content"] = doc.page_content | |
| doc.metadata["chunk_type"] = "text" | |
| doc.metadata["page_number"] = 1 | |
| results.append(doc) | |
| return results | |