Add input length error, to address #1826

This commit is contained in:
Matthew Honnibal 2018-03-29 21:45:26 +02:00
parent cca7e7ad11
commit 23afa6429f

View File

@ -111,7 +111,7 @@ class Language(object):
'merge_entities': lambda nlp, **cfg: merge_entities
}
def __init__(self, vocab=True, make_doc=True, meta={}, **kwargs):
def __init__(self, vocab=True, make_doc=True, max_length=10**6, meta={}, **kwargs):
"""Initialise a Language object.
vocab (Vocab): A `Vocab` object. If `True`, a vocab is created via
@ -126,6 +126,15 @@ class Language(object):
string occurs in both, the component is not loaded.
meta (dict): Custom meta data for the Language class. Is written to by
models to add model meta data.
max_length (int) :
Maximum number of characters in a single text. The current v2 models
may run out memory on extremely long texts, due to large internal
allocations. You should segment these texts into meaningful units,
e.g. paragraphs, subsections etc, before passing them to spaCy.
Default maximum length is 1,000,000 characters (1mb). As a rule of
thumb, if all pipeline components are enabled, spaCy's default
models currently requires roughly 1GB of temporary memory per
100,000 characters in one text.
RETURNS (Language): The newly constructed object.
"""
self._meta = dict(meta)
@ -141,6 +150,7 @@ class Language(object):
make_doc = factory(self, **meta.get('tokenizer', {}))
self.tokenizer = make_doc
self.pipeline = []
self.max_length = max_length
self._optimizer = None
@property
@ -340,6 +350,17 @@ class Language(object):
>>> tokens[0].text, tokens[0].head.tag_
('An', 'NN')
"""
if len(text) >= self.max_length:
msg = (
"Text of length {length} exceeds maximum of {max_length}. "
"The v2 parser and NER models require roughly 1GB of temporary "
"memory per 100,000 characters in the input. This means long "
"texts may cause memory allocation errors. If you're not using "
"the parser or NER, it's probably safe to increase the "
"nlp.max_length limit. The limit is in number of characters, "
"so you can check whether your inputs are too long by checking "
"len(text)".)
raise ValueError(msg.format(length=len(text), max_length=self.max_length))
doc = self.make_doc(text)
for name, proc in self.pipeline:
if name in disable: