* Fix merge conflicts

This commit is contained in:
Matthew Honnibal 2015-06-23 17:28:00 +02:00
commit 6dbe182491
24 changed files with 982 additions and 688 deletions

View File

@ -48,7 +48,7 @@ def add_noise(orig, noise_level):
return ''.join(_corrupt(c, noise_level) for c in orig) return ''.join(_corrupt(c, noise_level) for c in orig)
def score_model(scorer, nlp, raw_text, annot_tuples): def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
if raw_text is None: if raw_text is None:
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
else: else:
@ -57,7 +57,7 @@ def score_model(scorer, nlp, raw_text, annot_tuples):
nlp.entity(tokens) nlp.entity(tokens)
nlp.parser(tokens) nlp.parser(tokens)
gold = GoldParse(tokens, annot_tuples) gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=False) scorer.score(tokens, gold, verbose=verbose)
def _merge_sents(sents): def _merge_sents(sents):
@ -78,7 +78,8 @@ def _merge_sents(sents):
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
seed=0, gold_preproc=False, n_sents=0, corruption_level=0, seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
beam_width=1): beam_width=1, verbose=False,
use_orig_arc_eager=False):
dep_model_dir = path.join(model_dir, 'deps') dep_model_dir = path.join(model_dir, 'deps')
pos_model_dir = path.join(model_dir, 'pos') pos_model_dir = path.join(model_dir, 'pos')
ner_model_dir = path.join(model_dir, 'ner') ner_model_dir = path.join(model_dir, 'ner')
@ -118,7 +119,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
for annot_tuples, ctnt in sents: for annot_tuples, ctnt in sents:
if len(annot_tuples[1]) == 1: if len(annot_tuples[1]) == 1:
continue continue
score_model(scorer, nlp, raw_text, annot_tuples) score_model(scorer, nlp, raw_text, annot_tuples,
verbose=verbose if itn >= 2 else False)
if raw_text is None: if raw_text is None:
words = add_noise(annot_tuples[1], corruption_level) words = add_noise(annot_tuples[1], corruption_level)
tokens = nlp.tokenizer.tokens_from_list(words) tokens = nlp.tokenizer.tokens_from_list(words)
@ -127,8 +129,12 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
tokens = nlp.tokenizer(raw_text) tokens = nlp.tokenizer(raw_text)
nlp.tagger(tokens) nlp.tagger(tokens)
gold = GoldParse(tokens, annot_tuples, make_projective=True) gold = GoldParse(tokens, annot_tuples, make_projective=True)
if not gold.is_projective:
raise Exception(
"Non-projective sentence in training, after we should "
"have enforced projectivity: %s" % annot_tuples
)
loss += nlp.parser.train(tokens, gold) loss += nlp.parser.train(tokens, gold)
nlp.entity.train(tokens, gold) nlp.entity.train(tokens, gold)
nlp.tagger.train(tokens, gold.tags) nlp.tagger.train(tokens, gold.tags)
random.shuffle(gold_tuples) random.shuffle(gold_tuples)
@ -203,20 +209,24 @@ def write_parses(Language, dev_loc, model_dir, out_loc, beam_width=None):
n_iter=("Number of training iterations", "option", "i", int), n_iter=("Number of training iterations", "option", "i", int),
beam_width=("Number of candidates to maintain in the beam", "option", "k", int), beam_width=("Number of candidates to maintain in the beam", "option", "k", int),
verbose=("Verbose error reporting", "flag", "v", bool), verbose=("Verbose error reporting", "flag", "v", bool),
debug=("Debug mode", "flag", "d", bool) debug=("Debug mode", "flag", "d", bool),
use_orig_arc_eager=("Use the original, monotonic arc-eager system", "flag", "m", bool)
) )
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False, def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
debug=False, corruption_level=0.0, gold_preproc=False, beam_width=1, debug=False, corruption_level=0.0, gold_preproc=False, beam_width=1,
eval_only=False): eval_only=False, use_orig_arc_eager=False):
if use_orig_arc_eager:
English.ParserTransitionSystem = TreeArcEager
if not eval_only: if not eval_only:
gold_train = list(read_json_file(train_loc)) gold_train = list(read_json_file(train_loc))
train(English, gold_train, model_dir, train(English, gold_train, model_dir,
feat_set='basic' if not debug else 'debug', feat_set='basic' if not debug else 'debug',
gold_preproc=gold_preproc, n_sents=n_sents, gold_preproc=gold_preproc, n_sents=n_sents,
corruption_level=corruption_level, n_iter=n_iter, corruption_level=corruption_level, n_iter=n_iter,
beam_width=beam_width) beam_width=beam_width, verbose=verbose,
if out_loc: use_orig_arc_eager=use_orig_arc_eager)
write_parses(English, dev_loc, model_dir, out_loc, beam_width=beam_width) #if out_loc:
# write_parses(English, dev_loc, model_dir, out_loc, beam_width=beam_width)
scorer = evaluate(English, list(read_json_file(dev_loc)), scorer = evaluate(English, list(read_json_file(dev_loc)),
model_dir, gold_preproc=gold_preproc, verbose=verbose, model_dir, gold_preproc=gold_preproc, verbose=verbose,
beam_width=beam_width) beam_width=beam_width)

View File

@ -2,7 +2,7 @@ cython
cymem == 1.11 cymem == 1.11
pathlib pathlib
preshed == 0.37 preshed == 0.37
thinc == 1.76 thinc == 2.0
murmurhash == 0.24 murmurhash == 0.24
unidecode unidecode
numpy numpy

View File

@ -118,7 +118,7 @@ def run_setup(exts):
ext_modules=exts, ext_modules=exts,
license="Dual: Commercial or AGPL", license="Dual: Commercial or AGPL",
install_requires=['numpy', 'murmurhash', 'cymem >= 1.11', 'preshed == 0.37', install_requires=['numpy', 'murmurhash', 'cymem >= 1.11', 'preshed == 0.37',
'thinc == 1.76', "unidecode", 'wget', 'plac', 'six', 'thinc == 2.0', "unidecode", 'wget', 'plac', 'six',
'ujson'], 'ujson'],
setup_requires=["headers_workaround"], setup_requires=["headers_workaround"],
) )
@ -150,10 +150,12 @@ def main(modules, is_pypy):
MOD_NAMES = ['spacy.parts_of_speech', 'spacy.strings', MOD_NAMES = ['spacy.parts_of_speech', 'spacy.strings',
'spacy.lexeme', 'spacy.vocab', 'spacy.tokens', 'spacy.spans', 'spacy.lexeme', 'spacy.vocab', 'spacy.tokens', 'spacy.spans',
'spacy.morphology', 'spacy.morphology',
'spacy.syntax.stateclass',
'spacy._ml', 'spacy.tokenizer', 'spacy.en.attrs', 'spacy._ml', 'spacy.tokenizer', 'spacy.en.attrs',
'spacy.en.pos', 'spacy.syntax.parser', 'spacy.syntax._state', 'spacy.en.pos', 'spacy.syntax.parser',
'spacy.syntax.transition_system', 'spacy.syntax.transition_system',
'spacy.syntax.arc_eager', 'spacy.syntax._parse_features', 'spacy.syntax.arc_eager',
'spacy.syntax._parse_features',
'spacy.gold', 'spacy.orth', 'spacy.gold', 'spacy.orth',
'spacy.syntax.ner'] 'spacy.syntax.ner']

View File

@ -121,7 +121,7 @@ def _min_edit_path(cand_words, gold_words):
return prev_costs[n_gold], previous_row[-1] return prev_costs[n_gold], previous_row[-1]
def read_json_file(loc): def read_json_file(loc, docs_filter=None):
print loc print loc
if path.isdir(loc): if path.isdir(loc):
for filename in os.listdir(loc): for filename in os.listdir(loc):
@ -130,6 +130,8 @@ def read_json_file(loc):
with open(loc) as file_: with open(loc) as file_:
docs = ujson.load(file_) docs = ujson.load(file_)
for doc in docs: for doc in docs:
if docs_filter is not None and not docs_filter(doc):
continue
paragraphs = [] paragraphs = []
for paragraph in doc['paragraphs']: for paragraph in doc['paragraphs']:
sents = [] sents = []
@ -146,6 +148,9 @@ def read_json_file(loc):
tags.append(token['tag']) tags.append(token['tag'])
heads.append(token['head'] + i) heads.append(token['head'] + i)
labels.append(token['dep']) labels.append(token['dep'])
# Ensure ROOT label is case-insensitive
if labels[-1].lower() == 'root':
labels[-1] = 'ROOT'
ner.append(token.get('ner', '-')) ner.append(token.get('ner', '-'))
sents.append(( sents.append((
(ids, words, tags, heads, labels, ner), (ids, words, tags, heads, labels, ner),
@ -240,6 +245,16 @@ cdef class GoldParse:
self.heads[w2] = None self.heads[w2] = None
self.labels[w2] = '' self.labels[w2] = ''
# Check there are no cycles in the dependencies, i.e. we are a tree
for w in range(self.length):
seen = set([w])
head = w
while self.heads[head] != head and self.heads[head] != None:
head = self.heads[head]
if head in seen:
raise Exception("Cycle found: %s" % seen)
seen.add(head)
self.brackets = {} self.brackets = {}
for (gold_start, gold_end, label_str) in brackets: for (gold_start, gold_end, label_str) in brackets:
start = self.gold_to_cand[gold_start] start = self.gold_to_cand[gold_start]

View File

@ -10,11 +10,12 @@ def parse(sent_text, strip_bad_periods=False):
assert sent_text assert sent_text
annot = [] annot = []
words = [] words = []
id_map = {} id_map = {-1: -1}
for i, line in enumerate(sent_text.split('\n')): for i, line in enumerate(sent_text.split('\n')):
word, tag, head, dep = _parse_line(line) word, tag, head, dep = _parse_line(line)
if strip_bad_periods and words and _is_bad_period(words[-1], word): if strip_bad_periods and words and _is_bad_period(words[-1], word):
continue continue
id_map[i] = len(words)
annot.append({ annot.append({
'id': len(words), 'id': len(words),
@ -23,6 +24,8 @@ def parse(sent_text, strip_bad_periods=False):
'head': int(head) - 1, 'head': int(head) - 1,
'dep': dep}) 'dep': dep})
words.append(word) words.append(word)
for entry in annot:
entry['head'] = id_map[entry['head']]
return words, annot return words, annot

View File

@ -113,3 +113,9 @@ class Scorer(object):
set(item[:2] for item in cand_deps), set(item[:2] for item in cand_deps),
set(item[:2] for item in gold_deps), set(item[:2] for item in gold_deps),
) )
if verbose:
gold_words = [item[1] for item in gold.orig_annot]
for w_id, h_id, dep in (cand_deps - gold_deps):
print 'F', gold_words[w_id], dep, gold_words[h_id]
for w_id, h_id, dep in (gold_deps - cand_deps):
print 'M', gold_words[w_id], dep, gold_words[h_id]

View File

@ -61,6 +61,9 @@ cdef class StringStore:
def __get__(self): def __get__(self):
return self.size-1 return self.size-1
def __len__(self):
return self.size
def __getitem__(self, object string_or_id): def __getitem__(self, object string_or_id):
cdef bytes byte_string cdef bytes byte_string
cdef const Utf8Str* utf8str cdef const Utf8Str* utf8str

View File

@ -68,7 +68,7 @@ cdef struct TokenC:
int sense int sense
int head int head
int dep int dep
bint sent_end bint sent_start
uint32_t l_kids uint32_t l_kids
uint32_t r_kids uint32_t r_kids

View File

@ -1,9 +1,9 @@
from thinc.typedefs cimport atom_t from thinc.typedefs cimport atom_t
from ._state cimport State from .stateclass cimport StateClass
cdef int fill_context(atom_t* context, State* state) except -1 cdef int fill_context(atom_t* context, StateClass state) except -1
# Context elements # Context elements
# Ensure each token's attributes are listed: w, p, c, c6, c4. The order # Ensure each token's attributes are listed: w, p, c, c6, c4. The order

View File

@ -12,12 +12,10 @@ from libc.string cimport memset
from itertools import combinations from itertools import combinations
from ..tokens cimport TokenC from ..tokens cimport TokenC
from ._state cimport State
from ._state cimport get_s2, get_s1, get_s0, get_n0, get_n1, get_n2 from .stateclass cimport StateClass
from ._state cimport get_p2, get_p1
from ._state cimport get_e0, get_e1 from cymem.cymem cimport Pool
from ._state cimport has_head, get_left, get_right
from ._state cimport count_left_kids, count_right_kids
cdef inline void fill_token(atom_t* context, const TokenC* token) nogil: cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
@ -53,56 +51,56 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
# the source that are set to 1. # the source that are set to 1.
context[4] = token.lex.cluster & 15 context[4] = token.lex.cluster & 15
context[5] = token.lex.cluster & 63 context[5] = token.lex.cluster & 63
context[6] = token.dep if has_head(token) else 0 context[6] = token.dep if token.head != 0 else 0
context[7] = token.lex.prefix context[7] = token.lex.prefix
context[8] = token.lex.suffix context[8] = token.lex.suffix
context[9] = token.lex.shape context[9] = token.lex.shape
context[10] = token.ent_iob context[10] = token.ent_iob
context[11] = token.ent_type context[11] = token.ent_type
cdef int fill_context(atom_t* ctxt, StateClass st) except -1:
cdef int fill_context(atom_t* context, State* state) except -1:
# Take care to fill every element of context! # Take care to fill every element of context!
# We could memset, but this makes it very easy to have broken features that # We could memset, but this makes it very easy to have broken features that
# make almost no impact on accuracy. If instead they're unset, the impact # make almost no impact on accuracy. If instead they're unset, the impact
# tends to be dramatic, so we get an obvious regression to fix... # tends to be dramatic, so we get an obvious regression to fix...
fill_token(&context[S2w], get_s2(state)) fill_token(&ctxt[S2w], st.S_(2))
fill_token(&context[S1w], get_s1(state)) fill_token(&ctxt[S1w], st.S_(1))
fill_token(&context[S1rw], get_right(state, get_s1(state), 1)) fill_token(&ctxt[S1rw], st.R_(st.S(1), 1))
fill_token(&context[S0lw], get_left(state, get_s0(state), 1)) fill_token(&ctxt[S0lw], st.L_(st.S(0), 1))
fill_token(&context[S0l2w], get_left(state, get_s0(state), 2)) fill_token(&ctxt[S0l2w], st.L_(st.S(0), 2))
fill_token(&context[S0w], get_s0(state)) fill_token(&ctxt[S0w], st.S_(0))
fill_token(&context[S0r2w], get_right(state, get_s0(state), 2)) fill_token(&ctxt[S0r2w], st.R_(st.S(0), 2))
fill_token(&context[S0rw], get_right(state, get_s0(state), 1)) fill_token(&ctxt[S0rw], st.R_(st.S(0), 1))
fill_token(&context[N0lw], get_left(state, get_n0(state), 1)) fill_token(&ctxt[N0lw], st.L_(st.B(0), 1))
fill_token(&context[N0l2w], get_left(state, get_n0(state), 2)) fill_token(&ctxt[N0l2w], st.L_(st.B(0), 2))
fill_token(&context[N0w], get_n0(state)) fill_token(&ctxt[N0w], st.B_(0))
fill_token(&context[N1w], get_n1(state)) fill_token(&ctxt[N1w], st.B_(1))
fill_token(&context[N2w], get_n2(state)) fill_token(&ctxt[N2w], st.B_(2))
fill_token(&context[P1w], get_p1(state)) fill_token(&ctxt[P1w], st.safe_get(st.B(0)-1))
fill_token(&context[P2w], get_p2(state)) fill_token(&ctxt[P2w], st.safe_get(st.B(0)-2))
fill_token(&context[E0w], get_e0(state)) fill_token(&ctxt[E0w], st.E_(0))
fill_token(&context[E1w], get_e1(state)) fill_token(&ctxt[E1w], st.E_(1))
if state.stack_len >= 1:
context[dist] = min(state.i - state.stack[0], 5) if st.stack_depth() >= 1 and not st.eol():
ctxt[dist] = min(st.B(0) - st.E(0), 5)
else: else:
context[dist] = 0 ctxt[dist] = 0
context[N0lv] = min(count_left_kids(get_n0(state)), 5) ctxt[N0lv] = min(st.n_L(st.B(0)), 5)
context[S0lv] = min(count_left_kids(get_s0(state)), 5) ctxt[S0lv] = min(st.n_L(st.S(0)), 5)
context[S0rv] = min(count_right_kids(get_s0(state)), 5) ctxt[S0rv] = min(st.n_R(st.S(0)), 5)
context[S1lv] = min(count_left_kids(get_s1(state)), 5) ctxt[S1lv] = min(st.n_L(st.S(1)), 5)
context[S1rv] = min(count_right_kids(get_s1(state)), 5) ctxt[S1rv] = min(st.n_R(st.S(1)), 5)
context[S0_has_head] = 0 ctxt[S0_has_head] = 0
context[S1_has_head] = 0 ctxt[S1_has_head] = 0
context[S2_has_head] = 0 ctxt[S2_has_head] = 0
if state.stack_len >= 1: if st.stack_depth() >= 1:
context[S0_has_head] = has_head(get_s0(state)) + 1 ctxt[S0_has_head] = st.has_head(st.S(0)) + 1
if state.stack_len >= 2: if st.stack_depth() >= 2:
context[S1_has_head] = has_head(get_s1(state)) + 1 ctxt[S1_has_head] = st.has_head(st.S(1)) + 1
if state.stack_len >= 3: if st.stack_depth() >= 3:
context[S2_has_head] = has_head(get_s2(state)) + 1 ctxt[S2_has_head] = st.has_head(st.S(2)) + 1
ner = ( ner = (
@ -266,6 +264,32 @@ s0_n0 = (
(S0p, S0rp, N0p), (S0p, S0rp, N0p),
(S0p, N0lp, N0W), (S0p, N0lp, N0W),
(S0p, N0lp, N0p), (S0p, N0lp, N0p),
(S0L, N0p),
(S0p, S0rL, N0p),
(S0p, N0lL, N0p),
(S0p, S0rv, N0p),
(S0p, N0lv, N0p),
(S0c6, S0rL, S0r2L, N0p),
(S0p, N0lL, N0l2L, N0p),
)
s1_s0 = (
(S1p, S0p),
(S1p, S0p, S0_has_head),
(S1W, S0p),
(S1W, S0p, S0_has_head),
(S1c, S0p),
(S1c, S0p, S0_has_head),
(S1p, S1rL, S0p),
(S1p, S1rL, S0p, S0_has_head),
(S1p, S0lL, S0p),
(S1p, S0lL, S0p, S0_has_head),
(S1p, S0lL, S0l2L, S0p),
(S1p, S0lL, S0l2L, S0p, S0_has_head),
(S1L, S0L, S0W),
(S1L, S0L, S0p),
(S1p, S1L, S0L, S0p),
) )
@ -277,6 +301,8 @@ s1_n0 = (
(S1W, S1p, N0p), (S1W, S1p, N0p),
(S1p, N0W, N0p), (S1p, N0W, N0p),
(S1c6, S1p, N0c6, N0p), (S1c6, S1p, N0c6, N0p),
(S1L, N0p),
(S1p, S1rL, N0p),
) )
@ -288,6 +314,8 @@ s0_n1 = (
(S0W, S0p, N1p), (S0W, S0p, N1p),
(S0p, N1W, N1p), (S0p, N1W, N1p),
(S0c6, S0p, N1c6, N1p), (S0c6, S0p, N1c6, N1p),
(S0L, N1p),
(S0p, S0rL, N1p),
) )

View File

@ -2,10 +2,16 @@ from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t from thinc.typedefs cimport weight_t
from .stateclass cimport StateClass
from ._state cimport State
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport TransitionSystem, Transition
from ..gold cimport GoldParseC
cdef class ArcEager(TransitionSystem): cdef class ArcEager(TransitionSystem):
pass pass
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil

View File

@ -1,12 +1,8 @@
# cython: profile=True # cython: profile=True
from __future__ import unicode_literals from __future__ import unicode_literals
from ._state cimport State import ctypes
from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right import os
from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep
from ._state cimport head_in_buffer, children_in_buffer
from ._state cimport head_in_stack, children_in_stack
from ._state cimport count_left_kids
from ..structs cimport TokenC from ..structs cimport TokenC
@ -15,9 +11,16 @@ from .transition_system cimport move_cost_func_t, label_cost_func_t
from ..gold cimport GoldParse from ..gold cimport GoldParse
from ..gold cimport GoldParseC from ..gold cimport GoldParseC
from libc.stdint cimport uint32_t
from libc.string cimport memcpy
from cymem.cymem cimport Pool
from .stateclass cimport StateClass
DEF NON_MONOTONIC = True DEF NON_MONOTONIC = True
DEF USE_BREAK = True DEF USE_BREAK = True
DEF USE_ROOT_ARC_SEGMENT = True
cdef weight_t MIN_SCORE = -90000 cdef weight_t MIN_SCORE = -90000
@ -31,9 +34,6 @@ cdef enum:
BREAK BREAK
CONSTITUENT
ADJUST
N_MOVES N_MOVES
@ -43,417 +43,259 @@ MOVE_NAMES[REDUCE] = 'D'
MOVE_NAMES[LEFT] = 'L' MOVE_NAMES[LEFT] = 'L'
MOVE_NAMES[RIGHT] = 'R' MOVE_NAMES[RIGHT] = 'R'
MOVE_NAMES[BREAK] = 'B' MOVE_NAMES[BREAK] = 'B'
MOVE_NAMES[CONSTITUENT] = 'C'
MOVE_NAMES[ADJUST] = 'A'
# Helper functions for the arc-eager oracle # Helper functions for the arc-eager oracle
cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -1: cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
# When we push a word, we can't make arcs to or from the stack. So, we lose
# any of those arcs.
cdef int cost = 0 cdef int cost = 0
cost += head_in_stack(st, target, gold.heads) cdef int i, S_i
cost += children_in_stack(st, target, gold.heads) for i in range(stcls.stack_depth()):
S_i = stcls.S(i)
if gold.heads[target] == S_i:
cost += 1
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
cost += 1
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0
return cost return cost
cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1: cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef int cost = 0 cdef int cost = 0
cost += children_in_buffer(st, target, gold.heads) cdef int i, B_i
cost += head_in_buffer(st, target, gold.heads) for i in range(stcls.buffer_length()):
B_i = stcls.B(i)
cost += gold.heads[B_i] == target
cost += gold.heads[target] == B_i
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
break
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0
return cost return cost
cdef int arc_cost(const GoldParseC* gold, int head, int child, int label) except -1: cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil:
if gold.heads[child] != head: if arc_is_gold(gold, head, child):
return 0 return 0
elif gold.labels[child] == -1: elif stcls.H(child) == gold.heads[child]:
return 0
elif gold.labels[child] == label:
return 0
else:
return 1 return 1
# Head in buffer
elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != -1:
return 1
else:
return 0
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
if gold.labels[child] == -1:
return True
elif USE_ROOT_ARC_SEGMENT and _is_gold_root(gold, head) and _is_gold_root(gold, child):
return True
elif gold.heads[child] == head:
return True
else:
return False
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) nogil:
if gold.labels[child] == -1:
return True
elif label == -1:
return True
elif gold.labels[child] == label:
return True
else:
return False
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
return gold.labels[word] == -1 or gold.heads[word] == word
cdef class Shift: cdef class Shift:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) nogil:
return not at_eol(s) return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_start
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(StateClass st, int label) nogil:
# Set the dep label, in case we need it after we reduce st.push()
if NON_MONOTONIC: st.fast_forward()
state.sent[state.i].dep = label
push_stack(state)
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil:
if not Shift.is_valid(s, label): return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
return 9000
return Shift.move_cost(s, gold) + Shift.label_cost(s, gold, label)
@staticmethod @staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1: cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef int cost = push_cost(s, gold, s.i) return push_cost(s, gold, s.B(0))
# If we can break, and there's no cost to doing so, we should
if Break.is_valid(s, -1) and Break.cost(s, gold, -1) == 0:
cost += 1
return cost
@staticmethod @staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0 return 0
cdef class Reduce: cdef class Reduce:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) nogil:
if NON_MONOTONIC: return st.stack_depth() >= 2
return s.stack_len >= 2 #and not missing_brackets(s)
@staticmethod
cdef int transition(StateClass st, int label) nogil:
if st.has_head(st.S(0)):
st.pop()
else: else:
return s.stack_len >= 2 and has_head(get_s0(s)) st.unshift()
st.fast_forward()
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
if NON_MONOTONIC and not has_head(get_s0(state)):
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
pop_stack(state)
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Reduce.is_valid(s, label):
return 9000
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label) return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
@staticmethod @staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1: cdef inline int move_cost(StateClass st, const GoldParseC* gold) nogil:
if NON_MONOTONIC: return pop_cost(st, gold, st.S(0))
return pop_cost(s, gold, s.stack[0])
else:
return children_in_buffer(s, s.stack[0], gold.heads)
@staticmethod @staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0 return 0
cdef class LeftArc: cdef class LeftArc:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) nogil:
if NON_MONOTONIC: return not st.B_(0).sent_start
return s.stack_len >= 1 #and not missing_brackets(s)
else:
return s.stack_len >= 1 and not has_head(get_s0(s))
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(StateClass st, int label) nogil:
# Interpret left-arcs from EOL as attachment to root st.add_arc(st.B(0), st.S(0), label)
if at_eol(state): st.pop()
add_dep(state, state.stack[0], state.stack[0], label) st.fast_forward()
else:
add_dep(state, state.i, state.stack[0], label)
pop_stack(state)
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
if not LeftArc.is_valid(s, label):
return 9000
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label) return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
@staticmethod @staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1: cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
if not LeftArc.is_valid(s, -1):
return 9000
cdef int cost = 0 cdef int cost = 0
if gold.heads[s.stack[0]] == s.i: if arc_is_gold(gold, s.B(0), s.S(0)):
return cost return 0
elif at_eol(s): else:
# Are we root? # Account for deps we might lose between S0 and stack
if gold.labels[s.stack[0]] != -1: if not s.has_head(s.S(0)):
# If we're at EOL, prefer to reduce or break over left-arc for i in range(1, s.stack_depth()):
if Reduce.is_valid(s, -1) or Break.is_valid(s, -1): cost += gold.heads[s.S(i)] == s.S(0)
cost += gold.heads[s.stack[0]] != s.stack[0] cost += gold.heads[s.S(0)] == s.S(i)
return cost return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
cost += head_in_buffer(s, s.stack[0], gold.heads)
cost += children_in_buffer(s, s.stack[0], gold.heads)
if NON_MONOTONIC and s.stack_len >= 2:
cost += gold.heads[s.stack[0]] == s.stack[-1]
if gold.labels[s.stack[0]] != -1:
cost += gold.heads[s.stack[0]] == s.stack[0]
return cost
@staticmethod @staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
if label == -1 or gold.labels[s.stack[0]] == -1: return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
return 0
if gold.heads[s.stack[0]] == s.i and label != gold.labels[s.stack[0]]:
return 1
return 0
cdef class RightArc: cdef class RightArc:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) nogil:
return s.stack_len >= 1 and not at_eol(s) return not st.B_(0).sent_start
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(StateClass st, int label) nogil:
add_dep(state, state.stack[0], state.i, label) st.add_arc(st.S(0), st.B(0), label)
push_stack(state) st.push()
st.fast_forward()
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil:
if not RightArc.is_valid(s, label):
return 9000
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label) return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod @staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1: cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
return push_cost(s, gold, s.i) - (gold.heads[s.i] == s.stack[0]) if arc_is_gold(gold, s.S(0), s.B(0)):
return 0
elif s.shifted[s.B(0)]:
return push_cost(s, gold, s.B(0))
else:
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
@staticmethod @staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return arc_cost(gold, s.stack[0], s.i, label) return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
#cdef int cost = 0
#if gold.heads[s.i] == s.stack[0]:
# cost += label != -1 and label != gold.labels[s.i]
# return cost
# This indicates missing head
#if gold.labels[s.i] != -1:
# cost += head_in_buffer(s, s.i, gold.heads)
#cost += children_in_stack(s, s.i, gold.heads)
#cost += head_in_stack(s, s.i, gold.heads)
#return cost
cdef class Break: cdef class Break:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) nogil:
cdef int i cdef int i
if not USE_BREAK: if not USE_BREAK:
return False return False
elif at_eol(s): elif st.at_break():
return False
elif st.B(0) == 0:
return False
elif st.stack_depth() < 1:
return False
elif (st.S(0) + 1) != st.B(0):
# Must break at the token boundary
return False return False
#elif NON_MONOTONIC:
# return True
else: else:
# In the Break transition paper, they have this constraint that prevents
# Break if stack is disconnected. But, if we're doing non-monotonic parsing,
# we prefer to relax this constraint. This is helpful in parsing whole
# documents, because then we don't get stuck with words on the stack.
seen_headless = False
for i in range(s.stack_len):
if s.sent[s.stack[-i]].head == 0:
if seen_headless:
return False
else:
seen_headless = True
# TODO: Constituency constraints
return True return True
@staticmethod @staticmethod
cdef int transition(State* state, int label) except -1: cdef int transition(StateClass st, int label) nogil:
state.sent[state.i-1].sent_end = True st.set_break(st.B(0))
while state.stack_len != 0: st.fast_forward()
if get_s0(state).head == 0:
get_s0(state).dep = label
state.stack -= 1
state.stack_len -= 1
if not at_eol(state):
push_stack(state)
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
if not Break.is_valid(s, label): return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
return 9000
else:
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
@staticmethod @staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1: cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
# When we break, we Reduce all of the words on the stack.
cdef int cost = 0 cdef int cost = 0
# Number of deps between S0...Sn and N0...Nn cdef int i, j, S_i, B_i
for i in range(s.i, s.sent_len): for i in range(s.stack_depth()):
cost += children_in_stack(s, i, gold.heads) S_i = s.S(i)
cost += head_in_stack(s, i, gold.heads) for j in range(s.buffer_length()):
return cost B_i = s.B(j)
cost += gold.heads[S_i] == B_i
cost += gold.heads[B_i] == S_i
# Check for sentence boundary --- if it's here, we can't have any deps
# between stack and buffer, so rest of action is irrelevant.
s0_root = _get_root(s.S(0), gold)
b0_root = _get_root(s.B(0), gold)
if s0_root != b0_root or s0_root == -1 or b0_root == -1:
return cost
else:
return cost + 1
@staticmethod @staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1: cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0 return 0
cdef int _get_root(int word, const GoldParseC* gold) nogil:
cdef class Constituent: while gold.heads[word] != word and gold.labels[word] != -1 and word >= 0:
@staticmethod word = gold.heads[word]
cdef bint is_valid(const State* s, int label) except -1: if gold.labels[word] == -1:
if s.stack_len < 1: return -1
return False else:
return False return word
#else:
# # If all stack elements are popped, can't constituent
# for i in range(s.ctnts.stack_len):
# if not s.ctnts.is_popped[-i]:
# return True
# else:
# return False
@staticmethod
cdef int transition(State* state, int label) except -1:
return False
#cdef Constituent* bracket = new_bracket(state.ctnts)
#bracket.parent = NULL
#bracket.label = self.label
#bracket.head = get_s0(state)
#bracket.length = 0
#attach(bracket, state.ctnts.stack)
# Attach rightward children. They're in the brackets array somewhere
# between here and B0.
#cdef Constituent* node
#cdef const TokenC* node_gov
#for i in range(1, bracket - state.ctnts.stack):
# node = bracket - i
# node_gov = node.head + node.head.head
# if node_gov == bracket.head:
# attach(bracket, node)
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Constituent.is_valid(s, label):
return 9000
raise Exception("Constituent move should be disabled currently")
# The gold standard is indexed by end, then by start, then a set of labels
#brackets = gold.brackets(get_s0(s).r_edge, {})
#if not brackets:
# return 2 # 2 loss for bad bracket, only 1 for good bracket bad label
# Index the current brackets in the state
#existing = set()
#for i in range(s.ctnt_len):
# if ctnt.end == s.r_edge and ctnt.label == self.label:
# existing.add(ctnt.start)
#cdef int loss = 2
#cdef const TokenC* child
#cdef const TokenC* s0 = get_s0(s)
#cdef int n_left = count_left_kids(s0)
# Iterate over the possible start positions, and check whether we have a
# (start, end, label) match to the gold tree
#for i in range(1, n_left):
# child = get_left(s, s0, i)
# if child.l_edge in brackets and child.l_edge not in existing:
# if self.label in brackets[child.l_edge]
# return 0
# else:
# loss = 1 # If we see the start position, set loss to 1
#return loss
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not Constituent.is_valid(s, -1):
return 9000
raise Exception("Constituent move should be disabled currently")
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return 0
cdef class Adjust:
@staticmethod
cdef bint is_valid(const State* s, int label) except -1:
return False
#if s.ctnts.stack_len < 2:
# return False
#cdef const Constituent* b1 = s.ctnts.stack[-1]
#cdef const Constituent* b0 = s.ctnts.stack[0]
#if (b1.head + b1.head.head) != b0.head:
# return False
#elif b0.head >= b1.head:
# return False
#elif b0 >= b1:
# return False
@staticmethod
cdef int transition(State* state, int label) except -1:
return False
#cdef Constituent* b0 = state.ctnts.stack[0]
#cdef Constituent* b1 = state.ctnts.stack[1]
#assert (b1.head + b1.head.head) == b0.head
#assert b0.head < b1.head
#assert b0 < b1
#attach(b0, b1)
## Pop B1 from stack, but keep B0 on top
#state.ctnts.stack -= 1
#state.ctnts.stack[0] = b0
@staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Adjust.is_valid(s, label):
return 9000
raise Exception("Adjust move should be disabled currently")
@staticmethod
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
if not Adjust.is_valid(s, -1):
return 9000
raise Exception("Adjust move should be disabled currently")
@staticmethod
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
return 0
# The gold standard is indexed by end, then by start, then a set of labels
#gold_starts = gold.brackets(get_s0(s).r_edge, {})
# Case 1: There are 0 brackets ending at this word.
# --> Cost is sunk, but must allow brackets to begin
#if not gold_starts:
# return 0
# Is the top bracket correct?
#gold_labels = gold_starts.get(s.ctnt.start, set())
# TODO: Case where we have a unary rule
# TODO: Case where two brackets end on this word, with top bracket starting
# before
#cdef const TokenC* child
#cdef const TokenC* s0 = get_s0(s)
#cdef int n_left = count_left_kids(s0)
#cdef int i
# Iterate over the possible start positions, and check whether we have a
# (start, end, label) match to the gold tree
#for i in range(1, n_left):
# child = get_left(s, s0, i)
# if child.l_edge in brackets:
# if self.label in brackets[child.l_edge]:
# return 0
# else:
# loss = 1 # If we see the start position, set loss to 1
#return loss
cdef class ArcEager(TransitionSystem): cdef class ArcEager(TransitionSystem):
@classmethod @classmethod
def get_labels(cls, gold_parses): def get_labels(cls, gold_parses):
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {}, move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {'ROOT': True},
LEFT: {'ROOT': True}, BREAK: {'ROOT': True}, LEFT: {'ROOT': True}, BREAK: {'ROOT': True}}
CONSTITUENT: {}, ADJUST: {'': True}}
for raw_text, sents in gold_parses: for raw_text, sents in gold_parses:
for (ids, words, tags, heads, labels, iob), ctnts in sents: for (ids, words, tags, heads, labels, iob), ctnts in sents:
for child, head, label in zip(ids, heads, labels): for child, head, label in zip(ids, heads, labels):
if label.upper() == 'ROOT':
label = 'ROOT'
if label != 'ROOT': if label != 'ROOT':
if head < child: if head < child:
move_labels[RIGHT][label] = True move_labels[RIGHT][label] = True
elif head > child: elif head > child:
move_labels[LEFT][label] = True move_labels[LEFT][label] = True
for start, end, label in ctnts:
move_labels[CONSTITUENT][label] = True
return move_labels return move_labels
cdef int preprocess_gold(self, GoldParse gold) except -1: cdef int preprocess_gold(self, GoldParse gold) except -1:
@ -462,8 +304,11 @@ cdef class ArcEager(TransitionSystem):
gold.c.heads[i] = i gold.c.heads[i] = i
gold.c.labels[i] = -1 gold.c.labels[i] = -1
else: else:
label = gold.labels[i]
if label.upper() == 'ROOT':
label = 'ROOT'
gold.c.heads[i] = gold.heads[i] gold.c.heads[i] = gold.heads[i]
gold.c.labels[i] = self.strings[gold.labels[i]] gold.c.labels[i] = self.strings[label]
for end, brackets in gold.brackets.items(): for end, brackets in gold.brackets.items():
for start, label_strs in brackets.items(): for start, label_strs in brackets.items():
gold.c.brackets[start][end] = 1 gold.c.brackets[start][end] = 1
@ -517,41 +362,43 @@ cdef class ArcEager(TransitionSystem):
t.is_valid = Break.is_valid t.is_valid = Break.is_valid
t.do = Break.transition t.do = Break.transition
t.get_cost = Break.cost t.get_cost = Break.cost
elif move == CONSTITUENT:
t.is_valid = Constituent.is_valid
t.do = Constituent.transition
t.get_cost = Constituent.cost
elif move == ADJUST:
t.is_valid = Adjust.is_valid
t.do = Adjust.transition
t.get_cost = Adjust.cost
else: else:
raise Exception(move) raise Exception(move)
return t return t
cdef int initialize_state(self, State* state) except -1: cdef int initialize_state(self, StateClass st) except -1:
push_stack(state) # Ensure sent_start is set to 0 throughout
for i in range(st.length):
st._sent[i].sent_start = False
st._sent[i].l_edge = i
st._sent[i].r_edge = i
st.fast_forward()
cdef int finalize_state(self, State* state) except -1: cdef int finalize_state(self, StateClass st) except -1:
cdef int root_label = self.strings['ROOT'] cdef int root_label = self.strings['ROOT']
for i in range(state.sent_len): for i in range(st.length):
if state.sent[i].head == 0 and state.sent[i].dep == 0: if st._sent[i].head == 0 and st._sent[i].dep == 0:
state.sent[i].dep = root_label st._sent[i].dep = root_label
# If we're not using the Break transition, we segment via root-labelled
# arcs between the root words.
elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == root_label:
st._sent[i].head = 0
cdef int set_valid(self, bint* output, const State* state) except -1: cdef int set_valid(self, bint* output, StateClass stcls) except -1:
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(state, -1) is_valid[SHIFT] = Shift.is_valid(stcls, -1)
is_valid[REDUCE] = Reduce.is_valid(state, -1) is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
is_valid[LEFT] = LeftArc.is_valid(state, -1) is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
is_valid[RIGHT] = RightArc.is_valid(state, -1) is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
is_valid[BREAK] = Break.is_valid(state, -1) is_valid[BREAK] = Break.is_valid(stcls, -1)
is_valid[CONSTITUENT] = Constituent.is_valid(state, -1)
is_valid[ADJUST] = Adjust.is_valid(state, -1)
cdef int i cdef int i
n_valid = 0
for i in range(self.n_moves): for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move] output[i] = is_valid[self.c[i].move]
n_valid += output[i]
assert n_valid >= 1
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1: cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
cdef int i, move, label cdef int i, move, label
cdef label_cost_func_t[N_MOVES] label_cost_funcs cdef label_cost_func_t[N_MOVES] label_cost_funcs
cdef move_cost_func_t[N_MOVES] move_cost_funcs cdef move_cost_func_t[N_MOVES] move_cost_funcs
@ -563,35 +410,36 @@ cdef class ArcEager(TransitionSystem):
move_cost_funcs[LEFT] = LeftArc.move_cost move_cost_funcs[LEFT] = LeftArc.move_cost
move_cost_funcs[RIGHT] = RightArc.move_cost move_cost_funcs[RIGHT] = RightArc.move_cost
move_cost_funcs[BREAK] = Break.move_cost move_cost_funcs[BREAK] = Break.move_cost
move_cost_funcs[CONSTITUENT] = Constituent.move_cost
move_cost_funcs[ADJUST] = Adjust.move_cost
label_cost_funcs[SHIFT] = Shift.label_cost label_cost_funcs[SHIFT] = Shift.label_cost
label_cost_funcs[REDUCE] = Reduce.label_cost label_cost_funcs[REDUCE] = Reduce.label_cost
label_cost_funcs[LEFT] = LeftArc.label_cost label_cost_funcs[LEFT] = LeftArc.label_cost
label_cost_funcs[RIGHT] = RightArc.label_cost label_cost_funcs[RIGHT] = RightArc.label_cost
label_cost_funcs[BREAK] = Break.label_cost label_cost_funcs[BREAK] = Break.label_cost
label_cost_funcs[CONSTITUENT] = Constituent.label_cost
label_cost_funcs[ADJUST] = Adjust.label_cost
cdef int* labels = gold.c.labels cdef int* labels = gold.c.labels
cdef int* heads = gold.c.heads cdef int* heads = gold.c.heads
for i in range(self.n_moves):
move = self.c[i].move
label = self.c[i].label
if move_costs[move] == -1:
move_costs[move] = move_cost_funcs[move](s, &gold.c)
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: n_gold = 0
for i in range(self.n_moves):
if self.c[i].is_valid(stcls, self.c[i].label):
move = self.c[i].move
label = self.c[i].label
if move_costs[move] == -1:
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
n_gold += output[i] == 0
else:
output[i] = 9000
assert n_gold >= 1
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(s, -1) is_valid[SHIFT] = Shift.is_valid(stcls, -1)
is_valid[REDUCE] = Reduce.is_valid(s, -1) is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
is_valid[LEFT] = LeftArc.is_valid(s, -1) is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
is_valid[RIGHT] = RightArc.is_valid(s, -1) is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
is_valid[BREAK] = Break.is_valid(s, -1) is_valid[BREAK] = Break.is_valid(stcls, -1)
is_valid[CONSTITUENT] = Constituent.is_valid(s, -1)
is_valid[ADJUST] = Adjust.is_valid(s, -1)
cdef Transition best cdef Transition best
cdef weight_t score = MIN_SCORE cdef weight_t score = MIN_SCORE
cdef int i cdef int i
@ -600,15 +448,5 @@ cdef class ArcEager(TransitionSystem):
best = self.c[i] best = self.c[i]
score = scores[i] score = scores[i]
assert best.clas < self.n_moves assert best.clas < self.n_moves
assert score > MIN_SCORE assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length)
# Label Shift moves with the best Right-Arc label, for non-monotonic
# actions
if best.move == SHIFT:
score = MIN_SCORE
for i in range(self.n_moves):
if self.c[i].move == RIGHT and scores[i] > score:
best.label = self.c[i].label
score = scores[i]
return best return best

View File

@ -1,6 +1,5 @@
from .transition_system cimport TransitionSystem from .transition_system cimport TransitionSystem
from .transition_system cimport Transition from .transition_system cimport Transition
from ._state cimport State
from ..gold cimport GoldParseC from ..gold cimport GoldParseC

View File

@ -1,7 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from ._state cimport State
from .transition_system cimport Transition from .transition_system cimport Transition
from .transition_system cimport do_func_t from .transition_system cimport do_func_t
@ -11,6 +9,8 @@ from thinc.typedefs cimport weight_t
from ..gold cimport GoldParseC from ..gold cimport GoldParseC
from ..gold cimport GoldParse from ..gold cimport GoldParse
from .stateclass cimport StateClass
cdef enum: cdef enum:
MISSING MISSING
@ -34,18 +34,14 @@ MOVE_NAMES[OUT] = 'O'
cdef do_func_t[N_MOVES] do_funcs cdef do_func_t[N_MOVES] do_funcs
cdef bint entity_is_open(const State *s) except -1: cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
return s.ents_len >= 1 and s.ent.end == 0 if not st.entity_is_open():
cdef bint _entity_is_sunk(const State *s, Transition* golds) except -1:
if not entity_is_open(s):
return False return False
cdef const Transition* gold = &golds[s.ent.start] cdef const Transition* gold = &golds[st.E(0)]
if gold.move != BEGIN and gold.move != UNIT: if gold.move != BEGIN and gold.move != UNIT:
return True return True
elif gold.label != s.ent.label: elif gold.label != st.E_(0).ent_type:
return True return True
else: else:
return False return False
@ -132,14 +128,14 @@ cdef class BiluoPushDown(TransitionSystem):
raise Exception(move) raise Exception(move)
return t return t
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
cdef int best = -1 cdef int best = -1
cdef weight_t score = -90000 cdef weight_t score = -90000
cdef const Transition* m cdef const Transition* m
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
m = &self.c[i] m = &self.c[i]
if m.is_valid(s, m.label) and scores[i] > score: if m.is_valid(stcls, m.label) and scores[i] > score:
best = i best = i
score = scores[i] score = scores[i]
assert best >= 0 assert best >= 0
@ -147,49 +143,43 @@ cdef class BiluoPushDown(TransitionSystem):
t.score = score t.score = score
return t return t
cdef int set_valid(self, bint* output, const State* s) except -1: cdef int set_valid(self, bint* output, StateClass stcls) except -1:
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
m = &self.c[i] m = &self.c[i]
output[i] = m.is_valid(s, m.label) output[i] = m.is_valid(stcls, m.label)
cdef class Missing: cdef class Missing:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) nogil:
return False return False
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass s, int label) nogil:
raise NotImplementedError pass
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 9000 return 9000
cdef class Begin: cdef class Begin:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) nogil:
return label != 0 and not entity_is_open(s) return label != 0 and not st.entity_is_open()
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass st, int label) nogil:
s.ent += 1 st.open_ent(label)
s.ents_len += 1 st.set_ent_tag(st.B(0), 3, label)
s.ent.start = s.i st.push()
s.ent.label = label st.pop()
s.ent.end = 0
s.sent[s.i].ent_iob = 3
s.sent[s.i].ent_type = label
s.i += 1
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
if not Begin.is_valid(s, label): cdef int g_act = gold.ner[s.B(0)].move
return 9000 cdef int g_tag = gold.ner[s.B(0)].label
cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label
if g_act == MISSING: if g_act == MISSING:
return 0 return 0
@ -203,25 +193,24 @@ cdef class Begin:
# B, Gold U --> False (P) # B, Gold U --> False (P)
return 1 return 1
cdef class In: cdef class In:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) nogil:
return entity_is_open(s) and label != 0 and s.ent.label == label return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass st, int label) nogil:
s.sent[s.i].ent_iob = 1 st.set_ent_tag(st.B(0), 1, label)
s.sent[s.i].ent_type = label st.push()
s.i += 1 st.pop()
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
if not In.is_valid(s, label):
return 9000
move = IN move = IN
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT cdef int next_act = gold.ner[s.B(1)].move if s.B(0) < s.length else OUT
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.B(0)].move
cdef int g_tag = gold.ner[s.i].label cdef int g_tag = gold.ner[s.B(0)].label
cdef bint is_sunk = _entity_is_sunk(s, gold.ner) cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
if g_act == MISSING: if g_act == MISSING:
@ -245,27 +234,23 @@ cdef class In:
return 1 return 1
cdef class Last: cdef class Last:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) nogil:
return entity_is_open(s) and label != 0 and s.ent.label == label return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass st, int label) nogil:
s.ent.end = s.i+1 st.close_ent()
s.sent[s.i].ent_iob = 1 st.push()
s.sent[s.i].ent_type = label st.pop()
s.i += 1
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
if not Last.is_valid(s, label):
return 9000
move = LAST move = LAST
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.B(0)].move
cdef int g_tag = gold.ner[s.i].label cdef int g_tag = gold.ner[s.B(0)].label
if g_act == MISSING: if g_act == MISSING:
return 0 return 0
@ -290,26 +275,21 @@ cdef class Last:
cdef class Unit: cdef class Unit:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) nogil:
return label != 0 and not entity_is_open(s) return label != 0 and not st.entity_is_open()
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass st, int label) nogil:
s.ent += 1 st.open_ent(label)
s.ents_len += 1 st.close_ent()
s.ent.start = s.i st.set_ent_tag(st.B(0), 3, label)
s.ent.label = label st.push()
s.ent.end = s.i+1 st.pop()
s.sent[s.i].ent_iob = 3
s.sent[s.i].ent_type = label
s.i += 1
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
if not Unit.is_valid(s, label): cdef int g_act = gold.ner[s.B(0)].move
return 9000 cdef int g_tag = gold.ner[s.B(0)].label
cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label
if g_act == MISSING: if g_act == MISSING:
return 0 return 0
@ -326,22 +306,19 @@ cdef class Unit:
cdef class Out: cdef class Out:
@staticmethod @staticmethod
cdef bint is_valid(const State* s, int label) except -1: cdef bint is_valid(StateClass st, int label) nogil:
return not entity_is_open(s) return not st.entity_is_open()
@staticmethod @staticmethod
cdef int transition(State* s, int label) except -1: cdef int transition(StateClass st, int label) nogil:
s.sent[s.i].ent_iob = 2 st.set_ent_tag(st.B(0), 2, 0)
s.i += 1 st.push()
st.pop()
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
if not Out.is_valid(s, label): cdef int g_act = gold.ner[s.B(0)].move
return 9000 cdef int g_tag = gold.ner[s.B(0)].label
cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label
if g_act == MISSING: if g_act == MISSING:
return 0 return 0

View File

@ -5,8 +5,6 @@ from .._ml cimport Model
from .arc_eager cimport TransitionSystem from .arc_eager cimport TransitionSystem
from ..tokens cimport Tokens, TokenC from ..tokens cimport Tokens, TokenC
from ._state cimport State
cdef class Parser: cdef class Parser:

View File

@ -1,9 +1,13 @@
# cython: profile=True # cython: profile=True
# cython: experimental_cpp_class_def=True
""" """
MALT-style dependency parser MALT-style dependency parser
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
cimport cython cimport cython
from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
from libc.stdint cimport uint32_t, uint64_t from libc.stdint cimport uint32_t, uint64_t
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
import random import random
@ -14,7 +18,7 @@ import json
from cymem.cymem cimport Pool, Address from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from util import Config from util import Config
@ -31,14 +35,16 @@ from thinc.search cimport MaxViolation
from ..tokens cimport Tokens, TokenC from ..tokens cimport Tokens, TokenC
from ..strings cimport StringStore from ..strings cimport StringStore
from .arc_eager cimport TransitionSystem, Transition
from .transition_system import OracleError
from ._state cimport State, new_state, copy_state, is_final, push_stack from .transition_system import OracleError
from .transition_system cimport TransitionSystem, Transition
from ..gold cimport GoldParse from ..gold cimport GoldParse
from . import _parse_features from . import _parse_features
from ._parse_features cimport fill_context, CONTEXT_SIZE from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context
from .stateclass cimport StateClass
DEBUG = False DEBUG = False
@ -47,20 +53,6 @@ def set_debug(val):
DEBUG = val DEBUG = val
cdef unicode print_state(State* s, list words):
words = list(words) + ['EOL']
top = words[s.stack[0]] + '_%d' % s.sent[s.stack[0]].head
second = words[s.stack[-1]] + '_%d' % s.sent[s.stack[-1]].head
third = words[s.stack[-2]] + '_%d' % s.sent[s.stack[-2]].head
n0 = words[s.i] if s.i < len(words) else 'EOL'
n1 = words[s.i + 1] if s.i+1 < len(words) else 'EOL'
if s.ents_len:
ent = '%s %d-%d' % (s.ent.label, s.ent.start, s.ent.end)
else:
ent = '-'
return ' '.join((ent, str(s.stack_len), third, second, top, '|', n0, n1))
def get_templates(name): def get_templates(name):
pf = _parse_features pf = _parse_features
if name == 'ner': if name == 'ner':
@ -68,7 +60,7 @@ def get_templates(name):
elif name == 'debug': elif name == 'debug':
return pf.unigrams return pf.unigrams
else: else:
return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \ return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \
pf.tree_shape + pf.trigrams) pf.tree_shape + pf.trigrams)
@ -81,16 +73,14 @@ cdef class Parser:
self.model = Model(self.moves.n_moves, templates, model_dir) self.model = Model(self.moves.n_moves, templates, model_dir)
def __call__(self, Tokens tokens): def __call__(self, Tokens tokens):
if tokens.length == 0: if self.cfg.get('beam_width', 1) < 1:
return 0
if self.cfg.get('beam_width', 1) <= 1:
self._greedy_parse(tokens) self._greedy_parse(tokens)
else: else:
self._beam_parse(tokens) self._beam_parse(tokens)
def train(self, Tokens tokens, GoldParse gold): def train(self, Tokens tokens, GoldParse gold):
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
if self.cfg.beam_width <= 1: if self.cfg.beam_width < 1:
return self._greedy_train(tokens, gold) return self._greedy_train(tokens, gold)
else: else:
return self._beam_train(tokens, gold) return self._beam_train(tokens, gold)
@ -99,31 +89,36 @@ cdef class Parser:
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
cdef int n_feats cdef int n_feats
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length) cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(state) self.moves.initialize_state(stcls)
cdef Transition guess cdef Transition guess
while not is_final(state): words = [w.orth_ for w in tokens]
fill_context(context, state) while not stcls.is_final():
fill_context(context, stcls)
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, stcls)
guess.do(state, guess.label) #print self.moves.move_name(guess.move, guess.label), stcls.print_state(words)
self.moves.finalize_state(state) guess.do(stcls, guess.label)
tokens.set_parse(state.sent) assert stcls._s_i >= 0
self.moves.finalize_state(stcls)
tokens.set_parse(stcls._sent)
cdef int _beam_parse(self, Tokens tokens) except -1: cdef int _beam_parse(self, Tokens tokens) except -1:
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width) cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
words = [w.orth_ for w in tokens]
beam.initialize(_init_state, tokens.length, tokens.data) beam.initialize(_init_state, tokens.length, tokens.data)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
while not beam.is_done: while not beam.is_done:
self._advance_beam(beam, None, False) self._advance_beam(beam, None, False, words)
state = <State*>beam.at(0) state = <StateClass>beam.at(0)
self.moves.finalize_state(state) self.moves.finalize_state(state)
tokens.set_parse(state.sent) tokens.set_parse(state._sent)
_cleanup(beam)
def _greedy_train(self, Tokens tokens, GoldParse gold): def _greedy_train(self, Tokens tokens, GoldParse gold):
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length) cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(state) self.moves.initialize_state(stcls)
cdef int cost cdef int cost
cdef const Feature* feats cdef const Feature* feats
@ -132,14 +127,16 @@ cdef class Parser:
cdef Transition best cdef Transition best
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
loss = 0 loss = 0
while not is_final(state): words = [w.orth_ for w in tokens]
fill_context(context, state) history = []
while not stcls.is_final():
fill_context(context, stcls)
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, stcls)
best = self.moves.best_gold(scores, state, gold) best = self.moves.best_gold(scores, stcls, gold)
cost = guess.get_cost(state, &gold.c, guess.label) cost = guess.get_cost(stcls, &gold.c, guess.label)
self.model.update(context, guess.clas, best.clas, cost) self.model.update(context, guess.clas, best.clas, cost)
guess.do(state, guess.label) guess.do(stcls, guess.label)
loss += cost loss += cost
return loss return loss
@ -152,9 +149,10 @@ cdef class Parser:
gold.check_done(_check_final_state, NULL) gold.check_done(_check_final_state, NULL)
violn = MaxViolation() violn = MaxViolation()
words = [w.orth_ for w in tokens]
while not pred.is_done and not gold.is_done: while not pred.is_done and not gold.is_done:
self._advance_beam(pred, gold_parse, False) self._advance_beam(pred, gold_parse, False, words)
self._advance_beam(gold, gold_parse, True) self._advance_beam(gold, gold_parse, True, words)
violn.check(pred, gold) violn.check(pred, gold)
if pred.loss >= 1: if pred.loss >= 1:
counts = {clas: {} for clas in range(self.model.n_classes)} counts = {clas: {} for clas in range(self.model.n_classes)}
@ -163,62 +161,90 @@ cdef class Parser:
else: else:
counts = {} counts = {}
self.model._model.update(counts) self.model._model.update(counts)
_cleanup(pred)
_cleanup(gold)
return pred.loss return pred.loss
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold, words):
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
cdef State* state
cdef int i, j, cost cdef int i, j, cost
cdef bint is_valid cdef bint is_valid
cdef const Transition* move cdef const Transition* move
for i in range(beam.size): for i in range(beam.size):
state = <State*>beam.at(i) stcls = <StateClass>beam.at(i)
if not is_final(state): if not stcls.is_final():
fill_context(context, state) fill_context(context, stcls)
self.model.set_scores(beam.scores[i], context) self.model.set_scores(beam.scores[i], context)
self.moves.set_valid(beam.is_valid[i], state) self.moves.set_valid(beam.is_valid[i], stcls)
if gold is not None: if gold is not None:
for i in range(beam.size): for i in range(beam.size):
state = <State*>beam.at(i) stcls = <StateClass>beam.at(i)
self.moves.set_costs(beam.costs[i], state, gold) if not stcls.is_final():
if follow_gold: self.moves.set_costs(beam.costs[i], stcls, gold)
for j in range(self.moves.n_moves): if follow_gold:
beam.is_valid[i][j] *= beam.costs[i][j] == 0 for j in range(self.moves.n_moves):
beam.advance(_transition_state, <void*>self.moves.c) beam.is_valid[i][j] *= beam.costs[i][j] == 0
state = <State*>beam.at(0) beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc): def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length) cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(state) self.moves.initialize_state(stcls)
cdef class_t clas cdef class_t clas
cdef int n_feats cdef int n_feats
for clas in hist: for clas in hist:
fill_context(context, state) fill_context(context, stcls)
feats = self.model._extractor.get_feats(context, &n_feats) feats = self.model._extractor.get_feats(context, &n_feats)
count_feats(counts[clas], feats, n_feats, inc) count_feats(counts[clas], feats, n_feats, inc)
self.moves.c[clas].do(state, self.moves.c[clas].label) self.moves.c[clas].do(stcls, self.moves.c[clas].label)
# These are passed as callbacks to thinc.search.Beam # These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <State*>_dest dest = <StateClass>_dest
src = <const State*>_src src = <StateClass>_src
moves = <const Transition*>_moves moves = <const Transition*>_moves
copy_state(dest, src) dest.clone(src)
moves[clas].do(dest, moves[clas].label) moves[clas].do(dest, moves[clas].label)
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
state = new_state(mem, <const TokenC*>tokens, length) cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
push_stack(state) st.fast_forward()
return state Py_INCREF(st)
return <void*>st
cdef int _check_final_state(void* state, void* extra_args) except -1: cdef int _check_final_state(void* _state, void* extra_args) except -1:
return is_final(<State*>state) return (<StateClass>_state).is_final()
def _cleanup(Beam beam):
for i in range(beam.width):
Py_XDECREF(<PyObject*>beam._states[i].content)
Py_XDECREF(<PyObject*>beam._parents[i].content)
cdef hash_t _hash_state(void* _state, void* _) except 0:
return <hash_t>_state
#state = <const State*>_state
#cdef atom_t[10] rep
#rep[0] = state.stack[0] if state.stack_len >= 1 else 0
#rep[1] = state.stack[-1] if state.stack_len >= 2 else 0
#rep[2] = state.stack[-2] if state.stack_len >= 3 else 0
#rep[3] = state.i
#rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0
#rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0
#rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0
#rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0
#if get_left(state, get_n0(state), 1) != NULL:
# rep[8] = get_left(state, get_n0(state), 1).dep
#else:
# rep[8] = 0
#rep[9] = state.sent[state.i].l_kids
#return hash64(rep, sizeof(atom_t) * 10, 0)

104
spacy/syntax/stateclass.pxd Normal file
View File

@ -0,0 +1,104 @@
from libc.string cimport memcpy, memset
from cymem.cymem cimport Pool
from ..structs cimport TokenC, Entity
from ..vocab cimport EMPTY_LEXEME
cdef class StateClass:
cdef Pool mem
cdef int* _stack
cdef int* _buffer
cdef bint* shifted
cdef TokenC* _sent
cdef Entity* _ents
cdef TokenC _empty_token
cdef int length
cdef int _s_i
cdef int _b_i
cdef int _e_i
cdef int _break
@staticmethod
cdef inline StateClass init(const TokenC* sent, int length):
cdef StateClass self = StateClass(length)
cdef int i
for i in range(length):
self._sent[i] = sent[i]
self._buffer[i] = i
for i in range(length, length + 5):
self._sent[i].lex = &EMPTY_LEXEME
return self
cdef inline int S(self, int i) nogil:
if i >= self._s_i:
return -1
return self._stack[self._s_i - (i+1)]
cdef inline int B(self, int i) nogil:
if (i + self._b_i) >= self.length:
return -1
return self._buffer[self._b_i + i]
cdef int H(self, int i) nogil
cdef int E(self, int i) nogil
cdef int L(self, int i, int idx) nogil
cdef int R(self, int i, int idx) nogil
cdef const TokenC* S_(self, int i) nogil
cdef const TokenC* B_(self, int i) nogil
cdef const TokenC* H_(self, int i) nogil
cdef const TokenC* E_(self, int i) nogil
cdef const TokenC* L_(self, int i, int idx) nogil
cdef const TokenC* R_(self, int i, int idx) nogil
cdef const TokenC* safe_get(self, int i) nogil
cdef bint empty(self) nogil
cdef bint entity_is_open(self) nogil
cdef bint eol(self) nogil
cdef bint at_break(self) nogil
cdef bint is_final(self) nogil
cdef bint has_head(self, int i) nogil
cdef int n_L(self, int i) nogil
cdef int n_R(self, int i) nogil
cdef bint stack_is_connected(self) nogil
cdef int stack_depth(self) nogil
cdef int buffer_length(self) nogil
cdef void push(self) nogil
cdef void pop(self) nogil
cdef void unshift(self) nogil
cdef void add_arc(self, int head, int child, int label) nogil
cdef void del_arc(self, int head, int child) nogil
cdef void open_ent(self, int label) nogil
cdef void close_ent(self) nogil
cdef void set_ent_tag(self, int i, int ent_iob, int ent_type) nogil
cdef void set_break(self, int i) nogil
cdef void clone(self, StateClass src) nogil
cdef void fast_forward(self) nogil

253
spacy/syntax/stateclass.pyx Normal file
View File

@ -0,0 +1,253 @@
from libc.string cimport memcpy, memset
from libc.stdint cimport uint32_t
from ..vocab cimport EMPTY_LEXEME
from ..structs cimport Entity
cdef class StateClass:
def __init__(self, int length):
cdef Pool mem = Pool()
PADDING = 5
self._buffer = <int*>mem.alloc(length + PADDING, sizeof(int))
self._stack = <int*>mem.alloc(length + PADDING, sizeof(int))
self.shifted = <bint*>mem.alloc(length + PADDING, sizeof(bint))
self._sent = <TokenC*>mem.alloc(length + PADDING, sizeof(TokenC))
self._ents = <Entity*>mem.alloc(length + PADDING, sizeof(Entity))
cdef int i
for i in range(length):
self._ents[i].end = -1
for i in range(length, length + PADDING):
self._sent[i].lex = &EMPTY_LEXEME
self.mem = mem
self.length = length
self._break = -1
self._s_i = 0
self._b_i = 0
self._e_i = 0
for i in range(length):
self._buffer[i] = i
self._empty_token.lex = &EMPTY_LEXEME
cdef int H(self, int i) nogil:
if i < 0 or i >= self.length:
return -1
return self._sent[i].head + i
cdef int E(self, int i) nogil:
if self._e_i <= 0 or self._e_i >= self.length:
return 0
if i < 0 or i >= self.length:
return 0
return self._ents[self._e_i-1].start
cdef int L(self, int i, int idx) nogil:
if idx < 1:
return -1
if i < 0 or i >= self.length:
return -1
cdef const TokenC* target = &self._sent[i]
cdef const TokenC* ptr = self._sent
while ptr < target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head >= 1) and (ptr + ptr.head) < target:
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr - self._sent
ptr += 1
else:
ptr += 1
return -1
cdef int R(self, int i, int idx) nogil:
if idx < 1:
return -1
if i < 0 or i >= self.length:
return -1
cdef const TokenC* ptr = self._sent + (self.length - 1)
cdef const TokenC* target = &self._sent[i]
while ptr > target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head < 0) and ((ptr + ptr.head) > target):
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr - self._sent
ptr -= 1
else:
ptr -= 1
return -1
cdef const TokenC* S_(self, int i) nogil:
return self.safe_get(self.S(i))
cdef const TokenC* B_(self, int i) nogil:
return self.safe_get(self.B(i))
cdef const TokenC* H_(self, int i) nogil:
return self.safe_get(self.H(i))
cdef const TokenC* E_(self, int i) nogil:
return self.safe_get(self.E(i))
cdef const TokenC* L_(self, int i, int idx) nogil:
return self.safe_get(self.L(i, idx))
cdef const TokenC* R_(self, int i, int idx) nogil:
return self.safe_get(self.R(i, idx))
cdef const TokenC* safe_get(self, int i) nogil:
if i < 0 or i >= self.length:
return &self._empty_token
else:
return &self._sent[i]
cdef bint empty(self) nogil:
return self._s_i <= 0
cdef bint eol(self) nogil:
return self.buffer_length() == 0
cdef bint at_break(self) nogil:
return self._break != -1
cdef bint is_final(self) nogil:
return self.stack_depth() <= 0 and self._b_i >= self.length
cdef bint has_head(self, int i) nogil:
return self.safe_get(i).head != 0
cdef int n_L(self, int i) nogil:
return self.safe_get(i).l_kids
cdef int n_R(self, int i) nogil:
return self.safe_get(i).r_kids
cdef bint stack_is_connected(self) nogil:
return False
cdef bint entity_is_open(self) nogil:
if self._e_i < 1:
return False
return self._ents[self._e_i-1].end == -1
cdef int stack_depth(self) nogil:
return self._s_i
cdef int buffer_length(self) nogil:
if self._break != -1:
return self._break - self._b_i
else:
return self.length - self._b_i
cdef void push(self) nogil:
if self.B(0) != -1:
self._stack[self._s_i] = self.B(0)
self._s_i += 1
self._b_i += 1
if self._b_i > self._break:
self._break = -1
cdef void pop(self) nogil:
if self._s_i >= 1:
self._s_i -= 1
cdef void unshift(self) nogil:
self._b_i -= 1
self._buffer[self._b_i] = self.S(0)
self._s_i -= 1
self.shifted[self.B(0)] = True
cdef void fast_forward(self) nogil:
while self.buffer_length() == 0 or self.stack_depth() == 0:
if self.buffer_length() == 1 and self.stack_depth() == 0:
self.push()
self.pop()
elif self.buffer_length() == 0 and self.stack_depth() == 1:
self.pop()
elif self.buffer_length() == 0 and self.stack_depth() >= 2:
if self.has_head(self.S(0)):
self.pop()
else:
self.unshift()
elif (self.length - self._b_i) >= 1 and self.stack_depth() == 0:
self.push()
else:
break
cdef void add_arc(self, int head, int child, int label) nogil:
if self.has_head(child):
self.del_arc(self.H(child), child)
cdef int dist = head - child
self._sent[child].head = dist
self._sent[child].dep = label
cdef int i
if child > head:
self._sent[head].r_kids += 1
self._sent[head].r_edge = child
i = 0
while self.has_head(head) and i < self.length:
self._sent[head].r_edge = child
head = self.H(head)
i += 1 # Guard against infinite loops
else:
self._sent[head].l_kids += 1
self._sent[head].l_edge = self._sent[child].l_edge
cdef void del_arc(self, int h_i, int c_i) nogil:
cdef int dist = h_i - c_i
cdef TokenC* h = &self._sent[h_i]
if c_i > h_i:
h.r_kids -= 1
h.r_edge = self.R_(h_i, h.r_kids-1).r_edge if h.r_kids >= 1 else h_i
else:
h.l_kids -= 1
h.l_edge = self.L_(h_i, h.l_kids-1).l_edge if h.l_kids >= 1 else h_i
cdef void open_ent(self, int label) nogil:
self._ents[self._e_i].start = self.B(0)
self._ents[self._e_i].label = label
self._ents[self._e_i].end = -1
self._e_i += 1
cdef void close_ent(self) nogil:
self._ents[self._e_i-1].end = self.B(0)+1
self._sent[self.B(0)].ent_iob = 1
cdef void set_ent_tag(self, int i, int ent_iob, int ent_type) nogil:
if 0 <= i < self.length:
self._sent[i].ent_iob = ent_iob
self._sent[i].ent_type = ent_type
cdef void set_break(self, int _) nogil:
if 0 <= self.B(0) < self.length:
self._sent[self.B(0)].sent_start = True
self._break = self._b_i
cdef void clone(self, StateClass src) nogil:
memcpy(self._sent, src._sent, self.length * sizeof(TokenC))
memcpy(self._stack, src._stack, self.length * sizeof(int))
memcpy(self._buffer, src._buffer, self.length * sizeof(int))
memcpy(self._ents, src._ents, self.length * sizeof(Entity))
self._b_i = src._b_i
self._s_i = src._s_i
self._e_i = src._e_i
self._break = src._break
def print_state(self, words):
words = list(words) + ['_']
top = words[self.S(0)] + '_%d' % self.S_(0).head
second = words[self.S(1)] + '_%d' % self.S_(1).head
third = words[self.S(2)] + '_%d' % self.S_(2).head
n0 = words[self.B(0)]
n1 = words[self.B(1)]
return ' '.join((third, second, top, '|', n0, n1))

View File

@ -2,11 +2,12 @@ from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t from thinc.typedefs cimport weight_t
from ..structs cimport TokenC from ..structs cimport TokenC
from ._state cimport State
from ..gold cimport GoldParse from ..gold cimport GoldParse
from ..gold cimport GoldParseC from ..gold cimport GoldParseC
from ..strings cimport StringStore from ..strings cimport StringStore
from .stateclass cimport StateClass
cdef struct Transition: cdef struct Transition:
int clas int clas
@ -15,16 +16,16 @@ cdef struct Transition:
weight_t score weight_t score
bint (*is_valid)(const State* state, int label) except -1 bint (*is_valid)(StateClass state, int label) nogil
int (*get_cost)(const State* state, const GoldParseC* gold, int label) except -1 int (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil
int (*do)(State* state, int label) except -1 int (*do)(StateClass state, int label) nogil
ctypedef int (*get_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1 ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef int (*move_cost_func_t)(const State* state, const GoldParseC* gold) except -1 ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil
ctypedef int (*label_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1 ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef int (*do_func_t)(State* state, int label) except -1 ctypedef int (*do_func_t)(StateClass state, int label) nogil
cdef class TransitionSystem: cdef class TransitionSystem:
@ -34,8 +35,8 @@ cdef class TransitionSystem:
cdef bint* _is_valid cdef bint* _is_valid
cdef readonly int n_moves cdef readonly int n_moves
cdef int initialize_state(self, State* state) except -1 cdef int initialize_state(self, StateClass state) except -1
cdef int finalize_state(self, State* state) except -1 cdef int finalize_state(self, StateClass state) except -1
cdef int preprocess_gold(self, GoldParse gold) except -1 cdef int preprocess_gold(self, GoldParse gold) except -1
@ -43,18 +44,11 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except * cdef Transition init_transition(self, int clas, int move, int label) except *
cdef int set_valid(self, bint* output, const State* state) except -1 cdef int set_valid(self, bint* output, StateClass state) except -1
cdef int set_costs(self, int* output, const State* state, GoldParse gold) except -1 cdef int set_costs(self, int* output, StateClass state, GoldParse gold) except -1
cdef Transition best_valid(self, const weight_t* scores, const State* state) except * cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *
cdef Transition best_gold(self, const weight_t* scores, const State* state, cdef Transition best_gold(self, const weight_t* scores, StateClass state,
GoldParse gold) except * GoldParse gold) except *
#cdef class PyState:
# """Provide a Python class for testing purposes."""
# cdef Pool mem
# cdef TransitionSystem system
# cdef State* _state

View File

@ -1,8 +1,9 @@
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from ._state cimport State
from ..structs cimport TokenC from ..structs cimport TokenC
from thinc.typedefs cimport weight_t from thinc.typedefs cimport weight_t
from .stateclass cimport StateClass
cdef weight_t MIN_SCORE = -90000 cdef weight_t MIN_SCORE = -90000
@ -27,10 +28,10 @@ cdef class TransitionSystem:
i += 1 i += 1
self.c = moves self.c = moves
cdef int initialize_state(self, State* state) except -1: cdef int initialize_state(self, StateClass state) except -1:
pass pass
cdef int finalize_state(self, State* state) except -1: cdef int finalize_state(self, StateClass state) except -1:
pass pass
cdef int preprocess_gold(self, GoldParse gold) except -1: cdef int preprocess_gold(self, GoldParse gold) except -1:
@ -42,62 +43,30 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except *: cdef Transition init_transition(self, int clas, int move, int label) except *:
raise NotImplementedError raise NotImplementedError
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef Transition best_valid(self, const weight_t* scores, StateClass s) except *:
raise NotImplementedError raise NotImplementedError
cdef int set_valid(self, bint* output, const State* state) except -1: cdef int set_valid(self, bint* output, StateClass state) except -1:
raise NotImplementedError raise NotImplementedError
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1: cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
output[i] = self.c[i].get_cost(s, &gold.c, self.c[i].label) if self.c[i].is_valid(stcls, self.c[i].label):
output[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
else:
output[i] = 9000
cdef Transition best_gold(self, const weight_t* scores, const State* s, cdef Transition best_gold(self, const weight_t* scores, StateClass stcls,
GoldParse gold) except *: GoldParse gold) except *:
cdef Transition best cdef Transition best
cdef weight_t score = MIN_SCORE cdef weight_t score = MIN_SCORE
cdef int i cdef int i
for i in range(self.n_moves): for i in range(self.n_moves):
cost = self.c[i].get_cost(s, &gold.c, self.c[i].label) if self.c[i].is_valid(stcls, self.c[i].label):
if scores[i] > score and cost == 0: cost = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
best = self.c[i] if scores[i] > score and cost == 0:
score = scores[i] best = self.c[i]
score = scores[i]
assert score > MIN_SCORE assert score > MIN_SCORE
return best return best
#cdef class PyState:
# """Provide a Python class for testing purposes."""
# def __init__(self, GoldParse gold):
# self.mem = Pool()
# self.system = EntityRecognition(labels)
# self._state = init_state(self.mem, tokens, gold.length)
#
# def transition(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# trans.do(trans, self._state)
#
# def is_valid(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# return _is_valid(trans.move, trans.label, self._state)
#
# def is_gold(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# return _get_const(trans, self._state, self._gold)
#
# property ent:
# def __get__(self):
# pass
#
# property n_ents:
# def __get__(self):
# pass
#
# property i:
# def __get__(self):
# pass
#
# property open_entity:
# def __get__(self):
# return entity_is_open(self._s)

View File

@ -1,7 +1,7 @@
from libc.stdint cimport uint32_t from libc.stdint cimport uint32_t
from numpy cimport ndarray from numpy cimport ndarray
cimport numpy cimport numpy as np
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from thinc.typedefs cimport atom_t from thinc.typedefs cimport atom_t
@ -47,7 +47,7 @@ cdef class Tokens:
cdef int push_back(self, int i, LexemeOrToken lex_or_tok) except -1 cdef int push_back(self, int i, LexemeOrToken lex_or_tok) except -1
cpdef long[:,:] to_array(self, object features) cpdef np.ndarray to_array(self, object features)
cdef int set_parse(self, const TokenC* parsed) except -1 cdef int set_parse(self, const TokenC* parsed) except -1

View File

@ -17,8 +17,11 @@ from .spans import Span
from .structs cimport UniStr from .structs cimport UniStr
from unidecode import unidecode from unidecode import unidecode
# Compiler crashes on memory view coercion without this. Should report bug.
from cython.view cimport array as cvarray
cimport numpy as np
np.import_array()
cimport numpy
import numpy import numpy
cimport cython cimport cython
@ -183,15 +186,12 @@ cdef class Tokens:
""" """
cdef int i cdef int i
cdef Tokens sent = Tokens(self.vocab, self._string[self.data[0].idx:]) cdef Tokens sent = Tokens(self.vocab, self._string[self.data[0].idx:])
start = None start = 0
for i in range(self.length): for i in range(1, self.length):
if start is None: if self.data[i].sent_start:
yield Span(self, start, i)
start = i start = i
if self.data[i].sent_end: yield Span(self, start, self.length)
yield Span(self, start, i+1)
start = None
if start is not None:
yield Span(self, start, self.length)
cdef int push_back(self, int idx, LexemeOrToken lex_or_tok) except -1: cdef int push_back(self, int idx, LexemeOrToken lex_or_tok) except -1:
if self.length == self.max_length: if self.length == self.max_length:
@ -207,7 +207,7 @@ cdef class Tokens:
return idx + t.lex.length return idx + t.lex.length
@cython.boundscheck(False) @cython.boundscheck(False)
cpdef long[:,:] to_array(self, object py_attr_ids): cpdef np.ndarray to_array(self, object py_attr_ids):
"""Given a list of M attribute IDs, export the tokens to a numpy ndarray """Given a list of M attribute IDs, export the tokens to a numpy ndarray
of shape N*M, where N is the length of the sentence. of shape N*M, where N is the length of the sentence.
@ -221,10 +221,10 @@ cdef class Tokens:
""" """
cdef int i, j cdef int i, j
cdef attr_id_t feature cdef attr_id_t feature
cdef numpy.ndarray[long, ndim=2] output cdef np.ndarray[long, ndim=2] output
# Make an array from the attributes --- otherwise our inner loop is Python # Make an array from the attributes --- otherwise our inner loop is Python
# dict iteration. # dict iteration.
cdef numpy.ndarray[long, ndim=1] attr_ids = numpy.asarray(py_attr_ids) cdef np.ndarray[long, ndim=1] attr_ids = numpy.asarray(py_attr_ids)
output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int) output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int)
for i in range(self.length): for i in range(self.length):
for j, feature in enumerate(attr_ids): for j, feature in enumerate(attr_ids):
@ -464,7 +464,9 @@ cdef class Token:
property repvec: property repvec:
def __get__(self): def __get__(self):
return numpy.asarray(<float[:self.vocab.repvec_length,]> self.c.lex.repvec) cdef int length = self.vocab.repvec_length
repvec_view = <float[:length,]>self.c.lex.repvec
return numpy.asarray(repvec_view)
property n_lefts: property n_lefts:
def __get__(self): def __get__(self):
@ -546,13 +548,13 @@ cdef class Token:
property left_edge: property left_edge:
def __get__(self): def __get__(self):
return Token.cinit(self.vocab, self._string, return Token.cinit(self.vocab, self._string,
self.c + self.c.l_edge, self.i + self.c.l_edge, (self.c - self.i) + self.c.l_edge, self.c.l_edge,
self.array_len, self._seq) self.array_len, self._seq)
property right_edge: property right_edge:
def __get__(self): def __get__(self):
return Token.cinit(self.vocab, self._string, return Token.cinit(self.vocab, self._string,
self.c + self.c.r_edge, self.i + self.c.r_edge, (self.c - self.i) + self.c.r_edge, self.c.r_edge,
self.array_len, self._seq) self.array_len, self._seq)
property head: property head:

View File

@ -0,0 +1,59 @@
import spacy.munge.read_conll
hongbin_example = """
1 2. 0. LS _ 24 meta _ _ _
2 . . . _ 1 punct _ _ _
3 Wang wang NNP _ 4 compound _ _ _
4 Hongbin hongbin NNP _ 16 nsubj _ _ _
5 , , , _ 4 punct _ _ _
6 the the DT _ 11 det _ _ _
7 " " `` _ 11 punct _ _ _
8 communist communist JJ _ 11 amod _ _ _
9 trail trail NN _ 11 compound _ _ _
10 - - HYPH _ 11 punct _ _ _
11 blazer blazer NN _ 4 appos _ _ _
12 , , , _ 16 punct _ _ _
13 " " '' _ 16 punct _ _ _
14 has have VBZ _ 16 aux _ _ _
15 not not RB _ 16 neg _ _ _
16 turned turn VBN _ 24 ccomp _ _ _
17 into into IN syn=CLR 16 prep _ _ _
18 a a DT _ 19 det _ _ _
19 capitalist capitalist NN _ 17 pobj _ _ _
20 ( ( -LRB- _ 24 punct _ _ _
21 he he PRP _ 24 nsubj _ _ _
22 does do VBZ _ 24 aux _ _ _
23 n't not RB _ 24 neg _ _ _
24 have have VB _ 0 root _ _ _
25 any any DT _ 26 det _ _ _
26 shares share NNS _ 24 dobj _ _ _
27 , , , _ 24 punct _ _ _
28 does do VBZ _ 30 aux _ _ _
29 n't not RB _ 30 neg _ _ _
30 have have VB _ 24 conj _ _ _
31 any any DT _ 32 det _ _ _
32 savings saving NNS _ 30 dobj _ _ _
33 , , , _ 30 punct _ _ _
34 does do VBZ _ 36 aux _ _ _
35 n't not RB _ 36 neg _ _ _
36 have have VB _ 30 conj _ _ _
37 his his PRP$ _ 39 poss _ _ _
38 own own JJ _ 39 amod _ _ _
39 car car NN _ 36 dobj _ _ _
40 , , , _ 36 punct _ _ _
41 and and CC _ 36 cc _ _ _
42 does do VBZ _ 44 aux _ _ _
43 n't not RB _ 44 neg _ _ _
44 have have VB _ 36 conj _ _ _
45 a a DT _ 46 det _ _ _
46 mansion mansion NN _ 44 dobj _ _ _
47 ; ; . _ 24 punct _ _ _
""".strip()
def test_hongbin():
words, annot = spacy.munge.read_conll.parse(hongbin_example, strip_bad_periods=True)
assert words[annot[0]['head']] == 'have'
assert words[annot[1]['head']] == 'Hongbin'

View File

@ -103,10 +103,12 @@ def test_cnts5(en_tokenizer):
tokens = en_tokenizer(text) tokens = en_tokenizer(text)
assert len(tokens) == 11 assert len(tokens) == 11
def test_mr(en_tokenizer): # TODO: This is currently difficult --- infix interferes here.
text = """Mr. Smith""" #def test_mr(en_tokenizer):
tokens = en_tokenizer(text) # text = """Today is Tuesday.Mr."""
assert len(tokens) == 2 # tokens = en_tokenizer(text)
# assert len(tokens) == 5
# assert [w.orth_ for w in tokens] == ['Today', 'is', 'Tuesday', '.', 'Mr.']
def test_cnts6(en_tokenizer): def test_cnts6(en_tokenizer):