Attention
The Attention can be used in the form of self-attention to identify important positional features or channel features. It can be also used to add demographic data into the model.
Modules
Positional Attention
The code below demonstrates how to use the Positional Attention Module (PAM).
The object constructer for the Position Attention Module used to attend to different location specific features via aggreagation context.
Parameters for Constructor:
- in_shape (int,required): the number of channels in the input tensor for PAM Module
- reduction (int,default = 8): the compression in features channels to be done before computing the attention
- query_conv_kernel (int, default = 1): The kernel size for convolutional filter applied in query features
- key_conv_kernel (int, default = 1): The kernel size for convolutional filter applied in key features
- value_conv_kernel (int, default = 1): The kernel size for convolutional filter applied in value features
Usage:
from niftytorch.attention.attention import pam
PAM = pam(in_shape = 512,reduction = 8,query_conv_kernel = 3,key_conv_kernel = 3,value_conv_kernel = 3)
t = torch.rand(64,512,32,32)
out,attention = PAM(t)
print(out.shape)
>>> 64 x 512 x 32 x 32
print(attention.shape)
>>> 64 x 1024 x 1024
The PAM returns two tensors where the first tensor is the output from positional attention and the second is the attention map.
Channel Attention
The code below demonstrates how to use the Channel Attention Module (CAM).
The object constructer for the Channel Attention Module used to attend to different channel specific features.
Parameters for Constructor:
- in_shape (int,required): the number of channels in the input tensor for CAM Module
Usage:
from niftytorch.attention.attention import cam
PAM = cam(512)
t = torch.rand(64,512,32,32)
out,attention = CAM(t)
print(out.shape)
>>> 64 x 512 x 32 x 32
print(attention.shape)
>>> 64 x 512 x 512
The CAM returns two tensors where the first tensor is the output from channel-wise attention and the second is the attention map.