is_trainable method

This commit is contained in:
svlandeg 2020-10-05 17:43:42 +02:00
parent dc06912c76
commit 4e3ace4b8c
5 changed files with 26 additions and 14 deletions

View File

@ -1091,7 +1091,8 @@ class Language:
for name, proc in self.pipeline:
if (
name not in exclude
and hasattr(proc, "model")
and hasattr(proc, "is_trainable")
and proc.is_trainable()
and proc.model not in (True, False, None)
):
proc.finish_update(sgd)
@ -1297,7 +1298,9 @@ class Language:
for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {})
kwargs.setdefault("batch_size", batch_size)
if not hasattr(pipe, "pipe"):
# non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods
if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable():
docs = _pipe(docs, pipe, kwargs)
else:
docs = pipe.pipe(docs, **kwargs)

View File

@ -2,8 +2,9 @@ from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable
from collections import defaultdict
from pathlib import Path
import srsly
from spacy.training import Example
from .pipe import Pipe
from ..training import Example
from ..language import Language
from ..errors import Errors
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
@ -51,7 +52,7 @@ def make_entity_ruler(
)
class EntityRuler:
class EntityRuler(Pipe):
"""The EntityRuler lets you add spans to the `Doc.ents` using token-based
rules or exact phrase matches. It can be combined with the statistical
`EntityRecognizer` to boost accuracy, or used on its own to implement a
@ -134,7 +135,6 @@ class EntityRuler:
DOCS: https://nightly.spacy.io/api/entityruler#call
"""
self._require_patterns()
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
matches = set(
[(m_id, start, end) for m_id, start, end in matches if start != end]
@ -315,11 +315,6 @@ class EntityRuler:
self.phrase_patterns = defaultdict(list)
self._ent_ids = defaultdict(dict)
def _require_patterns(self) -> None:
"""Raise an error if the component has no patterns."""
if not self.patterns or list(self.patterns) == [""]:
raise ValueError(Errors.E900.format(name=self.name))
def _split_label(self, label: str) -> Tuple[str, str]:
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep
@ -348,6 +343,12 @@ class EntityRuler:
validate_examples(examples, "EntityRuler.score")
return Scorer.score_spans(examples, "ents", **kwargs)
def predict(self, docs):
pass
def set_annotations(self, docs, scores):
pass
def from_bytes(
self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
) -> "EntityRuler":

View File

@ -228,6 +228,9 @@ cdef class Pipe:
def is_resizable(self):
return hasattr(self, "model") and "resize_output" in self.model.attrs
def is_trainable(self):
return hasattr(self, "model") and isinstance(self.model, Model)
def set_output(self, nO):
if self.is_resizable():
self.model.attrs["resize_output"](self.model, nO)

View File

@ -17,8 +17,12 @@ def console_logger(progress_bar: bool = False):
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
msg = Printer(no_print=True)
# we assume here that only components are enabled that should be trained & logged
logged_pipes = nlp.pipe_names
# ensure that only trainable components are logged
logged_pipes = [
name
for name, proc in nlp.pipeline
if hasattr(proc, "is_trainable") and proc.is_trainable()
]
eval_frequency = nlp.config["training"]["eval_frequency"]
score_weights = nlp.config["training"]["score_weights"]
score_cols = [col for col, value in score_weights.items() if value is not None]
@ -43,7 +47,7 @@ def console_logger(progress_bar: bool = False):
return
losses = [
"{0:.2f}".format(float(info["losses"][pipe_name]))
for pipe_name in logged_pipes if pipe_name in info["losses"]
for pipe_name in logged_pipes
]
scores = []

View File

@ -181,7 +181,8 @@ def train_while_improving(
for name, proc in nlp.pipeline:
if (
name not in exclude
and hasattr(proc, "model")
and hasattr(proc, "is_trainable")
and proc.is_trainable()
and proc.model not in (True, False, None)
):
proc.finish_update(optimizer)