Merge pull request #6206 from svlandeg/fix/patterns-init

This commit is contained in:
Ines Montani 2020-10-06 10:27:23 +02:00 committed by GitHub
commit 568e12215d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 156 additions and 39 deletions

View File

@ -1091,10 +1091,11 @@ class Language:
for name, proc in self.pipeline: for name, proc in self.pipeline:
if ( if (
name not in exclude name not in exclude
and hasattr(proc, "model") and hasattr(proc, "is_trainable")
and proc.is_trainable()
and proc.model not in (True, False, None) and proc.model not in (True, False, None)
): ):
proc.model.finish_update(sgd) proc.finish_update(sgd)
return losses return losses
def rehearse( def rehearse(
@ -1297,7 +1298,9 @@ class Language:
for name, pipe in self.pipeline: for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {}) kwargs = component_cfg.get(name, {})
kwargs.setdefault("batch_size", batch_size) kwargs.setdefault("batch_size", batch_size)
if not hasattr(pipe, "pipe"): # non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods
if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable():
docs = _pipe(docs, pipe, kwargs) docs = _pipe(docs, pipe, kwargs)
else: else:
docs = pipe.pipe(docs, **kwargs) docs = pipe.pipe(docs, **kwargs)
@ -1407,7 +1410,9 @@ class Language:
kwargs = component_cfg.get(name, {}) kwargs = component_cfg.get(name, {})
# Allow component_cfg to overwrite the top-level kwargs. # Allow component_cfg to overwrite the top-level kwargs.
kwargs.setdefault("batch_size", batch_size) kwargs.setdefault("batch_size", batch_size)
if hasattr(proc, "pipe"): # non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods
if hasattr(proc, "pipe") and hasattr(proc, "is_trainable") and proc.is_trainable():
f = functools.partial(proc.pipe, **kwargs) f = functools.partial(proc.pipe, **kwargs)
else: else:
# Apply the function, but yield the doc # Apply the function, but yield the doc

View File

@ -238,7 +238,7 @@ class EntityLinker(Pipe):
) )
bp_context(d_scores) bp_context(d_scores)
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.finish_update(sgd)
losses[self.name] += loss losses[self.name] += loss
if set_annotations: if set_annotations:
self.set_annotations(docs, predictions) self.set_annotations(docs, predictions)

View File

@ -1,8 +1,10 @@
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable, Sequence
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import srsly import srsly
from .pipe import Pipe
from ..training import Example
from ..language import Language from ..language import Language
from ..errors import Errors from ..errors import Errors
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
@ -50,7 +52,7 @@ def make_entity_ruler(
) )
class EntityRuler: class EntityRuler(Pipe):
"""The EntityRuler lets you add spans to the `Doc.ents` using token-based """The EntityRuler lets you add spans to the `Doc.ents` using token-based
rules or exact phrase matches. It can be combined with the statistical rules or exact phrase matches. It can be combined with the statistical
`EntityRecognizer` to boost accuracy, or used on its own to implement a `EntityRecognizer` to boost accuracy, or used on its own to implement a
@ -183,6 +185,26 @@ class EntityRuler:
all_labels.add(l) all_labels.add(l)
return tuple(all_labels) return tuple(all_labels)
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language] = None,
patterns: Optional[Sequence[PatternType]] = None,
):
"""Initialize the pipe for training.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
patterns Optional[Iterable[PatternType]]: The list of patterns.
DOCS: https://nightly.spacy.io/api/entityruler#initialize
"""
if patterns:
self.add_patterns(patterns)
@property @property
def ent_ids(self) -> Tuple[str, ...]: def ent_ids(self) -> Tuple[str, ...]:
"""All entity ids present in the match patterns `id` properties """All entity ids present in the match patterns `id` properties
@ -320,6 +342,12 @@ class EntityRuler:
validate_examples(examples, "EntityRuler.score") validate_examples(examples, "EntityRuler.score")
return Scorer.score_spans(examples, "ents", **kwargs) return Scorer.score_spans(examples, "ents", **kwargs)
def predict(self, docs):
pass
def set_annotations(self, docs, scores):
pass
def from_bytes( def from_bytes(
self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList() self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
) -> "EntityRuler": ) -> "EntityRuler":

View File

@ -209,7 +209,7 @@ class ClozeMultitask(Pipe):
loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions) loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
bp_predictions(d_predictions) bp_predictions(d_predictions)
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.finish_update(sgd)
if losses is not None: if losses is not None:
losses[self.name] += loss losses[self.name] += loss
return losses return losses

View File

@ -132,7 +132,7 @@ cdef class Pipe:
loss, d_scores = self.get_loss(examples, scores) loss, d_scores = self.get_loss(examples, scores)
bp_scores(d_scores) bp_scores(d_scores)
if sgd not in (None, False): if sgd not in (None, False):
self.model.finish_update(sgd) self.finish_update(sgd)
losses[self.name] += loss losses[self.name] += loss
if set_annotations: if set_annotations:
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
@ -228,6 +228,9 @@ cdef class Pipe:
def is_resizable(self): def is_resizable(self):
return hasattr(self, "model") and "resize_output" in self.model.attrs return hasattr(self, "model") and "resize_output" in self.model.attrs
def is_trainable(self):
return hasattr(self, "model") and isinstance(self.model, Model)
def set_output(self, nO): def set_output(self, nO):
if self.is_resizable(): if self.is_resizable():
self.model.attrs["resize_output"](self.model, nO) self.model.attrs["resize_output"](self.model, nO)
@ -245,6 +248,17 @@ cdef class Pipe:
with self.model.use_params(params): with self.model.use_params(params):
yield yield
def finish_update(self, sgd):
"""Update parameters using the current parameter gradients.
The Optimizer instance contains the functionality to perform
the stochastic gradient descent.
sgd (thinc.api.Optimizer): The optimizer.
DOCS: https://nightly.spacy.io/api/pipe#finish_update
"""
self.model.finish_update(sgd)
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
"""Score a batch of examples. """Score a batch of examples.

View File

@ -203,7 +203,7 @@ class Tagger(Pipe):
loss, d_tag_scores = self.get_loss(examples, tag_scores) loss, d_tag_scores = self.get_loss(examples, tag_scores)
bp_tag_scores(d_tag_scores) bp_tag_scores(d_tag_scores)
if sgd not in (None, False): if sgd not in (None, False):
self.model.finish_update(sgd) self.finish_update(sgd)
losses[self.name] += loss losses[self.name] += loss
if set_annotations: if set_annotations:
@ -238,7 +238,7 @@ class Tagger(Pipe):
target = self._rehearsal_model(examples) target = self._rehearsal_model(examples)
gradient = guesses - target gradient = guesses - target
backprop(gradient) backprop(gradient)
self.model.finish_update(sgd) self.finish_update(sgd)
if losses is not None: if losses is not None:
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient**2).sum() losses[self.name] += (gradient**2).sum()

View File

@ -212,7 +212,7 @@ class TextCategorizer(Pipe):
loss, d_scores = self.get_loss(examples, scores) loss, d_scores = self.get_loss(examples, scores)
bp_scores(d_scores) bp_scores(d_scores)
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.finish_update(sgd)
losses[self.name] += loss losses[self.name] += loss
if set_annotations: if set_annotations:
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
@ -256,7 +256,7 @@ class TextCategorizer(Pipe):
gradient = scores - target gradient = scores - target
bp_scores(gradient) bp_scores(gradient)
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.finish_update(sgd)
if losses is not None: if losses is not None:
losses[self.name] += (gradient ** 2).sum() losses[self.name] += (gradient ** 2).sum()
return losses return losses

View File

@ -188,7 +188,7 @@ class Tok2Vec(Pipe):
accumulate_gradient(one_d_tokvecs) accumulate_gradient(one_d_tokvecs)
d_docs = bp_tokvecs(d_tokvecs) d_docs = bp_tokvecs(d_tokvecs)
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.finish_update(sgd)
return d_docs return d_docs
batch_id = Tok2VecListener.get_batch_id(docs) batch_id = Tok2VecListener.get_batch_id(docs)

View File

@ -315,7 +315,7 @@ cdef class Parser(Pipe):
backprop_tok2vec(golds) backprop_tok2vec(golds)
if sgd not in (None, False): if sgd not in (None, False):
self.model.finish_update(sgd) self.finish_update(sgd)
if set_annotations: if set_annotations:
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
self.set_annotations(docs, all_states) self.set_annotations(docs, all_states)
@ -367,7 +367,7 @@ cdef class Parser(Pipe):
# Do the backprop # Do the backprop
backprop_tok2vec(docs) backprop_tok2vec(docs)
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.finish_update(sgd)
losses[self.name] += loss / n_scores losses[self.name] += loss / n_scores
del backprop del backprop
del backprop_tok2vec del backprop_tok2vec
@ -437,7 +437,9 @@ cdef class Parser(Pipe):
for name, component in nlp.pipeline: for name, component in nlp.pipeline:
if component is self: if component is self:
break break
if hasattr(component, "pipe"): # non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods
if hasattr(component, "pipe") and hasattr(component, "is_trainable") and component.is_trainable():
doc_sample = list(component.pipe(doc_sample, batch_size=8)) doc_sample = list(component.pipe(doc_sample, batch_size=8))
else: else:
doc_sample = [component(doc) for doc in doc_sample] doc_sample = [component(doc) for doc in doc_sample]

View File

@ -119,7 +119,7 @@ def validate_init_settings(
if types don't match or required values are missing. if types don't match or required values are missing.
func (Callable): The initialize method of a given component etc. func (Callable): The initialize method of a given component etc.
settings (Dict[str, Any]): The settings from the repsective [initialize] block. settings (Dict[str, Any]): The settings from the respective [initialize] block.
section (str): Initialize section, for error message. section (str): Initialize section, for error message.
name (str): Name of the block in the section. name (str): Name of the block in the section.
exclude (Iterable[str]): Parameter names to exclude from schema. exclude (Iterable[str]): Parameter names to exclude from schema.

View File

@ -121,7 +121,7 @@ def test_attributeruler_init_patterns(nlp, pattern_dicts):
assert doc.has_annotation("LEMMA") assert doc.has_annotation("LEMMA")
assert doc.has_annotation("MORPH") assert doc.has_annotation("MORPH")
nlp.remove_pipe("attribute_ruler") nlp.remove_pipe("attribute_ruler")
# initialize with patterns from asset # initialize with patterns from misc registry
nlp.config["initialize"]["components"]["attribute_ruler"] = { nlp.config["initialize"]["components"]["attribute_ruler"] = {
"patterns": {"@misc": "attribute_ruler_patterns"} "patterns": {"@misc": "attribute_ruler_patterns"}
} }

View File

@ -1,4 +1,6 @@
import pytest import pytest
from spacy import registry
from spacy.tokens import Span from spacy.tokens import Span
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import EntityRuler from spacy.pipeline import EntityRuler
@ -11,6 +13,7 @@ def nlp():
@pytest.fixture @pytest.fixture
@registry.misc("entity_ruler_patterns")
def patterns(): def patterns():
return [ return [
{"label": "HELLO", "pattern": "hello world"}, {"label": "HELLO", "pattern": "hello world"},
@ -42,6 +45,29 @@ def test_entity_ruler_init(nlp, patterns):
assert doc.ents[1].label_ == "BYE" assert doc.ents[1].label_ == "BYE"
def test_entity_ruler_init_patterns(nlp, patterns):
# initialize with patterns
ruler = nlp.add_pipe("entity_ruler")
assert len(ruler.labels) == 0
ruler.initialize(lambda: [], patterns=patterns)
assert len(ruler.labels) == 4
doc = nlp("hello world bye bye")
assert doc.ents[0].label_ == "HELLO"
assert doc.ents[1].label_ == "BYE"
nlp.remove_pipe("entity_ruler")
# initialize with patterns from misc registry
nlp.config["initialize"]["components"]["entity_ruler"] = {
"patterns": {"@misc": "entity_ruler_patterns"}
}
ruler = nlp.add_pipe("entity_ruler")
assert len(ruler.labels) == 0
nlp.initialize()
assert len(ruler.labels) == 4
doc = nlp("hello world bye bye")
assert doc.ents[0].label_ == "HELLO"
assert doc.ents[1].label_ == "BYE"
def test_entity_ruler_existing(nlp, patterns): def test_entity_ruler_existing(nlp, patterns):
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns) ruler.add_patterns(patterns)

View File

@ -49,7 +49,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
nlp.resume_training(sgd=optimizer) nlp.resume_training(sgd=optimizer)
with nlp.select_pipes(disable=[*frozen_components, *resume_components]): with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer) nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
logger.info("Initialized pipeline components") logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
return nlp return nlp

View File

@ -17,8 +17,12 @@ def console_logger(progress_bar: bool = False):
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]: ) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
msg = Printer(no_print=True) msg = Printer(no_print=True)
# we assume here that only components are enabled that should be trained & logged # ensure that only trainable components are logged
logged_pipes = nlp.pipe_names logged_pipes = [
name
for name, proc in nlp.pipeline
if hasattr(proc, "is_trainable") and proc.is_trainable()
]
eval_frequency = nlp.config["training"]["eval_frequency"] eval_frequency = nlp.config["training"]["eval_frequency"]
score_weights = nlp.config["training"]["score_weights"] score_weights = nlp.config["training"]["score_weights"]
score_cols = [col for col, value in score_weights.items() if value is not None] score_cols = [col for col, value in score_weights.items() if value is not None]
@ -41,19 +45,10 @@ def console_logger(progress_bar: bool = False):
if progress is not None: if progress is not None:
progress.update(1) progress.update(1)
return return
try:
losses = [ losses = [
"{0:.2f}".format(float(info["losses"][pipe_name])) "{0:.2f}".format(float(info["losses"][pipe_name]))
for pipe_name in logged_pipes for pipe_name in logged_pipes
] ]
except KeyError as e:
raise KeyError(
Errors.E983.format(
dict="scores (losses)",
key=str(e),
keys=list(info["losses"].keys()),
)
) from None
scores = [] scores = []
for col in score_cols: for col in score_cols:

View File

@ -187,10 +187,11 @@ def train_while_improving(
for name, proc in nlp.pipeline: for name, proc in nlp.pipeline:
if ( if (
name not in exclude name not in exclude
and hasattr(proc, "model") and hasattr(proc, "is_trainable")
and proc.is_trainable()
and proc.model not in (True, False, None) and proc.model not in (True, False, None)
): ):
proc.model.finish_update(optimizer) proc.finish_update(optimizer)
optimizer.step_schedules() optimizer.step_schedules()
if not (step % eval_frequency): if not (step % eval_frequency):
if optimizer.averages: if optimizer.averages:
@ -293,6 +294,7 @@ def update_meta(
if metric is not None: if metric is not None:
nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0) nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0)
for pipe_name in nlp.pipe_names: for pipe_name in nlp.pipe_names:
if pipe_name in info["losses"]:
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name] nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]

View File

@ -74,6 +74,33 @@ be a token pattern (list) or a phrase pattern (string). For example:
| `ent_id_sep` | Separator used internally for entity IDs. Defaults to `"||"`. ~~str~~ | | `ent_id_sep` | Separator used internally for entity IDs. Defaults to `"||"`. ~~str~~ |
| `patterns` | Optional patterns to load in on initialization. ~~Optional[List[Dict[str, Union[str, List[dict]]]]]~~ | | `patterns` | Optional patterns to load in on initialization. ~~Optional[List[Dict[str, Union[str, List[dict]]]]]~~ |
## EntityRuler.initialize {#initialize tag="method" new="3"}
Initialize the component with patterns from a file.
> #### Example
>
> ```python
> entity_ruler = nlp.add_pipe("entity_ruler")
> entity_ruler.initialize(lambda: [], nlp=nlp, patterns=patterns)
> ```
>
> ```ini
> ### config.cfg
> [initialize.components.entity_ruler]
>
> [initialize.components.entity_ruler.patterns]
> @readers = "srsly.read_jsonl.v1"
> path = "corpus/entity_ruler_patterns.jsonl
> ```
| Name | Description |
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `get_examples` | Function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. Not used by the `EntityRuler`. ~~Callable[[], Iterable[Example]]~~ |
| _keyword-only_ | |
| `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ |
| `patterns` | The list of patterns. Defaults to `None`. ~~Optional[Sequence[Dict[str, Union[str, List[Dict[str, Any]]]]]]~~ |
## EntityRuler.\_\len\_\_ {#len tag="method"} ## EntityRuler.\_\len\_\_ {#len tag="method"}
The number of all patterns added to the entity ruler. The number of all patterns added to the entity ruler.

View File

@ -294,6 +294,24 @@ context, the original parameters are restored.
| -------- | -------------------------------------------------- | | -------- | -------------------------------------------------- |
| `params` | The parameter values to use in the model. ~~dict~~ | | `params` | The parameter values to use in the model. ~~dict~~ |
## Pipe.finish_update {#finish_update tag="method"}
Update parameters using the current parameter gradients. Defaults to calling
[`self.model.finish_update`](https://thinc.ai/docs/api-model#finish_update).
> #### Example
>
> ```python
> pipe = nlp.add_pipe("your_custom_pipe")
> optimizer = nlp.initialize()
> losses = pipe.update(examples, sgd=None)
> pipe.finish_update(sgd)
> ```
| Name | Description |
| ----- | ------------------------------------- |
| `sgd` | An optimizer. ~~Optional[Optimizer]~~ |
## Pipe.add_label {#add_label tag="method"} ## Pipe.add_label {#add_label tag="method"}
> #### Example > #### Example