mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
The coref model is able to be loaded
The span predictor component is initialized but not used at all now. Plan is to work on it after the word level clustering part is trainable end-to-end.
This commit is contained in:
parent
35cc2b138f
commit
c4f9c24738
|
@ -4,7 +4,8 @@ import warnings
|
|||
from thinc.api import Model, Linear, Relu, Dropout
|
||||
from thinc.api import chain, noop, Embed, add, tuplify, concatenate
|
||||
from thinc.api import reduce_first, reduce_last, reduce_mean
|
||||
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
|
||||
from thinc.api import PyTorchWrapper
|
||||
from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged
|
||||
from typing import List, Callable, Tuple, Any
|
||||
from ...tokens import Doc
|
||||
from ...util import registry
|
||||
|
@ -456,7 +457,7 @@ from typing import List, Tuple
|
|||
import torch
|
||||
|
||||
# TODO rename this to coref_util
|
||||
import .coref_util_wl as utils
|
||||
from .coref_util_wl import add_dummy
|
||||
|
||||
# TODO rename to plain coref
|
||||
@registry.architectures("spacy.WLCoref.v1")
|
||||
|
@ -478,20 +479,23 @@ def build_wl_coref_model(
|
|||
with Model.define_operators({">>": chain}):
|
||||
# TODO chain tok2vec with these models
|
||||
# TODO fix device - should be automatic
|
||||
device = "gpu:0"
|
||||
device = "cuda:0"
|
||||
coref_scorer = PyTorchWrapper(
|
||||
CorefScorer(
|
||||
device,
|
||||
embedding_size,
|
||||
hidden_size,
|
||||
n_hidden_layers,
|
||||
dropout_rate,
|
||||
dropout,
|
||||
rough_k,
|
||||
a_scoring_batch_size
|
||||
),
|
||||
convert_inputs=convert_coref_scorer_inputs,
|
||||
convert_outputs=convert_coref_scorer_outputs
|
||||
)
|
||||
|
||||
coref_model = tok2vec >> coref_scorer
|
||||
# XXX just ignore this until the coref scorer is integrated
|
||||
span_predictor = PyTorchWrapper(
|
||||
SpanPredictor(
|
||||
# TODO this was hardcoded to 1024, check
|
||||
|
@ -499,12 +503,13 @@ def build_wl_coref_model(
|
|||
sp_embedding_size,
|
||||
device
|
||||
),
|
||||
|
||||
convert_inputs=convert_span_predictor_inputs
|
||||
)
|
||||
# TODO combine models so output is uniform (just one forward pass)
|
||||
# It may be reasonable to have an option to disable span prediction,
|
||||
# and just return words as spans.
|
||||
return coref_scorer
|
||||
return coref_model
|
||||
|
||||
def convert_coref_scorer_inputs(
|
||||
model: Model,
|
||||
|
@ -534,6 +539,17 @@ def convert_coref_scorer_outputs(
|
|||
indices_xp = torch2xp(indices)
|
||||
return (scores_xp, indices_xp), convert_for_torch_backward
|
||||
|
||||
def convert_span_predictor_inputs(
|
||||
model: Model,
|
||||
X: Tuple[Ints1d, Floats2d, Ints1d],
|
||||
is_train: bool
|
||||
):
|
||||
sent_id = xp2torch(X[0], requires_grad=False)
|
||||
word_features = xp2torch(X[1], requires_grad=False)
|
||||
head_ids = xp2torch(X[2], requires_grad=False)
|
||||
argskwargs = ArgsKwargs(args=(sent_id, word_features, head_ids), kwargs={})
|
||||
return argskwargs, lambda dX: []
|
||||
|
||||
# TODO This probably belongs in the component, not the model.
|
||||
def predict_span_clusters(span_predictor: Model,
|
||||
sent_ids: Ints1d,
|
||||
|
@ -747,7 +763,7 @@ class AnaphoricityScorer(torch.nn.Module):
|
|||
|
||||
# [batch_size, n_ants]
|
||||
scores = top_rough_scores_batch + self._ffnn(pair_matrix)
|
||||
scores = utils.add_dummy(scores, eps=True)
|
||||
scores = add_dummy(scores, eps=True)
|
||||
|
||||
return scores
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user