Samarth / generate_metadata.py
Himanshu2003's picture
Upload 3 files
bdbd514 verified
import google.generativeai as genai
from langchain_core.runnables import RunnableLambda
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
import pandas as pd
import json
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import os
import pickle
SAVE_DIR = "vector_store"
os.makedirs(SAVE_DIR, exist_ok=True)
dataset_links = [
"https://www.data.gov.in/resource/sub-divisional-monthly-rainfall-1901-2017",
"https://www.data.gov.in/resource/area-weighted-monthly-seasonal-and-annual-rainfall-mm-36-meteorological-subdivisions-1901",
"https://www.data.gov.in/resource/all-india-area-weighted-monthly-seasonal-and-annual-rainfall-mm-1901-2015",
"https://www.data.gov.in/resource/one-district-one-product-list-description",
"https://www.data.gov.in/resource/agriculture-production-different-foodgrains-year-2003-2014-all-india-level",
"https://www.data.gov.in/resource/monthly-production-central-statistics-food-and-beverages-year-1997-2011",
"https://www.data.gov.in/resource/annual-wholesale-price-index-agriculture-produce"
]
dataset_names = [
"Sub Divisional Monthly Rainfall from 1901 to 2017",
"Area weighted monthly, seasonal and annual rainfall ( in mm) for 36 meteorological subdivisions from 1901-2015",
"All India area weighted monthly, seasonal and annual rainfall (in mm) from 1901-2015",
"One District One Product List with Description",
"Agriculture production of different foodgrains from year 2003 to 2014 at all India level",
"Monthly production central statistics of food and beverages from year 1997-2011",
"Annual Wholesale Price Index of Agriculture Produce"
]
datasets_list = [
"Sub_Division_IMD_2017.csv",
"rainfall_area-wt_sd_1901-2015.csv",
"rainfall_area-wt_India_1901-2015.csv",
"20250707_ODOP_Products_V31.csv",
"Production-Department_of_Agriculture_and_Cooperation_1.csv",
"Production-Central_Statistics_Office.csv",
"Agri008_1.csv"
]
embeddings = []
metadata_texts = []
genai.configure(api_key="AIzaSyD-iwKoPUSxGerqKjKhjvAJ3KRERpy0-18")
gemini_model = genai.GenerativeModel("gemini-2.5-flash")
# Wrap Gemini in a LangChain Runnable
llm_model = RunnableLambda(
lambda x: gemini_model.generate_content(x if isinstance(x, str) else str(x)).text
)
metadata_prompt = PromptTemplate(
input_variables=["table_name", "table_preview"],
template="""
You are a data analysis assistant.
You are given a preview of a dataset in JSON format.
### TABLE NAME:
{table_name}
### TABLE PREVIEW (JSON):
{table_preview}
Your job is to analyze the data and produce **clean, structured metadata** that can later be used
for automatic SQL generation and dataset selection.
Follow these rules:
- Always extract the exact column names and preserve their case/spelling.
- Infer a short description and data type for each column (numeric, text, date, etc.).
- Include 2-3 distinct example values for every column.
- Do NOT invent new columns or values.
- Be concise and factual.
Return output in **VALID JSON (no markdown, no explanations)** with these exact keys:
{{
"dataset_summary": "2-3 sentences describing what the dataset represents.",
"canonical_schema": [
{{
"column_name": "<exact column name>",
"data_type": "<inferred type>",
"description": "<short description>",
"example_values": ["<ex1>", "<ex2>", "<ex3>"]
}}
],
"potential_use_cases": [
"Short description of a possible analysis or application (2-3 total)"
],
"inferred_domain": "<e.g., Agriculture, Health, Weather, etc.>",
"keywords": ["<keyword1>", "<keyword2>", "<keyword3>"]
}}
"""
)
str_parser = StrOutputParser()
json_parser = JsonOutputParser()
chain = metadata_prompt | llm_model | str_parser
embedding_model = SentenceTransformer('./all-MiniLM-L6-v2')
for name, dataset, link in zip(dataset_names, datasets_list, dataset_links):
df = pd.read_csv(f"datasets/{dataset}")
data_dict = {col: df[col].head(6).tolist() for col in df.columns}
json_data = json.dumps(data_dict, indent=2)
response = chain.invoke({"table_name": name, "table_preview" : json_data})
json_response = json_parser.invoke(response)
json_response["table_name"] = name
json_response["table_source"] = link
metadata_texts.append(json_response)
# Get embeddings
embedding = embedding_model.encode([response])[0] # get 1D array
embeddings.append(embedding)
embeddings = np.array(embeddings).astype('float32')
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
# Save FAISS index
faiss.write_index(index, os.path.join(SAVE_DIR, "metadata_index.faiss"))
# Save dataset names + raw metadata text (for lookup)
with open(os.path.join(SAVE_DIR, "metadata_info.pkl"), "wb") as f:
pickle.dump({
"dataset_names": dataset_names,
"datasets_list": datasets_list,
"source_list": dataset_links,
"metadata_texts": metadata_texts
}, f)
print("βœ… Vector database created and saved successfully.")
query = "livestock population trends in agriculture"
query_embedding = embedding_model.encode([query]).astype('float32')
D, I = index.search(query_embedding, k=1)
# Map index back to table name
best_match_index = I[0][0]
print(f"πŸ” Closest table: {datasets_list[best_match_index]}")
print(f"πŸ“ Distance: {D[0][0]}")