Move settings to self.cfg, store min/max unset as None

This commit is contained in:
Adriane Boyd 2023-06-05 10:18:36 +02:00
parent ce4d33e726
commit dac12fb684
2 changed files with 32 additions and 40 deletions

View File

@ -117,20 +117,6 @@ def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
return scores return scores
class _MaxInt(int):
def __le__(self, other):
return False
def __lt__(self, other):
return False
def __ge__(self, other):
return True
def __gt__(self, other):
return True
def _char_indices(span: Span) -> Tuple[int, int]: def _char_indices(span: Span) -> Tuple[int, int]:
start = span[0].idx start = span[0].idx
end = span[-1].idx + len(span[-1]) end = span[-1].idx + len(span[-1])
@ -163,25 +149,24 @@ class SpanFinder(TrainablePipe):
initialization and training, the component will look for spans on the initialization and training, the component will look for spans on the
reference document under the same key. reference document under the same key.
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length. max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1. min_length (Optional[int]): Min length of the produced spans, defaults to None meaning shortest span is length 1.
""" """
self.vocab = nlp.vocab self.vocab = nlp.vocab
self.threshold = threshold if (max_length is not None and max_length < 1) or (
if max_length is None: min_length is not None and min_length < 1
max_length = _MaxInt() ):
if min_length is None:
min_length = 1
if max_length < 1 or min_length < 1:
raise ValueError( raise ValueError(
Errors.E1053.format(min_length=min_length, max_length=max_length) Errors.E1053.format(min_length=min_length, max_length=max_length)
) )
self.min_length = min_length
self.max_length = max_length
self.spans_key = spans_key
self.model = model self.model = model
self.name = name self.name = name
self.scorer = scorer self.scorer = scorer
self.cfg = {"spans_key": spans_key} self.cfg: Dict[str, Any] = {
"min_length": min_length,
"max_length": max_length,
"threshold": threshold,
"spans_key": spans_key,
}
def predict(self, docs: Iterable[Doc]): def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
@ -198,24 +183,31 @@ class SpanFinder(TrainablePipe):
""" """
offset = 0 offset = 0
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.spans[self.spans_key] = [] doc.spans[self.cfg["spans_key"]] = []
starts = [] starts = []
ends = [] ends = []
doc_scores = scores[offset : offset + len(doc)] doc_scores = scores[offset : offset + len(doc)]
for token, token_score in zip(doc, doc_scores): for token, token_score in zip(doc, doc_scores):
if token_score[0] >= self.threshold: if token_score[0] >= self.cfg["threshold"]:
starts.append(token.i) starts.append(token.i)
if token_score[1] >= self.threshold: if token_score[1] >= self.cfg["threshold"]:
ends.append(token.i) ends.append(token.i)
print(self.cfg)
for start in starts: for start in starts:
for end in ends: for end in ends:
span_length = end + 1 - start span_length = end + 1 - start
if span_length > self.max_length: if span_length < 1:
break continue
elif self.min_length <= span_length: if (
doc.spans[self.spans_key].append(doc[start : end + 1]) self.cfg["min_length"] is None
or self.cfg["min_length"] <= span_length
) and (
self.cfg["max_length"] is None
or span_length <= self.cfg["max_length"]
):
doc.spans[self.cfg["spans_key"]].append(doc[start : end + 1])
offset += len(doc) offset += len(doc)
def update( def update(
@ -271,8 +263,8 @@ class SpanFinder(TrainablePipe):
n_tokens = len(eg.predicted) n_tokens = len(eg.predicted)
truth = ops.xp.zeros((n_tokens, 2), dtype="float32") truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
mask = ops.xp.ones((n_tokens, 2), dtype="float32") mask = ops.xp.ones((n_tokens, 2), dtype="float32")
if self.spans_key in eg.reference.spans: if self.cfg["spans_key"] in eg.reference.spans:
for span in eg.reference.spans[self.spans_key]: for span in eg.reference.spans[self.cfg["spans_key"]]:
ref_start_char, ref_end_char = _char_indices(span) ref_start_char, ref_end_char = _char_indices(span)
pred_span = eg.predicted.char_span( pred_span = eg.predicted.char_span(
ref_start_char, ref_end_char, alignment_mode="expand" ref_start_char, ref_end_char, alignment_mode="expand"

View File

@ -209,7 +209,7 @@ def test_overfitting_IO():
# test the trained model # test the trained model
test_text = "I like London and Berlin" test_text = "I like London and Berlin"
doc = nlp(test_text) doc = nlp(test_text)
spans = doc.spans[span_finder.spans_key] spans = doc.spans[SPANS_KEY]
assert len(spans) == 3 assert len(spans) == 3
assert set([span.text for span in spans]) == { assert set([span.text for span in spans]) == {
"London", "London",
@ -222,7 +222,7 @@ def test_overfitting_IO():
nlp.to_disk(tmp_dir) nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir) nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text) doc2 = nlp2(test_text)
spans2 = doc2.spans[span_finder.spans_key] spans2 = doc2.spans[SPANS_KEY]
assert len(spans2) == 3 assert len(spans2) == 3
assert set([span.text for span in spans2]) == { assert set([span.text for span in spans2]) == {
"London", "London",
@ -232,11 +232,11 @@ def test_overfitting_IO():
# Test scoring # Test scoring
scores = nlp.evaluate(train_examples) scores = nlp.evaluate(train_examples)
assert f"span_finder_{span_finder.spans_key}_f" in scores assert f"span_finder_{SPANS_KEY}_f" in scores
# It's not perfect 1.0 F1 because it's designed to overgenerate for now. # It's not perfect 1.0 F1 because it's designed to overgenerate for now.
assert scores[f"span_finder_{span_finder.spans_key}_p"] == 0.75 assert scores[f"span_finder_{SPANS_KEY}_p"] == 0.75
assert scores[f"span_finder_{span_finder.spans_key}_r"] == 1.0 assert scores[f"span_finder_{SPANS_KEY}_r"] == 1.0
# also test that the spancat works for just a single entity in a sentence # also test that the spancat works for just a single entity in a sentence
doc = nlp("London") doc = nlp("London")
assert len(doc.spans[span_finder.spans_key]) == 1 assert len(doc.spans[SPANS_KEY]) == 1