1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
| @MODELS.register_module() class MaxSigmoidAttnBlock(BaseModule): """Max Sigmoid attention block."""
def __init__(self, in_channels: int, out_channels: int, guide_channels: int, embed_channels: int, kernel_size: int = 3, padding: int = 1, num_heads: int = 1, use_depthwise: bool = False, with_scale: bool = False, conv_cfg: OptConfigType = None, norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), init_cfg: OptMultiConfig = None, use_einsum: bool = True) -> None: super().__init__(init_cfg=init_cfg) conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
assert (out_channels % num_heads == 0 and embed_channels % num_heads == 0), \ 'out_channels and embed_channels should be divisible by num_heads.' self.num_heads = num_heads self.head_channels = embed_channels // num_heads self.use_einsum = use_einsum
self.embed_conv = ConvModule( in_channels, embed_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) if embed_channels != in_channels else None self.guide_fc = Linear(guide_channels, embed_channels) self.bias = nn.Parameter(torch.zeros(num_heads)) if with_scale: self.scale = nn.Parameter(torch.ones(1, num_heads, 1, 1)) else: self.scale = 1.0
self.project_conv = conv(in_channels, out_channels, kernel_size, stride=1, padding=padding, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None)
def forward(self, x: Tensor, guide: Tensor) -> Tensor: """Forward process.""" B, _, H, W = x.shape
guide = self.guide_fc(guide) guide = guide.reshape(B, -1, self.num_heads, self.head_channels) embed = self.embed_conv(x) if self.embed_conv is not None else x embed = embed.reshape(B, self.num_heads, self.head_channels, H, W)
if self.use_einsum: attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide) else: batch, m, channel, height, width = embed.shape _, n, _, _ = guide.shape embed = embed.permute(0, 1, 3, 4, 2) embed = embed.reshape(batch, m, -1, channel) guide = guide.permute(0, 2, 3, 1) attn_weight = torch.matmul(embed, guide) attn_weight = attn_weight.reshape(batch, m, height, width, n)
attn_weight = attn_weight.max(dim=-1)[0] attn_weight = attn_weight / (self.head_channels**0.5) attn_weight = attn_weight + self.bias[None, :, None, None] attn_weight = attn_weight.sigmoid() * self.scale
x = self.project_conv(x) x = x.reshape(B, self.num_heads, -1, H, W) x = x * attn_weight.unsqueeze(2) x = x.reshape(B, -1, H, W) return x