mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Add doc_cleaner component (#9659)
* Add doc_cleaner component * Fix types * Fix loop * Rephrase method description
This commit is contained in:
parent
a77f50baa4
commit
9ac6d4991e
|
@ -191,6 +191,7 @@ class Warnings(metaclass=ErrorsWithCodes):
|
||||||
"lead to errors.")
|
"lead to errors.")
|
||||||
W115 = ("Skipping {method}: the floret vector table cannot be modified. "
|
W115 = ("Skipping {method}: the floret vector table cannot be modified. "
|
||||||
"Vectors are calculated from character ngrams.")
|
"Vectors are calculated from character ngrams.")
|
||||||
|
W116 = ("Unable to clean attribute '{attr}'.")
|
||||||
|
|
||||||
|
|
||||||
class Errors(metaclass=ErrorsWithCodes):
|
class Errors(metaclass=ErrorsWithCodes):
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
import srsly
|
import srsly
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from ..errors import Warnings
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..matcher import Matcher
|
from ..matcher import Matcher
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
|
@ -136,3 +138,65 @@ class TokenSplitter:
|
||||||
"cfg": lambda p: self._set_config(srsly.read_json(p)),
|
"cfg": lambda p: self._set_config(srsly.read_json(p)),
|
||||||
}
|
}
|
||||||
util.from_disk(path, serializers, [])
|
util.from_disk(path, serializers, [])
|
||||||
|
|
||||||
|
|
||||||
|
@Language.factory(
|
||||||
|
"doc_cleaner",
|
||||||
|
default_config={"attrs": {"tensor": None, "_.trf_data": None}, "silent": True},
|
||||||
|
)
|
||||||
|
def make_doc_cleaner(nlp: Language, name: str, *, attrs: Dict[str, Any], silent: bool):
|
||||||
|
return DocCleaner(attrs, silent=silent)
|
||||||
|
|
||||||
|
|
||||||
|
class DocCleaner:
|
||||||
|
def __init__(self, attrs: Dict[str, Any], *, silent: bool = True):
|
||||||
|
self.cfg: Dict[str, Any] = {"attrs": dict(attrs), "silent": silent}
|
||||||
|
|
||||||
|
def __call__(self, doc: Doc) -> Doc:
|
||||||
|
attrs: dict = self.cfg["attrs"]
|
||||||
|
silent: bool = self.cfg["silent"]
|
||||||
|
for attr, value in attrs.items():
|
||||||
|
obj = doc
|
||||||
|
parts = attr.split(".")
|
||||||
|
skip = False
|
||||||
|
for part in parts[:-1]:
|
||||||
|
if hasattr(obj, part):
|
||||||
|
obj = getattr(obj, part)
|
||||||
|
else:
|
||||||
|
skip = True
|
||||||
|
if not silent:
|
||||||
|
warnings.warn(Warnings.W116.format(attr=attr))
|
||||||
|
if not skip:
|
||||||
|
if hasattr(obj, parts[-1]):
|
||||||
|
setattr(obj, parts[-1], value)
|
||||||
|
else:
|
||||||
|
if not silent:
|
||||||
|
warnings.warn(Warnings.W116.format(attr=attr))
|
||||||
|
return doc
|
||||||
|
|
||||||
|
def to_bytes(self, **kwargs):
|
||||||
|
serializers = {
|
||||||
|
"cfg": lambda: srsly.json_dumps(self.cfg),
|
||||||
|
}
|
||||||
|
return util.to_bytes(serializers, [])
|
||||||
|
|
||||||
|
def from_bytes(self, data, **kwargs):
|
||||||
|
deserializers = {
|
||||||
|
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
||||||
|
}
|
||||||
|
util.from_bytes(data, deserializers, [])
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_disk(self, path, **kwargs):
|
||||||
|
path = util.ensure_path(path)
|
||||||
|
serializers = {
|
||||||
|
"cfg": lambda p: srsly.write_json(p, self.cfg),
|
||||||
|
}
|
||||||
|
return util.to_disk(path, serializers, [])
|
||||||
|
|
||||||
|
def from_disk(self, path, **kwargs):
|
||||||
|
path = util.ensure_path(path)
|
||||||
|
serializers = {
|
||||||
|
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
||||||
|
}
|
||||||
|
util.from_disk(path, serializers, [])
|
||||||
|
|
|
@ -3,6 +3,8 @@ from spacy.pipeline.functions import merge_subtokens
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.tokens import Span, Doc
|
from spacy.tokens import Span, Doc
|
||||||
|
|
||||||
|
from ..doc.test_underscore import clean_underscore # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def doc(en_vocab):
|
def doc(en_vocab):
|
||||||
|
@ -74,3 +76,26 @@ def test_token_splitter():
|
||||||
"i",
|
"i",
|
||||||
]
|
]
|
||||||
assert all(len(t.text) <= token_splitter.split_length for t in doc)
|
assert all(len(t.text) <= token_splitter.split_length for t in doc)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("clean_underscore")
|
||||||
|
def test_factories_doc_cleaner():
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe("doc_cleaner")
|
||||||
|
doc = nlp.make_doc("text")
|
||||||
|
doc.tensor = [1, 2, 3]
|
||||||
|
doc = nlp(doc)
|
||||||
|
assert doc.tensor is None
|
||||||
|
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe("doc_cleaner", config={"silent": False})
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
doc = nlp("text")
|
||||||
|
|
||||||
|
Doc.set_extension("test_attr", default=-1)
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe("doc_cleaner", config={"attrs": {"_.test_attr": 0}})
|
||||||
|
doc = nlp.make_doc("text")
|
||||||
|
doc._.test_attr = 100
|
||||||
|
doc = nlp(doc)
|
||||||
|
assert doc._.test_attr == 0
|
||||||
|
|
|
@ -130,3 +130,25 @@ exceed the transformer model max length.
|
||||||
| `min_length` | The minimum length for a token to be split. Defaults to `25`. ~~int~~ |
|
| `min_length` | The minimum length for a token to be split. Defaults to `25`. ~~int~~ |
|
||||||
| `split_length` | The length of the split tokens. Defaults to `5`. ~~int~~ |
|
| `split_length` | The length of the split tokens. Defaults to `5`. ~~int~~ |
|
||||||
| **RETURNS** | The modified `Doc` with the split tokens. ~~Doc~~ |
|
| **RETURNS** | The modified `Doc` with the split tokens. ~~Doc~~ |
|
||||||
|
|
||||||
|
## doc_cleaner {#doc_cleaner tag="function" new="3.2.1"}
|
||||||
|
|
||||||
|
Clean up `Doc` attributes. Intended for use at the end of pipelines with
|
||||||
|
`tok2vec` or `transformer` pipeline components that store tensors and other
|
||||||
|
values that can require a lot of memory and frequently aren't needed after the
|
||||||
|
whole pipeline has run.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> config = {"attrs": {"tensor": None}}
|
||||||
|
> nlp.add_pipe("doc_cleaner", config=config)
|
||||||
|
> doc = nlp("text")
|
||||||
|
> assert doc.tensor is None
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Setting | Description |
|
||||||
|
| ----------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `attrs` | A dict of the `Doc` attributes and the values to set them to. Defaults to `{"tensor": None, "_.trf_data": None}` to clean up after `tok2vec` and `transformer` components. ~~dict~~ |
|
||||||
|
| `silent` | If `False`, show warnings if attributes aren't found or can't be set. Defaults to `True`. ~~bool~~ |
|
||||||
|
| **RETURNS** | The modified `Doc` with the modified attributes. ~~Doc~~ |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user