mirror of
https://github.com/explosion/spaCy.git
synced 2025-09-18 18:12:45 +03:00
avoid repetitive entities in the output
This commit is contained in:
parent
0d81bce9cc
commit
0aa1083ce8
|
@ -379,7 +379,7 @@ def ant_scorer_forward(
|
||||||
|
|
||||||
scores = pw_prod + pw_sum + mask
|
scores = pw_prod + pw_sum + mask
|
||||||
|
|
||||||
top_scores, top_scores_idx = topk(xp, scores, ant_limit)
|
top_scores, top_scores_idx = topk(xp, scores, min(ant_limit, len(scores)))
|
||||||
out.append((top_scores, top_scores_idx))
|
out.append((top_scores, top_scores_idx))
|
||||||
|
|
||||||
# In the full model these scores can be further refined. In the current
|
# In the full model these scores can be further refined. In the current
|
||||||
|
|
|
@ -109,16 +109,15 @@ def get_predicted_clusters(
|
||||||
|
|
||||||
def get_sentence_map(doc: Doc):
|
def get_sentence_map(doc: Doc):
|
||||||
"""For the given span, return a list of sentence indexes."""
|
"""For the given span, return a list of sentence indexes."""
|
||||||
|
if doc.is_sentenced:
|
||||||
try:
|
|
||||||
si = 0
|
si = 0
|
||||||
out = []
|
out = []
|
||||||
for sent in doc.sents:
|
for sent in doc.sents:
|
||||||
for tok in sent:
|
for _ in sent:
|
||||||
out.append(si)
|
out.append(si)
|
||||||
si += 1
|
si += 1
|
||||||
return out
|
return out
|
||||||
except ValueError:
|
else:
|
||||||
# If there are no sents then just return dummy values.
|
# If there are no sents then just return dummy values.
|
||||||
# Shouldn't happen in general training, but typical in init.
|
# Shouldn't happen in general training, but typical in init.
|
||||||
return [0] * len(doc)
|
return [0] * len(doc)
|
||||||
|
@ -198,8 +197,9 @@ def select_non_crossing_spans(
|
||||||
|
|
||||||
# sort idxs by order in doc
|
# sort idxs by order in doc
|
||||||
selected = sorted(selected, key=lambda idx: (starts[idx], ends[idx]))
|
selected = sorted(selected, key=lambda idx: (starts[idx], ends[idx]))
|
||||||
while len(selected) < limit:
|
# This was causing many repetitive entities in the output - removed for now
|
||||||
selected.append(selected[0]) # this seems a bit weird?
|
# while len(selected) < limit:
|
||||||
|
# selected.append(selected[0]) # this seems a bit weird?
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
import spacy
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
@ -50,8 +52,9 @@ def test_initialized(nlp):
|
||||||
assert nlp.pipe_names == ["coref"]
|
assert nlp.pipe_names == ["coref"]
|
||||||
text = "She gave me her pen."
|
text = "She gave me her pen."
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
# TODO: The results of this are weird & non-deterministic
|
for k, v in doc.spans.items():
|
||||||
print(doc.spans)
|
# Ensure there are no "She, She, She, She, She, ..." problems
|
||||||
|
assert len(v) <= 15
|
||||||
|
|
||||||
|
|
||||||
def test_initialized_short(nlp):
|
def test_initialized_short(nlp):
|
||||||
|
@ -73,6 +76,28 @@ def test_initialized_2(nlp):
|
||||||
print(nlp(text).spans)
|
print(nlp(text).spans)
|
||||||
|
|
||||||
|
|
||||||
|
def test_coref_serialization(nlp):
|
||||||
|
# Test that the coref component can be serialized
|
||||||
|
nlp.add_pipe("coref", last=True)
|
||||||
|
nlp.initialize()
|
||||||
|
assert nlp.pipe_names == ["coref"]
|
||||||
|
text = "She gave me her pen."
|
||||||
|
doc = nlp(text)
|
||||||
|
spans_result = doc.spans
|
||||||
|
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
nlp.to_disk(tmp_dir)
|
||||||
|
nlp2 = spacy.load(tmp_dir)
|
||||||
|
assert nlp2.pipe_names == ["coref"]
|
||||||
|
doc2 = nlp2(text)
|
||||||
|
spans_result2 = doc2.spans
|
||||||
|
print(1, [(k, len(v)) for k, v in spans_result.items()])
|
||||||
|
print(2, [(k, len(v)) for k, v in spans_result2.items()])
|
||||||
|
for k, v in spans_result.items():
|
||||||
|
assert spans_result[k] == spans_result2[k]
|
||||||
|
# assert spans_result == spans_result2
|
||||||
|
|
||||||
|
|
||||||
def test_overfitting_IO(nlp):
|
def test_overfitting_IO(nlp):
|
||||||
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
||||||
train_examples = []
|
train_examples = []
|
||||||
|
@ -90,7 +115,7 @@ def test_overfitting_IO(nlp):
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
print(i, doc.spans)
|
print(i, doc.spans)
|
||||||
print(losses["coref"]) # < 0.001
|
print(losses["coref"]) # < 0.001
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user