162 lines
6.7 KiB
Python
162 lines
6.7 KiB
Python
# coding=utf-8
|
|
# Copyright 2026 The Alibaba Qwen team.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
import json
|
|
import os
|
|
import shutil
|
|
|
|
import torch
|
|
from accelerate import Accelerator
|
|
from dataset import TTSDataset
|
|
from qwen_tts.inference.qwen3_tts_model import Qwen3TTSModel
|
|
from safetensors.torch import save_file
|
|
from torch.optim import AdamW
|
|
from torch.utils.data import DataLoader
|
|
from transformers import AutoConfig
|
|
|
|
target_speaker_embedding = None
|
|
def train():
|
|
global target_speaker_embedding
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--init_model_path", type=str, default="Qwen/Qwen3-TTS-12Hz-1.7B-Base")
|
|
parser.add_argument("--output_model_path", type=str, default="output")
|
|
parser.add_argument("--train_jsonl", type=str, required=True)
|
|
parser.add_argument("--batch_size", type=int, default=2)
|
|
parser.add_argument("--lr", type=float, default=2e-5)
|
|
parser.add_argument("--num_epochs", type=int, default=3)
|
|
parser.add_argument("--speaker_name", type=str, default="speaker_test")
|
|
args = parser.parse_args()
|
|
|
|
accelerator = Accelerator(gradient_accumulation_steps=4, mixed_precision="bf16", log_with="tensorboard")
|
|
|
|
MODEL_PATH = args.init_model_path
|
|
|
|
qwen3tts = Qwen3TTSModel.from_pretrained(
|
|
MODEL_PATH,
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2",
|
|
)
|
|
config = AutoConfig.from_pretrained(MODEL_PATH)
|
|
|
|
train_data = open(args.train_jsonl).readlines()
|
|
train_data = [json.loads(line) for line in train_data]
|
|
dataset = TTSDataset(train_data, qwen3tts.processor, config)
|
|
train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=dataset.collate_fn)
|
|
|
|
optimizer = AdamW(qwen3tts.model.parameters(), lr=args.lr, weight_decay=0.01)
|
|
|
|
model, optimizer, train_dataloader = accelerator.prepare(
|
|
qwen3tts.model, optimizer, train_dataloader
|
|
)
|
|
|
|
num_epochs = args.num_epochs
|
|
model.train()
|
|
|
|
for epoch in range(num_epochs):
|
|
for step, batch in enumerate(train_dataloader):
|
|
with accelerator.accumulate(model):
|
|
|
|
input_ids = batch['input_ids']
|
|
codec_ids = batch['codec_ids']
|
|
ref_mels = batch['ref_mels']
|
|
text_embedding_mask = batch['text_embedding_mask']
|
|
codec_embedding_mask = batch['codec_embedding_mask']
|
|
attention_mask = batch['attention_mask']
|
|
codec_0_labels = batch['codec_0_labels']
|
|
codec_mask = batch['codec_mask']
|
|
|
|
speaker_embedding = model.speaker_encoder(ref_mels.to(model.device).to(model.dtype)).detach()
|
|
if target_speaker_embedding is None:
|
|
target_speaker_embedding = speaker_embedding
|
|
|
|
input_text_ids = input_ids[:, :, 0]
|
|
input_codec_ids = input_ids[:, :, 1]
|
|
|
|
input_text_embedding = model.talker.model.text_embedding(input_text_ids) * text_embedding_mask
|
|
input_codec_embedding = model.talker.model.codec_embedding(input_codec_ids) * codec_embedding_mask
|
|
input_codec_embedding[:, 6, :] = speaker_embedding
|
|
|
|
input_embeddings = input_text_embedding + input_codec_embedding
|
|
|
|
for i in range(1, 16):
|
|
codec_i_embedding = model.talker.code_predictor.get_input_embeddings()[i - 1](codec_ids[:, :, i])
|
|
codec_i_embedding = codec_i_embedding * codec_mask.unsqueeze(-1)
|
|
input_embeddings = input_embeddings + codec_i_embedding
|
|
|
|
outputs = model.talker(
|
|
inputs_embeds=input_embeddings[:, :-1, :],
|
|
attention_mask=attention_mask[:, :-1],
|
|
labels=codec_0_labels[:, 1:],
|
|
output_hidden_states=True
|
|
)
|
|
|
|
hidden_states = outputs.hidden_states[0][-1]
|
|
talker_hidden_states = hidden_states[codec_mask[:, 1:]]
|
|
talker_codec_ids = codec_ids[codec_mask]
|
|
|
|
sub_talker_logits, sub_talker_loss = model.talker.forward_sub_talker_finetune(talker_codec_ids, talker_hidden_states)
|
|
|
|
loss = outputs.loss + sub_talker_loss
|
|
|
|
accelerator.backward(loss)
|
|
|
|
if accelerator.sync_gradients:
|
|
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
|
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
if step % 10 == 0:
|
|
accelerator.print(f"Epoch {epoch} | Step {step} | Loss: {loss.item():.4f}")
|
|
|
|
if accelerator.is_main_process:
|
|
output_dir = os.path.join(args.output_model_path, f"checkpoint-epoch-{epoch}")
|
|
shutil.copytree(MODEL_PATH, output_dir, dirs_exist_ok=True)
|
|
|
|
input_config_file = os.path.join(MODEL_PATH, "config.json")
|
|
output_config_file = os.path.join(output_dir, "config.json")
|
|
with open(input_config_file, 'r', encoding='utf-8') as f:
|
|
config_dict = json.load(f)
|
|
config_dict["tts_model_type"] = "custom_voice"
|
|
talker_config = config_dict.get("talker_config", {})
|
|
talker_config["spk_id"] = {
|
|
args.speaker_name: 3000
|
|
}
|
|
talker_config["spk_is_dialect"] = {
|
|
args.speaker_name: False
|
|
}
|
|
config_dict["talker_config"] = talker_config
|
|
|
|
with open(output_config_file, 'w', encoding='utf-8') as f:
|
|
json.dump(config_dict, f, indent=2, ensure_ascii=False)
|
|
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
state_dict = {k: v.detach().to("cpu") for k, v in unwrapped_model.state_dict().items()}
|
|
|
|
drop_prefix = "speaker_encoder"
|
|
keys_to_drop = [k for k in state_dict.keys() if k.startswith(drop_prefix)]
|
|
for k in keys_to_drop:
|
|
del state_dict[k]
|
|
|
|
weight = state_dict['talker.model.codec_embedding.weight']
|
|
state_dict['talker.model.codec_embedding.weight'][3000] = target_speaker_embedding[0].detach().to(weight.device).to(weight.dtype)
|
|
save_path = os.path.join(output_dir, "model.safetensors")
|
|
save_file(state_dict, save_path)
|
|
|
|
if __name__ == "__main__":
|
|
train()
|