Improve label management in parser and NER (#2108)

This patch does a few smallish things that tighten up the training workflow a little, and allow memory use during training to be reduced by letting the GoldCorpus stream data properly.

Previously, the parser and entity recognizer read and saved labels as lists, with extra labels noted separately. Lists were used becaue ordering is very important, to ensure that the label-to-class mapping is stable.

We now manage labels as nested dictionaries, first keyed by the action, and then keyed by the label. Values are frequencies. The trick is, how do we save new labels? We need to make sure we iterate over these in the same order they're added. Otherwise, we'll get different class IDs, and the model's predictions won't make sense.

To allow stable sorting, we map the new labels to negative values. If we have two new labels, they'll be noted as having "frequency" -1 and -2. The next new label will then have "frequency" -3. When we sort by (frequency, label), we then get a stable sort.

Storing frequencies then allows us to make the next nice improvement. Previously we had to iterate over the whole training set, to pre-process it for the deprojectivisation. This led to storing the whole training set in memory. This was most of the required memory during training.

To prevent this, we now store the frequencies as we stream in the data, and deprojectivize as we go. Once we've built the frequencies, we can then apply a frequency cut-off when we decide how many classes to make.

Finally, to allow proper data streaming, we also have to have some way of shuffling the iterator. This is awkward if the training files have multiple documents in them. To solve this, the GoldCorpus class now writes the training data to disk in msgpack files, one per document. We can then shuffle the data by shuffling the paths.

This is a squash merge, as I made a lot of very small commits. Individual commit messages below.

* Simplify label management for TransitionSystem and its subclasses

* Fix serialization for new label handling format in parser

* Simplify and improve GoldCorpus class. Reduce memory use, write to temp dir

* Set actions in transition system

* Require thinc 6.11.1.dev4

* Fix error in parser init

* Add unicode declaration

* Fix unicode declaration

* Update textcat test

* Try to get model training on less memory

* Print json loc for now

* Try rapidjson to reduce memory use

* Remove rapidjson requirement

* Try rapidjson for reduced mem usage

* Handle None heads when projectivising

* Stream json docs

* Fix train script

* Handle projectivity in GoldParse

* Fix projectivity handling

* Add minibatch_by_words util from ud_train

* Minibatch by number of words in spacy.cli.train

* Move minibatch_by_words util to spacy.util

* Fix label handling

* More hacking at label management in parser

* Fix encoding in msgpack serialization in GoldParse

* Adjust batch sizes in parser training

* Fix minibatch_by_words

* Add merge_subtokens function to pipeline.pyx

* Register merge_subtokens factory

* Restore use of msgpack tmp directory

* Use minibatch-by-words in train

* Handle retokenization in scorer

* Change back-off approach for missing labels. Use 'dep' label

* Update NER for new label management

* Set NER tags for over-segmented words

* Fix label alignment in gold

* Fix label back-off for infrequent labels

* Fix int type in labels dict key

* Fix int type in labels dict key

* Update feature definition for 8 feature set

* Update ud-train script for new label stuff

* Fix json streamer

* Print the line number if conll eval fails

* Update children and sentence boundaries after deprojectivisation

* Export set_children_from_heads from doc.pxd

* Render parses during UD training

* Remove print statement

* Require thinc 6.11.1.dev6. Try adding wheel as install_requires

* Set different dev version, to flush pip cache

* Update thinc version

* Update GoldCorpus docs

* Remove print statements

* Fix formatting and links [ci skip]
This commit is contained in:
Matthew Honnibal 2018-03-19 02:58:08 +01:00 committed by GitHub
parent 13c060b90c
commit bede11b67c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 452 additions and 266 deletions

View File

@ -3,7 +3,7 @@ pathlib
numpy>=1.7
cymem>=1.30,<1.32
preshed>=1.0.0,<2.0.0
thinc>=6.11.1.dev3,<6.12.0
thinc>=6.11.1.dev7,<6.12.0
murmurhash>=0.28,<0.29
cytoolz>=0.9.0,<0.10.0
plac<1.0.0,>=0.9.6

View File

@ -190,7 +190,7 @@ def setup_package():
'murmurhash>=0.28,<0.29',
'cymem>=1.30,<1.32',
'preshed>=1.0.0,<2.0.0',
'thinc>=6.11.1.dev3,<6.12.0',
'thinc>=6.11.1.dev7,<6.12.0',
'plac<1.0.0,>=0.9.6',
'pathlib',
'ujson>=1.35',
@ -200,6 +200,7 @@ def setup_package():
'ftfy>=4.4.2,<5.0.0',
'msgpack-python==0.5.4',
'msgpack-numpy==0.4.1'],
setup_requires=['wheel'],
classifiers=[
'Development Status :: 5 - Production/Stable',
'Environment :: Console',

View File

@ -3,7 +3,7 @@
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
__title__ = 'spacy'
__version__ = '2.1.0.dev3'
__version__ = '2.1.0.dev4'
__summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython'
__uri__ = 'https://spacy.io'
__author__ = 'Explosion AI'

View File

@ -168,7 +168,8 @@ def load_conllu(file):
if word.parent is None:
head = int(word.columns[HEAD])
if head > len(ud.words) - sentence_start:
raise UDError("HEAD '{}' points outside of the sentence".format(word.columns[HEAD]))
raise UDError("Line {}: HEAD '{}' points outside of the sentence".format(
linenum, word.columns[HEAD]))
if head:
parent = ud.words[sentence_start + head - 1]
word.parent = "remapping"

View File

@ -8,8 +8,8 @@ from thinc.neural._classes.model import Model
from timeit import default_timer as timer
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
from ..gold import GoldCorpus, minibatch
from ..util import prints
from ..gold import GoldCorpus
from ..util import prints, minibatch, minibatch_by_words
from .. import util
from .. import about
from .. import displacy
@ -51,8 +51,6 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
train_path = util.ensure_path(train_data)
dev_path = util.ensure_path(dev_data)
meta_path = util.ensure_path(meta_path)
if not output_path.exists():
output_path.mkdir()
if not train_path.exists():
prints(train_path, title="Training data not found", exits=1)
if dev_path and not dev_path.exists():
@ -65,7 +63,14 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
title="Not a valid meta.json format", exits=1)
meta.setdefault('lang', lang)
meta.setdefault('name', 'unnamed')
if not output_path.exists():
output_path.mkdir()
print("Counting training words (limit=%s" % n_sents)
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
n_train_words = corpus.count_train()
print(n_train_words)
pipeline = ['tagger', 'parser', 'ner']
if no_tagger and 'tagger' in pipeline:
pipeline.remove('tagger')
@ -81,13 +86,9 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
dropout_rates = util.decaying(util.env_opt('dropout_from', 0.2),
util.env_opt('dropout_to', 0.2),
util.env_opt('dropout_decay', 0.0))
batch_sizes = util.compounding(util.env_opt('batch_from', 1),
util.env_opt('batch_to', 16),
batch_sizes = util.compounding(util.env_opt('batch_from', 1000),
util.env_opt('batch_to', 1000),
util.env_opt('batch_compound', 1.001))
max_doc_len = util.env_opt('max_doc_len', 5000)
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
n_train_words = corpus.count_train()
lang_class = util.get_lang_class(lang)
nlp = lang_class()
meta['pipeline'] = pipeline
@ -105,6 +106,7 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
lex.is_oov = False
for name in pipeline:
nlp.add_pipe(nlp.create_pipe(name), name=name)
nlp.add_pipe(nlp.create_pipe('merge_subtokens'))
if parser_multitasks:
for objective in parser_multitasks.split(','):
nlp.parser.add_multitask_objective(objective)
@ -117,19 +119,19 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
print("Itn.\tP.Loss\tN.Loss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
try:
for i in range(n_iter):
train_docs = corpus.train_docs(nlp, projectivize=True, noise_level=0.0,
train_docs = corpus.train_docs(nlp, noise_level=0.0,
gold_preproc=gold_preproc, max_length=0)
words_seen = 0
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
losses = {}
for batch in minibatch(train_docs, size=batch_sizes):
batch = [(d, g) for (d, g) in batch if len(d) < max_doc_len]
for batch in minibatch_by_words(train_docs, size=batch_sizes):
if not batch:
continue
docs, golds = zip(*batch)
nlp.update(docs, golds, sgd=optimizer,
drop=next(dropout_rates), losses=losses)
pbar.update(sum(len(doc) for doc in docs))
words_seen += sum(len(doc) for doc in docs)
with nlp.use_params(optimizer.averages):
util.set_env_log(False)
epoch_model_path = output_path / ('model%d' % i)

View File

@ -13,9 +13,10 @@ import spacy
import spacy.util
from ..tokens import Token, Doc
from ..gold import GoldParse
from ..util import compounding
from ..util import compounding, minibatch_by_words
from ..syntax.nonproj import projectivize
from ..matcher import Matcher
from .. import displacy
from collections import defaultdict, Counter
from timeit import default_timer as timer
@ -37,30 +38,6 @@ lang.ja.Japanese.Defaults.use_janome = False
random.seed(0)
numpy.random.seed(0)
def minibatch_by_words(items, size):
random.shuffle(items)
if isinstance(size, int):
size_ = itertools.repeat(size)
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size_)
batch = []
while batch_size >= 0:
try:
doc, gold = next(items)
except StopIteration:
if batch:
yield batch
return
batch_size -= len(doc)
batch.append((doc, gold))
if batch:
yield batch
else:
break
################
# Data reading #
################
@ -199,7 +176,7 @@ def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
with sys_loc.open('r', encoding='utf8') as sys_file:
sys_ud = conll17_ud_eval.load_conllu(sys_file)
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
return scores
return docs, scores
def write_conllu(docs, file_):
@ -288,19 +265,11 @@ def initialize_pipeline(nlp, docs, golds, config):
nlp.parser.add_multitask_objective('tag')
if config.multitask_sent:
nlp.parser.add_multitask_objective('sent_start')
nlp.parser.moves.add_action(2, 'subtok')
nlp.add_pipe(nlp.create_pipe('tagger'))
for gold in golds:
for tag in gold.tags:
if tag is not None:
nlp.tagger.add_label(tag)
# Replace labels that didn't make the frequency cutoff
actions = set(nlp.parser.labels)
label_set = set([act.split('-')[1] for act in actions if '-' in act])
for gold in golds:
for i, label in enumerate(gold.labels):
if label is not None and label not in label_set:
gold.labels[i] = label.split('||')[0]
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
@ -372,7 +341,9 @@ def main(ud_dir, parses_dir, config, corpus, limit=0):
batch_sizes = compounding(config.batch_size //10, config.batch_size, 1.001)
for i in range(config.nr_epoch):
docs = [nlp.make_doc(doc.text) for doc in docs]
batches = minibatch_by_words(list(zip(docs, golds)), size=batch_sizes)
Xs = list(zip(docs, golds))
random.shuffle(Xs)
batches = minibatch_by_words(Xs, size=batch_sizes)
losses = {}
n_train_words = sum(len(doc) for doc in docs)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
@ -384,8 +355,16 @@ def main(ud_dir, parses_dir, config, corpus, limit=0):
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
with nlp.use_params(optimizer.averages):
scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
parsed_docs, scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
print_progress(i, losses, scores)
_render_parses(i, parsed_docs[:50])
def _render_parses(i, to_render):
to_render[0].user_data['title'] = "Batch %d" % i
with Path('/tmp/parses.html').open('w') as file_:
html = displacy.render(to_render[:5], style='dep', page=True)
file_.write(html)
if __name__ == '__main__':

View File

@ -3,18 +3,25 @@
from __future__ import unicode_literals, print_function
import re
import ujson
import random
import cytoolz
import itertools
import numpy
import tempfile
import shutil
from pathlib import Path
import msgpack
import ujson
from . import _align
from .syntax import nonproj
from .tokens import Doc
from . import util
from .util import minibatch, itershuffle
from .compat import json_dumps
from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek
def tags_to_entities(tags):
entities = []
@ -85,106 +92,38 @@ def align(cand_words, gold_words):
class GoldCorpus(object):
"""An annotated corpus, using the JSON file format. Manages
annotations for tagging, dependency parsing and NER."""
def __init__(self, train_path, dev_path, gold_preproc=True, limit=None):
def __init__(self, train, dev, gold_preproc=False, limit=None):
"""Create a GoldCorpus.
train_path (unicode or Path): File or directory of training data.
dev_path (unicode or Path): File or directory of development data.
RETURNS (GoldCorpus): The newly created object.
"""
self.train_path = util.ensure_path(train_path)
self.dev_path = util.ensure_path(dev_path)
self.limit = limit
self.train_locs = self.walk_corpus(self.train_path)
self.dev_locs = self.walk_corpus(self.dev_path)
if isinstance(train, str) or isinstance(train, Path):
train = self.read_tuples(self.walk_corpus(train))
dev = self.read_tuples(self.walk_corpus(dev))
@property
def train_tuples(self):
i = 0
for loc in self.train_locs:
gold_tuples = read_json_file(loc)
for item in gold_tuples:
yield item
i += len(item[1])
if self.limit and i >= self.limit:
break
# Write temp directory with one doc per file, so we can shuffle
# and stream
self.tmp_dir = Path(tempfile.mkdtemp())
self.write_msgpack(self.tmp_dir / 'train', train)
self.write_msgpack(self.tmp_dir / 'dev', dev)
@property
def dev_tuples(self):
i = 0
for loc in self.dev_locs:
gold_tuples = read_json_file(loc)
for item in gold_tuples:
yield item
i += len(item[1])
if self.limit and i >= self.limit:
break
def count_train(self):
n = 0
i = 0
for raw_text, paragraph_tuples in self.train_tuples:
n += sum([len(s[0][1]) for s in paragraph_tuples])
if self.limit and i >= self.limit:
break
i += len(paragraph_tuples)
return n
def train_docs(self, nlp, gold_preproc=False,
projectivize=False, max_length=None,
noise_level=0.0):
if projectivize:
train_tuples = nonproj.preprocess_training_data(
self.train_tuples, label_freq_cutoff=30)
random.shuffle(self.train_locs)
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc,
max_length=max_length,
noise_level=noise_level)
yield from itershuffle(gold_docs, bufsize=100)
def dev_docs(self, nlp, gold_preproc=False):
gold_docs = self.iter_gold_docs(nlp, self.dev_tuples, gold_preproc)
yield from gold_docs
@classmethod
def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None,
noise_level=0.0):
for raw_text, paragraph_tuples in tuples:
if gold_preproc:
raw_text = None
else:
paragraph_tuples = merge_sents(paragraph_tuples)
docs = cls._make_docs(nlp, raw_text, paragraph_tuples,
gold_preproc, noise_level=noise_level)
golds = cls._make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds):
if (not max_length) or len(doc) < max_length:
yield doc, gold
@classmethod
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc,
noise_level=0.0):
if raw_text is not None:
raw_text = add_noise(raw_text, noise_level)
return [nlp.make_doc(raw_text)]
else:
return [Doc(nlp.vocab,
words=add_noise(sent_tuples[1], noise_level))
for (sent_tuples, brackets) in paragraph_tuples]
@classmethod
def _make_golds(cls, docs, paragraph_tuples):
assert len(docs) == len(paragraph_tuples)
if len(docs) == 1:
return [GoldParse.from_annot_tuples(docs[0],
paragraph_tuples[0][0])]
else:
return [GoldParse.from_annot_tuples(doc, sent_tuples)
for doc, (sent_tuples, brackets)
in zip(docs, paragraph_tuples)]
def __del__(self):
shutil.rmtree(self.tmp_dir)
@staticmethod
def write_msgpack(directory, doc_tuples):
if not directory.exists():
directory.mkdir()
for i, doc_tuple in enumerate(doc_tuples):
with open(directory / '{}.msg'.format(i), 'wb') as file_:
msgpack.dump([doc_tuple], file_, use_bin_type=True, encoding='utf8')
@staticmethod
def walk_corpus(path):
path = util.ensure_path(path)
if not path.is_dir():
return [path]
paths = [path]
@ -202,6 +141,101 @@ class GoldCorpus(object):
locs.append(path)
return locs
@staticmethod
def read_tuples(locs, limit=0):
i = 0
for loc in locs:
loc = util.ensure_path(loc)
if loc.parts[-1].endswith('json'):
gold_tuples = read_json_file(loc)
elif loc.parts[-1].endswith('msg'):
with loc.open('rb') as file_:
gold_tuples = msgpack.load(file_, encoding='utf8')
else:
msg = "Cannot read from file: %s. Supported formats: .json, .msg"
raise ValueError(msg % loc)
for item in gold_tuples:
yield item
i += len(item[1])
if limit and i >= limit:
break
@property
def dev_tuples(self):
locs = (self.tmp_dir / 'dev').iterdir()
yield from self.read_tuples(locs, limit=self.limit)
@property
def train_tuples(self):
locs = (self.tmp_dir / 'train').iterdir()
yield from self.read_tuples(locs, limit=self.limit)
def count_train(self):
n = 0
i = 0
for raw_text, paragraph_tuples in self.train_tuples:
for sent_tuples, brackets in paragraph_tuples:
n += len(sent_tuples[1])
if self.limit and i >= self.limit:
break
i += len(paragraph_tuples)
return n
def train_docs(self, nlp, gold_preproc=False, max_length=None,
noise_level=0.0):
locs = list((self.tmp_dir / 'train').iterdir())
random.shuffle(locs)
train_tuples = self.read_tuples(locs, limit=self.limit)
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc,
max_length=max_length,
noise_level=noise_level,
make_projective=True)
yield from gold_docs
def dev_docs(self, nlp, gold_preproc=False):
gold_docs = self.iter_gold_docs(nlp, self.dev_tuples,
gold_preproc=gold_preproc)
yield from gold_docs
@classmethod
def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None,
noise_level=0.0, make_projective=False):
for raw_text, paragraph_tuples in tuples:
if gold_preproc:
raw_text = None
else:
paragraph_tuples = merge_sents(paragraph_tuples)
docs = cls._make_docs(nlp, raw_text, paragraph_tuples,
gold_preproc, noise_level=noise_level)
golds = cls._make_golds(docs, paragraph_tuples, make_projective)
for doc, gold in zip(docs, golds):
if (not max_length) or len(doc) < max_length:
yield doc, gold
@classmethod
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc,
noise_level=0.0):
if raw_text is not None:
raw_text = add_noise(raw_text, noise_level)
return [nlp.make_doc(raw_text)]
else:
return [Doc(nlp.vocab,
words=add_noise(sent_tuples[1], noise_level))
for (sent_tuples, brackets) in paragraph_tuples]
@classmethod
def _make_golds(cls, docs, paragraph_tuples, make_projective):
assert len(docs) == len(paragraph_tuples)
if len(docs) == 1:
return [GoldParse.from_annot_tuples(docs[0],
paragraph_tuples[0][0],
make_projective=make_projective)]
else:
return [GoldParse.from_annot_tuples(doc, sent_tuples,
make_projective=make_projective)
for doc, (sent_tuples, brackets)
in zip(docs, paragraph_tuples)]
def add_noise(orig, noise_level):
if random.random() >= noise_level:
@ -233,11 +267,7 @@ def read_json_file(loc, docs_filter=None, limit=None):
for filename in loc.iterdir():
yield from read_json_file(loc / filename, limit=limit)
else:
with loc.open('r', encoding='utf8') as file_:
docs = ujson.load(file_)
if limit is not None:
docs = docs[:limit]
for doc in docs:
for doc in _json_iterate(loc):
if docs_filter is not None and not docs_filter(doc):
continue
paragraphs = []
@ -267,6 +297,56 @@ def read_json_file(loc, docs_filter=None, limit=None):
yield [paragraph.get('raw', None), sents]
def _json_iterate(loc):
# We should've made these files jsonl...But since we didn't, parse out
# the docs one-by-one to reduce memory usage.
# It's okay to read in the whole file -- just don't parse it into JSON.
cdef bytes py_raw
loc = util.ensure_path(loc)
with loc.open('rb') as file_:
py_raw = file_.read()
raw = <char*>py_raw
cdef int square_depth = 0
cdef int curly_depth = 0
cdef int inside_string = 0
cdef int escape = 0
cdef int start = -1
cdef char c
cdef char quote = ord('"')
cdef char backslash = ord('\\')
cdef char open_square = ord('[')
cdef char close_square = ord(']')
cdef char open_curly = ord('{')
cdef char close_curly = ord('}')
for i in range(len(py_raw)):
c = raw[i]
if c == backslash:
escape = True
continue
if escape:
escape = False
continue
if c == quote:
inside_string = not inside_string
continue
if inside_string:
continue
if c == open_square:
square_depth += 1
elif c == close_square:
square_depth -= 1
elif c == open_curly:
if square_depth == 1 and curly_depth == 0:
start = i
curly_depth += 1
elif c == close_curly:
curly_depth -= 1
if square_depth == 1 and curly_depth == 0:
py_str = py_raw[start : i+1].decode('utf8')
yield ujson.loads(py_str)
start = -1
def iob_to_biluo(tags):
out = []
curr_label = None
@ -370,6 +450,10 @@ cdef class GoldParse:
self.labels = [None] * len(doc)
self.ner = [None] * len(doc)
# This needs to be done before we align the words
if make_projective and heads is not None and deps is not None:
heads, deps = nonproj.projectivize(heads, deps)
# Do many-to-one alignment for misaligned tokens.
# If we over-segment, we'll have one gold word that covers a sequence
# of predicted words
@ -396,14 +480,39 @@ cdef class GoldParse:
if i in i2j_multi:
self.words[i] = words[i2j_multi[i]]
self.tags[i] = tags[i2j_multi[i]]
is_last = i2j_multi[i] != i2j_multi.get(i+1)
is_first = i2j_multi[i] != i2j_multi.get(i-1)
# Set next word in multi-token span as head, until last
if i2j_multi[i] == i2j_multi.get(i+1):
if not is_last:
self.heads[i] = i+1
self.labels[i] = 'subtok'
else:
self.heads[i] = self.gold_to_cand[heads[i2j_multi[i]]]
self.labels[i] = deps[i2j_multi[i]]
# TODO: Set NER!
# Now set NER...This is annoying because if we've split
# got an entity word split into two, we need to adjust the
# BILOU tags. We can't have BB or LL etc.
# Case 1: O -- easy.
ner_tag = entities[i2j_multi[i]]
if ner_tag == 'O':
self.ner[i] = 'O'
# Case 2: U. This has to become a B I* L sequence.
elif ner_tag.startswith('U-'):
if is_first:
self.ner[i] = ner_tag.replace('U-', 'B-', 1)
elif is_last:
self.ner[i] = ner_tag.replace('U-', 'L-', 1)
else:
self.ner[i] = ner_tag.replace('U-', 'I-', 1)
# Case 3: L. If not last, change to I.
elif ner_tag.startswith('L-'):
if is_last:
self.ner[i] = ner_tag
else:
self.ner[i] = ner_tag.replace('L-', 'I-', 1)
# Case 4: I. Stays correct
elif ner_tag.startswith('I-'):
self.ner[i] = ner_tag
else:
self.words[i] = words[gold_i]
self.tags[i] = tags[gold_i]
@ -418,10 +527,6 @@ cdef class GoldParse:
if cycle is not None:
raise Exception("Cycle found: %s" % cycle)
if make_projective:
proj_heads, _ = nonproj.projectivize(self.heads, self.labels)
self.heads = proj_heads
def __len__(self):
"""Get the number of gold-standard tokens.

View File

@ -17,7 +17,7 @@ from .vocab import Vocab
from .lemmatizer import Lemmatizer
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer
from .pipeline import SimilarityHook, TextCategorizer, SentenceSegmenter
from .pipeline import merge_noun_chunks, merge_entities
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
from .compat import json_dumps, izip, basestring_
from .gold import GoldParse
from .scorer import Scorer
@ -108,7 +108,8 @@ class Language(object):
'sbd': lambda nlp, **cfg: SentenceSegmenter(nlp.vocab, **cfg),
'sentencizer': lambda nlp, **cfg: SentenceSegmenter(nlp.vocab, **cfg),
'merge_noun_chunks': lambda nlp, **cfg: merge_noun_chunks,
'merge_entities': lambda nlp, **cfg: merge_entities
'merge_entities': lambda nlp, **cfg: merge_entities,
'merge_subtokens': lambda nlp, **cfg: merge_subtokens,
}
def __init__(self, vocab=True, make_doc=True, meta={}, **kwargs):

View File

@ -25,6 +25,7 @@ from .morphology cimport Morphology
from .vocab cimport Vocab
from .syntax import nonproj
from .compat import json_dumps
from .matcher import Matcher
from .attrs import POS
from .parts_of_speech import X
@ -97,6 +98,17 @@ def merge_entities(doc):
return doc
def merge_subtokens(doc, label='subtok'):
merger = Matcher(doc.vocab)
merger.add('SUBTOK', None, [{'DEP': label, 'op': '+'}])
matches = merger(doc)
spans = [doc[start:end+1] for _, start, end in matches]
offsets = [(span.start_char, span.end_char) for span in spans]
for start_char, end_char in offsets:
doc.merge(start_char, end_char)
return doc
class Pipe(object):
"""This class is not instantiated directly. Components inherit from it, and
it defines the interface that components should follow to function as

View File

@ -1,7 +1,7 @@
# coding: utf8
from __future__ import division, print_function, unicode_literals
from .gold import tags_to_entities
from .gold import tags_to_entities, GoldParse
class PRFScore(object):
@ -84,6 +84,8 @@ class Scorer(object):
}
def score(self, tokens, gold, verbose=False, punct_labels=('p', 'punct')):
if len(tokens) != len(gold):
gold = GoldParse.from_annot_tuples(tokens, zip(*gold.orig_annot))
assert len(tokens) == len(gold)
gold_deps = set()
gold_tags = set()

View File

@ -108,7 +108,7 @@ cdef cppclass StateC:
ids[1] = this.B(1)
ids[2] = this.S(0)
ids[3] = this.S(1)
ids[4] = this.H(this.S(0))
ids[4] = this.S(2)
ids[5] = this.L(this.B(0), 1)
ids[6] = this.L(this.S(0), 1)
ids[7] = this.R(this.S(0), 1)

View File

@ -6,16 +6,19 @@ from __future__ import unicode_literals
from cpython.ref cimport Py_INCREF
from cymem.cymem cimport Pool
from collections import OrderedDict
from collections import OrderedDict, defaultdict, Counter
from thinc.extra.search cimport Beam
import json
from .stateclass cimport StateClass
from ._state cimport StateC
from .nonproj import is_nonproj_tree
from . import nonproj
from .transition_system cimport move_cost_func_t, label_cost_func_t
from ..gold cimport GoldParse, GoldParseC
from ..structs cimport TokenC
# Calculate cost as gold/not gold. We don't use scalar value anyway.
cdef int BINARY_COSTS = 1
DEF NON_MONOTONIC = True
DEF USE_BREAK = True
@ -54,6 +57,8 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
cost += 1
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
cost += 1
if BINARY_COSTS and cost >= 1:
return cost
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
return cost
@ -67,6 +72,8 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog
cost += gold.heads[target] == B_i
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
break
if BINARY_COSTS and cost >= 1:
return cost
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
cost += 1
return cost
@ -315,39 +322,42 @@ cdef class ArcEager(TransitionSystem):
@classmethod
def get_actions(cls, **kwargs):
actions = kwargs.get('actions', OrderedDict((
(SHIFT, ['']),
(REDUCE, ['']),
(RIGHT, []),
(LEFT, ['subtok']),
(BREAK, ['ROOT']))
))
seen_actions = set()
min_freq = kwargs.get('min_freq', None)
actions = defaultdict(lambda: Counter())
actions[SHIFT][''] = 1
actions[REDUCE][''] = 1
for label in kwargs.get('left_labels', []):
if label.upper() != 'ROOT':
if (LEFT, label) not in seen_actions:
actions[LEFT].append(label)
seen_actions.add((LEFT, label))
actions[LEFT][label] = 1
actions[SHIFT][label] = 1
for label in kwargs.get('right_labels', []):
if label.upper() != 'ROOT':
if (RIGHT, label) not in seen_actions:
actions[RIGHT].append(label)
seen_actions.add((RIGHT, label))
actions[RIGHT][label] = 1
actions[REDUCE][label] = 1
for raw_text, sents in kwargs.get('gold_parses', []):
for (ids, words, tags, heads, labels, iob), ctnts in sents:
heads, labels = nonproj.projectivize(heads, labels)
for child, head, label in zip(ids, heads, labels):
if label.upper() == 'ROOT':
if label.upper() == 'ROOT' :
label = 'ROOT'
if label != 'ROOT':
if head < child:
if (RIGHT, label) not in seen_actions:
actions[RIGHT].append(label)
seen_actions.add((RIGHT, label))
elif head > child:
if (LEFT, label) not in seen_actions:
actions[LEFT].append(label)
seen_actions.add((LEFT, label))
if head == child:
actions[BREAK][label] += 1
elif head < child:
actions[RIGHT][label] += 1
actions[REDUCE][''] += 1
elif head > child:
actions[LEFT][label] += 1
actions[SHIFT][''] += 1
if min_freq is not None:
for action, label_freqs in actions.items():
for label, freq in list(label_freqs.items()):
if freq < min_freq:
label_freqs.pop(label)
# Ensure these actions are present
actions[BREAK].setdefault('ROOT', 0)
actions[RIGHT].setdefault('subtok', 0)
actions[LEFT].setdefault('subtok', 0)
# Used for backoff
actions[RIGHT].setdefault('dep', 0)
actions[LEFT].setdefault('dep', 0)
return actions
property action_types:
@ -379,18 +389,34 @@ cdef class ArcEager(TransitionSystem):
def preprocess_gold(self, GoldParse gold):
if not self.has_gold(gold):
return None
for i in range(gold.length):
for i, (head, dep) in enumerate(zip(gold.heads, gold.labels)):
# Missing values
if gold.heads[i] is None or gold.labels[i] is None:
if head is None or dep is None:
gold.c.heads[i] = i
gold.c.has_dep[i] = False
else:
label = gold.labels[i]
if head > i:
action = LEFT
elif head < i:
action = RIGHT
else:
action = BREAK
if dep not in self.labels[action]:
if action == BREAK:
dep = 'ROOT'
elif nonproj.is_decorated(dep):
backoff = nonproj.decompose(dep)[0]
if backoff in self.labels[action]:
dep = backoff
else:
dep = 'dep'
else:
dep = 'dep'
gold.c.has_dep[i] = True
if label.upper() == 'ROOT':
label = 'ROOT'
gold.c.heads[i] = gold.heads[i]
gold.c.labels[i] = self.strings.add(label)
if dep.upper() == 'ROOT':
dep = 'ROOT'
gold.c.heads[i] = head
gold.c.labels[i] = self.strings.add(dep)
return gold
def get_beam_parses(self, Beam beam):
@ -536,7 +562,7 @@ cdef class ArcEager(TransitionSystem):
if label_str is not None and label_str not in label_set:
raise ValueError("Cannot get gold parser action: unknown label: %s" % label_str)
# Check projectivity --- other leading cause
if is_nonproj_tree(gold.heads):
if nonproj.is_nonproj_tree(gold.heads):
raise ValueError(
"Could not find a gold-standard action to supervise the "
"dependency parser. Likely cause: the tree is "

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
from thinc.typedefs cimport weight_t
from thinc.extra.search cimport Beam
from collections import OrderedDict
from collections import OrderedDict, Counter
from .stateclass cimport StateClass
from ._state cimport StateC
@ -64,21 +64,18 @@ cdef class BiluoPushDown(TransitionSystem):
@classmethod
def get_actions(cls, **kwargs):
actions = kwargs.get('actions', OrderedDict((
(MISSING, ['']),
(BEGIN, []),
(IN, []),
(LAST, []),
(UNIT, []),
(OUT, [''])
)))
seen_entities = set()
actions = {
MISSING: Counter(),
BEGIN: Counter(),
IN: Counter(),
LAST: Counter(),
UNIT: Counter(),
OUT: Counter()
}
actions[OUT][''] = 1
for entity_type in kwargs.get('entity_types', []):
if entity_type in seen_entities:
continue
seen_entities.add(entity_type)
for action in (BEGIN, IN, LAST, UNIT):
actions[action].append(entity_type)
actions[action][entity_type] = 1
moves = ('M', 'B', 'I', 'L', 'U')
for raw_text, sents in kwargs.get('gold_parses', []):
for (ids, words, tags, heads, labels, biluo), _ in sents:
@ -87,10 +84,8 @@ cdef class BiluoPushDown(TransitionSystem):
if ner_tag.count('-') != 1:
raise ValueError(ner_tag)
_, label = ner_tag.split('-')
if label not in seen_entities:
seen_entities.add(label)
for move_str in ('B', 'I', 'L', 'U'):
actions[moves.index(move_str)].append(label)
for action in (BEGIN, IN, LAST, UNIT):
actions[action][label] += 1
return actions
property action_types:
@ -213,7 +208,7 @@ cdef class BiluoPushDown(TransitionSystem):
raise Exception(move)
return t
def add_action(self, int action, label_name):
def add_action(self, int action, label_name, freq=None):
cdef attr_t label_id
if not isinstance(label_name, (int, long)):
label_id = self.strings.add(label_name)
@ -234,6 +229,12 @@ cdef class BiluoPushDown(TransitionSystem):
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id)
assert self.c[self.n_moves].label == label_id
self.n_moves += 1
if self.labels.get(action, []):
freq = min(0, min(self.labels[action].values()))
self.labels[action][label_name] = freq-1
else:
self.labels[action] = Counter()
self.labels[action][label_name] = -1
return 1
cdef int initialize_state(self, StateC* st) nogil:

View File

@ -302,7 +302,7 @@ cdef class Parser:
"""
self.vocab = vocab
if moves is True:
self.moves = self.TransitionSystem(self.vocab.strings, {})
self.moves = self.TransitionSystem(self.vocab.strings)
else:
self.moves = moves
if 'beam_width' not in cfg:
@ -311,12 +311,7 @@ cdef class Parser:
cfg['beam_density'] = util.env_opt('beam_density', 0.0)
if 'pretrained_dims' not in cfg:
cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
cfg.setdefault('cnn_maxout_pieces', 3)
self.cfg = cfg
if 'actions' in self.cfg:
for action, labels in self.cfg.get('actions', {}).items():
for label in labels:
self.moves.add_action(action, label)
self.model = model
self._multitasks = []
@ -676,7 +671,6 @@ cdef class Parser:
for beam in beams:
_cleanup(beam)
def _init_gold_batch(self, whole_docs, whole_golds, min_length=5, max_length=500):
"""Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
@ -831,9 +825,6 @@ cdef class Parser:
for action in self.moves.action_types:
added = self.moves.add_action(action, label)
if added:
# Important that the labels be stored as a list! We need the
# order, or the model goes out of synch
self.cfg.setdefault('extra_labels', []).append(label)
resized = True
if self.model not in (True, False, None) and resized:
# Weights are stored in (nr_out, nr_in) format, so we're basically
@ -847,12 +838,10 @@ cdef class Parser:
def begin_training(self, gold_tuples, pipeline=None, sgd=None, **cfg):
if 'model' in cfg:
self.model = cfg['model']
gold_tuples = nonproj.preprocess_training_data(gold_tuples,
label_freq_cutoff=30)
actions = self.moves.get_actions(gold_parses=gold_tuples)
for action, labels in actions.items():
for label in labels:
self.moves.add_action(action, label)
cfg.setdefault('min_action_freq', 30)
actions = self.moves.get_actions(gold_parses=gold_tuples,
min_freq=cfg.get('min_action_freq', 30))
self.moves.initialize_actions(actions)
cfg.setdefault('token_vector_width', 128)
if self.model is True:
cfg['pretrained_dims'] = self.vocab.vectors_length
@ -860,7 +849,7 @@ cdef class Parser:
if sgd is None:
sgd = self.create_optimizer()
self.model[1].begin_training(
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
if pipeline is not None:
self.init_multitask_objectives(gold_tuples, pipeline, sgd=sgd, **cfg)
link_vectors_to_models(self.vocab)

View File

@ -9,7 +9,7 @@ from __future__ import unicode_literals
from copy import copy
from ..tokens.doc cimport Doc
from ..tokens.doc cimport Doc, set_children_from_heads
DELIMITER = '||'
@ -74,7 +74,21 @@ def decompose(label):
def is_decorated(label):
return label.find(DELIMITER) != -1
return DELIMITER in label
def count_decorated_labels(gold_tuples):
freqs = {}
for raw_text, sents in gold_tuples:
for (ids, words, tags, heads, labels, iob), ctnts in sents:
proj_heads, deco_labels = projectivize(heads, labels)
# set the label to ROOT for each root dependent
deco_labels = ['ROOT' if head == i else deco_labels[i]
for i, head in enumerate(proj_heads)]
# count label frequencies
for label in deco_labels:
if is_decorated(label):
freqs[label] = freqs.get(label, 0) + 1
return freqs
def preprocess_training_data(gold_tuples, label_freq_cutoff=30):
@ -124,8 +138,9 @@ cpdef deprojectivize(Doc doc):
if DELIMITER in label:
new_label, head_label = label.split(DELIMITER)
new_head = _find_new_head(doc[i], head_label)
doc[i].head = new_head
doc.c[i].head = new_head.i - i
doc.c[i].dep = doc.vocab.strings.add(new_label)
set_children_from_heads(doc.c, doc.length)
return doc

View File

@ -42,6 +42,7 @@ cdef class TransitionSystem:
cdef public attr_t root_label
cdef public freqs
cdef init_state_t init_beam_state
cdef public object labels
cdef int initialize_state(self, StateC* state) nogil
cdef int finalize_state(self, StateC* state) nogil

View File

@ -5,7 +5,7 @@ from __future__ import unicode_literals
from cpython.ref cimport Py_INCREF
from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t
from collections import OrderedDict
from collections import OrderedDict, Counter
import ujson
from ..structs cimport TokenC
@ -28,7 +28,7 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef class TransitionSystem:
def __init__(self, StringStore string_table, labels_by_action):
def __init__(self, StringStore string_table, labels_by_action=None, min_freq=None):
self.mem = Pool()
self.strings = string_table
self.n_moves = 0
@ -36,21 +36,14 @@ cdef class TransitionSystem:
self.c = <Transition*>self.mem.alloc(self._size, sizeof(Transition))
for action, label_strs in labels_by_action.items():
for label_str in label_strs:
self.add_action(int(action), label_str)
self.labels = {}
if labels_by_action:
self.initialize_actions(labels_by_action, min_freq=min_freq)
self.root_label = self.strings.add('ROOT')
self.init_beam_state = _init_state
def __reduce__(self):
labels_by_action = OrderedDict()
cdef Transition t
for trans in self.c[:self.n_moves]:
label_str = self.strings[trans.label]
labels_by_action.setdefault(trans.move, []).append(label_str)
return (self.__class__,
(self.strings, labels_by_action),
None, None)
return (self.__class__, (self.strings, self.labels), None, None)
def init_batch(self, docs):
cdef StateClass state
@ -146,6 +139,22 @@ cdef class TransitionSystem:
act = self.c[clas]
return self.move_name(act.move, act.label)
def initialize_actions(self, labels_by_action, min_freq=None):
self.labels = {}
self.n_moves = 0
for action, label_freqs in sorted(labels_by_action.items()):
action = int(action)
# Make sure we take a copy here, and that we get a Counter
self.labels[action] = Counter()
# Have to be careful here: Sorting must be stable, or our model
# won't be read back in correctly.
sorted_labels = [(f, L) for L, f in label_freqs.items()]
sorted_labels.sort()
sorted_labels.reverse()
for freq, label_str in sorted_labels:
self.add_action(int(action), label_str)
self.labels[action][label_str] = freq
def add_action(self, int action, label_name):
cdef attr_t label_id
if not isinstance(label_name, int) and \
@ -164,6 +173,14 @@ cdef class TransitionSystem:
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id)
assert self.c[self.n_moves].label == label_id
self.n_moves += 1
if self.labels.get(action, []):
new_freq = min(self.labels[action].values())
else:
self.labels[action] = Counter()
new_freq = -1
if new_freq > 0:
new_freq = 0
self.labels[action][label_name] = new_freq-1
return 1
def to_disk(self, path, **exclude):
@ -178,26 +195,18 @@ cdef class TransitionSystem:
def to_bytes(self, **exclude):
transitions = []
for trans in self.c[:self.n_moves]:
transitions.append({
'clas': trans.clas,
'move': trans.move,
'label': self.strings[trans.label],
'name': self.move_name(trans.move, trans.label)
})
serializers = {
'transitions': lambda: json_dumps(transitions),
'moves': lambda: json_dumps(self.labels),
'strings': lambda: self.strings.to_bytes()
}
return util.to_bytes(serializers, exclude)
def from_bytes(self, bytes_data, **exclude):
transitions = []
labels = {}
deserializers = {
'transitions': lambda b: transitions.extend(ujson.loads(b)),
'moves': lambda b: labels.update(ujson.loads(b)),
'strings': lambda b: self.strings.from_bytes(b)
}
msg = util.from_bytes(bytes_data, deserializers, exclude)
for trans in transitions:
self.add_action(trans['move'], trans['label'])
self.initialize_actions(labels)
return self

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals
import random
import numpy.random
from ..pipeline import TextCategorizer
from ..lang.en import English
@ -9,6 +10,8 @@ from ..gold import GoldParse
def test_textcat_learns_multilabel():
random.seed(0)
numpy.random.seed(0)
docs = []
nlp = English()
vocab = nlp.vocab
@ -22,7 +25,7 @@ def test_textcat_learns_multilabel():
for letter in letters:
model.add_label(letter)
optimizer = model.begin_training()
for i in range(20):
for i in range(30):
losses = {}
Ys = [GoldParse(doc, cats=cats) for doc, cats in docs]
Xs = [doc for doc, cats in docs]

View File

@ -19,6 +19,9 @@ ctypedef fused LexemeOrToken:
const_TokenC_ptr
cdef int set_children_from_heads(TokenC* tokens, int length) except -1
cdef int token_by_start(const TokenC* tokens, int length, int start_char) except -2

View File

@ -436,6 +436,29 @@ def decaying(start, stop, decay):
nr_upd += 1
def minibatch_by_words(items, size, count_words=len):
'''Create minibatches of a given number of words.'''
if isinstance(size, int):
size_ = itertools.repeat(size)
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size_)
batch = []
while batch_size >= 0:
try:
doc, gold = next(items)
except StopIteration:
if batch:
yield batch
return
batch_size -= count_words(doc)
batch.append((doc, gold))
if batch:
yield batch
def itershuffle(iterable, bufsize=1000):
"""Shuffle an iterator. This works by holding `bufsize` items back
and yielding them sometime later. Obviously, this is not unbiased

View File

@ -12,11 +12,24 @@ p Create a #[code GoldCorpus].
+table(["Name", "Type", "Description"])
+row
+cell #[code train_path]
+cell unicode or #[code Path]
+cell File or directory of training data.
+cell #[code train]
+cell unicode or #[code Path] or iterable
+cell
| Training data, as a path (file or directory) or iterable. If an
| iterable, each item should be a #[code (text, paragraphs)]
| tuple, where each paragraph is a tuple
| #[code.u-break (sentences, brackets)],and each sentence is a
| tuple #[code.u-break (ids, words, tags, heads, ner)]. See the
| implementation of
| #[+src(gh("spacy", "spacy/gold.pyx")) #[code gold.read_json_file]]
| for further details.
+row
+cell #[code dev_path]
+cell unicode or #[code Path]
+cell File or directory of development data.
+cell #[code dev]
+cell unicode or #[code Path] or iterable
+cell Development data, as a path (file or directory) or iterable.
+row("foot")
+cell returns
+cell #[code GoldCorpus]
+cell The newly constructed object.