Skip to content

NNLM

toynlp.nnlm.model.NNLM

NNLM(config: ModelConfig)

Bases: Module

PARAMETER DESCRIPTION
config

ModelConfig, the model configuration.

TYPE: ModelConfig

Source code in toynlp/nnlm/model.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def __init__(
    self,
    config: ModelConfig,
):
    """
    The Neural Network Language Model (NNLM) model.

    Args:
        config: ModelConfig, the model configuration.
    """
    super(NNLM, self).__init__()
    self.with_direct_connection = config.with_direct_connection
    self.with_dropout = config.with_dropout
    # Embedding layer: |V| x m
    self.C = torch.nn.Embedding(config.vocab_size, config.embedding_dim)
    self.H = torch.nn.Linear(
        config.embedding_dim * (config.context_size - 1),
        config.hidden_dim,
        bias=False,
    )
    self.d = torch.nn.Parameter(torch.zeros(config.hidden_dim))
    self.U = torch.nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
    self.activation = torch.nn.Tanh()

    self.b = torch.nn.Parameter(torch.zeros(config.vocab_size))
    self.W = torch.nn.Linear(
        config.embedding_dim * (config.context_size - 1),
        config.vocab_size,
        bias=False,
    )

    self.dropout = torch.nn.Dropout(config.dropout_rate)

forward

forward(tokens: Tensor) -> Tensor

Forward pass of the model.

PARAMETER DESCRIPTION
tokens

torch.Tensor, (batch_size, seq_len-1), the input tokens.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

torch.Tensor, (batch_size, vocab_size), the logits.

Source code in toynlp/nnlm/model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
    """
    Forward pass of the model.

    Args:
        tokens: torch.Tensor, (batch_size, seq_len-1), the input tokens.

    Returns:
        torch.Tensor, (batch_size, vocab_size), the logits.

    """
    # tokens: (batch_size, seq_len-1) -> x: (batch_size, seq_len-1, embedding_dim)
    x = self.C(tokens)
    b, _, _ = x.shape
    # (batch_size, seq_len-1, embedding_dim) -> (batch_size, embedding_dim * (seq_len-1))
    x = x.reshape(b, -1)  # (batch_size, embedding_dim * (seq_len-1))
    if self.with_dropout:
        x = self.dropout(x)
    # (batch_size, embedding_dim * (seq_len-1)) -> (batch_size, vocab_size)
    x1 = self.b + self.U(
        self.activation(self.H(x) + self.d)
    )  # no direct connection
    if not self.with_direct_connection:
        x = x1
    else:
        x2 = self.W(x)
        x = x1 + x2
    # return logits
    return x