diff --git a/spacy/errors.py b/spacy/errors.py index 14010565b..ba550f492 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -934,6 +934,7 @@ class Errors(metaclass=ErrorsWithCodes): E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}") E1042 = ("Function was called with `{arg1}`={arg1_values} and " "`{arg2}`={arg2_values} but these arguments are conflicting.") + E1043 = ("Misalignment in coref. Head token has no match in training doc.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 7ad493ac0..a69c6673e 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -1,8 +1,8 @@ -from typing import List, Tuple +from typing import List, Tuple, Callable, cast from thinc.api import Model, chain, get_width from thinc.api import PyTorchWrapper, ArgsKwargs -from thinc.types import Floats2d +from thinc.types import Floats2d, Ints2d from thinc.util import torch, xp2torch, torch2xp from ...tokens import Doc @@ -22,10 +22,8 @@ def build_wl_coref_model( # pairs to keep per mention after rough scoring antecedent_limit: int = 50, antecedent_batch_size: int = 512, -): - # TODO add model return types - nI = None +) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]: with Model.define_operators({">>": chain}): coref_clusterer = Model( @@ -83,7 +81,6 @@ def coref_init(model: Model, X=None, Y=None): def coref_forward(model: Model, X, is_train: bool): return model.layers[0](X, is_train) - def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool): # The input here is List[Floats2d], one for each doc # just use the first @@ -91,16 +88,17 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo X = X[0] word_features = xp2torch(X, requires_grad=is_train) - # TODO fix or remove type annotations - def backprop(args: ArgsKwargs): # -> List[Floats2d]: + def backprop(args: ArgsKwargs) -> List[Floats2d]: # convert to xp and wrap in list - gradients = torch2xp(args.args[0]) + gradients = cast(Floats2d, torch2xp(args.args[0])) return [gradients] return ArgsKwargs(args=(word_features,), kwargs={}), backprop -def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool): +def convert_coref_clusterer_outputs( + model: Model, inputs_outputs, is_train: bool +) -> Tuple[Tuple[Floats2d, Ints2d], Callable]: _, outputs = inputs_outputs scores, indices = outputs @@ -111,8 +109,8 @@ def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool kwargs={"grad_tensors": [dY_t]}, ) - scores_xp = torch2xp(scores) - indices_xp = torch2xp(indices) + scores_xp = cast(Floats2d, torch2xp(scores)) + indices_xp = cast(Ints2d, torch2xp(indices)) return (scores_xp, indices_xp), convert_for_torch_backward diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index a004a69d7..1a6bc6364 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -143,16 +143,18 @@ def create_head_span_idxs(ops, doclen: int): def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]: - """Given a Doc, convert the cluster spans to simple int tuple lists.""" + """Convert the span clusters in a Doc to simple integer tuple lists. The + ints are char spans, to be tokenization independent. + """ out = [] for key, val in doc.spans.items(): cluster = [] for span in val: - # TODO check that there isn't an off-by-one error here - # cluster.append((span.start, span.end)) - # TODO This conversion should be happening earlier in processing + head_i = span.root.i - cluster.append((head_i, head_i + 1)) + head = doc[head_i] + char_span = (head.idx, head.idx + len(head)) + cluster.append(char_span) # don't want duplicates cluster = list(set(cluster)) diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index 4e394ed78..ca76e5a4a 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, cast from thinc.api import Model, chain, tuplify, get_width from thinc.api import PyTorchWrapper, ArgsKwargs @@ -76,15 +76,17 @@ def span_predictor_forward(model: Model, X, is_train: bool): return model.layers[0](X, is_train) def convert_span_predictor_inputs( - model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool + model: Model, + X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], + is_train: bool, ): tok2vec, (sent_ids, head_ids) = X # Normally we should use the input is_train, but for these two it's not relevant # TODO fix the type here, or remove it - def backprop(args: ArgsKwargs): #-> Tuple[List[Floats2d], None]: - gradients = torch2xp(args.args[1]) + def backprop(args: ArgsKwargs) -> Tuple[List[Floats2d], None]: + gradients = cast(Floats2d, torch2xp(args.args[1])) # The sent_ids and head_ids are None because no gradients - return [[gradients], None] + return ([gradients], None) word_features = xp2torch(tok2vec[0], requires_grad=is_train) sent_ids_tensor = xp2torch(sent_ids[0], requires_grad=False) @@ -129,7 +131,6 @@ def predict_span_clusters( def build_get_head_metadata(prefix): - # TODO this name is awful, fix it model = Model( "HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward ) @@ -175,7 +176,6 @@ class SpanPredictor(torch.nn.Module): raise ValueError("max_distance has to be an even number") # input size = single token size # 64 = probably distance emb size - # TODO check that dist_emb_size use is correct self.ffnn = torch.nn.Sequential( torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size), torch.nn.ReLU(), @@ -192,7 +192,6 @@ class SpanPredictor(torch.nn.Module): torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1), torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1), ) - # TODO make embeddings size a parameter self.max_distance = max_distance # handle distances between +-(max_distance - 2 / 2) self.emb = torch.nn.Embedding(max_distance, dist_emb_size) @@ -244,9 +243,7 @@ class SpanPredictor(torch.nn.Module): dim=1, ) lengths = same_sent.sum(dim=1) - padding_mask = torch.arange( - 0, lengths.max().item(), device=device - ).unsqueeze(0) + padding_mask = torch.arange(0, lengths.max().item(), device=device).unsqueeze(0) # (n_heads x max_sent_len) padding_mask = padding_mask < lengths.unsqueeze(1) # (n_heads x max_sent_len x input_size * 2 + distance_emb_size) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 6685b112e..40ff92fcb 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -95,7 +95,7 @@ def make_coref( class CoreferenceResolver(TrainablePipe): """Pipeline component for coreference resolution. - DOCS: https://spacy.io/api/coref (TODO) + DOCS: https://spacy.io/api/coref """ def __init__( @@ -118,8 +118,10 @@ class CoreferenceResolver(TrainablePipe): are stored in. span_cluster_prefix (str): Prefix for the key in doc.spans to store the coref clusters in. + scorer (Optional[Callable]): The scoring method. Defaults to + Scorer.score_coref_clusters. - DOCS: https://spacy.io/api/coref#init (TODO) + DOCS: https://spacy.io/api/coref#init """ self.vocab = vocab self.model = model @@ -133,11 +135,12 @@ class CoreferenceResolver(TrainablePipe): def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: """Apply the pipeline's model to a batch of docs, without modifying them. + Return the list of predicted clusters. docs (Iterable[Doc]): The documents to predict. - RETURNS: The models prediction for each document. + RETURNS (List[MentionClusters]): The model's prediction for each document. - DOCS: https://spacy.io/api/coref#predict (TODO) + DOCS: https://spacy.io/api/coref#predict """ out = [] for doc in docs: @@ -163,7 +166,7 @@ class CoreferenceResolver(TrainablePipe): docs (Iterable[Doc]): The documents to modify. clusters: The span clusters, produced by CoreferenceResolver.predict. - DOCS: https://spacy.io/api/coref#set_annotations (TODO) + DOCS: https://spacy.io/api/coref#set_annotations """ docs = list(docs) if len(docs) != len(clusters_by_doc): @@ -204,7 +207,7 @@ class CoreferenceResolver(TrainablePipe): Updated using the component name as the key. RETURNS (Dict[str, float]): The updated losses dictionary. - DOCS: https://spacy.io/api/coref#update (TODO) + DOCS: https://spacy.io/api/coref#update """ if losses is None: losses = {} @@ -218,12 +221,17 @@ class CoreferenceResolver(TrainablePipe): total_loss = 0 for eg in examples: - # TODO check this causes no issues (in practice it runs) + if eg.x.text != eg.y.text: + # TODO assign error number + raise ValueError( + """Text, including whitespace, must match between reference and + predicted docs in coref training. + """ + ) preds, backprop = self.model.begin_update([eg.predicted]) score_matrix, mention_idx = preds loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) total_loss += loss - # TODO check shape here backprop((d_scores, mention_idx)) if sgd is not None: @@ -232,7 +240,12 @@ class CoreferenceResolver(TrainablePipe): return losses def rehearse(self, examples, *, sgd=None, losses=None, **config): - raise NotImplementedError + # TODO this should be added later + raise NotImplementedError( + Errors.E931.format( + parent="CoreferenceResolver", method="add_label", name=self.name + ) + ) def add_label(self, label: str) -> int: """Technically this method should be implemented from TrainablePipe, @@ -257,7 +270,7 @@ class CoreferenceResolver(TrainablePipe): scores: Scores representing the model's predictions. RETURNS (Tuple[float, float]): The loss and the gradient. - DOCS: https://spacy.io/api/coref#get_loss (TODO) + DOCS: https://spacy.io/api/coref#get_loss """ ops = self.model.ops xp = ops.xp @@ -267,12 +280,23 @@ class CoreferenceResolver(TrainablePipe): example = list(examples)[0] cidx = mention_idx - clusters = get_clusters_from_doc(example.reference) + clusters_by_char = get_clusters_from_doc(example.reference) + # convert to token clusters, and give up if necessary + clusters = [] + for cluster in clusters_by_char: + cc = [] + for start_char, end_char in cluster: + span = example.predicted.char_span(start_char, end_char) + if span is None: + # TODO log more details + raise IndexError(Errors.E1043) + cc.append((span.start, span.end)) + clusters.append(cc) + span_idxs = create_head_span_idxs(ops, len(example.predicted)) gscores = create_gold_scores(span_idxs, clusters) - # TODO fix type here. This is bools but asarray2f wants ints. + # Note on type here. This is bools but asarray2f wants ints. gscores = ops.asarray2f(gscores) # type: ignore - # top_gscores = xp.take_along_axis(gscores, cidx, axis=1) top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1) # now add the placeholder gold_placeholder = ~top_gscores.any(axis=1).T @@ -304,7 +328,7 @@ class CoreferenceResolver(TrainablePipe): returns a representative sample of gold-standard Example objects. nlp (Language): The current nlp object the component is part of. - DOCS: https://spacy.io/api/coref#initialize (TODO) + DOCS: https://spacy.io/api/coref#initialize """ validate_get_examples(get_examples, "CoreferenceResolver.initialize") diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index aa7985a9c..36a291a88 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -383,7 +383,7 @@ class EntityLinker(TrainablePipe): no prediction. docs (Iterable[Doc]): The documents to predict. - RETURNS (List[str]): The models prediction for each document. + RETURNS (List[str]): The model's prediction for each document. DOCS: https://spacy.io/api/entitylinker#predict """ diff --git a/spacy/pipeline/span_predictor.py b/spacy/pipeline/span_predictor.py index b5f25cd81..ee724ad2e 100644 --- a/spacy/pipeline/span_predictor.py +++ b/spacy/pipeline/span_predictor.py @@ -29,7 +29,7 @@ distance_embedding_size = 64 conv_channels = 4 window_size = 1 max_distance = 128 -prefix = coref_head_clusters +prefix = "coref_head_clusters" [model.tok2vec] @architectures = "spacy.Tok2Vec.v2" @@ -95,6 +95,8 @@ class SpanPredictor(TrainablePipe): """Pipeline component to resolve one-token spans to full spans. Used in coreference resolution. + + DOCS: https://spacy.io/api/span_predictor """ def __init__( @@ -119,6 +121,14 @@ class SpanPredictor(TrainablePipe): } def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: + """Apply the pipeline's model to a batch of docs, without modifying them. + Return the list of predicted span clusters. + + docs (Iterable[Doc]): The documents to predict. + RETURNS (List[MentionClusters]): The model's prediction for each document. + + DOCS: https://spacy.io/api/span_predictor#predict + """ # for now pretend there's just one doc out = [] @@ -151,6 +161,13 @@ class SpanPredictor(TrainablePipe): return out def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: + """Modify a batch of Doc objects, using pre-computed scores. + + docs (Iterable[Doc]): The documents to modify. + clusters: The span clusters, produced by SpanPredictor.predict. + + DOCS: https://spacy.io/api/span_predictor#set_annotations + """ for doc, clusters in zip(docs, clusters_by_doc): for ii, cluster in enumerate(clusters): spans = [doc[mm[0] : mm[1]] for mm in cluster] @@ -166,6 +183,15 @@ class SpanPredictor(TrainablePipe): ) -> Dict[str, float]: """Learn from a batch of documents and gold-standard information, updating the pipe's model. Delegates to predict and get_loss. + + examples (Iterable[Example]): A batch of Example objects. + drop (float): The dropout rate. + sgd (thinc.api.Optimizer): The optimizer. + losses (Dict[str, float]): Optional record of the loss during training. + Updated using the component name as the key. + RETURNS (Dict[str, float]): The updated losses dictionary. + + DOCS: https://spacy.io/api/span_predictor#update """ if losses is None: losses = {} @@ -178,6 +204,13 @@ class SpanPredictor(TrainablePipe): total_loss = 0 for eg in examples: + if eg.x.text != eg.y.text: + # TODO assign error number + raise ValueError( + """Text, including whitespace, must match between reference and + predicted docs in span predictor training. + """ + ) span_scores, backprop = self.model.begin_update([eg.predicted]) # FIXME, this only happens once in the first 1000 docs of OntoNotes # and I'm not sure yet why. @@ -222,6 +255,15 @@ class SpanPredictor(TrainablePipe): examples: Iterable[Example], span_scores: Floats3d, ): + """Find the loss and gradient of loss for the batch of documents and + their predicted scores. + + examples (Iterable[Examples]): The batch of examples. + scores: Scores representing the model's predictions. + RETURNS (Tuple[float, float]): The loss and the gradient. + + DOCS: https://spacy.io/api/span_predictor#get_loss + """ ops = self.model.ops # NOTE This is doing fake batching, and should always get a list of one example @@ -231,16 +273,29 @@ class SpanPredictor(TrainablePipe): for eg in examples: starts = [] ends = [] + keeps = [] + sidx = 0 for key, sg in eg.reference.spans.items(): if key.startswith(self.output_prefix): - for mention in sg: - starts.append(mention.start) - ends.append(mention.end) + for ii, mention in enumerate(sg): + sidx += 1 + # convert to span in pred + sch, ech = (mention.start_char, mention.end_char) + span = eg.predicted.char_span(sch, ech) + # TODO add to errors.py + if span is None: + warnings.warn("Could not align gold span in span predictor, skipping") + continue + starts.append(span.start) + ends.append(span.end) + keeps.append(sidx - 1) starts = self.model.ops.xp.asarray(starts) ends = self.model.ops.xp.asarray(ends) - start_scores = span_scores[:, :, 0] - end_scores = span_scores[:, :, 1] + start_scores = span_scores[:, :, 0][keeps] + end_scores = span_scores[:, :, 1][keeps] + + n_classes = start_scores.shape[1] start_probs = ops.softmax(start_scores, axis=1) end_probs = ops.softmax(end_scores, axis=1) @@ -248,7 +303,14 @@ class SpanPredictor(TrainablePipe): end_targets = to_categorical(ends, n_classes) start_grads = start_probs - start_targets end_grads = end_probs - end_targets - grads = ops.xp.stack((start_grads, end_grads), axis=2) + # now return to original shape, with 0s + final_start_grads = ops.alloc2f(*span_scores[:, :, 0].shape) + final_start_grads[keeps] = start_grads + final_end_grads = ops.alloc2f(*final_start_grads.shape) + final_end_grads[keeps] = end_grads + # XXX Note this only works with fake batching + grads = ops.xp.stack((final_start_grads, final_end_grads), axis=2) + loss = float((grads**2).sum()) return loss, grads @@ -258,6 +320,15 @@ class SpanPredictor(TrainablePipe): *, nlp: Optional[Language] = None, ) -> None: + """Initialize the pipe for training, using a representative set + of data examples. + + get_examples (Callable[[], Iterable[Example]]): Function that + returns a representative sample of gold-standard Example objects. + nlp (Language): The current nlp object the component is part of. + + DOCS: https://spacy.io/api/span_predictor#initialize + """ validate_get_examples(get_examples, "SpanPredictor.initialize") X = [] @@ -267,6 +338,7 @@ class SpanPredictor(TrainablePipe): if not ex.predicted.spans: # set placeholder for shape inference doc = ex.predicted + # TODO should be able to check if there are some valid docs in the batch assert len(doc) > 2, "Coreference requires at least two tokens" doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]] X.append(ex.predicted) diff --git a/spacy/tests/pipeline/test_coref.py b/spacy/tests/pipeline/test_coref.py index 53f0b2011..3e297ddcd 100644 --- a/spacy/tests/pipeline/test_coref.py +++ b/spacy/tests/pipeline/test_coref.py @@ -9,6 +9,7 @@ from spacy.ml.models.coref_util import ( DEFAULT_CLUSTER_PREFIX, select_non_crossing_spans, get_sentence_ids, + get_clusters_from_doc, ) from thinc.util import has_torch @@ -35,6 +36,9 @@ TRAIN_DATA = [ # fmt: on +CONFIG = {"model": {"@architectures": "spacy.Coref.v1", "tok2vec_size": 64}} + + @pytest.fixture def nlp(): return English() @@ -60,9 +64,10 @@ def test_not_initialized(nlp): with pytest.raises(ValueError, match="E109"): nlp(text) + @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_initialized(nlp): - nlp.add_pipe("coref") + nlp.add_pipe("coref", config=CONFIG) nlp.initialize() assert nlp.pipe_names == ["coref"] text = "She gave me her pen." @@ -74,7 +79,7 @@ def test_initialized(nlp): @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_initialized_short(nlp): - nlp.add_pipe("coref") + nlp.add_pipe("coref", config=CONFIG) nlp.initialize() assert nlp.pipe_names == ["coref"] text = "Hi there" @@ -84,58 +89,47 @@ def test_initialized_short(nlp): @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_coref_serialization(nlp): # Test that the coref component can be serialized - nlp.add_pipe("coref", last=True) + nlp.add_pipe("coref", last=True, config=CONFIG) nlp.initialize() assert nlp.pipe_names == ["coref"] text = "She gave me her pen." doc = nlp(text) - spans_result = doc.spans with make_tempdir() as tmp_dir: nlp.to_disk(tmp_dir) nlp2 = spacy.load(tmp_dir) assert nlp2.pipe_names == ["coref"] doc2 = nlp2(text) - spans_result2 = doc2.spans - print(1, [(k, len(v)) for k, v in spans_result.items()]) - print(2, [(k, len(v)) for k, v in spans_result2.items()]) - # Note: spans do not compare equal because docs are different and docs - # use object identity for equality - for k, v in spans_result.items(): - assert str(spans_result[k]) == str(spans_result2[k]) - # assert spans_result == spans_result2 + + assert get_clusters_from_doc(doc) == get_clusters_from_doc(doc2) @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_overfitting_IO(nlp): - # Simple test to try and quickly overfit the senter - ensuring the ML models work correctly + # Simple test to try and quickly overfit - ensuring the ML models work correctly train_examples = [] for text, annot in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) - nlp.add_pipe("coref") + nlp.add_pipe("coref", config=CONFIG) optimizer = nlp.initialize() test_text = TRAIN_DATA[0][0] doc = nlp(test_text) - print("BEFORE", doc.spans) - for i in range(5): + # Needs ~12 epochs to converge + for i in range(15): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) doc = nlp(test_text) - print(i, doc.spans) - print(losses["coref"]) # < 0.001 # test the trained model doc = nlp(test_text) - print("AFTER", doc.spans) # Also test the results are still the same after IO with make_tempdir() as tmp_dir: nlp.to_disk(tmp_dir) nlp2 = util.load_model_from_path(tmp_dir) doc2 = nlp2(test_text) - print("doc2", doc2.spans) # Make sure that running pipe twice, or comparing to call, always amounts to the same predictions texts = [ @@ -143,14 +137,67 @@ def test_overfitting_IO(nlp): "I noticed many friends around me", "They received it. They received the SMS.", ] - batch_deps_1 = [doc.spans for doc in nlp.pipe(texts)] - print(batch_deps_1) - batch_deps_2 = [doc.spans for doc in nlp.pipe(texts)] - print(batch_deps_2) - no_batch_deps = [doc.spans for doc in [nlp(text) for text in texts]] - print(no_batch_deps) - # assert_equal(batch_deps_1, batch_deps_2) - # assert_equal(batch_deps_1, no_batch_deps) + docs1 = list(nlp.pipe(texts)) + docs2 = list(nlp.pipe(texts)) + docs3 = [nlp(text) for text in texts] + assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0]) + assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0]) + + +@pytest.mark.skipif(not has_torch, reason="Torch not available") +def test_tokenization_mismatch(nlp): + train_examples = [] + for text, annot in TRAIN_DATA: + eg = Example.from_dict(nlp.make_doc(text), annot) + ref = eg.reference + char_spans = {} + for key, cluster in ref.spans.items(): + char_spans[key] = [] + for span in cluster: + char_spans[key].append((span[0].idx, span[-1].idx + len(span[-1]))) + with ref.retokenize() as retokenizer: + # merge "many friends" + retokenizer.merge(ref[5:7]) + + # Note this works because it's the same doc and we know the keys + for key, _ in ref.spans.items(): + spans = char_spans[key] + ref.spans[key] = [ref.char_span(*span) for span in spans] + + train_examples.append(eg) + + nlp.add_pipe("coref", config=CONFIG) + optimizer = nlp.initialize() + test_text = TRAIN_DATA[0][0] + doc = nlp(test_text) + + for i in range(15): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + doc = nlp(test_text) + + # test the trained model + doc = nlp(test_text) + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + + # Make sure that running pipe twice, or comparing to call, always amounts to the same predictions + texts = [ + test_text, + "I noticed many friends around me", + "They received it. They received the SMS.", + ] + + # save the docs so they don't get garbage collected + docs1 = list(nlp.pipe(texts)) + docs2 = list(nlp.pipe(texts)) + docs3 = [nlp(text) for text in texts] + assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0]) + assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0]) @pytest.mark.skipif(not has_torch, reason="Torch not available") @@ -165,8 +212,26 @@ def test_crossing_spans(): guess = sorted(guess) assert gold == guess + @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_sentence_map(snlp): doc = snlp("I like text. This is text.") sm = get_sentence_ids(doc) assert sm == [0, 0, 0, 0, 1, 1, 1, 1] + + +@pytest.mark.skipif(not has_torch, reason="Torch not available") +def test_whitespace_mismatch(nlp): + train_examples = [] + for text, annot in TRAIN_DATA: + eg = Example.from_dict(nlp.make_doc(text), annot) + eg.predicted = nlp.make_doc(" " + text) + train_examples.append(eg) + + nlp.add_pipe("coref", config=CONFIG) + optimizer = nlp.initialize() + test_text = TRAIN_DATA[0][0] + doc = nlp(test_text) + + with pytest.raises(ValueError, match="whitespace"): + nlp.update(train_examples, sgd=optimizer) diff --git a/spacy/tests/pipeline/test_span_predictor.py b/spacy/tests/pipeline/test_span_predictor.py new file mode 100644 index 000000000..8a6c62011 --- /dev/null +++ b/spacy/tests/pipeline/test_span_predictor.py @@ -0,0 +1,227 @@ +import pytest +import spacy + +from spacy import util +from spacy.training import Example +from spacy.lang.en import English +from spacy.tests.util import make_tempdir +from spacy.ml.models.coref_util import ( + DEFAULT_CLUSTER_PREFIX, + select_non_crossing_spans, + get_sentence_ids, + get_clusters_from_doc, +) + +from thinc.util import has_torch + +# fmt: off +TRAIN_DATA = [ + ( + "John Smith picked up the red ball and he threw it away.", + { + "spans": { + f"{DEFAULT_CLUSTER_PREFIX}_1": [ + (0, 10, "MENTION"), # John Smith + (38, 40, "MENTION"), # he + + ], + f"{DEFAULT_CLUSTER_PREFIX}_2": [ + (25, 33, "MENTION"), # red ball + (47, 49, "MENTION"), # it + ], + f"coref_head_clusters_1": [ + (5, 10, "MENTION"), # Smith + (38, 40, "MENTION"), # he + + ], + f"coref_head_clusters_2": [ + (29, 33, "MENTION"), # red ball + (47, 49, "MENTION"), # it + ] + } + }, + ), +] +# fmt: on + +CONFIG = {"model": {"@architectures": "spacy.SpanPredictor.v1", "tok2vec_size": 64}} + + +@pytest.fixture +def nlp(): + return English() + + +@pytest.fixture +def snlp(): + en = English() + en.add_pipe("sentencizer") + return en + + +@pytest.mark.skipif(not has_torch, reason="Torch not available") +def test_add_pipe(nlp): + nlp.add_pipe("span_predictor") + assert nlp.pipe_names == ["span_predictor"] + + +@pytest.mark.skipif(not has_torch, reason="Torch not available") +def test_not_initialized(nlp): + nlp.add_pipe("span_predictor") + text = "She gave me her pen." + with pytest.raises(ValueError, match="E109"): + nlp(text) + + +@pytest.mark.skipif(not has_torch, reason="Torch not available") +def test_span_predictor_serialization(nlp): + # Test that the span predictor component can be serialized + nlp.add_pipe("span_predictor", last=True, config=CONFIG) + nlp.initialize() + assert nlp.pipe_names == ["span_predictor"] + text = "She gave me her pen." + doc = nlp(text) + + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = spacy.load(tmp_dir) + assert nlp2.pipe_names == ["span_predictor"] + doc2 = nlp2(text) + + assert get_clusters_from_doc(doc) == get_clusters_from_doc(doc2) + + +@pytest.mark.skipif(not has_torch, reason="Torch not available") +def test_overfitting_IO(nlp): + # Simple test to try and quickly overfit - ensuring the ML models work correctly + train_examples = [] + for text, annot in TRAIN_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) + + train_examples = [] + for text, annot in TRAIN_DATA: + eg = Example.from_dict(nlp.make_doc(text), annot) + ref = eg.reference + # Finally, copy over the head spans to the pred + pred = eg.predicted + for key, spans in ref.spans.items(): + if key.startswith("coref_head_clusters"): + pred.spans[key] = [pred[span.start : span.end] for span in spans] + + train_examples.append(eg) + nlp.add_pipe("span_predictor", config=CONFIG) + optimizer = nlp.initialize() + test_text = TRAIN_DATA[0][0] + doc = nlp(test_text) + + for i in range(15): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + doc = nlp(test_text) + + # test the trained model, using the pred since it has heads + doc = nlp(train_examples[0].predicted) + # XXX This actually tests that it can overfit + assert get_clusters_from_doc(doc) == get_clusters_from_doc(train_examples[0].reference) + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + + # Make sure that running pipe twice, or comparing to call, always amounts to the same predictions + texts = [ + test_text, + "I noticed many friends around me", + "They received it. They received the SMS.", + ] + # XXX Note these have no predictions because they have no input spans + docs1 = list(nlp.pipe(texts)) + docs2 = list(nlp.pipe(texts)) + docs3 = [nlp(text) for text in texts] + assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0]) + assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0]) + + +@pytest.mark.skipif(not has_torch, reason="Torch not available") +def test_tokenization_mismatch(nlp): + train_examples = [] + for text, annot in TRAIN_DATA: + eg = Example.from_dict(nlp.make_doc(text), annot) + ref = eg.reference + char_spans = {} + for key, cluster in ref.spans.items(): + char_spans[key] = [] + for span in cluster: + char_spans[key].append((span.start_char, span.end_char)) + with ref.retokenize() as retokenizer: + # merge "picked up" + retokenizer.merge(ref[2:4]) + + # Note this works because it's the same doc and we know the keys + for key, _ in ref.spans.items(): + spans = char_spans[key] + ref.spans[key] = [ref.char_span(*span) for span in spans] + + # Finally, copy over the head spans to the pred + pred = eg.predicted + for key, val in ref.spans.items(): + if key.startswith("coref_head_clusters"): + spans = char_spans[key] + pred.spans[key] = [pred.char_span(*span) for span in spans] + + train_examples.append(eg) + + nlp.add_pipe("span_predictor", config=CONFIG) + optimizer = nlp.initialize() + test_text = TRAIN_DATA[0][0] + doc = nlp(test_text) + + for i in range(15): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + doc = nlp(test_text) + + # test the trained model; need to use doc with head spans on it already + test_doc = train_examples[0].predicted + doc = nlp(test_doc) + # XXX This actually tests that it can overfit + assert get_clusters_from_doc(doc) == get_clusters_from_doc(train_examples[0].reference) + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + + # Make sure that running pipe twice, or comparing to call, always amounts to the same predictions + texts = [ + test_text, + "I noticed many friends around me", + "They received it. They received the SMS.", + ] + + # save the docs so they don't get garbage collected + docs1 = list(nlp.pipe(texts)) + docs2 = list(nlp.pipe(texts)) + docs3 = [nlp(text) for text in texts] + assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0]) + assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0]) + + +@pytest.mark.skipif(not has_torch, reason="Torch not available") +def test_whitespace_mismatch(nlp): + train_examples = [] + for text, annot in TRAIN_DATA: + eg = Example.from_dict(nlp.make_doc(text), annot) + eg.predicted = nlp.make_doc(" " + text) + train_examples.append(eg) + + nlp.add_pipe("span_predictor", config=CONFIG) + optimizer = nlp.initialize() + test_text = TRAIN_DATA[0][0] + doc = nlp(test_text) + + with pytest.raises(ValueError, match="whitespace"): + nlp.update(train_examples, sgd=optimizer) diff --git a/website/docs/api/architectures.md b/website/docs/api/architectures.md index 4e70eee87..e881864a9 100644 --- a/website/docs/api/architectures.md +++ b/website/docs/api/architectures.md @@ -587,8 +587,8 @@ consists of either two or three subnetworks: run once for each batch. - **lower**: Construct a feature-specific vector for each `(token, feature)` pair. This is also run once for each batch. Constructing the state - representation is then a matter of summing the component features and - applying the non-linearity. + representation is then a matter of summing the component features and applying + the non-linearity. - **upper** (optional): A feed-forward network that predicts scores from the state representation. If not present, the output from the lower model is used as action scores directly. @@ -628,8 +628,8 @@ same signature, but the `use_upper` argument was `True` by default. > ``` Build a tagger model, using a provided token-to-vector component. The tagger -model adds a linear layer with softmax activation to predict scores given -the token vectors. +model adds a linear layer with softmax activation to predict scores given the +token vectors. | Name | Description | | ----------- | ------------------------------------------------------------------------------------------ | @@ -920,8 +920,8 @@ A function that reads an existing `KnowledgeBase` from file. A function that takes as input a [`KnowledgeBase`](/api/kb) and a [`Span`](/api/span) object denoting a named entity, and returns a list of plausible [`Candidate`](/api/kb/#candidate) objects. The default -`CandidateGenerator` uses the text of a mention to find its potential -aliases in the `KnowledgeBase`. Note that this function is case-dependent. +`CandidateGenerator` uses the text of a mention to find its potential aliases in +the `KnowledgeBase`. Note that this function is case-dependent. ## Coreference Architectures @@ -975,7 +975,11 @@ The `Coref` model architecture is a Thinc `Model`. > [model] > @architectures = "spacy.SpanPredictor.v1" > hidden_size = 1024 -> dist_emb_size = 64 +> distance_embedding_size = 64 +> conv_channels = 4 +> window_size = 1 +> max_distance = 128 +> prefix = "coref_head_clusters" > > [model.tok2vec] > @architectures = "spacy-transformers.TransformerListener.v1" @@ -986,13 +990,14 @@ The `Coref` model architecture is a Thinc `Model`. The `SpanPredictor` model architecture is a Thinc `Model`. -| Name | Description | -| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ | -| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ | -| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ | -| `hidden_size` | Size of the main internal layers. ~~int~~ | -| `depth` | Depth of the internal network. ~~int~~ | -| `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ | -| `antecedent_batch_size` | Internal batch size. ~~int~~ | -| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ | +| Name | Description | +| ------------------------- | ----------------------------------------------------------------------------------------------------------------------------- | +| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ | +| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ | +| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ | +| `hidden_size` | Size of the main internal layers. ~~int~~ | +| `conv_channels` | The number of channels in the internal CNN. ~~int~~ | +| `window_size` | The number of neighboring tokens to consider in the internal CNN. `1` means consider one token on each side. ~~int~~ | +| `max_distance` | The longest possible length of a predicted span. ~~int~~ | +| `prefix` | The prefix that indicates spans to use for input data. ~~string~~ | +| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |