| | from torch import nn |
| | import torch |
| | from method.MambaCSSM import MambaCSSM |
| |
|
| | class MambaCSSMUnet(nn.Module): |
| |
|
| | def __init__(self, output_classes = 2): |
| | super(MambaCSSMUnet, self).__init__() |
| |
|
| | |
| | self.conv_block_1 = nn.Sequential( |
| | nn.Conv2d(6, 16, 3, 1, padding=1), |
| | nn.BatchNorm2d(16), |
| | nn.ReLU(), |
| | nn.Conv2d(16, 16, 3, 1, padding=1), |
| | nn.BatchNorm2d(16), |
| | nn.ReLU() |
| | ) |
| |
|
| | self.mp_block_1 = nn.MaxPool2d(2, 2, return_indices=True) |
| |
|
| | self.conv_block_2 = nn.Sequential( |
| | nn.Conv2d(16, 32, 3, 1, padding=1), |
| | nn.BatchNorm2d(32), |
| | nn.ReLU(), |
| | nn.Conv2d(32, 32, 3, 1, padding=1), |
| | nn.BatchNorm2d(32), |
| | nn.ReLU() |
| | ) |
| |
|
| | self.mp_block_2 = nn.MaxPool2d(2, 2, return_indices=True) |
| |
|
| | self.conv_block_3 = nn.Sequential( |
| | nn.Conv2d(32, 64, 3, 1, padding=1), |
| | nn.BatchNorm2d(64), |
| | nn.ReLU(), |
| | nn.Conv2d(64, 64, 3, 1, padding=1), |
| | nn.BatchNorm2d(64), |
| | nn.ReLU() |
| | ) |
| |
|
| | self.mp_block_3 = nn.MaxPool2d(2, 2, return_indices=True) |
| |
|
| | self.conv_block_4 = nn.Sequential( |
| | nn.Conv2d(64, 128, 3, 1, padding=1), |
| | nn.BatchNorm2d(128), |
| | nn.ReLU(), |
| | nn.Conv2d(128, 128, 3, 1, padding=1), |
| | nn.BatchNorm2d(128), |
| | nn.ReLU() |
| | ) |
| |
|
| | self.mp_block_4 = nn.MaxPool2d(2, 2, return_indices=True) |
| |
|
| | |
| |
|
| |
|
| | self.mamba = MambaCSSM(num_layers=4, d_model=256,d_conv=4, d_state=16) |
| |
|
| | |
| | |
| | self.mpu_block_4 = nn.MaxUnpool2d(2, 2) |
| | self.conv_4 = nn.Sequential( |
| | nn.Conv2d(256, 128, 3, 1, padding=1), |
| | nn.ReLU() |
| | ) |
| | self.deconv_4_block = nn.Sequential( |
| | nn.ConvTranspose2d(128, 64, 3, 1, padding=1), |
| | nn.ReLU(), |
| | nn.ConvTranspose2d(64, 64, 3, 1, padding=1), |
| | nn.ReLU() |
| | ) |
| |
|
| | self.mpu_block_3 = nn.MaxUnpool2d(2, 2) |
| |
|
| | self.conv_3 = nn.Sequential( |
| | nn.Conv2d(128, 64, 3, 1, padding=1), |
| | nn.ReLU() |
| | ) |
| |
|
| | self.deconv_3_block = nn.Sequential( |
| | nn.ConvTranspose2d(64, 32, 3, 1, padding=1), |
| | nn.ReLU(), |
| | nn.ConvTranspose2d(32, 32, 3, 1, padding=1), |
| | nn.ReLU() |
| | ) |
| |
|
| | self.mpu_block_2 = nn.MaxUnpool2d(2, 2) |
| |
|
| | self.conv_2 = nn.Sequential( |
| | nn.Conv2d(64, 32, 3, 1, padding=1), |
| | nn.ReLU() |
| | ) |
| |
|
| | self.deconv_2_block = nn.Sequential( |
| | nn.ConvTranspose2d(32, 16, 3, 1, padding=1), |
| | nn.ReLU(), |
| | nn.ConvTranspose2d(16, 16, 3, 1, padding=1), |
| | nn.ReLU() |
| | ) |
| |
|
| | self.mpu_block_1 = nn.MaxUnpool2d(2, 2) |
| |
|
| | self.conv_1 = nn.Sequential( |
| | nn.Conv2d(32, 16, 3, 1, padding=1), |
| | nn.ReLU() |
| | ) |
| |
|
| | self.deconv_1_block = nn.Sequential( |
| | nn.ConvTranspose2d(16, 8, 3, 1, padding=1), |
| | nn.ReLU(), |
| | nn.ConvTranspose2d(8, 6, 3, 1, padding=1), |
| | nn.ReLU() |
| | ) |
| |
|
| | self.conv_final = nn.Conv2d(6, output_classes, 1, 1) |
| |
|
| |
|
| | def forward(self, t1,t2): |
| |
|
| | t = torch.cat([t1,t2], dim = 1) |
| |
|
| | x1 = self.conv_block_1(t) |
| | f1, i1 = self.mp_block_1(x1) |
| | x2 = self.conv_block_2(f1) |
| | f2, i2 = self.mp_block_2(x2) |
| | x3 = self.conv_block_3(f2) |
| | f3, i3 = self.mp_block_3(x3) |
| | x4 = self.conv_block_4(f3) |
| | f4, i4 = self.mp_block_4(x4) |
| |
|
| |
|
| |
|
| | b,c,h,w = f4.shape |
| | f4_t1 = f4[:,:c//2, :,:] |
| | f4_t2 = f4[:,c//2:, :,:] |
| |
|
| |
|
| |
|
| | |
| | f4_t1 = f4_t1.view((-1, 64, 16*16)) |
| | f4_t2 = f4_t2.view((-1, 64, 16*16)) |
| | f5_t1,f5_t2 = self.mamba(f4_t1, f4_t2) |
| | f5_t1 = f5_t1.view((-1, 64, 16, 16)) |
| | f5_t2 = f5_t2.view((-1, 64, 16, 16)) |
| |
|
| | f5 = torch.cat([f5_t1, f5_t2], dim = 1) |
| |
|
| |
|
| | f6 = self.mpu_block_4(f5, i4) |
| | f7 = self.conv_4(torch.cat((x4, f6), dim=1)) |
| | f8 = self.deconv_4_block(f7) |
| |
|
| | f9 = self.mpu_block_3(f8, i3, output_size=x3.size()) |
| | f10 = self.conv_3(torch.cat((f9, x3), dim=1)) |
| | f11 = self.deconv_3_block(f10) |
| |
|
| | f12 = self.mpu_block_2(f11, i2) |
| | f13 = self.conv_2(torch.cat((f12, x2), dim=1)) |
| |
|
| | f14 = self.deconv_2_block(f13) |
| |
|
| | f15 = self.mpu_block_1(f14, i1) |
| | f16 = self.conv_1(torch.cat((f15, x1), dim=1)) |
| | f17 = self.deconv_1_block(f16) |
| | f18 = self.conv_final(f17) |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | return f18 |