mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-06 12:25:48 +03:00
Currently when a new label is introduced to NER during training, it causes the labels to be read in in an unexpected order. This invalidates the model.
80 lines
2.6 KiB
Cython
80 lines
2.6 KiB
Cython
from .syntax.parser cimport Parser
|
|
from .syntax.beam_parser cimport BeamParser
|
|
from .syntax.ner cimport BiluoPushDown
|
|
from .syntax.arc_eager cimport ArcEager
|
|
from .tagger import Tagger
|
|
|
|
# TODO: The disorganization here is pretty embarrassing. At least it's only
|
|
# internals.
|
|
from .syntax.parser import get_templates as get_feature_templates
|
|
from .attrs import DEP, ENT_TYPE
|
|
|
|
|
|
cdef class EntityRecognizer(Parser):
|
|
"""Annotate named entities on Doc objects."""
|
|
TransitionSystem = BiluoPushDown
|
|
|
|
feature_templates = get_feature_templates('ner')
|
|
|
|
def add_label(self, label):
|
|
Parser.add_label(self, label)
|
|
if isinstance(label, basestring):
|
|
label = self.vocab.strings[label]
|
|
# Set label into serializer. Super hacky :(
|
|
for attr, freqs in self.vocab.serializer_freqs:
|
|
if attr == ENT_TYPE and label not in freqs:
|
|
freqs.append([label, 1])
|
|
self.vocab._serializer = None
|
|
|
|
|
|
cdef class BeamEntityRecognizer(BeamParser):
|
|
"""Annotate named entities on Doc objects."""
|
|
TransitionSystem = BiluoPushDown
|
|
|
|
feature_templates = get_feature_templates('ner')
|
|
|
|
def add_label(self, label):
|
|
Parser.add_label(self, label)
|
|
if isinstance(label, basestring):
|
|
label = self.vocab.strings[label]
|
|
# Set label into serializer. Super hacky :(
|
|
for attr, freqs in self.vocab.serializer_freqs:
|
|
if attr == ENT_TYPE and label not in freqs:
|
|
freqs.append([label, 1])
|
|
self.vocab._serializer = None
|
|
|
|
|
|
cdef class DependencyParser(Parser):
|
|
TransitionSystem = ArcEager
|
|
|
|
feature_templates = get_feature_templates('basic')
|
|
|
|
def add_label(self, label):
|
|
Parser.add_label(self, label)
|
|
if isinstance(label, basestring):
|
|
label = self.vocab.strings[label]
|
|
for attr, freqs in self.vocab.serializer_freqs:
|
|
if attr == DEP and label not in freqs:
|
|
freqs.append([label, 1])
|
|
# Super hacky :(
|
|
self.vocab._serializer = None
|
|
|
|
|
|
cdef class BeamDependencyParser(BeamParser):
|
|
TransitionSystem = ArcEager
|
|
|
|
feature_templates = get_feature_templates('basic')
|
|
|
|
def add_label(self, label):
|
|
Parser.add_label(self, label)
|
|
if isinstance(label, basestring):
|
|
label = self.vocab.strings[label]
|
|
for attr, freqs in self.vocab.serializer_freqs:
|
|
if attr == DEP and label not in freqs:
|
|
freqs.append([label, 1])
|
|
# Super hacky :(
|
|
self.vocab._serializer = None
|
|
|
|
|
|
__all__ = [Tagger, DependencyParser, EntityRecognizer, BeamDependencyParser, BeamEntityRecognizer]
|