import google.generativeai as genai from langchain_core.runnables import RunnableLambda from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser, JsonOutputParser import pandas as pd import json from sentence_transformers import SentenceTransformer import faiss import numpy as np import os import pickle SAVE_DIR = "vector_store" os.makedirs(SAVE_DIR, exist_ok=True) dataset_links = [ "https://www.data.gov.in/resource/sub-divisional-monthly-rainfall-1901-2017", "https://www.data.gov.in/resource/area-weighted-monthly-seasonal-and-annual-rainfall-mm-36-meteorological-subdivisions-1901", "https://www.data.gov.in/resource/all-india-area-weighted-monthly-seasonal-and-annual-rainfall-mm-1901-2015", "https://www.data.gov.in/resource/one-district-one-product-list-description", "https://www.data.gov.in/resource/agriculture-production-different-foodgrains-year-2003-2014-all-india-level", "https://www.data.gov.in/resource/monthly-production-central-statistics-food-and-beverages-year-1997-2011", "https://www.data.gov.in/resource/annual-wholesale-price-index-agriculture-produce" ] dataset_names = [ "Sub Divisional Monthly Rainfall from 1901 to 2017", "Area weighted monthly, seasonal and annual rainfall ( in mm) for 36 meteorological subdivisions from 1901-2015", "All India area weighted monthly, seasonal and annual rainfall (in mm) from 1901-2015", "One District One Product List with Description", "Agriculture production of different foodgrains from year 2003 to 2014 at all India level", "Monthly production central statistics of food and beverages from year 1997-2011", "Annual Wholesale Price Index of Agriculture Produce" ] datasets_list = [ "Sub_Division_IMD_2017.csv", "rainfall_area-wt_sd_1901-2015.csv", "rainfall_area-wt_India_1901-2015.csv", "20250707_ODOP_Products_V31.csv", "Production-Department_of_Agriculture_and_Cooperation_1.csv", "Production-Central_Statistics_Office.csv", "Agri008_1.csv" ] embeddings = [] metadata_texts = [] genai.configure(api_key="AIzaSyD-iwKoPUSxGerqKjKhjvAJ3KRERpy0-18") gemini_model = genai.GenerativeModel("gemini-2.5-flash") # Wrap Gemini in a LangChain Runnable llm_model = RunnableLambda( lambda x: gemini_model.generate_content(x if isinstance(x, str) else str(x)).text ) metadata_prompt = PromptTemplate( input_variables=["table_name", "table_preview"], template=""" You are a data analysis assistant. You are given a preview of a dataset in JSON format. ### TABLE NAME: {table_name} ### TABLE PREVIEW (JSON): {table_preview} Your job is to analyze the data and produce **clean, structured metadata** that can later be used for automatic SQL generation and dataset selection. Follow these rules: - Always extract the exact column names and preserve their case/spelling. - Infer a short description and data type for each column (numeric, text, date, etc.). - Include 2-3 distinct example values for every column. - Do NOT invent new columns or values. - Be concise and factual. Return output in **VALID JSON (no markdown, no explanations)** with these exact keys: {{ "dataset_summary": "2-3 sentences describing what the dataset represents.", "canonical_schema": [ {{ "column_name": "", "data_type": "", "description": "", "example_values": ["", "", ""] }} ], "potential_use_cases": [ "Short description of a possible analysis or application (2-3 total)" ], "inferred_domain": "", "keywords": ["", "", ""] }} """ ) str_parser = StrOutputParser() json_parser = JsonOutputParser() chain = metadata_prompt | llm_model | str_parser embedding_model = SentenceTransformer('./all-MiniLM-L6-v2') for name, dataset, link in zip(dataset_names, datasets_list, dataset_links): df = pd.read_csv(f"datasets/{dataset}") data_dict = {col: df[col].head(6).tolist() for col in df.columns} json_data = json.dumps(data_dict, indent=2) response = chain.invoke({"table_name": name, "table_preview" : json_data}) json_response = json_parser.invoke(response) json_response["table_name"] = name json_response["table_source"] = link metadata_texts.append(json_response) # Get embeddings embedding = embedding_model.encode([response])[0] # get 1D array embeddings.append(embedding) embeddings = np.array(embeddings).astype('float32') dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) index.add(embeddings) # Save FAISS index faiss.write_index(index, os.path.join(SAVE_DIR, "metadata_index.faiss")) # Save dataset names + raw metadata text (for lookup) with open(os.path.join(SAVE_DIR, "metadata_info.pkl"), "wb") as f: pickle.dump({ "dataset_names": dataset_names, "datasets_list": datasets_list, "source_list": dataset_links, "metadata_texts": metadata_texts }, f) print("✅ Vector database created and saved successfully.") query = "livestock population trends in agriculture" query_embedding = embedding_model.encode([query]).astype('float32') D, I = index.search(query_embedding, k=1) # Map index back to table name best_match_index = I[0][0] print(f"🔍 Closest table: {datasets_list[best_match_index]}") print(f"📏 Distance: {D[0][0]}")