-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodules.py
More file actions
185 lines (139 loc) · 5.6 KB
/
modules.py
File metadata and controls
185 lines (139 loc) · 5.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import torch
import torch.nn as nn
from typing import Any
class SEModule(nn.Module):
def __init__(self, channels: int, ratio: int = 8) -> None:
super(SEModule, self).__init__()
# Average Pooling for Squeeze
self.avgpool = nn.AdaptiveAvgPool2d(1)
# Excitation Operation
self.fc = nn.Sequential(
nn.Linear(channels, channels // ratio),
nn.ReLU(inplace=True),
# nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Linear(channels // ratio, channels),
nn.Sigmoid(),
# nn.Tanh(),
)
def forward(self, x: Any) -> Any:
# Squeeze & Excite Forward Pass
b, c, _, _ = x.size()
y = self.avgpool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
# For Tanh
# y_normalized = (y + 1) * 0.5
return x * y
class ASPPModule(nn.Module):
def __init__(self, in_channels: int, out_channels: int, dilations: list[int]) -> None:
super(ASPPModule, self).__init__()
# Atrous Convolutions
self.atrous_convs = nn.ModuleList()
for d in dilations:
at_conv = nn.Conv2d(
in_channels, out_channels, kernel_size=3, dilation=d, padding="same", bias=False
)
self.atrous_convs.append(at_conv)
self.batch_norm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.squeeze_excite = SEModule(channels=out_channels)
# self.leaky_relu = nn.LeakyReLU(0.1)
self.dropout = nn.Dropout(p=0.5)
# Upsampling by Bilinear Interpolation
self.upsample = nn.UpsamplingBilinear2d(scale_factor=16)
# Global Average Pooling
self.avgpool = nn.AvgPool2d(kernel_size=(16, 16))
# 1x1 Convolution
self.conv1x1 = nn.Conv2d(
in_channels, out_channels, kernel_size=1, padding="same", bias=False
)
# Final 1x1 Convolution
self.final_conv = nn.Conv2d(
in_channels=out_channels * (len(dilations) + 2),
out_channels=out_channels,
kernel_size=1,
padding="same",
bias=False,
)
def forward(self, x: Any) -> Any:
# ASPP Forward Pass
# 1x1 Convolution
x1 = self.conv1x1(x)
x1 = self.batch_norm(x1)
x1 = self.dropout(x1)
x1 = self.relu(x1)
# x1 = self.leaky_relu(x1)
x1 = self.squeeze_excite(x1)
# Atrous Convolutions
atrous_outputs = []
for at_conv in self.atrous_convs:
at_output = at_conv(x)
at_output = self.batch_norm(at_output)
at_output = self.relu(at_output)
# at_output = self.leaky_relu(at_output)
at_output = self.squeeze_excite(at_output)
atrous_outputs.append(at_output)
# Global Average Pooling and 1x1 Convolution for global context
avg_pool = self.avgpool(x)
avg_pool = self.conv1x1(avg_pool)
avg_pool = self.batch_norm(avg_pool)
avg_pool = self.relu(avg_pool)
# avg_pool = self.leaky_relu(avg_pool)
avg_pool = self.upsample(avg_pool)
avg_pool = self.squeeze_excite(avg_pool)
# Concatenating Dilated Convolutions and Global Average Pooling
combined_output = torch.cat((x1, *atrous_outputs, avg_pool), dim=1)
# Final 1x1 Convolution for ASPP Output
aspp_output = self.final_conv(combined_output)
aspp_output = self.batch_norm(aspp_output)
aspp_output = self.relu(aspp_output)
# aspp_output = self.leaky_relu(aspp_output)
aspp_output = self.squeeze_excite(aspp_output)
return aspp_output
class DecoderModule(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super(DecoderModule, self).__init__()
# Squeeze and Excite Module
self.squeeze_excite = SEModule(channels=304)
self.squeeze_excite2 = SEModule(channels=out_channels)
self.squeeze_excite3 = SEModule(channels=48)
# 1x1 Convolution
self.conv_low = nn.Conv2d(in_channels, 48, kernel_size=1, padding="same", bias=False)
self.batch_norm = nn.BatchNorm2d(48)
self.batch_norm2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
# self.leaky_relu = nn.LeakyReLU(0.1)
self.dropout = nn.Dropout(p=0.5)
# 3x3 Convolution
self.final_conv1 = nn.Conv2d(
in_channels=304, out_channels=256, kernel_size=3, padding="same", bias=False
)
# 3x3 Convolution
self.final_conv2 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, padding="same", bias=False
)
def forward(self, x_high: Any, x_low: Any) -> Any:
# Decoder Forward Pass
# 1x1 Convolution on Low-Level Features
x_low = self.conv_low(x_low)
x_low = self.batch_norm(x_low)
x_low = self.dropout(x_low)
x_low = self.relu(x_low)
# x_low = self.leaky_relu(x_low)
x_low = self.squeeze_excite3(x_low)
# Concatenating High-Level and Low-Level Features
x = torch.cat((x_high, x_low), dim=1)
x = self.dropout(x)
x = self.squeeze_excite(x)
# 3x3 Convolution on Concatenated Feature Map
x = self.final_conv1(x)
x = self.batch_norm2(x)
x = self.relu(x)
# x = self.leaky_relu(x)
x = self.squeeze_excite2(x)
# 3x3 Convolution on Concatenated Feature Map
x = self.final_conv2(x)
x = self.batch_norm2(x)
x = self.relu(x)
# x = self.leaky_relu(x)
x = self.squeeze_excite2(x)
return x