error msg and unit tests for setting kb_id on span

This commit is contained in:
svlandeg 2019-03-22 12:05:35 +01:00
parent 3c9ac59ea0
commit 5b1cd49222
3 changed files with 24 additions and 4 deletions

View File

@ -371,6 +371,11 @@ class Errors(object):
"with spacy >= 2.1.0. To fix this, reinstall Python and use a wide " "with spacy >= 2.1.0. To fix this, reinstall Python and use a wide "
"unicode build instead. You can also rebuild Python and set the " "unicode build instead. You can also rebuild Python and set the "
"--enable-unicode=ucs4 flag.") "--enable-unicode=ucs4 flag.")
E131 = ("Cannot write the kb_id of an existing Span object because a Span "
"is a read-only view of the underlying Token objects stored in the Doc. "
"Instead, create a new Span object and specify the `kb_id` keyword argument, "
"for example:\nfrom spacy.tokens import Span\n"
"span = Span(doc, start={start}, end={end}, label='{label}', kb_id='{kb_id}')")
@add_codes @add_codes

View File

@ -172,10 +172,12 @@ def test_span_as_doc(doc):
assert span_doc[0].idx == 0 assert span_doc[0].idx == 0
def test_span_string_label(doc): def test_span_string_label_kb_id(doc):
span = Span(doc, 0, 1, label="hello") span = Span(doc, 0, 1, label="hello", kb_id="Q342")
assert span.label_ == "hello" assert span.label_ == "hello"
assert span.label == doc.vocab.strings["hello"] assert span.label == doc.vocab.strings["hello"]
assert span.kb_id_ == "Q342"
assert span.kb_id == doc.vocab.strings["Q342"]
def test_span_label_readonly(doc): def test_span_label_readonly(doc):
@ -184,6 +186,12 @@ def test_span_label_readonly(doc):
span.label_ = "hello" span.label_ = "hello"
def test_span_kb_id_readonly(doc):
span = Span(doc, 0, 1)
with pytest.raises(NotImplementedError):
span.kb_id_ = "Q342"
def test_span_ents_property(doc): def test_span_ents_property(doc):
"""Test span.ents for the """ """Test span.ents for the """
doc.ents = [ doc.ents = [

View File

@ -111,6 +111,8 @@ cdef class Span:
self.end_char = 0 self.end_char = 0
if isinstance(label, basestring_): if isinstance(label, basestring_):
label = doc.vocab.strings.add(label) label = doc.vocab.strings.add(label)
if isinstance(kb_id, basestring_):
kb_id = doc.vocab.strings.add(kb_id)
if label not in doc.vocab.strings: if label not in doc.vocab.strings:
raise ValueError(Errors.E084.format(label=label)) raise ValueError(Errors.E084.format(label=label))
self.label = label self.label = label
@ -662,9 +664,14 @@ cdef class Span:
def __get__(self): def __get__(self):
return self.doc.vocab.strings[self.kb_id] return self.doc.vocab.strings[self.kb_id]
# TODO: custom error msg like for label_
def __set__(self, unicode kb_id_): def __set__(self, unicode kb_id_):
raise NotImplementedError(TempErrors.T007.format(attr='kb_id_')) if not kb_id_:
kb_id_ = ''
current_label = self.label_
if not current_label:
current_label = ''
raise NotImplementedError(Errors.E131.format(start=self.start, end=self.end,
label=current_label, kb_id=kb_id_))
cdef int _count_words_to_root(const TokenC* token, int sent_length) except -1: cdef int _count_words_to_root(const TokenC* token, int sent_length) except -1: