mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-06 12:25:48 +03:00
is_trainable method
This commit is contained in:
parent
dc06912c76
commit
4e3ace4b8c
|
@ -1091,7 +1091,8 @@ class Language:
|
|||
for name, proc in self.pipeline:
|
||||
if (
|
||||
name not in exclude
|
||||
and hasattr(proc, "model")
|
||||
and hasattr(proc, "is_trainable")
|
||||
and proc.is_trainable()
|
||||
and proc.model not in (True, False, None)
|
||||
):
|
||||
proc.finish_update(sgd)
|
||||
|
@ -1297,7 +1298,9 @@ class Language:
|
|||
for name, pipe in self.pipeline:
|
||||
kwargs = component_cfg.get(name, {})
|
||||
kwargs.setdefault("batch_size", batch_size)
|
||||
if not hasattr(pipe, "pipe"):
|
||||
# non-trainable components may have a pipe() implementation that refers to dummy
|
||||
# predict and set_annotations methods
|
||||
if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable():
|
||||
docs = _pipe(docs, pipe, kwargs)
|
||||
else:
|
||||
docs = pipe.pipe(docs, **kwargs)
|
||||
|
|
|
@ -2,8 +2,9 @@ from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import srsly
|
||||
from spacy.training import Example
|
||||
|
||||
from .pipe import Pipe
|
||||
from ..training import Example
|
||||
from ..language import Language
|
||||
from ..errors import Errors
|
||||
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
|
||||
|
@ -51,7 +52,7 @@ def make_entity_ruler(
|
|||
)
|
||||
|
||||
|
||||
class EntityRuler:
|
||||
class EntityRuler(Pipe):
|
||||
"""The EntityRuler lets you add spans to the `Doc.ents` using token-based
|
||||
rules or exact phrase matches. It can be combined with the statistical
|
||||
`EntityRecognizer` to boost accuracy, or used on its own to implement a
|
||||
|
@ -134,7 +135,6 @@ class EntityRuler:
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/entityruler#call
|
||||
"""
|
||||
self._require_patterns()
|
||||
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
|
||||
matches = set(
|
||||
[(m_id, start, end) for m_id, start, end in matches if start != end]
|
||||
|
@ -315,11 +315,6 @@ class EntityRuler:
|
|||
self.phrase_patterns = defaultdict(list)
|
||||
self._ent_ids = defaultdict(dict)
|
||||
|
||||
def _require_patterns(self) -> None:
|
||||
"""Raise an error if the component has no patterns."""
|
||||
if not self.patterns or list(self.patterns) == [""]:
|
||||
raise ValueError(Errors.E900.format(name=self.name))
|
||||
|
||||
def _split_label(self, label: str) -> Tuple[str, str]:
|
||||
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep
|
||||
|
||||
|
@ -348,6 +343,12 @@ class EntityRuler:
|
|||
validate_examples(examples, "EntityRuler.score")
|
||||
return Scorer.score_spans(examples, "ents", **kwargs)
|
||||
|
||||
def predict(self, docs):
|
||||
pass
|
||||
|
||||
def set_annotations(self, docs, scores):
|
||||
pass
|
||||
|
||||
def from_bytes(
|
||||
self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
|
||||
) -> "EntityRuler":
|
||||
|
|
|
@ -228,6 +228,9 @@ cdef class Pipe:
|
|||
def is_resizable(self):
|
||||
return hasattr(self, "model") and "resize_output" in self.model.attrs
|
||||
|
||||
def is_trainable(self):
|
||||
return hasattr(self, "model") and isinstance(self.model, Model)
|
||||
|
||||
def set_output(self, nO):
|
||||
if self.is_resizable():
|
||||
self.model.attrs["resize_output"](self.model, nO)
|
||||
|
|
|
@ -17,8 +17,12 @@ def console_logger(progress_bar: bool = False):
|
|||
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
|
||||
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
|
||||
msg = Printer(no_print=True)
|
||||
# we assume here that only components are enabled that should be trained & logged
|
||||
logged_pipes = nlp.pipe_names
|
||||
# ensure that only trainable components are logged
|
||||
logged_pipes = [
|
||||
name
|
||||
for name, proc in nlp.pipeline
|
||||
if hasattr(proc, "is_trainable") and proc.is_trainable()
|
||||
]
|
||||
eval_frequency = nlp.config["training"]["eval_frequency"]
|
||||
score_weights = nlp.config["training"]["score_weights"]
|
||||
score_cols = [col for col, value in score_weights.items() if value is not None]
|
||||
|
@ -43,7 +47,7 @@ def console_logger(progress_bar: bool = False):
|
|||
return
|
||||
losses = [
|
||||
"{0:.2f}".format(float(info["losses"][pipe_name]))
|
||||
for pipe_name in logged_pipes if pipe_name in info["losses"]
|
||||
for pipe_name in logged_pipes
|
||||
]
|
||||
|
||||
scores = []
|
||||
|
|
|
@ -181,7 +181,8 @@ def train_while_improving(
|
|||
for name, proc in nlp.pipeline:
|
||||
if (
|
||||
name not in exclude
|
||||
and hasattr(proc, "model")
|
||||
and hasattr(proc, "is_trainable")
|
||||
and proc.is_trainable()
|
||||
and proc.model not in (True, False, None)
|
||||
):
|
||||
proc.finish_update(optimizer)
|
||||
|
|
Loading…
Reference in New Issue
Block a user