mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
* Work on greedy parser
This commit is contained in:
parent
a432862fde
commit
95ccea03b2
127
spacy/index.pyx
Normal file
127
spacy/index.pyx
Normal file
|
@ -0,0 +1,127 @@
|
|||
"""Create a term-document matrix"""
|
||||
cimport cython
|
||||
|
||||
from libc.string cimport memmove
|
||||
|
||||
from cymem.cymem cimport Address
|
||||
|
||||
from .lexeme cimport Lexeme, get_attr
|
||||
from .tokens cimport TokenC
|
||||
from .typedefs cimport hash_t
|
||||
|
||||
from preshed.maps cimport MapStruct, Cell, map_get, map_set, map_init
|
||||
|
||||
|
||||
cdef class Index:
|
||||
def __init__(self, attr_id_t attr_id):
|
||||
self.attr_id = attr_id
|
||||
self.max_value = 0
|
||||
|
||||
cpdef int count(self, Tokens tokens) except -1:
|
||||
cdef PreshCounter counts = PreshCounter(2 ** 8)
|
||||
cdef attr_id_t attr_id = self.attr_id
|
||||
cdef attr_t term
|
||||
cdef int i
|
||||
for i in range(tokens.length):
|
||||
term = get_attr(tokens.data[i].lex, attr_id)
|
||||
counts.inc(term, 1)
|
||||
if term > self.max_value:
|
||||
self.max_value = term
|
||||
cdef count_t count
|
||||
cdef count_vector_t doc_counts
|
||||
for term, count in counts:
|
||||
doc_counts.push_back(pair[id_t, count_t](term, count))
|
||||
self.counts.push_back(doc_counts)
|
||||
|
||||
|
||||
|
||||
cdef class PosMemory:
|
||||
def __init__(self, tag_names):
|
||||
self.tag_names = tag_names
|
||||
self.nr_tags = len(tag_names)
|
||||
self.mem = Pool()
|
||||
self._counts = PreshCounter()
|
||||
self._pos_counts = PreshCounter()
|
||||
|
||||
def __getitem__(self, ids):
|
||||
cdef id_t[2] ngram
|
||||
ngram[0] = ids[0]
|
||||
ngram[1] = ids[1]
|
||||
cdef hash_t ngram_key = hash64(ngram, 2 * sizeof(id_t), 0)
|
||||
cdef hash_t[2] pos_context
|
||||
pos_context[0] = ngram_key
|
||||
counts = {}
|
||||
cdef id_t i
|
||||
for i, tag in enumerate(self.tag_names):
|
||||
pos_context[1] = <hash_t>i
|
||||
key = hash64(pos_context, sizeof(hash_t) * 2, 0)
|
||||
count = self._pos_counts[key]
|
||||
counts[tag] = count
|
||||
return counts
|
||||
|
||||
@cython.cdivision(True)
|
||||
def iter_ngrams(self, float min_acc=0.99, count_t min_freq=10):
|
||||
cdef Address counts_addr = Address(self.nr_tags, sizeof(count_t))
|
||||
cdef count_t* counts = <count_t*>counts_addr.ptr
|
||||
cdef MapStruct* ngram_counts = self._counts.c_map
|
||||
cdef hash_t ngram_key
|
||||
cdef count_t ngram_freq
|
||||
cdef int best_pos
|
||||
cdef float acc
|
||||
|
||||
cdef int i
|
||||
for i in range(ngram_counts.length):
|
||||
ngram_key = ngram_counts.cells[i].key
|
||||
ngram_freq = <count_t>ngram_counts.cells[i].value
|
||||
if ngram_key != 0 and ngram_freq >= min_freq:
|
||||
best_pos = self.find_best_pos(counts, ngram_key)
|
||||
acc = counts[best_pos] / ngram_freq
|
||||
if acc >= min_acc:
|
||||
yield counts[best_pos], ngram_key, best_pos
|
||||
|
||||
cpdef int count(self, Tokens tokens) except -1:
|
||||
cdef int i
|
||||
cdef TokenC* t
|
||||
for i in range(tokens.length):
|
||||
t = &tokens.data[i]
|
||||
if t.lex.prob != 0 and t.lex.prob >= -14:
|
||||
self.inc(t, 1)
|
||||
|
||||
cdef int inc(self, TokenC* word, count_t inc) except -1:
|
||||
cdef hash_t[2] ngram_pos_context
|
||||
cdef hash_t ngram_key = self._ngram_key(word)
|
||||
ngram_pos_context[0] = ngram_key
|
||||
ngram_pos_context[1] = <hash_t>word.pos
|
||||
ngram_pos_key = hash64(ngram_pos_context, 2 * sizeof(hash_t), 0)
|
||||
self._counts.inc(ngram_key, inc)
|
||||
self._pos_counts.inc(ngram_pos_key, inc)
|
||||
|
||||
cdef int find_best_pos(self, count_t* counts, hash_t ngram_key) except -1:
|
||||
cdef hash_t[2] unhashed_key
|
||||
unhashed_key[0] = ngram_key
|
||||
|
||||
cdef count_t total = 0
|
||||
cdef hash_t key
|
||||
cdef int pos
|
||||
cdef int best
|
||||
cdef int mode = 0
|
||||
for pos in range(self.nr_tags):
|
||||
unhashed_key[1] = <hash_t>pos
|
||||
key = hash64(unhashed_key, sizeof(hash_t) * 2, 0)
|
||||
count = self._pos_counts[key]
|
||||
counts[pos] = count
|
||||
if count >= mode:
|
||||
mode = count
|
||||
best = pos
|
||||
total += count
|
||||
return best
|
||||
|
||||
cdef count_t ngram_count(self, TokenC* word) except -1:
|
||||
cdef hash_t ngram_key = self._ngram_key(word)
|
||||
return self._counts[ngram_key]
|
||||
|
||||
cdef hash_t _ngram_key(self, TokenC* word) except 0:
|
||||
cdef id_t[2] context
|
||||
context[0] = word.lex.sic
|
||||
context[1] = word[-1].lex.sic
|
||||
return hash64(context, sizeof(id_t) * 2, 0)
|
|
@ -44,14 +44,15 @@ cdef class Language:
|
|||
self.pos_tagger = None
|
||||
self.morphologizer = None
|
||||
|
||||
def load(self):
|
||||
def load(self, pos_dir=None):
|
||||
self.lexicon.load(path.join(util.DATA_DIR, self.name, 'lexemes'))
|
||||
self.lexicon.strings.load(path.join(util.DATA_DIR, self.name, 'strings'))
|
||||
if path.exists(path.join(util.DATA_DIR, self.name, 'pos')):
|
||||
self.pos_tagger = Tagger(path.join(util.DATA_DIR, self.name, 'pos'))
|
||||
self.morphologizer = Morphologizer(self.lexicon.strings,
|
||||
path.join(util.DATA_DIR, self.name))
|
||||
self.load_pos_cache(path.join(util.DATA_DIR, self.name, 'pos', 'bigram_cache_2m'))
|
||||
if pos_dir is None:
|
||||
pos_dir = path.join(util.DATA_DIR, self.name, 'pos')
|
||||
if path.exists(pos_dir):
|
||||
self.pos_tagger = Tagger(pos_dir)
|
||||
self.morphologizer = Morphologizer(self.lexicon.strings, pos_dir)
|
||||
#self.load_pos_cache(path.join(util.DATA_DIR, self.name, 'pos', 'bigram_cache_2m'))
|
||||
|
||||
cpdef Tokens tokens_from_list(self, list strings):
|
||||
cdef int length = sum([len(s) for s in strings])
|
||||
|
|
|
@ -5,6 +5,8 @@ import json
|
|||
|
||||
from .lemmatizer import Lemmatizer
|
||||
from .typedefs cimport id_t
|
||||
from . import util
|
||||
|
||||
|
||||
UNIV_TAGS = {
|
||||
'NULL': NO_TAG,
|
||||
|
@ -35,10 +37,10 @@ cdef class Morphologizer:
|
|||
def __init__(self, StringStore strings, data_dir):
|
||||
self.mem = Pool()
|
||||
self.strings = strings
|
||||
cfg = json.load(open(path.join(data_dir, 'pos', 'config.json')))
|
||||
cfg = json.load(open(path.join(data_dir, 'config.json')))
|
||||
tag_map = cfg['tag_map']
|
||||
self.tag_names = cfg['tag_names']
|
||||
self.lemmatizer = Lemmatizer(path.join(data_dir, '..', 'wordnet'))
|
||||
self.lemmatizer = Lemmatizer(path.join(util.DATA_DIR, 'wordnet'))
|
||||
self._cache = PreshMapArray(len(self.tag_names))
|
||||
self.tags = <PosTag*>self.mem.alloc(len(self.tag_names), sizeof(PosTag))
|
||||
for i, tag in enumerate(self.tag_names):
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
from cymem.cymem cimport Pool
|
||||
|
||||
|
||||
cdef struct Subtree:
|
||||
int[5] kids
|
||||
int[5] labels
|
||||
int length
|
||||
|
||||
|
||||
cdef struct State:
|
||||
double score
|
||||
size_t i
|
||||
size_t n
|
||||
size_t stack_len
|
||||
size_t top
|
||||
|
||||
size_t* stack
|
||||
|
||||
Subtree* lefts
|
||||
Subtree* rights
|
||||
int* heads
|
||||
int* labels
|
||||
|
||||
|
||||
cdef int add_dep(const State *s, size_t head, size_t child, size_t label) except -1
|
||||
|
||||
cdef size_t pop_stack(State *s) except 0
|
||||
cdef int push_stack(State *s) except -1
|
||||
|
||||
|
||||
cdef inline size_t get_s1(const State *s) nogil:
|
||||
if s.stack_len < 2:
|
||||
return 0
|
||||
return s.stack[s.stack_len - 2]
|
||||
|
||||
|
||||
cdef inline size_t get_l(const State *s, size_t head) nogil:
|
||||
cdef const Subtree* subtree = &s.lefts[head]
|
||||
if subtree.length == 0:
|
||||
return 0
|
||||
return subtree.kids[subtree.length - 1]
|
||||
|
||||
|
||||
cdef inline size_t get_l2(const State *s, size_t head) nogil:
|
||||
cdef const Subtree* subtree = &s.lefts[head]
|
||||
if subtree.length < 2:
|
||||
return 0
|
||||
return subtree.kids[subtree.length - 2]
|
||||
|
||||
|
||||
cdef inline size_t get_r(const State *s, size_t head) nogil:
|
||||
cdef const Subtree* subtree = &s.rights[head]
|
||||
if subtree.length == 0:
|
||||
return 0
|
||||
return subtree.kids[subtree.length - 1]
|
||||
|
||||
|
||||
cdef inline size_t get_r2(const State *s, size_t head) nogil:
|
||||
cdef const Subtree* subtree = &s.rights[head]
|
||||
if subtree.length < 2:
|
||||
return 0
|
||||
return subtree.kids[subtree.length - 2]
|
||||
|
||||
|
||||
cdef inline bint at_eol(const State *s) nogil:
|
||||
return s.i >= (s.n - 1)
|
||||
|
||||
|
||||
cdef inline bint is_final(State *s) nogil:
|
||||
return at_eol(s) and s.stack_len == 0
|
||||
|
||||
|
||||
cdef int has_child_in_buffer(State *s, size_t word, int* gold) except -1
|
||||
cdef int has_head_in_buffer(State *s, size_t word, int* gold) except -1
|
||||
cdef int has_child_in_stack(State *s, size_t word, int* gold) except -1
|
||||
cdef int has_head_in_stack(State *s, size_t word, int* gold) except -1
|
||||
|
||||
cdef State* init_state(Pool mem, const int sent_length) except NULL
|
|
@ -1,97 +0,0 @@
|
|||
# cython: profile=True
|
||||
from libc.string cimport memmove
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
|
||||
cdef int add_dep(State *s, size_t head, size_t child, size_t label) except -1:
|
||||
s.heads[child] = head
|
||||
s.labels[child] = label
|
||||
cdef Subtree* subtree = &s.lefts[head] if child < head else &s.rights[head]
|
||||
if subtree.length == 5:
|
||||
memmove(subtree.kids, &subtree.kids[1], 4 * sizeof(int))
|
||||
memmove(subtree.labels, &subtree.labels[1], 4 * sizeof(int))
|
||||
subtree.kids[4] = child
|
||||
subtree.labels[4] = label
|
||||
else:
|
||||
subtree.kids[subtree.length - 1] = child
|
||||
subtree.kids[subtree.length - 1] = label
|
||||
subtree.length += 1
|
||||
|
||||
|
||||
cdef size_t pop_stack(State *s) except 0:
|
||||
cdef size_t popped
|
||||
assert s.stack_len >= 1
|
||||
popped = s.top
|
||||
s.top = get_s1(s)
|
||||
s.stack_len -= 1
|
||||
assert s.top <= s.n, s.top
|
||||
assert popped != 0
|
||||
return popped
|
||||
|
||||
|
||||
cdef int push_stack(State *s) except -1:
|
||||
s.top = s.i
|
||||
s.stack[s.stack_len] = s.i
|
||||
s.stack_len += 1
|
||||
assert s.top <= s.n
|
||||
s.i += 1
|
||||
|
||||
|
||||
cdef int has_child_in_buffer(State *s, size_t word, int* gold_heads) except -1:
|
||||
assert word != 0
|
||||
cdef size_t buff_i
|
||||
cdef int n = 0
|
||||
for buff_i in range(s.i, s.n):
|
||||
if s.heads[buff_i] == 0 and gold_heads[buff_i] == word:
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
cdef int has_head_in_buffer(State *s, size_t word, int* gold_heads) except -1:
|
||||
assert word != 0
|
||||
cdef size_t buff_i
|
||||
for buff_i in range(s.i, s.n):
|
||||
if s.heads[buff_i] == 0 and gold_heads[word] == buff_i:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
cdef int has_child_in_stack(State *s, size_t word, int* gold_heads) except -1:
|
||||
assert word != 0
|
||||
cdef size_t i, stack_i
|
||||
cdef int n = 0
|
||||
for i in range(s.stack_len):
|
||||
stack_i = s.stack[i]
|
||||
# Should this be sensitive to whether the word has a head already?
|
||||
if gold_heads[stack_i] == word:
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
cdef int has_head_in_stack(State *s, size_t word, int* gold_heads) except -1:
|
||||
assert word != 0
|
||||
cdef size_t i, stack_i
|
||||
for i in range(s.stack_len):
|
||||
stack_i = s.stack[i]
|
||||
if gold_heads[word] == stack_i:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
DEF PADDING = 5
|
||||
|
||||
|
||||
cdef State* init_state(Pool mem, const int sent_length) except NULL:
|
||||
cdef size_t i
|
||||
cdef State* s = <State*>mem.alloc(1, sizeof(State))
|
||||
s.n = sent_length
|
||||
s.i = 1
|
||||
s.top = 0
|
||||
s.stack_len = 0
|
||||
n = s.n + PADDING
|
||||
s.stack = <size_t*>mem.alloc(n, sizeof(size_t))
|
||||
s.heads = <int*>mem.alloc(n, sizeof(int))
|
||||
s.labels = <int*>mem.alloc(n, sizeof(int))
|
||||
s.lefts = <Subtree*>mem.alloc(n, sizeof(Subtree))
|
||||
s.rights = <Subtree*>mem.alloc(n, sizeof(Subtree))
|
||||
return s
|
|
@ -1,162 +0,0 @@
|
|||
# cython: profile=True
|
||||
from ._state cimport State
|
||||
from ._state cimport at_eol, pop_stack, push_stack, add_dep
|
||||
from ._state cimport has_head_in_buffer, has_child_in_buffer
|
||||
from ._state cimport has_head_in_stack, has_child_in_stack
|
||||
import index.hashes
|
||||
|
||||
|
||||
cdef enum:
|
||||
ERR
|
||||
SHIFT
|
||||
REDUCE
|
||||
LEFT
|
||||
RIGHT
|
||||
N_MOVES
|
||||
|
||||
|
||||
cdef inline bint _can_shift(const State* s) nogil:
|
||||
return not at_eol(s)
|
||||
|
||||
|
||||
cdef inline bint _can_right(const State* s) nogil:
|
||||
return s.stack_len >= 1 and not at_eol(s)
|
||||
|
||||
|
||||
cdef inline bint _can_left(const State* s) nogil:
|
||||
return s.stack_len >= 1 and s.heads[s.top] == 0
|
||||
|
||||
|
||||
cdef inline bint _can_reduce(const State* s) nogil:
|
||||
return s.stack_len >= 2 and s.heads[s.top] != 0
|
||||
|
||||
|
||||
cdef int _shift_cost(const State* s, list gold) except -1:
|
||||
assert not at_eol(s)
|
||||
cost = 0
|
||||
cost += has_head_in_stack(s, s.i, gold)
|
||||
cost += has_child_in_stack(s, s.i, gold)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _right_cost(const State* s, list gold) except -1:
|
||||
assert s.stack_len >= 1
|
||||
cost = 0
|
||||
if gold[s.i] == s.top:
|
||||
return cost
|
||||
cost += has_head_in_buffer(s, s.i, gold)
|
||||
cost += has_child_in_stack(s, s.i, gold)
|
||||
cost += has_head_in_stack(s, s.i, gold)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _left_cost(const State* s, list gold) except -1:
|
||||
assert s.stack_len >= 1
|
||||
cost = 0
|
||||
if gold[s.top] == s.i:
|
||||
return cost
|
||||
cost += has_head_in_buffer(s, s.top, gold)
|
||||
cost += has_child_in_buffer(s, s.top, gold)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _reduce_cost(const State* s, list gold) except -1:
|
||||
assert s.stack_len >= 2
|
||||
cost = 0
|
||||
cost += has_child_in_buffer(s, s.top, gold)
|
||||
return cost
|
||||
|
||||
|
||||
cdef class TransitionSystem:
|
||||
def __init__(self, list left_labels, list right_labels):
|
||||
self.mem = Pool()
|
||||
self.n_moves = 2 + len(left_labels) + len(right_labels)
|
||||
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
||||
cdef int i = 0
|
||||
moves[i].move = SHIFT
|
||||
moves[i].label = 0
|
||||
i += 1
|
||||
moves[i].move = REDUCE
|
||||
moves[i].label = 0
|
||||
i += 1
|
||||
cdef int label
|
||||
for label in left_labels:
|
||||
moves[i].move = LEFT
|
||||
moves[i].label = label
|
||||
i += 1
|
||||
for label in right_labels:
|
||||
moves[i].move = RIGHT
|
||||
moves[i].label = label
|
||||
i += 1
|
||||
self._moves = moves
|
||||
|
||||
cdef int transition(self, State *s, const int clas) except -1:
|
||||
cdef const Transition* t = &self._moves[clas]
|
||||
if t.move == SHIFT:
|
||||
push_stack(s)
|
||||
elif t.move == LEFT:
|
||||
add_dep(s, s.i, s.top, t.label)
|
||||
pop_stack(s)
|
||||
elif t.move == RIGHT:
|
||||
add_dep(s, s.top, s.i, t.label)
|
||||
push_stack(s)
|
||||
elif t.move == REDUCE:
|
||||
pop_stack(s)
|
||||
else:
|
||||
raise StandardError(t.move)
|
||||
|
||||
cdef int best_valid(self, const weight_t* scores, const State* s) except -1:
|
||||
cdef bint[N_MOVES] valid
|
||||
valid[SHIFT] = _can_shift(s)
|
||||
valid[LEFT] = _can_left(s)
|
||||
valid[RIGHT] = _can_right(s)
|
||||
valid[REDUCE] = _can_reduce(s)
|
||||
|
||||
cdef int best = -1
|
||||
cdef weight_t score
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
if valid[self._moves[i].move] and scores[i] > score:
|
||||
best = i
|
||||
score = scores[i]
|
||||
assert best >= -1, "No valid moves found"
|
||||
return best
|
||||
|
||||
cdef int best_gold(self, const weight_t* scores, const State* s,
|
||||
list gold_heads, list gold_labels) except -1:
|
||||
cdef int[N_MOVES] unl_costs
|
||||
unl_costs[SHIFT] = _shift_cost(s, gold_heads) if _can_shift(s) else -1
|
||||
unl_costs[LEFT] = _left_cost(s, gold_heads) if _can_left(s) else -1
|
||||
unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1
|
||||
unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1
|
||||
|
||||
cdef int cost
|
||||
cdef int label_cost
|
||||
cdef int move
|
||||
cdef int label
|
||||
cdef int best = -1
|
||||
cdef weight_t score
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
move = self._moves[i].move
|
||||
label = self._moves[i].label
|
||||
if unl_costs[move] == 0:
|
||||
if move == SHIFT or move == REDUCE:
|
||||
label_cost = 0
|
||||
elif move == LEFT:
|
||||
if gold_heads[s.top] == s.i:
|
||||
label_cost = label != gold_labels[s.top]
|
||||
else:
|
||||
label_cost = 0
|
||||
elif move == RIGHT:
|
||||
if gold_heads[s.i] == s.top:
|
||||
label_cost = label != gold_labels[s.i]
|
||||
else:
|
||||
label_cost = 0
|
||||
else:
|
||||
raise StandardError("Unknown Move")
|
||||
if label_cost == 0 and scores[i] > score:
|
||||
best = i
|
||||
score = scores[i]
|
||||
assert best >= -1, "No gold moves found"
|
||||
return best
|
|
@ -1,160 +0,0 @@
|
|||
# cython: profile=True
|
||||
"""
|
||||
MALT-style dependency parser
|
||||
"""
|
||||
cimport cython
|
||||
import random
|
||||
import os.path
|
||||
from os.path import join as pjoin
|
||||
import shutil
|
||||
import json
|
||||
|
||||
from cymem.cymem cimport Pool, Address
|
||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
|
||||
|
||||
|
||||
from util import Config
|
||||
|
||||
from thinc.features cimport Extractor
|
||||
from thinc.features cimport Feature
|
||||
from thinc.features cimport count_feats
|
||||
|
||||
from thinc.learner cimport LinearModel
|
||||
|
||||
from ..tokens cimport Tokens, TokenC
|
||||
|
||||
from .arc_eager cimport TransitionSystem
|
||||
|
||||
from ._state cimport init_state, State, is_final, get_s1
|
||||
|
||||
VOCAB_SIZE = 1e6
|
||||
TAG_SET_SIZE = 50
|
||||
|
||||
DEF CONTEXT_SIZE = 50
|
||||
|
||||
|
||||
DEBUG = False
|
||||
def set_debug(val):
|
||||
global DEBUG
|
||||
DEBUG = val
|
||||
|
||||
|
||||
cdef str print_state(State* s, list words):
|
||||
top = words[s.top]
|
||||
second = words[get_s1(s)]
|
||||
n0 = words[s.i]
|
||||
n1 = words[s.i + 1]
|
||||
return ' '.join((second, top, '|', n0, n1))
|
||||
|
||||
|
||||
def train(sents, golds, model_dir, n_iter=15, feat_set=u'basic', seed=0):
|
||||
if os.path.exists(model_dir):
|
||||
shutil.rmtree(model_dir)
|
||||
os.mkdir(model_dir)
|
||||
left_labels, right_labels, dfl_labels = get_labels(golds)
|
||||
Config.write(model_dir, 'config', features=feat_set, seed=seed,
|
||||
left_labels=left_labels, right_labels=right_labels)
|
||||
parser = Parser(model_dir)
|
||||
indices = list(range(len(sents)))
|
||||
for n in range(n_iter):
|
||||
for i in indices:
|
||||
parser.train_sent(sents[i], *golds[i])
|
||||
#parser.tagger.train_sent(py_sent) # TODO
|
||||
acc = float(parser.guide.n_corr) / parser.guide.total
|
||||
print(parser.guide.end_train_iter(n) + '\t' +
|
||||
parser.tagger.guide.end_train_iter(n))
|
||||
random.shuffle(indices)
|
||||
parser.guide.end_training()
|
||||
parser.tagger.guide.end_training()
|
||||
parser.guide.dump(pjoin(model_dir, 'model'), freq_thresh=0)
|
||||
parser.tagger.guide.dump(pjoin(model_dir, 'tagger'))
|
||||
return acc
|
||||
|
||||
|
||||
def get_labels(sents):
|
||||
'''Get alphabetically-sorted lists of left, right and disfluency labels that
|
||||
occur in a sample of sentences. Used to determine the set of legal transitions
|
||||
from the training set.
|
||||
|
||||
Args:
|
||||
sentences (list[Input]): A list of Input objects, usually the training set.
|
||||
|
||||
Returns:
|
||||
labels (tuple[list, list, list]): Sorted lists of left, right and disfluency
|
||||
labels.
|
||||
'''
|
||||
left_labels = set()
|
||||
right_labels = set()
|
||||
# TODO
|
||||
return list(sorted(left_labels)), list(sorted(right_labels))
|
||||
|
||||
|
||||
def get_templates(feats_str):
|
||||
'''Interpret feats_str, returning a list of template tuples. Each template
|
||||
is a tuple of numeric indices, referring to positions in the context
|
||||
array. See _parse_features.pyx for examples. The templates are applied by
|
||||
thinc.features.Extractor, which picks out the appointed values and hashes
|
||||
the resulting array, to produce a single feature code.
|
||||
'''
|
||||
return tuple()
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
def __init__(self, model_dir):
|
||||
assert os.path.exists(model_dir) and os.path.isdir(model_dir)
|
||||
self.cfg = Config.read(model_dir, 'config')
|
||||
self.extractor = Extractor(get_templates(self.cfg.features))
|
||||
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
|
||||
|
||||
self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ)
|
||||
if os.path.exists(pjoin(model_dir, 'model')):
|
||||
self.model.load(pjoin(model_dir, 'model'))
|
||||
|
||||
cpdef int parse(self, Tokens tokens) except -1:
|
||||
cdef:
|
||||
Feature* feats
|
||||
weight_t* scores
|
||||
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef int n_feats
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = init_state(mem, tokens.length)
|
||||
while not is_final(state):
|
||||
fill_context(context, state, tokens.data) # TODO
|
||||
feats = self.extractor.get_feats(context, &n_feats)
|
||||
scores = self.model.get_scores(feats, n_feats)
|
||||
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
|
||||
self.moves.transition(state, guess)
|
||||
# TODO output
|
||||
|
||||
def train_sent(self, Tokens tokens, list gold_heads, list gold_labels):
|
||||
cdef:
|
||||
Feature* feats
|
||||
weight_t* scores
|
||||
|
||||
cdef int n_feats
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = init_state(mem, tokens.length)
|
||||
|
||||
while not is_final(state):
|
||||
fill_context(context, state, tokens.data) # TODO
|
||||
feats = self.extractor.get_feats(context, &n_feats)
|
||||
scores = self.model.get_scores(feats, n_feats)
|
||||
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
best = self.moves.best_gold(scores, state, gold_heads, gold_labels)
|
||||
|
||||
counts = {guess: {}, best: {}}
|
||||
if guess != best:
|
||||
count_feats(counts[guess], feats, n_feats, -1)
|
||||
count_feats(counts[best], feats, n_feats, 1)
|
||||
self.model.update(counts)
|
||||
|
||||
self.moves.transition(state, guess)
|
||||
|
||||
|
||||
cdef int fill_context(atom_t* context, State* s, TokenC* sent) except -1:
|
||||
pass
|
11420
spacy/syntax/_parse_features.cpp
Normal file
11420
spacy/syntax/_parse_features.cpp
Normal file
File diff suppressed because it is too large
Load Diff
123
spacy/syntax/_parse_features.pxd
Normal file
123
spacy/syntax/_parse_features.pxd
Normal file
|
@ -0,0 +1,123 @@
|
|||
from thinc.typedefs cimport atom_t
|
||||
|
||||
from ._state cimport State
|
||||
|
||||
|
||||
cdef int fill_context(atom_t* context, State* state) except -1
|
||||
# Context elements
|
||||
|
||||
# Ensure each token's attributes are listed: w, p, c, c6, c4. The order
|
||||
# is referenced by incrementing the enum...
|
||||
|
||||
# Tokens are listed in left-to-right order.
|
||||
#cdef size_t* SLOTS = [
|
||||
# S2w, S1w,
|
||||
# S0l0w, S0l2w, S0lw,
|
||||
# S0w,
|
||||
# S0r0w, S0r2w, S0rw,
|
||||
# N0l0w, N0l2w, N0lw,
|
||||
# N0w, N1w, N2w, N3w, 0
|
||||
#]
|
||||
|
||||
# NB: The order of the enum is _NOT_ arbitrary!!
|
||||
cpdef enum:
|
||||
S2w
|
||||
S2p
|
||||
S2c
|
||||
S2c4
|
||||
S2c6
|
||||
S2L
|
||||
|
||||
S1w
|
||||
S1p
|
||||
S1c
|
||||
S1c4
|
||||
S1c6
|
||||
S1L
|
||||
|
||||
S1rw
|
||||
S1rp
|
||||
S1rc
|
||||
S1rc4
|
||||
S1rc6
|
||||
S1rL
|
||||
|
||||
S0lw
|
||||
S0lp
|
||||
S0lc
|
||||
S0lc4
|
||||
S0lc6
|
||||
S0lL
|
||||
|
||||
S0l2w
|
||||
S0l2p
|
||||
S0l2c
|
||||
S0l2c4
|
||||
S0l2c6
|
||||
S0l2L
|
||||
|
||||
S0w
|
||||
S0p
|
||||
S0c
|
||||
S0c4
|
||||
S0c6
|
||||
S0L
|
||||
|
||||
S0r2w
|
||||
S0r2p
|
||||
S0r2c
|
||||
S0r2c4
|
||||
S0r2c6
|
||||
S0r2L
|
||||
|
||||
S0rw
|
||||
S0rp
|
||||
S0rc
|
||||
S0rc4
|
||||
S0rc6
|
||||
S0rL
|
||||
|
||||
N0l2w
|
||||
N0l2p
|
||||
N0l2c
|
||||
N0l2c4
|
||||
N0l2c6
|
||||
N0l2L
|
||||
|
||||
N0lw
|
||||
N0lp
|
||||
N0lc
|
||||
N0lc4
|
||||
N0lc6
|
||||
N0lL
|
||||
|
||||
N0w
|
||||
N0p
|
||||
N0c
|
||||
N0c4
|
||||
N0c6
|
||||
N0L
|
||||
|
||||
N1w
|
||||
N1p
|
||||
N1c
|
||||
N1c4
|
||||
N1c6
|
||||
N1L
|
||||
|
||||
N2w
|
||||
N2p
|
||||
N2c
|
||||
N2c4
|
||||
N2c6
|
||||
N2L
|
||||
|
||||
# Misc features at the end
|
||||
dist
|
||||
N0lv
|
||||
S0lv
|
||||
S0rv
|
||||
S1lv
|
||||
S1rv
|
||||
|
||||
CONTEXT_SIZE
|
212
spacy/syntax/_parse_features.pyx
Normal file
212
spacy/syntax/_parse_features.pyx
Normal file
|
@ -0,0 +1,212 @@
|
|||
# cython: profile=True
|
||||
"""
|
||||
Fill an array, context, with every _atomic_ value our features reference.
|
||||
We then write the _actual features_ as tuples of the atoms. The machinery
|
||||
that translates from the tuples to feature-extractors (which pick the values
|
||||
out of "context") is in features/extractor.pyx
|
||||
|
||||
The atomic feature names are listed in a big enum, so that the feature tuples
|
||||
can refer to them.
|
||||
"""
|
||||
from itertools import combinations
|
||||
|
||||
from ..tokens cimport TokenC
|
||||
from ._state cimport State
|
||||
from ._state cimport get_s2, get_s1, get_s0, get_n0, get_n1, get_n2
|
||||
from ._state cimport get_left, get_right
|
||||
|
||||
|
||||
cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
|
||||
context[0] = token.lex.sic
|
||||
context[1] = token.pos
|
||||
context[2] = token.lex.cluster
|
||||
# We've read in the string little-endian, so now we can take & (2**n)-1
|
||||
# to get the first n bits of the cluster.
|
||||
# e.g. s = "1110010101"
|
||||
# s = ''.join(reversed(s))
|
||||
# first_4_bits = int(s, 2)
|
||||
# print first_4_bits
|
||||
# 5
|
||||
# print "{0:b}".format(prefix).ljust(4, '0')
|
||||
# 1110
|
||||
# What we're doing here is picking a number where all bits are 1, e.g.
|
||||
# 15 is 1111, 63 is 111111 and doing bitwise AND, so getting all bits in
|
||||
# the source that are set to 1.
|
||||
context[3] = token.lex.cluster & 63
|
||||
context[4] = token.lex.cluster & 15
|
||||
context[5] = token.dep_tag
|
||||
|
||||
|
||||
cdef int fill_context(atom_t* context, State* state) except -1:
|
||||
# This fills in the basic properties of each of our "slot" tokens, e.g.
|
||||
# word on top of the stack, word at the front of the buffer, etc.
|
||||
cdef TokenC* n1 = get_n1(state)
|
||||
fill_token(&context[S2w], get_s2(state))
|
||||
fill_token(&context[S1w], get_s1(state))
|
||||
#fill_token(&context[S1rw], get_right(state, get_s1(state), 0))
|
||||
fill_token(&context[S0lw], get_left(state, get_s0(state), 0))
|
||||
fill_token(&context[S0l2w], get_left(state, get_s0(state), 1))
|
||||
fill_token(&context[S0w], get_s0(state))
|
||||
#fill_token(&context[S0r2w], get_right(state, get_s0(state), 1))
|
||||
fill_token(&context[S0rw], get_right(state, get_s0(state), 0))
|
||||
#fill_token(&context[N0lw], get_left(state, get_n0(state), 0))
|
||||
#fill_token(&context[N0l2w], get_left(state, get_n0(state), 1))
|
||||
fill_token(&context[N0w], get_n0(state))
|
||||
#fill_token(&context[N1w], get_n1(state))
|
||||
#fill_token(&context[N2w], get_n2(state))
|
||||
|
||||
#if state.stack_len >= 1:
|
||||
# context[dist] = state.stack[0] - state.sent
|
||||
#else:
|
||||
# context[dist] = 0
|
||||
#context[N0lv] = 0
|
||||
#context[S0lv] = 0
|
||||
#context[S0rv] = 0
|
||||
#context[S1lv] = 0
|
||||
#context[S1rv] = 0
|
||||
|
||||
|
||||
arc_eager = (
|
||||
(S0w, S0p),
|
||||
(S0w,),
|
||||
(S0p,),
|
||||
(N0w, N0p),
|
||||
(N0w,),
|
||||
(N0p,),
|
||||
(N1w, N1p),
|
||||
(N1w,),
|
||||
(N1p,),
|
||||
(N2w, N2p),
|
||||
(N2w,),
|
||||
(N2p,),
|
||||
|
||||
(S0w, S0p, N0w, N0p),
|
||||
(S0w, S0p, N0w),
|
||||
(S0w, N0w, N0p),
|
||||
(S0w, S0p, N0p),
|
||||
(S0p, N0w, N0p),
|
||||
(S0w, N0w),
|
||||
(S0p, N0p),
|
||||
(N0p, N1p),
|
||||
(N0p, N1p, N2p),
|
||||
(S0p, N0p, N1p),
|
||||
(S1p, S0p, N0p),
|
||||
(S0p, S0lp, N0p),
|
||||
(S0p, S0rp, N0p),
|
||||
(S0p, N0p, N0lp),
|
||||
(dist, S0w),
|
||||
(dist, S0p),
|
||||
(dist, N0w),
|
||||
(dist, N0p),
|
||||
(dist, S0w, N0w),
|
||||
(dist, S0p, N0p),
|
||||
(S0w, S0rv),
|
||||
(S0p, S0rv),
|
||||
(S0w, S0lv),
|
||||
(S0p, S0lv),
|
||||
(N0w, N0lv),
|
||||
(N0p, N0lv),
|
||||
(S1w,),
|
||||
(S1p,),
|
||||
(S0lw,),
|
||||
(S0lp,),
|
||||
(S0rw,),
|
||||
(S0rp,),
|
||||
(N0lw,),
|
||||
(N0lp,),
|
||||
(S2w,),
|
||||
(S2p,),
|
||||
(S0l2w,),
|
||||
(S0l2p,),
|
||||
(S0r2w,),
|
||||
(S0r2p,),
|
||||
(N0l2w,),
|
||||
(N0l2p,),
|
||||
(S0p, S0lp, S0l2p),
|
||||
(S0p, S0rp, S0r2p),
|
||||
(S0p, S1p, S2p),
|
||||
(N0p, N0lp, N0l2p),
|
||||
(S0L,),
|
||||
(S0lL,),
|
||||
(S0rL,),
|
||||
(N0lL,),
|
||||
(S1L,),
|
||||
(S0l2L,),
|
||||
(S0r2L,),
|
||||
(N0l2L,),
|
||||
(S0w, S0rL, S0r2L),
|
||||
(S0p, S0rL, S0r2L),
|
||||
(S0w, S0lL, S0l2L),
|
||||
(S0p, S0lL, S0l2L),
|
||||
(N0w, N0lL, N0l2L),
|
||||
(N0p, N0lL, N0l2L),
|
||||
)
|
||||
|
||||
|
||||
label_sets = (
|
||||
(S0w, S0lL, S0l2L),
|
||||
(S0p, S0rL, S0r2L),
|
||||
(S0p, S0lL, S0l2L),
|
||||
(S0p, S0rL, S0r2L),
|
||||
(N0w, N0lL, N0l2L),
|
||||
(N0p, N0lL, N0l2L),
|
||||
)
|
||||
|
||||
extra_labels = (
|
||||
(S0p, S0lL, S0lp),
|
||||
(S0p, S0lL, S0l2L),
|
||||
(S0p, S0rL, S0rp),
|
||||
(S0p, S0rL, S0r2L),
|
||||
(S0p, S0lL, S0rL),
|
||||
(S1p, S0L, S0rL),
|
||||
(S1p, S0L, S0lL),
|
||||
)
|
||||
|
||||
|
||||
# Koo et al (2008) dependency features, using Brown clusters.
|
||||
clusters = (
|
||||
# Koo et al have (head, child) --- we have S0, N0 for both.
|
||||
(S0c4, N0c4),
|
||||
(S0c6, N0c6),
|
||||
(S0c, N0c),
|
||||
(S0p, N0c4),
|
||||
(S0p, N0c6),
|
||||
(S0p, N0c),
|
||||
(S0c4, N0p),
|
||||
(S0c6, N0p),
|
||||
(S0c, N0p),
|
||||
# Siblings --- right arc
|
||||
(S0c4, S0rc4, N0c4),
|
||||
(S0c6, S0rc6, N0c6),
|
||||
(S0p, S0rc4, N0c4),
|
||||
(S0c4, S0rp, N0c4),
|
||||
(S0c4, S0rc4, N0p),
|
||||
# Siblings --- left arc
|
||||
(S0c4, N0lc4, N0c4),
|
||||
(S0c6, N0c6, N0c6),
|
||||
(S0c4, N0lc4, N0p),
|
||||
(S0c4, N0lp, N0c4),
|
||||
(S0p, N0lc4, N0c4),
|
||||
# Grand-child, right-arc
|
||||
(S1c4, S0c4, N0c4),
|
||||
(S1c6, S0c6, N0c6),
|
||||
(S1p, S0c4, N0c4),
|
||||
(S1c4, S0p, N0c4),
|
||||
(S1c4, S0c4, N0p),
|
||||
# Grand-child, left-arc
|
||||
(S0lc4, S0c4, N0c4),
|
||||
(S0lc6, S0c6, N0c6),
|
||||
(S0lp, S0c4, N0c4),
|
||||
(S0lc4, S0p, N0c4),
|
||||
(S0lc4, S0c4, N0p)
|
||||
)
|
||||
|
||||
|
||||
def pos_bigrams():
|
||||
kernels = [S2w, S1w, S0w, S0lw, S0rw, N0w, N0lw, N1w]
|
||||
bitags = []
|
||||
for t1, t2 in combinations(kernels, 2):
|
||||
feat = (t1 + 1, t2 + 1)
|
||||
bitags.append(feat)
|
||||
print "Adding %d bitags" % len(bitags)
|
||||
return tuple(bitags)
|
8325
spacy/syntax/_state.cpp
Normal file
8325
spacy/syntax/_state.cpp
Normal file
File diff suppressed because it is too large
Load Diff
85
spacy/syntax/_state.pxd
Normal file
85
spacy/syntax/_state.pxd
Normal file
|
@ -0,0 +1,85 @@
|
|||
from libc.stdint cimport uint32_t
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from ..tokens cimport TokenC
|
||||
|
||||
|
||||
cdef struct State:
|
||||
TokenC* sent
|
||||
int i
|
||||
int sent_len
|
||||
int stack_len
|
||||
|
||||
|
||||
cdef int add_dep(State *s, TokenC* head, TokenC* child, int label) except -1
|
||||
|
||||
|
||||
cdef TokenC* pop_stack(State *s) except NULL
|
||||
cdef int push_stack(State *s) except -1
|
||||
|
||||
|
||||
cdef inline bint has_head(const TokenC* t) nogil:
|
||||
return t.head != 0
|
||||
|
||||
|
||||
cdef inline int get_idx(const State* s, const TokenC* t) nogil:
|
||||
return t - s.sent
|
||||
|
||||
|
||||
cdef inline TokenC* get_n0(const State* s) nogil:
|
||||
return &s.sent[s.i]
|
||||
|
||||
|
||||
cdef inline TokenC* get_n1(const State* s) nogil:
|
||||
if s.i < (s.sent_len - 1):
|
||||
return &s.sent[s.i+1]
|
||||
else:
|
||||
return s.sent - 1
|
||||
|
||||
|
||||
cdef inline TokenC* get_n2(const State* s) nogil:
|
||||
return &s.sent[s.i+2]
|
||||
|
||||
|
||||
cdef inline TokenC* get_s0(const State *s) nogil:
|
||||
return s.stack[0]
|
||||
|
||||
|
||||
cdef inline TokenC* get_s1(const State *s) nogil:
|
||||
# Rely on our padding to ensure we don't go out of bounds here
|
||||
cdef TokenC** s1 = s.stack - 1
|
||||
return s1[0]
|
||||
|
||||
|
||||
cdef inline TokenC* get_s2(const State *s) nogil:
|
||||
# Rely on our padding to ensure we don't go out of bounds here
|
||||
cdef TokenC** s2 = s.stack - 2
|
||||
return s2[0]
|
||||
|
||||
cdef TokenC* get_right(State* s, TokenC* head, int idx) nogil
|
||||
cdef TokenC* get_left(State* s, TokenC* head, int idx) nogil
|
||||
|
||||
cdef inline bint at_eol(const State *s) nogil:
|
||||
return s.i >= s.sent_len
|
||||
|
||||
|
||||
cdef inline bint is_final(const State *s) nogil:
|
||||
return at_eol(s) # The stack will be attached to root anyway
|
||||
|
||||
|
||||
cdef int children_in_buffer(const State *s, const TokenC* target, list gold) except -1
|
||||
cdef int head_in_buffer(const State *s, const TokenC* target, list gold) except -1
|
||||
cdef int children_in_stack(const State *s, const TokenC* target, list gold) except -1
|
||||
cdef int head_in_stack(const State *s, const TokenC*, list gold) except -1
|
||||
|
||||
cdef State* init_state(Pool mem, TokenC* sent, const int sent_length) except NULL
|
||||
|
||||
|
||||
|
||||
cdef inline uint32_t _nth_significant_bit(uint32_t bits, int n) nogil:
|
||||
cdef int i
|
||||
for i in range(32):
|
||||
if bits & (1 << i):
|
||||
return i
|
||||
return 0
|
129
spacy/syntax/_state.pyx
Normal file
129
spacy/syntax/_state.pyx
Normal file
|
@ -0,0 +1,129 @@
|
|||
# cython: profile=True
|
||||
from libc.string cimport memmove
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from ..lexeme cimport EMPTY_LEXEME
|
||||
|
||||
|
||||
cdef int add_dep(State *s, TokenC* head, TokenC* child, int label) except -1:
|
||||
child.head = head - child
|
||||
child.dep_tag = label
|
||||
# Keep a bit-vector tracking child dependencies. If a word has a child at
|
||||
# offset i from it, set that bit (tracking left and right separately)
|
||||
if child > head:
|
||||
head.r_kids |= 1 << child.head
|
||||
else:
|
||||
head.l_kids |= 1 << (-child.head)
|
||||
|
||||
|
||||
cdef TokenC* pop_stack(State *s) except NULL:
|
||||
assert s.stack_len >= 1
|
||||
cdef TokenC* top = s.stack[0]
|
||||
s.stack -= 1
|
||||
s.stack_len -= 1
|
||||
return top
|
||||
|
||||
|
||||
cdef int push_stack(State *s) except -1:
|
||||
assert s.i < s.sent_len
|
||||
s.stack += 1
|
||||
s.stack[0] = &s.sent[s.i]
|
||||
s.stack_len += 1
|
||||
s.i += 1
|
||||
|
||||
|
||||
cdef int children_in_buffer(const State *s, const TokenC* target, list gold) except -1:
|
||||
# Golds holds an array of head offsets --- the head of word i is i - golds[i]
|
||||
# Iterate over the tokens of the queue, and check whether their gold head is
|
||||
# our target
|
||||
cdef int i
|
||||
cdef int n = 0
|
||||
cdef TokenC* buff_word
|
||||
cdef TokenC* buff_head
|
||||
cdef int buff_word_head_offset
|
||||
for i in range(s.i, s.sent_len):
|
||||
buff_word = &s.sent[i]
|
||||
buff_word_head_offset = gold[i]
|
||||
buff_head = buff_word + buff_word_head_offset
|
||||
if buff_head == target:
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
cdef int head_in_buffer(const State *s, const TokenC* target, list gold) except -1:
|
||||
cdef int target_idx = get_idx(s, target)
|
||||
cdef int target_head_idx = target_idx + gold[target_idx]
|
||||
return target_head_idx >= s.i
|
||||
|
||||
|
||||
cdef int children_in_stack(const State *s, const TokenC* target, list gold) except -1:
|
||||
if s.stack_len == 0:
|
||||
return 0
|
||||
cdef int i
|
||||
cdef int n = 0
|
||||
cdef const TokenC* stack_word
|
||||
cdef const TokenC* stack_word_head
|
||||
cdef int stack_word_head_offset
|
||||
for i in range(s.stack_len):
|
||||
stack_word = (s.stack - i)[0]
|
||||
stack_word_head_offset = gold[get_idx(s, stack_word)]
|
||||
stack_word_head = (s.stack + stack_word_head_offset)[0]
|
||||
if stack_word_head == target:
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
cdef int head_in_stack(const State *s, const TokenC* target, list gold) except -1:
|
||||
if s.stack_len == 0:
|
||||
return 0
|
||||
cdef int head_offset = gold[get_idx(s, target)]
|
||||
cdef const TokenC* target_head = target + head_offset
|
||||
cdef int i
|
||||
for i in range(s.stack_len):
|
||||
if target_head == (s.stack - i)[0]:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
cdef TokenC* get_left(State* s, TokenC* head, int idx) nogil:
|
||||
cdef uint32_t kids = head.l_kids
|
||||
if kids == 0:
|
||||
return s.sent - 1
|
||||
cdef int offset = _nth_significant_bit(kids, idx)
|
||||
cdef TokenC* child = head - offset
|
||||
if child >= s.sent:
|
||||
return child
|
||||
else:
|
||||
return s.sent - 1
|
||||
|
||||
|
||||
cdef TokenC* get_right(State* s, TokenC* head, int idx) nogil:
|
||||
cdef uint32_t kids = head.r_kids
|
||||
if kids == 0:
|
||||
return s.sent - 1
|
||||
cdef int offset = _nth_significant_bit(kids, idx)
|
||||
cdef TokenC* child = head + offset
|
||||
if child < (s.sent + s.sent_len):
|
||||
return child
|
||||
else:
|
||||
return s.sent - 1
|
||||
|
||||
|
||||
DEF PADDING = 5
|
||||
|
||||
|
||||
cdef State* init_state(Pool mem, TokenC* sent, const int sent_length) except NULL:
|
||||
cdef int padded_len = sent_length + PADDING + PADDING
|
||||
cdef State* s = <State*>mem.alloc(1, sizeof(State))
|
||||
s.stack = <TokenC**>mem.alloc(padded_len, sizeof(TokenC*))
|
||||
cdef TokenC* eol_token = sent - 1
|
||||
for i in range(PADDING):
|
||||
# sent should be padded, with a suitable sentinel token here
|
||||
s.stack[0] = eol_token
|
||||
s.stack += 1
|
||||
s.stack[0] = eol_token
|
||||
s.sent = sent
|
||||
s.stack_len = 0
|
||||
s.i = 0
|
||||
s.sent_len = sent_length
|
||||
return s
|
10383
spacy/syntax/arc_eager.cpp
Normal file
10383
spacy/syntax/arc_eager.cpp
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -14,6 +14,7 @@ cdef struct Transition:
|
|||
cdef class TransitionSystem:
|
||||
cdef Pool mem
|
||||
cdef readonly int n_moves
|
||||
cdef dict label_ids
|
||||
|
||||
cdef const Transition* _moves
|
||||
|
196
spacy/syntax/arc_eager.pyx
Normal file
196
spacy/syntax/arc_eager.pyx
Normal file
|
@ -0,0 +1,196 @@
|
|||
# cython: profile=True
|
||||
from ._state cimport State
|
||||
from ._state cimport has_head, get_idx, get_s0, get_n0
|
||||
from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep
|
||||
from ._state cimport head_in_buffer, children_in_buffer
|
||||
from ._state cimport head_in_stack, children_in_stack
|
||||
|
||||
from ..tokens cimport TokenC
|
||||
|
||||
|
||||
cdef enum:
|
||||
SHIFT
|
||||
REDUCE
|
||||
LEFT
|
||||
RIGHT
|
||||
N_MOVES
|
||||
|
||||
|
||||
cdef inline bint _can_shift(const State* s) nogil:
|
||||
return not at_eol(s)
|
||||
|
||||
|
||||
cdef inline bint _can_right(const State* s) nogil:
|
||||
return s.stack_len >= 1 and not at_eol(s)
|
||||
|
||||
|
||||
cdef inline bint _can_left(const State* s) nogil:
|
||||
return s.stack_len >= 1 and not has_head(get_s0(s))
|
||||
|
||||
|
||||
cdef inline bint _can_reduce(const State* s) nogil:
|
||||
return s.stack_len >= 2 and has_head(get_s0(s))
|
||||
|
||||
|
||||
cdef int _shift_cost(const State* s, list gold) except -1:
|
||||
assert not at_eol(s)
|
||||
cost = 0
|
||||
cost += head_in_stack(s, get_n0(s), gold)
|
||||
cost += children_in_stack(s, get_n0(s), gold)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _right_cost(const State* s, list gold) except -1:
|
||||
assert s.stack_len >= 1
|
||||
cdef int s0_idx = get_idx(s, get_s0(s))
|
||||
cost = 0
|
||||
if _gold_dep(s, get_s0(s), get_n0(s), gold):
|
||||
return cost
|
||||
cost += head_in_buffer(s, get_n0(s), gold)
|
||||
cost += children_in_stack(s, get_n0(s), gold)
|
||||
cost += head_in_stack(s, get_n0(s), gold)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _left_cost(const State* s, list gold) except -1:
|
||||
assert s.stack_len >= 1
|
||||
cost = 0
|
||||
if _gold_dep(s, get_n0(s), get_s0(s), gold):
|
||||
return cost
|
||||
|
||||
cost += head_in_buffer(s, get_s0(s), gold)
|
||||
cost += children_in_buffer(s, get_s0(s), gold)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _reduce_cost(const State* s, list gold) except -1:
|
||||
return children_in_buffer(s, get_s0(s), gold)
|
||||
|
||||
|
||||
cdef int _gold_dep(const State* s, const TokenC* head, const TokenC* child,
|
||||
list gold_offsets) except -1:
|
||||
cdef int head_idx = get_idx(s, head)
|
||||
cdef int child_idx = get_idx(s, child)
|
||||
return child_idx + gold_offsets[child_idx] == head_idx
|
||||
|
||||
|
||||
cdef class TransitionSystem:
|
||||
def __init__(self, list left_labels, list right_labels):
|
||||
self.mem = Pool()
|
||||
if 'ROOT' in right_labels:
|
||||
right_labels.pop(right_labels.index('ROOT'))
|
||||
if 'ROOT' in left_labels:
|
||||
left_labels.pop(left_labels.index('ROOT'))
|
||||
self.n_moves = 2 + len(left_labels) + len(right_labels)
|
||||
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
||||
cdef int i = 0
|
||||
moves[i].move = SHIFT
|
||||
moves[i].label = 0
|
||||
i += 1
|
||||
moves[i].move = REDUCE
|
||||
moves[i].label = 0
|
||||
i += 1
|
||||
self.label_ids = {'ROOT': 0}
|
||||
cdef int label_id
|
||||
for label_str in left_labels:
|
||||
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
|
||||
moves[i].move = LEFT
|
||||
moves[i].label = label_id
|
||||
i += 1
|
||||
for label_str in right_labels:
|
||||
label_id = self.label_ids.setdefault(label_str, len(self.label_ids))
|
||||
moves[i].move = RIGHT
|
||||
moves[i].label = label_id
|
||||
i += 1
|
||||
self._moves = moves
|
||||
|
||||
cdef int transition(self, State *s, const int clas) except -1:
|
||||
cdef const Transition* t = &self._moves[clas]
|
||||
if t.move == SHIFT:
|
||||
push_stack(s)
|
||||
elif t.move == LEFT:
|
||||
add_dep(s, get_n0(s), get_s0(s), t.label)
|
||||
pop_stack(s)
|
||||
elif t.move == RIGHT:
|
||||
add_dep(s, get_s0(s), get_n0(s), t.label)
|
||||
push_stack(s)
|
||||
elif t.move == REDUCE:
|
||||
pop_stack(s)
|
||||
else:
|
||||
raise StandardError(t.move)
|
||||
|
||||
cdef int best_valid(self, const weight_t* scores, const State* s) except -1:
|
||||
cdef bint[N_MOVES] valid
|
||||
valid[SHIFT] = _can_shift(s)
|
||||
valid[LEFT] = _can_left(s)
|
||||
valid[RIGHT] = _can_right(s)
|
||||
valid[REDUCE] = _can_reduce(s)
|
||||
|
||||
cdef int best = -1
|
||||
cdef weight_t score = -90000
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
if valid[self._moves[i].move] and scores[i] > score:
|
||||
best = i
|
||||
score = scores[i]
|
||||
return best
|
||||
|
||||
cdef int best_gold(self, const weight_t* scores, const State* s,
|
||||
list gold_heads, list label_strings) except -1:
|
||||
gold_labels = [self.label_ids[label_str] for label_str in label_strings]
|
||||
cdef int[N_MOVES] unl_costs
|
||||
unl_costs[SHIFT] = _shift_cost(s, gold_heads) if _can_shift(s) else -1
|
||||
unl_costs[LEFT] = _left_cost(s, gold_heads) if _can_left(s) else -1
|
||||
unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1
|
||||
unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1
|
||||
|
||||
cdef int cost
|
||||
cdef int move
|
||||
cdef int label
|
||||
cdef int best = -1
|
||||
cdef weight_t score = -9000
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
move = self._moves[i].move
|
||||
label = self._moves[i].label
|
||||
if unl_costs[move] == 0:
|
||||
if move == SHIFT or move == REDUCE:
|
||||
cost = 0
|
||||
elif move == LEFT:
|
||||
if _gold_dep(s, get_n0(s), get_s0(s), gold_heads):
|
||||
cost = label != gold_labels[get_idx(s, get_s0(s))]
|
||||
else:
|
||||
cost = 0
|
||||
elif move == RIGHT:
|
||||
if _gold_dep(s, get_s0(s), get_n0(s), gold_heads):
|
||||
cost = label != gold_labels[s.i]
|
||||
else:
|
||||
cost = 0
|
||||
else:
|
||||
raise StandardError("Unknown Move")
|
||||
if cost == 0 and (best == -1 or scores[i] > score):
|
||||
best = i
|
||||
score = scores[i]
|
||||
|
||||
if best < 0:
|
||||
for i in range(self.n_moves):
|
||||
if self._moves[i].move == LEFT:
|
||||
print self._moves[i].label,
|
||||
print
|
||||
print _gold_dep(s, get_n0(s), get_s0(s), gold_heads)
|
||||
print gold_labels[get_idx(s, get_s0(s))]
|
||||
print unl_costs[LEFT]
|
||||
print "S0:"
|
||||
print "Head:", gold_heads[get_idx(s, get_s0(s))]
|
||||
print "h. in b.", head_in_buffer(s, get_s0(s), gold_heads)
|
||||
print "c. in b.", children_in_buffer(s, get_s0(s), gold_heads)
|
||||
print "h. in s.", head_in_stack(s, get_s0(s), gold_heads)
|
||||
print "c. in s.", children_in_stack(s, get_s0(s), gold_heads)
|
||||
print "N0:"
|
||||
print "Head:", gold_heads[get_idx(s, get_n0(s))]
|
||||
print "h. in b.", head_in_buffer(s, get_n0(s), gold_heads)
|
||||
print "c. in b.", children_in_buffer(s, get_n0(s), gold_heads)
|
||||
print "h. in s.", head_in_stack(s, get_n0(s), gold_heads)
|
||||
print "c. in s.", children_in_stack(s, get_n0(s), gold_heads)
|
||||
raise StandardError
|
||||
return best
|
9585
spacy/syntax/parser.cpp
Normal file
9585
spacy/syntax/parser.cpp
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -6,10 +6,10 @@ from .arc_eager cimport TransitionSystem
|
|||
from ..tokens cimport Tokens, TokenC
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
cdef class GreedyParser:
|
||||
cdef object cfg
|
||||
cdef Extractor extractor
|
||||
cdef LinearModel model
|
||||
cdef readonly LinearModel model
|
||||
cdef TransitionSystem moves
|
||||
|
||||
cpdef int parse(self, Tokens tokens) except -1
|
113
spacy/syntax/parser.pyx
Normal file
113
spacy/syntax/parser.pyx
Normal file
|
@ -0,0 +1,113 @@
|
|||
# cython: profile=True
|
||||
"""
|
||||
MALT-style dependency parser
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
cimport cython
|
||||
import random
|
||||
import os.path
|
||||
from os.path import join as pjoin
|
||||
import shutil
|
||||
import json
|
||||
|
||||
from cymem.cymem cimport Pool, Address
|
||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
|
||||
|
||||
|
||||
from util import Config
|
||||
|
||||
from thinc.features cimport Extractor
|
||||
from thinc.features cimport Feature
|
||||
from thinc.features cimport count_feats
|
||||
|
||||
from thinc.learner cimport LinearModel
|
||||
|
||||
from ..tokens cimport Tokens, TokenC
|
||||
|
||||
from .arc_eager cimport TransitionSystem
|
||||
|
||||
from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1
|
||||
|
||||
from . import _parse_features
|
||||
from ._parse_features cimport fill_context, CONTEXT_SIZE
|
||||
|
||||
|
||||
DEF CONTEXT_SIZE = 50
|
||||
|
||||
|
||||
DEBUG = False
|
||||
def set_debug(val):
|
||||
global DEBUG
|
||||
DEBUG = val
|
||||
|
||||
|
||||
cdef unicode print_state(State* s, list words):
|
||||
words = list(words) + ['EOL']
|
||||
top = words[get_idx(s, get_s0(s))]
|
||||
second = words[get_idx(s, get_s1(s))]
|
||||
n0 = words[s.i]
|
||||
n1 = words[s.i + 1]
|
||||
return ' '.join((second, top, '|', n0, n1))
|
||||
|
||||
|
||||
def get_templates(name):
|
||||
return _parse_features.arc_eager
|
||||
|
||||
|
||||
cdef class GreedyParser:
|
||||
def __init__(self, model_dir):
|
||||
assert os.path.exists(model_dir) and os.path.isdir(model_dir)
|
||||
self.cfg = Config.read(model_dir, 'config')
|
||||
self.extractor = Extractor(get_templates(self.cfg.features))
|
||||
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
|
||||
|
||||
self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ)
|
||||
if os.path.exists(pjoin(model_dir, 'model')):
|
||||
self.model.load(pjoin(model_dir, 'model'))
|
||||
|
||||
cpdef int parse(self, Tokens tokens) except -1:
|
||||
cdef:
|
||||
Feature* feats
|
||||
const weight_t* scores
|
||||
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef int n_feats
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = init_state(mem, tokens.data, tokens.length)
|
||||
while not is_final(state):
|
||||
fill_context(context, state) # TODO
|
||||
feats = self.extractor.get_feats(context, &n_feats)
|
||||
scores = self.model.get_scores(feats, n_feats)
|
||||
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
|
||||
self.moves.transition(state, guess)
|
||||
# TODO output
|
||||
|
||||
def train_sent(self, Tokens tokens, list gold_heads, list gold_labels):
|
||||
cdef:
|
||||
Feature* feats
|
||||
weight_t* scores
|
||||
|
||||
cdef int n_feats
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = init_state(mem, tokens.data, tokens.length)
|
||||
words = [t.string for t in tokens]
|
||||
while not is_final(state):
|
||||
fill_context(context, state)
|
||||
feats = self.extractor.get_feats(context, &n_feats)
|
||||
scores = self.model.get_scores(feats, n_feats)
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
best = self.moves.best_gold(scores, state, gold_heads, gold_labels)
|
||||
counts = {guess: {}, best: {}}
|
||||
if guess != best:
|
||||
count_feats(counts[guess], feats, n_feats, -1)
|
||||
count_feats(counts[best], feats, n_feats, 1)
|
||||
self.model.update(counts)
|
||||
self.moves.transition(state, guess)
|
||||
cdef int i
|
||||
n_corr = 0
|
||||
for i in range(tokens.length):
|
||||
n_corr += state.sent[i].head == gold_heads[i]
|
||||
return n_corr
|
17
spacy/syntax/util.py
Normal file
17
spacy/syntax/util.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
from os import path
|
||||
import json
|
||||
|
||||
class Config(object):
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
@classmethod
|
||||
def write(cls, model_dir, name, **kwargs):
|
||||
open(path.join(model_dir, '%s.json' % name), 'w').write(json.dumps(kwargs))
|
||||
|
||||
@classmethod
|
||||
def read(cls, model_dir, name):
|
||||
return cls(**json.load(open(path.join(model_dir, '%s.json' % name))))
|
||||
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
from libc.stdint cimport uint32_t
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
|
@ -19,6 +21,10 @@ cdef struct TokenC:
|
|||
int pos
|
||||
int lemma
|
||||
int sense
|
||||
int head
|
||||
int dep_tag
|
||||
uint32_t l_kids
|
||||
uint32_t r_kids
|
||||
|
||||
|
||||
ctypedef const Lexeme* const_Lexeme_ptr
|
||||
|
|
Loading…
Reference in New Issue
Block a user