diff --git a/src/attention_model.py b/src/attention_model.py index 1616651..0446f9b 100644 --- a/src/attention_model.py +++ b/src/attention_model.py @@ -22,13 +22,16 @@ def dot_product_attention( q (torch.Tensor): The query tensor of shape [batch, heads, out_length, d_k]. k (torch.Tensor): The key tensor of shape [batch, heads, out_length, d_k]. v (torch.Tensor): The value-tensor of shape [batch, heads, out_length, d_v]. + is_causal (bool): Whether to apply a causal mask. Returns: torch.Tensor: The attention values of shape [batch, heads, out_length, d_v] """ # TODO implement multi head attention. - # Use i.e. torch.transpose, torch.sqrt, torch.tril, torch.exp, torch.inf - # as well as torch.nn.functional.softmax . + # Hint: You will likely need torch.transpose, torch.sqrt, torch.tril, + # torch.inf, and torch.nn.functional.softmax. + # For applying the causal mask, you can either try using torch.exp or torch.masked_fill. + attention_out = None return attention_out diff --git a/src/util.py b/src/util.py index c0f8cda..3bef9bc 100644 --- a/src/util.py +++ b/src/util.py @@ -91,7 +91,7 @@ def convert(sequences: torch.Tensor, inv_vocab: dict) -> list: """Convert an array of character-integers to a list of letters. Args: - sequences (jnp.ndarray): An integer array, which represents characters. + sequences (torch.Tensor): An integer array, which represents characters. inv_vocab (dict): The dictonary with the integer to char mapping. Returns: