Move settings to self.cfg, store min/max unset as None

This commit is contained in:
Adriane Boyd 2023-06-05 10:18:36 +02:00
parent ce4d33e726
commit dac12fb684
2 changed files with 32 additions and 40 deletions

View File

@ -117,20 +117,6 @@ def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
return scores
class _MaxInt(int):
def __le__(self, other):
return False
def __lt__(self, other):
return False
def __ge__(self, other):
return True
def __gt__(self, other):
return True
def _char_indices(span: Span) -> Tuple[int, int]:
start = span[0].idx
end = span[-1].idx + len(span[-1])
@ -163,25 +149,24 @@ class SpanFinder(TrainablePipe):
initialization and training, the component will look for spans on the
reference document under the same key.
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
min_length (Optional[int]): Min length of the produced spans, defaults to None meaning shortest span is length 1.
"""
self.vocab = nlp.vocab
self.threshold = threshold
if max_length is None:
max_length = _MaxInt()
if min_length is None:
min_length = 1
if max_length < 1 or min_length < 1:
if (max_length is not None and max_length < 1) or (
min_length is not None and min_length < 1
):
raise ValueError(
Errors.E1053.format(min_length=min_length, max_length=max_length)
)
self.min_length = min_length
self.max_length = max_length
self.spans_key = spans_key
self.model = model
self.name = name
self.scorer = scorer
self.cfg = {"spans_key": spans_key}
self.cfg: Dict[str, Any] = {
"min_length": min_length,
"max_length": max_length,
"threshold": threshold,
"spans_key": spans_key,
}
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.
@ -198,24 +183,31 @@ class SpanFinder(TrainablePipe):
"""
offset = 0
for i, doc in enumerate(docs):
doc.spans[self.spans_key] = []
doc.spans[self.cfg["spans_key"]] = []
starts = []
ends = []
doc_scores = scores[offset : offset + len(doc)]
for token, token_score in zip(doc, doc_scores):
if token_score[0] >= self.threshold:
if token_score[0] >= self.cfg["threshold"]:
starts.append(token.i)
if token_score[1] >= self.threshold:
if token_score[1] >= self.cfg["threshold"]:
ends.append(token.i)
print(self.cfg)
for start in starts:
for end in ends:
span_length = end + 1 - start
if span_length > self.max_length:
break
elif self.min_length <= span_length:
doc.spans[self.spans_key].append(doc[start : end + 1])
if span_length < 1:
continue
if (
self.cfg["min_length"] is None
or self.cfg["min_length"] <= span_length
) and (
self.cfg["max_length"] is None
or span_length <= self.cfg["max_length"]
):
doc.spans[self.cfg["spans_key"]].append(doc[start : end + 1])
offset += len(doc)
def update(
@ -271,8 +263,8 @@ class SpanFinder(TrainablePipe):
n_tokens = len(eg.predicted)
truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
mask = ops.xp.ones((n_tokens, 2), dtype="float32")
if self.spans_key in eg.reference.spans:
for span in eg.reference.spans[self.spans_key]:
if self.cfg["spans_key"] in eg.reference.spans:
for span in eg.reference.spans[self.cfg["spans_key"]]:
ref_start_char, ref_end_char = _char_indices(span)
pred_span = eg.predicted.char_span(
ref_start_char, ref_end_char, alignment_mode="expand"

View File

@ -209,7 +209,7 @@ def test_overfitting_IO():
# test the trained model
test_text = "I like London and Berlin"
doc = nlp(test_text)
spans = doc.spans[span_finder.spans_key]
spans = doc.spans[SPANS_KEY]
assert len(spans) == 3
assert set([span.text for span in spans]) == {
"London",
@ -222,7 +222,7 @@ def test_overfitting_IO():
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
spans2 = doc2.spans[span_finder.spans_key]
spans2 = doc2.spans[SPANS_KEY]
assert len(spans2) == 3
assert set([span.text for span in spans2]) == {
"London",
@ -232,11 +232,11 @@ def test_overfitting_IO():
# Test scoring
scores = nlp.evaluate(train_examples)
assert f"span_finder_{span_finder.spans_key}_f" in scores
assert f"span_finder_{SPANS_KEY}_f" in scores
# It's not perfect 1.0 F1 because it's designed to overgenerate for now.
assert scores[f"span_finder_{span_finder.spans_key}_p"] == 0.75
assert scores[f"span_finder_{span_finder.spans_key}_r"] == 1.0
assert scores[f"span_finder_{SPANS_KEY}_p"] == 0.75
assert scores[f"span_finder_{SPANS_KEY}_r"] == 1.0
# also test that the spancat works for just a single entity in a sentence
doc = nlp("London")
assert len(doc.spans[span_finder.spans_key]) == 1
assert len(doc.spans[SPANS_KEY]) == 1