💫 Allow passing of config parameters to specific pipeline components (#3386)

* Add component_cfg kwarg to begin_training

* Document component_cfg arg to begin_training

* Update docs and auto-format

* Support component_cfg across Language

* Format

* Update docs and docstrings [ci skip]

* Fix begin_training
This commit is contained in:
Matthew Honnibal 2019-03-10 23:36:47 +01:00 committed by Ines Montani
parent 8dbf1e9037
commit 98acf5ffe4
2 changed files with 75 additions and 35 deletions

View File

@ -106,6 +106,7 @@ class Language(object):
DOCS: https://spacy.io/api/language DOCS: https://spacy.io/api/language
""" """
Defaults = BaseDefaults Defaults = BaseDefaults
lang = None lang = None
@ -344,13 +345,15 @@ class Language(object):
raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names)) raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
return self.pipeline.pop(self.pipe_names.index(name)) return self.pipeline.pop(self.pipe_names.index(name))
def __call__(self, text, disable=[]): def __call__(self, text, disable=[], component_cfg=None):
"""Apply the pipeline to some text. The text can span multiple sentences, """Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string and can contain arbtrary whitespace. Alignment into the original string
is preserved. is preserved.
text (unicode): The text to be processed. text (unicode): The text to be processed.
disable (list): Names of the pipeline components to disable. disable (list): Names of the pipeline components to disable.
component_cfg (dict): An optional dictionary with extra keyword arguments
for specific components.
RETURNS (Doc): A container for accessing the annotations. RETURNS (Doc): A container for accessing the annotations.
EXAMPLE: EXAMPLE:
@ -363,12 +366,14 @@ class Language(object):
Errors.E088.format(length=len(text), max_length=self.max_length) Errors.E088.format(length=len(text), max_length=self.max_length)
) )
doc = self.make_doc(text) doc = self.make_doc(text)
if component_cfg is None:
component_cfg = {}
for name, proc in self.pipeline: for name, proc in self.pipeline:
if name in disable: if name in disable:
continue continue
if not hasattr(proc, "__call__"): if not hasattr(proc, "__call__"):
raise ValueError(Errors.E003.format(component=type(proc), name=name)) raise ValueError(Errors.E003.format(component=type(proc), name=name))
doc = proc(doc) doc = proc(doc, **component_cfg.get(name, {}))
if doc is None: if doc is None:
raise ValueError(Errors.E005.format(name=name)) raise ValueError(Errors.E005.format(name=name))
return doc return doc
@ -396,7 +401,7 @@ class Language(object):
def make_doc(self, text): def make_doc(self, text):
return self.tokenizer(text) return self.tokenizer(text)
def update(self, docs, golds, drop=0.0, sgd=None, losses=None): def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None):
"""Update the models in the pipeline. """Update the models in the pipeline.
docs (iterable): A batch of `Doc` objects. docs (iterable): A batch of `Doc` objects.
@ -443,11 +448,15 @@ class Language(object):
pipes = list(self.pipeline) pipes = list(self.pipeline)
random.shuffle(pipes) random.shuffle(pipes)
if component_cfg is None:
component_cfg = {}
for name, proc in pipes: for name, proc in pipes:
if not hasattr(proc, "update"): if not hasattr(proc, "update"):
continue continue
grads = {} grads = {}
proc.update(docs, golds, drop=drop, sgd=get_grads, losses=losses) kwargs = component_cfg.get(name, {})
kwargs.setdefault("drop", drop)
proc.update(docs, golds, sgd=get_grads, losses=losses, **kwargs)
for key, (W, dW) in grads.items(): for key, (W, dW) in grads.items():
sgd(W, dW, key=key) sgd(W, dW, key=key)
@ -517,11 +526,12 @@ class Language(object):
for doc, gold in docs_golds: for doc, gold in docs_golds:
yield doc, gold yield doc, gold
def begin_training(self, get_gold_tuples=None, sgd=None, **cfg): def begin_training(self, get_gold_tuples=None, sgd=None, component_cfg=None, **cfg):
"""Allocate models, pre-process training data and acquire a trainer and """Allocate models, pre-process training data and acquire a trainer and
optimizer. Used as a contextmanager. optimizer. Used as a contextmanager.
get_gold_tuples (function): Function returning gold data get_gold_tuples (function): Function returning gold data
component_cfg (dict): Config parameters for specific components.
**cfg: Config parameters. **cfg: Config parameters.
RETURNS: An optimizer RETURNS: An optimizer
""" """
@ -543,10 +553,17 @@ class Language(object):
if sgd is None: if sgd is None:
sgd = create_default_optimizer(Model.ops) sgd = create_default_optimizer(Model.ops)
self._optimizer = sgd self._optimizer = sgd
if component_cfg is None:
component_cfg = {}
for name, proc in self.pipeline: for name, proc in self.pipeline:
if hasattr(proc, "begin_training"): if hasattr(proc, "begin_training"):
kwargs = component_cfg.get(name, {})
kwargs.update(cfg)
proc.begin_training( proc.begin_training(
get_gold_tuples, pipeline=self.pipeline, sgd=self._optimizer, **cfg get_gold_tuples,
pipeline=self.pipeline,
sgd=self._optimizer,
**kwargs
) )
return self._optimizer return self._optimizer
@ -574,20 +591,27 @@ class Language(object):
proc._rehearsal_model = deepcopy(proc.model) proc._rehearsal_model = deepcopy(proc.model)
return self._optimizer return self._optimizer
def evaluate(self, docs_golds, verbose=False, batch_size=256): def evaluate(
scorer = Scorer() self, docs_golds, verbose=False, batch_size=256, scorer=None, component_cfg=None
):
if scorer is None:
scorer = Scorer()
docs, golds = zip(*docs_golds) docs, golds = zip(*docs_golds)
docs = list(docs) docs = list(docs)
golds = list(golds) golds = list(golds)
for name, pipe in self.pipeline: for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {})
kwargs.setdefault("batch_size", batch_size)
if not hasattr(pipe, "pipe"): if not hasattr(pipe, "pipe"):
docs = (pipe(doc) for doc in docs) docs = (pipe(doc, **kwargs) for doc in docs)
else: else:
docs = pipe.pipe(docs, batch_size=batch_size) docs = pipe.pipe(docs, **kwargs)
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
if verbose: if verbose:
print(doc) print(doc)
scorer.score(doc, gold, verbose=verbose) kwargs = component_cfg.get("scorer", {})
kwargs.setdefault("verbose", verbose)
scorer.score(doc, gold, **kwargs)
return scorer return scorer
@contextmanager @contextmanager
@ -630,6 +654,7 @@ class Language(object):
batch_size=1000, batch_size=1000,
disable=[], disable=[],
cleanup=False, cleanup=False,
component_cfg=None,
): ):
"""Process texts as a stream, and yield `Doc` objects in order. """Process texts as a stream, and yield `Doc` objects in order.
@ -643,6 +668,8 @@ class Language(object):
disable (list): Names of the pipeline components to disable. disable (list): Names of the pipeline components to disable.
cleanup (bool): If True, unneeded strings are freed, cleanup (bool): If True, unneeded strings are freed,
to control memory use. Experimental. to control memory use. Experimental.
component_cfg (dict): An optional dictionary with extra keyword arguments
for specific components.
YIELDS (Doc): Documents in the order of the original text. YIELDS (Doc): Documents in the order of the original text.
EXAMPLE: EXAMPLE:
@ -655,20 +682,30 @@ class Language(object):
texts = (tc[0] for tc in text_context1) texts = (tc[0] for tc in text_context1)
contexts = (tc[1] for tc in text_context2) contexts = (tc[1] for tc in text_context2)
docs = self.pipe( docs = self.pipe(
texts, n_threads=n_threads, batch_size=batch_size, disable=disable texts,
n_threads=n_threads,
batch_size=batch_size,
disable=disable,
component_cfg=component_cfg,
) )
for doc, context in izip(docs, contexts): for doc, context in izip(docs, contexts):
yield (doc, context) yield (doc, context)
return return
docs = (self.make_doc(text) for text in texts) docs = (self.make_doc(text) for text in texts)
if component_cfg is None:
component_cfg = {}
for name, proc in self.pipeline: for name, proc in self.pipeline:
if name in disable: if name in disable:
continue continue
kwargs = component_cfg.get(name, {})
# Allow component_cfg to overwrite the top-level kwargs.
kwargs.setdefault("batch_size", batch_size)
kwargs.setdefault("n_threads", n_threads)
if hasattr(proc, "pipe"): if hasattr(proc, "pipe"):
docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size) docs = proc.pipe(docs, **kwargs)
else: else:
# Apply the function, but yield the doc # Apply the function, but yield the doc
docs = _pipe(proc, docs) docs = _pipe(proc, docs, kwargs)
# Track weakrefs of "recent" documents, so that we can see when they # Track weakrefs of "recent" documents, so that we can see when they
# expire from memory. When they do, we know we don't need old strings. # expire from memory. When they do, we know we don't need old strings.
# This way, we avoid maintaining an unbounded growth in string entries # This way, we avoid maintaining an unbounded growth in string entries
@ -861,7 +898,7 @@ class DisabledPipes(list):
self[:] = [] self[:] = []
def _pipe(func, docs): def _pipe(func, docs, kwargs):
for doc in docs: for doc in docs:
doc = func(doc) doc = func(doc, **kwargs)
yield doc yield doc

View File

@ -91,13 +91,14 @@ multiprocessing.
> assert doc.is_parsed > assert doc.is_parsed
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| ------------ | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | | -------------------------------------------- | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `texts` | - | A sequence of unicode objects. | | `texts` | - | A sequence of unicode objects. |
| `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. | | `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. |
| `batch_size` | int | The number of texts to buffer. | | `batch_size` | int | The number of texts to buffer. |
| `disable` | list | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). | | `disable` | list | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
| **YIELDS** | `Doc` | Documents in the order of the original text. | | `component_cfg` <Tag variant="new">2.1</Tag> | dict | Config parameters for specific pipeline components, keyed by component name. |
| **YIELDS** | `Doc` | Documents in the order of the original text. |
## Language.update {#update tag="method"} ## Language.update {#update tag="method"}
@ -112,13 +113,14 @@ Update the models in the pipeline.
> nlp.update([doc], [gold], drop=0.5, sgd=optimizer) > nlp.update([doc], [gold], drop=0.5, sgd=optimizer)
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| ----------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | -------------------------------------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `docs` | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text. | | `docs` | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text. |
| `golds` | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). | | `golds` | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). |
| `drop` | float | The dropout rate. | | `drop` | float | The dropout rate. |
| `sgd` | callable | An optimizer. | | `sgd` | callable | An optimizer. |
| **RETURNS** | dict | Results from the update. | | `component_cfg` <Tag variant="new">2.1</Tag> | dict | Config parameters for specific pipeline components, keyed by component name. |
| **RETURNS** | dict | Results from the update. |
## Language.begin_training {#begin_training tag="method"} ## Language.begin_training {#begin_training tag="method"}
@ -130,11 +132,12 @@ Allocate models, pre-process training data and acquire an optimizer.
> optimizer = nlp.begin_training(gold_tuples) > optimizer = nlp.begin_training(gold_tuples)
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| ------------- | -------- | ---------------------------- | | -------------------------------------------- | -------- | ---------------------------------------------------------------------------- |
| `gold_tuples` | iterable | Gold-standard training data. | | `gold_tuples` | iterable | Gold-standard training data. |
| `**cfg` | - | Config parameters. | | `component_cfg` <Tag variant="new">2.1</Tag> | dict | Config parameters for specific pipeline components, keyed by component name. |
| **RETURNS** | callable | An optimizer. | | `**cfg` | - | Config parameters (sent to all components). |
| **RETURNS** | callable | An optimizer. |
## Language.use_params {#use_params tag="contextmanager, method"} ## Language.use_params {#use_params tag="contextmanager, method"}