mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 04:02:20 +03:00
Merge branch 'feature/coref' into coref/dimension-inference
This commit is contained in:
commit
4d032396b8
|
@ -934,6 +934,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
||||||
E1042 = ("Function was called with `{arg1}`={arg1_values} and "
|
E1042 = ("Function was called with `{arg1}`={arg1_values} and "
|
||||||
"`{arg2}`={arg2_values} but these arguments are conflicting.")
|
"`{arg2}`={arg2_values} but these arguments are conflicting.")
|
||||||
|
E1043 = ("Misalignment in coref. Head token has no match in training doc.")
|
||||||
|
|
||||||
|
|
||||||
# Deprecated model shortcuts, only used in errors and warnings
|
# Deprecated model shortcuts, only used in errors and warnings
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, Callable, cast
|
||||||
|
|
||||||
from thinc.api import Model, chain, get_width
|
from thinc.api import Model, chain, get_width
|
||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d, Ints2d
|
||||||
from thinc.util import torch, xp2torch, torch2xp
|
from thinc.util import torch, xp2torch, torch2xp
|
||||||
|
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
|
@ -22,10 +22,8 @@ def build_wl_coref_model(
|
||||||
# pairs to keep per mention after rough scoring
|
# pairs to keep per mention after rough scoring
|
||||||
antecedent_limit: int = 50,
|
antecedent_limit: int = 50,
|
||||||
antecedent_batch_size: int = 512,
|
antecedent_batch_size: int = 512,
|
||||||
):
|
|
||||||
# TODO add model return types
|
|
||||||
|
|
||||||
nI = None
|
nI = None
|
||||||
|
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
coref_clusterer = Model(
|
coref_clusterer = Model(
|
||||||
|
@ -83,7 +81,6 @@ def coref_init(model: Model, X=None, Y=None):
|
||||||
def coref_forward(model: Model, X, is_train: bool):
|
def coref_forward(model: Model, X, is_train: bool):
|
||||||
return model.layers[0](X, is_train)
|
return model.layers[0](X, is_train)
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
||||||
# The input here is List[Floats2d], one for each doc
|
# The input here is List[Floats2d], one for each doc
|
||||||
# just use the first
|
# just use the first
|
||||||
|
@ -91,16 +88,17 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
|
||||||
X = X[0]
|
X = X[0]
|
||||||
word_features = xp2torch(X, requires_grad=is_train)
|
word_features = xp2torch(X, requires_grad=is_train)
|
||||||
|
|
||||||
# TODO fix or remove type annotations
|
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||||
def backprop(args: ArgsKwargs): # -> List[Floats2d]:
|
|
||||||
# convert to xp and wrap in list
|
# convert to xp and wrap in list
|
||||||
gradients = torch2xp(args.args[0])
|
gradients = cast(Floats2d, torch2xp(args.args[0]))
|
||||||
return [gradients]
|
return [gradients]
|
||||||
|
|
||||||
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
|
def convert_coref_clusterer_outputs(
|
||||||
|
model: Model, inputs_outputs, is_train: bool
|
||||||
|
) -> Tuple[Tuple[Floats2d, Ints2d], Callable]:
|
||||||
_, outputs = inputs_outputs
|
_, outputs = inputs_outputs
|
||||||
scores, indices = outputs
|
scores, indices = outputs
|
||||||
|
|
||||||
|
@ -111,8 +109,8 @@ def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool
|
||||||
kwargs={"grad_tensors": [dY_t]},
|
kwargs={"grad_tensors": [dY_t]},
|
||||||
)
|
)
|
||||||
|
|
||||||
scores_xp = torch2xp(scores)
|
scores_xp = cast(Floats2d, torch2xp(scores))
|
||||||
indices_xp = torch2xp(indices)
|
indices_xp = cast(Ints2d, torch2xp(indices))
|
||||||
return (scores_xp, indices_xp), convert_for_torch_backward
|
return (scores_xp, indices_xp), convert_for_torch_backward
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -143,16 +143,18 @@ def create_head_span_idxs(ops, doclen: int):
|
||||||
|
|
||||||
|
|
||||||
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
||||||
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
|
"""Convert the span clusters in a Doc to simple integer tuple lists. The
|
||||||
|
ints are char spans, to be tokenization independent.
|
||||||
|
"""
|
||||||
out = []
|
out = []
|
||||||
for key, val in doc.spans.items():
|
for key, val in doc.spans.items():
|
||||||
cluster = []
|
cluster = []
|
||||||
for span in val:
|
for span in val:
|
||||||
# TODO check that there isn't an off-by-one error here
|
|
||||||
# cluster.append((span.start, span.end))
|
|
||||||
# TODO This conversion should be happening earlier in processing
|
|
||||||
head_i = span.root.i
|
head_i = span.root.i
|
||||||
cluster.append((head_i, head_i + 1))
|
head = doc[head_i]
|
||||||
|
char_span = (head.idx, head.idx + len(head))
|
||||||
|
cluster.append(char_span)
|
||||||
|
|
||||||
# don't want duplicates
|
# don't want duplicates
|
||||||
cluster = list(set(cluster))
|
cluster = list(set(cluster))
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, cast
|
||||||
|
|
||||||
from thinc.api import Model, chain, tuplify, get_width
|
from thinc.api import Model, chain, tuplify, get_width
|
||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
|
@ -76,15 +76,17 @@ def span_predictor_forward(model: Model, X, is_train: bool):
|
||||||
return model.layers[0](X, is_train)
|
return model.layers[0](X, is_train)
|
||||||
|
|
||||||
def convert_span_predictor_inputs(
|
def convert_span_predictor_inputs(
|
||||||
model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool
|
model: Model,
|
||||||
|
X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]],
|
||||||
|
is_train: bool,
|
||||||
):
|
):
|
||||||
tok2vec, (sent_ids, head_ids) = X
|
tok2vec, (sent_ids, head_ids) = X
|
||||||
# Normally we should use the input is_train, but for these two it's not relevant
|
# Normally we should use the input is_train, but for these two it's not relevant
|
||||||
# TODO fix the type here, or remove it
|
# TODO fix the type here, or remove it
|
||||||
def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]:
|
def backprop(args: ArgsKwargs) -> Tuple[List[Floats2d], None]:
|
||||||
gradients = torch2xp(args.args[1])
|
gradients = cast(Floats2d, torch2xp(args.args[1]))
|
||||||
# The sent_ids and head_ids are None because no gradients
|
# The sent_ids and head_ids are None because no gradients
|
||||||
return [[gradients], None]
|
return ([gradients], None)
|
||||||
|
|
||||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||||
sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False)
|
sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False)
|
||||||
|
@ -129,7 +131,6 @@ def predict_span_clusters(
|
||||||
|
|
||||||
|
|
||||||
def build_get_head_metadata(prefix):
|
def build_get_head_metadata(prefix):
|
||||||
# TODO this name is awful, fix it
|
|
||||||
model = Model(
|
model = Model(
|
||||||
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
|
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
|
||||||
)
|
)
|
||||||
|
@ -175,7 +176,6 @@ class SpanPredictor(torch.nn.Module):
|
||||||
raise ValueError("max_distance has to be an even number")
|
raise ValueError("max_distance has to be an even number")
|
||||||
# input size = single token size
|
# input size = single token size
|
||||||
# 64 = probably distance emb size
|
# 64 = probably distance emb size
|
||||||
# TODO check that dist_emb_size use is correct
|
|
||||||
self.ffnn = torch.nn.Sequential(
|
self.ffnn = torch.nn.Sequential(
|
||||||
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
|
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
|
@ -192,7 +192,6 @@ class SpanPredictor(torch.nn.Module):
|
||||||
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
|
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
|
||||||
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
|
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
|
||||||
)
|
)
|
||||||
# TODO make embeddings size a parameter
|
|
||||||
self.max_distance = max_distance
|
self.max_distance = max_distance
|
||||||
# handle distances between +-(max_distance - 2 / 2)
|
# handle distances between +-(max_distance - 2 / 2)
|
||||||
self.emb = torch.nn.Embedding(max_distance, dist_emb_size)
|
self.emb = torch.nn.Embedding(max_distance, dist_emb_size)
|
||||||
|
@ -244,9 +243,7 @@ class SpanPredictor(torch.nn.Module):
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
lengths = same_sent.sum(dim=1)
|
lengths = same_sent.sum(dim=1)
|
||||||
padding_mask = torch.arange(
|
padding_mask = torch.arange(0, lengths.max().item(), device=device).unsqueeze(0)
|
||||||
0, lengths.max().item(), device=device
|
|
||||||
).unsqueeze(0)
|
|
||||||
# (n_heads x max_sent_len)
|
# (n_heads x max_sent_len)
|
||||||
padding_mask = padding_mask < lengths.unsqueeze(1)
|
padding_mask = padding_mask < lengths.unsqueeze(1)
|
||||||
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)
|
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)
|
||||||
|
|
|
@ -95,7 +95,7 @@ def make_coref(
|
||||||
class CoreferenceResolver(TrainablePipe):
|
class CoreferenceResolver(TrainablePipe):
|
||||||
"""Pipeline component for coreference resolution.
|
"""Pipeline component for coreference resolution.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref (TODO)
|
DOCS: https://spacy.io/api/coref
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -118,8 +118,10 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
are stored in.
|
are stored in.
|
||||||
span_cluster_prefix (str): Prefix for the key in doc.spans to store the
|
span_cluster_prefix (str): Prefix for the key in doc.spans to store the
|
||||||
coref clusters in.
|
coref clusters in.
|
||||||
|
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||||
|
Scorer.score_coref_clusters.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#init (TODO)
|
DOCS: https://spacy.io/api/coref#init
|
||||||
"""
|
"""
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -133,11 +135,12 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
Return the list of predicted clusters.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to predict.
|
docs (Iterable[Doc]): The documents to predict.
|
||||||
RETURNS: The models prediction for each document.
|
RETURNS (List[MentionClusters]): The model's prediction for each document.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#predict (TODO)
|
DOCS: https://spacy.io/api/coref#predict
|
||||||
"""
|
"""
|
||||||
out = []
|
out = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
@ -163,7 +166,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
docs (Iterable[Doc]): The documents to modify.
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
clusters: The span clusters, produced by CoreferenceResolver.predict.
|
clusters: The span clusters, produced by CoreferenceResolver.predict.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
|
DOCS: https://spacy.io/api/coref#set_annotations
|
||||||
"""
|
"""
|
||||||
docs = list(docs)
|
docs = list(docs)
|
||||||
if len(docs) != len(clusters_by_doc):
|
if len(docs) != len(clusters_by_doc):
|
||||||
|
@ -204,7 +207,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
Updated using the component name as the key.
|
Updated using the component name as the key.
|
||||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#update (TODO)
|
DOCS: https://spacy.io/api/coref#update
|
||||||
"""
|
"""
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -218,12 +221,17 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
# TODO check this causes no issues (in practice it runs)
|
if eg.x.text != eg.y.text:
|
||||||
|
# TODO assign error number
|
||||||
|
raise ValueError(
|
||||||
|
"""Text, including whitespace, must match between reference and
|
||||||
|
predicted docs in coref training.
|
||||||
|
"""
|
||||||
|
)
|
||||||
preds, backprop = self.model.begin_update([eg.predicted])
|
preds, backprop = self.model.begin_update([eg.predicted])
|
||||||
score_matrix, mention_idx = preds
|
score_matrix, mention_idx = preds
|
||||||
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
# TODO check shape here
|
|
||||||
backprop((d_scores, mention_idx))
|
backprop((d_scores, mention_idx))
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
|
@ -232,7 +240,12 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def rehearse(self, examples, *, sgd=None, losses=None, **config):
|
def rehearse(self, examples, *, sgd=None, losses=None, **config):
|
||||||
raise NotImplementedError
|
# TODO this should be added later
|
||||||
|
raise NotImplementedError(
|
||||||
|
Errors.E931.format(
|
||||||
|
parent="CoreferenceResolver", method="add_label", name=self.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def add_label(self, label: str) -> int:
|
def add_label(self, label: str) -> int:
|
||||||
"""Technically this method should be implemented from TrainablePipe,
|
"""Technically this method should be implemented from TrainablePipe,
|
||||||
|
@ -257,7 +270,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
scores: Scores representing the model's predictions.
|
scores: Scores representing the model's predictions.
|
||||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#get_loss (TODO)
|
DOCS: https://spacy.io/api/coref#get_loss
|
||||||
"""
|
"""
|
||||||
ops = self.model.ops
|
ops = self.model.ops
|
||||||
xp = ops.xp
|
xp = ops.xp
|
||||||
|
@ -267,12 +280,23 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
example = list(examples)[0]
|
example = list(examples)[0]
|
||||||
cidx = mention_idx
|
cidx = mention_idx
|
||||||
|
|
||||||
clusters = get_clusters_from_doc(example.reference)
|
clusters_by_char = get_clusters_from_doc(example.reference)
|
||||||
|
# convert to token clusters, and give up if necessary
|
||||||
|
clusters = []
|
||||||
|
for cluster in clusters_by_char:
|
||||||
|
cc = []
|
||||||
|
for start_char, end_char in cluster:
|
||||||
|
span = example.predicted.char_span(start_char, end_char)
|
||||||
|
if span is None:
|
||||||
|
# TODO log more details
|
||||||
|
raise IndexError(Errors.E1043)
|
||||||
|
cc.append((span.start, span.end))
|
||||||
|
clusters.append(cc)
|
||||||
|
|
||||||
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
||||||
gscores = create_gold_scores(span_idxs, clusters)
|
gscores = create_gold_scores(span_idxs, clusters)
|
||||||
# TODO fix type here. This is bools but asarray2f wants ints.
|
# Note on type here. This is bools but asarray2f wants ints.
|
||||||
gscores = ops.asarray2f(gscores) # type: ignore
|
gscores = ops.asarray2f(gscores) # type: ignore
|
||||||
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
|
||||||
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
||||||
# now add the placeholder
|
# now add the placeholder
|
||||||
gold_placeholder = ~top_gscores.any(axis=1).T
|
gold_placeholder = ~top_gscores.any(axis=1).T
|
||||||
|
@ -304,7 +328,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
returns a representative sample of gold-standard Example objects.
|
returns a representative sample of gold-standard Example objects.
|
||||||
nlp (Language): The current nlp object the component is part of.
|
nlp (Language): The current nlp object the component is part of.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#initialize (TODO)
|
DOCS: https://spacy.io/api/coref#initialize
|
||||||
"""
|
"""
|
||||||
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
||||||
|
|
||||||
|
|
|
@ -383,7 +383,7 @@ class EntityLinker(TrainablePipe):
|
||||||
no prediction.
|
no prediction.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to predict.
|
docs (Iterable[Doc]): The documents to predict.
|
||||||
RETURNS (List[str]): The models prediction for each document.
|
RETURNS (List[str]): The model's prediction for each document.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entitylinker#predict
|
DOCS: https://spacy.io/api/entitylinker#predict
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -29,7 +29,7 @@ distance_embedding_size = 64
|
||||||
conv_channels = 4
|
conv_channels = 4
|
||||||
window_size = 1
|
window_size = 1
|
||||||
max_distance = 128
|
max_distance = 128
|
||||||
prefix = coref_head_clusters
|
prefix = "coref_head_clusters"
|
||||||
|
|
||||||
[model.tok2vec]
|
[model.tok2vec]
|
||||||
@architectures = "spacy.Tok2Vec.v2"
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
@ -95,6 +95,8 @@ class SpanPredictor(TrainablePipe):
|
||||||
"""Pipeline component to resolve one-token spans to full spans.
|
"""Pipeline component to resolve one-token spans to full spans.
|
||||||
|
|
||||||
Used in coreference resolution.
|
Used in coreference resolution.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/span_predictor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -119,6 +121,14 @@ class SpanPredictor(TrainablePipe):
|
||||||
}
|
}
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||||
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
Return the list of predicted span clusters.
|
||||||
|
|
||||||
|
docs (Iterable[Doc]): The documents to predict.
|
||||||
|
RETURNS (List[MentionClusters]): The model's prediction for each document.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/span_predictor#predict
|
||||||
|
"""
|
||||||
# for now pretend there's just one doc
|
# for now pretend there's just one doc
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
|
@ -151,6 +161,13 @@ class SpanPredictor(TrainablePipe):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
||||||
|
"""Modify a batch of Doc objects, using pre-computed scores.
|
||||||
|
|
||||||
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
|
clusters: The span clusters, produced by SpanPredictor.predict.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/span_predictor#set_annotations
|
||||||
|
"""
|
||||||
for doc, clusters in zip(docs, clusters_by_doc):
|
for doc, clusters in zip(docs, clusters_by_doc):
|
||||||
for ii, cluster in enumerate(clusters):
|
for ii, cluster in enumerate(clusters):
|
||||||
spans = [doc[mm[0] : mm[1]] for mm in cluster]
|
spans = [doc[mm[0] : mm[1]] for mm in cluster]
|
||||||
|
@ -166,6 +183,15 @@ class SpanPredictor(TrainablePipe):
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""Learn from a batch of documents and gold-standard information,
|
"""Learn from a batch of documents and gold-standard information,
|
||||||
updating the pipe's model. Delegates to predict and get_loss.
|
updating the pipe's model. Delegates to predict and get_loss.
|
||||||
|
|
||||||
|
examples (Iterable[Example]): A batch of Example objects.
|
||||||
|
drop (float): The dropout rate.
|
||||||
|
sgd (thinc.api.Optimizer): The optimizer.
|
||||||
|
losses (Dict[str, float]): Optional record of the loss during training.
|
||||||
|
Updated using the component name as the key.
|
||||||
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/span_predictor#update
|
||||||
"""
|
"""
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -178,6 +204,13 @@ class SpanPredictor(TrainablePipe):
|
||||||
|
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
|
if eg.x.text != eg.y.text:
|
||||||
|
# TODO assign error number
|
||||||
|
raise ValueError(
|
||||||
|
"""Text, including whitespace, must match between reference and
|
||||||
|
predicted docs in span predictor training.
|
||||||
|
"""
|
||||||
|
)
|
||||||
span_scores, backprop = self.model.begin_update([eg.predicted])
|
span_scores, backprop = self.model.begin_update([eg.predicted])
|
||||||
# FIXME, this only happens once in the first 1000 docs of OntoNotes
|
# FIXME, this only happens once in the first 1000 docs of OntoNotes
|
||||||
# and I'm not sure yet why.
|
# and I'm not sure yet why.
|
||||||
|
@ -222,6 +255,15 @@ class SpanPredictor(TrainablePipe):
|
||||||
examples: Iterable[Example],
|
examples: Iterable[Example],
|
||||||
span_scores: Floats3d,
|
span_scores: Floats3d,
|
||||||
):
|
):
|
||||||
|
"""Find the loss and gradient of loss for the batch of documents and
|
||||||
|
their predicted scores.
|
||||||
|
|
||||||
|
examples (Iterable[Examples]): The batch of examples.
|
||||||
|
scores: Scores representing the model's predictions.
|
||||||
|
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/span_predictor#get_loss
|
||||||
|
"""
|
||||||
ops = self.model.ops
|
ops = self.model.ops
|
||||||
|
|
||||||
# NOTE This is doing fake batching, and should always get a list of one example
|
# NOTE This is doing fake batching, and should always get a list of one example
|
||||||
|
@ -231,16 +273,29 @@ class SpanPredictor(TrainablePipe):
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
starts = []
|
starts = []
|
||||||
ends = []
|
ends = []
|
||||||
|
keeps = []
|
||||||
|
sidx = 0
|
||||||
for key, sg in eg.reference.spans.items():
|
for key, sg in eg.reference.spans.items():
|
||||||
if key.startswith(self.output_prefix):
|
if key.startswith(self.output_prefix):
|
||||||
for mention in sg:
|
for ii, mention in enumerate(sg):
|
||||||
starts.append(mention.start)
|
sidx += 1
|
||||||
ends.append(mention.end)
|
# convert to span in pred
|
||||||
|
sch, ech = (mention.start_char, mention.end_char)
|
||||||
|
span = eg.predicted.char_span(sch, ech)
|
||||||
|
# TODO add to errors.py
|
||||||
|
if span is None:
|
||||||
|
warnings.warn("Could not align gold span in span predictor, skipping")
|
||||||
|
continue
|
||||||
|
starts.append(span.start)
|
||||||
|
ends.append(span.end)
|
||||||
|
keeps.append(sidx - 1)
|
||||||
|
|
||||||
starts = self.model.ops.xp.asarray(starts)
|
starts = self.model.ops.xp.asarray(starts)
|
||||||
ends = self.model.ops.xp.asarray(ends)
|
ends = self.model.ops.xp.asarray(ends)
|
||||||
start_scores = span_scores[:, :, 0]
|
start_scores = span_scores[:, :, 0][keeps]
|
||||||
end_scores = span_scores[:, :, 1]
|
end_scores = span_scores[:, :, 1][keeps]
|
||||||
|
|
||||||
|
|
||||||
n_classes = start_scores.shape[1]
|
n_classes = start_scores.shape[1]
|
||||||
start_probs = ops.softmax(start_scores, axis=1)
|
start_probs = ops.softmax(start_scores, axis=1)
|
||||||
end_probs = ops.softmax(end_scores, axis=1)
|
end_probs = ops.softmax(end_scores, axis=1)
|
||||||
|
@ -248,7 +303,14 @@ class SpanPredictor(TrainablePipe):
|
||||||
end_targets = to_categorical(ends, n_classes)
|
end_targets = to_categorical(ends, n_classes)
|
||||||
start_grads = start_probs - start_targets
|
start_grads = start_probs - start_targets
|
||||||
end_grads = end_probs - end_targets
|
end_grads = end_probs - end_targets
|
||||||
grads = ops.xp.stack((start_grads, end_grads), axis=2)
|
# now return to original shape, with 0s
|
||||||
|
final_start_grads = ops.alloc2f(*span_scores[:, :, 0].shape)
|
||||||
|
final_start_grads[keeps] = start_grads
|
||||||
|
final_end_grads = ops.alloc2f(*final_start_grads.shape)
|
||||||
|
final_end_grads[keeps] = end_grads
|
||||||
|
# XXX Note this only works with fake batching
|
||||||
|
grads = ops.xp.stack((final_start_grads, final_end_grads), axis=2)
|
||||||
|
|
||||||
loss = float((grads**2).sum())
|
loss = float((grads**2).sum())
|
||||||
return loss, grads
|
return loss, grads
|
||||||
|
|
||||||
|
@ -258,6 +320,15 @@ class SpanPredictor(TrainablePipe):
|
||||||
*,
|
*,
|
||||||
nlp: Optional[Language] = None,
|
nlp: Optional[Language] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the pipe for training, using a representative set
|
||||||
|
of data examples.
|
||||||
|
|
||||||
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||||
|
returns a representative sample of gold-standard Example objects.
|
||||||
|
nlp (Language): The current nlp object the component is part of.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/span_predictor#initialize
|
||||||
|
"""
|
||||||
validate_get_examples(get_examples, "SpanPredictor.initialize")
|
validate_get_examples(get_examples, "SpanPredictor.initialize")
|
||||||
|
|
||||||
X = []
|
X = []
|
||||||
|
@ -267,6 +338,7 @@ class SpanPredictor(TrainablePipe):
|
||||||
if not ex.predicted.spans:
|
if not ex.predicted.spans:
|
||||||
# set placeholder for shape inference
|
# set placeholder for shape inference
|
||||||
doc = ex.predicted
|
doc = ex.predicted
|
||||||
|
# TODO should be able to check if there are some valid docs in the batch
|
||||||
assert len(doc) > 2, "Coreference requires at least two tokens"
|
assert len(doc) > 2, "Coreference requires at least two tokens"
|
||||||
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
|
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
|
||||||
X.append(ex.predicted)
|
X.append(ex.predicted)
|
||||||
|
|
|
@ -9,6 +9,7 @@ from spacy.ml.models.coref_util import (
|
||||||
DEFAULT_CLUSTER_PREFIX,
|
DEFAULT_CLUSTER_PREFIX,
|
||||||
select_non_crossing_spans,
|
select_non_crossing_spans,
|
||||||
get_sentence_ids,
|
get_sentence_ids,
|
||||||
|
get_clusters_from_doc,
|
||||||
)
|
)
|
||||||
|
|
||||||
from thinc.util import has_torch
|
from thinc.util import has_torch
|
||||||
|
@ -35,6 +36,9 @@ TRAIN_DATA = [
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG = {"model": {"@architectures": "spacy.Coref.v1", "tok2vec_size": 64}}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def nlp():
|
def nlp():
|
||||||
return English()
|
return English()
|
||||||
|
@ -60,9 +64,10 @@ def test_not_initialized(nlp):
|
||||||
with pytest.raises(ValueError, match="E109"):
|
with pytest.raises(ValueError, match="E109"):
|
||||||
nlp(text)
|
nlp(text)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_initialized(nlp):
|
def test_initialized(nlp):
|
||||||
nlp.add_pipe("coref")
|
nlp.add_pipe("coref", config=CONFIG)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert nlp.pipe_names == ["coref"]
|
assert nlp.pipe_names == ["coref"]
|
||||||
text = "She gave me her pen."
|
text = "She gave me her pen."
|
||||||
|
@ -74,7 +79,7 @@ def test_initialized(nlp):
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_initialized_short(nlp):
|
def test_initialized_short(nlp):
|
||||||
nlp.add_pipe("coref")
|
nlp.add_pipe("coref", config=CONFIG)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert nlp.pipe_names == ["coref"]
|
assert nlp.pipe_names == ["coref"]
|
||||||
text = "Hi there"
|
text = "Hi there"
|
||||||
|
@ -84,58 +89,47 @@ def test_initialized_short(nlp):
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_coref_serialization(nlp):
|
def test_coref_serialization(nlp):
|
||||||
# Test that the coref component can be serialized
|
# Test that the coref component can be serialized
|
||||||
nlp.add_pipe("coref", last=True)
|
nlp.add_pipe("coref", last=True, config=CONFIG)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert nlp.pipe_names == ["coref"]
|
assert nlp.pipe_names == ["coref"]
|
||||||
text = "She gave me her pen."
|
text = "She gave me her pen."
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
spans_result = doc.spans
|
|
||||||
|
|
||||||
with make_tempdir() as tmp_dir:
|
with make_tempdir() as tmp_dir:
|
||||||
nlp.to_disk(tmp_dir)
|
nlp.to_disk(tmp_dir)
|
||||||
nlp2 = spacy.load(tmp_dir)
|
nlp2 = spacy.load(tmp_dir)
|
||||||
assert nlp2.pipe_names == ["coref"]
|
assert nlp2.pipe_names == ["coref"]
|
||||||
doc2 = nlp2(text)
|
doc2 = nlp2(text)
|
||||||
spans_result2 = doc2.spans
|
|
||||||
print(1, [(k, len(v)) for k, v in spans_result.items()])
|
assert get_clusters_from_doc(doc) == get_clusters_from_doc(doc2)
|
||||||
print(2, [(k, len(v)) for k, v in spans_result2.items()])
|
|
||||||
# Note: spans do not compare equal because docs are different and docs
|
|
||||||
# use object identity for equality
|
|
||||||
for k, v in spans_result.items():
|
|
||||||
assert str(spans_result[k]) == str(spans_result2[k])
|
|
||||||
# assert spans_result == spans_result2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_overfitting_IO(nlp):
|
def test_overfitting_IO(nlp):
|
||||||
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit - ensuring the ML models work correctly
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annot in TRAIN_DATA:
|
for text, annot in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
||||||
|
|
||||||
nlp.add_pipe("coref")
|
nlp.add_pipe("coref", config=CONFIG)
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
print("BEFORE", doc.spans)
|
|
||||||
|
|
||||||
for i in range(5):
|
# Needs ~12 epochs to converge
|
||||||
|
for i in range(15):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
print(i, doc.spans)
|
|
||||||
print(losses["coref"]) # < 0.001
|
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
print("AFTER", doc.spans)
|
|
||||||
|
|
||||||
# Also test the results are still the same after IO
|
# Also test the results are still the same after IO
|
||||||
with make_tempdir() as tmp_dir:
|
with make_tempdir() as tmp_dir:
|
||||||
nlp.to_disk(tmp_dir)
|
nlp.to_disk(tmp_dir)
|
||||||
nlp2 = util.load_model_from_path(tmp_dir)
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
doc2 = nlp2(test_text)
|
doc2 = nlp2(test_text)
|
||||||
print("doc2", doc2.spans)
|
|
||||||
|
|
||||||
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||||
texts = [
|
texts = [
|
||||||
|
@ -143,14 +137,67 @@ def test_overfitting_IO(nlp):
|
||||||
"I noticed many friends around me",
|
"I noticed many friends around me",
|
||||||
"They received it. They received the SMS.",
|
"They received it. They received the SMS.",
|
||||||
]
|
]
|
||||||
batch_deps_1 = [doc.spans for doc in nlp.pipe(texts)]
|
docs1 = list(nlp.pipe(texts))
|
||||||
print(batch_deps_1)
|
docs2 = list(nlp.pipe(texts))
|
||||||
batch_deps_2 = [doc.spans for doc in nlp.pipe(texts)]
|
docs3 = [nlp(text) for text in texts]
|
||||||
print(batch_deps_2)
|
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
|
||||||
no_batch_deps = [doc.spans for doc in [nlp(text) for text in texts]]
|
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
|
||||||
print(no_batch_deps)
|
|
||||||
# assert_equal(batch_deps_1, batch_deps_2)
|
|
||||||
# assert_equal(batch_deps_1, no_batch_deps)
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
|
def test_tokenization_mismatch(nlp):
|
||||||
|
train_examples = []
|
||||||
|
for text, annot in TRAIN_DATA:
|
||||||
|
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||||
|
ref = eg.reference
|
||||||
|
char_spans = {}
|
||||||
|
for key, cluster in ref.spans.items():
|
||||||
|
char_spans[key] = []
|
||||||
|
for span in cluster:
|
||||||
|
char_spans[key].append((span[0].idx, span[-1].idx + len(span[-1])))
|
||||||
|
with ref.retokenize() as retokenizer:
|
||||||
|
# merge "many friends"
|
||||||
|
retokenizer.merge(ref[5:7])
|
||||||
|
|
||||||
|
# Note this works because it's the same doc and we know the keys
|
||||||
|
for key, _ in ref.spans.items():
|
||||||
|
spans = char_spans[key]
|
||||||
|
ref.spans[key] = [ref.char_span(*span) for span in spans]
|
||||||
|
|
||||||
|
train_examples.append(eg)
|
||||||
|
|
||||||
|
nlp.add_pipe("coref", config=CONFIG)
|
||||||
|
optimizer = nlp.initialize()
|
||||||
|
test_text = TRAIN_DATA[0][0]
|
||||||
|
doc = nlp(test_text)
|
||||||
|
|
||||||
|
for i in range(15):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
doc = nlp(test_text)
|
||||||
|
|
||||||
|
# test the trained model
|
||||||
|
doc = nlp(test_text)
|
||||||
|
|
||||||
|
# Also test the results are still the same after IO
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
nlp.to_disk(tmp_dir)
|
||||||
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
|
doc2 = nlp2(test_text)
|
||||||
|
|
||||||
|
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||||
|
texts = [
|
||||||
|
test_text,
|
||||||
|
"I noticed many friends around me",
|
||||||
|
"They received it. They received the SMS.",
|
||||||
|
]
|
||||||
|
|
||||||
|
# save the docs so they don't get garbage collected
|
||||||
|
docs1 = list(nlp.pipe(texts))
|
||||||
|
docs2 = list(nlp.pipe(texts))
|
||||||
|
docs3 = [nlp(text) for text in texts]
|
||||||
|
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
|
||||||
|
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
|
@ -165,8 +212,26 @@ def test_crossing_spans():
|
||||||
guess = sorted(guess)
|
guess = sorted(guess)
|
||||||
assert gold == guess
|
assert gold == guess
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_sentence_map(snlp):
|
def test_sentence_map(snlp):
|
||||||
doc = snlp("I like text. This is text.")
|
doc = snlp("I like text. This is text.")
|
||||||
sm = get_sentence_ids(doc)
|
sm = get_sentence_ids(doc)
|
||||||
assert sm == [0, 0, 0, 0, 1, 1, 1, 1]
|
assert sm == [0, 0, 0, 0, 1, 1, 1, 1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
|
def test_whitespace_mismatch(nlp):
|
||||||
|
train_examples = []
|
||||||
|
for text, annot in TRAIN_DATA:
|
||||||
|
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||||
|
eg.predicted = nlp.make_doc(" " + text)
|
||||||
|
train_examples.append(eg)
|
||||||
|
|
||||||
|
nlp.add_pipe("coref", config=CONFIG)
|
||||||
|
optimizer = nlp.initialize()
|
||||||
|
test_text = TRAIN_DATA[0][0]
|
||||||
|
doc = nlp(test_text)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="whitespace"):
|
||||||
|
nlp.update(train_examples, sgd=optimizer)
|
||||||
|
|
227
spacy/tests/pipeline/test_span_predictor.py
Normal file
227
spacy/tests/pipeline/test_span_predictor.py
Normal file
|
@ -0,0 +1,227 @@
|
||||||
|
import pytest
|
||||||
|
import spacy
|
||||||
|
|
||||||
|
from spacy import util
|
||||||
|
from spacy.training import Example
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.tests.util import make_tempdir
|
||||||
|
from spacy.ml.models.coref_util import (
|
||||||
|
DEFAULT_CLUSTER_PREFIX,
|
||||||
|
select_non_crossing_spans,
|
||||||
|
get_sentence_ids,
|
||||||
|
get_clusters_from_doc,
|
||||||
|
)
|
||||||
|
|
||||||
|
from thinc.util import has_torch
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
TRAIN_DATA = [
|
||||||
|
(
|
||||||
|
"John Smith picked up the red ball and he threw it away.",
|
||||||
|
{
|
||||||
|
"spans": {
|
||||||
|
f"{DEFAULT_CLUSTER_PREFIX}_1": [
|
||||||
|
(0, 10, "MENTION"), # John Smith
|
||||||
|
(38, 40, "MENTION"), # he
|
||||||
|
|
||||||
|
],
|
||||||
|
f"{DEFAULT_CLUSTER_PREFIX}_2": [
|
||||||
|
(25, 33, "MENTION"), # red ball
|
||||||
|
(47, 49, "MENTION"), # it
|
||||||
|
],
|
||||||
|
f"coref_head_clusters_1": [
|
||||||
|
(5, 10, "MENTION"), # Smith
|
||||||
|
(38, 40, "MENTION"), # he
|
||||||
|
|
||||||
|
],
|
||||||
|
f"coref_head_clusters_2": [
|
||||||
|
(29, 33, "MENTION"), # red ball
|
||||||
|
(47, 49, "MENTION"), # it
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
CONFIG = {"model": {"@architectures": "spacy.SpanPredictor.v1", "tok2vec_size": 64}}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def nlp():
|
||||||
|
return English()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def snlp():
|
||||||
|
en = English()
|
||||||
|
en.add_pipe("sentencizer")
|
||||||
|
return en
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
|
def test_add_pipe(nlp):
|
||||||
|
nlp.add_pipe("span_predictor")
|
||||||
|
assert nlp.pipe_names == ["span_predictor"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
|
def test_not_initialized(nlp):
|
||||||
|
nlp.add_pipe("span_predictor")
|
||||||
|
text = "She gave me her pen."
|
||||||
|
with pytest.raises(ValueError, match="E109"):
|
||||||
|
nlp(text)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
|
def test_span_predictor_serialization(nlp):
|
||||||
|
# Test that the span predictor component can be serialized
|
||||||
|
nlp.add_pipe("span_predictor", last=True, config=CONFIG)
|
||||||
|
nlp.initialize()
|
||||||
|
assert nlp.pipe_names == ["span_predictor"]
|
||||||
|
text = "She gave me her pen."
|
||||||
|
doc = nlp(text)
|
||||||
|
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
nlp.to_disk(tmp_dir)
|
||||||
|
nlp2 = spacy.load(tmp_dir)
|
||||||
|
assert nlp2.pipe_names == ["span_predictor"]
|
||||||
|
doc2 = nlp2(text)
|
||||||
|
|
||||||
|
assert get_clusters_from_doc(doc) == get_clusters_from_doc(doc2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
|
def test_overfitting_IO(nlp):
|
||||||
|
# Simple test to try and quickly overfit - ensuring the ML models work correctly
|
||||||
|
train_examples = []
|
||||||
|
for text, annot in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
for text, annot in TRAIN_DATA:
|
||||||
|
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||||
|
ref = eg.reference
|
||||||
|
# Finally, copy over the head spans to the pred
|
||||||
|
pred = eg.predicted
|
||||||
|
for key, spans in ref.spans.items():
|
||||||
|
if key.startswith("coref_head_clusters"):
|
||||||
|
pred.spans[key] = [pred[span.start : span.end] for span in spans]
|
||||||
|
|
||||||
|
train_examples.append(eg)
|
||||||
|
nlp.add_pipe("span_predictor", config=CONFIG)
|
||||||
|
optimizer = nlp.initialize()
|
||||||
|
test_text = TRAIN_DATA[0][0]
|
||||||
|
doc = nlp(test_text)
|
||||||
|
|
||||||
|
for i in range(15):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
doc = nlp(test_text)
|
||||||
|
|
||||||
|
# test the trained model, using the pred since it has heads
|
||||||
|
doc = nlp(train_examples[0].predicted)
|
||||||
|
# XXX This actually tests that it can overfit
|
||||||
|
assert get_clusters_from_doc(doc) == get_clusters_from_doc(train_examples[0].reference)
|
||||||
|
|
||||||
|
# Also test the results are still the same after IO
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
nlp.to_disk(tmp_dir)
|
||||||
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
|
doc2 = nlp2(test_text)
|
||||||
|
|
||||||
|
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||||
|
texts = [
|
||||||
|
test_text,
|
||||||
|
"I noticed many friends around me",
|
||||||
|
"They received it. They received the SMS.",
|
||||||
|
]
|
||||||
|
# XXX Note these have no predictions because they have no input spans
|
||||||
|
docs1 = list(nlp.pipe(texts))
|
||||||
|
docs2 = list(nlp.pipe(texts))
|
||||||
|
docs3 = [nlp(text) for text in texts]
|
||||||
|
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
|
||||||
|
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
|
def test_tokenization_mismatch(nlp):
|
||||||
|
train_examples = []
|
||||||
|
for text, annot in TRAIN_DATA:
|
||||||
|
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||||
|
ref = eg.reference
|
||||||
|
char_spans = {}
|
||||||
|
for key, cluster in ref.spans.items():
|
||||||
|
char_spans[key] = []
|
||||||
|
for span in cluster:
|
||||||
|
char_spans[key].append((span.start_char, span.end_char))
|
||||||
|
with ref.retokenize() as retokenizer:
|
||||||
|
# merge "picked up"
|
||||||
|
retokenizer.merge(ref[2:4])
|
||||||
|
|
||||||
|
# Note this works because it's the same doc and we know the keys
|
||||||
|
for key, _ in ref.spans.items():
|
||||||
|
spans = char_spans[key]
|
||||||
|
ref.spans[key] = [ref.char_span(*span) for span in spans]
|
||||||
|
|
||||||
|
# Finally, copy over the head spans to the pred
|
||||||
|
pred = eg.predicted
|
||||||
|
for key, val in ref.spans.items():
|
||||||
|
if key.startswith("coref_head_clusters"):
|
||||||
|
spans = char_spans[key]
|
||||||
|
pred.spans[key] = [pred.char_span(*span) for span in spans]
|
||||||
|
|
||||||
|
train_examples.append(eg)
|
||||||
|
|
||||||
|
nlp.add_pipe("span_predictor", config=CONFIG)
|
||||||
|
optimizer = nlp.initialize()
|
||||||
|
test_text = TRAIN_DATA[0][0]
|
||||||
|
doc = nlp(test_text)
|
||||||
|
|
||||||
|
for i in range(15):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
doc = nlp(test_text)
|
||||||
|
|
||||||
|
# test the trained model; need to use doc with head spans on it already
|
||||||
|
test_doc = train_examples[0].predicted
|
||||||
|
doc = nlp(test_doc)
|
||||||
|
# XXX This actually tests that it can overfit
|
||||||
|
assert get_clusters_from_doc(doc) == get_clusters_from_doc(train_examples[0].reference)
|
||||||
|
|
||||||
|
# Also test the results are still the same after IO
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
nlp.to_disk(tmp_dir)
|
||||||
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
|
doc2 = nlp2(test_text)
|
||||||
|
|
||||||
|
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||||
|
texts = [
|
||||||
|
test_text,
|
||||||
|
"I noticed many friends around me",
|
||||||
|
"They received it. They received the SMS.",
|
||||||
|
]
|
||||||
|
|
||||||
|
# save the docs so they don't get garbage collected
|
||||||
|
docs1 = list(nlp.pipe(texts))
|
||||||
|
docs2 = list(nlp.pipe(texts))
|
||||||
|
docs3 = [nlp(text) for text in texts]
|
||||||
|
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
|
||||||
|
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
|
def test_whitespace_mismatch(nlp):
|
||||||
|
train_examples = []
|
||||||
|
for text, annot in TRAIN_DATA:
|
||||||
|
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||||
|
eg.predicted = nlp.make_doc(" " + text)
|
||||||
|
train_examples.append(eg)
|
||||||
|
|
||||||
|
nlp.add_pipe("span_predictor", config=CONFIG)
|
||||||
|
optimizer = nlp.initialize()
|
||||||
|
test_text = TRAIN_DATA[0][0]
|
||||||
|
doc = nlp(test_text)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="whitespace"):
|
||||||
|
nlp.update(train_examples, sgd=optimizer)
|
|
@ -587,8 +587,8 @@ consists of either two or three subnetworks:
|
||||||
run once for each batch.
|
run once for each batch.
|
||||||
- **lower**: Construct a feature-specific vector for each `(token, feature)`
|
- **lower**: Construct a feature-specific vector for each `(token, feature)`
|
||||||
pair. This is also run once for each batch. Constructing the state
|
pair. This is also run once for each batch. Constructing the state
|
||||||
representation is then a matter of summing the component features and
|
representation is then a matter of summing the component features and applying
|
||||||
applying the non-linearity.
|
the non-linearity.
|
||||||
- **upper** (optional): A feed-forward network that predicts scores from the
|
- **upper** (optional): A feed-forward network that predicts scores from the
|
||||||
state representation. If not present, the output from the lower model is used
|
state representation. If not present, the output from the lower model is used
|
||||||
as action scores directly.
|
as action scores directly.
|
||||||
|
@ -628,8 +628,8 @@ same signature, but the `use_upper` argument was `True` by default.
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
Build a tagger model, using a provided token-to-vector component. The tagger
|
Build a tagger model, using a provided token-to-vector component. The tagger
|
||||||
model adds a linear layer with softmax activation to predict scores given
|
model adds a linear layer with softmax activation to predict scores given the
|
||||||
the token vectors.
|
token vectors.
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ----------- | ------------------------------------------------------------------------------------------ |
|
| ----------- | ------------------------------------------------------------------------------------------ |
|
||||||
|
@ -920,8 +920,8 @@ A function that reads an existing `KnowledgeBase` from file.
|
||||||
A function that takes as input a [`KnowledgeBase`](/api/kb) and a
|
A function that takes as input a [`KnowledgeBase`](/api/kb) and a
|
||||||
[`Span`](/api/span) object denoting a named entity, and returns a list of
|
[`Span`](/api/span) object denoting a named entity, and returns a list of
|
||||||
plausible [`Candidate`](/api/kb/#candidate) objects. The default
|
plausible [`Candidate`](/api/kb/#candidate) objects. The default
|
||||||
`CandidateGenerator` uses the text of a mention to find its potential
|
`CandidateGenerator` uses the text of a mention to find its potential aliases in
|
||||||
aliases in the `KnowledgeBase`. Note that this function is case-dependent.
|
the `KnowledgeBase`. Note that this function is case-dependent.
|
||||||
|
|
||||||
## Coreference Architectures
|
## Coreference Architectures
|
||||||
|
|
||||||
|
@ -975,7 +975,11 @@ The `Coref` model architecture is a Thinc `Model`.
|
||||||
> [model]
|
> [model]
|
||||||
> @architectures = "spacy.SpanPredictor.v1"
|
> @architectures = "spacy.SpanPredictor.v1"
|
||||||
> hidden_size = 1024
|
> hidden_size = 1024
|
||||||
> dist_emb_size = 64
|
> distance_embedding_size = 64
|
||||||
|
> conv_channels = 4
|
||||||
|
> window_size = 1
|
||||||
|
> max_distance = 128
|
||||||
|
> prefix = "coref_head_clusters"
|
||||||
>
|
>
|
||||||
> [model.tok2vec]
|
> [model.tok2vec]
|
||||||
> @architectures = "spacy-transformers.TransformerListener.v1"
|
> @architectures = "spacy-transformers.TransformerListener.v1"
|
||||||
|
@ -987,12 +991,13 @@ The `Coref` model architecture is a Thinc `Model`.
|
||||||
The `SpanPredictor` model architecture is a Thinc `Model`.
|
The `SpanPredictor` model architecture is a Thinc `Model`.
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ------------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
|
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
|
||||||
| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ |
|
| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ |
|
||||||
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
|
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
|
||||||
| `hidden_size` | Size of the main internal layers. ~~int~~ |
|
| `hidden_size` | Size of the main internal layers. ~~int~~ |
|
||||||
| `depth` | Depth of the internal network. ~~int~~ |
|
| `conv_channels` | The number of channels in the internal CNN. ~~int~~ |
|
||||||
| `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ |
|
| `window_size` | The number of neighboring tokens to consider in the internal CNN. `1` means consider one token on each side. ~~int~~ |
|
||||||
| `antecedent_batch_size` | Internal batch size. ~~int~~ |
|
| `max_distance` | The longest possible length of a predicted span. ~~int~~ |
|
||||||
|
| `prefix` | The prefix that indicates spans to use for input data. ~~string~~ |
|
||||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |
|
| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user