mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 04:32:32 +03:00
The coref model is able to be loaded
The span predictor component is initialized but not used at all now. Plan is to work on it after the word level clustering part is trainable end-to-end.
This commit is contained in:
parent
35cc2b138f
commit
c4f9c24738
|
@ -4,7 +4,8 @@ import warnings
|
||||||
from thinc.api import Model, Linear, Relu, Dropout
|
from thinc.api import Model, Linear, Relu, Dropout
|
||||||
from thinc.api import chain, noop, Embed, add, tuplify, concatenate
|
from thinc.api import chain, noop, Embed, add, tuplify, concatenate
|
||||||
from thinc.api import reduce_first, reduce_last, reduce_mean
|
from thinc.api import reduce_first, reduce_last, reduce_mean
|
||||||
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
|
from thinc.api import PyTorchWrapper
|
||||||
|
from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged
|
||||||
from typing import List, Callable, Tuple, Any
|
from typing import List, Callable, Tuple, Any
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
|
@ -456,7 +457,7 @@ from typing import List, Tuple
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# TODO rename this to coref_util
|
# TODO rename this to coref_util
|
||||||
import .coref_util_wl as utils
|
from .coref_util_wl import add_dummy
|
||||||
|
|
||||||
# TODO rename to plain coref
|
# TODO rename to plain coref
|
||||||
@registry.architectures("spacy.WLCoref.v1")
|
@registry.architectures("spacy.WLCoref.v1")
|
||||||
|
@ -478,20 +479,23 @@ def build_wl_coref_model(
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
# TODO chain tok2vec with these models
|
# TODO chain tok2vec with these models
|
||||||
# TODO fix device - should be automatic
|
# TODO fix device - should be automatic
|
||||||
device = "gpu:0"
|
device = "cuda:0"
|
||||||
coref_scorer = PyTorchWrapper(
|
coref_scorer = PyTorchWrapper(
|
||||||
CorefScorer(
|
CorefScorer(
|
||||||
device,
|
device,
|
||||||
embedding_size,
|
embedding_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
n_hidden_layers,
|
n_hidden_layers,
|
||||||
dropout_rate,
|
dropout,
|
||||||
rough_k,
|
rough_k,
|
||||||
a_scoring_batch_size
|
a_scoring_batch_size
|
||||||
),
|
),
|
||||||
convert_inputs=convert_coref_scorer_inputs,
|
convert_inputs=convert_coref_scorer_inputs,
|
||||||
convert_outputs=convert_coref_scorer_outputs
|
convert_outputs=convert_coref_scorer_outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
coref_model = tok2vec >> coref_scorer
|
||||||
|
# XXX just ignore this until the coref scorer is integrated
|
||||||
span_predictor = PyTorchWrapper(
|
span_predictor = PyTorchWrapper(
|
||||||
SpanPredictor(
|
SpanPredictor(
|
||||||
# TODO this was hardcoded to 1024, check
|
# TODO this was hardcoded to 1024, check
|
||||||
|
@ -499,12 +503,13 @@ def build_wl_coref_model(
|
||||||
sp_embedding_size,
|
sp_embedding_size,
|
||||||
device
|
device
|
||||||
),
|
),
|
||||||
|
|
||||||
convert_inputs=convert_span_predictor_inputs
|
convert_inputs=convert_span_predictor_inputs
|
||||||
)
|
)
|
||||||
# TODO combine models so output is uniform (just one forward pass)
|
# TODO combine models so output is uniform (just one forward pass)
|
||||||
# It may be reasonable to have an option to disable span prediction,
|
# It may be reasonable to have an option to disable span prediction,
|
||||||
# and just return words as spans.
|
# and just return words as spans.
|
||||||
return coref_scorer
|
return coref_model
|
||||||
|
|
||||||
def convert_coref_scorer_inputs(
|
def convert_coref_scorer_inputs(
|
||||||
model: Model,
|
model: Model,
|
||||||
|
@ -534,6 +539,17 @@ def convert_coref_scorer_outputs(
|
||||||
indices_xp = torch2xp(indices)
|
indices_xp = torch2xp(indices)
|
||||||
return (scores_xp, indices_xp), convert_for_torch_backward
|
return (scores_xp, indices_xp), convert_for_torch_backward
|
||||||
|
|
||||||
|
def convert_span_predictor_inputs(
|
||||||
|
model: Model,
|
||||||
|
X: Tuple[Ints1d, Floats2d, Ints1d],
|
||||||
|
is_train: bool
|
||||||
|
):
|
||||||
|
sent_id = xp2torch(X[0], requires_grad=False)
|
||||||
|
word_features = xp2torch(X[1], requires_grad=False)
|
||||||
|
head_ids = xp2torch(X[2], requires_grad=False)
|
||||||
|
argskwargs = ArgsKwargs(args=(sent_id, word_features, head_ids), kwargs={})
|
||||||
|
return argskwargs, lambda dX: []
|
||||||
|
|
||||||
# TODO This probably belongs in the component, not the model.
|
# TODO This probably belongs in the component, not the model.
|
||||||
def predict_span_clusters(span_predictor: Model,
|
def predict_span_clusters(span_predictor: Model,
|
||||||
sent_ids: Ints1d,
|
sent_ids: Ints1d,
|
||||||
|
@ -747,7 +763,7 @@ class AnaphoricityScorer(torch.nn.Module):
|
||||||
|
|
||||||
# [batch_size, n_ants]
|
# [batch_size, n_ants]
|
||||||
scores = top_rough_scores_batch + self._ffnn(pair_matrix)
|
scores = top_rough_scores_batch + self._ffnn(pair_matrix)
|
||||||
scores = utils.add_dummy(scores, eps=True)
|
scores = add_dummy(scores, eps=True)
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user