Files
ViGent/models/MuseTalk/musetalk/utils/audio_processor.py
2026-01-16 16:27:30 +08:00

114 lines
4.8 KiB
Python

import math
import os
import librosa
import numpy as np
import torch
from einops import rearrange
from transformers import AutoFeatureExtractor
class AudioProcessor:
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
if not os.path.exists(wav_path):
return None
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
assert sampling_rate == 16000
# Split audio into 30s segments
segment_length = 30 * sampling_rate
segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
features = []
for segment in segments:
audio_feature = self.feature_extractor(
segment,
return_tensors="pt",
sampling_rate=sampling_rate
).input_features
if weight_dtype is not None:
audio_feature = audio_feature.to(dtype=weight_dtype)
features.append(audio_feature)
return features, len(librosa_output)
def get_whisper_chunk(
self,
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=25,
audio_padding_length_left=2,
audio_padding_length_right=2,
):
audio_feature_length_per_frame = 2 * (audio_padding_length_left + audio_padding_length_right + 1)
whisper_feature = []
# Process multiple 30s mel input features
for input_feature in whisper_input_features:
input_feature = input_feature.to(device).to(weight_dtype)
audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2)
whisper_feature.append(audio_feats)
whisper_feature = torch.cat(whisper_feature, dim=1)
# Trim the last segment to remove padding
sr = 16000
audio_fps = 50
fps = int(fps)
whisper_idx_multiplier = audio_fps / fps
num_frames = math.floor((librosa_length / sr) * fps)
actual_length = math.floor((librosa_length / sr) * audio_fps)
whisper_feature = whisper_feature[:,:actual_length,...]
# Calculate padding amount
padding_nums = math.ceil(whisper_idx_multiplier)
# Add padding at start and end
whisper_feature = torch.cat([
torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
whisper_feature,
# Add extra padding to prevent out of bounds
torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
], 1)
audio_prompts = []
for frame_index in range(num_frames):
audio_index = math.floor(frame_index * whisper_idx_multiplier)
end_index = audio_index + audio_feature_length_per_frame
# Handle case where audio is shorter than video
if end_index > whisper_feature.shape[1]:
available = whisper_feature[:, audio_index:]
padding_size = end_index - whisper_feature.shape[1]
if padding_size > 0:
padding = torch.zeros((whisper_feature.shape[0], padding_size, *whisper_feature.shape[2:]),
device=whisper_feature.device, dtype=whisper_feature.dtype)
audio_clip = torch.cat([available, padding], dim=1)
else:
audio_clip = available
else:
audio_clip = whisper_feature[:, audio_index: end_index]
# Final size check and padding
if audio_clip.shape[1] < audio_feature_length_per_frame:
padding_size = audio_feature_length_per_frame - audio_clip.shape[1]
padding = torch.zeros((whisper_feature.shape[0], padding_size, *whisper_feature.shape[2:]),
device=whisper_feature.device, dtype=whisper_feature.dtype)
audio_clip = torch.cat([audio_clip, padding], dim=1)
audio_prompts.append(audio_clip)
audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
return audio_prompts
if __name__ == "__main__":
audio_processor = AudioProcessor()
wav_path = "./2.wav"
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
print("Audio Feature shape:", audio_feature.shape)
print("librosa_feature_length:", librosa_feature_length)