mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-03 20:00:21 +03:00
Move settings to self.cfg, store min/max unset as None
This commit is contained in:
parent
ce4d33e726
commit
dac12fb684
|
@ -117,20 +117,6 @@ def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
|||
return scores
|
||||
|
||||
|
||||
class _MaxInt(int):
|
||||
def __le__(self, other):
|
||||
return False
|
||||
|
||||
def __lt__(self, other):
|
||||
return False
|
||||
|
||||
def __ge__(self, other):
|
||||
return True
|
||||
|
||||
def __gt__(self, other):
|
||||
return True
|
||||
|
||||
|
||||
def _char_indices(span: Span) -> Tuple[int, int]:
|
||||
start = span[0].idx
|
||||
end = span[-1].idx + len(span[-1])
|
||||
|
@ -163,25 +149,24 @@ class SpanFinder(TrainablePipe):
|
|||
initialization and training, the component will look for spans on the
|
||||
reference document under the same key.
|
||||
max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length.
|
||||
min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1.
|
||||
min_length (Optional[int]): Min length of the produced spans, defaults to None meaning shortest span is length 1.
|
||||
"""
|
||||
self.vocab = nlp.vocab
|
||||
self.threshold = threshold
|
||||
if max_length is None:
|
||||
max_length = _MaxInt()
|
||||
if min_length is None:
|
||||
min_length = 1
|
||||
if max_length < 1 or min_length < 1:
|
||||
if (max_length is not None and max_length < 1) or (
|
||||
min_length is not None and min_length < 1
|
||||
):
|
||||
raise ValueError(
|
||||
Errors.E1053.format(min_length=min_length, max_length=max_length)
|
||||
)
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
self.spans_key = spans_key
|
||||
self.model = model
|
||||
self.name = name
|
||||
self.scorer = scorer
|
||||
self.cfg = {"spans_key": spans_key}
|
||||
self.cfg: Dict[str, Any] = {
|
||||
"min_length": min_length,
|
||||
"max_length": max_length,
|
||||
"threshold": threshold,
|
||||
"spans_key": spans_key,
|
||||
}
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
@ -198,24 +183,31 @@ class SpanFinder(TrainablePipe):
|
|||
"""
|
||||
offset = 0
|
||||
for i, doc in enumerate(docs):
|
||||
doc.spans[self.spans_key] = []
|
||||
doc.spans[self.cfg["spans_key"]] = []
|
||||
starts = []
|
||||
ends = []
|
||||
doc_scores = scores[offset : offset + len(doc)]
|
||||
|
||||
for token, token_score in zip(doc, doc_scores):
|
||||
if token_score[0] >= self.threshold:
|
||||
if token_score[0] >= self.cfg["threshold"]:
|
||||
starts.append(token.i)
|
||||
if token_score[1] >= self.threshold:
|
||||
if token_score[1] >= self.cfg["threshold"]:
|
||||
ends.append(token.i)
|
||||
|
||||
print(self.cfg)
|
||||
for start in starts:
|
||||
for end in ends:
|
||||
span_length = end + 1 - start
|
||||
if span_length > self.max_length:
|
||||
break
|
||||
elif self.min_length <= span_length:
|
||||
doc.spans[self.spans_key].append(doc[start : end + 1])
|
||||
if span_length < 1:
|
||||
continue
|
||||
if (
|
||||
self.cfg["min_length"] is None
|
||||
or self.cfg["min_length"] <= span_length
|
||||
) and (
|
||||
self.cfg["max_length"] is None
|
||||
or span_length <= self.cfg["max_length"]
|
||||
):
|
||||
doc.spans[self.cfg["spans_key"]].append(doc[start : end + 1])
|
||||
offset += len(doc)
|
||||
|
||||
def update(
|
||||
|
@ -271,8 +263,8 @@ class SpanFinder(TrainablePipe):
|
|||
n_tokens = len(eg.predicted)
|
||||
truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
|
||||
mask = ops.xp.ones((n_tokens, 2), dtype="float32")
|
||||
if self.spans_key in eg.reference.spans:
|
||||
for span in eg.reference.spans[self.spans_key]:
|
||||
if self.cfg["spans_key"] in eg.reference.spans:
|
||||
for span in eg.reference.spans[self.cfg["spans_key"]]:
|
||||
ref_start_char, ref_end_char = _char_indices(span)
|
||||
pred_span = eg.predicted.char_span(
|
||||
ref_start_char, ref_end_char, alignment_mode="expand"
|
||||
|
|
|
@ -209,7 +209,7 @@ def test_overfitting_IO():
|
|||
# test the trained model
|
||||
test_text = "I like London and Berlin"
|
||||
doc = nlp(test_text)
|
||||
spans = doc.spans[span_finder.spans_key]
|
||||
spans = doc.spans[SPANS_KEY]
|
||||
assert len(spans) == 3
|
||||
assert set([span.text for span in spans]) == {
|
||||
"London",
|
||||
|
@ -222,7 +222,7 @@ def test_overfitting_IO():
|
|||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
doc2 = nlp2(test_text)
|
||||
spans2 = doc2.spans[span_finder.spans_key]
|
||||
spans2 = doc2.spans[SPANS_KEY]
|
||||
assert len(spans2) == 3
|
||||
assert set([span.text for span in spans2]) == {
|
||||
"London",
|
||||
|
@ -232,11 +232,11 @@ def test_overfitting_IO():
|
|||
|
||||
# Test scoring
|
||||
scores = nlp.evaluate(train_examples)
|
||||
assert f"span_finder_{span_finder.spans_key}_f" in scores
|
||||
assert f"span_finder_{SPANS_KEY}_f" in scores
|
||||
# It's not perfect 1.0 F1 because it's designed to overgenerate for now.
|
||||
assert scores[f"span_finder_{span_finder.spans_key}_p"] == 0.75
|
||||
assert scores[f"span_finder_{span_finder.spans_key}_r"] == 1.0
|
||||
assert scores[f"span_finder_{SPANS_KEY}_p"] == 0.75
|
||||
assert scores[f"span_finder_{SPANS_KEY}_r"] == 1.0
|
||||
|
||||
# also test that the spancat works for just a single entity in a sentence
|
||||
doc = nlp("London")
|
||||
assert len(doc.spans[span_finder.spans_key]) == 1
|
||||
assert len(doc.spans[SPANS_KEY]) == 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user