mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-02 02:43:36 +03:00
Merge branch 'punctparse'
This commit is contained in:
commit
b3f9b199cf
|
@ -9,6 +9,7 @@ import codecs
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import gzip
|
import gzip
|
||||||
|
import nltk
|
||||||
|
|
||||||
import plac
|
import plac
|
||||||
import cProfile
|
import cProfile
|
||||||
|
@ -22,10 +23,15 @@ from spacy.syntax.parser import GreedyParser
|
||||||
from spacy.syntax.util import Config
|
from spacy.syntax.util import Config
|
||||||
|
|
||||||
|
|
||||||
|
def is_punct_label(label):
|
||||||
|
return label == 'P' or label.lower() == 'punct'
|
||||||
|
|
||||||
|
|
||||||
def read_tokenized_gold(file_):
|
def read_tokenized_gold(file_):
|
||||||
"""Read a standard CoNLL/MALT-style format"""
|
"""Read a standard CoNLL/MALT-style format"""
|
||||||
sents = []
|
sents = []
|
||||||
for sent_str in file_.read().strip().split('\n\n'):
|
for sent_str in file_.read().strip().split('\n\n'):
|
||||||
|
ids = []
|
||||||
words = []
|
words = []
|
||||||
heads = []
|
heads = []
|
||||||
labels = []
|
labels = []
|
||||||
|
@ -35,50 +41,137 @@ def read_tokenized_gold(file_):
|
||||||
words.append(word)
|
words.append(word)
|
||||||
if head_idx == -1:
|
if head_idx == -1:
|
||||||
head_idx = i
|
head_idx = i
|
||||||
|
ids.append(id_)
|
||||||
heads.append(head_idx)
|
heads.append(head_idx)
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
tags.append(pos_string)
|
tags.append(pos_string)
|
||||||
sents.append((words, heads, labels, tags))
|
sents.append((ids_, words, heads, labels, tags))
|
||||||
return sents
|
return sents
|
||||||
|
|
||||||
|
|
||||||
def read_docparse_gold(file_):
|
def read_docparse_gold(file_):
|
||||||
sents = []
|
paragraphs = []
|
||||||
for sent_str in file_.read().strip().split('\n\n'):
|
for sent_str in file_.read().strip().split('\n\n'):
|
||||||
words = []
|
words = []
|
||||||
heads = []
|
heads = []
|
||||||
labels = []
|
labels = []
|
||||||
tags = []
|
tags = []
|
||||||
|
ids = []
|
||||||
lines = sent_str.strip().split('\n')
|
lines = sent_str.strip().split('\n')
|
||||||
raw_text = lines[0]
|
raw_text = lines[0]
|
||||||
tok_text = lines[1]
|
tok_text = lines[1]
|
||||||
for i, line in enumerate(lines[2:]):
|
for i, line in enumerate(lines[2:]):
|
||||||
word, pos_string, head_idx, label = _parse_line(line)
|
id_, word, pos_string, head_idx, label = _parse_line(line)
|
||||||
|
if label == 'root':
|
||||||
|
label = 'ROOT'
|
||||||
words.append(word)
|
words.append(word)
|
||||||
if head_idx == -1:
|
if head_idx < 0:
|
||||||
head_idx = i
|
head_idx = id_
|
||||||
|
ids.append(id_)
|
||||||
heads.append(head_idx)
|
heads.append(head_idx)
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
tags.append(pos_string)
|
tags.append(pos_string)
|
||||||
words = tok_text.replace('<SEP>', ' ').replace('<SENT>', ' ').split(' ')
|
tokenized = [sent_str.replace('<SEP>', ' ').split(' ')
|
||||||
sents.append((words, heads, labels, tags))
|
for sent_str in tok_text.split('<SENT>')]
|
||||||
return sents
|
paragraphs.append((raw_text, tokenized, ids, words, tags, heads, labels))
|
||||||
|
return paragraphs
|
||||||
|
|
||||||
|
|
||||||
|
def _map_indices_to_tokens(ids, heads):
|
||||||
|
mapped = []
|
||||||
|
for head in heads:
|
||||||
|
if head not in ids:
|
||||||
|
mapped.append(None)
|
||||||
|
else:
|
||||||
|
mapped.append(ids.index(head))
|
||||||
|
return mapped
|
||||||
|
|
||||||
|
|
||||||
def _parse_line(line):
|
def _parse_line(line):
|
||||||
pieces = line.split()
|
pieces = line.split()
|
||||||
if len(pieces) == 4:
|
if len(pieces) == 4:
|
||||||
return pieces[0], pieces[1], int(pieces[2]) - 1, pieces[3]
|
return 0, pieces[0], pieces[1], int(pieces[2]) - 1, pieces[3]
|
||||||
else:
|
else:
|
||||||
|
id_ = int(pieces[0])
|
||||||
word = pieces[1]
|
word = pieces[1]
|
||||||
pos = pieces[3]
|
pos = pieces[3]
|
||||||
head_idx = int(pieces[6]) - 1
|
head_idx = int(pieces[6])
|
||||||
label = pieces[7]
|
label = pieces[7]
|
||||||
return word, pos, head_idx, label
|
return id_, word, pos, head_idx, label
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
def _align_annotations_to_non_gold_tokens(tokens, words, annot):
|
||||||
|
global loss
|
||||||
|
tags = []
|
||||||
|
heads = []
|
||||||
|
labels = []
|
||||||
|
orig_words = list(words)
|
||||||
|
missed = []
|
||||||
|
for token in tokens:
|
||||||
|
while annot and token.idx > annot[0][0]:
|
||||||
|
miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
|
||||||
|
miss_w = words.pop(0)
|
||||||
|
if not is_punct_label(miss_label):
|
||||||
|
missed.append(miss_w)
|
||||||
|
loss += 1
|
||||||
|
if not annot:
|
||||||
|
tags.append(None)
|
||||||
|
heads.append(None)
|
||||||
|
labels.append(None)
|
||||||
|
continue
|
||||||
|
id_, tag, head, label = annot[0]
|
||||||
|
if token.idx == id_:
|
||||||
|
tags.append(tag)
|
||||||
|
heads.append(head)
|
||||||
|
labels.append(label)
|
||||||
|
annot.pop(0)
|
||||||
|
words.pop(0)
|
||||||
|
elif token.idx < id_:
|
||||||
|
tags.append(None)
|
||||||
|
heads.append(None)
|
||||||
|
labels.append(None)
|
||||||
|
else:
|
||||||
|
raise StandardError
|
||||||
|
#if missed:
|
||||||
|
# print orig_words
|
||||||
|
# print missed
|
||||||
|
# for t in tokens:
|
||||||
|
# print t.idx, t.orth_
|
||||||
|
return loss, tags, heads, labels
|
||||||
|
|
||||||
|
|
||||||
|
def iter_data(paragraphs, tokenizer, gold_preproc=False):
|
||||||
|
for raw, tokenized, ids, words, tags, heads, labels in paragraphs:
|
||||||
|
if not gold_preproc:
|
||||||
|
tokens = tokenizer(raw)
|
||||||
|
loss, tags, heads, labels = _align_annotations_to_non_gold_tokens(
|
||||||
|
tokens, list(words),
|
||||||
|
zip(ids, tags, heads, labels))
|
||||||
|
ids = [t.idx for t in tokens]
|
||||||
|
heads = _map_indices_to_tokens(ids, heads)
|
||||||
|
yield tokens, tags, heads, labels
|
||||||
|
else:
|
||||||
|
assert len(words) == len(heads)
|
||||||
|
for words in tokenized:
|
||||||
|
sent_ids = ids[:len(words)]
|
||||||
|
sent_tags = tags[:len(words)]
|
||||||
|
sent_heads = heads[:len(words)]
|
||||||
|
sent_labels = labels[:len(words)]
|
||||||
|
sent_heads = _map_indices_to_tokens(sent_ids, sent_heads)
|
||||||
|
tokens = tokenizer.tokens_from_list(words)
|
||||||
|
yield tokens, sent_tags, sent_heads, sent_labels
|
||||||
|
ids = ids[len(words):]
|
||||||
|
tags = tags[len(words):]
|
||||||
|
heads = heads[len(words):]
|
||||||
|
labels = labels[len(words):]
|
||||||
|
|
||||||
|
|
||||||
def get_labels(sents):
|
def get_labels(sents):
|
||||||
left_labels = set()
|
left_labels = set()
|
||||||
right_labels = set()
|
right_labels = set()
|
||||||
for _, heads, labels, _ in sents:
|
for raw, tokenized, ids, words, tags, heads, labels in sents:
|
||||||
for child, (head, label) in enumerate(zip(heads, labels)):
|
for child, (head, label) in enumerate(zip(heads, labels)):
|
||||||
if head > child:
|
if head > child:
|
||||||
left_labels.add(label)
|
left_labels.add(label)
|
||||||
|
@ -87,7 +180,8 @@ def get_labels(sents):
|
||||||
return list(sorted(left_labels)), list(sorted(right_labels))
|
return list(sorted(left_labels)), list(sorted(right_labels))
|
||||||
|
|
||||||
|
|
||||||
def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0):
|
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
|
gold_preproc=False):
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
if path.exists(dep_model_dir):
|
if path.exists(dep_model_dir):
|
||||||
|
@ -99,7 +193,7 @@ def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0):
|
||||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
|
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
|
||||||
pos_model_dir)
|
pos_model_dir)
|
||||||
|
|
||||||
left_labels, right_labels = get_labels(sents)
|
left_labels, right_labels = get_labels(paragraphs)
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
left_labels=left_labels, right_labels=right_labels)
|
left_labels=left_labels, right_labels=right_labels)
|
||||||
|
|
||||||
|
@ -109,58 +203,50 @@ def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0):
|
||||||
heads_corr = 0
|
heads_corr = 0
|
||||||
pos_corr = 0
|
pos_corr = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
for words, heads, labels, tags in sents:
|
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
||||||
tags = [nlp.tagger.tag_names.index(tag) for tag in tags]
|
gold_preproc=gold_preproc):
|
||||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
heads_corr += nlp.parser.train_sent(tokens, heads, labels)
|
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False)
|
||||||
pos_corr += nlp.tagger.train(tokens, tags)
|
pos_corr += nlp.tagger.train(tokens, tag_strs)
|
||||||
n_tokens += len(tokens)
|
n_tokens += len(tokens)
|
||||||
acc = float(heads_corr) / n_tokens
|
acc = float(heads_corr) / n_tokens
|
||||||
pos_acc = float(pos_corr) / n_tokens
|
pos_acc = float(pos_corr) / n_tokens
|
||||||
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
|
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
|
||||||
random.shuffle(sents)
|
random.shuffle(paragraphs)
|
||||||
nlp.parser.model.end_training()
|
nlp.parser.model.end_training()
|
||||||
nlp.tagger.model.end_training()
|
nlp.tagger.model.end_training()
|
||||||
#nlp.parser.model.dump(path.join(dep_model_dir, 'model'), freq_thresh=0)
|
|
||||||
return acc
|
return acc
|
||||||
|
|
||||||
|
|
||||||
def evaluate(Language, dev_loc, model_dir):
|
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
n_corr = 0
|
n_corr = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
skipped = 0
|
||||||
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
||||||
sents = read_tokenized_gold(file_)
|
paragraphs = read_docparse_gold(file_)
|
||||||
for words, heads, labels, tags in sents:
|
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
||||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
gold_preproc=gold_preproc):
|
||||||
|
assert len(tokens) == len(labels)
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
nlp.parser(tokens)
|
nlp.parser(tokens)
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
#print i, token.string, i + token.head, heads[i], labels[i]
|
if heads[i] is None:
|
||||||
if labels[i] == 'P' or labels[i] == 'punct':
|
skipped += 1
|
||||||
|
continue
|
||||||
|
if is_punct_label(labels[i]):
|
||||||
continue
|
continue
|
||||||
n_corr += token.head.i == heads[i]
|
n_corr += token.head.i == heads[i]
|
||||||
total += 1
|
total += 1
|
||||||
return float(n_corr) / total
|
print loss, skipped, (loss+skipped + total)
|
||||||
|
return float(n_corr) / (total + loss)
|
||||||
|
|
||||||
PROFILE = False
|
|
||||||
|
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir):
|
def main(train_loc, dev_loc, model_dir):
|
||||||
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
||||||
train_sents = read_tokenized_gold(file_)
|
train_sents = read_docparse_gold(file_)
|
||||||
if PROFILE:
|
#train(English, train_sents, model_dir, gold_preproc=False)
|
||||||
import cProfile
|
print evaluate(English, dev_loc, model_dir, gold_preproc=False)
|
||||||
import pstats
|
|
||||||
cmd = "train(EN, train_sents, tag_names, model_dir, n_iter=2)"
|
|
||||||
cProfile.runctx(cmd, globals(), locals(), "Profile.prof")
|
|
||||||
s = pstats.Stats("Profile.prof")
|
|
||||||
s.strip_dirs().sort_stats("time").print_stats()
|
|
||||||
else:
|
|
||||||
train(English, train_sents, model_dir)
|
|
||||||
print evaluate(English, dev_loc, model_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -255,19 +255,23 @@ cdef class EnPosTagger:
|
||||||
tokens._tag_strings = self.tag_names
|
tokens._tag_strings = self.tag_names
|
||||||
tokens.is_tagged = True
|
tokens.is_tagged = True
|
||||||
|
|
||||||
def train(self, Tokens tokens, object golds):
|
def train(self, Tokens tokens, object gold_tag_strs):
|
||||||
cdef int i
|
cdef int i
|
||||||
|
cdef int loss
|
||||||
cdef atom_t[N_CONTEXT_FIELDS] context
|
cdef atom_t[N_CONTEXT_FIELDS] context
|
||||||
cdef const weight_t* scores
|
cdef const weight_t* scores
|
||||||
|
golds = [self.tag_names.index(g) if g is not None else -1
|
||||||
|
for g in gold_tag_strs]
|
||||||
correct = 0
|
correct = 0
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
fill_context(context, i, tokens.data)
|
fill_context(context, i, tokens.data)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = arg_max(scores, self.model.n_classes)
|
guess = arg_max(scores, self.model.n_classes)
|
||||||
self.model.update(context, guess, golds[i], guess != golds[i])
|
loss = guess != golds[i] if golds[i] != -1 else 0
|
||||||
|
self.model.update(context, guess, golds[i], loss)
|
||||||
tokens.data[i].tag = guess
|
tokens.data[i].tag = guess
|
||||||
self.set_morph(i, tokens.data)
|
self.set_morph(i, tokens.data)
|
||||||
correct += guess == golds[i]
|
correct += loss == 0
|
||||||
return correct
|
return correct
|
||||||
|
|
||||||
cdef int set_morph(self, const int i, TokenC* tokens) except -1:
|
cdef int set_morph(self, const int i, TokenC* tokens) except -1:
|
||||||
|
|
|
@ -8,7 +8,9 @@ from ._state cimport head_in_stack, children_in_stack
|
||||||
|
|
||||||
from ..structs cimport TokenC
|
from ..structs cimport TokenC
|
||||||
|
|
||||||
|
|
||||||
DEF NON_MONOTONIC = True
|
DEF NON_MONOTONIC = True
|
||||||
|
DEF USE_BREAK = True
|
||||||
|
|
||||||
|
|
||||||
cdef enum:
|
cdef enum:
|
||||||
|
@ -16,8 +18,12 @@ cdef enum:
|
||||||
REDUCE
|
REDUCE
|
||||||
LEFT
|
LEFT
|
||||||
RIGHT
|
RIGHT
|
||||||
|
BREAK
|
||||||
N_MOVES
|
N_MOVES
|
||||||
|
|
||||||
|
# Break transition from here
|
||||||
|
# http://www.aclweb.org/anthology/P13-1074
|
||||||
|
|
||||||
|
|
||||||
cdef inline bint _can_shift(const State* s) nogil:
|
cdef inline bint _can_shift(const State* s) nogil:
|
||||||
return not at_eol(s)
|
return not at_eol(s)
|
||||||
|
@ -41,6 +47,24 @@ cdef inline bint _can_reduce(const State* s) nogil:
|
||||||
return s.stack_len >= 2 and has_head(get_s0(s))
|
return s.stack_len >= 2 and has_head(get_s0(s))
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline bint _can_break(const State* s) nogil:
|
||||||
|
cdef int i
|
||||||
|
if not USE_BREAK:
|
||||||
|
return False
|
||||||
|
elif at_eol(s):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# If stack is disconnected, cannot break
|
||||||
|
seen_headless = False
|
||||||
|
for i in range(s.stack_len):
|
||||||
|
if s.sent[s.stack[-i]].head == 0:
|
||||||
|
if seen_headless:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
seen_headless = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
cdef int _shift_cost(const State* s, const int* gold) except -1:
|
cdef int _shift_cost(const State* s, const int* gold) except -1:
|
||||||
assert not at_eol(s)
|
assert not at_eol(s)
|
||||||
cost = 0
|
cost = 0
|
||||||
|
@ -48,6 +72,9 @@ cdef int _shift_cost(const State* s, const int* gold) except -1:
|
||||||
cost += children_in_stack(s, s.i, gold)
|
cost += children_in_stack(s, s.i, gold)
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
cost += gold[s.stack[0]] == s.i
|
cost += gold[s.stack[0]] == s.i
|
||||||
|
# If we can break, and there's no cost to doing so, we should
|
||||||
|
if _can_break(s) and _break_cost(s, gold) == 0:
|
||||||
|
cost += 1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,6 +101,7 @@ cdef int _left_cost(const State* s, const int* gold) except -1:
|
||||||
cost += children_in_buffer(s, s.stack[0], gold)
|
cost += children_in_buffer(s, s.stack[0], gold)
|
||||||
if NON_MONOTONIC and s.stack_len >= 2:
|
if NON_MONOTONIC and s.stack_len >= 2:
|
||||||
cost += gold[s.stack[0]] == s.stack[-1]
|
cost += gold[s.stack[0]] == s.stack[-1]
|
||||||
|
cost += gold[s.stack[0]] == s.stack[0]
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,6 +113,16 @@ cdef int _reduce_cost(const State* s, const int* gold) except -1:
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _break_cost(const State* s, const int* gold) except -1:
|
||||||
|
# When we break, we Reduce all of the words on the stack.
|
||||||
|
cdef int cost = 0
|
||||||
|
# Number of deps between S0...Sn and N0...Nn
|
||||||
|
for i in range(s.i, s.sent_len):
|
||||||
|
cost += children_in_stack(s, i, gold)
|
||||||
|
cost += head_in_stack(s, i, gold)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
|
||||||
cdef class TransitionSystem:
|
cdef class TransitionSystem:
|
||||||
def __init__(self, list left_labels, list right_labels):
|
def __init__(self, list left_labels, list right_labels):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
|
@ -94,7 +132,7 @@ cdef class TransitionSystem:
|
||||||
right_labels.pop(right_labels.index('ROOT'))
|
right_labels.pop(right_labels.index('ROOT'))
|
||||||
if 'ROOT' in left_labels:
|
if 'ROOT' in left_labels:
|
||||||
left_labels.pop(left_labels.index('ROOT'))
|
left_labels.pop(left_labels.index('ROOT'))
|
||||||
self.n_moves = 2 + len(left_labels) + len(right_labels)
|
self.n_moves = 3 + len(left_labels) + len(right_labels)
|
||||||
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
||||||
cdef int i = 0
|
cdef int i = 0
|
||||||
moves[i].move = SHIFT
|
moves[i].move = SHIFT
|
||||||
|
@ -121,6 +159,10 @@ cdef class TransitionSystem:
|
||||||
moves[i].label = label_id
|
moves[i].label = label_id
|
||||||
moves[i].clas = i
|
moves[i].clas = i
|
||||||
i += 1
|
i += 1
|
||||||
|
moves[i].move = BREAK
|
||||||
|
moves[i].label = 0
|
||||||
|
moves[i].clas = i
|
||||||
|
i += 1
|
||||||
self._moves = moves
|
self._moves = moves
|
||||||
|
|
||||||
cdef int transition(self, State *s, const Transition* t) except -1:
|
cdef int transition(self, State *s, const Transition* t) except -1:
|
||||||
|
@ -136,8 +178,17 @@ cdef class TransitionSystem:
|
||||||
add_dep(s, s.stack[0], s.i, t.label)
|
add_dep(s, s.stack[0], s.i, t.label)
|
||||||
push_stack(s)
|
push_stack(s)
|
||||||
elif t.move == REDUCE:
|
elif t.move == REDUCE:
|
||||||
|
# TODO: Huh? Is this some weirdness from the non-monotonic?
|
||||||
add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep)
|
add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep)
|
||||||
pop_stack(s)
|
pop_stack(s)
|
||||||
|
elif t.move == BREAK:
|
||||||
|
while s.stack_len != 0:
|
||||||
|
if get_s0(s).head == 0:
|
||||||
|
get_s0(s).dep = 0
|
||||||
|
s.stack -= 1
|
||||||
|
s.stack_len -= 1
|
||||||
|
if not at_eol(s):
|
||||||
|
push_stack(s)
|
||||||
else:
|
else:
|
||||||
raise Exception(t.move)
|
raise Exception(t.move)
|
||||||
|
|
||||||
|
@ -147,6 +198,7 @@ cdef class TransitionSystem:
|
||||||
valid[LEFT] = _can_left(s)
|
valid[LEFT] = _can_left(s)
|
||||||
valid[RIGHT] = _can_right(s)
|
valid[RIGHT] = _can_right(s)
|
||||||
valid[REDUCE] = _can_reduce(s)
|
valid[REDUCE] = _can_reduce(s)
|
||||||
|
valid[BREAK] = _can_break(s)
|
||||||
|
|
||||||
cdef int best = -1
|
cdef int best = -1
|
||||||
cdef weight_t score = 0
|
cdef weight_t score = 0
|
||||||
|
@ -175,6 +227,7 @@ cdef class TransitionSystem:
|
||||||
unl_costs[LEFT] = _left_cost(s, gold_heads) if _can_left(s) else -1
|
unl_costs[LEFT] = _left_cost(s, gold_heads) if _can_left(s) else -1
|
||||||
unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1
|
unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1
|
||||||
unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1
|
unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1
|
||||||
|
unl_costs[BREAK] = _break_cost(s, gold_heads) if _can_break(s) else -1
|
||||||
|
|
||||||
guess.cost = unl_costs[guess.move]
|
guess.cost = unl_costs[guess.move]
|
||||||
cdef Transition t
|
cdef Transition t
|
||||||
|
@ -191,10 +244,11 @@ cdef class TransitionSystem:
|
||||||
elif gold_heads[s.i] == s.stack[0]:
|
elif gold_heads[s.i] == s.stack[0]:
|
||||||
target_label = gold_labels[s.i]
|
target_label = gold_labels[s.i]
|
||||||
if guess.move == RIGHT:
|
if guess.move == RIGHT:
|
||||||
|
if unl_costs[guess.move] != 0:
|
||||||
guess.cost += guess.label != target_label
|
guess.cost += guess.label != target_label
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
t = self._moves[i]
|
t = self._moves[i]
|
||||||
if t.move == RIGHT and t.label == target_label:
|
if t.label == target_label and unl_costs[t.move] == 0:
|
||||||
return t
|
return t
|
||||||
|
|
||||||
cdef int best = -1
|
cdef int best = -1
|
||||||
|
|
|
@ -41,11 +41,12 @@ def set_debug(val):
|
||||||
|
|
||||||
cdef unicode print_state(State* s, list words):
|
cdef unicode print_state(State* s, list words):
|
||||||
words = list(words) + ['EOL']
|
words = list(words) + ['EOL']
|
||||||
top = words[s.stack[0]]
|
top = words[s.stack[0]] + '_%d' % s.sent[s.stack[0]].head
|
||||||
second = words[s.stack[-1]]
|
second = words[s.stack[-1]] + '_%d' % s.sent[s.stack[-1]].head
|
||||||
|
third = words[s.stack[-2]] + '_%d' % s.sent[s.stack[-2]].head
|
||||||
n0 = words[s.i]
|
n0 = words[s.i]
|
||||||
n1 = words[s.i + 1]
|
n1 = words[s.i + 1]
|
||||||
return ' '.join((second, top, '|', n0, n1))
|
return ' '.join((str(s.stack_len), third, second, top, '|', n0, n1))
|
||||||
|
|
||||||
|
|
||||||
def get_templates(name):
|
def get_templates(name):
|
||||||
|
@ -86,7 +87,8 @@ cdef class GreedyParser:
|
||||||
tokens.is_parsed = True
|
tokens.is_parsed = True
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def train_sent(self, Tokens tokens, list gold_heads, list gold_labels):
|
def train_sent(self, Tokens tokens, list gold_heads, list gold_labels,
|
||||||
|
force_gold=False):
|
||||||
cdef:
|
cdef:
|
||||||
const Feature* feats
|
const Feature* feats
|
||||||
const weight_t* scores
|
const weight_t* scores
|
||||||
|
@ -100,19 +102,39 @@ cdef class GreedyParser:
|
||||||
cdef int* labels_array = <int*>mem.alloc(tokens.length, sizeof(int))
|
cdef int* labels_array = <int*>mem.alloc(tokens.length, sizeof(int))
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
|
if gold_heads[i] is None:
|
||||||
|
heads_array[i] = -1
|
||||||
|
labels_array[i] = -1
|
||||||
|
else:
|
||||||
heads_array[i] = gold_heads[i]
|
heads_array[i] = gold_heads[i]
|
||||||
labels_array[i] = self.moves.label_ids[gold_labels[i]]
|
labels_array[i] = self.moves.label_ids[gold_labels[i]]
|
||||||
|
|
||||||
py_words = [t.orth_ for t in tokens]
|
py_words = [t.orth_ for t in tokens]
|
||||||
|
py_moves = ['S', 'D', 'L', 'R', 'BS', 'BR']
|
||||||
|
history = []
|
||||||
|
#print py_words
|
||||||
cdef State* state = init_state(mem, tokens.data, tokens.length)
|
cdef State* state = init_state(mem, tokens.data, tokens.length)
|
||||||
while not is_final(state):
|
while not is_final(state):
|
||||||
fill_context(context, state)
|
fill_context(context, state)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = self.moves.best_valid(scores, state)
|
guess = self.moves.best_valid(scores, state)
|
||||||
best = self.moves.best_gold(&guess, scores, state, heads_array, labels_array)
|
best = self.moves.best_gold(&guess, scores, state, heads_array, labels_array)
|
||||||
|
history.append((py_moves[best.move], print_state(state, py_words)))
|
||||||
self.model.update(context, guess.clas, best.clas, guess.cost)
|
self.model.update(context, guess.clas, best.clas, guess.cost)
|
||||||
|
if force_gold:
|
||||||
|
self.moves.transition(state, &best)
|
||||||
|
else:
|
||||||
self.moves.transition(state, &guess)
|
self.moves.transition(state, &guess)
|
||||||
cdef int n_corr = 0
|
cdef int n_corr = 0
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
|
if gold_heads[i] != -1:
|
||||||
n_corr += (i + state.sent[i].head) == gold_heads[i]
|
n_corr += (i + state.sent[i].head) == gold_heads[i]
|
||||||
|
if force_gold and n_corr != tokens.length:
|
||||||
|
print py_words
|
||||||
|
print gold_heads
|
||||||
|
for move, state_str in history:
|
||||||
|
print move, state_str
|
||||||
|
for i in range(tokens.length):
|
||||||
|
print py_words[i], py_words[i + state.sent[i].head], py_words[gold_heads[i]]
|
||||||
|
raise Exception
|
||||||
return n_corr
|
return n_corr
|
||||||
|
|
|
@ -95,7 +95,6 @@ cdef class Tokenizer:
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
cdef int _try_cache(self, int idx, hash_t key, Tokens tokens) except -1:
|
cdef int _try_cache(self, int idx, hash_t key, Tokens tokens) except -1:
|
||||||
#cached = <Cached*>self._specials.get(key)
|
|
||||||
cached = <_Cached*>self._cache.get(key)
|
cached = <_Cached*>self._cache.get(key)
|
||||||
if cached == NULL:
|
if cached == NULL:
|
||||||
return False
|
return False
|
||||||
|
@ -176,7 +175,10 @@ cdef class Tokenizer:
|
||||||
if string.n != 0:
|
if string.n != 0:
|
||||||
cache_hit = self._try_cache(idx, string.key, tokens)
|
cache_hit = self._try_cache(idx, string.key, tokens)
|
||||||
if cache_hit:
|
if cache_hit:
|
||||||
idx = tokens.data[tokens.length - 1].idx + 1
|
# Get last idx
|
||||||
|
idx = tokens.data[tokens.length - 1].idx
|
||||||
|
# Increment by last length
|
||||||
|
idx += tokens.data[tokens.length - 1].lex.length
|
||||||
else:
|
else:
|
||||||
split = self._find_infix(string.chars, string.n)
|
split = self._find_infix(string.chars, string.n)
|
||||||
if split == 0 or split == -1:
|
if split == 0 or split == -1:
|
||||||
|
|
|
@ -8,7 +8,7 @@ from spacy.en import English
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def EN():
|
def EN():
|
||||||
return English()
|
return English().tokenizer
|
||||||
|
|
||||||
def test_single_word(EN):
|
def test_single_word(EN):
|
||||||
tokens = EN(u'hello')
|
tokens = EN(u'hello')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user