Fix non-projective label filtering

This commit is contained in:
Matthew Honnibal 2018-03-27 13:41:33 +02:00
commit f57bfbccdc
4 changed files with 397 additions and 27 deletions

371
spacy/cli/ud_train.py Normal file
View File

@ -0,0 +1,371 @@
'''Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
.conllu format for development data, allowing the official scorer to be used.
'''
from __future__ import unicode_literals
import plac
import tqdm
from pathlib import Path
import re
import sys
import json
import spacy
import spacy.util
from ..tokens import Token, Doc
from ..gold import GoldParse
from ..util import compounding, minibatch_by_words
from ..syntax.nonproj import projectivize
from ..matcher import Matcher
from .. import displacy
from collections import defaultdict, Counter
from timeit import default_timer as timer
import itertools
import random
import numpy.random
import cytoolz
from . import conll17_ud_eval
from .. import lang
from .. import lang
from ..lang import zh
from ..lang import ja
################
# Data reading #
################
space_re = re.compile('\s+')
def split_text(text):
return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')]
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
max_doc_length=None, limit=None):
'''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
include Doc objects created using nlp.make_doc and then aligned against
the gold-standard sequences. If oracle_segments=True, include Doc objects
created from the gold-standard segments. At least one must be True.'''
if not raw_text and not oracle_segments:
raise ValueError("At least one of raw_text or oracle_segments must be True")
paragraphs = split_text(text_file.read())
conllu = read_conllu(conllu_file)
# sd is spacy doc; cd is conllu doc
# cs is conllu sent, ct is conllu token
docs = []
golds = []
for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
sent_annots = []
for cs in cd:
sent = defaultdict(list)
for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs:
if '.' in id_:
continue
if '-' in id_:
continue
id_ = int(id_)-1
head = int(head)-1 if head != '0' else id_
sent['words'].append(word)
sent['tags'].append(tag)
sent['heads'].append(head)
sent['deps'].append('ROOT' if dep == 'root' else dep)
sent['spaces'].append(space_after == '_')
sent['entities'] = ['-'] * len(sent['words'])
sent['heads'], sent['deps'] = projectivize(sent['heads'],
sent['deps'])
if oracle_segments:
docs.append(Doc(nlp.vocab, words=sent['words'], spaces=sent['spaces']))
golds.append(GoldParse(docs[-1], **sent))
sent_annots.append(sent)
if raw_text and max_doc_length and len(sent_annots) >= max_doc_length:
doc, gold = _make_gold(nlp, None, sent_annots)
sent_annots = []
docs.append(doc)
golds.append(gold)
if limit and len(docs) >= limit:
return docs, golds
if raw_text and sent_annots:
doc, gold = _make_gold(nlp, None, sent_annots)
docs.append(doc)
golds.append(gold)
if limit and len(docs) >= limit:
return docs, golds
return docs, golds
def read_conllu(file_):
docs = []
sent = []
doc = []
for line in file_:
if line.startswith('# newdoc'):
if doc:
docs.append(doc)
doc = []
elif line.startswith('#'):
continue
elif not line.strip():
if sent:
doc.append(sent)
sent = []
else:
sent.append(list(line.strip().split('\t')))
if len(sent[-1]) != 10:
print(repr(line))
raise ValueError
if sent:
doc.append(sent)
if doc:
docs.append(doc)
return docs
def _make_gold(nlp, text, sent_annots):
# Flatten the conll annotations, and adjust the head indices
flat = defaultdict(list)
for sent in sent_annots:
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
flat[field].extend(sent[field])
# Construct text if necessary
assert len(flat['words']) == len(flat['spaces'])
if text is None:
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
doc = nlp.make_doc(text)
flat.pop('spaces')
gold = GoldParse(doc, **flat)
return doc, gold
#############################
# Data transforms for spaCy #
#############################
def golds_to_gold_tuples(docs, golds):
'''Get out the annoying 'tuples' format used by begin_training, given the
GoldParse objects.'''
tuples = []
for doc, gold in zip(docs, golds):
text = doc.text
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
sents = [((ids, words, tags, heads, labels, iob), [])]
tuples.append((text, sents))
return tuples
##############
# Evaluation #
##############
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
with text_loc.open('r', encoding='utf8') as text_file:
texts = split_text(text_file.read())
docs = list(nlp.pipe(texts))
with sys_loc.open('w', encoding='utf8') as out_file:
write_conllu(docs, out_file)
with gold_loc.open('r', encoding='utf8') as gold_file:
gold_ud = conll17_ud_eval.load_conllu(gold_file)
with sys_loc.open('r', encoding='utf8') as sys_file:
sys_ud = conll17_ud_eval.load_conllu(sys_file)
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
return docs, scores
def write_conllu(docs, file_):
merger = Matcher(docs[0].vocab)
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
for i, doc in enumerate(docs):
matches = merger(doc)
spans = [doc[start:end+1] for _, start, end in matches]
offsets = [(span.start_char, span.end_char) for span in spans]
for start_char, end_char in offsets:
doc.merge(start_char, end_char)
file_.write("# newdoc id = {i}\n".format(i=i))
for j, sent in enumerate(doc.sents):
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
file_.write("# text = {text}\n".format(text=sent.text))
for k, token in enumerate(sent):
file_.write(token._.get_conllu_lines(k) + '\n')
file_.write('\n')
def print_progress(itn, losses, ud_scores):
fields = {
'dep_loss': losses.get('parser', 0.0),
'tag_loss': losses.get('tagger', 0.0),
'words': ud_scores['Words'].f1 * 100,
'sents': ud_scores['Sentences'].f1 * 100,
'tags': ud_scores['XPOS'].f1 * 100,
'uas': ud_scores['UAS'].f1 * 100,
'las': ud_scores['LAS'].f1 * 100,
}
header = ['Epoch', 'Loss', 'LAS', 'UAS', 'TAG', 'SENT', 'WORD']
if itn == 0:
print('\t'.join(header))
tpl = '\t'.join((
'{:d}',
'{dep_loss:.1f}',
'{las:.1f}',
'{uas:.1f}',
'{tags:.1f}',
'{sents:.1f}',
'{words:.1f}',
))
print(tpl.format(itn, **fields))
#def get_sent_conllu(sent, sent_id):
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
def get_token_conllu(token, i):
if token._.begins_fused:
n = 1
while token.nbor(n)._.inside_fused:
n += 1
id_ = '%d-%d' % (i, i+n)
lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_']
else:
lines = []
if token.head.i == token.i:
head = 0
else:
head = i + (token.head.i - token.i) + 1
fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_',
str(head), token.dep_.lower(), '_', '_']
lines.append('\t'.join(fields))
return '\n'.join(lines)
Token.set_extension('get_conllu_lines', method=get_token_conllu)
Token.set_extension('begins_fused', default=False)
Token.set_extension('inside_fused', default=False)
##################
# Initialization #
##################
def load_nlp(corpus, config):
lang = corpus.split('_')[0]
nlp = spacy.blank(lang)
if config.vectors:
nlp.vocab.from_disk(Path(config.vectors) / 'vocab')
return nlp
def initialize_pipeline(nlp, docs, golds, config, device):
nlp.add_pipe(nlp.create_pipe('parser'))
if config.multitask_tag:
nlp.parser.add_multitask_objective('tag')
if config.multitask_sent:
nlp.parser.add_multitask_objective('sent_start')
nlp.add_pipe(nlp.create_pipe('tagger'))
for gold in golds:
for tag in gold.tags:
if tag is not None:
nlp.tagger.add_label(tag)
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device)
########################
# Command line helpers #
########################
class Config(object):
def __init__(self, vectors=None, max_doc_length=10, multitask_tag=True,
multitask_sent=True, nr_epoch=30, batch_size=1000, dropout=0.2):
for key, value in locals().items():
setattr(self, key, value)
@classmethod
def load(cls, loc):
with Path(loc).open('r', encoding='utf8') as file_:
cfg = json.load(file_)
return cls(**cfg)
class Dataset(object):
def __init__(self, path, section):
self.path = path
self.section = section
self.conllu = None
self.text = None
for file_path in self.path.iterdir():
name = file_path.parts[-1]
if section in name and name.endswith('conllu'):
self.conllu = file_path
elif section in name and name.endswith('txt'):
self.text = file_path
if self.conllu is None:
msg = "Could not find .txt file in {path} for {section}"
raise IOError(msg.format(section=section, path=path))
if self.text is None:
msg = "Could not find .txt file in {path} for {section}"
self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0]
class TreebankPaths(object):
def __init__(self, ud_path, treebank, **cfg):
self.train = Dataset(ud_path / treebank, 'train')
self.dev = Dataset(ud_path / treebank, 'dev')
self.lang = self.train.lang
@plac.annotations(
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
"positional", None, str),
parses_dir=("Directory to write the development parses", "positional", None, Path),
config=("Path to json formatted config file", "positional"),
limit=("Size limit", "option", "n", int),
use_gpu=("Use GPU", "option", "g", int)
)
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1):
spacy.util.fix_random_seed()
lang.zh.Chinese.Defaults.use_jieba = False
lang.ja.Japanese.Defaults.use_janome = False
config = Config.load(config)
paths = TreebankPaths(ud_dir, corpus)
if not (parses_dir / corpus).exists():
(parses_dir / corpus).mkdir()
print("Train and evaluate", corpus, "using lang", paths.lang)
nlp = load_nlp(paths.lang, config)
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
max_doc_length=config.max_doc_length, limit=limit)
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
for i in range(config.nr_epoch):
docs = [nlp.make_doc(doc.text) for doc in docs]
Xs = list(zip(docs, golds))
random.shuffle(Xs)
batches = minibatch_by_words(Xs, size=batch_sizes)
losses = {}
n_train_words = sum(len(doc) for doc in docs)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
for batch in batches:
batch_docs, batch_gold = zip(*batch)
pbar.update(sum(len(doc) for doc in batch_docs))
nlp.update(batch_docs, batch_gold, sgd=optimizer,
drop=config.dropout, losses=losses)
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
with nlp.use_params(optimizer.averages):
parsed_docs, scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
print_progress(i, losses, scores)
_render_parses(i, parsed_docs[:50])
def _render_parses(i, to_render):
to_render[0].user_data['title'] = "Batch %d" % i
with Path('/tmp/parses.html').open('w') as file_:
html = displacy.render(to_render[:5], style='dep', page=True)
file_.write(html)
if __name__ == '__main__':
plac.call(main)

View File

@ -462,7 +462,7 @@ class Language(object):
self._optimizer = sgd
for name, proc in self.pipeline:
if hasattr(proc, 'begin_training'):
proc.begin_training(get_gold_tuples(),
proc.begin_training(get_gold_tuples,
pipeline=self.pipeline,
sgd=self._optimizer,
**cfg)

View File

@ -172,7 +172,7 @@ class Pipe(object):
return create_default_optimizer(self.model.ops,
**self.cfg.get('optimizer', {}))
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None,
**kwargs):
"""Initialize the pipe for training, using data exampes if available.
If no model has been initialized yet, the model is added."""
@ -374,7 +374,7 @@ class Tensorizer(Pipe):
loss = (d_scores**2).sum()
return loss, d_scores
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
def begin_training(self, gold_tuples=lambda: [], pipeline=None, sgd=None,
**kwargs):
"""Allocate models, pre-process training data and acquire an
optimizer.
@ -498,11 +498,11 @@ class Tagger(Pipe):
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
return float(loss), d_scores
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None,
**kwargs):
orig_tag_map = dict(self.vocab.morphology.tag_map)
new_tag_map = OrderedDict()
for raw_text, annots_brackets in gold_tuples:
for raw_text, annots_brackets in get_gold_tuples():
for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots
for tag in tags:
@ -673,9 +673,9 @@ class MultitaskObjective(Tagger):
def set_annotations(self, docs, dep_ids, tensors=None):
pass
def begin_training(self, gold_tuples=tuple(), pipeline=None, tok2vec=None,
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, tok2vec=None,
sgd=None, **kwargs):
gold_tuples = nonproj.preprocess_training_data(gold_tuples)
gold_tuples = nonproj.preprocess_training_data(get_gold_tuples())
for raw_text, annots_brackets in gold_tuples:
for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots
@ -898,7 +898,7 @@ class TextCategorizer(Pipe):
self.labels.append(label)
return 1
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None):
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
token_vector_width = pipeline[0].model.nO
else:
@ -925,10 +925,10 @@ cdef class DependencyParser(Parser):
labeller = MultitaskObjective(self.vocab, target=target)
self._multitasks.append(labeller)
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg):
for labeller in self._multitasks:
tok2vec = self.model[0]
labeller.begin_training(gold_tuples, pipeline=pipeline,
labeller.begin_training(get_gold_tuples, pipeline=pipeline,
tok2vec=tok2vec, sgd=sgd)
def __reduce__(self):
@ -946,10 +946,10 @@ cdef class EntityRecognizer(Parser):
labeller = MultitaskObjective(self.vocab, target=target)
self._multitasks.append(labeller)
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg):
for labeller in self._multitasks:
tok2vec = self.model[0]
labeller.begin_training(gold_tuples, pipeline=pipeline,
labeller.begin_training(get_gold_tuples, pipeline=pipeline,
tok2vec=tok2vec)
def __reduce__(self):

View File

@ -164,16 +164,17 @@ cdef void sum_state_features(float* output,
cdef const float* feature
padding = cached
cached += F * O
cdef int id_stride = F*O
cdef float one = 1.
for b in range(B):
for f in range(F):
if token_ids[f] < 0:
feature = &padding[f*O]
else:
idx = token_ids[f] * F * O + f*O
idx = token_ids[f] * id_stride + f*O
feature = &cached[idx]
for i in range(O):
output[i] += feature[i]
output += O
openblas.simple_axpy(&output[b*O], O,
feature, one)
token_ids += F
@ -726,7 +727,7 @@ cdef class Parser:
lower, stream, drop=0.0)
return (tokvecs, bp_tokvecs), state2vec, upper
nr_feature = 13
nr_feature = 8
def get_token_ids(self, states):
cdef StateClass state
@ -821,15 +822,13 @@ cdef class Parser:
copy_array(larger.b[:smaller.nO], smaller.b)
self.model[-1]._layers[-1] = larger
def begin_training(self, gold_tuples, pipeline=None, sgd=None, **cfg):
def begin_training(self, get_gold_tuples, pipeline=None, sgd=None, **cfg):
if 'model' in cfg:
self.model = cfg['model']
gold_tuples = nonproj.preprocess_training_data(gold_tuples,
label_freq_cutoff=100)
actions = self.moves.get_actions(gold_parses=gold_tuples)
for action, labels in actions.items():
for label in labels:
self.moves.add_action(action, label)
cfg.setdefault('min_action_freq', 30)
actions = self.moves.get_actions(gold_parses=get_gold_tuples(),
min_freq=cfg.get('min_action_freq', 30))
self.moves.initialize_actions(actions)
cfg.setdefault('token_vector_width', 128)
if self.model is True:
cfg['pretrained_dims'] = self.vocab.vectors_length
@ -837,9 +836,9 @@ cdef class Parser:
if sgd is None:
sgd = self.create_optimizer()
self.model[1].begin_training(
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
if pipeline is not None:
self.init_multitask_objectives(gold_tuples, pipeline, sgd=sgd, **cfg)
self.init_multitask_objectives(get_gold_tuples, pipeline, sgd=sgd, **cfg)
link_vectors_to_models(self.vocab)
else:
if sgd is None:
@ -853,7 +852,7 @@ cdef class Parser:
# Defined in subclasses, to avoid circular import
raise NotImplementedError
def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
def init_multitask_objectives(self, get_gold_tuples, pipeline, **cfg):
'''Setup models for secondary objectives, to benefit from multi-task
learning. This method is intended to be overridden by subclasses.