From 251b3eb4e5c688e076f4e761a43ffbab9ea793b9 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Mon, 5 Oct 2020 14:59:13 +0200 Subject: [PATCH] add initialize method for entity_ruler --- spacy/errors.py | 2 ++ spacy/pipeline/entityruler.py | 30 +++++++++++++++++++++++++++++- spacy/training/initialize.py | 2 +- 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 20edf45b5..18abb6bba 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -456,6 +456,8 @@ class Errors: "issue tracker: http://github.com/explosion/spaCy/issues") # TODO: fix numbering after merging develop into master + E900 = ("Patterns for component '{name}' not initialized. This can be fixed " + "by calling 'add_patterns' or 'initialize'.") E092 = ("The sentence-per-line IOB/IOB2 file is not formatted correctly. " "Try checking whitespace and delimiters. See " "https://nightly.spacy.io/api/cli#convert") diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 9166a69b8..a4bc098fb 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -1,7 +1,8 @@ -from typing import Optional, Union, List, Dict, Tuple, Iterable, Any +from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable from collections import defaultdict from pathlib import Path import srsly +from spacy.training import Example from ..language import Language from ..errors import Errors @@ -133,6 +134,7 @@ class EntityRuler: DOCS: https://nightly.spacy.io/api/entityruler#call """ + self._require_patterns() matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc)) matches = set( [(m_id, start, end) for m_id, start, end in matches if start != end] @@ -183,6 +185,27 @@ class EntityRuler: all_labels.add(l) return tuple(all_labels) + def initialize( + self, + get_examples: Callable[[], Iterable[Example]], + *, + nlp: Optional[Language] = None, + patterns_path: Optional[Path] = None + ): + """Initialize the pipe for training. + + get_examples (Callable[[], Iterable[Example]]): Function that + returns a representative sample of gold-standard Example objects. + nlp (Language): The current nlp object the component is part of. + patterns_path: Path to serialized patterns. + + DOCS (TODO): https://nightly.spacy.io/api/entityruler#initialize + """ + if patterns_path: + patterns = srsly.read_jsonl(patterns_path) + self.add_patterns(patterns) + + @property def ent_ids(self) -> Tuple[str, ...]: """All entity ids present in the match patterns `id` properties @@ -292,6 +315,11 @@ class EntityRuler: self.phrase_patterns = defaultdict(list) self._ent_ids = defaultdict(dict) + def _require_patterns(self) -> None: + """Raise an error if the component has no patterns.""" + if not self.patterns or list(self.patterns) == [""]: + raise ValueError(Errors.E900.format(name=self.name)) + def _split_label(self, label: str) -> Tuple[str, str]: """Split Entity label into ent_label and ent_id if it contains self.ent_id_sep diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index bbdf4f62b..7c84caf95 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -49,7 +49,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": nlp.resume_training(sgd=optimizer) with nlp.select_pipes(disable=[*frozen_components, *resume_components]): nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer) - logger.info("Initialized pipeline components") + logger.info(f"Initialized pipeline components: {nlp.pipe_names}") return nlp