Refactor the pipeline classes to make them more consistent, and remove the redundant blank() constructor.

This commit is contained in:
Matthew Honnibal 2016-10-16 21:34:57 +02:00
parent 311a985fe0
commit f787cd29fe
11 changed files with 199 additions and 120 deletions

View File

@ -0,0 +1,63 @@
from __future__ import unicode_literals, print_function
import json
import pathlib
import random
import spacy
from spacy.pipeline import EntityRecognizer
from spacy.gold import GoldParse
def train_ner(nlp, train_data, entity_types):
ner = EntityRecognizer(nlp.vocab, entity_types=entity_types)
for itn in range(5):
random.shuffle(train_data)
for raw_text, entity_offsets in train_data:
doc = nlp.make_doc(raw_text)
gold = GoldParse(doc, entities=entity_offsets)
ner.update(doc, gold)
ner.model.end_training()
return ner
def main(model_dir=None):
if model_dir is not None:
model_dir = pathlb.Path(model_dir)
if not model_dir.exists():
model_dir.mkdir()
assert model_dir.isdir()
nlp = spacy.load('en', parser=False, entity=False, vectors=False)
train_data = [
(
'Who is Shaka Khan?',
[(len('Who is '), len('Who is Shaka Khan'), 'PERSON')]
),
(
'I like London and Berlin.',
[(len('I like '), len('I like London'), 'LOC'),
(len('I like London and '), len('I like London and Berlin'), 'LOC')]
)
]
ner = train_ner(nlp, train_data, ['PERSON', 'LOC'])
doc = nlp.make_doc('Who is Shaka Khan?')
nlp.tagger(doc)
ner(doc)
for word in doc:
print(word.text, word.tag_, word.ent_type_, word.ent_iob)
if model_dir is not None:
with (model_dir / 'config.json').open('wb') as file_:
json.dump(ner.cfg, file_)
ner.model.dump(str(model_dir / 'model'))
if __name__ == '__main__':
main()
# Who "" 2
# is "" 2
# Shaka "" PERSON 3
# Khan "" PERSON 1
# ? "" 2

View File

@ -10,11 +10,10 @@ from spacy.tokens import Doc
def train_parser(nlp, train_data, left_labels, right_labels): def train_parser(nlp, train_data, left_labels, right_labels):
parser = DependencyParser.blank( parser = DependencyParser(
nlp.vocab, nlp.vocab,
left_labels=left_labels, left_labels=left_labels,
right_labels=right_labels, right_labels=right_labels)
features=nlp.defaults.parser_features)
for itn in range(1000): for itn in range(1000):
random.shuffle(train_data) random.shuffle(train_data)
loss = 0 loss = 0

View File

@ -53,7 +53,7 @@ def main(output_dir=None):
vocab = Vocab(tag_map=TAG_MAP) vocab = Vocab(tag_map=TAG_MAP)
# The default_templates argument is where features are specified. See # The default_templates argument is where features are specified. See
# spacy/tagger.pyx for the defaults. # spacy/tagger.pyx for the defaults.
tagger = Tagger.blank(vocab, Tagger.default_templates()) tagger = Tagger(vocab)
for i in range(5): for i in range(5):
for words, tags in DATA: for words, tags in DATA:
doc = Doc(vocab, words=words) doc = Doc(vocab, words=words)

View File

@ -24,7 +24,7 @@ def blank(name, vocab=None, tokenizer=None, parser=None, tagger=None, entity=Non
target_name, target_version = util.split_data_name(name) target_name, target_version = util.split_data_name(name)
cls = get_lang_class(target_name) cls = get_lang_class(target_name)
return cls( return cls(
path, path=None,
vectors=vectors, vectors=vectors,
vocab=vocab, vocab=vocab,
tokenizer=tokenizer, tokenizer=tokenizer,

View File

@ -19,7 +19,6 @@ except NameError:
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .vocab import Vocab from .vocab import Vocab
from .syntax.parser import Parser
from .tagger import Tagger from .tagger import Tagger
from .matcher import Matcher from .matcher import Matcher
from . import attrs from . import attrs
@ -95,7 +94,9 @@ class BaseDefaults(object):
if self.path: if self.path:
return Tagger.load(self.path / 'pos', vocab) return Tagger.load(self.path / 'pos', vocab)
else: else:
return Tagger.blank(vocab, Tagger.default_templates()) if 'features' not in cfg:
cfg['features'] = self.parser_features
return Tagger(vocab, **cfg)
def Parser(self, vocab, **cfg): def Parser(self, vocab, **cfg):
if self.path and (self.path / 'deps').exists(): if self.path and (self.path / 'deps').exists():
@ -103,7 +104,7 @@ class BaseDefaults(object):
else: else:
if 'features' not in cfg: if 'features' not in cfg:
cfg['features'] = self.parser_features cfg['features'] = self.parser_features
return DependencyParser.blank(vocab, **cfg) return DependencyParser(vocab, **cfg)
def Entity(self, vocab, **cfg): def Entity(self, vocab, **cfg):
if self.path and (self.path / 'ner').exists(): if self.path and (self.path / 'ner').exists():
@ -111,7 +112,7 @@ class BaseDefaults(object):
else: else:
if 'features' not in cfg: if 'features' not in cfg:
cfg['features'] = self.entity_features cfg['features'] = self.entity_features
return EntityRecognizer.blank(vocab, **cfg) return EntityRecognizer(vocab, **cfg)
def Matcher(self, vocab, **cfg): def Matcher(self, vocab, **cfg):
if self.path: if self.path:

View File

@ -1,6 +1,7 @@
from .syntax.parser cimport Parser from .syntax.parser cimport Parser
from .syntax.ner cimport BiluoPushDown from .syntax.ner cimport BiluoPushDown
from .syntax.arc_eager cimport ArcEager from .syntax.arc_eager cimport ArcEager
from .tagger cimport Tagger
cdef class EntityRecognizer(Parser): cdef class EntityRecognizer(Parser):

View File

@ -2,41 +2,23 @@ from .syntax.parser cimport Parser
from .syntax.ner cimport BiluoPushDown from .syntax.ner cimport BiluoPushDown
from .syntax.arc_eager cimport ArcEager from .syntax.arc_eager cimport ArcEager
from .vocab cimport Vocab from .vocab cimport Vocab
from .tagger cimport Tagger from .tagger import Tagger
# TODO: The disorganization here is pretty embarrassing. At least it's only
# internals.
from .syntax.parser import get_templates as get_feature_templates
cdef class EntityRecognizer(Parser): cdef class EntityRecognizer(Parser):
@classmethod TransitionSystem = BiluoPushDown
def load(cls, path, Vocab vocab):
return Parser.load(path, vocab, BiluoPushDown) feature_templates = get_feature_templates('ner')
@classmethod
def blank(cls, Vocab vocab, **cfg):
if 'actions' not in cfg:
cfg['actions'] = {0: {'': True}, 5: {'': True}}
entity_types = cfg.get('entity_types', [''])
for action_type in (1, 2, 3, 4):
cfg['actions'][action_type] = {ent_type: True for ent_type in entity_types}
return Parser.blank(vocab, BiluoPushDown, **cfg)
cdef class DependencyParser(Parser): cdef class DependencyParser(Parser):
@classmethod TransitionSystem = ArcEager
def load(cls, path, Vocab vocab):
return Parser.load(path, vocab, ArcEager)
@classmethod
def blank(cls, Vocab vocab, **cfg):
if 'actions' not in cfg:
cfg['actions'] = {0: {'': True}, 1: {'': True}, 2: {}, 3: {},
4: {'ROOT': True}}
for label in cfg.get('left_labels', []):
cfg['actions'][2][label] = True
for label in cfg.get('right_labels', []):
cfg['actions'][3][label] = True
for label in cfg.get('break_labels', []):
cfg['actions'][4][label] = True
return Parser.blank(vocab, ArcEager, **cfg)
feature_templates = get_feature_templates('basic')
__all__ = [Tagger, DependencyParser, EntityRecognizer] __all__ = [Tagger, DependencyParser, EntityRecognizer]

View File

@ -279,20 +279,36 @@ cdef int _get_root(int word, const GoldParseC* gold) nogil:
cdef class ArcEager(TransitionSystem): cdef class ArcEager(TransitionSystem):
@classmethod @classmethod
def get_labels(cls, gold_parses): def get_actions(cls, **kwargs):
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {}, actions = kwargs.get('actions',
LEFT: {}, BREAK: {'ROOT': True}} {
for raw_text, sents in gold_parses: SHIFT: {'': True},
REDUCE: {'': True},
RIGHT: {},
LEFT: {},
BREAK: {'ROOT': True}})
for label in kwargs.get('labels', []):
if label.upper() != 'ROOT':
actions[LEFT][label] = True
actions[RIGHT][label] = True
for label in kwargs.get('left_labels', []):
if label.upper() != 'ROOT':
actions[LEFT][label] = True
for label in kwargs.get('right_labels', []):
if label.upper() != 'ROOT':
actions[RIGHT][label] = True
for raw_text, sents in kwargs.get('gold_parses', []):
for (ids, words, tags, heads, labels, iob), ctnts in sents: for (ids, words, tags, heads, labels, iob), ctnts in sents:
for child, head, label in zip(ids, heads, labels): for child, head, label in zip(ids, heads, labels):
if label.upper() == 'ROOT': if label.upper() == 'ROOT':
label = 'ROOT' label = 'ROOT'
if label != 'ROOT': if label != 'ROOT':
if head < child: if head < child:
move_labels[RIGHT][label] = True actions[RIGHT][label] = True
elif head > child: elif head > child:
move_labels[LEFT][label] = True actions[LEFT][label] = True
return move_labels return actions
property action_types: property action_types:
def __get__(self): def __get__(self):

View File

@ -51,11 +51,21 @@ cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
cdef class BiluoPushDown(TransitionSystem): cdef class BiluoPushDown(TransitionSystem):
@classmethod @classmethod
def get_labels(cls, gold_tuples): def get_actions(cls, **kwargs):
move_labels = {MISSING: {'': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {}, actions = kwargs.get('actions',
OUT: {'': True}} {
MISSING: {'': True},
BEGIN: {},
IN: {},
LAST: {},
UNIT: {},
OUT: {'': True}
})
for entity_type in kwargs.get('entity_types', []):
for action in (BEGIN, IN, LAST, UNIT):
actions[action][entity_type] = True
moves = ('M', 'B', 'I', 'L', 'U') moves = ('M', 'B', 'I', 'L', 'U')
for raw_text, sents in gold_tuples: for raw_text, sents in kwargs.get('gold_tuples', []):
for (ids, words, tags, heads, labels, biluo), _ in sents: for (ids, words, tags, heads, labels, biluo), _ in sents:
for i, ner_tag in enumerate(biluo): for i, ner_tag in enumerate(biluo):
if ner_tag != 'O' and ner_tag != '-': if ner_tag != 'O' and ner_tag != '-':
@ -63,8 +73,8 @@ cdef class BiluoPushDown(TransitionSystem):
raise ValueError(ner_tag) raise ValueError(ner_tag)
_, label = ner_tag.split('-') _, label = ner_tag.split('-')
for move_str in ('B', 'I', 'L', 'U'): for move_str in ('B', 'I', 'L', 'U'):
move_labels[moves.index(move_str)][label] = True actions[moves.index(move_str)][label] = True
return move_labels return actions
property action_types: property action_types:
def __get__(self): def __get__(self):

View File

@ -67,10 +67,6 @@ def get_templates(name):
pf.tree_shape + pf.trigrams) pf.tree_shape + pf.trigrams)
def ParserFactory(transition_system):
return lambda strings, dir_: Parser(strings, dir_, transition_system)
cdef class ParserModel(AveragedPerceptron): cdef class ParserModel(AveragedPerceptron):
cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil: cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil:
fill_context(eg.atoms, state) fill_context(eg.atoms, state)
@ -79,28 +75,26 @@ cdef class ParserModel(AveragedPerceptron):
cdef class Parser: cdef class Parser:
@classmethod @classmethod
def load(cls, path, Vocab vocab, moves_class): def load(cls, path, Vocab vocab, TransitionSystem=None, require=False):
with (path / 'config.json').open() as file_: with (path / 'config.json').open() as file_:
cfg = json.load(file_) cfg = json.load(file_)
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
if (path / 'model').exists():
self.model.load(str(path / 'model'))
elif require:
raise IOError(
"Required file %s/model not found when loading" % str(path))
return self
def __init__(self, Vocab vocab, TransitionSystem=None, ParserModel model=None, **cfg):
if TransitionSystem is None:
TransitionSystem = self.TransitionSystem
actions = TransitionSystem.get_actions(**cfg)
self.moves = TransitionSystem(vocab.strings, actions)
# TODO: Remove this when we no longer need to support old-style models # TODO: Remove this when we no longer need to support old-style models
if isinstance(cfg.get('features'), basestring): if isinstance(cfg.get('features'), basestring):
cfg['features'] = get_templates(cfg['features']) cfg['features'] = get_templates(cfg['features'])
moves = moves_class(vocab.strings, cfg['actions']) self.model = ParserModel(cfg['features'])
model = ParserModel(cfg['features'])
if (path / 'model').exists():
model.load(str(path / 'model'))
return cls(vocab, moves, model, **cfg)
@classmethod
def blank(cls, Vocab vocab, moves_class, **cfg):
moves = moves_class(vocab.strings, cfg.get('actions', {}))
templates = cfg.get('features', tuple())
model = ParserModel(templates)
return cls(vocab, moves, model, **cfg)
def __init__(self, Vocab vocab, transition_system, ParserModel model, **cfg):
self.moves = transition_system
self.model = model
self.cfg = cfg self.cfg = cfg
def __reduce__(self): def __reduce__(self):
@ -192,8 +186,8 @@ cdef class Parser:
free(eg.is_valid) free(eg.is_valid)
return 0 return 0
def update(self, Doc tokens, GoldParse gold): def update(self, Doc tokens, raw_gold):
self.moves.preprocess_gold(gold) cdef GoldParse gold = self.preprocess_gold(raw_gold)
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length) cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
self.moves.initialize_state(stcls.c) self.moves.initialize_state(stcls.c)
cdef Pool mem = Pool() cdef Pool mem = Pool()
@ -230,6 +224,15 @@ cdef class Parser:
for action in self.moves.action_types: for action in self.moves.action_types:
self.moves.add_action(action, label) self.moves.add_action(action, label)
def preprocess_gold(self, raw_gold):
cdef GoldParse gold
if isinstance(raw_gold, GoldParse):
gold = raw_gold
self.moves.preprocess_gold(raw_gold)
return gold
else:
raise ValueError("Parser.preprocess_gold requires GoldParse-type input.")
cdef class StepwiseState: cdef class StepwiseState:
cdef readonly StateClass stcls cdef readonly StateClass stcls

View File

@ -103,58 +103,30 @@ cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil:
cdef class Tagger: cdef class Tagger:
"""A part-of-speech tagger for English""" """A part-of-speech tagger for English"""
@classmethod @classmethod
def default_templates(cls): def load(cls, path, vocab, require=False):
return ( # TODO: Change this to expect config.json when we don't have to
(W_orth,), # support old data.
(P1_lemma, P1_pos),
(P2_lemma, P2_pos),
(N1_orth,),
(N2_orth,),
(W_suffix,),
(W_prefix,),
(P1_pos,),
(P2_pos,),
(P1_pos, P2_pos),
(P1_pos, W_orth),
(P1_suffix,),
(N1_suffix,),
(W_shape,),
(W_cluster,),
(N1_cluster,),
(N2_cluster,),
(P1_cluster,),
(P2_cluster,),
(W_flags,),
(N1_flags,),
(N2_flags,),
(P1_flags,),
(P2_flags,),
)
@classmethod
def blank(cls, vocab, templates):
model = TaggerModel(templates)
return cls(vocab, model)
@classmethod
def load(cls, path, vocab):
path = path if not isinstance(path, basestring) else pathlib.Path(path) path = path if not isinstance(path, basestring) else pathlib.Path(path)
if (path / 'templates.json').exists(): if (path / 'templates.json').exists():
with (path / 'templates.json').open() as file_: with (path / 'templates.json').open() as file_:
templates = json.load(file_) templates = json.load(file_)
elif require:
raise IOError(
"Required file %s/templates.json not found when loading Tagger" % str(path))
else: else:
templates = cls.default_templates() templates = cls.feature_templates
self = cls(vocab, model=None, feature_templates=templates)
model = TaggerModel(templates)
if (path / 'model').exists(): if (path / 'model').exists():
model.load(str(path / 'model')) self.model.load(str(path / 'model'))
return cls(vocab, model) elif require:
raise IOError(
"Required file %s/model not found when loading Tagger" % str(path))
return self
def __init__(self, Vocab vocab, TaggerModel model, **cfg): def __init__(self, Vocab vocab, TaggerModel model=None, **cfg):
if model is None:
model = TaggerModel(cfg.get('features', self.feature_templates))
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
# TODO: Move this to tag map # TODO: Move this to tag map
@ -242,3 +214,35 @@ cdef class Tagger:
tokens.is_tagged = True tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length tokens._py_tokens = [None] * tokens.length
return correct return correct
feature_templates = (
(W_orth,),
(P1_lemma, P1_pos),
(P2_lemma, P2_pos),
(N1_orth,),
(N2_orth,),
(W_suffix,),
(W_prefix,),
(P1_pos,),
(P2_pos,),
(P1_pos, P2_pos),
(P1_pos, W_orth),
(P1_suffix,),
(N1_suffix,),
(W_shape,),
(W_cluster,),
(N1_cluster,),
(N2_cluster,),
(P1_cluster,),
(P2_cluster,),
(W_flags,),
(N1_flags,),
(N2_flags,),
(P1_flags,),
(P2_flags,),
)