From 98acf5ffe408d3ec58fcfba0e0deb742891815d5 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 10 Mar 2019 23:36:47 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AB=20Allow=20passing=20of=20config=20?= =?UTF-8?q?parameters=20to=20specific=20pipeline=20components=20(#3386)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add component_cfg kwarg to begin_training * Document component_cfg arg to begin_training * Update docs and auto-format * Support component_cfg across Language * Format * Update docs and docstrings [ci skip] * Fix begin_training --- spacy/language.py | 69 +++++++++++++++++++++++++++--------- website/docs/api/language.md | 41 +++++++++++---------- 2 files changed, 75 insertions(+), 35 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 6fb30e46d..44a819132 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -106,6 +106,7 @@ class Language(object): DOCS: https://spacy.io/api/language """ + Defaults = BaseDefaults lang = None @@ -344,13 +345,15 @@ class Language(object): raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names)) return self.pipeline.pop(self.pipe_names.index(name)) - def __call__(self, text, disable=[]): + def __call__(self, text, disable=[], component_cfg=None): """Apply the pipeline to some text. The text can span multiple sentences, and can contain arbtrary whitespace. Alignment into the original string is preserved. text (unicode): The text to be processed. disable (list): Names of the pipeline components to disable. + component_cfg (dict): An optional dictionary with extra keyword arguments + for specific components. RETURNS (Doc): A container for accessing the annotations. EXAMPLE: @@ -363,12 +366,14 @@ class Language(object): Errors.E088.format(length=len(text), max_length=self.max_length) ) doc = self.make_doc(text) + if component_cfg is None: + component_cfg = {} for name, proc in self.pipeline: if name in disable: continue if not hasattr(proc, "__call__"): raise ValueError(Errors.E003.format(component=type(proc), name=name)) - doc = proc(doc) + doc = proc(doc, **component_cfg.get(name, {})) if doc is None: raise ValueError(Errors.E005.format(name=name)) return doc @@ -396,7 +401,7 @@ class Language(object): def make_doc(self, text): return self.tokenizer(text) - def update(self, docs, golds, drop=0.0, sgd=None, losses=None): + def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None): """Update the models in the pipeline. docs (iterable): A batch of `Doc` objects. @@ -443,11 +448,15 @@ class Language(object): pipes = list(self.pipeline) random.shuffle(pipes) + if component_cfg is None: + component_cfg = {} for name, proc in pipes: if not hasattr(proc, "update"): continue grads = {} - proc.update(docs, golds, drop=drop, sgd=get_grads, losses=losses) + kwargs = component_cfg.get(name, {}) + kwargs.setdefault("drop", drop) + proc.update(docs, golds, sgd=get_grads, losses=losses, **kwargs) for key, (W, dW) in grads.items(): sgd(W, dW, key=key) @@ -517,11 +526,12 @@ class Language(object): for doc, gold in docs_golds: yield doc, gold - def begin_training(self, get_gold_tuples=None, sgd=None, **cfg): + def begin_training(self, get_gold_tuples=None, sgd=None, component_cfg=None, **cfg): """Allocate models, pre-process training data and acquire a trainer and optimizer. Used as a contextmanager. get_gold_tuples (function): Function returning gold data + component_cfg (dict): Config parameters for specific components. **cfg: Config parameters. RETURNS: An optimizer """ @@ -543,10 +553,17 @@ class Language(object): if sgd is None: sgd = create_default_optimizer(Model.ops) self._optimizer = sgd + if component_cfg is None: + component_cfg = {} for name, proc in self.pipeline: if hasattr(proc, "begin_training"): + kwargs = component_cfg.get(name, {}) + kwargs.update(cfg) proc.begin_training( - get_gold_tuples, pipeline=self.pipeline, sgd=self._optimizer, **cfg + get_gold_tuples, + pipeline=self.pipeline, + sgd=self._optimizer, + **kwargs ) return self._optimizer @@ -574,20 +591,27 @@ class Language(object): proc._rehearsal_model = deepcopy(proc.model) return self._optimizer - def evaluate(self, docs_golds, verbose=False, batch_size=256): - scorer = Scorer() + def evaluate( + self, docs_golds, verbose=False, batch_size=256, scorer=None, component_cfg=None + ): + if scorer is None: + scorer = Scorer() docs, golds = zip(*docs_golds) docs = list(docs) golds = list(golds) for name, pipe in self.pipeline: + kwargs = component_cfg.get(name, {}) + kwargs.setdefault("batch_size", batch_size) if not hasattr(pipe, "pipe"): - docs = (pipe(doc) for doc in docs) + docs = (pipe(doc, **kwargs) for doc in docs) else: - docs = pipe.pipe(docs, batch_size=batch_size) + docs = pipe.pipe(docs, **kwargs) for doc, gold in zip(docs, golds): if verbose: print(doc) - scorer.score(doc, gold, verbose=verbose) + kwargs = component_cfg.get("scorer", {}) + kwargs.setdefault("verbose", verbose) + scorer.score(doc, gold, **kwargs) return scorer @contextmanager @@ -630,6 +654,7 @@ class Language(object): batch_size=1000, disable=[], cleanup=False, + component_cfg=None, ): """Process texts as a stream, and yield `Doc` objects in order. @@ -643,6 +668,8 @@ class Language(object): disable (list): Names of the pipeline components to disable. cleanup (bool): If True, unneeded strings are freed, to control memory use. Experimental. + component_cfg (dict): An optional dictionary with extra keyword arguments + for specific components. YIELDS (Doc): Documents in the order of the original text. EXAMPLE: @@ -655,20 +682,30 @@ class Language(object): texts = (tc[0] for tc in text_context1) contexts = (tc[1] for tc in text_context2) docs = self.pipe( - texts, n_threads=n_threads, batch_size=batch_size, disable=disable + texts, + n_threads=n_threads, + batch_size=batch_size, + disable=disable, + component_cfg=component_cfg, ) for doc, context in izip(docs, contexts): yield (doc, context) return docs = (self.make_doc(text) for text in texts) + if component_cfg is None: + component_cfg = {} for name, proc in self.pipeline: if name in disable: continue + kwargs = component_cfg.get(name, {}) + # Allow component_cfg to overwrite the top-level kwargs. + kwargs.setdefault("batch_size", batch_size) + kwargs.setdefault("n_threads", n_threads) if hasattr(proc, "pipe"): - docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size) + docs = proc.pipe(docs, **kwargs) else: # Apply the function, but yield the doc - docs = _pipe(proc, docs) + docs = _pipe(proc, docs, kwargs) # Track weakrefs of "recent" documents, so that we can see when they # expire from memory. When they do, we know we don't need old strings. # This way, we avoid maintaining an unbounded growth in string entries @@ -861,7 +898,7 @@ class DisabledPipes(list): self[:] = [] -def _pipe(func, docs): +def _pipe(func, docs, kwargs): for doc in docs: - doc = func(doc) + doc = func(doc, **kwargs) yield doc diff --git a/website/docs/api/language.md b/website/docs/api/language.md index 34d14ec01..a8598815b 100644 --- a/website/docs/api/language.md +++ b/website/docs/api/language.md @@ -91,13 +91,14 @@ multiprocessing. > assert doc.is_parsed > ``` -| Name | Type | Description | -| ------------ | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `texts` | - | A sequence of unicode objects. | -| `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. | -| `batch_size` | int | The number of texts to buffer. | -| `disable` | list | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). | -| **YIELDS** | `Doc` | Documents in the order of the original text. | +| Name | Type | Description | +| -------------------------------------------- | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `texts` | - | A sequence of unicode objects. | +| `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. | +| `batch_size` | int | The number of texts to buffer. | +| `disable` | list | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). | +| `component_cfg` 2.1 | dict | Config parameters for specific pipeline components, keyed by component name. | +| **YIELDS** | `Doc` | Documents in the order of the original text. | ## Language.update {#update tag="method"} @@ -112,13 +113,14 @@ Update the models in the pipeline. > nlp.update([doc], [gold], drop=0.5, sgd=optimizer) > ``` -| Name | Type | Description | -| ----------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `docs` | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text. | -| `golds` | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). | -| `drop` | float | The dropout rate. | -| `sgd` | callable | An optimizer. | -| **RETURNS** | dict | Results from the update. | +| Name | Type | Description | +| -------------------------------------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `docs` | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text. | +| `golds` | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). | +| `drop` | float | The dropout rate. | +| `sgd` | callable | An optimizer. | +| `component_cfg` 2.1 | dict | Config parameters for specific pipeline components, keyed by component name. | +| **RETURNS** | dict | Results from the update. | ## Language.begin_training {#begin_training tag="method"} @@ -130,11 +132,12 @@ Allocate models, pre-process training data and acquire an optimizer. > optimizer = nlp.begin_training(gold_tuples) > ``` -| Name | Type | Description | -| ------------- | -------- | ---------------------------- | -| `gold_tuples` | iterable | Gold-standard training data. | -| `**cfg` | - | Config parameters. | -| **RETURNS** | callable | An optimizer. | +| Name | Type | Description | +| -------------------------------------------- | -------- | ---------------------------------------------------------------------------- | +| `gold_tuples` | iterable | Gold-standard training data. | +| `component_cfg` 2.1 | dict | Config parameters for specific pipeline components, keyed by component name. | +| `**cfg` | - | Config parameters (sent to all components). | +| **RETURNS** | callable | An optimizer. | ## Language.use_params {#use_params tag="contextmanager, method"}