mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-17 19:52:18 +03:00
Merge branch 'fix/coref-alignment' into feature/coref
This commit is contained in:
commit
1b3db149df
|
@ -934,6 +934,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
||||
E1042 = ("Function was called with `{arg1}`={arg1_values} and "
|
||||
"`{arg2}`={arg2_values} but these arguments are conflicting.")
|
||||
E1043 = ("Misalignment in coref. Head token has no match in training doc.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -143,16 +143,18 @@ def create_head_span_idxs(ops, doclen: int):
|
|||
|
||||
|
||||
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
||||
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
|
||||
"""Convert the span clusters in a Doc to simple integer tuple lists. The
|
||||
ints are char spans, to be tokenization independent.
|
||||
"""
|
||||
out = []
|
||||
for key, val in doc.spans.items():
|
||||
cluster = []
|
||||
for span in val:
|
||||
# TODO check that there isn't an off-by-one error here
|
||||
# cluster.append((span.start, span.end))
|
||||
# TODO This conversion should be happening earlier in processing
|
||||
|
||||
head_i = span.root.i
|
||||
cluster.append((head_i, head_i + 1))
|
||||
head = doc[head_i]
|
||||
char_span = (head.idx, head.idx + len(head))
|
||||
cluster.append(char_span)
|
||||
|
||||
# don't want duplicates
|
||||
cluster = list(set(cluster))
|
||||
|
|
|
@ -221,6 +221,13 @@ class CoreferenceResolver(TrainablePipe):
|
|||
total_loss = 0
|
||||
|
||||
for eg in examples:
|
||||
if eg.x.text != eg.y.text:
|
||||
# TODO assign error number
|
||||
raise ValueError(
|
||||
"""Text, including whitespace, must match between reference and
|
||||
predicted docs in coref training.
|
||||
"""
|
||||
)
|
||||
preds, backprop = self.model.begin_update([eg.predicted])
|
||||
score_matrix, mention_idx = preds
|
||||
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
||||
|
@ -273,7 +280,19 @@ class CoreferenceResolver(TrainablePipe):
|
|||
example = list(examples)[0]
|
||||
cidx = mention_idx
|
||||
|
||||
clusters = get_clusters_from_doc(example.reference)
|
||||
clusters_by_char = get_clusters_from_doc(example.reference)
|
||||
# convert to token clusters, and give up if necessary
|
||||
clusters = []
|
||||
for cluster in clusters_by_char:
|
||||
cc = []
|
||||
for start_char, end_char in cluster:
|
||||
span = example.predicted.char_span(start_char, end_char)
|
||||
if span is None:
|
||||
# TODO log more details
|
||||
raise IndexError(Errors.E1043)
|
||||
cc.append((span.start, span.end))
|
||||
clusters.append(cc)
|
||||
|
||||
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
||||
gscores = create_gold_scores(span_idxs, clusters)
|
||||
# Note on type here. This is bools but asarray2f wants ints.
|
||||
|
|
|
@ -204,6 +204,13 @@ class SpanPredictor(TrainablePipe):
|
|||
|
||||
total_loss = 0
|
||||
for eg in examples:
|
||||
if eg.x.text != eg.y.text:
|
||||
# TODO assign error number
|
||||
raise ValueError(
|
||||
"""Text, including whitespace, must match between reference and
|
||||
predicted docs in span predictor training.
|
||||
"""
|
||||
)
|
||||
span_scores, backprop = self.model.begin_update([eg.predicted])
|
||||
# FIXME, this only happens once in the first 1000 docs of OntoNotes
|
||||
# and I'm not sure yet why.
|
||||
|
@ -266,16 +273,29 @@ class SpanPredictor(TrainablePipe):
|
|||
for eg in examples:
|
||||
starts = []
|
||||
ends = []
|
||||
keeps = []
|
||||
sidx = 0
|
||||
for key, sg in eg.reference.spans.items():
|
||||
if key.startswith(self.output_prefix):
|
||||
for mention in sg:
|
||||
starts.append(mention.start)
|
||||
ends.append(mention.end)
|
||||
for ii, mention in enumerate(sg):
|
||||
sidx += 1
|
||||
# convert to span in pred
|
||||
sch, ech = (mention.start_char, mention.end_char)
|
||||
span = eg.predicted.char_span(sch, ech)
|
||||
# TODO add to errors.py
|
||||
if span is None:
|
||||
warnings.warn("Could not align gold span in span predictor, skipping")
|
||||
continue
|
||||
starts.append(span.start)
|
||||
ends.append(span.end)
|
||||
keeps.append(sidx - 1)
|
||||
|
||||
starts = self.model.ops.xp.asarray(starts)
|
||||
ends = self.model.ops.xp.asarray(ends)
|
||||
start_scores = span_scores[:, :, 0]
|
||||
end_scores = span_scores[:, :, 1]
|
||||
start_scores = span_scores[:, :, 0][keeps]
|
||||
end_scores = span_scores[:, :, 1][keeps]
|
||||
|
||||
|
||||
n_classes = start_scores.shape[1]
|
||||
start_probs = ops.softmax(start_scores, axis=1)
|
||||
end_probs = ops.softmax(end_scores, axis=1)
|
||||
|
@ -283,7 +303,14 @@ class SpanPredictor(TrainablePipe):
|
|||
end_targets = to_categorical(ends, n_classes)
|
||||
start_grads = start_probs - start_targets
|
||||
end_grads = end_probs - end_targets
|
||||
grads = ops.xp.stack((start_grads, end_grads), axis=2)
|
||||
# now return to original shape, with 0s
|
||||
final_start_grads = ops.alloc2f(*span_scores[:, :, 0].shape)
|
||||
final_start_grads[keeps] = start_grads
|
||||
final_end_grads = ops.alloc2f(*final_start_grads.shape)
|
||||
final_end_grads[keeps] = end_grads
|
||||
# XXX Note this only works with fake batching
|
||||
grads = ops.xp.stack((final_start_grads, final_end_grads), axis=2)
|
||||
|
||||
loss = float((grads**2).sum())
|
||||
return loss, grads
|
||||
|
||||
|
@ -311,6 +338,7 @@ class SpanPredictor(TrainablePipe):
|
|||
if not ex.predicted.spans:
|
||||
# set placeholder for shape inference
|
||||
doc = ex.predicted
|
||||
# TODO should be able to check if there are some valid docs in the batch
|
||||
assert len(doc) > 2, "Coreference requires at least two tokens"
|
||||
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
|
||||
X.append(ex.predicted)
|
||||
|
|
|
@ -9,6 +9,7 @@ from spacy.ml.models.coref_util import (
|
|||
DEFAULT_CLUSTER_PREFIX,
|
||||
select_non_crossing_spans,
|
||||
get_sentence_ids,
|
||||
get_clusters_from_doc,
|
||||
)
|
||||
|
||||
from thinc.util import has_torch
|
||||
|
@ -35,6 +36,9 @@ TRAIN_DATA = [
|
|||
# fmt: on
|
||||
|
||||
|
||||
CONFIG = {"model": {"@architectures": "spacy.Coref.v1", "tok2vec_size": 64}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
return English()
|
||||
|
@ -60,9 +64,10 @@ def test_not_initialized(nlp):
|
|||
with pytest.raises(ValueError, match="E109"):
|
||||
nlp(text)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_initialized(nlp):
|
||||
nlp.add_pipe("coref")
|
||||
nlp.add_pipe("coref", config=CONFIG)
|
||||
nlp.initialize()
|
||||
assert nlp.pipe_names == ["coref"]
|
||||
text = "She gave me her pen."
|
||||
|
@ -74,7 +79,7 @@ def test_initialized(nlp):
|
|||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_initialized_short(nlp):
|
||||
nlp.add_pipe("coref")
|
||||
nlp.add_pipe("coref", config=CONFIG)
|
||||
nlp.initialize()
|
||||
assert nlp.pipe_names == ["coref"]
|
||||
text = "Hi there"
|
||||
|
@ -84,58 +89,47 @@ def test_initialized_short(nlp):
|
|||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_coref_serialization(nlp):
|
||||
# Test that the coref component can be serialized
|
||||
nlp.add_pipe("coref", last=True)
|
||||
nlp.add_pipe("coref", last=True, config=CONFIG)
|
||||
nlp.initialize()
|
||||
assert nlp.pipe_names == ["coref"]
|
||||
text = "She gave me her pen."
|
||||
doc = nlp(text)
|
||||
spans_result = doc.spans
|
||||
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = spacy.load(tmp_dir)
|
||||
assert nlp2.pipe_names == ["coref"]
|
||||
doc2 = nlp2(text)
|
||||
spans_result2 = doc2.spans
|
||||
print(1, [(k, len(v)) for k, v in spans_result.items()])
|
||||
print(2, [(k, len(v)) for k, v in spans_result2.items()])
|
||||
# Note: spans do not compare equal because docs are different and docs
|
||||
# use object identity for equality
|
||||
for k, v in spans_result.items():
|
||||
assert str(spans_result[k]) == str(spans_result2[k])
|
||||
# assert spans_result == spans_result2
|
||||
|
||||
assert get_clusters_from_doc(doc) == get_clusters_from_doc(doc2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_overfitting_IO(nlp):
|
||||
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
||||
# Simple test to try and quickly overfit - ensuring the ML models work correctly
|
||||
train_examples = []
|
||||
for text, annot in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
||||
|
||||
nlp.add_pipe("coref")
|
||||
nlp.add_pipe("coref", config=CONFIG)
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
print("BEFORE", doc.spans)
|
||||
|
||||
for i in range(5):
|
||||
# Needs ~12 epochs to converge
|
||||
for i in range(15):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
doc = nlp(test_text)
|
||||
print(i, doc.spans)
|
||||
print(losses["coref"]) # < 0.001
|
||||
|
||||
# test the trained model
|
||||
doc = nlp(test_text)
|
||||
print("AFTER", doc.spans)
|
||||
|
||||
# Also test the results are still the same after IO
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
doc2 = nlp2(test_text)
|
||||
print("doc2", doc2.spans)
|
||||
|
||||
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||
texts = [
|
||||
|
@ -143,14 +137,67 @@ def test_overfitting_IO(nlp):
|
|||
"I noticed many friends around me",
|
||||
"They received it. They received the SMS.",
|
||||
]
|
||||
batch_deps_1 = [doc.spans for doc in nlp.pipe(texts)]
|
||||
print(batch_deps_1)
|
||||
batch_deps_2 = [doc.spans for doc in nlp.pipe(texts)]
|
||||
print(batch_deps_2)
|
||||
no_batch_deps = [doc.spans for doc in [nlp(text) for text in texts]]
|
||||
print(no_batch_deps)
|
||||
# assert_equal(batch_deps_1, batch_deps_2)
|
||||
# assert_equal(batch_deps_1, no_batch_deps)
|
||||
docs1 = list(nlp.pipe(texts))
|
||||
docs2 = list(nlp.pipe(texts))
|
||||
docs3 = [nlp(text) for text in texts]
|
||||
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
|
||||
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_tokenization_mismatch(nlp):
|
||||
train_examples = []
|
||||
for text, annot in TRAIN_DATA:
|
||||
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||
ref = eg.reference
|
||||
char_spans = {}
|
||||
for key, cluster in ref.spans.items():
|
||||
char_spans[key] = []
|
||||
for span in cluster:
|
||||
char_spans[key].append((span[0].idx, span[-1].idx + len(span[-1])))
|
||||
with ref.retokenize() as retokenizer:
|
||||
# merge "many friends"
|
||||
retokenizer.merge(ref[5:7])
|
||||
|
||||
# Note this works because it's the same doc and we know the keys
|
||||
for key, _ in ref.spans.items():
|
||||
spans = char_spans[key]
|
||||
ref.spans[key] = [ref.char_span(*span) for span in spans]
|
||||
|
||||
train_examples.append(eg)
|
||||
|
||||
nlp.add_pipe("coref", config=CONFIG)
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
||||
for i in range(15):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
doc = nlp(test_text)
|
||||
|
||||
# test the trained model
|
||||
doc = nlp(test_text)
|
||||
|
||||
# Also test the results are still the same after IO
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
doc2 = nlp2(test_text)
|
||||
|
||||
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||
texts = [
|
||||
test_text,
|
||||
"I noticed many friends around me",
|
||||
"They received it. They received the SMS.",
|
||||
]
|
||||
|
||||
# save the docs so they don't get garbage collected
|
||||
docs1 = list(nlp.pipe(texts))
|
||||
docs2 = list(nlp.pipe(texts))
|
||||
docs3 = [nlp(text) for text in texts]
|
||||
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
|
||||
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
|
@ -165,8 +212,26 @@ def test_crossing_spans():
|
|||
guess = sorted(guess)
|
||||
assert gold == guess
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_sentence_map(snlp):
|
||||
doc = snlp("I like text. This is text.")
|
||||
sm = get_sentence_ids(doc)
|
||||
assert sm == [0, 0, 0, 0, 1, 1, 1, 1]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_whitespace_mismatch(nlp):
|
||||
train_examples = []
|
||||
for text, annot in TRAIN_DATA:
|
||||
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||
eg.predicted = nlp.make_doc(" " + text)
|
||||
train_examples.append(eg)
|
||||
|
||||
nlp.add_pipe("coref", config=CONFIG)
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
||||
with pytest.raises(ValueError, match="whitespace"):
|
||||
nlp.update(train_examples, sgd=optimizer)
|
||||
|
|
227
spacy/tests/pipeline/test_span_predictor.py
Normal file
227
spacy/tests/pipeline/test_span_predictor.py
Normal file
|
@ -0,0 +1,227 @@
|
|||
import pytest
|
||||
import spacy
|
||||
|
||||
from spacy import util
|
||||
from spacy.training import Example
|
||||
from spacy.lang.en import English
|
||||
from spacy.tests.util import make_tempdir
|
||||
from spacy.ml.models.coref_util import (
|
||||
DEFAULT_CLUSTER_PREFIX,
|
||||
select_non_crossing_spans,
|
||||
get_sentence_ids,
|
||||
get_clusters_from_doc,
|
||||
)
|
||||
|
||||
from thinc.util import has_torch
|
||||
|
||||
# fmt: off
|
||||
TRAIN_DATA = [
|
||||
(
|
||||
"John Smith picked up the red ball and he threw it away.",
|
||||
{
|
||||
"spans": {
|
||||
f"{DEFAULT_CLUSTER_PREFIX}_1": [
|
||||
(0, 10, "MENTION"), # John Smith
|
||||
(38, 40, "MENTION"), # he
|
||||
|
||||
],
|
||||
f"{DEFAULT_CLUSTER_PREFIX}_2": [
|
||||
(25, 33, "MENTION"), # red ball
|
||||
(47, 49, "MENTION"), # it
|
||||
],
|
||||
f"coref_head_clusters_1": [
|
||||
(5, 10, "MENTION"), # Smith
|
||||
(38, 40, "MENTION"), # he
|
||||
|
||||
],
|
||||
f"coref_head_clusters_2": [
|
||||
(29, 33, "MENTION"), # red ball
|
||||
(47, 49, "MENTION"), # it
|
||||
]
|
||||
}
|
||||
},
|
||||
),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
CONFIG = {"model": {"@architectures": "spacy.SpanPredictor.v1", "tok2vec_size": 64}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
return English()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def snlp():
|
||||
en = English()
|
||||
en.add_pipe("sentencizer")
|
||||
return en
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_add_pipe(nlp):
|
||||
nlp.add_pipe("span_predictor")
|
||||
assert nlp.pipe_names == ["span_predictor"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_not_initialized(nlp):
|
||||
nlp.add_pipe("span_predictor")
|
||||
text = "She gave me her pen."
|
||||
with pytest.raises(ValueError, match="E109"):
|
||||
nlp(text)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_span_predictor_serialization(nlp):
|
||||
# Test that the span predictor component can be serialized
|
||||
nlp.add_pipe("span_predictor", last=True, config=CONFIG)
|
||||
nlp.initialize()
|
||||
assert nlp.pipe_names == ["span_predictor"]
|
||||
text = "She gave me her pen."
|
||||
doc = nlp(text)
|
||||
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = spacy.load(tmp_dir)
|
||||
assert nlp2.pipe_names == ["span_predictor"]
|
||||
doc2 = nlp2(text)
|
||||
|
||||
assert get_clusters_from_doc(doc) == get_clusters_from_doc(doc2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_overfitting_IO(nlp):
|
||||
# Simple test to try and quickly overfit - ensuring the ML models work correctly
|
||||
train_examples = []
|
||||
for text, annot in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
||||
|
||||
train_examples = []
|
||||
for text, annot in TRAIN_DATA:
|
||||
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||
ref = eg.reference
|
||||
# Finally, copy over the head spans to the pred
|
||||
pred = eg.predicted
|
||||
for key, spans in ref.spans.items():
|
||||
if key.startswith("coref_head_clusters"):
|
||||
pred.spans[key] = [pred[span.start : span.end] for span in spans]
|
||||
|
||||
train_examples.append(eg)
|
||||
nlp.add_pipe("span_predictor", config=CONFIG)
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
||||
for i in range(15):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
doc = nlp(test_text)
|
||||
|
||||
# test the trained model, using the pred since it has heads
|
||||
doc = nlp(train_examples[0].predicted)
|
||||
# XXX This actually tests that it can overfit
|
||||
assert get_clusters_from_doc(doc) == get_clusters_from_doc(train_examples[0].reference)
|
||||
|
||||
# Also test the results are still the same after IO
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
doc2 = nlp2(test_text)
|
||||
|
||||
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||
texts = [
|
||||
test_text,
|
||||
"I noticed many friends around me",
|
||||
"They received it. They received the SMS.",
|
||||
]
|
||||
# XXX Note these have no predictions because they have no input spans
|
||||
docs1 = list(nlp.pipe(texts))
|
||||
docs2 = list(nlp.pipe(texts))
|
||||
docs3 = [nlp(text) for text in texts]
|
||||
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
|
||||
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_tokenization_mismatch(nlp):
|
||||
train_examples = []
|
||||
for text, annot in TRAIN_DATA:
|
||||
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||
ref = eg.reference
|
||||
char_spans = {}
|
||||
for key, cluster in ref.spans.items():
|
||||
char_spans[key] = []
|
||||
for span in cluster:
|
||||
char_spans[key].append((span.start_char, span.end_char))
|
||||
with ref.retokenize() as retokenizer:
|
||||
# merge "picked up"
|
||||
retokenizer.merge(ref[2:4])
|
||||
|
||||
# Note this works because it's the same doc and we know the keys
|
||||
for key, _ in ref.spans.items():
|
||||
spans = char_spans[key]
|
||||
ref.spans[key] = [ref.char_span(*span) for span in spans]
|
||||
|
||||
# Finally, copy over the head spans to the pred
|
||||
pred = eg.predicted
|
||||
for key, val in ref.spans.items():
|
||||
if key.startswith("coref_head_clusters"):
|
||||
spans = char_spans[key]
|
||||
pred.spans[key] = [pred.char_span(*span) for span in spans]
|
||||
|
||||
train_examples.append(eg)
|
||||
|
||||
nlp.add_pipe("span_predictor", config=CONFIG)
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
||||
for i in range(15):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
doc = nlp(test_text)
|
||||
|
||||
# test the trained model; need to use doc with head spans on it already
|
||||
test_doc = train_examples[0].predicted
|
||||
doc = nlp(test_doc)
|
||||
# XXX This actually tests that it can overfit
|
||||
assert get_clusters_from_doc(doc) == get_clusters_from_doc(train_examples[0].reference)
|
||||
|
||||
# Also test the results are still the same after IO
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
doc2 = nlp2(test_text)
|
||||
|
||||
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||
texts = [
|
||||
test_text,
|
||||
"I noticed many friends around me",
|
||||
"They received it. They received the SMS.",
|
||||
]
|
||||
|
||||
# save the docs so they don't get garbage collected
|
||||
docs1 = list(nlp.pipe(texts))
|
||||
docs2 = list(nlp.pipe(texts))
|
||||
docs3 = [nlp(text) for text in texts]
|
||||
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
|
||||
assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_whitespace_mismatch(nlp):
|
||||
train_examples = []
|
||||
for text, annot in TRAIN_DATA:
|
||||
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||
eg.predicted = nlp.make_doc(" " + text)
|
||||
train_examples.append(eg)
|
||||
|
||||
nlp.add_pipe("span_predictor", config=CONFIG)
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
||||
with pytest.raises(ValueError, match="whitespace"):
|
||||
nlp.update(train_examples, sgd=optimizer)
|
Loading…
Reference in New Issue
Block a user