mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Refactor the pipeline classes to make them more consistent, and remove the redundant blank() constructor.
This commit is contained in:
parent
311a985fe0
commit
f787cd29fe
63
examples/training/train_ner.py
Normal file
63
examples/training/train_ner.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
from __future__ import unicode_literals, print_function
|
||||
import json
|
||||
import pathlib
|
||||
import random
|
||||
|
||||
import spacy
|
||||
from spacy.pipeline import EntityRecognizer
|
||||
from spacy.gold import GoldParse
|
||||
|
||||
|
||||
def train_ner(nlp, train_data, entity_types):
|
||||
ner = EntityRecognizer(nlp.vocab, entity_types=entity_types)
|
||||
for itn in range(5):
|
||||
random.shuffle(train_data)
|
||||
for raw_text, entity_offsets in train_data:
|
||||
doc = nlp.make_doc(raw_text)
|
||||
gold = GoldParse(doc, entities=entity_offsets)
|
||||
ner.update(doc, gold)
|
||||
ner.model.end_training()
|
||||
return ner
|
||||
|
||||
|
||||
def main(model_dir=None):
|
||||
if model_dir is not None:
|
||||
model_dir = pathlb.Path(model_dir)
|
||||
if not model_dir.exists():
|
||||
model_dir.mkdir()
|
||||
assert model_dir.isdir()
|
||||
|
||||
nlp = spacy.load('en', parser=False, entity=False, vectors=False)
|
||||
|
||||
train_data = [
|
||||
(
|
||||
'Who is Shaka Khan?',
|
||||
[(len('Who is '), len('Who is Shaka Khan'), 'PERSON')]
|
||||
),
|
||||
(
|
||||
'I like London and Berlin.',
|
||||
[(len('I like '), len('I like London'), 'LOC'),
|
||||
(len('I like London and '), len('I like London and Berlin'), 'LOC')]
|
||||
)
|
||||
]
|
||||
ner = train_ner(nlp, train_data, ['PERSON', 'LOC'])
|
||||
|
||||
doc = nlp.make_doc('Who is Shaka Khan?')
|
||||
nlp.tagger(doc)
|
||||
ner(doc)
|
||||
for word in doc:
|
||||
print(word.text, word.tag_, word.ent_type_, word.ent_iob)
|
||||
|
||||
if model_dir is not None:
|
||||
with (model_dir / 'config.json').open('wb') as file_:
|
||||
json.dump(ner.cfg, file_)
|
||||
ner.model.dump(str(model_dir / 'model'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
# Who "" 2
|
||||
# is "" 2
|
||||
# Shaka "" PERSON 3
|
||||
# Khan "" PERSON 1
|
||||
# ? "" 2
|
|
@ -10,11 +10,10 @@ from spacy.tokens import Doc
|
|||
|
||||
|
||||
def train_parser(nlp, train_data, left_labels, right_labels):
|
||||
parser = DependencyParser.blank(
|
||||
nlp.vocab,
|
||||
left_labels=left_labels,
|
||||
right_labels=right_labels,
|
||||
features=nlp.defaults.parser_features)
|
||||
parser = DependencyParser(
|
||||
nlp.vocab,
|
||||
left_labels=left_labels,
|
||||
right_labels=right_labels)
|
||||
for itn in range(1000):
|
||||
random.shuffle(train_data)
|
||||
loss = 0
|
||||
|
|
|
@ -53,7 +53,7 @@ def main(output_dir=None):
|
|||
vocab = Vocab(tag_map=TAG_MAP)
|
||||
# The default_templates argument is where features are specified. See
|
||||
# spacy/tagger.pyx for the defaults.
|
||||
tagger = Tagger.blank(vocab, Tagger.default_templates())
|
||||
tagger = Tagger(vocab)
|
||||
for i in range(5):
|
||||
for words, tags in DATA:
|
||||
doc = Doc(vocab, words=words)
|
||||
|
|
|
@ -24,7 +24,7 @@ def blank(name, vocab=None, tokenizer=None, parser=None, tagger=None, entity=Non
|
|||
target_name, target_version = util.split_data_name(name)
|
||||
cls = get_lang_class(target_name)
|
||||
return cls(
|
||||
path,
|
||||
path=None,
|
||||
vectors=vectors,
|
||||
vocab=vocab,
|
||||
tokenizer=tokenizer,
|
||||
|
|
|
@ -19,7 +19,6 @@ except NameError:
|
|||
|
||||
from .tokenizer import Tokenizer
|
||||
from .vocab import Vocab
|
||||
from .syntax.parser import Parser
|
||||
from .tagger import Tagger
|
||||
from .matcher import Matcher
|
||||
from . import attrs
|
||||
|
@ -95,7 +94,9 @@ class BaseDefaults(object):
|
|||
if self.path:
|
||||
return Tagger.load(self.path / 'pos', vocab)
|
||||
else:
|
||||
return Tagger.blank(vocab, Tagger.default_templates())
|
||||
if 'features' not in cfg:
|
||||
cfg['features'] = self.parser_features
|
||||
return Tagger(vocab, **cfg)
|
||||
|
||||
def Parser(self, vocab, **cfg):
|
||||
if self.path and (self.path / 'deps').exists():
|
||||
|
@ -103,7 +104,7 @@ class BaseDefaults(object):
|
|||
else:
|
||||
if 'features' not in cfg:
|
||||
cfg['features'] = self.parser_features
|
||||
return DependencyParser.blank(vocab, **cfg)
|
||||
return DependencyParser(vocab, **cfg)
|
||||
|
||||
def Entity(self, vocab, **cfg):
|
||||
if self.path and (self.path / 'ner').exists():
|
||||
|
@ -111,7 +112,7 @@ class BaseDefaults(object):
|
|||
else:
|
||||
if 'features' not in cfg:
|
||||
cfg['features'] = self.entity_features
|
||||
return EntityRecognizer.blank(vocab, **cfg)
|
||||
return EntityRecognizer(vocab, **cfg)
|
||||
|
||||
def Matcher(self, vocab, **cfg):
|
||||
if self.path:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .syntax.parser cimport Parser
|
||||
from .syntax.ner cimport BiluoPushDown
|
||||
from .syntax.arc_eager cimport ArcEager
|
||||
from .tagger cimport Tagger
|
||||
|
||||
|
||||
cdef class EntityRecognizer(Parser):
|
||||
|
|
|
@ -2,41 +2,23 @@ from .syntax.parser cimport Parser
|
|||
from .syntax.ner cimport BiluoPushDown
|
||||
from .syntax.arc_eager cimport ArcEager
|
||||
from .vocab cimport Vocab
|
||||
from .tagger cimport Tagger
|
||||
from .tagger import Tagger
|
||||
|
||||
# TODO: The disorganization here is pretty embarrassing. At least it's only
|
||||
# internals.
|
||||
from .syntax.parser import get_templates as get_feature_templates
|
||||
|
||||
|
||||
cdef class EntityRecognizer(Parser):
|
||||
@classmethod
|
||||
def load(cls, path, Vocab vocab):
|
||||
return Parser.load(path, vocab, BiluoPushDown)
|
||||
|
||||
@classmethod
|
||||
def blank(cls, Vocab vocab, **cfg):
|
||||
if 'actions' not in cfg:
|
||||
cfg['actions'] = {0: {'': True}, 5: {'': True}}
|
||||
entity_types = cfg.get('entity_types', [''])
|
||||
for action_type in (1, 2, 3, 4):
|
||||
cfg['actions'][action_type] = {ent_type: True for ent_type in entity_types}
|
||||
return Parser.blank(vocab, BiluoPushDown, **cfg)
|
||||
TransitionSystem = BiluoPushDown
|
||||
|
||||
feature_templates = get_feature_templates('ner')
|
||||
|
||||
|
||||
cdef class DependencyParser(Parser):
|
||||
@classmethod
|
||||
def load(cls, path, Vocab vocab):
|
||||
return Parser.load(path, vocab, ArcEager)
|
||||
|
||||
@classmethod
|
||||
def blank(cls, Vocab vocab, **cfg):
|
||||
if 'actions' not in cfg:
|
||||
cfg['actions'] = {0: {'': True}, 1: {'': True}, 2: {}, 3: {},
|
||||
4: {'ROOT': True}}
|
||||
for label in cfg.get('left_labels', []):
|
||||
cfg['actions'][2][label] = True
|
||||
for label in cfg.get('right_labels', []):
|
||||
cfg['actions'][3][label] = True
|
||||
for label in cfg.get('break_labels', []):
|
||||
cfg['actions'][4][label] = True
|
||||
return Parser.blank(vocab, ArcEager, **cfg)
|
||||
TransitionSystem = ArcEager
|
||||
|
||||
feature_templates = get_feature_templates('basic')
|
||||
|
||||
|
||||
__all__ = [Tagger, DependencyParser, EntityRecognizer]
|
||||
|
|
|
@ -279,20 +279,36 @@ cdef int _get_root(int word, const GoldParseC* gold) nogil:
|
|||
|
||||
cdef class ArcEager(TransitionSystem):
|
||||
@classmethod
|
||||
def get_labels(cls, gold_parses):
|
||||
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
|
||||
LEFT: {}, BREAK: {'ROOT': True}}
|
||||
for raw_text, sents in gold_parses:
|
||||
def get_actions(cls, **kwargs):
|
||||
actions = kwargs.get('actions',
|
||||
{
|
||||
SHIFT: {'': True},
|
||||
REDUCE: {'': True},
|
||||
RIGHT: {},
|
||||
LEFT: {},
|
||||
BREAK: {'ROOT': True}})
|
||||
for label in kwargs.get('labels', []):
|
||||
if label.upper() != 'ROOT':
|
||||
actions[LEFT][label] = True
|
||||
actions[RIGHT][label] = True
|
||||
for label in kwargs.get('left_labels', []):
|
||||
if label.upper() != 'ROOT':
|
||||
actions[LEFT][label] = True
|
||||
for label in kwargs.get('right_labels', []):
|
||||
if label.upper() != 'ROOT':
|
||||
actions[RIGHT][label] = True
|
||||
|
||||
for raw_text, sents in kwargs.get('gold_parses', []):
|
||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||
for child, head, label in zip(ids, heads, labels):
|
||||
if label.upper() == 'ROOT':
|
||||
label = 'ROOT'
|
||||
if label != 'ROOT':
|
||||
if head < child:
|
||||
move_labels[RIGHT][label] = True
|
||||
actions[RIGHT][label] = True
|
||||
elif head > child:
|
||||
move_labels[LEFT][label] = True
|
||||
return move_labels
|
||||
actions[LEFT][label] = True
|
||||
return actions
|
||||
|
||||
property action_types:
|
||||
def __get__(self):
|
||||
|
|
|
@ -51,11 +51,21 @@ cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
|
|||
|
||||
cdef class BiluoPushDown(TransitionSystem):
|
||||
@classmethod
|
||||
def get_labels(cls, gold_tuples):
|
||||
move_labels = {MISSING: {'': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {},
|
||||
OUT: {'': True}}
|
||||
def get_actions(cls, **kwargs):
|
||||
actions = kwargs.get('actions',
|
||||
{
|
||||
MISSING: {'': True},
|
||||
BEGIN: {},
|
||||
IN: {},
|
||||
LAST: {},
|
||||
UNIT: {},
|
||||
OUT: {'': True}
|
||||
})
|
||||
for entity_type in kwargs.get('entity_types', []):
|
||||
for action in (BEGIN, IN, LAST, UNIT):
|
||||
actions[action][entity_type] = True
|
||||
moves = ('M', 'B', 'I', 'L', 'U')
|
||||
for raw_text, sents in gold_tuples:
|
||||
for raw_text, sents in kwargs.get('gold_tuples', []):
|
||||
for (ids, words, tags, heads, labels, biluo), _ in sents:
|
||||
for i, ner_tag in enumerate(biluo):
|
||||
if ner_tag != 'O' and ner_tag != '-':
|
||||
|
@ -63,8 +73,8 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
raise ValueError(ner_tag)
|
||||
_, label = ner_tag.split('-')
|
||||
for move_str in ('B', 'I', 'L', 'U'):
|
||||
move_labels[moves.index(move_str)][label] = True
|
||||
return move_labels
|
||||
actions[moves.index(move_str)][label] = True
|
||||
return actions
|
||||
|
||||
property action_types:
|
||||
def __get__(self):
|
||||
|
|
|
@ -67,10 +67,6 @@ def get_templates(name):
|
|||
pf.tree_shape + pf.trigrams)
|
||||
|
||||
|
||||
def ParserFactory(transition_system):
|
||||
return lambda strings, dir_: Parser(strings, dir_, transition_system)
|
||||
|
||||
|
||||
cdef class ParserModel(AveragedPerceptron):
|
||||
cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil:
|
||||
fill_context(eg.atoms, state)
|
||||
|
@ -79,28 +75,26 @@ cdef class ParserModel(AveragedPerceptron):
|
|||
|
||||
cdef class Parser:
|
||||
@classmethod
|
||||
def load(cls, path, Vocab vocab, moves_class):
|
||||
def load(cls, path, Vocab vocab, TransitionSystem=None, require=False):
|
||||
with (path / 'config.json').open() as file_:
|
||||
cfg = json.load(file_)
|
||||
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
|
||||
if (path / 'model').exists():
|
||||
self.model.load(str(path / 'model'))
|
||||
elif require:
|
||||
raise IOError(
|
||||
"Required file %s/model not found when loading" % str(path))
|
||||
return self
|
||||
|
||||
def __init__(self, Vocab vocab, TransitionSystem=None, ParserModel model=None, **cfg):
|
||||
if TransitionSystem is None:
|
||||
TransitionSystem = self.TransitionSystem
|
||||
actions = TransitionSystem.get_actions(**cfg)
|
||||
self.moves = TransitionSystem(vocab.strings, actions)
|
||||
# TODO: Remove this when we no longer need to support old-style models
|
||||
if isinstance(cfg.get('features'), basestring):
|
||||
cfg['features'] = get_templates(cfg['features'])
|
||||
moves = moves_class(vocab.strings, cfg['actions'])
|
||||
model = ParserModel(cfg['features'])
|
||||
if (path / 'model').exists():
|
||||
model.load(str(path / 'model'))
|
||||
return cls(vocab, moves, model, **cfg)
|
||||
|
||||
@classmethod
|
||||
def blank(cls, Vocab vocab, moves_class, **cfg):
|
||||
moves = moves_class(vocab.strings, cfg.get('actions', {}))
|
||||
templates = cfg.get('features', tuple())
|
||||
model = ParserModel(templates)
|
||||
return cls(vocab, moves, model, **cfg)
|
||||
|
||||
def __init__(self, Vocab vocab, transition_system, ParserModel model, **cfg):
|
||||
self.moves = transition_system
|
||||
self.model = model
|
||||
self.model = ParserModel(cfg['features'])
|
||||
self.cfg = cfg
|
||||
|
||||
def __reduce__(self):
|
||||
|
@ -192,8 +186,8 @@ cdef class Parser:
|
|||
free(eg.is_valid)
|
||||
return 0
|
||||
|
||||
def update(self, Doc tokens, GoldParse gold):
|
||||
self.moves.preprocess_gold(gold)
|
||||
def update(self, Doc tokens, raw_gold):
|
||||
cdef GoldParse gold = self.preprocess_gold(raw_gold)
|
||||
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
|
||||
self.moves.initialize_state(stcls.c)
|
||||
cdef Pool mem = Pool()
|
||||
|
@ -230,6 +224,15 @@ cdef class Parser:
|
|||
for action in self.moves.action_types:
|
||||
self.moves.add_action(action, label)
|
||||
|
||||
def preprocess_gold(self, raw_gold):
|
||||
cdef GoldParse gold
|
||||
if isinstance(raw_gold, GoldParse):
|
||||
gold = raw_gold
|
||||
self.moves.preprocess_gold(raw_gold)
|
||||
return gold
|
||||
else:
|
||||
raise ValueError("Parser.preprocess_gold requires GoldParse-type input.")
|
||||
|
||||
|
||||
cdef class StepwiseState:
|
||||
cdef readonly StateClass stcls
|
||||
|
|
|
@ -103,58 +103,30 @@ cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil:
|
|||
cdef class Tagger:
|
||||
"""A part-of-speech tagger for English"""
|
||||
@classmethod
|
||||
def default_templates(cls):
|
||||
return (
|
||||
(W_orth,),
|
||||
(P1_lemma, P1_pos),
|
||||
(P2_lemma, P2_pos),
|
||||
(N1_orth,),
|
||||
(N2_orth,),
|
||||
|
||||
(W_suffix,),
|
||||
(W_prefix,),
|
||||
|
||||
(P1_pos,),
|
||||
(P2_pos,),
|
||||
(P1_pos, P2_pos),
|
||||
(P1_pos, W_orth),
|
||||
(P1_suffix,),
|
||||
(N1_suffix,),
|
||||
|
||||
(W_shape,),
|
||||
(W_cluster,),
|
||||
(N1_cluster,),
|
||||
(N2_cluster,),
|
||||
(P1_cluster,),
|
||||
(P2_cluster,),
|
||||
|
||||
(W_flags,),
|
||||
(N1_flags,),
|
||||
(N2_flags,),
|
||||
(P1_flags,),
|
||||
(P2_flags,),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def blank(cls, vocab, templates):
|
||||
model = TaggerModel(templates)
|
||||
return cls(vocab, model)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path, vocab):
|
||||
def load(cls, path, vocab, require=False):
|
||||
# TODO: Change this to expect config.json when we don't have to
|
||||
# support old data.
|
||||
path = path if not isinstance(path, basestring) else pathlib.Path(path)
|
||||
if (path / 'templates.json').exists():
|
||||
with (path / 'templates.json').open() as file_:
|
||||
templates = json.load(file_)
|
||||
elif require:
|
||||
raise IOError(
|
||||
"Required file %s/templates.json not found when loading Tagger" % str(path))
|
||||
else:
|
||||
templates = cls.default_templates()
|
||||
templates = cls.feature_templates
|
||||
self = cls(vocab, model=None, feature_templates=templates)
|
||||
|
||||
model = TaggerModel(templates)
|
||||
if (path / 'model').exists():
|
||||
model.load(str(path / 'model'))
|
||||
return cls(vocab, model)
|
||||
self.model.load(str(path / 'model'))
|
||||
elif require:
|
||||
raise IOError(
|
||||
"Required file %s/model not found when loading Tagger" % str(path))
|
||||
return self
|
||||
|
||||
def __init__(self, Vocab vocab, TaggerModel model, **cfg):
|
||||
def __init__(self, Vocab vocab, TaggerModel model=None, **cfg):
|
||||
if model is None:
|
||||
model = TaggerModel(cfg.get('features', self.feature_templates))
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
# TODO: Move this to tag map
|
||||
|
@ -242,3 +214,35 @@ cdef class Tagger:
|
|||
tokens.is_tagged = True
|
||||
tokens._py_tokens = [None] * tokens.length
|
||||
return correct
|
||||
|
||||
|
||||
feature_templates = (
|
||||
(W_orth,),
|
||||
(P1_lemma, P1_pos),
|
||||
(P2_lemma, P2_pos),
|
||||
(N1_orth,),
|
||||
(N2_orth,),
|
||||
|
||||
(W_suffix,),
|
||||
(W_prefix,),
|
||||
|
||||
(P1_pos,),
|
||||
(P2_pos,),
|
||||
(P1_pos, P2_pos),
|
||||
(P1_pos, W_orth),
|
||||
(P1_suffix,),
|
||||
(N1_suffix,),
|
||||
|
||||
(W_shape,),
|
||||
(W_cluster,),
|
||||
(N1_cluster,),
|
||||
(N2_cluster,),
|
||||
(P1_cluster,),
|
||||
(P2_cluster,),
|
||||
|
||||
(W_flags,),
|
||||
(N1_flags,),
|
||||
(N2_flags,),
|
||||
(P1_flags,),
|
||||
(P2_flags,),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user