mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
<!--- Provide a general summary of your changes in the title. --> ## Description This PR adds the abilility to override custom extension attributes during merging. This will only work for attributes that are writable, i.e. attributes registered with a default value like `default=False` or attribute that have both a getter *and* a setter implemented. ```python Token.set_extension('is_musician', default=False) doc = nlp("I like David Bowie.") with doc.retokenize() as retokenizer: attrs = {"LEMMA": "David Bowie", "_": {"is_musician": True}} retokenizer.merge(doc[2:4], attrs=attrs) assert doc[2].text == "David Bowie" assert doc[2].lemma_ == "David Bowie" assert doc[2]._.is_musician ``` ### Types of change enhancement ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
This commit is contained in:
parent
403b9cd58b
commit
df19e2bff6
|
@ -327,6 +327,17 @@ class Errors(object):
|
|||
"performance.")
|
||||
E117 = ("The newly split tokens must match the text of the original token. "
|
||||
"New orths: {new}. Old text: {old}.")
|
||||
E118 = ("The custom extension attribute '{attr}' is not registered on the "
|
||||
"Token object so it can't be set during retokenization. To "
|
||||
"register an attribute, use the Token.set_extension classmethod.")
|
||||
E119 = ("Can't set custom extension attribute '{attr}' during retokenization "
|
||||
"because it's not writable. This usually means it was registered "
|
||||
"with a getter function (and no setter) or as a method extension, "
|
||||
"so the value is computed dynamically. To overwrite a custom "
|
||||
"attribute manually, it should be registered with a default value "
|
||||
"or with a getter AND setter.")
|
||||
E120 = ("Can't set custom extension attributes during retokenization. "
|
||||
"Expected dict mapping attribute names to values, but got: {value}")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import unicode_literals
|
|||
import pytest
|
||||
from spacy.attrs import LEMMA
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.tokens import Doc
|
||||
from spacy.tokens import Doc, Token
|
||||
|
||||
from ..util import get_doc
|
||||
|
||||
|
@ -259,3 +259,36 @@ def test_doc_retokenize_spans_subtree_size_check(en_tokenizer):
|
|||
attrs = {"lemma": "none", "ent_type": "none"}
|
||||
retokenizer.merge(doc[0:2], attrs=attrs)
|
||||
assert len(list(sent1.root.subtree)) == init_len - 1
|
||||
|
||||
|
||||
def test_doc_retokenize_merge_extension_attrs(en_vocab):
|
||||
Token.set_extension("a", default=False, force=True)
|
||||
Token.set_extension("b", default="nothing", force=True)
|
||||
doc = Doc(en_vocab, words=["hello", "world", "!"])
|
||||
# Test regular merging
|
||||
with doc.retokenize() as retokenizer:
|
||||
attrs = {"lemma": "hello world", "_": {"a": True, "b": "1"}}
|
||||
retokenizer.merge(doc[0:2], attrs=attrs)
|
||||
assert doc[0].lemma_ == "hello world"
|
||||
assert doc[0]._.a == True
|
||||
assert doc[0]._.b == "1"
|
||||
# Test bulk merging
|
||||
doc = Doc(en_vocab, words=["hello", "world", "!", "!"])
|
||||
with doc.retokenize() as retokenizer:
|
||||
retokenizer.merge(doc[0:2], attrs={"_": {"a": True, "b": "1"}})
|
||||
retokenizer.merge(doc[2:4], attrs={"_": {"a": None, "b": "2"}})
|
||||
assert doc[0]._.a == True
|
||||
assert doc[0]._.b == "1"
|
||||
assert doc[1]._.a == None
|
||||
assert doc[1]._.b == "2"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("underscore_attrs", [{"a": "x"}, {"b": "x"}, {"c": "x"}, [1]])
|
||||
def test_doc_retokenize_merge_extension_attrs_invalid(en_vocab, underscore_attrs):
|
||||
Token.set_extension("a", getter=lambda x: x, force=True)
|
||||
Token.set_extension("b", method=lambda x: x, force=True)
|
||||
doc = Doc(en_vocab, words=["hello", "world", "!"])
|
||||
attrs = {"_": underscore_attrs}
|
||||
with pytest.raises(ValueError):
|
||||
with doc.retokenize() as retokenizer:
|
||||
retokenizer.merge(doc[0:2], attrs=attrs)
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import pytest
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.tokens import Doc
|
||||
from spacy.tokens import Doc, Token
|
||||
|
||||
from ..util import get_doc
|
||||
|
||||
|
@ -125,3 +125,43 @@ def test_doc_retokenize_split_orths_mismatch(en_vocab):
|
|||
with pytest.raises(ValueError):
|
||||
with doc.retokenize() as retokenizer:
|
||||
retokenizer.split(doc[0], ["L", "A"], [(doc[0], 0), (doc[0], 0)])
|
||||
|
||||
|
||||
def test_doc_retokenize_split_extension_attrs(en_vocab):
|
||||
Token.set_extension("a", default=False, force=True)
|
||||
Token.set_extension("b", default="nothing", force=True)
|
||||
doc = Doc(en_vocab, words=["LosAngeles", "start"])
|
||||
with doc.retokenize() as retokenizer:
|
||||
heads = [(doc[0], 1), doc[1]]
|
||||
underscore = [{"a": True, "b": "1"}, {"b": "2"}]
|
||||
attrs = {"lemma": ["los", "angeles"], "_": underscore}
|
||||
retokenizer.split(doc[0], ["Los", "Angeles"], heads, attrs=attrs)
|
||||
assert doc[0].lemma_ == "los"
|
||||
assert doc[0]._.a == True
|
||||
assert doc[0]._.b == "1"
|
||||
assert doc[1].lemma_ == "angeles"
|
||||
assert doc[1]._.a == False
|
||||
assert doc[1]._.b == "2"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"underscore_attrs",
|
||||
[
|
||||
[{"a": "x"}, {}], # Overwriting getter without setter
|
||||
[{"b": "x"}, {}], # Overwriting method
|
||||
[{"c": "x"}, {}], # Overwriting nonexistent attribute
|
||||
[{"a": "x"}, {"x": "x"}], # Combination
|
||||
[{"a": "x", "x": "x"}, {"x": "x"}], # Combination
|
||||
{"x": "x"}, # Not a list of dicts
|
||||
],
|
||||
)
|
||||
def test_doc_retokenize_split_extension_attrs_invalid(en_vocab, underscore_attrs):
|
||||
Token.set_extension("x", default=False, force=True)
|
||||
Token.set_extension("a", getter=lambda x: x, force=True)
|
||||
Token.set_extension("b", method=lambda x: x, force=True)
|
||||
doc = Doc(en_vocab, words=["LosAngeles", "start"])
|
||||
attrs = {"_": underscore_attrs}
|
||||
with pytest.raises(ValueError):
|
||||
with doc.retokenize() as retokenizer:
|
||||
heads = [(doc[0], 1), doc[1]]
|
||||
retokenizer.split(doc[0], ["Los", "Angeles"], heads, attrs=attrs)
|
||||
|
|
|
@ -36,8 +36,8 @@ def test_issue_1971_2(en_vocab):
|
|||
|
||||
def test_issue_1971_3(en_vocab):
|
||||
"""Test that pattern matches correctly for multiple extension attributes."""
|
||||
Token.set_extension("a", default=1)
|
||||
Token.set_extension("b", default=2)
|
||||
Token.set_extension("a", default=1, force=True)
|
||||
Token.set_extension("b", default=2, force=True)
|
||||
doc = Doc(en_vocab, words=["hello", "world"])
|
||||
matcher = Matcher(en_vocab)
|
||||
matcher.add("A", None, [{"_": {"a": 1}}])
|
||||
|
@ -51,8 +51,8 @@ def test_issue_1971_4(en_vocab):
|
|||
"""Test that pattern matches correctly with multiple extension attribute
|
||||
values on a single token.
|
||||
"""
|
||||
Token.set_extension("ext_a", default="str_a")
|
||||
Token.set_extension("ext_b", default="str_b")
|
||||
Token.set_extension("ext_a", default="str_a", force=True)
|
||||
Token.set_extension("ext_b", default="str_b", force=True)
|
||||
matcher = Matcher(en_vocab)
|
||||
doc = Doc(en_vocab, words=["this", "is", "text"])
|
||||
pattern = [{"_": {"ext_a": "str_a", "ext_b": "str_b"}}] * 3
|
||||
|
|
|
@ -17,6 +17,8 @@ from .token cimport Token
|
|||
from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
||||
from ..structs cimport LexemeC, TokenC
|
||||
from ..attrs cimport TAG
|
||||
|
||||
from .underscore import is_writable_attr
|
||||
from ..attrs import intify_attrs
|
||||
from ..util import SimpleFrozenDict
|
||||
from ..errors import Errors
|
||||
|
@ -43,8 +45,14 @@ cdef class Retokenizer:
|
|||
if token.i in self.tokens_to_merge:
|
||||
raise ValueError(Errors.E102.format(token=repr(token)))
|
||||
self.tokens_to_merge.add(token.i)
|
||||
|
||||
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
|
||||
if "_" in attrs: # Extension attributes
|
||||
extensions = attrs["_"]
|
||||
_validate_extensions(extensions)
|
||||
attrs = {key: value for key, value in attrs.items() if key != "_"}
|
||||
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
|
||||
attrs["_"] = extensions
|
||||
else:
|
||||
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
|
||||
self.merges.append((span, attrs))
|
||||
|
||||
def split(self, Token token, orths, heads, attrs=SimpleFrozenDict()):
|
||||
|
@ -53,7 +61,15 @@ cdef class Retokenizer:
|
|||
"""
|
||||
if ''.join(orths) != token.text:
|
||||
raise ValueError(Errors.E117.format(new=''.join(orths), old=token.text))
|
||||
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
|
||||
if "_" in attrs: # Extension attributes
|
||||
extensions = attrs["_"]
|
||||
for extension in extensions:
|
||||
_validate_extensions(extension)
|
||||
attrs = {key: value for key, value in attrs.items() if key != "_"}
|
||||
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
|
||||
attrs["_"] = extensions
|
||||
else:
|
||||
attrs = intify_attrs(attrs, strings_map=self.doc.vocab.strings)
|
||||
head_offsets = []
|
||||
for head in heads:
|
||||
if isinstance(head, Token):
|
||||
|
@ -131,7 +147,10 @@ def _merge(Doc doc, int start, int end, attributes):
|
|||
cdef TokenC* token = &doc.c[start]
|
||||
token.spacy = doc.c[end-1].spacy
|
||||
for attr_name, attr_value in attributes.items():
|
||||
if attr_name == TAG:
|
||||
if attr_name == "_": # Set extension attributes
|
||||
for ext_attr_key, ext_attr_value in attr_value.items():
|
||||
doc[start]._.set(ext_attr_key, ext_attr_value)
|
||||
elif attr_name == TAG:
|
||||
doc.vocab.morphology.assign_tag(token, attr_value)
|
||||
else:
|
||||
Token.set_struct_attr(token, attr_name, attr_value)
|
||||
|
@ -183,6 +202,7 @@ def _merge(Doc doc, int start, int end, attributes):
|
|||
# Return the merged Python object
|
||||
return doc[start]
|
||||
|
||||
|
||||
def _bulk_merge(Doc doc, merges):
|
||||
"""Retokenize the document, such that the spans described in 'merges'
|
||||
are merged into a single token. This method assumes that the merges
|
||||
|
@ -213,7 +233,10 @@ def _bulk_merge(Doc doc, merges):
|
|||
tokens[merge_index] = token
|
||||
# Assign attributes
|
||||
for attr_name, attr_value in attributes.items():
|
||||
if attr_name == TAG:
|
||||
if attr_name == "_": # Set extension attributes
|
||||
for ext_attr_key, ext_attr_value in attr_value.items():
|
||||
doc[start]._.set(ext_attr_key, ext_attr_value)
|
||||
elif attr_name == TAG:
|
||||
doc.vocab.morphology.assign_tag(token, attr_value)
|
||||
else:
|
||||
Token.set_struct_attr(token, attr_name, attr_value)
|
||||
|
@ -379,7 +402,10 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
|
|||
for attr_name, attr_values in attrs.items():
|
||||
for i, attr_value in enumerate(attr_values):
|
||||
token = &doc.c[token_index + i]
|
||||
if attr_name == TAG:
|
||||
if attr_name == "_":
|
||||
for ext_attr_key, ext_attr_value in attr_value.items():
|
||||
doc[token_index + i]._.set(ext_attr_key, ext_attr_value)
|
||||
elif attr_name == TAG:
|
||||
doc.vocab.morphology.assign_tag(token, get_string_id(attr_value))
|
||||
else:
|
||||
Token.set_struct_attr(token, attr_name, get_string_id(attr_value))
|
||||
|
@ -391,3 +417,15 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
|
|||
doc.c[i].head -= i
|
||||
# set children from head
|
||||
set_children_from_heads(doc.c, doc.length)
|
||||
|
||||
|
||||
def _validate_extensions(extensions):
|
||||
if not isinstance(extensions, dict):
|
||||
raise ValueError(Errors.E120.format(value=repr(extensions)))
|
||||
for key, value in extensions.items():
|
||||
# Get the extension and make sure it's available and writable
|
||||
extension = Token.get_extension(key)
|
||||
if not extension: # Extension attribute doesn't exist
|
||||
raise ValueError(Errors.E118.format(attr=key))
|
||||
if not is_writable_attr(extension):
|
||||
raise ValueError(Errors.E119.format(attr=key))
|
||||
|
|
|
@ -75,3 +75,18 @@ def get_ext_args(**kwargs):
|
|||
if method is not None and not hasattr(method, "__call__"):
|
||||
raise ValueError(Errors.E091.format(name="method", value=repr(method)))
|
||||
return (default, method, getter, setter)
|
||||
|
||||
|
||||
def is_writable_attr(ext):
|
||||
"""Check if an extension attribute is writable.
|
||||
ext (tuple): The (default, getter, setter, method) tuple available via
|
||||
{Doc,Span,Token}.get_extension.
|
||||
RETURNS (bool): Whether the attribute is writable.
|
||||
"""
|
||||
default, method, getter, setter = ext
|
||||
# Extension is writable if it has a setter (getter + setter), if it has a
|
||||
# default value (or, if its default value is none, none of the other values
|
||||
# should be set).
|
||||
if setter is not None or default is not None or all(e is None for e in ext):
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -403,6 +403,8 @@ invalidated, although they may accidentally continue to work.
|
|||
### Retokenizer.merge {#retokenizer.merge tag="method"}
|
||||
|
||||
Mark a span for merging. The `attrs` will be applied to the resulting token.
|
||||
Writable custom extension attributes can be provided as a dictionary mapping
|
||||
attribute names to values as the `"_"` key.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -428,7 +430,8 @@ subtoken index. For example, `(doc[3], 1)` will attach the subtoken to the
|
|||
second subtoken of `doc[3]`. This mechanism allows attaching subtokens to other
|
||||
newly created subtokens, without having to keep track of the changing token
|
||||
indices. If the specified head token will be split within the retokenizer block
|
||||
and no subtoken index is specified, it will default to `0`.
|
||||
and no subtoken index is specified, it will default to `0`. Attributes to set on
|
||||
subtokens can be provided as a list of values.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
|
|
@ -1083,6 +1083,55 @@ with doc.retokenize() as retokenizer:
|
|||
|
||||
</Infobox>
|
||||
|
||||
### Overwriting custom extension attributes {#retokenization-extensions}
|
||||
|
||||
If you've registered custom
|
||||
[extension attributes](/usage/processing-pipelines##custom-components-attributes),
|
||||
you can overwrite them during tokenization by providing a dictionary of
|
||||
attribute names mapped to new values as the `"_"` key in the `attrs`. For
|
||||
merging, you need to provide one dictionary of attributes for the resulting
|
||||
merged token. For splitting, you need to provide a list of dictionaries with
|
||||
custom attributes, one per split subtoken.
|
||||
|
||||
<Infobox title="Important note" variant="warning">
|
||||
|
||||
To set extension attributes during retokenization, the attributes need to be
|
||||
**registered** using the [`Token.set_extension`](/api/token#set_extension)
|
||||
method and they need to be **writable**. This means that they should either have
|
||||
a default value that can be overwritten, or a getter _and_ setter. Method
|
||||
extensions or extensions with only a getter are computed dynamically, so their
|
||||
values can't be overwritten. For more details, see the
|
||||
[extension attribute docs](/usage/processing-pipelines/#custom-components-attributes).
|
||||
|
||||
</Infobox>
|
||||
|
||||
> #### ✏️ Things to try
|
||||
>
|
||||
> 1. Add another custom extension – maybe `"music_style"`? – and overwrite it.
|
||||
> 2. Change the extension attribute to use only a `getter` function. You should
|
||||
> see that spaCy raises an error, because the attribute is not writable
|
||||
> anymore.
|
||||
> 3. Rewrite the code to split a token with `retokenizer.split`. Remember that
|
||||
> you need to provide a list of extension attribute values as the `"_"`
|
||||
> property, one for each split subtoken.
|
||||
|
||||
```python
|
||||
### {executable="true"}
|
||||
import spacy
|
||||
from spacy.tokens import Token
|
||||
|
||||
# Register a custom token attribute, token._.is_musician
|
||||
Token.set_extension("is_musician", default=False)
|
||||
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
doc = nlp("I like David Bowie")
|
||||
print("Before:", [(token.text, token._.is_musician) for token in doc])
|
||||
|
||||
with doc.retokenize() as retokenizer:
|
||||
retokenizer.merge(doc[2:4], attrs={"_": {"is_musician": True}})
|
||||
print("After:", [(token.text, token._.is_musician) for token in doc])
|
||||
```
|
||||
|
||||
## Sentence Segmentation {#sbd}
|
||||
|
||||
A [`Doc`](/api/doc) object's sentences are available via the `Doc.sents`
|
||||
|
|
Loading…
Reference in New Issue
Block a user