Resolve fine-tuning conflict

This commit is contained in:
Matthew Honnibal 2017-09-17 05:30:04 -05:00
commit 43210abacc
23 changed files with 283 additions and 198 deletions

View File

@ -1 +1,56 @@
environment:
matrix:
# For Python versions available on Appveyor, see
# http://www.appveyor.com/docs/installed-software#python
# The list here is complete (excluding Python 2.6, which
# isn't covered by this document) at the time of writing.
- PYTHON: "C:\\Python27"
#- PYTHON: "C:\\Python33"
#- PYTHON: "C:\\Python34"
#- PYTHON: "C:\\Python35"
#- PYTHON: "C:\\Python27-x64"
#- PYTHON: "C:\\Python33-x64"
#- DISTUTILS_USE_SDK: "1"
#- PYTHON: "C:\\Python34-x64"
#- DISTUTILS_USE_SDK: "1"
#- PYTHON: "C:\\Python35-x64"
- PYTHON: "C:\\Python36-x64"
install:
# We need wheel installed to build wheels
- "%PYTHON%\\python.exe -m pip install wheel"
- "%PYTHON%\\python.exe -m pip install cython"
- "%PYTHON%\\python.exe -m pip install -r requirements.txt"
- "%PYTHON%\\python.exe setup.py build_ext --inplace"
- "%PYTHON%\\python.exe -m pip install -e ."
build: off build: off
test_script:
# Put your test command here.
# If you don't need to build C extensions on 64-bit Python 3.3 or 3.4,
# you can remove "build.cmd" from the front of the command, as it's
# only needed to support those cases.
# Note that you must use the environment variable %PYTHON% to refer to
# the interpreter you're using - Appveyor does not do anything special
# to put the Python version you want to use on PATH.
- "%PYTHON%\\python.exe -m pytest spacy/"
after_test:
# This step builds your wheels.
# Again, you only need build.cmd if you're building C extensions for
# 64-bit Python 3.3/3.4. And you need to use %PYTHON% to get the correct
# interpreter
- "%PYTHON%\\python.exe setup.py bdist_wheel"
artifacts:
# bdist_wheel puts your built wheel in the dist directory
- path: dist\*
#on_success:
# You can use this step to upload your artifacts to a public website.
# See Appveyor's documentation for more details. Or you can simply
# access your wheels from the Appveyor "artifacts" tab for your build.

View File

@ -13,24 +13,27 @@ Input data:
https://www.lt.informatik.tu-darmstadt.de/fileadmin/user_upload/Group_LangTech/data/GermEval2014_complete_data.zip https://www.lt.informatik.tu-darmstadt.de/fileadmin/user_upload/Group_LangTech/data/GermEval2014_complete_data.zip
Developed for: spaCy 1.7.1 Developed for: spaCy 1.7.1
Last tested for: spaCy 1.7.1 Last tested for: spaCy 2.0.0a13
''' '''
from __future__ import unicode_literals, print_function from __future__ import unicode_literals, print_function
import plac import plac
from pathlib import Path from pathlib import Path
import random import random
import json import json
from thinc.neural.optimizers import Adam
from thinc.neural.ops import NumpyOps
import tqdm
import spacy.orth as orth_funcs
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.pipeline import BeamEntityRecognizer from spacy.pipeline import TokenVectorEncoder, NeuralEntityRecognizer
from spacy.pipeline import EntityRecognizer
from spacy.tokenizer import Tokenizer from spacy.tokenizer import Tokenizer
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.attrs import * from spacy.attrs import *
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.gold import _iob_to_biluo as iob_to_biluo from spacy.gold import iob_to_biluo
from spacy.gold import minibatch
from spacy.scorer import Scorer from spacy.scorer import Scorer
import spacy.util
try: try:
unicode unicode
@ -38,95 +41,40 @@ except NameError:
unicode = str unicode = str
spacy.util.set_env_log(True)
def init_vocab(): def init_vocab():
return Vocab( return Vocab(
lex_attr_getters={ lex_attr_getters={
LOWER: lambda string: string.lower(), LOWER: lambda string: string.lower(),
SHAPE: orth_funcs.word_shape, NORM: lambda string: string.lower(),
PREFIX: lambda string: string[0], PREFIX: lambda string: string[0],
SUFFIX: lambda string: string[-3:], SUFFIX: lambda string: string[-3:],
CLUSTER: lambda string: 0,
IS_ALPHA: orth_funcs.is_alpha,
IS_ASCII: orth_funcs.is_ascii,
IS_DIGIT: lambda string: string.isdigit(),
IS_LOWER: orth_funcs.is_lower,
IS_PUNCT: orth_funcs.is_punct,
IS_SPACE: lambda string: string.isspace(),
IS_TITLE: orth_funcs.is_title,
IS_UPPER: orth_funcs.is_upper,
IS_STOP: lambda string: False,
IS_OOV: lambda string: True
}) })
def save_vocab(vocab, path):
path = Path(path)
if not path.exists():
path.mkdir()
elif not path.is_dir():
raise IOError("Can't save vocab to %s\nNot a directory" % path)
with (path / 'strings.json').open('w') as file_:
vocab.strings.dump(file_)
vocab.dump((path / 'lexemes.bin').as_posix())
def load_vocab(path):
path = Path(path)
if not path.exists():
raise IOError("Cannot load vocab from %s\nDoes not exist" % path)
if not path.is_dir():
raise IOError("Cannot load vocab from %s\nNot a directory" % path)
return Vocab.load(path)
def init_ner_model(vocab, features=None):
if features is None:
features = tuple(EntityRecognizer.feature_templates)
return EntityRecognizer(vocab, features=features)
def save_ner_model(model, path):
path = Path(path)
if not path.exists():
path.mkdir()
if not path.is_dir():
raise IOError("Can't save model to %s\nNot a directory" % path)
model.model.dump((path / 'model').as_posix())
with (path / 'config.json').open('w') as file_:
data = json.dumps(model.cfg)
if not isinstance(data, unicode):
data = data.decode('utf8')
file_.write(data)
def load_ner_model(vocab, path):
return EntityRecognizer.load(path, vocab)
class Pipeline(object): class Pipeline(object):
@classmethod def __init__(self, vocab=None, tokenizer=None, tensorizer=None, entity=None):
def load(cls, path):
path = Path(path)
if not path.exists():
raise IOError("Cannot load pipeline from %s\nDoes not exist" % path)
if not path.is_dir():
raise IOError("Cannot load pipeline from %s\nNot a directory" % path)
vocab = load_vocab(path)
tokenizer = Tokenizer(vocab, {}, None, None, None)
ner_model = load_ner_model(vocab, path / 'ner')
return cls(vocab, tokenizer, ner_model)
def __init__(self, vocab=None, tokenizer=None, entity=None):
if vocab is None: if vocab is None:
vocab = init_vocab() vocab = init_vocab()
if tokenizer is None: if tokenizer is None:
tokenizer = Tokenizer(vocab, {}, None, None, None) tokenizer = Tokenizer(vocab, {}, None, None, None)
if tensorizer is None:
tensorizer = TokenVectorEncoder(vocab)
if entity is None: if entity is None:
entity = init_ner_model(self.vocab) entity = NeuralEntityRecognizer(vocab)
self.vocab = vocab self.vocab = vocab
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tensorizer = tensorizer
self.entity = entity self.entity = entity
self.pipeline = [self.entity] self.pipeline = [tensorizer, self.entity]
def begin_training(self):
for model in self.pipeline:
model.begin_training([])
optimizer = Adam(NumpyOps(), 0.001)
return optimizer
def __call__(self, input_): def __call__(self, input_):
doc = self.make_doc(input_) doc = self.make_doc(input_)
@ -147,14 +95,18 @@ class Pipeline(object):
gold = GoldParse(doc, entities=annotations) gold = GoldParse(doc, entities=annotations)
return gold return gold
def update(self, input_, annot): def update(self, inputs, annots, sgd, losses=None, drop=0.):
doc = self.make_doc(input_) if losses is None:
gold = self.make_gold(input_, annot) losses = {}
for ner in gold.ner: docs = [self.make_doc(input_) for input_ in inputs]
if ner not in (None, '-', 'O'): golds = [self.make_gold(input_, annot) for input_, annot in
action, label = ner.split('-', 1) zip(inputs, annots)]
self.entity.add_label(label)
return self.entity.update(doc, gold) tensors, bp_tensors = self.tensorizer.update(docs, golds, drop=drop)
d_tensors = self.entity.update((docs, tensors), golds, drop=drop,
sgd=sgd, losses=losses)
bp_tensors(d_tensors, sgd=sgd)
return losses
def evaluate(self, examples): def evaluate(self, examples):
scorer = Scorer() scorer = Scorer()
@ -164,34 +116,38 @@ class Pipeline(object):
scorer.score(doc, gold) scorer.score(doc, gold)
return scorer.scores return scorer.scores
def average_weights(self): def to_disk(self, path):
self.entity.model.end_training()
def save(self, path):
path = Path(path) path = Path(path)
if not path.exists(): if not path.exists():
path.mkdir() path.mkdir()
elif not path.is_dir(): elif not path.is_dir():
raise IOError("Can't save pipeline to %s\nNot a directory" % path) raise IOError("Can't save pipeline to %s\nNot a directory" % path)
save_vocab(self.vocab, path / 'vocab') self.vocab.to_disk(path / 'vocab')
save_ner_model(self.entity, path / 'ner') self.tensorizer.to_disk(path / 'tensorizer')
self.entity.to_disk(path / 'ner')
def from_disk(self, path):
path = Path(path)
if not path.exists():
raise IOError("Cannot load pipeline from %s\nDoes not exist" % path)
if not path.is_dir():
raise IOError("Cannot load pipeline from %s\nNot a directory" % path)
self.vocab = self.vocab.from_disk(path / 'vocab')
self.tensorizer = self.tensorizer.from_disk(path / 'tensorizer')
self.entity = self.entity.from_disk(path / 'ner')
def train(nlp, train_examples, dev_examples, ctx, nr_epoch=5): def train(nlp, train_examples, dev_examples, nr_epoch=5):
next_epoch = train_examples sgd = nlp.begin_training()
print("Iter", "Loss", "P", "R", "F") print("Iter", "Loss", "P", "R", "F")
for i in range(nr_epoch): for i in range(nr_epoch):
this_epoch = next_epoch random.shuffle(train_examples)
next_epoch = [] losses = {}
loss = 0 for batch in minibatch(tqdm.tqdm(train_examples, leave=False), size=8):
for input_, annot in this_epoch: inputs, annots = zip(*batch)
loss += nlp.update(input_, annot) nlp.update(list(inputs), list(annots), sgd, losses=losses)
if (i+1) < nr_epoch:
next_epoch.append((input_, annot))
random.shuffle(next_epoch)
scores = nlp.evaluate(dev_examples) scores = nlp.evaluate(dev_examples)
report_scores(i, loss, scores) report_scores(i, losses['ner'], scores)
nlp.average_weights()
scores = nlp.evaluate(dev_examples) scores = nlp.evaluate(dev_examples)
report_scores(channels, i+1, loss, scores) report_scores(channels, i+1, loss, scores)
@ -208,7 +164,8 @@ def read_examples(path):
with path.open() as file_: with path.open() as file_:
sents = file_.read().strip().split('\n\n') sents = file_.read().strip().split('\n\n')
for sent in sents: for sent in sents:
if not sent.strip(): sent = sent.strip()
if not sent:
continue continue
tokens = sent.split('\n') tokens = sent.split('\n')
while tokens and tokens[0].startswith('#'): while tokens and tokens[0].startswith('#'):
@ -217,28 +174,39 @@ def read_examples(path):
iob = [] iob = []
for token in tokens: for token in tokens:
if token.strip(): if token.strip():
pieces = token.split() pieces = token.split('\t')
words.append(pieces[1]) words.append(pieces[1])
iob.append(pieces[2]) iob.append(pieces[2])
yield words, iob_to_biluo(iob) yield words, iob_to_biluo(iob)
def get_labels(examples):
labels = set()
for words, tags in examples:
for tag in tags:
if '-' in tag:
labels.add(tag.split('-')[1])
return sorted(labels)
@plac.annotations( @plac.annotations(
model_dir=("Path to save the model", "positional", None, Path), model_dir=("Path to save the model", "positional", None, Path),
train_loc=("Path to your training data", "positional", None, Path), train_loc=("Path to your training data", "positional", None, Path),
dev_loc=("Path to your development data", "positional", None, Path), dev_loc=("Path to your development data", "positional", None, Path),
) )
def main(model_dir=Path('/home/matt/repos/spaCy/spacy/data/de-1.0.0'), def main(model_dir, train_loc, dev_loc, nr_epoch=30):
train_loc=None, dev_loc=None, nr_epoch=30): print(model_dir, train_loc, dev_loc)
train_examples = list(read_examples(train_loc))
train_examples = read_examples(train_loc)
dev_examples = read_examples(dev_loc) dev_examples = read_examples(dev_loc)
nlp = Pipeline.load(model_dir) nlp = Pipeline()
for label in get_labels(train_examples):
nlp.entity.add_label(label)
print("Add label", label)
train(nlp, train_examples, list(dev_examples), ctx, nr_epoch) train(nlp, train_examples, list(dev_examples), nr_epoch)
nlp.save(model_dir) nlp.to_disk(model_dir)
if __name__ == '__main__': if __name__ == '__main__':
main() plac.call(main)

View File

@ -25,7 +25,7 @@ For more details, see the documentation:
* Saving and loading models: https://spacy.io/docs/usage/saving-loading * Saving and loading models: https://spacy.io/docs/usage/saving-loading
Developed for: spaCy 1.7.6 Developed for: spaCy 1.7.6
Last tested for: spaCy 1.7.6 Last updated for: spaCy 2.0.0a13
""" """
from __future__ import unicode_literals, print_function from __future__ import unicode_literals, print_function
@ -34,55 +34,41 @@ from pathlib import Path
import random import random
import spacy import spacy
from spacy.gold import GoldParse from spacy.gold import GoldParse, minibatch
from spacy.tagger import Tagger from spacy.pipeline import NeuralEntityRecognizer
from spacy.pipeline import TokenVectorEncoder
def get_gold_parses(tokenizer, train_data):
'''Shuffle and create GoldParse objects'''
random.shuffle(train_data)
for raw_text, entity_offsets in train_data:
doc = tokenizer(raw_text)
gold = GoldParse(doc, entities=entity_offsets)
yield doc, gold
def train_ner(nlp, train_data, output_dir): def train_ner(nlp, train_data, output_dir):
# Add new words to vocab
for raw_text, _ in train_data:
doc = nlp.make_doc(raw_text)
for word in doc:
_ = nlp.vocab[word.orth]
random.seed(0) random.seed(0)
# You may need to change the learning rate. It's generally difficult to optimizer = nlp.begin_training(lambda: [])
# guess what rate you should set, especially when you have limited data. nlp.meta['name'] = 'en_ent_animal'
nlp.entity.model.learn_rate = 0.001 for itn in range(50):
for itn in range(1000): losses = {}
random.shuffle(train_data) for batch in minibatch(get_gold_parses(nlp.make_doc, train_data), size=3):
loss = 0. docs, golds = zip(*batch)
for raw_text, entity_offsets in train_data: nlp.update(docs, golds, losses=losses, sgd=optimizer, update_shared=True,
gold = GoldParse(doc, entities=entity_offsets) drop=0.35)
# By default, the GoldParse class assumes that the entities print(losses)
# described by offset are complete, and all other words should if not output_dir:
# have the tag 'O'. You can tell it to make no assumptions return
# about the tag of a word by giving it the tag '-'. elif not output_dir.exists():
# However, this allows a trivial solution to the current output_dir.mkdir()
# learning problem: if words are either 'any tag' or 'ANIMAL', nlp.to_disk(output_dir)
# the model can learn that all words can be tagged 'ANIMAL'.
#for i in range(len(gold.ner)):
#if not gold.ner[i].endswith('ANIMAL'):
# gold.ner[i] = '-'
doc = nlp.make_doc(raw_text)
nlp.tagger(doc)
# As of 1.9, spaCy's parser now lets you supply a dropout probability
# This might help the model generalize better from only a few
# examples.
loss += nlp.entity.update(doc, gold, drop=0.9)
if loss == 0:
break
# This step averages the model's weights. This may or may not be good for
# your situation --- it's empirical.
nlp.end_training()
if output_dir:
if not output_dir.exists():
output_dir.mkdir()
nlp.save_to_directory(output_dir)
def main(model_name, output_directory=None): def main(model_name, output_directory=None):
print("Loading initial model", model_name) print("Creating initial model", model_name)
nlp = spacy.load(model_name) nlp = spacy.blank(model_name)
if output_directory is not None: if output_directory is not None:
output_directory = Path(output_directory) output_directory = Path(output_directory)
@ -91,6 +77,11 @@ def main(model_name, output_directory=None):
"Horses are too tall and they pretend to care about your feelings", "Horses are too tall and they pretend to care about your feelings",
[(0, 6, 'ANIMAL')], [(0, 6, 'ANIMAL')],
), ),
(
"Do they bite?",
[],
),
( (
"horses are too tall and they pretend to care about your feelings", "horses are too tall and they pretend to care about your feelings",
[(0, 6, 'ANIMAL')] [(0, 6, 'ANIMAL')]
@ -109,18 +100,20 @@ def main(model_name, output_directory=None):
) )
] ]
nlp.entity.add_label('ANIMAL') nlp.pipeline.append(TokenVectorEncoder(nlp.vocab))
nlp.pipeline.append(NeuralEntityRecognizer(nlp.vocab))
nlp.pipeline[-1].add_label('ANIMAL')
train_ner(nlp, train_data, output_directory) train_ner(nlp, train_data, output_directory)
# Test that the entity is recognized # Test that the entity is recognized
doc = nlp('Do you like horses?') text = 'Do you like horses?'
print("Ents in 'Do you like horses?':") print("Ents in 'Do you like horses?':")
doc = nlp(text)
for ent in doc.ents: for ent in doc.ents:
print(ent.label_, ent.text) print(ent.label_, ent.text)
if output_directory: if output_directory:
print("Loading from", output_directory) print("Loading from", output_directory)
nlp2 = spacy.load('en', path=output_directory) nlp2 = spacy.load(output_directory)
nlp2.entity.add_label('ANIMAL')
doc2 = nlp2('Do you like horses?') doc2 = nlp2('Do you like horses?')
for ent in doc2.ents: for ent in doc2.ents:
print(ent.label_, ent.text) print(ent.label_, ent.text)

View File

@ -1,9 +1,9 @@
cython<0.24 cython>=0.24
pathlib pathlib
numpy>=1.7 numpy>=1.7
cymem>=1.30,<1.32 cymem>=1.30,<1.32
preshed>=1.0.0,<2.0.0 preshed>=1.0.0,<2.0.0
thinc>=6.8.0,<6.9.0 thinc>=6.8.1,<6.9.0
murmurhash>=0.28,<0.29 murmurhash>=0.28,<0.29
plac<1.0.0,>=0.9.6 plac<1.0.0,>=0.9.6
six six

View File

@ -195,7 +195,7 @@ def setup_package():
'murmurhash>=0.28,<0.29', 'murmurhash>=0.28,<0.29',
'cymem>=1.30,<1.32', 'cymem>=1.30,<1.32',
'preshed>=1.0.0,<2.0.0', 'preshed>=1.0.0,<2.0.0',
'thinc>=6.8.0,<6.9.0', 'thinc>=6.8.1,<6.9.0',
'plac<1.0.0,>=0.9.6', 'plac<1.0.0,>=0.9.6',
'pip>=9.0.0,<10.0.0', 'pip>=9.0.0,<10.0.0',
'six', 'six',

View File

@ -3,7 +3,7 @@
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py # https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
__title__ = 'spacy-nightly' __title__ = 'spacy-nightly'
__version__ = '2.0.0a13' __version__ = '2.0.0a14'
__summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython' __summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython'
__uri__ = 'https://spacy.io' __uri__ = 'https://spacy.io'
__author__ = 'Explosion AI' __author__ = 'Explosion AI'

View File

@ -80,6 +80,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
n_train_words = corpus.count_train() n_train_words = corpus.count_train()
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu) optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
nlp._optimizer = None
print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %") print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
try: try:

View File

@ -7,6 +7,7 @@ import re
import ujson import ujson
import random import random
import cytoolz import cytoolz
import itertools
from .syntax import nonproj from .syntax import nonproj
from .util import ensure_path from .util import ensure_path
@ -146,9 +147,13 @@ def minibatch(items, size=8):
'''Iterate over batches of items. `size` may be an iterator, '''Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step. so that batch-size can vary on each step.
''' '''
if isinstance(size, int):
size_ = itertools.repeat(8)
else:
size_ = size
items = iter(items) items = iter(items)
while True: while True:
batch_size = next(size) #if hasattr(size, '__next__') else size batch_size = next(size_)
batch = list(cytoolz.take(int(batch_size), items)) batch = list(cytoolz.take(int(batch_size), items))
if len(batch) == 0: if len(batch) == 0:
break break

View File

@ -14,8 +14,8 @@ class Chinese(Language):
except ImportError: except ImportError:
raise ImportError("The Chinese tokenizer requires the Jieba library: " raise ImportError("The Chinese tokenizer requires the Jieba library: "
"https://github.com/fxsjy/jieba") "https://github.com/fxsjy/jieba")
words = list(jieba.cut(text, cut_all=True)) words = list(jieba.cut(text, cut_all=False))
words=[x for x in words if x] words = [x for x in words if x]
return Doc(self.vocab, words=words, spaces=[False]*len(words)) return Doc(self.vocab, words=words, spaces=[False]*len(words))

View File

@ -346,15 +346,9 @@ class Language(object):
"""Allocate models, pre-process training data and acquire a trainer and """Allocate models, pre-process training data and acquire a trainer and
optimizer. Used as a contextmanager. optimizer. Used as a contextmanager.
gold_tuples (iterable): Gold-standard training data. get_gold_tuples (function): Function returning gold data
**cfg: Config parameters. **cfg: Config parameters.
YIELDS (tuple): A trainer and an optimizer. returns: An optimizer
EXAMPLE:
>>> with nlp.begin_training(gold, use_gpu=True) as (trainer, optimizer):
>>> for epoch in trainer.epochs(gold):
>>> for docs, golds in epoch:
>>> state = nlp.update(docs, golds, sgd=optimizer)
""" """
if self.parser: if self.parser:
self.pipeline.append(NeuralLabeller(self.vocab)) self.pipeline.append(NeuralLabeller(self.vocab))

View File

@ -38,7 +38,8 @@ class Lemmatizer(object):
avoid lemmatization entirely. avoid lemmatization entirely.
""" """
morphology = {} if morphology is None else morphology morphology = {} if morphology is None else morphology
others = [key for key in morphology if key not in (POS, 'number', 'pos', 'verbform')] others = [key for key in morphology
if key not in (POS, 'Number', 'POS', 'VerbForm', 'Tense')]
true_morph_key = morphology.get('morph', 0) true_morph_key = morphology.get('morph', 0)
if univ_pos == 'noun' and morphology.get('Number') == 'sing': if univ_pos == 'noun' and morphology.get('Number') == 'sing':
return True return True
@ -47,7 +48,9 @@ class Lemmatizer(object):
# This maps 'VBP' to base form -- probably just need 'IS_BASE' # This maps 'VBP' to base form -- probably just need 'IS_BASE'
# morphology # morphology
elif univ_pos == 'verb' and (morphology.get('VerbForm') == 'fin' and \ elif univ_pos == 'verb' and (morphology.get('VerbForm') == 'fin' and \
morphology.get('Tense') == 'pres'): morphology.get('Tense') == 'pres' and \
morphology.get('Number') is None and \
not others):
return True return True
elif univ_pos == 'adj' and morphology.get('Degree') == 'pos': elif univ_pos == 'adj' and morphology.get('Degree') == 'pos':
return True return True

View File

@ -1,4 +1,4 @@
cpdef enum symbol_t: cdef enum symbol_t:
NIL NIL
IS_ALPHA IS_ALPHA
IS_ASCII IS_ASCII

View File

@ -1,4 +1,6 @@
# coding: utf8 # coding: utf8
#cython: optimize.unpack_method_calls=False
from __future__ import unicode_literals from __future__ import unicode_literals
IDS = { IDS = {
@ -458,4 +460,11 @@ IDS = {
"xcomp": xcomp "xcomp": xcomp
} }
NAMES = [it[0] for it in sorted(IDS.items(), key=lambda it: it[1])] def sort_nums(x):
return x[1]
NAMES = [it[0] for it in sorted(IDS.items(), key=sort_nums)]
# Unfortunate hack here, to work around problem with long cpdef enum
# (which is generating an enormous amount of C++ in Cython 0.24+)
# We keep the enum cdef, and just make sure the names are available to Python
locals().update(IDS)

View File

@ -121,6 +121,8 @@ cdef cppclass StateC:
for i in range(n): for i in range(n):
if ids[i] >= 0: if ids[i] >= 0:
ids[i] += this.offset ids[i] += this.offset
else:
ids[i] = -1
int S(int i) nogil const: int S(int i) nogil const:
if i >= this._s_i: if i >= this._s_i:
@ -163,9 +165,9 @@ cdef cppclass StateC:
int E(int i) nogil const: int E(int i) nogil const:
if this._e_i <= 0 or this._e_i >= this.length: if this._e_i <= 0 or this._e_i >= this.length:
return 0 return -1
if i < 0 or i >= this._e_i: if i < 0 or i >= this._e_i:
return 0 return -1
return this._ents[this._e_i - (i+1)].start return this._ents[this._e_i - (i+1)].start
int L(int i, int idx) nogil const: int L(int i, int idx) nogil const:

View File

@ -161,8 +161,7 @@ cdef class BiluoPushDown(TransitionSystem):
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
cdef attr_t label cdef attr_t label
if name == '-' or name == None: if name == '-' or name == None:
move_str = 'M' return Transition(clas=0, move=MISSING, label=0, score=0)
label = 0
elif name == '!O': elif name == '!O':
return Transition(clas=0, move=ISNT, label=0, score=0) return Transition(clas=0, move=ISNT, label=0, score=0)
elif '-' in name: elif '-' in name:
@ -220,6 +219,31 @@ cdef class BiluoPushDown(TransitionSystem):
raise Exception(move) raise Exception(move)
return t return t
def add_action(self, int action, label_name):
cdef attr_t label_id
if not isinstance(label_name, (int, long)):
label_id = self.strings.add(label_name)
else:
label_id = label_name
if action == OUT and label_id != 0:
return
if action == MISSING or action == ISNT:
return
# Check we're not creating a move we already have, so that this is
# idempotent
for trans in self.c[:self.n_moves]:
if trans.move == action and trans.label == label_id:
return 0
if self.n_moves >= self._size:
self._size *= 2
self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0]))
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id)
assert self.c[self.n_moves].label == label_id
self.n_moves += 1
return 1
cdef int initialize_state(self, StateC* st) nogil: cdef int initialize_state(self, StateC* st) nogil:
# This is especially necessary when we use limited training data. # This is especially necessary when we use limited training data.
for i in range(st.length): for i in range(st.length):

View File

@ -262,8 +262,8 @@ cdef class Parser:
upper.is_noop = True upper.is_noop = True
else: else:
upper = chain( upper = chain(
clone(Maxout(hidden_width), (depth-1)), clone(Maxout(hidden_width), depth-1),
zero_init(Affine(nr_class, drop_factor=0.0)) zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
) )
upper.is_noop = False upper.is_noop = False
# TODO: This is an unfortunate hack atm! # TODO: This is an unfortunate hack atm!
@ -395,7 +395,6 @@ cdef class Parser:
tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses))) tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
else: else:
tokvecs = self.model[0].ops.flatten(tokvecses) tokvecs = self.model[0].ops.flatten(tokvecses)
nr_state = len(docs) nr_state = len(docs)
nr_class = self.moves.n_moves nr_class = self.moves.n_moves
nr_dim = tokvecs.shape[1] nr_dim = tokvecs.shape[1]
@ -421,7 +420,7 @@ cdef class Parser:
cdef int has_hidden = not getattr(vec2scores, 'is_noop', False) cdef int has_hidden = not getattr(vec2scores, 'is_noop', False)
while not next_step.empty(): while not next_step.empty():
if not has_hidden: if not has_hidden:
for i in cython.parallel.prange( for i in range(
next_step.size(), num_threads=6, nogil=True): next_step.size(), num_threads=6, nogil=True):
self._parse_step(next_step[i], self._parse_step(next_step[i],
feat_weights, nr_class, nr_feat, nr_piece) feat_weights, nr_class, nr_feat, nr_piece)
@ -528,7 +527,6 @@ cdef class Parser:
if losses is not None and self.name not in losses: if losses is not None and self.name not in losses:
losses[self.name] = 0. losses[self.name] = 0.
docs, tokvec_lists = docs_tokvecs docs, tokvec_lists = docs_tokvecs
tokvecs = self.model[0].ops.flatten(tokvec_lists)
if isinstance(docs, Doc) and isinstance(golds, GoldParse): if isinstance(docs, Doc) and isinstance(golds, GoldParse):
docs = [docs] docs = [docs]
golds = [golds] golds = [golds]
@ -609,7 +607,7 @@ cdef class Parser:
assert min(lengths) >= 1 assert min(lengths) >= 1
if USE_FINE_TUNE: if USE_FINE_TUNE:
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
tokvecs += self.model[0].ops.flatten(my_tokvecs) tokvecs = self.model[0].ops.flatten(my_tokvecs)
else: else:
tokvecs = self.model[0].ops.flatten(tokvecs) tokvecs = self.model[0].ops.flatten(tokvecs)
states = self.moves.init_batch(docs) states = self.moves.init_batch(docs)
@ -647,6 +645,15 @@ cdef class Parser:
d_tokvecs = bp_my_tokvecs(d_tokvecs, sgd=sgd) d_tokvecs = bp_my_tokvecs(d_tokvecs, sgd=sgd)
return d_tokvecs return d_tokvecs
def _pad_tokvecs(self, tokvecs):
# Add a vector for missing values at the start of tokvecs
xp = get_array_module(tokvecs)
pad = xp.zeros((1, tokvecs.shape[1]), dtype=tokvecs.dtype)
return xp.vstack((pad, tokvecs))
def _unpad_tokvecs(self, d_tokvecs):
return d_tokvecs[1:]
def _init_gold_batch(self, whole_docs, whole_golds): def _init_gold_batch(self, whole_docs, whole_golds):
"""Make a square batch, of length equal to the shortest doc. A long """Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N, doc will get multiple states. Let's say we have a doc of length 2*N,

View File

@ -148,7 +148,7 @@ cdef class TransitionSystem:
def add_action(self, int action, label_name): def add_action(self, int action, label_name):
cdef attr_t label_id cdef attr_t label_id
if not isinstance(label_name, int): if not isinstance(label_name, (int, long)):
label_id = self.strings.add(label_name) label_id = self.strings.add(label_name)
else: else:
label_id = label_name label_id = label_name

View File

@ -0,0 +1,8 @@
import pytest
@pytest.mark.models('en')
def test_issue1305(EN):
'''Test lemmatization of English VBZ'''
assert EN.vocab.morphology.lemmatizer('works', 'verb') == set(['work'])
doc = EN(u'This app works well')
assert doc[2].lemma_ == 'work'

View File

@ -9,11 +9,14 @@ import pytest
@pytest.mark.models('en') @pytest.mark.models('en')
def test_issue429(EN): def test_issue429(EN):
def merge_phrases(matcher, doc, i, matches): def merge_phrases(matcher, doc, i, matches):
if i != len(matches) - 1: if i != len(matches) - 1:
return None return None
spans = [(ent_id, ent_id, doc[start:end]) for ent_id, start, end in matches] spans = [(ent_id, ent_id, doc[start:end]) for ent_id, start, end in matches]
for ent_id, label, span in spans: for ent_id, label, span in spans:
span.merge('NNP' if label else span.root.tag_, span.text, EN.vocab.strings[label]) span.merge(
tag=('NNP' if label else span.root.tag_),
lemma=span.text,
label='PERSON')
doc = EN('a') doc = EN('a')
matcher = Matcher(EN.vocab) matcher = Matcher(EN.vocab)

View File

@ -6,6 +6,16 @@ from ...strings import StringStore
import pytest import pytest
def test_string_hash(stringstore):
'''Test that string hashing is stable across platforms'''
ss = stringstore
assert ss.add('apple') == 8566208034543834098
heart = '\U0001f499'
print(heart)
h = ss.add(heart)
assert h == 11841826740069053588
def test_stringstore_from_api_docs(stringstore): def test_stringstore_from_api_docs(stringstore):
apple_hash = stringstore.add('apple') apple_hash = stringstore.add('apple')
assert apple_hash == 8566208034543834098 assert apple_hash == 8566208034543834098

View File

@ -1,6 +1,7 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import sys
import pytest import pytest
@ -37,9 +38,10 @@ def test_tokenizer_excludes_false_pos_emoticons(tokenizer, text, length):
tokens = tokenizer(text) tokens = tokenizer(text)
assert len(tokens) == length assert len(tokens) == length
@pytest.mark.parametrize('text,length', [('can you still dunk?🍕🍔😵LOL', 8), @pytest.mark.parametrize('text,length', [('can you still dunk?🍕🍔😵LOL', 8),
('i💙you', 3), ('🤘🤘yay!', 4)]) ('i💙you', 3), ('🤘🤘yay!', 4)])
def test_tokenizer_handles_emoji(tokenizer, text, length): def test_tokenizer_handles_emoji(tokenizer, text, length):
tokens = tokenizer(text) # These break on narrow unicode builds, e.g. Windows
assert len(tokens) == length if sys.maxunicode >= 1114111:
tokens = tokenizer(text)
assert len(tokens) == length

View File

@ -17,6 +17,7 @@ fi
if [ "${VIA}" == "compile" ]; then if [ "${VIA}" == "compile" ]; then
pip install -r requirements.txt pip install -r requirements.txt
python setup.py build_ext --inplace
pip install -e . pip install -e .
fi fi

View File

@ -282,7 +282,7 @@ p
def __call__(self, text): def __call__(self, text):
words = text.split(' ') words = text.split(' ')
# All tokens 'own' a subsequent space character in this tokenizer # All tokens 'own' a subsequent space character in this tokenizer
spaces = [True] * len(word) spaces = [True] * len(words)
return Doc(self.vocab, words=words, spaces=spaces) return Doc(self.vocab, words=words, spaces=spaces)
p p