YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

import math

import torch import torch.nn as nn import torch.nn.functional as F from ufal.chu_liu_edmonds import chu_liu_edmonds from transformers import AutoModel

class GEGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) return x * F.gelu(gate)

class Classifier(nn.Module): def init(self, hidden_size, vocab_size, dropout): super().init()

    self.transform = nn.Sequential(
        nn.Linear(hidden_size, 2 * 2560),
        GEGLU(),
        nn.LayerNorm(2560, elementwise_affine=False),
        nn.Linear(2560, hidden_size, bias=False),
        nn.Dropout(dropout),
        nn.Linear(hidden_size, vocab_size)
    )
    self.initialize(hidden_size)

def initialize(self, hidden_size):
    std = math.sqrt(2.0 / (5.0 * hidden_size))
    nn.init.trunc_normal_(self.transform[0].weight, mean=0.0, std=std, a=-2 * std, b=2 * std)
    nn.init.trunc_normal_(self.transform[-1].weight, mean=0.0, std=std, a=-2 * std, b=2 * std)
    self.transform[0].bias.data.zero_()
    self.transform[-1].bias.data.zero_()

def forward(self, x):
    return self.transform(x)

class ZeroClassifier(nn.Module): def forward(self, x): output = torch.zeros(x.size(0), x.size(1), 2, device=x.device, dtype=x.dtype) output[:, :, 0] = 1.0 output[:, :, 1] = -1.0 return output

class EdgeClassifier(nn.Module): def init(self, hidden_size, dep_hidden_size, vocab_size, dropout): super().init()

    self.head_dep_transform = nn.Sequential(
        nn.Linear(hidden_size, 2 * 2560),
        GEGLU(),
        nn.LayerNorm(2560, elementwise_affine=False),
        nn.Linear(2560, hidden_size, bias=False),
        nn.Dropout(dropout)
    )
    self.head_root_transform = nn.Sequential(
        nn.Linear(hidden_size, 2 * 2560),
        GEGLU(),
        nn.LayerNorm(2560, elementwise_affine=False),
        nn.Linear(2560, hidden_size, bias=False),
        nn.Dropout(dropout)
    )
    self.head_bilinear = nn.Parameter(torch.zeros(hidden_size, hidden_size))
    self.head_linear_dep = nn.Linear(hidden_size, 1, bias=False)
    self.head_linear_root = nn.Linear(hidden_size, 1, bias=False)

    self.dep_dep_transform = nn.Sequential(
        nn.Linear(hidden_size, 2 * 2560),
        GEGLU(),
        nn.LayerNorm(2560, elementwise_affine=False),
        nn.Linear(2560, dep_hidden_size, bias=False),
        nn.Dropout(dropout)
    )
    self.dep_root_transform = nn.Sequential(
        nn.Linear(hidden_size, 2 * 2560),
        GEGLU(),
        nn.LayerNorm(2560, elementwise_affine=False),
        nn.Linear(2560, dep_hidden_size, bias=False),
        nn.Dropout(dropout)
    )
    self.dep_bilinear = nn.Parameter(torch.zeros(dep_hidden_size, dep_hidden_size, vocab_size))
    self.dep_linear_dep = nn.Linear(dep_hidden_size, vocab_size, bias=False)
    self.dep_linear_root = nn.Linear(dep_hidden_size, vocab_size, bias=False)
    self.dep_bias = nn.Parameter(torch.zeros(vocab_size))

    self.hidden_size = hidden_size
    self.dep_hidden_size = dep_hidden_size

    self.mask_value = float("-inf")
    self.initialize(hidden_size)

def initialize(self, hidden_size):
    std = math.sqrt(2.0 / (5.0 * hidden_size))
    nn.init.trunc_normal_(self.head_dep_transform[0].weight, mean=0.0, std=std, a=-2 * std, b=2 * std)
    nn.init.trunc_normal_(self.head_root_transform[0].weight, mean=0.0, std=std, a=-2 * std, b=2 * std)
    nn.init.trunc_normal_(self.dep_dep_transform[0].weight, mean=0.0, std=std, a=-2 * std, b=2 * std)
    nn.init.trunc_normal_(self.dep_root_transform[0].weight, mean=0.0, std=std, a=-2 * std, b=2 * std)

    nn.init.trunc_normal_(self.head_linear_dep.weight, mean=0.0, std=std, a=-2 * std, b=2 * std)
    nn.init.trunc_normal_(self.head_linear_root.weight, mean=0.0, std=std, a=-2 * std, b=2 * std)
    nn.init.trunc_normal_(self.dep_linear_dep.weight, mean=0.0, std=std, a=-2 * std, b=2 * std)
    nn.init.trunc_normal_(self.dep_linear_root.weight, mean=0.0, std=std, a=-2 * std, b=2 * std)

    self.head_dep_transform[0].bias.data.zero_()
    self.head_root_transform[0].bias.data.zero_()
    self.dep_dep_transform[0].bias.data.zero_()
    self.dep_root_transform[0].bias.data.zero_()

def forward(self, head_x, dep_x, lengths, head_gold=None):
    head_dep = self.head_dep_transform(head_x[:, 1:, :])
    head_root = self.head_root_transform(head_x)
    head_prediction = torch.einsum("bkn,nm,blm->bkl", head_dep, self.head_bilinear,
                                   head_root / math.sqrt(self.hidden_size)) \
                      + self.head_linear_dep(head_dep) + self.head_linear_root(head_root).transpose(1, 2)

    mask = (torch.arange(head_x.size(1), device=head_x.device).unsqueeze(0) >= lengths.unsqueeze(1)).unsqueeze(1)
    mask = mask | (torch.ones(head_x.size(1) - 1, head_x.size(1), dtype=torch.bool, device=head_x.device).tril(
        1) & torch.ones(head_x.size(1) - 1, head_x.size(1), dtype=torch.bool, device=head_x.device).triu(1))
    head_prediction = head_prediction.masked_fill(mask, self.mask_value)

    if head_gold is None:
        head_logp = head_prediction
        head_logp = F.pad(head_logp, (0, 0, 1, 0), value=torch.nan).cpu()
        head_gold = []
        for i, length in enumerate(lengths.tolist()):
            head = self.max_spanning_tree(head_logp[i, :length, :length])
            head = head + ((head_x.size(1) - 1) - len(head)) * [0]
            head_gold.append(torch.tensor(head))
        head_gold = torch.stack(head_gold).to(head_x.device)

    dep_dep = self.dep_dep_transform(dep_x[:, 1:])
    dep_root = dep_x.gather(1, head_gold.unsqueeze(-1).expand(-1, -1, dep_x.size(-1)).clamp(min=0))
    dep_root = self.dep_root_transform(dep_root)
    dep_prediction = torch.einsum("btm,mnl,btn->btl", dep_dep, self.dep_bilinear,
                                  dep_root / math.sqrt(self.dep_hidden_size)) \
                     + self.dep_linear_dep(dep_dep) + self.dep_linear_root(dep_root) + self.dep_bias

    return head_prediction, dep_prediction, head_gold

def max_spanning_tree(self, weight_matrix):
    weight_matrix = weight_matrix.clone()
    weight_matrix[weight_matrix == self.mask_value] = torch.nan
    # weight_matrix[:, 0] = torch.nan

    # we need to make sure that the root is the parent of a single node
    # first, we try to use the default weights, it should work in most cases
    parents, _ = chu_liu_edmonds(weight_matrix.numpy().astype(float))

    assert parents[0] == -1, f"{parents}\n{weight_matrix}"
    parents = parents[1:]

    # check if the root is the parent of a single node
    if parents.count(0) == 1:
        return parents

    # if not, we need to modify the weights and try all possibilities
    # we try to find the node that is the parent of the root
    best_score = float("-inf")
    best_parents = None

    for i in range(len(parents)):
        weight_matrix_mod = weight_matrix.clone()
        weight_matrix_mod[:i + 1, 0] = torch.nan
        weight_matrix_mod[i + 2:, 0] = torch.nan
        parents, score = chu_liu_edmonds(weight_matrix_mod.numpy().astype(float))
        parents = parents[1:]

        if score > best_score:
            best_score = score
            best_parents = parents

    def print_whole_matrix(matrix):
        for i in range(matrix.shape[0]):
            print(" ".join([str(x) for x in matrix[i]]))

    assert best_parents is not None, f"{best_parents}\n{print_whole_matrix(weight_matrix)}"
    return best_parents

class Model(nn.Module): def init(self, args, dataset): super().init() if "hplt" in args.model_path: self.bert = AutoModel.from_pretrained(args.model_path, trust_remote_code=True) else: self.bert = AutoModel.from_pretrained(args.model_path) try: self.n_layers = self.bert.config.num_hidden_layers except AttributeError: self.n_layers = self.bert.config.num_layers args.hidden_size = self.bert.config.hidden_size

    self.dropout = nn.Dropout(args.dropout)
    self.layer_norm = nn.LayerNorm(args.hidden_size, elementwise_affine=False)
    self.upos_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
    self.xpos_layer_score = nn.Parameter(torch.zeros(self.n_layers + 2, dtype=torch.float))
    self.feats_layer_score = nn.Parameter(torch.zeros(self.n_layers + 2, dtype=torch.float))
    self.lemma_layer_score = nn.Parameter(torch.zeros(self.n_layers + 2, dtype=torch.float))
    self.head_layer_score = nn.Parameter(torch.zeros(self.n_layers + 2, dtype=torch.float))
    self.dep_layer_score = nn.Parameter(torch.zeros(self.n_layers + 2, dtype=torch.float))

    self.upos_embedding = nn.Embedding(len(dataset.upos_vocab), args.hidden_size)
    std = math.sqrt(2.0 / (5.0 * args.hidden_size))
    nn.init.trunc_normal_(self.upos_embedding.weight, mean=0.0, std=std, a=-2 * std, b=2 * std)

    self.lemma_classifier = nn.ModuleDict({
        cls: Classifier(args.hidden_size, max(len(dataset.lemma_vocab[cls]) - 1, 1), args.dropout) if len(
            dataset.lemma_vocab[cls]) > 2 else ZeroClassifier()
        for cls in dataset.lemma_vocab.keys()
    })
    self.upos_classifier = Classifier(args.hidden_size, max(len(dataset.upos_vocab) - 1, 1), args.dropout) if len(
        dataset.upos_vocab) > 2 else ZeroClassifier()
    self.xpos_classifier = Classifier(args.hidden_size, max(len(dataset.xpos_vocab) - 1, 1), args.dropout) if len(
        dataset.xpos_vocab) > 2 else ZeroClassifier()
    self.feats_classifier = Classifier(args.hidden_size, max(len(dataset.feats_vocab) - 1, 1), args.dropout) if len(
        dataset.feats_vocab) > 2 else ZeroClassifier()
    self.edge_classifier = EdgeClassifier(args.hidden_size, 128, max(len(dataset.arc_dep_vocab) - 1, 1),
                                          args.dropout)
    self.aux_feats_classifiers = nn.ModuleDict({
        cls: Classifier(args.hidden_size, max(len(dataset.feats_classes_vocab[cls]) - 1, 1), args.dropout) if len(
            dataset.feats_classes_vocab[cls]) > 2 else ZeroClassifier()
        for cls in dataset.feats_classes_vocab
    })

def forward(self, x, alignment_mask, subword_lengths, word_lengths, upos_gold=None, head_gold=None):
    padding_mask = torch.arange(x.size(1), device=x.device).unsqueeze(0) < subword_lengths.unsqueeze(1)
    print(padding_mask.device)
    x = self.bert(x, padding_mask, output_hidden_states=True).hidden_states
    x = torch.stack(x, dim=0)
    x = torch.einsum("lbsd,bst->lbtd", x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).unsqueeze(0).clamp(
        min=1.0)
    # upos_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.upos_layer_score, dim=0))
    upos_x = (x[:, :, 1:-1, :] * torch.softmax(self.upos_layer_score, dim=0).view(-1, 1, 1, 1)).sum(0)
    upos_x = self.dropout(self.layer_norm(upos_x))
    upos_preds = self.upos_classifier(upos_x)

    if upos_gold is None:
        upos_gold = upos_preds.argmax(-1)

    upos_embedding = self.upos_embedding(upos_gold.clamp(min=0))
    upos_embedding = F.pad(upos_embedding, (0, 0, 1, 1), value=0.0)
    x = torch.cat([x, upos_embedding.unsqueeze(0)], dim=0)

    xpos_x = (x[:, :, 1:-1, :] * torch.softmax(self.xpos_layer_score, dim=0).view(-1, 1, 1, 1)).sum(0)
    feats_x = (x[:, :, 1:-1, :] * torch.softmax(self.feats_layer_score, dim=0).view(-1, 1, 1, 1)).sum(0)
    lemma_x = (x[:, :, 1:-1, :] * torch.softmax(self.lemma_layer_score, dim=0).view(-1, 1, 1, 1)).sum(0)
    head_x = (x[:, :, 0:-1, :] * torch.softmax(self.head_layer_score, dim=0).view(-1, 1, 1, 1)).sum(0)
    dep_x = (x[:, :, 0:-1, :] * torch.softmax(self.dep_layer_score, dim=0).view(-1, 1, 1, 1)).sum(0)

    xpos_x = self.dropout(self.layer_norm(xpos_x))
    feats_x = self.dropout(self.layer_norm(feats_x))
    lemma_x = self.dropout(self.layer_norm(lemma_x))
    head_x = self.dropout(self.layer_norm(head_x))
    dep_x = self.dropout(self.layer_norm(dep_x))

    lemma_preds = {
        cls: classifier(lemma_x)
        for cls, classifier in self.lemma_classifier.items()
    }
    xpos_preds = self.xpos_classifier(xpos_x)
    feats_preds = self.feats_classifier(feats_x)
    head_prediction, dep_prediction, head_liu = self.edge_classifier(head_x, dep_x, word_lengths, head_gold)
    aux_feats_prediction = {
        cls: classifier(feats_x)
        for cls, classifier in self.aux_feats_classifiers.items()
    }

    return lemma_preds, upos_preds, xpos_preds, feats_preds, aux_feats_prediction, head_prediction, dep_prediction, head_liu
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including HPLT/hplt_gpt_bert_base_3_0_ron_Latn-UD