Refinements to retokenize.split() function (#3282)

* Change retokenize.split() API for heads

* Pass lists as values for attrs in split

* Fix test_doc_split filename

* Add error for mismatched tokens after split

* Raise error if new tokens don't match text

* Fix doc test

* Fix error

* Move deps under attrs

* Fix split tests

* Fix retokenize.split
This commit is contained in:
Matthew Honnibal 2019-02-15 17:32:31 +01:00 committed by GitHub
parent 2dbc61bc26
commit 92b6bd2977
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 73 deletions

View File

@ -324,6 +324,8 @@ class Errors(object):
"labels before training begins. This functionality was available " "labels before training begins. This functionality was available "
"in previous versions, but had significant bugs that led to poor " "in previous versions, but had significant bugs that led to poor "
"performance") "performance")
E117 = ("The newly split tokens must match the text of the original token. "
"New orths: {new}. Old text: {old}.")
@add_codes @add_codes

View File

@ -17,8 +17,16 @@ def test_doc_split(en_vocab):
assert doc[0].head.text == "start" assert doc[0].head.text == "start"
assert doc[1].head.text == "." assert doc[1].head.text == "."
with doc.retokenize() as retokenizer: with doc.retokenize() as retokenizer:
attrs = {"tag": "NNP", "lemma": "Los Angeles", "ent_type": "GPE"} retokenizer.split(
retokenizer.split(doc[0], ["Los", "Angeles"], [1, 0], attrs=attrs) doc[0],
["Los", "Angeles"],
[(doc[0], 1), doc[1]],
attrs={
"tag": ["NNP"]*2,
"lemma": ["Los", "Angeles"],
"ent_type": ["GPE"]*2
},
)
assert len(doc) == 4 assert len(doc) == 4
assert doc[0].text == "Los" assert doc[0].text == "Los"
assert doc[0].head.text == "Angeles" assert doc[0].head.text == "Angeles"
@ -38,7 +46,8 @@ def test_split_dependencies(en_vocab):
dep1 = doc.vocab.strings.add("amod") dep1 = doc.vocab.strings.add("amod")
dep2 = doc.vocab.strings.add("subject") dep2 = doc.vocab.strings.add("subject")
with doc.retokenize() as retokenizer: with doc.retokenize() as retokenizer:
retokenizer.split(doc[0], ["Los", "Angeles"], [1, 0], [dep1, dep2]) retokenizer.split(doc[0], ["Los", "Angeles"],
[(doc[0], 1), doc[1]], attrs={'dep': [dep1, dep2]})
assert doc[0].dep == dep1 assert doc[0].dep == dep1
assert doc[1].dep == dep2 assert doc[1].dep == dep2
@ -48,35 +57,12 @@ def test_split_heads_error(en_vocab):
# Not enough heads # Not enough heads
with pytest.raises(ValueError): with pytest.raises(ValueError):
with doc.retokenize() as retokenizer: with doc.retokenize() as retokenizer:
retokenizer.split(doc[0], ["Los", "Angeles"], [0]) retokenizer.split(doc[0], ["Los", "Angeles"], [doc[1]])
# Too many heads # Too many heads
with pytest.raises(ValueError): with pytest.raises(ValueError):
with doc.retokenize() as retokenizer: with doc.retokenize() as retokenizer:
retokenizer.split(doc[0], ["Los", "Angeles"], [1, 1, 0]) retokenizer.split(doc[0], ["Los", "Angeles"], [doc[1], doc[1], doc[1]])
# No token head
with pytest.raises(ValueError):
with doc.retokenize() as retokenizer:
retokenizer.split(doc[0], ["Los", "Angeles"], [1, 1])
# Several token heads
with pytest.raises(ValueError):
with doc.retokenize() as retokenizer:
retokenizer.split(doc[0], ["Los", "Angeles"], [0, 0])
@pytest.mark.xfail
def test_split_heads_out_of_bounds(en_vocab):
"""Test that the retokenizer raises an error for out-of-bounds heads. The
indices are relative, so head 1 for "Angeles" would be the token following
it, which is out-of-bounds. Previously, the retokenizer would accept this
and spaCy would then fail later.
"""
doc = Doc(en_vocab, words=["Start", "LosAngeles"])
with pytest.raises(ValueError):
with doc.retokenize() as retokenizer:
retokenizer.split(doc[1], ["Los", "Angeles"], [0, 1])
def test_spans_entity_merge_iob(): def test_spans_entity_merge_iob():
@ -87,7 +73,8 @@ def test_spans_entity_merge_iob():
assert doc[0].ent_iob_ == "B" assert doc[0].ent_iob_ == "B"
assert doc[1].ent_iob_ == "I" assert doc[1].ent_iob_ == "I"
with doc.retokenize() as retokenizer: with doc.retokenize() as retokenizer:
retokenizer.split(doc[0], ["a", "b", "c"], [1, 1, 0]) retokenizer.split(doc[0], ["a", "b", "c"],
[(doc[0], 1), (doc[0], 2), doc[1]])
assert doc[0].ent_iob_ == "B" assert doc[0].ent_iob_ == "B"
assert doc[1].ent_iob_ == "I" assert doc[1].ent_iob_ == "I"
assert doc[2].ent_iob_ == "I" assert doc[2].ent_iob_ == "I"
@ -107,14 +94,15 @@ def test_spans_sentence_update_after_merge(en_vocab):
init_len = len(sent1) init_len = len(sent1)
init_len2 = len(sent2) init_len2 = len(sent2)
with doc.retokenize() as retokenizer: with doc.retokenize() as retokenizer:
retokenizer.split(doc[0], ["Stewart", "Lee"], [1, 0]) retokenizer.split(doc[0], ["Stewart", "Lee"], [(doc[0], 1), doc[1]],
retokenizer.split(doc[14], ["Joe", "Pasquale"], [1, 0]) attrs={"dep": ["compound", "nsubj"]})
retokenizer.split(doc[13], ["Joe", "Pasquale"], [(doc[13], 1), doc[12]],
attrs={"dep": ["compound", "dobj"]})
sent1, sent2 = list(doc.sents) sent1, sent2 = list(doc.sents)
assert len(sent1) == init_len + 1 assert len(sent1) == init_len + 1
assert len(sent2) == init_len2 + 1 assert len(sent2) == init_len2 + 1
@pytest.mark.xfail
def test_split_orths_mismatch(en_vocab): def test_split_orths_mismatch(en_vocab):
"""Test that the regular retokenizer.split raises an error if the orths """Test that the regular retokenizer.split raises an error if the orths
don't match the original token text. There might still be a method that don't match the original token text. There might still be a method that
@ -125,4 +113,4 @@ def test_split_orths_mismatch(en_vocab):
doc = Doc(en_vocab, words=["LosAngeles", "start", "."]) doc = Doc(en_vocab, words=["LosAngeles", "start", "."])
with pytest.raises(ValueError): with pytest.raises(ValueError):
with doc.retokenize() as retokenizer: with doc.retokenize() as retokenizer:
retokenizer.split(doc[0], ["L", "A"], [0, -1]) retokenizer.split(doc[0], ["L", "A"], [(doc[0], 0), (doc[0], 0)])

View File

@ -20,6 +20,7 @@ from ..attrs cimport TAG
from ..attrs import intify_attrs from ..attrs import intify_attrs
from ..util import SimpleFrozenDict from ..util import SimpleFrozenDict
from ..errors import Errors from ..errors import Errors
from ..strings import get_string_id
cdef class Retokenizer: cdef class Retokenizer:
@ -46,12 +47,20 @@ cdef class Retokenizer:
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings) attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
self.merges.append((span, attrs)) self.merges.append((span, attrs))
def split(self, Token token, orths, heads, deps=[], attrs=SimpleFrozenDict()): def split(self, Token token, orths, heads, attrs=SimpleFrozenDict()):
"""Mark a Token for splitting, into the specified orths. The attrs """Mark a Token for splitting, into the specified orths. The attrs
will be applied to each subtoken. will be applied to each subtoken.
""" """
if ''.join(orths) != token.text:
raise ValueError(Errors.E117.format(new=''.join(orths), old=token.text))
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings) attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
self.splits.append((token.i, orths, heads, deps, attrs)) head_offsets = []
for head in heads:
if isinstance(head, Token):
head_offsets.append((head.idx, 0))
else:
head_offsets.append((head[0].idx, head[1]))
self.splits.append((token.idx, orths, head_offsets, attrs))
def __enter__(self): def __enter__(self):
self.merges = [] self.merges = []
@ -67,13 +76,31 @@ cdef class Retokenizer:
start = span.start start = span.start
end = span.end end = span.end
_merge(self.doc, start, end, attrs) _merge(self.doc, start, end, attrs)
# Iterate in order, to keep things simple.
for start_char, orths, heads, attrs in sorted(self.splits):
# Resolve token index
token_index = token_by_start(self.doc.c, self.doc.length, start_char)
# Check we're still able to find tokens starting at the character offsets
# referred to in the splits. If we merged these tokens previously, we
# have to raise an error
if token_index == -1:
raise IndexError(
"Cannot find token to be split. Did it get merged?")
head_indices = []
for head_char, subtoken in heads:
head_index = token_by_start(self.doc.c, self.doc.length, head_char)
if head_index == -1:
raise IndexError(
"Cannot find head of token to be split. Did it get merged?")
# We want to refer to the token index of the head *after* the
# mergery. We need to account for the extra tokens introduced.
# e.g., let's say we have [ab, c] and we want a and b to depend
# on c. The correct index for c will be 2, not 1.
if head_index > token_index:
head_index += len(orths)-1
head_indices.append(head_index+subtoken)
_split(self.doc, token_index, orths, head_indices, attrs)
offset = 0
# Iterate in order, to keep the offset simple.
for token_index, orths, heads, deps, attrs in sorted(self.splits):
_split(self.doc, token_index + offset, orths, heads, deps, attrs)
# Adjust for the previous tokens
offset += len(orths)-1
def _merge(Doc doc, int start, int end, attributes): def _merge(Doc doc, int start, int end, attributes):
"""Retokenize the document, such that the span at """Retokenize the document, such that the span at
@ -292,7 +319,7 @@ def _resize_tensor(tensor, ranges):
return xp.delete(tensor, delete, axis=0) return xp.delete(tensor, delete, axis=0)
def _split(Doc doc, int token_index, orths, heads, deps, attrs): def _split(Doc doc, int token_index, orths, heads, attrs):
"""Retokenize the document, such that the token at """Retokenize the document, such that the token at
`doc[token_index]` is split into tokens with the orth 'orths' `doc[token_index]` is split into tokens with the orth 'orths'
token_index(int): token index of the token to split. token_index(int): token index of the token to split.
@ -308,27 +335,14 @@ def _split(Doc doc, int token_index, orths, heads, deps, attrs):
if(len(heads) != nb_subtokens): if(len(heads) != nb_subtokens):
raise ValueError(Errors.E115) raise ValueError(Errors.E115)
token_head_index = -1 # First, make the dependencies absolutes
for index, head in enumerate(heads):
if head == 0:
if token_head_index != -1:
raise ValueError(Errors.E114)
token_head_index = index
if token_head_index == -1:
raise ValueError(Errors.E113)
# First, make the dependencies absolutes, and adjust all possible dependencies before
# creating the tokens
for i in range(doc.length): for i in range(doc.length):
doc.c[i].head += i doc.c[i].head += i
# Adjust dependencies # Adjust dependencies, so they refer to post-split indexing
offset = nb_subtokens - 1 offset = nb_subtokens - 1
for i in range(doc.length): for i in range(doc.length):
head_idx = doc.c[i].head if doc.c[i].head > token_index:
if head_idx == token_index:
doc.c[i].head = token_head_index
elif head_idx > token_index:
doc.c[i].head += offset doc.c[i].head += offset
new_token_head = doc.c[token_index].head
# Double doc.c max_length if necessary (until big enough for all new tokens) # Double doc.c max_length if necessary (until big enough for all new tokens)
while doc.length + nb_subtokens - 1 >= doc.max_length: while doc.length + nb_subtokens - 1 >= doc.max_length:
doc._realloc(doc.length * 2) doc._realloc(doc.length * 2)
@ -352,12 +366,6 @@ def _split(Doc doc, int token_index, orths, heads, deps, attrs):
token.spacy = False token.spacy = False
else: else:
token.spacy = orig_token.spacy token.spacy = orig_token.spacy
# Apply attrs to each subtoken
for attr_name, attr_value in attrs.items():
if attr_name == TAG:
doc.vocab.morphology.assign_tag(token, attr_value)
else:
Token.set_struct_attr(token, attr_name, attr_value)
# Make IOB consistent # Make IOB consistent
if (orig_token.ent_iob == 3): if (orig_token.ent_iob == 3):
if i == 0: if i == 0:
@ -367,17 +375,19 @@ def _split(Doc doc, int token_index, orths, heads, deps, attrs):
else: else:
# In all other cases subtokens inherit iob from origToken # In all other cases subtokens inherit iob from origToken
token.ent_iob = orig_token.ent_iob token.ent_iob = orig_token.ent_iob
# Use the head of the new token everywhere. This will be partially overwritten later on. # Apply attrs to each subtoken
token.head = new_token_head for attr_name, attr_values in attrs.items():
for i, attr_value in enumerate(attr_values):
token = &doc.c[token_index + i]
if attr_name == TAG:
doc.vocab.morphology.assign_tag(token, get_string_id(attr_value))
else:
Token.set_struct_attr(token, attr_name, get_string_id(attr_value))
# Assign correct dependencies to the inner token
for i, head in enumerate(heads):
doc.c[token_index + i].head = head
# Transform the dependencies into relative ones again # Transform the dependencies into relative ones again
for i in range(doc.length): for i in range(doc.length):
doc.c[i].head -= i doc.c[i].head -= i
# Assign correct dependencies to the inner token
for i, head in enumerate(heads):
if head != 0:
# the token's head's head is already correct
doc.c[token_index + i].head = head
for i, dep in enumerate(deps):
doc[token_index + i].dep = dep
# set children from head # set children from head
set_children_from_heads(doc.c, doc.length) set_children_from_heads(doc.c, doc.length)