Skip to content

Implementing Swin Transformer: Hierarchical Vision Transformer with Shifted Windows

Introduction

The Swin Transformer (Shifted Window Transformer) represents a breakthrough in computer vision, introducing a hierarchical vision transformer architecture that efficiently computes self-attention using shifted windows. Published in 2021, this paper has become foundational for modern vision models.

Paper Reference: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

In this article, we'll explore the key concepts behind Swin Transformer and provide a clean, educational PyTorch implementation from scratch.


Key Innovations

The Swin Transformer addresses major challenges in applying transformers to vision tasks:

1. Hierarchical Architecture

Unlike standard Vision Transformers (ViT) that maintain fixed-scale features, Swin Transformer uses a hierarchical structure similar to CNNs. This allows it to:

  • Handle images at multiple scales
  • Serve as a general-purpose backbone for dense prediction tasks
  • Efficiently process high-resolution images

2. Shifted Window Mechanism

Instead of computing global self-attention (which is computationally expensive), Swin Transformer:

  • Computes attention within local windows
  • Uses shifted windows across layers to enable cross-window connections
  • Achieves linear computational complexity relative to image size

3. Efficiency

The windowed attention mechanism reduces complexity from O(n²) to O(n), making it practical for high-resolution images.


Architecture Overview

The Swin Transformer consists of several key components:

Input Image
    ↓
Patch Embedding (Split into patches)
    ↓
Stage 1: Swin Blocks (Window Attention)
    ↓
Patch Merging (Downsample)
    ↓
Stage 2: Swin Blocks (Shifted Window Attention)
    ↓
Patch Merging
    ↓
Stage 3 & 4: More Swin Blocks
    ↓
Output Features

Core Components:

  1. Patch Embedding: Splits the image into non-overlapping patches
  2. Window Partitioning: Divides feature maps into fixed-size windows
  3. Window Attention: Computes self-attention within each window
  4. Shifted Windows: Alternates between regular and shifted window partitions
  5. Patch Merging: Downsamples features between stages (like pooling in CNNs)

Implementation

Let's build the Swin Transformer from scratch in PyTorch. This is an educational implementation focusing on clarity and understanding.

Step 1: Window Utilities

First, we need helper functions to partition feature maps into windows and reverse the operation:

import torch

def window_partition(x, window_size):
    """
    Partition feature map into non-overlapping windows.

    Args:
        x: (B, H, W, C) tensor
        window_size: Window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return windows.view(-1, window_size, window_size, C)

def window_reverse(windows, window_size, H, W):
    """
    Reverse window partition back to feature map.

    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size: Window size
        H: Height of feature map
        W: Width of feature map

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return x.view(B, H, W, -1)

Step 2: Window-Based Multi-Head Self-Attention

The core innovation - computing attention within local windows:

import torch
import torch.nn as nn

class WindowAttention(nn.Module):
    """
    Window-based multi-head self-attention module.
    """
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # QKV projection
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, mask=None):
        """
        Args:
            x: Input features (num_windows*B, N, C)
            mask: Attention mask for shifted windows (optional)
        """
        B_, N, C = x.shape

        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B_, num_heads, N, head_dim)
        q, k, v = qkv.unbind(0)

        # Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * self.scale

        if mask is not None:
            attn = attn + mask

        attn = attn.softmax(dim=-1)

        # Apply attention to values
        out = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        return self.proj(out)

Step 3: Swin Transformer Block

Each block performs windowed attention followed by an MLP:

import torch
import torch.nn as nn

class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer Block with window-based or shifted window attention.
    """
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution  # (H, W)
        self.window_size = window_size
        self.shift_size = shift_size
        self.num_heads = num_heads

        # Layer normalization and attention
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size, num_heads)

        # MLP (Feed-forward network)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x):
        """
        Args:
            x: Input features (B, H*W, C)
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "Input feature size doesn't match"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # Cyclic shift for shifted window attention
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # Partition into windows
        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        # Window attention
        attn_windows = self.attn(x_windows)

        # Merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)

        # Reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(B, H * W, C)

        # Residual connection
        x = shortcut + x

        # MLP with residual
        x = x + self.mlp(self.norm2(x))

        return x

Step 4: Patch Embedding

Convert the input image into patch embeddings:

import torch.nn as nn

class PatchEmbed(nn.Module):
    """
    Image to Patch Embedding using convolution.
    """
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = (img_size // patch_size, img_size // patch_size)
        self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]

        self.proj = nn.Conv2d(in_chans, embed_dim,
                             kernel_size=patch_size,
                             stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """
        Args:
            x: Input image (B, C, H, W)
        Returns:
            Patch embeddings (B, num_patches, embed_dim)
        """
        B, C, H, W = x.shape
        x = self.proj(x)  # (B, embed_dim, H/patch_size, W/patch_size)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        x = self.norm(x)
        return x

Step 5: Swin Stage

A stage consists of multiple Swin Transformer blocks with alternating window shifts:

import torch.nn as nn

class SwinStage(nn.Module):
    """
    A Swin Transformer stage with multiple blocks.
    """
    def __init__(self, dim, depth, num_heads, input_resolution, window_size=7):
        super().__init__()
        self.dim = dim
        self.depth = depth

        # Build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                input_resolution=input_resolution,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2
            )
            for i in range(depth)
        ])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

Step 6: Patch Merging (Downsampling)

Downsample the feature maps between stages:

import torch
import torch.nn as nn

class PatchMerging(nn.Module):
    """
    Patch Merging Layer (downsampling by 2x).
    """
    def __init__(self, input_resolution, dim):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x):
        """
        Args:
            x: Input features (B, H*W, C)
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "Input size doesn't match"

        x = x.view(B, H, W, C)

        # Downsample by taking every other pixel
        x0 = x[:, 0::2, 0::2, :]  # Top-left
        x1 = x[:, 1::2, 0::2, :]  # Bottom-left
        x2 = x[:, 0::2, 1::2, :]  # Top-right
        x3 = x[:, 1::2, 1::2, :]  # Bottom-right

        x = torch.cat([x0, x1, x2, x3], dim=-1)  # (B, H/2, W/2, 4*C)
        x = x.view(B, -1, 4 * C)

        x = self.norm(x)
        x = self.reduction(x)  # (B, H/2*W/2, 2*C)

        return x

Step 7: Complete Swin Transformer

Put it all together:

import torch
import torch.nn as nn

class SwinTransformer(nn.Module):
    """
    Swin Transformer for image classification.
    """
    def __init__(
        self,
        img_size=224,
        patch_size=4,
        in_chans=3,
        num_classes=1000,
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7
    ):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

        # Patch embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim
        )

        patches_resolution = self.patch_embed.patches_resolution

        # Build stages
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = SwinStage(
                dim=int(embed_dim * 2 ** i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                input_resolution=(
                    patches_resolution[0] // (2 ** i_layer),
                    patches_resolution[1] // (2 ** i_layer)
                ),
                window_size=window_size
            )
            self.layers.append(layer)

            # Add patch merging between stages (except last)
            if i_layer < self.num_layers - 1:
                self.layers.append(
                    PatchMerging(
                        input_resolution=(
                            patches_resolution[0] // (2 ** i_layer),
                            patches_resolution[1] // (2 ** i_layer)
                        ),
                        dim=int(embed_dim * 2 ** i_layer)
                    )
                )

        # Classification head
        self.norm = nn.LayerNorm(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes)

    def forward(self, x):
        """
        Args:
            x: Input images (B, 3, H, W)
        Returns:
            Class logits (B, num_classes)
        """
        x = self.patch_embed(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)
        x = self.avgpool(x.transpose(1, 2))  # (B, C, 1)
        x = torch.flatten(x, 1)
        x = self.head(x)

        return x

Usage Example

Here's how to use the Swin Transformer:

import torch
from swin_transformer import SwinTransformer

# Create model
model = SwinTransformer(
    img_size=224,
    patch_size=4,
    in_chans=3,
    num_classes=1000,
    embed_dim=96,
    depths=[2, 2, 6, 2],      # Swin-Tiny configuration
    num_heads=[3, 6, 12, 24],
    window_size=7
)

# Random input
x = torch.randn(1, 3, 224, 224)

# Forward pass
output = model(x)
print(f"Output shape: {output.shape}")  # (1, 1000)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_params:,}")

Model Variants

The paper introduces several model sizes:

Model Embed Dim Depths Heads Params
Swin-T 96 [2,2,6,2] [3,6,12,24] 28M
Swin-S 96 [2,2,18,2] [3,6,12,24] 50M
Swin-B 128 [2,2,18,2] [4,8,16,32] 88M
Swin-L 192 [2,2,18,2] [6,12,24,48] 197M

Key Takeaways

  1. Local + Global: Swin Transformer efficiently combines local window attention with cross-window connections via shifting.
  2. Hierarchical Design: Multi-scale features make it suitable for dense prediction tasks (detection, segmentation).
  3. Efficiency: Linear complexity w.r.t. image size enables processing of high-resolution images.
  4. Versatility: Can serve as a general-purpose backbone for various vision tasks.

Resources


Further Reading

To extend this implementation, consider:

  • Adding relative position bias for better spatial awareness
  • Implementing attention masking for shifted windows properly
  • Adding stochastic depth for regularization
  • Creating training scripts with ImageNet dataset
  • Implementing downstream tasks (object detection with Swin Transformer backbone)

Conclusion

The Swin Transformer represents a significant advancement in vision transformers, bridging the gap between CNNs and transformers through its hierarchical, efficient design. This implementation provides a foundation for understanding and experimenting with the architecture.

Whether you're building a classification model, an object detection system, or exploring new vision architectures, the Swin Transformer's design principles offer valuable insights into efficient attention mechanisms for computer vision.