mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
333b1a308b
* Draft layer for BILUO actions * Fixes to biluo layer * WIP on BILUO layer * Add tests for BILUO layer * Format * Fix transitions * Update test * Link in the simple_ner * Update BILUO tagger * Update __init__ * Import simple_ner * Update test * Import * Add files * Add config * Fix label passing for BILUO and tagger * Fix label handling for simple_ner component * Update simple NER test * Update config * Hack train script * Update BILUO layer * Fix SimpleNER component * Update train_from_config * Add biluo_to_iob helper * Add IOB layer * Add IOBTagger model * Update biluo layer * Update SimpleNER tagger * Update BILUO * Read random seed in train-from-config * Update use of normal_init * Fix normalization of gradient in SimpleNER * Update IOBTagger * Remove print * Tweak masking in BILUO * Add dropout in SimpleNER * Update thinc * Tidy up simple_ner * Fix biluo model * Unhack train-from-config * Update setup.cfg and requirements * Add tb_framework.py for parser model * Try to avoid memory leak in BILUO * Move ParserModel into spacy.ml, avoid need for subclass. * Use updated parser model * Remove incorrect call to model.initializre in PrecomputableAffine * Update parser model * Avoid divide by zero in tagger * Add extra dropout layer in tagger * Refine minibatch_by_words function to avoid oom * Fix parser model after refactor * Try to avoid div-by-zero in SimpleNER * Fix infinite loop in minibatch_by_words * Use SequenceCategoricalCrossentropy in Tagger * Fix parser model when hidden layer * Remove extra dropout from tagger * Add extra nan check in tagger * Fix thinc version * Update tests and imports * Fix test * Update test * Update tests * Fix tests * Fix test Co-authored-by: Ines Montani <ines@ines.io>
110 lines
4.2 KiB
Python
110 lines
4.2 KiB
Python
"""Thinc layer to do simpler transition-based parsing, NER, etc."""
|
|
from typing import List, Tuple, Dict, Optional
|
|
import numpy
|
|
from thinc.api import Ops, Model, with_array, softmax_activation, padded2list
|
|
from thinc.api import to_numpy
|
|
from thinc.types import Padded, Ints1d, Ints3d, Floats2d, Floats3d
|
|
|
|
from ..tokens import Doc
|
|
|
|
|
|
def BILUO() -> Model[Padded, Padded]:
|
|
return Model(
|
|
"biluo",
|
|
forward,
|
|
init=init,
|
|
dims={"nO": None},
|
|
attrs={"get_num_actions": get_num_actions}
|
|
)
|
|
|
|
|
|
def init(model, X: Optional[Padded]=None, Y: Optional[Padded]=None):
|
|
if X is not None and Y is not None:
|
|
if X.data.shape != Y.data.shape:
|
|
# TODO: Fix error
|
|
raise ValueError("Mismatched shapes (TODO: Fix message)")
|
|
model.set_dim("nO", X.data.shape[2])
|
|
elif X is not None:
|
|
model.set_dim("nO", X.data.shape[2])
|
|
elif Y is not None:
|
|
model.set_dim("nO", Y.data.shape[2])
|
|
elif model.get_dim("nO") is None:
|
|
raise ValueError("Dimension unset for BILUO: nO")
|
|
|
|
|
|
def forward(model: Model[Padded, Padded], Xp: Padded, is_train: bool):
|
|
n_labels = (model.get_dim("nO") - 1) // 4
|
|
n_tokens, n_docs, n_actions = Xp.data.shape
|
|
# At each timestep, we make a validity mask of shape (n_docs, n_actions)
|
|
# to indicate which actions are valid next for each sequence. To construct
|
|
# the mask, we have a state of shape (2, n_actions) and a validity table of
|
|
# shape (2, n_actions+1, n_actions). The first dimension of the state indicates
|
|
# whether it's the last token, the second dimension indicates the previous
|
|
# action, plus a special 'null action' for the first entry.
|
|
valid_transitions = model.ops.asarray(_get_transition_table(n_labels))
|
|
prev_actions = model.ops.alloc1i(n_docs)
|
|
# Initialize as though prev action was O
|
|
prev_actions.fill(n_actions - 1)
|
|
Y = model.ops.alloc3f(*Xp.data.shape)
|
|
masks = model.ops.alloc3f(*Y.shape)
|
|
max_value = Xp.data.max()
|
|
for t in range(Xp.data.shape[0]):
|
|
is_last = (Xp.lengths < (t+2)).astype("i")
|
|
masks[t] = valid_transitions[is_last, prev_actions]
|
|
# Don't train the out-of-bounds sequences.
|
|
masks[t, Xp.size_at_t[t]:] = 0
|
|
# Valid actions get 0*10e8, invalid get large negative value
|
|
Y[t] = Xp.data[t] + ((masks[t]-1) * max_value * 10)
|
|
prev_actions = Y[t].argmax(axis=-1)
|
|
|
|
def backprop_biluo(dY: Padded) -> Padded:
|
|
dY.data *= masks
|
|
return dY
|
|
|
|
return Padded(Y, Xp.size_at_t, Xp.lengths, Xp.indices), backprop_biluo
|
|
|
|
|
|
def get_num_actions(n_labels: int) -> int:
|
|
# One BEGIN action per label
|
|
# One IN action per label
|
|
# One LAST action per label
|
|
# One UNIT action per label
|
|
# One OUT action
|
|
return n_labels + n_labels + n_labels + n_labels + 1
|
|
|
|
|
|
def _get_transition_table(
|
|
n_labels: int, *, _cache: Dict[int, Floats3d] = {}
|
|
) -> Floats3d:
|
|
n_actions = get_num_actions(n_labels)
|
|
if n_actions in _cache:
|
|
return _cache[n_actions]
|
|
table = numpy.zeros((2, n_actions, n_actions), dtype="f")
|
|
B_start, B_end = (0, n_labels)
|
|
I_start, I_end = (B_end, B_end + n_labels)
|
|
L_start, L_end = (I_end, I_end + n_labels)
|
|
U_start, U_end = (L_end, L_end + n_labels)
|
|
# Using ranges allows us to set specific cells, which is necessary to express
|
|
# that only actions of the same label are valid continuations.
|
|
B_range = numpy.arange(B_start, B_end)
|
|
I_range = numpy.arange(I_start, I_end)
|
|
L_range = numpy.arange(L_start, L_end)
|
|
O_action = U_end
|
|
# If this is the last token and the previous action was B or I, only L
|
|
# of that label is valid
|
|
table[1, B_range, L_range] = 1
|
|
table[1, I_range, L_range] = 1
|
|
# If this isn't the last token and the previous action was B or I, only I or
|
|
# L of that label are valid.
|
|
table[0, B_range, I_range] = 1
|
|
table[0, B_range, L_range] = 1
|
|
table[0, I_range, I_range] = 1
|
|
table[0, I_range, L_range] = 1
|
|
# If this isn't the last token and the previous was L, U or O, B is valid
|
|
table[0, L_start:, :B_end] = 1
|
|
# Regardless of whether this is the last token, if the previous action was
|
|
# {L, U, O}, U and O are valid.
|
|
table[:, L_start:, U_start:] = 1
|
|
_cache[n_actions] = table
|
|
return table
|