ChuxiJ commited on
Commit
2d5e35a
·
1 Parent(s): 091cc88

feat: update api

Browse files
Files changed (3) hide show
  1. acestep/api_server.py +323 -215
  2. acestep/local_cache.py +129 -0
  3. pyproject.toml +1 -0
acestep/api_server.py CHANGED
@@ -1,9 +1,12 @@
1
  """FastAPI server for ACE-Step V1.5.
2
 
3
  Endpoints:
4
- - POST /v1/music/generate Create an async music generation job (queued)
5
- - Supports application/json and multipart/form-data (with file upload)
6
- - GET /v1/jobs/{job_id} Poll job status/result (+ queue position/eta when queued)
 
 
 
7
 
8
  NOTE:
9
  - In-memory queue and job store -> run uvicorn with workers=1.
@@ -25,7 +28,7 @@ from contextlib import asynccontextmanager
25
  from dataclasses import dataclass
26
  from pathlib import Path
27
  from threading import Lock
28
- from typing import Any, Dict, Literal, Optional
29
  from uuid import uuid4
30
 
31
  try:
@@ -54,6 +57,46 @@ from acestep.inference import (
54
  from acestep.gradio_ui.events.results_handlers import _build_generation_info
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def _parse_description_hints(description: str) -> tuple[Optional[str], bool]:
58
  """
59
  Parse a description string to extract language code and instrumental flag.
@@ -129,7 +172,7 @@ JobStatus = Literal["queued", "running", "succeeded", "failed"]
129
 
130
 
131
  class GenerateMusicRequest(BaseModel):
132
- caption: str = Field(default="", description="Text caption describing the music")
133
  lyrics: str = Field(default="", description="Lyric text")
134
 
135
  # New API semantics:
@@ -147,7 +190,7 @@ class GenerateMusicRequest(BaseModel):
147
  model: Optional[str] = Field(default=None, description="Model name to use (e.g., 'acestep-v15-turbo')")
148
 
149
  bpm: Optional[int] = None
150
- # Accept common client keys via manual parsing (see _build_req_from_mapping).
151
  key_scale: str = ""
152
  time_signature: str = ""
153
  vocal_language: str = "en"
@@ -208,15 +251,8 @@ class GenerateMusicRequest(BaseModel):
208
  allow_population_by_alias = True
209
 
210
 
211
- _LM_DEFAULT_TEMPERATURE = 0.85
212
- _LM_DEFAULT_CFG_SCALE = 2.5
213
- _LM_DEFAULT_TOP_P = 0.9
214
- _DEFAULT_DIT_INSTRUCTION = DEFAULT_DIT_INSTRUCTION
215
- _DEFAULT_LM_INSTRUCTION = DEFAULT_LM_INSTRUCTION
216
-
217
-
218
  class CreateJobResponse(BaseModel):
219
- job_id: str
220
  status: JobStatus
221
  queue_position: int = 0 # 1-based best-effort position when queued
222
 
@@ -267,6 +303,7 @@ class _JobRecord:
267
  finished_at: Optional[float] = None
268
  result: Optional[Dict[str, Any]] = None
269
  error: Optional[str] = None
 
270
 
271
 
272
  class _JobStore:
@@ -281,6 +318,18 @@ class _JobStore:
281
  self._jobs[job_id] = rec
282
  return rec
283
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  def get(self, job_id: str) -> Optional[_JobRecord]:
285
  with self._lock:
286
  return self._jobs.get(job_id)
@@ -391,6 +440,70 @@ def _to_bool(v: Any, default: bool = False) -> bool:
391
  return s in {"1", "true", "yes", "y", "on"}
392
 
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  async def _save_upload_to_temp(upload: StarletteUploadFile, *, prefix: str) -> str:
395
  suffix = Path(upload.filename or "").suffix
396
  fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=suffix)
@@ -420,13 +533,13 @@ def create_app() -> FastAPI:
420
  store = _JobStore()
421
 
422
  QUEUE_MAXSIZE = int(os.getenv("ACESTEP_QUEUE_MAXSIZE", "200"))
423
- WORKER_COUNT = int(os.getenv("ACESTEP_QUEUE_WORKERS", "1")) # GPU 建议 1
424
 
425
  INITIAL_AVG_JOB_SECONDS = float(os.getenv("ACESTEP_AVG_JOB_SECONDS", "5.0"))
426
  AVG_WINDOW = int(os.getenv("ACESTEP_AVG_WINDOW", "50"))
427
 
428
  def _path_to_audio_url(path: str) -> str:
429
- """将本地文件路径转换为可下载的相对 URL"""
430
  if not path:
431
  return path
432
  if path.startswith("http://") or path.startswith("https://"):
@@ -525,6 +638,14 @@ def create_app() -> FastAPI:
525
  app.state.temp_audio_dir = os.path.join(tmp_root, "api_audio")
526
  os.makedirs(app.state.temp_audio_dir, exist_ok=True)
527
 
 
 
 
 
 
 
 
 
528
  async def _ensure_initialized() -> None:
529
  h: AceStepHandler = app.state.handler
530
 
@@ -613,6 +734,33 @@ def create_app() -> FastAPI:
613
  except Exception:
614
  pass
615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
617
  job_store: _JobStore = app.state.job_store
618
  llm: LLMHandler = app.state.llm_handler
@@ -728,10 +876,7 @@ def create_app() -> FastAPI:
728
  # - use_format (LM enhances caption/lyrics)
729
  # - use_cot_caption or use_cot_language (LM enhances metadata)
730
  need_llm = thinking or sample_mode or has_sample_query or use_format or use_cot_caption or use_cot_language
731
-
732
- print(f"[api_server] Request params: req.thinking={req.thinking}, req.sample_mode={req.sample_mode}, req.use_cot_caption={req.use_cot_caption}, req.use_cot_language={req.use_cot_language}, req.use_format={req.use_format}")
733
- print(f"[api_server] Determined: thinking={thinking}, sample_mode={sample_mode}, use_cot_caption={use_cot_caption}, use_cot_language={use_cot_language}, use_format={use_format}, need_llm={need_llm}")
734
-
735
  # Ensure LLM is ready if needed
736
  if need_llm:
737
  _ensure_llm_ready()
@@ -739,7 +884,7 @@ def create_app() -> FastAPI:
739
  raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
740
 
741
  # Handle sample mode or description: generate caption/lyrics/metas via LM
742
- caption = req.caption
743
  lyrics = req.lyrics
744
  bpm = req.bpm
745
  key_scale = req.key_scale
@@ -749,26 +894,17 @@ def create_app() -> FastAPI:
749
  if sample_mode or has_sample_query:
750
  if has_sample_query:
751
  # Use create_sample() with description query
752
- print(f"[api_server] Description mode: generating sample from query: {req.sample_query[:100]}")
753
-
754
- # Parse description for language and instrumental hints (aligned with feishu_bot)
755
  parsed_language, parsed_instrumental = _parse_description_hints(req.sample_query)
756
- print(f"[api_server] Parsed from description: language={parsed_language}, instrumental={parsed_instrumental}")
757
-
758
  # Determine vocal_language with priority:
759
- # 1. User-specified vocal_language (if not default "en") - highest priority
760
  # 2. Language parsed from description
761
  # 3. None (no constraint)
762
  if req.vocal_language and req.vocal_language not in ("en", "unknown", ""):
763
- # User explicitly specified a non-default language, use it
764
  sample_language = req.vocal_language
765
- print(f"[api_server] Using user-specified vocal_language: {sample_language}")
766
  else:
767
- # Fall back to language parsed from description
768
  sample_language = parsed_language
769
- if sample_language:
770
- print(f"[api_server] Using language from description: {sample_language}")
771
-
772
  sample_result = create_sample(
773
  llm_handler=llm,
774
  query=req.sample_query,
@@ -790,11 +926,8 @@ def create_app() -> FastAPI:
790
  key_scale = sample_result.keyscale
791
  time_signature = sample_result.timesignature
792
  audio_duration = sample_result.duration
793
-
794
- print(f"[api_server] Sample from description generated: caption_len={len(caption)}, lyrics_len={len(lyrics)}, bpm={bpm}")
795
  else:
796
  # Original sample_mode behavior: random generation
797
- print("[api_server] Sample mode: generating random caption/lyrics via LM")
798
  sample_metadata, sample_status = llm.understand_audio_from_codes(
799
  audio_codes="NO USER INPUT",
800
  temperature=req.lm_temperature,
@@ -815,15 +948,11 @@ def create_app() -> FastAPI:
815
  key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
816
  time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
817
  audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
818
-
819
- print(f"[api_server] Sample generated: caption_len={len(caption)}, lyrics_len={len(lyrics)}, bpm={bpm}, duration={audio_duration}")
820
-
821
  # Apply format_sample() if use_format is True and caption/lyrics are provided
822
- # Track whether format_sample generated duration (to decide if Phase 1 is needed)
823
  format_has_duration = False
824
-
825
  if req.use_format and (caption or lyrics):
826
- print(f"[api_server] Applying format_sample to enhance input...")
827
  _ensure_llm_ready()
828
  if getattr(app.state, "_llm_init_error", None):
829
  raise RuntimeError(f"5Hz LM init failed (needed for format): {app.state._llm_init_error}")
@@ -865,33 +994,10 @@ def create_app() -> FastAPI:
865
  key_scale = format_result.keyscale
866
  if format_result.timesignature:
867
  time_signature = format_result.timesignature
868
-
869
- print(f"[api_server] Format applied: new caption_len={len(caption)}, lyrics_len={len(lyrics)}, bpm={bpm}, duration={audio_duration}, has_duration={format_has_duration}")
870
- else:
871
- print(f"[api_server] Warning: format_sample failed: {format_result.error}, using original input")
872
-
873
- print(f"[api_server] Before GenerationParams: thinking={thinking}, sample_mode={sample_mode}")
874
- # Parse timesteps string to list of floats if provided
875
- parsed_timesteps = None
876
- if req.timesteps and req.timesteps.strip():
877
- try:
878
- parsed_timesteps = [float(t.strip()) for t in req.timesteps.split(",") if t.strip()]
879
- except ValueError:
880
- print(f"[api_server] Warning: Failed to parse timesteps '{req.timesteps}', using default")
881
- parsed_timesteps = None
882
 
883
- print(f"[api_server] Caption/Lyrics to use: caption_len={len(caption)}, lyrics_len={len(lyrics)}")
 
884
 
885
- # Parse timesteps if provided
886
- parsed_timesteps = None
887
- if req.timesteps:
888
- try:
889
- parsed_timesteps = [float(t.strip()) for t in req.timesteps.split(",") if t.strip()]
890
- print(f"[api_server] Using custom timesteps: {parsed_timesteps}")
891
- except Exception as e:
892
- print(f"[api_server] Warning: Failed to parse timesteps '{req.timesteps}': {e}")
893
- parsed_timesteps = None
894
-
895
  # Determine actual inference steps (timesteps override inference_steps)
896
  actual_inference_steps = len(parsed_timesteps) if parsed_timesteps else req.inference_steps
897
 
@@ -960,19 +1066,6 @@ def create_app() -> FastAPI:
960
  # Check LLM initialization status
961
  llm_is_initialized = getattr(app.state, "_llm_initialized", False)
962
  llm_to_pass = llm if llm_is_initialized else None
963
-
964
- print(f"[api_server] Generating music with unified interface:")
965
- print(f" - thinking={params.thinking}")
966
- print(f" - use_cot_caption={params.use_cot_caption}")
967
- print(f" - use_cot_language={params.use_cot_language}")
968
- print(f" - use_cot_metas={params.use_cot_metas}")
969
- print(f" - batch_size={batch_size}")
970
- print(f" - llm_initialized={llm_is_initialized}")
971
- print(f" - llm_handler={'Available' if llm_to_pass else 'None'}")
972
- if llm_to_pass:
973
- print(f" - LLM will be used for: CoT caption={params.use_cot_caption}, CoT language={params.use_cot_language}, CoT metas={params.use_cot_metas}, thinking={params.thinking}")
974
- else:
975
- print(f" - WARNING: LLM features requested but LLM not initialized!")
976
 
977
  # Generate music using unified interface
978
  result = generate_music(
@@ -983,9 +1076,6 @@ def create_app() -> FastAPI:
983
  save_dir=app.state.temp_audio_dir,
984
  progress=None,
985
  )
986
-
987
- print(f"[api_server] Generation completed. Success={result.success}, Audios={len(result.audios)}")
988
- print(f"[api_server] Time costs keys: {list(result.extra_outputs.get('time_costs', {}).keys())}")
989
 
990
  if not result.success:
991
  raise RuntimeError(f"Music generation failed: {result.error or result.status_message}")
@@ -1080,8 +1170,14 @@ def create_app() -> FastAPI:
1080
  loop = asyncio.get_running_loop()
1081
  result = await loop.run_in_executor(executor, _blocking_generate)
1082
  job_store.mark_succeeded(job_id, result)
 
 
 
1083
  except Exception:
1084
  job_store.mark_failed(job_id, traceback.format_exc())
 
 
 
1085
  finally:
1086
  dt = max(0.0, time.time() - t0)
1087
  async with app.state.stats_lock:
@@ -1131,122 +1227,71 @@ def create_app() -> FastAPI:
1131
  avg = float(getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS))
1132
  return pos * avg
1133
 
1134
- @app.post("/v1/music/generate", response_model=CreateJobResponse)
1135
  async def create_music_generate_job(request: Request) -> CreateJobResponse:
1136
  content_type = (request.headers.get("content-type") or "").lower()
1137
  temp_files: list[str] = []
1138
 
1139
- def _build_req_from_mapping(mapping: Any, *, reference_audio_path: Optional[str], src_audio_path: Optional[str]) -> GenerateMusicRequest:
1140
- get = getattr(mapping, "get", None)
1141
- if not callable(get):
1142
- raise HTTPException(status_code=400, detail="Invalid request payload")
1143
-
1144
- def _get_any(*keys: str, default: Any = None) -> Any:
1145
- # 1) Top-level keys
1146
- for k in keys:
1147
- v = get(k, None)
1148
- if v is not None:
1149
- return v
1150
-
1151
- # 2) Nested metas/metadata/user_metadata (dict or JSON string)
1152
- nested = (
1153
- get("metas", None)
1154
- or get("meta", None)
1155
- or get("metadata", None)
1156
- or get("user_metadata", None)
1157
- or get("userMetadata", None)
1158
- )
1159
-
1160
- if isinstance(nested, str):
1161
- s = nested.strip()
1162
- if s.startswith("{") and s.endswith("}"):
1163
- try:
1164
- nested = json.loads(s)
1165
- except Exception:
1166
- nested = None
1167
-
1168
- if isinstance(nested, dict):
1169
- g2 = nested.get
1170
- for k in keys:
1171
- v = g2(k, None)
1172
- if v is not None:
1173
- return v
1174
-
1175
- return default
1176
-
1177
- normalized_audio_duration = _to_float(_get_any("audio_duration", "duration", "audioDuration"), None)
1178
- normalized_bpm = _to_int(_get_any("bpm"), None)
1179
- normalized_keyscale = str(_get_any("key_scale", "keyscale", "keyScale", default="") or "")
1180
- normalized_timesig = str(_get_any("time_signature", "timesignature", "timeSignature", default="") or "")
1181
-
1182
- # Accept it as an alias to avoid clients needing to special-case server.
1183
- if normalized_audio_duration is None:
1184
- normalized_audio_duration = _to_float(_get_any("target_duration", "targetDuration"), None)
1185
-
1186
  return GenerateMusicRequest(
1187
- caption=str(get("caption", "") or ""),
1188
- lyrics=str(get("lyrics", "") or ""),
1189
- thinking=_to_bool(get("thinking"), False),
1190
- sample_mode=_to_bool(_get_any("sample_mode", "sampleMode"), False),
1191
- sample_query=str(_get_any("sample_query", "sampleQuery", "description", "desc", default="") or ""),
1192
- use_format=_to_bool(_get_any("use_format", "useFormat", "format"), False),
1193
- model=str(_get_any("model", "dit_model", "ditModel", default="") or "").strip() or None,
1194
- bpm=normalized_bpm,
1195
- key_scale=normalized_keyscale,
1196
- time_signature=normalized_timesig,
1197
- vocal_language=str(_get_any("vocal_language", "vocalLanguage", default="en") or "en"),
1198
- inference_steps=_to_int(_get_any("inference_steps", "inferenceSteps"), 8) or 8,
1199
- guidance_scale=_to_float(_get_any("guidance_scale", "guidanceScale"), 7.0) or 7.0,
1200
- use_random_seed=_to_bool(_get_any("use_random_seed", "useRandomSeed"), True),
1201
- seed=_to_int(get("seed"), -1) or -1,
1202
- reference_audio_path=reference_audio_path,
1203
- src_audio_path=src_audio_path,
1204
- audio_duration=normalized_audio_duration,
1205
- batch_size=_to_int(get("batch_size"), None),
1206
- audio_code_string=str(_get_any("audio_code_string", "audioCodeString", default="") or ""),
1207
- repainting_start=_to_float(get("repainting_start"), 0.0) or 0.0,
1208
- repainting_end=_to_float(get("repainting_end"), None),
1209
- instruction=str(get("instruction", _DEFAULT_DIT_INSTRUCTION) or ""),
1210
- audio_cover_strength=_to_float(_get_any("audio_cover_strength", "audioCoverStrength"), 1.0) or 1.0,
1211
- task_type=str(_get_any("task_type", "taskType", default="text2music") or "text2music"),
1212
- use_adg=_to_bool(get("use_adg"), False),
1213
- cfg_interval_start=_to_float(get("cfg_interval_start"), 0.0) or 0.0,
1214
- cfg_interval_end=_to_float(get("cfg_interval_end"), 1.0) or 1.0,
1215
- infer_method=str(_get_any("infer_method", "inferMethod", default="ode") or "ode"),
1216
- shift=_to_float(_get_any("shift"), 3.0) or 3.0,
1217
- audio_format=str(get("audio_format", "mp3") or "mp3"),
1218
- use_tiled_decode=_to_bool(_get_any("use_tiled_decode", "useTiledDecode"), True),
1219
- lm_model_path=str(get("lm_model_path") or "").strip() or None,
1220
- lm_backend=str(get("lm_backend", "vllm") or "vllm"),
1221
- lm_temperature=_to_float(get("lm_temperature"), _LM_DEFAULT_TEMPERATURE) or _LM_DEFAULT_TEMPERATURE,
1222
- lm_cfg_scale=_to_float(get("lm_cfg_scale"), _LM_DEFAULT_CFG_SCALE) or _LM_DEFAULT_CFG_SCALE,
1223
- lm_top_k=_to_int(get("lm_top_k"), None),
1224
- lm_top_p=_to_float(get("lm_top_p"), _LM_DEFAULT_TOP_P),
1225
- lm_repetition_penalty=_to_float(get("lm_repetition_penalty"), 1.0) or 1.0,
1226
- lm_negative_prompt=str(get("lm_negative_prompt", "NO USER INPUT") or "NO USER INPUT"),
1227
- constrained_decoding=_to_bool(_get_any("constrained_decoding", "constrainedDecoding", "constrained"), True),
1228
- constrained_decoding_debug=_to_bool(_get_any("constrained_decoding_debug", "constrainedDecodingDebug"), False),
1229
- use_cot_caption=_to_bool(_get_any("use_cot_caption", "cot_caption", "cot-caption"), True),
1230
- use_cot_language=_to_bool(_get_any("use_cot_language", "cot_language", "cot-language"), True),
1231
- is_format_caption=_to_bool(_get_any("is_format_caption", "isFormatCaption"), False),
1232
  )
1233
 
1234
- def _first_value(v: Any) -> Any:
1235
- if isinstance(v, list) and v:
1236
- return v[0]
1237
- return v
1238
-
1239
  if content_type.startswith("application/json"):
1240
  body = await request.json()
1241
  if not isinstance(body, dict):
1242
  raise HTTPException(status_code=400, detail="JSON payload must be an object")
1243
- req = _build_req_from_mapping(body, reference_audio_path=None, src_audio_path=None)
1244
 
1245
  elif content_type.endswith("+json"):
1246
  body = await request.json()
1247
  if not isinstance(body, dict):
1248
  raise HTTPException(status_code=400, detail="JSON payload must be an object")
1249
- req = _build_req_from_mapping(body, reference_audio_path=None, src_audio_path=None)
1250
 
1251
  elif content_type.startswith("multipart/form-data"):
1252
  form = await request.form()
@@ -1269,13 +1314,21 @@ def create_app() -> FastAPI:
1269
  else:
1270
  src_audio_path = str(form.get("src_audio_path") or "").strip() or None
1271
 
1272
- req = _build_req_from_mapping(form, reference_audio_path=reference_audio_path, src_audio_path=src_audio_path)
 
 
 
 
1273
 
1274
  elif content_type.startswith("application/x-www-form-urlencoded"):
1275
  form = await request.form()
1276
  reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
1277
  src_audio_path = str(form.get("src_audio_path") or "").strip() or None
1278
- req = _build_req_from_mapping(form, reference_audio_path=reference_audio_path, src_audio_path=src_audio_path)
 
 
 
 
1279
 
1280
  else:
1281
  raw = await request.body()
@@ -1285,7 +1338,7 @@ def create_app() -> FastAPI:
1285
  try:
1286
  body = json.loads(raw.decode("utf-8"))
1287
  if isinstance(body, dict):
1288
- req = _build_req_from_mapping(body, reference_audio_path=None, src_audio_path=None)
1289
  else:
1290
  raise HTTPException(status_code=400, detail="JSON payload must be an object")
1291
  except HTTPException:
@@ -1298,10 +1351,14 @@ def create_app() -> FastAPI:
1298
  # Best-effort: parse key=value bodies even if Content-Type is missing.
1299
  elif raw_stripped and b"=" in raw:
1300
  parsed = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
1301
- flat = {k: _first_value(v) for k, v in parsed.items()}
1302
  reference_audio_path = str(flat.get("reference_audio_path") or "").strip() or None
1303
  src_audio_path = str(flat.get("src_audio_path") or "").strip() or None
1304
- req = _build_req_from_mapping(flat, reference_audio_path=reference_audio_path, src_audio_path=src_audio_path)
 
 
 
 
1305
  else:
1306
  raise HTTPException(
1307
  status_code=415,
@@ -1331,7 +1388,7 @@ def create_app() -> FastAPI:
1331
  position = len(app.state.pending_ids)
1332
 
1333
  await q.put((rec.job_id, req))
1334
- return CreateJobResponse(job_id=rec.job_id, status="queued", queue_position=position)
1335
 
1336
  @app.post("/v1/music/random", response_model=CreateJobResponse)
1337
  async def create_random_sample_job(request: Request) -> CreateJobResponse:
@@ -1375,35 +1432,86 @@ def create_app() -> FastAPI:
1375
  position = len(app.state.pending_ids)
1376
 
1377
  await q.put((rec.job_id, req))
1378
- return CreateJobResponse(job_id=rec.job_id, status="queued", queue_position=position)
1379
 
1380
- @app.get("/v1/jobs/{job_id}", response_model=JobResponse)
1381
- async def get_job(job_id: str) -> JobResponse:
1382
- rec = store.get(job_id)
1383
- if rec is None:
1384
- raise HTTPException(status_code=404, detail="Job not found")
1385
 
1386
- pos = 0
1387
- eta = None
1388
- async with app.state.stats_lock:
1389
- avg = float(getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS))
 
1390
 
1391
- if rec.status == "queued":
1392
- pos = await _queue_position(job_id)
1393
- eta = await _eta_seconds_for_position(pos)
1394
-
1395
- return JobResponse(
1396
- job_id=rec.job_id,
1397
- status=rec.status,
1398
- created_at=rec.created_at,
1399
- started_at=rec.started_at,
1400
- finished_at=rec.finished_at,
1401
- queue_position=pos,
1402
- eta_seconds=eta,
1403
- avg_job_seconds=avg,
1404
- result=JobResult(**rec.result) if rec.result else None,
1405
- error=rec.error,
1406
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1407
 
1408
  @app.get("/health")
1409
  async def health_check():
 
1
  """FastAPI server for ACE-Step V1.5.
2
 
3
  Endpoints:
4
+ - POST /release_task Create music generation task
5
+ - POST /query_result Batch query task results
6
+ - POST /v1/music/random Create random sample task
7
+ - GET /v1/models List available models
8
+ - GET /v1/audio Download audio file
9
+ - GET /health Health check
10
 
11
  NOTE:
12
  - In-memory queue and job store -> run uvicorn with workers=1.
 
28
  from dataclasses import dataclass
29
  from pathlib import Path
30
  from threading import Lock
31
+ from typing import Any, Dict, List, Literal, Optional
32
  from uuid import uuid4
33
 
34
  try:
 
57
  from acestep.gradio_ui.events.results_handlers import _build_generation_info
58
 
59
 
60
+ # =============================================================================
61
+ # Constants
62
+ # =============================================================================
63
+
64
+ RESULT_KEY_PREFIX = "ace_step_v1.5_"
65
+ RESULT_EXPIRE_SECONDS = 7 * 24 * 60 * 60 # 7 days
66
+ TASK_TIMEOUT_SECONDS = 3600 # 1 hour
67
+ STATUS_MAP = {"queued": 0, "running": 0, "succeeded": 1, "failed": 2}
68
+
69
+ LM_DEFAULT_TEMPERATURE = 0.85
70
+ LM_DEFAULT_CFG_SCALE = 2.5
71
+ LM_DEFAULT_TOP_P = 0.9
72
+
73
+ # Parameter aliases for request parsing
74
+ PARAM_ALIASES = {
75
+ "prompt": ["prompt"],
76
+ "sample_mode": ["sample_mode", "sampleMode"],
77
+ "sample_query": ["sample_query", "sampleQuery", "description", "desc"],
78
+ "use_format": ["use_format", "useFormat", "format"],
79
+ "model": ["model", "dit_model", "ditModel"],
80
+ "key_scale": ["key_scale", "keyscale", "keyScale"],
81
+ "time_signature": ["time_signature", "timesignature", "timeSignature"],
82
+ "audio_duration": ["audio_duration", "duration", "audioDuration", "target_duration", "targetDuration"],
83
+ "vocal_language": ["vocal_language", "vocalLanguage"],
84
+ "inference_steps": ["inference_steps", "inferenceSteps"],
85
+ "guidance_scale": ["guidance_scale", "guidanceScale"],
86
+ "use_random_seed": ["use_random_seed", "useRandomSeed"],
87
+ "audio_code_string": ["audio_code_string", "audioCodeString"],
88
+ "audio_cover_strength": ["audio_cover_strength", "audioCoverStrength"],
89
+ "task_type": ["task_type", "taskType"],
90
+ "infer_method": ["infer_method", "inferMethod"],
91
+ "use_tiled_decode": ["use_tiled_decode", "useTiledDecode"],
92
+ "constrained_decoding": ["constrained_decoding", "constrainedDecoding", "constrained"],
93
+ "constrained_decoding_debug": ["constrained_decoding_debug", "constrainedDecodingDebug"],
94
+ "use_cot_caption": ["use_cot_caption", "cot_caption", "cot-caption"],
95
+ "use_cot_language": ["use_cot_language", "cot_language", "cot-language"],
96
+ "is_format_caption": ["is_format_caption", "isFormatCaption"],
97
+ }
98
+
99
+
100
  def _parse_description_hints(description: str) -> tuple[Optional[str], bool]:
101
  """
102
  Parse a description string to extract language code and instrumental flag.
 
172
 
173
 
174
  class GenerateMusicRequest(BaseModel):
175
+ prompt: str = Field(default="", description="Text prompt describing the music")
176
  lyrics: str = Field(default="", description="Lyric text")
177
 
178
  # New API semantics:
 
190
  model: Optional[str] = Field(default=None, description="Model name to use (e.g., 'acestep-v15-turbo')")
191
 
192
  bpm: Optional[int] = None
193
+ # Accept common client keys via manual parsing (see RequestParser).
194
  key_scale: str = ""
195
  time_signature: str = ""
196
  vocal_language: str = "en"
 
251
  allow_population_by_alias = True
252
 
253
 
 
 
 
 
 
 
 
254
  class CreateJobResponse(BaseModel):
255
+ task_id: str
256
  status: JobStatus
257
  queue_position: int = 0 # 1-based best-effort position when queued
258
 
 
303
  finished_at: Optional[float] = None
304
  result: Optional[Dict[str, Any]] = None
305
  error: Optional[str] = None
306
+ env: str = "development"
307
 
308
 
309
  class _JobStore:
 
318
  self._jobs[job_id] = rec
319
  return rec
320
 
321
+ def create_with_id(self, job_id: str, env: str = "development") -> _JobRecord:
322
+ """Create job record with specified ID"""
323
+ rec = _JobRecord(
324
+ job_id=job_id,
325
+ status="queued",
326
+ created_at=time.time(),
327
+ env=env
328
+ )
329
+ with self._lock:
330
+ self._jobs[job_id] = rec
331
+ return rec
332
+
333
  def get(self, job_id: str) -> Optional[_JobRecord]:
334
  with self._lock:
335
  return self._jobs.get(job_id)
 
440
  return s in {"1", "true", "yes", "y", "on"}
441
 
442
 
443
+ def _map_status(status: str) -> int:
444
+ """Map job status string to integer code."""
445
+ return STATUS_MAP.get(status, 2)
446
+
447
+
448
+ def _parse_timesteps(s: Optional[str]) -> Optional[List[float]]:
449
+ """Parse comma-separated timesteps string to list of floats."""
450
+ if not s or not s.strip():
451
+ return None
452
+ try:
453
+ return [float(t.strip()) for t in s.split(",") if t.strip()]
454
+ except (ValueError, Exception):
455
+ return None
456
+
457
+
458
+ class RequestParser:
459
+ """Parse request parameters from multiple sources with alias support."""
460
+
461
+ def __init__(self, raw: dict):
462
+ self._raw = dict(raw) if raw else {}
463
+ self._param_obj = self._parse_json(self._raw.get("param_obj"))
464
+ self._metas = self._find_metas()
465
+
466
+ def _parse_json(self, v) -> dict:
467
+ if isinstance(v, dict):
468
+ return v
469
+ if isinstance(v, str) and v.strip():
470
+ try:
471
+ return json.loads(v)
472
+ except Exception:
473
+ pass
474
+ return {}
475
+
476
+ def _find_metas(self) -> dict:
477
+ for key in ("metas", "meta", "metadata", "user_metadata", "userMetadata"):
478
+ v = self._raw.get(key)
479
+ if v:
480
+ return self._parse_json(v)
481
+ return {}
482
+
483
+ def get(self, name: str, default=None):
484
+ """Get parameter by canonical name from all sources."""
485
+ aliases = PARAM_ALIASES.get(name, [name])
486
+ for source in (self._raw, self._param_obj, self._metas):
487
+ for alias in aliases:
488
+ v = source.get(alias)
489
+ if v is not None:
490
+ return v
491
+ return default
492
+
493
+ def str(self, name: str, default: str = "") -> str:
494
+ v = self.get(name)
495
+ return str(v) if v is not None else default
496
+
497
+ def int(self, name: str, default: Optional[int] = None) -> Optional[int]:
498
+ return _to_int(self.get(name), default)
499
+
500
+ def float(self, name: str, default: Optional[float] = None) -> Optional[float]:
501
+ return _to_float(self.get(name), default)
502
+
503
+ def bool(self, name: str, default: bool = False) -> bool:
504
+ return _to_bool(self.get(name), default)
505
+
506
+
507
  async def _save_upload_to_temp(upload: StarletteUploadFile, *, prefix: str) -> str:
508
  suffix = Path(upload.filename or "").suffix
509
  fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=suffix)
 
533
  store = _JobStore()
534
 
535
  QUEUE_MAXSIZE = int(os.getenv("ACESTEP_QUEUE_MAXSIZE", "200"))
536
+ WORKER_COUNT = int(os.getenv("ACESTEP_QUEUE_WORKERS", "1")) # Single GPU recommended
537
 
538
  INITIAL_AVG_JOB_SECONDS = float(os.getenv("ACESTEP_AVG_JOB_SECONDS", "5.0"))
539
  AVG_WINDOW = int(os.getenv("ACESTEP_AVG_WINDOW", "50"))
540
 
541
  def _path_to_audio_url(path: str) -> str:
542
+ """Convert local file path to downloadable relative URL"""
543
  if not path:
544
  return path
545
  if path.startswith("http://") or path.startswith("https://"):
 
638
  app.state.temp_audio_dir = os.path.join(tmp_root, "api_audio")
639
  os.makedirs(app.state.temp_audio_dir, exist_ok=True)
640
 
641
+ # Initialize local cache
642
+ try:
643
+ from acestep.local_cache import get_local_cache
644
+ local_cache_dir = os.path.join(cache_root, "local_redis")
645
+ app.state.local_cache = get_local_cache(local_cache_dir)
646
+ except ImportError:
647
+ app.state.local_cache = None
648
+
649
  async def _ensure_initialized() -> None:
650
  h: AceStepHandler = app.state.handler
651
 
 
734
  except Exception:
735
  pass
736
 
737
+ def _update_local_cache(job_id: str, result: Optional[Dict], status: str) -> None:
738
+ """Update local cache with job result"""
739
+ local_cache = getattr(app.state, 'local_cache', None)
740
+ if not local_cache:
741
+ return
742
+
743
+ rec = store.get(job_id)
744
+ env = getattr(rec, 'env', 'development') if rec else 'development'
745
+ create_time = rec.created_at if rec else time.time()
746
+
747
+ status_int = _map_status(status)
748
+
749
+ if status == "succeeded" and result:
750
+ audio_paths = result.get("audio_paths", [])
751
+ if audio_paths:
752
+ result_data = [
753
+ {"file": p, "wave": "", "status": status_int, "create_time": int(create_time), "env": env}
754
+ for p in audio_paths
755
+ ]
756
+ else:
757
+ result_data = [{"file": "", "wave": "", "status": status_int, "create_time": int(create_time), "env": env}]
758
+ else:
759
+ result_data = [{"file": "", "wave": "", "status": status_int, "create_time": int(create_time), "env": env}]
760
+
761
+ result_key = f"{RESULT_KEY_PREFIX}{job_id}"
762
+ local_cache.set(result_key, result_data, ex=RESULT_EXPIRE_SECONDS)
763
+
764
  async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
765
  job_store: _JobStore = app.state.job_store
766
  llm: LLMHandler = app.state.llm_handler
 
876
  # - use_format (LM enhances caption/lyrics)
877
  # - use_cot_caption or use_cot_language (LM enhances metadata)
878
  need_llm = thinking or sample_mode or has_sample_query or use_format or use_cot_caption or use_cot_language
879
+
 
 
 
880
  # Ensure LLM is ready if needed
881
  if need_llm:
882
  _ensure_llm_ready()
 
884
  raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
885
 
886
  # Handle sample mode or description: generate caption/lyrics/metas via LM
887
+ caption = req.prompt
888
  lyrics = req.lyrics
889
  bpm = req.bpm
890
  key_scale = req.key_scale
 
894
  if sample_mode or has_sample_query:
895
  if has_sample_query:
896
  # Use create_sample() with description query
 
 
 
897
  parsed_language, parsed_instrumental = _parse_description_hints(req.sample_query)
898
+
 
899
  # Determine vocal_language with priority:
900
+ # 1. User-specified vocal_language (if not default "en")
901
  # 2. Language parsed from description
902
  # 3. None (no constraint)
903
  if req.vocal_language and req.vocal_language not in ("en", "unknown", ""):
 
904
  sample_language = req.vocal_language
 
905
  else:
 
906
  sample_language = parsed_language
907
+
 
 
908
  sample_result = create_sample(
909
  llm_handler=llm,
910
  query=req.sample_query,
 
926
  key_scale = sample_result.keyscale
927
  time_signature = sample_result.timesignature
928
  audio_duration = sample_result.duration
 
 
929
  else:
930
  # Original sample_mode behavior: random generation
 
931
  sample_metadata, sample_status = llm.understand_audio_from_codes(
932
  audio_codes="NO USER INPUT",
933
  temperature=req.lm_temperature,
 
948
  key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
949
  time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
950
  audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
951
+
 
 
952
  # Apply format_sample() if use_format is True and caption/lyrics are provided
 
953
  format_has_duration = False
954
+
955
  if req.use_format and (caption or lyrics):
 
956
  _ensure_llm_ready()
957
  if getattr(app.state, "_llm_init_error", None):
958
  raise RuntimeError(f"5Hz LM init failed (needed for format): {app.state._llm_init_error}")
 
994
  key_scale = format_result.keyscale
995
  if format_result.timesignature:
996
  time_signature = format_result.timesignature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
997
 
998
+ # Parse timesteps string to list of floats if provided
999
+ parsed_timesteps = _parse_timesteps(req.timesteps)
1000
 
 
 
 
 
 
 
 
 
 
 
1001
  # Determine actual inference steps (timesteps override inference_steps)
1002
  actual_inference_steps = len(parsed_timesteps) if parsed_timesteps else req.inference_steps
1003
 
 
1066
  # Check LLM initialization status
1067
  llm_is_initialized = getattr(app.state, "_llm_initialized", False)
1068
  llm_to_pass = llm if llm_is_initialized else None
 
 
 
 
 
 
 
 
 
 
 
 
 
1069
 
1070
  # Generate music using unified interface
1071
  result = generate_music(
 
1076
  save_dir=app.state.temp_audio_dir,
1077
  progress=None,
1078
  )
 
 
 
1079
 
1080
  if not result.success:
1081
  raise RuntimeError(f"Music generation failed: {result.error or result.status_message}")
 
1170
  loop = asyncio.get_running_loop()
1171
  result = await loop.run_in_executor(executor, _blocking_generate)
1172
  job_store.mark_succeeded(job_id, result)
1173
+
1174
+ # Update local cache
1175
+ _update_local_cache(job_id, result, "succeeded")
1176
  except Exception:
1177
  job_store.mark_failed(job_id, traceback.format_exc())
1178
+
1179
+ # Update local cache
1180
+ _update_local_cache(job_id, None, "failed")
1181
  finally:
1182
  dt = max(0.0, time.time() - t0)
1183
  async with app.state.stats_lock:
 
1227
  avg = float(getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS))
1228
  return pos * avg
1229
 
1230
+ @app.post("/release_task", response_model=CreateJobResponse)
1231
  async def create_music_generate_job(request: Request) -> CreateJobResponse:
1232
  content_type = (request.headers.get("content-type") or "").lower()
1233
  temp_files: list[str] = []
1234
 
1235
+ def _build_request(p: RequestParser, **kwargs) -> GenerateMusicRequest:
1236
+ """Build GenerateMusicRequest from parsed parameters."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  return GenerateMusicRequest(
1238
+ prompt=p.str("prompt"),
1239
+ lyrics=p.str("lyrics"),
1240
+ thinking=p.bool("thinking"),
1241
+ sample_mode=p.bool("sample_mode"),
1242
+ sample_query=p.str("sample_query"),
1243
+ use_format=p.bool("use_format"),
1244
+ model=p.str("model") or None,
1245
+ bpm=p.int("bpm"),
1246
+ key_scale=p.str("key_scale"),
1247
+ time_signature=p.str("time_signature"),
1248
+ audio_duration=p.float("audio_duration"),
1249
+ vocal_language=p.str("vocal_language", "en"),
1250
+ inference_steps=p.int("inference_steps", 8),
1251
+ guidance_scale=p.float("guidance_scale", 7.0),
1252
+ use_random_seed=p.bool("use_random_seed", True),
1253
+ seed=p.int("seed", -1),
1254
+ batch_size=p.int("batch_size"),
1255
+ audio_code_string=p.str("audio_code_string"),
1256
+ repainting_start=p.float("repainting_start", 0.0),
1257
+ repainting_end=p.float("repainting_end"),
1258
+ instruction=p.str("instruction", DEFAULT_DIT_INSTRUCTION),
1259
+ audio_cover_strength=p.float("audio_cover_strength", 1.0),
1260
+ task_type=p.str("task_type", "text2music"),
1261
+ use_adg=p.bool("use_adg"),
1262
+ cfg_interval_start=p.float("cfg_interval_start", 0.0),
1263
+ cfg_interval_end=p.float("cfg_interval_end", 1.0),
1264
+ infer_method=p.str("infer_method", "ode"),
1265
+ shift=p.float("shift", 3.0),
1266
+ audio_format=p.str("audio_format", "mp3"),
1267
+ use_tiled_decode=p.bool("use_tiled_decode", True),
1268
+ lm_model_path=p.str("lm_model_path") or None,
1269
+ lm_backend=p.str("lm_backend", "vllm"),
1270
+ lm_temperature=p.float("lm_temperature", LM_DEFAULT_TEMPERATURE),
1271
+ lm_cfg_scale=p.float("lm_cfg_scale", LM_DEFAULT_CFG_SCALE),
1272
+ lm_top_k=p.int("lm_top_k"),
1273
+ lm_top_p=p.float("lm_top_p", LM_DEFAULT_TOP_P),
1274
+ lm_repetition_penalty=p.float("lm_repetition_penalty", 1.0),
1275
+ lm_negative_prompt=p.str("lm_negative_prompt", "NO USER INPUT"),
1276
+ constrained_decoding=p.bool("constrained_decoding", True),
1277
+ constrained_decoding_debug=p.bool("constrained_decoding_debug"),
1278
+ use_cot_caption=p.bool("use_cot_caption", True),
1279
+ use_cot_language=p.bool("use_cot_language", True),
1280
+ is_format_caption=p.bool("is_format_caption"),
1281
+ **kwargs,
 
1282
  )
1283
 
 
 
 
 
 
1284
  if content_type.startswith("application/json"):
1285
  body = await request.json()
1286
  if not isinstance(body, dict):
1287
  raise HTTPException(status_code=400, detail="JSON payload must be an object")
1288
+ req = _build_request(RequestParser(body))
1289
 
1290
  elif content_type.endswith("+json"):
1291
  body = await request.json()
1292
  if not isinstance(body, dict):
1293
  raise HTTPException(status_code=400, detail="JSON payload must be an object")
1294
+ req = _build_request(RequestParser(body))
1295
 
1296
  elif content_type.startswith("multipart/form-data"):
1297
  form = await request.form()
 
1314
  else:
1315
  src_audio_path = str(form.get("src_audio_path") or "").strip() or None
1316
 
1317
+ req = _build_request(
1318
+ RequestParser(dict(form)),
1319
+ reference_audio_path=reference_audio_path,
1320
+ src_audio_path=src_audio_path,
1321
+ )
1322
 
1323
  elif content_type.startswith("application/x-www-form-urlencoded"):
1324
  form = await request.form()
1325
  reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
1326
  src_audio_path = str(form.get("src_audio_path") or "").strip() or None
1327
+ req = _build_request(
1328
+ RequestParser(dict(form)),
1329
+ reference_audio_path=reference_audio_path,
1330
+ src_audio_path=src_audio_path,
1331
+ )
1332
 
1333
  else:
1334
  raw = await request.body()
 
1338
  try:
1339
  body = json.loads(raw.decode("utf-8"))
1340
  if isinstance(body, dict):
1341
+ req = _build_request(RequestParser(body))
1342
  else:
1343
  raise HTTPException(status_code=400, detail="JSON payload must be an object")
1344
  except HTTPException:
 
1351
  # Best-effort: parse key=value bodies even if Content-Type is missing.
1352
  elif raw_stripped and b"=" in raw:
1353
  parsed = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
1354
+ flat = {k: (v[0] if isinstance(v, list) and v else v) for k, v in parsed.items()}
1355
  reference_audio_path = str(flat.get("reference_audio_path") or "").strip() or None
1356
  src_audio_path = str(flat.get("src_audio_path") or "").strip() or None
1357
+ req = _build_request(
1358
+ RequestParser(flat),
1359
+ reference_audio_path=reference_audio_path,
1360
+ src_audio_path=src_audio_path,
1361
+ )
1362
  else:
1363
  raise HTTPException(
1364
  status_code=415,
 
1388
  position = len(app.state.pending_ids)
1389
 
1390
  await q.put((rec.job_id, req))
1391
+ return CreateJobResponse(task_id=rec.job_id, status="queued", queue_position=position)
1392
 
1393
  @app.post("/v1/music/random", response_model=CreateJobResponse)
1394
  async def create_random_sample_job(request: Request) -> CreateJobResponse:
 
1432
  position = len(app.state.pending_ids)
1433
 
1434
  await q.put((rec.job_id, req))
1435
+ return CreateJobResponse(task_id=rec.job_id, status="queued", queue_position=position)
1436
 
1437
+ @app.post("/query_result")
1438
+ async def query_result(request: Request) -> List[Dict[str, Any]]:
1439
+ """Batch query job results"""
1440
+ content_type = (request.headers.get("content-type") or "").lower()
 
1441
 
1442
+ if "json" in content_type:
1443
+ body = await request.json()
1444
+ else:
1445
+ form = await request.form()
1446
+ body = {k: v for k, v in form.items()}
1447
 
1448
+ task_id_list_str = body.get("task_id_list", "[]")
1449
+
1450
+ # Parse task ID list
1451
+ if isinstance(task_id_list_str, list):
1452
+ task_id_list = task_id_list_str
1453
+ else:
1454
+ try:
1455
+ task_id_list = json.loads(task_id_list_str)
1456
+ except Exception:
1457
+ task_id_list = []
1458
+
1459
+ local_cache = getattr(app.state, 'local_cache', None)
1460
+ data_list = []
1461
+ current_time = time.time()
1462
+
1463
+ for task_id in task_id_list:
1464
+ result_key = f"{RESULT_KEY_PREFIX}{task_id}"
1465
+
1466
+ # Read from local cache first
1467
+ if local_cache:
1468
+ data = local_cache.get(result_key)
1469
+ if data:
1470
+ try:
1471
+ data_json = json.loads(data)
1472
+ except Exception:
1473
+ data_json = []
1474
+
1475
+ if len(data_json) <= 0:
1476
+ data_list.append({"task_id": task_id, "result": data, "status": 2})
1477
+ else:
1478
+ status = data_json[0].get("status")
1479
+ create_time = data_json[0].get("create_time", 0)
1480
+ if status == 0 and (current_time - create_time) > TASK_TIMEOUT_SECONDS:
1481
+ data_list.append({"task_id": task_id, "result": data, "status": 2})
1482
+ else:
1483
+ data_list.append({
1484
+ "task_id": task_id,
1485
+ "result": data,
1486
+ "status": int(status) if status is not None else 1,
1487
+ })
1488
+ continue
1489
+
1490
+ # Fallback to job_store query
1491
+ rec = store.get(task_id)
1492
+ if rec:
1493
+ env = getattr(rec, 'env', 'development')
1494
+ create_time = rec.created_at
1495
+ status_int = _map_status(rec.status)
1496
+
1497
+ if rec.result and rec.status == "succeeded":
1498
+ audio_paths = rec.result.get("audio_paths", [])
1499
+ result_data = [
1500
+ {"file": p, "wave": "", "status": status_int, "create_time": int(create_time), "env": env}
1501
+ for p in audio_paths
1502
+ ] if audio_paths else [{"file": "", "wave": "", "status": status_int, "create_time": int(create_time), "env": env}]
1503
+ else:
1504
+ result_data = [{"file": "", "wave": "", "status": status_int, "create_time": int(create_time), "env": env}]
1505
+
1506
+ data_list.append({
1507
+ "task_id": task_id,
1508
+ "result": json.dumps(result_data, ensure_ascii=False),
1509
+ "status": status_int,
1510
+ })
1511
+ else:
1512
+ data_list.append({"task_id": task_id, "result": "[]", "status": 0})
1513
+
1514
+ return data_list
1515
 
1516
  @app.get("/health")
1517
  async def health_check():
acestep/local_cache.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local cache module to replace Redis
2
+
3
+ Uses diskcache as backend, provides Redis-compatible API.
4
+ Supports persistent storage and TTL expiration.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ from typing import Any, Optional
10
+ from threading import Lock
11
+
12
+ try:
13
+ from diskcache import Cache
14
+ HAS_DISKCACHE = True
15
+ except ImportError:
16
+ HAS_DISKCACHE = False
17
+
18
+
19
+ class LocalCache:
20
+ """
21
+ Local cache implementation with Redis-compatible API.
22
+ Uses diskcache as backend, supports persistence and TTL.
23
+ """
24
+
25
+ _instance = None
26
+ _lock = Lock()
27
+
28
+ def __new__(cls, cache_dir: Optional[str] = None):
29
+ """Singleton pattern"""
30
+ if cls._instance is None:
31
+ with cls._lock:
32
+ if cls._instance is None:
33
+ cls._instance = super().__new__(cls)
34
+ cls._instance._initialized = False
35
+ return cls._instance
36
+
37
+ def __init__(self, cache_dir: Optional[str] = None):
38
+ if getattr(self, '_initialized', False):
39
+ return
40
+
41
+ if not HAS_DISKCACHE:
42
+ raise ImportError(
43
+ "diskcache not installed. Run: pip install diskcache"
44
+ )
45
+
46
+ if cache_dir is None:
47
+ cache_dir = os.path.join(
48
+ os.path.dirname(os.path.dirname(__file__)),
49
+ ".cache",
50
+ "local_redis"
51
+ )
52
+
53
+ os.makedirs(cache_dir, exist_ok=True)
54
+ self._cache = Cache(cache_dir)
55
+ self._initialized = True
56
+
57
+ def set(self, name: str, value: Any, ex: Optional[int] = None) -> bool:
58
+ """
59
+ Set key-value pair
60
+
61
+ Args:
62
+ name: Key name
63
+ value: Value (auto-serialize dict/list)
64
+ ex: Expiration time (seconds)
65
+
66
+ Returns:
67
+ bool: Success status
68
+ """
69
+ if isinstance(value, (dict, list)):
70
+ value = json.dumps(value, ensure_ascii=False)
71
+ self._cache.set(name, value, expire=ex)
72
+ return True
73
+
74
+ def get(self, name: str) -> Optional[str]:
75
+ """Get value"""
76
+ return self._cache.get(name)
77
+
78
+ def delete(self, name: str) -> int:
79
+ """Delete key, returns number of deleted items"""
80
+ return 1 if self._cache.delete(name) else 0
81
+
82
+ def exists(self, name: str) -> bool:
83
+ """Check if key exists"""
84
+ return name in self._cache
85
+
86
+ def keys(self, pattern: str = "*") -> list:
87
+ """
88
+ Get list of matching keys
89
+ Note: Simplified implementation, only supports prefix and full matching
90
+ """
91
+ if pattern == "*":
92
+ return list(self._cache.iterkeys())
93
+
94
+ prefix = pattern.rstrip("*")
95
+ return [k for k in self._cache.iterkeys() if k.startswith(prefix)]
96
+
97
+ def expire(self, name: str, seconds: int) -> bool:
98
+ """Set key expiration time"""
99
+ value = self._cache.get(name)
100
+ if value is not None:
101
+ self._cache.set(name, value, expire=seconds)
102
+ return True
103
+ return False
104
+
105
+ def ttl(self, name: str) -> int:
106
+ """
107
+ Get remaining time to live (seconds)
108
+ Note: diskcache does not directly support TTL queries
109
+ """
110
+ if name in self._cache:
111
+ return -1 # Exists but TTL unknown
112
+ return -2 # Key does not exist
113
+
114
+ def close(self):
115
+ """Close cache connection"""
116
+ if hasattr(self, '_cache'):
117
+ self._cache.close()
118
+
119
+
120
+ # Lazily initialized global instance
121
+ _local_cache: Optional[LocalCache] = None
122
+
123
+
124
+ def get_local_cache(cache_dir: Optional[str] = None) -> LocalCache:
125
+ """Get local cache instance"""
126
+ global _local_cache
127
+ if _local_cache is None:
128
+ _local_cache = LocalCache(cache_dir)
129
+ return _local_cache
pyproject.toml CHANGED
@@ -25,6 +25,7 @@ dependencies = [
25
  "einops>=0.8.1",
26
  "accelerate>=1.12.0",
27
  "fastapi>=0.110.0",
 
28
  "uvicorn[standard]>=0.27.0",
29
  "numba>=0.63.1",
30
  "vector-quantize-pytorch>=1.27.15",
 
25
  "einops>=0.8.1",
26
  "accelerate>=1.12.0",
27
  "fastapi>=0.110.0",
28
+ "diskcache",
29
  "uvicorn[standard]>=0.27.0",
30
  "numba>=0.63.1",
31
  "vector-quantize-pytorch>=1.27.15",