mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	💫 Allow passing of config parameters to specific pipeline components (#3386)
* Add component_cfg kwarg to begin_training * Document component_cfg arg to begin_training * Update docs and auto-format * Support component_cfg across Language * Format * Update docs and docstrings [ci skip] * Fix begin_training
This commit is contained in:
		
							parent
							
								
									8dbf1e9037
								
							
						
					
					
						commit
						98acf5ffe4
					
				| 
						 | 
					@ -106,6 +106,7 @@ class Language(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DOCS: https://spacy.io/api/language
 | 
					    DOCS: https://spacy.io/api/language
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Defaults = BaseDefaults
 | 
					    Defaults = BaseDefaults
 | 
				
			||||||
    lang = None
 | 
					    lang = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -344,13 +345,15 @@ class Language(object):
 | 
				
			||||||
            raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
 | 
					            raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
 | 
				
			||||||
        return self.pipeline.pop(self.pipe_names.index(name))
 | 
					        return self.pipeline.pop(self.pipe_names.index(name))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, text, disable=[]):
 | 
					    def __call__(self, text, disable=[], component_cfg=None):
 | 
				
			||||||
        """Apply the pipeline to some text. The text can span multiple sentences,
 | 
					        """Apply the pipeline to some text. The text can span multiple sentences,
 | 
				
			||||||
        and can contain arbtrary whitespace. Alignment into the original string
 | 
					        and can contain arbtrary whitespace. Alignment into the original string
 | 
				
			||||||
        is preserved.
 | 
					        is preserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        text (unicode): The text to be processed.
 | 
					        text (unicode): The text to be processed.
 | 
				
			||||||
        disable (list): Names of the pipeline components to disable.
 | 
					        disable (list): Names of the pipeline components to disable.
 | 
				
			||||||
 | 
					        component_cfg (dict): An optional dictionary with extra keyword arguments
 | 
				
			||||||
 | 
					            for specific components.
 | 
				
			||||||
        RETURNS (Doc): A container for accessing the annotations.
 | 
					        RETURNS (Doc): A container for accessing the annotations.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        EXAMPLE:
 | 
					        EXAMPLE:
 | 
				
			||||||
| 
						 | 
					@ -363,12 +366,14 @@ class Language(object):
 | 
				
			||||||
                Errors.E088.format(length=len(text), max_length=self.max_length)
 | 
					                Errors.E088.format(length=len(text), max_length=self.max_length)
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        doc = self.make_doc(text)
 | 
					        doc = self.make_doc(text)
 | 
				
			||||||
 | 
					        if component_cfg is None:
 | 
				
			||||||
 | 
					            component_cfg = {}
 | 
				
			||||||
        for name, proc in self.pipeline:
 | 
					        for name, proc in self.pipeline:
 | 
				
			||||||
            if name in disable:
 | 
					            if name in disable:
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
            if not hasattr(proc, "__call__"):
 | 
					            if not hasattr(proc, "__call__"):
 | 
				
			||||||
                raise ValueError(Errors.E003.format(component=type(proc), name=name))
 | 
					                raise ValueError(Errors.E003.format(component=type(proc), name=name))
 | 
				
			||||||
            doc = proc(doc)
 | 
					            doc = proc(doc, **component_cfg.get(name, {}))
 | 
				
			||||||
            if doc is None:
 | 
					            if doc is None:
 | 
				
			||||||
                raise ValueError(Errors.E005.format(name=name))
 | 
					                raise ValueError(Errors.E005.format(name=name))
 | 
				
			||||||
        return doc
 | 
					        return doc
 | 
				
			||||||
| 
						 | 
					@ -396,7 +401,7 @@ class Language(object):
 | 
				
			||||||
    def make_doc(self, text):
 | 
					    def make_doc(self, text):
 | 
				
			||||||
        return self.tokenizer(text)
 | 
					        return self.tokenizer(text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, docs, golds, drop=0.0, sgd=None, losses=None):
 | 
					    def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None):
 | 
				
			||||||
        """Update the models in the pipeline.
 | 
					        """Update the models in the pipeline.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        docs (iterable): A batch of `Doc` objects.
 | 
					        docs (iterable): A batch of `Doc` objects.
 | 
				
			||||||
| 
						 | 
					@ -443,11 +448,15 @@ class Language(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        pipes = list(self.pipeline)
 | 
					        pipes = list(self.pipeline)
 | 
				
			||||||
        random.shuffle(pipes)
 | 
					        random.shuffle(pipes)
 | 
				
			||||||
 | 
					        if component_cfg is None:
 | 
				
			||||||
 | 
					            component_cfg = {}
 | 
				
			||||||
        for name, proc in pipes:
 | 
					        for name, proc in pipes:
 | 
				
			||||||
            if not hasattr(proc, "update"):
 | 
					            if not hasattr(proc, "update"):
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
            grads = {}
 | 
					            grads = {}
 | 
				
			||||||
            proc.update(docs, golds, drop=drop, sgd=get_grads, losses=losses)
 | 
					            kwargs = component_cfg.get(name, {})
 | 
				
			||||||
 | 
					            kwargs.setdefault("drop", drop)
 | 
				
			||||||
 | 
					            proc.update(docs, golds, sgd=get_grads, losses=losses, **kwargs)
 | 
				
			||||||
            for key, (W, dW) in grads.items():
 | 
					            for key, (W, dW) in grads.items():
 | 
				
			||||||
                sgd(W, dW, key=key)
 | 
					                sgd(W, dW, key=key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -517,11 +526,12 @@ class Language(object):
 | 
				
			||||||
        for doc, gold in docs_golds:
 | 
					        for doc, gold in docs_golds:
 | 
				
			||||||
            yield doc, gold
 | 
					            yield doc, gold
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def begin_training(self, get_gold_tuples=None, sgd=None, **cfg):
 | 
					    def begin_training(self, get_gold_tuples=None, sgd=None, component_cfg=None, **cfg):
 | 
				
			||||||
        """Allocate models, pre-process training data and acquire a trainer and
 | 
					        """Allocate models, pre-process training data and acquire a trainer and
 | 
				
			||||||
        optimizer. Used as a contextmanager.
 | 
					        optimizer. Used as a contextmanager.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        get_gold_tuples (function): Function returning gold data
 | 
					        get_gold_tuples (function): Function returning gold data
 | 
				
			||||||
 | 
					        component_cfg (dict): Config parameters for specific components.
 | 
				
			||||||
        **cfg: Config parameters.
 | 
					        **cfg: Config parameters.
 | 
				
			||||||
        RETURNS: An optimizer
 | 
					        RETURNS: An optimizer
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -543,10 +553,17 @@ class Language(object):
 | 
				
			||||||
        if sgd is None:
 | 
					        if sgd is None:
 | 
				
			||||||
            sgd = create_default_optimizer(Model.ops)
 | 
					            sgd = create_default_optimizer(Model.ops)
 | 
				
			||||||
        self._optimizer = sgd
 | 
					        self._optimizer = sgd
 | 
				
			||||||
 | 
					        if component_cfg is None:
 | 
				
			||||||
 | 
					            component_cfg = {}
 | 
				
			||||||
        for name, proc in self.pipeline:
 | 
					        for name, proc in self.pipeline:
 | 
				
			||||||
            if hasattr(proc, "begin_training"):
 | 
					            if hasattr(proc, "begin_training"):
 | 
				
			||||||
 | 
					                kwargs = component_cfg.get(name, {})
 | 
				
			||||||
 | 
					                kwargs.update(cfg)
 | 
				
			||||||
                proc.begin_training(
 | 
					                proc.begin_training(
 | 
				
			||||||
                    get_gold_tuples, pipeline=self.pipeline, sgd=self._optimizer, **cfg
 | 
					                    get_gold_tuples,
 | 
				
			||||||
 | 
					                    pipeline=self.pipeline,
 | 
				
			||||||
 | 
					                    sgd=self._optimizer,
 | 
				
			||||||
 | 
					                    **kwargs
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
        return self._optimizer
 | 
					        return self._optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -574,20 +591,27 @@ class Language(object):
 | 
				
			||||||
                proc._rehearsal_model = deepcopy(proc.model)
 | 
					                proc._rehearsal_model = deepcopy(proc.model)
 | 
				
			||||||
        return self._optimizer
 | 
					        return self._optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def evaluate(self, docs_golds, verbose=False, batch_size=256):
 | 
					    def evaluate(
 | 
				
			||||||
        scorer = Scorer()
 | 
					        self, docs_golds, verbose=False, batch_size=256, scorer=None, component_cfg=None
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        if scorer is None:
 | 
				
			||||||
 | 
					            scorer = Scorer()
 | 
				
			||||||
        docs, golds = zip(*docs_golds)
 | 
					        docs, golds = zip(*docs_golds)
 | 
				
			||||||
        docs = list(docs)
 | 
					        docs = list(docs)
 | 
				
			||||||
        golds = list(golds)
 | 
					        golds = list(golds)
 | 
				
			||||||
        for name, pipe in self.pipeline:
 | 
					        for name, pipe in self.pipeline:
 | 
				
			||||||
 | 
					            kwargs = component_cfg.get(name, {})
 | 
				
			||||||
 | 
					            kwargs.setdefault("batch_size", batch_size)
 | 
				
			||||||
            if not hasattr(pipe, "pipe"):
 | 
					            if not hasattr(pipe, "pipe"):
 | 
				
			||||||
                docs = (pipe(doc) for doc in docs)
 | 
					                docs = (pipe(doc, **kwargs) for doc in docs)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                docs = pipe.pipe(docs, batch_size=batch_size)
 | 
					                docs = pipe.pipe(docs, **kwargs)
 | 
				
			||||||
        for doc, gold in zip(docs, golds):
 | 
					        for doc, gold in zip(docs, golds):
 | 
				
			||||||
            if verbose:
 | 
					            if verbose:
 | 
				
			||||||
                print(doc)
 | 
					                print(doc)
 | 
				
			||||||
            scorer.score(doc, gold, verbose=verbose)
 | 
					            kwargs = component_cfg.get("scorer", {})
 | 
				
			||||||
 | 
					            kwargs.setdefault("verbose", verbose)
 | 
				
			||||||
 | 
					            scorer.score(doc, gold, **kwargs)
 | 
				
			||||||
        return scorer
 | 
					        return scorer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @contextmanager
 | 
					    @contextmanager
 | 
				
			||||||
| 
						 | 
					@ -630,6 +654,7 @@ class Language(object):
 | 
				
			||||||
        batch_size=1000,
 | 
					        batch_size=1000,
 | 
				
			||||||
        disable=[],
 | 
					        disable=[],
 | 
				
			||||||
        cleanup=False,
 | 
					        cleanup=False,
 | 
				
			||||||
 | 
					        component_cfg=None,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """Process texts as a stream, and yield `Doc` objects in order.
 | 
					        """Process texts as a stream, and yield `Doc` objects in order.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -643,6 +668,8 @@ class Language(object):
 | 
				
			||||||
        disable (list): Names of the pipeline components to disable.
 | 
					        disable (list): Names of the pipeline components to disable.
 | 
				
			||||||
        cleanup (bool): If True, unneeded strings are freed,
 | 
					        cleanup (bool): If True, unneeded strings are freed,
 | 
				
			||||||
            to control memory use. Experimental.
 | 
					            to control memory use. Experimental.
 | 
				
			||||||
 | 
					        component_cfg (dict): An optional dictionary with extra keyword arguments
 | 
				
			||||||
 | 
					            for specific components.
 | 
				
			||||||
        YIELDS (Doc): Documents in the order of the original text.
 | 
					        YIELDS (Doc): Documents in the order of the original text.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        EXAMPLE:
 | 
					        EXAMPLE:
 | 
				
			||||||
| 
						 | 
					@ -655,20 +682,30 @@ class Language(object):
 | 
				
			||||||
            texts = (tc[0] for tc in text_context1)
 | 
					            texts = (tc[0] for tc in text_context1)
 | 
				
			||||||
            contexts = (tc[1] for tc in text_context2)
 | 
					            contexts = (tc[1] for tc in text_context2)
 | 
				
			||||||
            docs = self.pipe(
 | 
					            docs = self.pipe(
 | 
				
			||||||
                texts, n_threads=n_threads, batch_size=batch_size, disable=disable
 | 
					                texts,
 | 
				
			||||||
 | 
					                n_threads=n_threads,
 | 
				
			||||||
 | 
					                batch_size=batch_size,
 | 
				
			||||||
 | 
					                disable=disable,
 | 
				
			||||||
 | 
					                component_cfg=component_cfg,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            for doc, context in izip(docs, contexts):
 | 
					            for doc, context in izip(docs, contexts):
 | 
				
			||||||
                yield (doc, context)
 | 
					                yield (doc, context)
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        docs = (self.make_doc(text) for text in texts)
 | 
					        docs = (self.make_doc(text) for text in texts)
 | 
				
			||||||
 | 
					        if component_cfg is None:
 | 
				
			||||||
 | 
					            component_cfg = {}
 | 
				
			||||||
        for name, proc in self.pipeline:
 | 
					        for name, proc in self.pipeline:
 | 
				
			||||||
            if name in disable:
 | 
					            if name in disable:
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
 | 
					            kwargs = component_cfg.get(name, {})
 | 
				
			||||||
 | 
					            # Allow component_cfg to overwrite the top-level kwargs.
 | 
				
			||||||
 | 
					            kwargs.setdefault("batch_size", batch_size)
 | 
				
			||||||
 | 
					            kwargs.setdefault("n_threads", n_threads)
 | 
				
			||||||
            if hasattr(proc, "pipe"):
 | 
					            if hasattr(proc, "pipe"):
 | 
				
			||||||
                docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size)
 | 
					                docs = proc.pipe(docs, **kwargs)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                # Apply the function, but yield the doc
 | 
					                # Apply the function, but yield the doc
 | 
				
			||||||
                docs = _pipe(proc, docs)
 | 
					                docs = _pipe(proc, docs, kwargs)
 | 
				
			||||||
        # Track weakrefs of "recent" documents, so that we can see when they
 | 
					        # Track weakrefs of "recent" documents, so that we can see when they
 | 
				
			||||||
        # expire from memory. When they do, we know we don't need old strings.
 | 
					        # expire from memory. When they do, we know we don't need old strings.
 | 
				
			||||||
        # This way, we avoid maintaining an unbounded growth in string entries
 | 
					        # This way, we avoid maintaining an unbounded growth in string entries
 | 
				
			||||||
| 
						 | 
					@ -861,7 +898,7 @@ class DisabledPipes(list):
 | 
				
			||||||
        self[:] = []
 | 
					        self[:] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _pipe(func, docs):
 | 
					def _pipe(func, docs, kwargs):
 | 
				
			||||||
    for doc in docs:
 | 
					    for doc in docs:
 | 
				
			||||||
        doc = func(doc)
 | 
					        doc = func(doc, **kwargs)
 | 
				
			||||||
        yield doc
 | 
					        yield doc
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -91,13 +91,14 @@ multiprocessing.
 | 
				
			||||||
>     assert doc.is_parsed
 | 
					>     assert doc.is_parsed
 | 
				
			||||||
> ```
 | 
					> ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| Name         | Type  | Description                                                                                                                                                |
 | 
					| Name                                         | Type  | Description                                                                                                                                                |
 | 
				
			||||||
| ------------ | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
 | 
					| -------------------------------------------- | ----- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
 | 
				
			||||||
| `texts`      | -     | A sequence of unicode objects.                                                                                                                             |
 | 
					| `texts`                                      | -     | A sequence of unicode objects.                                                                                                                             |
 | 
				
			||||||
| `as_tuples`  | bool  | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. |
 | 
					| `as_tuples`                                  | bool  | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. |
 | 
				
			||||||
| `batch_size` | int   | The number of texts to buffer.                                                                                                                             |
 | 
					| `batch_size`                                 | int   | The number of texts to buffer.                                                                                                                             |
 | 
				
			||||||
| `disable`    | list  | Names of pipeline components to [disable](/usage/processing-pipelines#disabling).                                                                          |
 | 
					| `disable`                                    | list  | Names of pipeline components to [disable](/usage/processing-pipelines#disabling).                                                                          |
 | 
				
			||||||
| **YIELDS**   | `Doc` | Documents in the order of the original text.                                                                                                               |
 | 
					| `component_cfg` <Tag variant="new">2.1</Tag> | dict  | Config parameters for specific pipeline components, keyed by component name.                                                                               |
 | 
				
			||||||
 | 
					| **YIELDS**                                   | `Doc` | Documents in the order of the original text.                                                                                                               |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Language.update {#update tag="method"}
 | 
					## Language.update {#update tag="method"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -112,13 +113,14 @@ Update the models in the pipeline.
 | 
				
			||||||
>     nlp.update([doc], [gold], drop=0.5, sgd=optimizer)
 | 
					>     nlp.update([doc], [gold], drop=0.5, sgd=optimizer)
 | 
				
			||||||
> ```
 | 
					> ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| Name        | Type     | Description                                                                                                                                                                                                         |
 | 
					| Name                                         | Type     | Description                                                                                                                                                                                                         |
 | 
				
			||||||
| ----------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
 | 
					| -------------------------------------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
 | 
				
			||||||
| `docs`      | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text.                                                                                                                      |
 | 
					| `docs`                                       | iterable | A batch of `Doc` objects or unicode. If unicode, a `Doc` object will be created from the text.                                                                                                                      |
 | 
				
			||||||
| `golds`     | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). |
 | 
					| `golds`                                      | iterable | A batch of `GoldParse` objects or dictionaries. Dictionaries will be used to create [`GoldParse`](/api/goldparse) objects. For the available keys and their usage, see [`GoldParse.__init__`](/api/goldparse#init). |
 | 
				
			||||||
| `drop`      | float    | The dropout rate.                                                                                                                                                                                                   |
 | 
					| `drop`                                       | float    | The dropout rate.                                                                                                                                                                                                   |
 | 
				
			||||||
| `sgd`       | callable | An optimizer.                                                                                                                                                                                                       |
 | 
					| `sgd`                                        | callable | An optimizer.                                                                                                                                                                                                       |
 | 
				
			||||||
| **RETURNS** | dict     | Results from the update.                                                                                                                                                                                            |
 | 
					| `component_cfg` <Tag variant="new">2.1</Tag> | dict     | Config parameters for specific pipeline components, keyed by component name.                                                                                                                                        |
 | 
				
			||||||
 | 
					| **RETURNS**                                  | dict     | Results from the update.                                                                                                                                                                                            |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Language.begin_training {#begin_training tag="method"}
 | 
					## Language.begin_training {#begin_training tag="method"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -130,11 +132,12 @@ Allocate models, pre-process training data and acquire an optimizer.
 | 
				
			||||||
> optimizer = nlp.begin_training(gold_tuples)
 | 
					> optimizer = nlp.begin_training(gold_tuples)
 | 
				
			||||||
> ```
 | 
					> ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| Name          | Type     | Description                  |
 | 
					| Name                                         | Type     | Description                                                                  |
 | 
				
			||||||
| ------------- | -------- | ---------------------------- |
 | 
					| -------------------------------------------- | -------- | ---------------------------------------------------------------------------- |
 | 
				
			||||||
| `gold_tuples` | iterable | Gold-standard training data. |
 | 
					| `gold_tuples`                                | iterable | Gold-standard training data.                                                 |
 | 
				
			||||||
| `**cfg`       | -        | Config parameters.           |
 | 
					| `component_cfg` <Tag variant="new">2.1</Tag> | dict     | Config parameters for specific pipeline components, keyed by component name. |
 | 
				
			||||||
| **RETURNS**   | callable | An optimizer.                |
 | 
					| `**cfg`                                      | -        | Config parameters (sent to all components).                                  |
 | 
				
			||||||
 | 
					| **RETURNS**                                  | callable | An optimizer.                                                                |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Language.use_params {#use_params tag="contextmanager, method"}
 | 
					## Language.use_params {#use_params tag="contextmanager, method"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user