mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
add initialize method for entity_ruler
This commit is contained in:
parent
e3acad6264
commit
251b3eb4e5
|
@ -456,6 +456,8 @@ class Errors:
|
||||||
"issue tracker: http://github.com/explosion/spaCy/issues")
|
"issue tracker: http://github.com/explosion/spaCy/issues")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E900 = ("Patterns for component '{name}' not initialized. This can be fixed "
|
||||||
|
"by calling 'add_patterns' or 'initialize'.")
|
||||||
E092 = ("The sentence-per-line IOB/IOB2 file is not formatted correctly. "
|
E092 = ("The sentence-per-line IOB/IOB2 file is not formatted correctly. "
|
||||||
"Try checking whitespace and delimiters. See "
|
"Try checking whitespace and delimiters. See "
|
||||||
"https://nightly.spacy.io/api/cli#convert")
|
"https://nightly.spacy.io/api/cli#convert")
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any
|
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import srsly
|
import srsly
|
||||||
|
from spacy.training import Example
|
||||||
|
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
@ -133,6 +134,7 @@ class EntityRuler:
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/entityruler#call
|
DOCS: https://nightly.spacy.io/api/entityruler#call
|
||||||
"""
|
"""
|
||||||
|
self._require_patterns()
|
||||||
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
|
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
|
||||||
matches = set(
|
matches = set(
|
||||||
[(m_id, start, end) for m_id, start, end in matches if start != end]
|
[(m_id, start, end) for m_id, start, end in matches if start != end]
|
||||||
|
@ -183,6 +185,27 @@ class EntityRuler:
|
||||||
all_labels.add(l)
|
all_labels.add(l)
|
||||||
return tuple(all_labels)
|
return tuple(all_labels)
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
get_examples: Callable[[], Iterable[Example]],
|
||||||
|
*,
|
||||||
|
nlp: Optional[Language] = None,
|
||||||
|
patterns_path: Optional[Path] = None
|
||||||
|
):
|
||||||
|
"""Initialize the pipe for training.
|
||||||
|
|
||||||
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||||
|
returns a representative sample of gold-standard Example objects.
|
||||||
|
nlp (Language): The current nlp object the component is part of.
|
||||||
|
patterns_path: Path to serialized patterns.
|
||||||
|
|
||||||
|
DOCS (TODO): https://nightly.spacy.io/api/entityruler#initialize
|
||||||
|
"""
|
||||||
|
if patterns_path:
|
||||||
|
patterns = srsly.read_jsonl(patterns_path)
|
||||||
|
self.add_patterns(patterns)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ent_ids(self) -> Tuple[str, ...]:
|
def ent_ids(self) -> Tuple[str, ...]:
|
||||||
"""All entity ids present in the match patterns `id` properties
|
"""All entity ids present in the match patterns `id` properties
|
||||||
|
@ -292,6 +315,11 @@ class EntityRuler:
|
||||||
self.phrase_patterns = defaultdict(list)
|
self.phrase_patterns = defaultdict(list)
|
||||||
self._ent_ids = defaultdict(dict)
|
self._ent_ids = defaultdict(dict)
|
||||||
|
|
||||||
|
def _require_patterns(self) -> None:
|
||||||
|
"""Raise an error if the component has no patterns."""
|
||||||
|
if not self.patterns or list(self.patterns) == [""]:
|
||||||
|
raise ValueError(Errors.E900.format(name=self.name))
|
||||||
|
|
||||||
def _split_label(self, label: str) -> Tuple[str, str]:
|
def _split_label(self, label: str) -> Tuple[str, str]:
|
||||||
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep
|
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
nlp.resume_training(sgd=optimizer)
|
nlp.resume_training(sgd=optimizer)
|
||||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||||
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
|
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
|
||||||
logger.info("Initialized pipeline components")
|
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user