mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 08:12:24 +03:00
only use a single spans_key like in spancat
This commit is contained in:
parent
90af16af76
commit
6f750d0da6
|
@ -1,4 +1,3 @@
|
|||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast
|
||||
|
||||
from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate
|
||||
|
@ -41,7 +40,6 @@ depth = 4
|
|||
"""
|
||||
|
||||
DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"]
|
||||
DEFAULT_PREDICTED_KEY = "span_candidates"
|
||||
|
||||
|
||||
@Language.factory(
|
||||
|
@ -50,21 +48,15 @@ DEFAULT_PREDICTED_KEY = "span_candidates"
|
|||
default_config={
|
||||
"threshold": 0.5,
|
||||
"model": DEFAULT_SPAN_FINDER_MODEL,
|
||||
"predicted_key": DEFAULT_PREDICTED_KEY,
|
||||
"training_key": DEFAULT_SPANS_KEY,
|
||||
# XXX Doesn't 0 seem bad compared to None instead?
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
"max_length": None,
|
||||
"min_length": None,
|
||||
"scorer": {
|
||||
"@scorers": "spacy.span_finder_scorer.v1",
|
||||
"predicted_key": DEFAULT_PREDICTED_KEY,
|
||||
"training_key": DEFAULT_SPANS_KEY,
|
||||
},
|
||||
"scorer": {"@scorers": "spacy.span_finder_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
f"span_finder_{DEFAULT_PREDICTED_KEY}_f": 1.0,
|
||||
f"span_finder_{DEFAULT_PREDICTED_KEY}_p": 0.0,
|
||||
f"span_finder_{DEFAULT_PREDICTED_KEY}_r": 0.0,
|
||||
f"span_finder_{DEFAULT_SPANS_KEY}_f": 1.0,
|
||||
f"span_finder_{DEFAULT_SPANS_KEY}_p": 0.0,
|
||||
f"span_finder_{DEFAULT_SPANS_KEY}_r": 0.0,
|
||||
},
|
||||
)
|
||||
def make_span_finder(
|
||||
|
@ -75,8 +67,7 @@ def make_span_finder(
|
|||
threshold: float,
|
||||
max_length: Optional[int],
|
||||
min_length: Optional[int],
|
||||
predicted_key: str = DEFAULT_PREDICTED_KEY,
|
||||
training_key: str = DEFAULT_SPANS_KEY,
|
||||
spans_key: str,
|
||||
) -> "SpanFinder":
|
||||
"""Create a SpanFinder component. The component predicts whether a token is
|
||||
the start or the end of a potential span.
|
||||
|
@ -84,10 +75,9 @@ def make_span_finder(
|
|||
model (Model[List[Doc], Floats2d]): A model instance that
|
||||
is given a list of documents and predicts a probability for each token.
|
||||
threshold (float): Minimum probability to consider a prediction positive.
|
||||
predicted_key (str): Name of the span group the predicted spans are saved
|
||||
to
|
||||
training_key (str): Name of the span group the training spans are read
|
||||
from
|
||||
spans_key (str): Key of the doc.spans dict to save the spans under. During
|
||||
initialization and training, the component will look for spans on the
|
||||
reference document under the same key.
|
||||
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
|
||||
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
|
||||
"""
|
||||
|
@ -99,51 +89,26 @@ def make_span_finder(
|
|||
scorer=scorer,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
predicted_key=predicted_key,
|
||||
training_key=training_key,
|
||||
spans_key=spans_key,
|
||||
)
|
||||
|
||||
|
||||
@registry.scorers("spacy.span_finder_scorer.v1")
|
||||
def make_span_finder_scorer(
|
||||
predicted_key: str = DEFAULT_PREDICTED_KEY,
|
||||
training_key: str = DEFAULT_SPANS_KEY,
|
||||
):
|
||||
return partial(
|
||||
span_finder_score, predicted_key=predicted_key, training_key=training_key
|
||||
)
|
||||
def make_span_finder_scorer():
|
||||
return span_finder_score
|
||||
|
||||
|
||||
def span_finder_score(
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
predicted_key: str = DEFAULT_PREDICTED_KEY,
|
||||
training_key: str = DEFAULT_SPANS_KEY,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
kwargs = dict(kwargs)
|
||||
print(kwargs)
|
||||
attr_prefix = "span_finder_"
|
||||
kwargs.setdefault("attr", f"{attr_prefix}{predicted_key}")
|
||||
kwargs.setdefault("allow_overlap", True)
|
||||
key = kwargs["spans_key"]
|
||||
kwargs.setdefault("attr", f"{attr_prefix}{key}")
|
||||
kwargs.setdefault(
|
||||
"getter", lambda doc, key: doc.spans.get(key[len(attr_prefix) :], [])
|
||||
)
|
||||
kwargs.setdefault("labeled", False)
|
||||
kwargs.setdefault("has_annotation", lambda doc: predicted_key in doc.spans)
|
||||
# score_spans can only score spans with the same key in both the reference
|
||||
# and predicted docs, so temporarily copy the reference spans from the
|
||||
# reference key to the candidates key in the reference docs, restoring the
|
||||
# original span groups afterwards
|
||||
orig_span_groups = []
|
||||
for eg in examples:
|
||||
orig_span_groups.append(eg.reference.spans.get(predicted_key))
|
||||
if training_key in eg.reference.spans:
|
||||
eg.reference.spans[predicted_key] = eg.reference.spans[training_key]
|
||||
scores = Scorer.score_spans(examples, **kwargs)
|
||||
for orig_span_group, eg in zip(orig_span_groups, examples):
|
||||
if orig_span_group is not None:
|
||||
eg.reference.spans[predicted_key] = orig_span_group
|
||||
return scores
|
||||
kwargs.setdefault("has_annotation", lambda doc: key in doc.spans)
|
||||
return Scorer.score_spans(examples, **kwargs)
|
||||
|
||||
|
||||
class _MaxInt(int):
|
||||
|
@ -179,13 +144,8 @@ class SpanFinder(TrainablePipe):
|
|||
max_length: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
# XXX I think this is weird and should be just None like in
|
||||
scorer: Optional[Callable] = partial(
|
||||
span_finder_score,
|
||||
predicted_key=DEFAULT_PREDICTED_KEY,
|
||||
training_key=DEFAULT_SPANS_KEY,
|
||||
),
|
||||
predicted_key: str = DEFAULT_PREDICTED_KEY,
|
||||
training_key: str = DEFAULT_SPANS_KEY,
|
||||
scorer: Optional[Callable] = span_finder_score,
|
||||
spans_key: str = DEFAULT_SPANS_KEY,
|
||||
) -> None:
|
||||
"""Initialize the span boundary detector.
|
||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
|
@ -194,8 +154,9 @@ class SpanFinder(TrainablePipe):
|
|||
threshold (float): Minimum probability to consider a prediction
|
||||
positive.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
predicted_key (str): Name of the span group the candidate spans are saved to
|
||||
training_key (str): Name of the span group the training spans are read from
|
||||
spans_key (str): Key of the doc.spans dict to save the spans under. During
|
||||
initialization and training, the component will look for spans on the
|
||||
reference document under the same key.
|
||||
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
|
||||
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
|
||||
"""
|
||||
|
@ -211,11 +172,11 @@ class SpanFinder(TrainablePipe):
|
|||
)
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
self.predicted_key = predicted_key
|
||||
self.training_key = training_key
|
||||
self.spans_key = spans_key
|
||||
self.model = model
|
||||
self.name = name
|
||||
self.scorer = scorer
|
||||
self.cfg = {"spans_key": spans_key}
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
@ -232,7 +193,7 @@ class SpanFinder(TrainablePipe):
|
|||
"""
|
||||
offset = 0
|
||||
for i, doc in enumerate(docs):
|
||||
doc.spans[self.predicted_key] = []
|
||||
doc.spans[self.spans_key] = []
|
||||
starts = []
|
||||
ends = []
|
||||
doc_scores = scores[offset : offset + len(doc)]
|
||||
|
@ -249,7 +210,7 @@ class SpanFinder(TrainablePipe):
|
|||
if span_length > self.max_length:
|
||||
break
|
||||
elif self.min_length <= span_length:
|
||||
doc.spans[self.predicted_key].append(doc[start : end + 1])
|
||||
doc.spans[self.spans_key].append(doc[start : end + 1])
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
@ -304,8 +265,8 @@ class SpanFinder(TrainablePipe):
|
|||
n_tokens = len(eg.predicted)
|
||||
truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
|
||||
mask = ops.xp.ones((n_tokens, 2), dtype="float32")
|
||||
if self.training_key in eg.reference.spans:
|
||||
for span in eg.reference.spans[self.training_key]:
|
||||
if self.spans_key in eg.reference.spans:
|
||||
for span in eg.reference.spans[self.spans_key]:
|
||||
ref_start_char, ref_end_char = _char_indices(span)
|
||||
pred_span = eg.predicted.char_span(
|
||||
ref_start_char, ref_end_char, alignment_mode="expand"
|
||||
|
@ -342,8 +303,8 @@ class SpanFinder(TrainablePipe):
|
|||
start_indices = set()
|
||||
end_indices = set()
|
||||
|
||||
if self.training_key in doc.spans:
|
||||
for span in doc.spans[self.training_key]:
|
||||
if self.spans_key in doc.spans:
|
||||
for span in doc.spans[self.spans_key]:
|
||||
start_indices.add(span.start)
|
||||
end_indices.add(span.end - 1)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from thinc.types import Ragged
|
|||
|
||||
from spacy.language import Language
|
||||
from spacy.lang.en import English
|
||||
from spacy.pipeline.span_finder import DEFAULT_PREDICTED_KEY, span_finder_default_config
|
||||
from spacy.pipeline.span_finder import span_finder_default_config
|
||||
from spacy.tokens import Doc
|
||||
from spacy.training import Example
|
||||
from spacy import util
|
||||
|
@ -12,22 +12,22 @@ from spacy.util import registry
|
|||
from spacy.util import fix_random_seed, make_tempdir
|
||||
|
||||
|
||||
TRAINING_KEY = "pytest"
|
||||
SPANS_KEY = "pytest"
|
||||
TRAIN_DATA = [
|
||||
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
|
||||
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
|
||||
(
|
||||
"I like London and Berlin.",
|
||||
{"spans": {TRAINING_KEY: [(7, 13, "LOC"), (18, 24)]}},
|
||||
{"spans": {SPANS_KEY: [(7, 13), (18, 24)]}},
|
||||
),
|
||||
]
|
||||
|
||||
TRAIN_DATA_OVERLAPPING = [
|
||||
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
|
||||
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
|
||||
(
|
||||
"I like London and Berlin",
|
||||
{"spans": {TRAINING_KEY: [(7, 13), (18, 24), (7, 24)]}},
|
||||
{"spans": {SPANS_KEY: [(7, 13), (18, 24), (7, 24)]}},
|
||||
),
|
||||
("", {"spans": {TRAINING_KEY: []}}),
|
||||
("", {"spans": {SPANS_KEY: []}}),
|
||||
]
|
||||
|
||||
|
||||
|
@ -88,8 +88,8 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr
|
|||
nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference)
|
||||
)
|
||||
example = Example(predicted, reference)
|
||||
example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)]
|
||||
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
||||
example.reference.spans[SPANS_KEY] = [example.reference.char_span(5, 9)]
|
||||
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
||||
nlp.initialize()
|
||||
ops = span_finder.model.ops
|
||||
if predicted.text != reference.text:
|
||||
|
@ -107,8 +107,8 @@ def test_span_finder_model():
|
|||
nlp = Language()
|
||||
|
||||
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
||||
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]]
|
||||
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]]
|
||||
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
|
||||
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
|
||||
|
||||
total_tokens = 0
|
||||
for doc in docs:
|
||||
|
@ -128,15 +128,15 @@ def test_span_finder_component():
|
|||
nlp = Language()
|
||||
|
||||
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
||||
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]]
|
||||
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]]
|
||||
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
|
||||
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
|
||||
|
||||
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
||||
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
||||
nlp.initialize()
|
||||
docs = list(span_finder.pipe(docs))
|
||||
|
||||
# TODO: update hard-coded name
|
||||
assert "span_candidates" in docs[0].spans
|
||||
assert SPANS_KEY in docs[0].spans
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -153,7 +153,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
|||
config={
|
||||
"max_length": max_length,
|
||||
"min_length": min_length,
|
||||
"training_key": TRAINING_KEY,
|
||||
"spans_key": SPANS_KEY,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
@ -162,7 +162,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
|||
config={
|
||||
"max_length": max_length,
|
||||
"min_length": min_length,
|
||||
"training_key": TRAINING_KEY,
|
||||
"spans_key": SPANS_KEY,
|
||||
},
|
||||
)
|
||||
nlp.initialize()
|
||||
|
@ -182,8 +182,8 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
|||
]
|
||||
span_finder.set_annotations([doc], scores)
|
||||
|
||||
assert doc.spans[DEFAULT_PREDICTED_KEY]
|
||||
assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count
|
||||
assert doc.spans[SPANS_KEY]
|
||||
assert len(doc.spans[SPANS_KEY]) == span_count
|
||||
|
||||
# Assert below will fail when max_length is set to 0
|
||||
if max_length is None:
|
||||
|
@ -193,40 +193,39 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
|
|||
|
||||
assert all(
|
||||
min_length <= len(span) <= max_length
|
||||
for span in doc.spans[DEFAULT_PREDICTED_KEY]
|
||||
for span in doc.spans[SPANS_KEY]
|
||||
)
|
||||
|
||||
|
||||
def test_span_finder_suggester():
|
||||
nlp = Language()
|
||||
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
||||
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]]
|
||||
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]]
|
||||
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
||||
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
|
||||
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
|
||||
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
||||
nlp.initialize()
|
||||
span_finder.set_annotations(docs, span_finder.predict(docs))
|
||||
|
||||
suggester = registry.misc.get("spacy.span_finder_suggester.v1")(
|
||||
candidates_key="span_candidates"
|
||||
candidates_key=SPANS_KEY
|
||||
)
|
||||
|
||||
candidates = suggester(docs)
|
||||
|
||||
span_length = 0
|
||||
for doc in docs:
|
||||
span_length += len(doc.spans["span_candidates"])
|
||||
span_length += len(doc.spans[SPANS_KEY])
|
||||
|
||||
assert span_length == len(candidates.dataXd)
|
||||
assert type(candidates) == Ragged
|
||||
assert len(candidates.dataXd[0]) == 2
|
||||
|
||||
|
||||
# XXX Fails because i think the suggester is not correctly implemented?
|
||||
def test_overfitting_IO():
|
||||
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
|
||||
fix_random_seed(0)
|
||||
nlp = English()
|
||||
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
||||
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
|
||||
train_examples = make_examples(nlp)
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
assert span_finder.model.get_dim("nO") == 2
|
||||
|
@ -239,30 +238,27 @@ def test_overfitting_IO():
|
|||
# test the trained model
|
||||
test_text = "I like London and Berlin"
|
||||
doc = nlp(test_text)
|
||||
spans = doc.spans[span_finder.predicted_key]
|
||||
assert len(spans) == 2
|
||||
assert len(spans.attrs["scores"]) == 2
|
||||
assert min(spans.attrs["scores"]) > 0.9
|
||||
assert set([span.text for span in spans]) == {"London", "Berlin"}
|
||||
spans = doc.spans[span_finder.spans_key]
|
||||
assert len(spans) == 3
|
||||
assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"}
|
||||
|
||||
# Also test the results are still the same after IO
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
doc2 = nlp2(test_text)
|
||||
spans2 = doc2.spans[TRAINING_KEY]
|
||||
assert len(spans2) == 2
|
||||
assert len(spans2.attrs["scores"]) == 2
|
||||
assert min(spans2.attrs["scores"]) > 0.9
|
||||
assert set([span.text for span in spans2]) == {"London", "Berlin"}
|
||||
spans2 = doc2.spans[span_finder.spans_key]
|
||||
assert len(spans2) == 3
|
||||
assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"}
|
||||
|
||||
# Test scoring
|
||||
scores = nlp.evaluate(train_examples)
|
||||
assert f"spans_{TRAINING_KEY}_f" in scores
|
||||
assert scores[f"spans_{TRAINING_KEY}_p"] == 1.0
|
||||
assert scores[f"spans_{TRAINING_KEY}_r"] == 1.0
|
||||
assert scores[f"spans_{TRAINING_KEY}_f"] == 1.0
|
||||
sf = nlp.get_pipe("span_finder")
|
||||
print(sf.spans_key)
|
||||
assert f"span_finder_{span_finder.spans_key}_f" in scores
|
||||
# XXX Its not perfect 1.0 F1 because we want it to overgenerate for now.
|
||||
assert scores[f"span_finder_{span_finder.spans_key}_f"] == 0.4
|
||||
|
||||
# also test that the spancat works for just a single entity in a sentence
|
||||
doc = nlp("London")
|
||||
assert len(doc.spans[span_finder.predicted_key]) == 1
|
||||
assert len(doc.spans[span_finder.spans_key]) == 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user