| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from einops import rearrange |
| from huggingface_hub import PyTorchModelHubMixin |
| from .mamba_block2_SEMamba import TFMambaBlock |
| from .codec_module_time_d4 import DenseEncoder, MagDecoder, PhaseDecoder |
|
|
| class SEMamba(nn.Module, PyTorchModelHubMixin): |
| """ |
| SEMamba model for speech enhancement using Mamba blocks. |
| |
| This model uses a dense encoder, multiple Mamba blocks, and separate magnitude |
| and phase decoders to process noisy magnitude and phase inputs. |
| """ |
| def __init__(self, cfg): |
| """ |
| Initialize the SEMamba model. |
| |
| Args: |
| - cfg: Configuration object containing model parameters. |
| """ |
| super(SEMamba, self).__init__() |
| self.cfg = cfg |
| self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 |
|
|
| |
| self.dense_encoder = DenseEncoder(cfg) |
|
|
| |
| self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)]) |
|
|
| |
| self.mask_decoder = MagDecoder(cfg) |
| self.phase_decoder = PhaseDecoder(cfg) |
|
|
| def forward(self, noisy_mag, noisy_pha): |
| """ |
| Forward pass for the SEMamba model. |
| |
| Args: |
| - noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T]. |
| - noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T]. |
| |
| Returns: |
| - denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T]. |
| - denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T]. |
| - denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2]. |
| """ |
| |
| noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) |
| noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) |
|
|
| |
| x = torch.cat((noisy_mag, noisy_pha), dim=1) |
| |
| |
| B, C, T, F = x.shape |
| zeros = torch.zeros(B, C, T, 2, device=x.device) |
| x = torch.cat((x, zeros), dim=-1) |
| zeros = torch.zeros(B, C, 2, F+2, device=x.device) |
| x = torch.cat((x, zeros), dim=-2) |
| |
| |
| x = self.dense_encoder(x) |
|
|
| |
| for block in self.TSMamba: |
| x = block(x) |
|
|
| |
| denoised_mag = rearrange(self.mask_decoder(x), 'b c t f -> b f t c').squeeze(-1) |
| denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1) |
| |
| |
| denoised_mag = denoised_mag[:, :F, :T] |
| denoised_pha = denoised_pha[:, :F, :T] |
| |
| |
| denoised_com = torch.stack( |
| (denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)), |
| dim=-1 |
| ) |
|
|
| return denoised_mag, denoised_pha, denoised_com |
|
|