Add doc_cleaner component (#9659)

* Add doc_cleaner component

* Fix types

* Fix loop

* Rephrase method description
This commit is contained in:
Adriane Boyd 2021-11-23 15:33:33 +01:00 committed by GitHub
parent a77f50baa4
commit 9ac6d4991e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 0 deletions

View File

@ -191,6 +191,7 @@ class Warnings(metaclass=ErrorsWithCodes):
"lead to errors.") "lead to errors.")
W115 = ("Skipping {method}: the floret vector table cannot be modified. " W115 = ("Skipping {method}: the floret vector table cannot be modified. "
"Vectors are calculated from character ngrams.") "Vectors are calculated from character ngrams.")
W116 = ("Unable to clean attribute '{attr}'.")
class Errors(metaclass=ErrorsWithCodes): class Errors(metaclass=ErrorsWithCodes):

View File

@ -1,6 +1,8 @@
from typing import Dict, Any from typing import Dict, Any
import srsly import srsly
import warnings
from ..errors import Warnings
from ..language import Language from ..language import Language
from ..matcher import Matcher from ..matcher import Matcher
from ..tokens import Doc from ..tokens import Doc
@ -136,3 +138,65 @@ class TokenSplitter:
"cfg": lambda p: self._set_config(srsly.read_json(p)), "cfg": lambda p: self._set_config(srsly.read_json(p)),
} }
util.from_disk(path, serializers, []) util.from_disk(path, serializers, [])
@Language.factory(
"doc_cleaner",
default_config={"attrs": {"tensor": None, "_.trf_data": None}, "silent": True},
)
def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool):
return DocCleaner(attrs, silent=silent)
class DocCleaner:
def __init__(self, attrs: Dict[str, Any], *, silent: bool = True):
self.cfg: Dict[str, Any] = {"attrs": dict(attrs), "silent": silent}
def __call__(self, doc: Doc) -> Doc:
attrs: dict = self.cfg["attrs"]
silent: bool = self.cfg["silent"]
for attr, value in attrs.items():
obj = doc
parts = attr.split(".")
skip = False
for part in parts[:-1]:
if hasattr(obj, part):
obj = getattr(obj, part)
else:
skip = True
if not silent:
warnings.warn(Warnings.W116.format(attr=attr))
if not skip:
if hasattr(obj, parts[-1]):
setattr(obj, parts[-1], value)
else:
if not silent:
warnings.warn(Warnings.W116.format(attr=attr))
return doc
def to_bytes(self, **kwargs):
serializers = {
"cfg": lambda: srsly.json_dumps(self.cfg),
}
return util.to_bytes(serializers, [])
def from_bytes(self, data, **kwargs):
deserializers = {
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
}
util.from_bytes(data, deserializers, [])
return self
def to_disk(self, path, **kwargs):
path = util.ensure_path(path)
serializers = {
"cfg": lambda p: srsly.write_json(p, self.cfg),
}
return util.to_disk(path, serializers, [])
def from_disk(self, path, **kwargs):
path = util.ensure_path(path)
serializers = {
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
}
util.from_disk(path, serializers, [])

View File

@ -3,6 +3,8 @@ from spacy.pipeline.functions import merge_subtokens
from spacy.language import Language from spacy.language import Language
from spacy.tokens import Span, Doc from spacy.tokens import Span, Doc
from ..doc.test_underscore import clean_underscore # noqa: F401
@pytest.fixture @pytest.fixture
def doc(en_vocab): def doc(en_vocab):
@ -74,3 +76,26 @@ def test_token_splitter():
"i", "i",
] ]
assert all(len(t.text) <= token_splitter.split_length for t in doc) assert all(len(t.text) <= token_splitter.split_length for t in doc)
@pytest.mark.usefixtures("clean_underscore")
def test_factories_doc_cleaner():
nlp = Language()
nlp.add_pipe("doc_cleaner")
doc = nlp.make_doc("text")
doc.tensor = [1, 2, 3]
doc = nlp(doc)
assert doc.tensor is None
nlp = Language()
nlp.add_pipe("doc_cleaner", config={"silent": False})
with pytest.warns(UserWarning):
doc = nlp("text")
Doc.set_extension("test_attr", default=-1)
nlp = Language()
nlp.add_pipe("doc_cleaner", config={"attrs": {"_.test_attr": 0}})
doc = nlp.make_doc("text")
doc._.test_attr = 100
doc = nlp(doc)
assert doc._.test_attr == 0

View File

@ -130,3 +130,25 @@ exceed the transformer model max length.
| `min_length` | The minimum length for a token to be split. Defaults to `25`. ~~int~~ | | `min_length` | The minimum length for a token to be split. Defaults to `25`. ~~int~~ |
| `split_length` | The length of the split tokens. Defaults to `5`. ~~int~~ | | `split_length` | The length of the split tokens. Defaults to `5`. ~~int~~ |
| **RETURNS** | The modified `Doc` with the split tokens. ~~Doc~~ | | **RETURNS** | The modified `Doc` with the split tokens. ~~Doc~~ |
## doc_cleaner {#doc_cleaner tag="function" new="3.2.1"}
Clean up `Doc` attributes. Intended for use at the end of pipelines with
`tok2vec` or `transformer` pipeline components that store tensors and other
values that can require a lot of memory and frequently aren't needed after the
whole pipeline has run.
> #### Example
>
> ```python
> config = {"attrs": {"tensor": None}}
> nlp.add_pipe("doc_cleaner", config=config)
> doc = nlp("text")
> assert doc.tensor is None
> ```
| Setting | Description |
| ----------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `attrs` | A dict of the `Doc` attributes and the values to set them to. Defaults to `{"tensor": None, "_.trf_data": None}` to clean up after `tok2vec` and `transformer` components. ~~dict~~ |
| `silent` | If `False`, show warnings if attributes aren't found or can't be set. Defaults to `True`. ~~bool~~ |
| **RETURNS** | The modified `Doc` with the modified attributes. ~~Doc~~ |