mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-13 05:07:03 +03:00
Merge the parser refactor into v4
(#10940)
* Try to fix doc.copy * Set dev version * Make vocab always own lexemes * Change version * Add SpanGroups.copy method * Fix set_annotations during Parser.update * Fix dict proxy copy * Upd version * Fix copying SpanGroups * Fix set_annotations in parser.update * Fix parser set_annotations during update * Revert "Fix parser set_annotations during update" This reverts commiteb138c89ed
. * Revert "Fix set_annotations in parser.update" This reverts commitc6df0eafd0
. * Fix set_annotations during parser update * Inc version * Handle final states in get_oracle_sequence * Inc version * Try to fix parser training * Inc version * Fix * Inc version * Fix parser oracle * Inc version * Inc version * Fix transition has_gold * Inc version * Try to use real histories, not oracle * Inc version * Upd parser * Inc version * WIP on rewrite parser * WIP refactor parser * New progress on parser model refactor * Prepare to remove parser_model.pyx * Convert parser from cdef class * Delete spacy.ml.parser_model * Delete _precomputable_affine module * Wire up tb_framework to new parser model * Wire up parser model * Uncython ner.pyx and dep_parser.pyx * Uncython * Work on parser model * Support unseen_classes in parser model * Support unseen classes in parser * Cleaner handling of unseen classes * Work through tests * Keep working through errors * Keep working through errors * Work on parser. 15 tests failing * Xfail beam stuff. 9 failures * More xfail. 7 failures * Xfail. 6 failures * cleanup * formatting * fixes * pass nO through * Fix empty doc in update * Hackishly fix resizing. 3 failures * Fix redundant test. 2 failures * Add reference version * black formatting * Get tests passing with reference implementation * Fix missing prints * Add missing file * Improve indexing on reference implementation * Get non-reference forward func working * Start rigging beam back up * removing redundant tests, cf #8106 * black formatting * temporarily xfailing issue 4314 * make flake8 happy again * mypy fixes * ensure labels are added upon predict * cleanup remnants from merge conflicts * Improve unseen label masking Two changes to speed up masking by ~10%: - Use a bool array rather than an array of float32. - Let the mask indicate whether a label was seen, rather than unseen. The mask is most frequently used to index scores for seen labels. However, since the mask marked unseen labels, this required computing an intermittent flipped mask. * Write moves costs directly into numpy array (#10163) This avoids elementwise indexing and the allocation of an additional array. Gives a ~15% speed improvement when using batch_by_sequence with size 32. * Temporarily disable ner and rehearse tests Until rehearse is implemented again in the refactored parser. * Fix loss serialization issue (#10600) * Fix loss serialization issue Serialization of a model fails with: TypeError: array(738.3855, dtype=float32) is not JSON serializable Fix this using float conversion. * Disable CI steps that require spacy.TransitionBasedParser.v2 After finishing the refactor, TransitionBasedParser.v2 should be provided for backwards compat. * Add back support for beam parsing to the refactored parser (#10633) * Add back support for beam parsing Beam parsing was already implemented as part of the `BeamBatch` class. This change makes its counterpart `GreedyBatch`. Both classes are hooked up in `TransitionModel`, selecting `GreedyBatch` when the beam size is one, or `BeamBatch` otherwise. * Use kwarg for beam width Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Avoid implicit default for beam_width and beam_density * Parser.{beam,greedy}_parse: ensure labels are added * Remove 'deprecated' comments Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Parser `StateC` optimizations (#10746) * `StateC`: Optimizations Avoid GIL acquisition in `__init__` Increase default buffer capacities on init Reduce C++ exception overhead * Fix typo * Replace `set::count` with `set::find` * Add exception attribute to c'tor * Remove unused import * Use a power-of-two value for initial capacity Use default-insert to init `_heads` and `_unshiftable` * Merge `cdef` variable declarations and assignments * Vectorize `example.get_aligned_parses` (#10789) * `example`: Vectorize `get_aligned_parse` Rename `numpy` import * Convert aligned array to lists before returning * Revert import renaming * Elide slice arguments when selecting the entire range * Tagger/morphologizer alignment performance optimizations (#10798) * `example`: Unwrap `numpy` scalar arrays before passing them to `StringStore.__getitem__` * `AlignmentArray`: Use native list as staging buffer for offset calculation * `example`: Vectorize `get_aligned` * Hoist inner functions out of `get_aligned` * Replace inline `if..else` clause in assignment statement * `AlignmentArray`: Use raw indexing into offset and data `numpy` arrays * `example`: Replace array unique value check with `groupby` * `example`: Correctly exclude tokens with no alignment in `_get_aligned_vectorized` Simplify `_get_aligned_non_vectorized` * `util`: Update `all_equal` docstring * Explicitly use `int32_t*` * Restore C CPU inference in the refactored parser (#10747) * Bring back the C parsing model The C parsing model is used for CPU inference and is still faster for CPU inference than the forward pass of the Thinc model. * Use C sgemm provided by the Ops implementation * Make tb_framework module Cython, merge in C forward implementation * TransitionModel: raise in backprop returned from forward_cpu * Re-enable greedy parse test * Return transition scores when forward_cpu is used * Apply suggestions from code review Import `Model` from `thinc.api` Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Use relative imports in tb_framework * Don't assume a default for beam_width * We don't have a direct dependency on BLIS anymore * Rename forwards to _forward_{fallback,greedy_cpu} * Require thinc >=8.1.0,<8.2.0 * tb_framework: clean up imports * Fix return type of _get_seen_mask * Move up _forward_greedy_cpu * Style fixes. * Lower thinc lowerbound to 8.1.0.dev0 * Formatting fix Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Reimplement parser rehearsal function (#10878) * Reimplement parser rehearsal function Before the parser refactor, rehearsal was driven by a loop in the `rehearse` method itself. For each parsing step, the loops would: 1. Get the predictions of the teacher. 2. Get the predictions and backprop function of the student. 3. Compute the loss and backprop into the student. 4. Move the teacher and student forward with the predictions of the student. In the refactored parser, we cannot perform search stepwise rehearsal anymore, since the model now predicts all parsing steps at once. Therefore, rehearsal is performed in the following steps: 1. Get the predictions of all parsing steps from the student, along with its backprop function. 2. Get the predictions from the teacher, but use the predictions of the student to advance the parser while doing so. 3. Compute the loss and backprop into the student. To support the second step a new method, `advance_with_actions` is added to `GreedyBatch`, which performs the provided parsing steps. * tb_framework: wrap upper_W and upper_b in Linear Thinc's Optimizer cannot handle resizing of existing parameters. Until it does, we work around this by wrapping the weights/biases of the upper layer of the parser model in Linear. When the upper layer is resized, we copy over the existing parameters into a new Linear instance. This does not trigger an error in Optimizer, because it sees the resized layer as a new set of parameters. * Add test for TransitionSystem.apply_actions * Better FIXME marker Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Fixes from Madeesh * Apply suggestions from Sofie Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Remove useless assignment Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Rename some identifiers in the parser refactor (#10935) * Rename _parseC to _parse_batch * tb_framework: prefix many auxiliary functions with underscore To clearly state the intent that they are private. * Rename `lower` to `hidden`, `upper` to `output` * Parser slow test fixup We don't have TransitionBasedParser.{v1,v2} until we bring it back as a legacy option. * Remove last vestiges of PrecomputableAffine This does not exist anymore as a separate layer. * ner: re-enable sentence boundary checks * Re-enable test that works now. * test_ner: make loss test more strict again * Remove commented line * Re-enable some more beam parser tests * Remove unused _forward_reference function * Update for CBlas changes in Thinc 8.1.0.dev2 Bump thinc dependency to 8.1.0.dev3. * Remove references to spacy.TransitionBasedParser.{v1,v2} Since they will not be offered starting with spaCy v4. * `tb_framework`: Replace references to `thinc.backends.linalg` with `CBlas` * dont use get_array_module (#11056) (#11293) Co-authored-by: kadarakos <kadar.akos@gmail.com> * Move `thinc.extra.search` to `spacy.pipeline._parser_internals` (#11317) * `search`: Move from `thinc.extra.search` Fix NPE in `Beam.__dealloc__` * `pytest`: Add support for executing Cython tests Move `search` tests from thinc and patch them to run with `pytest` * `mypy` fix * Update comment * `conftest`: Expose `register_cython_tests` * Remove unused import * Move `argmax` impls to new `_parser_utils` Cython module (#11410) * Parser does not have to be a cdef class anymore This also fixes validation of the initialization schema. * Add back spacy.TransitionBasedParser.v2 * Fix a rename that was missed in #10878. So that rehearsal tests pass. * Remove module from setup.py that got added during the merge * Bring back support for `update_with_oracle_cut_size` (#12086) * Bring back support for `update_with_oracle_cut_size` This option was available in the pre-refactor parser, but was never implemented in the refactored parser. This option cuts transition sequences that are longer than `update_with_oracle_cut` size into separate sequences that have at most `update_with_oracle_cut` transitions. The oracle (gold standard) transition sequence is used to determine the cuts and the initial states for the additional sequences. Applying this cut makes the batches more homogeneous in the transition sequence lengths, making forward passes (and as a consequence training) much faster. Training time 1000 steps on de_core_news_lg: - Before this change: 149s - After this change: 68s - Pre-refactor parser: 81s * Fix a rename that was missed in #10878. So that rehearsal tests pass. * Apply suggestions from @shadeMe * Use chained conditional * Test with update_with_oracle_cut_size={0, 1, 5, 100} And fix a git that occurs with a cut size of 1. * Fix up some merge fall out * Update parser distillation for the refactor In the old parser, we'd iterate over the transitions in the distill function and compute the loss/gradients on the go. In the refactored parser, we first let the student model parse the inputs. Then we'll let the teacher compute the transition probabilities of the states in the student's transition sequence. We can then compute the gradients of the student given the teacher. * Add back spacy.TransitionBasedParser.v1 references - Accordion in the architecture docs. - Test in test_parse, but disabled until we have a spacy-legacy release. Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: svlandeg <svlandeg@github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: kadarakos <kadar.akos@gmail.com>
This commit is contained in:
parent
5e297aa20e
commit
a183db3cef
6
setup.py
6
setup.py
|
@ -33,12 +33,10 @@ MOD_NAMES = [
|
||||||
"spacy.kb.candidate",
|
"spacy.kb.candidate",
|
||||||
"spacy.kb.kb",
|
"spacy.kb.kb",
|
||||||
"spacy.kb.kb_in_memory",
|
"spacy.kb.kb_in_memory",
|
||||||
"spacy.ml.parser_model",
|
"spacy.ml.tb_framework",
|
||||||
"spacy.morphology",
|
"spacy.morphology",
|
||||||
"spacy.pipeline.dep_parser",
|
|
||||||
"spacy.pipeline._edit_tree_internals.edit_trees",
|
"spacy.pipeline._edit_tree_internals.edit_trees",
|
||||||
"spacy.pipeline.morphologizer",
|
"spacy.pipeline.morphologizer",
|
||||||
"spacy.pipeline.ner",
|
|
||||||
"spacy.pipeline.pipe",
|
"spacy.pipeline.pipe",
|
||||||
"spacy.pipeline.trainable_pipe",
|
"spacy.pipeline.trainable_pipe",
|
||||||
"spacy.pipeline.sentencizer",
|
"spacy.pipeline.sentencizer",
|
||||||
|
@ -46,6 +44,7 @@ MOD_NAMES = [
|
||||||
"spacy.pipeline.tagger",
|
"spacy.pipeline.tagger",
|
||||||
"spacy.pipeline.transition_parser",
|
"spacy.pipeline.transition_parser",
|
||||||
"spacy.pipeline._parser_internals.arc_eager",
|
"spacy.pipeline._parser_internals.arc_eager",
|
||||||
|
"spacy.pipeline._parser_internals.batch",
|
||||||
"spacy.pipeline._parser_internals.ner",
|
"spacy.pipeline._parser_internals.ner",
|
||||||
"spacy.pipeline._parser_internals.nonproj",
|
"spacy.pipeline._parser_internals.nonproj",
|
||||||
"spacy.pipeline._parser_internals.search",
|
"spacy.pipeline._parser_internals.search",
|
||||||
|
@ -53,6 +52,7 @@ MOD_NAMES = [
|
||||||
"spacy.pipeline._parser_internals.stateclass",
|
"spacy.pipeline._parser_internals.stateclass",
|
||||||
"spacy.pipeline._parser_internals.transition_system",
|
"spacy.pipeline._parser_internals.transition_system",
|
||||||
"spacy.pipeline._parser_internals._beam_utils",
|
"spacy.pipeline._parser_internals._beam_utils",
|
||||||
|
"spacy.pipeline._parser_internals._parser_utils",
|
||||||
"spacy.tokenizer",
|
"spacy.tokenizer",
|
||||||
"spacy.training.align",
|
"spacy.training.align",
|
||||||
"spacy.training.gold_io",
|
"spacy.training.gold_io",
|
||||||
|
|
|
@ -87,12 +87,11 @@ grad_factor = 1.0
|
||||||
factory = "parser"
|
factory = "parser"
|
||||||
|
|
||||||
[components.parser.model]
|
[components.parser.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v2"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
state_type = "parser"
|
state_type = "parser"
|
||||||
extra_state_tokens = false
|
extra_state_tokens = false
|
||||||
hidden_width = 128
|
hidden_width = 128
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
use_upper = false
|
|
||||||
nO = null
|
nO = null
|
||||||
|
|
||||||
[components.parser.model.tok2vec]
|
[components.parser.model.tok2vec]
|
||||||
|
@ -108,12 +107,11 @@ grad_factor = 1.0
|
||||||
factory = "ner"
|
factory = "ner"
|
||||||
|
|
||||||
[components.ner.model]
|
[components.ner.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v2"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
state_type = "ner"
|
state_type = "ner"
|
||||||
extra_state_tokens = false
|
extra_state_tokens = false
|
||||||
hidden_width = 64
|
hidden_width = 64
|
||||||
maxout_pieces = 2
|
maxout_pieces = 2
|
||||||
use_upper = false
|
|
||||||
nO = null
|
nO = null
|
||||||
|
|
||||||
[components.ner.model.tok2vec]
|
[components.ner.model.tok2vec]
|
||||||
|
@ -314,12 +312,11 @@ width = ${components.tok2vec.model.encode.width}
|
||||||
factory = "parser"
|
factory = "parser"
|
||||||
|
|
||||||
[components.parser.model]
|
[components.parser.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v2"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
state_type = "parser"
|
state_type = "parser"
|
||||||
extra_state_tokens = false
|
extra_state_tokens = false
|
||||||
hidden_width = 128
|
hidden_width = 128
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
use_upper = true
|
|
||||||
nO = null
|
nO = null
|
||||||
|
|
||||||
[components.parser.model.tok2vec]
|
[components.parser.model.tok2vec]
|
||||||
|
@ -332,12 +329,11 @@ width = ${components.tok2vec.model.encode.width}
|
||||||
factory = "ner"
|
factory = "ner"
|
||||||
|
|
||||||
[components.ner.model]
|
[components.ner.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v2"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
state_type = "ner"
|
state_type = "ner"
|
||||||
extra_state_tokens = false
|
extra_state_tokens = false
|
||||||
hidden_width = 64
|
hidden_width = 64
|
||||||
maxout_pieces = 2
|
maxout_pieces = 2
|
||||||
use_upper = true
|
|
||||||
nO = null
|
nO = null
|
||||||
|
|
||||||
[components.ner.model.tok2vec]
|
[components.ner.model.tok2vec]
|
||||||
|
|
|
@ -209,6 +209,8 @@ class Warnings(metaclass=ErrorsWithCodes):
|
||||||
"`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.")
|
"`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.")
|
||||||
W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.")
|
W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.")
|
||||||
|
|
||||||
|
W400 = ("`use_upper=False` is ignored, the upper layer is always enabled")
|
||||||
|
|
||||||
|
|
||||||
class Errors(metaclass=ErrorsWithCodes):
|
class Errors(metaclass=ErrorsWithCodes):
|
||||||
E001 = ("No component '{name}' found in pipeline. Available names: {opts}")
|
E001 = ("No component '{name}' found in pipeline. Available names: {opts}")
|
||||||
|
@ -958,6 +960,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
E4002 = ("Pipe '{name}' requires a teacher pipe for distillation.")
|
E4002 = ("Pipe '{name}' requires a teacher pipe for distillation.")
|
||||||
E4003 = ("Training examples for distillation must have the exact same tokens in the "
|
E4003 = ("Training examples for distillation must have the exact same tokens in the "
|
||||||
"reference and predicted docs.")
|
"reference and predicted docs.")
|
||||||
|
E4004 = ("Backprop is not supported when is_train is not set.")
|
||||||
|
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
|
@ -1,164 +0,0 @@
|
||||||
from thinc.api import Model, normal_init
|
|
||||||
|
|
||||||
from ..util import registry
|
|
||||||
|
|
||||||
|
|
||||||
@registry.layers("spacy.PrecomputableAffine.v1")
|
|
||||||
def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1):
|
|
||||||
model = Model(
|
|
||||||
"precomputable_affine",
|
|
||||||
forward,
|
|
||||||
init=init,
|
|
||||||
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
|
|
||||||
params={"W": None, "b": None, "pad": None},
|
|
||||||
attrs={"dropout_rate": dropout},
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def forward(model, X, is_train):
|
|
||||||
nF = model.get_dim("nF")
|
|
||||||
nO = model.get_dim("nO")
|
|
||||||
nP = model.get_dim("nP")
|
|
||||||
nI = model.get_dim("nI")
|
|
||||||
W = model.get_param("W")
|
|
||||||
# Preallocate array for layer output, including padding.
|
|
||||||
Yf = model.ops.alloc2f(X.shape[0] + 1, nF * nO * nP, zeros=False)
|
|
||||||
model.ops.gemm(X, W.reshape((nF * nO * nP, nI)), trans2=True, out=Yf[1:])
|
|
||||||
Yf = Yf.reshape((Yf.shape[0], nF, nO, nP))
|
|
||||||
|
|
||||||
# Set padding. Padding has shape (1, nF, nO, nP). Unfortunately, we cannot
|
|
||||||
# change its shape to (nF, nO, nP) without breaking existing models. So
|
|
||||||
# we'll squeeze the first dimension here.
|
|
||||||
Yf[0] = model.ops.xp.squeeze(model.get_param("pad"), 0)
|
|
||||||
|
|
||||||
def backward(dY_ids):
|
|
||||||
# This backprop is particularly tricky, because we get back a different
|
|
||||||
# thing from what we put out. We put out an array of shape:
|
|
||||||
# (nB, nF, nO, nP), and get back:
|
|
||||||
# (nB, nO, nP) and ids (nB, nF)
|
|
||||||
# The ids tell us the values of nF, so we would have:
|
|
||||||
#
|
|
||||||
# dYf = zeros((nB, nF, nO, nP))
|
|
||||||
# for b in range(nB):
|
|
||||||
# for f in range(nF):
|
|
||||||
# dYf[b, ids[b, f]] += dY[b]
|
|
||||||
#
|
|
||||||
# However, we avoid building that array for efficiency -- and just pass
|
|
||||||
# in the indices.
|
|
||||||
dY, ids = dY_ids
|
|
||||||
assert dY.ndim == 3
|
|
||||||
assert dY.shape[1] == nO, dY.shape
|
|
||||||
assert dY.shape[2] == nP, dY.shape
|
|
||||||
# nB = dY.shape[0]
|
|
||||||
model.inc_grad("pad", _backprop_precomputable_affine_padding(model, dY, ids))
|
|
||||||
Xf = X[ids]
|
|
||||||
Xf = Xf.reshape((Xf.shape[0], nF * nI))
|
|
||||||
|
|
||||||
model.inc_grad("b", dY.sum(axis=0))
|
|
||||||
dY = dY.reshape((dY.shape[0], nO * nP))
|
|
||||||
|
|
||||||
Wopfi = W.transpose((1, 2, 0, 3))
|
|
||||||
Wopfi = Wopfi.reshape((nO * nP, nF * nI))
|
|
||||||
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi)
|
|
||||||
|
|
||||||
dWopfi = model.ops.gemm(dY, Xf, trans1=True)
|
|
||||||
dWopfi = dWopfi.reshape((nO, nP, nF, nI))
|
|
||||||
# (o, p, f, i) --> (f, o, p, i)
|
|
||||||
dWopfi = dWopfi.transpose((2, 0, 1, 3))
|
|
||||||
model.inc_grad("W", dWopfi)
|
|
||||||
return dXf.reshape((dXf.shape[0], nF, nI))
|
|
||||||
|
|
||||||
return Yf, backward
|
|
||||||
|
|
||||||
|
|
||||||
def _backprop_precomputable_affine_padding(model, dY, ids):
|
|
||||||
nB = dY.shape[0]
|
|
||||||
nF = model.get_dim("nF")
|
|
||||||
nP = model.get_dim("nP")
|
|
||||||
nO = model.get_dim("nO")
|
|
||||||
# Backprop the "padding", used as a filler for missing values.
|
|
||||||
# Values that are missing are set to -1, and each state vector could
|
|
||||||
# have multiple missing values. The padding has different values for
|
|
||||||
# different missing features. The gradient of the padding vector is:
|
|
||||||
#
|
|
||||||
# for b in range(nB):
|
|
||||||
# for f in range(nF):
|
|
||||||
# if ids[b, f] < 0:
|
|
||||||
# d_pad[f] += dY[b]
|
|
||||||
#
|
|
||||||
# Which can be rewritten as:
|
|
||||||
#
|
|
||||||
# (ids < 0).T @ dY
|
|
||||||
mask = model.ops.asarray(ids < 0, dtype="f")
|
|
||||||
d_pad = model.ops.gemm(mask, dY.reshape(nB, nO * nP), trans1=True)
|
|
||||||
return d_pad.reshape((1, nF, nO, nP))
|
|
||||||
|
|
||||||
|
|
||||||
def init(model, X=None, Y=None):
|
|
||||||
"""This is like the 'layer sequential unit variance', but instead
|
|
||||||
of taking the actual inputs, we randomly generate whitened data.
|
|
||||||
|
|
||||||
Why's this all so complicated? We have a huge number of inputs,
|
|
||||||
and the maxout unit makes guessing the dynamics tricky. Instead
|
|
||||||
we set the maxout weights to values that empirically result in
|
|
||||||
whitened outputs given whitened inputs.
|
|
||||||
"""
|
|
||||||
if model.has_param("W") and model.get_param("W").any():
|
|
||||||
return
|
|
||||||
|
|
||||||
nF = model.get_dim("nF")
|
|
||||||
nO = model.get_dim("nO")
|
|
||||||
nP = model.get_dim("nP")
|
|
||||||
nI = model.get_dim("nI")
|
|
||||||
W = model.ops.alloc4f(nF, nO, nP, nI)
|
|
||||||
b = model.ops.alloc2f(nO, nP)
|
|
||||||
pad = model.ops.alloc4f(1, nF, nO, nP)
|
|
||||||
|
|
||||||
ops = model.ops
|
|
||||||
W = normal_init(ops, W.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI)))
|
|
||||||
pad = normal_init(ops, pad.shape, mean=1.0)
|
|
||||||
model.set_param("W", W)
|
|
||||||
model.set_param("b", b)
|
|
||||||
model.set_param("pad", pad)
|
|
||||||
|
|
||||||
ids = ops.alloc((5000, nF), dtype="f")
|
|
||||||
ids += ops.xp.random.uniform(0, 1000, ids.shape)
|
|
||||||
ids = ops.asarray(ids, dtype="i")
|
|
||||||
tokvecs = ops.alloc((5000, nI), dtype="f")
|
|
||||||
tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape(
|
|
||||||
tokvecs.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
def predict(ids, tokvecs):
|
|
||||||
# nS ids. nW tokvecs. Exclude the padding array.
|
|
||||||
hiddens = model.predict(tokvecs[:-1]) # (nW, f, o, p)
|
|
||||||
vectors = model.ops.alloc((ids.shape[0], nO * nP), dtype="f")
|
|
||||||
# need nS vectors
|
|
||||||
hiddens = hiddens.reshape((hiddens.shape[0] * nF, nO * nP))
|
|
||||||
model.ops.scatter_add(vectors, ids.flatten(), hiddens)
|
|
||||||
vectors = vectors.reshape((vectors.shape[0], nO, nP))
|
|
||||||
vectors += b
|
|
||||||
vectors = model.ops.asarray(vectors)
|
|
||||||
if nP >= 2:
|
|
||||||
return model.ops.maxout(vectors)[0]
|
|
||||||
else:
|
|
||||||
return vectors * (vectors >= 0)
|
|
||||||
|
|
||||||
tol_var = 0.01
|
|
||||||
tol_mean = 0.01
|
|
||||||
t_max = 10
|
|
||||||
W = model.get_param("W").copy()
|
|
||||||
b = model.get_param("b").copy()
|
|
||||||
for t_i in range(t_max):
|
|
||||||
acts1 = predict(ids, tokvecs)
|
|
||||||
var = model.ops.xp.var(acts1)
|
|
||||||
mean = model.ops.xp.mean(acts1)
|
|
||||||
if abs(var - 1.0) >= tol_var:
|
|
||||||
W /= model.ops.xp.sqrt(var)
|
|
||||||
model.set_param("W", W)
|
|
||||||
elif abs(mean) >= tol_mean:
|
|
||||||
b -= mean
|
|
||||||
model.set_param("b", b)
|
|
||||||
else:
|
|
||||||
break
|
|
|
@ -1,17 +1,20 @@
|
||||||
from typing import Optional, List, cast
|
from typing import Optional, List, Tuple, Any
|
||||||
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
|
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
|
from thinc.api import Model
|
||||||
|
import warnings
|
||||||
|
|
||||||
from ...errors import Errors
|
from ...errors import Errors, Warnings
|
||||||
from ...compat import Literal
|
from ...compat import Literal
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from .._precomputable_affine import PrecomputableAffine
|
|
||||||
from ..tb_framework import TransitionModel
|
from ..tb_framework import TransitionModel
|
||||||
from ...tokens import Doc
|
from ...tokens.doc import Doc
|
||||||
|
|
||||||
|
TransitionSystem = Any # TODO
|
||||||
|
State = Any # TODO
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.TransitionBasedParser.v2")
|
@registry.architectures.register("spacy.TransitionBasedParser.v2")
|
||||||
def build_tb_parser_model(
|
def transition_parser_v2(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
state_type: Literal["parser", "ner"],
|
state_type: Literal["parser", "ner"],
|
||||||
extra_state_tokens: bool,
|
extra_state_tokens: bool,
|
||||||
|
@ -19,6 +22,46 @@ def build_tb_parser_model(
|
||||||
maxout_pieces: int,
|
maxout_pieces: int,
|
||||||
use_upper: bool,
|
use_upper: bool,
|
||||||
nO: Optional[int] = None,
|
nO: Optional[int] = None,
|
||||||
|
) -> Model:
|
||||||
|
if not use_upper:
|
||||||
|
warnings.warn(Warnings.W400)
|
||||||
|
|
||||||
|
return build_tb_parser_model(
|
||||||
|
tok2vec,
|
||||||
|
state_type,
|
||||||
|
extra_state_tokens,
|
||||||
|
hidden_width,
|
||||||
|
maxout_pieces,
|
||||||
|
nO=nO,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.architectures.register("spacy.TransitionBasedParser.v3")
|
||||||
|
def transition_parser_v3(
|
||||||
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
|
state_type: Literal["parser", "ner"],
|
||||||
|
extra_state_tokens: bool,
|
||||||
|
hidden_width: int,
|
||||||
|
maxout_pieces: int,
|
||||||
|
nO: Optional[int] = None,
|
||||||
|
) -> Model:
|
||||||
|
return build_tb_parser_model(
|
||||||
|
tok2vec,
|
||||||
|
state_type,
|
||||||
|
extra_state_tokens,
|
||||||
|
hidden_width,
|
||||||
|
maxout_pieces,
|
||||||
|
nO=nO,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_tb_parser_model(
|
||||||
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
|
state_type: Literal["parser", "ner"],
|
||||||
|
extra_state_tokens: bool,
|
||||||
|
hidden_width: int,
|
||||||
|
maxout_pieces: int,
|
||||||
|
nO: Optional[int] = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
"""
|
"""
|
||||||
Build a transition-based parser model. Can apply to NER or dependency-parsing.
|
Build a transition-based parser model. Can apply to NER or dependency-parsing.
|
||||||
|
@ -51,14 +94,7 @@ def build_tb_parser_model(
|
||||||
feature sets (for the NER) or 13 (for the parser).
|
feature sets (for the NER) or 13 (for the parser).
|
||||||
hidden_width (int): The width of the hidden layer.
|
hidden_width (int): The width of the hidden layer.
|
||||||
maxout_pieces (int): How many pieces to use in the state prediction layer.
|
maxout_pieces (int): How many pieces to use in the state prediction layer.
|
||||||
Recommended values are 1, 2 or 3. If 1, the maxout non-linearity
|
Recommended values are 1, 2 or 3.
|
||||||
is replaced with a ReLu non-linearity if use_upper=True, and no
|
|
||||||
non-linearity if use_upper=False.
|
|
||||||
use_upper (bool): Whether to use an additional hidden layer after the state
|
|
||||||
vector in order to predict the action scores. It is recommended to set
|
|
||||||
this to False for large pretrained models such as transformers, and True
|
|
||||||
for smaller networks. The upper layer is computed on CPU, which becomes
|
|
||||||
a bottleneck on larger GPU-based models, where it's also less necessary.
|
|
||||||
nO (int or None): The number of actions the model will predict between.
|
nO (int or None): The number of actions the model will predict between.
|
||||||
Usually inferred from data at the beginning of training, or loaded from
|
Usually inferred from data at the beginning of training, or loaded from
|
||||||
disk.
|
disk.
|
||||||
|
@ -69,106 +105,11 @@ def build_tb_parser_model(
|
||||||
nr_feature_tokens = 6 if extra_state_tokens else 3
|
nr_feature_tokens = 6 if extra_state_tokens else 3
|
||||||
else:
|
else:
|
||||||
raise ValueError(Errors.E917.format(value=state_type))
|
raise ValueError(Errors.E917.format(value=state_type))
|
||||||
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
return TransitionModel(
|
||||||
tok2vec = chain(
|
tok2vec=tok2vec,
|
||||||
tok2vec,
|
state_tokens=nr_feature_tokens,
|
||||||
list2array(),
|
hidden_width=hidden_width,
|
||||||
Linear(hidden_width, t2v_width),
|
maxout_pieces=maxout_pieces,
|
||||||
|
nO=nO,
|
||||||
|
unseen_classes=set(),
|
||||||
)
|
)
|
||||||
tok2vec.set_dim("nO", hidden_width)
|
|
||||||
lower = _define_lower(
|
|
||||||
nO=hidden_width if use_upper else nO,
|
|
||||||
nF=nr_feature_tokens,
|
|
||||||
nI=tok2vec.get_dim("nO"),
|
|
||||||
nP=maxout_pieces,
|
|
||||||
)
|
|
||||||
upper = None
|
|
||||||
if use_upper:
|
|
||||||
with use_ops("cpu"):
|
|
||||||
# Initialize weights at zero, as it's a classification layer.
|
|
||||||
upper = _define_upper(nO=nO, nI=None)
|
|
||||||
return TransitionModel(tok2vec, lower, upper, resize_output)
|
|
||||||
|
|
||||||
|
|
||||||
def _define_upper(nO, nI):
|
|
||||||
return Linear(nO=nO, nI=nI, init_W=zero_init)
|
|
||||||
|
|
||||||
|
|
||||||
def _define_lower(nO, nF, nI, nP):
|
|
||||||
return PrecomputableAffine(nO=nO, nF=nF, nI=nI, nP=nP)
|
|
||||||
|
|
||||||
|
|
||||||
def resize_output(model, new_nO):
|
|
||||||
if model.attrs["has_upper"]:
|
|
||||||
return _resize_upper(model, new_nO)
|
|
||||||
return _resize_lower(model, new_nO)
|
|
||||||
|
|
||||||
|
|
||||||
def _resize_upper(model, new_nO):
|
|
||||||
upper = model.get_ref("upper")
|
|
||||||
if upper.has_dim("nO") is None:
|
|
||||||
upper.set_dim("nO", new_nO)
|
|
||||||
return model
|
|
||||||
elif new_nO == upper.get_dim("nO"):
|
|
||||||
return model
|
|
||||||
|
|
||||||
smaller = upper
|
|
||||||
nI = smaller.maybe_get_dim("nI")
|
|
||||||
with use_ops("cpu"):
|
|
||||||
larger = _define_upper(nO=new_nO, nI=nI)
|
|
||||||
# it could be that the model is not initialized yet, then skip this bit
|
|
||||||
if smaller.has_param("W"):
|
|
||||||
larger_W = larger.ops.alloc2f(new_nO, nI)
|
|
||||||
larger_b = larger.ops.alloc1f(new_nO)
|
|
||||||
smaller_W = smaller.get_param("W")
|
|
||||||
smaller_b = smaller.get_param("b")
|
|
||||||
# Weights are stored in (nr_out, nr_in) format, so we're basically
|
|
||||||
# just adding rows here.
|
|
||||||
if smaller.has_dim("nO"):
|
|
||||||
old_nO = smaller.get_dim("nO")
|
|
||||||
larger_W[:old_nO] = smaller_W
|
|
||||||
larger_b[:old_nO] = smaller_b
|
|
||||||
for i in range(old_nO, new_nO):
|
|
||||||
model.attrs["unseen_classes"].add(i)
|
|
||||||
|
|
||||||
larger.set_param("W", larger_W)
|
|
||||||
larger.set_param("b", larger_b)
|
|
||||||
model._layers[-1] = larger
|
|
||||||
model.set_ref("upper", larger)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def _resize_lower(model, new_nO):
|
|
||||||
lower = model.get_ref("lower")
|
|
||||||
if lower.has_dim("nO") is None:
|
|
||||||
lower.set_dim("nO", new_nO)
|
|
||||||
return model
|
|
||||||
|
|
||||||
smaller = lower
|
|
||||||
nI = smaller.maybe_get_dim("nI")
|
|
||||||
nF = smaller.maybe_get_dim("nF")
|
|
||||||
nP = smaller.maybe_get_dim("nP")
|
|
||||||
larger = _define_lower(nO=new_nO, nI=nI, nF=nF, nP=nP)
|
|
||||||
# it could be that the model is not initialized yet, then skip this bit
|
|
||||||
if smaller.has_param("W"):
|
|
||||||
larger_W = larger.ops.alloc4f(nF, new_nO, nP, nI)
|
|
||||||
larger_b = larger.ops.alloc2f(new_nO, nP)
|
|
||||||
larger_pad = larger.ops.alloc4f(1, nF, new_nO, nP)
|
|
||||||
smaller_W = smaller.get_param("W")
|
|
||||||
smaller_b = smaller.get_param("b")
|
|
||||||
smaller_pad = smaller.get_param("pad")
|
|
||||||
# Copy the old weights and padding into the new layer
|
|
||||||
if smaller.has_dim("nO"):
|
|
||||||
old_nO = smaller.get_dim("nO")
|
|
||||||
larger_W[:, 0:old_nO, :, :] = smaller_W
|
|
||||||
larger_pad[:, :, 0:old_nO, :] = smaller_pad
|
|
||||||
larger_b[0:old_nO, :] = smaller_b
|
|
||||||
for i in range(old_nO, new_nO):
|
|
||||||
model.attrs["unseen_classes"].add(i)
|
|
||||||
|
|
||||||
larger.set_param("W", larger_W)
|
|
||||||
larger.set_param("b", larger_b)
|
|
||||||
larger.set_param("pad", larger_pad)
|
|
||||||
model._layers[1] = larger
|
|
||||||
model.set_ref("lower", larger)
|
|
||||||
return model
|
|
||||||
|
|
|
@ -1,49 +0,0 @@
|
||||||
from libc.string cimport memset, memcpy
|
|
||||||
from thinc.backends.cblas cimport CBlas
|
|
||||||
from ..typedefs cimport weight_t, hash_t
|
|
||||||
from ..pipeline._parser_internals._state cimport StateC
|
|
||||||
|
|
||||||
|
|
||||||
cdef struct SizesC:
|
|
||||||
int states
|
|
||||||
int classes
|
|
||||||
int hiddens
|
|
||||||
int pieces
|
|
||||||
int feats
|
|
||||||
int embed_width
|
|
||||||
|
|
||||||
|
|
||||||
cdef struct WeightsC:
|
|
||||||
const float* feat_weights
|
|
||||||
const float* feat_bias
|
|
||||||
const float* hidden_bias
|
|
||||||
const float* hidden_weights
|
|
||||||
const float* seen_classes
|
|
||||||
|
|
||||||
|
|
||||||
cdef struct ActivationsC:
|
|
||||||
int* token_ids
|
|
||||||
float* unmaxed
|
|
||||||
float* scores
|
|
||||||
float* hiddens
|
|
||||||
int* is_valid
|
|
||||||
int _curr_size
|
|
||||||
int _max_size
|
|
||||||
|
|
||||||
|
|
||||||
cdef WeightsC get_c_weights(model) except *
|
|
||||||
|
|
||||||
cdef SizesC get_c_sizes(model, int batch_size) except *
|
|
||||||
|
|
||||||
cdef ActivationsC alloc_activations(SizesC n) nogil
|
|
||||||
|
|
||||||
cdef void free_activations(const ActivationsC* A) nogil
|
|
||||||
|
|
||||||
cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
|
|
||||||
const WeightsC* W, SizesC n) nogil
|
|
||||||
|
|
||||||
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil
|
|
||||||
|
|
||||||
cdef void cpu_log_loss(float* d_scores,
|
|
||||||
const float* costs, const int* is_valid, const float* scores, int O) nogil
|
|
||||||
|
|
|
@ -1,500 +0,0 @@
|
||||||
# cython: infer_types=True, cdivision=True, boundscheck=False
|
|
||||||
cimport numpy as np
|
|
||||||
from libc.math cimport exp
|
|
||||||
from libc.string cimport memset, memcpy
|
|
||||||
from libc.stdlib cimport calloc, free, realloc
|
|
||||||
from thinc.backends.cblas cimport saxpy, sgemm
|
|
||||||
|
|
||||||
import numpy
|
|
||||||
import numpy.random
|
|
||||||
from thinc.api import Model, CupyOps, NumpyOps, get_ops
|
|
||||||
|
|
||||||
from .. import util
|
|
||||||
from ..errors import Errors
|
|
||||||
from ..typedefs cimport weight_t, class_t, hash_t
|
|
||||||
from ..pipeline._parser_internals.stateclass cimport StateClass
|
|
||||||
|
|
||||||
|
|
||||||
cdef WeightsC get_c_weights(model) except *:
|
|
||||||
cdef WeightsC output
|
|
||||||
cdef precompute_hiddens state2vec = model.state2vec
|
|
||||||
output.feat_weights = state2vec.get_feat_weights()
|
|
||||||
output.feat_bias = <const float*>state2vec.bias.data
|
|
||||||
cdef np.ndarray vec2scores_W
|
|
||||||
cdef np.ndarray vec2scores_b
|
|
||||||
if model.vec2scores is None:
|
|
||||||
output.hidden_weights = NULL
|
|
||||||
output.hidden_bias = NULL
|
|
||||||
else:
|
|
||||||
vec2scores_W = model.vec2scores.get_param("W")
|
|
||||||
vec2scores_b = model.vec2scores.get_param("b")
|
|
||||||
output.hidden_weights = <const float*>vec2scores_W.data
|
|
||||||
output.hidden_bias = <const float*>vec2scores_b.data
|
|
||||||
cdef np.ndarray class_mask = model._class_mask
|
|
||||||
output.seen_classes = <const float*>class_mask.data
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
cdef SizesC get_c_sizes(model, int batch_size) except *:
|
|
||||||
cdef SizesC output
|
|
||||||
output.states = batch_size
|
|
||||||
if model.vec2scores is None:
|
|
||||||
output.classes = model.state2vec.get_dim("nO")
|
|
||||||
else:
|
|
||||||
output.classes = model.vec2scores.get_dim("nO")
|
|
||||||
output.hiddens = model.state2vec.get_dim("nO")
|
|
||||||
output.pieces = model.state2vec.get_dim("nP")
|
|
||||||
output.feats = model.state2vec.get_dim("nF")
|
|
||||||
output.embed_width = model.tokvecs.shape[1]
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
cdef ActivationsC alloc_activations(SizesC n) nogil:
|
|
||||||
cdef ActivationsC A
|
|
||||||
memset(&A, 0, sizeof(A))
|
|
||||||
resize_activations(&A, n)
|
|
||||||
return A
|
|
||||||
|
|
||||||
|
|
||||||
cdef void free_activations(const ActivationsC* A) nogil:
|
|
||||||
free(A.token_ids)
|
|
||||||
free(A.scores)
|
|
||||||
free(A.unmaxed)
|
|
||||||
free(A.hiddens)
|
|
||||||
free(A.is_valid)
|
|
||||||
|
|
||||||
|
|
||||||
cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
|
|
||||||
if n.states <= A._max_size:
|
|
||||||
A._curr_size = n.states
|
|
||||||
return
|
|
||||||
if A._max_size == 0:
|
|
||||||
A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
|
|
||||||
A.scores = <float*>calloc(n.states * n.classes, sizeof(A.scores[0]))
|
|
||||||
A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
|
|
||||||
A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
|
|
||||||
A.is_valid = <int*>calloc(n.states * n.classes, sizeof(A.is_valid[0]))
|
|
||||||
A._max_size = n.states
|
|
||||||
else:
|
|
||||||
A.token_ids = <int*>realloc(A.token_ids,
|
|
||||||
n.states * n.feats * sizeof(A.token_ids[0]))
|
|
||||||
A.scores = <float*>realloc(A.scores,
|
|
||||||
n.states * n.classes * sizeof(A.scores[0]))
|
|
||||||
A.unmaxed = <float*>realloc(A.unmaxed,
|
|
||||||
n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0]))
|
|
||||||
A.hiddens = <float*>realloc(A.hiddens,
|
|
||||||
n.states * n.hiddens * sizeof(A.hiddens[0]))
|
|
||||||
A.is_valid = <int*>realloc(A.is_valid,
|
|
||||||
n.states * n.classes * sizeof(A.is_valid[0]))
|
|
||||||
A._max_size = n.states
|
|
||||||
A._curr_size = n.states
|
|
||||||
|
|
||||||
|
|
||||||
cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
|
|
||||||
const WeightsC* W, SizesC n) nogil:
|
|
||||||
cdef double one = 1.0
|
|
||||||
resize_activations(A, n)
|
|
||||||
for i in range(n.states):
|
|
||||||
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
|
|
||||||
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
|
|
||||||
memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float))
|
|
||||||
sum_state_features(cblas, A.unmaxed,
|
|
||||||
W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces)
|
|
||||||
for i in range(n.states):
|
|
||||||
saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1)
|
|
||||||
for j in range(n.hiddens):
|
|
||||||
index = i * n.hiddens * n.pieces + j * n.pieces
|
|
||||||
which = _arg_max(&A.unmaxed[index], n.pieces)
|
|
||||||
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
|
|
||||||
memset(A.scores, 0, n.states * n.classes * sizeof(float))
|
|
||||||
if W.hidden_weights == NULL:
|
|
||||||
memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float))
|
|
||||||
else:
|
|
||||||
# Compute hidden-to-output
|
|
||||||
sgemm(cblas)(False, True, n.states, n.classes, n.hiddens,
|
|
||||||
1.0, <const float *>A.hiddens, n.hiddens,
|
|
||||||
<const float *>W.hidden_weights, n.hiddens,
|
|
||||||
0.0, A.scores, n.classes)
|
|
||||||
# Add bias
|
|
||||||
for i in range(n.states):
|
|
||||||
saxpy(cblas)(n.classes, 1., W.hidden_bias, 1, &A.scores[i*n.classes], 1)
|
|
||||||
# Set unseen classes to minimum value
|
|
||||||
i = 0
|
|
||||||
min_ = A.scores[0]
|
|
||||||
for i in range(1, n.states * n.classes):
|
|
||||||
if A.scores[i] < min_:
|
|
||||||
min_ = A.scores[i]
|
|
||||||
for i in range(n.states):
|
|
||||||
for j in range(n.classes):
|
|
||||||
if not W.seen_classes[j]:
|
|
||||||
A.scores[i*n.classes+j] = min_
|
|
||||||
|
|
||||||
|
|
||||||
cdef void sum_state_features(CBlas cblas, float* output,
|
|
||||||
const float* cached, const int* token_ids, int B, int F, int O) nogil:
|
|
||||||
cdef int idx, b, f, i
|
|
||||||
cdef const float* feature
|
|
||||||
padding = cached
|
|
||||||
cached += F * O
|
|
||||||
cdef int id_stride = F*O
|
|
||||||
cdef float one = 1.
|
|
||||||
for b in range(B):
|
|
||||||
for f in range(F):
|
|
||||||
if token_ids[f] < 0:
|
|
||||||
feature = &padding[f*O]
|
|
||||||
else:
|
|
||||||
idx = token_ids[f] * id_stride + f*O
|
|
||||||
feature = &cached[idx]
|
|
||||||
saxpy(cblas)(O, one, <const float*>feature, 1, &output[b*O], 1)
|
|
||||||
token_ids += F
|
|
||||||
|
|
||||||
|
|
||||||
cdef void cpu_log_loss(float* d_scores,
|
|
||||||
const float* costs, const int* is_valid, const float* scores,
|
|
||||||
int O) nogil:
|
|
||||||
"""Do multi-label log loss"""
|
|
||||||
cdef double max_, gmax, Z, gZ
|
|
||||||
best = arg_max_if_gold(scores, costs, is_valid, O)
|
|
||||||
guess = _arg_max(scores, O)
|
|
||||||
|
|
||||||
if best == -1 or guess == -1:
|
|
||||||
# These shouldn't happen, but if they do, we want to make sure we don't
|
|
||||||
# cause an OOB access.
|
|
||||||
return
|
|
||||||
Z = 1e-10
|
|
||||||
gZ = 1e-10
|
|
||||||
max_ = scores[guess]
|
|
||||||
gmax = scores[best]
|
|
||||||
for i in range(O):
|
|
||||||
Z += exp(scores[i] - max_)
|
|
||||||
if costs[i] <= costs[best]:
|
|
||||||
gZ += exp(scores[i] - gmax)
|
|
||||||
for i in range(O):
|
|
||||||
if costs[i] <= costs[best]:
|
|
||||||
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
|
|
||||||
else:
|
|
||||||
d_scores[i] = exp(scores[i]-max_) / Z
|
|
||||||
|
|
||||||
|
|
||||||
cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs,
|
|
||||||
const int* is_valid, int n) nogil:
|
|
||||||
# Find minimum cost
|
|
||||||
cdef float cost = 1
|
|
||||||
for i in range(n):
|
|
||||||
if is_valid[i] and costs[i] < cost:
|
|
||||||
cost = costs[i]
|
|
||||||
# Now find best-scoring with that cost
|
|
||||||
cdef int best = -1
|
|
||||||
for i in range(n):
|
|
||||||
if costs[i] <= cost and is_valid[i]:
|
|
||||||
if best == -1 or scores[i] > scores[best]:
|
|
||||||
best = i
|
|
||||||
return best
|
|
||||||
|
|
||||||
|
|
||||||
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
|
|
||||||
cdef int best = -1
|
|
||||||
for i in range(n):
|
|
||||||
if is_valid[i] >= 1:
|
|
||||||
if best == -1 or scores[i] > scores[best]:
|
|
||||||
best = i
|
|
||||||
return best
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ParserStepModel(Model):
|
|
||||||
def __init__(self, docs, layers, *, has_upper, unseen_classes=None, train=True,
|
|
||||||
dropout=0.1):
|
|
||||||
Model.__init__(self, name="parser_step_model", forward=step_forward)
|
|
||||||
self.attrs["has_upper"] = has_upper
|
|
||||||
self.attrs["dropout_rate"] = dropout
|
|
||||||
self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train)
|
|
||||||
if layers[1].get_dim("nP") >= 2:
|
|
||||||
activation = "maxout"
|
|
||||||
elif has_upper:
|
|
||||||
activation = None
|
|
||||||
else:
|
|
||||||
activation = "relu"
|
|
||||||
self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
|
|
||||||
activation=activation, train=train)
|
|
||||||
if has_upper:
|
|
||||||
self.vec2scores = layers[-1]
|
|
||||||
else:
|
|
||||||
self.vec2scores = None
|
|
||||||
self.cuda_stream = util.get_cuda_stream(non_blocking=True)
|
|
||||||
self.backprops = []
|
|
||||||
self._class_mask = numpy.zeros((self.nO,), dtype='f')
|
|
||||||
self._class_mask.fill(1)
|
|
||||||
if unseen_classes is not None:
|
|
||||||
for class_ in unseen_classes:
|
|
||||||
self._class_mask[class_] = 0.
|
|
||||||
|
|
||||||
def clear_memory(self):
|
|
||||||
del self.tokvecs
|
|
||||||
del self.bp_tokvecs
|
|
||||||
del self.state2vec
|
|
||||||
del self.backprops
|
|
||||||
del self._class_mask
|
|
||||||
|
|
||||||
@property
|
|
||||||
def nO(self):
|
|
||||||
if self.attrs["has_upper"]:
|
|
||||||
return self.vec2scores.get_dim("nO")
|
|
||||||
else:
|
|
||||||
return self.state2vec.get_dim("nO")
|
|
||||||
|
|
||||||
def class_is_unseen(self, class_):
|
|
||||||
return self._class_mask[class_]
|
|
||||||
|
|
||||||
def mark_class_unseen(self, class_):
|
|
||||||
self._class_mask[class_] = 0
|
|
||||||
|
|
||||||
def mark_class_seen(self, class_):
|
|
||||||
self._class_mask[class_] = 1
|
|
||||||
|
|
||||||
def get_token_ids(self, states):
|
|
||||||
cdef StateClass state
|
|
||||||
states = [state for state in states if not state.is_final()]
|
|
||||||
cdef np.ndarray ids = numpy.zeros((len(states), self.state2vec.nF),
|
|
||||||
dtype='i', order='C')
|
|
||||||
ids.fill(-1)
|
|
||||||
c_ids = <int*>ids.data
|
|
||||||
for state in states:
|
|
||||||
state.c.set_context_tokens(c_ids, ids.shape[1])
|
|
||||||
c_ids += ids.shape[1]
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def backprop_step(self, token_ids, d_vector, get_d_tokvecs):
|
|
||||||
if isinstance(self.state2vec.ops, CupyOps) \
|
|
||||||
and not isinstance(token_ids, self.state2vec.ops.xp.ndarray):
|
|
||||||
# Move token_ids and d_vector to GPU, asynchronously
|
|
||||||
self.backprops.append((
|
|
||||||
util.get_async(self.cuda_stream, token_ids),
|
|
||||||
util.get_async(self.cuda_stream, d_vector),
|
|
||||||
get_d_tokvecs
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
self.backprops.append((token_ids, d_vector, get_d_tokvecs))
|
|
||||||
|
|
||||||
|
|
||||||
def finish_steps(self, golds):
|
|
||||||
# Add a padding vector to the d_tokvecs gradient, so that missing
|
|
||||||
# values don't affect the real gradient.
|
|
||||||
d_tokvecs = self.ops.alloc((self.tokvecs.shape[0]+1, self.tokvecs.shape[1]))
|
|
||||||
# Tells CUDA to block, so our async copies complete.
|
|
||||||
if self.cuda_stream is not None:
|
|
||||||
self.cuda_stream.synchronize()
|
|
||||||
for ids, d_vector, bp_vector in self.backprops:
|
|
||||||
d_state_features = bp_vector((d_vector, ids))
|
|
||||||
ids = ids.flatten()
|
|
||||||
d_state_features = d_state_features.reshape(
|
|
||||||
(ids.size, d_state_features.shape[2]))
|
|
||||||
self.ops.scatter_add(d_tokvecs, ids,
|
|
||||||
d_state_features)
|
|
||||||
# Padded -- see update()
|
|
||||||
self.bp_tokvecs(d_tokvecs[:-1])
|
|
||||||
return d_tokvecs
|
|
||||||
|
|
||||||
NUMPY_OPS = NumpyOps()
|
|
||||||
|
|
||||||
def step_forward(model: ParserStepModel, states, is_train):
|
|
||||||
token_ids = model.get_token_ids(states)
|
|
||||||
vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
|
|
||||||
mask = None
|
|
||||||
if model.attrs["has_upper"]:
|
|
||||||
dropout_rate = model.attrs["dropout_rate"]
|
|
||||||
if is_train and dropout_rate > 0:
|
|
||||||
mask = NUMPY_OPS.get_dropout_mask(vector.shape, 0.1)
|
|
||||||
vector *= mask
|
|
||||||
scores, get_d_vector = model.vec2scores(vector, is_train)
|
|
||||||
else:
|
|
||||||
scores = NumpyOps().asarray(vector)
|
|
||||||
get_d_vector = lambda d_scores: d_scores
|
|
||||||
# If the class is unseen, make sure its score is minimum
|
|
||||||
scores[:, model._class_mask == 0] = numpy.nanmin(scores)
|
|
||||||
|
|
||||||
def backprop_parser_step(d_scores):
|
|
||||||
# Zero vectors for unseen classes
|
|
||||||
d_scores *= model._class_mask
|
|
||||||
d_vector = get_d_vector(d_scores)
|
|
||||||
if mask is not None:
|
|
||||||
d_vector *= mask
|
|
||||||
model.backprop_step(token_ids, d_vector, get_d_tokvecs)
|
|
||||||
return None
|
|
||||||
return scores, backprop_parser_step
|
|
||||||
|
|
||||||
|
|
||||||
cdef class precompute_hiddens:
|
|
||||||
"""Allow a model to be "primed" by pre-computing input features in bulk.
|
|
||||||
|
|
||||||
This is used for the parser, where we want to take a batch of documents,
|
|
||||||
and compute vectors for each (token, position) pair. These vectors can then
|
|
||||||
be reused, especially for beam-search.
|
|
||||||
|
|
||||||
Let's say we're using 12 features for each state, e.g. word at start of
|
|
||||||
buffer, three words on stack, their children, etc. In the normal arc-eager
|
|
||||||
system, a document of length N is processed in 2*N states. This means we'll
|
|
||||||
create 2*N*12 feature vectors --- but if we pre-compute, we only need
|
|
||||||
N*12 vector computations. The saving for beam-search is much better:
|
|
||||||
if we have a beam of k, we'll normally make 2*N*12*K computations --
|
|
||||||
so we can save the factor k. This also gives a nice CPU/GPU division:
|
|
||||||
we can do all our hard maths up front, packed into large multiplications,
|
|
||||||
and do the hard-to-program parsing on the CPU.
|
|
||||||
"""
|
|
||||||
cdef readonly int nF, nO, nP
|
|
||||||
cdef bint _is_synchronized
|
|
||||||
cdef public object ops
|
|
||||||
cdef public object numpy_ops
|
|
||||||
cdef public object _cpu_ops
|
|
||||||
cdef np.ndarray _features
|
|
||||||
cdef np.ndarray _cached
|
|
||||||
cdef np.ndarray bias
|
|
||||||
cdef object _cuda_stream
|
|
||||||
cdef object _bp_hiddens
|
|
||||||
cdef object activation
|
|
||||||
|
|
||||||
def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None,
|
|
||||||
activation="maxout", train=False):
|
|
||||||
gpu_cached, bp_features = lower_model(tokvecs, train)
|
|
||||||
cdef np.ndarray cached
|
|
||||||
if not isinstance(gpu_cached, numpy.ndarray):
|
|
||||||
# Note the passing of cuda_stream here: it lets
|
|
||||||
# cupy make the copy asynchronously.
|
|
||||||
# We then have to block before first use.
|
|
||||||
cached = gpu_cached.get(stream=cuda_stream)
|
|
||||||
else:
|
|
||||||
cached = gpu_cached
|
|
||||||
if not isinstance(lower_model.get_param("b"), numpy.ndarray):
|
|
||||||
self.bias = lower_model.get_param("b").get(stream=cuda_stream)
|
|
||||||
else:
|
|
||||||
self.bias = lower_model.get_param("b")
|
|
||||||
self.nF = cached.shape[1]
|
|
||||||
if lower_model.has_dim("nP"):
|
|
||||||
self.nP = lower_model.get_dim("nP")
|
|
||||||
else:
|
|
||||||
self.nP = 1
|
|
||||||
self.nO = cached.shape[2]
|
|
||||||
self.ops = lower_model.ops
|
|
||||||
self.numpy_ops = NumpyOps()
|
|
||||||
self._cpu_ops = get_ops("cpu") if isinstance(self.ops, CupyOps) else self.ops
|
|
||||||
assert activation in (None, "relu", "maxout")
|
|
||||||
self.activation = activation
|
|
||||||
self._is_synchronized = False
|
|
||||||
self._cuda_stream = cuda_stream
|
|
||||||
self._cached = cached
|
|
||||||
self._bp_hiddens = bp_features
|
|
||||||
|
|
||||||
cdef const float* get_feat_weights(self) except NULL:
|
|
||||||
if not self._is_synchronized and self._cuda_stream is not None:
|
|
||||||
self._cuda_stream.synchronize()
|
|
||||||
self._is_synchronized = True
|
|
||||||
return <float*>self._cached.data
|
|
||||||
|
|
||||||
def has_dim(self, name):
|
|
||||||
if name == "nF":
|
|
||||||
return self.nF if self.nF is not None else True
|
|
||||||
elif name == "nP":
|
|
||||||
return self.nP if self.nP is not None else True
|
|
||||||
elif name == "nO":
|
|
||||||
return self.nO if self.nO is not None else True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_dim(self, name):
|
|
||||||
if name == "nF":
|
|
||||||
return self.nF
|
|
||||||
elif name == "nP":
|
|
||||||
return self.nP
|
|
||||||
elif name == "nO":
|
|
||||||
return self.nO
|
|
||||||
else:
|
|
||||||
raise ValueError(Errors.E1033.format(name=name))
|
|
||||||
|
|
||||||
def set_dim(self, name, value):
|
|
||||||
if name == "nF":
|
|
||||||
self.nF = value
|
|
||||||
elif name == "nP":
|
|
||||||
self.nP = value
|
|
||||||
elif name == "nO":
|
|
||||||
self.nO = value
|
|
||||||
else:
|
|
||||||
raise ValueError(Errors.E1033.format(name=name))
|
|
||||||
|
|
||||||
def __call__(self, X, bint is_train):
|
|
||||||
if is_train:
|
|
||||||
return self.begin_update(X)
|
|
||||||
else:
|
|
||||||
return self.predict(X), lambda X: X
|
|
||||||
|
|
||||||
def predict(self, X):
|
|
||||||
return self.begin_update(X)[0]
|
|
||||||
|
|
||||||
def begin_update(self, token_ids):
|
|
||||||
cdef np.ndarray state_vector = numpy.zeros(
|
|
||||||
(token_ids.shape[0], self.nO, self.nP), dtype='f')
|
|
||||||
# This is tricky, but (assuming GPU available);
|
|
||||||
# - Input to forward on CPU
|
|
||||||
# - Output from forward on CPU
|
|
||||||
# - Input to backward on GPU!
|
|
||||||
# - Output from backward on GPU
|
|
||||||
bp_hiddens = self._bp_hiddens
|
|
||||||
|
|
||||||
cdef CBlas cblas = self._cpu_ops.cblas()
|
|
||||||
|
|
||||||
feat_weights = self.get_feat_weights()
|
|
||||||
cdef int[:, ::1] ids = token_ids
|
|
||||||
sum_state_features(cblas, <float*>state_vector.data,
|
|
||||||
feat_weights, &ids[0,0],
|
|
||||||
token_ids.shape[0], self.nF, self.nO*self.nP)
|
|
||||||
state_vector += self.bias
|
|
||||||
state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
|
|
||||||
|
|
||||||
def backward(d_state_vector_ids):
|
|
||||||
d_state_vector, token_ids = d_state_vector_ids
|
|
||||||
d_state_vector = bp_nonlinearity(d_state_vector)
|
|
||||||
d_tokens = bp_hiddens((d_state_vector, token_ids))
|
|
||||||
return d_tokens
|
|
||||||
return state_vector, backward
|
|
||||||
|
|
||||||
def _nonlinearity(self, state_vector):
|
|
||||||
if self.activation == "maxout":
|
|
||||||
return self._maxout_nonlinearity(state_vector)
|
|
||||||
else:
|
|
||||||
return self._relu_nonlinearity(state_vector)
|
|
||||||
|
|
||||||
def _maxout_nonlinearity(self, state_vector):
|
|
||||||
state_vector, mask = self.numpy_ops.maxout(state_vector)
|
|
||||||
# We're outputting to CPU, but we need this variable on GPU for the
|
|
||||||
# backward pass.
|
|
||||||
mask = self.ops.asarray(mask)
|
|
||||||
|
|
||||||
def backprop_maxout(d_best):
|
|
||||||
return self.ops.backprop_maxout(d_best, mask, self.nP)
|
|
||||||
|
|
||||||
return state_vector, backprop_maxout
|
|
||||||
|
|
||||||
def _relu_nonlinearity(self, state_vector):
|
|
||||||
state_vector = state_vector.reshape((state_vector.shape[0], -1))
|
|
||||||
mask = state_vector >= 0.
|
|
||||||
state_vector *= mask
|
|
||||||
# We're outputting to CPU, but we need this variable on GPU for the
|
|
||||||
# backward pass.
|
|
||||||
mask = self.ops.asarray(mask)
|
|
||||||
|
|
||||||
def backprop_relu(d_best):
|
|
||||||
d_best *= mask
|
|
||||||
return d_best.reshape((d_best.shape + (1,)))
|
|
||||||
|
|
||||||
return state_vector, backprop_relu
|
|
||||||
|
|
||||||
cdef inline int _arg_max(const float* scores, const int n_classes) nogil:
|
|
||||||
if n_classes == 2:
|
|
||||||
return 0 if scores[0] > scores[1] else 1
|
|
||||||
cdef int i
|
|
||||||
cdef int best = 0
|
|
||||||
cdef float mode = scores[0]
|
|
||||||
for i in range(1, n_classes):
|
|
||||||
if scores[i] > mode:
|
|
||||||
mode = scores[i]
|
|
||||||
best = i
|
|
||||||
return best
|
|
28
spacy/ml/tb_framework.pxd
Normal file
28
spacy/ml/tb_framework.pxd
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
from libc.stdint cimport int8_t
|
||||||
|
|
||||||
|
|
||||||
|
cdef struct SizesC:
|
||||||
|
int states
|
||||||
|
int classes
|
||||||
|
int hiddens
|
||||||
|
int pieces
|
||||||
|
int feats
|
||||||
|
int embed_width
|
||||||
|
int tokens
|
||||||
|
|
||||||
|
|
||||||
|
cdef struct WeightsC:
|
||||||
|
const float* feat_weights
|
||||||
|
const float* feat_bias
|
||||||
|
const float* hidden_bias
|
||||||
|
const float* hidden_weights
|
||||||
|
const int8_t* seen_mask
|
||||||
|
|
||||||
|
|
||||||
|
cdef struct ActivationsC:
|
||||||
|
int* token_ids
|
||||||
|
float* unmaxed
|
||||||
|
float* hiddens
|
||||||
|
int* is_valid
|
||||||
|
int _curr_size
|
||||||
|
int _max_size
|
|
@ -1,50 +0,0 @@
|
||||||
from thinc.api import Model, noop
|
|
||||||
from .parser_model import ParserStepModel
|
|
||||||
from ..util import registry
|
|
||||||
|
|
||||||
|
|
||||||
@registry.layers("spacy.TransitionModel.v1")
|
|
||||||
def TransitionModel(
|
|
||||||
tok2vec, lower, upper, resize_output, dropout=0.2, unseen_classes=set()
|
|
||||||
):
|
|
||||||
"""Set up a stepwise transition-based model"""
|
|
||||||
if upper is None:
|
|
||||||
has_upper = False
|
|
||||||
upper = noop()
|
|
||||||
else:
|
|
||||||
has_upper = True
|
|
||||||
# don't define nO for this object, because we can't dynamically change it
|
|
||||||
return Model(
|
|
||||||
name="parser_model",
|
|
||||||
forward=forward,
|
|
||||||
dims={"nI": tok2vec.maybe_get_dim("nI")},
|
|
||||||
layers=[tok2vec, lower, upper],
|
|
||||||
refs={"tok2vec": tok2vec, "lower": lower, "upper": upper},
|
|
||||||
init=init,
|
|
||||||
attrs={
|
|
||||||
"has_upper": has_upper,
|
|
||||||
"unseen_classes": set(unseen_classes),
|
|
||||||
"resize_output": resize_output,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(model, X, is_train):
|
|
||||||
step_model = ParserStepModel(
|
|
||||||
X,
|
|
||||||
model.layers,
|
|
||||||
unseen_classes=model.attrs["unseen_classes"],
|
|
||||||
train=is_train,
|
|
||||||
has_upper=model.attrs["has_upper"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return step_model, step_model.finish_steps
|
|
||||||
|
|
||||||
|
|
||||||
def init(model, X=None, Y=None):
|
|
||||||
model.get_ref("tok2vec").initialize(X=X)
|
|
||||||
lower = model.get_ref("lower")
|
|
||||||
lower.initialize()
|
|
||||||
if model.attrs["has_upper"]:
|
|
||||||
statevecs = model.ops.alloc2f(2, lower.get_dim("nO"))
|
|
||||||
model.get_ref("upper").initialize(X=statevecs)
|
|
621
spacy/ml/tb_framework.pyx
Normal file
621
spacy/ml/tb_framework.pyx
Normal file
|
@ -0,0 +1,621 @@
|
||||||
|
# cython: infer_types=True, cdivision=True, boundscheck=False
|
||||||
|
from typing import List, Tuple, Any, Optional, TypeVar, cast
|
||||||
|
from libc.string cimport memset, memcpy
|
||||||
|
from libc.stdlib cimport calloc, free, realloc
|
||||||
|
from libcpp.vector cimport vector
|
||||||
|
import numpy
|
||||||
|
cimport numpy as np
|
||||||
|
from thinc.api import Model, normal_init, chain, list2array, Linear
|
||||||
|
from thinc.api import uniform_init, glorot_uniform_init, zero_init
|
||||||
|
from thinc.api import NumpyOps
|
||||||
|
from thinc.backends.cblas cimport CBlas, saxpy, sgemm
|
||||||
|
from thinc.types import Floats1d, Floats2d, Floats3d, Floats4d
|
||||||
|
from thinc.types import Ints1d, Ints2d
|
||||||
|
|
||||||
|
from ..errors import Errors
|
||||||
|
from ..pipeline._parser_internals import _beam_utils
|
||||||
|
from ..pipeline._parser_internals.batch import GreedyBatch
|
||||||
|
from ..pipeline._parser_internals._parser_utils cimport arg_max
|
||||||
|
from ..pipeline._parser_internals.transition_system cimport c_transition_batch, c_apply_actions
|
||||||
|
from ..pipeline._parser_internals.transition_system cimport TransitionSystem
|
||||||
|
from ..pipeline._parser_internals.stateclass cimport StateC, StateClass
|
||||||
|
from ..tokens.doc import Doc
|
||||||
|
from ..util import registry
|
||||||
|
|
||||||
|
|
||||||
|
State = Any # TODO
|
||||||
|
|
||||||
|
|
||||||
|
@registry.layers("spacy.TransitionModel.v2")
|
||||||
|
def TransitionModel(
|
||||||
|
*,
|
||||||
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
|
beam_width: int = 1,
|
||||||
|
beam_density: float = 0.0,
|
||||||
|
state_tokens: int,
|
||||||
|
hidden_width: int,
|
||||||
|
maxout_pieces: int,
|
||||||
|
nO: Optional[int] = None,
|
||||||
|
unseen_classes=set(),
|
||||||
|
) -> Model[Tuple[List[Doc], TransitionSystem], List[Tuple[State, List[Floats2d]]]]:
|
||||||
|
"""Set up a transition-based parsing model, using a maxout hidden
|
||||||
|
layer and a linear output layer.
|
||||||
|
"""
|
||||||
|
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||||
|
tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) # type: ignore
|
||||||
|
tok2vec_projected.set_dim("nO", hidden_width)
|
||||||
|
|
||||||
|
# FIXME: we use `output` as a container for the output layer's
|
||||||
|
# weights and biases. Thinc optimizers cannot handle resizing
|
||||||
|
# of parameters. So, when the parser model is resized, we
|
||||||
|
# construct a new `output` layer, which has a different key in
|
||||||
|
# the optimizer. Once the optimizer supports parameter resizing,
|
||||||
|
# we can replace the `output` layer by `output_W` and `output_b`
|
||||||
|
# parameters in this model.
|
||||||
|
output = Linear(nO=None, nI=hidden_width, init_W=zero_init)
|
||||||
|
|
||||||
|
return Model(
|
||||||
|
name="parser_model",
|
||||||
|
forward=forward,
|
||||||
|
init=init,
|
||||||
|
layers=[tok2vec_projected, output],
|
||||||
|
refs={
|
||||||
|
"tok2vec": tok2vec_projected,
|
||||||
|
"output": output,
|
||||||
|
},
|
||||||
|
params={
|
||||||
|
"hidden_W": None, # Floats2d W for the hidden layer
|
||||||
|
"hidden_b": None, # Floats1d bias for the hidden layer
|
||||||
|
"hidden_pad": None, # Floats1d padding for the hidden layer
|
||||||
|
},
|
||||||
|
dims={
|
||||||
|
"nO": None, # Output size
|
||||||
|
"nP": maxout_pieces,
|
||||||
|
"nH": hidden_width,
|
||||||
|
"nI": tok2vec_projected.maybe_get_dim("nO"),
|
||||||
|
"nF": state_tokens,
|
||||||
|
},
|
||||||
|
attrs={
|
||||||
|
"beam_width": beam_width,
|
||||||
|
"beam_density": beam_density,
|
||||||
|
"unseen_classes": set(unseen_classes),
|
||||||
|
"resize_output": resize_output,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_output(model: Model, new_nO: int) -> Model:
|
||||||
|
old_nO = model.maybe_get_dim("nO")
|
||||||
|
output = model.get_ref("output")
|
||||||
|
if old_nO is None:
|
||||||
|
model.set_dim("nO", new_nO)
|
||||||
|
output.set_dim("nO", new_nO)
|
||||||
|
output.initialize()
|
||||||
|
return model
|
||||||
|
elif new_nO <= old_nO:
|
||||||
|
return model
|
||||||
|
elif output.has_param("W"):
|
||||||
|
nH = model.get_dim("nH")
|
||||||
|
new_output = Linear(nO=new_nO, nI=nH, init_W=zero_init)
|
||||||
|
new_output.initialize()
|
||||||
|
new_W = new_output.get_param("W")
|
||||||
|
new_b = new_output.get_param("b")
|
||||||
|
old_W = output.get_param("W")
|
||||||
|
old_b = output.get_param("b")
|
||||||
|
new_W[:old_nO] = old_W # type: ignore
|
||||||
|
new_b[:old_nO] = old_b # type: ignore
|
||||||
|
for i in range(old_nO, new_nO):
|
||||||
|
model.attrs["unseen_classes"].add(i)
|
||||||
|
model.layers[-1] = new_output
|
||||||
|
model.set_ref("output", new_output)
|
||||||
|
# TODO: Avoid this private intrusion
|
||||||
|
model._dims["nO"] = new_nO
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def init(
|
||||||
|
model,
|
||||||
|
X: Optional[Tuple[List[Doc], TransitionSystem]] = None,
|
||||||
|
Y: Optional[Tuple[List[State], List[Floats2d]]] = None,
|
||||||
|
):
|
||||||
|
if X is not None:
|
||||||
|
docs, moves = X
|
||||||
|
model.get_ref("tok2vec").initialize(X=docs)
|
||||||
|
else:
|
||||||
|
model.get_ref("tok2vec").initialize()
|
||||||
|
inferred_nO = _infer_nO(Y)
|
||||||
|
if inferred_nO is not None:
|
||||||
|
current_nO = model.maybe_get_dim("nO")
|
||||||
|
if current_nO is None or current_nO != inferred_nO:
|
||||||
|
model.attrs["resize_output"](model, inferred_nO)
|
||||||
|
nO = model.get_dim("nO")
|
||||||
|
nP = model.get_dim("nP")
|
||||||
|
nH = model.get_dim("nH")
|
||||||
|
nI = model.get_dim("nI")
|
||||||
|
nF = model.get_dim("nF")
|
||||||
|
ops = model.ops
|
||||||
|
|
||||||
|
Wl = ops.alloc2f(nH * nP, nF * nI)
|
||||||
|
bl = ops.alloc1f(nH * nP)
|
||||||
|
padl = ops.alloc1f(nI)
|
||||||
|
# Wl = zero_init(ops, Wl.shape)
|
||||||
|
Wl = glorot_uniform_init(ops, Wl.shape)
|
||||||
|
padl = uniform_init(ops, padl.shape) # type: ignore
|
||||||
|
# TODO: Experiment with whether better to initialize output_W
|
||||||
|
model.set_param("hidden_W", Wl)
|
||||||
|
model.set_param("hidden_b", bl)
|
||||||
|
model.set_param("hidden_pad", padl)
|
||||||
|
# model = _lsuv_init(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class TransitionModelInputs:
|
||||||
|
"""
|
||||||
|
Input to transition model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# dataclass annotation is not yet supported in Cython 0.29.x,
|
||||||
|
# so, we'll do something close to it.
|
||||||
|
|
||||||
|
actions: Optional[List[Ints1d]]
|
||||||
|
docs: List[Doc]
|
||||||
|
max_moves: int
|
||||||
|
moves: TransitionSystem
|
||||||
|
states: Optional[List[State]]
|
||||||
|
|
||||||
|
__slots__ = [
|
||||||
|
"actions",
|
||||||
|
"docs",
|
||||||
|
"max_moves",
|
||||||
|
"moves",
|
||||||
|
"states",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
docs: List[Doc],
|
||||||
|
moves: TransitionSystem,
|
||||||
|
actions: Optional[List[Ints1d]]=None,
|
||||||
|
max_moves: int=0,
|
||||||
|
states: Optional[List[State]]=None):
|
||||||
|
"""
|
||||||
|
actions (Optional[List[Ints1d]]): actions to apply for each Doc.
|
||||||
|
docs (List[Doc]): Docs to predict transition sequences for.
|
||||||
|
max_moves: (int): the maximum number of moves to apply, values less
|
||||||
|
than 1 will apply moves to states until they are final states.
|
||||||
|
moves (TransitionSystem): the transition system to use when predicting
|
||||||
|
the transition sequences.
|
||||||
|
states (Optional[List[States]]): the initial states to predict the
|
||||||
|
transition sequences for. When absent, the initial states are
|
||||||
|
initialized from the provided Docs.
|
||||||
|
"""
|
||||||
|
self.actions = actions
|
||||||
|
self.docs = docs
|
||||||
|
self.moves = moves
|
||||||
|
self.max_moves = max_moves
|
||||||
|
self.states = states
|
||||||
|
|
||||||
|
|
||||||
|
def forward(model, inputs: TransitionModelInputs, is_train: bool):
|
||||||
|
docs = inputs.docs
|
||||||
|
moves = inputs.moves
|
||||||
|
actions = inputs.actions
|
||||||
|
|
||||||
|
beam_width = model.attrs["beam_width"]
|
||||||
|
hidden_pad = model.get_param("hidden_pad")
|
||||||
|
tok2vec = model.get_ref("tok2vec")
|
||||||
|
|
||||||
|
states = moves.init_batch(docs) if inputs.states is None else inputs.states
|
||||||
|
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
|
||||||
|
tokvecs = model.ops.xp.vstack((tokvecs, hidden_pad))
|
||||||
|
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
|
||||||
|
seen_mask = _get_seen_mask(model)
|
||||||
|
|
||||||
|
if not is_train and beam_width == 1 and isinstance(model.ops, NumpyOps):
|
||||||
|
# Note: max_moves is only used during training, so we don't need to
|
||||||
|
# pass it to the greedy inference path.
|
||||||
|
return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions)
|
||||||
|
else:
|
||||||
|
return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec,
|
||||||
|
feats, backprop_feats, seen_mask, is_train, actions=actions,
|
||||||
|
max_moves=inputs.max_moves)
|
||||||
|
|
||||||
|
|
||||||
|
def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats,
|
||||||
|
np.ndarray[np.npy_bool, ndim=1] seen_mask, actions: Optional[List[Ints1d]]=None):
|
||||||
|
cdef vector[StateC*] c_states
|
||||||
|
cdef StateClass state
|
||||||
|
for state in states:
|
||||||
|
if not state.is_final():
|
||||||
|
c_states.push_back(state.c)
|
||||||
|
weights = _get_c_weights(model, <float*>feats.data, seen_mask)
|
||||||
|
# Precomputed features have rows for each token, plus one for padding.
|
||||||
|
cdef int n_tokens = feats.shape[0] - 1
|
||||||
|
sizes = _get_c_sizes(model, c_states.size(), n_tokens)
|
||||||
|
cdef CBlas cblas = model.ops.cblas()
|
||||||
|
scores = _parse_batch(cblas, moves, &c_states[0], weights, sizes, actions=actions)
|
||||||
|
|
||||||
|
def backprop(dY):
|
||||||
|
raise ValueError(Errors.E4004)
|
||||||
|
|
||||||
|
return (states, scores), backprop
|
||||||
|
|
||||||
|
cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
|
||||||
|
WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None):
|
||||||
|
cdef int i, j
|
||||||
|
cdef vector[StateC *] unfinished
|
||||||
|
cdef ActivationsC activations = _alloc_activations(sizes)
|
||||||
|
cdef np.ndarray step_scores
|
||||||
|
cdef np.ndarray step_actions
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
while sizes.states >= 1:
|
||||||
|
step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f")
|
||||||
|
step_actions = actions[0] if actions is not None else None
|
||||||
|
with nogil:
|
||||||
|
_predict_states(cblas, &activations, <float*>step_scores.data, states, &weights, sizes)
|
||||||
|
if actions is None:
|
||||||
|
# Validate actions, argmax, take action.
|
||||||
|
c_transition_batch(moves, states, <const float*>step_scores.data, sizes.classes,
|
||||||
|
sizes.states)
|
||||||
|
else:
|
||||||
|
c_apply_actions(moves, states, <const int*>step_actions.data, sizes.states)
|
||||||
|
for i in range(sizes.states):
|
||||||
|
if not states[i].is_final():
|
||||||
|
unfinished.push_back(states[i])
|
||||||
|
for i in range(unfinished.size()):
|
||||||
|
states[i] = unfinished[i]
|
||||||
|
sizes.states = unfinished.size()
|
||||||
|
scores.append(step_scores)
|
||||||
|
unfinished.clear()
|
||||||
|
actions = actions[1:] if actions is not None else None
|
||||||
|
_free_activations(&activations)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def _forward_fallback(
|
||||||
|
model: Model,
|
||||||
|
moves: TransitionSystem,
|
||||||
|
states: List[StateClass],
|
||||||
|
tokvecs, backprop_tok2vec,
|
||||||
|
feats,
|
||||||
|
backprop_feats,
|
||||||
|
seen_mask,
|
||||||
|
is_train: bool,
|
||||||
|
actions: Optional[List[Ints1d]]=None,
|
||||||
|
max_moves: int=0):
|
||||||
|
nF = model.get_dim("nF")
|
||||||
|
output = model.get_ref("output")
|
||||||
|
hidden_b = model.get_param("hidden_b")
|
||||||
|
nH = model.get_dim("nH")
|
||||||
|
nP = model.get_dim("nP")
|
||||||
|
|
||||||
|
beam_width = model.attrs["beam_width"]
|
||||||
|
beam_density = model.attrs["beam_density"]
|
||||||
|
|
||||||
|
ops = model.ops
|
||||||
|
|
||||||
|
all_ids = []
|
||||||
|
all_which = []
|
||||||
|
all_statevecs = []
|
||||||
|
all_scores = []
|
||||||
|
if beam_width == 1:
|
||||||
|
batch = GreedyBatch(moves, states, None)
|
||||||
|
else:
|
||||||
|
batch = _beam_utils.BeamBatch(
|
||||||
|
moves, states, None, width=beam_width, density=beam_density
|
||||||
|
)
|
||||||
|
arange = ops.xp.arange(nF)
|
||||||
|
n_moves = 0
|
||||||
|
while not batch.is_done:
|
||||||
|
ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i")
|
||||||
|
for i, state in enumerate(batch.get_unfinished_states()):
|
||||||
|
state.set_context_tokens(ids, i, nF)
|
||||||
|
# Sum the state features, add the bias and apply the activation (maxout)
|
||||||
|
# to create the state vectors.
|
||||||
|
preacts2f = feats[ids, arange].sum(axis=1) # type: ignore
|
||||||
|
preacts2f += hidden_b
|
||||||
|
preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
|
||||||
|
assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape
|
||||||
|
statevecs, which = ops.maxout(preacts)
|
||||||
|
# We don't use output's backprop, since we want to backprop for
|
||||||
|
# all states at once, rather than a single state.
|
||||||
|
scores = output.predict(statevecs)
|
||||||
|
scores[:, seen_mask] = ops.xp.nanmin(scores)
|
||||||
|
# Transition the states, filtering out any that are finished.
|
||||||
|
cpu_scores = ops.to_numpy(scores)
|
||||||
|
if actions is None:
|
||||||
|
batch.advance(cpu_scores)
|
||||||
|
else:
|
||||||
|
batch.advance_with_actions(actions[0])
|
||||||
|
actions = actions[1:]
|
||||||
|
all_scores.append(scores)
|
||||||
|
if is_train:
|
||||||
|
# Remember intermediate results for the backprop.
|
||||||
|
all_ids.append(ids)
|
||||||
|
all_statevecs.append(statevecs)
|
||||||
|
all_which.append(which)
|
||||||
|
if n_moves >= max_moves >= 1:
|
||||||
|
break
|
||||||
|
n_moves += 1
|
||||||
|
|
||||||
|
def backprop_parser(d_states_d_scores):
|
||||||
|
ids = ops.xp.vstack(all_ids)
|
||||||
|
which = ops.xp.vstack(all_which)
|
||||||
|
statevecs = ops.xp.vstack(all_statevecs)
|
||||||
|
_, d_scores = d_states_d_scores
|
||||||
|
if model.attrs.get("unseen_classes"):
|
||||||
|
# If we have a negative gradient (i.e. the probability should
|
||||||
|
# increase) on any classes we filtered out as unseen, mark
|
||||||
|
# them as seen.
|
||||||
|
for clas in set(model.attrs["unseen_classes"]):
|
||||||
|
if (d_scores[:, clas] < 0).any():
|
||||||
|
model.attrs["unseen_classes"].remove(clas)
|
||||||
|
d_scores *= seen_mask == False
|
||||||
|
# Calculate the gradients for the parameters of the output layer.
|
||||||
|
# The weight gemm is (nS, nO) @ (nS, nH).T
|
||||||
|
output.inc_grad("b", d_scores.sum(axis=0))
|
||||||
|
output.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True))
|
||||||
|
# Now calculate d_statevecs, by backproping through the output linear layer.
|
||||||
|
# This gemm is (nS, nO) @ (nO, nH)
|
||||||
|
output_W = output.get_param("W")
|
||||||
|
d_statevecs = ops.gemm(d_scores, output_W)
|
||||||
|
# Backprop through the maxout activation
|
||||||
|
d_preacts = ops.backprop_maxout(d_statevecs, which, nP)
|
||||||
|
d_preacts2f = ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP)
|
||||||
|
model.inc_grad("hidden_b", d_preacts2f.sum(axis=0))
|
||||||
|
# We don't need to backprop the summation, because we pass back the IDs instead
|
||||||
|
d_state_features = backprop_feats((d_preacts2f, ids))
|
||||||
|
d_tokvecs = ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1])
|
||||||
|
ops.scatter_add(d_tokvecs, ids, d_state_features)
|
||||||
|
model.inc_grad("hidden_pad", d_tokvecs[-1])
|
||||||
|
return (backprop_tok2vec(d_tokvecs[:-1]), None)
|
||||||
|
|
||||||
|
return (list(batch), all_scores), backprop_parser
|
||||||
|
|
||||||
|
|
||||||
|
def _get_seen_mask(model: Model) -> numpy.array[bool, 1]:
|
||||||
|
mask = model.ops.xp.zeros(model.get_dim("nO"), dtype="bool")
|
||||||
|
for class_ in model.attrs.get("unseen_classes", set()):
|
||||||
|
mask[class_] = True
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):
|
||||||
|
W: Floats2d = model.get_param("hidden_W")
|
||||||
|
nF = model.get_dim("nF")
|
||||||
|
nH = model.get_dim("nH")
|
||||||
|
nP = model.get_dim("nP")
|
||||||
|
nI = model.get_dim("nI")
|
||||||
|
# The weights start out (nH * nP, nF * nI). Transpose and reshape to (nF * nH *nP, nI)
|
||||||
|
W3f = model.ops.reshape3f(W, nH * nP, nF, nI)
|
||||||
|
W3f = W3f.transpose((1, 0, 2))
|
||||||
|
W2f = model.ops.reshape2f(W3f, nF * nH * nP, nI)
|
||||||
|
assert X.shape == (X.shape[0], nI), X.shape
|
||||||
|
Yf_ = model.ops.gemm(X, W2f, trans2=True)
|
||||||
|
Yf = model.ops.reshape3f(Yf_, Yf_.shape[0], nF, nH * nP)
|
||||||
|
|
||||||
|
def backward(dY_ids: Tuple[Floats3d, Ints2d]):
|
||||||
|
# This backprop is particularly tricky, because we get back a different
|
||||||
|
# thing from what we put out. We put out an array of shape:
|
||||||
|
# (nB, nF, nH, nP), and get back:
|
||||||
|
# (nB, nH, nP) and ids (nB, nF)
|
||||||
|
# The ids tell us the values of nF, so we would have:
|
||||||
|
#
|
||||||
|
# dYf = zeros((nB, nF, nH, nP))
|
||||||
|
# for b in range(nB):
|
||||||
|
# for f in range(nF):
|
||||||
|
# dYf[b, ids[b, f]] += dY[b]
|
||||||
|
#
|
||||||
|
# However, we avoid building that array for efficiency -- and just pass
|
||||||
|
# in the indices.
|
||||||
|
dY, ids = dY_ids
|
||||||
|
dXf = model.ops.gemm(dY, W)
|
||||||
|
Xf = X[ids].reshape((ids.shape[0], -1))
|
||||||
|
dW = model.ops.gemm(dY, Xf, trans1=True)
|
||||||
|
model.inc_grad("hidden_W", dW)
|
||||||
|
return model.ops.reshape3f(dXf, dXf.shape[0], nF, nI)
|
||||||
|
|
||||||
|
return Yf, backward
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_nO(Y: Optional[Tuple[List[State], List[Floats2d]]]) -> Optional[int]:
|
||||||
|
if Y is None:
|
||||||
|
return None
|
||||||
|
_, scores = Y
|
||||||
|
if len(scores) == 0:
|
||||||
|
return None
|
||||||
|
assert scores[0].shape[0] >= 1
|
||||||
|
assert len(scores[0].shape) == 2
|
||||||
|
return scores[0].shape[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _lsuv_init(model: Model):
|
||||||
|
"""This is like the 'layer sequential unit variance', but instead
|
||||||
|
of taking the actual inputs, we randomly generate whitened data.
|
||||||
|
|
||||||
|
Why's this all so complicated? We have a huge number of inputs,
|
||||||
|
and the maxout unit makes guessing the dynamics tricky. Instead
|
||||||
|
we set the maxout weights to values that empirically result in
|
||||||
|
whitened outputs given whitened inputs.
|
||||||
|
"""
|
||||||
|
W = model.maybe_get_param("hidden_W")
|
||||||
|
if W is not None and W.any():
|
||||||
|
return
|
||||||
|
|
||||||
|
nF = model.get_dim("nF")
|
||||||
|
nH = model.get_dim("nH")
|
||||||
|
nP = model.get_dim("nP")
|
||||||
|
nI = model.get_dim("nI")
|
||||||
|
W = model.ops.alloc4f(nF, nH, nP, nI)
|
||||||
|
b = model.ops.alloc2f(nH, nP)
|
||||||
|
pad = model.ops.alloc4f(1, nF, nH, nP)
|
||||||
|
|
||||||
|
ops = model.ops
|
||||||
|
W = normal_init(ops, W.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI)))
|
||||||
|
pad = normal_init(ops, pad.shape, mean=1.0)
|
||||||
|
model.set_param("W", W)
|
||||||
|
model.set_param("b", b)
|
||||||
|
model.set_param("pad", pad)
|
||||||
|
|
||||||
|
ids = ops.alloc_f((5000, nF), dtype="f")
|
||||||
|
ids += ops.xp.random.uniform(0, 1000, ids.shape)
|
||||||
|
ids = ops.asarray(ids, dtype="i")
|
||||||
|
tokvecs = ops.alloc_f((5000, nI), dtype="f")
|
||||||
|
tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape(
|
||||||
|
tokvecs.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(ids, tokvecs):
|
||||||
|
# nS ids. nW tokvecs. Exclude the padding array.
|
||||||
|
hiddens, _ = _forward_precomputable_affine(model, tokvecs[:-1], False)
|
||||||
|
vectors = model.ops.alloc2f(ids.shape[0], nH * nP)
|
||||||
|
# need nS vectors
|
||||||
|
hiddens = hiddens.reshape((hiddens.shape[0] * nF, nH * nP))
|
||||||
|
model.ops.scatter_add(vectors, ids.flatten(), hiddens)
|
||||||
|
vectors3f = model.ops.reshape3f(vectors, vectors.shape[0], nH, nP)
|
||||||
|
vectors3f += b
|
||||||
|
return model.ops.maxout(vectors3f)[0]
|
||||||
|
|
||||||
|
tol_var = 0.01
|
||||||
|
tol_mean = 0.01
|
||||||
|
t_max = 10
|
||||||
|
W = cast(Floats4d, model.get_param("hidden_W").copy())
|
||||||
|
b = cast(Floats2d, model.get_param("hidden_b").copy())
|
||||||
|
for t_i in range(t_max):
|
||||||
|
acts1 = predict(ids, tokvecs)
|
||||||
|
var = model.ops.xp.var(acts1)
|
||||||
|
mean = model.ops.xp.mean(acts1)
|
||||||
|
if abs(var - 1.0) >= tol_var:
|
||||||
|
W /= model.ops.xp.sqrt(var)
|
||||||
|
model.set_param("hidden_W", W)
|
||||||
|
elif abs(mean) >= tol_mean:
|
||||||
|
b -= mean
|
||||||
|
model.set_param("hidden_b", b)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
cdef WeightsC _get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *:
|
||||||
|
output = model.get_ref("output")
|
||||||
|
cdef np.ndarray hidden_b = model.get_param("hidden_b")
|
||||||
|
cdef np.ndarray output_W = output.get_param("W")
|
||||||
|
cdef np.ndarray output_b = output.get_param("b")
|
||||||
|
|
||||||
|
cdef WeightsC weights
|
||||||
|
weights.feat_weights = feats
|
||||||
|
weights.feat_bias = <const float*>hidden_b.data
|
||||||
|
weights.hidden_weights = <const float *> output_W.data
|
||||||
|
weights.hidden_bias = <const float *> output_b.data
|
||||||
|
weights.seen_mask = <const int8_t*> seen_mask.data
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
cdef SizesC _get_c_sizes(model, int batch_size, int tokens) except *:
|
||||||
|
cdef SizesC sizes
|
||||||
|
sizes.states = batch_size
|
||||||
|
sizes.classes = model.get_dim("nO")
|
||||||
|
sizes.hiddens = model.get_dim("nH")
|
||||||
|
sizes.pieces = model.get_dim("nP")
|
||||||
|
sizes.feats = model.get_dim("nF")
|
||||||
|
sizes.embed_width = model.get_dim("nI")
|
||||||
|
sizes.tokens = tokens
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
|
||||||
|
cdef ActivationsC _alloc_activations(SizesC n) nogil:
|
||||||
|
cdef ActivationsC A
|
||||||
|
memset(&A, 0, sizeof(A))
|
||||||
|
_resize_activations(&A, n)
|
||||||
|
return A
|
||||||
|
|
||||||
|
|
||||||
|
cdef void _free_activations(const ActivationsC* A) nogil:
|
||||||
|
free(A.token_ids)
|
||||||
|
free(A.unmaxed)
|
||||||
|
free(A.hiddens)
|
||||||
|
free(A.is_valid)
|
||||||
|
|
||||||
|
|
||||||
|
cdef void _resize_activations(ActivationsC* A, SizesC n) nogil:
|
||||||
|
if n.states <= A._max_size:
|
||||||
|
A._curr_size = n.states
|
||||||
|
return
|
||||||
|
if A._max_size == 0:
|
||||||
|
A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
|
||||||
|
A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
|
||||||
|
A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
|
||||||
|
A.is_valid = <int*>calloc(n.states * n.classes, sizeof(A.is_valid[0]))
|
||||||
|
A._max_size = n.states
|
||||||
|
else:
|
||||||
|
A.token_ids = <int*>realloc(A.token_ids,
|
||||||
|
n.states * n.feats * sizeof(A.token_ids[0]))
|
||||||
|
A.unmaxed = <float*>realloc(A.unmaxed,
|
||||||
|
n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0]))
|
||||||
|
A.hiddens = <float*>realloc(A.hiddens,
|
||||||
|
n.states * n.hiddens * sizeof(A.hiddens[0]))
|
||||||
|
A.is_valid = <int*>realloc(A.is_valid,
|
||||||
|
n.states * n.classes * sizeof(A.is_valid[0]))
|
||||||
|
A._max_size = n.states
|
||||||
|
A._curr_size = n.states
|
||||||
|
|
||||||
|
|
||||||
|
cdef void _predict_states(CBlas cblas, ActivationsC* A, float* scores, StateC** states, const WeightsC* W, SizesC n) nogil:
|
||||||
|
_resize_activations(A, n)
|
||||||
|
for i in range(n.states):
|
||||||
|
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
|
||||||
|
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
|
||||||
|
_sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n)
|
||||||
|
for i in range(n.states):
|
||||||
|
saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1)
|
||||||
|
for j in range(n.hiddens):
|
||||||
|
index = i * n.hiddens * n.pieces + j * n.pieces
|
||||||
|
which = arg_max(&A.unmaxed[index], n.pieces)
|
||||||
|
A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
|
||||||
|
if W.hidden_weights == NULL:
|
||||||
|
memcpy(scores, A.hiddens, n.states * n.classes * sizeof(float))
|
||||||
|
else:
|
||||||
|
# Compute hidden-to-output
|
||||||
|
sgemm(cblas)(False, True, n.states, n.classes, n.hiddens,
|
||||||
|
1.0, <const float *>A.hiddens, n.hiddens,
|
||||||
|
<const float *>W.hidden_weights, n.hiddens,
|
||||||
|
0.0, scores, n.classes)
|
||||||
|
# Add bias
|
||||||
|
for i in range(n.states):
|
||||||
|
saxpy(cblas)(n.classes, 1., W.hidden_bias, 1, &scores[i*n.classes], 1)
|
||||||
|
# Set unseen classes to minimum value
|
||||||
|
i = 0
|
||||||
|
min_ = scores[0]
|
||||||
|
for i in range(1, n.states * n.classes):
|
||||||
|
if scores[i] < min_:
|
||||||
|
min_ = scores[i]
|
||||||
|
for i in range(n.states):
|
||||||
|
for j in range(n.classes):
|
||||||
|
if W.seen_mask[j]:
|
||||||
|
scores[i*n.classes+j] = min_
|
||||||
|
|
||||||
|
|
||||||
|
cdef void _sum_state_features(CBlas cblas, float* output,
|
||||||
|
const float* cached, const int* token_ids, SizesC n) nogil:
|
||||||
|
cdef int idx, b, f, i
|
||||||
|
cdef const float* feature
|
||||||
|
cdef int B = n.states
|
||||||
|
cdef int O = n.hiddens * n.pieces
|
||||||
|
cdef int F = n.feats
|
||||||
|
cdef int T = n.tokens
|
||||||
|
padding = cached + (T * F * O)
|
||||||
|
cdef int id_stride = F*O
|
||||||
|
cdef float one = 1.
|
||||||
|
for b in range(B):
|
||||||
|
for f in range(F):
|
||||||
|
if token_ids[f] < 0:
|
||||||
|
feature = &padding[f*O]
|
||||||
|
else:
|
||||||
|
idx = token_ids[f] * id_stride + f*O
|
||||||
|
feature = &cached[idx]
|
||||||
|
saxpy(cblas)(O, one, <const float*>feature, 1, &output[b*O], 1)
|
||||||
|
token_ids += F
|
||||||
|
|
|
@ -7,6 +7,7 @@ from cpython.ref cimport PyObject, Py_XDECREF
|
||||||
from ...typedefs cimport hash_t, class_t
|
from ...typedefs cimport hash_t, class_t
|
||||||
from .transition_system cimport TransitionSystem, Transition
|
from .transition_system cimport TransitionSystem, Transition
|
||||||
from ...errors import Errors
|
from ...errors import Errors
|
||||||
|
from .batch cimport Batch
|
||||||
from .search cimport Beam, MaxViolation
|
from .search cimport Beam, MaxViolation
|
||||||
from .search import MaxViolation
|
from .search import MaxViolation
|
||||||
from .stateclass cimport StateC, StateClass
|
from .stateclass cimport StateC, StateClass
|
||||||
|
@ -26,7 +27,7 @@ cdef int check_final_state(void* _state, void* extra_args) except -1:
|
||||||
return state.is_final()
|
return state.is_final()
|
||||||
|
|
||||||
|
|
||||||
cdef class BeamBatch(object):
|
cdef class BeamBatch(Batch):
|
||||||
cdef public TransitionSystem moves
|
cdef public TransitionSystem moves
|
||||||
cdef public object states
|
cdef public object states
|
||||||
cdef public object docs
|
cdef public object docs
|
||||||
|
|
2
spacy/pipeline/_parser_internals/_parser_utils.pxd
Normal file
2
spacy/pipeline/_parser_internals/_parser_utils.pxd
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
cdef int arg_max(const float* scores, const int n_classes) nogil
|
||||||
|
cdef int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil
|
22
spacy/pipeline/_parser_internals/_parser_utils.pyx
Normal file
22
spacy/pipeline/_parser_internals/_parser_utils.pyx
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# cython: infer_types=True
|
||||||
|
|
||||||
|
cdef inline int arg_max(const float* scores, const int n_classes) nogil:
|
||||||
|
if n_classes == 2:
|
||||||
|
return 0 if scores[0] > scores[1] else 1
|
||||||
|
cdef int i
|
||||||
|
cdef int best = 0
|
||||||
|
cdef float mode = scores[0]
|
||||||
|
for i in range(1, n_classes):
|
||||||
|
if scores[i] > mode:
|
||||||
|
mode = scores[i]
|
||||||
|
best = i
|
||||||
|
return best
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline int arg_max_if_valid(const float* scores, const int* is_valid, int n) nogil:
|
||||||
|
cdef int best = -1
|
||||||
|
for i in range(n):
|
||||||
|
if is_valid[i] >= 1:
|
||||||
|
if best == -1 or scores[i] > scores[best]:
|
||||||
|
best = i
|
||||||
|
return best
|
|
@ -6,7 +6,6 @@ cimport libcpp
|
||||||
from libcpp.unordered_map cimport unordered_map
|
from libcpp.unordered_map cimport unordered_map
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
from libcpp.set cimport set
|
from libcpp.set cimport set
|
||||||
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
|
|
||||||
from murmurhash.mrmr cimport hash64
|
from murmurhash.mrmr cimport hash64
|
||||||
|
|
||||||
from ...vocab cimport EMPTY_LEXEME
|
from ...vocab cimport EMPTY_LEXEME
|
||||||
|
@ -26,7 +25,7 @@ cdef struct ArcC:
|
||||||
|
|
||||||
|
|
||||||
cdef cppclass StateC:
|
cdef cppclass StateC:
|
||||||
int* _heads
|
vector[int] _heads
|
||||||
const TokenC* _sent
|
const TokenC* _sent
|
||||||
vector[int] _stack
|
vector[int] _stack
|
||||||
vector[int] _rebuffer
|
vector[int] _rebuffer
|
||||||
|
@ -34,31 +33,34 @@ cdef cppclass StateC:
|
||||||
unordered_map[int, vector[ArcC]] _left_arcs
|
unordered_map[int, vector[ArcC]] _left_arcs
|
||||||
unordered_map[int, vector[ArcC]] _right_arcs
|
unordered_map[int, vector[ArcC]] _right_arcs
|
||||||
vector[libcpp.bool] _unshiftable
|
vector[libcpp.bool] _unshiftable
|
||||||
|
vector[int] history
|
||||||
set[int] _sent_starts
|
set[int] _sent_starts
|
||||||
TokenC _empty_token
|
TokenC _empty_token
|
||||||
int length
|
int length
|
||||||
int offset
|
int offset
|
||||||
int _b_i
|
int _b_i
|
||||||
|
|
||||||
__init__(const TokenC* sent, int length) nogil:
|
__init__(const TokenC* sent, int length) nogil except +:
|
||||||
|
this._heads.resize(length, -1)
|
||||||
|
this._unshiftable.resize(length, False)
|
||||||
|
|
||||||
|
# Reserve memory ahead of time to minimize allocations during parsing.
|
||||||
|
# The initial capacity set here ideally reflects the expected average-case/majority usage.
|
||||||
|
cdef int init_capacity = 32
|
||||||
|
this._stack.reserve(init_capacity)
|
||||||
|
this._rebuffer.reserve(init_capacity)
|
||||||
|
this._ents.reserve(init_capacity)
|
||||||
|
this._left_arcs.reserve(init_capacity)
|
||||||
|
this._right_arcs.reserve(init_capacity)
|
||||||
|
this.history.reserve(init_capacity)
|
||||||
|
|
||||||
this._sent = sent
|
this._sent = sent
|
||||||
this._heads = <int*>calloc(length, sizeof(int))
|
|
||||||
if not (this._sent and this._heads):
|
|
||||||
with gil:
|
|
||||||
PyErr_SetFromErrno(MemoryError)
|
|
||||||
PyErr_CheckSignals()
|
|
||||||
this.offset = 0
|
this.offset = 0
|
||||||
this.length = length
|
this.length = length
|
||||||
this._b_i = 0
|
this._b_i = 0
|
||||||
for i in range(length):
|
|
||||||
this._heads[i] = -1
|
|
||||||
this._unshiftable.push_back(0)
|
|
||||||
memset(&this._empty_token, 0, sizeof(TokenC))
|
memset(&this._empty_token, 0, sizeof(TokenC))
|
||||||
this._empty_token.lex = &EMPTY_LEXEME
|
this._empty_token.lex = &EMPTY_LEXEME
|
||||||
|
|
||||||
__dealloc__():
|
|
||||||
free(this._heads)
|
|
||||||
|
|
||||||
void set_context_tokens(int* ids, int n) nogil:
|
void set_context_tokens(int* ids, int n) nogil:
|
||||||
cdef int i, j
|
cdef int i, j
|
||||||
if n == 1:
|
if n == 1:
|
||||||
|
@ -131,19 +133,20 @@ cdef cppclass StateC:
|
||||||
ids[i] = -1
|
ids[i] = -1
|
||||||
|
|
||||||
int S(int i) nogil const:
|
int S(int i) nogil const:
|
||||||
if i >= this._stack.size():
|
cdef int stack_size = this._stack.size()
|
||||||
|
if i >= stack_size or i < 0:
|
||||||
return -1
|
return -1
|
||||||
elif i < 0:
|
else:
|
||||||
return -1
|
return this._stack[stack_size - (i+1)]
|
||||||
return this._stack.at(this._stack.size() - (i+1))
|
|
||||||
|
|
||||||
int B(int i) nogil const:
|
int B(int i) nogil const:
|
||||||
|
cdef int buf_size = this._rebuffer.size()
|
||||||
if i < 0:
|
if i < 0:
|
||||||
return -1
|
return -1
|
||||||
elif i < this._rebuffer.size():
|
elif i < buf_size:
|
||||||
return this._rebuffer.at(this._rebuffer.size() - (i+1))
|
return this._rebuffer[buf_size - (i+1)]
|
||||||
else:
|
else:
|
||||||
b_i = this._b_i + (i - this._rebuffer.size())
|
b_i = this._b_i + (i - buf_size)
|
||||||
if b_i >= this.length:
|
if b_i >= this.length:
|
||||||
return -1
|
return -1
|
||||||
else:
|
else:
|
||||||
|
@ -242,7 +245,7 @@ cdef cppclass StateC:
|
||||||
return 0
|
return 0
|
||||||
elif this._sent[word].sent_start == 1:
|
elif this._sent[word].sent_start == 1:
|
||||||
return 1
|
return 1
|
||||||
elif this._sent_starts.count(word) >= 1:
|
elif this._sent_starts.const_find(word) != this._sent_starts.const_end():
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
@ -327,7 +330,7 @@ cdef cppclass StateC:
|
||||||
if item >= this._unshiftable.size():
|
if item >= this._unshiftable.size():
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
return this._unshiftable.at(item)
|
return this._unshiftable[item]
|
||||||
|
|
||||||
void set_reshiftable(int item) nogil:
|
void set_reshiftable(int item) nogil:
|
||||||
if item < this._unshiftable.size():
|
if item < this._unshiftable.size():
|
||||||
|
@ -347,6 +350,9 @@ cdef cppclass StateC:
|
||||||
this._heads[child] = head
|
this._heads[child] = head
|
||||||
|
|
||||||
void map_del_arc(unordered_map[int, vector[ArcC]]* heads_arcs, int h_i, int c_i) nogil:
|
void map_del_arc(unordered_map[int, vector[ArcC]]* heads_arcs, int h_i, int c_i) nogil:
|
||||||
|
cdef vector[ArcC]* arcs
|
||||||
|
cdef ArcC* arc
|
||||||
|
|
||||||
arcs_it = heads_arcs.find(h_i)
|
arcs_it = heads_arcs.find(h_i)
|
||||||
if arcs_it == heads_arcs.end():
|
if arcs_it == heads_arcs.end():
|
||||||
return
|
return
|
||||||
|
@ -355,12 +361,12 @@ cdef cppclass StateC:
|
||||||
if arcs.size() == 0:
|
if arcs.size() == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
arc = arcs.back()
|
arc = &arcs.back()
|
||||||
if arc.head == h_i and arc.child == c_i:
|
if arc.head == h_i and arc.child == c_i:
|
||||||
arcs.pop_back()
|
arcs.pop_back()
|
||||||
else:
|
else:
|
||||||
for i in range(arcs.size()-1):
|
for i in range(arcs.size()-1):
|
||||||
arc = arcs.at(i)
|
arc = &deref(arcs)[i]
|
||||||
if arc.head == h_i and arc.child == c_i:
|
if arc.head == h_i and arc.child == c_i:
|
||||||
arc.head = -1
|
arc.head = -1
|
||||||
arc.child = -1
|
arc.child = -1
|
||||||
|
@ -400,10 +406,11 @@ cdef cppclass StateC:
|
||||||
this._rebuffer = src._rebuffer
|
this._rebuffer = src._rebuffer
|
||||||
this._sent_starts = src._sent_starts
|
this._sent_starts = src._sent_starts
|
||||||
this._unshiftable = src._unshiftable
|
this._unshiftable = src._unshiftable
|
||||||
memcpy(this._heads, src._heads, this.length * sizeof(this._heads[0]))
|
this._heads = src._heads
|
||||||
this._ents = src._ents
|
this._ents = src._ents
|
||||||
this._left_arcs = src._left_arcs
|
this._left_arcs = src._left_arcs
|
||||||
this._right_arcs = src._right_arcs
|
this._right_arcs = src._right_arcs
|
||||||
this._b_i = src._b_i
|
this._b_i = src._b_i
|
||||||
this.offset = src.offset
|
this.offset = src.offset
|
||||||
this._empty_token = src._empty_token
|
this._empty_token = src._empty_token
|
||||||
|
this.history = src.history
|
||||||
|
|
|
@ -773,6 +773,8 @@ cdef class ArcEager(TransitionSystem):
|
||||||
return list(arcs)
|
return list(arcs)
|
||||||
|
|
||||||
def has_gold(self, Example eg, start=0, end=None):
|
def has_gold(self, Example eg, start=0, end=None):
|
||||||
|
if end is not None and end < 0:
|
||||||
|
end = None
|
||||||
for word in eg.y[start:end]:
|
for word in eg.y[start:end]:
|
||||||
if word.dep != 0:
|
if word.dep != 0:
|
||||||
return True
|
return True
|
||||||
|
@ -858,6 +860,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
state.print_state()
|
state.print_state()
|
||||||
)))
|
)))
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
|
state.c.history.push_back(i)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
failed = False
|
failed = False
|
||||||
|
|
2
spacy/pipeline/_parser_internals/batch.pxd
Normal file
2
spacy/pipeline/_parser_internals/batch.pxd
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
cdef class Batch:
|
||||||
|
pass
|
52
spacy/pipeline/_parser_internals/batch.pyx
Normal file
52
spacy/pipeline/_parser_internals/batch.pyx
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
TransitionSystem = Any # TODO
|
||||||
|
|
||||||
|
cdef class Batch:
|
||||||
|
def advance(self, scores):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_states(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_done(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_unfinished_states(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class GreedyBatch(Batch):
|
||||||
|
def __init__(self, moves: TransitionSystem, states, golds):
|
||||||
|
self._moves = moves
|
||||||
|
self._states = states
|
||||||
|
self._next_states = [s for s in states if not s.is_final()]
|
||||||
|
|
||||||
|
def advance(self, scores):
|
||||||
|
self._next_states = self._moves.transition_states(self._next_states, scores)
|
||||||
|
|
||||||
|
def advance_with_actions(self, actions):
|
||||||
|
self._next_states = self._moves.apply_actions(self._next_states, actions)
|
||||||
|
|
||||||
|
def get_states(self):
|
||||||
|
return self._states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_done(self):
|
||||||
|
return all(s.is_final() for s in self._states)
|
||||||
|
|
||||||
|
def get_unfinished_states(self):
|
||||||
|
return [st for st in self._states if not st.is_final()]
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self._states[i]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._states)
|
|
@ -156,7 +156,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
if token.ent_type:
|
if token.ent_type:
|
||||||
labels.add(token.ent_type_)
|
labels.add(token.ent_type_)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def move_name(self, int move, attr_t label):
|
def move_name(self, int move, attr_t label):
|
||||||
if move == OUT:
|
if move == OUT:
|
||||||
return 'O'
|
return 'O'
|
||||||
|
@ -306,6 +306,8 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
for span in eg.y.spans.get(neg_key, []):
|
for span in eg.y.spans.get(neg_key, []):
|
||||||
if span.start >= start and span.end <= end:
|
if span.start >= start and span.end <= end:
|
||||||
return True
|
return True
|
||||||
|
if end is not None and end < 0:
|
||||||
|
end = None
|
||||||
for word in eg.y[start:end]:
|
for word in eg.y[start:end]:
|
||||||
if word.ent_iob != 0:
|
if word.ent_iob != 0:
|
||||||
return True
|
return True
|
||||||
|
@ -646,7 +648,7 @@ cdef class Unit:
|
||||||
cost += 1
|
cost += 1
|
||||||
break
|
break
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef class Out:
|
cdef class Out:
|
||||||
|
|
|
@ -20,6 +20,10 @@ cdef class StateClass:
|
||||||
if self._borrowed != 1:
|
if self._borrowed != 1:
|
||||||
del self.c
|
del self.c
|
||||||
|
|
||||||
|
@property
|
||||||
|
def history(self):
|
||||||
|
return list(self.c.history)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stack(self):
|
def stack(self):
|
||||||
return [self.S(i) for i in range(self.c.stack_depth())]
|
return [self.S(i) for i in range(self.c.stack_depth())]
|
||||||
|
@ -176,3 +180,6 @@ cdef class StateClass:
|
||||||
|
|
||||||
def clone(self, StateClass src):
|
def clone(self, StateClass src):
|
||||||
self.c.clone(src.c)
|
self.c.clone(src.c)
|
||||||
|
|
||||||
|
def set_context_tokens(self, int[:, :] output, int row, int n_feats):
|
||||||
|
self.c.set_context_tokens(&output[row, 0], n_feats)
|
||||||
|
|
|
@ -53,3 +53,10 @@ cdef class TransitionSystem:
|
||||||
|
|
||||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||||
const StateC* state, gold) except -1
|
const StateC* state, gold) except -1
|
||||||
|
|
||||||
|
|
||||||
|
cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions,
|
||||||
|
int batch_size) nogil
|
||||||
|
|
||||||
|
cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
|
||||||
|
int nr_class, int batch_size) nogil
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
# cython: infer_types=True
|
# cython: infer_types=True
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
|
from libc.stdlib cimport calloc, free
|
||||||
|
from libcpp.vector cimport vector
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import srsly
|
import srsly
|
||||||
|
@ -10,6 +12,7 @@ from ...typedefs cimport weight_t, attr_t
|
||||||
from ...tokens.doc cimport Doc
|
from ...tokens.doc cimport Doc
|
||||||
from ...structs cimport TokenC
|
from ...structs cimport TokenC
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
|
from ._parser_utils cimport arg_max_if_valid
|
||||||
|
|
||||||
from ...errors import Errors
|
from ...errors import Errors
|
||||||
from ... import util
|
from ... import util
|
||||||
|
@ -73,7 +76,18 @@ cdef class TransitionSystem:
|
||||||
offset += len(doc)
|
offset += len(doc)
|
||||||
return states
|
return states
|
||||||
|
|
||||||
|
def follow_history(self, doc, history):
|
||||||
|
cdef int clas
|
||||||
|
cdef StateClass state = StateClass(doc)
|
||||||
|
for clas in history:
|
||||||
|
action = self.c[clas]
|
||||||
|
action.do(state.c, action.label)
|
||||||
|
state.c.history.push_back(clas)
|
||||||
|
return state
|
||||||
|
|
||||||
def get_oracle_sequence(self, Example example, _debug=False):
|
def get_oracle_sequence(self, Example example, _debug=False):
|
||||||
|
if not self.has_gold(example):
|
||||||
|
return []
|
||||||
states, golds, _ = self.init_gold_batch([example])
|
states, golds, _ = self.init_gold_batch([example])
|
||||||
if not states:
|
if not states:
|
||||||
return []
|
return []
|
||||||
|
@ -85,6 +99,8 @@ cdef class TransitionSystem:
|
||||||
return self.get_oracle_sequence_from_state(state, gold)
|
return self.get_oracle_sequence_from_state(state, gold)
|
||||||
|
|
||||||
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
|
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
|
||||||
|
if state.is_final():
|
||||||
|
return []
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||||
assert self.n_moves > 0
|
assert self.n_moves > 0
|
||||||
|
@ -110,6 +126,7 @@ cdef class TransitionSystem:
|
||||||
"S0 head?", str(state.has_head(state.S(0))),
|
"S0 head?", str(state.has_head(state.S(0))),
|
||||||
)))
|
)))
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
|
state.c.history.push_back(i)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if _debug:
|
if _debug:
|
||||||
|
@ -137,6 +154,28 @@ cdef class TransitionSystem:
|
||||||
raise ValueError(Errors.E170.format(name=name))
|
raise ValueError(Errors.E170.format(name=name))
|
||||||
action = self.lookup_transition(name)
|
action = self.lookup_transition(name)
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
|
state.c.history.push_back(action.clas)
|
||||||
|
|
||||||
|
def apply_actions(self, states, const int[::1] actions):
|
||||||
|
assert len(states) == actions.shape[0]
|
||||||
|
cdef StateClass state
|
||||||
|
cdef vector[StateC*] c_states
|
||||||
|
c_states.resize(len(states))
|
||||||
|
cdef int i
|
||||||
|
for (i, state) in enumerate(states):
|
||||||
|
c_states[i] = state.c
|
||||||
|
c_apply_actions(self, &c_states[0], &actions[0], actions.shape[0])
|
||||||
|
return [state for state in states if not state.c.is_final()]
|
||||||
|
|
||||||
|
def transition_states(self, states, float[:, ::1] scores):
|
||||||
|
assert len(states) == scores.shape[0]
|
||||||
|
cdef StateClass state
|
||||||
|
cdef float* c_scores = &scores[0, 0]
|
||||||
|
cdef vector[StateC*] c_states
|
||||||
|
for state in states:
|
||||||
|
c_states.push_back(state.c)
|
||||||
|
c_transition_batch(self, &c_states[0], c_scores, scores.shape[1], scores.shape[0])
|
||||||
|
return [state for state in states if not state.c.is_final()]
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name) except *:
|
cdef Transition lookup_transition(self, object name) except *:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -250,3 +289,35 @@ cdef class TransitionSystem:
|
||||||
self.cfg.update(msg['cfg'])
|
self.cfg.update(msg['cfg'])
|
||||||
self.initialize_actions(labels)
|
self.initialize_actions(labels)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions,
|
||||||
|
int batch_size) nogil:
|
||||||
|
cdef int i
|
||||||
|
cdef Transition action
|
||||||
|
cdef StateC* state
|
||||||
|
for i in range(batch_size):
|
||||||
|
state = states[i]
|
||||||
|
action = moves.c[actions[i]]
|
||||||
|
action.do(state, action.label)
|
||||||
|
state.history.push_back(action.clas)
|
||||||
|
|
||||||
|
|
||||||
|
cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores,
|
||||||
|
int nr_class, int batch_size) nogil:
|
||||||
|
is_valid = <int*>calloc(moves.n_moves, sizeof(int))
|
||||||
|
cdef int i, guess
|
||||||
|
cdef Transition action
|
||||||
|
for i in range(batch_size):
|
||||||
|
moves.set_valid(is_valid, states[i])
|
||||||
|
guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class)
|
||||||
|
if guess == -1:
|
||||||
|
# This shouldn't happen, but it's hard to raise an error here,
|
||||||
|
# and we don't want to infinite loop. So, force to end state.
|
||||||
|
states[i].force_final()
|
||||||
|
else:
|
||||||
|
action = moves.c[guess]
|
||||||
|
action.do(states[i], action.label)
|
||||||
|
states[i].history.push_back(guess)
|
||||||
|
free(is_valid)
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,8 @@ from typing import Optional, Iterable, Callable
|
||||||
from thinc.api import Model, Config
|
from thinc.api import Model, Config
|
||||||
|
|
||||||
from ._parser_internals.transition_system import TransitionSystem
|
from ._parser_internals.transition_system import TransitionSystem
|
||||||
from .transition_parser cimport Parser
|
from .transition_parser import Parser
|
||||||
from ._parser_internals.arc_eager cimport ArcEager
|
from ._parser_internals.arc_eager import ArcEager
|
||||||
|
|
||||||
from .functions import merge_subtokens
|
from .functions import merge_subtokens
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
|
@ -18,12 +18,11 @@ from ..util import registry
|
||||||
|
|
||||||
default_model_config = """
|
default_model_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v2"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
state_type = "parser"
|
state_type = "parser"
|
||||||
extra_state_tokens = false
|
extra_state_tokens = false
|
||||||
hidden_width = 64
|
hidden_width = 64
|
||||||
maxout_pieces = 2
|
maxout_pieces = 2
|
||||||
use_upper = true
|
|
||||||
|
|
||||||
[model.tok2vec]
|
[model.tok2vec]
|
||||||
@architectures = "spacy.HashEmbedCNN.v2"
|
@architectures = "spacy.HashEmbedCNN.v2"
|
||||||
|
@ -123,6 +122,7 @@ def make_parser(
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"beam_parser",
|
"beam_parser",
|
||||||
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
|
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
|
||||||
|
@ -228,6 +228,7 @@ def parser_score(examples, **kwargs):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/dependencyparser#score
|
DOCS: https://spacy.io/api/dependencyparser#score
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def has_sents(doc):
|
def has_sents(doc):
|
||||||
return doc.has_annotation("SENT_START")
|
return doc.has_annotation("SENT_START")
|
||||||
|
|
||||||
|
@ -235,8 +236,11 @@ def parser_score(examples, **kwargs):
|
||||||
dep = getattr(token, attr)
|
dep = getattr(token, attr)
|
||||||
dep = token.vocab.strings.as_string(dep).lower()
|
dep = token.vocab.strings.as_string(dep).lower()
|
||||||
return dep
|
return dep
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
results.update(Scorer.score_spans(examples, "sents", has_annotation=has_sents, **kwargs))
|
results.update(
|
||||||
|
Scorer.score_spans(examples, "sents", has_annotation=has_sents, **kwargs)
|
||||||
|
)
|
||||||
kwargs.setdefault("getter", dep_getter)
|
kwargs.setdefault("getter", dep_getter)
|
||||||
kwargs.setdefault("ignore_labels", ("p", "punct"))
|
kwargs.setdefault("ignore_labels", ("p", "punct"))
|
||||||
results.update(Scorer.score_deps(examples, "dep", **kwargs))
|
results.update(Scorer.score_deps(examples, "dep", **kwargs))
|
||||||
|
@ -249,11 +253,12 @@ def make_parser_scorer():
|
||||||
return parser_score
|
return parser_score
|
||||||
|
|
||||||
|
|
||||||
cdef class DependencyParser(Parser):
|
class DependencyParser(Parser):
|
||||||
"""Pipeline component for dependency parsing.
|
"""Pipeline component for dependency parsing.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/dependencyparser
|
DOCS: https://spacy.io/api/dependencyparser
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TransitionSystem = ArcEager
|
TransitionSystem = ArcEager
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -273,8 +278,7 @@ cdef class DependencyParser(Parser):
|
||||||
incorrect_spans_key=None,
|
incorrect_spans_key=None,
|
||||||
scorer=parser_score,
|
scorer=parser_score,
|
||||||
):
|
):
|
||||||
"""Create a DependencyParser.
|
"""Create a DependencyParser."""
|
||||||
"""
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vocab,
|
vocab,
|
||||||
model,
|
model,
|
|
@ -4,22 +4,22 @@ from typing import Optional, Iterable, Callable
|
||||||
from thinc.api import Model, Config
|
from thinc.api import Model, Config
|
||||||
|
|
||||||
from ._parser_internals.transition_system import TransitionSystem
|
from ._parser_internals.transition_system import TransitionSystem
|
||||||
from .transition_parser cimport Parser
|
from .transition_parser import Parser
|
||||||
from ._parser_internals.ner cimport BiluoPushDown
|
from ._parser_internals.ner import BiluoPushDown
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..scorer import get_ner_prf, PRFScore
|
from ..scorer import get_ner_prf, PRFScore
|
||||||
|
from ..training import validate_examples
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
from ..training import remove_bilu_prefix
|
from ..training import remove_bilu_prefix
|
||||||
|
|
||||||
|
|
||||||
default_model_config = """
|
default_model_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v2"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
state_type = "ner"
|
state_type = "ner"
|
||||||
extra_state_tokens = false
|
extra_state_tokens = false
|
||||||
hidden_width = 64
|
hidden_width = 64
|
||||||
maxout_pieces = 2
|
maxout_pieces = 2
|
||||||
use_upper = true
|
|
||||||
|
|
||||||
[model.tok2vec]
|
[model.tok2vec]
|
||||||
@architectures = "spacy.HashEmbedCNN.v2"
|
@architectures = "spacy.HashEmbedCNN.v2"
|
||||||
|
@ -44,8 +44,12 @@ DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
"incorrect_spans_key": None,
|
"incorrect_spans_key": None,
|
||||||
"scorer": {"@scorers": "spacy.ner_scorer.v1"},
|
"scorer": {"@scorers": "spacy.ner_scorer.v1"},
|
||||||
},
|
},
|
||||||
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
|
default_score_weights={
|
||||||
|
"ents_f": 1.0,
|
||||||
|
"ents_p": 0.0,
|
||||||
|
"ents_r": 0.0,
|
||||||
|
"ents_per_type": None,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
def make_ner(
|
def make_ner(
|
||||||
nlp: Language,
|
nlp: Language,
|
||||||
|
@ -98,6 +102,7 @@ def make_ner(
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"beam_ner",
|
"beam_ner",
|
||||||
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
|
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
|
||||||
|
@ -111,7 +116,12 @@ def make_ner(
|
||||||
"incorrect_spans_key": None,
|
"incorrect_spans_key": None,
|
||||||
"scorer": None,
|
"scorer": None,
|
||||||
},
|
},
|
||||||
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
|
default_score_weights={
|
||||||
|
"ents_f": 1.0,
|
||||||
|
"ents_p": 0.0,
|
||||||
|
"ents_r": 0.0,
|
||||||
|
"ents_per_type": None,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
def make_beam_ner(
|
def make_beam_ner(
|
||||||
nlp: Language,
|
nlp: Language,
|
||||||
|
@ -185,11 +195,12 @@ def make_ner_scorer():
|
||||||
return ner_score
|
return ner_score
|
||||||
|
|
||||||
|
|
||||||
cdef class EntityRecognizer(Parser):
|
class EntityRecognizer(Parser):
|
||||||
"""Pipeline component for named entity recognition.
|
"""Pipeline component for named entity recognition.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entityrecognizer
|
DOCS: https://spacy.io/api/entityrecognizer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TransitionSystem = BiluoPushDown
|
TransitionSystem = BiluoPushDown
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -207,15 +218,14 @@ cdef class EntityRecognizer(Parser):
|
||||||
incorrect_spans_key=None,
|
incorrect_spans_key=None,
|
||||||
scorer=ner_score,
|
scorer=ner_score,
|
||||||
):
|
):
|
||||||
"""Create an EntityRecognizer.
|
"""Create an EntityRecognizer."""
|
||||||
"""
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vocab,
|
vocab,
|
||||||
model,
|
model,
|
||||||
name,
|
name,
|
||||||
moves,
|
moves,
|
||||||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||||
min_action_freq=1, # not relevant for NER
|
min_action_freq=1, # not relevant for NER
|
||||||
learn_tokens=False, # not relevant for NER
|
learn_tokens=False, # not relevant for NER
|
||||||
beam_width=beam_width,
|
beam_width=beam_width,
|
||||||
beam_density=beam_density,
|
beam_density=beam_density,
|
|
@ -1,21 +0,0 @@
|
||||||
from cymem.cymem cimport Pool
|
|
||||||
from thinc.backends.cblas cimport CBlas
|
|
||||||
|
|
||||||
from ..vocab cimport Vocab
|
|
||||||
from .trainable_pipe cimport TrainablePipe
|
|
||||||
from ._parser_internals.transition_system cimport Transition, TransitionSystem
|
|
||||||
from ._parser_internals._state cimport StateC
|
|
||||||
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC
|
|
||||||
|
|
||||||
|
|
||||||
cdef class Parser(TrainablePipe):
|
|
||||||
cdef public object _rehearsal_model
|
|
||||||
cdef readonly TransitionSystem moves
|
|
||||||
cdef public object _multitasks
|
|
||||||
cdef object _cpu_ops
|
|
||||||
|
|
||||||
cdef void _parseC(self, CBlas cblas, StateC** states,
|
|
||||||
WeightsC weights, SizesC sizes) nogil
|
|
||||||
|
|
||||||
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
|
||||||
int nr_class, int batch_size) nogil
|
|
|
@ -8,26 +8,27 @@ from libcpp.vector cimport vector
|
||||||
from libc.string cimport memset, memcpy
|
from libc.string cimport memset, memcpy
|
||||||
from libc.stdlib cimport calloc, free
|
from libc.stdlib cimport calloc, free
|
||||||
import random
|
import random
|
||||||
|
import contextlib
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps, Optimizer
|
from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps, Optimizer
|
||||||
from thinc.api import chain, softmax_activation, use_ops
|
from thinc.api import chain, softmax_activation, use_ops, get_array_module
|
||||||
from thinc.legacy import LegacySequenceCategoricalCrossentropy
|
from thinc.legacy import LegacySequenceCategoricalCrossentropy
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d, Ints1d
|
||||||
import numpy.random
|
import numpy.random
|
||||||
import numpy
|
import numpy
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from ._parser_internals.stateclass cimport StateClass
|
from ..ml.tb_framework import TransitionModelInputs
|
||||||
|
from ._parser_internals.stateclass cimport StateC, StateClass
|
||||||
from ._parser_internals.search cimport Beam
|
from ._parser_internals.search cimport Beam
|
||||||
from ..ml.parser_model cimport alloc_activations, free_activations
|
|
||||||
from ..ml.parser_model cimport predict_states, arg_max_if_valid
|
|
||||||
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
|
|
||||||
from ..ml.parser_model cimport get_c_weights, get_c_sizes
|
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe cimport TrainablePipe
|
||||||
from ._parser_internals cimport _beam_utils
|
from ._parser_internals cimport _beam_utils
|
||||||
from ._parser_internals import _beam_utils
|
from ._parser_internals import _beam_utils
|
||||||
|
from ..vocab cimport Vocab
|
||||||
|
from ._parser_internals.transition_system cimport Transition, TransitionSystem
|
||||||
|
from ..typedefs cimport weight_t
|
||||||
|
|
||||||
from ..training import validate_examples, validate_get_examples
|
from ..training import validate_examples, validate_get_examples
|
||||||
from ..training import validate_distillation_examples
|
from ..training import validate_distillation_examples
|
||||||
|
@ -38,7 +39,7 @@ from .. import util
|
||||||
NUMPY_OPS = NumpyOps()
|
NUMPY_OPS = NumpyOps()
|
||||||
|
|
||||||
|
|
||||||
cdef class Parser(TrainablePipe):
|
class Parser(TrainablePipe):
|
||||||
"""
|
"""
|
||||||
Base class of the DependencyParser and EntityRecognizer.
|
Base class of the DependencyParser and EntityRecognizer.
|
||||||
"""
|
"""
|
||||||
|
@ -138,8 +139,9 @@ cdef class Parser(TrainablePipe):
|
||||||
@property
|
@property
|
||||||
def move_names(self):
|
def move_names(self):
|
||||||
names = []
|
names = []
|
||||||
|
cdef TransitionSystem moves = self.moves
|
||||||
for i in range(self.moves.n_moves):
|
for i in range(self.moves.n_moves):
|
||||||
name = self.moves.move_name(self.moves.c[i].move, self.moves.c[i].label)
|
name = self.moves.move_name(moves.c[i].move, moves.c[i].label)
|
||||||
# Explicitly removing the internal "U-" token used for blocking entities
|
# Explicitly removing the internal "U-" token used for blocking entities
|
||||||
if name != "U-":
|
if name != "U-":
|
||||||
names.append(name)
|
names.append(name)
|
||||||
|
@ -245,15 +247,6 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
student_docs = [eg.predicted for eg in examples]
|
student_docs = [eg.predicted for eg in examples]
|
||||||
|
|
||||||
teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples])
|
|
||||||
student_step_model, backprop_tok2vec = self.model.begin_update(student_docs)
|
|
||||||
|
|
||||||
# Add softmax activation, so that we can compute student losses
|
|
||||||
# with cross-entropy loss.
|
|
||||||
with use_ops("numpy"):
|
|
||||||
teacher_model = chain(teacher_step_model, softmax_activation())
|
|
||||||
student_model = chain(student_step_model, softmax_activation())
|
|
||||||
|
|
||||||
max_moves = self.cfg["update_with_oracle_cut_size"]
|
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||||
if max_moves >= 1:
|
if max_moves >= 1:
|
||||||
# Chop sequences into lengths of this many words, to make the
|
# Chop sequences into lengths of this many words, to make the
|
||||||
|
@ -261,51 +254,39 @@ cdef class Parser(TrainablePipe):
|
||||||
# sequence, we use the teacher's predictions as the gold
|
# sequence, we use the teacher's predictions as the gold
|
||||||
# standard.
|
# standard.
|
||||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
||||||
states = self._init_batch(teacher_step_model, student_docs, max_moves)
|
states = self._init_batch(teacher_pipe, student_docs, max_moves)
|
||||||
else:
|
else:
|
||||||
states = self.moves.init_batch(student_docs)
|
states = self.moves.init_batch(student_docs)
|
||||||
|
|
||||||
loss = 0.0
|
# We distill as follows: 1. we first let the student predict transition
|
||||||
n_moves = 0
|
# sequences (and the corresponding transition probabilities); (2) we
|
||||||
while states:
|
# let the teacher follow the student's predicted transition sequences
|
||||||
# We do distillation as follows: (1) for every state, we compute the
|
# to obtain the teacher's transition probabilities; (3) we compute the
|
||||||
# transition softmax distributions: (2) we backpropagate the error of
|
# gradients of the student's transition distributions relative to the
|
||||||
# the student (compared to the teacher) into the student model; (3)
|
# teacher's distributions.
|
||||||
# for all states, we move to the next state using the student's
|
|
||||||
# predictions.
|
|
||||||
teacher_scores = teacher_model.predict(states)
|
|
||||||
student_scores, backprop = student_model.begin_update(states)
|
|
||||||
state_loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
|
||||||
backprop(d_scores)
|
|
||||||
loss += state_loss
|
|
||||||
self.transition_states(states, student_scores)
|
|
||||||
states = [state for state in states if not state.is_final()]
|
|
||||||
|
|
||||||
# Stop when we reach the maximum number of moves, otherwise we start
|
student_inputs = TransitionModelInputs(docs=student_docs, moves=self.moves,
|
||||||
# to process the remainder of cut sequences again.
|
max_moves=max_moves)
|
||||||
if max_moves >= 1 and n_moves >= max_moves:
|
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
||||||
break
|
actions = states2actions(student_states)
|
||||||
n_moves += 1
|
teacher_inputs = TransitionModelInputs(docs=[eg.reference for eg in examples],
|
||||||
|
moves=self.moves, actions=actions)
|
||||||
|
(_, teacher_scores) = teacher_pipe.model.predict(teacher_inputs)
|
||||||
|
|
||||||
backprop_tok2vec(student_docs)
|
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores)
|
||||||
|
backprop_scores((student_states, d_scores))
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
|
|
||||||
losses[self.name] += loss
|
losses[self.name] += loss
|
||||||
|
|
||||||
del backprop
|
|
||||||
del backprop_tok2vec
|
|
||||||
teacher_step_model.clear_memory()
|
|
||||||
student_step_model.clear_memory()
|
|
||||||
del teacher_model
|
|
||||||
del student_model
|
|
||||||
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
|
||||||
def get_teacher_student_loss(
|
def get_teacher_student_loss(
|
||||||
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d]
|
self, teacher_scores: List[Floats2d], student_scores: List[Floats2d],
|
||||||
|
normalize: bool=False,
|
||||||
) -> Tuple[float, List[Floats2d]]:
|
) -> Tuple[float, List[Floats2d]]:
|
||||||
"""Calculate the loss and its gradient for a batch of student
|
"""Calculate the loss and its gradient for a batch of student
|
||||||
scores, relative to teacher scores.
|
scores, relative to teacher scores.
|
||||||
|
@ -317,10 +298,28 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss
|
DOCS: https://spacy.io/api/dependencyparser#get_teacher_student_loss
|
||||||
"""
|
"""
|
||||||
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
|
|
||||||
d_scores, loss = loss_func(student_scores, teacher_scores)
|
# We can't easily hook up a softmax layer in the parsing model, since
|
||||||
if self.model.ops.xp.isnan(loss):
|
# the get_loss does additional masking. So, we could apply softmax
|
||||||
raise ValueError(Errors.E910.format(name=self.name))
|
# manually here and use Thinc's cross-entropy loss. But it's a bit
|
||||||
|
# suboptimal, since we can have a lot of states that would result in
|
||||||
|
# many kernel launches. Futhermore the parsing model's backprop expects
|
||||||
|
# a XP array, so we'd have to concat the softmaxes anyway. So, like
|
||||||
|
# the get_loss implementation, we'll compute the loss and gradients
|
||||||
|
# ourselves.
|
||||||
|
|
||||||
|
teacher_scores = self.model.ops.softmax(self.model.ops.xp.vstack(teacher_scores),
|
||||||
|
axis=-1, inplace=True)
|
||||||
|
student_scores = self.model.ops.softmax(self.model.ops.xp.vstack(student_scores),
|
||||||
|
axis=-1, inplace=True)
|
||||||
|
|
||||||
|
assert teacher_scores.shape == student_scores.shape
|
||||||
|
|
||||||
|
d_scores = student_scores - teacher_scores
|
||||||
|
if normalize:
|
||||||
|
d_scores /= d_scores.shape[0]
|
||||||
|
loss = (d_scores**2).sum() / d_scores.size
|
||||||
|
|
||||||
return float(loss), d_scores
|
return float(loss), d_scores
|
||||||
|
|
||||||
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
|
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
|
||||||
|
@ -343,9 +342,6 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
stream: The sequence of documents to process.
|
stream: The sequence of documents to process.
|
||||||
batch_size (int): Number of documents to accumulate into a working set.
|
batch_size (int): Number of documents to accumulate into a working set.
|
||||||
error_handler (Callable[[str, List[Doc], Exception], Any]): Function that
|
|
||||||
deals with a failing batch of documents. The default function just reraises
|
|
||||||
the exception.
|
|
||||||
|
|
||||||
YIELDS (Doc): Documents, in order.
|
YIELDS (Doc): Documents, in order.
|
||||||
"""
|
"""
|
||||||
|
@ -367,78 +363,29 @@ cdef class Parser(TrainablePipe):
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
self._ensure_labels_are_added(docs)
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
result = self.moves.init_batch(docs)
|
result = self.moves.init_batch(docs)
|
||||||
return result
|
return result
|
||||||
if self.cfg["beam_width"] == 1:
|
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
||||||
return self.greedy_parse(docs, drop=0.0)
|
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
||||||
else:
|
states_or_beams, _ = self.model.predict(inputs)
|
||||||
return self.beam_parse(
|
return states_or_beams
|
||||||
docs,
|
|
||||||
drop=0.0,
|
|
||||||
beam_width=self.cfg["beam_width"],
|
|
||||||
beam_density=self.cfg["beam_density"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def greedy_parse(self, docs, drop=0.):
|
def greedy_parse(self, docs, drop=0.):
|
||||||
cdef vector[StateC*] states
|
self._resize()
|
||||||
cdef StateClass state
|
|
||||||
cdef CBlas cblas = self._cpu_ops.cblas()
|
|
||||||
self._ensure_labels_are_added(docs)
|
self._ensure_labels_are_added(docs)
|
||||||
set_dropout_rate(self.model, drop)
|
with _change_attrs(self.model, beam_width=1):
|
||||||
batch = self.moves.init_batch(docs)
|
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
||||||
model = self.model.predict(docs)
|
states, _ = self.model.predict(inputs)
|
||||||
weights = get_c_weights(model)
|
return states
|
||||||
for state in batch:
|
|
||||||
if not state.is_final():
|
|
||||||
states.push_back(state.c)
|
|
||||||
sizes = get_c_sizes(model, states.size())
|
|
||||||
with nogil:
|
|
||||||
self._parseC(cblas, &states[0], weights, sizes)
|
|
||||||
model.clear_memory()
|
|
||||||
del model
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||||
cdef Beam beam
|
|
||||||
cdef Doc doc
|
|
||||||
self._ensure_labels_are_added(docs)
|
self._ensure_labels_are_added(docs)
|
||||||
batch = _beam_utils.BeamBatch(
|
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
||||||
self.moves,
|
inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
||||||
self.moves.init_batch(docs),
|
beams, _ = self.model.predict(inputs)
|
||||||
None,
|
return beams
|
||||||
beam_width,
|
|
||||||
density=beam_density
|
|
||||||
)
|
|
||||||
model = self.model.predict(docs)
|
|
||||||
while not batch.is_done:
|
|
||||||
states = batch.get_unfinished_states()
|
|
||||||
if not states:
|
|
||||||
break
|
|
||||||
scores = model.predict(states)
|
|
||||||
batch.advance(scores)
|
|
||||||
model.clear_memory()
|
|
||||||
del model
|
|
||||||
return list(batch)
|
|
||||||
|
|
||||||
cdef void _parseC(self, CBlas cblas, StateC** states,
|
|
||||||
WeightsC weights, SizesC sizes) nogil:
|
|
||||||
cdef int i, j
|
|
||||||
cdef vector[StateC*] unfinished
|
|
||||||
cdef ActivationsC activations = alloc_activations(sizes)
|
|
||||||
while sizes.states >= 1:
|
|
||||||
predict_states(cblas, &activations, states, &weights, sizes)
|
|
||||||
# Validate actions, argmax, take action.
|
|
||||||
self.c_transition_batch(states,
|
|
||||||
activations.scores, sizes.classes, sizes.states)
|
|
||||||
for i in range(sizes.states):
|
|
||||||
if not states[i].is_final():
|
|
||||||
unfinished.push_back(states[i])
|
|
||||||
for i in range(unfinished.size()):
|
|
||||||
states[i] = unfinished[i]
|
|
||||||
sizes.states = unfinished.size()
|
|
||||||
unfinished.clear()
|
|
||||||
free_activations(&activations)
|
|
||||||
|
|
||||||
def set_annotations(self, docs, states_or_beams):
|
def set_annotations(self, docs, states_or_beams):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
|
@ -450,35 +397,6 @@ cdef class Parser(TrainablePipe):
|
||||||
for hook in self.postprocesses:
|
for hook in self.postprocesses:
|
||||||
hook(doc)
|
hook(doc)
|
||||||
|
|
||||||
def transition_states(self, states, float[:, ::1] scores):
|
|
||||||
cdef StateClass state
|
|
||||||
cdef float* c_scores = &scores[0, 0]
|
|
||||||
cdef vector[StateC*] c_states
|
|
||||||
for state in states:
|
|
||||||
c_states.push_back(state.c)
|
|
||||||
self.c_transition_batch(&c_states[0], c_scores, scores.shape[1], scores.shape[0])
|
|
||||||
return [state for state in states if not state.c.is_final()]
|
|
||||||
|
|
||||||
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
|
||||||
int nr_class, int batch_size) nogil:
|
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
|
||||||
with gil:
|
|
||||||
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
|
|
||||||
is_valid = <int*>calloc(self.moves.n_moves, sizeof(int))
|
|
||||||
cdef int i, guess
|
|
||||||
cdef Transition action
|
|
||||||
for i in range(batch_size):
|
|
||||||
self.moves.set_valid(is_valid, states[i])
|
|
||||||
guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class)
|
|
||||||
if guess == -1:
|
|
||||||
# This shouldn't happen, but it's hard to raise an error here,
|
|
||||||
# and we don't want to infinite loop. So, force to end state.
|
|
||||||
states[i].force_final()
|
|
||||||
else:
|
|
||||||
action = self.moves.c[guess]
|
|
||||||
action.do(states[i], action.label)
|
|
||||||
free(is_valid)
|
|
||||||
|
|
||||||
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
if losses is None:
|
if losses is None:
|
||||||
|
@ -490,67 +408,99 @@ cdef class Parser(TrainablePipe):
|
||||||
)
|
)
|
||||||
for multitask in self._multitasks:
|
for multitask in self._multitasks:
|
||||||
multitask.update(examples, drop=drop, sgd=sgd)
|
multitask.update(examples, drop=drop, sgd=sgd)
|
||||||
|
# We need to take care to act on the whole batch, because we might be
|
||||||
|
# getting vectors via a listener.
|
||||||
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
|
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
|
||||||
if n_examples == 0:
|
if n_examples == 0:
|
||||||
return losses
|
return losses
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
# The probability we use beam update, instead of falling back to
|
docs = [eg.x for eg in examples if len(eg.x)]
|
||||||
# a greedy update
|
|
||||||
beam_update_prob = self.cfg["beam_update_prob"]
|
|
||||||
if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob:
|
|
||||||
return self.update_beam(
|
|
||||||
examples,
|
|
||||||
beam_width=self.cfg["beam_width"],
|
|
||||||
sgd=sgd,
|
|
||||||
losses=losses,
|
|
||||||
beam_density=self.cfg["beam_density"]
|
|
||||||
)
|
|
||||||
max_moves = self.cfg["update_with_oracle_cut_size"]
|
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||||
if max_moves >= 1:
|
if max_moves >= 1:
|
||||||
# Chop sequences into lengths of this many words, to make the
|
# Chop sequences into lengths of this many words, to make the
|
||||||
# batch uniform length.
|
# batch uniform length.
|
||||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
|
||||||
states, golds, _ = self._init_gold_batch(
|
init_states, gold_states, _ = self._init_gold_batch(
|
||||||
examples,
|
examples,
|
||||||
max_length=max_moves
|
max_length=max_moves
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
init_states, gold_states, _ = self.moves.init_gold_batch(examples)
|
||||||
if not states:
|
|
||||||
return losses
|
|
||||||
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
|
|
||||||
|
|
||||||
all_states = list(states)
|
|
||||||
states_golds = list(zip(states, golds))
|
|
||||||
n_moves = 0
|
|
||||||
while states_golds:
|
|
||||||
states, golds = zip(*states_golds)
|
|
||||||
scores, backprop = model.begin_update(states)
|
|
||||||
d_scores = self.get_batch_loss(states, golds, scores, losses)
|
|
||||||
# Note that the gradient isn't normalized by the batch size
|
|
||||||
# here, because our "samples" are really the states...But we
|
|
||||||
# can't normalize by the number of states either, as then we'd
|
|
||||||
# be getting smaller gradients for states in long sequences.
|
|
||||||
backprop(d_scores)
|
|
||||||
# Follow the predicted action
|
|
||||||
self.transition_states(states, scores)
|
|
||||||
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
|
||||||
if max_moves >= 1 and n_moves >= max_moves:
|
|
||||||
break
|
|
||||||
n_moves += 1
|
|
||||||
|
|
||||||
backprop_tok2vec(golds)
|
inputs = TransitionModelInputs(docs=docs, moves=self.moves,
|
||||||
|
max_moves=max_moves, states=[state.copy() for state in init_states])
|
||||||
|
(pred_states, scores), backprop_scores = self.model.begin_update(inputs)
|
||||||
|
if sum(s.shape[0] for s in scores) == 0:
|
||||||
|
return losses
|
||||||
|
d_scores = self.get_loss((gold_states, init_states, pred_states, scores),
|
||||||
|
examples, max_moves)
|
||||||
|
backprop_scores((pred_states, d_scores))
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
|
losses[self.name] += float((d_scores**2).sum())
|
||||||
# Ugh, this is annoying. If we're working on GPU, we want to free the
|
# Ugh, this is annoying. If we're working on GPU, we want to free the
|
||||||
# memory ASAP. It seems that Python doesn't necessarily get around to
|
# memory ASAP. It seems that Python doesn't necessarily get around to
|
||||||
# removing these in time if we don't explicitly delete? It's confusing.
|
# removing these in time if we don't explicitly delete? It's confusing.
|
||||||
del backprop
|
del backprop_scores
|
||||||
del backprop_tok2vec
|
|
||||||
model.clear_memory()
|
|
||||||
del model
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
def get_loss(self, states_scores, examples, max_moves):
|
||||||
|
gold_states, init_states, pred_states, scores = states_scores
|
||||||
|
scores = self.model.ops.xp.vstack(scores)
|
||||||
|
costs = self._get_costs_from_histories(
|
||||||
|
examples,
|
||||||
|
gold_states,
|
||||||
|
init_states,
|
||||||
|
[list(state.history) for state in pred_states],
|
||||||
|
max_moves
|
||||||
|
)
|
||||||
|
xp = get_array_module(scores)
|
||||||
|
best_costs = costs.min(axis=1, keepdims=True)
|
||||||
|
gscores = scores.copy()
|
||||||
|
min_score = scores.min() - 1000
|
||||||
|
assert costs.shape == scores.shape, (costs.shape, scores.shape)
|
||||||
|
gscores[costs > best_costs] = min_score
|
||||||
|
max_ = scores.max(axis=1, keepdims=True)
|
||||||
|
gmax = gscores.max(axis=1, keepdims=True)
|
||||||
|
exp_scores = xp.exp(scores - max_)
|
||||||
|
exp_gscores = xp.exp(gscores - gmax)
|
||||||
|
Z = exp_scores.sum(axis=1, keepdims=True)
|
||||||
|
gZ = exp_gscores.sum(axis=1, keepdims=True)
|
||||||
|
d_scores = exp_scores / Z
|
||||||
|
d_scores -= (costs <= best_costs) * (exp_gscores / gZ)
|
||||||
|
return d_scores
|
||||||
|
|
||||||
|
def _get_costs_from_histories(self, examples, gold_states, init_states, histories, max_moves):
|
||||||
|
cdef TransitionSystem moves = self.moves
|
||||||
|
cdef StateClass state
|
||||||
|
cdef int clas
|
||||||
|
cdef int nF = self.model.get_dim("nF")
|
||||||
|
cdef int nO = moves.n_moves
|
||||||
|
cdef int nS = sum([len(history) for history in histories])
|
||||||
|
cdef Pool mem = Pool()
|
||||||
|
cdef np.ndarray costs_i
|
||||||
|
is_valid = <int*>mem.alloc(nO, sizeof(int))
|
||||||
|
batch = list(zip(init_states, histories, gold_states))
|
||||||
|
n_moves = 0
|
||||||
|
output = []
|
||||||
|
while batch:
|
||||||
|
costs = numpy.zeros((len(batch), nO), dtype="f")
|
||||||
|
for i, (state, history, gold) in enumerate(batch):
|
||||||
|
costs_i = costs[i]
|
||||||
|
clas = history.pop(0)
|
||||||
|
moves.set_costs(is_valid, <weight_t*>costs_i.data, state.c, gold)
|
||||||
|
action = moves.c[clas]
|
||||||
|
action.do(state.c, action.label)
|
||||||
|
state.c.history.push_back(clas)
|
||||||
|
output.append(costs)
|
||||||
|
batch = [(s, h, g) for s, h, g in batch if len(h) != 0]
|
||||||
|
if n_moves >= max_moves >= 1:
|
||||||
|
break
|
||||||
|
n_moves += 1
|
||||||
|
|
||||||
|
return self.model.ops.xp.vstack(output)
|
||||||
|
|
||||||
def rehearse(self, examples, sgd=None, losses=None, **cfg):
|
def rehearse(self, examples, sgd=None, losses=None, **cfg):
|
||||||
"""Perform a "rehearsal" update, to prevent catastrophic forgetting."""
|
"""Perform a "rehearsal" update, to prevent catastrophic forgetting."""
|
||||||
if losses is None:
|
if losses is None:
|
||||||
|
@ -560,10 +510,9 @@ cdef class Parser(TrainablePipe):
|
||||||
multitask.rehearse(examples, losses=losses, sgd=sgd)
|
multitask.rehearse(examples, losses=losses, sgd=sgd)
|
||||||
if self._rehearsal_model is None:
|
if self._rehearsal_model is None:
|
||||||
return None
|
return None
|
||||||
losses.setdefault(self.name, 0.)
|
losses.setdefault(self.name, 0.0)
|
||||||
validate_examples(examples, "Parser.rehearse")
|
validate_examples(examples, "Parser.rehearse")
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
states = self.moves.init_batch(docs)
|
|
||||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
# This is pretty dirty, but the NER can resize itself in init_batch,
|
||||||
# if labels are missing. We therefore have to check whether we need to
|
# if labels are missing. We therefore have to check whether we need to
|
||||||
# expand our model output.
|
# expand our model output.
|
||||||
|
@ -571,85 +520,33 @@ cdef class Parser(TrainablePipe):
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||||
set_dropout_rate(self._rehearsal_model, 0.0)
|
set_dropout_rate(self._rehearsal_model, 0.0)
|
||||||
set_dropout_rate(self.model, 0.0)
|
set_dropout_rate(self.model, 0.0)
|
||||||
tutor, _ = self._rehearsal_model.begin_update(docs)
|
student_inputs = TransitionModelInputs(docs=docs, moves=self.moves)
|
||||||
model, backprop_tok2vec = self.model.begin_update(docs)
|
(student_states, student_scores), backprop_scores = self.model.begin_update(student_inputs)
|
||||||
n_scores = 0.
|
actions = states2actions(student_states)
|
||||||
loss = 0.
|
teacher_inputs = TransitionModelInputs(docs=docs, moves=self.moves, actions=actions)
|
||||||
while states:
|
_, teacher_scores = self._rehearsal_model.predict(teacher_inputs)
|
||||||
targets, _ = tutor.begin_update(states)
|
|
||||||
guesses, backprop = model.begin_update(states)
|
loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores, normalize=True)
|
||||||
d_scores = (guesses - targets) / targets.shape[0]
|
|
||||||
# If all weights for an output are 0 in the original model, don't
|
teacher_scores = self.model.ops.xp.vstack(teacher_scores)
|
||||||
# supervise that output. This allows us to add classes.
|
student_scores = self.model.ops.xp.vstack(student_scores)
|
||||||
loss += (d_scores**2).sum()
|
assert teacher_scores.shape == student_scores.shape
|
||||||
backprop(d_scores)
|
|
||||||
# Follow the predicted action
|
d_scores = (student_scores - teacher_scores) / teacher_scores.shape[0]
|
||||||
self.transition_states(states, guesses)
|
# If all weights for an output are 0 in the original model, don't
|
||||||
states = [state for state in states if not state.is_final()]
|
# supervise that output. This allows us to add classes.
|
||||||
n_scores += d_scores.size
|
loss = (d_scores**2).sum() / d_scores.size
|
||||||
# Do the backprop
|
backprop_scores((student_states, d_scores))
|
||||||
backprop_tok2vec(docs)
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
losses[self.name] += loss / n_scores
|
losses[self.name] += loss
|
||||||
del backprop
|
|
||||||
del backprop_tok2vec
|
|
||||||
model.clear_memory()
|
|
||||||
tutor.clear_memory()
|
|
||||||
del model
|
|
||||||
del tutor
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def update_beam(self, examples, *, beam_width,
|
def update_beam(self, examples, *, beam_width,
|
||||||
drop=0., sgd=None, losses=None, beam_density=0.0):
|
drop=0., sgd=None, losses=None, beam_density=0.0):
|
||||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
raise NotImplementedError
|
||||||
if not states:
|
|
||||||
return losses
|
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
|
||||||
model, backprop_tok2vec = self.model.begin_update(
|
|
||||||
[eg.predicted for eg in examples])
|
|
||||||
loss = _beam_utils.update_beam(
|
|
||||||
self.moves,
|
|
||||||
states,
|
|
||||||
golds,
|
|
||||||
model,
|
|
||||||
beam_width,
|
|
||||||
beam_density=beam_density,
|
|
||||||
)
|
|
||||||
losses[self.name] += loss
|
|
||||||
backprop_tok2vec(golds)
|
|
||||||
if sgd is not None:
|
|
||||||
self.finish_update(sgd)
|
|
||||||
|
|
||||||
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
|
|
||||||
cdef StateClass state
|
|
||||||
cdef Pool mem = Pool()
|
|
||||||
cdef int i
|
|
||||||
|
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
|
||||||
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
|
|
||||||
|
|
||||||
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
|
|
||||||
costs = <float*>mem.alloc(self.moves.n_moves, sizeof(float))
|
|
||||||
cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves),
|
|
||||||
dtype='f', order='C')
|
|
||||||
c_d_scores = <float*>d_scores.data
|
|
||||||
unseen_classes = self.model.attrs["unseen_classes"]
|
|
||||||
for i, (state, gold) in enumerate(zip(states, golds)):
|
|
||||||
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
|
|
||||||
memset(costs, 0, self.moves.n_moves * sizeof(float))
|
|
||||||
self.moves.set_costs(is_valid, costs, state.c, gold)
|
|
||||||
for j in range(self.moves.n_moves):
|
|
||||||
if costs[j] <= 0.0 and j in unseen_classes:
|
|
||||||
unseen_classes.remove(j)
|
|
||||||
cpu_log_loss(c_d_scores,
|
|
||||||
costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
|
||||||
c_d_scores += d_scores.shape[1]
|
|
||||||
# Note that we don't normalize this. See comment in update() for why.
|
|
||||||
if losses is not None:
|
|
||||||
losses.setdefault(self.name, 0.)
|
|
||||||
losses[self.name] += (d_scores**2).sum()
|
|
||||||
return d_scores
|
|
||||||
|
|
||||||
def set_output(self, nO):
|
def set_output(self, nO):
|
||||||
self.model.attrs["resize_output"](self.model, nO)
|
self.model.attrs["resize_output"](self.model, nO)
|
||||||
|
@ -688,7 +585,7 @@ cdef class Parser(TrainablePipe):
|
||||||
for example in islice(get_examples(), 10):
|
for example in islice(get_examples(), 10):
|
||||||
doc_sample.append(example.predicted)
|
doc_sample.append(example.predicted)
|
||||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.initialize(doc_sample)
|
self.model.initialize((doc_sample, self.moves))
|
||||||
if nlp is not None:
|
if nlp is not None:
|
||||||
self.init_multitask_objectives(get_examples, nlp.pipeline)
|
self.init_multitask_objectives(get_examples, nlp.pipeline)
|
||||||
|
|
||||||
|
@ -781,26 +678,27 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
def _init_gold_batch(self, examples, max_length):
|
def _init_gold_batch(self, examples, max_length):
|
||||||
"""Make a square batch, of length equal to the shortest transition
|
"""Make a square batch, of length equal to the shortest transition
|
||||||
sequence or a cap. A long
|
sequence or a cap. A long doc will get multiple states. Let's say we
|
||||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
have a doc of length 2*N, where N is the shortest doc. We'll make
|
||||||
where N is the shortest doc. We'll make two states, one representing
|
two states, one representing long_doc[:N], and another representing
|
||||||
long_doc[:N], and another representing long_doc[N:]."""
|
long_doc[N:]."""
|
||||||
cdef:
|
cdef:
|
||||||
StateClass start_state
|
StateClass start_state
|
||||||
StateClass state
|
StateClass state
|
||||||
Transition action
|
Transition action
|
||||||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
TransitionSystem moves = self.moves
|
||||||
|
all_states = moves.init_batch([eg.predicted for eg in examples])
|
||||||
states = []
|
states = []
|
||||||
golds = []
|
golds = []
|
||||||
to_cut = []
|
to_cut = []
|
||||||
for state, eg in zip(all_states, examples):
|
for state, eg in zip(all_states, examples):
|
||||||
if self.moves.has_gold(eg) and not state.is_final():
|
if moves.has_gold(eg) and not state.is_final():
|
||||||
gold = self.moves.init_gold(state, eg)
|
gold = moves.init_gold(state, eg)
|
||||||
if len(eg.x) < max_length:
|
if len(eg.x) < max_length:
|
||||||
states.append(state)
|
states.append(state)
|
||||||
golds.append(gold)
|
golds.append(gold)
|
||||||
else:
|
else:
|
||||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
oracle_actions = moves.get_oracle_sequence_from_state(
|
||||||
state.copy(), gold)
|
state.copy(), gold)
|
||||||
to_cut.append((eg, state, gold, oracle_actions))
|
to_cut.append((eg, state, gold, oracle_actions))
|
||||||
if not to_cut:
|
if not to_cut:
|
||||||
|
@ -810,13 +708,52 @@ cdef class Parser(TrainablePipe):
|
||||||
for i in range(0, len(oracle_actions), max_length):
|
for i in range(0, len(oracle_actions), max_length):
|
||||||
start_state = state.copy()
|
start_state = state.copy()
|
||||||
for clas in oracle_actions[i:i+max_length]:
|
for clas in oracle_actions[i:i+max_length]:
|
||||||
action = self.moves.c[clas]
|
action = moves.c[clas]
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
if state.is_final():
|
if state.is_final():
|
||||||
break
|
break
|
||||||
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
if moves.has_gold(eg, start_state.B(0), state.B(0)):
|
||||||
states.append(start_state)
|
states.append(start_state)
|
||||||
golds.append(gold)
|
golds.append(gold)
|
||||||
if state.is_final():
|
if state.is_final():
|
||||||
break
|
break
|
||||||
return states, golds, max_length
|
return states, golds, max_length
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _change_attrs(model, **kwargs):
|
||||||
|
"""Temporarily modify a thinc model's attributes."""
|
||||||
|
unset = object()
|
||||||
|
old_attrs = {}
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
old_attrs[key] = model.attrs.get(key, unset)
|
||||||
|
model.attrs[key] = value
|
||||||
|
yield model
|
||||||
|
for key, value in old_attrs.items():
|
||||||
|
if value is unset:
|
||||||
|
model.attrs.pop(key)
|
||||||
|
else:
|
||||||
|
model.attrs[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def states2actions(states: List[StateClass]) -> List[Ints1d]:
|
||||||
|
cdef int step
|
||||||
|
cdef StateClass state
|
||||||
|
cdef StateC* c_state
|
||||||
|
actions = []
|
||||||
|
while True:
|
||||||
|
step = len(actions)
|
||||||
|
|
||||||
|
step_actions = []
|
||||||
|
for state in states:
|
||||||
|
c_state = state.c
|
||||||
|
if step < c_state.history.size():
|
||||||
|
step_actions.append(c_state.history[step])
|
||||||
|
|
||||||
|
# We are done if we have exhausted all histories.
|
||||||
|
if len(step_actions) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
actions.append(numpy.array(step_actions, dtype="i"))
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
|
@ -13,6 +13,7 @@ from spacy.pipeline._parser_internals.ner import BiluoPushDown
|
||||||
from spacy.training import Example, iob_to_biluo, split_bilu_label
|
from spacy.training import Example, iob_to_biluo, split_bilu_label
|
||||||
from spacy.tokens import Doc, Span
|
from spacy.tokens import Doc, Span
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
from thinc.api import fix_random_seed
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
@ -412,7 +413,7 @@ def test_train_empty():
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
ner = nlp.add_pipe("ner", last=True)
|
ner = nlp.add_pipe("ner", last=True)
|
||||||
ner.add_label("PERSON")
|
ner.add_label("PERSON")
|
||||||
nlp.initialize()
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
for itn in range(2):
|
for itn in range(2):
|
||||||
losses = {}
|
losses = {}
|
||||||
batches = util.minibatch(train_examples, size=8)
|
batches = util.minibatch(train_examples, size=8)
|
||||||
|
@ -539,11 +540,11 @@ def test_block_ner():
|
||||||
assert [token.ent_type_ for token in doc] == expected_types
|
assert [token.ent_type_ for token in doc] == expected_types
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_upper", [True, False])
|
def test_overfitting_IO():
|
||||||
def test_overfitting_IO(use_upper):
|
fix_random_seed(1)
|
||||||
# Simple test to try and quickly overfit the NER component
|
# Simple test to try and quickly overfit the NER component
|
||||||
nlp = English()
|
nlp = English()
|
||||||
ner = nlp.add_pipe("ner", config={"model": {"use_upper": use_upper}})
|
ner = nlp.add_pipe("ner", config={"model": {}})
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA:
|
for text, annotations in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
@ -575,7 +576,6 @@ def test_overfitting_IO(use_upper):
|
||||||
assert ents2[0].label_ == "LOC"
|
assert ents2[0].label_ == "LOC"
|
||||||
# Ensure that the predictions are still the same, even after adding a new label
|
# Ensure that the predictions are still the same, even after adding a new label
|
||||||
ner2 = nlp2.get_pipe("ner")
|
ner2 = nlp2.get_pipe("ner")
|
||||||
assert ner2.model.attrs["has_upper"] == use_upper
|
|
||||||
ner2.add_label("RANDOM_NEW_LABEL")
|
ner2.add_label("RANDOM_NEW_LABEL")
|
||||||
doc3 = nlp2(test_text)
|
doc3 = nlp2(test_text)
|
||||||
ents3 = doc3.ents
|
ents3 = doc3.ents
|
||||||
|
|
|
@ -1,13 +1,17 @@
|
||||||
|
import itertools
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy
|
||||||
from numpy.testing import assert_equal
|
from numpy.testing import assert_equal
|
||||||
from thinc.api import Adam
|
from thinc.api import Adam
|
||||||
|
|
||||||
from spacy import registry, util
|
from spacy import registry, util
|
||||||
from spacy.attrs import DEP, NORM
|
from spacy.attrs import DEP, NORM
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.tokens import Doc
|
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
|
from spacy.tokens import Doc
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
from spacy import util, registry
|
||||||
|
from thinc.api import fix_random_seed
|
||||||
|
|
||||||
from ...pipeline import DependencyParser
|
from ...pipeline import DependencyParser
|
||||||
from ...pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
from ...pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
||||||
|
@ -59,6 +63,8 @@ PARTIAL_DATA = [
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
PARSERS = ["parser"] # TODO: Test beam_parser when ready
|
||||||
|
|
||||||
eps = 0.1
|
eps = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,6 +177,57 @@ def test_parser_parse_one_word_sentence(en_vocab, en_parser, words):
|
||||||
assert doc[0].dep != 0
|
assert doc[0].dep != 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_parser_apply_actions(en_vocab, en_parser):
|
||||||
|
words = ["I", "ate", "pizza"]
|
||||||
|
words2 = ["Eat", "more", "pizza", "!"]
|
||||||
|
doc1 = Doc(en_vocab, words=words)
|
||||||
|
doc2 = Doc(en_vocab, words=words2)
|
||||||
|
docs = [doc1, doc2]
|
||||||
|
|
||||||
|
moves = en_parser.moves
|
||||||
|
moves.add_action(0, "")
|
||||||
|
moves.add_action(1, "")
|
||||||
|
moves.add_action(2, "nsubj")
|
||||||
|
moves.add_action(3, "obj")
|
||||||
|
moves.add_action(2, "amod")
|
||||||
|
|
||||||
|
actions = [
|
||||||
|
numpy.array([0, 0], dtype="i"),
|
||||||
|
numpy.array([2, 0], dtype="i"),
|
||||||
|
numpy.array([0, 4], dtype="i"),
|
||||||
|
numpy.array([3, 3], dtype="i"),
|
||||||
|
numpy.array([1, 1], dtype="i"),
|
||||||
|
numpy.array([1, 1], dtype="i"),
|
||||||
|
numpy.array([0], dtype="i"),
|
||||||
|
numpy.array([1], dtype="i"),
|
||||||
|
]
|
||||||
|
|
||||||
|
states = moves.init_batch(docs)
|
||||||
|
active_states = states
|
||||||
|
|
||||||
|
for step_actions in actions:
|
||||||
|
active_states = moves.apply_actions(active_states, step_actions)
|
||||||
|
|
||||||
|
assert len(active_states) == 0
|
||||||
|
|
||||||
|
for (state, doc) in zip(states, docs):
|
||||||
|
moves.set_annotations(state, doc)
|
||||||
|
|
||||||
|
assert docs[0][0].head.i == 1
|
||||||
|
assert docs[0][0].dep_ == "nsubj"
|
||||||
|
assert docs[0][1].head.i == 1
|
||||||
|
assert docs[0][1].dep_ == "ROOT"
|
||||||
|
assert docs[0][2].head.i == 1
|
||||||
|
assert docs[0][2].dep_ == "obj"
|
||||||
|
|
||||||
|
assert docs[1][0].head.i == 0
|
||||||
|
assert docs[1][0].dep_ == "ROOT"
|
||||||
|
assert docs[1][1].head.i == 2
|
||||||
|
assert docs[1][1].dep_ == "amod"
|
||||||
|
assert docs[1][2].head.i == 0
|
||||||
|
assert docs[1][2].dep_ == "obj"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
reason="The step_through API was removed (but should be brought back)"
|
reason="The step_through API was removed (but should be brought back)"
|
||||||
)
|
)
|
||||||
|
@ -319,7 +376,7 @@ def test_parser_constructor(en_vocab):
|
||||||
DependencyParser(en_vocab, model)
|
DependencyParser(en_vocab, model)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
|
@pytest.mark.parametrize("pipe_name", PARSERS)
|
||||||
def test_incomplete_data(pipe_name):
|
def test_incomplete_data(pipe_name):
|
||||||
# Test that the parser works with incomplete information
|
# Test that the parser works with incomplete information
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
@ -345,11 +402,15 @@ def test_incomplete_data(pipe_name):
|
||||||
assert doc[2].head.i == 1
|
assert doc[2].head.i == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
|
@pytest.mark.parametrize(
|
||||||
def test_overfitting_IO(pipe_name):
|
"pipe_name,max_moves", itertools.product(PARSERS, [0, 1, 5, 100])
|
||||||
|
)
|
||||||
|
def test_overfitting_IO(pipe_name, max_moves):
|
||||||
|
fix_random_seed(0)
|
||||||
# Simple test to try and quickly overfit the dependency parser (normal or beam)
|
# Simple test to try and quickly overfit the dependency parser (normal or beam)
|
||||||
nlp = English()
|
nlp = English()
|
||||||
parser = nlp.add_pipe(pipe_name)
|
parser = nlp.add_pipe(pipe_name)
|
||||||
|
parser.cfg["update_with_oracle_cut_size"] = max_moves
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA:
|
for text, annotations in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
@ -451,10 +512,12 @@ def test_distill():
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"parser_config",
|
"parser_config",
|
||||||
[
|
[
|
||||||
# TransitionBasedParser V1
|
# TODO: re-enable after we have a spacy-legacy release for v4. See
|
||||||
({"@architectures": "spacy.TransitionBasedParser.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}),
|
# https://github.com/explosion/spacy-legacy/pull/36
|
||||||
# TransitionBasedParser V2
|
#({"@architectures": "spacy.TransitionBasedParser.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}),
|
||||||
({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}),
|
({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}),
|
||||||
|
({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": False}),
|
||||||
|
({"@architectures": "spacy.TransitionBasedParser.v3", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
|
@ -382,7 +382,7 @@ cfg_string_multi = """
|
||||||
factory = "ner"
|
factory = "ner"
|
||||||
|
|
||||||
[components.ner.model]
|
[components.ner.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v2"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
|
|
||||||
[components.ner.model.tok2vec]
|
[components.ner.model.tok2vec]
|
||||||
@architectures = "spacy.Tok2VecListener.v1"
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
|
|
@ -122,33 +122,11 @@ width = ${components.tok2vec.model.width}
|
||||||
|
|
||||||
parser_config_string_upper = """
|
parser_config_string_upper = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v2"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
state_type = "parser"
|
state_type = "parser"
|
||||||
extra_state_tokens = false
|
extra_state_tokens = false
|
||||||
hidden_width = 66
|
hidden_width = 66
|
||||||
maxout_pieces = 2
|
maxout_pieces = 2
|
||||||
use_upper = true
|
|
||||||
|
|
||||||
[model.tok2vec]
|
|
||||||
@architectures = "spacy.HashEmbedCNN.v1"
|
|
||||||
pretrained_vectors = null
|
|
||||||
width = 333
|
|
||||||
depth = 4
|
|
||||||
embed_size = 5555
|
|
||||||
window_size = 1
|
|
||||||
maxout_pieces = 7
|
|
||||||
subword_features = false
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
parser_config_string_no_upper = """
|
|
||||||
[model]
|
|
||||||
@architectures = "spacy.TransitionBasedParser.v2"
|
|
||||||
state_type = "parser"
|
|
||||||
extra_state_tokens = false
|
|
||||||
hidden_width = 66
|
|
||||||
maxout_pieces = 2
|
|
||||||
use_upper = false
|
|
||||||
|
|
||||||
[model.tok2vec]
|
[model.tok2vec]
|
||||||
@architectures = "spacy.HashEmbedCNN.v1"
|
@architectures = "spacy.HashEmbedCNN.v1"
|
||||||
|
@ -179,7 +157,6 @@ def my_parser():
|
||||||
extra_state_tokens=True,
|
extra_state_tokens=True,
|
||||||
hidden_width=65,
|
hidden_width=65,
|
||||||
maxout_pieces=5,
|
maxout_pieces=5,
|
||||||
use_upper=True,
|
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -285,15 +262,16 @@ def test_serialize_custom_nlp():
|
||||||
nlp.to_disk(d)
|
nlp.to_disk(d)
|
||||||
nlp2 = spacy.load(d)
|
nlp2 = spacy.load(d)
|
||||||
model = nlp2.get_pipe("parser").model
|
model = nlp2.get_pipe("parser").model
|
||||||
model.get_ref("tok2vec")
|
assert model.get_ref("tok2vec") is not None
|
||||||
# check that we have the correct settings, not the default ones
|
assert model.has_param("hidden_W")
|
||||||
assert model.get_ref("upper").get_dim("nI") == 65
|
assert model.has_param("hidden_b")
|
||||||
assert model.get_ref("lower").get_dim("nI") == 65
|
output = model.get_ref("output")
|
||||||
|
assert output is not None
|
||||||
|
assert output.has_param("W")
|
||||||
|
assert output.has_param("b")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("parser_config_string", [parser_config_string_upper])
|
||||||
"parser_config_string", [parser_config_string_upper, parser_config_string_no_upper]
|
|
||||||
)
|
|
||||||
def test_serialize_parser(parser_config_string):
|
def test_serialize_parser(parser_config_string):
|
||||||
"""Create a non-default parser config to check nlp serializes it correctly"""
|
"""Create a non-default parser config to check nlp serializes it correctly"""
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
@ -306,11 +284,13 @@ def test_serialize_parser(parser_config_string):
|
||||||
nlp.to_disk(d)
|
nlp.to_disk(d)
|
||||||
nlp2 = spacy.load(d)
|
nlp2 = spacy.load(d)
|
||||||
model = nlp2.get_pipe("parser").model
|
model = nlp2.get_pipe("parser").model
|
||||||
model.get_ref("tok2vec")
|
assert model.get_ref("tok2vec") is not None
|
||||||
# check that we have the correct settings, not the default ones
|
assert model.has_param("hidden_W")
|
||||||
if model.attrs["has_upper"]:
|
assert model.has_param("hidden_b")
|
||||||
assert model.get_ref("upper").get_dim("nI") == 66
|
output = model.get_ref("output")
|
||||||
assert model.get_ref("lower").get_dim("nI") == 66
|
assert output is not None
|
||||||
|
assert output.has_param("b")
|
||||||
|
assert output.has_param("W")
|
||||||
|
|
||||||
|
|
||||||
def test_config_nlp_roundtrip():
|
def test_config_nlp_roundtrip():
|
||||||
|
@ -457,9 +437,7 @@ def test_config_auto_fill_extra_fields():
|
||||||
load_model_from_config(nlp.config)
|
load_model_from_config(nlp.config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("parser_config_string", [parser_config_string_upper])
|
||||||
"parser_config_string", [parser_config_string_upper, parser_config_string_no_upper]
|
|
||||||
)
|
|
||||||
def test_config_validate_literal(parser_config_string):
|
def test_config_validate_literal(parser_config_string):
|
||||||
nlp = English()
|
nlp = English()
|
||||||
config = Config().from_str(parser_config_string)
|
config = Config().from_str(parser_config_string)
|
||||||
|
|
|
@ -5,10 +5,8 @@ from pathlib import Path
|
||||||
from spacy.about import __version__ as spacy_version
|
from spacy.about import __version__ as spacy_version
|
||||||
from spacy import util
|
from spacy import util
|
||||||
from spacy import prefer_gpu, require_gpu, require_cpu
|
from spacy import prefer_gpu, require_gpu, require_cpu
|
||||||
from spacy.ml._precomputable_affine import PrecomputableAffine
|
from spacy.util import dot_to_object, SimpleFrozenList, import_file, to_ternary_int
|
||||||
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
|
from spacy.util import find_available_port
|
||||||
from spacy.util import dot_to_object, SimpleFrozenList, import_file
|
|
||||||
from spacy.util import to_ternary_int, find_available_port
|
|
||||||
from thinc.api import Config, Optimizer, ConfigValidationError
|
from thinc.api import Config, Optimizer, ConfigValidationError
|
||||||
from thinc.api import get_current_ops, set_current_ops, NumpyOps, CupyOps, MPSOps
|
from thinc.api import get_current_ops, set_current_ops, NumpyOps, CupyOps, MPSOps
|
||||||
from thinc.compat import has_cupy_gpu, has_torch_mps_gpu
|
from thinc.compat import has_cupy_gpu, has_torch_mps_gpu
|
||||||
|
@ -81,34 +79,6 @@ def test_util_get_package_path(package):
|
||||||
assert isinstance(path, Path)
|
assert isinstance(path, Path)
|
||||||
|
|
||||||
|
|
||||||
def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2):
|
|
||||||
model = PrecomputableAffine(nO=nO, nI=nI, nF=nF, nP=nP).initialize()
|
|
||||||
assert model.get_param("W").shape == (nF, nO, nP, nI)
|
|
||||||
tensor = model.ops.alloc((10, nI))
|
|
||||||
Y, get_dX = model.begin_update(tensor)
|
|
||||||
assert Y.shape == (tensor.shape[0] + 1, nF, nO, nP)
|
|
||||||
dY = model.ops.alloc((15, nO, nP))
|
|
||||||
ids = model.ops.alloc((15, nF))
|
|
||||||
ids[1, 2] = -1
|
|
||||||
dY[1] = 1
|
|
||||||
assert not model.has_grad("pad")
|
|
||||||
d_pad = _backprop_precomputable_affine_padding(model, dY, ids)
|
|
||||||
assert d_pad[0, 2, 0, 0] == 1.0
|
|
||||||
ids.fill(0.0)
|
|
||||||
dY.fill(0.0)
|
|
||||||
dY[0] = 0
|
|
||||||
ids[1, 2] = 0
|
|
||||||
ids[1, 1] = -1
|
|
||||||
ids[1, 0] = -1
|
|
||||||
dY[1] = 1
|
|
||||||
ids[2, 0] = -1
|
|
||||||
dY[2] = 5
|
|
||||||
d_pad = _backprop_precomputable_affine_padding(model, dY, ids)
|
|
||||||
assert d_pad[0, 0, 0, 0] == 6
|
|
||||||
assert d_pad[0, 1, 0, 0] == 1
|
|
||||||
assert d_pad[0, 2, 0, 0] == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_prefer_gpu():
|
def test_prefer_gpu():
|
||||||
current_ops = get_current_ops()
|
current_ops = get_current_ops()
|
||||||
if has_cupy_gpu:
|
if has_cupy_gpu:
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from collections.abc import Iterable as IterableInstance
|
from collections.abc import Iterable as IterableInstance
|
||||||
import warnings
|
|
||||||
import numpy
|
import numpy
|
||||||
from murmurhash.mrmr cimport hash64
|
from murmurhash.mrmr cimport hash64
|
||||||
|
|
||||||
|
|
|
@ -553,18 +553,17 @@ for a Tok2Vec layer.
|
||||||
|
|
||||||
## Parser & NER architectures {id="parser"}
|
## Parser & NER architectures {id="parser"}
|
||||||
|
|
||||||
### spacy.TransitionBasedParser.v2 {id="TransitionBasedParser",source="spacy/ml/models/parser.py"}
|
### spacy.TransitionBasedParser.v3 {id="TransitionBasedParser",source="spacy/ml/models/parser.py"}
|
||||||
|
|
||||||
> #### Example Config
|
> #### Example Config
|
||||||
>
|
>
|
||||||
> ```ini
|
> ```ini
|
||||||
> [model]
|
> [model]
|
||||||
> @architectures = "spacy.TransitionBasedParser.v2"
|
> @architectures = "spacy.TransitionBasedParser.v3"
|
||||||
> state_type = "ner"
|
> state_type = "ner"
|
||||||
> extra_state_tokens = false
|
> extra_state_tokens = false
|
||||||
> hidden_width = 64
|
> hidden_width = 64
|
||||||
> maxout_pieces = 2
|
> maxout_pieces = 2
|
||||||
> use_upper = true
|
|
||||||
>
|
>
|
||||||
> [model.tok2vec]
|
> [model.tok2vec]
|
||||||
> @architectures = "spacy.HashEmbedCNN.v2"
|
> @architectures = "spacy.HashEmbedCNN.v2"
|
||||||
|
@ -594,23 +593,22 @@ consists of either two or three subnetworks:
|
||||||
state representation. If not present, the output from the lower model is used
|
state representation. If not present, the output from the lower model is used
|
||||||
as action scores directly.
|
as action scores directly.
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ |
|
| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||||
| `state_type` | Which task to extract features for. Possible values are "ner" and "parser". ~~str~~ |
|
| `state_type` | Which task to extract features for. Possible values are "ner" and "parser". ~~str~~ |
|
||||||
| `extra_state_tokens` | Whether to use an expanded feature set when extracting the state tokens. Slightly slower, but sometimes improves accuracy slightly. Defaults to `False`. ~~bool~~ |
|
| `extra_state_tokens` | Whether to use an expanded feature set when extracting the state tokens. Slightly slower, but sometimes improves accuracy slightly. Defaults to `False`. ~~bool~~ |
|
||||||
| `hidden_width` | The width of the hidden layer. ~~int~~ |
|
| `hidden_width` | The width of the hidden layer. ~~int~~ |
|
||||||
| `maxout_pieces` | How many pieces to use in the state prediction layer. Recommended values are `1`, `2` or `3`. If `1`, the maxout non-linearity is replaced with a [`Relu`](https://thinc.ai/docs/api-layers#relu) non-linearity if `use_upper` is `True`, and no non-linearity if `False`. ~~int~~ |
|
| `maxout_pieces` | How many pieces to use in the state prediction layer. Recommended values are `1`, `2` or `3`. ~~int~~ |
|
||||||
| `use_upper` | Whether to use an additional hidden layer after the state vector in order to predict the action scores. It is recommended to set this to `False` for large pretrained models such as transformers, and `True` for smaller networks. The upper layer is computed on CPU, which becomes a bottleneck on larger GPU-based models, where it's also less necessary. ~~bool~~ |
|
| `nO` | The number of actions the model will predict between. Usually inferred from data at the beginning of training, or loaded from disk. ~~int~~ |
|
||||||
| `nO` | The number of actions the model will predict between. Usually inferred from data at the beginning of training, or loaded from disk. ~~int~~ |
|
| **CREATES** | The model using the architecture. ~~Model[List[Docs], List[List[Floats2d]]]~~ |
|
||||||
| **CREATES** | The model using the architecture. ~~Model[List[Docs], List[List[Floats2d]]]~~ |
|
|
||||||
|
|
||||||
<Accordion title="spacy.TransitionBasedParser.v1 definition" spaced>
|
<Accordion title="spacy.TransitionBasedParser.v1 definition" spaced>
|
||||||
|
|
||||||
[TransitionBasedParser.v1](/api/legacy#TransitionBasedParser_v1) had the exact
|
[TransitionBasedParser.v1](/api/legacy#TransitionBasedParser_v1) had the exact
|
||||||
same signature, but the `use_upper` argument was `True` by default.
|
same signature, but the `use_upper` argument was `True` by default.
|
||||||
|
|
||||||
</Accordion>
|
</Accordion>
|
||||||
|
|
||||||
## Tagging architectures {id="tagger",source="spacy/ml/models/tagger.py"}
|
## Tagging architectures {id="tagger",source="spacy/ml/models/tagger.py"}
|
||||||
|
|
||||||
|
|
|
@ -361,7 +361,7 @@ Module spacy.language
|
||||||
File /path/to/spacy/language.py (line 64)
|
File /path/to/spacy/language.py (line 64)
|
||||||
ℹ [components.ner.model]
|
ℹ [components.ner.model]
|
||||||
Registry @architectures
|
Registry @architectures
|
||||||
Name spacy.TransitionBasedParser.v1
|
Name spacy.TransitionBasedParser.v3
|
||||||
Module spacy.ml.models.parser
|
Module spacy.ml.models.parser
|
||||||
File /path/to/spacy/ml/models/parser.py (line 11)
|
File /path/to/spacy/ml/models/parser.py (line 11)
|
||||||
ℹ [components.ner.model.tok2vec]
|
ℹ [components.ner.model.tok2vec]
|
||||||
|
@ -371,7 +371,7 @@ Module spacy.ml.models.tok2vec
|
||||||
File /path/to/spacy/ml/models/tok2vec.py (line 16)
|
File /path/to/spacy/ml/models/tok2vec.py (line 16)
|
||||||
ℹ [components.parser.model]
|
ℹ [components.parser.model]
|
||||||
Registry @architectures
|
Registry @architectures
|
||||||
Name spacy.TransitionBasedParser.v1
|
Name spacy.TransitionBasedParser.v3
|
||||||
Module spacy.ml.models.parser
|
Module spacy.ml.models.parser
|
||||||
File /path/to/spacy/ml/models/parser.py (line 11)
|
File /path/to/spacy/ml/models/parser.py (line 11)
|
||||||
ℹ [components.parser.model.tok2vec]
|
ℹ [components.parser.model.tok2vec]
|
||||||
|
@ -696,7 +696,7 @@ scorer = {"@scorers":"spacy.ner_scorer.v1"}
|
||||||
update_with_oracle_cut_size = 100
|
update_with_oracle_cut_size = 100
|
||||||
|
|
||||||
[components.ner.model]
|
[components.ner.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v2"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
state_type = "ner"
|
state_type = "ner"
|
||||||
extra_state_tokens = false
|
extra_state_tokens = false
|
||||||
- hidden_width = 64
|
- hidden_width = 64
|
||||||
|
@ -719,7 +719,7 @@ scorer = {"@scorers":"spacy.parser_scorer.v1"}
|
||||||
update_with_oracle_cut_size = 100
|
update_with_oracle_cut_size = 100
|
||||||
|
|
||||||
[components.parser.model]
|
[components.parser.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v2"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
state_type = "parser"
|
state_type = "parser"
|
||||||
extra_state_tokens = false
|
extra_state_tokens = false
|
||||||
hidden_width = 128
|
hidden_width = 128
|
||||||
|
|
|
@ -225,7 +225,7 @@ the others, but may not be as accurate, especially if texts are short.
|
||||||
### spacy.TransitionBasedParser.v1 {id="TransitionBasedParser_v1"}
|
### spacy.TransitionBasedParser.v1 {id="TransitionBasedParser_v1"}
|
||||||
|
|
||||||
Identical to
|
Identical to
|
||||||
[`spacy.TransitionBasedParser.v2`](/api/architectures#TransitionBasedParser)
|
[`spacy.TransitionBasedParser.v3`](/api/architectures#TransitionBasedParser)
|
||||||
except the `use_upper` was set to `True` by default.
|
except the `use_upper` was set to `True` by default.
|
||||||
|
|
||||||
## Layers {id="layers"}
|
## Layers {id="layers"}
|
||||||
|
|
|
@ -140,7 +140,7 @@ factory = "tok2vec"
|
||||||
factory = "ner"
|
factory = "ner"
|
||||||
|
|
||||||
[components.ner.model]
|
[components.ner.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v1"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
|
|
||||||
[components.ner.model.tok2vec]
|
[components.ner.model.tok2vec]
|
||||||
@architectures = "spacy.Tok2VecListener.v1"
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
@ -156,7 +156,7 @@ same. This makes them fully independent and doesn't require an upstream
|
||||||
factory = "ner"
|
factory = "ner"
|
||||||
|
|
||||||
[components.ner.model]
|
[components.ner.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v1"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
|
|
||||||
[components.ner.model.tok2vec]
|
[components.ner.model.tok2vec]
|
||||||
@architectures = "spacy.Tok2Vec.v2"
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
@ -472,7 +472,7 @@ sneakily delegates to the `Transformer` pipeline component.
|
||||||
factory = "ner"
|
factory = "ner"
|
||||||
|
|
||||||
[nlp.pipeline.ner.model]
|
[nlp.pipeline.ner.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v1"
|
@architectures = "spacy.TransitionBasedParser.v3"
|
||||||
state_type = "ner"
|
state_type = "ner"
|
||||||
extra_state_tokens = false
|
extra_state_tokens = false
|
||||||
hidden_width = 128
|
hidden_width = 128
|
||||||
|
|
Loading…
Reference in New Issue
Block a user