Merge Span._ and Span.as_doc methods

This commit is contained in:
Matthew Honnibal 2017-10-09 22:00:15 -05:00
commit e0a9b02b67
21 changed files with 333 additions and 97 deletions

View File

@ -21,7 +21,6 @@ import thinc.neural._classes.layernorm
thinc.neural._classes.layernorm.set_compat_six_eight(False) thinc.neural._classes.layernorm.set_compat_six_eight(False)
def train_textcat(tokenizer, textcat, def train_textcat(tokenizer, textcat,
train_texts, train_cats, dev_texts, dev_cats, train_texts, train_cats, dev_texts, dev_cats,
n_iter=20): n_iter=20):
@ -57,13 +56,15 @@ def evaluate(tokenizer, textcat, texts, cats):
for i, doc in enumerate(textcat.pipe(docs)): for i, doc in enumerate(textcat.pipe(docs)):
gold = cats[i] gold = cats[i]
for label, score in doc.cats.items(): for label, score in doc.cats.items():
if score >= 0.5 and label in gold: if label not in gold:
continue
if score >= 0.5 and gold[label] >= 0.5:
tp += 1. tp += 1.
elif score >= 0.5 and label not in gold: elif score >= 0.5 and gold[label] < 0.5:
fp += 1. fp += 1.
elif score < 0.5 and label not in gold: elif score < 0.5 and gold[label] < 0.5:
tn += 1 tn += 1
if score < 0.5 and label in gold: elif score < 0.5 and gold[label] >= 0.5:
fn += 1 fn += 1
precis = tp / (tp + fp) precis = tp / (tp + fp)
recall = tp / (tp + fn) recall = tp / (tp + fn)
@ -80,7 +81,7 @@ def load_data(limit=0):
train_data = train_data[-limit:] train_data = train_data[-limit:]
texts, labels = zip(*train_data) texts, labels = zip(*train_data)
cats = [(['POSITIVE'] if y else []) for y in labels] cats = [{'POSITIVE': bool(y)} for y in labels]
split = int(len(train_data) * 0.8) split = int(len(train_data) * 0.8)
@ -97,7 +98,7 @@ def main(model_loc=None):
textcat = TextCategorizer(tokenizer.vocab, labels=['POSITIVE']) textcat = TextCategorizer(tokenizer.vocab, labels=['POSITIVE'])
print("Load IMDB data") print("Load IMDB data")
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=1000) (train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=2000)
print("Itn.\tLoss\tP\tR\tF") print("Itn.\tLoss\tP\tR\tF")
progress = '{i:d} {loss:.3f} {textcat_p:.3f} {textcat_r:.3f} {textcat_f:.3f}' progress = '{i:d} {loss:.3f} {textcat_p:.3f} {textcat_r:.3f} {textcat_f:.3f}'

View File

@ -264,7 +264,8 @@ def HistoryFeatures(nr_class, hist_size=8, nr_dim=8):
return layerize(noop()) return layerize(noop())
embed_tables = [Embed(nr_dim, nr_class, column=i, name='embed%d') embed_tables = [Embed(nr_dim, nr_class, column=i, name='embed%d')
for i in range(hist_size)] for i in range(hist_size)]
embed = concatenate(*embed_tables) embed = chain(concatenate(*embed_tables),
LN(Maxout(hist_size*nr_dim, hist_size*nr_dim)))
ops = embed.ops ops = embed.ops
def add_history_fwd(vectors_hists, drop=0.): def add_history_fwd(vectors_hists, drop=0.):
vectors, hist_ids = vectors_hists vectors, hist_ids = vectors_hists
@ -742,5 +743,3 @@ def concatenate_lists(*layers, **kwargs): # pragma: no cover
return ys, concatenate_lists_bwd return ys, concatenate_lists_bwd
model = wrap(concatenate_lists_fwd, concat) model = wrap(concatenate_lists_fwd, concat)
return model return model

View File

@ -3,13 +3,13 @@
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py # https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
__title__ = 'spacy-nightly' __title__ = 'spacy-nightly'
__version__ = '2.0.0a16' __version__ = '2.0.0a17'
__summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython' __summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython'
__uri__ = 'https://spacy.io' __uri__ = 'https://spacy.io'
__author__ = 'Explosion AI' __author__ = 'Explosion AI'
__email__ = 'contact@explosion.ai' __email__ = 'contact@explosion.ai'
__license__ = 'MIT' __license__ = 'MIT'
__release__ = True __release__ = False
__docs_models__ = 'https://alpha.spacy.io/usage/models' __docs_models__ = 'https://alpha.spacy.io/usage/models'
__download_url__ = 'https://github.com/explosion/spacy-models/releases/download' __download_url__ = 'https://github.com/explosion/spacy-models/releases/download'

View File

@ -4,7 +4,7 @@ from __future__ import unicode_literals
import plac import plac
from pathlib import Path from pathlib import Path
from .converters import conllu2json, iob2json from .converters import conllu2json, iob2json, conll_ner2json
from ..util import prints from ..util import prints
# Converters are matched by file extension. To add a converter, add a new entry # Converters are matched by file extension. To add a converter, add a new entry
@ -12,9 +12,10 @@ from ..util import prints
# from /converters. # from /converters.
CONVERTERS = { CONVERTERS = {
'.conllu': conllu2json, 'conllu': conllu2json,
'.conll': conllu2json, 'conll': conllu2json,
'.iob': iob2json, 'ner': conll_ner2json,
'iob': iob2json,
} }
@ -22,9 +23,11 @@ CONVERTERS = {
input_file=("input file", "positional", None, str), input_file=("input file", "positional", None, str),
output_dir=("output directory for converted file", "positional", None, str), output_dir=("output directory for converted file", "positional", None, str),
n_sents=("Number of sentences per doc", "option", "n", int), n_sents=("Number of sentences per doc", "option", "n", int),
converter=("Name of converter (auto, iob, conllu or ner)", "option", "c", str),
morphology=("Enable appending morphology to tags", "flag", "m", bool) morphology=("Enable appending morphology to tags", "flag", "m", bool)
) )
def convert(cmd, input_file, output_dir, n_sents=1, morphology=False): def convert(cmd, input_file, output_dir, n_sents=1, morphology=False,
converter='auto'):
""" """
Convert files into JSON format for use with train command and other Convert files into JSON format for use with train command and other
experiment management functions. experiment management functions.
@ -35,9 +38,11 @@ def convert(cmd, input_file, output_dir, n_sents=1, morphology=False):
prints(input_path, title="Input file not found", exits=1) prints(input_path, title="Input file not found", exits=1)
if not output_path.exists(): if not output_path.exists():
prints(output_path, title="Output directory not found", exits=1) prints(output_path, title="Output directory not found", exits=1)
file_ext = input_path.suffix if converter == 'auto':
if not file_ext in CONVERTERS: converter = input_path.suffix[1:]
prints("Can't find converter for %s" % input_path.parts[-1], if not converter in CONVERTERS:
title="Unknown format", exits=1) prints("Can't find converter for %s" % converter,
CONVERTERS[file_ext](input_path, output_path, title="Unknown format", exits=1)
n_sents=n_sents, use_morphology=morphology) func = CONVERTERS[converter]
func(input_path, output_path,
n_sents=n_sents, use_morphology=morphology)

View File

@ -1,2 +1,3 @@
from .conllu2json import conllu2json from .conllu2json import conllu2json
from .iob2json import iob2json from .iob2json import iob2json
from .conll_ner2json import conll_ner2json

View File

@ -114,15 +114,33 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=10, n_sents=0,
nlp.to_disk(epoch_model_path) nlp.to_disk(epoch_model_path)
nlp_loaded = lang_class(pipeline=pipeline) nlp_loaded = lang_class(pipeline=pipeline)
nlp_loaded = nlp_loaded.from_disk(epoch_model_path) nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
scorer = nlp_loaded.evaluate( dev_docs = list(corpus.dev_docs(
list(corpus.dev_docs(
nlp_loaded, nlp_loaded,
gold_preproc=gold_preproc))) gold_preproc=gold_preproc))
nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs)
start_time = timer()
scorer = nlp_loaded.evaluate(dev_docs)
end_time = timer()
if use_gpu < 0:
gpu_wps = None
cpu_wps = nwords/(end_time-start_time)
else:
gpu_wps = nwords/(end_time-start_time)
with Model.use_device('cpu'):
nlp_loaded = lang_class(pipeline=pipeline)
nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
dev_docs = list(corpus.dev_docs(
nlp_loaded, gold_preproc=gold_preproc))
start_time = timer()
scorer = nlp_loaded.evaluate(dev_docs)
end_time = timer()
cpu_wps = nwords/(end_time-start_time)
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json') acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
with acc_loc.open('w') as file_: with acc_loc.open('w') as file_:
file_.write(json_dumps(scorer.scores)) file_.write(json_dumps(scorer.scores))
meta_loc = output_path / ('model%d' % i) / 'meta.json' meta_loc = output_path / ('model%d' % i) / 'meta.json'
meta['accuracy'] = scorer.scores meta['accuracy'] = scorer.scores
meta['speed'] = {'nwords': nwords, 'cpu':cpu_wps, 'gpu': gpu_wps}
meta['lang'] = nlp.lang meta['lang'] = nlp.lang
meta['pipeline'] = pipeline meta['pipeline'] = pipeline
meta['spacy_version'] = '>=%s' % about.__version__ meta['spacy_version'] = '>=%s' % about.__version__
@ -132,7 +150,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=10, n_sents=0,
with meta_loc.open('w') as file_: with meta_loc.open('w') as file_:
file_.write(json_dumps(meta)) file_.write(json_dumps(meta))
util.set_env_log(True) util.set_env_log(True)
print_progress(i, losses, scorer.scores) print_progress(i, losses, scorer.scores, cpu_wps=cpu_wps, gpu_wps=gpu_wps)
finally: finally:
print("Saving model...") print("Saving model...")
try: try:
@ -153,16 +171,17 @@ def _render_parses(i, to_render):
file_.write(html) file_.write(html)
def print_progress(itn, losses, dev_scores, wps=0.0): def print_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0):
scores = {} scores = {}
for col in ['dep_loss', 'tag_loss', 'uas', 'tags_acc', 'token_acc', for col in ['dep_loss', 'tag_loss', 'uas', 'tags_acc', 'token_acc',
'ents_p', 'ents_r', 'ents_f', 'wps']: 'ents_p', 'ents_r', 'ents_f', 'cpu_wps', 'gpu_wps']:
scores[col] = 0.0 scores[col] = 0.0
scores['dep_loss'] = losses.get('parser', 0.0) scores['dep_loss'] = losses.get('parser', 0.0)
scores['ner_loss'] = losses.get('ner', 0.0) scores['ner_loss'] = losses.get('ner', 0.0)
scores['tag_loss'] = losses.get('tagger', 0.0) scores['tag_loss'] = losses.get('tagger', 0.0)
scores.update(dev_scores) scores.update(dev_scores)
scores['wps'] = wps scores['cpu_wps'] = cpu_wps
scores['gpu_wps'] = gpu_wps or 0.0
tpl = '\t'.join(( tpl = '\t'.join((
'{:d}', '{:d}',
'{dep_loss:.3f}', '{dep_loss:.3f}',
@ -173,7 +192,9 @@ def print_progress(itn, losses, dev_scores, wps=0.0):
'{ents_f:.3f}', '{ents_f:.3f}',
'{tags_acc:.3f}', '{tags_acc:.3f}',
'{token_acc:.3f}', '{token_acc:.3f}',
'{wps:.1f}')) '{cpu_wps:.1f}',
'{gpu_wps:.1f}',
))
print(tpl.format(itn, **scores)) print(tpl.format(itn, **scores))

View File

@ -387,7 +387,7 @@ cdef class GoldParse:
def __init__(self, doc, annot_tuples=None, words=None, tags=None, heads=None, def __init__(self, doc, annot_tuples=None, words=None, tags=None, heads=None,
deps=None, entities=None, make_projective=False, deps=None, entities=None, make_projective=False,
cats=tuple()): cats=None):
"""Create a GoldParse. """Create a GoldParse.
doc (Doc): The document the annotations refer to. doc (Doc): The document the annotations refer to.
@ -398,12 +398,15 @@ cdef class GoldParse:
entities (iterable): A sequence of named entity annotations, either as entities (iterable): A sequence of named entity annotations, either as
BILUO tag strings, or as `(start_char, end_char, label)` tuples, BILUO tag strings, or as `(start_char, end_char, label)` tuples,
representing the entity positions. representing the entity positions.
cats (iterable): A sequence of labels for text classification. Each cats (dict): Labels for text classification. Each key in the dictionary
label may be a string or an int, or a `(start_char, end_char, label)` may be a string or an int, or a `(start_char, end_char, label)`
tuple, indicating that the label is applied to only part of the tuple, indicating that the label is applied to only part of the
document (usually a sentence). Unlike entity annotations, label document (usually a sentence). Unlike entity annotations, label
annotations can overlap, i.e. a single word can be covered by annotations can overlap, i.e. a single word can be covered by
multiple labelled spans. multiple labelled spans. The TextCategorizer component expects
true examples of a label to have the value 1.0, and negative examples
of a label to have the value 0.0. Labels not in the dictionary are
treated as missing -- the gradient for those labels will be zero.
RETURNS (GoldParse): The newly constructed object. RETURNS (GoldParse): The newly constructed object.
""" """
if words is None: if words is None:
@ -434,7 +437,7 @@ cdef class GoldParse:
self.c.sent_start = <int*>self.mem.alloc(len(doc), sizeof(int)) self.c.sent_start = <int*>self.mem.alloc(len(doc), sizeof(int))
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition)) self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
self.cats = list(cats) self.cats = {} if cats is None else dict(cats)
self.words = [None] * len(doc) self.words = [None] * len(doc)
self.tags = [None] * len(doc) self.tags = [None] * len(doc)
self.heads = [None] * len(doc) self.heads = [None] * len(doc)

View File

@ -126,7 +126,7 @@ def word_shape(text):
LEX_ATTRS = { LEX_ATTRS = {
attrs.LOWER: lambda string: string.lower(), attrs.LOWER: lambda string: string.lower(),
attrs.NORM: lambda string: string.lower(), attrs.NORM: lambda string: string.lower(),
attrs.PREFIX: lambda string: string[0], attrs.PREFIX: lambda string: string[:3],
attrs.SUFFIX: lambda string: string[-3:], attrs.SUFFIX: lambda string: string[-3:],
attrs.CLUSTER: lambda string: 0, attrs.CLUSTER: lambda string: 0,
attrs.IS_ALPHA: lambda string: string.isalpha(), attrs.IS_ALPHA: lambda string: string.isalpha(),

View File

@ -158,11 +158,13 @@ class BaseThincComponent(object):
def to_bytes(self, **exclude): def to_bytes(self, **exclude):
"""Serialize the pipe to a bytestring.""" """Serialize the pipe to a bytestring."""
serialize = OrderedDict(( serialize = OrderedDict()
('cfg', lambda: json_dumps(self.cfg)), serialize['cfg'] = lambda: json_dumps(self.cfg)
('model', lambda: self.model.to_bytes()), if self.model in (True, False, None):
('vocab', lambda: self.vocab.to_bytes()) serialize['model'] = lambda: self.model
)) else:
serialize['model'] = self.model.to_bytes
serialize['vocab'] = self.vocab.to_bytes
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, **exclude): def from_bytes(self, bytes_data, **exclude):
@ -183,11 +185,11 @@ class BaseThincComponent(object):
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
"""Serialize the pipe to disk.""" """Serialize the pipe to disk."""
serialize = OrderedDict(( serialize = OrderedDict()
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))), serialize['cfg'] = lambda p: p.open('w').write(json_dumps(self.cfg))
('vocab', lambda p: self.vocab.to_disk(p)), serialize['vocab'] = lambda p: self.vocab.to_disk(p)
('model', lambda p: p.open('wb').write(self.model.to_bytes())), if self.model not in (None, True, False):
)) serialize['model'] = lambda p: p.open('wb').write(self.model.to_bytes())
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
def from_disk(self, path, **exclude): def from_disk(self, path, **exclude):
@ -438,13 +440,16 @@ class NeuralTagger(BaseThincComponent):
yield yield
def to_bytes(self, **exclude): def to_bytes(self, **exclude):
serialize = OrderedDict(( serialize = OrderedDict()
('model', lambda: self.model.to_bytes()), if self.model in (None, True, False):
('vocab', lambda: self.vocab.to_bytes()), serialize['model'] = lambda: self.model
('tag_map', lambda: msgpack.dumps(self.vocab.morphology.tag_map, else:
use_bin_type=True, serialize['model'] = self.model.to_bytes
encoding='utf8')) serialize['vocab'] = self.vocab.to_bytes
))
serialize['tag_map'] = lambda: msgpack.dumps(self.vocab.morphology.tag_map,
use_bin_type=True,
encoding='utf8')
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, **exclude): def from_bytes(self, bytes_data, **exclude):
@ -552,7 +557,6 @@ class NeuralLabeller(NeuralTagger):
label = self.make_label(i, words, tags, heads, deps, ents) label = self.make_label(i, words, tags, heads, deps, ents)
if label is not None and label not in self.labels: if label is not None and label not in self.labels:
self.labels[label] = len(self.labels) self.labels[label] = len(self.labels)
print(len(self.labels))
if self.model is True: if self.model is True:
token_vector_width = util.env_opt('token_vector_width') token_vector_width = util.env_opt('token_vector_width')
self.model = chain( self.model = chain(
@ -721,11 +725,17 @@ class TextCategorizer(BaseThincComponent):
def get_loss(self, docs, golds, scores): def get_loss(self, docs, golds, scores):
truths = numpy.zeros((len(golds), len(self.labels)), dtype='f') truths = numpy.zeros((len(golds), len(self.labels)), dtype='f')
not_missing = numpy.ones((len(golds), len(self.labels)), dtype='f')
for i, gold in enumerate(golds): for i, gold in enumerate(golds):
for j, label in enumerate(self.labels): for j, label in enumerate(self.labels):
truths[i, j] = label in gold.cats if label in gold.cats:
truths[i, j] = gold.cats[label]
else:
not_missing[i, j] = 0.
truths = self.model.ops.asarray(truths) truths = self.model.ops.asarray(truths)
not_missing = self.model.ops.asarray(not_missing)
d_scores = (scores-truths) / scores.shape[0] d_scores = (scores-truths) / scores.shape[0]
d_scores *= not_missing
mean_square_error = ((scores-truths)**2).sum(axis=1).mean() mean_square_error = ((scores-truths)**2).sum(axis=1).mean()
return mean_square_error, d_scores return mean_square_error, d_scores

View File

@ -61,13 +61,13 @@ cdef struct TokenC:
attr_t sense attr_t sense
int head int head
attr_t dep attr_t dep
bint sent_start
uint32_t l_kids uint32_t l_kids
uint32_t r_kids uint32_t r_kids
uint32_t l_edge uint32_t l_edge
uint32_t r_edge uint32_t r_edge
int sent_start
int ent_iob int ent_iob
attr_t ent_type # TODO: Is there a better way to do this? Multiple sources of truth.. attr_t ent_type # TODO: Is there a better way to do this? Multiple sources of truth..
hash_t ent_id hash_t ent_id

View File

@ -307,6 +307,8 @@ cdef cppclass StateC:
this._stack[this._s_i] = this.B(0) this._stack[this._s_i] = this.B(0)
this._s_i += 1 this._s_i += 1
this._b_i += 1 this._b_i += 1
if this.B_(0).sent_start == 1:
this.set_break(this.B(0))
if this._b_i > this._break: if this._b_i > this._break:
this._break = -1 this._break = -1
@ -383,7 +385,7 @@ cdef cppclass StateC:
void set_break(int i) nogil: void set_break(int i) nogil:
if 0 <= i < this.length: if 0 <= i < this.length:
this._sent[i].sent_start = True this._sent[i].sent_start = 1
this._break = this._b_i this._break = this._b_i
void clone(const StateC* src) nogil: void clone(const StateC* src) nogil:

View File

@ -118,7 +118,7 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
cdef class Shift: cdef class Shift:
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_start return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and st.B_(0).sent_start != 1
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
@ -178,7 +178,7 @@ cdef class Reduce:
cdef class LeftArc: cdef class LeftArc:
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
return not st.B_(0).sent_start return st.B_(0).sent_start != 1
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
@ -212,7 +212,7 @@ cdef class LeftArc:
cdef class RightArc: cdef class RightArc:
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
return not st.B_(0).sent_start return st.B_(0).sent_start != 1
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
@ -248,6 +248,10 @@ cdef class Break:
return False return False
elif st.stack_depth() < 1: elif st.stack_depth() < 1:
return False return False
elif st.B_(0).l_edge < 0:
return False
elif st._sent[st.B_(0).l_edge].sent_start < 0:
return False
else: else:
return True return True

View File

@ -219,30 +219,28 @@ cdef class BiluoPushDown(TransitionSystem):
raise Exception(move) raise Exception(move)
return t return t
#def add_action(self, int action, label_name): def add_action(self, int action, label_name):
# cdef attr_t label_id cdef attr_t label_id
# if not isinstance(label_name, (int, long)): if not isinstance(label_name, (int, long)):
# label_id = self.strings.add(label_name) label_id = self.strings.add(label_name)
# else: else:
# label_id = label_name label_id = label_name
# if action == OUT and label_id != 0: if action == OUT and label_id != 0:
# return return
# if action == MISSING or action == ISNT: if action == MISSING or action == ISNT:
# return return
# # Check we're not creating a move we already have, so that this is # Check we're not creating a move we already have, so that this is
# # idempotent # idempotent
# for trans in self.c[:self.n_moves]: for trans in self.c[:self.n_moves]:
# if trans.move == action and trans.label == label_id: if trans.move == action and trans.label == label_id:
# return 0 return 0
# if self.n_moves >= self._size: if self.n_moves >= self._size:
# self._size *= 2 self._size *= 2
# self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0])) self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0]))
# self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id) self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id)
# assert self.c[self.n_moves].label == label_id assert self.c[self.n_moves].label == label_id
# self.n_moves += 1 self.n_moves += 1
# return 1 return 1
cdef int initialize_state(self, StateC* st) nogil: cdef int initialize_state(self, StateC* st) nogil:
# This is especially necessary when we use limited training data. # This is especially necessary when we use limited training data.

View File

@ -51,7 +51,7 @@ from .._ml import Tok2Vec, doc2feats, rebatch, fine_tune
from .._ml import Residual, drop_layer, flatten from .._ml import Residual, drop_layer, flatten
from .._ml import link_vectors_to_models from .._ml import link_vectors_to_models
from .._ml import HistoryFeatures from .._ml import HistoryFeatures
from ..compat import json_dumps from ..compat import json_dumps, copy_array
from . import _parse_features from . import _parse_features
from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport CONTEXT_SIZE
@ -239,13 +239,13 @@ cdef class Parser:
""" """
@classmethod @classmethod
def Model(cls, nr_class, **cfg): def Model(cls, nr_class, **cfg):
depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 2)) depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 0))
token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128)) token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128))
hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 128)) hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 128))
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 1)) parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 3))
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000)) embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000))
hist_size = util.env_opt('history_feats', cfg.get('hist_size', 4)) hist_size = util.env_opt('history_feats', cfg.get('hist_size', 0))
hist_width = util.env_opt('history_width', cfg.get('hist_width', 16)) hist_width = util.env_opt('history_width', cfg.get('hist_width', 0))
if hist_size >= 1 and depth == 0: if hist_size >= 1 and depth == 0:
raise ValueError("Inconsistent hyper-params: " raise ValueError("Inconsistent hyper-params: "
"history_feats >= 1 but parser_hidden_depth==0") "history_feats >= 1 but parser_hidden_depth==0")
@ -789,12 +789,22 @@ cdef class Parser:
return [] return []
def add_label(self, label): def add_label(self, label):
resized = False
for action in self.moves.action_types: for action in self.moves.action_types:
added = self.moves.add_action(action, label) added = self.moves.add_action(action, label)
if added: if added:
# Important that the labels be stored as a list! We need the # Important that the labels be stored as a list! We need the
# order, or the model goes out of synch # order, or the model goes out of synch
self.cfg.setdefault('extra_labels', []).append(label) self.cfg.setdefault('extra_labels', []).append(label)
resized = True
if self.model not in (True, False, None) and resized:
# Weights are stored in (nr_out, nr_in) format, so we're basically
# just adding rows here.
smaller = self.model[-1]._layers[-1]
larger = Affine(self.moves.n_moves, smaller.nI)
copy_array(larger.W[:smaller.nO], smaller.W)
copy_array(larger.b[:smaller.nO], smaller.b)
self.model[-1]._layers[-1] = larger
def begin_training(self, gold_tuples, pipeline=None, **cfg): def begin_training(self, gold_tuples, pipeline=None, **cfg):
if 'model' in cfg: if 'model' in cfg:

View File

@ -0,0 +1,68 @@
'''Test the ability to add a label to a (potentially trained) parsing model.'''
from __future__ import unicode_literals
import pytest
import numpy.random
from thinc.neural.optimizers import Adam
from thinc.neural.ops import NumpyOps
from ...attrs import NORM
from ...gold import GoldParse
from ...vocab import Vocab
from ...tokens import Doc
from ...pipeline import NeuralDependencyParser
numpy.random.seed(0)
@pytest.fixture
def vocab():
return Vocab(lex_attr_getters={NORM: lambda s: s})
@pytest.fixture
def parser(vocab):
parser = NeuralDependencyParser(vocab)
parser.cfg['token_vector_width'] = 4
parser.cfg['hidden_width'] = 6
parser.cfg['hist_size'] = 0
parser.add_label('left')
parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001)
for i in range(30):
losses = {}
doc = Doc(vocab, words=['a', 'b', 'c', 'd'])
gold = GoldParse(doc, heads=[1, 1, 3, 3],
deps=['left', 'ROOT', 'left', 'ROOT'])
parser.update([doc], [gold], sgd=sgd, losses=losses)
return parser
def test_add_label(parser):
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc = parser(doc)
assert doc[0].head.i == 1
assert doc[0].dep_ == 'left'
assert doc[1].head.i == 1
assert doc[2].head.i == 3
assert doc[2].head.i == 3
parser.add_label('right')
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc = parser(doc)
assert doc[0].head.i == 1
assert doc[0].dep_ == 'left'
assert doc[1].head.i == 1
assert doc[2].head.i == 3
assert doc[2].head.i == 3
sgd = Adam(NumpyOps(), 0.001)
for i in range(10):
losses = {}
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
gold = GoldParse(doc, heads=[1, 1, 3, 3],
deps=['right', 'ROOT', 'left', 'ROOT'])
parser.update([doc], [gold], sgd=sgd, losses=losses)
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc = parser(doc)
assert doc[0].dep_ == 'right'
assert doc[2].dep_ == 'left'

View File

@ -35,7 +35,8 @@ def parser(vocab, arc_eager):
@pytest.fixture @pytest.fixture
def model(arc_eager, tok2vec): def model(arc_eager, tok2vec):
return Parser.Model(arc_eager.n_moves, token_vector_width=tok2vec.nO)[0] return Parser.Model(arc_eager.n_moves, token_vector_width=tok2vec.nO,
hist_size=0)[0]
@pytest.fixture @pytest.fixture
def doc(vocab): def doc(vocab):
@ -51,7 +52,7 @@ def test_can_init_nn_parser(parser):
def test_build_model(parser): def test_build_model(parser):
parser.model = Parser.Model(parser.moves.n_moves)[0] parser.model = Parser.Model(parser.moves.n_moves, hist_size=0)[0]
assert parser.model is not None assert parser.model is not None

View File

@ -0,0 +1,73 @@
'''Test that the parser respects preset sentence boundaries.'''
from __future__ import unicode_literals
import pytest
from thinc.neural.optimizers import Adam
from thinc.neural.ops import NumpyOps
from ...attrs import NORM
from ...gold import GoldParse
from ...vocab import Vocab
from ...tokens import Doc
from ...pipeline import NeuralDependencyParser
@pytest.fixture
def vocab():
return Vocab(lex_attr_getters={NORM: lambda s: s})
@pytest.fixture
def parser(vocab):
parser = NeuralDependencyParser(vocab)
parser.cfg['token_vector_width'] = 4
parser.cfg['hidden_width'] = 32
#parser.add_label('right')
parser.add_label('left')
parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001)
for i in range(10):
losses = {}
doc = Doc(vocab, words=['a', 'b', 'c', 'd'])
gold = GoldParse(doc, heads=[1, 1, 3, 3],
deps=['left', 'ROOT', 'left', 'ROOT'])
parser.update([doc], [gold], sgd=sgd, losses=losses)
return parser
def test_no_sentences(parser):
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc = parser(doc)
assert len(list(doc.sents)) == 2
def test_sents_1(parser):
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc[2].sent_start = True
doc = parser(doc)
assert len(list(doc.sents)) >= 2
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc[1].sent_start = False
doc[2].sent_start = True
doc[3].sent_start = False
doc = parser(doc)
assert len(list(doc.sents)) == 2
def test_sents_1_2(parser):
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc[1].sent_start = True
doc[2].sent_start = True
doc = parser(doc)
assert len(list(doc.sents)) == 3
def test_sents_1_3(parser):
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc[1].sent_start = True
doc[3].sent_start = True
doc = parser(doc)
assert len(list(doc.sents)) == 4
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
doc[1].sent_start = True
doc[2].sent_start = False
doc[3].sent_start = True
doc = parser(doc)
assert len(list(doc.sents)) == 3

View File

@ -0,0 +1,9 @@
import spacy
import spacy.lang.en
from spacy.pipeline import TextCategorizer
def test_bytes_serialize_issue_1105():
nlp = spacy.lang.en.English()
tokenizer = nlp.tokenizer
textcat = TextCategorizer(tokenizer.vocab, labels=['ENTITY', 'ACTION', 'MODIFIER'])
textcat_bytes = textcat.to_bytes()

View File

@ -506,7 +506,7 @@ cdef class Doc:
cdef int i cdef int i
start = 0 start = 0
for i in range(1, self.length): for i in range(1, self.length):
if self.c[i].sent_start: if self.c[i].sent_start == 1:
yield Span(self, start, i) yield Span(self, start, i)
start = i start = i
if start != self.length: if start != self.length:

View File

@ -129,6 +129,29 @@ cdef class Span:
def _(self): def _(self):
return Underscore(Underscore.span_extensions, self, return Underscore(Underscore.span_extensions, self,
start=self.start_char, end=self.end_char) start=self.start_char, end=self.end_char)
def as_doc(self):
'''Create a Doc object view of the Span's data.
This is mostly useful for C-typed interfaces.
'''
cdef Doc doc = Doc(self.doc.vocab)
doc.length = self.end-self.start
doc.c = &self.doc.c[self.start]
doc.mem = self.doc.mem
doc.is_parsed = self.doc.is_parsed
doc.is_tagged = self.doc.is_tagged
doc.noun_chunks_iterator = self.doc.noun_chunks_iterator
doc.user_hooks = self.doc.user_hooks
doc.user_span_hooks = self.doc.user_span_hooks
doc.user_token_hooks = self.doc.user_token_hooks
doc.vector = self.vector
doc.vector_norm = self.vector_norm
for key, value in self.doc.cats.items():
if hasattr(key, '__len__') and len(key) == 3:
cat_start, cat_end, cat_label = key
if cat_start == self.start_char and cat_end == self.end_char:
doc.cats[cat_label] = value
return doc
def merge(self, *args, **attributes): def merge(self, *args, **attributes):
"""Retokenize the document, such that the span is merged into a single """Retokenize the document, such that the span is merged into a single

View File

@ -300,13 +300,21 @@ cdef class Token:
def __get__(self): def __get__(self):
return self.c.sent_start return self.c.sent_start
def __set__(self, bint value): def __set__(self, value):
if self.doc.is_parsed: if self.doc.is_parsed:
raise ValueError( raise ValueError(
'Refusing to write to token.sent_start if its document is parsed, ' 'Refusing to write to token.sent_start if its document is parsed, '
'because this may cause inconsistent state. ' 'because this may cause inconsistent state. '
'See https://github.com/spacy-io/spaCy/issues/235 for workarounds.') 'See https://github.com/spacy-io/spaCy/issues/235 for workarounds.')
self.c.sent_start = value if value is None:
self.c.sent_start = 0
elif value is True:
self.c.sent_start = 1
elif value is False:
self.c.sent_start = -1
else:
raise ValueError("Invalid value for token.sent_start -- must be one of "
"None, True, False")
property lefts: property lefts:
def __get__(self): def __get__(self):