Merge branch 'feature/coref' into coref/dimension-inference

This commit is contained in:
Paul O'Leary McCann 2022-07-11 19:18:46 +09:00
commit 4d032396b8
10 changed files with 486 additions and 95 deletions

View File

@ -934,6 +934,7 @@ class Errors(metaclass=ErrorsWithCodes):
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}") E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
E1042 = ("Function was called with `{arg1}`={arg1_values} and " E1042 = ("Function was called with `{arg1}`={arg1_values} and "
"`{arg2}`={arg2_values} but these arguments are conflicting.") "`{arg2}`={arg2_values} but these arguments are conflicting.")
E1043 = ("Misalignment in coref. Head token has no match in training doc.")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -1,8 +1,8 @@
from typing import List, Tuple from typing import List, Tuple, Callable, cast
from thinc.api import Model, chain, get_width from thinc.api import Model, chain, get_width
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d from thinc.types import Floats2d, Ints2d
from thinc.util import torch, xp2torch, torch2xp from thinc.util import torch, xp2torch, torch2xp
from ...tokens import Doc from ...tokens import Doc
@ -22,10 +22,8 @@ def build_wl_coref_model(
# pairs to keep per mention after rough scoring # pairs to keep per mention after rough scoring
antecedent_limit: int = 50, antecedent_limit: int = 50,
antecedent_batch_size: int = 512, antecedent_batch_size: int = 512,
):
# TODO add model return types
nI = None nI = None
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
coref_clusterer = Model( coref_clusterer = Model(
@ -83,7 +81,6 @@ def coref_init(model: Model, X=None, Y=None):
def coref_forward(model: Model, X, is_train: bool): def coref_forward(model: Model, X, is_train: bool):
return model.layers[0](X, is_train) return model.layers[0](X, is_train)
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool): def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
# The input here is List[Floats2d], one for each doc # The input here is List[Floats2d], one for each doc
# just use the first # just use the first
@ -91,16 +88,17 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
X = X[0] X = X[0]
word_features = xp2torch(X, requires_grad=is_train) word_features = xp2torch(X, requires_grad=is_train)
# TODO fix or remove type annotations def backprop(args: ArgsKwargs) -> List[Floats2d]:
def backprop(args: ArgsKwargs): # -> List[Floats2d]:
# convert to xp and wrap in list # convert to xp and wrap in list
gradients = torch2xp(args.args[0]) gradients = cast(Floats2d, torch2xp(args.args[0]))
return [gradients] return [gradients]
return ArgsKwargs(args=(word_features,), kwargs={}), backprop return ArgsKwargs(args=(word_features,), kwargs={}), backprop
def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool): def convert_coref_clusterer_outputs(
model: Model, inputs_outputs, is_train: bool
) -> Tuple[Tuple[Floats2d, Ints2d], Callable]:
_, outputs = inputs_outputs _, outputs = inputs_outputs
scores, indices = outputs scores, indices = outputs
@ -111,8 +109,8 @@ def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool
kwargs={"grad_tensors": [dY_t]}, kwargs={"grad_tensors": [dY_t]},
) )
scores_xp = torch2xp(scores) scores_xp = cast(Floats2d, torch2xp(scores))
indices_xp = torch2xp(indices) indices_xp = cast(Ints2d, torch2xp(indices))
return (scores_xp, indices_xp), convert_for_torch_backward return (scores_xp, indices_xp), convert_for_torch_backward

View File

@ -143,16 +143,18 @@ def create_head_span_idxs(ops, doclen: int):
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]: def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
"""Given a Doc, convert the cluster spans to simple int tuple lists.""" """Convert the span clusters in a Doc to simple integer tuple lists. The
ints are char spans, to be tokenization independent.
"""
out = [] out = []
for key, val in doc.spans.items(): for key, val in doc.spans.items():
cluster = [] cluster = []
for span in val: for span in val:
# TODO check that there isn't an off-by-one error here
# cluster.append((span.start, span.end))
# TODO This conversion should be happening earlier in processing
head_i = span.root.i head_i = span.root.i
cluster.append((head_i, head_i + 1)) head = doc[head_i]
char_span = (head.idx, head.idx + len(head))
cluster.append(char_span)
# don't want duplicates # don't want duplicates
cluster = list(set(cluster)) cluster = list(set(cluster))

View File

@ -1,4 +1,4 @@
from typing import List, Tuple from typing import List, Tuple, cast
from thinc.api import Model, chain, tuplify, get_width from thinc.api import Model, chain, tuplify, get_width
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
@ -76,15 +76,17 @@ def span_predictor_forward(model: Model, X, is_train: bool):
return model.layers[0](X, is_train) return model.layers[0](X, is_train)
def convert_span_predictor_inputs( def convert_span_predictor_inputs(
model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool model: Model,
X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]],
is_train: bool,
): ):
tok2vec, (sent_ids, head_ids) = X tok2vec, (sent_ids, head_ids) = X
# Normally we should use the input is_train, but for these two it's not relevant # Normally we should use the input is_train, but for these two it's not relevant
# TODO fix the type here, or remove it # TODO fix the type here, or remove it
def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]: def backprop(args: ArgsKwargs) -> Tuple[List[Floats2d], None]:
gradients = torch2xp(args.args[1]) gradients = cast(Floats2d, torch2xp(args.args[1]))
# The sent_ids and head_ids are None because no gradients # The sent_ids and head_ids are None because no gradients
return [[gradients], None] return ([gradients], None)
word_features = xp2torch(tok2vec[0], requires_grad=is_train) word_features = xp2torch(tok2vec[0], requires_grad=is_train)
sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False) sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False)
@ -129,7 +131,6 @@ def predict_span_clusters(
def build_get_head_metadata(prefix): def build_get_head_metadata(prefix):
# TODO this name is awful, fix it
model = Model( model = Model(
"HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward "HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward
) )
@ -175,7 +176,6 @@ class SpanPredictor(torch.nn.Module):
raise ValueError("max_distance has to be an even number") raise ValueError("max_distance has to be an even number")
# input size = single token size # input size = single token size
# 64 = probably distance emb size # 64 = probably distance emb size
# TODO check that dist_emb_size use is correct
self.ffnn = torch.nn.Sequential( self.ffnn = torch.nn.Sequential(
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size), torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
torch.nn.ReLU(), torch.nn.ReLU(),
@ -192,7 +192,6 @@ class SpanPredictor(torch.nn.Module):
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1), torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1), torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
) )
# TODO make embeddings size a parameter
self.max_distance = max_distance self.max_distance = max_distance
# handle distances between +-(max_distance - 2 / 2) # handle distances between +-(max_distance - 2 / 2)
self.emb = torch.nn.Embedding(max_distance, dist_emb_size) self.emb = torch.nn.Embedding(max_distance, dist_emb_size)
@ -244,9 +243,7 @@ class SpanPredictor(torch.nn.Module):
dim=1, dim=1,
) )
lengths = same_sent.sum(dim=1) lengths = same_sent.sum(dim=1)
padding_mask = torch.arange( padding_mask = torch.arange(0, lengths.max().item(), device=device).unsqueeze(0)
0, lengths.max().item(), device=device
).unsqueeze(0)
# (n_heads x max_sent_len) # (n_heads x max_sent_len)
padding_mask = padding_mask < lengths.unsqueeze(1) padding_mask = padding_mask < lengths.unsqueeze(1)
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size) # (n_heads x max_sent_len x input_size * 2 + distance_emb_size)

View File

@ -95,7 +95,7 @@ def make_coref(
class CoreferenceResolver(TrainablePipe): class CoreferenceResolver(TrainablePipe):
"""Pipeline component for coreference resolution. """Pipeline component for coreference resolution.
DOCS: https://spacy.io/api/coref (TODO) DOCS: https://spacy.io/api/coref
""" """
def __init__( def __init__(
@ -118,8 +118,10 @@ class CoreferenceResolver(TrainablePipe):
are stored in. are stored in.
span_cluster_prefix (str): Prefix for the key in doc.spans to store the span_cluster_prefix (str): Prefix for the key in doc.spans to store the
coref clusters in. coref clusters in.
scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_coref_clusters.
DOCS: https://spacy.io/api/coref#init (TODO) DOCS: https://spacy.io/api/coref#init
""" """
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
@ -133,11 +135,12 @@ class CoreferenceResolver(TrainablePipe):
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
Return the list of predicted clusters.
docs (Iterable[Doc]): The documents to predict. docs (Iterable[Doc]): The documents to predict.
RETURNS: The models prediction for each document. RETURNS (List[MentionClusters]): The model's prediction for each document.
DOCS: https://spacy.io/api/coref#predict (TODO) DOCS: https://spacy.io/api/coref#predict
""" """
out = [] out = []
for doc in docs: for doc in docs:
@ -163,7 +166,7 @@ class CoreferenceResolver(TrainablePipe):
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
clusters: The span clusters, produced by CoreferenceResolver.predict. clusters: The span clusters, produced by CoreferenceResolver.predict.
DOCS: https://spacy.io/api/coref#set_annotations (TODO) DOCS: https://spacy.io/api/coref#set_annotations
""" """
docs = list(docs) docs = list(docs)
if len(docs) != len(clusters_by_doc): if len(docs) != len(clusters_by_doc):
@ -204,7 +207,7 @@ class CoreferenceResolver(TrainablePipe):
Updated using the component name as the key. Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary. RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/coref#update (TODO) DOCS: https://spacy.io/api/coref#update
""" """
if losses is None: if losses is None:
losses = {} losses = {}
@ -218,12 +221,17 @@ class CoreferenceResolver(TrainablePipe):
total_loss = 0 total_loss = 0
for eg in examples: for eg in examples:
# TODO check this causes no issues (in practice it runs) if eg.x.text != eg.y.text:
# TODO assign error number
raise ValueError(
"""Text, including whitespace, must match between reference and
predicted docs in coref training.
"""
)
preds, backprop = self.model.begin_update([eg.predicted]) preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds score_matrix, mention_idx = preds
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
total_loss += loss total_loss += loss
# TODO check shape here
backprop((d_scores, mention_idx)) backprop((d_scores, mention_idx))
if sgd is not None: if sgd is not None:
@ -232,7 +240,12 @@ class CoreferenceResolver(TrainablePipe):
return losses return losses
def rehearse(self, examples, *, sgd=None, losses=None, **config): def rehearse(self, examples, *, sgd=None, losses=None, **config):
raise NotImplementedError # TODO this should be added later
raise NotImplementedError(
Errors.E931.format(
parent="CoreferenceResolver", method="add_label", name=self.name
)
)
def add_label(self, label: str) -> int: def add_label(self, label: str) -> int:
"""Technically this method should be implemented from TrainablePipe, """Technically this method should be implemented from TrainablePipe,
@ -257,7 +270,7 @@ class CoreferenceResolver(TrainablePipe):
scores: Scores representing the model's predictions. scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient. RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/coref#get_loss (TODO) DOCS: https://spacy.io/api/coref#get_loss
""" """
ops = self.model.ops ops = self.model.ops
xp = ops.xp xp = ops.xp
@ -267,12 +280,23 @@ class CoreferenceResolver(TrainablePipe):
example = list(examples)[0] example = list(examples)[0]
cidx = mention_idx cidx = mention_idx
clusters = get_clusters_from_doc(example.reference) clusters_by_char = get_clusters_from_doc(example.reference)
# convert to token clusters, and give up if necessary
clusters = []
for cluster in clusters_by_char:
cc = []
for start_char, end_char in cluster:
span = example.predicted.char_span(start_char, end_char)
if span is None:
# TODO log more details
raise IndexError(Errors.E1043)
cc.append((span.start, span.end))
clusters.append(cc)
span_idxs = create_head_span_idxs(ops, len(example.predicted)) span_idxs = create_head_span_idxs(ops, len(example.predicted))
gscores = create_gold_scores(span_idxs, clusters) gscores = create_gold_scores(span_idxs, clusters)
# TODO fix type here. This is bools but asarray2f wants ints. # Note on type here. This is bools but asarray2f wants ints.
gscores = ops.asarray2f(gscores) # type: ignore gscores = ops.asarray2f(gscores) # type: ignore
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1) top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
# now add the placeholder # now add the placeholder
gold_placeholder = ~top_gscores.any(axis=1).T gold_placeholder = ~top_gscores.any(axis=1).T
@ -304,7 +328,7 @@ class CoreferenceResolver(TrainablePipe):
returns a representative sample of gold-standard Example objects. returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of. nlp (Language): The current nlp object the component is part of.
DOCS: https://spacy.io/api/coref#initialize (TODO) DOCS: https://spacy.io/api/coref#initialize
""" """
validate_get_examples(get_examples, "CoreferenceResolver.initialize") validate_get_examples(get_examples, "CoreferenceResolver.initialize")

View File

@ -383,7 +383,7 @@ class EntityLinker(TrainablePipe):
no prediction. no prediction.
docs (Iterable[Doc]): The documents to predict. docs (Iterable[Doc]): The documents to predict.
RETURNS (List[str]): The models prediction for each document. RETURNS (List[str]): The model's prediction for each document.
DOCS: https://spacy.io/api/entitylinker#predict DOCS: https://spacy.io/api/entitylinker#predict
""" """

View File

@ -29,7 +29,7 @@ distance_embedding_size = 64
conv_channels = 4 conv_channels = 4
window_size = 1 window_size = 1
max_distance = 128 max_distance = 128
prefix = coref_head_clusters prefix = "coref_head_clusters"
[model.tok2vec] [model.tok2vec]
@architectures = "spacy.Tok2Vec.v2" @architectures = "spacy.Tok2Vec.v2"
@ -95,6 +95,8 @@ class SpanPredictor(TrainablePipe):
"""Pipeline component to resolve one-token spans to full spans. """Pipeline component to resolve one-token spans to full spans.
Used in coreference resolution. Used in coreference resolution.
DOCS: https://spacy.io/api/span_predictor
""" """
def __init__( def __init__(
@ -119,6 +121,14 @@ class SpanPredictor(TrainablePipe):
} }
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
"""Apply the pipeline's model to a batch of docs, without modifying them.
Return the list of predicted span clusters.
docs (Iterable[Doc]): The documents to predict.
RETURNS (List[MentionClusters]): The model's prediction for each document.
DOCS: https://spacy.io/api/span_predictor#predict
"""
# for now pretend there's just one doc # for now pretend there's just one doc
out = [] out = []
@ -151,6 +161,13 @@ class SpanPredictor(TrainablePipe):
return out return out
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
"""Modify a batch of Doc objects, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
clusters: The span clusters, produced by SpanPredictor.predict.
DOCS: https://spacy.io/api/span_predictor#set_annotations
"""
for doc, clusters in zip(docs, clusters_by_doc): for doc, clusters in zip(docs, clusters_by_doc):
for ii, cluster in enumerate(clusters): for ii, cluster in enumerate(clusters):
spans = [doc[mm[0] : mm[1]] for mm in cluster] spans = [doc[mm[0] : mm[1]] for mm in cluster]
@ -166,6 +183,15 @@ class SpanPredictor(TrainablePipe):
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Learn from a batch of documents and gold-standard information, """Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss. updating the pipe's model. Delegates to predict and get_loss.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/span_predictor#update
""" """
if losses is None: if losses is None:
losses = {} losses = {}
@ -178,6 +204,13 @@ class SpanPredictor(TrainablePipe):
total_loss = 0 total_loss = 0
for eg in examples: for eg in examples:
if eg.x.text != eg.y.text:
# TODO assign error number
raise ValueError(
"""Text, including whitespace, must match between reference and
predicted docs in span predictor training.
"""
)
span_scores, backprop = self.model.begin_update([eg.predicted]) span_scores, backprop = self.model.begin_update([eg.predicted])
# FIXME, this only happens once in the first 1000 docs of OntoNotes # FIXME, this only happens once in the first 1000 docs of OntoNotes
# and I'm not sure yet why. # and I'm not sure yet why.
@ -222,6 +255,15 @@ class SpanPredictor(TrainablePipe):
examples: Iterable[Example], examples: Iterable[Example],
span_scores: Floats3d, span_scores: Floats3d,
): ):
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.
examples (Iterable[Examples]): The batch of examples.
scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/span_predictor#get_loss
"""
ops = self.model.ops ops = self.model.ops
# NOTE This is doing fake batching, and should always get a list of one example # NOTE This is doing fake batching, and should always get a list of one example
@ -231,16 +273,29 @@ class SpanPredictor(TrainablePipe):
for eg in examples: for eg in examples:
starts = [] starts = []
ends = [] ends = []
keeps = []
sidx = 0
for key, sg in eg.reference.spans.items(): for key, sg in eg.reference.spans.items():
if key.startswith(self.output_prefix): if key.startswith(self.output_prefix):
for mention in sg: for ii, mention in enumerate(sg):
starts.append(mention.start) sidx += 1
ends.append(mention.end) # convert to span in pred
sch, ech = (mention.start_char, mention.end_char)
span = eg.predicted.char_span(sch, ech)
# TODO add to errors.py
if span is None:
warnings.warn("Could not align gold span in span predictor, skipping")
continue
starts.append(span.start)
ends.append(span.end)
keeps.append(sidx - 1)
starts = self.model.ops.xp.asarray(starts) starts = self.model.ops.xp.asarray(starts)
ends = self.model.ops.xp.asarray(ends) ends = self.model.ops.xp.asarray(ends)
start_scores = span_scores[:, :, 0] start_scores = span_scores[:, :, 0][keeps]
end_scores = span_scores[:, :, 1] end_scores = span_scores[:, :, 1][keeps]
n_classes = start_scores.shape[1] n_classes = start_scores.shape[1]
start_probs = ops.softmax(start_scores, axis=1) start_probs = ops.softmax(start_scores, axis=1)
end_probs = ops.softmax(end_scores, axis=1) end_probs = ops.softmax(end_scores, axis=1)
@ -248,7 +303,14 @@ class SpanPredictor(TrainablePipe):
end_targets = to_categorical(ends, n_classes) end_targets = to_categorical(ends, n_classes)
start_grads = start_probs - start_targets start_grads = start_probs - start_targets
end_grads = end_probs - end_targets end_grads = end_probs - end_targets
grads = ops.xp.stack((start_grads, end_grads), axis=2) # now return to original shape, with 0s
final_start_grads = ops.alloc2f(*span_scores[:, :, 0].shape)
final_start_grads[keeps] = start_grads
final_end_grads = ops.alloc2f(*final_start_grads.shape)
final_end_grads[keeps] = end_grads
# XXX Note this only works with fake batching
grads = ops.xp.stack((final_start_grads, final_end_grads), axis=2)
loss = float((grads**2).sum()) loss = float((grads**2).sum())
return loss, grads return loss, grads
@ -258,6 +320,15 @@ class SpanPredictor(TrainablePipe):
*, *,
nlp: Optional[Language] = None, nlp: Optional[Language] = None,
) -> None: ) -> None:
"""Initialize the pipe for training, using a representative set
of data examples.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
DOCS: https://spacy.io/api/span_predictor#initialize
"""
validate_get_examples(get_examples, "SpanPredictor.initialize") validate_get_examples(get_examples, "SpanPredictor.initialize")
X = [] X = []
@ -267,6 +338,7 @@ class SpanPredictor(TrainablePipe):
if not ex.predicted.spans: if not ex.predicted.spans:
# set placeholder for shape inference # set placeholder for shape inference
doc = ex.predicted doc = ex.predicted
# TODO should be able to check if there are some valid docs in the batch
assert len(doc) > 2, "Coreference requires at least two tokens" assert len(doc) > 2, "Coreference requires at least two tokens"
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]] doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
X.append(ex.predicted) X.append(ex.predicted)

View File

@ -9,6 +9,7 @@ from spacy.ml.models.coref_util import (
DEFAULT_CLUSTER_PREFIX, DEFAULT_CLUSTER_PREFIX,
select_non_crossing_spans, select_non_crossing_spans,
get_sentence_ids, get_sentence_ids,
get_clusters_from_doc,
) )
from thinc.util import has_torch from thinc.util import has_torch
@ -35,6 +36,9 @@ TRAIN_DATA = [
# fmt: on # fmt: on
CONFIG = {"model": {"@architectures": "spacy.Coref.v1", "tok2vec_size": 64}}
@pytest.fixture @pytest.fixture
def nlp(): def nlp():
return English() return English()
@ -60,9 +64,10 @@ def test_not_initialized(nlp):
with pytest.raises(ValueError, match="E109"): with pytest.raises(ValueError, match="E109"):
nlp(text) nlp(text)
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_initialized(nlp): def test_initialized(nlp):
nlp.add_pipe("coref") nlp.add_pipe("coref", config=CONFIG)
nlp.initialize() nlp.initialize()
assert nlp.pipe_names == ["coref"] assert nlp.pipe_names == ["coref"]
text = "She gave me her pen." text = "She gave me her pen."
@ -74,7 +79,7 @@ def test_initialized(nlp):
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_initialized_short(nlp): def test_initialized_short(nlp):
nlp.add_pipe("coref") nlp.add_pipe("coref", config=CONFIG)
nlp.initialize() nlp.initialize()
assert nlp.pipe_names == ["coref"] assert nlp.pipe_names == ["coref"]
text = "Hi there" text = "Hi there"
@ -84,58 +89,47 @@ def test_initialized_short(nlp):
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_coref_serialization(nlp): def test_coref_serialization(nlp):
# Test that the coref component can be serialized # Test that the coref component can be serialized
nlp.add_pipe("coref", last=True) nlp.add_pipe("coref", last=True, config=CONFIG)
nlp.initialize() nlp.initialize()
assert nlp.pipe_names == ["coref"] assert nlp.pipe_names == ["coref"]
text = "She gave me her pen." text = "She gave me her pen."
doc = nlp(text) doc = nlp(text)
spans_result = doc.spans
with make_tempdir() as tmp_dir: with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir) nlp.to_disk(tmp_dir)
nlp2 = spacy.load(tmp_dir) nlp2 = spacy.load(tmp_dir)
assert nlp2.pipe_names == ["coref"] assert nlp2.pipe_names == ["coref"]
doc2 = nlp2(text) doc2 = nlp2(text)
spans_result2 = doc2.spans
print(1, [(k, len(v)) for k, v in spans_result.items()]) assert get_clusters_from_doc(doc) == get_clusters_from_doc(doc2)
print(2, [(k, len(v)) for k, v in spans_result2.items()])
# Note: spans do not compare equal because docs are different and docs
# use object identity for equality
for k, v in spans_result.items():
assert str(spans_result[k]) == str(spans_result2[k])
# assert spans_result == spans_result2
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_overfitting_IO(nlp): def test_overfitting_IO(nlp):
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly # Simple test to try and quickly overfit - ensuring the ML models work correctly
train_examples = [] train_examples = []
for text, annot in TRAIN_DATA: for text, annot in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
nlp.add_pipe("coref") nlp.add_pipe("coref", config=CONFIG)
optimizer = nlp.initialize() optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0] test_text = TRAIN_DATA[0][0]
doc = nlp(test_text) doc = nlp(test_text)
print("BEFORE", doc.spans)
for i in range(5): # Needs ~12 epochs to converge
for i in range(15):
losses = {} losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses) nlp.update(train_examples, sgd=optimizer, losses=losses)
doc = nlp(test_text) doc = nlp(test_text)
print(i, doc.spans)
print(losses["coref"]) # < 0.001
# test the trained model # test the trained model
doc = nlp(test_text) doc = nlp(test_text)
print("AFTER", doc.spans)
# Also test the results are still the same after IO # Also test the results are still the same after IO
with make_tempdir() as tmp_dir: with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir) nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir) nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text) doc2 = nlp2(test_text)
print("doc2", doc2.spans)
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions # Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
texts = [ texts = [
@ -143,14 +137,67 @@ def test_overfitting_IO(nlp):
"I noticed many friends around me", "I noticed many friends around me",
"They received it. They received the SMS.", "They received it. They received the SMS.",
] ]
batch_deps_1 = [doc.spans for doc in nlp.pipe(texts)] docs1 = list(nlp.pipe(texts))
print(batch_deps_1) docs2 = list(nlp.pipe(texts))
batch_deps_2 = [doc.spans for doc in nlp.pipe(texts)] docs3 = [nlp(text) for text in texts]
print(batch_deps_2) assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
no_batch_deps = [doc.spans for doc in [nlp(text) for text in texts]] assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
print(no_batch_deps)
# assert_equal(batch_deps_1, batch_deps_2)
# assert_equal(batch_deps_1, no_batch_deps) @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_tokenization_mismatch(nlp):
train_examples = []
for text, annot in TRAIN_DATA:
eg = Example.from_dict(nlp.make_doc(text), annot)
ref = eg.reference
char_spans = {}
for key, cluster in ref.spans.items():
char_spans[key] = []
for span in cluster:
char_spans[key].append((span[0].idx, span[-1].idx + len(span[-1])))
with ref.retokenize() as retokenizer:
# merge "many friends"
retokenizer.merge(ref[5:7])
# Note this works because it's the same doc and we know the keys
for key, _ in ref.spans.items():
spans = char_spans[key]
ref.spans[key] = [ref.char_span(*span) for span in spans]
train_examples.append(eg)
nlp.add_pipe("coref", config=CONFIG)
optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
for i in range(15):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
doc = nlp(test_text)
# test the trained model
doc = nlp(test_text)
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
texts = [
test_text,
"I noticed many friends around me",
"They received it. They received the SMS.",
]
# save the docs so they don't get garbage collected
docs1 = list(nlp.pipe(texts))
docs2 = list(nlp.pipe(texts))
docs3 = [nlp(text) for text in texts]
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
@ -165,8 +212,26 @@ def test_crossing_spans():
guess = sorted(guess) guess = sorted(guess)
assert gold == guess assert gold == guess
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_sentence_map(snlp): def test_sentence_map(snlp):
doc = snlp("I like text. This is text.") doc = snlp("I like text. This is text.")
sm = get_sentence_ids(doc) sm = get_sentence_ids(doc)
assert sm == [0, 0, 0, 0, 1, 1, 1, 1] assert sm == [0, 0, 0, 0, 1, 1, 1, 1]
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_whitespace_mismatch(nlp):
train_examples = []
for text, annot in TRAIN_DATA:
eg = Example.from_dict(nlp.make_doc(text), annot)
eg.predicted = nlp.make_doc(" " + text)
train_examples.append(eg)
nlp.add_pipe("coref", config=CONFIG)
optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
with pytest.raises(ValueError, match="whitespace"):
nlp.update(train_examples, sgd=optimizer)

View File

@ -0,0 +1,227 @@
import pytest
import spacy
from spacy import util
from spacy.training import Example
from spacy.lang.en import English
from spacy.tests.util import make_tempdir
from spacy.ml.models.coref_util import (
DEFAULT_CLUSTER_PREFIX,
select_non_crossing_spans,
get_sentence_ids,
get_clusters_from_doc,
)
from thinc.util import has_torch
# fmt: off
TRAIN_DATA = [
(
"John Smith picked up the red ball and he threw it away.",
{
"spans": {
f"{DEFAULT_CLUSTER_PREFIX}_1": [
(0, 10, "MENTION"), # John Smith
(38, 40, "MENTION"), # he
],
f"{DEFAULT_CLUSTER_PREFIX}_2": [
(25, 33, "MENTION"), # red ball
(47, 49, "MENTION"), # it
],
f"coref_head_clusters_1": [
(5, 10, "MENTION"), # Smith
(38, 40, "MENTION"), # he
],
f"coref_head_clusters_2": [
(29, 33, "MENTION"), # red ball
(47, 49, "MENTION"), # it
]
}
},
),
]
# fmt: on
CONFIG = {"model": {"@architectures": "spacy.SpanPredictor.v1", "tok2vec_size": 64}}
@pytest.fixture
def nlp():
return English()
@pytest.fixture
def snlp():
en = English()
en.add_pipe("sentencizer")
return en
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_add_pipe(nlp):
nlp.add_pipe("span_predictor")
assert nlp.pipe_names == ["span_predictor"]
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_not_initialized(nlp):
nlp.add_pipe("span_predictor")
text = "She gave me her pen."
with pytest.raises(ValueError, match="E109"):
nlp(text)
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_span_predictor_serialization(nlp):
# Test that the span predictor component can be serialized
nlp.add_pipe("span_predictor", last=True, config=CONFIG)
nlp.initialize()
assert nlp.pipe_names == ["span_predictor"]
text = "She gave me her pen."
doc = nlp(text)
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = spacy.load(tmp_dir)
assert nlp2.pipe_names == ["span_predictor"]
doc2 = nlp2(text)
assert get_clusters_from_doc(doc) == get_clusters_from_doc(doc2)
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_overfitting_IO(nlp):
# Simple test to try and quickly overfit - ensuring the ML models work correctly
train_examples = []
for text, annot in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
train_examples = []
for text, annot in TRAIN_DATA:
eg = Example.from_dict(nlp.make_doc(text), annot)
ref = eg.reference
# Finally, copy over the head spans to the pred
pred = eg.predicted
for key, spans in ref.spans.items():
if key.startswith("coref_head_clusters"):
pred.spans[key] = [pred[span.start : span.end] for span in spans]
train_examples.append(eg)
nlp.add_pipe("span_predictor", config=CONFIG)
optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
for i in range(15):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
doc = nlp(test_text)
# test the trained model, using the pred since it has heads
doc = nlp(train_examples[0].predicted)
# XXX This actually tests that it can overfit
assert get_clusters_from_doc(doc) == get_clusters_from_doc(train_examples[0].reference)
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
texts = [
test_text,
"I noticed many friends around me",
"They received it. They received the SMS.",
]
# XXX Note these have no predictions because they have no input spans
docs1 = list(nlp.pipe(texts))
docs2 = list(nlp.pipe(texts))
docs3 = [nlp(text) for text in texts]
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_tokenization_mismatch(nlp):
train_examples = []
for text, annot in TRAIN_DATA:
eg = Example.from_dict(nlp.make_doc(text), annot)
ref = eg.reference
char_spans = {}
for key, cluster in ref.spans.items():
char_spans[key] = []
for span in cluster:
char_spans[key].append((span.start_char, span.end_char))
with ref.retokenize() as retokenizer:
# merge "picked up"
retokenizer.merge(ref[2:4])
# Note this works because it's the same doc and we know the keys
for key, _ in ref.spans.items():
spans = char_spans[key]
ref.spans[key] = [ref.char_span(*span) for span in spans]
# Finally, copy over the head spans to the pred
pred = eg.predicted
for key, val in ref.spans.items():
if key.startswith("coref_head_clusters"):
spans = char_spans[key]
pred.spans[key] = [pred.char_span(*span) for span in spans]
train_examples.append(eg)
nlp.add_pipe("span_predictor", config=CONFIG)
optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
for i in range(15):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
doc = nlp(test_text)
# test the trained model; need to use doc with head spans on it already
test_doc = train_examples[0].predicted
doc = nlp(test_doc)
# XXX This actually tests that it can overfit
assert get_clusters_from_doc(doc) == get_clusters_from_doc(train_examples[0].reference)
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
texts = [
test_text,
"I noticed many friends around me",
"They received it. They received the SMS.",
]
# save the docs so they don't get garbage collected
docs1 = list(nlp.pipe(texts))
docs2 = list(nlp.pipe(texts))
docs3 = [nlp(text) for text in texts]
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_whitespace_mismatch(nlp):
train_examples = []
for text, annot in TRAIN_DATA:
eg = Example.from_dict(nlp.make_doc(text), annot)
eg.predicted = nlp.make_doc(" " + text)
train_examples.append(eg)
nlp.add_pipe("span_predictor", config=CONFIG)
optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0]
doc = nlp(test_text)
with pytest.raises(ValueError, match="whitespace"):
nlp.update(train_examples, sgd=optimizer)

View File

@ -587,8 +587,8 @@ consists of either two or three subnetworks:
run once for each batch. run once for each batch.
- **lower**: Construct a feature-specific vector for each `(token, feature)` - **lower**: Construct a feature-specific vector for each `(token, feature)`
pair. This is also run once for each batch. Constructing the state pair. This is also run once for each batch. Constructing the state
representation is then a matter of summing the component features and representation is then a matter of summing the component features and applying
applying the non-linearity. the non-linearity.
- **upper** (optional): A feed-forward network that predicts scores from the - **upper** (optional): A feed-forward network that predicts scores from the
state representation. If not present, the output from the lower model is used state representation. If not present, the output from the lower model is used
as action scores directly. as action scores directly.
@ -628,8 +628,8 @@ same signature, but the `use_upper` argument was `True` by default.
> ``` > ```
Build a tagger model, using a provided token-to-vector component. The tagger Build a tagger model, using a provided token-to-vector component. The tagger
model adds a linear layer with softmax activation to predict scores given model adds a linear layer with softmax activation to predict scores given the
the token vectors. token vectors.
| Name | Description | | Name | Description |
| ----------- | ------------------------------------------------------------------------------------------ | | ----------- | ------------------------------------------------------------------------------------------ |
@ -920,8 +920,8 @@ A function that reads an existing `KnowledgeBase` from file.
A function that takes as input a [`KnowledgeBase`](/api/kb) and a A function that takes as input a [`KnowledgeBase`](/api/kb) and a
[`Span`](/api/span) object denoting a named entity, and returns a list of [`Span`](/api/span) object denoting a named entity, and returns a list of
plausible [`Candidate`](/api/kb/#candidate) objects. The default plausible [`Candidate`](/api/kb/#candidate) objects. The default
`CandidateGenerator` uses the text of a mention to find its potential `CandidateGenerator` uses the text of a mention to find its potential aliases in
aliases in the `KnowledgeBase`. Note that this function is case-dependent. the `KnowledgeBase`. Note that this function is case-dependent.
## Coreference Architectures ## Coreference Architectures
@ -975,7 +975,11 @@ The `Coref` model architecture is a Thinc `Model`.
> [model] > [model]
> @architectures = "spacy.SpanPredictor.v1" > @architectures = "spacy.SpanPredictor.v1"
> hidden_size = 1024 > hidden_size = 1024
> dist_emb_size = 64 > distance_embedding_size = 64
> conv_channels = 4
> window_size = 1
> max_distance = 128
> prefix = "coref_head_clusters"
> >
> [model.tok2vec] > [model.tok2vec]
> @architectures = "spacy-transformers.TransformerListener.v1" > @architectures = "spacy-transformers.TransformerListener.v1"
@ -986,13 +990,14 @@ The `Coref` model architecture is a Thinc `Model`.
The `SpanPredictor` model architecture is a Thinc `Model`. The `SpanPredictor` model architecture is a Thinc `Model`.
| Name | Description | | Name | Description |
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ | | `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ | | `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ |
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ | | `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
| `hidden_size` | Size of the main internal layers. ~~int~~ | | `hidden_size` | Size of the main internal layers. ~~int~~ |
| `depth` | Depth of the internal network. ~~int~~ | | `conv_channels` | The number of channels in the internal CNN. ~~int~~ |
| `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ | | `window_size` | The number of neighboring tokens to consider in the internal CNN. `1` means consider one token on each side. ~~int~~ |
| `antecedent_batch_size` | Internal batch size. ~~int~~ | | `max_distance` | The longest possible length of a predicted span. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ | | `prefix` | The prefix that indicates spans to use for input data. ~~string~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |