Gemini
feat: add detailed logging
01d5a5d
import os
import sys
import json
import random
import concurrent.futures
import openai
from tqdm import tqdm
from prompt import JUDGE_COT_PROMPT, JUDGE_PROMPT, MEMORY_COT_PROMPT, MEMORY_PROMPT, CONTEXT_COT_PROMPT, CONTEXT_PROMPT, CONTEXT_ENHANCE_EVAL_SYS, JUDGE_EVAL_SYS, MEMORY_EVAL_SYS, USR
from utils import OPENAI_API_KEY,OPENAI_BASE_URL,Global_Bio
from openai import OpenAI
from pydantic import BaseModel
from collections import defaultdict
# COT mode
IS_COT = False
# USER NAME SETTING
USER_NAME = "Felix Tao"
# prefered language
preference_language = "English"
class Rate(BaseModel):
comparison: str
detailed_analysis: str
class DPOData:
"""Generates DPO data for training language models.
This class is responsible for creating diverse training data based on user notes,
entities, and configurations. It leverages LLMs to generate questions and answers.
"""
def __init__(self, input_path, output_dir,preference_language: str):
"""Initialize the DPO data generator.
Args:
input_path: Path to the input JSON file.
output_dir: Directory to save the output JSON files.
"""
self.input_path = input_path
self.output_dir = output_dir
# Use the API key and base URL from utils.py
self.model_name = "gpt-4o" # Set your model name here
if OPENAI_BASE_URL:
self.client = openai.OpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_BASE_URL,
)
else:
self.client = OpenAI(
api_key=OPENAI_API_KEY,
)
self.preference_language = preference_language
def load_and_sample_data(self, sample_fraction=0.1):
"""
Load data from a JSON file and sample a fraction of it.
:param sample_fraction: Fraction of data to sample.
:return: Sampled data.
"""
with open(input_path, 'r', encoding='utf-8') as file:
data = json.load(file)
sampled_data = random.sample(data, int(len(data) * sample_fraction))
chat_messages = self.create_chat_data(sampled_data)
return chat_messages
# build messages in chat format
def create_chat_data(self,data):
def preprocess(sample, is_cot=False):
if sample.get('assistant') is None and sample.get('enhanced_request') is not None:
user_message = f"{USER_NAME}'s request is " + sample['user_request']
infer_prompt = CONTEXT_COT_PROMPT.format(user_name=USER_NAME) if is_cot else CONTEXT_PROMPT.format(user_name=USER_NAME)
messages = [
{"role": "system", "content": infer_prompt},
{"role": "user", "content": user_message},
# {"role": "assistant", "content": sample['enhanced_request'].strip('\n')},
]
return [{"messages": messages,"user":user_message,"label":sample['enhanced_request'].strip('\n'),"eval_prompt":CONTEXT_ENHANCE_EVAL_SYS,"infer_prompt":infer_prompt}]
if sample.get('assistant') is None and sample.get('user_feedback') is not None:
user_message = f"{USER_NAME}'s request is " + sample['user_request'] + "\n" + "The response of expert is " + sample['expert_response']
infer_prompt = JUDGE_COT_PROMPT.format(user_name=USER_NAME) if is_cot else JUDGE_PROMPT.format(user_name=USER_NAME)
messages = [
{"role": "system", "content": infer_prompt},
{"role": "user", "content": user_message},
# {"role": "assistant", "content": sample['user_feedback'].strip('\n')},
]
global_bio = Global_Bio
return [{"messages": messages,"user":user_message,"label":sample['user_feedback'].strip('\n'),"eval_prompt":JUDGE_EVAL_SYS.format(global_bio=global_bio),"infer_prompt":infer_prompt}]
sample['assistant'] = sample['assistant'].strip('\n')
if sample.get('timestamp') is not None and sample.get('is_timeqa', None) is None:
# messages1 = [
# {"role": "system", "content": "You are a helpful assistant.\n\nThe current date is " + sample['timestamp'][:10]},
# {"role": "user", "content": "<|ME|>" + sample['user']},
# {"role": "assistant", "content": sample['assistant']},
# ]
messages2 = [
{"role": "system", "content": ""},
{"role": "user", "content": "<|ME|>" + sample['user']},
{"role": "assistant", "content": sample['assistant']},
]
if 'None' in sample['assistant']:
return []
# return [{"content": tokenizer.apply_chat_template(messages1, tokenize=False)},
# {"content": tokenizer.apply_chat_template(messages2, tokenize=False)}]
return [{"messages": messages2}]
elif sample.get('is_timeqa', None) is not None:
messages = [
{"role": "system", "content": "You are a helpful assistant.\n\nToday’s date is " + sample['timestamp']},
{"role": "user", "content": "<|ME|>" + sample['user']},
{"role": "assistant", "content": sample['assistant']},
]
if 'None' in sample['assistant']:
return []
return {"messages": messages}
elif sample.get('exact_day', None) is not None:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "<|ME|>" + sample['user']},
{"role": "assistant", "content": sample['assistant']},
]
return [{"messages": messages}]
else:
infer_prompt = MEMORY_COT_PROMPT.format(user_name=USER_NAME) if is_cot else MEMORY_PROMPT.format(user_name=USER_NAME)
messages = [
{"role": "system", "content": infer_prompt},
{"role": "user", "content": sample['user']},
# {"role": "assistant", "content": sample['assistant']},
]
if 'None' in sample['assistant']:
return []
return [{"messages": messages,"user":sample['user'],"label":sample['assistant'],"eval_prompt":MEMORY_EVAL_SYS,"infer_prompt":infer_prompt}]
res_dataset = []
for case in data:
res_dataset.extend(preprocess(case, IS_COT))
# res = Dataset.from_list(res_dataset)
# print(f"**************Dataset contains {res.num_rows} elements.**************")
print(f"**************Dataset contains {len(res_dataset)} elements.**************")
# print(res_dataset[:2])
return res_dataset
def generate_all_traces(self, processed_data):
"""
Generate traces for all processed data.
:param processed_data: Preprocessed data.
:return: All generated traces.
"""
all_traces=[]
for instance in tqdm(processed_data, desc=f"Generating traces"):
message = instance.get("messages",[])
# generate trace for each message
traces = self.generate_traces(message,3)
# attach traces to each instance
instance_with_traces = {
"user": instance["user"],
"label": instance["label"],
"traces": traces,
"eval_prompt": instance["eval_prompt"],
"infer_prompt":instance["infer_prompt"]
}
print(instance_with_traces)
all_traces.append(instance_with_traces)
return all_traces
def generate_traces(self,messages, nums_traces=3):
"""
Generate traces using the OpenAI API.
llama.cpp can serve as http server
so we can use it as a openai compatible endpoint.
:param messages: List of messages to send to the API.
:param nums_traces: The number of traces to generate.
:return: List of traces.
"""
traces = []
client = OpenAI(base_url="http://127.0.0.1:8080/v1", api_key="key")
for _ in range(nums_traces):
response = client.chat.completions.create(
model="",
messages=messages,
stream=False,
temperature=0.7,
max_tokens=2048,
top_p=1.0 # Adjust top_p as needed
)
traces.append(response.choices[0].message.content)
return traces
def compare_eval(self, instances):
"""
Compare evaluations and determine chosen and rejected responses.
:param instances: Instances with traces.
:return: Instances with chosen and rejected responses.
"""
all_eval_messages = []
# compare traces
for ins in instances:
traces = ins["traces"]
if len(traces) < 3:
raise ValueError("Each instance must have exactly 3 traces.")
# build messages
eval_messages = [
[{
"role": "system",
"content": ins["eval_prompt"]
}, {
"role": "user",
"content": USR.format(
user_input=ins["user"],
model_answer_1=traces[0], # trace1 vs trace2
model_answer_2=traces[1],
reference_info=ins["label"]
)
}],
[{
"role": "system",
"content": ins["eval_prompt"]
}, {
"role": "user",
"content": USR.format(
user_input=ins["user"],
model_answer_1=traces[0], # trace1 vs trace3
model_answer_2=traces[2],
reference_info=ins["label"]
)
}],
[{
"role": "system",
"content": ins["eval_prompt"]
}, {
"role": "user",
"content": USR.format(
user_input=ins["user"],
model_answer_1=traces[1], # trace2 vs trace3
model_answer_2=traces[2],
reference_info=ins["label"]
)
}]
]
all_eval_messages.extend(eval_messages)
# access eval rs
trying_limit = len(all_eval_messages)
eval_results = self.multi_process_request(all_eval_messages[:trying_limit], 10, self.process_request_structered, Rate)
# group results
for ins_idx, ins in enumerate(instances):
start_idx = ins_idx * 3
end_idx = start_idx + 3
instance_eval_results = eval_results[start_idx:end_idx]
print(instance_eval_results)
# get rejected responses and chosen responses
tmp_comparisons = []
for result in instance_eval_results:
if type(result) == Rate:
tmp_comparisons.append(result.comparison)
else:
tmp_comparisons.append('tie')
chosen_response, rejected_response, detailed_analysis = self.compare_traces(
traces=ins["traces"],
eval_results=tmp_comparisons
)
# attach results to each instance
ins["chosen_response"] = chosen_response
ins["rejected_response"] = rejected_response
ins["detailed_analysis"] = detailed_analysis
# print the results
print(f"choose_response: {chosen_response}")
print(f"rejected_response: {rejected_response}")
return instances
def compare_traces(self,traces, eval_results):
"""
Compare three traces to determine the best and worst trace.
:param traces: A list of three traces [trace1, trace2, trace3].
:param eval_results: The results of pairwise comparisons, formatted as [{"comparison": "first win"/"tie"/"second win", "detailed_analysis": "..."}, ...].
:return: chosen_response, rejected_response, detailed_analysis
"""
# initialization
win_loss = defaultdict(lambda: {"wins": 0, "losses": 0, "ties": 0})
# comparison res
comparisons = [
(0, 1, eval_results[0]), # trace1 vs trace2
(0, 2, eval_results[1]), # trace1 vs trace3
(1, 2, eval_results[2]), # trace2 vs trace3
]
# calculate wins, losses, and ties
for i, j, result in comparisons:
if result == "first win":
win_loss[traces[i]]["wins"] += 1
win_loss[traces[j]]["losses"] += 1
elif result == "second win":
win_loss[traces[j]]["wins"] += 1
win_loss[traces[i]]["losses"] += 1
elif result == "tie":
win_loss[traces[i]]["ties"] += 1
win_loss[traces[j]]["ties"] += 1
else:
raise ValueError(f"Invalid comparison result: {result}")
# calculate win rate for each trace
def calculate_win_rate(trace):
total = win_loss[trace]["wins"] + win_loss[trace]["losses"] + win_loss[trace]["ties"]
if total == 0:
return 0
return win_loss[trace]["wins"] / total
# select chosen_response and rejected_response
sorted_traces = sorted(traces, key=lambda x: calculate_win_rate(x), reverse=True)
chosen_response = sorted_traces[0]
rejected_response = sorted_traces[-1]
# get detailed analysis
detailed_analysis = f"Chosen Response: {chosen_response} (Win Rate: {calculate_win_rate(chosen_response):.2f})\n"
detailed_analysis += f"Rejected Response: {rejected_response} (Win Rate: {calculate_win_rate(rejected_response):.2f})\n"
detailed_analysis += "Comparison Details:\n"
for trace in traces:
detailed_analysis += f"{trace}: Wins={win_loss[trace]['wins']}, Losses={win_loss[trace]['losses']}, Ties={win_loss[trace]['ties']}\n"
return chosen_response, rejected_response, detailed_analysis
def process_request_structered(self,messages, format_class):
try:
model = self.model_name
completion = self.client.beta.chat.completions.parse(
model=model,
messages=messages,
response_format=format_class,
# extra_body={"metadata": {"tags": ["lpmPreferDataGen"]}},
)
message = completion.choices[0].message
if message.parsed:
print(f"model answer:{message.parsed}")
return message.parsed
else:
return message.refusal
except Exception as e:
return f"Error occurred: {str(e)}"
def multi_process_request(self,all_messages, max_workers, func, structure=None):
with concurrent.futures.ThreadPoolExecutor(max_workers=min(max_workers, len(all_messages))) as executor:
futures = [(i, executor.submit(func, messages, structure)) if structure is not None else (i, executor.submit(func, messages)) for i, messages in enumerate(all_messages)]
results = [None] * len(all_messages)
for i, future in tqdm(futures):
try:
result = future.result()
results[i] = result
except Exception as e:
results[i] = f"Raise ERROR: {e} WHEN GENERATE RESPONSE"
return results
def prepare_dpo_datasets(self,sampled_data):
"""
Prepare full and direct training versions of the DPO dataset.
:param sampled_data: Sampled data from the input JSON file.
:return: Full version and direct training version of the dataset.
"""
full_version = []
direct_training_version = []
for item in sampled_data:
full_version.append(item)
direct_training_version.append({
'prompt': {"system":item['infer_prompt'],"user":item['user']},
'chosen': item['chosen_response'],
'rejected': item['rejected_response']
})
return full_version, direct_training_version
def save_datasets(self,output_dir, full_version, direct_training_version):
"""
Save the full and direct training versions of the DPO dataset to JSON files.
:param output_dir: Directory to save the output JSON files.
:param full_version: Full version of the dataset.
:param direct_training_version: Direct training version of the dataset.
"""
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, 'dpo_full.json'), 'w', encoding='utf-8') as file:
json.dump(full_version, file, ensure_ascii=False, indent=4)
with open(os.path.join(output_dir, 'dpo_direct.json'), 'w', encoding='utf-8') as file:
json.dump(direct_training_version, file, ensure_ascii=False, indent=4)
def run(self):
"""
Main function to orchestrate the workflow.
"""
# Load and sample data, combine system prompt for each task.
sampled_data = self.load_and_sample_data()
# Generate traces for all cases
all_traces = self.generate_all_traces(sampled_data)
# Compare eval -> get chosen and rejected responses
compare_res = self.compare_eval(all_traces)
# Prepare DPO datasets
full_version, direct_training_version = self.prepare_dpo_datasets(compare_res)
# Save datasets
self.save_datasets(self.output_dir,full_version, direct_training_version)
print(f"Sampled data saved to {self.output_dir}")
# Example usage
if __name__ == "__main__":
input_path = 'resources/L2/data/merged.json'
output_dir = 'resources/L2/data/dpo/'
dpo_data = DPOData(input_path, output_dir,preference_language)
dpo_data.run()