mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-22 22:20:08 +03:00
133 lines
4.0 KiB
Cython
133 lines
4.0 KiB
Cython
# cython: infer_types=True, binding=True
|
|
import importlib
|
|
import sys
|
|
from collections import defaultdict
|
|
from typing import Callable, Optional
|
|
|
|
from thinc.api import Config, Model
|
|
|
|
from ._parser_internals.transition_system import TransitionSystem
|
|
|
|
from ._parser_internals.ner cimport BiluoPushDown
|
|
from .transition_parser cimport Parser
|
|
|
|
from ..language import Language
|
|
from ..scorer import get_ner_prf
|
|
from ..training import remove_bilu_prefix
|
|
from ..util import registry
|
|
|
|
default_model_config = """
|
|
[model]
|
|
@architectures = "spacy.TransitionBasedParser.v2"
|
|
state_type = "ner"
|
|
extra_state_tokens = false
|
|
hidden_width = 64
|
|
maxout_pieces = 2
|
|
use_upper = true
|
|
|
|
[model.tok2vec]
|
|
@architectures = "spacy.HashEmbedCNN.v2"
|
|
pretrained_vectors = null
|
|
width = 96
|
|
depth = 4
|
|
embed_size = 2000
|
|
window_size = 1
|
|
maxout_pieces = 3
|
|
subword_features = true
|
|
"""
|
|
DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]
|
|
|
|
|
|
def ner_score(examples, **kwargs):
|
|
return get_ner_prf(examples, **kwargs)
|
|
|
|
|
|
def make_ner_scorer():
|
|
return ner_score
|
|
|
|
|
|
cdef class EntityRecognizer(Parser):
|
|
"""Pipeline component for named entity recognition.
|
|
|
|
DOCS: https://spacy.io/api/entityrecognizer
|
|
"""
|
|
TransitionSystem = BiluoPushDown
|
|
|
|
def __init__(
|
|
self,
|
|
vocab,
|
|
model,
|
|
name="ner",
|
|
moves=None,
|
|
*,
|
|
update_with_oracle_cut_size=100,
|
|
beam_width=1,
|
|
beam_density=0.0,
|
|
beam_update_prob=0.0,
|
|
multitasks=tuple(),
|
|
incorrect_spans_key=None,
|
|
scorer=ner_score,
|
|
):
|
|
"""Create an EntityRecognizer.
|
|
"""
|
|
super().__init__(
|
|
vocab,
|
|
model,
|
|
name,
|
|
moves,
|
|
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
|
min_action_freq=1, # not relevant for NER
|
|
learn_tokens=False, # not relevant for NER
|
|
beam_width=beam_width,
|
|
beam_density=beam_density,
|
|
beam_update_prob=beam_update_prob,
|
|
multitasks=multitasks,
|
|
incorrect_spans_key=incorrect_spans_key,
|
|
scorer=scorer,
|
|
)
|
|
|
|
def add_multitask_objective(self, mt_component):
|
|
"""Register another component as a multi-task objective. Experimental."""
|
|
self._multitasks.append(mt_component)
|
|
|
|
def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
|
|
"""Setup multi-task objective components. Experimental and internal."""
|
|
# TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
|
|
for labeller in self._multitasks:
|
|
labeller.model.set_dim("nO", len(self.labels))
|
|
if labeller.model.has_ref("output_layer"):
|
|
labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
|
|
labeller.initialize(get_examples, nlp=nlp)
|
|
|
|
@property
|
|
def labels(self):
|
|
# Get the labels from the model by looking at the available moves, e.g.
|
|
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
|
|
labels = set(remove_bilu_prefix(move) for move in self.move_names
|
|
if move[0] in ("B", "I", "L", "U"))
|
|
return tuple(sorted(labels))
|
|
|
|
def scored_ents(self, beams):
|
|
"""Return a dictionary of (start, end, label) tuples with corresponding scores
|
|
for each beam/doc that was processed.
|
|
"""
|
|
entity_scores = []
|
|
for beam in beams:
|
|
score_dict = defaultdict(float)
|
|
for score, ents in self.moves.get_beam_parses(beam):
|
|
for start, end, label in ents:
|
|
score_dict[(start, end, label)] += score
|
|
entity_scores.append(score_dict)
|
|
return entity_scores
|
|
|
|
|
|
# Setup backwards compatibility hook for factories
|
|
def __getattr__(name):
|
|
if name == "make_ner":
|
|
module = importlib.import_module("spacy.pipeline.factories")
|
|
return module.make_ner
|
|
elif name == "make_beam_ner":
|
|
module = importlib.import_module("spacy.pipeline.factories")
|
|
return module.make_beam_ner
|
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|