mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	raise error when setting overlapping entities as doc.ents (#2880)
This commit is contained in:
		
							parent
							
								
									071789467e
								
							
						
					
					
						commit
						57f274b693
					
				| 
						 | 
				
			
			@ -250,7 +250,9 @@ class Errors(object):
 | 
			
		|||
    E096 = ("Invalid object passed to displaCy: Can only visualize Doc or "
 | 
			
		||||
             "Span objects, or dicts if set to manual=True.")
 | 
			
		||||
    E097 = ("Can't merge non-disjoint spans. '{token}' is already part of tokens to merge")
 | 
			
		||||
 | 
			
		||||
    E098 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A token"
 | 
			
		||||
            " can only be part of one entity, so make sure the entities you're "
 | 
			
		||||
            "setting don't overlap.")
 | 
			
		||||
 | 
			
		||||
@add_codes
 | 
			
		||||
class TempErrors(object):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,7 @@ from __future__ import unicode_literals
 | 
			
		|||
 | 
			
		||||
from ...pipeline import EntityRecognizer
 | 
			
		||||
from ..util import get_doc
 | 
			
		||||
from ...tokens import Span
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -22,3 +23,13 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
 | 
			
		|||
 | 
			
		||||
    doc.ents = [(doc.vocab.strings['WORD'], 0, 2)]
 | 
			
		||||
    assert [w.ent_iob_ for w in doc] == ['B', 'I', '', '']
 | 
			
		||||
 | 
			
		||||
def test_add_overlapping_entities(en_vocab):
 | 
			
		||||
    text = ["Louisiana", "Office", "of", "Conservation"]
 | 
			
		||||
    doc = get_doc(en_vocab, text)
 | 
			
		||||
    entity = Span(doc, 0, 4, label=391)
 | 
			
		||||
    doc.ents = [entity]
 | 
			
		||||
 | 
			
		||||
    new_entity = Span(doc, 0, 1, label=392)
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
        doc.ents = list(doc.ents) + [new_entity]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,9 +17,10 @@ def test_issue242(en_tokenizer):
 | 
			
		|||
    matcher.add('FOOD', None, *patterns)
 | 
			
		||||
 | 
			
		||||
    matches = [(ent_type, start, end) for ent_type, start, end in matcher(doc)]
 | 
			
		||||
    doc.ents += tuple(matches)
 | 
			
		||||
    match1, match2 = matches
 | 
			
		||||
    assert match1[1] == 3
 | 
			
		||||
    assert match1[2] == 5
 | 
			
		||||
    assert match2[1] == 4
 | 
			
		||||
    assert match2[2] == 6
 | 
			
		||||
 | 
			
		||||
    doc.ents += tuple([match2])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -458,6 +458,21 @@ cdef class Doc:
 | 
			
		|||
            #    prediction
 | 
			
		||||
            # 3. Test basic data-driven ORTH gazetteer
 | 
			
		||||
            # 4. Test more nuanced date and currency regex
 | 
			
		||||
 | 
			
		||||
            tokens_in_ents = {}
 | 
			
		||||
            cdef attr_t entity_type
 | 
			
		||||
            cdef int ent_start, ent_end
 | 
			
		||||
            for ent_info in ents:
 | 
			
		||||
                entity_type, ent_start, ent_end = get_entity_info(ent_info)
 | 
			
		||||
                for token_index in range(ent_start, ent_end):
 | 
			
		||||
                    if token_index in tokens_in_ents.keys():
 | 
			
		||||
                        raise ValueError(Errors.E098.format(
 | 
			
		||||
                            span1=(tokens_in_ents[token_index][0],
 | 
			
		||||
                                   tokens_in_ents[token_index][1],
 | 
			
		||||
                                   self.vocab.strings[tokens_in_ents[token_index][2]]),
 | 
			
		||||
                            span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
 | 
			
		||||
                    tokens_in_ents[token_index] = (ent_start, ent_end, entity_type)
 | 
			
		||||
 | 
			
		||||
            cdef int i
 | 
			
		||||
            for i in range(self.length):
 | 
			
		||||
                self.c[i].ent_type = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -465,15 +480,7 @@ cdef class Doc:
 | 
			
		|||
            cdef attr_t ent_type
 | 
			
		||||
            cdef int start, end
 | 
			
		||||
            for ent_info in ents:
 | 
			
		||||
                if isinstance(ent_info, Span):
 | 
			
		||||
                    ent_id = ent_info.ent_id
 | 
			
		||||
                    ent_type = ent_info.label
 | 
			
		||||
                    start = ent_info.start
 | 
			
		||||
                    end = ent_info.end
 | 
			
		||||
                elif len(ent_info) == 3:
 | 
			
		||||
                    ent_type, start, end = ent_info
 | 
			
		||||
                else:
 | 
			
		||||
                    ent_id, ent_type, start, end = ent_info
 | 
			
		||||
                ent_type, start, end = get_entity_info(ent_info)
 | 
			
		||||
                if ent_type is None or ent_type < 0:
 | 
			
		||||
                    # Mark as O
 | 
			
		||||
                    for i in range(start, end):
 | 
			
		||||
| 
						 | 
				
			
			@ -1062,3 +1069,14 @@ def fix_attributes(doc, attributes):
 | 
			
		|||
            attributes[ENT_TYPE] = doc.vocab.strings[attributes['label']]
 | 
			
		||||
    if 'ent_type' in attributes:
 | 
			
		||||
        attributes[ENT_TYPE] = attributes['ent_type']
 | 
			
		||||
 | 
			
		||||
def get_entity_info(ent_info):
 | 
			
		||||
    if isinstance(ent_info, Span):
 | 
			
		||||
        ent_type = ent_info.label
 | 
			
		||||
        start = ent_info.start
 | 
			
		||||
        end = ent_info.end
 | 
			
		||||
    elif len(ent_info) == 3:
 | 
			
		||||
        ent_type, start, end = ent_info
 | 
			
		||||
    else:
 | 
			
		||||
        ent_id, ent_type, start, end = ent_info
 | 
			
		||||
    return ent_type, start, end
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user