only use a single spans_key like in spancat

This commit is contained in:
kadarakos 2023-06-01 10:19:22 +00:00
parent 90af16af76
commit 6f750d0da6
2 changed files with 69 additions and 112 deletions

View File

@ -1,4 +1,3 @@
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast
from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate
@ -41,7 +40,6 @@ depth = 4
""" """
DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"] DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"]
DEFAULT_PREDICTED_KEY = "span_candidates"
@Language.factory( @Language.factory(
@ -50,21 +48,15 @@ DEFAULT_PREDICTED_KEY = "span_candidates"
default_config={ default_config={
"threshold": 0.5, "threshold": 0.5,
"model": DEFAULT_SPAN_FINDER_MODEL, "model": DEFAULT_SPAN_FINDER_MODEL,
"predicted_key": DEFAULT_PREDICTED_KEY, "spans_key": DEFAULT_SPANS_KEY,
"training_key": DEFAULT_SPANS_KEY,
# XXX Doesn't 0 seem bad compared to None instead?
"max_length": None, "max_length": None,
"min_length": None, "min_length": None,
"scorer": { "scorer": {"@scorers": "spacy.span_finder_scorer.v1"},
"@scorers": "spacy.span_finder_scorer.v1",
"predicted_key": DEFAULT_PREDICTED_KEY,
"training_key": DEFAULT_SPANS_KEY,
},
}, },
default_score_weights={ default_score_weights={
f"span_finder_{DEFAULT_PREDICTED_KEY}_f": 1.0, f"span_finder_{DEFAULT_SPANS_KEY}_f": 1.0,
f"span_finder_{DEFAULT_PREDICTED_KEY}_p": 0.0, f"span_finder_{DEFAULT_SPANS_KEY}_p": 0.0,
f"span_finder_{DEFAULT_PREDICTED_KEY}_r": 0.0, f"span_finder_{DEFAULT_SPANS_KEY}_r": 0.0,
}, },
) )
def make_span_finder( def make_span_finder(
@ -75,8 +67,7 @@ def make_span_finder(
threshold: float, threshold: float,
max_length: Optional[int], max_length: Optional[int],
min_length: Optional[int], min_length: Optional[int],
predicted_key: str = DEFAULT_PREDICTED_KEY, spans_key: str,
training_key: str = DEFAULT_SPANS_KEY,
) -> "SpanFinder": ) -> "SpanFinder":
"""Create a SpanFinder component. The component predicts whether a token is """Create a SpanFinder component. The component predicts whether a token is
the start or the end of a potential span. the start or the end of a potential span.
@ -84,10 +75,9 @@ def make_span_finder(
model (Model[List[Doc], Floats2d]): A model instance that model (Model[List[Doc], Floats2d]): A model instance that
is given a list of documents and predicts a probability for each token. is given a list of documents and predicts a probability for each token.
threshold (float): Minimum probability to consider a prediction positive. threshold (float): Minimum probability to consider a prediction positive.
predicted_key (str): Name of the span group the predicted spans are saved spans_key (str): Key of the doc.spans dict to save the spans under. During
to initialization and training, the component will look for spans on the
training_key (str): Name of the span group the training spans are read reference document under the same key.
from
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length. max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1. min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
""" """
@ -99,51 +89,26 @@ def make_span_finder(
scorer=scorer, scorer=scorer,
max_length=max_length, max_length=max_length,
min_length=min_length, min_length=min_length,
predicted_key=predicted_key, spans_key=spans_key,
training_key=training_key,
) )
@registry.scorers("spacy.span_finder_scorer.v1") @registry.scorers("spacy.span_finder_scorer.v1")
def make_span_finder_scorer( def make_span_finder_scorer():
predicted_key: str = DEFAULT_PREDICTED_KEY, return span_finder_score
training_key: str = DEFAULT_SPANS_KEY,
):
return partial(
span_finder_score, predicted_key=predicted_key, training_key=training_key
)
def span_finder_score( def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
examples: Iterable[Example],
*,
predicted_key: str = DEFAULT_PREDICTED_KEY,
training_key: str = DEFAULT_SPANS_KEY,
**kwargs,
) -> Dict[str, Any]:
kwargs = dict(kwargs) kwargs = dict(kwargs)
print(kwargs)
attr_prefix = "span_finder_" attr_prefix = "span_finder_"
kwargs.setdefault("attr", f"{attr_prefix}{predicted_key}") key = kwargs["spans_key"]
kwargs.setdefault("allow_overlap", True) kwargs.setdefault("attr", f"{attr_prefix}{key}")
kwargs.setdefault( kwargs.setdefault(
"getter", lambda doc, key: doc.spans.get(key[len(attr_prefix) :], []) "getter", lambda doc, key: doc.spans.get(key[len(attr_prefix) :], [])
) )
kwargs.setdefault("labeled", False) kwargs.setdefault("has_annotation", lambda doc: key in doc.spans)
kwargs.setdefault("has_annotation", lambda doc: predicted_key in doc.spans) return Scorer.score_spans(examples, **kwargs)
# score_spans can only score spans with the same key in both the reference
# and predicted docs, so temporarily copy the reference spans from the
# reference key to the candidates key in the reference docs, restoring the
# original span groups afterwards
orig_span_groups = []
for eg in examples:
orig_span_groups.append(eg.reference.spans.get(predicted_key))
if training_key in eg.reference.spans:
eg.reference.spans[predicted_key] = eg.reference.spans[training_key]
scores = Scorer.score_spans(examples, **kwargs)
for orig_span_group, eg in zip(orig_span_groups, examples):
if orig_span_group is not None:
eg.reference.spans[predicted_key] = orig_span_group
return scores
class _MaxInt(int): class _MaxInt(int):
@ -179,13 +144,8 @@ class SpanFinder(TrainablePipe):
max_length: Optional[int] = None, max_length: Optional[int] = None,
min_length: Optional[int] = None, min_length: Optional[int] = None,
# XXX I think this is weird and should be just None like in # XXX I think this is weird and should be just None like in
scorer: Optional[Callable] = partial( scorer: Optional[Callable] = span_finder_score,
span_finder_score, spans_key: str = DEFAULT_SPANS_KEY,
predicted_key=DEFAULT_PREDICTED_KEY,
training_key=DEFAULT_SPANS_KEY,
),
predicted_key: str = DEFAULT_PREDICTED_KEY,
training_key: str = DEFAULT_SPANS_KEY,
) -> None: ) -> None:
"""Initialize the span boundary detector. """Initialize the span boundary detector.
model (thinc.api.Model): The Thinc Model powering the pipeline component. model (thinc.api.Model): The Thinc Model powering the pipeline component.
@ -194,8 +154,9 @@ class SpanFinder(TrainablePipe):
threshold (float): Minimum probability to consider a prediction threshold (float): Minimum probability to consider a prediction
positive. positive.
scorer (Optional[Callable]): The scoring method. scorer (Optional[Callable]): The scoring method.
predicted_key (str): Name of the span group the candidate spans are saved to spans_key (str): Key of the doc.spans dict to save the spans under. During
training_key (str): Name of the span group the training spans are read from initialization and training, the component will look for spans on the
reference document under the same key.
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length. max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1. min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
""" """
@ -211,11 +172,11 @@ class SpanFinder(TrainablePipe):
) )
self.min_length = min_length self.min_length = min_length
self.max_length = max_length self.max_length = max_length
self.predicted_key = predicted_key self.spans_key = spans_key
self.training_key = training_key
self.model = model self.model = model
self.name = name self.name = name
self.scorer = scorer self.scorer = scorer
self.cfg = {"spans_key": spans_key}
def predict(self, docs: Iterable[Doc]): def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
@ -232,7 +193,7 @@ class SpanFinder(TrainablePipe):
""" """
offset = 0 offset = 0
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.spans[self.predicted_key] = [] doc.spans[self.spans_key] = []
starts = [] starts = []
ends = [] ends = []
doc_scores = scores[offset : offset + len(doc)] doc_scores = scores[offset : offset + len(doc)]
@ -249,7 +210,7 @@ class SpanFinder(TrainablePipe):
if span_length > self.max_length: if span_length > self.max_length:
break break
elif self.min_length <= span_length: elif self.min_length <= span_length:
doc.spans[self.predicted_key].append(doc[start : end + 1]) doc.spans[self.spans_key].append(doc[start : end + 1])
def update( def update(
self, self,
@ -304,8 +265,8 @@ class SpanFinder(TrainablePipe):
n_tokens = len(eg.predicted) n_tokens = len(eg.predicted)
truth = ops.xp.zeros((n_tokens, 2), dtype="float32") truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
mask = ops.xp.ones((n_tokens, 2), dtype="float32") mask = ops.xp.ones((n_tokens, 2), dtype="float32")
if self.training_key in eg.reference.spans: if self.spans_key in eg.reference.spans:
for span in eg.reference.spans[self.training_key]: for span in eg.reference.spans[self.spans_key]:
ref_start_char, ref_end_char = _char_indices(span) ref_start_char, ref_end_char = _char_indices(span)
pred_span = eg.predicted.char_span( pred_span = eg.predicted.char_span(
ref_start_char, ref_end_char, alignment_mode="expand" ref_start_char, ref_end_char, alignment_mode="expand"
@ -342,8 +303,8 @@ class SpanFinder(TrainablePipe):
start_indices = set() start_indices = set()
end_indices = set() end_indices = set()
if self.training_key in doc.spans: if self.spans_key in doc.spans:
for span in doc.spans[self.training_key]: for span in doc.spans[self.spans_key]:
start_indices.add(span.start) start_indices.add(span.start)
end_indices.add(span.end - 1) end_indices.add(span.end - 1)

View File

@ -4,7 +4,7 @@ from thinc.types import Ragged
from spacy.language import Language from spacy.language import Language
from spacy.lang.en import English from spacy.lang.en import English
from spacy.pipeline.span_finder import DEFAULT_PREDICTED_KEY, span_finder_default_config from spacy.pipeline.span_finder import span_finder_default_config
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.training import Example from spacy.training import Example
from spacy import util from spacy import util
@ -12,22 +12,22 @@ from spacy.util import registry
from spacy.util import fix_random_seed, make_tempdir from spacy.util import fix_random_seed, make_tempdir
TRAINING_KEY = "pytest" SPANS_KEY = "pytest"
TRAIN_DATA = [ TRAIN_DATA = [
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}), ("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
( (
"I like London and Berlin.", "I like London and Berlin.",
{"spans": {TRAINING_KEY: [(7, 13, "LOC"), (18, 24)]}}, {"spans": {SPANS_KEY: [(7, 13), (18, 24)]}},
), ),
] ]
TRAIN_DATA_OVERLAPPING = [ TRAIN_DATA_OVERLAPPING = [
("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}), ("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}),
( (
"I like London and Berlin", "I like London and Berlin",
{"spans": {TRAINING_KEY: [(7, 13), (18, 24), (7, 24)]}}, {"spans": {SPANS_KEY: [(7, 13), (18, 24), (7, 24)]}},
), ),
("", {"spans": {TRAINING_KEY: []}}), ("", {"spans": {SPANS_KEY: []}}),
] ]
@ -88,8 +88,8 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr
nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference) nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference)
) )
example = Example(predicted, reference) example = Example(predicted, reference)
example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)] example.reference.spans[SPANS_KEY] = [example.reference.char_span(5, 9)]
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
nlp.initialize() nlp.initialize()
ops = span_finder.model.ops ops = span_finder.model.ops
if predicted.text != reference.text: if predicted.text != reference.text:
@ -107,8 +107,8 @@ def test_span_finder_model():
nlp = Language() nlp = Language()
docs = [nlp("This is an example."), nlp("This is the second example.")] docs = [nlp("This is an example."), nlp("This is the second example.")]
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]] docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]] docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
total_tokens = 0 total_tokens = 0
for doc in docs: for doc in docs:
@ -128,15 +128,15 @@ def test_span_finder_component():
nlp = Language() nlp = Language()
docs = [nlp("This is an example."), nlp("This is the second example.")] docs = [nlp("This is an example."), nlp("This is the second example.")]
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]] docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]] docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
nlp.initialize() nlp.initialize()
docs = list(span_finder.pipe(docs)) docs = list(span_finder.pipe(docs))
# TODO: update hard-coded name # TODO: update hard-coded name
assert "span_candidates" in docs[0].spans assert SPANS_KEY in docs[0].spans
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -153,7 +153,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
config={ config={
"max_length": max_length, "max_length": max_length,
"min_length": min_length, "min_length": min_length,
"training_key": TRAINING_KEY, "spans_key": SPANS_KEY,
}, },
) )
return return
@ -162,7 +162,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
config={ config={
"max_length": max_length, "max_length": max_length,
"min_length": min_length, "min_length": min_length,
"training_key": TRAINING_KEY, "spans_key": SPANS_KEY,
}, },
) )
nlp.initialize() nlp.initialize()
@ -182,8 +182,8 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
] ]
span_finder.set_annotations([doc], scores) span_finder.set_annotations([doc], scores)
assert doc.spans[DEFAULT_PREDICTED_KEY] assert doc.spans[SPANS_KEY]
assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count assert len(doc.spans[SPANS_KEY]) == span_count
# Assert below will fail when max_length is set to 0 # Assert below will fail when max_length is set to 0
if max_length is None: if max_length is None:
@ -193,40 +193,39 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count):
assert all( assert all(
min_length <= len(span) <= max_length min_length <= len(span) <= max_length
for span in doc.spans[DEFAULT_PREDICTED_KEY] for span in doc.spans[SPANS_KEY]
) )
def test_span_finder_suggester(): def test_span_finder_suggester():
nlp = Language() nlp = Language()
docs = [nlp("This is an example."), nlp("This is the second example.")] docs = [nlp("This is an example."), nlp("This is the second example.")]
docs[0].spans[TRAINING_KEY] = [docs[0][3:4]] docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
docs[1].spans[TRAINING_KEY] = [docs[1][3:5]] docs[1].spans[SPANS_KEY] = [docs[1][3:5]]
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
nlp.initialize() nlp.initialize()
span_finder.set_annotations(docs, span_finder.predict(docs)) span_finder.set_annotations(docs, span_finder.predict(docs))
suggester = registry.misc.get("spacy.span_finder_suggester.v1")( suggester = registry.misc.get("spacy.span_finder_suggester.v1")(
candidates_key="span_candidates" candidates_key=SPANS_KEY
) )
candidates = suggester(docs) candidates = suggester(docs)
span_length = 0 span_length = 0
for doc in docs: for doc in docs:
span_length += len(doc.spans["span_candidates"]) span_length += len(doc.spans[SPANS_KEY])
assert span_length == len(candidates.dataXd) assert span_length == len(candidates.dataXd)
assert type(candidates) == Ragged assert type(candidates) == Ragged
assert len(candidates.dataXd[0]) == 2 assert len(candidates.dataXd[0]) == 2
# XXX Fails because i think the suggester is not correctly implemented?
def test_overfitting_IO(): def test_overfitting_IO():
# Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly # Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly
fix_random_seed(0) fix_random_seed(0)
nlp = English() nlp = English()
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY})
train_examples = make_examples(nlp) train_examples = make_examples(nlp)
optimizer = nlp.initialize(get_examples=lambda: train_examples) optimizer = nlp.initialize(get_examples=lambda: train_examples)
assert span_finder.model.get_dim("nO") == 2 assert span_finder.model.get_dim("nO") == 2
@ -239,30 +238,27 @@ def test_overfitting_IO():
# test the trained model # test the trained model
test_text = "I like London and Berlin" test_text = "I like London and Berlin"
doc = nlp(test_text) doc = nlp(test_text)
spans = doc.spans[span_finder.predicted_key] spans = doc.spans[span_finder.spans_key]
assert len(spans) == 2 assert len(spans) == 3
assert len(spans.attrs["scores"]) == 2 assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"}
assert min(spans.attrs["scores"]) > 0.9
assert set([span.text for span in spans]) == {"London", "Berlin"}
# Also test the results are still the same after IO # Also test the results are still the same after IO
with make_tempdir() as tmp_dir: with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir) nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir) nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text) doc2 = nlp2(test_text)
spans2 = doc2.spans[TRAINING_KEY] spans2 = doc2.spans[span_finder.spans_key]
assert len(spans2) == 2 assert len(spans2) == 3
assert len(spans2.attrs["scores"]) == 2 assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"}
assert min(spans2.attrs["scores"]) > 0.9
assert set([span.text for span in spans2]) == {"London", "Berlin"}
# Test scoring # Test scoring
scores = nlp.evaluate(train_examples) scores = nlp.evaluate(train_examples)
assert f"spans_{TRAINING_KEY}_f" in scores sf = nlp.get_pipe("span_finder")
assert scores[f"spans_{TRAINING_KEY}_p"] == 1.0 print(sf.spans_key)
assert scores[f"spans_{TRAINING_KEY}_r"] == 1.0 assert f"span_finder_{span_finder.spans_key}_f" in scores
assert scores[f"spans_{TRAINING_KEY}_f"] == 1.0 # XXX Its not perfect 1.0 F1 because we want it to overgenerate for now.
assert scores[f"span_finder_{span_finder.spans_key}_f"] == 0.4
# also test that the spancat works for just a single entity in a sentence # also test that the spancat works for just a single entity in a sentence
doc = nlp("London") doc = nlp("London")
assert len(doc.spans[span_finder.predicted_key]) == 1 assert len(doc.spans[span_finder.spans_key]) == 1