diff --git a/spacy/language.py b/spacy/language.py
index 6fb30e46d..44a819132 100644
--- a/spacy/language.py
+++ b/spacy/language.py
@@ -106,6 +106,7 @@ class Language(object):
DOCS: https://spacy.io/api/language
"""
+
Defaults = BaseDefaults
lang = None
@@ -344,13 +345,15 @@ class Language(object):
raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
return self.pipeline.pop(self.pipe_names.index(name))
- def __call__(self, text, disable=[]):
+ def __call__(self, text, disable=[], component_cfg=None):
"""Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string
is preserved.
text (unicode): The text to be processed.
disable (list): Names of the pipeline components to disable.
+ component_cfg (dict): An optional dictionary with extra keyword arguments
+ for specific components.
RETURNS (Doc): A container for accessing the annotations.
EXAMPLE:
@@ -363,12 +366,14 @@ class Language(object):
Errors.E088.format(length=len(text), max_length=self.max_length)
)
doc = self.make_doc(text)
+ if component_cfg is None:
+ component_cfg = {}
for name, proc in self.pipeline:
if name in disable:
continue
if not hasattr(proc, "__call__"):
raise ValueError(Errors.E003.format(component=type(proc), name=name))
- doc = proc(doc)
+ doc = proc(doc, **component_cfg.get(name, {}))
if doc is None:
raise ValueError(Errors.E005.format(name=name))
return doc
@@ -396,7 +401,7 @@ class Language(object):
def make_doc(self, text):
return self.tokenizer(text)
- def update(self, docs, golds, drop=0.0, sgd=None, losses=None):
+ def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None):
"""Update the models in the pipeline.
docs (iterable): A batch of `Doc` objects.
@@ -443,11 +448,15 @@ class Language(object):
pipes = list(self.pipeline)
random.shuffle(pipes)
+ if component_cfg is None:
+ component_cfg = {}
for name, proc in pipes:
if not hasattr(proc, "update"):
continue
grads = {}
- proc.update(docs, golds, drop=drop, sgd=get_grads, losses=losses)
+ kwargs = component_cfg.get(name, {})
+ kwargs.setdefault("drop", drop)
+ proc.update(docs, golds, sgd=get_grads, losses=losses, **kwargs)
for key, (W, dW) in grads.items():
sgd(W, dW, key=key)
@@ -517,11 +526,12 @@ class Language(object):
for doc, gold in docs_golds:
yield doc, gold
- def begin_training(self, get_gold_tuples=None, sgd=None, **cfg):
+ def begin_training(self, get_gold_tuples=None, sgd=None, component_cfg=None, **cfg):
"""Allocate models, pre-process training data and acquire a trainer and
optimizer. Used as a contextmanager.
get_gold_tuples (function): Function returning gold data
+ component_cfg (dict): Config parameters for specific components.
**cfg: Config parameters.
RETURNS: An optimizer
"""
@@ -543,10 +553,17 @@ class Language(object):
if sgd is None:
sgd = create_default_optimizer(Model.ops)
self._optimizer = sgd
+ if component_cfg is None:
+ component_cfg = {}
for name, proc in self.pipeline:
if hasattr(proc, "begin_training"):
+ kwargs = component_cfg.get(name, {})
+ kwargs.update(cfg)
proc.begin_training(
- get_gold_tuples, pipeline=self.pipeline, sgd=self._optimizer, **cfg
+ get_gold_tuples,
+ pipeline=self.pipeline,
+ sgd=self._optimizer,
+ **kwargs
)
return self._optimizer
@@ -574,20 +591,27 @@ class Language(object):
proc._rehearsal_model = deepcopy(proc.model)
return self._optimizer
- def evaluate(self, docs_golds, verbose=False, batch_size=256):
- scorer = Scorer()
+ def evaluate(
+ self, docs_golds, verbose=False, batch_size=256, scorer=None, component_cfg=None
+ ):
+ if scorer is None:
+ scorer = Scorer()
docs, golds = zip(*docs_golds)
docs = list(docs)
golds = list(golds)
for name, pipe in self.pipeline:
+ kwargs = component_cfg.get(name, {})
+ kwargs.setdefault("batch_size", batch_size)
if not hasattr(pipe, "pipe"):
- docs = (pipe(doc) for doc in docs)
+ docs = (pipe(doc, **kwargs) for doc in docs)
else:
- docs = pipe.pipe(docs, batch_size=batch_size)
+ docs = pipe.pipe(docs, **kwargs)
for doc, gold in zip(docs, golds):
if verbose:
print(doc)
- scorer.score(doc, gold, verbose=verbose)
+ kwargs = component_cfg.get("scorer", {})
+ kwargs.setdefault("verbose", verbose)
+ scorer.score(doc, gold, **kwargs)
return scorer
@contextmanager
@@ -630,6 +654,7 @@ class Language(object):
batch_size=1000,
disable=[],
cleanup=False,
+ component_cfg=None,
):
"""Process texts as a stream, and yield `Doc` objects in order.
@@ -643,6 +668,8 @@ class Language(object):
disable (list): Names of the pipeline components to disable.
cleanup (bool): If True, unneeded strings are freed,
to control memory use. Experimental.
+ component_cfg (dict): An optional dictionary with extra keyword arguments
+ for specific components.
YIELDS (Doc): Documents in the order of the original text.
EXAMPLE:
@@ -655,20 +682,30 @@ class Language(object):
texts = (tc[0] for tc in text_context1)
contexts = (tc[1] for tc in text_context2)
docs = self.pipe(
- texts, n_threads=n_threads, batch_size=batch_size, disable=disable
+ texts,
+ n_threads=n_threads,
+ batch_size=batch_size,
+ disable=disable,
+ component_cfg=component_cfg,
)
for doc, context in izip(docs, contexts):
yield (doc, context)
return
docs = (self.make_doc(text) for text in texts)
+ if component_cfg is None:
+ component_cfg = {}
for name, proc in self.pipeline:
if name in disable:
continue
+ kwargs = component_cfg.get(name, {})
+ # Allow component_cfg to overwrite the top-level kwargs.
+ kwargs.setdefault("batch_size", batch_size)
+ kwargs.setdefault("n_threads", n_threads)
if hasattr(proc, "pipe"):
- docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size)
+ docs = proc.pipe(docs, **kwargs)
else:
# Apply the function, but yield the doc
- docs = _pipe(proc, docs)
+ docs = _pipe(proc, docs, kwargs)
# Track weakrefs of "recent" documents, so that we can see when they
# expire from memory. When they do, we know we don't need old strings.
# This way, we avoid maintaining an unbounded growth in string entries
@@ -861,7 +898,7 @@ class DisabledPipes(list):
self[:] = []
-def _pipe(func, docs):
+def _pipe(func, docs, kwargs):
for doc in docs:
- doc = func(doc)
+ doc = func(doc, **kwargs)
yield doc
diff --git a/website/docs/api/language.md b/website/docs/api/language.md
index 34d14ec01..a8598815b 100644
--- a/website/docs/api/language.md
+++ b/website/docs/api/language.md
@@ -91,13 +91,14 @@ multiprocessing.
> assert doc.is_parsed
> ```
-| Name | Type | Description |
-| ------------ | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `texts` | - | A sequence of unicode objects. |
-| `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. |
-| `batch_size` | int | The number of texts to buffer. |
-| `disable` | list | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
-| **YIELDS** | `Doc` | Documents in the order of the original text. |
+| Name | Type | Description |
+| -------------------------------------------- | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `texts` | - | A sequence of unicode objects. |
+| `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. |
+| `batch_size` | int | The number of texts to buffer. |
+| `disable` | list | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
+| `component_cfg` 2.1 | dict | Config parameters for specific pipeline components, keyed by component name. |
+| **YIELDS** | `Doc` | Documents in the order of the original text. |
## Language.update {#update tag="method"}
@@ -112,13 +113,14 @@ Update the models in the pipeline.
> nlp.update([doc], [gold], drop=0.5, sgd=optimizer)
> ```
-| Name | Type | Description |
-| ----------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `docs` | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text. |
-| `golds` | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). |
-| `drop` | float | The dropout rate. |
-| `sgd` | callable | An optimizer. |
-| **RETURNS** | dict | Results from the update. |
+| Name | Type | Description |
+| -------------------------------------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `docs` | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text. |
+| `golds` | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). |
+| `drop` | float | The dropout rate. |
+| `sgd` | callable | An optimizer. |
+| `component_cfg` 2.1 | dict | Config parameters for specific pipeline components, keyed by component name. |
+| **RETURNS** | dict | Results from the update. |
## Language.begin_training {#begin_training tag="method"}
@@ -130,11 +132,12 @@ Allocate models, pre-process training data and acquire an optimizer.
> optimizer = nlp.begin_training(gold_tuples)
> ```
-| Name | Type | Description |
-| ------------- | -------- | ---------------------------- |
-| `gold_tuples` | iterable | Gold-standard training data. |
-| `**cfg` | - | Config parameters. |
-| **RETURNS** | callable | An optimizer. |
+| Name | Type | Description |
+| -------------------------------------------- | -------- | ---------------------------------------------------------------------------- |
+| `gold_tuples` | iterable | Gold-standard training data. |
+| `component_cfg` 2.1 | dict | Config parameters for specific pipeline components, keyed by component name. |
+| `**cfg` | - | Config parameters (sent to all components). |
+| **RETURNS** | callable | An optimizer. |
## Language.use_params {#use_params tag="contextmanager, method"}