mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
Refactor the pipeline classes to make them more consistent, and remove the redundant blank() constructor.
This commit is contained in:
parent
311a985fe0
commit
f787cd29fe
63
examples/training/train_ner.py
Normal file
63
examples/training/train_ner.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
from __future__ import unicode_literals, print_function
|
||||||
|
import json
|
||||||
|
import pathlib
|
||||||
|
import random
|
||||||
|
|
||||||
|
import spacy
|
||||||
|
from spacy.pipeline import EntityRecognizer
|
||||||
|
from spacy.gold import GoldParse
|
||||||
|
|
||||||
|
|
||||||
|
def train_ner(nlp, train_data, entity_types):
|
||||||
|
ner = EntityRecognizer(nlp.vocab, entity_types=entity_types)
|
||||||
|
for itn in range(5):
|
||||||
|
random.shuffle(train_data)
|
||||||
|
for raw_text, entity_offsets in train_data:
|
||||||
|
doc = nlp.make_doc(raw_text)
|
||||||
|
gold = GoldParse(doc, entities=entity_offsets)
|
||||||
|
ner.update(doc, gold)
|
||||||
|
ner.model.end_training()
|
||||||
|
return ner
|
||||||
|
|
||||||
|
|
||||||
|
def main(model_dir=None):
|
||||||
|
if model_dir is not None:
|
||||||
|
model_dir = pathlb.Path(model_dir)
|
||||||
|
if not model_dir.exists():
|
||||||
|
model_dir.mkdir()
|
||||||
|
assert model_dir.isdir()
|
||||||
|
|
||||||
|
nlp = spacy.load('en', parser=False, entity=False, vectors=False)
|
||||||
|
|
||||||
|
train_data = [
|
||||||
|
(
|
||||||
|
'Who is Shaka Khan?',
|
||||||
|
[(len('Who is '), len('Who is Shaka Khan'), 'PERSON')]
|
||||||
|
),
|
||||||
|
(
|
||||||
|
'I like London and Berlin.',
|
||||||
|
[(len('I like '), len('I like London'), 'LOC'),
|
||||||
|
(len('I like London and '), len('I like London and Berlin'), 'LOC')]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ner = train_ner(nlp, train_data, ['PERSON', 'LOC'])
|
||||||
|
|
||||||
|
doc = nlp.make_doc('Who is Shaka Khan?')
|
||||||
|
nlp.tagger(doc)
|
||||||
|
ner(doc)
|
||||||
|
for word in doc:
|
||||||
|
print(word.text, word.tag_, word.ent_type_, word.ent_iob)
|
||||||
|
|
||||||
|
if model_dir is not None:
|
||||||
|
with (model_dir / 'config.json').open('wb') as file_:
|
||||||
|
json.dump(ner.cfg, file_)
|
||||||
|
ner.model.dump(str(model_dir / 'model'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
# Who "" 2
|
||||||
|
# is "" 2
|
||||||
|
# Shaka "" PERSON 3
|
||||||
|
# Khan "" PERSON 1
|
||||||
|
# ? "" 2
|
|
@ -10,11 +10,10 @@ from spacy.tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
def train_parser(nlp, train_data, left_labels, right_labels):
|
def train_parser(nlp, train_data, left_labels, right_labels):
|
||||||
parser = DependencyParser.blank(
|
parser = DependencyParser(
|
||||||
nlp.vocab,
|
nlp.vocab,
|
||||||
left_labels=left_labels,
|
left_labels=left_labels,
|
||||||
right_labels=right_labels,
|
right_labels=right_labels)
|
||||||
features=nlp.defaults.parser_features)
|
|
||||||
for itn in range(1000):
|
for itn in range(1000):
|
||||||
random.shuffle(train_data)
|
random.shuffle(train_data)
|
||||||
loss = 0
|
loss = 0
|
||||||
|
|
|
@ -53,7 +53,7 @@ def main(output_dir=None):
|
||||||
vocab = Vocab(tag_map=TAG_MAP)
|
vocab = Vocab(tag_map=TAG_MAP)
|
||||||
# The default_templates argument is where features are specified. See
|
# The default_templates argument is where features are specified. See
|
||||||
# spacy/tagger.pyx for the defaults.
|
# spacy/tagger.pyx for the defaults.
|
||||||
tagger = Tagger.blank(vocab, Tagger.default_templates())
|
tagger = Tagger(vocab)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
for words, tags in DATA:
|
for words, tags in DATA:
|
||||||
doc = Doc(vocab, words=words)
|
doc = Doc(vocab, words=words)
|
||||||
|
|
|
@ -24,7 +24,7 @@ def blank(name, vocab=None, tokenizer=None, parser=None, tagger=None, entity=Non
|
||||||
target_name, target_version = util.split_data_name(name)
|
target_name, target_version = util.split_data_name(name)
|
||||||
cls = get_lang_class(target_name)
|
cls = get_lang_class(target_name)
|
||||||
return cls(
|
return cls(
|
||||||
path,
|
path=None,
|
||||||
vectors=vectors,
|
vectors=vectors,
|
||||||
vocab=vocab,
|
vocab=vocab,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
|
|
@ -19,7 +19,6 @@ except NameError:
|
||||||
|
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .vocab import Vocab
|
from .vocab import Vocab
|
||||||
from .syntax.parser import Parser
|
|
||||||
from .tagger import Tagger
|
from .tagger import Tagger
|
||||||
from .matcher import Matcher
|
from .matcher import Matcher
|
||||||
from . import attrs
|
from . import attrs
|
||||||
|
@ -95,7 +94,9 @@ class BaseDefaults(object):
|
||||||
if self.path:
|
if self.path:
|
||||||
return Tagger.load(self.path / 'pos', vocab)
|
return Tagger.load(self.path / 'pos', vocab)
|
||||||
else:
|
else:
|
||||||
return Tagger.blank(vocab, Tagger.default_templates())
|
if 'features' not in cfg:
|
||||||
|
cfg['features'] = self.parser_features
|
||||||
|
return Tagger(vocab, **cfg)
|
||||||
|
|
||||||
def Parser(self, vocab, **cfg):
|
def Parser(self, vocab, **cfg):
|
||||||
if self.path and (self.path / 'deps').exists():
|
if self.path and (self.path / 'deps').exists():
|
||||||
|
@ -103,7 +104,7 @@ class BaseDefaults(object):
|
||||||
else:
|
else:
|
||||||
if 'features' not in cfg:
|
if 'features' not in cfg:
|
||||||
cfg['features'] = self.parser_features
|
cfg['features'] = self.parser_features
|
||||||
return DependencyParser.blank(vocab, **cfg)
|
return DependencyParser(vocab, **cfg)
|
||||||
|
|
||||||
def Entity(self, vocab, **cfg):
|
def Entity(self, vocab, **cfg):
|
||||||
if self.path and (self.path / 'ner').exists():
|
if self.path and (self.path / 'ner').exists():
|
||||||
|
@ -111,7 +112,7 @@ class BaseDefaults(object):
|
||||||
else:
|
else:
|
||||||
if 'features' not in cfg:
|
if 'features' not in cfg:
|
||||||
cfg['features'] = self.entity_features
|
cfg['features'] = self.entity_features
|
||||||
return EntityRecognizer.blank(vocab, **cfg)
|
return EntityRecognizer(vocab, **cfg)
|
||||||
|
|
||||||
def Matcher(self, vocab, **cfg):
|
def Matcher(self, vocab, **cfg):
|
||||||
if self.path:
|
if self.path:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from .syntax.parser cimport Parser
|
from .syntax.parser cimport Parser
|
||||||
from .syntax.ner cimport BiluoPushDown
|
from .syntax.ner cimport BiluoPushDown
|
||||||
from .syntax.arc_eager cimport ArcEager
|
from .syntax.arc_eager cimport ArcEager
|
||||||
|
from .tagger cimport Tagger
|
||||||
|
|
||||||
|
|
||||||
cdef class EntityRecognizer(Parser):
|
cdef class EntityRecognizer(Parser):
|
||||||
|
|
|
@ -2,41 +2,23 @@ from .syntax.parser cimport Parser
|
||||||
from .syntax.ner cimport BiluoPushDown
|
from .syntax.ner cimport BiluoPushDown
|
||||||
from .syntax.arc_eager cimport ArcEager
|
from .syntax.arc_eager cimport ArcEager
|
||||||
from .vocab cimport Vocab
|
from .vocab cimport Vocab
|
||||||
from .tagger cimport Tagger
|
from .tagger import Tagger
|
||||||
|
|
||||||
|
# TODO: The disorganization here is pretty embarrassing. At least it's only
|
||||||
|
# internals.
|
||||||
|
from .syntax.parser import get_templates as get_feature_templates
|
||||||
|
|
||||||
|
|
||||||
cdef class EntityRecognizer(Parser):
|
cdef class EntityRecognizer(Parser):
|
||||||
@classmethod
|
TransitionSystem = BiluoPushDown
|
||||||
def load(cls, path, Vocab vocab):
|
|
||||||
return Parser.load(path, vocab, BiluoPushDown)
|
feature_templates = get_feature_templates('ner')
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def blank(cls, Vocab vocab, **cfg):
|
|
||||||
if 'actions' not in cfg:
|
|
||||||
cfg['actions'] = {0: {'': True}, 5: {'': True}}
|
|
||||||
entity_types = cfg.get('entity_types', [''])
|
|
||||||
for action_type in (1, 2, 3, 4):
|
|
||||||
cfg['actions'][action_type] = {ent_type: True for ent_type in entity_types}
|
|
||||||
return Parser.blank(vocab, BiluoPushDown, **cfg)
|
|
||||||
|
|
||||||
|
|
||||||
cdef class DependencyParser(Parser):
|
cdef class DependencyParser(Parser):
|
||||||
@classmethod
|
TransitionSystem = ArcEager
|
||||||
def load(cls, path, Vocab vocab):
|
|
||||||
return Parser.load(path, vocab, ArcEager)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def blank(cls, Vocab vocab, **cfg):
|
|
||||||
if 'actions' not in cfg:
|
|
||||||
cfg['actions'] = {0: {'': True}, 1: {'': True}, 2: {}, 3: {},
|
|
||||||
4: {'ROOT': True}}
|
|
||||||
for label in cfg.get('left_labels', []):
|
|
||||||
cfg['actions'][2][label] = True
|
|
||||||
for label in cfg.get('right_labels', []):
|
|
||||||
cfg['actions'][3][label] = True
|
|
||||||
for label in cfg.get('break_labels', []):
|
|
||||||
cfg['actions'][4][label] = True
|
|
||||||
return Parser.blank(vocab, ArcEager, **cfg)
|
|
||||||
|
|
||||||
|
feature_templates = get_feature_templates('basic')
|
||||||
|
|
||||||
|
|
||||||
__all__ = [Tagger, DependencyParser, EntityRecognizer]
|
__all__ = [Tagger, DependencyParser, EntityRecognizer]
|
||||||
|
|
|
@ -279,20 +279,36 @@ cdef int _get_root(int word, const GoldParseC* gold) nogil:
|
||||||
|
|
||||||
cdef class ArcEager(TransitionSystem):
|
cdef class ArcEager(TransitionSystem):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_labels(cls, gold_parses):
|
def get_actions(cls, **kwargs):
|
||||||
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
|
actions = kwargs.get('actions',
|
||||||
LEFT: {}, BREAK: {'ROOT': True}}
|
{
|
||||||
for raw_text, sents in gold_parses:
|
SHIFT: {'': True},
|
||||||
|
REDUCE: {'': True},
|
||||||
|
RIGHT: {},
|
||||||
|
LEFT: {},
|
||||||
|
BREAK: {'ROOT': True}})
|
||||||
|
for label in kwargs.get('labels', []):
|
||||||
|
if label.upper() != 'ROOT':
|
||||||
|
actions[LEFT][label] = True
|
||||||
|
actions[RIGHT][label] = True
|
||||||
|
for label in kwargs.get('left_labels', []):
|
||||||
|
if label.upper() != 'ROOT':
|
||||||
|
actions[LEFT][label] = True
|
||||||
|
for label in kwargs.get('right_labels', []):
|
||||||
|
if label.upper() != 'ROOT':
|
||||||
|
actions[RIGHT][label] = True
|
||||||
|
|
||||||
|
for raw_text, sents in kwargs.get('gold_parses', []):
|
||||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||||
for child, head, label in zip(ids, heads, labels):
|
for child, head, label in zip(ids, heads, labels):
|
||||||
if label.upper() == 'ROOT':
|
if label.upper() == 'ROOT':
|
||||||
label = 'ROOT'
|
label = 'ROOT'
|
||||||
if label != 'ROOT':
|
if label != 'ROOT':
|
||||||
if head < child:
|
if head < child:
|
||||||
move_labels[RIGHT][label] = True
|
actions[RIGHT][label] = True
|
||||||
elif head > child:
|
elif head > child:
|
||||||
move_labels[LEFT][label] = True
|
actions[LEFT][label] = True
|
||||||
return move_labels
|
return actions
|
||||||
|
|
||||||
property action_types:
|
property action_types:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
|
|
|
@ -51,11 +51,21 @@ cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
|
||||||
|
|
||||||
cdef class BiluoPushDown(TransitionSystem):
|
cdef class BiluoPushDown(TransitionSystem):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_labels(cls, gold_tuples):
|
def get_actions(cls, **kwargs):
|
||||||
move_labels = {MISSING: {'': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {},
|
actions = kwargs.get('actions',
|
||||||
OUT: {'': True}}
|
{
|
||||||
|
MISSING: {'': True},
|
||||||
|
BEGIN: {},
|
||||||
|
IN: {},
|
||||||
|
LAST: {},
|
||||||
|
UNIT: {},
|
||||||
|
OUT: {'': True}
|
||||||
|
})
|
||||||
|
for entity_type in kwargs.get('entity_types', []):
|
||||||
|
for action in (BEGIN, IN, LAST, UNIT):
|
||||||
|
actions[action][entity_type] = True
|
||||||
moves = ('M', 'B', 'I', 'L', 'U')
|
moves = ('M', 'B', 'I', 'L', 'U')
|
||||||
for raw_text, sents in gold_tuples:
|
for raw_text, sents in kwargs.get('gold_tuples', []):
|
||||||
for (ids, words, tags, heads, labels, biluo), _ in sents:
|
for (ids, words, tags, heads, labels, biluo), _ in sents:
|
||||||
for i, ner_tag in enumerate(biluo):
|
for i, ner_tag in enumerate(biluo):
|
||||||
if ner_tag != 'O' and ner_tag != '-':
|
if ner_tag != 'O' and ner_tag != '-':
|
||||||
|
@ -63,8 +73,8 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
raise ValueError(ner_tag)
|
raise ValueError(ner_tag)
|
||||||
_, label = ner_tag.split('-')
|
_, label = ner_tag.split('-')
|
||||||
for move_str in ('B', 'I', 'L', 'U'):
|
for move_str in ('B', 'I', 'L', 'U'):
|
||||||
move_labels[moves.index(move_str)][label] = True
|
actions[moves.index(move_str)][label] = True
|
||||||
return move_labels
|
return actions
|
||||||
|
|
||||||
property action_types:
|
property action_types:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
|
|
|
@ -67,10 +67,6 @@ def get_templates(name):
|
||||||
pf.tree_shape + pf.trigrams)
|
pf.tree_shape + pf.trigrams)
|
||||||
|
|
||||||
|
|
||||||
def ParserFactory(transition_system):
|
|
||||||
return lambda strings, dir_: Parser(strings, dir_, transition_system)
|
|
||||||
|
|
||||||
|
|
||||||
cdef class ParserModel(AveragedPerceptron):
|
cdef class ParserModel(AveragedPerceptron):
|
||||||
cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil:
|
cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil:
|
||||||
fill_context(eg.atoms, state)
|
fill_context(eg.atoms, state)
|
||||||
|
@ -79,28 +75,26 @@ cdef class ParserModel(AveragedPerceptron):
|
||||||
|
|
||||||
cdef class Parser:
|
cdef class Parser:
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path, Vocab vocab, moves_class):
|
def load(cls, path, Vocab vocab, TransitionSystem=None, require=False):
|
||||||
with (path / 'config.json').open() as file_:
|
with (path / 'config.json').open() as file_:
|
||||||
cfg = json.load(file_)
|
cfg = json.load(file_)
|
||||||
|
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
|
||||||
|
if (path / 'model').exists():
|
||||||
|
self.model.load(str(path / 'model'))
|
||||||
|
elif require:
|
||||||
|
raise IOError(
|
||||||
|
"Required file %s/model not found when loading" % str(path))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __init__(self, Vocab vocab, TransitionSystem=None, ParserModel model=None, **cfg):
|
||||||
|
if TransitionSystem is None:
|
||||||
|
TransitionSystem = self.TransitionSystem
|
||||||
|
actions = TransitionSystem.get_actions(**cfg)
|
||||||
|
self.moves = TransitionSystem(vocab.strings, actions)
|
||||||
# TODO: Remove this when we no longer need to support old-style models
|
# TODO: Remove this when we no longer need to support old-style models
|
||||||
if isinstance(cfg.get('features'), basestring):
|
if isinstance(cfg.get('features'), basestring):
|
||||||
cfg['features'] = get_templates(cfg['features'])
|
cfg['features'] = get_templates(cfg['features'])
|
||||||
moves = moves_class(vocab.strings, cfg['actions'])
|
self.model = ParserModel(cfg['features'])
|
||||||
model = ParserModel(cfg['features'])
|
|
||||||
if (path / 'model').exists():
|
|
||||||
model.load(str(path / 'model'))
|
|
||||||
return cls(vocab, moves, model, **cfg)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def blank(cls, Vocab vocab, moves_class, **cfg):
|
|
||||||
moves = moves_class(vocab.strings, cfg.get('actions', {}))
|
|
||||||
templates = cfg.get('features', tuple())
|
|
||||||
model = ParserModel(templates)
|
|
||||||
return cls(vocab, moves, model, **cfg)
|
|
||||||
|
|
||||||
def __init__(self, Vocab vocab, transition_system, ParserModel model, **cfg):
|
|
||||||
self.moves = transition_system
|
|
||||||
self.model = model
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
|
@ -192,8 +186,8 @@ cdef class Parser:
|
||||||
free(eg.is_valid)
|
free(eg.is_valid)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def update(self, Doc tokens, GoldParse gold):
|
def update(self, Doc tokens, raw_gold):
|
||||||
self.moves.preprocess_gold(gold)
|
cdef GoldParse gold = self.preprocess_gold(raw_gold)
|
||||||
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
|
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
|
||||||
self.moves.initialize_state(stcls.c)
|
self.moves.initialize_state(stcls.c)
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
|
@ -230,6 +224,15 @@ cdef class Parser:
|
||||||
for action in self.moves.action_types:
|
for action in self.moves.action_types:
|
||||||
self.moves.add_action(action, label)
|
self.moves.add_action(action, label)
|
||||||
|
|
||||||
|
def preprocess_gold(self, raw_gold):
|
||||||
|
cdef GoldParse gold
|
||||||
|
if isinstance(raw_gold, GoldParse):
|
||||||
|
gold = raw_gold
|
||||||
|
self.moves.preprocess_gold(raw_gold)
|
||||||
|
return gold
|
||||||
|
else:
|
||||||
|
raise ValueError("Parser.preprocess_gold requires GoldParse-type input.")
|
||||||
|
|
||||||
|
|
||||||
cdef class StepwiseState:
|
cdef class StepwiseState:
|
||||||
cdef readonly StateClass stcls
|
cdef readonly StateClass stcls
|
||||||
|
|
|
@ -103,58 +103,30 @@ cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil:
|
||||||
cdef class Tagger:
|
cdef class Tagger:
|
||||||
"""A part-of-speech tagger for English"""
|
"""A part-of-speech tagger for English"""
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_templates(cls):
|
def load(cls, path, vocab, require=False):
|
||||||
return (
|
# TODO: Change this to expect config.json when we don't have to
|
||||||
(W_orth,),
|
# support old data.
|
||||||
(P1_lemma, P1_pos),
|
|
||||||
(P2_lemma, P2_pos),
|
|
||||||
(N1_orth,),
|
|
||||||
(N2_orth,),
|
|
||||||
|
|
||||||
(W_suffix,),
|
|
||||||
(W_prefix,),
|
|
||||||
|
|
||||||
(P1_pos,),
|
|
||||||
(P2_pos,),
|
|
||||||
(P1_pos, P2_pos),
|
|
||||||
(P1_pos, W_orth),
|
|
||||||
(P1_suffix,),
|
|
||||||
(N1_suffix,),
|
|
||||||
|
|
||||||
(W_shape,),
|
|
||||||
(W_cluster,),
|
|
||||||
(N1_cluster,),
|
|
||||||
(N2_cluster,),
|
|
||||||
(P1_cluster,),
|
|
||||||
(P2_cluster,),
|
|
||||||
|
|
||||||
(W_flags,),
|
|
||||||
(N1_flags,),
|
|
||||||
(N2_flags,),
|
|
||||||
(P1_flags,),
|
|
||||||
(P2_flags,),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def blank(cls, vocab, templates):
|
|
||||||
model = TaggerModel(templates)
|
|
||||||
return cls(vocab, model)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, path, vocab):
|
|
||||||
path = path if not isinstance(path, basestring) else pathlib.Path(path)
|
path = path if not isinstance(path, basestring) else pathlib.Path(path)
|
||||||
if (path / 'templates.json').exists():
|
if (path / 'templates.json').exists():
|
||||||
with (path / 'templates.json').open() as file_:
|
with (path / 'templates.json').open() as file_:
|
||||||
templates = json.load(file_)
|
templates = json.load(file_)
|
||||||
|
elif require:
|
||||||
|
raise IOError(
|
||||||
|
"Required file %s/templates.json not found when loading Tagger" % str(path))
|
||||||
else:
|
else:
|
||||||
templates = cls.default_templates()
|
templates = cls.feature_templates
|
||||||
|
self = cls(vocab, model=None, feature_templates=templates)
|
||||||
|
|
||||||
model = TaggerModel(templates)
|
|
||||||
if (path / 'model').exists():
|
if (path / 'model').exists():
|
||||||
model.load(str(path / 'model'))
|
self.model.load(str(path / 'model'))
|
||||||
return cls(vocab, model)
|
elif require:
|
||||||
|
raise IOError(
|
||||||
|
"Required file %s/model not found when loading Tagger" % str(path))
|
||||||
|
return self
|
||||||
|
|
||||||
def __init__(self, Vocab vocab, TaggerModel model, **cfg):
|
def __init__(self, Vocab vocab, TaggerModel model=None, **cfg):
|
||||||
|
if model is None:
|
||||||
|
model = TaggerModel(cfg.get('features', self.feature_templates))
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
# TODO: Move this to tag map
|
# TODO: Move this to tag map
|
||||||
|
@ -242,3 +214,35 @@ cdef class Tagger:
|
||||||
tokens.is_tagged = True
|
tokens.is_tagged = True
|
||||||
tokens._py_tokens = [None] * tokens.length
|
tokens._py_tokens = [None] * tokens.length
|
||||||
return correct
|
return correct
|
||||||
|
|
||||||
|
|
||||||
|
feature_templates = (
|
||||||
|
(W_orth,),
|
||||||
|
(P1_lemma, P1_pos),
|
||||||
|
(P2_lemma, P2_pos),
|
||||||
|
(N1_orth,),
|
||||||
|
(N2_orth,),
|
||||||
|
|
||||||
|
(W_suffix,),
|
||||||
|
(W_prefix,),
|
||||||
|
|
||||||
|
(P1_pos,),
|
||||||
|
(P2_pos,),
|
||||||
|
(P1_pos, P2_pos),
|
||||||
|
(P1_pos, W_orth),
|
||||||
|
(P1_suffix,),
|
||||||
|
(N1_suffix,),
|
||||||
|
|
||||||
|
(W_shape,),
|
||||||
|
(W_cluster,),
|
||||||
|
(N1_cluster,),
|
||||||
|
(N2_cluster,),
|
||||||
|
(P1_cluster,),
|
||||||
|
(P2_cluster,),
|
||||||
|
|
||||||
|
(W_flags,),
|
||||||
|
(N1_flags,),
|
||||||
|
(N2_flags,),
|
||||||
|
(P1_flags,),
|
||||||
|
(P2_flags,),
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user