Refactor seen token detection

This commit is contained in:
Adriane Boyd 2020-09-22 14:42:51 +02:00
parent 535842e483
commit b1a7d6c528
2 changed files with 4 additions and 24 deletions

View File

@ -690,8 +690,8 @@ class Errors:
"in more than one span in entities, blocked, missing or outside.") "in more than one span in entities, blocked, missing or outside.")
E1011 = ("Unsupported default '{default}' in doc.set_ents. Available " E1011 = ("Unsupported default '{default}' in doc.set_ents. Available "
"options: {modes}") "options: {modes}")
E1012 = ("Spans provided to doc.set_ents must be provided as a list of " E1012 = ("Entity spans and blocked/missing/outside spans should be "
"`Span` objects.") "provided to doc.set_ents as lists of `Span` objects.")
E1013 = ("Unable to set entity for span with empty label. Entity spans are " E1013 = ("Unable to set entity for span with empty label. Entity spans are "
"required to have a label. To set entity information as missing " "required to have a label. To set entity information as missing "
"or blocked, use the keyword arguments with doc.set_ents.") "or blocked, use the keyword arguments with doc.set_ents.")

View File

@ -8,6 +8,7 @@ from libc.stdint cimport int32_t, uint64_t
import copy import copy
from collections import Counter from collections import Counter
from enum import Enum from enum import Enum
import itertools
import numpy import numpy
import srsly import srsly
from thinc.api import get_array_module from thinc.api import get_array_module
@ -742,28 +743,7 @@ cdef class Doc:
# Find all tokens covered by spans and check that none are overlapping # Find all tokens covered by spans and check that none are overlapping
seen_tokens = set() seen_tokens = set()
for span in entities: for span in itertools.chain.from_iterable([entities, blocked, missing, outside]):
if not isinstance(span, Span):
raise ValueError(Errors.E1012.format(span=span))
for i in range(span.start, span.end):
if i in seen_tokens:
raise ValueError(Errors.E1010.format(i=i))
seen_tokens.add(i)
for span in blocked:
if not isinstance(span, Span):
raise ValueError(Errors.E1012.format(span=span))
for i in range(span.start, span.end):
if i in seen_tokens:
raise ValueError(Errors.E1010.format(i=i))
seen_tokens.add(i)
for span in missing:
if not isinstance(span, Span):
raise ValueError(Errors.E1012.format(span=span))
for i in range(span.start, span.end):
if i in seen_tokens:
raise ValueError(Errors.E1010.format(i=i))
seen_tokens.add(i)
for span in outside:
if not isinstance(span, Span): if not isinstance(span, Span):
raise ValueError(Errors.E1012.format(span=span)) raise ValueError(Errors.E1012.format(span=span))
for i in range(span.start, span.end): for i in range(span.start, span.end):