Files
ViGent2/models/Qwen3-TTS/finetuning/sft_12hz.py
Kevin Wong 4a3dd2b225 更新
2026-01-28 17:22:31 +08:00

162 lines
6.7 KiB
Python

# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import shutil
import torch
from accelerate import Accelerator
from dataset import TTSDataset
from qwen_tts.inference.qwen3_tts_model import Qwen3TTSModel
from safetensors.torch import save_file
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoConfig
target_speaker_embedding = None
def train():
global target_speaker_embedding
parser = argparse.ArgumentParser()
parser.add_argument("--init_model_path", type=str, default="Qwen/Qwen3-TTS-12Hz-1.7B-Base")
parser.add_argument("--output_model_path", type=str, default="output")
parser.add_argument("--train_jsonl", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--num_epochs", type=int, default=3)
parser.add_argument("--speaker_name", type=str, default="speaker_test")
args = parser.parse_args()
accelerator = Accelerator(gradient_accumulation_steps=4, mixed_precision="bf16", log_with="tensorboard")
MODEL_PATH = args.init_model_path
qwen3tts = Qwen3TTSModel.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
config = AutoConfig.from_pretrained(MODEL_PATH)
train_data = open(args.train_jsonl).readlines()
train_data = [json.loads(line) for line in train_data]
dataset = TTSDataset(train_data, qwen3tts.processor, config)
train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=dataset.collate_fn)
optimizer = AdamW(qwen3tts.model.parameters(), lr=args.lr, weight_decay=0.01)
model, optimizer, train_dataloader = accelerator.prepare(
qwen3tts.model, optimizer, train_dataloader
)
num_epochs = args.num_epochs
model.train()
for epoch in range(num_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(model):
input_ids = batch['input_ids']
codec_ids = batch['codec_ids']
ref_mels = batch['ref_mels']
text_embedding_mask = batch['text_embedding_mask']
codec_embedding_mask = batch['codec_embedding_mask']
attention_mask = batch['attention_mask']
codec_0_labels = batch['codec_0_labels']
codec_mask = batch['codec_mask']
speaker_embedding = model.speaker_encoder(ref_mels.to(model.device).to(model.dtype)).detach()
if target_speaker_embedding is None:
target_speaker_embedding = speaker_embedding
input_text_ids = input_ids[:, :, 0]
input_codec_ids = input_ids[:, :, 1]
input_text_embedding = model.talker.model.text_embedding(input_text_ids) * text_embedding_mask
input_codec_embedding = model.talker.model.codec_embedding(input_codec_ids) * codec_embedding_mask
input_codec_embedding[:, 6, :] = speaker_embedding
input_embeddings = input_text_embedding + input_codec_embedding
for i in range(1, 16):
codec_i_embedding = model.talker.code_predictor.get_input_embeddings()[i - 1](codec_ids[:, :, i])
codec_i_embedding = codec_i_embedding * codec_mask.unsqueeze(-1)
input_embeddings = input_embeddings + codec_i_embedding
outputs = model.talker(
inputs_embeds=input_embeddings[:, :-1, :],
attention_mask=attention_mask[:, :-1],
labels=codec_0_labels[:, 1:],
output_hidden_states=True
)
hidden_states = outputs.hidden_states[0][-1]
talker_hidden_states = hidden_states[codec_mask[:, 1:]]
talker_codec_ids = codec_ids[codec_mask]
sub_talker_logits, sub_talker_loss = model.talker.forward_sub_talker_finetune(talker_codec_ids, talker_hidden_states)
loss = outputs.loss + sub_talker_loss
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
if step % 10 == 0:
accelerator.print(f"Epoch {epoch} | Step {step} | Loss: {loss.item():.4f}")
if accelerator.is_main_process:
output_dir = os.path.join(args.output_model_path, f"checkpoint-epoch-{epoch}")
shutil.copytree(MODEL_PATH, output_dir, dirs_exist_ok=True)
input_config_file = os.path.join(MODEL_PATH, "config.json")
output_config_file = os.path.join(output_dir, "config.json")
with open(input_config_file, 'r', encoding='utf-8') as f:
config_dict = json.load(f)
config_dict["tts_model_type"] = "custom_voice"
talker_config = config_dict.get("talker_config", {})
talker_config["spk_id"] = {
args.speaker_name: 3000
}
talker_config["spk_is_dialect"] = {
args.speaker_name: False
}
config_dict["talker_config"] = talker_config
with open(output_config_file, 'w', encoding='utf-8') as f:
json.dump(config_dict, f, indent=2, ensure_ascii=False)
unwrapped_model = accelerator.unwrap_model(model)
state_dict = {k: v.detach().to("cpu") for k, v in unwrapped_model.state_dict().items()}
drop_prefix = "speaker_encoder"
keys_to_drop = [k for k in state_dict.keys() if k.startswith(drop_prefix)]
for k in keys_to_drop:
del state_dict[k]
weight = state_dict['talker.model.codec_embedding.weight']
state_dict['talker.model.codec_embedding.weight'][3000] = target_speaker_embedding[0].detach().to(weight.device).to(weight.dtype)
save_path = os.path.join(output_dir, "model.safetensors")
save_file(state_dict, save_path)
if __name__ == "__main__":
train()