Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 195 additions & 76 deletions src/python/entity_align/model/AlignCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,94 +18,185 @@
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.autograd import Variable
import os


#This model corresponds to AlignCNN in our paper
#First, strings converted to list of character embeddings
#Then, lstm runs over character embeddings
#lstm embeddings at last time stamp matrix multiplied
#Finally, cnn detects features in that matrix and outputs similarity score
def cuda(x):
return x.cuda() if "CUDA_VISIBLE_DEVICES" in os.environ else x


# This model corresponds to AlignCNN in our paper
# First, strings converted to list of character embeddings
# Then, lstm runs over character embeddings
# lstm embeddings at last time stamp matrix multiplied
# Finally, cnn detects features in that matrix and outputs similarity score
class AlignCNN(torch.nn.Module):
def __init__(self,config,vocab):
def __init__(self, config, vocab):
super(AlignCNN, self).__init__()

self.config = config
self.vocab = vocab

# Character embeddings
self.embedding = nn.Embedding(vocab.size+1, config.embedding_dim, padding_idx=0)
self.embedding = nn.Embedding(
vocab.size + 1, config.embedding_dim, padding_idx=0
)

# Sequence encoder of strings (LSTM)
self.rnn = nn.LSTM(config.embedding_dim, config.rnn_hidden_size, 1, bidirectional = config.bidirectional, batch_first = True)
self.rnn = nn.LSTM(
config.embedding_dim,
config.rnn_hidden_size,
1,
bidirectional=config.bidirectional,
batch_first=True,
)

if self.config.bidirectional:
self.num_directions = 2
else:
self.num_directions = 1

# Variables for initial states of LSTM (these are different for train and dev because dev might be of different batch sizes)
self.h0 = Variable(torch.zeros(self.num_directions, config.batch_size, config.rnn_hidden_size).cuda(), requires_grad=False)
self.c0 = Variable(torch.zeros(self.num_directions, config.batch_size, config.rnn_hidden_size).cuda(), requires_grad=False)
self.h0_dev = Variable(torch.zeros(self.num_directions, config.dev_batch_size, config.rnn_hidden_size).cuda(), requires_grad=False)
self.c0_dev = Variable(torch.zeros(self.num_directions, config.dev_batch_size, config.rnn_hidden_size).cuda(), requires_grad=False)

self.h0 = Variable(
cuda(
torch.zeros(
self.num_directions, config.batch_size, config.rnn_hidden_size
)
),
requires_grad=False,
)
self.c0 = Variable(
cuda(
torch.zeros(
self.num_directions, config.batch_size, config.rnn_hidden_size
)
),
requires_grad=False,
)
self.h0_dev = Variable(
cuda(
torch.zeros(
self.num_directions, config.dev_batch_size, config.rnn_hidden_size
)
),
requires_grad=False,
)
self.c0_dev = Variable(
cuda(
torch.zeros(
self.num_directions, config.dev_batch_size, config.rnn_hidden_size
)
),
requires_grad=False,
)

# Define the CNN used to score the alignment matrix
pool_output_height = int(np.floor(config.max_string_len/2.0))
pool_output_height = int(np.floor(config.max_string_len / 2.0))

# Select # of layers / increasing or decreasing filter size based on config
if config.num_layers == 4:
self.num_layers = 4
self.relu = nn.ReLU()
if config.increasing == True:
convlyr = nn.Conv2d(1, config.filter_count, 3, padding=1, stride=1)
convlyr2 = nn.Conv2d(config.filter_count, config.filter_count2, 5, padding=2, stride=1)
convlyr3 = nn.Conv2d(config.filter_count2, config.filter_count3, 5, padding=2, stride=1)
convlyr4 = nn.Conv2d(config.filter_count3, config.filter_count4, 7, padding=3, stride=1)
convlyr2 = nn.Conv2d(
config.filter_count, config.filter_count2, 5, padding=2, stride=1
)
convlyr3 = nn.Conv2d(
config.filter_count2, config.filter_count3, 5, padding=2, stride=1
)
convlyr4 = nn.Conv2d(
config.filter_count3, config.filter_count4, 7, padding=3, stride=1
)
else:
convlyr = nn.Conv2d(1, config.filter_count, 7, padding=3, stride=1)
convlyr2 = nn.Conv2d(config.filter_count, config.filter_count2, 5, padding=2, stride=1)
convlyr3 = nn.Conv2d(config.filter_count2, config.filter_count3, 5, padding=2, stride=1)
convlyr4 = nn.Conv2d(config.filter_count3, config.filter_count4, 3, padding=1, stride=1)
self.add_module("cnn2",convlyr2)
self.add_module("cnn3",convlyr3)
self.add_module("cnn4",convlyr4)
self.align_weights = nn.Parameter(torch.randn(config.filter_count3, pool_output_height, pool_output_height).cuda(),requires_grad=True)
convlyr2 = nn.Conv2d(
config.filter_count, config.filter_count2, 5, padding=2, stride=1
)
convlyr3 = nn.Conv2d(
config.filter_count2, config.filter_count3, 5, padding=2, stride=1
)
convlyr4 = nn.Conv2d(
config.filter_count3, config.filter_count4, 3, padding=1, stride=1
)
self.add_module("cnn2", convlyr2)
self.add_module("cnn3", convlyr3)
self.add_module("cnn4", convlyr4)
self.align_weights = nn.Parameter(
cuda(
torch.randn(
config.filter_count3, pool_output_height, pool_output_height
)
),
requires_grad=True,
)
elif config.num_layers == 3:
self.num_layers = 3
self.relu = nn.ReLU()
if config.increasing == True:
convlyr = nn.Conv2d(1, config.filter_count, 5, padding=2, stride=1)
convlyr2 = nn.Conv2d(config.filter_count, config.filter_count2, 5, padding=2, stride=1)
convlyr3 = nn.Conv2d(config.filter_count2, config.filter_count3, 7, padding=3, stride=1)
convlyr2 = nn.Conv2d(
config.filter_count, config.filter_count2, 5, padding=2, stride=1
)
convlyr3 = nn.Conv2d(
config.filter_count2, config.filter_count3, 7, padding=3, stride=1
)
else:
convlyr = nn.Conv2d(1, config.filter_count, 7, padding=3, stride=1)
convlyr2 = nn.Conv2d(config.filter_count, config.filter_count2, 5, padding=2, stride=1)
convlyr3 = nn.Conv2d(config.filter_count2, config.filter_count3, 5, padding=2, stride=1)
self.add_module("cnn2",convlyr2)
self.add_module("cnn3",convlyr3)
self.align_weights = nn.Parameter(torch.randn(config.filter_count3, pool_output_height, pool_output_height).cuda(),requires_grad=True)
convlyr2 = nn.Conv2d(
config.filter_count, config.filter_count2, 5, padding=2, stride=1
)
convlyr3 = nn.Conv2d(
config.filter_count2, config.filter_count3, 5, padding=2, stride=1
)
self.add_module("cnn2", convlyr2)
self.add_module("cnn3", convlyr3)
self.align_weights = nn.Parameter(
cuda(
torch.randn(
config.filter_count3, pool_output_height, pool_output_height
)
),
requires_grad=True,
)
elif config.num_layers == 2:
self.num_layers = 2
self.relu = nn.ReLU()
convlyr = nn.Conv2d(1, config.filter_count, 5, padding=2, stride=1)
convlyr2 = nn.Conv2d(config.filter_count, config.filter_count2, 3, padding=1, stride=1)
self.add_module("cnn2",convlyr2)
self.align_weights = nn.Parameter(torch.randn(config.filter_count2, pool_output_height, pool_output_height).cuda(),requires_grad=True)
convlyr2 = nn.Conv2d(
config.filter_count, config.filter_count2, 3, padding=1, stride=1
)
self.add_module("cnn2", convlyr2)
self.align_weights = nn.Parameter(
cuda(
torch.randn(
config.filter_count2, pool_output_height, pool_output_height
)
),
requires_grad=True,
)
else:
self.num_layers = 1
convlyr = nn.Conv2d(1, config.filter_count, 7, padding=3, stride=1)
self.align_weights = nn.Parameter(torch.randn(config.filter_count, pool_output_height, pool_output_height).cuda(),requires_grad=True)
self.add_module("cnn",convlyr)
self.align_weights = nn.Parameter(
cuda(
torch.randn(
config.filter_count, pool_output_height, pool_output_height
)
),
requires_grad=True,
)
self.add_module("cnn", convlyr)
# Define pooling
self.pool = nn.MaxPool2d((2, 2), stride=2)

# Vector of ones (used for loss)
self.ones = Variable(torch.ones(config.batch_size, 1).cuda())
self.ones = Variable(cuda(torch.ones(config.batch_size, 1)))

# Loss
self.loss = BCEWithLogitsLoss()

def compute_loss(self,source,pos,neg, source_len,pos_len,neg_len):
def compute_loss(self, source, pos, neg, source_len, pos_len, neg_len):
""" Compute the loss (BPR) for a batch of examples
:param source: Entity mentions
:param pos: True aliases of the Mentions
Expand All @@ -115,13 +206,14 @@ def compute_loss(self,source,pos,neg, source_len,pos_len,neg_len):
:param neg_len: lengths of negatives
:return:
"""
source_embed, src_mask = self.embed(source,source_len)
pos_embed, pos_mask = self.embed(pos,pos_len)
neg_embed, neg_mask = self.embed(neg,neg_len)
source_embed, src_mask = self.embed(source, source_len)
pos_embed, pos_mask = self.embed(pos, pos_len)
neg_embed, neg_mask = self.embed(neg, neg_len)
loss = self.loss(
self.score_pair_train(source_embed , pos_embed, src_mask, pos_mask)
- self.score_pair_train(source_embed , neg_embed, src_mask, neg_mask),
self.ones)
self.score_pair_train(source_embed, pos_embed, src_mask, pos_mask)
- self.score_pair_train(source_embed, neg_embed, src_mask, neg_mask),
self.ones,
)

return loss

Expand All @@ -136,22 +228,22 @@ def print_mm(self, src, tgt, src_len, tgt_len):
"""
source_embed, source_mask = self.embed_dev(src, src_len)
target_embed, target_mask = self.embed_dev(tgt, tgt_len)
return torch.bmm(source_embed,torch.transpose(target_embed, 2, 1))
return torch.bmm(source_embed, torch.transpose(target_embed, 2, 1))

def score_pair_train(self,src,tgt, src_mask, tgt_mask):
def score_pair_train(self, src, tgt, src_mask, tgt_mask):
"""
:param src: Batchsize by Max_String_Length
:param tgt: Batchsize by Max_String_Length
:param src_mask: Batchsize by Max_String_Length, binary mask corresponding to length of underlying str
:param tgt_mask: Batchsize by Max_String_Length, binary mask corresponding to length of underlying str
:return: Batchsize by 1
"""
multpld = torch.bmm(src,torch.transpose(tgt, 2, 1))
multpld = torch.bmm(src, torch.transpose(tgt, 2, 1))
src_mask = src_mask.unsqueeze(dim=2)
tgt_mask = tgt_mask.unsqueeze(dim=1)
mat_mask = torch.bmm(src_mask, tgt_mask)
multpld = torch.mul(multpld, mat_mask)
convd = self.cnn(multpld.unsqueeze(1)) #need num channels
convd = self.cnn(multpld.unsqueeze(1)) # need num channels
if self.num_layers > 1:
convd = self.relu(convd)
convd = self.cnn2(convd)
Expand All @@ -162,63 +254,90 @@ def score_pair_train(self,src,tgt, src_mask, tgt_mask):
convd = self.relu(convd)
convd = self.cnn4(convd)
convd_after_pooling = self.pool(convd)
#print(convd_after_pooling.size())
#print(self.align_weights.size())
output = torch.sum(self.align_weights.expand_as(convd_after_pooling) * convd_after_pooling, dim=3,keepdim=True)
output = torch.sum(output, dim=2,keepdim=True)
# print(convd_after_pooling.size())
# print(self.align_weights.size())
output = torch.sum(
self.align_weights.expand_as(convd_after_pooling) * convd_after_pooling,
dim=3,
keepdim=True,
)
output = torch.sum(output, dim=2, keepdim=True)
output = torch.squeeze(output, dim=3)
output = torch.squeeze(output, dim=2)
output = torch.sum(output, dim=1,keepdim=True)
output = torch.sum(output, dim=1, keepdim=True)

return output

def embed(self,string_mat, string_len):
def embed(self, string_mat, string_len):
"""
:param string_mat: Batch_size by max_string_len
:return: batch_size by embedding dim
"""
string_mat = torch.from_numpy(string_mat).cuda()
mask = Variable(torch.cuda.ByteTensor((string_mat > 0)).float())
string_mat = cuda(torch.from_numpy(string_mat))
mask = Variable(cuda(torch).ByteTensor((string_mat > 0)).float())
embed_token = self.embedding(Variable(string_mat))
final_emb, final_hn_cn = self.rnn(embed_token, (self.h0, self.c0))
return final_emb, mask

def embed_dev(self, string_mat, string_len, print_embed=False, batch_size = None):
def embed_dev(self, string_mat, string_len, print_embed=False, batch_size=None):
"""
:param string_mat: Batch_size by max_string_len
:return: batch_size by embedding dim
"""
string_mat = torch.from_numpy(string_mat).cuda()
mask = Variable(torch.cuda.ByteTensor((string_mat > 0)).float())
string_mat = cuda(torch.from_numpy(string_mat))
mask = Variable(cuda(torch).ByteTensor((string_mat > 0)).float())
if not batch_size:
this_batch_size = self.config.dev_batch_size
this_h0 = self.h0_dev
this_c0 = self.c0_dev
else:
print("irregular batch size {}".format(batch_size))
this_batch_size = batch_size
this_h0 = Variable(torch.zeros(self.num_directions, batch_size,
self.config.rnn_hidden_size).cuda(),
requires_grad=False)
this_c0 = Variable(torch.zeros(self.num_directions, batch_size,
self.config.rnn_hidden_size).cuda(),
requires_grad=False)
this_h0 = Variable(
cuda(
torch.zeros(
self.num_directions, batch_size, self.config.rnn_hidden_size
)
),
requires_grad=False,
)
this_c0 = Variable(
cuda(
torch.zeros(
self.num_directions, batch_size, self.config.rnn_hidden_size
)
),
requires_grad=False,
)
embed_token = self.embedding(Variable(string_mat))
if print_embed==True:
if print_embed == True:
return embed_token
final_emb, final_hn_cn = self.rnn(embed_token, (this_h0, this_c0))
return final_emb, mask

def score_dev_test_batch(self,batch_queries,
batch_query_lengths,
batch_targets,
batch_target_lengths,
batch_size):
def score_dev_test_batch(
self,
batch_queries,
batch_query_lengths,
batch_targets,
batch_target_lengths,
batch_size,
):
if batch_size == self.config.dev_batch_size:
source_embed,source_mask = self.embed_dev(batch_queries, batch_query_lengths)
target_embed,target_mask = self.embed_dev(batch_targets, batch_target_lengths)
source_embed, source_mask = self.embed_dev(
batch_queries, batch_query_lengths
)
target_embed, target_mask = self.embed_dev(
batch_targets, batch_target_lengths
)
else:
source_embed,source_mask = self.embed_dev(batch_queries, batch_query_lengths,batch_size=batch_size)
target_embed,target_mask = self.embed_dev(batch_targets, batch_target_lengths,batch_size=batch_size)
scores = self.score_pair_train(source_embed, target_embed,source_mask,target_mask)
source_embed, source_mask = self.embed_dev(
batch_queries, batch_query_lengths, batch_size=batch_size
)
target_embed, target_mask = self.embed_dev(
batch_targets, batch_target_lengths, batch_size=batch_size
)
scores = self.score_pair_train(
source_embed, target_embed, source_mask, target_mask
)
return scores
Loading