-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransformer.py
More file actions
114 lines (97 loc) · 4.95 KB
/
transformer.py
File metadata and controls
114 lines (97 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
from numpy import einsum
import torch
class Encoder:
def __init__(self, text, vocab_size, d_model, n_heads, d_ff, n_layers, dropout):
self.text = text
self.vocab_size = vocab_size
self.d_model = d_model
self.n_heads = n_heads
self.d_ff = d_ff
self.n_layers = n_layers
self.dropout = dropout
self.__init_params()
def __init_params(self):
# hidden weights in FP32
d_k = d_v = self.d_model // self.n_heads
# n_layers x n_heads x d_model x d_k
self.Wk = torch.tensor(np.random.randn(self.n_layers, self.n_heads, self.d_model, d_k) * np.sqrt(2 / (self.d_model + self.d_model)), dtype=torch.float32)
self.Wq = torch.tensor(np.random.randn(self.n_layers, self.n_heads, self.d_model, d_k) * np.sqrt(2 / (self.d_model + self.d_model)), dtype=torch.float32)
self.Wv = torch.tensor(np.random.randn(self.n_layers, self.n_heads, self.d_model, d_v) * np.sqrt(2 / (self.d_model + self.d_model)), dtype=torch.float32)
self.Wo = torch.tensor(np.random.randn(self.n_layers, self.d_model, self.d_model) * np.sqrt(2 / (self.d_model + self.d_model)), dtype=torch.float32)
self.W1 = torch.tensor(np.random.randn(self.n_layers, self.d_model, self.d_ff) * np.sqrt(2 / (self.d_model + self.d_ff)), dtype=torch.float32)
self.W2 = torch.tensor(np.random.randn(self.n_layers, self.d_ff, self.d_model) * np.sqrt(2 / (self.d_ff + self.d_model)), dtype=torch.float32)
self.b1 = torch.tensor(np.random.randn(self.n_layers, self.d_ff, 1) * np.sqrt(2 / (self.d_model + self.d_ff)), dtype=torch.float32)
self.b2 = torch.tensor(np.random.randn(self.n_layers, self.d_model, 1) * np.sqrt(2 / (self.d_ff + self.d_model)), dtype=torch.float32)
# encoder steps
# 1) tokenize text
# 2) create embedding for each token
# 2.1) add positional encoding
# 3) pass through n_layers of encoder
# 3.1) multi-head attention
# 3.2) feed forward network
# 4) return encoded text seq_len x d_model
def forward(self):
print("Forward pass:")
# 1) tokenize text
tokens = self.tokenize(self.text)
# 2) create embedding for each token
embeddings = self.embedding(tokens)
# 3) pass through n_layers of encoder
encoded_text = self.encoder(embeddings)
return encoded_text
def tokenize(self, text):
# tokenize text
return text.split()
def embedding(self, tokens):
# create embedding for each token
return torch.rand(len(tokens), self.d_model)
def encoder(self, embeddings):
# pass through n_layers of encoder
for i in range(self.n_layers):
print(f" Layer {i} ")
embeddings = self.encoder_layer(embeddings,
self.Wk[i], self.Wq[i], self.Wv[i], self.Wo[i],
self.W1[i], self.b1[i], self.W2[i], self.b2[i])
return embeddings
def encoder_layer(self, embeddings, Wk, Wq, Wv, Wo, W1, b1, W2, b2):
# multi-head attention w/residual connection
embeddings = self.multi_head_attention(embeddings, Wk, Wq, Wv, Wo)
embeddings = embeddings + self.layer_norm(embeddings)
# feed forward network w/residual connection
embeddings = self.feed_forward_network(embeddings, W1, b1, W2, b2)
embeddings = embeddings + self.layer_norm(embeddings)
return embeddings
def multi_head_attention(self, embeddings, Wk, Wq, Wv, Wo):
# multi-head attention
# separate embeddings into n_heads
h = [self.attention(embeddings, Wk[i], Wq[i], Wv[i]) for i in range(self.n_heads)]
# concatenate heads
embeddings = torch.einsum('ik,kl->il', torch.cat(h, dim=1), Wo)
return embeddings
# parallelize attention computation
# A(K,Q,V) = softmax(QK^T / sqrt(d_k))V
# K = X * Wk
# Q = X * Wq
# V = X * Wv
def attention(self, embeddings, Wk, Wq, Wv):
K = torch.einsum('ik,kl->li', embeddings, Wk)
Q = torch.einsum('ik,kl->il', embeddings, Wq)
V = torch.einsum('ik,kl->il', embeddings, Wv)
embeddings = torch.einsum('ik,kl->il', torch.softmax(
torch.einsum('ik,kl->il',Q,K) / torch.sqrt(torch.tensor(self.d_model)), dim=1),V)
return embeddings
def feed_forward_network(self, embeddings, W1, b1, W2, b2):
# feed forward network
# FFN(X) = ReLU(XW1 + b1)W2 + b2
# broadcast bias to match shape
embeddings = torch.relu(torch.einsum('ik,kl->il', embeddings, W1) + b1.T)
embeddings = torch.einsum('ik,kl->il', embeddings, W2) + b2.T
return embeddings
def layer_norm(self, embeddings):
# layer normalization
embeddings = (embeddings - torch.mean(embeddings)) / torch.std(embeddings)
return embeddings
def main():
Encoder("Hello, World!", 100, 512, 8, 2048, 6, 0.1).forward()
main()