RE-USE / models /generator_SEMamba_time_d4.py
szuweifu's picture
Upload generator_SEMamba_time_d4.py
4ffb980 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import torch
import torch.nn as nn
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from .mamba_block2_SEMamba import TFMambaBlock
from .codec_module_time_d4 import DenseEncoder, MagDecoder, PhaseDecoder
class SEMamba(nn.Module, PyTorchModelHubMixin):
"""
SEMamba model for speech enhancement using Mamba blocks.
This model uses a dense encoder, multiple Mamba blocks, and separate magnitude
and phase decoders to process noisy magnitude and phase inputs.
"""
def __init__(self, cfg):
"""
Initialize the SEMamba model.
Args:
- cfg: Configuration object containing model parameters.
"""
super(SEMamba, self).__init__()
self.cfg = cfg
self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 # default tfmamba: 4
# Initialize dense encoder
self.dense_encoder = DenseEncoder(cfg)
# Initialize Mamba blocks
self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)])
# Initialize decoders
self.mask_decoder = MagDecoder(cfg)
self.phase_decoder = PhaseDecoder(cfg)
def forward(self, noisy_mag, noisy_pha):
"""
Forward pass for the SEMamba model.
Args:
- noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T].
- noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T].
Returns:
- denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T].
- denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T].
- denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2].
"""
# Reshape inputs
noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
# Concatenate magnitude and phase inputs
x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F]
# Prevent unpredictable errors
B, C, T, F = x.shape
zeros = torch.zeros(B, C, T, 2, device=x.device)
x = torch.cat((x, zeros), dim=-1)
zeros = torch.zeros(B, C, 2, F+2, device=x.device)
x = torch.cat((x, zeros), dim=-2)
# Encode input
x = self.dense_encoder(x)
# Apply Mamba blocks
for block in self.TSMamba:
x = block(x)
# Decode output
denoised_mag = rearrange(self.mask_decoder(x), 'b c t f -> b f t c').squeeze(-1)
denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1)
# Prevent unpredictable errors
denoised_mag = denoised_mag[:, :F, :T]
denoised_pha = denoised_pha[:, :F, :T]
# Combine denoised magnitude and phase into a complex representation
denoised_com = torch.stack(
(denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)),
dim=-1
)
return denoised_mag, denoised_pha, denoised_com