Files
ViGent2/models/Qwen3-TTS/qwen_tts/inference/qwen3_tts_model.py
Kevin Wong 4a3dd2b225 更新
2026-01-28 17:22:31 +08:00

878 lines
36 KiB
Python

# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import urllib.request
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
import librosa
import numpy as np
import soundfile as sf
import torch
from transformers import AutoConfig, AutoModel, AutoProcessor
from ..core.models import Qwen3TTSConfig, Qwen3TTSForConditionalGeneration, Qwen3TTSProcessor
AudioLike = Union[
str, # wav path, URL, base64
np.ndarray, # waveform (requires sr)
Tuple[np.ndarray, int], # (waveform, sr)
]
MaybeList = Union[Any, List[Any]]
@dataclass
class VoiceClonePromptItem:
"""
Container for one sample's voice-clone prompt information that can be fed to the model.
Fields are aligned with `Qwen3TTSForConditionalGeneration.generate(..., voice_clone_prompt=...)`.
"""
ref_code: Optional[torch.Tensor] # (T, Q) or (T,) depending on tokenizer 25Hz/12Hz
ref_spk_embedding: torch.Tensor # (D,)
x_vector_only_mode: bool
icl_mode: bool
ref_text: Optional[str] = None
class Qwen3TTSModel:
"""
A HuggingFace-style wrapper for Qwen3 TTS models (CustomVoice/VoiceDesign/Base) that provides:
- from_pretrained() initialization via AutoModel/AutoProcessor
- generation APIs for:
* CustomVoice: generate_custom_voice()
* VoiceDesign: generate_voice_design()
* Base: generate_voice_clone() + create_voice_clone_prompt()
- consistent output: (wavs: List[np.ndarray], sample_rate: int)
Notes:
- This wrapper expects the underlying model class to be `Qwen3TTSForConditionalGeneration`
- Language / speaker validation is done via model methods:
model.get_supported_languages(), model.get_supported_speakers()
"""
def __init__(self, model: Qwen3TTSForConditionalGeneration, processor, generate_defaults: Optional[Dict[str, Any]] = None):
self.model = model
self.processor = processor
self.generate_defaults = generate_defaults or {}
self.device = getattr(model, "device", None)
if self.device is None:
try:
self.device = next(model.parameters()).device
except StopIteration:
self.device = torch.device("cpu")
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
**kwargs,
) -> "Qwen3TTSModel":
"""
Load a Qwen3 TTS model and its processor in HuggingFace `from_pretrained` style.
This method:
1) Loads config via AutoConfig (so your side can register model_type -> config/model).
2) Loads the model via AutoModel.from_pretrained(...), forwarding `kwargs` unchanged.
3) Loads the processor via AutoProcessor.from_pretrained(model_path).
4) Loads optional `generate_config.json` from the model directory/repo snapshot if present.
Args:
pretrained_model_name_or_path (str):
HuggingFace repo id or local directory of the model.
**kwargs:
Forwarded as-is into `AutoModel.from_pretrained(...)`.
Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="flash_attention_2".
Returns:
Qwen3TTSModel:
Wrapper instance containing `model`, `processor`, and generation defaults.
"""
AutoConfig.register("qwen3_tts", Qwen3TTSConfig)
AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration)
AutoProcessor.register(Qwen3TTSConfig, Qwen3TTSProcessor)
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
if not isinstance(model, Qwen3TTSForConditionalGeneration):
raise TypeError(
f"AutoModel returned {type(model)}, expected Qwen3TTSForConditionalGeneration. "
)
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True,)
generate_defaults = model.generate_config
return cls(model=model, processor=processor, generate_defaults=generate_defaults)
def _supported_languages_set(self) -> Optional[set]:
langs = getattr(self.model, "get_supported_languages", None)
if callable(langs):
v = langs()
if v is None:
return None
return set([str(x).lower() for x in v])
return None
def _supported_speakers_set(self) -> Optional[set]:
spks = getattr(self.model, "get_supported_speakers", None)
if callable(spks):
v = spks()
if v is None:
return None
return set([str(x).lower() for x in v])
return None
def _validate_languages(self, languages: List[str]) -> None:
"""
Validate that requested languages are supported by the model.
Args:
languages (List[str]): Language names for each sample.
Raises:
ValueError: If any language is not supported.
"""
supported = self._supported_languages_set()
if supported is None:
return
bad = []
for lang in languages:
if lang is None:
bad.append(lang)
continue
if str(lang).lower() not in supported:
bad.append(lang)
if bad:
raise ValueError(f"Unsupported languages: {bad}. Supported: {sorted(supported)}")
def _validate_speakers(self, speakers: List[Optional[str]]) -> None:
"""
Validate that requested speakers are supported by the Instruct model.
Args:
speakers (List[Optional[str]]): Speaker names for each sample.
Raises:
ValueError: If any speaker is not supported.
"""
supported = self._supported_speakers_set()
if supported is None:
return
bad = []
for spk in speakers:
if spk is None or spk == "":
continue
if str(spk).lower() not in supported:
bad.append(spk)
if bad:
raise ValueError(f"Unsupported speakers: {bad}. Supported: {sorted(supported)}")
def _is_probably_base64(self, s: str) -> bool:
if s.startswith("data:audio"):
return True
if ("/" not in s and "\\" not in s) and len(s) > 256:
return True
return False
def _is_url(self, s: str) -> bool:
try:
u = urlparse(s)
return u.scheme in ("http", "https") and bool(u.netloc)
except Exception:
return False
def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
if "," in b64 and b64.strip().startswith("data:"):
b64 = b64.split(",", 1)[1]
return base64.b64decode(b64)
def _load_audio_to_np(self, x: str) -> Tuple[np.ndarray, int]:
if self._is_url(x):
with urllib.request.urlopen(x) as resp:
audio_bytes = resp.read()
with io.BytesIO(audio_bytes) as f:
audio, sr = sf.read(f, dtype="float32", always_2d=False)
elif self._is_probably_base64(x):
wav_bytes = self._decode_base64_to_wav_bytes(x)
with io.BytesIO(wav_bytes) as f:
audio, sr = sf.read(f, dtype="float32", always_2d=False)
else:
audio, sr = librosa.load(x, sr=None, mono=True)
if audio.ndim > 1:
audio = np.mean(audio, axis=-1)
return audio.astype(np.float32), int(sr)
def _normalize_audio_inputs(self, audios: Union[AudioLike, List[AudioLike]]) -> List[Tuple[np.ndarray, int]]:
"""
Normalize audio inputs into a list of (waveform, sr).
Supported forms:
- str: wav path / URL / base64 audio string
- (np.ndarray, sr): waveform + sampling rate
- list of the above
Args:
audios:
Audio input(s).
Returns:
List[Tuple[np.ndarray, int]]:
List of (float32 waveform, original sr).
Raises:
ValueError: If a numpy waveform is provided without sr.
"""
if isinstance(audios, list):
items = audios
else:
items = [audios]
out: List[Tuple[np.ndarray, int]] = []
for a in items:
if isinstance(a, str):
out.append(self._load_audio_to_np(a))
elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
out.append((a[0].astype(np.float32), int(a[1])))
elif isinstance(a, np.ndarray):
raise ValueError("For numpy waveform input, pass a tuple (audio, sr).")
else:
raise TypeError(f"Unsupported audio input type: {type(a)}")
for i, a in enumerate(out):
if a[0].ndim > 1:
a[0] = np.mean(a[0], axis=-1).astype(np.float32)
out[i] = (a[0], a[1])
return out
def _ensure_list(self, x: MaybeList) -> List[Any]:
return x if isinstance(x, list) else [x]
def _build_assistant_text(self, text: str) -> str:
return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
def _build_ref_text(self, text: str) -> str:
return f"<|im_start|>assistant\n{text}<|im_end|>\n"
def _build_instruct_text(self, instruct: str) -> str:
return f"<|im_start|>user\n{instruct}<|im_end|>\n"
def _tokenize_texts(self, texts: List[str]) -> List[torch.Tensor]:
input_ids = []
for text in texts:
input = self.processor(text=text, return_tensors="pt", padding=True)
input_id = input["input_ids"].to(self.device)
input_id = input_id.unsqueeze(0) if input_id.dim() == 1 else input_id
input_ids.append(input_id)
return input_ids
def _merge_generate_kwargs(
self,
do_sample: Optional[bool] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
repetition_penalty: Optional[float] = None,
subtalker_dosample: Optional[bool] = None,
subtalker_top_k: Optional[int] = None,
subtalker_top_p: Optional[float] = None,
subtalker_temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Merge user-provided generation arguments with defaults from `generate_config.json`.
Rule:
- If the user explicitly passes a value (not None), use it.
- Otherwise, use the value from generate_config.json if present.
- Otherwise, fall back to the hard defaults.
Args:
do_sample, top_k, top_p, temperature, repetition_penalty,
subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens:
Common generation parameters.
**kwargs:
Other arguments forwarded to model.generate().
Returns:
Dict[str, Any]: Final kwargs to pass into model.generate().
"""
hard_defaults = dict(
do_sample=True,
top_k=50,
top_p=1.0,
temperature=0.9,
repetition_penalty=1.05,
subtalker_dosample=True,
subtalker_top_k=50,
subtalker_top_p=1.0,
subtalker_temperature=0.9,
max_new_tokens=2048,
)
def pick(name: str, user_val: Any) -> Any:
if user_val is not None:
return user_val
if name in self.generate_defaults:
return self.generate_defaults[name]
return hard_defaults[name]
merged = dict(kwargs)
merged.update(
do_sample=pick("do_sample", do_sample),
top_k=pick("top_k", top_k),
top_p=pick("top_p", top_p),
temperature=pick("temperature", temperature),
repetition_penalty=pick("repetition_penalty", repetition_penalty),
subtalker_dosample=pick("subtalker_dosample", subtalker_dosample),
subtalker_top_k=pick("subtalker_top_k", subtalker_top_k),
subtalker_top_p=pick("subtalker_top_p", subtalker_top_p),
subtalker_temperature=pick("subtalker_temperature", subtalker_temperature),
max_new_tokens=pick("max_new_tokens", max_new_tokens),
)
return merged
# voice clone model
@torch.inference_mode()
def create_voice_clone_prompt(
self,
ref_audio: Union[AudioLike, List[AudioLike]],
ref_text: Optional[Union[str, List[Optional[str]]]] = None,
x_vector_only_mode: Union[bool, List[bool]] = False,
) -> List[VoiceClonePromptItem]:
"""
Build voice-clone prompt items from reference audio (and optionally reference text) using Base model.
Modes:
- x_vector_only_mode=True:
Only speaker embedding is used to clone voice; ref_text/ref_code are ignored.
This is mutually exclusive with ICL.
- x_vector_only_mode=False:
ICL mode is enabled automatically (icl_mode=True). In this case ref_text is required,
because the model continues/conditions on the reference text + reference speech codes.
Batch behavior:
- ref_audio can be a single item or a list.
- ref_text and x_vector_only_mode can be scalars or lists.
- If any of them are lists with length > 1, lengths must match.
Audio input:
- str: local wav path / URL / base64
- (np.ndarray, sr): waveform + sampling rate
Args:
ref_audio:
Reference audio(s) used to extract:
- ref_code via `model.speech_tokenizer.encode(...)`
- ref_spk_embedding via `model.extract_speaker_embedding(...)` (resampled to 24k)
ref_text:
Reference transcript(s). Required when x_vector_only_mode=False (ICL mode).
x_vector_only_mode:
Whether to use speaker embedding only. If False, ICL mode will be used.
Returns:
List[VoiceClonePromptItem]:
List of prompt items that can be converted into `voice_clone_prompt` dict.
Raises:
ValueError:
- If x_vector_only_mode=False but ref_text is missing.
- If batch lengths mismatch.
"""
if self.model.tts_model_type != "base":
raise ValueError(
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
f"tts_model_size: {self.model.tts_model_size}\n"
f"tts_model_type: {self.model.tts_model_type}\n"
"does not support create_voice_clone_prompt, Please check Model Card or Readme for more details."
)
ref_audio_list = self._ensure_list(ref_audio)
ref_text_list = self._ensure_list(ref_text) if isinstance(ref_text, list) else ([ref_text] * len(ref_audio_list))
xvec_list = self._ensure_list(x_vector_only_mode) if isinstance(x_vector_only_mode, list) else ([x_vector_only_mode] * len(ref_audio_list))
if len(ref_text_list) != len(ref_audio_list) or len(xvec_list) != len(ref_audio_list):
raise ValueError(
f"Batch size mismatch: ref_audio={len(ref_audio_list)}, ref_text={len(ref_text_list)}, x_vector_only_mode={len(xvec_list)}"
)
normalized = self._normalize_audio_inputs(ref_audio_list)
ref_wavs_for_code: List[np.ndarray] = []
ref_sr_for_code: List[int] = []
for wav, sr in normalized:
ref_wavs_for_code.append(wav)
ref_sr_for_code.append(sr)
if len(set(ref_sr_for_code)) == 1:
enc = self.model.speech_tokenizer.encode(ref_wavs_for_code, sr=ref_sr_for_code[0])
ref_codes = enc.audio_codes
else:
ref_codes = []
for wav, sr in normalized:
ref_codes.append(self.model.speech_tokenizer.encode(wav, sr=sr).audio_codes[0])
items: List[VoiceClonePromptItem] = []
for i, ((wav, sr), code, rtext, xvec_only) in enumerate(zip(normalized, ref_codes, ref_text_list, xvec_list)):
if not xvec_only:
if rtext is None or rtext == "":
raise ValueError(f"ref_text is required when x_vector_only_mode=False (ICL mode). Bad index={i}")
wav_resample = wav
if sr != self.model.speaker_encoder_sample_rate:
wav_resample = librosa.resample(y=wav_resample.astype(np.float32),
orig_sr=int(sr),
target_sr=self.model.speaker_encoder_sample_rate)
spk_emb = self.model.extract_speaker_embedding(audio=wav_resample,
sr=self.model.speaker_encoder_sample_rate)
items.append(
VoiceClonePromptItem(
ref_code=None if xvec_only else code,
ref_spk_embedding=spk_emb,
x_vector_only_mode=bool(xvec_only),
icl_mode=bool(not xvec_only),
ref_text=rtext,
)
)
return items
def _prompt_items_to_voice_clone_prompt(self, items: List[VoiceClonePromptItem]) -> Dict[str, Any]:
return dict(
ref_code=[it.ref_code for it in items],
ref_spk_embedding=[it.ref_spk_embedding for it in items],
x_vector_only_mode=[it.x_vector_only_mode for it in items],
icl_mode=[it.icl_mode for it in items],
)
# voice clone model
@torch.no_grad()
def generate_voice_clone(
self,
text: Union[str, List[str]],
language: Union[str, List[str]] = None,
ref_audio: Optional[Union[AudioLike, List[AudioLike]]] = None,
ref_text: Optional[Union[str, List[Optional[str]]]] = None,
x_vector_only_mode: Union[bool, List[bool]] = False,
voice_clone_prompt: Optional[Union[Dict[str, Any], List[VoiceClonePromptItem]]] = None,
non_streaming_mode: bool = False,
**kwargs,
) -> Tuple[List[np.ndarray], int]:
"""
Voice clone speech using the Base model.
You can provide either:
- (ref_audio, ref_text, x_vector_only_mode) and let this method build the prompt, OR
- `VoiceClonePromptItem` returned by `create_voice_clone_prompt`, OR
- a list of `VoiceClonePromptItem` returned by `create_voice_clone_prompt`.
`ref_audio` Supported forms:
- str: wav path / URL / base64 audio string
- (np.ndarray, sr): waveform + sampling rate
- list of the above
Input flexibility:
- text/language can be scalar or list.
- prompt can be single or batch.
- If batch mode (len(text)>1), lengths must match.
Args:
text:
Text(s) to synthesize.
language:
Language(s) for each sample.
ref_audio:
Reference audio(s) for prompt building. Required if voice_clone_prompt is not provided.
ref_text:
Reference text(s) used for ICL mode (required when x_vector_only_mode=False).
x_vector_only_mode:
If True, only speaker embedding is used (ignores ref_text/ref_code).
If False, ICL mode is used automatically.
voice_clone_prompt:
list[VoiceClonePromptItem] from `create_voice_clone_prompt`.
non_streaming_mode:
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
rather than enabling true streaming input or streaming generation.
do_sample:
Whether to use sampling, recommended to be set to `true` for most use cases.
top_k:
Top-k sampling parameter.
top_p:
Top-p sampling parameter.
temperature:
Sampling temperature; higher => more random.
repetition_penalty:
Penalty to reduce repeated tokens/codes.
subtalker_dosample:
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
subtalker_top_k:
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
subtalker_top_p:
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
subtalker_temperature:
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
max_new_tokens:
Maximum number of new codec tokens to generate.
**kwargs:
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
Returns:
Tuple[List[np.ndarray], int]:
(wavs, sample_rate)
Raises:
ValueError:
If batch sizes mismatch or required prompt inputs are missing.
"""
if self.model.tts_model_type != "base":
raise ValueError(
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
f"tts_model_size: {self.model.tts_model_size}\n"
f"tts_model_type: {self.model.tts_model_type}\n"
"does not support generate_voice_clone, Please check Model Card or Readme for more details."
)
texts = self._ensure_list(text)
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
if len(languages) == 1 and len(texts) > 1:
languages = languages * len(texts)
if len(texts) != len(languages):
raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}")
self._validate_languages(languages)
if voice_clone_prompt is None:
if ref_audio is None:
raise ValueError("Either `voice_clone_prompt` or `ref_audio` must be provided.")
prompt_items = self.create_voice_clone_prompt(ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=x_vector_only_mode)
if len(prompt_items) == 1 and len(texts) > 1:
prompt_items = prompt_items * len(texts)
if len(prompt_items) != len(texts):
raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
ref_texts_for_ids = [it.ref_text for it in prompt_items]
else:
if isinstance(voice_clone_prompt, list):
prompt_items = voice_clone_prompt
if len(prompt_items) == 1 and len(texts) > 1:
prompt_items = prompt_items * len(texts)
if len(prompt_items) != len(texts):
raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
ref_texts_for_ids = [it.ref_text for it in prompt_items]
else:
voice_clone_prompt_dict = voice_clone_prompt
ref_texts_for_ids = None
input_texts = [self._build_assistant_text(t) for t in texts]
input_ids = self._tokenize_texts(input_texts)
ref_ids = None
if ref_texts_for_ids is not None:
ref_ids = []
for i, rt in enumerate(ref_texts_for_ids):
if rt is None or rt == "":
ref_ids.append(None)
else:
ref_tok = self._tokenize_texts([self._build_ref_text(rt)])[0]
ref_ids.append(ref_tok)
gen_kwargs = self._merge_generate_kwargs(**kwargs)
talker_codes_list, _ = self.model.generate(
input_ids=input_ids,
ref_ids=ref_ids,
voice_clone_prompt=voice_clone_prompt_dict,
languages=languages,
non_streaming_mode=non_streaming_mode,
**gen_kwargs,
)
codes_for_decode = []
for i, codes in enumerate(talker_codes_list):
ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
if ref_code_list is not None and ref_code_list[i] is not None:
codes_for_decode.append(torch.cat([ref_code_list[i].to(codes.device), codes], dim=0))
else:
codes_for_decode.append(codes)
wavs_all, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in codes_for_decode])
wavs_out: List[np.ndarray] = []
for i, wav in enumerate(wavs_all):
ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
if ref_code_list is not None and ref_code_list[i] is not None:
ref_len = int(ref_code_list[i].shape[0])
total_len = int(codes_for_decode[i].shape[0])
cut = int(ref_len / max(total_len, 1) * wav.shape[0])
wavs_out.append(wav[cut:])
else:
wavs_out.append(wav)
return wavs_out, fs
# voice design model
@torch.no_grad()
def generate_voice_design(
self,
text: Union[str, List[str]],
instruct: Union[str, List[str]],
language: Union[str, List[str]] = None,
non_streaming_mode: bool = True,
**kwargs,
) -> Tuple[List[np.ndarray], int]:
"""
Generate speech with the VoiceDesign model using natural-language style instructions.
Args:
text:
Text(s) to synthesize.
language:
Language(s) for each sample.
instruct:
Instruction(s) describing desired voice/style. Empty string is allowed (treated as no instruction).
non_streaming_mode:
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
rather than enabling true streaming input or streaming generation.
do_sample:
Whether to use sampling, recommended to be set to `true` for most use cases.
top_k:
Top-k sampling parameter.
top_p:
Top-p sampling parameter.
temperature:
Sampling temperature; higher => more random.
repetition_penalty:
Penalty to reduce repeated tokens/codes.
subtalker_dosample:
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
subtalker_top_k:
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
subtalker_top_p:
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
subtalker_temperature:
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
max_new_tokens:
Maximum number of new codec tokens to generate.
**kwargs:
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
Returns:
Tuple[List[np.ndarray], int]:
(wavs, sample_rate)
"""
if self.model.tts_model_type != "voice_design":
raise ValueError(
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
f"tts_model_size: {self.model.tts_model_size}\n"
f"tts_model_type: {self.model.tts_model_type}\n"
"does not support generate_voice_design, Please check Model Card or Readme for more details."
)
texts = self._ensure_list(text)
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
instructs = self._ensure_list(instruct)
if len(languages) == 1 and len(texts) > 1:
languages = languages * len(texts)
if len(instructs) == 1 and len(texts) > 1:
instructs = instructs * len(texts)
if not (len(texts) == len(languages) == len(instructs)):
raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}, instruct={len(instructs)}")
self._validate_languages(languages)
input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
instruct_ids: List[Optional[torch.Tensor]] = []
for ins in instructs:
if ins is None or ins == "":
instruct_ids.append(None)
else:
instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
gen_kwargs = self._merge_generate_kwargs(**kwargs)
talker_codes_list, _ = self.model.generate(
input_ids=input_ids,
instruct_ids=instruct_ids,
languages=languages,
non_streaming_mode=non_streaming_mode,
**gen_kwargs,
)
wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
return wavs, fs
# custom voice model
@torch.no_grad()
def generate_custom_voice(
self,
text: Union[str, List[str]],
speaker: Union[str, List[str]],
language: Union[str, List[str]] = None,
instruct: Optional[Union[str, List[str]]] = None,
non_streaming_mode: bool = True,
**kwargs,
) -> Tuple[List[np.ndarray], int]:
"""
Generate speech with the CustomVoice model using a predefined speaker id, optionally controlled by instruction text.
Args:
text:
Text(s) to synthesize.
language:
Language(s) for each sample.
speaker:
Speaker name(s). Will be validated against `model.get_supported_speakers()` (case-insensitive).
instruct:
Optional instruction(s). If None, treated as empty (no instruction).
non_streaming_mode:
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
rather than enabling true streaming input or streaming generation.
do_sample:
Whether to use sampling, recommended to be set to `true` for most use cases.
top_k:
Top-k sampling parameter.
top_p:
Top-p sampling parameter.
temperature:
Sampling temperature; higher => more random.
repetition_penalty:
Penalty to reduce repeated tokens/codes.
subtalker_dosample:
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
subtalker_top_k:
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
subtalker_top_p:
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
subtalker_temperature:
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
max_new_tokens:
Maximum number of new codec tokens to generate.
**kwargs:
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
Returns:
Tuple[List[np.ndarray], int]:
(wavs, sample_rate)
Raises:
ValueError:
If any speaker/language is unsupported or batch sizes mismatch.
"""
if self.model.tts_model_type != "custom_voice":
raise ValueError(
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
f"tts_model_size: {self.model.tts_model_size}\n"
f"tts_model_type: {self.model.tts_model_type}\n"
"does not support generate_custom_voice, Please check Model Card or Readme for more details."
)
texts = self._ensure_list(text)
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
speakers = self._ensure_list(speaker)
if self.model.tts_model_size in "0b6": # for 0b6 model, instruct is not supported
instruct = None
instructs = self._ensure_list(instruct) if isinstance(instruct, list) else ([instruct] * len(texts) if instruct is not None else [""] * len(texts))
if len(languages) == 1 and len(texts) > 1:
languages = languages * len(texts)
if len(speakers) == 1 and len(texts) > 1:
speakers = speakers * len(texts)
if len(instructs) == 1 and len(texts) > 1:
instructs = instructs * len(texts)
if not (len(texts) == len(languages) == len(speakers) == len(instructs)):
raise ValueError(
f"Batch size mismatch: text={len(texts)}, language={len(languages)}, speaker={len(speakers)}, instruct={len(instructs)}"
)
self._validate_languages(languages)
self._validate_speakers(speakers)
input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
instruct_ids: List[Optional[torch.Tensor]] = []
for ins in instructs:
if ins is None or ins == "":
instruct_ids.append(None)
else:
instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
gen_kwargs = self._merge_generate_kwargs(**kwargs)
talker_codes_list, _ = self.model.generate(
input_ids=input_ids,
instruct_ids=instruct_ids,
languages=languages,
speakers=speakers,
non_streaming_mode=non_streaming_mode,
**gen_kwargs,
)
wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
return wavs, fs
def get_supported_speakers(self) -> Optional[List[str]]:
"""
List supported speaker names for the current model.
This is a convenience wrapper around `model.get_supported_speakers()`.
If the underlying model does not expose speaker constraints (returns None),
this method also returns None.
Returns:
Optional[List[str]]:
- A sorted list of supported speaker names (lowercased), if available.
- None if the model does not provide supported speakers.
"""
supported = self._supported_speakers_set()
if supported is None:
return None
return sorted(supported)
def get_supported_languages(self) -> Optional[List[str]]:
"""
List supported language names for the current model.
This is a convenience wrapper around `model.get_supported_languages()`.
If the underlying model does not expose language constraints (returns None),
this method also returns None.
Returns:
Optional[List[str]]:
- A sorted list of supported language names (lowercased), if available.
- None if the model does not provide supported languages.
"""
supported = self._supported_languages_set()
if supported is None:
return None
return sorted(supported)