Get data flowing through pipeline. Needs redesign

This commit is contained in:
Matthew Honnibal 2017-05-16 11:21:59 +02:00
parent 1d7c18e58a
commit 5211645af3
6 changed files with 143 additions and 284 deletions

View File

@ -135,7 +135,7 @@ def Tok2Vec(width, embed_size, preprocess=None):
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
) )
if preprocess is not None: if preprocess not in (False, None):
tok2vec = preprocess >> tok2vec tok2vec = preprocess >> tok2vec
# Work around thinc API limitations :(. TODO: Revise in Thinc 7 # Work around thinc API limitations :(. TODO: Revise in Thinc 7
tok2vec.nO = width tok2vec.nO = width

View File

@ -41,8 +41,7 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
gold_train = list(read_gold_json(train_path)) gold_train = list(read_gold_json(train_path))
gold_dev = list(read_gold_json(dev_path)) if dev_path else None gold_dev = list(read_gold_json(dev_path)) if dev_path else None
train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg, train_model(lang, gold_train, gold_dev, output_path, n_iter)
entity_cfg, n_iter)
if gold_dev: if gold_dev:
scorer = evaluate(lang, gold_dev, output_path) scorer = evaluate(lang, gold_dev, output_path)
print_results(scorer) print_results(scorer)
@ -58,24 +57,30 @@ def train_config(config):
prints("%s not found in config file." % setting, title="Missing setting") prints("%s not found in config file." % setting, title="Missing setting")
def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_cfg, def train_model(Language, train_data, dev_data, output_path, n_iter, **cfg):
entity_cfg, n_iter):
print("Itn.\tN weight\tN feats\tUAS\tNER F.\tTag %\tToken %") print("Itn.\tN weight\tN feats\tUAS\tNER F.\tTag %\tToken %")
with Language.train(output_path, train_data, nlp = Language(pipeline=['tensor', 'dependencies', 'entities'])
pos=tagger_cfg, deps=parser_cfg, ner=entity_cfg) as trainer:
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)): # TODO: Get spaCy using Thinc's trainer and optimizer
for docs, golds in partition_all(12, epoch): with nlp.begin_training(train_data, **cfg) as (trainer, optimizer):
trainer.update(docs, golds) for itn, epoch in enumerate(trainer.epochs(n_iter)):
losses = defaultdict(float)
for docs, golds in epoch:
grads = {}
def get_grads(W, dW, key=None):
grads[key] = (W, dW)
for proc in nlp.pipeline:
loss = proc.update(docs, golds, drop=0.0, sgd=get_grads)
losses[proc.name] += loss
for key, (W, dW) in grads.items():
optimizer(W, dW, key=key)
if dev_data: if dev_data:
dev_scores = trainer.evaluate(dev_data).scores dev_scores = trainer.evaluate(dev_data).scores
else: else:
defaultdict(float) defaultdict(float)
print_progress(itn, trainer.nlp.parser.model.nr_weight, print_progress(itn, losses['dep'], **dev_scores)
trainer.nlp.parser.model.nr_active_feat,
**dev_scores)
def evaluate(Language, gold_tuples, output_path): def evaluate(Language, gold_tuples, output_path):

View File

@ -11,7 +11,8 @@ from .lemmatizer import Lemmatizer
from .train import Trainer from .train import Trainer
from .syntax.parser import get_templates from .syntax.parser import get_templates
from .syntax.nonproj import PseudoProjectivity from .syntax.nonproj import PseudoProjectivity
from .pipeline import DependencyParser, EntityRecognizer from .pipeline import DependencyParser, NeuralDependencyParser, EntityRecognizer
from .pipeline import TokenVectorEncoder, NeuralEntityRecognizer
from .syntax.arc_eager import ArcEager from .syntax.arc_eager import ArcEager
from .syntax.ner import BiluoPushDown from .syntax.ner import BiluoPushDown
from .compat import json_dumps from .compat import json_dumps
@ -31,111 +32,49 @@ class BaseDefaults(object):
@classmethod @classmethod
def create_vocab(cls, nlp=None): def create_vocab(cls, nlp=None):
lemmatizer = cls.create_lemmatizer(nlp) lemmatizer = cls.create_lemmatizer(nlp)
if nlp is None or nlp.path is None: lex_attr_getters = dict(cls.lex_attr_getters)
lex_attr_getters = dict(cls.lex_attr_getters) # This is messy, but it's the minimal working fix to Issue #639.
# This is very messy, but it's the minimal working fix to Issue #639. lex_attr_getters[IS_STOP] = lambda string: string.lower() in cls.stop_words
# This defaults stuff needs to be refactored (again) vocab = Vocab(lex_attr_getters=lex_attr_getters, tag_map=cls.tag_map,
lex_attr_getters[IS_STOP] = lambda string: string.lower() in cls.stop_words lemmatizer=lemmatizer)
vocab = Vocab(lex_attr_getters=lex_attr_getters, tag_map=cls.tag_map,
lemmatizer=lemmatizer)
else:
vocab = Vocab.load(nlp.path, lex_attr_getters=cls.lex_attr_getters,
tag_map=cls.tag_map, lemmatizer=lemmatizer)
for tag_str, exc in cls.morph_rules.items(): for tag_str, exc in cls.morph_rules.items():
for orth_str, attrs in exc.items(): for orth_str, attrs in exc.items():
vocab.morphology.add_special_case(tag_str, orth_str, attrs) vocab.morphology.add_special_case(tag_str, orth_str, attrs)
return vocab return vocab
@classmethod
def add_vectors(cls, nlp=None):
if nlp is None or nlp.path is None:
return False
else:
vec_path = nlp.path / 'vocab' / 'vec.bin'
if vec_path.exists():
return lambda vocab: vocab.load_vectors_from_bin_loc(vec_path)
@classmethod @classmethod
def create_tokenizer(cls, nlp=None): def create_tokenizer(cls, nlp=None):
rules = cls.tokenizer_exceptions rules = cls.tokenizer_exceptions
if cls.token_match: token_match = cls.token_match
token_match = cls.token_match prefix_search = util.compile_prefix_regex(cls.prefixes).search \
if cls.prefixes: if cls.prefixes else None
prefix_search = util.compile_prefix_regex(cls.prefixes).search suffix_search = util.compile_suffix_regex(cls.suffixes).search \
else: if cls.suffixes else None
prefix_search = None infix_finditer = util.compile_infix_regex(cls.infixes).finditer \
if cls.suffixes: if cls.infixes else None
suffix_search = util.compile_suffix_regex(cls.suffixes).search
else:
suffix_search = None
if cls.infixes:
infix_finditer = util.compile_infix_regex(cls.infixes).finditer
else:
infix_finditer = None
vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp) vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp)
return Tokenizer(vocab, rules=rules, return Tokenizer(vocab, rules=rules,
prefix_search=prefix_search, suffix_search=suffix_search, prefix_search=prefix_search, suffix_search=suffix_search,
infix_finditer=infix_finditer, token_match=token_match) infix_finditer=infix_finditer, token_match=token_match)
@classmethod @classmethod
def create_tagger(cls, nlp=None): def create_pipeline(cls, nlp=None):
if nlp is None: meta = nlp.meta if nlp is not None else {}
return Tagger(cls.create_vocab(), features=cls.tagger_features) # Resolve strings, like "cnn", "lstm", etc
elif nlp.path is False:
return Tagger(nlp.vocab, features=cls.tagger_features)
elif nlp.path is None or not (nlp.path / 'pos').exists():
return None
else:
return Tagger.load(nlp.path / 'pos', nlp.vocab)
@classmethod
def create_parser(cls, nlp=None, **cfg):
if nlp is None:
return DependencyParser(cls.create_vocab(), features=cls.parser_features,
**cfg)
elif nlp.path is False:
return DependencyParser(nlp.vocab, features=cls.parser_features, **cfg)
elif nlp.path is None or not (nlp.path / 'deps').exists():
return None
else:
return DependencyParser.load(nlp.path / 'deps', nlp.vocab, **cfg)
@classmethod
def create_entity(cls, nlp=None, **cfg):
if nlp is None:
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features, **cfg)
elif nlp.path is False:
return EntityRecognizer(nlp.vocab, features=cls.entity_features, **cfg)
elif nlp.path is None or not (nlp.path / 'ner').exists():
return None
else:
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab, **cfg)
@classmethod
def create_matcher(cls, nlp=None):
if nlp is None:
return Matcher(cls.create_vocab())
elif nlp.path is False:
return Matcher(nlp.vocab)
elif nlp.path is None or not (nlp.path / 'vocab').exists():
return None
else:
return Matcher.load(nlp.path / 'vocab', nlp.vocab)
@classmethod
def create_pipeline(self, nlp=None):
pipeline = [] pipeline = []
if nlp is None: for entry in cls.pipeline:
return [] factory = cls.Defaults.factories[entry]
if nlp.tagger: pipeline.append(factory(self, **meta.get(entry, {})))
pipeline.append(nlp.tagger)
if nlp.parser:
pipeline.append(nlp.parser)
pipeline.append(PseudoProjectivity.deprojectivize)
if nlp.entity:
pipeline.append(nlp.entity)
return pipeline return pipeline
factories = {
'make_doc': create_tokenizer,
'tensor': lambda nlp, **cfg: TokenVectorEncoder(nlp.vocab, **cfg),
'tags': lambda nlp, **cfg: Tagger(nlp.vocab, **cfg),
'dependencies': lambda nlp, **cfg: NeuralDependencyParser(nlp.vocab, **cfg),
'entities': lambda nlp, **cfg: NeuralEntityRecognizer(nlp.vocab, **cfg),
}
token_match = TOKEN_MATCH token_match = TOKEN_MATCH
prefixes = tuple(TOKENIZER_PREFIXES) prefixes = tuple(TOKENIZER_PREFIXES)
suffixes = tuple(TOKENIZER_SUFFIXES) suffixes = tuple(TOKENIZER_SUFFIXES)
@ -161,120 +100,30 @@ class Language(object):
Defaults = BaseDefaults Defaults = BaseDefaults
lang = None lang = None
@classmethod def __init__(self, vocab=True, make_doc=True, pipeline=None, meta={}):
def setup_directory(cls, path, **configs): self.meta = dict(meta)
"""
Initialise a model directory.
"""
for name, config in configs.items():
directory = path / name
if directory.exists():
shutil.rmtree(str(directory))
directory.mkdir()
with (directory / 'config.json').open('w') as file_:
data = json_dumps(config)
file_.write(data)
if not (path / 'vocab').exists():
(path / 'vocab').mkdir()
@classmethod if vocab is True:
@contextmanager factory = self.Defaults.create_vocab
def train(cls, path, gold_tuples, **configs): vocab = factory(self, **meta.get('vocab', {}))
parser_cfg = configs.get('deps', {}) self.vocab = vocab
if parser_cfg.get('pseudoprojective'): if make_doc is True:
# preprocess training data here before ArcEager.get_labels() is called factory = self.Defaults.create_tokenizer
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples) make_doc = factory(self, **meta.get('tokenizer', {}))
self.make_doc = make_doc
for subdir in ('deps', 'ner', 'pos'): if pipeline is True:
if subdir not in configs: self.pipeline = self.Defaults.create_pipeline(self)
configs[subdir] = {} elif pipeline:
if parser_cfg: self.pipeline = list(pipeline)
configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples) # Resolve strings, like "cnn", "lstm", etc
if 'ner' in configs: for i, entry in enumerate(self.pipeline):
configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples) if entry in self.Defaults.factories:
factory = self.Defaults.factories[entry]
cls.setup_directory(path, **configs) self.pipeline[i] = factory(self, **meta.get(entry, {}))
self = cls(
path=path,
vocab=False,
tokenizer=False,
tagger=False,
parser=False,
entity=False,
matcher=False,
vectors=False,
pipeline=False)
self.vocab = self.Defaults.create_vocab(self)
self.tokenizer = self.Defaults.create_tokenizer(self)
self.tagger = self.Defaults.create_tagger(self)
self.parser = self.Defaults.create_parser(self)
self.entity = self.Defaults.create_entity(self)
self.pipeline = self.Defaults.create_pipeline(self)
yield Trainer(self, gold_tuples)
self.end_training()
self.save_to_directory(path)
def __init__(self, **overrides):
"""
Create or load the pipeline.
Arguments:
**overrides: Keyword arguments indicating which defaults to override.
Returns:
Language: The newly constructed object.
"""
if 'data_dir' in overrides and 'path' not in overrides:
raise ValueError("The argument 'data_dir' has been renamed to 'path'")
path = util.ensure_path(overrides.get('path', True))
if path is True:
path = util.get_data_path() / self.lang
if not path.exists() and 'path' not in overrides:
path = None
self.meta = overrides.get('meta', {})
self.path = path
self.vocab = self.Defaults.create_vocab(self) \
if 'vocab' not in overrides \
else overrides['vocab']
add_vectors = self.Defaults.add_vectors(self) \
if 'add_vectors' not in overrides \
else overrides['add_vectors']
if self.vocab and add_vectors:
add_vectors(self.vocab)
self.tokenizer = self.Defaults.create_tokenizer(self) \
if 'tokenizer' not in overrides \
else overrides['tokenizer']
self.tagger = self.Defaults.create_tagger(self) \
if 'tagger' not in overrides \
else overrides['tagger']
self.parser = self.Defaults.create_parser(self) \
if 'parser' not in overrides \
else overrides['parser']
self.entity = self.Defaults.create_entity(self) \
if 'entity' not in overrides \
else overrides['entity']
self.matcher = self.Defaults.create_matcher(self) \
if 'matcher' not in overrides \
else overrides['matcher']
if 'make_doc' in overrides:
self.make_doc = overrides['make_doc']
elif 'create_make_doc' in overrides:
self.make_doc = overrides['create_make_doc'](self)
elif not hasattr(self, 'make_doc'):
self.make_doc = lambda text: self.tokenizer(text)
if 'pipeline' in overrides:
self.pipeline = overrides['pipeline']
elif 'create_pipeline' in overrides:
self.pipeline = overrides['create_pipeline'](self)
else: else:
self.pipeline = [self.tagger, self.parser, self.matcher, self.entity] self.pipeline = []
def __call__(self, text, tag=True, parse=True, entity=True): def __call__(self, text, **disabled):
""" """
Apply the pipeline to some text. The text can span multiple sentences, Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string and can contain arbtrary whitespace. Alignment into the original string
@ -294,18 +143,24 @@ class Language(object):
('An', 'NN') ('An', 'NN')
""" """
doc = self.make_doc(text) doc = self.make_doc(text)
if self.entity and entity:
# Add any of the entity labels already set, in case we don't have them.
for token in doc:
if token.ent_type != 0:
self.entity.add_label(token.ent_type)
skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity}
for proc in self.pipeline: for proc in self.pipeline:
if proc and not skip.get(proc): name = getattr(proc, 'name', None)
proc(doc) if name in disabled and not disabled[named]:
continue
proc(doc)
return doc return doc
def pipe(self, texts, tag=True, parse=True, entity=True, n_threads=2, batch_size=1000): @contextmanager
def begin_training(self, gold_tuples, **cfg):
contexts = []
for proc in self.pipeline:
if hasattr(proc, 'begin_training'):
context = proc.begin_training(gold_tuples, pipeline=self.pipeline)
contexts.append(context)
trainer = Trainer(self, gold_tuples, **cfg)
yield trainer, trainer.optimizer
def pipe(self, texts, n_threads=2, batch_size=1000, **disabled):
""" """
Process texts as a stream, and yield Doc objects in order. Process texts as a stream, and yield Doc objects in order.
@ -317,55 +172,28 @@ class Language(object):
parse (bool) parse (bool)
entity (bool) entity (bool)
""" """
skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity}
stream = (self.make_doc(text) for text in texts) stream = (self.make_doc(text) for text in texts)
for proc in self.pipeline: for proc in self.pipeline:
if proc and not skip.get(proc): name = getattr(proc, 'name', None)
if hasattr(proc, 'pipe'): if name in disabled and not disabled[named]:
stream = proc.pipe(stream, n_threads=n_threads, batch_size=batch_size) continue
else:
stream = (proc(item) for item in stream) if hasattr(proc, 'pipe'):
stream = proc.pipe(stream, n_threads=n_threads, batch_size=batch_size)
else:
stream = (proc(item) for item in stream)
for doc in stream: for doc in stream:
yield doc yield doc
def save_to_directory(self, path): def to_disk(self, path):
""" raise NotImplemented
Save the Vocab, StringStore and pipeline to a directory.
Arguments: def from_disk(self, path):
path (string or pathlib path): Path to save the model. raise NotImplemented
"""
configs = {
'pos': self.tagger.cfg if self.tagger else {},
'deps': self.parser.cfg if self.parser else {},
'ner': self.entity.cfg if self.entity else {},
}
path = util.ensure_path(path) def to_bytes(self, path):
if not path.exists(): raise NotImplemented
path.mkdir()
self.setup_directory(path, **configs)
strings_loc = path / 'vocab' / 'strings.json' def from_bytes(self, path):
with strings_loc.open('w', encoding='utf8') as file_: raise NotImplemented
self.vocab.strings.dump(file_)
self.vocab.dump(path / 'vocab' / 'lexemes.bin')
# TODO: Word vectors?
if self.tagger:
self.tagger.model.dump(str(path / 'pos' / 'model'))
if self.parser:
self.parser.model.dump(str(path / 'deps' / 'model'))
if self.entity:
self.entity.model.dump(str(path / 'ner' / 'model'))
def end_training(self, path=None):
if self.tagger:
self.tagger.model.end_training()
if self.parser:
self.parser.model.end_training()
if self.entity:
self.entity.model.end_training()
# NB: This is slightly different from before --- we no longer default
# to taking nlp.path
if path is not None:
self.save_to_directory(path)

View File

@ -9,7 +9,8 @@ import numpy
cimport numpy as np cimport numpy as np
from .tokens.doc cimport Doc from .tokens.doc cimport Doc
from .syntax.parser cimport Parser from .syntax.parser cimport Parser as LinearParser
from .syntax.nn_parser cimport Parser as NeuralParser
from .syntax.parser import get_templates as get_feature_templates from .syntax.parser import get_templates as get_feature_templates
from .syntax.beam_parser cimport BeamParser from .syntax.beam_parser cimport BeamParser
from .syntax.ner cimport BiluoPushDown from .syntax.ner cimport BiluoPushDown
@ -30,13 +31,13 @@ from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
from ._ml import Tok2Vec, flatten, get_col, doc2feats from ._ml import Tok2Vec, flatten, get_col, doc2feats
class TokenVectorEncoder(object): class TokenVectorEncoder(object):
'''Assign position-sensitive vectors to tokens, using a CNN or RNN.''' '''Assign position-sensitive vectors to tokens, using a CNN or RNN.'''
name = 'tok2vec'
@classmethod @classmethod
def Model(cls, width=128, embed_size=5000, **cfg): def Model(cls, width=128, embed_size=5000, **cfg):
return Tok2Vec(width, embed_size, preprocess=False) return Tok2Vec(width, embed_size, preprocess=doc2feats())
def __init__(self, vocab, model=True, **cfg): def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab self.vocab = vocab
@ -76,10 +77,11 @@ class TokenVectorEncoder(object):
doc.vocab.morphology.assign_tag_id(&doc.c[j], tag_id) doc.vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
idx += 1 idx += 1
def update(self, docs_feats, golds, drop=0., sgd=None): def update(self, docs, golds, drop=0., sgd=None):
return 0.0
cdef int i, j, idx cdef int i, j, idx
cdef GoldParse gold cdef GoldParse gold
docs, feats = docs_feats feats = self.doc2feats(docs)
scores, finish_update = self.tagger.begin_update(feats, drop=drop) scores, finish_update = self.tagger.begin_update(feats, drop=drop)
tag_index = {tag: i for i, tag in enumerate(docs[0].vocab.morphology.tag_names)} tag_index = {tag: i for i, tag in enumerate(docs[0].vocab.morphology.tag_names)}
@ -95,7 +97,7 @@ class TokenVectorEncoder(object):
finish_update(d_scores, sgd) finish_update(d_scores, sgd)
cdef class EntityRecognizer(Parser): cdef class EntityRecognizer(LinearParser):
""" """
Annotate named entities on Doc objects. Annotate named entities on Doc objects.
""" """
@ -104,7 +106,7 @@ cdef class EntityRecognizer(Parser):
feature_templates = get_feature_templates('ner') feature_templates = get_feature_templates('ner')
def add_label(self, label): def add_label(self, label):
Parser.add_label(self, label) LinearParser.add_label(self, label)
if isinstance(label, basestring): if isinstance(label, basestring):
label = self.vocab.strings[label] label = self.vocab.strings[label]
@ -118,21 +120,31 @@ cdef class BeamEntityRecognizer(BeamParser):
feature_templates = get_feature_templates('ner') feature_templates = get_feature_templates('ner')
def add_label(self, label): def add_label(self, label):
Parser.add_label(self, label) LinearParser.add_label(self, label)
if isinstance(label, basestring): if isinstance(label, basestring):
label = self.vocab.strings[label] label = self.vocab.strings[label]
cdef class DependencyParser(Parser): cdef class DependencyParser(LinearParser):
TransitionSystem = ArcEager TransitionSystem = ArcEager
feature_templates = get_feature_templates('basic') feature_templates = get_feature_templates('basic')
def add_label(self, label): def add_label(self, label):
Parser.add_label(self, label) LinearParser.add_label(self, label)
if isinstance(label, basestring): if isinstance(label, basestring):
label = self.vocab.strings[label] label = self.vocab.strings[label]
cdef class NeuralDependencyParser(NeuralParser):
name = 'parser'
TransitionSystem = ArcEager
cdef class NeuralEntityRecognizer(NeuralParser):
name = 'entity'
TransitionSystem = BiluoPushDown
cdef class BeamDependencyParser(BeamParser): cdef class BeamDependencyParser(BeamParser):
TransitionSystem = ArcEager TransitionSystem = ArcEager

View File

@ -238,11 +238,7 @@ cdef class Parser:
upper.begin_training(upper.ops.allocate((500, hidden_width))) upper.begin_training(upper.ops.allocate((500, hidden_width)))
return tok2vec, lower, upper return tok2vec, lower, upper
@classmethod def __init__(self, Vocab vocab, model=True, **cfg):
def Moves(cls):
return TransitionSystem()
def __init__(self, Vocab vocab, moves=True, model=True, **cfg):
""" """
Create a Parser. Create a Parser.
@ -262,9 +258,13 @@ cdef class Parser:
Arbitrary configuration parameters. Set to the .cfg attribute Arbitrary configuration parameters. Set to the .cfg attribute
""" """
self.vocab = vocab self.vocab = vocab
self.moves = self.Moves(self.vocab) if moves is True else moves self.moves = self.TransitionSystem(self.vocab.strings, {})
self.model = self.Model(self.moves.n_moves) if model is True else model
self.cfg = cfg self.cfg = cfg
if 'actions' in self.cfg:
for action, labels in self.cfg.get('actions', {}).items():
for label in labels:
self.moves.add_action(action, label)
self.model = model
def __reduce__(self): def __reduce__(self):
return (Parser, (self.vocab, self.moves, self.model, self.cfg), None, None) return (Parser, (self.vocab, self.moves, self.model, self.cfg), None, None)
@ -440,6 +440,17 @@ cdef class Parser:
# order, or the model goes out of synch # order, or the model goes out of synch
self.cfg.setdefault('extra_labels', []).append(label) self.cfg.setdefault('extra_labels', []).append(label)
def begin_training(self, gold_tuples, **cfg):
if 'model' in cfg:
self.model = cfg['model']
actions = self.moves.get_actions(gold_parses=gold_tuples)
for action, labels in actions.items():
for label in labels:
self.moves.add_action(action, label)
if self.model is True:
tok2vec = cfg['pipeline'][0].model
self.model = self.Model(self.moves.n_moves, tok2vec=tok2vec, **cfg)
class ParserStateError(ValueError): class ParserStateError(ValueError):
def __init__(self, doc): def __init__(self, doc):

View File

@ -3,12 +3,14 @@ from __future__ import absolute_import, unicode_literals
import random import random
import tqdm import tqdm
from cytoolz import partition_all
from thinc.neural.optimizers import Adam from thinc.neural.optimizers import Adam
from thinc.neural.ops import NumpyOps, CupyOps from thinc.neural.ops import NumpyOps, CupyOps
from .gold import GoldParse, merge_sents from .gold import GoldParse, merge_sents
from .scorer import Scorer from .scorer import Scorer
from .tokens.doc import Doc
class Trainer(object): class Trainer(object):
@ -19,6 +21,7 @@ class Trainer(object):
self.nlp = nlp self.nlp = nlp
self.gold_tuples = gold_tuples self.gold_tuples = gold_tuples
self.nr_epoch = 0 self.nr_epoch = 0
self.optimizer = Adam(NumpyOps(), 0.001)
def epochs(self, nr_epoch, augment_data=None, gold_preproc=False): def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
cached_golds = {} cached_golds = {}
@ -75,9 +78,9 @@ class Trainer(object):
def make_docs(self, raw_text, paragraph_tuples): def make_docs(self, raw_text, paragraph_tuples):
if raw_text is not None: if raw_text is not None:
return [self.nlp.tokenizer(raw_text)] return [self.nlp.make_doc(raw_text)]
else: else:
return [self.nlp.tokenizer.tokens_from_list(sent_tuples[0][1]) return [Doc(self.nlp.vocab, words=sent_tuples[0][1])
for sent_tuples in paragraph_tuples] for sent_tuples in paragraph_tuples]
def make_golds(self, docs, paragraph_tuples): def make_golds(self, docs, paragraph_tuples):