mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 08:12:24 +03:00
only use a single spans_key like in spancat
This commit is contained in:
parent
90af16af76
commit
6f750d0da6
|
@ -1,4 +1,3 @@
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast
|
||||||
|
|
||||||
from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate
|
from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate
|
||||||
|
@ -41,7 +40,6 @@ depth = 4
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"]
|
DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"]
|
||||||
DEFAULT_PREDICTED_KEY = "span_candidates"
|
|
||||||
|
|
||||||
|
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
|
@ -50,21 +48,15 @@ DEFAULT_PREDICTED_KEY = "span_candidates"
|
||||||
default_config={
|
default_config={
|
||||||
"threshold": 0.5,
|
"threshold": 0.5,
|
||||||
"model": DEFAULT_SPAN_FINDER_MODEL,
|
"model": DEFAULT_SPAN_FINDER_MODEL,
|
||||||
"predicted_key": DEFAULT_PREDICTED_KEY,
|
"spans_key": DEFAULT_SPANS_KEY,
|
||||||
"training_key": DEFAULT_SPANS_KEY,
|
|
||||||
# XXX Doesn't 0 seem bad compared to None instead?
|
|
||||||
"max_length": None,
|
"max_length": None,
|
||||||
"min_length": None,
|
"min_length": None,
|
||||||
"scorer": {
|
"scorer": {"@scorers": "spacy.span_finder_scorer.v1"},
|
||||||
"@scorers": "spacy.span_finder_scorer.v1",
|
|
||||||
"predicted_key": DEFAULT_PREDICTED_KEY,
|
|
||||||
"training_key": DEFAULT_SPANS_KEY,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
default_score_weights={
|
default_score_weights={
|
||||||
f"span_finder_{DEFAULT_PREDICTED_KEY}_f": 1.0,
|
f"span_finder_{DEFAULT_SPANS_KEY}_f": 1.0,
|
||||||
f"span_finder_{DEFAULT_PREDICTED_KEY}_p": 0.0,
|
f"span_finder_{DEFAULT_SPANS_KEY}_p": 0.0,
|
||||||
f"span_finder_{DEFAULT_PREDICTED_KEY}_r": 0.0,
|
f"span_finder_{DEFAULT_SPANS_KEY}_r": 0.0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def make_span_finder(
|
def make_span_finder(
|
||||||
|
@ -75,8 +67,7 @@ def make_span_finder(
|
||||||
threshold: float,
|
threshold: float,
|
||||||
max_length: Optional[int],
|
max_length: Optional[int],
|
||||||
min_length: Optional[int],
|
min_length: Optional[int],
|
||||||
predicted_key: str = DEFAULT_PREDICTED_KEY,
|
spans_key: str,
|
||||||
training_key: str = DEFAULT_SPANS_KEY,
|
|
||||||
) -> "SpanFinder":
|
) -> "SpanFinder":
|
||||||
"""Create a SpanFinder component. The component predicts whether a token is
|
"""Create a SpanFinder component. The component predicts whether a token is
|
||||||
the start or the end of a potential span.
|
the start or the end of a potential span.
|
||||||
|
@ -84,10 +75,9 @@ def make_span_finder(
|
||||||
model (Model[List[Doc], Floats2d]): A model instance that
|
model (Model[List[Doc], Floats2d]): A model instance that
|
||||||
is given a list of documents and predicts a probability for each token.
|
is given a list of documents and predicts a probability for each token.
|
||||||
threshold (float): Minimum probability to consider a prediction positive.
|
threshold (float): Minimum probability to consider a prediction positive.
|
||||||
predicted_key (str): Name of the span group the predicted spans are saved
|
spans_key (str): Key of the doc.spans dict to save the spans under. During
|
||||||
to
|
initialization and training, the component will look for spans on the
|
||||||
training_key (str): Name of the span group the training spans are read
|
reference document under the same key.
|
||||||
from
|
|
||||||
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
|
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
|
||||||
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
|
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
|
||||||
"""
|
"""
|
||||||
|
@ -99,51 +89,26 @@ def make_span_finder(
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
predicted_key=predicted_key,
|
spans_key=spans_key,
|
||||||
training_key=training_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.scorers("spacy.span_finder_scorer.v1")
|
@registry.scorers("spacy.span_finder_scorer.v1")
|
||||||
def make_span_finder_scorer(
|
def make_span_finder_scorer():
|
||||||
predicted_key: str = DEFAULT_PREDICTED_KEY,
|
return span_finder_score
|
||||||
training_key: str = DEFAULT_SPANS_KEY,
|
|
||||||
):
|
|
||||||
return partial(
|
|
||||||
span_finder_score, predicted_key=predicted_key, training_key=training_key
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def span_finder_score(
|
def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||||
examples: Iterable[Example],
|
|
||||||
*,
|
|
||||||
predicted_key: str = DEFAULT_PREDICTED_KEY,
|
|
||||||
training_key: str = DEFAULT_SPANS_KEY,
|
|
||||||
**kwargs,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
kwargs = dict(kwargs)
|
kwargs = dict(kwargs)
|
||||||
|
print(kwargs)
|
||||||
attr_prefix = "span_finder_"
|
attr_prefix = "span_finder_"
|
||||||
kwargs.setdefault("attr", f"{attr_prefix}{predicted_key}")
|
key = kwargs["spans_key"]
|
||||||
kwargs.setdefault("allow_overlap", True)
|
kwargs.setdefault("attr", f"{attr_prefix}{key}")
|
||||||
kwargs.setdefault(
|
kwargs.setdefault(
|
||||||
"getter", lambda doc, key: doc.spans.get(key[len(attr_prefix) :], [])
|
"getter", lambda doc, key: doc.spans.get(key[len(attr_prefix) :], [])
|
||||||
)
|
)
|
||||||
kwargs.setdefault("labeled", False)
|
kwargs.setdefault("has_annotation", lambda doc: key in doc.spans)
|
||||||
kwargs.setdefault("has_annotation", lambda doc: predicted_key in doc.spans)
|
return Scorer.score_spans(examples, **kwargs)
|
||||||
# score_spans can only score spans with the same key in both the reference
|
|
||||||
# and predicted docs, so temporarily copy the reference spans from the
|
|
||||||
# reference key to the candidates key in the reference docs, restoring the
|
|
||||||
# original span groups afterwards
|
|
||||||
orig_span_groups = []
|
|
||||||
for eg in examples:
|
|
||||||
orig_span_groups.append(eg.reference.spans.get(predicted_key))
|
|
||||||
if training_key in eg.reference.spans:
|
|
||||||
eg.reference.spans[predicted_key] = eg.reference.spans[training_key]
|
|
||||||
scores = Scorer.score_spans(examples, **kwargs)
|
|
||||||
for orig_span_group, eg in zip(orig_span_groups, examples):
|
|
||||||
if orig_span_group is not None:
|
|
||||||
eg.reference.spans[predicted_key] = orig_span_group
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
class _MaxInt(int):
|
class _MaxInt(int):
|
||||||
|
@ -179,13 +144,8 @@ class SpanFinder(TrainablePipe):
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
min_length: Optional[int] = None,
|
min_length: Optional[int] = None,
|
||||||
# XXX I think this is weird and should be just None like in
|
# XXX I think this is weird and should be just None like in
|
||||||
scorer: Optional[Callable] = partial(
|
scorer: Optional[Callable] = span_finder_score,
|
||||||
span_finder_score,
|
spans_key: str = DEFAULT_SPANS_KEY,
|
||||||
predicted_key=DEFAULT_PREDICTED_KEY,
|
|
||||||
training_key=DEFAULT_SPANS_KEY,
|
|
||||||
),
|
|
||||||
predicted_key: str = DEFAULT_PREDICTED_KEY,
|
|
||||||
training_key: str = DEFAULT_SPANS_KEY,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the span boundary detector.
|
"""Initialize the span boundary detector.
|
||||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||||
|
@ -194,8 +154,9 @@ class SpanFinder(TrainablePipe):
|
||||||
threshold (float): Minimum probability to consider a prediction
|
threshold (float): Minimum probability to consider a prediction
|
||||||
positive.
|
positive.
|
||||||
scorer (Optional[Callable]): The scoring method.
|
scorer (Optional[Callable]): The scoring method.
|
||||||
predicted_key (str): Name of the span group the candidate spans are saved to
|
spans_key (str): Key of the doc.spans dict to save the spans under. During
|
||||||
training_key (str): Name of the span group the training spans are read from
|
initialization and training, the component will look for spans on the
|
||||||
|
reference document under the same key.
|
||||||
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
|
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
|
||||||
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
|
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
|
||||||
"""
|
"""
|
||||||
|
@ -211,11 +172,11 @@ class SpanFinder(TrainablePipe):
|
||||||
)
|
)
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.predicted_key = predicted_key
|
self.spans_key = spans_key
|
||||||
self.training_key = training_key
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
self.cfg = {"spans_key": spans_key}
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]):
|
def predict(self, docs: Iterable[Doc]):
|
||||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
@ -232,7 +193,7 @@ class SpanFinder(TrainablePipe):
|
||||||
"""
|
"""
|
||||||
offset = 0
|
offset = 0
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.spans[self.predicted_key] = []
|
doc.spans[self.spans_key] = []
|
||||||
starts = []
|
starts = []
|
||||||
ends = []
|
ends = []
|
||||||
doc_scores = scores[offset : offset + len(doc)]
|
doc_scores = scores[offset : offset + len(doc)]
|
||||||
|
@ -249,7 +210,7 @@ class SpanFinder(TrainablePipe):
|
||||||
if span_length > self.max_length:
|
if span_length > self.max_length:
|
||||||
break
|
break
|
||||||
elif self.min_length <= span_length:
|
elif self.min_length <= span_length:
|
||||||
doc.spans[self.predicted_key].append(doc[start : end + 1])
|
doc.spans[self.spans_key].append(doc[start : end + 1])
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
|
@ -304,8 +265,8 @@ class SpanFinder(TrainablePipe):
|
||||||
n_tokens = len(eg.predicted)
|
n_tokens = len(eg.predicted)
|
||||||
truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
|
truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
|
||||||
mask = ops.xp.ones((n_tokens, 2), dtype="float32")
|
mask = ops.xp.ones((n_tokens, 2), dtype="float32")
|
||||||
if self.training_key in eg.reference.spans:
|
if self.spans_key in eg.reference.spans:
|
||||||
for span in eg.reference.spans[self.training_key]:
|
for span in eg.reference.spans[self.spans_key]:
|
||||||
ref_start_char, ref_end_char = _char_indices(span)
|
ref_start_char, ref_end_char = _char_indices(span)
|
||||||
pred_span = eg.predicted.char_span(
|
pred_span = eg.predicted.char_span(
|
||||||
ref_start_char, ref_end_char, alignment_mode="expand"
|
ref_start_char, ref_end_char, alignment_mode="expand"
|
||||||
|
@ -342,8 +303,8 @@ class SpanFinder(TrainablePipe):
|
||||||
start_indices = set()
|
start_indices = set()
|
||||||
end_indices = set()
|
end_indices = set()
|
||||||
|
|
||||||
if self.training_key in doc.spans:
|
if self.spans_key in doc.spans:
|
||||||
for span in doc.spans[self.training_key]:
|
for span in doc.spans[self.spans_key]:
|
||||||
start_indices.add(span.start)
|
start_indices.add(span.start)
|
||||||
end_indices.add(span.end - 1)
|
end_indices.add(span.end - 1)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from thinc.types import Ragged
|
||||||
|
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.pipeline.span_finder import DEFAULT_PREDICTED_KEY, span_finder_default_config
|
from spacy.pipeline.span_finder import span_finder_default_config
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy import util
|
from spacy import util
|
||||||
|
@ -12,22 +12,22 @@ from spacy.util import registry
|
||||||
from spacy.util import fix_random_seed, make_tempdir
|
from spacy.util import fix_random_seed, make_tempdir
|
||||||
|
|
||||||
|
|
||||||
TRAINING_KEY = "pytest"
|
SPANS_KEY = "pytest"
|
||||||
TRAIN_DATA = [
|
TRAIN_DATA = [
|
||||||
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
|
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
|
||||||
(
|
(
|
||||||
"I like London and Berlin.",
|
"I like London and Berlin.",
|
||||||
{"spans": {TRAINING_KEY: [(7, 13, "LOC"), (18, 24)]}},
|
{"spans": {SPANS_KEY: [(7, 13), (18, 24)]}},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
TRAIN_DATA_OVERLAPPING = [
|
TRAIN_DATA_OVERLAPPING = [
|
||||||
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
|
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
|
||||||
(
|
(
|
||||||
"I like London and Berlin",
|
"I like London and Berlin",
|
||||||
{"spans": {TRAINING_KEY: [(7, 13), (18, 24), (7, 24)]}},
|
{"spans": {SPANS_KEY: [(7, 13), (18, 24), (7, 24)]}},
|
||||||
),
|
),
|
||||||
("", {"spans": {TRAINING_KEY: []}}),
|
("", {"spans": {SPANS_KEY: []}}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,8 +88,8 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr
|
||||||
nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference)
|
nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference)
|
||||||
)
|
)
|
||||||
example = Example(predicted, reference)
|
example = Example(predicted, reference)
|
||||||
example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)]
|
example.reference.spans[SPANS_KEY] = [example.reference.char_span(5, 9)]
|
||||||
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
ops = span_finder.model.ops
|
ops = span_finder.model.ops
|
||||||
if predicted.text != reference.text:
|
if predicted.text != reference.text:
|
||||||
|
@ -107,8 +107,8 @@ def test_span_finder_model():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
|
||||||
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
||||||
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]]
|
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
|
||||||
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]]
|
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
|
||||||
|
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
@ -128,15 +128,15 @@ def test_span_finder_component():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
|
||||||
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
||||||
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]]
|
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
|
||||||
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]]
|
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
|
||||||
|
|
||||||
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
docs = list(span_finder.pipe(docs))
|
docs = list(span_finder.pipe(docs))
|
||||||
|
|
||||||
# TODO: update hard-coded name
|
# TODO: update hard-coded name
|
||||||
assert "span_candidates" in docs[0].spans
|
assert SPANS_KEY in docs[0].spans
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -153,7 +153,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
||||||
config={
|
config={
|
||||||
"max_length": max_length,
|
"max_length": max_length,
|
||||||
"min_length": min_length,
|
"min_length": min_length,
|
||||||
"training_key": TRAINING_KEY,
|
"spans_key": SPANS_KEY,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
@ -162,7 +162,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
||||||
config={
|
config={
|
||||||
"max_length": max_length,
|
"max_length": max_length,
|
||||||
"min_length": min_length,
|
"min_length": min_length,
|
||||||
"training_key": TRAINING_KEY,
|
"spans_key": SPANS_KEY,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
|
@ -182,8 +182,8 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
||||||
]
|
]
|
||||||
span_finder.set_annotations([doc], scores)
|
span_finder.set_annotations([doc], scores)
|
||||||
|
|
||||||
assert doc.spans[DEFAULT_PREDICTED_KEY]
|
assert doc.spans[SPANS_KEY]
|
||||||
assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count
|
assert len(doc.spans[SPANS_KEY]) == span_count
|
||||||
|
|
||||||
# Assert below will fail when max_length is set to 0
|
# Assert below will fail when max_length is set to 0
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
|
@ -193,40 +193,39 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
||||||
|
|
||||||
assert all(
|
assert all(
|
||||||
min_length <= len(span) <= max_length
|
min_length <= len(span) <= max_length
|
||||||
for span in doc.spans[DEFAULT_PREDICTED_KEY]
|
for span in doc.spans[SPANS_KEY]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_span_finder_suggester():
|
def test_span_finder_suggester():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
||||||
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]]
|
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
|
||||||
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]]
|
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
|
||||||
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
span_finder.set_annotations(docs, span_finder.predict(docs))
|
span_finder.set_annotations(docs, span_finder.predict(docs))
|
||||||
|
|
||||||
suggester = registry.misc.get("spacy.span_finder_suggester.v1")(
|
suggester = registry.misc.get("spacy.span_finder_suggester.v1")(
|
||||||
candidates_key="span_candidates"
|
candidates_key=SPANS_KEY
|
||||||
)
|
)
|
||||||
|
|
||||||
candidates = suggester(docs)
|
candidates = suggester(docs)
|
||||||
|
|
||||||
span_length = 0
|
span_length = 0
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
span_length += len(doc.spans["span_candidates"])
|
span_length += len(doc.spans[SPANS_KEY])
|
||||||
|
|
||||||
assert span_length == len(candidates.dataXd)
|
assert span_length == len(candidates.dataXd)
|
||||||
assert type(candidates) == Ragged
|
assert type(candidates) == Ragged
|
||||||
assert len(candidates.dataXd[0]) == 2
|
assert len(candidates.dataXd[0]) == 2
|
||||||
|
|
||||||
|
|
||||||
# XXX Fails because i think the suggester is not correctly implemented?
|
|
||||||
def test_overfitting_IO():
|
def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
nlp = English()
|
nlp = English()
|
||||||
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
||||||
train_examples = make_examples(nlp)
|
train_examples = make_examples(nlp)
|
||||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
assert span_finder.model.get_dim("nO") == 2
|
assert span_finder.model.get_dim("nO") == 2
|
||||||
|
@ -239,30 +238,27 @@ def test_overfitting_IO():
|
||||||
# test the trained model
|
# test the trained model
|
||||||
test_text = "I like London and Berlin"
|
test_text = "I like London and Berlin"
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
spans = doc.spans[span_finder.predicted_key]
|
spans = doc.spans[span_finder.spans_key]
|
||||||
assert len(spans) == 2
|
assert len(spans) == 3
|
||||||
assert len(spans.attrs["scores"]) == 2
|
assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"}
|
||||||
assert min(spans.attrs["scores"]) > 0.9
|
|
||||||
assert set([span.text for span in spans]) == {"London", "Berlin"}
|
|
||||||
|
|
||||||
# Also test the results are still the same after IO
|
# Also test the results are still the same after IO
|
||||||
with make_tempdir() as tmp_dir:
|
with make_tempdir() as tmp_dir:
|
||||||
nlp.to_disk(tmp_dir)
|
nlp.to_disk(tmp_dir)
|
||||||
nlp2 = util.load_model_from_path(tmp_dir)
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
doc2 = nlp2(test_text)
|
doc2 = nlp2(test_text)
|
||||||
spans2 = doc2.spans[TRAINING_KEY]
|
spans2 = doc2.spans[span_finder.spans_key]
|
||||||
assert len(spans2) == 2
|
assert len(spans2) == 3
|
||||||
assert len(spans2.attrs["scores"]) == 2
|
assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"}
|
||||||
assert min(spans2.attrs["scores"]) > 0.9
|
|
||||||
assert set([span.text for span in spans2]) == {"London", "Berlin"}
|
|
||||||
|
|
||||||
# Test scoring
|
# Test scoring
|
||||||
scores = nlp.evaluate(train_examples)
|
scores = nlp.evaluate(train_examples)
|
||||||
assert f"spans_{TRAINING_KEY}_f" in scores
|
sf = nlp.get_pipe("span_finder")
|
||||||
assert scores[f"spans_{TRAINING_KEY}_p"] == 1.0
|
print(sf.spans_key)
|
||||||
assert scores[f"spans_{TRAINING_KEY}_r"] == 1.0
|
assert f"span_finder_{span_finder.spans_key}_f" in scores
|
||||||
assert scores[f"spans_{TRAINING_KEY}_f"] == 1.0
|
# XXX Its not perfect 1.0 F1 because we want it to overgenerate for now.
|
||||||
|
assert scores[f"span_finder_{span_finder.spans_key}_f"] == 0.4
|
||||||
|
|
||||||
# also test that the spancat works for just a single entity in a sentence
|
# also test that the spancat works for just a single entity in a sentence
|
||||||
doc = nlp("London")
|
doc = nlp("London")
|
||||||
assert len(doc.spans[span_finder.predicted_key]) == 1
|
assert len(doc.spans[span_finder.spans_key]) == 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user