Spaces:
Runtime error
Runtime error
Commit
·
988981a
1
Parent(s):
5195c5a
config update
Browse files- api/__main__.py +1 -11
- app.py +2 -15
- benchmark/__main__.py +1 -10
- config/.env.example +11 -3
- discord_bot/__main__.py +2 -16
- discord_bot/client/client.py +17 -33
- qa_engine/config.py +9 -1
- qa_engine/logger.py +4 -78
- qa_engine/mocks.py +1 -1
- qa_engine/qa_engine.py +54 -82
- requirements.txt +0 -1
api/__main__.py
CHANGED
|
@@ -6,17 +6,7 @@ from qa_engine import logger, Config, QAEngine
|
|
| 6 |
|
| 7 |
config = Config()
|
| 8 |
app = FastAPI()
|
| 9 |
-
qa_engine = QAEngine(
|
| 10 |
-
llm_model_id=config.question_answering_model_id,
|
| 11 |
-
embedding_model_id=config.embedding_model_id,
|
| 12 |
-
index_repo_id=config.index_repo_id,
|
| 13 |
-
prompt_template=config.prompt_template,
|
| 14 |
-
use_docs_for_context=config.use_docs_for_context,
|
| 15 |
-
num_relevant_docs=config.num_relevant_docs,
|
| 16 |
-
add_sources_to_response=config.add_sources_to_response,
|
| 17 |
-
use_messages_for_context=config.use_messages_in_context,
|
| 18 |
-
debug=config.debug
|
| 19 |
-
)
|
| 20 |
|
| 21 |
|
| 22 |
@app.get('/')
|
|
|
|
| 6 |
|
| 7 |
config = Config()
|
| 8 |
app = FastAPI()
|
| 9 |
+
qa_engine = QAEngine(config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
@app.get('/')
|
app.py
CHANGED
|
@@ -8,16 +8,7 @@ from discord_bot import DiscordClient
|
|
| 8 |
|
| 9 |
|
| 10 |
config = Config()
|
| 11 |
-
qa_engine = QAEngine(
|
| 12 |
-
llm_model_id=config.question_answering_model_id,
|
| 13 |
-
embedding_model_id=config.embedding_model_id,
|
| 14 |
-
index_repo_id=config.index_repo_id,
|
| 15 |
-
prompt_template=config.prompt_template,
|
| 16 |
-
use_docs_for_context=config.use_docs_for_context,
|
| 17 |
-
add_sources_to_response=config.add_sources_to_response,
|
| 18 |
-
use_messages_for_context=config.use_messages_in_context,
|
| 19 |
-
debug=config.debug
|
| 20 |
-
)
|
| 21 |
|
| 22 |
|
| 23 |
def gradio_interface():
|
|
@@ -41,11 +32,7 @@ def gradio_interface():
|
|
| 41 |
def discord_bot_inference_thread():
|
| 42 |
client = DiscordClient(
|
| 43 |
qa_engine=qa_engine,
|
| 44 |
-
|
| 45 |
-
num_last_messages=config.num_last_messages,
|
| 46 |
-
use_names_in_context=config.use_names_in_context,
|
| 47 |
-
enable_commands=config.enable_commands,
|
| 48 |
-
debug=config.debug
|
| 49 |
)
|
| 50 |
client.run(config.discord_token)
|
| 51 |
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
config = Config()
|
| 11 |
+
qa_engine = QAEngine(config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def gradio_interface():
|
|
|
|
| 32 |
def discord_bot_inference_thread():
|
| 33 |
client = DiscordClient(
|
| 34 |
qa_engine=qa_engine,
|
| 35 |
+
config=config
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
)
|
| 37 |
client.run(config.discord_token)
|
| 38 |
|
benchmark/__main__.py
CHANGED
|
@@ -10,16 +10,7 @@ from qa_engine import logger, Config, QAEngine
|
|
| 10 |
QUESTIONS_FILENAME = 'benchmark/questions.json'
|
| 11 |
|
| 12 |
config = Config()
|
| 13 |
-
qa_engine = QAEngine(
|
| 14 |
-
llm_model_id=config.question_answering_model_id,
|
| 15 |
-
embedding_model_id=config.embedding_model_id,
|
| 16 |
-
index_repo_id=config.index_repo_id,
|
| 17 |
-
prompt_template=config.prompt_template,
|
| 18 |
-
use_docs_for_context=config.use_docs_for_context,
|
| 19 |
-
add_sources_to_response=config.add_sources_to_response,
|
| 20 |
-
use_messages_for_context=config.use_messages_in_context,
|
| 21 |
-
debug=config.debug
|
| 22 |
-
)
|
| 23 |
|
| 24 |
|
| 25 |
def main():
|
|
|
|
| 10 |
QUESTIONS_FILENAME = 'benchmark/questions.json'
|
| 11 |
|
| 12 |
config = Config()
|
| 13 |
+
qa_engine = QAEngine(config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def main():
|
config/.env.example
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# QA engine settings
|
| 2 |
-
QUESTION_ANSWERING_MODEL_ID=
|
| 3 |
-
EMBEDDING_MODEL_ID=
|
| 4 |
-
INDEX_REPO_ID=
|
| 5 |
PROMPT_TEMPLATE_NAME=llama
|
| 6 |
USE_DOCS_FOR_CONTEXT=True
|
| 7 |
NUM_RELEVANT_DOCS=4
|
|
@@ -9,6 +9,14 @@ ADD_SOURCES_TO_RESPONSE=True
|
|
| 9 |
USE_MESSAGES_IN_CONTEXT=True
|
| 10 |
DEBUG=True
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# Discord settings
|
| 13 |
DISCORD_TOKEN=your-bot-token
|
| 14 |
NUM_LAST_MESSAGES=1
|
|
|
|
| 1 |
# QA engine settings
|
| 2 |
+
QUESTION_ANSWERING_MODEL_ID=mock
|
| 3 |
+
EMBEDDING_MODEL_ID=hkunlp/instructor-large
|
| 4 |
+
INDEX_REPO_ID=KonradSzafer/index-instructor-large-812-m512-all_repos_above_50_stars
|
| 5 |
PROMPT_TEMPLATE_NAME=llama
|
| 6 |
USE_DOCS_FOR_CONTEXT=True
|
| 7 |
NUM_RELEVANT_DOCS=4
|
|
|
|
| 9 |
USE_MESSAGES_IN_CONTEXT=True
|
| 10 |
DEBUG=True
|
| 11 |
|
| 12 |
+
# Model settings
|
| 13 |
+
MIN_NEW_TOKENS=64
|
| 14 |
+
MAX_NEW_TOKENS=800
|
| 15 |
+
TEMPERATURE=0.6
|
| 16 |
+
TOP_K=50
|
| 17 |
+
TOP_P=0.9
|
| 18 |
+
DO_SAMPLE=True
|
| 19 |
+
|
| 20 |
# Discord settings
|
| 21 |
DISCORD_TOKEN=your-bot-token
|
| 22 |
NUM_LAST_MESSAGES=1
|
discord_bot/__main__.py
CHANGED
|
@@ -3,24 +3,10 @@ from discord_bot.client import DiscordClient
|
|
| 3 |
|
| 4 |
|
| 5 |
config = Config()
|
| 6 |
-
qa_engine = QAEngine(
|
| 7 |
-
llm_model_id=config.question_answering_model_id,
|
| 8 |
-
embedding_model_id=config.embedding_model_id,
|
| 9 |
-
index_repo_id=config.index_repo_id,
|
| 10 |
-
prompt_template=config.prompt_template,
|
| 11 |
-
use_docs_for_context=config.use_docs_for_context,
|
| 12 |
-
num_relevant_docs=config.num_relevant_docs,
|
| 13 |
-
add_sources_to_response=config.add_sources_to_response,
|
| 14 |
-
use_messages_for_context=config.use_messages_in_context,
|
| 15 |
-
debug=config.debug
|
| 16 |
-
)
|
| 17 |
client = DiscordClient(
|
| 18 |
qa_engine=qa_engine,
|
| 19 |
-
|
| 20 |
-
num_last_messages=config.num_last_messages,
|
| 21 |
-
use_names_in_context=config.use_names_in_context,
|
| 22 |
-
enable_commands=config.enable_commands,
|
| 23 |
-
debug=config.debug
|
| 24 |
)
|
| 25 |
|
| 26 |
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
config = Config()
|
| 6 |
+
qa_engine = QAEngine(config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
client = DiscordClient(
|
| 8 |
qa_engine=qa_engine,
|
| 9 |
+
config=config
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
)
|
| 11 |
|
| 12 |
|
discord_bot/client/client.py
CHANGED
|
@@ -4,56 +4,40 @@ from urllib.parse import quote
|
|
| 4 |
import discord
|
| 5 |
from typing import List
|
| 6 |
|
| 7 |
-
from qa_engine import logger, QAEngine
|
| 8 |
from discord_bot.client.utils import split_text_into_chunks
|
| 9 |
|
| 10 |
|
| 11 |
class DiscordClient(discord.Client):
|
| 12 |
"""
|
| 13 |
Discord Client class, used for interacting with a Discord server.
|
| 14 |
-
|
| 15 |
-
Args:
|
| 16 |
-
qa_service_url (str): The URL of the question answering service.
|
| 17 |
-
num_last_messages (int, optional): The number of previous messages to use as context for generating answers.
|
| 18 |
-
Defaults to 5.
|
| 19 |
-
use_names_in_context (bool, optional): Whether to include user names in the message context. Defaults to True.
|
| 20 |
-
enable_commands (bool, optional): Whether to enable commands for the bot. Defaults to True.
|
| 21 |
-
|
| 22 |
-
Attributes:
|
| 23 |
-
qa_service_url (str): The URL of the question answering service.
|
| 24 |
-
num_last_messages (int): The number of previous messages to use as context for generating answers.
|
| 25 |
-
use_names_in_context (bool): Whether to include user names in the message context.
|
| 26 |
-
enable_commands (bool): Whether to enable commands for the bot.
|
| 27 |
-
max_message_len (int): The maximum length of a message.
|
| 28 |
-
system_prompt (str): The system prompt to be used.
|
| 29 |
-
|
| 30 |
"""
|
| 31 |
def __init__(
|
| 32 |
self,
|
| 33 |
qa_engine: QAEngine,
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
use_names_in_context: bool = True,
|
| 37 |
-
enable_commands: bool = True,
|
| 38 |
-
debug: bool = False
|
| 39 |
-
):
|
| 40 |
logger.info('Initializing Discord client...')
|
| 41 |
intents = discord.Intents.all()
|
| 42 |
intents.message_content = True
|
| 43 |
super().__init__(intents=intents, command_prefix='!')
|
| 44 |
|
| 45 |
-
assert num_last_messages >= 1, \
|
| 46 |
-
'The number of last messages in context should be at least 1'
|
| 47 |
-
|
| 48 |
self.qa_engine: QAEngine = qa_engine
|
| 49 |
-
self.channel_ids: list[int] = DiscordClient._process_channel_ids(
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
self.
|
| 53 |
-
self.
|
| 54 |
-
self.
|
|
|
|
|
|
|
| 55 |
self.max_message_len: int = 2000
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
@staticmethod
|
| 59 |
def _process_channel_ids(channel_ids) -> list[int]:
|
|
@@ -103,7 +87,7 @@ class DiscordClient(discord.Client):
|
|
| 103 |
chunks = split_text_into_chunks(
|
| 104 |
text=answer,
|
| 105 |
split_characters=['. ', ', ', '\n'],
|
| 106 |
-
min_size=self.
|
| 107 |
max_size=self.max_message_len
|
| 108 |
)
|
| 109 |
for chunk in chunks:
|
|
|
|
| 4 |
import discord
|
| 5 |
from typing import List
|
| 6 |
|
| 7 |
+
from qa_engine import logger, Config, QAEngine
|
| 8 |
from discord_bot.client.utils import split_text_into_chunks
|
| 9 |
|
| 10 |
|
| 11 |
class DiscordClient(discord.Client):
|
| 12 |
"""
|
| 13 |
Discord Client class, used for interacting with a Discord server.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
def __init__(
|
| 16 |
self,
|
| 17 |
qa_engine: QAEngine,
|
| 18 |
+
config: Config,
|
| 19 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
logger.info('Initializing Discord client...')
|
| 21 |
intents = discord.Intents.all()
|
| 22 |
intents.message_content = True
|
| 23 |
super().__init__(intents=intents, command_prefix='!')
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
self.qa_engine: QAEngine = qa_engine
|
| 26 |
+
self.channel_ids: list[int] = DiscordClient._process_channel_ids(
|
| 27 |
+
config.discord_channel_ids
|
| 28 |
+
)
|
| 29 |
+
self.num_last_messages: int = config.num_last_messages
|
| 30 |
+
self.use_names_in_context: bool = config.use_names_in_context
|
| 31 |
+
self.enable_commands: bool = config.enable_commands
|
| 32 |
+
self.debug: bool = config.debug
|
| 33 |
+
self.min_message_len: int = 1800
|
| 34 |
self.max_message_len: int = 2000
|
| 35 |
|
| 36 |
+
assert all([isinstance(id, int) for id in self.channel_ids]), \
|
| 37 |
+
'All channel ids should be of type int'
|
| 38 |
+
assert self.num_last_messages >= 1, \
|
| 39 |
+
'The number of last messages in context should be at least 1'
|
| 40 |
+
|
| 41 |
|
| 42 |
@staticmethod
|
| 43 |
def _process_channel_ids(channel_ids) -> list[int]:
|
|
|
|
| 87 |
chunks = split_text_into_chunks(
|
| 88 |
text=answer,
|
| 89 |
split_characters=['. ', ', ', '\n'],
|
| 90 |
+
min_size=self.min_message_len,
|
| 91 |
max_size=self.max_message_len
|
| 92 |
)
|
| 93 |
for chunk in chunks:
|
qa_engine/config.py
CHANGED
|
@@ -11,7 +11,7 @@ def get_env(env_name: str, default: Any = None, warn: bool = True) -> str:
|
|
| 11 |
if default is not None:
|
| 12 |
if warn:
|
| 13 |
logger.warning(
|
| 14 |
-
f'Environment variable {env_name} not found.
|
| 15 |
f'Using the default value: {default}.'
|
| 16 |
)
|
| 17 |
return default
|
|
@@ -34,6 +34,14 @@ class Config:
|
|
| 34 |
use_messages_in_context: bool = eval(get_env('USE_MESSAGES_IN_CONTEXT', 'True'))
|
| 35 |
debug: bool = eval(get_env('DEBUG', 'True'))
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# Discord bot config - optional
|
| 38 |
discord_token: str = get_env('DISCORD_TOKEN', '-', warn=False)
|
| 39 |
discord_channel_ids: list[int] = get_env('DISCORD_CHANNEL_IDS', field(default_factory=list), warn=False)
|
|
|
|
| 11 |
if default is not None:
|
| 12 |
if warn:
|
| 13 |
logger.warning(
|
| 14 |
+
f'Environment variable {env_name} not found.' \
|
| 15 |
f'Using the default value: {default}.'
|
| 16 |
)
|
| 17 |
return default
|
|
|
|
| 34 |
use_messages_in_context: bool = eval(get_env('USE_MESSAGES_IN_CONTEXT', 'True'))
|
| 35 |
debug: bool = eval(get_env('DEBUG', 'True'))
|
| 36 |
|
| 37 |
+
# Model config
|
| 38 |
+
min_new_tokens: int = int(get_env('MIN_NEW_TOKENS', 64))
|
| 39 |
+
max_new_tokens: int = int(get_env('MAX_NEW_TOKENS', 800))
|
| 40 |
+
temperature: float = float(get_env('TEMPERATURE', 0.6))
|
| 41 |
+
top_k: int = int(get_env('TOP_K', 50))
|
| 42 |
+
top_p: float = float(get_env('TOP_P', 0.95))
|
| 43 |
+
do_sample: bool = eval(get_env('DO_SAMPLE', 'True'))
|
| 44 |
+
|
| 45 |
# Discord bot config - optional
|
| 46 |
discord_token: str = get_env('DISCORD_TOKEN', '-', warn=False)
|
| 47 |
discord_channel_ids: list[int] = get_env('DISCORD_CHANNEL_IDS', field(default_factory=list), warn=False)
|
qa_engine/logger.py
CHANGED
|
@@ -1,88 +1,14 @@
|
|
| 1 |
import logging
|
| 2 |
-
import os
|
| 3 |
-
import io
|
| 4 |
-
import json
|
| 5 |
-
from google.cloud import bigquery
|
| 6 |
-
from google.oauth2 import service_account
|
| 7 |
-
from google.api_core.exceptions import GoogleAPIError
|
| 8 |
-
|
| 9 |
-
job_config = bigquery.LoadJobConfig(
|
| 10 |
-
schema=[
|
| 11 |
-
bigquery.SchemaField("timestamp", "TIMESTAMP", mode="REQUIRED"),
|
| 12 |
-
bigquery.SchemaField("log_entry", "STRING", mode="REQUIRED"),
|
| 13 |
-
],
|
| 14 |
-
write_disposition="WRITE_APPEND",
|
| 15 |
-
)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class BigQueryLoggingHandler(logging.Handler):
|
| 19 |
-
def __init__(self):
|
| 20 |
-
super().__init__()
|
| 21 |
-
try:
|
| 22 |
-
project_id = os.getenv("BIGQUERY_PROJECT_ID")
|
| 23 |
-
dataset_id = os.getenv("BIGQUERY_DATASET_ID")
|
| 24 |
-
table_id = os.getenv("BIGQUERY_TABLE_ID")
|
| 25 |
-
print(f"project_id: {project_id}")
|
| 26 |
-
print(f"dataset_id: {dataset_id}")
|
| 27 |
-
print(f"table_id: {table_id}")
|
| 28 |
-
service_account_info = json.loads(
|
| 29 |
-
os.getenv("GOOGLE_SERVICE_ACCOUNT_JSON")
|
| 30 |
-
.replace('"', "")
|
| 31 |
-
.replace("'", '"')
|
| 32 |
-
)
|
| 33 |
-
print(f"service_account_info: {service_account_info}")
|
| 34 |
-
print(f"service_account_info type: {type(service_account_info)}")
|
| 35 |
-
print(f"service_account_info keys: {service_account_info.keys()}")
|
| 36 |
-
credentials = service_account.Credentials.from_service_account_info(
|
| 37 |
-
service_account_info
|
| 38 |
-
)
|
| 39 |
-
self.client = bigquery.Client(credentials=credentials, project=project_id)
|
| 40 |
-
self.table_ref = self.client.dataset(dataset_id).table(table_id)
|
| 41 |
-
except Exception as e:
|
| 42 |
-
print(f"Error: {e}")
|
| 43 |
-
self.handleError(e)
|
| 44 |
-
|
| 45 |
-
def emit(self, record):
|
| 46 |
-
try:
|
| 47 |
-
recordstr = f"{self.format(record)}"
|
| 48 |
-
body = io.BytesIO(recordstr.encode("utf-8"))
|
| 49 |
-
job = self.client.load_table_from_file(
|
| 50 |
-
body, self.table_ref, job_config=job_config
|
| 51 |
-
)
|
| 52 |
-
job.result()
|
| 53 |
-
except GoogleAPIError as e:
|
| 54 |
-
self.handleError(e)
|
| 55 |
-
except Exception as e:
|
| 56 |
-
self.handleError(e)
|
| 57 |
-
|
| 58 |
-
def handleError(self, record):
|
| 59 |
-
"""
|
| 60 |
-
Handle errors associated with logging.
|
| 61 |
-
This method prevents logging-related exceptions from propagating.
|
| 62 |
-
Optionally, implement more sophisticated error handling here.
|
| 63 |
-
"""
|
| 64 |
-
if isinstance(record, logging.LogRecord):
|
| 65 |
-
super().handleError(record)
|
| 66 |
-
else:
|
| 67 |
-
print(f"Logging error: {record}")
|
| 68 |
|
| 69 |
|
| 70 |
logger = logging.getLogger(__name__)
|
| 71 |
|
| 72 |
-
|
| 73 |
def setup_logger() -> None:
|
| 74 |
"""
|
| 75 |
Logger setup.
|
| 76 |
"""
|
| 77 |
logger.setLevel(logging.DEBUG)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
)
|
| 82 |
-
stream_handler = logging.StreamHandler()
|
| 83 |
-
stream_handler.setFormatter(stream_formatter)
|
| 84 |
-
logger.addHandler(stream_handler)
|
| 85 |
-
|
| 86 |
-
bq_handler = BigQueryLoggingHandler()
|
| 87 |
-
bq_handler.setFormatter(stream_formatter)
|
| 88 |
-
logger.addHandler(bq_handler)
|
|
|
|
| 1 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
logger = logging.getLogger(__name__)
|
| 5 |
|
|
|
|
| 6 |
def setup_logger() -> None:
|
| 7 |
"""
|
| 8 |
Logger setup.
|
| 9 |
"""
|
| 10 |
logger.setLevel(logging.DEBUG)
|
| 11 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 12 |
+
handler = logging.StreamHandler()
|
| 13 |
+
handler.setFormatter(formatter)
|
| 14 |
+
logger.addHandler(handler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qa_engine/mocks.py
CHANGED
|
@@ -10,7 +10,7 @@ class MockLocalBinaryModel(LLM):
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
model_path: str = None
|
| 13 |
-
llm: str = '
|
| 14 |
|
| 15 |
def __init__(self):
|
| 16 |
super().__init__()
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
model_path: str = None
|
| 13 |
+
llm: str = 'Warsaw'
|
| 14 |
|
| 15 |
def __init__(self):
|
| 16 |
super().__init__()
|
qa_engine/qa_engine.py
CHANGED
|
@@ -16,7 +16,7 @@ from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings
|
|
| 16 |
from langchain.vectorstores import FAISS
|
| 17 |
from sentence_transformers import CrossEncoder
|
| 18 |
|
| 19 |
-
from qa_engine import logger
|
| 20 |
from qa_engine.response import Response
|
| 21 |
from qa_engine.mocks import MockLocalBinaryModel
|
| 22 |
|
|
@@ -25,16 +25,16 @@ class LocalBinaryModel(LLM):
|
|
| 25 |
model_id: str = None
|
| 26 |
llm: None = None
|
| 27 |
|
| 28 |
-
def __init__(self,
|
| 29 |
super().__init__()
|
| 30 |
# pip install llama_cpp_python==0.1.39
|
| 31 |
from llama_cpp import Llama
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
self.llm = Llama(model_path=model_path, n_ctx=4096)
|
| 38 |
|
| 39 |
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
| 40 |
output = self.llm(
|
|
@@ -58,13 +58,19 @@ class TransformersPipelineModel(LLM):
|
|
| 58 |
model_id: str = None
|
| 59 |
pipeline: str = None
|
| 60 |
|
| 61 |
-
def __init__(self,
|
| 62 |
super().__init__()
|
| 63 |
-
self.model_id =
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
-
model_id,
|
| 68 |
torch_dtype=torch.bfloat16,
|
| 69 |
trust_remote_code=True,
|
| 70 |
load_in_8bit=False,
|
|
@@ -79,10 +85,12 @@ class TransformersPipelineModel(LLM):
|
|
| 79 |
device_map='auto',
|
| 80 |
eos_token_id=tokenizer.eos_token_id,
|
| 81 |
pad_token_id=tokenizer.eos_token_id,
|
| 82 |
-
min_new_tokens=
|
| 83 |
-
max_new_tokens=
|
| 84 |
-
temperature=
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
)
|
| 87 |
|
| 88 |
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
|
@@ -103,7 +111,7 @@ class APIServedModel(LLM):
|
|
| 103 |
model_url: str = None
|
| 104 |
debug: bool = None
|
| 105 |
|
| 106 |
-
def __init__(self, model_url: str
|
| 107 |
super().__init__()
|
| 108 |
if model_url[-1] == '/':
|
| 109 |
raise ValueError('URL should not end with a slash - "/"')
|
|
@@ -132,66 +140,36 @@ class APIServedModel(LLM):
|
|
| 132 |
return 'api_model'
|
| 133 |
|
| 134 |
|
| 135 |
-
|
| 136 |
class QAEngine():
|
| 137 |
"""
|
| 138 |
QAEngine class, used for generating answers to questions.
|
| 139 |
-
|
| 140 |
-
Args:
|
| 141 |
-
llm_model_id (str): The ID of the LLM model to be used.
|
| 142 |
-
embedding_model_id (str): The ID of the embedding model to be used.
|
| 143 |
-
index_repo_id (str): The ID of the index repository to be used.
|
| 144 |
-
run_locally (bool, optional): Whether to run the models locally or on the Hugging Face hub. Defaults to True.
|
| 145 |
-
use_docs_for_context (bool, optional): Whether to use relevant documents as context for generating answers.
|
| 146 |
-
Defaults to True.
|
| 147 |
-
use_messages_for_context (bool, optional): Whether to use previous messages as context for generating answers.
|
| 148 |
-
Defaults to True.
|
| 149 |
-
debug (bool, optional): Whether to log debug information. Defaults to False.
|
| 150 |
-
|
| 151 |
-
Attributes:
|
| 152 |
-
use_docs_for_context (bool): Whether to use relevant documents as context for generating answers.
|
| 153 |
-
use_messages_for_context (bool): Whether to use previous messages as context for generating answers.
|
| 154 |
-
debug (bool): Whether to log debug information.
|
| 155 |
-
llm_model (Union[LocalBinaryModel, HuggingFacePipeline, HuggingFaceHub]): The LLM model to be used.
|
| 156 |
-
embedding_model (Union[HuggingFaceInstructEmbeddings, HuggingFaceHubEmbeddings]): The embedding model to be used.
|
| 157 |
-
prompt_template (PromptTemplate): The prompt template to be used.
|
| 158 |
-
llm_chain (LLMChain): The LLM chain to be used.
|
| 159 |
-
knowledge_index (FAISS): The FAISS index to be used.
|
| 160 |
-
|
| 161 |
"""
|
| 162 |
-
def __init__(
|
| 163 |
-
self,
|
| 164 |
-
llm_model_id: str,
|
| 165 |
-
embedding_model_id: str,
|
| 166 |
-
index_repo_id: str,
|
| 167 |
-
prompt_template: str,
|
| 168 |
-
use_docs_for_context: bool = True,
|
| 169 |
-
num_relevant_docs: int = 3,
|
| 170 |
-
add_sources_to_response: bool = True,
|
| 171 |
-
use_messages_for_context: bool = True,
|
| 172 |
-
first_stage_docs: int = 50,
|
| 173 |
-
debug: bool = False
|
| 174 |
-
):
|
| 175 |
super().__init__()
|
| 176 |
-
self.
|
| 177 |
-
self.
|
| 178 |
-
self.
|
| 179 |
-
self.
|
| 180 |
-
self.
|
| 181 |
-
self.
|
| 182 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
prompt = PromptTemplate(
|
| 185 |
-
template=prompt_template,
|
| 186 |
input_variables=['question', 'context']
|
| 187 |
)
|
| 188 |
-
self.llm_model =
|
| 189 |
self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
|
| 190 |
|
| 191 |
if self.use_docs_for_context:
|
| 192 |
-
logger.info(f'Downloading {index_repo_id}')
|
| 193 |
snapshot_download(
|
| 194 |
-
repo_id=index_repo_id,
|
| 195 |
allow_patterns=['*.faiss', '*.pkl'],
|
| 196 |
repo_type='dataset',
|
| 197 |
local_dir='indexes/run/'
|
|
@@ -200,7 +178,7 @@ class QAEngine():
|
|
| 200 |
embed_instruction = 'Represent the Hugging Face library documentation'
|
| 201 |
query_instruction = 'Query the most relevant piece of information from the Hugging Face documentation'
|
| 202 |
embedding_model = HuggingFaceInstructEmbeddings(
|
| 203 |
-
model_name=embedding_model_id,
|
| 204 |
embed_instruction=embed_instruction,
|
| 205 |
query_instruction=query_instruction
|
| 206 |
)
|
|
@@ -209,27 +187,22 @@ class QAEngine():
|
|
| 209 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
| 210 |
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
if 'local_models/' in llm_model_id:
|
| 215 |
logger.info('using local binary model')
|
| 216 |
-
return LocalBinaryModel(
|
| 217 |
-
|
| 218 |
-
)
|
| 219 |
-
elif 'api_models/' in llm_model_id:
|
| 220 |
logger.info('using api served model')
|
| 221 |
return APIServedModel(
|
| 222 |
-
model_url=
|
| 223 |
debug=self.debug
|
| 224 |
)
|
| 225 |
-
elif
|
| 226 |
logger.info('using mock model')
|
| 227 |
return MockLocalBinaryModel()
|
| 228 |
else:
|
| 229 |
logger.info('using transformers pipeline model')
|
| 230 |
-
return TransformersPipelineModel(
|
| 231 |
-
model_id=llm_model_id
|
| 232 |
-
)
|
| 233 |
|
| 234 |
|
| 235 |
@staticmethod
|
|
@@ -245,7 +218,8 @@ class QAEngine():
|
|
| 245 |
Preprocess the answer by removing unnecessary sequences and stop sequences.
|
| 246 |
'''
|
| 247 |
SEQUENCES_TO_REMOVE = [
|
| 248 |
-
'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'
|
|
|
|
| 249 |
]
|
| 250 |
SEQUENCES_TO_STOP = [
|
| 251 |
'User:', 'You:', 'Question:'
|
|
@@ -296,9 +270,8 @@ class QAEngine():
|
|
| 296 |
)
|
| 297 |
]
|
| 298 |
relevant_docs = relevant_docs[:self.num_relevant_docs]
|
| 299 |
-
context += '\
|
| 300 |
-
|
| 301 |
-
context += f'\n\n<DOCUMENT_{i}>\n {doc.page_content} \n</DOCUMENT_{i}>'
|
| 302 |
metadata = [doc.metadata for doc in relevant_docs]
|
| 303 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
| 304 |
|
|
@@ -314,7 +287,6 @@ class QAEngine():
|
|
| 314 |
sep = '\n' + '-' * 100
|
| 315 |
logger.info(f'question len: {len(question)} {sep}')
|
| 316 |
logger.info(f'question: {question} {sep}')
|
| 317 |
-
logger.info(f'question processed: {question} {sep}')
|
| 318 |
logger.info(f'answer len: {len(response.get_answer())} {sep}')
|
| 319 |
logger.info(f'answer original: {answer} {sep}')
|
| 320 |
logger.info(f'answer postprocessed: {response.get_answer()} {sep}')
|
|
|
|
| 16 |
from langchain.vectorstores import FAISS
|
| 17 |
from sentence_transformers import CrossEncoder
|
| 18 |
|
| 19 |
+
from qa_engine import logger, Config
|
| 20 |
from qa_engine.response import Response
|
| 21 |
from qa_engine.mocks import MockLocalBinaryModel
|
| 22 |
|
|
|
|
| 25 |
model_id: str = None
|
| 26 |
llm: None = None
|
| 27 |
|
| 28 |
+
def __init__(self, config: Config):
|
| 29 |
super().__init__()
|
| 30 |
# pip install llama_cpp_python==0.1.39
|
| 31 |
from llama_cpp import Llama
|
| 32 |
|
| 33 |
+
self.model_id = config.question_answering_model_id
|
| 34 |
+
self.model_path = f'qa_engine/{self.model_id}'
|
| 35 |
+
if not os.path.exists(self.model_path):
|
| 36 |
+
raise ValueError(f'{self.model_path} does not exist')
|
| 37 |
+
self.llm = Llama(model_path=self.model_path, n_ctx=4096)
|
| 38 |
|
| 39 |
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
| 40 |
output = self.llm(
|
|
|
|
| 58 |
model_id: str = None
|
| 59 |
pipeline: str = None
|
| 60 |
|
| 61 |
+
def __init__(self, config: Config):
|
| 62 |
super().__init__()
|
| 63 |
+
self.model_id = config.question_answering_model_id
|
| 64 |
+
self.min_new_tokens = config.min_new_tokens
|
| 65 |
+
self.max_new_tokens = config.max_new_tokens
|
| 66 |
+
self.temperature = config.temperature
|
| 67 |
+
self.top_k = config.top_k
|
| 68 |
+
self.top_p = config.top_p
|
| 69 |
+
self.do_sample = config.do_sample
|
| 70 |
+
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| 72 |
model = AutoModelForCausalLM.from_pretrained(
|
| 73 |
+
self.model_id,
|
| 74 |
torch_dtype=torch.bfloat16,
|
| 75 |
trust_remote_code=True,
|
| 76 |
load_in_8bit=False,
|
|
|
|
| 85 |
device_map='auto',
|
| 86 |
eos_token_id=tokenizer.eos_token_id,
|
| 87 |
pad_token_id=tokenizer.eos_token_id,
|
| 88 |
+
min_new_tokens=self.min_new_tokens,
|
| 89 |
+
max_new_tokens=self.max_new_tokens,
|
| 90 |
+
temperature=self.temperature,
|
| 91 |
+
top_k=self.top_k,
|
| 92 |
+
top_p=self.top_p,
|
| 93 |
+
do_sample=self.do_sample,
|
| 94 |
)
|
| 95 |
|
| 96 |
def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
|
|
|
|
| 111 |
model_url: str = None
|
| 112 |
debug: bool = None
|
| 113 |
|
| 114 |
+
def __init__(self, model_url: str, debug: bool = False):
|
| 115 |
super().__init__()
|
| 116 |
if model_url[-1] == '/':
|
| 117 |
raise ValueError('URL should not end with a slash - "/"')
|
|
|
|
| 140 |
return 'api_model'
|
| 141 |
|
| 142 |
|
|
|
|
| 143 |
class QAEngine():
|
| 144 |
"""
|
| 145 |
QAEngine class, used for generating answers to questions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
"""
|
| 147 |
+
def __init__(self, config: Config):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
super().__init__()
|
| 149 |
+
self.config = config
|
| 150 |
+
self.question_answering_model_id=config.question_answering_model_id
|
| 151 |
+
self.embedding_model_id=config.embedding_model_id
|
| 152 |
+
self.index_repo_id=config.index_repo_id
|
| 153 |
+
self.prompt_template=config.prompt_template
|
| 154 |
+
self.use_docs_for_context=config.use_docs_for_context
|
| 155 |
+
self.num_relevant_docs=config.num_relevant_docs
|
| 156 |
+
self.add_sources_to_response=config.add_sources_to_response
|
| 157 |
+
self.use_messages_for_context=config.use_messages_in_context
|
| 158 |
+
self.debug=config.debug
|
| 159 |
+
|
| 160 |
+
self.first_stage_docs: int = 50
|
| 161 |
|
| 162 |
prompt = PromptTemplate(
|
| 163 |
+
template=self.prompt_template,
|
| 164 |
input_variables=['question', 'context']
|
| 165 |
)
|
| 166 |
+
self.llm_model = self._get_model()
|
| 167 |
self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
|
| 168 |
|
| 169 |
if self.use_docs_for_context:
|
| 170 |
+
logger.info(f'Downloading {self.index_repo_id}')
|
| 171 |
snapshot_download(
|
| 172 |
+
repo_id=self.index_repo_id,
|
| 173 |
allow_patterns=['*.faiss', '*.pkl'],
|
| 174 |
repo_type='dataset',
|
| 175 |
local_dir='indexes/run/'
|
|
|
|
| 178 |
embed_instruction = 'Represent the Hugging Face library documentation'
|
| 179 |
query_instruction = 'Query the most relevant piece of information from the Hugging Face documentation'
|
| 180 |
embedding_model = HuggingFaceInstructEmbeddings(
|
| 181 |
+
model_name=self.embedding_model_id,
|
| 182 |
embed_instruction=embed_instruction,
|
| 183 |
query_instruction=query_instruction
|
| 184 |
)
|
|
|
|
| 187 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
| 188 |
|
| 189 |
|
| 190 |
+
def _get_model(self):
|
| 191 |
+
if 'local_models/' in self.question_answering_model_id:
|
|
|
|
| 192 |
logger.info('using local binary model')
|
| 193 |
+
return LocalBinaryModel(self.config)
|
| 194 |
+
elif 'api_models/' in self.question_answering_model_id:
|
|
|
|
|
|
|
| 195 |
logger.info('using api served model')
|
| 196 |
return APIServedModel(
|
| 197 |
+
model_url=self.question_answering_model_id.replace('api_models/', ''),
|
| 198 |
debug=self.debug
|
| 199 |
)
|
| 200 |
+
elif self.question_answering_model_id == 'mock':
|
| 201 |
logger.info('using mock model')
|
| 202 |
return MockLocalBinaryModel()
|
| 203 |
else:
|
| 204 |
logger.info('using transformers pipeline model')
|
| 205 |
+
return TransformersPipelineModel(self.config)
|
|
|
|
|
|
|
| 206 |
|
| 207 |
|
| 208 |
@staticmethod
|
|
|
|
| 218 |
Preprocess the answer by removing unnecessary sequences and stop sequences.
|
| 219 |
'''
|
| 220 |
SEQUENCES_TO_REMOVE = [
|
| 221 |
+
'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]',
|
| 222 |
+
'<context>', '<\context>', '<question>', '<\question>',
|
| 223 |
]
|
| 224 |
SEQUENCES_TO_STOP = [
|
| 225 |
'User:', 'You:', 'Question:'
|
|
|
|
| 270 |
)
|
| 271 |
]
|
| 272 |
relevant_docs = relevant_docs[:self.num_relevant_docs]
|
| 273 |
+
context += '\nExtracted documents:\n'
|
| 274 |
+
context += ''.join([doc.page_content for doc in relevant_docs])
|
|
|
|
| 275 |
metadata = [doc.metadata for doc in relevant_docs]
|
| 276 |
response.set_sources(sources=[str(m['source']) for m in metadata])
|
| 277 |
|
|
|
|
| 287 |
sep = '\n' + '-' * 100
|
| 288 |
logger.info(f'question len: {len(question)} {sep}')
|
| 289 |
logger.info(f'question: {question} {sep}')
|
|
|
|
| 290 |
logger.info(f'answer len: {len(response.get_answer())} {sep}')
|
| 291 |
logger.info(f'answer original: {answer} {sep}')
|
| 292 |
logger.info(f'answer postprocessed: {response.get_answer()} {sep}')
|
requirements.txt
CHANGED
|
@@ -26,4 +26,3 @@ InstructorEmbedding==1.0.0
|
|
| 26 |
faiss_cpu==1.7.3
|
| 27 |
uvicorn==0.22.0
|
| 28 |
pytest==7.3.1
|
| 29 |
-
google-cloud-bigquery==3.17.2
|
|
|
|
| 26 |
faiss_cpu==1.7.3
|
| 27 |
uvicorn==0.22.0
|
| 28 |
pytest==7.3.1
|
|
|