Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import Callable, List, Optional, Tuple | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from transformers import AutoTokenizer, PreTrainedTokenizerBase | |
| from .config import LongCepoConfig | |
| logger = logging.getLogger(__name__) | |
| class CBLog(dict): | |
| """Object for logging the number of LLM calls and tokens used in the pipeline""" | |
| __allowed_keys__ = {"total_tokens", "completion_tokens", "llm_calls"} | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| self.update(*args, **kwargs) | |
| def __setitem__(self, key, value): | |
| if key not in self.__allowed_keys__: | |
| raise KeyError( | |
| f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}" | |
| ) | |
| if not isinstance(value, int): | |
| raise TypeError( | |
| f"Value for '{key}' must be int, got {type(value).__name__}" | |
| ) | |
| super().__setitem__(key, value) | |
| def update(self, other=None, **kwargs): | |
| updates = {} | |
| if other: | |
| if isinstance(other, dict): | |
| updates.update(other) | |
| else: | |
| updates.update(dict(other)) | |
| updates.update(kwargs) | |
| for key, value in updates.items(): | |
| if key not in self.__allowed_keys__: | |
| raise KeyError( | |
| f"Key '{key}' not allowed. Allowed keys: {self.__allowed_keys__}" | |
| ) | |
| if not isinstance(value, int): | |
| raise TypeError( | |
| f"Value for '{key}' must be int, got {type(value).__name__}" | |
| ) | |
| self[key] = self.get(key, 0) + value | |
| def concurrent_map( | |
| gen_function: Callable, | |
| client, | |
| model: str, | |
| context_chunks: List[str], | |
| query: str, | |
| system_prompt: str, | |
| cb_log: CBLog, | |
| summaries_per_chunk: Optional[List[str]] = None, | |
| workers: int = 16, | |
| ) -> Tuple[List[str], CBLog]: | |
| """ | |
| Runs `gen_function` concurrently over a list of context chunks. | |
| Args: | |
| gen_function (Callable): Function to call with each chunk and associated arguments. | |
| client: LLM API client. | |
| model (str): Base model name. | |
| context_chunks (List[str]): Input context chunks. | |
| query (str): User query. | |
| system_prompt (str): System prompt string. | |
| cb_log (CBLog): Log object for tracking model calls. | |
| summaries_per_chunk (Optional[List[str]]): Concatenated neighbor summaries for each chunk. | |
| workers (int): Number of threads to use. | |
| Returns: | |
| Tuple[List[str], CBLog]: List of responses (in original order) and updated log object. | |
| """ | |
| result = [None] * len(context_chunks) | |
| wrapped_gen_function = lambda index, *args: (index, gen_function(*args)) | |
| with ThreadPoolExecutor(max_workers=workers) as executor: | |
| future_to_idx = {} | |
| for idx, chunk in enumerate(context_chunks): | |
| args = [client, model, chunk, query, system_prompt] | |
| if summaries_per_chunk is not None: | |
| args.append(summaries_per_chunk[idx]) | |
| future_to_idx[executor.submit(wrapped_gen_function, idx, *args)] = idx | |
| for future in as_completed(future_to_idx): | |
| try: | |
| index, (response, upd_log) = future.result() | |
| result[index] = response | |
| cb_log.update(upd_log) | |
| except Exception as e: | |
| logger.error(f"Error processing chunk: {e}") | |
| return result, cb_log | |
| def get_prompt_response( | |
| client, | |
| model: str, | |
| prompt: str, | |
| system_prompt: str, | |
| max_tokens: int, | |
| temperature: float = 0.7, | |
| top_p: float = 0.7, | |
| ): | |
| """ | |
| Helper function that sends a prompt to the chat-based LLM API and returns the generated response along with usage logging. | |
| Args: | |
| client: LLM API client. | |
| model (str): Base model name. | |
| prompt (str): The user prompt to send. | |
| system_prompt (str): System prompt string. | |
| max_tokens (int): Maximum number of tokens in the response. | |
| temperature (float): Sampling temperature for randomness (default: 0.7). | |
| top_p (float): Cumulative probability cutoff for token selection (default: 0.7). | |
| Returns: | |
| Tuple[str, CBLog]: The model's response text and a CBLog object tracking token usage. | |
| """ | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| print("max_tokens", max_tokens) | |
| print("messages", messages) | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| top_p=top_p, | |
| temperature=temperature, | |
| stream=False, | |
| ) | |
| print("response") | |
| print(response) | |
| upd_log = CBLog( | |
| llm_calls=1, | |
| total_tokens=response.usage.total_tokens, | |
| completion_tokens=response.usage.completion_tokens, | |
| ) | |
| return response.choices[0].message.content, upd_log | |
| def loop_until_match( | |
| function: Callable, pattern_list: Tuple[str], num_attempts: int = 10 | |
| ): | |
| """ | |
| Repeatedly calls a function until its output matches one of the given patterns or max attempts is reached. | |
| Args: | |
| function (Callable): Function returning (answer: str, cb_log). | |
| pattern_list (Tuple[str]): Patterns to match in the answer. | |
| num_attempts (int): Max number of attempts (default: 10). | |
| Returns: | |
| Tuple[str, Any]: The matching answer and its corresponding log object. | |
| """ | |
| correct_format = False | |
| for _ in range(num_attempts): | |
| answer, cb_log = function() | |
| for pattern in pattern_list: | |
| if pattern in answer: | |
| correct_format = True | |
| if correct_format: | |
| break | |
| logger.info("Wrong output formatting, retrying...") | |
| return answer, cb_log | |
| def longcepo_init( | |
| initial_query: str, | |
| ) -> Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]: | |
| """ | |
| Initializes context, query, tokenizer, logging, and config from an input string. | |
| Args: | |
| initial_query (str): Input string containing context and query separated by a delimiter string. | |
| Returns: | |
| Tuple[str, str, PreTrainedTokenizerBase, CBLog, LongCepoConfig]: | |
| Parsed context, query, tokenizer instance, log object, and LongCePO config. | |
| """ | |
| cb_log = CBLog() | |
| config = LongCepoConfig() | |
| context, query = initial_query.split(config.context_query_delimiter) | |
| tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, model_max_length=config.max_context_window) | |
| return context.strip(), query.strip(), tokenizer, cb_log, config | |