mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Merge pull request #6206 from svlandeg/fix/patterns-init
This commit is contained in:
		
						commit
						568e12215d
					
				| 
						 | 
				
			
			@ -1091,10 +1091,11 @@ class Language:
 | 
			
		|||
            for name, proc in self.pipeline:
 | 
			
		||||
                if (
 | 
			
		||||
                    name not in exclude
 | 
			
		||||
                    and hasattr(proc, "model")
 | 
			
		||||
                    and hasattr(proc, "is_trainable")
 | 
			
		||||
                    and proc.is_trainable()
 | 
			
		||||
                    and proc.model not in (True, False, None)
 | 
			
		||||
                ):
 | 
			
		||||
                    proc.model.finish_update(sgd)
 | 
			
		||||
                    proc.finish_update(sgd)
 | 
			
		||||
        return losses
 | 
			
		||||
 | 
			
		||||
    def rehearse(
 | 
			
		||||
| 
						 | 
				
			
			@ -1297,7 +1298,9 @@ class Language:
 | 
			
		|||
        for name, pipe in self.pipeline:
 | 
			
		||||
            kwargs = component_cfg.get(name, {})
 | 
			
		||||
            kwargs.setdefault("batch_size", batch_size)
 | 
			
		||||
            if not hasattr(pipe, "pipe"):
 | 
			
		||||
            # non-trainable components may have a pipe() implementation that refers to dummy
 | 
			
		||||
            # predict and set_annotations methods
 | 
			
		||||
            if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable():
 | 
			
		||||
                docs = _pipe(docs, pipe, kwargs)
 | 
			
		||||
            else:
 | 
			
		||||
                docs = pipe.pipe(docs, **kwargs)
 | 
			
		||||
| 
						 | 
				
			
			@ -1407,7 +1410,9 @@ class Language:
 | 
			
		|||
            kwargs = component_cfg.get(name, {})
 | 
			
		||||
            # Allow component_cfg to overwrite the top-level kwargs.
 | 
			
		||||
            kwargs.setdefault("batch_size", batch_size)
 | 
			
		||||
            if hasattr(proc, "pipe"):
 | 
			
		||||
            # non-trainable components may have a pipe() implementation that refers to dummy
 | 
			
		||||
            # predict and set_annotations methods
 | 
			
		||||
            if hasattr(proc, "pipe") and hasattr(proc, "is_trainable") and proc.is_trainable():
 | 
			
		||||
                f = functools.partial(proc.pipe, **kwargs)
 | 
			
		||||
            else:
 | 
			
		||||
                # Apply the function, but yield the doc
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -238,7 +238,7 @@ class EntityLinker(Pipe):
 | 
			
		|||
        )
 | 
			
		||||
        bp_context(d_scores)
 | 
			
		||||
        if sgd is not None:
 | 
			
		||||
            self.model.finish_update(sgd)
 | 
			
		||||
            self.finish_update(sgd)
 | 
			
		||||
        losses[self.name] += loss
 | 
			
		||||
        if set_annotations:
 | 
			
		||||
            self.set_annotations(docs, predictions)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,8 +1,10 @@
 | 
			
		|||
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any
 | 
			
		||||
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable, Sequence
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
import srsly
 | 
			
		||||
 | 
			
		||||
from .pipe import Pipe
 | 
			
		||||
from ..training import Example
 | 
			
		||||
from ..language import Language
 | 
			
		||||
from ..errors import Errors
 | 
			
		||||
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
 | 
			
		||||
| 
						 | 
				
			
			@ -50,7 +52,7 @@ def make_entity_ruler(
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EntityRuler:
 | 
			
		||||
class EntityRuler(Pipe):
 | 
			
		||||
    """The EntityRuler lets you add spans to the `Doc.ents` using token-based
 | 
			
		||||
    rules or exact phrase matches. It can be combined with the statistical
 | 
			
		||||
    `EntityRecognizer` to boost accuracy, or used on its own to implement a
 | 
			
		||||
| 
						 | 
				
			
			@ -183,6 +185,26 @@ class EntityRuler:
 | 
			
		|||
                all_labels.add(l)
 | 
			
		||||
        return tuple(all_labels)
 | 
			
		||||
 | 
			
		||||
    def initialize(
 | 
			
		||||
        self,
 | 
			
		||||
        get_examples: Callable[[], Iterable[Example]],
 | 
			
		||||
        *,
 | 
			
		||||
        nlp: Optional[Language] = None,
 | 
			
		||||
        patterns: Optional[Sequence[PatternType]] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        """Initialize the pipe for training.
 | 
			
		||||
 | 
			
		||||
        get_examples (Callable[[], Iterable[Example]]): Function that
 | 
			
		||||
            returns a representative sample of gold-standard Example objects.
 | 
			
		||||
        nlp (Language): The current nlp object the component is part of.
 | 
			
		||||
        patterns Optional[Iterable[PatternType]]: The list of patterns.
 | 
			
		||||
 | 
			
		||||
        DOCS: https://nightly.spacy.io/api/entityruler#initialize
 | 
			
		||||
        """
 | 
			
		||||
        if patterns:
 | 
			
		||||
            self.add_patterns(patterns)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def ent_ids(self) -> Tuple[str, ...]:
 | 
			
		||||
        """All entity ids present in the match patterns `id` properties
 | 
			
		||||
| 
						 | 
				
			
			@ -320,6 +342,12 @@ class EntityRuler:
 | 
			
		|||
        validate_examples(examples, "EntityRuler.score")
 | 
			
		||||
        return Scorer.score_spans(examples, "ents", **kwargs)
 | 
			
		||||
 | 
			
		||||
    def predict(self, docs):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def set_annotations(self, docs, scores):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def from_bytes(
 | 
			
		||||
        self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
 | 
			
		||||
    ) -> "EntityRuler":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -209,7 +209,7 @@ class ClozeMultitask(Pipe):
 | 
			
		|||
        loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
 | 
			
		||||
        bp_predictions(d_predictions)
 | 
			
		||||
        if sgd is not None:
 | 
			
		||||
            self.model.finish_update(sgd)
 | 
			
		||||
            self.finish_update(sgd)
 | 
			
		||||
        if losses is not None:
 | 
			
		||||
            losses[self.name] += loss
 | 
			
		||||
        return losses
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -132,7 +132,7 @@ cdef class Pipe:
 | 
			
		|||
        loss, d_scores = self.get_loss(examples, scores)
 | 
			
		||||
        bp_scores(d_scores)
 | 
			
		||||
        if sgd not in (None, False):
 | 
			
		||||
            self.model.finish_update(sgd)
 | 
			
		||||
            self.finish_update(sgd)
 | 
			
		||||
        losses[self.name] += loss
 | 
			
		||||
        if set_annotations:
 | 
			
		||||
            docs = [eg.predicted for eg in examples]
 | 
			
		||||
| 
						 | 
				
			
			@ -228,6 +228,9 @@ cdef class Pipe:
 | 
			
		|||
    def is_resizable(self):
 | 
			
		||||
        return hasattr(self, "model") and "resize_output" in self.model.attrs
 | 
			
		||||
 | 
			
		||||
    def is_trainable(self):
 | 
			
		||||
        return hasattr(self, "model") and isinstance(self.model, Model)
 | 
			
		||||
 | 
			
		||||
    def set_output(self, nO):
 | 
			
		||||
        if self.is_resizable():
 | 
			
		||||
            self.model.attrs["resize_output"](self.model, nO)
 | 
			
		||||
| 
						 | 
				
			
			@ -245,6 +248,17 @@ cdef class Pipe:
 | 
			
		|||
        with self.model.use_params(params):
 | 
			
		||||
            yield
 | 
			
		||||
 | 
			
		||||
    def finish_update(self, sgd):
 | 
			
		||||
        """Update parameters using the current parameter gradients.
 | 
			
		||||
        The Optimizer instance contains the functionality to perform
 | 
			
		||||
        the stochastic gradient descent.
 | 
			
		||||
 | 
			
		||||
        sgd (thinc.api.Optimizer): The optimizer.
 | 
			
		||||
 | 
			
		||||
        DOCS: https://nightly.spacy.io/api/pipe#finish_update
 | 
			
		||||
        """
 | 
			
		||||
        self.model.finish_update(sgd)
 | 
			
		||||
 | 
			
		||||
    def score(self, examples, **kwargs):
 | 
			
		||||
        """Score a batch of examples.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -203,7 +203,7 @@ class Tagger(Pipe):
 | 
			
		|||
        loss, d_tag_scores = self.get_loss(examples, tag_scores)
 | 
			
		||||
        bp_tag_scores(d_tag_scores)
 | 
			
		||||
        if sgd not in (None, False):
 | 
			
		||||
            self.model.finish_update(sgd)
 | 
			
		||||
            self.finish_update(sgd)
 | 
			
		||||
 | 
			
		||||
        losses[self.name] += loss
 | 
			
		||||
        if set_annotations:
 | 
			
		||||
| 
						 | 
				
			
			@ -238,7 +238,7 @@ class Tagger(Pipe):
 | 
			
		|||
        target = self._rehearsal_model(examples)
 | 
			
		||||
        gradient = guesses - target
 | 
			
		||||
        backprop(gradient)
 | 
			
		||||
        self.model.finish_update(sgd)
 | 
			
		||||
        self.finish_update(sgd)
 | 
			
		||||
        if losses is not None:
 | 
			
		||||
            losses.setdefault(self.name, 0.0)
 | 
			
		||||
            losses[self.name] += (gradient**2).sum()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -212,7 +212,7 @@ class TextCategorizer(Pipe):
 | 
			
		|||
        loss, d_scores = self.get_loss(examples, scores)
 | 
			
		||||
        bp_scores(d_scores)
 | 
			
		||||
        if sgd is not None:
 | 
			
		||||
            self.model.finish_update(sgd)
 | 
			
		||||
            self.finish_update(sgd)
 | 
			
		||||
        losses[self.name] += loss
 | 
			
		||||
        if set_annotations:
 | 
			
		||||
            docs = [eg.predicted for eg in examples]
 | 
			
		||||
| 
						 | 
				
			
			@ -256,7 +256,7 @@ class TextCategorizer(Pipe):
 | 
			
		|||
        gradient = scores - target
 | 
			
		||||
        bp_scores(gradient)
 | 
			
		||||
        if sgd is not None:
 | 
			
		||||
            self.model.finish_update(sgd)
 | 
			
		||||
            self.finish_update(sgd)
 | 
			
		||||
        if losses is not None:
 | 
			
		||||
            losses[self.name] += (gradient ** 2).sum()
 | 
			
		||||
        return losses
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -188,7 +188,7 @@ class Tok2Vec(Pipe):
 | 
			
		|||
            accumulate_gradient(one_d_tokvecs)
 | 
			
		||||
            d_docs = bp_tokvecs(d_tokvecs)
 | 
			
		||||
            if sgd is not None:
 | 
			
		||||
                self.model.finish_update(sgd)
 | 
			
		||||
                self.finish_update(sgd)
 | 
			
		||||
            return d_docs
 | 
			
		||||
 | 
			
		||||
        batch_id = Tok2VecListener.get_batch_id(docs)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -315,7 +315,7 @@ cdef class Parser(Pipe):
 | 
			
		|||
 | 
			
		||||
        backprop_tok2vec(golds)
 | 
			
		||||
        if sgd not in (None, False):
 | 
			
		||||
            self.model.finish_update(sgd)
 | 
			
		||||
            self.finish_update(sgd)
 | 
			
		||||
        if set_annotations:
 | 
			
		||||
            docs = [eg.predicted for eg in examples]
 | 
			
		||||
            self.set_annotations(docs, all_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -367,7 +367,7 @@ cdef class Parser(Pipe):
 | 
			
		|||
        # Do the backprop
 | 
			
		||||
        backprop_tok2vec(docs)
 | 
			
		||||
        if sgd is not None:
 | 
			
		||||
            self.model.finish_update(sgd)
 | 
			
		||||
            self.finish_update(sgd)
 | 
			
		||||
        losses[self.name] += loss / n_scores
 | 
			
		||||
        del backprop
 | 
			
		||||
        del backprop_tok2vec
 | 
			
		||||
| 
						 | 
				
			
			@ -437,7 +437,9 @@ cdef class Parser(Pipe):
 | 
			
		|||
            for name, component in nlp.pipeline:
 | 
			
		||||
                if component is self:
 | 
			
		||||
                    break
 | 
			
		||||
                if hasattr(component, "pipe"):
 | 
			
		||||
                # non-trainable components may have a pipe() implementation that refers to dummy
 | 
			
		||||
                # predict and set_annotations methods
 | 
			
		||||
                if hasattr(component, "pipe") and hasattr(component, "is_trainable") and component.is_trainable():
 | 
			
		||||
                    doc_sample = list(component.pipe(doc_sample, batch_size=8))
 | 
			
		||||
                else:
 | 
			
		||||
                    doc_sample = [component(doc) for doc in doc_sample]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -119,7 +119,7 @@ def validate_init_settings(
 | 
			
		|||
    if types don't match or required values are missing.
 | 
			
		||||
 | 
			
		||||
    func (Callable): The initialize method of a given component etc.
 | 
			
		||||
    settings (Dict[str, Any]): The settings from the repsective [initialize] block.
 | 
			
		||||
    settings (Dict[str, Any]): The settings from the respective [initialize] block.
 | 
			
		||||
    section (str): Initialize section, for error message.
 | 
			
		||||
    name (str): Name of the block in the section.
 | 
			
		||||
    exclude (Iterable[str]): Parameter names to exclude from schema.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -121,7 +121,7 @@ def test_attributeruler_init_patterns(nlp, pattern_dicts):
 | 
			
		|||
    assert doc.has_annotation("LEMMA")
 | 
			
		||||
    assert doc.has_annotation("MORPH")
 | 
			
		||||
    nlp.remove_pipe("attribute_ruler")
 | 
			
		||||
    # initialize with patterns from asset
 | 
			
		||||
    # initialize with patterns from misc registry
 | 
			
		||||
    nlp.config["initialize"]["components"]["attribute_ruler"] = {
 | 
			
		||||
        "patterns": {"@misc": "attribute_ruler_patterns"}
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,6 @@
 | 
			
		|||
import pytest
 | 
			
		||||
 | 
			
		||||
from spacy import registry
 | 
			
		||||
from spacy.tokens import Span
 | 
			
		||||
from spacy.language import Language
 | 
			
		||||
from spacy.pipeline import EntityRuler
 | 
			
		||||
| 
						 | 
				
			
			@ -11,6 +13,7 @@ def nlp():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
@registry.misc("entity_ruler_patterns")
 | 
			
		||||
def patterns():
 | 
			
		||||
    return [
 | 
			
		||||
        {"label": "HELLO", "pattern": "hello world"},
 | 
			
		||||
| 
						 | 
				
			
			@ -42,6 +45,29 @@ def test_entity_ruler_init(nlp, patterns):
 | 
			
		|||
    assert doc.ents[1].label_ == "BYE"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_entity_ruler_init_patterns(nlp, patterns):
 | 
			
		||||
    # initialize with patterns
 | 
			
		||||
    ruler = nlp.add_pipe("entity_ruler")
 | 
			
		||||
    assert len(ruler.labels) == 0
 | 
			
		||||
    ruler.initialize(lambda: [], patterns=patterns)
 | 
			
		||||
    assert len(ruler.labels) == 4
 | 
			
		||||
    doc = nlp("hello world bye bye")
 | 
			
		||||
    assert doc.ents[0].label_ == "HELLO"
 | 
			
		||||
    assert doc.ents[1].label_ == "BYE"
 | 
			
		||||
    nlp.remove_pipe("entity_ruler")
 | 
			
		||||
    # initialize with patterns from misc registry
 | 
			
		||||
    nlp.config["initialize"]["components"]["entity_ruler"] = {
 | 
			
		||||
        "patterns": {"@misc": "entity_ruler_patterns"}
 | 
			
		||||
    }
 | 
			
		||||
    ruler = nlp.add_pipe("entity_ruler")
 | 
			
		||||
    assert len(ruler.labels) == 0
 | 
			
		||||
    nlp.initialize()
 | 
			
		||||
    assert len(ruler.labels) == 4
 | 
			
		||||
    doc = nlp("hello world bye bye")
 | 
			
		||||
    assert doc.ents[0].label_ == "HELLO"
 | 
			
		||||
    assert doc.ents[1].label_ == "BYE"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_entity_ruler_existing(nlp, patterns):
 | 
			
		||||
    ruler = nlp.add_pipe("entity_ruler")
 | 
			
		||||
    ruler.add_patterns(patterns)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -49,7 +49,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
 | 
			
		|||
            nlp.resume_training(sgd=optimizer)
 | 
			
		||||
    with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
 | 
			
		||||
        nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
 | 
			
		||||
        logger.info("Initialized pipeline components")
 | 
			
		||||
        logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
 | 
			
		||||
    return nlp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,8 +17,12 @@ def console_logger(progress_bar: bool = False):
 | 
			
		|||
        nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
 | 
			
		||||
    ) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
 | 
			
		||||
        msg = Printer(no_print=True)
 | 
			
		||||
        # we assume here that only components are enabled that should be trained & logged
 | 
			
		||||
        logged_pipes = nlp.pipe_names
 | 
			
		||||
        # ensure that only trainable components are logged
 | 
			
		||||
        logged_pipes = [
 | 
			
		||||
            name
 | 
			
		||||
            for name, proc in nlp.pipeline
 | 
			
		||||
            if hasattr(proc, "is_trainable") and proc.is_trainable()
 | 
			
		||||
        ]
 | 
			
		||||
        eval_frequency = nlp.config["training"]["eval_frequency"]
 | 
			
		||||
        score_weights = nlp.config["training"]["score_weights"]
 | 
			
		||||
        score_cols = [col for col, value in score_weights.items() if value is not None]
 | 
			
		||||
| 
						 | 
				
			
			@ -41,19 +45,10 @@ def console_logger(progress_bar: bool = False):
 | 
			
		|||
                if progress is not None:
 | 
			
		||||
                    progress.update(1)
 | 
			
		||||
                return
 | 
			
		||||
            try:
 | 
			
		||||
                losses = [
 | 
			
		||||
                    "{0:.2f}".format(float(info["losses"][pipe_name]))
 | 
			
		||||
                    for pipe_name in logged_pipes
 | 
			
		||||
                ]
 | 
			
		||||
            except KeyError as e:
 | 
			
		||||
                raise KeyError(
 | 
			
		||||
                    Errors.E983.format(
 | 
			
		||||
                        dict="scores (losses)",
 | 
			
		||||
                        key=str(e),
 | 
			
		||||
                        keys=list(info["losses"].keys()),
 | 
			
		||||
                    )
 | 
			
		||||
                ) from None
 | 
			
		||||
            losses = [
 | 
			
		||||
                "{0:.2f}".format(float(info["losses"][pipe_name]))
 | 
			
		||||
                for pipe_name in logged_pipes
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
            scores = []
 | 
			
		||||
            for col in score_cols:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -187,10 +187,11 @@ def train_while_improving(
 | 
			
		|||
        for name, proc in nlp.pipeline:
 | 
			
		||||
            if (
 | 
			
		||||
                name not in exclude
 | 
			
		||||
                and hasattr(proc, "model")
 | 
			
		||||
                and hasattr(proc, "is_trainable")
 | 
			
		||||
                and proc.is_trainable()
 | 
			
		||||
                and proc.model not in (True, False, None)
 | 
			
		||||
            ):
 | 
			
		||||
                proc.model.finish_update(optimizer)
 | 
			
		||||
                proc.finish_update(optimizer)
 | 
			
		||||
        optimizer.step_schedules()
 | 
			
		||||
        if not (step % eval_frequency):
 | 
			
		||||
            if optimizer.averages:
 | 
			
		||||
| 
						 | 
				
			
			@ -293,7 +294,8 @@ def update_meta(
 | 
			
		|||
        if metric is not None:
 | 
			
		||||
            nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0)
 | 
			
		||||
    for pipe_name in nlp.pipe_names:
 | 
			
		||||
        nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
 | 
			
		||||
        if pipe_name in info["losses"]:
 | 
			
		||||
            nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_before_to_disk_callback(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -74,6 +74,33 @@ be a token pattern (list) or a phrase pattern (string). For example:
 | 
			
		|||
| `ent_id_sep`                      | Separator used internally for entity IDs. Defaults to `"||"`. ~~str~~                                                                                                                                                                 |
 | 
			
		||||
| `patterns`                        | Optional patterns to load in on initialization. ~~Optional[List[Dict[str, Union[str, List[dict]]]]]~~                                                                                                                                 |
 | 
			
		||||
 | 
			
		||||
## EntityRuler.initialize {#initialize tag="method" new="3"}
 | 
			
		||||
 | 
			
		||||
Initialize the component with patterns from a file.
 | 
			
		||||
 | 
			
		||||
> #### Example
 | 
			
		||||
>
 | 
			
		||||
> ```python
 | 
			
		||||
> entity_ruler = nlp.add_pipe("entity_ruler")
 | 
			
		||||
> entity_ruler.initialize(lambda: [], nlp=nlp, patterns=patterns)
 | 
			
		||||
> ```
 | 
			
		||||
>
 | 
			
		||||
> ```ini
 | 
			
		||||
> ### config.cfg
 | 
			
		||||
> [initialize.components.entity_ruler]
 | 
			
		||||
>
 | 
			
		||||
> [initialize.components.entity_ruler.patterns]
 | 
			
		||||
> @readers = "srsly.read_jsonl.v1"
 | 
			
		||||
> path = "corpus/entity_ruler_patterns.jsonl
 | 
			
		||||
> ```
 | 
			
		||||
 | 
			
		||||
| Name           | Description                                                                                                                                                          |
 | 
			
		||||
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
 | 
			
		||||
| `get_examples` | Function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. Not used by the `EntityRuler`. ~~Callable[[], Iterable[Example]]~~ |
 | 
			
		||||
| _keyword-only_ |                                                                                                                                                                      |
 | 
			
		||||
| `nlp`          | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~                                                                                                 |
 | 
			
		||||
| `patterns`     | The list of patterns. Defaults to `None`. ~~Optional[Sequence[Dict[str, Union[str, List[Dict[str, Any]]]]]]~~                                                        |
 | 
			
		||||
 | 
			
		||||
## EntityRuler.\_\len\_\_ {#len tag="method"}
 | 
			
		||||
 | 
			
		||||
The number of all patterns added to the entity ruler.
 | 
			
		||||
| 
						 | 
				
			
			@ -256,6 +283,6 @@ Get all patterns that were added to the entity ruler.
 | 
			
		|||
| Name              | Description                                                                                                           |
 | 
			
		||||
| ----------------- | --------------------------------------------------------------------------------------------------------------------- |
 | 
			
		||||
| `matcher`         | The underlying matcher used to process token patterns. ~~Matcher~~                                                    |
 | 
			
		||||
| `phrase_matcher`  | The underlying phrase matcher used to process phrase patterns. ~~PhraseMatcher~~                                     |
 | 
			
		||||
| `phrase_matcher`  | The underlying phrase matcher used to process phrase patterns. ~~PhraseMatcher~~                                      |
 | 
			
		||||
| `token_patterns`  | The token patterns present in the entity ruler, keyed by label. ~~Dict[str, List[Dict[str, Union[str, List[dict]]]]~~ |
 | 
			
		||||
| `phrase_patterns` | The phrase patterns present in the entity ruler, keyed by label. ~~Dict[str, List[Doc]]~~                             |
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -294,6 +294,24 @@ context, the original parameters are restored.
 | 
			
		|||
| -------- | -------------------------------------------------- |
 | 
			
		||||
| `params` | The parameter values to use in the model. ~~dict~~ |
 | 
			
		||||
 | 
			
		||||
## Pipe.finish_update {#finish_update tag="method"}
 | 
			
		||||
 | 
			
		||||
Update parameters using the current parameter gradients. Defaults to calling
 | 
			
		||||
[`self.model.finish_update`](https://thinc.ai/docs/api-model#finish_update).
 | 
			
		||||
 | 
			
		||||
> #### Example
 | 
			
		||||
>
 | 
			
		||||
> ```python
 | 
			
		||||
> pipe = nlp.add_pipe("your_custom_pipe")
 | 
			
		||||
> optimizer = nlp.initialize()
 | 
			
		||||
> losses = pipe.update(examples, sgd=None)
 | 
			
		||||
> pipe.finish_update(sgd)
 | 
			
		||||
> ```
 | 
			
		||||
 | 
			
		||||
| Name  | Description                           |
 | 
			
		||||
| ----- | ------------------------------------- |
 | 
			
		||||
| `sgd` | An optimizer. ~~Optional[Optimizer]~~ |
 | 
			
		||||
 | 
			
		||||
## Pipe.add_label {#add_label tag="method"}
 | 
			
		||||
 | 
			
		||||
> #### Example
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user