Merge branch 'master' of ssh://github.com/explosion/spaCy

This commit is contained in:
Matthew Honnibal 2016-11-26 12:36:18 +01:00
commit 296d33a4fc
13 changed files with 130 additions and 157 deletions

View File

@ -100,7 +100,7 @@ def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False
nlp.entity(tokens) nlp.entity(tokens)
else: else:
tokens = nlp(raw_text) tokens = nlp(raw_text)
gold = GoldParse(tokens, annot_tuples) gold = GoldParse.from_annot_tuples(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=verbose) scorer.score(tokens, gold, verbose=verbose)
return scorer return scorer

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
import plac import plac
import json import json
from os import path from os import path
@ -5,106 +6,25 @@ import shutil
import os import os
import random import random
import io import io
import pathlib
from spacy.syntax.util import Config from spacy.tokens import Doc
from spacy.syntax.nonproj import PseudoProjectivity
from spacy.language import Language
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.tokenizer import Tokenizer
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tagger import Tagger from spacy.tagger import Tagger
from spacy.syntax.parser import Parser from spacy.pipeline import DependencyParser
from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.parser import get_templates from spacy.syntax.parser import get_templates
from spacy.syntax.arc_eager import ArcEager
from spacy.scorer import Scorer from spacy.scorer import Scorer
import spacy.attrs import spacy.attrs
import io
from spacy.language import Language
from spacy.tagger import W_orth
TAGGER_TEMPLATES = (
(W_orth,),
)
try:
from codecs import open
except ImportError:
pass
class TreebankParser(object):
@staticmethod
def setup_model_dir(model_dir, labels, templates, feat_set='basic', seed=0):
dep_model_dir = path.join(model_dir, 'deps')
pos_model_dir = path.join(model_dir, 'pos')
if path.exists(dep_model_dir):
shutil.rmtree(dep_model_dir)
if path.exists(pos_model_dir):
shutil.rmtree(pos_model_dir)
os.mkdir(dep_model_dir)
os.mkdir(pos_model_dir)
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
labels=labels)
@classmethod
def from_dir(cls, tag_map, model_dir):
vocab = Vocab(tag_map=tag_map, get_lex_attr=Language.default_lex_attrs())
vocab.get_lex_attr[spacy.attrs.LANG] = lambda _: 0
tokenizer = Tokenizer(vocab, {}, None, None, None)
tagger = Tagger.blank(vocab, TAGGER_TEMPLATES)
cfg = Config.read(path.join(model_dir, 'deps'), 'config')
parser = Parser.from_dir(path.join(model_dir, 'deps'), vocab.strings, ArcEager)
return cls(vocab, tokenizer, tagger, parser)
def __init__(self, vocab, tokenizer, tagger, parser):
self.vocab = vocab
self.tokenizer = tokenizer
self.tagger = tagger
self.parser = parser
def train(self, words, tags, heads, deps):
tokens = self.tokenizer.tokens_from_list(list(words))
self.tagger.train(tokens, tags)
tokens = self.tokenizer.tokens_from_list(list(words))
ids = range(len(words))
ner = ['O'] * len(words)
gold = GoldParse(tokens, ((ids, words, tags, heads, deps, ner)),
make_projective=False)
self.tagger(tokens)
if gold.is_projective:
try:
self.parser.train(tokens, gold)
except:
for id_, word, head, dep in zip(ids, words, heads, deps):
print(id_, word, head, dep)
raise
def __call__(self, words, tags=None):
tokens = self.tokenizer.tokens_from_list(list(words))
if tags is None:
self.tagger(tokens)
else:
self.tagger.tag_from_strings(tokens, tags)
self.parser(tokens)
return tokens
def end_training(self, data_dir):
self.parser.model.end_training()
self.parser.model.dump(path.join(data_dir, 'deps', 'model'))
self.tagger.model.end_training()
self.tagger.model.dump(path.join(data_dir, 'pos', 'model'))
strings_loc = path.join(data_dir, 'vocab', 'strings.json')
with io.open(strings_loc, 'w', encoding='utf8') as file_:
self.vocab.strings.dump(file_)
self.vocab.dump(path.join(data_dir, 'vocab', 'lexemes.bin'))
def read_conllx(loc): def read_conllx(loc):
with open(loc, 'r', 'utf8') as file_: with io.open(loc, 'r', encoding='utf8') as file_:
text = file_.read() text = file_.read()
for sent in text.strip().split('\n\n'): for sent in text.strip().split('\n\n'):
lines = sent.strip().split('\n') lines = sent.strip().split('\n')
@ -113,24 +33,31 @@ def read_conllx(loc):
lines.pop(0) lines.pop(0)
tokens = [] tokens = []
for line in lines: for line in lines:
id_, word, lemma, pos, tag, morph, head, dep, _1, _2 = line.split() id_, word, lemma, tag, pos, morph, head, dep, _1, _2 = line.split()
if '-' in id_: if '-' in id_:
continue continue
id_ = int(id_) - 1 try:
head = (int(head) - 1) if head != '0' else id_ id_ = int(id_) - 1
dep = 'ROOT' if dep == 'root' else dep head = (int(head) - 1) if head != '0' else id_
tokens.append((id_, word, tag, head, dep, 'O')) dep = 'ROOT' if dep == 'root' else dep
tuples = zip(*tokens) tokens.append((id_, word, tag, head, dep, 'O'))
yield (None, [(tuples, [])]) except:
print(line)
raise
tuples = [list(t) for t in zip(*tokens)]
yield (None, [[tuples, []]])
def score_model(nlp, gold_docs, verbose=False): def score_model(vocab, tagger, parser, gold_docs, verbose=False):
scorer = Scorer() scorer = Scorer()
for _, gold_doc in gold_docs: for _, gold_doc in gold_docs:
for annot_tuples, _ in gold_doc: for (ids, words, tags, heads, deps, entities), _ in gold_doc:
tokens = nlp(list(annot_tuples[1]), tags=list(annot_tuples[2])) doc = Doc(vocab, words=words)
gold = GoldParse(tokens, annot_tuples) tagger(doc)
scorer.score(tokens, gold, verbose=verbose) parser(doc)
PseudoProjectivity.deprojectivize(doc)
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
scorer.score(doc, gold, verbose=verbose)
return scorer return scorer
@ -138,22 +65,45 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc):
with open(tag_map_loc) as file_: with open(tag_map_loc) as file_:
tag_map = json.loads(file_.read()) tag_map = json.loads(file_.read())
train_sents = list(read_conllx(train_loc)) train_sents = list(read_conllx(train_loc))
labels = ArcEager.get_labels(train_sents) train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
templates = get_templates('basic')
TreebankParser.setup_model_dir(model_dir, labels, templates) actions = ArcEager.get_actions(gold_parses=train_sents)
features = get_templates('basic')
nlp = TreebankParser.from_dir(tag_map, model_dir) model_dir = pathlib.Path(model_dir)
with (model_dir / 'deps' / 'config.json').open('w') as file_:
json.dump({'pseudoprojective': True, 'labels': actions, 'features': features}, file_)
vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map)
# Populate vocab
for _, doc_sents in train_sents:
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
for word in words:
_ = vocab[word]
for dep in deps:
_ = vocab[dep]
for tag in tags:
_ = vocab[tag]
for tag in tags:
assert tag in tag_map, repr(tag)
tagger = Tagger(vocab, tag_map=tag_map)
parser = DependencyParser(vocab, actions=actions, features=features)
for itn in range(15): for itn in range(15):
for _, doc_sents in train_sents: for _, doc_sents in train_sents:
for (ids, words, tags, heads, deps, ner), _ in doc_sents: for (ids, words, tags, heads, deps, ner), _ in doc_sents:
nlp.train(words, tags, heads, deps) doc = Doc(vocab, words=words)
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
tagger(doc)
parser.update(doc, gold)
doc = Doc(vocab, words=words)
tagger.update(doc, gold)
random.shuffle(train_sents) random.shuffle(train_sents)
scorer = score_model(nlp, read_conllx(dev_loc)) scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc)) print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc))
nlp = Language(vocab=vocab, tagger=tagger, parser=parser)
nlp.end_training(model_dir) nlp.end_training(model_dir)
scorer = score_model(nlp, read_conllx(dev_loc)) scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc)) print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))

View File

@ -116,6 +116,10 @@ def intify_attrs(stringy_attrs, strings_map=None, _do_deprecated=False):
stringy_attrs["TAG"] = stringy_attrs.pop("pos") stringy_attrs["TAG"] = stringy_attrs.pop("pos")
if 'morph' in stringy_attrs: if 'morph' in stringy_attrs:
morphs = stringy_attrs.pop('morph') morphs = stringy_attrs.pop('morph')
if 'number' in stringy_attrs:
stringy_attrs.pop('number')
if 'tenspect' in stringy_attrs:
stringy_attrs.pop('tenspect')
# for name, value in morphs.items(): # for name, value in morphs.items():
# stringy_attrs[name] = value # stringy_attrs[name] = value
for name, value in stringy_attrs.items(): for name, value in stringy_attrs.items():

View File

@ -19,6 +19,7 @@ cdef class GoldParse:
cdef int length cdef int length
cdef readonly int loss cdef readonly int loss
cdef readonly list words
cdef readonly list tags cdef readonly list tags
cdef readonly list heads cdef readonly list heads
cdef readonly list labels cdef readonly list labels

View File

@ -19,6 +19,8 @@ def tags_to_entities(tags):
entities = [] entities = []
start = None start = None
for i, tag in enumerate(tags): for i, tag in enumerate(tags):
if tag is None:
continue
if tag.startswith('O'): if tag.startswith('O'):
# TODO: We shouldn't be getting these malformed inputs. Fix this. # TODO: We shouldn't be getting these malformed inputs. Fix this.
if start is not None: if start is not None:
@ -249,7 +251,7 @@ cdef class GoldParse:
if deps is None: if deps is None:
deps = [None for _ in doc] deps = [None for _ in doc]
if entities is None: if entities is None:
entities = [None for _ in doc] entities = ['-' for _ in doc]
elif len(entities) == 0: elif len(entities) == 0:
entities = ['O' for _ in doc] entities = ['O' for _ in doc]
elif not isinstance(entities[0], basestring): elif not isinstance(entities[0], basestring):
@ -266,6 +268,7 @@ cdef class GoldParse:
self.c.labels = <int*>self.mem.alloc(len(doc), sizeof(int)) self.c.labels = <int*>self.mem.alloc(len(doc), sizeof(int))
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition)) self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
self.words = [None] * len(doc)
self.tags = [None] * len(doc) self.tags = [None] * len(doc)
self.heads = [None] * len(doc) self.heads = [None] * len(doc)
self.labels = [''] * len(doc) self.labels = [''] * len(doc)
@ -279,6 +282,7 @@ cdef class GoldParse:
for i, gold_i in enumerate(self.cand_to_gold): for i, gold_i in enumerate(self.cand_to_gold):
if doc[i].text.isspace(): if doc[i].text.isspace():
self.words[i] = doc[i].text
self.tags[i] = 'SP' self.tags[i] = 'SP'
self.heads[i] = None self.heads[i] = None
self.labels[i] = None self.labels[i] = None
@ -286,6 +290,7 @@ cdef class GoldParse:
if gold_i is None: if gold_i is None:
pass pass
else: else:
self.words[i] = words[gold_i]
self.tags[i] = tags[gold_i] self.tags[i] = tags[gold_i]
self.heads[i] = self.gold_to_cand[heads[gold_i]] self.heads[i] = self.gold_to_cand[heads[gold_i]]
self.labels[i] = deps[gold_i] self.labels[i] = deps[gold_i]

View File

@ -5,10 +5,7 @@ import pathlib
from contextlib import contextmanager from contextlib import contextmanager
import shutil import shutil
try: import ujson as json
import ujson as json
except ImportError:
import json
try: try:
@ -31,6 +28,8 @@ from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP
from .syntax.parser import get_templates from .syntax.parser import get_templates
from .syntax.nonproj import PseudoProjectivity from .syntax.nonproj import PseudoProjectivity
from .pipeline import DependencyParser, EntityRecognizer from .pipeline import DependencyParser, EntityRecognizer
from .syntax.arc_eager import ArcEager
from .syntax.ner import BiluoPushDown
class BaseDefaults(object): class BaseDefaults(object):
@ -96,26 +95,27 @@ class BaseDefaults(object):
return Tagger.load(nlp.path / 'pos', nlp.vocab) return Tagger.load(nlp.path / 'pos', nlp.vocab)
@classmethod @classmethod
def create_parser(cls, nlp=None): def create_parser(cls, nlp=None, **cfg):
if nlp is None: if nlp is None:
return DependencyParser(cls.create_vocab(), features=cls.parser_features) return DependencyParser(cls.create_vocab(), features=cls.parser_features,
**cfg)
elif nlp.path is False: elif nlp.path is False:
return DependencyParser(nlp.vocab, features=cls.parser_features) return DependencyParser(nlp.vocab, features=cls.parser_features, **cfg)
elif nlp.path is None or not (nlp.path / 'deps').exists(): elif nlp.path is None or not (nlp.path / 'deps').exists():
return None return None
else: else:
return DependencyParser.load(nlp.path / 'deps', nlp.vocab) return DependencyParser.load(nlp.path / 'deps', nlp.vocab, **cfg)
@classmethod @classmethod
def create_entity(cls, nlp=None): def create_entity(cls, nlp=None, **cfg):
if nlp is None: if nlp is None:
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features) return EntityRecognizer(cls.create_vocab(), features=cls.entity_features, **cfg)
elif nlp.path is False: elif nlp.path is False:
return EntityRecognizer(nlp.vocab, features=cls.entity_features) return EntityRecognizer(nlp.vocab, features=cls.entity_features, **cfg)
elif nlp.path is None or not (nlp.path / 'ner').exists(): elif nlp.path is None or not (nlp.path / 'ner').exists():
return None return None
else: else:
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab) return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab, **cfg)
@classmethod @classmethod
def create_matcher(cls, nlp=None): def create_matcher(cls, nlp=None):
@ -216,14 +216,14 @@ class Language(object):
# preprocess training data here before ArcEager.get_labels() is called # preprocess training data here before ArcEager.get_labels() is called
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples) gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
parser_cfg['labels'] = ArcEager.get_labels(gold_tuples) parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
entity_cfg['labels'] = BiluoPushDown.get_labels(gold_tuples) entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
with (dep_model_dir / 'config.json').open('wb') as file_: with (dep_model_dir / 'config.json').open('w') as file_:
json.dump(parser_cfg, file_) json.dump(parser_cfg, file_)
with (ner_model_dir / 'config.json').open('wb') as file_: with (ner_model_dir / 'config.json').open('w') as file_:
json.dump(entity_cfg, file_) json.dump(entity_cfg, file_)
with (pos_model_dir / 'config.json').open('wb') as file_: with (pos_model_dir / 'config.json').open('w') as file_:
json.dump(tagger_cfg, file_) json.dump(tagger_cfg, file_)
self = cls( self = cls(
@ -238,15 +238,12 @@ class Language(object):
vectors=False, vectors=False,
pipeline=False) pipeline=False)
self.defaults.parser_labels = parser_cfg['labels'] self.vocab = self.Defaults.create_vocab(self)
self.defaults.entity_labels = entity_cfg['labels'] self.tokenizer = self.Defaults.create_tokenizer(self)
self.tagger = self.Defaults.create_tagger(self)
self.vocab = self.defaults.Vocab() self.parser = self.Defaults.create_parser(self)
self.tokenizer = self.defaults.Tokenizer(self.vocab) self.entity = self.Defaults.create_entity(self)
self.tagger = self.defaults.Tagger(self.vocab, **tagger_cfg) self.pipeline = self.Defaults.create_pipeline(self)
self.parser = self.defaults.Parser(self.vocab, **parser_cfg)
self.entity = self.defaults.Entity(self.vocab, **entity_cfg)
self.pipeline = self.defaults.Pipeline(self)
yield Trainer(self, gold_tuples) yield Trainer(self, gold_tuples)
self.end_training() self.end_training()
@ -267,7 +264,7 @@ class Language(object):
add_vectors = self.Defaults.add_vectors(self) \ add_vectors = self.Defaults.add_vectors(self) \
if 'add_vectors' not in overrides \ if 'add_vectors' not in overrides \
else overrides['add_vectors'] else overrides['add_vectors']
if add_vectors: if self.vocab and add_vectors:
add_vectors(self.vocab) add_vectors(self.vocab)
self.tokenizer = self.Defaults.create_tokenizer(self) \ self.tokenizer = self.Defaults.create_tokenizer(self) \
if 'tokenizer' not in overrides \ if 'tokenizer' not in overrides \
@ -387,7 +384,7 @@ class Language(object):
else: else:
entity_iob_freqs = [] entity_iob_freqs = []
entity_type_freqs = [] entity_type_freqs = []
with (path / 'vocab' / 'serializer.json').open('wb') as file_: with (path / 'vocab' / 'serializer.json').open('w') as file_:
file_.write( file_.write(
json.dumps([ json.dumps([
(TAG, tagger_freqs), (TAG, tagger_freqs),

View File

@ -87,7 +87,7 @@ class Scorer(object):
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot])) gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))
for id_, word, tag, head, dep, ner in gold.orig_annot: for id_, word, tag, head, dep, ner in gold.orig_annot:
gold_tags.add((id_, tag)) gold_tags.add((id_, tag))
if dep.lower() not in punct_labels: if dep is not None and dep.lower() not in punct_labels:
gold_deps.add((id_, head, dep.lower())) gold_deps.add((id_, head, dep.lower()))
cand_deps = set() cand_deps = set()
cand_tags = set() cand_tags = set()

View File

@ -439,7 +439,7 @@ cdef class ArcEager(TransitionSystem):
if move_costs[move] == -1: if move_costs[move] == -1:
move_costs[move] = move_cost_funcs[move](stcls, &gold.c) move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label) costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
n_gold += costs[i] == 0 n_gold += costs[i] <= 0
else: else:
is_valid[i] = False is_valid[i] = False
costs[i] = 9000 costs[i] = 9000
@ -456,8 +456,14 @@ cdef class ArcEager(TransitionSystem):
"before training and after parsing. Either pass make_projective=True " "before training and after parsing. Either pass make_projective=True "
"to the GoldParse class, or use PseudoProjectivity.preprocess_training_data") "to the GoldParse class, or use PseudoProjectivity.preprocess_training_data")
else: else:
print(gold.words)
print(gold.heads)
print(gold.labels)
raise ValueError( raise ValueError(
"Could not find a gold-standard action to supervise the dependency " "Could not find a gold-standard action to supervise the dependency "
"parser.\n" "parser.\n"
"The GoldParse was projective.") "The GoldParse was projective.\n"
"The transition system has %d actions.\n"
"State at failure:\n"
"%s" % (self.n_moves, stcls.print_state(gold.words)))
assert n_gold >= 1 assert n_gold >= 1

View File

@ -65,7 +65,7 @@ cdef class BiluoPushDown(TransitionSystem):
for action in (BEGIN, IN, LAST, UNIT): for action in (BEGIN, IN, LAST, UNIT):
actions[action][entity_type] = True actions[action][entity_type] = True
moves = ('M', 'B', 'I', 'L', 'U') moves = ('M', 'B', 'I', 'L', 'U')
for raw_text, sents in kwargs.get('gold_tuples', []): for raw_text, sents in kwargs.get('gold_parses', []):
for (ids, words, tags, heads, labels, biluo), _ in sents: for (ids, words, tags, heads, labels, biluo), _ in sents:
for i, ner_tag in enumerate(biluo): for i, ner_tag in enumerate(biluo):
if ner_tag != 'O' and ner_tag != '-': if ner_tag != 'O' and ner_tag != '-':

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from copy import copy from copy import copy
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc

View File

@ -76,7 +76,7 @@ cdef class ParserModel(AveragedPerceptron):
cdef class Parser: cdef class Parser:
"""Base class of the DependencyParser and EntityRecognizer.""" """Base class of the DependencyParser and EntityRecognizer."""
@classmethod @classmethod
def load(cls, path, Vocab vocab, TransitionSystem=None, require=False): def load(cls, path, Vocab vocab, TransitionSystem=None, require=False, **cfg):
"""Load the statistical model from the supplied path. """Load the statistical model from the supplied path.
Arguments: Arguments:
@ -92,7 +92,7 @@ cdef class Parser:
with (path / 'config.json').open() as file_: with (path / 'config.json').open() as file_:
cfg = json.load(file_) cfg = json.load(file_)
# TODO: remove this shim when we don't have to support older data # TODO: remove this shim when we don't have to support older data
if 'labels' in cfg: if 'labels' in cfg and 'actions' not in cfg:
cfg['actions'] = cfg.pop('labels') cfg['actions'] = cfg.pop('labels')
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg) self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
if (path / 'model').exists(): if (path / 'model').exists():
@ -266,7 +266,7 @@ cdef class Parser:
loss += eg.costs[eg.guess] loss += eg.costs[eg.guess]
eg.fill_scores(0, eg.nr_class) eg.fill_scores(0, eg.nr_class)
eg.fill_costs(0, eg.nr_class) eg.fill_costs(0, eg.nr_class)
eg.fill_is_valid(0, eg.nr_class) eg.fill_is_valid(1, eg.nr_class)
return loss return loss
def step_through(self, Doc doc): def step_through(self, Doc doc):

View File

@ -14,22 +14,31 @@ class Trainer(object):
self.gold_tuples = gold_tuples self.gold_tuples = gold_tuples
def epochs(self, nr_epoch, augment_data=None, gold_preproc=False): def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
def _epoch(): cached_golds = {}
for raw_text, paragraph_tuples in self.gold_tuples: def _epoch(indices):
for i in indices:
raw_text, paragraph_tuples = self.gold_tuples[i]
if gold_preproc: if gold_preproc:
raw_text = None raw_text = None
else: else:
paragraph_tuples = merge_sents(paragraph_tuples) paragraph_tuples = merge_sents(paragraph_tuples)
if augment_data is not None: if augment_data is None:
docs = self.make_docs(raw_text, paragraph_tuples)
if i in cached_golds:
golds = cached_golds[i]
else:
golds = self.make_golds(docs, paragraph_tuples)
else:
raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples) raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples)
docs = self.make_docs(raw_text, paragraph_tuples) docs = self.make_docs(raw_text, paragraph_tuples)
golds = self.make_golds(docs, paragraph_tuples) golds = self.make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
yield doc, gold yield doc, gold
indices = list(range(len(self.gold_tuples)))
for itn in range(nr_epoch): for itn in range(nr_epoch):
random.shuffle(self.gold_tuples) random.shuffle(indices)
yield _epoch() yield _epoch(indices)
def update(self, doc, gold): def update(self, doc, gold):
for process in self.nlp.pipeline: for process in self.nlp.pipeline:
@ -62,8 +71,8 @@ class Trainer(object):
def make_golds(self, docs, paragraph_tuples): def make_golds(self, docs, paragraph_tuples):
if len(docs) == 1: if len(docs) == 1:
return [GoldParse(docs[0], sent_tuples[0]) return [GoldParse.from_annot_tuples(docs[0], sent_tuples[0])
for sent_tuples in paragraph_tuples] for sent_tuples in paragraph_tuples]
else: else:
return [GoldParse(doc, sent_tuples[0]) return [GoldParse.from_annot_tuples(doc, sent_tuples[0])
for doc, sent_tuples in zip(docs, paragraph_tuples)] for doc, sent_tuples in zip(docs, paragraph_tuples)]

View File

@ -36,10 +36,10 @@ p
+code("Example"). +code("Example").
doc = nlp(u'London is a big city in the United Kingdom.') doc = nlp(u'London is a big city in the United Kingdom.')
print(doc[0].text, doc[0].ent_iob, doc[0].ent_type_)) print(doc[0].text, doc[0].ent_iob, doc[0].ent_type_)
# (u'London', 2, u'GPE') # (u'London', 2, u'GPE')
print(doc[1].text, doc[1].ent_iob, doc[1].ent_type_)) print(doc[1].text, doc[1].ent_iob, doc[1].ent_type_)
(u'is', 3, u'')] # (u'is', 3, u'')
+h(2, "setting") Setting entity annotations +h(2, "setting") Setting entity annotations