Oviya commited on
Commit
d7101fa
·
1 Parent(s): 6c90094

chatbot update

Browse files
Files changed (3) hide show
  1. chatbot.py +232 -0
  2. pytrade.py +67 -0
  3. requirements.txt +7 -0
chatbot.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import re
4
+ import json
5
+ import time
6
+ from datetime import datetime
7
+ from typing import List, Dict
8
+
9
+ from flask import Flask, request, jsonify
10
+ from dotenv import load_dotenv
11
+ import requests
12
+
13
+ # ----------------------------
14
+ # Optional providers (OpenAI v1 / Cohere)
15
+ # ----------------------------
16
+ OPENAI_CLIENT = None
17
+ try:
18
+ from openai import OpenAI
19
+ OPENAI_CLIENT = "available"
20
+ except Exception:
21
+ OPENAI_CLIENT = None
22
+
23
+ try:
24
+ import cohere
25
+ except Exception:
26
+ cohere = None
27
+
28
+ load_dotenv()
29
+ app = Flask(__name__)
30
+
31
+ # ----------------------------
32
+ # Config
33
+ # ----------------------------
34
+ LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower().strip()
35
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
36
+ COHERE_API_KEY = os.getenv("COHERE_API_KEY")
37
+ SERPAPI_API_KEY = os.getenv("SERPAPI_API_KEY")
38
+ SEARCH_TOPK = int(os.getenv("SEARCH_TOPK", "5"))
39
+ TIMEZONE = "Asia/Kolkata"
40
+
41
+ if LLM_PROVIDER == "openai" and not OPENAI_API_KEY:
42
+ print("[WARN] OPENAI_API_KEY not set; general answers will fail.")
43
+ if LLM_PROVIDER == "cohere" and not COHERE_API_KEY:
44
+ print("[WARN] COHERE_API_KEY not set; general answers will fail.")
45
+ if not SERPAPI_API_KEY:
46
+ print("[WARN] SERPAPI_API_KEY not set; 'latest' queries will not work.")
47
+
48
+ # Initialize OpenAI client (v1+)
49
+ openai_client = None
50
+ if LLM_PROVIDER == "openai" and OPENAI_CLIENT and OPENAI_API_KEY:
51
+ openai_client = OpenAI(api_key=OPENAI_API_KEY)
52
+
53
+ # ----------------------------
54
+ # Utilities
55
+ # ----------------------------
56
+
57
+ # Common “latest/live” triggers
58
+ LATEST_TRIGGERS = [
59
+ r"\btoday\b", r"\bnow\b", r"\blatest\b", r"\bupdate\b", r"\brecent\b",
60
+ r"\bbreaking\b", r"\blive\b", r"\bthis\s+hour\b", r"\bthis\s+minute\b",
61
+ r"\bcurrent\b", r"\bas of\b", r"\btoday'?s\b", r"\bprice\s+today\b"
62
+ ]
63
+ LATEST_PATTERN = re.compile("|".join(LATEST_TRIGGERS), re.IGNORECASE)
64
+
65
+ # Simple aliases for finance names/tickers (extend as needed)
66
+ ALIASES = {
67
+ "tcs": "Tata Consultancy Services",
68
+ "ril": "Reliance Industries",
69
+ "infy": "Infosys",
70
+ "hdfc bank": "HDFC Bank",
71
+ "icici": "ICICI Bank",
72
+ }
73
+
74
+ def normalize_entities(text: str) -> str:
75
+ t = text
76
+ for k, v in ALIASES.items():
77
+ t = re.sub(rf"\b{k}\b", v, t, flags=re.IGNORECASE)
78
+ return t
79
+
80
+ def needs_live_context(query: str) -> bool:
81
+ """Heuristic to detect time-sensitive queries."""
82
+ if not query:
83
+ return False
84
+ q = query.lower()
85
+
86
+ if LATEST_PATTERN.search(q):
87
+ return True
88
+
89
+ # Domain shortcuts
90
+ domain_triggers = [
91
+ "who won", "match result", "score now", "stock price", "share price",
92
+ "usd inr rate", "exchange rate", "weather", "today's weather",
93
+ "news on", "headline", "earnings today", "ipo today",
94
+ "live price", "current price", "price right now"
95
+ ]
96
+ if any(t in q for t in domain_triggers):
97
+ return True
98
+
99
+ # Finance shortcut: “price of <entity>”
100
+ if re.search(r"\bprice of\b", q) and not re.search(r"\byesterday|last close|history\b", q):
101
+ return True
102
+
103
+ return False
104
+
105
+ def pick_is_news(query: str) -> bool:
106
+ """Treat as news if clear news terms appear."""
107
+ q = query.lower()
108
+ news_terms = ["news", "headline", "breaking", "election", "budget", "earthquake", "merger", "acquisition", "ceo resigns"]
109
+ return any(t in q for t in news_terms)
110
+
111
+ def serpapi_search(query: str, is_news: bool = False, num: int = SEARCH_TOPK) -> List[Dict[str, str]]:
112
+ """Fetch top search or news results from SerpAPI."""
113
+ if not SERPAPI_API_KEY:
114
+ return []
115
+
116
+ params = {
117
+ "api_key": SERPAPI_API_KEY,
118
+ "q": query,
119
+ }
120
+
121
+ if is_news:
122
+ url = "https://serpapi.com/search.json"
123
+ params.update({"engine": "google_news", "num": min(num, 10), "hl": "en", "gl": "in"})
124
+ else:
125
+ url = "https://serpapi.com/search.json"
126
+ params.update({"engine": "google", "num": min(num, 10), "hl": "en", "gl": "in"})
127
+
128
+ r = requests.get(url, params=params, timeout=20)
129
+ r.raise_for_status()
130
+ data = r.json()
131
+
132
+ results: List[Dict[str, str]] = []
133
+ if is_news:
134
+ for item in (data.get("news_results") or [])[:num]:
135
+ results.append({
136
+ "title": item.get("title") or "",
137
+ "snippet": item.get("snippet") or item.get("description") or "",
138
+ "link": item.get("link") or "",
139
+ "source": (item.get("source") or {}).get("name") or item.get("source") or ""
140
+ })
141
+ else:
142
+ for item in (data.get("organic_results") or [])[:num]:
143
+ results.append({
144
+ "title": item.get("title") or "",
145
+ "snippet": item.get("snippet") or "",
146
+ "link": item.get("link") or "",
147
+ "source": item.get("source") or ""
148
+ })
149
+ return results
150
+
151
+ def build_citation_block(hits: List[Dict[str, str]]) -> str:
152
+ """Compact citations for the LLM and the response."""
153
+ lines = []
154
+ for i, h in enumerate(hits, start=1):
155
+ title = (h.get("title") or "").strip()
156
+ link = (h.get("link") or "").strip()
157
+ source = (h.get("source") or "").strip()
158
+ snippet = (h.get("snippet") or "").strip()
159
+ lines.append(f"[{i}] {title} — {source}\n{snippet}\n{link}")
160
+ return "\n\n".join(lines)
161
+
162
+ # ----------------------------
163
+ # LLM Calls
164
+ # ----------------------------
165
+
166
+ BASE_SYSTEM_PROMPT = (
167
+ "You are a helpful and precise assistant. Use simple, neutral English. "
168
+ "When sources are provided, synthesize them, highlight clear facts, and include a short 'Sources' list as [1], [2], etc. "
169
+ "If information is uncertain or evolving, state that clearly."
170
+ )
171
+
172
+ def call_openai(system_prompt: str, user_prompt: str) -> str:
173
+ """OpenAI Python SDK ≥ 1.0.0."""
174
+ if not openai_client:
175
+ raise RuntimeError("OpenAI is not configured.")
176
+ resp = openai_client.chat.completions.create(
177
+ model="gpt-4o-mini",
178
+ messages=[
179
+ {"role": "system", "content": system_prompt},
180
+ {"role": "user", "content": user_prompt}
181
+ ],
182
+ temperature=0.2,
183
+ max_tokens=900,
184
+ )
185
+ return (resp.choices[0].message.content or "").strip()
186
+
187
+ def call_cohere(system_prompt: str, user_prompt: str) -> str:
188
+ """Cohere chat (adjust model if needed)."""
189
+ if not cohere or not COHERE_API_KEY:
190
+ raise RuntimeError("Cohere is not configured.")
191
+ client = cohere.Client(api_key=COHERE_API_KEY)
192
+ resp = client.chat(
193
+ model="command-r-plus",
194
+ messages=[
195
+ {"role": "system", "content": system_prompt},
196
+ {"role": "user", "content": user_prompt}
197
+ ],
198
+ temperature=0.2,
199
+ max_tokens=900,
200
+ )
201
+ text = getattr(resp, "text", None) or (getattr(resp, "output_text", None))
202
+ if not text and hasattr(resp, "message") and hasattr(resp.message, "content"):
203
+ parts = resp.message.content
204
+ text = "".join(getattr(p, "text", "") for p in parts)
205
+ return (text or "").strip()
206
+
207
+ def call_llm(system_prompt: str, user_prompt: str) -> str:
208
+ if LLM_PROVIDER == "openai":
209
+ return call_openai(system_prompt, user_prompt)
210
+ elif LLM_PROVIDER == "cohere":
211
+ return call_cohere(system_prompt, user_prompt)
212
+ else:
213
+ raise RuntimeError("Unsupported LLM_PROVIDER")
214
+
215
+ def compose_live_user_prompt(query: str, hits: List[Dict[str, str]]) -> str:
216
+ citation_block = build_citation_block(hits)
217
+ today = datetime.now().strftime("%B %d, %Y")
218
+ return (
219
+ f"User question (time-sensitive): {query}\n"
220
+ f"Date today: {today}\n\n"
221
+ f"You have these top search results. Answer using only what these sources support. "
222
+ f"Be concise and include a 'Sources' section with numbered citations pointing to the links.\n\n"
223
+ f"{citation_block}\n\n"
224
+ f"Now write the answer:"
225
+ )
226
+
227
+ def compose_general_user_prompt(query: str) -> str:
228
+ today = datetime.now().strftime("%B %d, %Y")
229
+ return (
230
+ f"User question: {query}\n"
231
+ f"(Answer in simple, neutral English. If facts might have changed after {today}, mention that briefly.)"
232
+ )
pytrade.py CHANGED
@@ -15,6 +15,18 @@ import json
15
  import os
16
  import time
17
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  app = Flask(__name__)
20
 
@@ -109,6 +121,61 @@ def analyze_all():
109
  except Exception as e:
110
  return jsonify({"error": str(e)}), 500
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  if __name__ == "__main__":
113
  # Default to 5000 locally; on Hugging Face Spaces the platform injects PORT.
114
  port = int(os.environ.get("PORT", "5000"))
 
15
  import os
16
  import time
17
  import requests
18
+ from typing import List, Dict
19
+ from chatbot import (
20
+ normalize_entities,
21
+ needs_live_context,
22
+ pick_is_news,
23
+ serpapi_search,
24
+ compose_live_user_prompt,
25
+ compose_general_user_prompt,
26
+ call_llm,
27
+ BASE_SYSTEM_PROMPT,
28
+ SEARCH_TOPK
29
+ )
30
 
31
  app = Flask(__name__)
32
 
 
121
  except Exception as e:
122
  return jsonify({"error": str(e)}), 500
123
 
124
+
125
+
126
+ @app.route("/chat", methods=["POST"])
127
+ def chat():
128
+ """
129
+ Request JSON:
130
+ { "message": "your question" }
131
+ or
132
+ { "question": "your question" }
133
+
134
+ Response JSON:
135
+ {
136
+ "answer": "...",
137
+ "live": true/false,
138
+ "sources": [{title, link, source, snippet}]
139
+ }
140
+ """
141
+ data = request.get_json(force=True, silent=True) or {}
142
+ message = (data.get("message") or data.get("question") or "").strip()
143
+
144
+ if not message:
145
+ return jsonify({"error": "message or question is required"}), 400
146
+
147
+ # Normalize common aliases (e.g., TCS -> Tata Consultancy Services)
148
+ message = normalize_entities(message)
149
+
150
+ # Decide if this needs live context
151
+ live = needs_live_context(message)
152
+
153
+ hits: List[Dict[str, str]] = []
154
+ if live:
155
+ is_news = pick_is_news(message)
156
+ try:
157
+ hits = serpapi_search(message, is_news=is_news, num=SEARCH_TOPK)
158
+ except Exception:
159
+ hits = []
160
+ live = False
161
+
162
+ try:
163
+ if live and hits:
164
+ user_prompt = compose_live_user_prompt(message, hits)
165
+ answer = call_llm(BASE_SYSTEM_PROMPT, user_prompt)
166
+ return jsonify({"answer": answer, "live": True, "sources": hits})
167
+ else:
168
+ user_prompt = compose_general_user_prompt(message)
169
+ answer = call_llm(BASE_SYSTEM_PROMPT, user_prompt)
170
+ return jsonify({"answer": answer, "live": False, "sources": []})
171
+ except Exception as e:
172
+ return jsonify({
173
+ "error": "LLM call failed",
174
+ "details": str(e),
175
+ "live": live,
176
+ "sources": hits
177
+ }), 500
178
+
179
  if __name__ == "__main__":
180
  # Default to 5000 locally; on Hugging Face Spaces the platform injects PORT.
181
  port = int(os.environ.get("PORT", "5000"))
requirements.txt CHANGED
@@ -15,3 +15,10 @@ lxml_html_clean
15
  nltk
16
  rapidfuzz
17
  gunicorn
 
 
 
 
 
 
 
 
15
  nltk
16
  rapidfuzz
17
  gunicorn
18
+ torch
19
+ dotenv
20
+ gunicorn
21
+ torch
22
+ python-dotenv
23
+ openai>=1.0.0
24
+