Spaces:
Runtime error
Runtime error
| import modal | |
| import logging | |
| app = modal.App("qwen-reranker-vllm") | |
| hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol") | |
| vllm_cache_vol = modal.Volume.from_name("vllm-cache") | |
| MINUTES = 60 # seconds | |
| vllm_image = ( | |
| modal.Image.debian_slim(python_version="3.12") | |
| .pip_install( | |
| "vllm==0.8.5", | |
| "transformers", | |
| "torch", | |
| "fastapi[all]", | |
| "pydantic" | |
| ) | |
| .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) | |
| ) | |
| with vllm_image.imports(): | |
| from transformers import AutoTokenizer | |
| from vllm import LLM, SamplingParams | |
| from vllm.inputs.data import TokensPrompt | |
| import torch | |
| import math | |
| class Reranker: | |
| def load_reranker(self): | |
| logging.info("in the rank function") | |
| self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B") | |
| self.tokenizer.padding_side = "left" | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model = LLM( | |
| model="Qwen/Qwen3-Reranker-4B", | |
| tensor_parallel_size=torch.cuda.device_count(), | |
| max_model_len=10000, | |
| enable_prefix_caching=True, | |
| gpu_memory_utilization=0.8 | |
| ) | |
| self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" | |
| self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False) | |
| self.max_length = 8192 | |
| self.true_token = self.tokenizer("yes", add_special_tokens=False).input_ids[0] | |
| self.false_token = self.tokenizer("no", add_special_tokens=False).input_ids[0] | |
| self.sampling_params = SamplingParams( | |
| temperature=0, | |
| max_tokens=1, | |
| logprobs=20, | |
| allowed_token_ids=[self.true_token, self.false_token], | |
| ) | |
| def format_instruction(self, instruction, query, doc): | |
| return [ | |
| {"role": "system", "content": "Judge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\""}, | |
| {"role": "user", "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}"} | |
| ] | |
| def process_inputs(self,pairs, instruction): | |
| messages = [self.format_instruction(instruction, query, doc) for query, doc in pairs] | |
| messages = self.tokenizer.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=False, enable_thinking=False | |
| ) | |
| messages = [ele[:self.max_length] + self.suffix_tokens for ele in messages] | |
| messages = [TokensPrompt(prompt_token_ids=ele) for ele in messages] | |
| return messages | |
| def compute_logits(self, messages): | |
| outputs = self.model.generate(messages, self.sampling_params, use_tqdm=False) | |
| scores = [] | |
| for i in range(len(outputs)): | |
| final_logits = outputs[i].outputs[0].logprobs[-1] | |
| token_count = len(outputs[i].outputs[0].token_ids) | |
| if self.true_token not in final_logits: | |
| true_logit = -10 | |
| else: | |
| true_logit = final_logits[self.true_token].logprob | |
| if self.false_token not in final_logits: | |
| false_logit = -10 | |
| else: | |
| false_logit = final_logits[self.false_token].logprob | |
| true_score = math.exp(true_logit) | |
| false_score = math.exp(false_logit) | |
| score = true_score / (true_score + false_score) | |
| scores.append(score) | |
| return scores | |
| def rerank(self, query, documents,task): | |
| #task = 'Given a web search query, retrieve relevant passages that answer the query' | |
| pairs = [(query, doc) for doc in documents] | |
| inputs = self.process_inputs(pairs, task) | |
| scores = self.compute_logits( inputs) | |
| return [{"score": float(score), "content": doc} for score, doc in zip(scores, documents)] | |
| def fastapi_app(): | |
| from pydantic import BaseModel | |
| from fastapi import FastAPI, Request, Response | |
| from fastapi.responses import JSONResponse | |
| from typing import List | |
| web_app = FastAPI() | |
| reranker = Reranker() | |
| class ScoringResult(BaseModel): | |
| score: float | |
| content: str | |
| class RankingRequest(BaseModel): | |
| task:str | |
| query: str | |
| documents: List[str] | |
| async def predict(payload: RankingRequest): | |
| logging.info("call the rank function") | |
| query = payload.query | |
| documents = payload.documents | |
| task = payload.task | |
| output_data = reranker.rerank.remote(query,documents,task) | |
| return JSONResponse(content=output_data) | |
| return web_app |