Open
Show file tree
Hide file tree
Changes from 1 commit
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Failed to load files.
Next Next commit
add projection layer to roberta encoder
  • Loading branch information
@parmeet
parmeet committedNov 24, 2021
commit db7b6dbe8eda1ec62f8c9601a7c280cf025d9d93
Original file line numberDiff line numberDiff line change
Expand Up@@ -10,6 +10,7 @@

from .modules import (
TransformerEncoder,
ProjectionLayer,
)
import logging
logger = logging.getLogger(__name__)
Expand All@@ -25,6 +26,8 @@ class RobertaEncoderConf:
num_attention_heads: int = 12
num_encoder_layers: int = 12
dropout: float = 0.1
projection_dim: Optional[int] = None
projection_dropout: Optional[float] = None
scaling: Optional[float] = None
normalize_before: bool = False

Expand All@@ -40,6 +43,8 @@ def __init__(
num_attention_heads: int,
num_encoder_layers: int,
dropout: float = 0.1,
projection_dim: Optional[int] = None,
projection_dropout: Optional[float] = None,
scaling: Optional[float] = None,
normalize_before: bool = False,
):
Expand All@@ -62,6 +67,10 @@ def __init__(
return_all_layers=False,
)

self.project = None
if projection_dim is not None:
self.project = ProjectionLayer(embed_dim=embedding_dim, projection_dim=projection_dim, dropout=projection_dropout)

@classmethod
def from_config(cls, config: RobertaEncoderConf):
return cls(**asdict(config))
Expand All@@ -73,6 +82,10 @@ def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
output = output.transpose(1, 0)
if mask is not None:
output = output[mask.to(torch.bool), :]

if self.project is not None:
output = self.project(output)

return output


Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -31,6 +31,27 @@ def _make_positions(self, tensor, pad_index: int):
return torch.cumsum(masked, dim=1) * masked + pad_index


class ProjectionLayer(Module):
def __init__(self,
embed_dim: int,
projection_dim: int,
dropout: Optional[float] = None) -> None:
super().__init__()

self.projection_layer = nn.Linear(embed_dim, projection_dim)
self.norm_layer = nn.LayerNorm(projection_dim)
if dropout is not None:
self.dropout_layer = nn.Dropout(dropout)
else:
self.dropout_layer = nn.Identity()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.projection_layer(x)
x = self.norm_layer(x)
x = self.dropout_layer(x)
return x


class ResidualMLP(Module):
def __init__(
self,
Expand Down