mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Improve label management in parser and NER (#2108)
This patch does a few smallish things that tighten up the training workflow a little, and allow memory use during training to be reduced by letting the GoldCorpus stream data properly. Previously, the parser and entity recognizer read and saved labels as lists, with extra labels noted separately. Lists were used becaue ordering is very important, to ensure that the label-to-class mapping is stable. We now manage labels as nested dictionaries, first keyed by the action, and then keyed by the label. Values are frequencies. The trick is, how do we save new labels? We need to make sure we iterate over these in the same order they're added. Otherwise, we'll get different class IDs, and the model's predictions won't make sense. To allow stable sorting, we map the new labels to negative values. If we have two new labels, they'll be noted as having "frequency" -1 and -2. The next new label will then have "frequency" -3. When we sort by (frequency, label), we then get a stable sort. Storing frequencies then allows us to make the next nice improvement. Previously we had to iterate over the whole training set, to pre-process it for the deprojectivisation. This led to storing the whole training set in memory. This was most of the required memory during training. To prevent this, we now store the frequencies as we stream in the data, and deprojectivize as we go. Once we've built the frequencies, we can then apply a frequency cut-off when we decide how many classes to make. Finally, to allow proper data streaming, we also have to have some way of shuffling the iterator. This is awkward if the training files have multiple documents in them. To solve this, the GoldCorpus class now writes the training data to disk in msgpack files, one per document. We can then shuffle the data by shuffling the paths. This is a squash merge, as I made a lot of very small commits. Individual commit messages below. * Simplify label management for TransitionSystem and its subclasses * Fix serialization for new label handling format in parser * Simplify and improve GoldCorpus class. Reduce memory use, write to temp dir * Set actions in transition system * Require thinc 6.11.1.dev4 * Fix error in parser init * Add unicode declaration * Fix unicode declaration * Update textcat test * Try to get model training on less memory * Print json loc for now * Try rapidjson to reduce memory use * Remove rapidjson requirement * Try rapidjson for reduced mem usage * Handle None heads when projectivising * Stream json docs * Fix train script * Handle projectivity in GoldParse * Fix projectivity handling * Add minibatch_by_words util from ud_train * Minibatch by number of words in spacy.cli.train * Move minibatch_by_words util to spacy.util * Fix label handling * More hacking at label management in parser * Fix encoding in msgpack serialization in GoldParse * Adjust batch sizes in parser training * Fix minibatch_by_words * Add merge_subtokens function to pipeline.pyx * Register merge_subtokens factory * Restore use of msgpack tmp directory * Use minibatch-by-words in train * Handle retokenization in scorer * Change back-off approach for missing labels. Use 'dep' label * Update NER for new label management * Set NER tags for over-segmented words * Fix label alignment in gold * Fix label back-off for infrequent labels * Fix int type in labels dict key * Fix int type in labels dict key * Update feature definition for 8 feature set * Update ud-train script for new label stuff * Fix json streamer * Print the line number if conll eval fails * Update children and sentence boundaries after deprojectivisation * Export set_children_from_heads from doc.pxd * Render parses during UD training * Remove print statement * Require thinc 6.11.1.dev6. Try adding wheel as install_requires * Set different dev version, to flush pip cache * Update thinc version * Update GoldCorpus docs * Remove print statements * Fix formatting and links [ci skip]
This commit is contained in:
parent
13c060b90c
commit
bede11b67c
|
@ -3,7 +3,7 @@ pathlib
|
|||
numpy>=1.7
|
||||
cymem>=1.30,<1.32
|
||||
preshed>=1.0.0,<2.0.0
|
||||
thinc>=6.11.1.dev3,<6.12.0
|
||||
thinc>=6.11.1.dev7,<6.12.0
|
||||
murmurhash>=0.28,<0.29
|
||||
cytoolz>=0.9.0,<0.10.0
|
||||
plac<1.0.0,>=0.9.6
|
||||
|
|
3
setup.py
3
setup.py
|
@ -190,7 +190,7 @@ def setup_package():
|
|||
'murmurhash>=0.28,<0.29',
|
||||
'cymem>=1.30,<1.32',
|
||||
'preshed>=1.0.0,<2.0.0',
|
||||
'thinc>=6.11.1.dev3,<6.12.0',
|
||||
'thinc>=6.11.1.dev7,<6.12.0',
|
||||
'plac<1.0.0,>=0.9.6',
|
||||
'pathlib',
|
||||
'ujson>=1.35',
|
||||
|
@ -200,6 +200,7 @@ def setup_package():
|
|||
'ftfy>=4.4.2,<5.0.0',
|
||||
'msgpack-python==0.5.4',
|
||||
'msgpack-numpy==0.4.1'],
|
||||
setup_requires=['wheel'],
|
||||
classifiers=[
|
||||
'Development Status :: 5 - Production/Stable',
|
||||
'Environment :: Console',
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
|
||||
|
||||
__title__ = 'spacy'
|
||||
__version__ = '2.1.0.dev3'
|
||||
__version__ = '2.1.0.dev4'
|
||||
__summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython'
|
||||
__uri__ = 'https://spacy.io'
|
||||
__author__ = 'Explosion AI'
|
||||
|
|
|
@ -168,7 +168,8 @@ def load_conllu(file):
|
|||
if word.parent is None:
|
||||
head = int(word.columns[HEAD])
|
||||
if head > len(ud.words) - sentence_start:
|
||||
raise UDError("HEAD '{}' points outside of the sentence".format(word.columns[HEAD]))
|
||||
raise UDError("Line {}: HEAD '{}' points outside of the sentence".format(
|
||||
linenum, word.columns[HEAD]))
|
||||
if head:
|
||||
parent = ud.words[sentence_start + head - 1]
|
||||
word.parent = "remapping"
|
||||
|
|
|
@ -8,8 +8,8 @@ from thinc.neural._classes.model import Model
|
|||
from timeit import default_timer as timer
|
||||
|
||||
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
|
||||
from ..gold import GoldCorpus, minibatch
|
||||
from ..util import prints
|
||||
from ..gold import GoldCorpus
|
||||
from ..util import prints, minibatch, minibatch_by_words
|
||||
from .. import util
|
||||
from .. import about
|
||||
from .. import displacy
|
||||
|
@ -51,8 +51,6 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
|||
train_path = util.ensure_path(train_data)
|
||||
dev_path = util.ensure_path(dev_data)
|
||||
meta_path = util.ensure_path(meta_path)
|
||||
if not output_path.exists():
|
||||
output_path.mkdir()
|
||||
if not train_path.exists():
|
||||
prints(train_path, title="Training data not found", exits=1)
|
||||
if dev_path and not dev_path.exists():
|
||||
|
@ -65,7 +63,14 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
|||
title="Not a valid meta.json format", exits=1)
|
||||
meta.setdefault('lang', lang)
|
||||
meta.setdefault('name', 'unnamed')
|
||||
|
||||
if not output_path.exists():
|
||||
output_path.mkdir()
|
||||
|
||||
print("Counting training words (limit=%s" % n_sents)
|
||||
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
||||
n_train_words = corpus.count_train()
|
||||
print(n_train_words)
|
||||
pipeline = ['tagger', 'parser', 'ner']
|
||||
if no_tagger and 'tagger' in pipeline:
|
||||
pipeline.remove('tagger')
|
||||
|
@ -81,13 +86,9 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
|||
dropout_rates = util.decaying(util.env_opt('dropout_from', 0.2),
|
||||
util.env_opt('dropout_to', 0.2),
|
||||
util.env_opt('dropout_decay', 0.0))
|
||||
batch_sizes = util.compounding(util.env_opt('batch_from', 1),
|
||||
util.env_opt('batch_to', 16),
|
||||
batch_sizes = util.compounding(util.env_opt('batch_from', 1000),
|
||||
util.env_opt('batch_to', 1000),
|
||||
util.env_opt('batch_compound', 1.001))
|
||||
max_doc_len = util.env_opt('max_doc_len', 5000)
|
||||
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
||||
n_train_words = corpus.count_train()
|
||||
|
||||
lang_class = util.get_lang_class(lang)
|
||||
nlp = lang_class()
|
||||
meta['pipeline'] = pipeline
|
||||
|
@ -105,6 +106,7 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
|||
lex.is_oov = False
|
||||
for name in pipeline:
|
||||
nlp.add_pipe(nlp.create_pipe(name), name=name)
|
||||
nlp.add_pipe(nlp.create_pipe('merge_subtokens'))
|
||||
if parser_multitasks:
|
||||
for objective in parser_multitasks.split(','):
|
||||
nlp.parser.add_multitask_objective(objective)
|
||||
|
@ -117,19 +119,19 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
|||
print("Itn.\tP.Loss\tN.Loss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
|
||||
try:
|
||||
for i in range(n_iter):
|
||||
train_docs = corpus.train_docs(nlp, projectivize=True, noise_level=0.0,
|
||||
train_docs = corpus.train_docs(nlp, noise_level=0.0,
|
||||
gold_preproc=gold_preproc, max_length=0)
|
||||
words_seen = 0
|
||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||
losses = {}
|
||||
for batch in minibatch(train_docs, size=batch_sizes):
|
||||
batch = [(d, g) for (d, g) in batch if len(d) < max_doc_len]
|
||||
for batch in minibatch_by_words(train_docs, size=batch_sizes):
|
||||
if not batch:
|
||||
continue
|
||||
docs, golds = zip(*batch)
|
||||
nlp.update(docs, golds, sgd=optimizer,
|
||||
drop=next(dropout_rates), losses=losses)
|
||||
pbar.update(sum(len(doc) for doc in docs))
|
||||
|
||||
words_seen += sum(len(doc) for doc in docs)
|
||||
with nlp.use_params(optimizer.averages):
|
||||
util.set_env_log(False)
|
||||
epoch_model_path = output_path / ('model%d' % i)
|
||||
|
|
|
@ -13,9 +13,10 @@ import spacy
|
|||
import spacy.util
|
||||
from ..tokens import Token, Doc
|
||||
from ..gold import GoldParse
|
||||
from ..util import compounding
|
||||
from ..util import compounding, minibatch_by_words
|
||||
from ..syntax.nonproj import projectivize
|
||||
from ..matcher import Matcher
|
||||
from .. import displacy
|
||||
from collections import defaultdict, Counter
|
||||
from timeit import default_timer as timer
|
||||
|
||||
|
@ -37,30 +38,6 @@ lang.ja.Japanese.Defaults.use_janome = False
|
|||
random.seed(0)
|
||||
numpy.random.seed(0)
|
||||
|
||||
def minibatch_by_words(items, size):
|
||||
random.shuffle(items)
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
size_ = size
|
||||
items = iter(items)
|
||||
while True:
|
||||
batch_size = next(size_)
|
||||
batch = []
|
||||
while batch_size >= 0:
|
||||
try:
|
||||
doc, gold = next(items)
|
||||
except StopIteration:
|
||||
if batch:
|
||||
yield batch
|
||||
return
|
||||
batch_size -= len(doc)
|
||||
batch.append((doc, gold))
|
||||
if batch:
|
||||
yield batch
|
||||
else:
|
||||
break
|
||||
|
||||
################
|
||||
# Data reading #
|
||||
################
|
||||
|
@ -199,7 +176,7 @@ def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
|
|||
with sys_loc.open('r', encoding='utf8') as sys_file:
|
||||
sys_ud = conll17_ud_eval.load_conllu(sys_file)
|
||||
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
|
||||
return scores
|
||||
return docs, scores
|
||||
|
||||
|
||||
def write_conllu(docs, file_):
|
||||
|
@ -288,19 +265,11 @@ def initialize_pipeline(nlp, docs, golds, config):
|
|||
nlp.parser.add_multitask_objective('tag')
|
||||
if config.multitask_sent:
|
||||
nlp.parser.add_multitask_objective('sent_start')
|
||||
nlp.parser.moves.add_action(2, 'subtok')
|
||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||
for gold in golds:
|
||||
for tag in gold.tags:
|
||||
if tag is not None:
|
||||
nlp.tagger.add_label(tag)
|
||||
# Replace labels that didn't make the frequency cutoff
|
||||
actions = set(nlp.parser.labels)
|
||||
label_set = set([act.split('-')[1] for act in actions if '-' in act])
|
||||
for gold in golds:
|
||||
for i, label in enumerate(gold.labels):
|
||||
if label is not None and label not in label_set:
|
||||
gold.labels[i] = label.split('||')[0]
|
||||
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
|
||||
|
||||
|
||||
|
@ -372,7 +341,9 @@ def main(ud_dir, parses_dir, config, corpus, limit=0):
|
|||
batch_sizes = compounding(config.batch_size //10, config.batch_size, 1.001)
|
||||
for i in range(config.nr_epoch):
|
||||
docs = [nlp.make_doc(doc.text) for doc in docs]
|
||||
batches = minibatch_by_words(list(zip(docs, golds)), size=batch_sizes)
|
||||
Xs = list(zip(docs, golds))
|
||||
random.shuffle(Xs)
|
||||
batches = minibatch_by_words(Xs, size=batch_sizes)
|
||||
losses = {}
|
||||
n_train_words = sum(len(doc) for doc in docs)
|
||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||
|
@ -384,8 +355,16 @@ def main(ud_dir, parses_dir, config, corpus, limit=0):
|
|||
|
||||
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
|
||||
with nlp.use_params(optimizer.averages):
|
||||
scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
|
||||
parsed_docs, scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
|
||||
print_progress(i, losses, scores)
|
||||
_render_parses(i, parsed_docs[:50])
|
||||
|
||||
|
||||
def _render_parses(i, to_render):
|
||||
to_render[0].user_data['title'] = "Batch %d" % i
|
||||
with Path('/tmp/parses.html').open('w') as file_:
|
||||
html = displacy.render(to_render[:5], style='dep', page=True)
|
||||
file_.write(html)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
305
spacy/gold.pyx
305
spacy/gold.pyx
|
@ -3,18 +3,25 @@
|
|||
from __future__ import unicode_literals, print_function
|
||||
|
||||
import re
|
||||
import ujson
|
||||
import random
|
||||
import cytoolz
|
||||
import itertools
|
||||
import numpy
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import msgpack
|
||||
|
||||
import ujson
|
||||
|
||||
from . import _align
|
||||
from .syntax import nonproj
|
||||
from .tokens import Doc
|
||||
from . import util
|
||||
from .util import minibatch, itershuffle
|
||||
from .compat import json_dumps
|
||||
|
||||
from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek
|
||||
|
||||
def tags_to_entities(tags):
|
||||
entities = []
|
||||
|
@ -85,106 +92,38 @@ def align(cand_words, gold_words):
|
|||
class GoldCorpus(object):
|
||||
"""An annotated corpus, using the JSON file format. Manages
|
||||
annotations for tagging, dependency parsing and NER."""
|
||||
def __init__(self, train_path, dev_path, gold_preproc=True, limit=None):
|
||||
def __init__(self, train, dev, gold_preproc=False, limit=None):
|
||||
"""Create a GoldCorpus.
|
||||
|
||||
train_path (unicode or Path): File or directory of training data.
|
||||
dev_path (unicode or Path): File or directory of development data.
|
||||
RETURNS (GoldCorpus): The newly created object.
|
||||
"""
|
||||
self.train_path = util.ensure_path(train_path)
|
||||
self.dev_path = util.ensure_path(dev_path)
|
||||
self.limit = limit
|
||||
self.train_locs = self.walk_corpus(self.train_path)
|
||||
self.dev_locs = self.walk_corpus(self.dev_path)
|
||||
if isinstance(train, str) or isinstance(train, Path):
|
||||
train = self.read_tuples(self.walk_corpus(train))
|
||||
dev = self.read_tuples(self.walk_corpus(dev))
|
||||
|
||||
@property
|
||||
def train_tuples(self):
|
||||
i = 0
|
||||
for loc in self.train_locs:
|
||||
gold_tuples = read_json_file(loc)
|
||||
for item in gold_tuples:
|
||||
yield item
|
||||
i += len(item[1])
|
||||
if self.limit and i >= self.limit:
|
||||
break
|
||||
# Write temp directory with one doc per file, so we can shuffle
|
||||
# and stream
|
||||
self.tmp_dir = Path(tempfile.mkdtemp())
|
||||
self.write_msgpack(self.tmp_dir / 'train', train)
|
||||
self.write_msgpack(self.tmp_dir / 'dev', dev)
|
||||
|
||||
@property
|
||||
def dev_tuples(self):
|
||||
i = 0
|
||||
for loc in self.dev_locs:
|
||||
gold_tuples = read_json_file(loc)
|
||||
for item in gold_tuples:
|
||||
yield item
|
||||
i += len(item[1])
|
||||
if self.limit and i >= self.limit:
|
||||
break
|
||||
|
||||
def count_train(self):
|
||||
n = 0
|
||||
i = 0
|
||||
for raw_text, paragraph_tuples in self.train_tuples:
|
||||
n += sum([len(s[0][1]) for s in paragraph_tuples])
|
||||
if self.limit and i >= self.limit:
|
||||
break
|
||||
i += len(paragraph_tuples)
|
||||
return n
|
||||
|
||||
def train_docs(self, nlp, gold_preproc=False,
|
||||
projectivize=False, max_length=None,
|
||||
noise_level=0.0):
|
||||
if projectivize:
|
||||
train_tuples = nonproj.preprocess_training_data(
|
||||
self.train_tuples, label_freq_cutoff=30)
|
||||
random.shuffle(self.train_locs)
|
||||
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc,
|
||||
max_length=max_length,
|
||||
noise_level=noise_level)
|
||||
yield from itershuffle(gold_docs, bufsize=100)
|
||||
|
||||
def dev_docs(self, nlp, gold_preproc=False):
|
||||
gold_docs = self.iter_gold_docs(nlp, self.dev_tuples, gold_preproc)
|
||||
yield from gold_docs
|
||||
|
||||
@classmethod
|
||||
def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None,
|
||||
noise_level=0.0):
|
||||
for raw_text, paragraph_tuples in tuples:
|
||||
if gold_preproc:
|
||||
raw_text = None
|
||||
else:
|
||||
paragraph_tuples = merge_sents(paragraph_tuples)
|
||||
docs = cls._make_docs(nlp, raw_text, paragraph_tuples,
|
||||
gold_preproc, noise_level=noise_level)
|
||||
golds = cls._make_golds(docs, paragraph_tuples)
|
||||
for doc, gold in zip(docs, golds):
|
||||
if (not max_length) or len(doc) < max_length:
|
||||
yield doc, gold
|
||||
|
||||
@classmethod
|
||||
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc,
|
||||
noise_level=0.0):
|
||||
if raw_text is not None:
|
||||
raw_text = add_noise(raw_text, noise_level)
|
||||
return [nlp.make_doc(raw_text)]
|
||||
else:
|
||||
return [Doc(nlp.vocab,
|
||||
words=add_noise(sent_tuples[1], noise_level))
|
||||
for (sent_tuples, brackets) in paragraph_tuples]
|
||||
|
||||
@classmethod
|
||||
def _make_golds(cls, docs, paragraph_tuples):
|
||||
assert len(docs) == len(paragraph_tuples)
|
||||
if len(docs) == 1:
|
||||
return [GoldParse.from_annot_tuples(docs[0],
|
||||
paragraph_tuples[0][0])]
|
||||
else:
|
||||
return [GoldParse.from_annot_tuples(doc, sent_tuples)
|
||||
for doc, (sent_tuples, brackets)
|
||||
in zip(docs, paragraph_tuples)]
|
||||
def __del__(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
|
||||
@staticmethod
|
||||
def write_msgpack(directory, doc_tuples):
|
||||
if not directory.exists():
|
||||
directory.mkdir()
|
||||
for i, doc_tuple in enumerate(doc_tuples):
|
||||
with open(directory / '{}.msg'.format(i), 'wb') as file_:
|
||||
msgpack.dump([doc_tuple], file_, use_bin_type=True, encoding='utf8')
|
||||
|
||||
@staticmethod
|
||||
def walk_corpus(path):
|
||||
path = util.ensure_path(path)
|
||||
if not path.is_dir():
|
||||
return [path]
|
||||
paths = [path]
|
||||
|
@ -202,6 +141,101 @@ class GoldCorpus(object):
|
|||
locs.append(path)
|
||||
return locs
|
||||
|
||||
@staticmethod
|
||||
def read_tuples(locs, limit=0):
|
||||
i = 0
|
||||
for loc in locs:
|
||||
loc = util.ensure_path(loc)
|
||||
if loc.parts[-1].endswith('json'):
|
||||
gold_tuples = read_json_file(loc)
|
||||
elif loc.parts[-1].endswith('msg'):
|
||||
with loc.open('rb') as file_:
|
||||
gold_tuples = msgpack.load(file_, encoding='utf8')
|
||||
else:
|
||||
msg = "Cannot read from file: %s. Supported formats: .json, .msg"
|
||||
raise ValueError(msg % loc)
|
||||
for item in gold_tuples:
|
||||
yield item
|
||||
i += len(item[1])
|
||||
if limit and i >= limit:
|
||||
break
|
||||
|
||||
@property
|
||||
def dev_tuples(self):
|
||||
locs = (self.tmp_dir / 'dev').iterdir()
|
||||
yield from self.read_tuples(locs, limit=self.limit)
|
||||
|
||||
@property
|
||||
def train_tuples(self):
|
||||
locs = (self.tmp_dir / 'train').iterdir()
|
||||
yield from self.read_tuples(locs, limit=self.limit)
|
||||
|
||||
def count_train(self):
|
||||
n = 0
|
||||
i = 0
|
||||
for raw_text, paragraph_tuples in self.train_tuples:
|
||||
for sent_tuples, brackets in paragraph_tuples:
|
||||
n += len(sent_tuples[1])
|
||||
if self.limit and i >= self.limit:
|
||||
break
|
||||
i += len(paragraph_tuples)
|
||||
return n
|
||||
|
||||
def train_docs(self, nlp, gold_preproc=False, max_length=None,
|
||||
noise_level=0.0):
|
||||
locs = list((self.tmp_dir / 'train').iterdir())
|
||||
random.shuffle(locs)
|
||||
train_tuples = self.read_tuples(locs, limit=self.limit)
|
||||
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc,
|
||||
max_length=max_length,
|
||||
noise_level=noise_level,
|
||||
make_projective=True)
|
||||
yield from gold_docs
|
||||
|
||||
def dev_docs(self, nlp, gold_preproc=False):
|
||||
gold_docs = self.iter_gold_docs(nlp, self.dev_tuples,
|
||||
gold_preproc=gold_preproc)
|
||||
yield from gold_docs
|
||||
|
||||
@classmethod
|
||||
def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None,
|
||||
noise_level=0.0, make_projective=False):
|
||||
for raw_text, paragraph_tuples in tuples:
|
||||
if gold_preproc:
|
||||
raw_text = None
|
||||
else:
|
||||
paragraph_tuples = merge_sents(paragraph_tuples)
|
||||
docs = cls._make_docs(nlp, raw_text, paragraph_tuples,
|
||||
gold_preproc, noise_level=noise_level)
|
||||
golds = cls._make_golds(docs, paragraph_tuples, make_projective)
|
||||
for doc, gold in zip(docs, golds):
|
||||
if (not max_length) or len(doc) < max_length:
|
||||
yield doc, gold
|
||||
|
||||
@classmethod
|
||||
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc,
|
||||
noise_level=0.0):
|
||||
if raw_text is not None:
|
||||
raw_text = add_noise(raw_text, noise_level)
|
||||
return [nlp.make_doc(raw_text)]
|
||||
else:
|
||||
return [Doc(nlp.vocab,
|
||||
words=add_noise(sent_tuples[1], noise_level))
|
||||
for (sent_tuples, brackets) in paragraph_tuples]
|
||||
|
||||
@classmethod
|
||||
def _make_golds(cls, docs, paragraph_tuples, make_projective):
|
||||
assert len(docs) == len(paragraph_tuples)
|
||||
if len(docs) == 1:
|
||||
return [GoldParse.from_annot_tuples(docs[0],
|
||||
paragraph_tuples[0][0],
|
||||
make_projective=make_projective)]
|
||||
else:
|
||||
return [GoldParse.from_annot_tuples(doc, sent_tuples,
|
||||
make_projective=make_projective)
|
||||
for doc, (sent_tuples, brackets)
|
||||
in zip(docs, paragraph_tuples)]
|
||||
|
||||
|
||||
def add_noise(orig, noise_level):
|
||||
if random.random() >= noise_level:
|
||||
|
@ -233,11 +267,7 @@ def read_json_file(loc, docs_filter=None, limit=None):
|
|||
for filename in loc.iterdir():
|
||||
yield from read_json_file(loc / filename, limit=limit)
|
||||
else:
|
||||
with loc.open('r', encoding='utf8') as file_:
|
||||
docs = ujson.load(file_)
|
||||
if limit is not None:
|
||||
docs = docs[:limit]
|
||||
for doc in docs:
|
||||
for doc in _json_iterate(loc):
|
||||
if docs_filter is not None and not docs_filter(doc):
|
||||
continue
|
||||
paragraphs = []
|
||||
|
@ -267,6 +297,56 @@ def read_json_file(loc, docs_filter=None, limit=None):
|
|||
yield [paragraph.get('raw', None), sents]
|
||||
|
||||
|
||||
def _json_iterate(loc):
|
||||
# We should've made these files jsonl...But since we didn't, parse out
|
||||
# the docs one-by-one to reduce memory usage.
|
||||
# It's okay to read in the whole file -- just don't parse it into JSON.
|
||||
cdef bytes py_raw
|
||||
loc = util.ensure_path(loc)
|
||||
with loc.open('rb') as file_:
|
||||
py_raw = file_.read()
|
||||
raw = <char*>py_raw
|
||||
cdef int square_depth = 0
|
||||
cdef int curly_depth = 0
|
||||
cdef int inside_string = 0
|
||||
cdef int escape = 0
|
||||
cdef int start = -1
|
||||
cdef char c
|
||||
cdef char quote = ord('"')
|
||||
cdef char backslash = ord('\\')
|
||||
cdef char open_square = ord('[')
|
||||
cdef char close_square = ord(']')
|
||||
cdef char open_curly = ord('{')
|
||||
cdef char close_curly = ord('}')
|
||||
for i in range(len(py_raw)):
|
||||
c = raw[i]
|
||||
if c == backslash:
|
||||
escape = True
|
||||
continue
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if c == quote:
|
||||
inside_string = not inside_string
|
||||
continue
|
||||
if inside_string:
|
||||
continue
|
||||
if c == open_square:
|
||||
square_depth += 1
|
||||
elif c == close_square:
|
||||
square_depth -= 1
|
||||
elif c == open_curly:
|
||||
if square_depth == 1 and curly_depth == 0:
|
||||
start = i
|
||||
curly_depth += 1
|
||||
elif c == close_curly:
|
||||
curly_depth -= 1
|
||||
if square_depth == 1 and curly_depth == 0:
|
||||
py_str = py_raw[start : i+1].decode('utf8')
|
||||
yield ujson.loads(py_str)
|
||||
start = -1
|
||||
|
||||
|
||||
def iob_to_biluo(tags):
|
||||
out = []
|
||||
curr_label = None
|
||||
|
@ -370,6 +450,10 @@ cdef class GoldParse:
|
|||
self.labels = [None] * len(doc)
|
||||
self.ner = [None] * len(doc)
|
||||
|
||||
# This needs to be done before we align the words
|
||||
if make_projective and heads is not None and deps is not None:
|
||||
heads, deps = nonproj.projectivize(heads, deps)
|
||||
|
||||
# Do many-to-one alignment for misaligned tokens.
|
||||
# If we over-segment, we'll have one gold word that covers a sequence
|
||||
# of predicted words
|
||||
|
@ -396,14 +480,39 @@ cdef class GoldParse:
|
|||
if i in i2j_multi:
|
||||
self.words[i] = words[i2j_multi[i]]
|
||||
self.tags[i] = tags[i2j_multi[i]]
|
||||
is_last = i2j_multi[i] != i2j_multi.get(i+1)
|
||||
is_first = i2j_multi[i] != i2j_multi.get(i-1)
|
||||
# Set next word in multi-token span as head, until last
|
||||
if i2j_multi[i] == i2j_multi.get(i+1):
|
||||
if not is_last:
|
||||
self.heads[i] = i+1
|
||||
self.labels[i] = 'subtok'
|
||||
else:
|
||||
self.heads[i] = self.gold_to_cand[heads[i2j_multi[i]]]
|
||||
self.labels[i] = deps[i2j_multi[i]]
|
||||
# TODO: Set NER!
|
||||
# Now set NER...This is annoying because if we've split
|
||||
# got an entity word split into two, we need to adjust the
|
||||
# BILOU tags. We can't have BB or LL etc.
|
||||
# Case 1: O -- easy.
|
||||
ner_tag = entities[i2j_multi[i]]
|
||||
if ner_tag == 'O':
|
||||
self.ner[i] = 'O'
|
||||
# Case 2: U. This has to become a B I* L sequence.
|
||||
elif ner_tag.startswith('U-'):
|
||||
if is_first:
|
||||
self.ner[i] = ner_tag.replace('U-', 'B-', 1)
|
||||
elif is_last:
|
||||
self.ner[i] = ner_tag.replace('U-', 'L-', 1)
|
||||
else:
|
||||
self.ner[i] = ner_tag.replace('U-', 'I-', 1)
|
||||
# Case 3: L. If not last, change to I.
|
||||
elif ner_tag.startswith('L-'):
|
||||
if is_last:
|
||||
self.ner[i] = ner_tag
|
||||
else:
|
||||
self.ner[i] = ner_tag.replace('L-', 'I-', 1)
|
||||
# Case 4: I. Stays correct
|
||||
elif ner_tag.startswith('I-'):
|
||||
self.ner[i] = ner_tag
|
||||
else:
|
||||
self.words[i] = words[gold_i]
|
||||
self.tags[i] = tags[gold_i]
|
||||
|
@ -418,10 +527,6 @@ cdef class GoldParse:
|
|||
if cycle is not None:
|
||||
raise Exception("Cycle found: %s" % cycle)
|
||||
|
||||
if make_projective:
|
||||
proj_heads, _ = nonproj.projectivize(self.heads, self.labels)
|
||||
self.heads = proj_heads
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of gold-standard tokens.
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from .vocab import Vocab
|
|||
from .lemmatizer import Lemmatizer
|
||||
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer
|
||||
from .pipeline import SimilarityHook, TextCategorizer, SentenceSegmenter
|
||||
from .pipeline import merge_noun_chunks, merge_entities
|
||||
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
|
||||
from .compat import json_dumps, izip, basestring_
|
||||
from .gold import GoldParse
|
||||
from .scorer import Scorer
|
||||
|
@ -108,7 +108,8 @@ class Language(object):
|
|||
'sbd': lambda nlp, **cfg: SentenceSegmenter(nlp.vocab, **cfg),
|
||||
'sentencizer': lambda nlp, **cfg: SentenceSegmenter(nlp.vocab, **cfg),
|
||||
'merge_noun_chunks': lambda nlp, **cfg: merge_noun_chunks,
|
||||
'merge_entities': lambda nlp, **cfg: merge_entities
|
||||
'merge_entities': lambda nlp, **cfg: merge_entities,
|
||||
'merge_subtokens': lambda nlp, **cfg: merge_subtokens,
|
||||
}
|
||||
|
||||
def __init__(self, vocab=True, make_doc=True, meta={}, **kwargs):
|
||||
|
|
|
@ -25,6 +25,7 @@ from .morphology cimport Morphology
|
|||
from .vocab cimport Vocab
|
||||
from .syntax import nonproj
|
||||
from .compat import json_dumps
|
||||
from .matcher import Matcher
|
||||
|
||||
from .attrs import POS
|
||||
from .parts_of_speech import X
|
||||
|
@ -97,6 +98,17 @@ def merge_entities(doc):
|
|||
return doc
|
||||
|
||||
|
||||
def merge_subtokens(doc, label='subtok'):
|
||||
merger = Matcher(doc.vocab)
|
||||
merger.add('SUBTOK', None, [{'DEP': label, 'op': '+'}])
|
||||
matches = merger(doc)
|
||||
spans = [doc[start:end+1] for _, start, end in matches]
|
||||
offsets = [(span.start_char, span.end_char) for span in spans]
|
||||
for start_char, end_char in offsets:
|
||||
doc.merge(start_char, end_char)
|
||||
return doc
|
||||
|
||||
|
||||
class Pipe(object):
|
||||
"""This class is not instantiated directly. Components inherit from it, and
|
||||
it defines the interface that components should follow to function as
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# coding: utf8
|
||||
from __future__ import division, print_function, unicode_literals
|
||||
|
||||
from .gold import tags_to_entities
|
||||
from .gold import tags_to_entities, GoldParse
|
||||
|
||||
|
||||
class PRFScore(object):
|
||||
|
@ -84,6 +84,8 @@ class Scorer(object):
|
|||
}
|
||||
|
||||
def score(self, tokens, gold, verbose=False, punct_labels=('p', 'punct')):
|
||||
if len(tokens) != len(gold):
|
||||
gold = GoldParse.from_annot_tuples(tokens, zip(*gold.orig_annot))
|
||||
assert len(tokens) == len(gold)
|
||||
gold_deps = set()
|
||||
gold_tags = set()
|
||||
|
|
|
@ -108,7 +108,7 @@ cdef cppclass StateC:
|
|||
ids[1] = this.B(1)
|
||||
ids[2] = this.S(0)
|
||||
ids[3] = this.S(1)
|
||||
ids[4] = this.H(this.S(0))
|
||||
ids[4] = this.S(2)
|
||||
ids[5] = this.L(this.B(0), 1)
|
||||
ids[6] = this.L(this.S(0), 1)
|
||||
ids[7] = this.R(this.S(0), 1)
|
||||
|
|
|
@ -6,16 +6,19 @@ from __future__ import unicode_literals
|
|||
|
||||
from cpython.ref cimport Py_INCREF
|
||||
from cymem.cymem cimport Pool
|
||||
from collections import OrderedDict
|
||||
from collections import OrderedDict, defaultdict, Counter
|
||||
from thinc.extra.search cimport Beam
|
||||
import json
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
from .nonproj import is_nonproj_tree
|
||||
from . import nonproj
|
||||
from .transition_system cimport move_cost_func_t, label_cost_func_t
|
||||
from ..gold cimport GoldParse, GoldParseC
|
||||
from ..structs cimport TokenC
|
||||
|
||||
# Calculate cost as gold/not gold. We don't use scalar value anyway.
|
||||
cdef int BINARY_COSTS = 1
|
||||
|
||||
DEF NON_MONOTONIC = True
|
||||
DEF USE_BREAK = True
|
||||
|
@ -54,6 +57,8 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
|
|||
cost += 1
|
||||
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
|
||||
cost += 1
|
||||
if BINARY_COSTS and cost >= 1:
|
||||
return cost
|
||||
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
|
||||
return cost
|
||||
|
||||
|
@ -67,6 +72,8 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog
|
|||
cost += gold.heads[target] == B_i
|
||||
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
|
||||
break
|
||||
if BINARY_COSTS and cost >= 1:
|
||||
return cost
|
||||
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
||||
cost += 1
|
||||
return cost
|
||||
|
@ -315,39 +322,42 @@ cdef class ArcEager(TransitionSystem):
|
|||
|
||||
@classmethod
|
||||
def get_actions(cls, **kwargs):
|
||||
actions = kwargs.get('actions', OrderedDict((
|
||||
(SHIFT, ['']),
|
||||
(REDUCE, ['']),
|
||||
(RIGHT, []),
|
||||
(LEFT, ['subtok']),
|
||||
(BREAK, ['ROOT']))
|
||||
))
|
||||
seen_actions = set()
|
||||
min_freq = kwargs.get('min_freq', None)
|
||||
actions = defaultdict(lambda: Counter())
|
||||
actions[SHIFT][''] = 1
|
||||
actions[REDUCE][''] = 1
|
||||
for label in kwargs.get('left_labels', []):
|
||||
if label.upper() != 'ROOT':
|
||||
if (LEFT, label) not in seen_actions:
|
||||
actions[LEFT].append(label)
|
||||
seen_actions.add((LEFT, label))
|
||||
actions[LEFT][label] = 1
|
||||
actions[SHIFT][label] = 1
|
||||
for label in kwargs.get('right_labels', []):
|
||||
if label.upper() != 'ROOT':
|
||||
if (RIGHT, label) not in seen_actions:
|
||||
actions[RIGHT].append(label)
|
||||
seen_actions.add((RIGHT, label))
|
||||
|
||||
actions[RIGHT][label] = 1
|
||||
actions[REDUCE][label] = 1
|
||||
for raw_text, sents in kwargs.get('gold_parses', []):
|
||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||
heads, labels = nonproj.projectivize(heads, labels)
|
||||
for child, head, label in zip(ids, heads, labels):
|
||||
if label.upper() == 'ROOT':
|
||||
if label.upper() == 'ROOT' :
|
||||
label = 'ROOT'
|
||||
if label != 'ROOT':
|
||||
if head < child:
|
||||
if (RIGHT, label) not in seen_actions:
|
||||
actions[RIGHT].append(label)
|
||||
seen_actions.add((RIGHT, label))
|
||||
elif head > child:
|
||||
if (LEFT, label) not in seen_actions:
|
||||
actions[LEFT].append(label)
|
||||
seen_actions.add((LEFT, label))
|
||||
if head == child:
|
||||
actions[BREAK][label] += 1
|
||||
elif head < child:
|
||||
actions[RIGHT][label] += 1
|
||||
actions[REDUCE][''] += 1
|
||||
elif head > child:
|
||||
actions[LEFT][label] += 1
|
||||
actions[SHIFT][''] += 1
|
||||
if min_freq is not None:
|
||||
for action, label_freqs in actions.items():
|
||||
for label, freq in list(label_freqs.items()):
|
||||
if freq < min_freq:
|
||||
label_freqs.pop(label)
|
||||
# Ensure these actions are present
|
||||
actions[BREAK].setdefault('ROOT', 0)
|
||||
actions[RIGHT].setdefault('subtok', 0)
|
||||
actions[LEFT].setdefault('subtok', 0)
|
||||
# Used for backoff
|
||||
actions[RIGHT].setdefault('dep', 0)
|
||||
actions[LEFT].setdefault('dep', 0)
|
||||
return actions
|
||||
|
||||
property action_types:
|
||||
|
@ -379,18 +389,34 @@ cdef class ArcEager(TransitionSystem):
|
|||
def preprocess_gold(self, GoldParse gold):
|
||||
if not self.has_gold(gold):
|
||||
return None
|
||||
for i in range(gold.length):
|
||||
for i, (head, dep) in enumerate(zip(gold.heads, gold.labels)):
|
||||
# Missing values
|
||||
if gold.heads[i] is None or gold.labels[i] is None:
|
||||
if head is None or dep is None:
|
||||
gold.c.heads[i] = i
|
||||
gold.c.has_dep[i] = False
|
||||
else:
|
||||
label = gold.labels[i]
|
||||
if head > i:
|
||||
action = LEFT
|
||||
elif head < i:
|
||||
action = RIGHT
|
||||
else:
|
||||
action = BREAK
|
||||
if dep not in self.labels[action]:
|
||||
if action == BREAK:
|
||||
dep = 'ROOT'
|
||||
elif nonproj.is_decorated(dep):
|
||||
backoff = nonproj.decompose(dep)[0]
|
||||
if backoff in self.labels[action]:
|
||||
dep = backoff
|
||||
else:
|
||||
dep = 'dep'
|
||||
else:
|
||||
dep = 'dep'
|
||||
gold.c.has_dep[i] = True
|
||||
if label.upper() == 'ROOT':
|
||||
label = 'ROOT'
|
||||
gold.c.heads[i] = gold.heads[i]
|
||||
gold.c.labels[i] = self.strings.add(label)
|
||||
if dep.upper() == 'ROOT':
|
||||
dep = 'ROOT'
|
||||
gold.c.heads[i] = head
|
||||
gold.c.labels[i] = self.strings.add(dep)
|
||||
return gold
|
||||
|
||||
def get_beam_parses(self, Beam beam):
|
||||
|
@ -536,7 +562,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
if label_str is not None and label_str not in label_set:
|
||||
raise ValueError("Cannot get gold parser action: unknown label: %s" % label_str)
|
||||
# Check projectivity --- other leading cause
|
||||
if is_nonproj_tree(gold.heads):
|
||||
if nonproj.is_nonproj_tree(gold.heads):
|
||||
raise ValueError(
|
||||
"Could not find a gold-standard action to supervise the "
|
||||
"dependency parser. Likely cause: the tree is "
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
|||
|
||||
from thinc.typedefs cimport weight_t
|
||||
from thinc.extra.search cimport Beam
|
||||
from collections import OrderedDict
|
||||
from collections import OrderedDict, Counter
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
|
@ -64,21 +64,18 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
|
||||
@classmethod
|
||||
def get_actions(cls, **kwargs):
|
||||
actions = kwargs.get('actions', OrderedDict((
|
||||
(MISSING, ['']),
|
||||
(BEGIN, []),
|
||||
(IN, []),
|
||||
(LAST, []),
|
||||
(UNIT, []),
|
||||
(OUT, [''])
|
||||
)))
|
||||
seen_entities = set()
|
||||
actions = {
|
||||
MISSING: Counter(),
|
||||
BEGIN: Counter(),
|
||||
IN: Counter(),
|
||||
LAST: Counter(),
|
||||
UNIT: Counter(),
|
||||
OUT: Counter()
|
||||
}
|
||||
actions[OUT][''] = 1
|
||||
for entity_type in kwargs.get('entity_types', []):
|
||||
if entity_type in seen_entities:
|
||||
continue
|
||||
seen_entities.add(entity_type)
|
||||
for action in (BEGIN, IN, LAST, UNIT):
|
||||
actions[action].append(entity_type)
|
||||
actions[action][entity_type] = 1
|
||||
moves = ('M', 'B', 'I', 'L', 'U')
|
||||
for raw_text, sents in kwargs.get('gold_parses', []):
|
||||
for (ids, words, tags, heads, labels, biluo), _ in sents:
|
||||
|
@ -87,10 +84,8 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
if ner_tag.count('-') != 1:
|
||||
raise ValueError(ner_tag)
|
||||
_, label = ner_tag.split('-')
|
||||
if label not in seen_entities:
|
||||
seen_entities.add(label)
|
||||
for move_str in ('B', 'I', 'L', 'U'):
|
||||
actions[moves.index(move_str)].append(label)
|
||||
for action in (BEGIN, IN, LAST, UNIT):
|
||||
actions[action][label] += 1
|
||||
return actions
|
||||
|
||||
property action_types:
|
||||
|
@ -213,7 +208,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
raise Exception(move)
|
||||
return t
|
||||
|
||||
def add_action(self, int action, label_name):
|
||||
def add_action(self, int action, label_name, freq=None):
|
||||
cdef attr_t label_id
|
||||
if not isinstance(label_name, (int, long)):
|
||||
label_id = self.strings.add(label_name)
|
||||
|
@ -234,6 +229,12 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id)
|
||||
assert self.c[self.n_moves].label == label_id
|
||||
self.n_moves += 1
|
||||
if self.labels.get(action, []):
|
||||
freq = min(0, min(self.labels[action].values()))
|
||||
self.labels[action][label_name] = freq-1
|
||||
else:
|
||||
self.labels[action] = Counter()
|
||||
self.labels[action][label_name] = -1
|
||||
return 1
|
||||
|
||||
cdef int initialize_state(self, StateC* st) nogil:
|
||||
|
|
|
@ -302,7 +302,7 @@ cdef class Parser:
|
|||
"""
|
||||
self.vocab = vocab
|
||||
if moves is True:
|
||||
self.moves = self.TransitionSystem(self.vocab.strings, {})
|
||||
self.moves = self.TransitionSystem(self.vocab.strings)
|
||||
else:
|
||||
self.moves = moves
|
||||
if 'beam_width' not in cfg:
|
||||
|
@ -311,12 +311,7 @@ cdef class Parser:
|
|||
cfg['beam_density'] = util.env_opt('beam_density', 0.0)
|
||||
if 'pretrained_dims' not in cfg:
|
||||
cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
||||
cfg.setdefault('cnn_maxout_pieces', 3)
|
||||
self.cfg = cfg
|
||||
if 'actions' in self.cfg:
|
||||
for action, labels in self.cfg.get('actions', {}).items():
|
||||
for label in labels:
|
||||
self.moves.add_action(action, label)
|
||||
self.model = model
|
||||
self._multitasks = []
|
||||
|
||||
|
@ -676,7 +671,6 @@ cdef class Parser:
|
|||
for beam in beams:
|
||||
_cleanup(beam)
|
||||
|
||||
|
||||
def _init_gold_batch(self, whole_docs, whole_golds, min_length=5, max_length=500):
|
||||
"""Make a square batch, of length equal to the shortest doc. A long
|
||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||
|
@ -831,9 +825,6 @@ cdef class Parser:
|
|||
for action in self.moves.action_types:
|
||||
added = self.moves.add_action(action, label)
|
||||
if added:
|
||||
# Important that the labels be stored as a list! We need the
|
||||
# order, or the model goes out of synch
|
||||
self.cfg.setdefault('extra_labels', []).append(label)
|
||||
resized = True
|
||||
if self.model not in (True, False, None) and resized:
|
||||
# Weights are stored in (nr_out, nr_in) format, so we're basically
|
||||
|
@ -847,12 +838,10 @@ cdef class Parser:
|
|||
def begin_training(self, gold_tuples, pipeline=None, sgd=None, **cfg):
|
||||
if 'model' in cfg:
|
||||
self.model = cfg['model']
|
||||
gold_tuples = nonproj.preprocess_training_data(gold_tuples,
|
||||
label_freq_cutoff=30)
|
||||
actions = self.moves.get_actions(gold_parses=gold_tuples)
|
||||
for action, labels in actions.items():
|
||||
for label in labels:
|
||||
self.moves.add_action(action, label)
|
||||
cfg.setdefault('min_action_freq', 30)
|
||||
actions = self.moves.get_actions(gold_parses=gold_tuples,
|
||||
min_freq=cfg.get('min_action_freq', 30))
|
||||
self.moves.initialize_actions(actions)
|
||||
cfg.setdefault('token_vector_width', 128)
|
||||
if self.model is True:
|
||||
cfg['pretrained_dims'] = self.vocab.vectors_length
|
||||
|
@ -860,7 +849,7 @@ cdef class Parser:
|
|||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
self.model[1].begin_training(
|
||||
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
|
||||
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
|
||||
if pipeline is not None:
|
||||
self.init_multitask_objectives(gold_tuples, pipeline, sgd=sgd, **cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
|
|
|
@ -9,7 +9,7 @@ from __future__ import unicode_literals
|
|||
|
||||
from copy import copy
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..tokens.doc cimport Doc, set_children_from_heads
|
||||
|
||||
|
||||
DELIMITER = '||'
|
||||
|
@ -74,7 +74,21 @@ def decompose(label):
|
|||
|
||||
|
||||
def is_decorated(label):
|
||||
return label.find(DELIMITER) != -1
|
||||
return DELIMITER in label
|
||||
|
||||
def count_decorated_labels(gold_tuples):
|
||||
freqs = {}
|
||||
for raw_text, sents in gold_tuples:
|
||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||
proj_heads, deco_labels = projectivize(heads, labels)
|
||||
# set the label to ROOT for each root dependent
|
||||
deco_labels = ['ROOT' if head == i else deco_labels[i]
|
||||
for i, head in enumerate(proj_heads)]
|
||||
# count label frequencies
|
||||
for label in deco_labels:
|
||||
if is_decorated(label):
|
||||
freqs[label] = freqs.get(label, 0) + 1
|
||||
return freqs
|
||||
|
||||
|
||||
def preprocess_training_data(gold_tuples, label_freq_cutoff=30):
|
||||
|
@ -124,8 +138,9 @@ cpdef deprojectivize(Doc doc):
|
|||
if DELIMITER in label:
|
||||
new_label, head_label = label.split(DELIMITER)
|
||||
new_head = _find_new_head(doc[i], head_label)
|
||||
doc[i].head = new_head
|
||||
doc.c[i].head = new_head.i - i
|
||||
doc.c[i].dep = doc.vocab.strings.add(new_label)
|
||||
set_children_from_heads(doc.c, doc.length)
|
||||
return doc
|
||||
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@ cdef class TransitionSystem:
|
|||
cdef public attr_t root_label
|
||||
cdef public freqs
|
||||
cdef init_state_t init_beam_state
|
||||
cdef public object labels
|
||||
|
||||
cdef int initialize_state(self, StateC* state) nogil
|
||||
cdef int finalize_state(self, StateC* state) nogil
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import unicode_literals
|
|||
from cpython.ref cimport Py_INCREF
|
||||
from cymem.cymem cimport Pool
|
||||
from thinc.typedefs cimport weight_t
|
||||
from collections import OrderedDict
|
||||
from collections import OrderedDict, Counter
|
||||
import ujson
|
||||
|
||||
from ..structs cimport TokenC
|
||||
|
@ -28,7 +28,7 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
|||
|
||||
|
||||
cdef class TransitionSystem:
|
||||
def __init__(self, StringStore string_table, labels_by_action):
|
||||
def __init__(self, StringStore string_table, labels_by_action=None, min_freq=None):
|
||||
self.mem = Pool()
|
||||
self.strings = string_table
|
||||
self.n_moves = 0
|
||||
|
@ -36,21 +36,14 @@ cdef class TransitionSystem:
|
|||
|
||||
self.c = <Transition*>self.mem.alloc(self._size, sizeof(Transition))
|
||||
|
||||
for action, label_strs in labels_by_action.items():
|
||||
for label_str in label_strs:
|
||||
self.add_action(int(action), label_str)
|
||||
self.labels = {}
|
||||
if labels_by_action:
|
||||
self.initialize_actions(labels_by_action, min_freq=min_freq)
|
||||
self.root_label = self.strings.add('ROOT')
|
||||
self.init_beam_state = _init_state
|
||||
|
||||
def __reduce__(self):
|
||||
labels_by_action = OrderedDict()
|
||||
cdef Transition t
|
||||
for trans in self.c[:self.n_moves]:
|
||||
label_str = self.strings[trans.label]
|
||||
labels_by_action.setdefault(trans.move, []).append(label_str)
|
||||
return (self.__class__,
|
||||
(self.strings, labels_by_action),
|
||||
None, None)
|
||||
return (self.__class__, (self.strings, self.labels), None, None)
|
||||
|
||||
def init_batch(self, docs):
|
||||
cdef StateClass state
|
||||
|
@ -146,6 +139,22 @@ cdef class TransitionSystem:
|
|||
act = self.c[clas]
|
||||
return self.move_name(act.move, act.label)
|
||||
|
||||
def initialize_actions(self, labels_by_action, min_freq=None):
|
||||
self.labels = {}
|
||||
self.n_moves = 0
|
||||
for action, label_freqs in sorted(labels_by_action.items()):
|
||||
action = int(action)
|
||||
# Make sure we take a copy here, and that we get a Counter
|
||||
self.labels[action] = Counter()
|
||||
# Have to be careful here: Sorting must be stable, or our model
|
||||
# won't be read back in correctly.
|
||||
sorted_labels = [(f, L) for L, f in label_freqs.items()]
|
||||
sorted_labels.sort()
|
||||
sorted_labels.reverse()
|
||||
for freq, label_str in sorted_labels:
|
||||
self.add_action(int(action), label_str)
|
||||
self.labels[action][label_str] = freq
|
||||
|
||||
def add_action(self, int action, label_name):
|
||||
cdef attr_t label_id
|
||||
if not isinstance(label_name, int) and \
|
||||
|
@ -164,6 +173,14 @@ cdef class TransitionSystem:
|
|||
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id)
|
||||
assert self.c[self.n_moves].label == label_id
|
||||
self.n_moves += 1
|
||||
if self.labels.get(action, []):
|
||||
new_freq = min(self.labels[action].values())
|
||||
else:
|
||||
self.labels[action] = Counter()
|
||||
new_freq = -1
|
||||
if new_freq > 0:
|
||||
new_freq = 0
|
||||
self.labels[action][label_name] = new_freq-1
|
||||
return 1
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
|
@ -178,26 +195,18 @@ cdef class TransitionSystem:
|
|||
|
||||
def to_bytes(self, **exclude):
|
||||
transitions = []
|
||||
for trans in self.c[:self.n_moves]:
|
||||
transitions.append({
|
||||
'clas': trans.clas,
|
||||
'move': trans.move,
|
||||
'label': self.strings[trans.label],
|
||||
'name': self.move_name(trans.move, trans.label)
|
||||
})
|
||||
serializers = {
|
||||
'transitions': lambda: json_dumps(transitions),
|
||||
'moves': lambda: json_dumps(self.labels),
|
||||
'strings': lambda: self.strings.to_bytes()
|
||||
}
|
||||
return util.to_bytes(serializers, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
transitions = []
|
||||
labels = {}
|
||||
deserializers = {
|
||||
'transitions': lambda b: transitions.extend(ujson.loads(b)),
|
||||
'moves': lambda b: labels.update(ujson.loads(b)),
|
||||
'strings': lambda b: self.strings.from_bytes(b)
|
||||
}
|
||||
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
||||
for trans in transitions:
|
||||
self.add_action(trans['move'], trans['label'])
|
||||
self.initialize_actions(labels)
|
||||
return self
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
import random
|
||||
import numpy.random
|
||||
|
||||
from ..pipeline import TextCategorizer
|
||||
from ..lang.en import English
|
||||
|
@ -9,6 +10,8 @@ from ..gold import GoldParse
|
|||
|
||||
|
||||
def test_textcat_learns_multilabel():
|
||||
random.seed(0)
|
||||
numpy.random.seed(0)
|
||||
docs = []
|
||||
nlp = English()
|
||||
vocab = nlp.vocab
|
||||
|
@ -22,7 +25,7 @@ def test_textcat_learns_multilabel():
|
|||
for letter in letters:
|
||||
model.add_label(letter)
|
||||
optimizer = model.begin_training()
|
||||
for i in range(20):
|
||||
for i in range(30):
|
||||
losses = {}
|
||||
Ys = [GoldParse(doc, cats=cats) for doc, cats in docs]
|
||||
Xs = [doc for doc, cats in docs]
|
||||
|
|
|
@ -19,6 +19,9 @@ ctypedef fused LexemeOrToken:
|
|||
const_TokenC_ptr
|
||||
|
||||
|
||||
cdef int set_children_from_heads(TokenC* tokens, int length) except -1
|
||||
|
||||
|
||||
cdef int token_by_start(const TokenC* tokens, int length, int start_char) except -2
|
||||
|
||||
|
||||
|
|
|
@ -436,6 +436,29 @@ def decaying(start, stop, decay):
|
|||
nr_upd += 1
|
||||
|
||||
|
||||
def minibatch_by_words(items, size, count_words=len):
|
||||
'''Create minibatches of a given number of words.'''
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
size_ = size
|
||||
items = iter(items)
|
||||
while True:
|
||||
batch_size = next(size_)
|
||||
batch = []
|
||||
while batch_size >= 0:
|
||||
try:
|
||||
doc, gold = next(items)
|
||||
except StopIteration:
|
||||
if batch:
|
||||
yield batch
|
||||
return
|
||||
batch_size -= count_words(doc)
|
||||
batch.append((doc, gold))
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
|
||||
def itershuffle(iterable, bufsize=1000):
|
||||
"""Shuffle an iterator. This works by holding `bufsize` items back
|
||||
and yielding them sometime later. Obviously, this is not unbiased –
|
||||
|
|
|
@ -12,11 +12,24 @@ p Create a #[code GoldCorpus].
|
|||
|
||||
+table(["Name", "Type", "Description"])
|
||||
+row
|
||||
+cell #[code train_path]
|
||||
+cell unicode or #[code Path]
|
||||
+cell File or directory of training data.
|
||||
+cell #[code train]
|
||||
+cell unicode or #[code Path] or iterable
|
||||
+cell
|
||||
| Training data, as a path (file or directory) or iterable. If an
|
||||
| iterable, each item should be a #[code (text, paragraphs)]
|
||||
| tuple, where each paragraph is a tuple
|
||||
| #[code.u-break (sentences, brackets)],and each sentence is a
|
||||
| tuple #[code.u-break (ids, words, tags, heads, ner)]. See the
|
||||
| implementation of
|
||||
| #[+src(gh("spacy", "spacy/gold.pyx")) #[code gold.read_json_file]]
|
||||
| for further details.
|
||||
|
||||
+row
|
||||
+cell #[code dev_path]
|
||||
+cell unicode or #[code Path]
|
||||
+cell File or directory of development data.
|
||||
+cell #[code dev]
|
||||
+cell unicode or #[code Path] or iterable
|
||||
+cell Development data, as a path (file or directory) or iterable.
|
||||
|
||||
+row("foot")
|
||||
+cell returns
|
||||
+cell #[code GoldCorpus]
|
||||
+cell The newly constructed object.
|
||||
|
|
Loading…
Reference in New Issue
Block a user