트랜스포머 pytorch 코드분석 (Attention is All You Need)

2024. 8. 13. 01:58머신러닝&딥러닝/Transformers

728x90

드디어 대망의... AI 안 하는 사람도 한번쯤은 들어봤다는 "그 논문"이다. 본 포스팅에서는 Attention is All You Need (1) 에서 제안한, 자연어와 컴퓨터비전 등 AI 분야에 막대한 영향을 끼친 트랜스포머 아키텍처를 pytorch로 구현한 코드를 분석하고자 한다. 코드는 아래 영상을 참고하였다. 논문에 대한 이론적인 이해는 어느 정도 되었다고 가정하고 코드를 분석하려고 하니, 논문 내용을 처음 보는 사람은 먼저 읽고 오길 바란다.
 
(1) A. Vaswani, Attention is All You Need, NIPS, 2017
 
https://www.youtube.com/watch?v=U0s0f995w14&t=1724s

 
 

전체 구조

  1. SelfAttention 클래스 정의
  2. TransformerBlock 클래스 정의
  3. Encoder 클래스 정의
  4. DecoderBlock 클래스 정의
  5. Decoder 클래스 정의
  6. Transformer 클래스 정의

1. SelfAttention

이 논문에서 제안한 self attention 구조는 Scaled Dot-Prouct Attention 이라고도 하는데, 수식으로는 아래와 같다.

Q와 K의 관계를 이용하여 attention map을 찍고, embedding dimension(dk)을 이용하여 scaling한 후 V를 곱하는 것이다.
 

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size/heads
        
        assert (self.head_dim * heads == embed_size), "Embed size needs to be divided by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        values = values.reshape(N, value_len ,self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)


        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        energy = torch.einsum('nqhd,nkhd->nhqk', [queries, keys])
        # queries : (N, query_len, heads, head_dim)
        # keys_shape : (N, key_len, heads, head_dim)
        # energy_shape : (N, heads, query_len, key_len)
        if mask is not None:
            energy = energy.masked_fill(mask==0, float('-1e20'))
        attention = torch.softmax(energy/(self.embed_size**(1/2)), dim=3)

        out = torch.einsum('nhql, nlhd -> nqhd', [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim
        )
        # attention shape : (N, heads, query_len, key_len)
        # values shape : (N, value len, heads, head_dim)

        out = self.fc_out(out)
        return out

 
SelfAttention은 인자로 embed_size 와 heads를 갖는다. heads는 말 그대로, Multi-head attention에 사용되는 Self attention head의 수이다.

위 그림의 h가 곧 인자 heads에 해당하는 것이다. Multi-head attention에서는 각 head의 self attention 연산 결과가 아래와 같이 concat되는데, 이때 head 하나당 head_dim개의 차원을 갖는다고 하면 concat 이후의 차원은 head_dim * heads가 될 것이다. 이것이 embed_size이다.

 
따라서 embed_size는 반드시 head_dim * heads와 같아야 함을 assert문을 통해 선언하였다. 

또한, q, k, v는 self attention 연산 전 linear layer를 거쳐 들어간다. H개의 SelfAttention head가 있을 때, 각 head는 전체 q, k, v를 H조각으로 쪼개어 보게 되고, 쪼갠 조각의 크기는 head_dim이다. 따라서, linear layer는 input과 output 모두 head_dim의 길이를 갖는다.
 
SelfAttention의 forward 함수에 들어가는 q, k, v는 전체 MultiHead에 들어가는 큰 덩어리이다. 예를 들어, 10개의 문장이 input으로 들어온다면, N = 10이 될 것이고 이는 batch size와 유사하게 이해할 수 있다. 또한, 문장의 길이에 해당하는 값이 query_len이다. 각 q, k, v는 reshape 함수를 이용해 (N, L, H, D) 형태의 4차원 텐서로 바뀌고, linear layer를 통과한다.
 
위 코드에서 energy로 정의된 값은, q와 k의 곱이다. 이를 구현하기 위해 torch의 강력한 텐서 곱 메소드인 einsum이 사용되었다.

energy = torch.einsum('nqhd,nkhd->nhqk', [queries, keys])

 
위 식은 q와 k의 4차원 형태에서, head_dim을 축으로 하여 텐서 곱을 수행한다는 것을 나타낸다.
 
mask는 energy의 일부를 마스킹하는 역할을 한다. Mask가 없으면 energy에 마스크가 씌워지지 않은 채 scaling 단계로 넘어간다.
 
Scaling된 energy는 attention map에 해당한다. 여기에 v를 곱하기 위해 또 다시 einsum을 적용하였다.

out = torch.einsum('nhql, nlhd -> nqhd', [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim
        )

einsum 연산의 결과는 nqhd의 4차원 텐서인데, 이는 input 당시 query 텐서의 모양과 같다. 여기서 마지막 두 차원은 h와 d를 합쳐, embed_dim으로 만들기 위해 3차원으로 reshape한다. 이 텐서에 fc layer를 적용하면 SelfAttention 블럭이 완성된다.
 

2. TransformerBlock

TransformerBlock은 SelfAttention 에 FFN을 합치고 Norm을 적용한 것이라 할 수 있다. 따라서 위에서 정의한 SelfAttention을 그대로 사용하면 되기에 비교적 간단하다.

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward+x))

 
FFN(Feed Forward Network)은 은닉층 한 개를 가진 네트워크이다. 입력층과 출력층의 크기는 모두 embed_size이며 은닉층의 크기는 forward_expansion이라는 상수로 조절된다. 즉, embed_size의 몇 배가 은닉층의 크기가 될 것인지를 결정한다.
 
norm이 적용되는 이유는 skip connection을 보정하기 위해서이다. TransformerBlock에는 두 번의 skip connection이 적용되는데, attention 연산 결과에 query가 다시 더해져 x가 만들어지는 부분과 FFN 이후 x를 다시 더하는 부분이다. 이때 각각 LayerNorm이 적용된다.
 

3. Encoder

Encoder 블럭은 input으로 들어온 데이터를 임베딩하고, N번의 TransformerBlock을 적용하는 과정을 거친다.

class Encoder(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length,
    ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion = forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

 
Input으로 들어온 자연어를 임베딩하는 데에는 nn.Embedding이 쓰인다. 이는 각 단어에 해당하는 embedding vector를 딕셔너리 형태로 만들어주는 함수이다. embedding vector의 길이는 embed_size이고, 사전에 들어갈 단어의 개수는 src_vocab_size이다. 또한 positional embedding도 nn.Embedding을 사용하여, 들어오는 문장 중 최대 길이(max_length)에 맞추어 위치 임베딩값을 결정하였다. num_layers는 TransformerBlock을 반복할 횟수를 결정한다.
 

positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

 
positional encoding을 위해, 0, 1, ..., seq_length에 해당하는 숫자 배열을 만들고 N개만큼 expand하여 nn.Embedding을 적용하였다.

 
이는 논문 본문에서 위와 같이 사용했던 sinusoidal positional encoding과는 조금 다른 구현이다.
 

4. DecoderBlock

디코더는 입력 시퀀스를 받아 원하는 형태의 출력 시퀀스를 만드는 역할을 한다.

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out

 
DecoderBlock에는 src_mask와 trg_mask의 두 가지 종류 마스크가 등장한다. src_mask는 의미 없는 패딩 토큰을 마스킹하거나, 토큰에 중요도를 부과하기 위해 사용된다. trg_mask는 디코더가 미래의 단어를 참조하지 못하도록 하는 look-ahead mask 등에 사용된다.
 
DecoderBlock은 타겟 시퀀스 x를 입력받아 trg_mask를 적용한 self attention을 계산하고, skip connection을 더한 값을 새로운 query로 사용한다. 또한 encoder에서의 출력값을 받아서 key와 value를 사용한다. 이렇게 얻은 q, k, v를 TransformerBlock에 통과시켜 DecoderBlock의 최종 output을 얻는다.

5. Decoder

타겟 시퀀스 x를 encoder에서와 마찬가지로 word embedding + positional embedding하여 텐서로 만들고, DecoderBlock을 적용시킬 횟수 num_layers를 이용하여 Decoder를 구성한다.

class Decoder(nn.Module):
    def __init__(
            self,
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length
    ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device) for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out

6. Transformer

이제, Encoder와 Decoder를 조합하여 전체 트랜스포머 구조를 완성하자.

class Transformer(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            trg_vocab_size,
            src_pad_idx,
            trg_pad_idx,
            embed_size = 256,
            num_layers = 6,
            forward_expansion = 4,
            heads = 8,
            dropout = 0,
            device = 'cuda',
            max_length = 100
    ):
        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )
        self.decoder = Decoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
    
    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #(N, 1, 1, src_len)
        return src_mask.to(self.device)
    
    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        return trg_mask.to(self.device)
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

 
Transformer 클래스에는 mask를 만드는 두 함수 make_src_mask, make_trg_mask 두 개가 추가로 들어간다. src_mask는 패딩 영역을 마스킹하기 위해 쓰이고, trg_mask는 look-ahead mask 기능을 위해 torch.tril을 이용하여 텐서의 오른쪽 위 삼각형에 해당하는 영역은 모두 마스킹하였다.

반응형