Skip duplicate spans in Doc.retokenize (#4339)

This commit is contained in:
Ines Montani 2019-09-30 12:43:48 +02:00 committed by Matthew Honnibal
parent 71bd040834
commit f7d1736241
2 changed files with 16 additions and 0 deletions

View File

@ -414,3 +414,14 @@ def test_doc_retokenizer_merge_lex_attrs(en_vocab):
assert doc[1].is_stop assert doc[1].is_stop
assert not doc[0].is_stop assert not doc[0].is_stop
assert not doc[1].like_num assert not doc[1].like_num
def test_retokenize_skip_duplicates(en_vocab):
"""Test that the retokenizer automatically skips duplicate spans instead
of complaining about overlaps. See #3687."""
doc = Doc(en_vocab, words=["hello", "world", "!"])
with doc.retokenize() as retokenizer:
retokenizer.merge(doc[0:2])
retokenizer.merge(doc[0:2])
assert len(doc) == 2
assert doc[0].text == "hello world"

View File

@ -35,12 +35,14 @@ cdef class Retokenizer:
cdef list merges cdef list merges
cdef list splits cdef list splits
cdef set tokens_to_merge cdef set tokens_to_merge
cdef list _spans_to_merge
def __init__(self, doc): def __init__(self, doc):
self.doc = doc self.doc = doc
self.merges = [] self.merges = []
self.splits = [] self.splits = []
self.tokens_to_merge = set() self.tokens_to_merge = set()
self._spans_to_merge = [] # keep a record to filter out duplicates
def merge(self, Span span, attrs=SimpleFrozenDict()): def merge(self, Span span, attrs=SimpleFrozenDict()):
"""Mark a span for merging. The attrs will be applied to the resulting """Mark a span for merging. The attrs will be applied to the resulting
@ -51,10 +53,13 @@ cdef class Retokenizer:
DOCS: https://spacy.io/api/doc#retokenizer.merge DOCS: https://spacy.io/api/doc#retokenizer.merge
""" """
if (span.start, span.end) in self._spans_to_merge:
return
for token in span: for token in span:
if token.i in self.tokens_to_merge: if token.i in self.tokens_to_merge:
raise ValueError(Errors.E102.format(token=repr(token))) raise ValueError(Errors.E102.format(token=repr(token)))
self.tokens_to_merge.add(token.i) self.tokens_to_merge.add(token.i)
self._spans_to_merge.append((span.start, span.end))
if "_" in attrs: # Extension attributes if "_" in attrs: # Extension attributes
extensions = attrs["_"] extensions = attrs["_"]
_validate_extensions(extensions) _validate_extensions(extensions)