is_trainable method

This commit is contained in:
svlandeg 2020-10-05 17:43:42 +02:00
parent dc06912c76
commit 4e3ace4b8c
5 changed files with 26 additions and 14 deletions

View File

@ -1091,7 +1091,8 @@ class Language:
for name, proc in self.pipeline: for name, proc in self.pipeline:
if ( if (
name not in exclude name not in exclude
and hasattr(proc, "model") and hasattr(proc, "is_trainable")
and proc.is_trainable()
and proc.model not in (True, False, None) and proc.model not in (True, False, None)
): ):
proc.finish_update(sgd) proc.finish_update(sgd)
@ -1297,7 +1298,9 @@ class Language:
for name, pipe in self.pipeline: for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {}) kwargs = component_cfg.get(name, {})
kwargs.setdefault("batch_size", batch_size) kwargs.setdefault("batch_size", batch_size)
if not hasattr(pipe, "pipe"): # non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods
if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable():
docs = _pipe(docs, pipe, kwargs) docs = _pipe(docs, pipe, kwargs)
else: else:
docs = pipe.pipe(docs, **kwargs) docs = pipe.pipe(docs, **kwargs)

View File

@ -2,8 +2,9 @@ from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import srsly import srsly
from spacy.training import Example
from .pipe import Pipe
from ..training import Example
from ..language import Language from ..language import Language
from ..errors import Errors from ..errors import Errors
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
@ -51,7 +52,7 @@ def make_entity_ruler(
) )
class EntityRuler: class EntityRuler(Pipe):
"""The EntityRuler lets you add spans to the `Doc.ents` using token-based """The EntityRuler lets you add spans to the `Doc.ents` using token-based
rules or exact phrase matches. It can be combined with the statistical rules or exact phrase matches. It can be combined with the statistical
`EntityRecognizer` to boost accuracy, or used on its own to implement a `EntityRecognizer` to boost accuracy, or used on its own to implement a
@ -134,7 +135,6 @@ class EntityRuler:
DOCS: https://nightly.spacy.io/api/entityruler#call DOCS: https://nightly.spacy.io/api/entityruler#call
""" """
self._require_patterns()
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc)) matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
matches = set( matches = set(
[(m_id, start, end) for m_id, start, end in matches if start != end] [(m_id, start, end) for m_id, start, end in matches if start != end]
@ -315,11 +315,6 @@ class EntityRuler:
self.phrase_patterns = defaultdict(list) self.phrase_patterns = defaultdict(list)
self._ent_ids = defaultdict(dict) self._ent_ids = defaultdict(dict)
def _require_patterns(self) -> None:
"""Raise an error if the component has no patterns."""
if not self.patterns or list(self.patterns) == [""]:
raise ValueError(Errors.E900.format(name=self.name))
def _split_label(self, label: str) -> Tuple[str, str]: def _split_label(self, label: str) -> Tuple[str, str]:
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep """Split Entity label into ent_label and ent_id if it contains self.ent_id_sep
@ -348,6 +343,12 @@ class EntityRuler:
validate_examples(examples, "EntityRuler.score") validate_examples(examples, "EntityRuler.score")
return Scorer.score_spans(examples, "ents", **kwargs) return Scorer.score_spans(examples, "ents", **kwargs)
def predict(self, docs):
pass
def set_annotations(self, docs, scores):
pass
def from_bytes( def from_bytes(
self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList() self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
) -> "EntityRuler": ) -> "EntityRuler":

View File

@ -228,6 +228,9 @@ cdef class Pipe:
def is_resizable(self): def is_resizable(self):
return hasattr(self, "model") and "resize_output" in self.model.attrs return hasattr(self, "model") and "resize_output" in self.model.attrs
def is_trainable(self):
return hasattr(self, "model") and isinstance(self.model, Model)
def set_output(self, nO): def set_output(self, nO):
if self.is_resizable(): if self.is_resizable():
self.model.attrs["resize_output"](self.model, nO) self.model.attrs["resize_output"](self.model, nO)

View File

@ -17,8 +17,12 @@ def console_logger(progress_bar: bool = False):
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]: ) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
msg = Printer(no_print=True) msg = Printer(no_print=True)
# we assume here that only components are enabled that should be trained & logged # ensure that only trainable components are logged
logged_pipes = nlp.pipe_names logged_pipes = [
name
for name, proc in nlp.pipeline
if hasattr(proc, "is_trainable") and proc.is_trainable()
]
eval_frequency = nlp.config["training"]["eval_frequency"] eval_frequency = nlp.config["training"]["eval_frequency"]
score_weights = nlp.config["training"]["score_weights"] score_weights = nlp.config["training"]["score_weights"]
score_cols = [col for col, value in score_weights.items() if value is not None] score_cols = [col for col, value in score_weights.items() if value is not None]
@ -43,7 +47,7 @@ def console_logger(progress_bar: bool = False):
return return
losses = [ losses = [
"{0:.2f}".format(float(info["losses"][pipe_name])) "{0:.2f}".format(float(info["losses"][pipe_name]))
for pipe_name in logged_pipes if pipe_name in info["losses"] for pipe_name in logged_pipes
] ]
scores = [] scores = []

View File

@ -181,7 +181,8 @@ def train_while_improving(
for name, proc in nlp.pipeline: for name, proc in nlp.pipeline:
if ( if (
name not in exclude name not in exclude
and hasattr(proc, "model") and hasattr(proc, "is_trainable")
and proc.is_trainable()
and proc.model not in (True, False, None) and proc.model not in (True, False, None)
): ):
proc.finish_update(optimizer) proc.finish_update(optimizer)