💫 Allow setting of custom attributes during retokenization (closes #3314) (#3324)

<!--- Provide a general summary of your changes in the title. -->

## Description

This PR adds the abilility to override custom extension attributes during merging. This will only work for attributes that are writable, i.e. attributes registered with a default value like `default=False` or attribute that have both a getter *and* a setter implemented.

```python
Token.set_extension('is_musician', default=False)

doc = nlp("I like David Bowie.")
with doc.retokenize() as retokenizer:
    attrs = {"LEMMA": "David Bowie", "_": {"is_musician": True}}
    retokenizer.merge(doc[2:4], attrs=attrs)

assert doc[2].text == "David Bowie"
assert doc[2].lemma_ == "David Bowie"
assert doc[2]._.is_musician
```

### Types of change
enhancement

## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
This commit is contained in:
Ines Montani 2019-02-24 18:38:47 +01:00 committed by GitHub
parent 403b9cd58b
commit df19e2bff6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 202 additions and 13 deletions

View File

@ -327,6 +327,17 @@ class Errors(object):
"performance.") "performance.")
E117 = ("The newly split tokens must match the text of the original token. " E117 = ("The newly split tokens must match the text of the original token. "
"New orths: {new}. Old text: {old}.") "New orths: {new}. Old text: {old}.")
E118 = ("The custom extension attribute '{attr}' is not registered on the "
"Token object so it can't be set during retokenization. To "
"register an attribute, use the Token.set_extension classmethod.")
E119 = ("Can't set custom extension attribute '{attr}' during retokenization "
"because it's not writable. This usually means it was registered "
"with a getter function (and no setter) or as a method extension, "
"so the value is computed dynamically. To overwrite a custom "
"attribute manually, it should be registered with a default value "
"or with a getter AND setter.")
E120 = ("Can't set custom extension attributes during retokenization. "
"Expected dict mapping attribute names to values, but got: {value}")
@add_codes @add_codes

View File

@ -4,7 +4,7 @@ from __future__ import unicode_literals
import pytest import pytest
from spacy.attrs import LEMMA from spacy.attrs import LEMMA
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tokens import Doc from spacy.tokens import Doc, Token
from ..util import get_doc from ..util import get_doc
@ -259,3 +259,36 @@ def test_doc_retokenize_spans_subtree_size_check(en_tokenizer):
attrs = {"lemma": "none", "ent_type": "none"} attrs = {"lemma": "none", "ent_type": "none"}
retokenizer.merge(doc[0:2], attrs=attrs) retokenizer.merge(doc[0:2], attrs=attrs)
assert len(list(sent1.root.subtree)) == init_len - 1 assert len(list(sent1.root.subtree)) == init_len - 1
def test_doc_retokenize_merge_extension_attrs(en_vocab):
Token.set_extension("a", default=False, force=True)
Token.set_extension("b", default="nothing", force=True)
doc = Doc(en_vocab, words=["hello", "world", "!"])
# Test regular merging
with doc.retokenize() as retokenizer:
attrs = {"lemma": "hello world", "_": {"a": True, "b": "1"}}
retokenizer.merge(doc[0:2], attrs=attrs)
assert doc[0].lemma_ == "hello world"
assert doc[0]._.a == True
assert doc[0]._.b == "1"
# Test bulk merging
doc = Doc(en_vocab, words=["hello", "world", "!", "!"])
with doc.retokenize() as retokenizer:
retokenizer.merge(doc[0:2], attrs={"_": {"a": True, "b": "1"}})
retokenizer.merge(doc[2:4], attrs={"_": {"a": None, "b": "2"}})
assert doc[0]._.a == True
assert doc[0]._.b == "1"
assert doc[1]._.a == None
assert doc[1]._.b == "2"
@pytest.mark.parametrize("underscore_attrs", [{"a": "x"}, {"b": "x"}, {"c": "x"}, [1]])
def test_doc_retokenize_merge_extension_attrs_invalid(en_vocab, underscore_attrs):
Token.set_extension("a", getter=lambda x: x, force=True)
Token.set_extension("b", method=lambda x: x, force=True)
doc = Doc(en_vocab, words=["hello", "world", "!"])
attrs = {"_": underscore_attrs}
with pytest.raises(ValueError):
with doc.retokenize() as retokenizer:
retokenizer.merge(doc[0:2], attrs=attrs)

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import pytest import pytest
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tokens import Doc from spacy.tokens import Doc, Token
from ..util import get_doc from ..util import get_doc
@ -125,3 +125,43 @@ def test_doc_retokenize_split_orths_mismatch(en_vocab):
with pytest.raises(ValueError): with pytest.raises(ValueError):
with doc.retokenize() as retokenizer: with doc.retokenize() as retokenizer:
retokenizer.split(doc[0], ["L", "A"], [(doc[0], 0), (doc[0], 0)]) retokenizer.split(doc[0], ["L", "A"], [(doc[0], 0), (doc[0], 0)])
def test_doc_retokenize_split_extension_attrs(en_vocab):
Token.set_extension("a", default=False, force=True)
Token.set_extension("b", default="nothing", force=True)
doc = Doc(en_vocab, words=["LosAngeles", "start"])
with doc.retokenize() as retokenizer:
heads = [(doc[0], 1), doc[1]]
underscore = [{"a": True, "b": "1"}, {"b": "2"}]
attrs = {"lemma": ["los", "angeles"], "_": underscore}
retokenizer.split(doc[0], ["Los", "Angeles"], heads, attrs=attrs)
assert doc[0].lemma_ == "los"
assert doc[0]._.a == True
assert doc[0]._.b == "1"
assert doc[1].lemma_ == "angeles"
assert doc[1]._.a == False
assert doc[1]._.b == "2"
@pytest.mark.parametrize(
"underscore_attrs",
[
[{"a": "x"}, {}], # Overwriting getter without setter
[{"b": "x"}, {}], # Overwriting method
[{"c": "x"}, {}], # Overwriting nonexistent attribute
[{"a": "x"}, {"x": "x"}], # Combination
[{"a": "x", "x": "x"}, {"x": "x"}], # Combination
{"x": "x"}, # Not a list of dicts
],
)
def test_doc_retokenize_split_extension_attrs_invalid(en_vocab, underscore_attrs):
Token.set_extension("x", default=False, force=True)
Token.set_extension("a", getter=lambda x: x, force=True)
Token.set_extension("b", method=lambda x: x, force=True)
doc = Doc(en_vocab, words=["LosAngeles", "start"])
attrs = {"_": underscore_attrs}
with pytest.raises(ValueError):
with doc.retokenize() as retokenizer:
heads = [(doc[0], 1), doc[1]]
retokenizer.split(doc[0], ["Los", "Angeles"], heads, attrs=attrs)

View File

@ -36,8 +36,8 @@ def test_issue_1971_2(en_vocab):
def test_issue_1971_3(en_vocab): def test_issue_1971_3(en_vocab):
"""Test that pattern matches correctly for multiple extension attributes.""" """Test that pattern matches correctly for multiple extension attributes."""
Token.set_extension("a", default=1) Token.set_extension("a", default=1, force=True)
Token.set_extension("b", default=2) Token.set_extension("b", default=2, force=True)
doc = Doc(en_vocab, words=["hello", "world"]) doc = Doc(en_vocab, words=["hello", "world"])
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)
matcher.add("A", None, [{"_": {"a": 1}}]) matcher.add("A", None, [{"_": {"a": 1}}])
@ -51,8 +51,8 @@ def test_issue_1971_4(en_vocab):
"""Test that pattern matches correctly with multiple extension attribute """Test that pattern matches correctly with multiple extension attribute
values on a single token. values on a single token.
""" """
Token.set_extension("ext_a", default="str_a") Token.set_extension("ext_a", default="str_a", force=True)
Token.set_extension("ext_b", default="str_b") Token.set_extension("ext_b", default="str_b", force=True)
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)
doc = Doc(en_vocab, words=["this", "is", "text"]) doc = Doc(en_vocab, words=["this", "is", "text"])
pattern = [{"_": {"ext_a": "str_a", "ext_b": "str_b"}}] * 3 pattern = [{"_": {"ext_a": "str_a", "ext_b": "str_b"}}] * 3

View File

@ -17,6 +17,8 @@ from .token cimport Token
from ..lexeme cimport Lexeme, EMPTY_LEXEME from ..lexeme cimport Lexeme, EMPTY_LEXEME
from ..structs cimport LexemeC, TokenC from ..structs cimport LexemeC, TokenC
from ..attrs cimport TAG from ..attrs cimport TAG
from .underscore import is_writable_attr
from ..attrs import intify_attrs from ..attrs import intify_attrs
from ..util import SimpleFrozenDict from ..util import SimpleFrozenDict
from ..errors import Errors from ..errors import Errors
@ -43,7 +45,13 @@ cdef class Retokenizer:
if token.i in self.tokens_to_merge: if token.i in self.tokens_to_merge:
raise ValueError(Errors.E102.format(token=repr(token))) raise ValueError(Errors.E102.format(token=repr(token)))
self.tokens_to_merge.add(token.i) self.tokens_to_merge.add(token.i)
if "_" in attrs: # Extension attributes
extensions = attrs["_"]
_validate_extensions(extensions)
attrs = {key: value for key, value in attrs.items() if key != "_"}
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
attrs["_"] = extensions
else:
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings) attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
self.merges.append((span, attrs)) self.merges.append((span, attrs))
@ -53,6 +61,14 @@ cdef class Retokenizer:
""" """
if ''.join(orths) != token.text: if ''.join(orths) != token.text:
raise ValueError(Errors.E117.format(new=''.join(orths), old=token.text)) raise ValueError(Errors.E117.format(new=''.join(orths), old=token.text))
if "_" in attrs: # Extension attributes
extensions = attrs["_"]
for extension in extensions:
_validate_extensions(extension)
attrs = {key: value for key, value in attrs.items() if key != "_"}
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
attrs["_"] = extensions
else:
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings) attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
head_offsets = [] head_offsets = []
for head in heads: for head in heads:
@ -131,7 +147,10 @@ def _merge(Doc doc, int start, int end, attributes):
cdef TokenC* token = &doc.c[start] cdef TokenC* token = &doc.c[start]
token.spacy = doc.c[end-1].spacy token.spacy = doc.c[end-1].spacy
for attr_name, attr_value in attributes.items(): for attr_name, attr_value in attributes.items():
if attr_name == TAG: if attr_name == "_": # Set extension attributes
for ext_attr_key, ext_attr_value in attr_value.items():
doc[start]._.set(ext_attr_key, ext_attr_value)
elif attr_name == TAG:
doc.vocab.morphology.assign_tag(token, attr_value) doc.vocab.morphology.assign_tag(token, attr_value)
else: else:
Token.set_struct_attr(token, attr_name, attr_value) Token.set_struct_attr(token, attr_name, attr_value)
@ -183,6 +202,7 @@ def _merge(Doc doc, int start, int end, attributes):
# Return the merged Python object # Return the merged Python object
return doc[start] return doc[start]
def _bulk_merge(Doc doc, merges): def _bulk_merge(Doc doc, merges):
"""Retokenize the document, such that the spans described in 'merges' """Retokenize the document, such that the spans described in 'merges'
are merged into a single token. This method assumes that the merges are merged into a single token. This method assumes that the merges
@ -213,7 +233,10 @@ def _bulk_merge(Doc doc, merges):
tokens[merge_index] = token tokens[merge_index] = token
# Assign attributes # Assign attributes
for attr_name, attr_value in attributes.items(): for attr_name, attr_value in attributes.items():
if attr_name == TAG: if attr_name == "_": # Set extension attributes
for ext_attr_key, ext_attr_value in attr_value.items():
doc[start]._.set(ext_attr_key, ext_attr_value)
elif attr_name == TAG:
doc.vocab.morphology.assign_tag(token, attr_value) doc.vocab.morphology.assign_tag(token, attr_value)
else: else:
Token.set_struct_attr(token, attr_name, attr_value) Token.set_struct_attr(token, attr_name, attr_value)
@ -379,7 +402,10 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
for attr_name, attr_values in attrs.items(): for attr_name, attr_values in attrs.items():
for i, attr_value in enumerate(attr_values): for i, attr_value in enumerate(attr_values):
token = &doc.c[token_index + i] token = &doc.c[token_index + i]
if attr_name == TAG: if attr_name == "_":
for ext_attr_key, ext_attr_value in attr_value.items():
doc[token_index + i]._.set(ext_attr_key, ext_attr_value)
elif attr_name == TAG:
doc.vocab.morphology.assign_tag(token, get_string_id(attr_value)) doc.vocab.morphology.assign_tag(token, get_string_id(attr_value))
else: else:
Token.set_struct_attr(token, attr_name, get_string_id(attr_value)) Token.set_struct_attr(token, attr_name, get_string_id(attr_value))
@ -391,3 +417,15 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
doc.c[i].head -= i doc.c[i].head -= i
# set children from head # set children from head
set_children_from_heads(doc.c, doc.length) set_children_from_heads(doc.c, doc.length)
def _validate_extensions(extensions):
if not isinstance(extensions, dict):
raise ValueError(Errors.E120.format(value=repr(extensions)))
for key, value in extensions.items():
# Get the extension and make sure it's available and writable
extension = Token.get_extension(key)
if not extension: # Extension attribute doesn't exist
raise ValueError(Errors.E118.format(attr=key))
if not is_writable_attr(extension):
raise ValueError(Errors.E119.format(attr=key))

View File

@ -75,3 +75,18 @@ def get_ext_args(**kwargs):
if method is not None and not hasattr(method, "__call__"): if method is not None and not hasattr(method, "__call__"):
raise ValueError(Errors.E091.format(name="method", value=repr(method))) raise ValueError(Errors.E091.format(name="method", value=repr(method)))
return (default, method, getter, setter) return (default, method, getter, setter)
def is_writable_attr(ext):
"""Check if an extension attribute is writable.
ext (tuple): The (default, getter, setter, method) tuple available via
{Doc,Span,Token}.get_extension.
RETURNS (bool): Whether the attribute is writable.
"""
default, method, getter, setter = ext
# Extension is writable if it has a setter (getter + setter), if it has a
# default value (or, if its default value is none, none of the other values
# should be set).
if setter is not None or default is not None or all(e is None for e in ext):
return True
return False

View File

@ -403,6 +403,8 @@ invalidated, although they may accidentally continue to work.
### Retokenizer.merge {#retokenizer.merge tag="method"} ### Retokenizer.merge {#retokenizer.merge tag="method"}
Mark a span for merging. The `attrs` will be applied to the resulting token. Mark a span for merging. The `attrs` will be applied to the resulting token.
Writable custom extension attributes can be provided as a dictionary mapping
attribute names to values as the `"_"` key.
> #### Example > #### Example
> >
@ -428,7 +430,8 @@ subtoken index. For example, `(doc[3], 1)` will attach the subtoken to the
second subtoken of `doc[3]`. This mechanism allows attaching subtokens to other second subtoken of `doc[3]`. This mechanism allows attaching subtokens to other
newly created subtokens, without having to keep track of the changing token newly created subtokens, without having to keep track of the changing token
indices. If the specified head token will be split within the retokenizer block indices. If the specified head token will be split within the retokenizer block
and no subtoken index is specified, it will default to `0`. and no subtoken index is specified, it will default to `0`. Attributes to set on
subtokens can be provided as a list of values.
> #### Example > #### Example
> >

View File

@ -1083,6 +1083,55 @@ with doc.retokenize() as retokenizer:
</Infobox> </Infobox>
### Overwriting custom extension attributes {#retokenization-extensions}
If you've registered custom
[extension attributes](/usage/processing-pipelines##custom-components-attributes),
you can overwrite them during tokenization by providing a dictionary of
attribute names mapped to new values as the `"_"` key in the `attrs`. For
merging, you need to provide one dictionary of attributes for the resulting
merged token. For splitting, you need to provide a list of dictionaries with
custom attributes, one per split subtoken.
<Infobox title="Important note" variant="warning">
To set extension attributes during retokenization, the attributes need to be
**registered** using the [`Token.set_extension`](/api/token#set_extension)
method and they need to be **writable**. This means that they should either have
a default value that can be overwritten, or a getter _and_ setter. Method
extensions or extensions with only a getter are computed dynamically, so their
values can't be overwritten. For more details, see the
[extension attribute docs](/usage/processing-pipelines/#custom-components-attributes).
</Infobox>
> #### ✏️ Things to try
>
> 1. Add another custom extension maybe `"music_style"`? and overwrite it.
> 2. Change the extension attribute to use only a `getter` function. You should
> see that spaCy raises an error, because the attribute is not writable
> anymore.
> 3. Rewrite the code to split a token with `retokenizer.split`. Remember that
> you need to provide a list of extension attribute values as the `"_"`
> property, one for each split subtoken.
```python
### {executable="true"}
import spacy
from spacy.tokens import Token
# Register a custom token attribute, token._.is_musician
Token.set_extension("is_musician", default=False)
nlp = spacy.load("en_core_web_sm")
doc = nlp("I like David Bowie")
print("Before:", [(token.text, token._.is_musician) for token in doc])
with doc.retokenize() as retokenizer:
retokenizer.merge(doc[2:4], attrs={"_": {"is_musician": True}})
print("After:", [(token.text, token._.is_musician) for token in doc])
```
## Sentence Segmentation {#sbd} ## Sentence Segmentation {#sbd}
A [`Doc`](/api/doc) object's sentences are available via the `Doc.sents` A [`Doc`](/api/doc) object's sentences are available via the `Doc.sents`