Merge pull request #2019 from explosion/feature/better-gold

Make Levenshtein alignment faster, bug fixes to parser, add UD parsing script
This commit is contained in:
Matthew Honnibal 2018-02-23 04:41:26 +01:00 committed by GitHub
commit dd3ebe4931
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 566 additions and 130 deletions

303
examples/training/conllu.py Normal file
View File

@ -0,0 +1,303 @@
'''Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
.conllu format for development data, allowing the official scorer to be used.
'''
from __future__ import unicode_literals
import plac
import tqdm
import re
import sys
import spacy
import spacy.util
from spacy.tokens import Doc
from spacy.gold import GoldParse, minibatch
from spacy.syntax.nonproj import projectivize
from collections import Counter
from timeit import default_timer as timer
from spacy._align import align
def prevent_bad_sentences(doc):
'''This is an example pipeline component for fixing sentence segmentation
mistakes. The component sets is_sent_start to False, which means the
parser will be prevented from making a sentence boundary there. The
rules here aren't necessarily a good idea.'''
for token in doc[1:]:
if token.nbor(-1).text == ',':
token.is_sent_start = False
elif not token.nbor(-1).whitespace_:
token.is_sent_start = False
elif not token.nbor(-1).is_punct:
token.is_sent_start = False
elif token.nbor(-1).is_left_punct:
token.is_sent_start = False
return doc
def load_model(lang):
'''This shows how to adjust the tokenization rules, to special-case
for ways the CoNLLU tokenization differs. We need to get the tokenizer
accuracy high on the various treebanks in order to do well. If we don't
align on a content word, all dependencies to and from that word will
be marked as incorrect.
'''
English = spacy.util.get_lang_class(lang)
English.Defaults.infixes += ('(?<=[^-\d])[+\-\*^](?=[^-\d])',)
English.Defaults.infixes += ('(?<=[^-])[+\-\*^](?=[^-\d])',)
English.Defaults.infixes += ('(?<=[^-\d])[+\-\*^](?=[^-])',)
English.Defaults.token_match = re.compile(r'=+').match
nlp = English()
nlp.tokenizer.add_special_case('***', [{'ORTH': '***'}])
nlp.tokenizer.add_special_case("):", [{'ORTH': ")"}, {"ORTH": ":"}])
nlp.tokenizer.add_special_case("and/or", [{'ORTH': "and"}, {"ORTH": "/"}, {"ORTH": "or"}])
nlp.tokenizer.add_special_case("non-Microsoft", [{'ORTH': "non-Microsoft"}])
nlp.tokenizer.add_special_case("mis-matches", [{'ORTH': "mis-matches"}])
nlp.tokenizer.add_special_case("X.", [{'ORTH': "X"}, {"ORTH": "."}])
nlp.tokenizer.add_special_case("b/c", [{'ORTH': "b/c"}])
return nlp
def get_token_acc(docs, golds):
'''Quick function to evaluate tokenization accuracy.'''
miss = 0
hit = 0
for doc, gold in zip(docs, golds):
for i in range(len(doc)):
token = doc[i]
align = gold.words[i]
if align == None:
miss += 1
else:
hit += 1
return miss, hit
def golds_to_gold_tuples(docs, golds):
'''Get out the annoying 'tuples' format used by begin_training, given the
GoldParse objects.'''
tuples = []
for doc, gold in zip(docs, golds):
text = doc.text
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
sents = [((ids, words, tags, heads, labels, iob), [])]
tuples.append((text, sents))
return tuples
def split_text(text):
return [par.strip().replace('\n', ' ')
for par in text.split('\n\n')]
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
limit=None):
'''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
include Doc objects created using nlp.make_doc and then aligned against
the gold-standard sequences. If oracle_segments=True, include Doc objects
created from the gold-standard segments. At least one must be True.'''
if not raw_text and not oracle_segments:
raise ValueError("At least one of raw_text or oracle_segments must be True")
paragraphs = split_text(text_file.read())
conllu = read_conllu(conllu_file)
# sd is spacy doc; cd is conllu doc
# cs is conllu sent, ct is conllu token
docs = []
golds = []
for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
doc_words = []
doc_tags = []
doc_heads = []
doc_deps = []
doc_ents = []
for cs in cd:
sent_words = []
sent_tags = []
sent_heads = []
sent_deps = []
for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs:
if '.' in id_:
continue
if '-' in id_:
continue
id_ = int(id_)-1
head = int(head)-1 if head != '0' else id_
sent_words.append(word)
sent_tags.append(tag)
sent_heads.append(head)
sent_deps.append('ROOT' if dep == 'root' else dep)
if oracle_segments:
sent_heads, sent_deps = projectivize(sent_heads, sent_deps)
docs.append(Doc(nlp.vocab, words=sent_words))
golds.append(GoldParse(docs[-1], words=sent_words, heads=sent_heads,
tags=sent_tags, deps=sent_deps,
entities=['-']*len(sent_words)))
for head in sent_heads:
doc_heads.append(len(doc_words)+head)
doc_words.extend(sent_words)
doc_tags.extend(sent_tags)
doc_deps.extend(sent_deps)
doc_ents.extend(['-']*len(sent_words))
# Create a GoldParse object for the sentence
doc_heads, doc_deps = projectivize(doc_heads, doc_deps)
if raw_text:
docs.append(nlp.make_doc(text))
golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags,
heads=doc_heads, deps=doc_deps,
entities=doc_ents))
if limit and doc_id >= limit:
break
return docs, golds
def refresh_docs(docs):
vocab = docs[0].vocab
return [Doc(vocab, words=[t.text for t in doc],
spaces=[t.whitespace_ for t in doc])
for doc in docs]
def read_conllu(file_):
docs = []
doc = None
sent = []
for line in file_:
if line.startswith('# newdoc'):
if doc:
docs.append(doc)
doc = []
elif line.startswith('#'):
continue
elif not line.strip():
if sent:
if doc is None:
docs.append([sent])
else:
doc.append(sent)
sent = []
else:
sent.append(line.strip().split())
if sent:
if doc is None:
docs.append([sent])
else:
doc.append(sent)
if doc:
docs.append(doc)
return docs
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
joint_sbd=True):
with open(text_loc) as text_file:
with open(conllu_loc) as conllu_file:
docs, golds = read_data(nlp, conllu_file, text_file,
oracle_segments=oracle_segments)
if joint_sbd:
pass
else:
sbd = nlp.create_pipe('sentencizer')
for doc in docs:
doc = sbd(doc)
for sent in doc.sents:
sent[0].is_sent_start = True
for word in sent[1:]:
word.is_sent_start = False
scorer = nlp.evaluate(zip(docs, golds))
return docs, scorer
def print_progress(itn, losses, scorer):
scores = {}
for col in ['dep_loss', 'tag_loss', 'uas', 'tags_acc', 'token_acc',
'ents_p', 'ents_r', 'ents_f', 'cpu_wps', 'gpu_wps']:
scores[col] = 0.0
scores['dep_loss'] = losses.get('parser', 0.0)
scores['ner_loss'] = losses.get('ner', 0.0)
scores['tag_loss'] = losses.get('tagger', 0.0)
scores.update(scorer.scores)
tpl = '\t'.join((
'{:d}',
'{dep_loss:.3f}',
'{ner_loss:.3f}',
'{uas:.3f}',
'{ents_p:.3f}',
'{ents_r:.3f}',
'{ents_f:.3f}',
'{tags_acc:.3f}',
'{token_acc:.3f}',
))
print(tpl.format(itn, **scores))
def print_conllu(docs, file_):
for i, doc in enumerate(docs):
file_.write("# newdoc id = {i}\n".format(i=i))
for j, sent in enumerate(doc.sents):
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
file_.write("# text = {text}\n".format(text=sent.text))
for k, t in enumerate(sent):
if t.head.i == t.i:
head = 0
else:
head = k + (t.head.i - t.i) + 1
fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_',
str(head), t.dep_.lower(), '_', '_']
file_.write('\t'.join(fields) + '\n')
file_.write('\n')
def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
output_loc):
nlp = load_model(spacy_model)
with open(conllu_train_loc) as conllu_file:
with open(text_train_loc) as text_file:
docs, golds = read_data(nlp, conllu_file, text_file,
oracle_segments=True, raw_text=True,
limit=None)
print("Create parser")
nlp.add_pipe(nlp.create_pipe('parser'))
nlp.add_pipe(nlp.create_pipe('tagger'))
for gold in golds:
for tag in gold.tags:
if tag is not None:
nlp.tagger.add_label(tag)
optimizer = nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
# Replace labels that didn't make the frequency cutoff
actions = set(nlp.parser.labels)
label_set = set([act.split('-')[1] for act in actions if '-' in act])
for gold in golds:
for i, label in enumerate(gold.labels):
if label is not None and label not in label_set:
gold.labels[i] = label.split('||')[0]
n_train_words = sum(len(doc) for doc in docs)
print(n_train_words)
print("Begin training")
# Batch size starts at 1 and grows, so that we make updates quickly
# at the beginning of training.
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 8),
spacy.util.env_opt('batch_to', 8),
spacy.util.env_opt('batch_compound', 1.001))
for i in range(30):
docs = refresh_docs(docs)
batches = minibatch(list(zip(docs, golds)), size=batch_sizes)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
losses = {}
for batch in batches:
if not batch:
continue
batch_docs, batch_gold = zip(*batch)
nlp.update(batch_docs, batch_gold, sgd=optimizer,
drop=0.2, losses=losses)
pbar.update(sum(len(doc) for doc in batch_docs))
with nlp.use_params(optimizer.averages):
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
oracle_segments=False, joint_sbd=True)
print_progress(i, losses, scorer)
with open(output_loc, 'w') as file_:
print_conllu(dev_docs, file_)
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
oracle_segments=False, joint_sbd=False)
print_progress(i, losses, scorer)
if __name__ == '__main__':
plac.call(main)

View File

@ -18,6 +18,7 @@ PACKAGES = find_packages()
MOD_NAMES = [
'spacy._align',
'spacy.parts_of_speech',
'spacy.strings',
'spacy.lexeme',

175
spacy/_align.pyx Normal file
View File

@ -0,0 +1,175 @@
# cython: infer_types=True
'''Do Levenshtein alignment, for evaluation of tokenized input.
Random notes:
r i n g
0 1 2 3 4
r 1 0 1 2 3
a 2 1 1 2 3
n 3 2 2 1 2
g 4 3 3 2 1
0,0: (1,1)=min(0+0,1+1,1+1)=0 S
1,0: (2,1)=min(1+1,0+1,2+1)=1 D
2,0: (3,1)=min(2+1,3+1,1+1)=2 D
3,0: (4,1)=min(3+1,4+1,2+1)=3 D
0,1: (1,2)=min(1+1,2+1,0+1)=1 D
1,1: (2,2)=min(0+1,1+1,1+1)=1 S
2,1: (3,2)=min(1+1,1+1,2+1)=2 S or I
3,1: (4,2)=min(2+1,2+1,3+1)=3 S or I
0,2: (1,3)=min(2+1,3+1,1+1)=2 I
1,2: (2,3)=min(1+1,2+1,1+1)=2 S or I
2,2: (3,3)
3,2: (4,3)
At state (i, j) we're asking "How do I transform S[:i+1] to T[:j+1]?"
We know the costs to transition:
S[:i] -> T[:j] (at D[i,j])
S[:i+1] -> T[:j] (at D[i+1,j])
S[:i] -> T[:j+1] (at D[i,j+1])
Further, we now we can tranform:
S[:i+1] -> S[:i] (DEL) for 1,
T[:j+1] -> T[:j] (INS) for 1.
S[i+1] -> T[j+1] (SUB) for 0 or 1
Therefore we have the costs:
SUB: Cost(S[:i]->T[:j]) + Cost(S[i]->S[j])
i.e. D[i, j] + S[i+1] != T[j+1]
INS: Cost(S[:i+1]->T[:j]) + Cost(T[:j+1]->T[:j])
i.e. D[i+1,j] + 1
DEL: Cost(S[:i]->T[:j+1]) + Cost(S[:i+1]->S[:i])
i.e. D[i,j+1] + 1
Source string S has length m, with index i
Target string T has length n, with index j
Output two alignment vectors: i2j (length m) and j2i (length n)
# function LevenshteinDistance(char s[1..m], char t[1..n]):
# for all i and j, d[i,j] will hold the Levenshtein distance between
# the first i characters of s and the first j characters of t
# note that d has (m+1)*(n+1) values
# set each element in d to zero
ring rang
- r i n g
- 0 0 0 0 0
r 0 0 0 0 0
a 0 0 0 0 0
n 0 0 0 0 0
g 0 0 0 0 0
# source prefixes can be transformed into empty string by
# dropping all characters
# d[i, 0] := i
ring rang
- r i n g
- 0 0 0 0 0
r 1 0 0 0 0
a 2 0 0 0 0
n 3 0 0 0 0
g 4 0 0 0 0
# target prefixes can be reached from empty source prefix
# by inserting every character
# d[0, j] := j
- r i n g
- 0 1 2 3 4
r 1 0 0 0 0
a 2 0 0 0 0
n 3 0 0 0 0
g 4 0 0 0 0
'''
import numpy
cimport numpy as np
from .compat import unicode_
from murmurhash.mrmr cimport hash32
def align(S, T):
cdef int m = len(S)
cdef int n = len(T)
cdef np.ndarray matrix = numpy.zeros((m+1, n+1), dtype='int32')
cdef np.ndarray i2j = numpy.zeros((m,), dtype='i')
cdef np.ndarray j2i = numpy.zeros((n,), dtype='i')
cdef np.ndarray S_arr = _convert_sequence(S)
cdef np.ndarray T_arr = _convert_sequence(T)
fill_matrix(<int*>matrix.data,
<const int*>S_arr.data, m, <const int*>T_arr.data, n)
fill_i2j(i2j, matrix)
fill_j2i(j2i, matrix)
return matrix[-1,-1], i2j, j2i, matrix
def _convert_sequence(seq):
if isinstance(seq, numpy.ndarray):
return numpy.ascontiguousarray(seq, dtype='i')
cdef np.ndarray output = numpy.zeros((len(seq),), dtype='i')
cdef bytes item_bytes
for i, item in enumerate(seq):
if isinstance(item, unicode):
item_bytes = item.encode('utf8')
else:
item_bytes = item
output[i] = hash32(<void*><char*>item_bytes, len(item_bytes), 0)
return output
cdef void fill_matrix(int* D,
const int* S, int m, const int* T, int n) nogil:
m1 = m+1
n1 = n+1
for i in range(m1*n1):
D[i] = 0
for i in range(m1):
D[i*n1] = i
for j in range(n1):
D[j] = j
cdef int sub_cost, ins_cost, del_cost
for j in range(n):
for i in range(m):
i_j = i*n1 + j
i1_j1 = (i+1)*n1 + j+1
i1_j = (i+1)*n1 + j
i_j1 = i*n1 + j+1
if S[i] != T[j]:
sub_cost = D[i_j] + 1
else:
sub_cost = D[i_j]
del_cost = D[i_j1] + 1
ins_cost = D[i1_j] + 1
best = min(min(sub_cost, ins_cost), del_cost)
D[i1_j1] = best
cdef void fill_i2j(np.ndarray i2j, np.ndarray D) except *:
j = D.shape[1]-2
cdef int i = D.shape[0]-2
while i >= 0:
while D[i+1, j] < D[i+1, j+1]:
j -= 1
if D[i, j+1] < D[i+1, j+1]:
i2j[i] = -1
else:
i2j[i] = j
j -= 1
i -= 1
cdef void fill_j2i(np.ndarray j2i, np.ndarray D) except *:
i = D.shape[0]-2
cdef int j = D.shape[1]-2
while j >= 0:
while D[i, j+1] < D[i+1, j+1]:
i -= 1
if D[i+1, j] < D[i+1, j+1]:
j2i[j] = -1
else:
j2i[j] = i
i -= 1
j -= 1

View File

@ -7,7 +7,9 @@ import ujson
import random
import cytoolz
import itertools
import numpy
from . import _align
from .syntax import nonproj
from .tokens import Doc
from . import util
@ -59,90 +61,15 @@ def merge_sents(sents):
return [(m_deps, m_brackets)]
def align(cand_words, gold_words):
cost, edit_path = _min_edit_path(cand_words, gold_words)
alignment = []
i_of_gold = 0
for move in edit_path:
if move == 'M':
alignment.append(i_of_gold)
i_of_gold += 1
elif move == 'S':
alignment.append(None)
i_of_gold += 1
elif move == 'D':
alignment.append(None)
elif move == 'I':
i_of_gold += 1
else:
raise Exception(move)
return alignment
punct_re = re.compile(r'\W')
def _min_edit_path(cand_words, gold_words):
cdef:
Pool mem
int i, j, n_cand, n_gold
int* curr_costs
int* prev_costs
# TODO: Fix this --- just do it properly, make the full edit matrix and
# then walk back over it...
# Preprocess inputs
def align(cand_words, gold_words):
cand_words = [punct_re.sub('', w).lower() for w in cand_words]
gold_words = [punct_re.sub('', w).lower() for w in gold_words]
if cand_words == gold_words:
return 0, ''.join(['M' for _ in gold_words])
mem = Pool()
n_cand = len(cand_words)
n_gold = len(gold_words)
# Levenshtein distance, except we need the history, and we may want
# different costs. Mark operations with a string, and score the history
# using _edit_cost.
previous_row = []
prev_costs = <int*>mem.alloc(n_gold + 1, sizeof(int))
curr_costs = <int*>mem.alloc(n_gold + 1, sizeof(int))
for i in range(n_gold + 1):
cell = ''
for j in range(i):
cell += 'I'
previous_row.append('I' * i)
prev_costs[i] = i
for i, cand in enumerate(cand_words):
current_row = ['D' * (i + 1)]
curr_costs[0] = i+1
for j, gold in enumerate(gold_words):
if gold.lower() == cand.lower():
s_cost = prev_costs[j]
i_cost = curr_costs[j] + 1
d_cost = prev_costs[j + 1] + 1
else:
s_cost = prev_costs[j] + 1
i_cost = curr_costs[j] + 1
d_cost = prev_costs[j + 1] + (1 if cand else 0)
if s_cost <= i_cost and s_cost <= d_cost:
best_cost = s_cost
best_hist = previous_row[j] + ('M' if gold == cand else 'S')
elif i_cost <= s_cost and i_cost <= d_cost:
best_cost = i_cost
best_hist = current_row[j] + 'I'
else:
best_cost = d_cost
best_hist = previous_row[j + 1] + 'D'
current_row.append(best_hist)
curr_costs[j+1] = best_cost
previous_row = current_row
for j in range(len(gold_words) + 1):
prev_costs[j] = curr_costs[j]
curr_costs[j] = 0
return prev_costs[n_gold], previous_row[-1]
alignment = numpy.arange(len(cand_words))
return 0, alignment, alignment
cost, i2j, j2i, matrix = _align.align(cand_words, gold_words)
return cost, i2j, j2i
class GoldCorpus(object):
@ -434,8 +361,9 @@ cdef class GoldParse:
self.labels = [None] * len(doc)
self.ner = [None] * len(doc)
self.cand_to_gold = align([t.orth_ for t in doc], words)
self.gold_to_cand = align(words, [t.orth_ for t in doc])
cost, i2j, j2i = align([t.orth_ for t in doc], words)
self.cand_to_gold = [(j if j != -1 else None) for j in i2j]
self.gold_to_cand = [(i if i != -1 else None) for i in j2i]
annot_tuples = (range(len(words)), words, tags, heads, deps, entities)
self.orig_annot = list(zip(*annot_tuples))

View File

@ -1,7 +1,7 @@
# coding: utf8
from __future__ import unicode_literals
from .symbols import POS, NOUN, VERB, ADJ, PUNCT
from .symbols import POS, NOUN, VERB, ADJ, PUNCT, PROPN
from .symbols import VerbForm_inf, VerbForm_none, Number_sing, Degree_pos
@ -27,11 +27,13 @@ class Lemmatizer(object):
univ_pos = 'adj'
elif univ_pos in (PUNCT, 'PUNCT', 'punct'):
univ_pos = 'punct'
elif univ_pos in (PROPN, 'PROPN'):
return [string]
else:
return list(set([string.lower()]))
return [string.lower()]
# See Issue #435 for example of where this logic is requied.
if self.is_base_form(univ_pos, morphology):
return list(set([string.lower()]))
return [string.lower()]
lemmas = lemmatize(string, self.index.get(univ_pos, {}),
self.exc.get(univ_pos, {}),
self.rules.get(univ_pos, []))
@ -88,6 +90,7 @@ class Lemmatizer(object):
def lemmatize(string, index, exceptions, rules):
orig = string
string = string.lower()
forms = []
forms.extend(exceptions.get(string, []))
@ -105,5 +108,5 @@ def lemmatize(string, index, exceptions, rules):
if not forms:
forms.extend(oov_forms)
if not forms:
forms.append(string)
forms.append(orig)
return list(set(forms))

View File

@ -110,7 +110,8 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
cdef class Shift:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and st.B_(0).sent_start != 1
sent_start = st._sent[st.B_(0).l_edge].sent_start
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and sent_start != 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
@ -170,7 +171,8 @@ cdef class Reduce:
cdef class LeftArc:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
return st.B_(0).sent_start != 1
sent_start = st._sent[st.B_(0).l_edge].sent_start
return sent_start != 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
@ -205,7 +207,8 @@ cdef class RightArc:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
# If there's (perhaps partial) parse pre-set, don't allow cycle.
return st.B_(0).sent_start != 1 and st.H(st.S(0)) != st.B(0)
sent_start = st._sent[st.B_(0).l_edge].sent_start
return sent_start != 1 and st.H(st.S(0)) != st.B(0)
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
@ -527,7 +530,12 @@ cdef class ArcEager(TransitionSystem):
is_valid[i] = False
costs[i] = 9000
if n_gold < 1:
# Check projectivity --- leading cause
# Check label set --- leading cause
label_set = set([self.strings[self.c[i].label] for i in range(self.n_moves)])
for label_str in gold.labels:
if label_str is not None and label_str not in label_set:
raise ValueError("Cannot get gold parser action: unknown label: %s" % label_str)
# Check projectivity --- other leading cause
if is_nonproj_tree(gold.heads):
raise ValueError(
"Could not find a gold-standard action to supervise the "

View File

@ -555,7 +555,10 @@ cdef class Parser:
for multitask in self._multitasks:
multitask.update(docs, golds, drop=drop, sgd=sgd)
cuda_stream = util.get_cuda_stream()
states, golds, max_steps = self._init_gold_batch(docs, golds)
# Chop sequences into lengths of this many transitions, to make the
# batch uniform length.
cut_gold = numpy.random.choice(range(20, 100))
states, golds, max_steps = self._init_gold_batch(docs, golds, max_length=cut_gold)
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
drop)
todo = [(s, g) for (s, g) in zip(states, golds)
@ -659,7 +662,7 @@ cdef class Parser:
_cleanup(beam)
def _init_gold_batch(self, whole_docs, whole_golds):
def _init_gold_batch(self, whole_docs, whole_golds, min_length=5, max_length=500):
"""Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
where N is the shortest doc. We'll make two states, one representing
@ -668,7 +671,7 @@ cdef class Parser:
StateClass state
Transition action
whole_states = self.moves.init_batch(whole_docs)
max_length = max(5, min(50, min([len(doc) for doc in whole_docs])))
max_length = max(min_length, min(max_length, min([len(doc) for doc in whole_docs])))
max_moves = 0
states = []
golds = []
@ -790,6 +793,11 @@ cdef class Parser:
for doc in docs:
hook(doc)
@property
def labels(self):
class_names = [self.moves.get_class_name(i) for i in range(self.moves.n_moves)]
return class_names
@property
def tok2vec(self):
'''Return the embedding and convolutional layer of the model.'''
@ -825,7 +833,7 @@ cdef class Parser:
if 'model' in cfg:
self.model = cfg['model']
gold_tuples = nonproj.preprocess_training_data(gold_tuples,
label_freq_cutoff=100)
label_freq_cutoff=30)
actions = self.moves.get_actions(gold_parses=gold_tuples)
for action, labels in actions.items():
for label in labels:

View File

@ -1,36 +0,0 @@
# coding: utf-8
"""Find the min-cost alignment between two tokenizations"""
from __future__ import unicode_literals
from ...gold import _min_edit_path as min_edit_path
from ...gold import align
import pytest
@pytest.mark.parametrize('cand,gold,path', [
(["U.S", ".", "policy"], ["U.S.", "policy"], (0, 'MDM')),
(["U.N", ".", "policy"], ["U.S.", "policy"], (1, 'SDM')),
(["The", "cat", "sat", "down"], ["The", "cat", "sat", "down"], (0, 'MMMM')),
(["cat", "sat", "down"], ["The", "cat", "sat", "down"], (1, 'IMMM')),
(["The", "cat", "down"], ["The", "cat", "sat", "down"], (1, 'MMIM')),
(["The", "cat", "sag", "down"], ["The", "cat", "sat", "down"], (1, 'MMSM'))])
def test_gold_lev_align_edit_path(cand, gold, path):
assert min_edit_path(cand, gold) == path
def test_gold_lev_align_edit_path2():
cand = ["your", "stuff"]
gold = ["you", "r", "stuff"]
assert min_edit_path(cand, gold) in [(2, 'ISM'), (2, 'SIM')]
@pytest.mark.parametrize('cand,gold,result', [
(["U.S", ".", "policy"], ["U.S.", "policy"], [0, None, 1]),
(["your", "stuff"], ["you", "r", "stuff"], [None, 2]),
(["i", "like", "2", "guys", " ", "well", "id", "just", "come", "straight", "out"],
["i", "like", "2", "guys", "well", "i", "d", "just", "come", "straight", "out"],
[0, 1, 2, 3, None, 4, None, 7, 8, 9, 10])])
def test_gold_lev_align(cand, gold, result):
assert align(cand, gold) == result

46
spacy/tests/test_align.py Normal file
View File

@ -0,0 +1,46 @@
import pytest
from .._align import align
@pytest.mark.parametrize('string1,string2,cost', [
('hello', 'hell', 1),
('rat', 'cat', 1),
('rat', 'rat', 0),
('rat', 'catsie', 4),
('t', 'catsie', 5),
])
def test_align_costs(string1, string2, cost):
output_cost, i2j, j2i, matrix = align(string1, string2)
assert output_cost == cost
@pytest.mark.parametrize('string1,string2,i2j', [
('hello', 'hell', [0,1,2,3,-1]),
('rat', 'cat', [0,1,2]),
('rat', 'rat', [0,1,2]),
('rat', 'catsie', [0,1,2]),
('t', 'catsie', [2]),
])
def test_align_i2j(string1, string2, i2j):
output_cost, output_i2j, j2i, matrix = align(string1, string2)
assert list(output_i2j) == i2j
@pytest.mark.parametrize('string1,string2,j2i', [
('hello', 'hell', [0,1,2,3]),
('rat', 'cat', [0,1,2]),
('rat', 'rat', [0,1,2]),
('rat', 'catsie', [0,1,2, -1, -1, -1]),
('t', 'catsie', [-1, -1, 0, -1, -1, -1]),
])
def test_align_i2j(string1, string2, j2i):
output_cost, output_i2j, output_j2i, matrix = align(string1, string2)
assert list(output_j2i) == j2i
def test_align_strings():
words1 = ['hello', 'this', 'is', 'test!']
words2 = ['hellothis', 'is', 'test', '!']
cost, i2j, j2i, matrix = align(words1, words2)
assert cost == 4
assert list(i2j) == [0, -1, 1, 2]
assert list(j2i) == [0, 2, 3, -1]