only use a single spans_key like in spancat

This commit is contained in:
kadarakos 2023-06-01 10:19:22 +00:00
parent 90af16af76
commit 6f750d0da6
2 changed files with 69 additions and 112 deletions

View File

@ -1,4 +1,3 @@
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast
from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate
@ -41,7 +40,6 @@ depth = 4
"""
DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"]
DEFAULT_PREDICTED_KEY = "span_candidates"
@Language.factory(
@ -50,21 +48,15 @@ DEFAULT_PREDICTED_KEY = "span_candidates"
default_config={
"threshold": 0.5,
"model": DEFAULT_SPAN_FINDER_MODEL,
"predicted_key": DEFAULT_PREDICTED_KEY,
"training_key": DEFAULT_SPANS_KEY,
# XXX Doesn't 0 seem bad compared to None instead?
"spans_key": DEFAULT_SPANS_KEY,
"max_length": None,
"min_length": None,
"scorer": {
"@scorers": "spacy.span_finder_scorer.v1",
"predicted_key": DEFAULT_PREDICTED_KEY,
"training_key": DEFAULT_SPANS_KEY,
},
"scorer": {"@scorers": "spacy.span_finder_scorer.v1"},
},
default_score_weights={
f"span_finder_{DEFAULT_PREDICTED_KEY}_f": 1.0,
f"span_finder_{DEFAULT_PREDICTED_KEY}_p": 0.0,
f"span_finder_{DEFAULT_PREDICTED_KEY}_r": 0.0,
f"span_finder_{DEFAULT_SPANS_KEY}_f": 1.0,
f"span_finder_{DEFAULT_SPANS_KEY}_p": 0.0,
f"span_finder_{DEFAULT_SPANS_KEY}_r": 0.0,
},
)
def make_span_finder(
@ -75,8 +67,7 @@ def make_span_finder(
threshold: float,
max_length: Optional[int],
min_length: Optional[int],
predicted_key: str = DEFAULT_PREDICTED_KEY,
training_key: str = DEFAULT_SPANS_KEY,
spans_key: str,
) -> "SpanFinder":
"""Create a SpanFinder component. The component predicts whether a token is
the start or the end of a potential span.
@ -84,10 +75,9 @@ def make_span_finder(
model (Model[List[Doc], Floats2d]): A model instance that
is given a list of documents and predicts a probability for each token.
threshold (float): Minimum probability to consider a prediction positive.
predicted_key (str): Name of the span group the predicted spans are saved
to
training_key (str): Name of the span group the training spans are read
from
spans_key (str): Key of the doc.spans dict to save the spans under. During
initialization and training, the component will look for spans on the
reference document under the same key.
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
"""
@ -99,51 +89,26 @@ def make_span_finder(
scorer=scorer,
max_length=max_length,
min_length=min_length,
predicted_key=predicted_key,
training_key=training_key,
spans_key=spans_key,
)
@registry.scorers("spacy.span_finder_scorer.v1")
def make_span_finder_scorer(
predicted_key: str = DEFAULT_PREDICTED_KEY,
training_key: str = DEFAULT_SPANS_KEY,
):
return partial(
span_finder_score, predicted_key=predicted_key, training_key=training_key
)
def make_span_finder_scorer():
return span_finder_score
def span_finder_score(
examples: Iterable[Example],
*,
predicted_key: str = DEFAULT_PREDICTED_KEY,
training_key: str = DEFAULT_SPANS_KEY,
**kwargs,
) -> Dict[str, Any]:
def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
kwargs = dict(kwargs)
print(kwargs)
attr_prefix = "span_finder_"
kwargs.setdefault("attr", f"{attr_prefix}{predicted_key}")
kwargs.setdefault("allow_overlap", True)
key = kwargs["spans_key"]
kwargs.setdefault("attr", f"{attr_prefix}{key}")
kwargs.setdefault(
"getter", lambda doc, key: doc.spans.get(key[len(attr_prefix) :], [])
)
kwargs.setdefault("labeled", False)
kwargs.setdefault("has_annotation", lambda doc: predicted_key in doc.spans)
# score_spans can only score spans with the same key in both the reference
# and predicted docs, so temporarily copy the reference spans from the
# reference key to the candidates key in the reference docs, restoring the
# original span groups afterwards
orig_span_groups = []
for eg in examples:
orig_span_groups.append(eg.reference.spans.get(predicted_key))
if training_key in eg.reference.spans:
eg.reference.spans[predicted_key] = eg.reference.spans[training_key]
scores = Scorer.score_spans(examples, **kwargs)
for orig_span_group, eg in zip(orig_span_groups, examples):
if orig_span_group is not None:
eg.reference.spans[predicted_key] = orig_span_group
return scores
kwargs.setdefault("has_annotation", lambda doc: key in doc.spans)
return Scorer.score_spans(examples, **kwargs)
class _MaxInt(int):
@ -179,13 +144,8 @@ class SpanFinder(TrainablePipe):
max_length: Optional[int] = None,
min_length: Optional[int] = None,
# XXX I think this is weird and should be just None like in
scorer: Optional[Callable] = partial(
span_finder_score,
predicted_key=DEFAULT_PREDICTED_KEY,
training_key=DEFAULT_SPANS_KEY,
),
predicted_key: str = DEFAULT_PREDICTED_KEY,
training_key: str = DEFAULT_SPANS_KEY,
scorer: Optional[Callable] = span_finder_score,
spans_key: str = DEFAULT_SPANS_KEY,
) -> None:
"""Initialize the span boundary detector.
model (thinc.api.Model): The Thinc Model powering the pipeline component.
@ -194,8 +154,9 @@ class SpanFinder(TrainablePipe):
threshold (float): Minimum probability to consider a prediction
positive.
scorer (Optional[Callable]): The scoring method.
predicted_key (str): Name of the span group the candidate spans are saved to
training_key (str): Name of the span group the training spans are read from
spans_key (str): Key of the doc.spans dict to save the spans under. During
initialization and training, the component will look for spans on the
reference document under the same key.
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
"""
@ -211,11 +172,11 @@ class SpanFinder(TrainablePipe):
)
self.min_length = min_length
self.max_length = max_length
self.predicted_key = predicted_key
self.training_key = training_key
self.spans_key = spans_key
self.model = model
self.name = name
self.scorer = scorer
self.cfg = {"spans_key": spans_key}
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.
@ -232,7 +193,7 @@ class SpanFinder(TrainablePipe):
"""
offset = 0
for i, doc in enumerate(docs):
doc.spans[self.predicted_key] = []
doc.spans[self.spans_key] = []
starts = []
ends = []
doc_scores = scores[offset : offset + len(doc)]
@ -249,7 +210,7 @@ class SpanFinder(TrainablePipe):
if span_length > self.max_length:
break
elif self.min_length <= span_length:
doc.spans[self.predicted_key].append(doc[start : end + 1])
doc.spans[self.spans_key].append(doc[start : end + 1])
def update(
self,
@ -304,8 +265,8 @@ class SpanFinder(TrainablePipe):
n_tokens = len(eg.predicted)
truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
mask = ops.xp.ones((n_tokens, 2), dtype="float32")
if self.training_key in eg.reference.spans:
for span in eg.reference.spans[self.training_key]:
if self.spans_key in eg.reference.spans:
for span in eg.reference.spans[self.spans_key]:
ref_start_char, ref_end_char = _char_indices(span)
pred_span = eg.predicted.char_span(
ref_start_char, ref_end_char, alignment_mode="expand"
@ -342,8 +303,8 @@ class SpanFinder(TrainablePipe):
start_indices = set()
end_indices = set()
if self.training_key in doc.spans:
for span in doc.spans[self.training_key]:
if self.spans_key in doc.spans:
for span in doc.spans[self.spans_key]:
start_indices.add(span.start)
end_indices.add(span.end - 1)

View File

@ -4,7 +4,7 @@ from thinc.types import Ragged
from spacy.language import Language
from spacy.lang.en import English
from spacy.pipeline.span_finder import DEFAULT_PREDICTED_KEY, span_finder_default_config
from spacy.pipeline.span_finder import span_finder_default_config
from spacy.tokens import Doc
from spacy.training import Example
from spacy import util
@ -12,22 +12,22 @@ from spacy.util import registry
from spacy.util import fix_random_seed, make_tempdir
TRAINING_KEY = "pytest"
SPANS_KEY = "pytest"
TRAIN_DATA = [
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
(
"I like London and Berlin.",
{"spans": {TRAINING_KEY: [(7, 13, "LOC"), (18, 24)]}},
{"spans": {SPANS_KEY: [(7, 13), (18, 24)]}},
),
]
TRAIN_DATA_OVERLAPPING = [
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}),
("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
(
"I like London and Berlin",
{"spans": {TRAINING_KEY: [(7, 13), (18, 24), (7, 24)]}},
{"spans": {SPANS_KEY: [(7, 13), (18, 24), (7, 24)]}},
),
("", {"spans": {TRAINING_KEY: []}}),
("", {"spans": {SPANS_KEY: []}}),
]
@ -88,8 +88,8 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr
nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference)
)
example = Example(predicted, reference)
example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)]
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
example.reference.spans[SPANS_KEY] = [example.reference.char_span(5, 9)]
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
nlp.initialize()
ops = span_finder.model.ops
if predicted.text != reference.text:
@ -107,8 +107,8 @@ def test_span_finder_model():
nlp = Language()
docs = [nlp("This is an example."), nlp("This is the second example.")]
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]]
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]]
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
total_tokens = 0
for doc in docs:
@ -128,15 +128,15 @@ def test_span_finder_component():
nlp = Language()
docs = [nlp("This is an example."), nlp("This is the second example.")]
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]]
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]]
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
nlp.initialize()
docs = list(span_finder.pipe(docs))
# TODO: update hard-coded name
assert "span_candidates" in docs[0].spans
assert SPANS_KEY in docs[0].spans
@pytest.mark.parametrize(
@ -153,7 +153,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
config={
"max_length": max_length,
"min_length": min_length,
"training_key": TRAINING_KEY,
"spans_key": SPANS_KEY,
},
)
return
@ -162,7 +162,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
config={
"max_length": max_length,
"min_length": min_length,
"training_key": TRAINING_KEY,
"spans_key": SPANS_KEY,
},
)
nlp.initialize()
@ -182,8 +182,8 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
]
span_finder.set_annotations([doc], scores)
assert doc.spans[DEFAULT_PREDICTED_KEY]
assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count
assert doc.spans[SPANS_KEY]
assert len(doc.spans[SPANS_KEY]) == span_count
# Assert below will fail when max_length is set to 0
if max_length is None:
@ -193,40 +193,39 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
assert all(
min_length <= len(span) <= max_length
for span in doc.spans[DEFAULT_PREDICTED_KEY]
for span in doc.spans[SPANS_KEY]
)
def test_span_finder_suggester():
nlp = Language()
docs = [nlp("This is an example."), nlp("This is the second example.")]
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]]
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]]
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
nlp.initialize()
span_finder.set_annotations(docs, span_finder.predict(docs))
suggester = registry.misc.get("spacy.span_finder_suggester.v1")(
candidates_key="span_candidates"
candidates_key=SPANS_KEY
)
candidates = suggester(docs)
span_length = 0
for doc in docs:
span_length += len(doc.spans["span_candidates"])
span_length += len(doc.spans[SPANS_KEY])
assert span_length == len(candidates.dataXd)
assert type(candidates) == Ragged
assert len(candidates.dataXd[0]) == 2
# XXX Fails because i think the suggester is not correctly implemented?
def test_overfitting_IO():
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
fix_random_seed(0)
nlp = English()
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
train_examples = make_examples(nlp)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
assert span_finder.model.get_dim("nO") == 2
@ -239,30 +238,27 @@ def test_overfitting_IO():
# test the trained model
test_text = "I like London and Berlin"
doc = nlp(test_text)
spans = doc.spans[span_finder.predicted_key]
assert len(spans) == 2
assert len(spans.attrs["scores"]) == 2
assert min(spans.attrs["scores"]) > 0.9
assert set([span.text for span in spans]) == {"London", "Berlin"}
spans = doc.spans[span_finder.spans_key]
assert len(spans) == 3
assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"}
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
spans2 = doc2.spans[TRAINING_KEY]
assert len(spans2) == 2
assert len(spans2.attrs["scores"]) == 2
assert min(spans2.attrs["scores"]) > 0.9
assert set([span.text for span in spans2]) == {"London", "Berlin"}
spans2 = doc2.spans[span_finder.spans_key]
assert len(spans2) == 3
assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"}
# Test scoring
scores = nlp.evaluate(train_examples)
assert f"spans_{TRAINING_KEY}_f" in scores
assert scores[f"spans_{TRAINING_KEY}_p"] == 1.0
assert scores[f"spans_{TRAINING_KEY}_r"] == 1.0
assert scores[f"spans_{TRAINING_KEY}_f"] == 1.0
sf = nlp.get_pipe("span_finder")
print(sf.spans_key)
assert f"span_finder_{span_finder.spans_key}_f" in scores
# XXX Its not perfect 1.0 F1 because we want it to overgenerate for now.
assert scores[f"span_finder_{span_finder.spans_key}_f"] == 0.4
# also test that the spancat works for just a single entity in a sentence
doc = nlp("London")
assert len(doc.spans[span_finder.predicted_key]) == 1
assert len(doc.spans[span_finder.spans_key]) == 1