mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
* Fix merge conflicts
This commit is contained in:
commit
6dbe182491
|
@ -48,7 +48,7 @@ def add_noise(orig, noise_level):
|
|||
return ''.join(_corrupt(c, noise_level) for c in orig)
|
||||
|
||||
|
||||
def score_model(scorer, nlp, raw_text, annot_tuples):
|
||||
def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
|
||||
if raw_text is None:
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
else:
|
||||
|
@ -57,7 +57,7 @@ def score_model(scorer, nlp, raw_text, annot_tuples):
|
|||
nlp.entity(tokens)
|
||||
nlp.parser(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold, verbose=False)
|
||||
scorer.score(tokens, gold, verbose=verbose)
|
||||
|
||||
|
||||
def _merge_sents(sents):
|
||||
|
@ -78,7 +78,8 @@ def _merge_sents(sents):
|
|||
|
||||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||
seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
|
||||
beam_width=1):
|
||||
beam_width=1, verbose=False,
|
||||
use_orig_arc_eager=False):
|
||||
dep_model_dir = path.join(model_dir, 'deps')
|
||||
pos_model_dir = path.join(model_dir, 'pos')
|
||||
ner_model_dir = path.join(model_dir, 'ner')
|
||||
|
@ -118,7 +119,8 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
for annot_tuples, ctnt in sents:
|
||||
if len(annot_tuples[1]) == 1:
|
||||
continue
|
||||
score_model(scorer, nlp, raw_text, annot_tuples)
|
||||
score_model(scorer, nlp, raw_text, annot_tuples,
|
||||
verbose=verbose if itn >= 2 else False)
|
||||
if raw_text is None:
|
||||
words = add_noise(annot_tuples[1], corruption_level)
|
||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
||||
|
@ -127,8 +129,12 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
|||
tokens = nlp.tokenizer(raw_text)
|
||||
nlp.tagger(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
||||
if not gold.is_projective:
|
||||
raise Exception(
|
||||
"Non-projective sentence in training, after we should "
|
||||
"have enforced projectivity: %s" % annot_tuples
|
||||
)
|
||||
loss += nlp.parser.train(tokens, gold)
|
||||
|
||||
nlp.entity.train(tokens, gold)
|
||||
nlp.tagger.train(tokens, gold.tags)
|
||||
random.shuffle(gold_tuples)
|
||||
|
@ -203,20 +209,24 @@ def write_parses(Language, dev_loc, model_dir, out_loc, beam_width=None):
|
|||
n_iter=("Number of training iterations", "option", "i", int),
|
||||
beam_width=("Number of candidates to maintain in the beam", "option", "k", int),
|
||||
verbose=("Verbose error reporting", "flag", "v", bool),
|
||||
debug=("Debug mode", "flag", "d", bool)
|
||||
debug=("Debug mode", "flag", "d", bool),
|
||||
use_orig_arc_eager=("Use the original, monotonic arc-eager system", "flag", "m", bool)
|
||||
)
|
||||
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
|
||||
debug=False, corruption_level=0.0, gold_preproc=False, beam_width=1,
|
||||
eval_only=False):
|
||||
eval_only=False, use_orig_arc_eager=False):
|
||||
if use_orig_arc_eager:
|
||||
English.ParserTransitionSystem = TreeArcEager
|
||||
if not eval_only:
|
||||
gold_train = list(read_json_file(train_loc))
|
||||
train(English, gold_train, model_dir,
|
||||
feat_set='basic' if not debug else 'debug',
|
||||
gold_preproc=gold_preproc, n_sents=n_sents,
|
||||
corruption_level=corruption_level, n_iter=n_iter,
|
||||
beam_width=beam_width)
|
||||
if out_loc:
|
||||
write_parses(English, dev_loc, model_dir, out_loc, beam_width=beam_width)
|
||||
beam_width=beam_width, verbose=verbose,
|
||||
use_orig_arc_eager=use_orig_arc_eager)
|
||||
#if out_loc:
|
||||
# write_parses(English, dev_loc, model_dir, out_loc, beam_width=beam_width)
|
||||
scorer = evaluate(English, list(read_json_file(dev_loc)),
|
||||
model_dir, gold_preproc=gold_preproc, verbose=verbose,
|
||||
beam_width=beam_width)
|
||||
|
|
|
@ -2,7 +2,7 @@ cython
|
|||
cymem == 1.11
|
||||
pathlib
|
||||
preshed == 0.37
|
||||
thinc == 1.76
|
||||
thinc == 2.0
|
||||
murmurhash == 0.24
|
||||
unidecode
|
||||
numpy
|
||||
|
|
10
setup.py
10
setup.py
|
@ -118,7 +118,7 @@ def run_setup(exts):
|
|||
ext_modules=exts,
|
||||
license="Dual: Commercial or AGPL",
|
||||
install_requires=['numpy', 'murmurhash', 'cymem >= 1.11', 'preshed == 0.37',
|
||||
'thinc == 1.76', "unidecode", 'wget', 'plac', 'six',
|
||||
'thinc == 2.0', "unidecode", 'wget', 'plac', 'six',
|
||||
'ujson'],
|
||||
setup_requires=["headers_workaround"],
|
||||
)
|
||||
|
@ -150,11 +150,13 @@ def main(modules, is_pypy):
|
|||
MOD_NAMES = ['spacy.parts_of_speech', 'spacy.strings',
|
||||
'spacy.lexeme', 'spacy.vocab', 'spacy.tokens', 'spacy.spans',
|
||||
'spacy.morphology',
|
||||
'spacy.syntax.stateclass',
|
||||
'spacy._ml', 'spacy.tokenizer', 'spacy.en.attrs',
|
||||
'spacy.en.pos', 'spacy.syntax.parser', 'spacy.syntax._state',
|
||||
'spacy.en.pos', 'spacy.syntax.parser',
|
||||
'spacy.syntax.transition_system',
|
||||
'spacy.syntax.arc_eager', 'spacy.syntax._parse_features',
|
||||
'spacy.gold', 'spacy.orth',
|
||||
'spacy.syntax.arc_eager',
|
||||
'spacy.syntax._parse_features',
|
||||
'spacy.gold', 'spacy.orth',
|
||||
'spacy.syntax.ner']
|
||||
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ def _min_edit_path(cand_words, gold_words):
|
|||
return prev_costs[n_gold], previous_row[-1]
|
||||
|
||||
|
||||
def read_json_file(loc):
|
||||
def read_json_file(loc, docs_filter=None):
|
||||
print loc
|
||||
if path.isdir(loc):
|
||||
for filename in os.listdir(loc):
|
||||
|
@ -130,6 +130,8 @@ def read_json_file(loc):
|
|||
with open(loc) as file_:
|
||||
docs = ujson.load(file_)
|
||||
for doc in docs:
|
||||
if docs_filter is not None and not docs_filter(doc):
|
||||
continue
|
||||
paragraphs = []
|
||||
for paragraph in doc['paragraphs']:
|
||||
sents = []
|
||||
|
@ -146,6 +148,9 @@ def read_json_file(loc):
|
|||
tags.append(token['tag'])
|
||||
heads.append(token['head'] + i)
|
||||
labels.append(token['dep'])
|
||||
# Ensure ROOT label is case-insensitive
|
||||
if labels[-1].lower() == 'root':
|
||||
labels[-1] = 'ROOT'
|
||||
ner.append(token.get('ner', '-'))
|
||||
sents.append((
|
||||
(ids, words, tags, heads, labels, ner),
|
||||
|
@ -240,6 +245,16 @@ cdef class GoldParse:
|
|||
self.heads[w2] = None
|
||||
self.labels[w2] = ''
|
||||
|
||||
# Check there are no cycles in the dependencies, i.e. we are a tree
|
||||
for w in range(self.length):
|
||||
seen = set([w])
|
||||
head = w
|
||||
while self.heads[head] != head and self.heads[head] != None:
|
||||
head = self.heads[head]
|
||||
if head in seen:
|
||||
raise Exception("Cycle found: %s" % seen)
|
||||
seen.add(head)
|
||||
|
||||
self.brackets = {}
|
||||
for (gold_start, gold_end, label_str) in brackets:
|
||||
start = self.gold_to_cand[gold_start]
|
||||
|
|
|
@ -10,11 +10,12 @@ def parse(sent_text, strip_bad_periods=False):
|
|||
assert sent_text
|
||||
annot = []
|
||||
words = []
|
||||
id_map = {}
|
||||
id_map = {-1: -1}
|
||||
for i, line in enumerate(sent_text.split('\n')):
|
||||
word, tag, head, dep = _parse_line(line)
|
||||
if strip_bad_periods and words and _is_bad_period(words[-1], word):
|
||||
continue
|
||||
id_map[i] = len(words)
|
||||
|
||||
annot.append({
|
||||
'id': len(words),
|
||||
|
@ -23,6 +24,8 @@ def parse(sent_text, strip_bad_periods=False):
|
|||
'head': int(head) - 1,
|
||||
'dep': dep})
|
||||
words.append(word)
|
||||
for entry in annot:
|
||||
entry['head'] = id_map[entry['head']]
|
||||
return words, annot
|
||||
|
||||
|
||||
|
|
|
@ -113,3 +113,9 @@ class Scorer(object):
|
|||
set(item[:2] for item in cand_deps),
|
||||
set(item[:2] for item in gold_deps),
|
||||
)
|
||||
if verbose:
|
||||
gold_words = [item[1] for item in gold.orig_annot]
|
||||
for w_id, h_id, dep in (cand_deps - gold_deps):
|
||||
print 'F', gold_words[w_id], dep, gold_words[h_id]
|
||||
for w_id, h_id, dep in (gold_deps - cand_deps):
|
||||
print 'M', gold_words[w_id], dep, gold_words[h_id]
|
||||
|
|
|
@ -61,6 +61,9 @@ cdef class StringStore:
|
|||
def __get__(self):
|
||||
return self.size-1
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def __getitem__(self, object string_or_id):
|
||||
cdef bytes byte_string
|
||||
cdef const Utf8Str* utf8str
|
||||
|
|
|
@ -68,7 +68,7 @@ cdef struct TokenC:
|
|||
int sense
|
||||
int head
|
||||
int dep
|
||||
bint sent_end
|
||||
bint sent_start
|
||||
|
||||
uint32_t l_kids
|
||||
uint32_t r_kids
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from thinc.typedefs cimport atom_t
|
||||
|
||||
from ._state cimport State
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
|
||||
cdef int fill_context(atom_t* context, State* state) except -1
|
||||
cdef int fill_context(atom_t* context, StateClass state) except -1
|
||||
# Context elements
|
||||
|
||||
# Ensure each token's attributes are listed: w, p, c, c6, c4. The order
|
||||
|
|
|
@ -12,12 +12,10 @@ from libc.string cimport memset
|
|||
from itertools import combinations
|
||||
|
||||
from ..tokens cimport TokenC
|
||||
from ._state cimport State
|
||||
from ._state cimport get_s2, get_s1, get_s0, get_n0, get_n1, get_n2
|
||||
from ._state cimport get_p2, get_p1
|
||||
from ._state cimport get_e0, get_e1
|
||||
from ._state cimport has_head, get_left, get_right
|
||||
from ._state cimport count_left_kids, count_right_kids
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
|
||||
cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
|
||||
|
@ -53,56 +51,56 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
|
|||
# the source that are set to 1.
|
||||
context[4] = token.lex.cluster & 15
|
||||
context[5] = token.lex.cluster & 63
|
||||
context[6] = token.dep if has_head(token) else 0
|
||||
context[6] = token.dep if token.head != 0 else 0
|
||||
context[7] = token.lex.prefix
|
||||
context[8] = token.lex.suffix
|
||||
context[9] = token.lex.shape
|
||||
context[10] = token.ent_iob
|
||||
context[11] = token.ent_type
|
||||
|
||||
|
||||
cdef int fill_context(atom_t* context, State* state) except -1:
|
||||
cdef int fill_context(atom_t* ctxt, StateClass st) except -1:
|
||||
# Take care to fill every element of context!
|
||||
# We could memset, but this makes it very easy to have broken features that
|
||||
# make almost no impact on accuracy. If instead they're unset, the impact
|
||||
# tends to be dramatic, so we get an obvious regression to fix...
|
||||
fill_token(&context[S2w], get_s2(state))
|
||||
fill_token(&context[S1w], get_s1(state))
|
||||
fill_token(&context[S1rw], get_right(state, get_s1(state), 1))
|
||||
fill_token(&context[S0lw], get_left(state, get_s0(state), 1))
|
||||
fill_token(&context[S0l2w], get_left(state, get_s0(state), 2))
|
||||
fill_token(&context[S0w], get_s0(state))
|
||||
fill_token(&context[S0r2w], get_right(state, get_s0(state), 2))
|
||||
fill_token(&context[S0rw], get_right(state, get_s0(state), 1))
|
||||
fill_token(&context[N0lw], get_left(state, get_n0(state), 1))
|
||||
fill_token(&context[N0l2w], get_left(state, get_n0(state), 2))
|
||||
fill_token(&context[N0w], get_n0(state))
|
||||
fill_token(&context[N1w], get_n1(state))
|
||||
fill_token(&context[N2w], get_n2(state))
|
||||
fill_token(&context[P1w], get_p1(state))
|
||||
fill_token(&context[P2w], get_p2(state))
|
||||
fill_token(&ctxt[S2w], st.S_(2))
|
||||
fill_token(&ctxt[S1w], st.S_(1))
|
||||
fill_token(&ctxt[S1rw], st.R_(st.S(1), 1))
|
||||
fill_token(&ctxt[S0lw], st.L_(st.S(0), 1))
|
||||
fill_token(&ctxt[S0l2w], st.L_(st.S(0), 2))
|
||||
fill_token(&ctxt[S0w], st.S_(0))
|
||||
fill_token(&ctxt[S0r2w], st.R_(st.S(0), 2))
|
||||
fill_token(&ctxt[S0rw], st.R_(st.S(0), 1))
|
||||
fill_token(&ctxt[N0lw], st.L_(st.B(0), 1))
|
||||
fill_token(&ctxt[N0l2w], st.L_(st.B(0), 2))
|
||||
fill_token(&ctxt[N0w], st.B_(0))
|
||||
fill_token(&ctxt[N1w], st.B_(1))
|
||||
fill_token(&ctxt[N2w], st.B_(2))
|
||||
fill_token(&ctxt[P1w], st.safe_get(st.B(0)-1))
|
||||
fill_token(&ctxt[P2w], st.safe_get(st.B(0)-2))
|
||||
|
||||
fill_token(&context[E0w], get_e0(state))
|
||||
fill_token(&context[E1w], get_e1(state))
|
||||
if state.stack_len >= 1:
|
||||
context[dist] = min(state.i - state.stack[0], 5)
|
||||
fill_token(&ctxt[E0w], st.E_(0))
|
||||
fill_token(&ctxt[E1w], st.E_(1))
|
||||
|
||||
if st.stack_depth() >= 1 and not st.eol():
|
||||
ctxt[dist] = min(st.B(0) - st.E(0), 5)
|
||||
else:
|
||||
context[dist] = 0
|
||||
context[N0lv] = min(count_left_kids(get_n0(state)), 5)
|
||||
context[S0lv] = min(count_left_kids(get_s0(state)), 5)
|
||||
context[S0rv] = min(count_right_kids(get_s0(state)), 5)
|
||||
context[S1lv] = min(count_left_kids(get_s1(state)), 5)
|
||||
context[S1rv] = min(count_right_kids(get_s1(state)), 5)
|
||||
ctxt[dist] = 0
|
||||
ctxt[N0lv] = min(st.n_L(st.B(0)), 5)
|
||||
ctxt[S0lv] = min(st.n_L(st.S(0)), 5)
|
||||
ctxt[S0rv] = min(st.n_R(st.S(0)), 5)
|
||||
ctxt[S1lv] = min(st.n_L(st.S(1)), 5)
|
||||
ctxt[S1rv] = min(st.n_R(st.S(1)), 5)
|
||||
|
||||
context[S0_has_head] = 0
|
||||
context[S1_has_head] = 0
|
||||
context[S2_has_head] = 0
|
||||
if state.stack_len >= 1:
|
||||
context[S0_has_head] = has_head(get_s0(state)) + 1
|
||||
if state.stack_len >= 2:
|
||||
context[S1_has_head] = has_head(get_s1(state)) + 1
|
||||
if state.stack_len >= 3:
|
||||
context[S2_has_head] = has_head(get_s2(state)) + 1
|
||||
ctxt[S0_has_head] = 0
|
||||
ctxt[S1_has_head] = 0
|
||||
ctxt[S2_has_head] = 0
|
||||
if st.stack_depth() >= 1:
|
||||
ctxt[S0_has_head] = st.has_head(st.S(0)) + 1
|
||||
if st.stack_depth() >= 2:
|
||||
ctxt[S1_has_head] = st.has_head(st.S(1)) + 1
|
||||
if st.stack_depth() >= 3:
|
||||
ctxt[S2_has_head] = st.has_head(st.S(2)) + 1
|
||||
|
||||
|
||||
ner = (
|
||||
|
@ -266,6 +264,32 @@ s0_n0 = (
|
|||
(S0p, S0rp, N0p),
|
||||
(S0p, N0lp, N0W),
|
||||
(S0p, N0lp, N0p),
|
||||
(S0L, N0p),
|
||||
(S0p, S0rL, N0p),
|
||||
(S0p, N0lL, N0p),
|
||||
(S0p, S0rv, N0p),
|
||||
(S0p, N0lv, N0p),
|
||||
(S0c6, S0rL, S0r2L, N0p),
|
||||
(S0p, N0lL, N0l2L, N0p),
|
||||
)
|
||||
|
||||
|
||||
s1_s0 = (
|
||||
(S1p, S0p),
|
||||
(S1p, S0p, S0_has_head),
|
||||
(S1W, S0p),
|
||||
(S1W, S0p, S0_has_head),
|
||||
(S1c, S0p),
|
||||
(S1c, S0p, S0_has_head),
|
||||
(S1p, S1rL, S0p),
|
||||
(S1p, S1rL, S0p, S0_has_head),
|
||||
(S1p, S0lL, S0p),
|
||||
(S1p, S0lL, S0p, S0_has_head),
|
||||
(S1p, S0lL, S0l2L, S0p),
|
||||
(S1p, S0lL, S0l2L, S0p, S0_has_head),
|
||||
(S1L, S0L, S0W),
|
||||
(S1L, S0L, S0p),
|
||||
(S1p, S1L, S0L, S0p),
|
||||
)
|
||||
|
||||
|
||||
|
@ -277,6 +301,8 @@ s1_n0 = (
|
|||
(S1W, S1p, N0p),
|
||||
(S1p, N0W, N0p),
|
||||
(S1c6, S1p, N0c6, N0p),
|
||||
(S1L, N0p),
|
||||
(S1p, S1rL, N0p),
|
||||
)
|
||||
|
||||
|
||||
|
@ -288,6 +314,8 @@ s0_n1 = (
|
|||
(S0W, S0p, N1p),
|
||||
(S0p, N1W, N1p),
|
||||
(S0c6, S0p, N1c6, N1p),
|
||||
(S0L, N1p),
|
||||
(S0p, S0rL, N1p),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -2,10 +2,16 @@ from cymem.cymem cimport Pool
|
|||
|
||||
from thinc.typedefs cimport weight_t
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
from ._state cimport State
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
from ..gold cimport GoldParseC
|
||||
|
||||
|
||||
cdef class ArcEager(TransitionSystem):
|
||||
pass
|
||||
|
||||
|
||||
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil
|
||||
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil
|
||||
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
# cython: profile=True
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from ._state cimport State
|
||||
from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right
|
||||
from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep
|
||||
from ._state cimport head_in_buffer, children_in_buffer
|
||||
from ._state cimport head_in_stack, children_in_stack
|
||||
from ._state cimport count_left_kids
|
||||
import ctypes
|
||||
import os
|
||||
|
||||
from ..structs cimport TokenC
|
||||
|
||||
|
@ -15,9 +11,16 @@ from .transition_system cimport move_cost_func_t, label_cost_func_t
|
|||
from ..gold cimport GoldParse
|
||||
from ..gold cimport GoldParseC
|
||||
|
||||
from libc.stdint cimport uint32_t
|
||||
from libc.string cimport memcpy
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
|
||||
DEF NON_MONOTONIC = True
|
||||
DEF USE_BREAK = True
|
||||
DEF USE_ROOT_ARC_SEGMENT = True
|
||||
|
||||
cdef weight_t MIN_SCORE = -90000
|
||||
|
||||
|
@ -31,9 +34,6 @@ cdef enum:
|
|||
|
||||
BREAK
|
||||
|
||||
CONSTITUENT
|
||||
ADJUST
|
||||
|
||||
N_MOVES
|
||||
|
||||
|
||||
|
@ -43,417 +43,259 @@ MOVE_NAMES[REDUCE] = 'D'
|
|||
MOVE_NAMES[LEFT] = 'L'
|
||||
MOVE_NAMES[RIGHT] = 'R'
|
||||
MOVE_NAMES[BREAK] = 'B'
|
||||
MOVE_NAMES[CONSTITUENT] = 'C'
|
||||
MOVE_NAMES[ADJUST] = 'A'
|
||||
|
||||
|
||||
# Helper functions for the arc-eager oracle
|
||||
|
||||
cdef int push_cost(const State* st, const GoldParseC* gold, int target) except -1:
|
||||
# When we push a word, we can't make arcs to or from the stack. So, we lose
|
||||
# any of those arcs.
|
||||
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||
cdef int cost = 0
|
||||
cost += head_in_stack(st, target, gold.heads)
|
||||
cost += children_in_stack(st, target, gold.heads)
|
||||
cdef int i, S_i
|
||||
for i in range(stcls.stack_depth()):
|
||||
S_i = stcls.S(i)
|
||||
if gold.heads[target] == S_i:
|
||||
cost += 1
|
||||
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
|
||||
cost += 1
|
||||
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0
|
||||
return cost
|
||||
|
||||
|
||||
cdef int pop_cost(const State* st, const GoldParseC* gold, int target) except -1:
|
||||
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
|
||||
cdef int cost = 0
|
||||
cost += children_in_buffer(st, target, gold.heads)
|
||||
cost += head_in_buffer(st, target, gold.heads)
|
||||
cdef int i, B_i
|
||||
for i in range(stcls.buffer_length()):
|
||||
B_i = stcls.B(i)
|
||||
cost += gold.heads[B_i] == target
|
||||
cost += gold.heads[target] == B_i
|
||||
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
|
||||
break
|
||||
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0
|
||||
return cost
|
||||
|
||||
|
||||
cdef int arc_cost(const GoldParseC* gold, int head, int child, int label) except -1:
|
||||
if gold.heads[child] != head:
|
||||
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil:
|
||||
if arc_is_gold(gold, head, child):
|
||||
return 0
|
||||
elif gold.labels[child] == -1:
|
||||
return 0
|
||||
elif gold.labels[child] == label:
|
||||
return 0
|
||||
else:
|
||||
elif stcls.H(child) == gold.heads[child]:
|
||||
return 1
|
||||
# Head in buffer
|
||||
elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != -1:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
|
||||
if gold.labels[child] == -1:
|
||||
return True
|
||||
elif USE_ROOT_ARC_SEGMENT and _is_gold_root(gold, head) and _is_gold_root(gold, child):
|
||||
return True
|
||||
elif gold.heads[child] == head:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, int label) nogil:
|
||||
if gold.labels[child] == -1:
|
||||
return True
|
||||
elif label == -1:
|
||||
return True
|
||||
elif gold.labels[child] == label:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
||||
return gold.labels[word] == -1 or gold.heads[word] == word
|
||||
|
||||
|
||||
cdef class Shift:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
return not at_eol(s)
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_start
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
# Set the dep label, in case we need it after we reduce
|
||||
if NON_MONOTONIC:
|
||||
state.sent[state.i].dep = label
|
||||
push_stack(state)
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.push()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Shift.is_valid(s, label):
|
||||
return 9000
|
||||
return Shift.move_cost(s, gold) + Shift.label_cost(s, gold, label)
|
||||
cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil:
|
||||
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
cdef int cost = push_cost(s, gold, s.i)
|
||||
# If we can break, and there's no cost to doing so, we should
|
||||
if Break.is_valid(s, -1) and Break.cost(s, gold, -1) == 0:
|
||||
cost += 1
|
||||
return cost
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
return push_cost(s, gold, s.B(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return 0
|
||||
|
||||
|
||||
cdef class Reduce:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
if NON_MONOTONIC:
|
||||
return s.stack_len >= 2 #and not missing_brackets(s)
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return st.stack_depth() >= 2
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
if st.has_head(st.S(0)):
|
||||
st.pop()
|
||||
else:
|
||||
return s.stack_len >= 2 and has_head(get_s0(s))
|
||||
st.unshift()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
if NON_MONOTONIC and not has_head(get_s0(state)):
|
||||
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
|
||||
pop_stack(state)
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Reduce.is_valid(s, label):
|
||||
return 9000
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
if NON_MONOTONIC:
|
||||
return pop_cost(s, gold, s.stack[0])
|
||||
else:
|
||||
return children_in_buffer(s, s.stack[0], gold.heads)
|
||||
cdef inline int move_cost(StateClass st, const GoldParseC* gold) nogil:
|
||||
return pop_cost(st, gold, st.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
cdef class LeftArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
if NON_MONOTONIC:
|
||||
return s.stack_len >= 1 #and not missing_brackets(s)
|
||||
else:
|
||||
return s.stack_len >= 1 and not has_head(get_s0(s))
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return not st.B_(0).sent_start
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
# Interpret left-arcs from EOL as attachment to root
|
||||
if at_eol(state):
|
||||
add_dep(state, state.stack[0], state.stack[0], label)
|
||||
else:
|
||||
add_dep(state, state.i, state.stack[0], label)
|
||||
pop_stack(state)
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.add_arc(st.B(0), st.S(0), label)
|
||||
st.pop()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not LeftArc.is_valid(s, label):
|
||||
return 9000
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
if not LeftArc.is_valid(s, -1):
|
||||
return 9000
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
cdef int cost = 0
|
||||
if gold.heads[s.stack[0]] == s.i:
|
||||
return cost
|
||||
elif at_eol(s):
|
||||
# Are we root?
|
||||
if gold.labels[s.stack[0]] != -1:
|
||||
# If we're at EOL, prefer to reduce or break over left-arc
|
||||
if Reduce.is_valid(s, -1) or Break.is_valid(s, -1):
|
||||
cost += gold.heads[s.stack[0]] != s.stack[0]
|
||||
return cost
|
||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||
if NON_MONOTONIC and s.stack_len >= 2:
|
||||
cost += gold.heads[s.stack[0]] == s.stack[-1]
|
||||
if gold.labels[s.stack[0]] != -1:
|
||||
cost += gold.heads[s.stack[0]] == s.stack[0]
|
||||
return cost
|
||||
if arc_is_gold(gold, s.B(0), s.S(0)):
|
||||
return 0
|
||||
else:
|
||||
# Account for deps we might lose between S0 and stack
|
||||
if not s.has_head(s.S(0)):
|
||||
for i in range(1, s.stack_depth()):
|
||||
cost += gold.heads[s.S(i)] == s.S(0)
|
||||
cost += gold.heads[s.S(0)] == s.S(i)
|
||||
return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if label == -1 or gold.labels[s.stack[0]] == -1:
|
||||
return 0
|
||||
if gold.heads[s.stack[0]] == s.i and label != gold.labels[s.stack[0]]:
|
||||
return 1
|
||||
return 0
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
|
||||
|
||||
|
||||
cdef class RightArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
return s.stack_len >= 1 and not at_eol(s)
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return not st.B_(0).sent_start
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
add_dep(state, state.stack[0], state.i, label)
|
||||
push_stack(state)
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.add_arc(st.S(0), st.B(0), label)
|
||||
st.push()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not RightArc.is_valid(s, label):
|
||||
return 9000
|
||||
cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
return push_cost(s, gold, s.i) - (gold.heads[s.i] == s.stack[0])
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
if arc_is_gold(gold, s.S(0), s.B(0)):
|
||||
return 0
|
||||
elif s.shifted[s.B(0)]:
|
||||
return push_cost(s, gold, s.B(0))
|
||||
else:
|
||||
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return arc_cost(gold, s.stack[0], s.i, label)
|
||||
#cdef int cost = 0
|
||||
#if gold.heads[s.i] == s.stack[0]:
|
||||
# cost += label != -1 and label != gold.labels[s.i]
|
||||
# return cost
|
||||
# This indicates missing head
|
||||
#if gold.labels[s.i] != -1:
|
||||
# cost += head_in_buffer(s, s.i, gold.heads)
|
||||
#cost += children_in_stack(s, s.i, gold.heads)
|
||||
#cost += head_in_stack(s, s.i, gold.heads)
|
||||
#return cost
|
||||
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
|
||||
|
||||
|
||||
cdef class Break:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
cdef int i
|
||||
if not USE_BREAK:
|
||||
return False
|
||||
elif at_eol(s):
|
||||
elif st.at_break():
|
||||
return False
|
||||
elif st.B(0) == 0:
|
||||
return False
|
||||
elif st.stack_depth() < 1:
|
||||
return False
|
||||
elif (st.S(0) + 1) != st.B(0):
|
||||
# Must break at the token boundary
|
||||
return False
|
||||
#elif NON_MONOTONIC:
|
||||
# return True
|
||||
else:
|
||||
# In the Break transition paper, they have this constraint that prevents
|
||||
# Break if stack is disconnected. But, if we're doing non-monotonic parsing,
|
||||
# we prefer to relax this constraint. This is helpful in parsing whole
|
||||
# documents, because then we don't get stuck with words on the stack.
|
||||
seen_headless = False
|
||||
for i in range(s.stack_len):
|
||||
if s.sent[s.stack[-i]].head == 0:
|
||||
if seen_headless:
|
||||
return False
|
||||
else:
|
||||
seen_headless = True
|
||||
# TODO: Constituency constraints
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
state.sent[state.i-1].sent_end = True
|
||||
while state.stack_len != 0:
|
||||
if get_s0(state).head == 0:
|
||||
get_s0(state).dep = label
|
||||
state.stack -= 1
|
||||
state.stack_len -= 1
|
||||
if not at_eol(state):
|
||||
push_stack(state)
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.set_break(st.B(0))
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Break.is_valid(s, label):
|
||||
return 9000
|
||||
else:
|
||||
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
# When we break, we Reduce all of the words on the stack.
|
||||
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
|
||||
cdef int cost = 0
|
||||
# Number of deps between S0...Sn and N0...Nn
|
||||
for i in range(s.i, s.sent_len):
|
||||
cost += children_in_stack(s, i, gold.heads)
|
||||
cost += head_in_stack(s, i, gold.heads)
|
||||
return cost
|
||||
cdef int i, j, S_i, B_i
|
||||
for i in range(s.stack_depth()):
|
||||
S_i = s.S(i)
|
||||
for j in range(s.buffer_length()):
|
||||
B_i = s.B(j)
|
||||
cost += gold.heads[S_i] == B_i
|
||||
cost += gold.heads[B_i] == S_i
|
||||
# Check for sentence boundary --- if it's here, we can't have any deps
|
||||
# between stack and buffer, so rest of action is irrelevant.
|
||||
s0_root = _get_root(s.S(0), gold)
|
||||
b0_root = _get_root(s.B(0), gold)
|
||||
if s0_root != b0_root or s0_root == -1 or b0_root == -1:
|
||||
return cost
|
||||
else:
|
||||
return cost + 1
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return 0
|
||||
|
||||
|
||||
cdef class Constituent:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
if s.stack_len < 1:
|
||||
return False
|
||||
return False
|
||||
#else:
|
||||
# # If all stack elements are popped, can't constituent
|
||||
# for i in range(s.ctnts.stack_len):
|
||||
# if not s.ctnts.is_popped[-i]:
|
||||
# return True
|
||||
# else:
|
||||
# return False
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
return False
|
||||
#cdef Constituent* bracket = new_bracket(state.ctnts)
|
||||
|
||||
#bracket.parent = NULL
|
||||
#bracket.label = self.label
|
||||
#bracket.head = get_s0(state)
|
||||
#bracket.length = 0
|
||||
|
||||
#attach(bracket, state.ctnts.stack)
|
||||
# Attach rightward children. They're in the brackets array somewhere
|
||||
# between here and B0.
|
||||
#cdef Constituent* node
|
||||
#cdef const TokenC* node_gov
|
||||
#for i in range(1, bracket - state.ctnts.stack):
|
||||
# node = bracket - i
|
||||
# node_gov = node.head + node.head.head
|
||||
# if node_gov == bracket.head:
|
||||
# attach(bracket, node)
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Constituent.is_valid(s, label):
|
||||
return 9000
|
||||
raise Exception("Constituent move should be disabled currently")
|
||||
# The gold standard is indexed by end, then by start, then a set of labels
|
||||
#brackets = gold.brackets(get_s0(s).r_edge, {})
|
||||
#if not brackets:
|
||||
# return 2 # 2 loss for bad bracket, only 1 for good bracket bad label
|
||||
# Index the current brackets in the state
|
||||
#existing = set()
|
||||
#for i in range(s.ctnt_len):
|
||||
# if ctnt.end == s.r_edge and ctnt.label == self.label:
|
||||
# existing.add(ctnt.start)
|
||||
#cdef int loss = 2
|
||||
#cdef const TokenC* child
|
||||
#cdef const TokenC* s0 = get_s0(s)
|
||||
#cdef int n_left = count_left_kids(s0)
|
||||
# Iterate over the possible start positions, and check whether we have a
|
||||
# (start, end, label) match to the gold tree
|
||||
#for i in range(1, n_left):
|
||||
# child = get_left(s, s0, i)
|
||||
# if child.l_edge in brackets and child.l_edge not in existing:
|
||||
# if self.label in brackets[child.l_edge]
|
||||
# return 0
|
||||
# else:
|
||||
# loss = 1 # If we see the start position, set loss to 1
|
||||
#return loss
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
if not Constituent.is_valid(s, -1):
|
||||
return 9000
|
||||
raise Exception("Constituent move should be disabled currently")
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
cdef class Adjust:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
return False
|
||||
#if s.ctnts.stack_len < 2:
|
||||
# return False
|
||||
|
||||
#cdef const Constituent* b1 = s.ctnts.stack[-1]
|
||||
#cdef const Constituent* b0 = s.ctnts.stack[0]
|
||||
|
||||
#if (b1.head + b1.head.head) != b0.head:
|
||||
# return False
|
||||
#elif b0.head >= b1.head:
|
||||
# return False
|
||||
#elif b0 >= b1:
|
||||
# return False
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
return False
|
||||
#cdef Constituent* b0 = state.ctnts.stack[0]
|
||||
#cdef Constituent* b1 = state.ctnts.stack[1]
|
||||
|
||||
#assert (b1.head + b1.head.head) == b0.head
|
||||
#assert b0.head < b1.head
|
||||
#assert b0 < b1
|
||||
|
||||
#attach(b0, b1)
|
||||
## Pop B1 from stack, but keep B0 on top
|
||||
#state.ctnts.stack -= 1
|
||||
#state.ctnts.stack[0] = b0
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Adjust.is_valid(s, label):
|
||||
return 9000
|
||||
raise Exception("Adjust move should be disabled currently")
|
||||
|
||||
@staticmethod
|
||||
cdef int move_cost(const State* s, const GoldParseC* gold) except -1:
|
||||
if not Adjust.is_valid(s, -1):
|
||||
return 9000
|
||||
raise Exception("Adjust move should be disabled currently")
|
||||
|
||||
@staticmethod
|
||||
cdef int label_cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
return 0
|
||||
# The gold standard is indexed by end, then by start, then a set of labels
|
||||
#gold_starts = gold.brackets(get_s0(s).r_edge, {})
|
||||
# Case 1: There are 0 brackets ending at this word.
|
||||
# --> Cost is sunk, but must allow brackets to begin
|
||||
#if not gold_starts:
|
||||
# return 0
|
||||
# Is the top bracket correct?
|
||||
#gold_labels = gold_starts.get(s.ctnt.start, set())
|
||||
# TODO: Case where we have a unary rule
|
||||
# TODO: Case where two brackets end on this word, with top bracket starting
|
||||
# before
|
||||
|
||||
#cdef const TokenC* child
|
||||
#cdef const TokenC* s0 = get_s0(s)
|
||||
#cdef int n_left = count_left_kids(s0)
|
||||
#cdef int i
|
||||
# Iterate over the possible start positions, and check whether we have a
|
||||
# (start, end, label) match to the gold tree
|
||||
#for i in range(1, n_left):
|
||||
# child = get_left(s, s0, i)
|
||||
# if child.l_edge in brackets:
|
||||
# if self.label in brackets[child.l_edge]:
|
||||
# return 0
|
||||
# else:
|
||||
# loss = 1 # If we see the start position, set loss to 1
|
||||
#return loss
|
||||
|
||||
cdef int _get_root(int word, const GoldParseC* gold) nogil:
|
||||
while gold.heads[word] != word and gold.labels[word] != -1 and word >= 0:
|
||||
word = gold.heads[word]
|
||||
if gold.labels[word] == -1:
|
||||
return -1
|
||||
else:
|
||||
return word
|
||||
|
||||
|
||||
cdef class ArcEager(TransitionSystem):
|
||||
@classmethod
|
||||
def get_labels(cls, gold_parses):
|
||||
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
|
||||
LEFT: {'ROOT': True}, BREAK: {'ROOT': True},
|
||||
CONSTITUENT: {}, ADJUST: {'': True}}
|
||||
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {'ROOT': True},
|
||||
LEFT: {'ROOT': True}, BREAK: {'ROOT': True}}
|
||||
for raw_text, sents in gold_parses:
|
||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||
for child, head, label in zip(ids, heads, labels):
|
||||
if label.upper() == 'ROOT':
|
||||
label = 'ROOT'
|
||||
if label != 'ROOT':
|
||||
if head < child:
|
||||
move_labels[RIGHT][label] = True
|
||||
elif head > child:
|
||||
move_labels[LEFT][label] = True
|
||||
for start, end, label in ctnts:
|
||||
move_labels[CONSTITUENT][label] = True
|
||||
return move_labels
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
|
@ -462,8 +304,11 @@ cdef class ArcEager(TransitionSystem):
|
|||
gold.c.heads[i] = i
|
||||
gold.c.labels[i] = -1
|
||||
else:
|
||||
label = gold.labels[i]
|
||||
if label.upper() == 'ROOT':
|
||||
label = 'ROOT'
|
||||
gold.c.heads[i] = gold.heads[i]
|
||||
gold.c.labels[i] = self.strings[gold.labels[i]]
|
||||
gold.c.labels[i] = self.strings[label]
|
||||
for end, brackets in gold.brackets.items():
|
||||
for start, label_strs in brackets.items():
|
||||
gold.c.brackets[start][end] = 1
|
||||
|
@ -517,41 +362,43 @@ cdef class ArcEager(TransitionSystem):
|
|||
t.is_valid = Break.is_valid
|
||||
t.do = Break.transition
|
||||
t.get_cost = Break.cost
|
||||
elif move == CONSTITUENT:
|
||||
t.is_valid = Constituent.is_valid
|
||||
t.do = Constituent.transition
|
||||
t.get_cost = Constituent.cost
|
||||
elif move == ADJUST:
|
||||
t.is_valid = Adjust.is_valid
|
||||
t.do = Adjust.transition
|
||||
t.get_cost = Adjust.cost
|
||||
else:
|
||||
raise Exception(move)
|
||||
return t
|
||||
|
||||
cdef int initialize_state(self, State* state) except -1:
|
||||
push_stack(state)
|
||||
cdef int initialize_state(self, StateClass st) except -1:
|
||||
# Ensure sent_start is set to 0 throughout
|
||||
for i in range(st.length):
|
||||
st._sent[i].sent_start = False
|
||||
st._sent[i].l_edge = i
|
||||
st._sent[i].r_edge = i
|
||||
st.fast_forward()
|
||||
|
||||
cdef int finalize_state(self, State* state) except -1:
|
||||
cdef int finalize_state(self, StateClass st) except -1:
|
||||
cdef int root_label = self.strings['ROOT']
|
||||
for i in range(state.sent_len):
|
||||
if state.sent[i].head == 0 and state.sent[i].dep == 0:
|
||||
state.sent[i].dep = root_label
|
||||
for i in range(st.length):
|
||||
if st._sent[i].head == 0 and st._sent[i].dep == 0:
|
||||
st._sent[i].dep = root_label
|
||||
# If we're not using the Break transition, we segment via root-labelled
|
||||
# arcs between the root words.
|
||||
elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == root_label:
|
||||
st._sent[i].head = 0
|
||||
|
||||
cdef int set_valid(self, bint* output, const State* state) except -1:
|
||||
cdef int set_valid(self, bint* output, StateClass stcls) except -1:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = Shift.is_valid(state, -1)
|
||||
is_valid[REDUCE] = Reduce.is_valid(state, -1)
|
||||
is_valid[LEFT] = LeftArc.is_valid(state, -1)
|
||||
is_valid[RIGHT] = RightArc.is_valid(state, -1)
|
||||
is_valid[BREAK] = Break.is_valid(state, -1)
|
||||
is_valid[CONSTITUENT] = Constituent.is_valid(state, -1)
|
||||
is_valid[ADJUST] = Adjust.is_valid(state, -1)
|
||||
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
|
||||
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
|
||||
is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
|
||||
is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
|
||||
is_valid[BREAK] = Break.is_valid(stcls, -1)
|
||||
cdef int i
|
||||
n_valid = 0
|
||||
for i in range(self.n_moves):
|
||||
output[i] = is_valid[self.c[i].move]
|
||||
n_valid += output[i]
|
||||
assert n_valid >= 1
|
||||
|
||||
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
|
||||
cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
|
||||
cdef int i, move, label
|
||||
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
||||
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
||||
|
@ -563,35 +410,36 @@ cdef class ArcEager(TransitionSystem):
|
|||
move_cost_funcs[LEFT] = LeftArc.move_cost
|
||||
move_cost_funcs[RIGHT] = RightArc.move_cost
|
||||
move_cost_funcs[BREAK] = Break.move_cost
|
||||
move_cost_funcs[CONSTITUENT] = Constituent.move_cost
|
||||
move_cost_funcs[ADJUST] = Adjust.move_cost
|
||||
|
||||
label_cost_funcs[SHIFT] = Shift.label_cost
|
||||
label_cost_funcs[REDUCE] = Reduce.label_cost
|
||||
label_cost_funcs[LEFT] = LeftArc.label_cost
|
||||
label_cost_funcs[RIGHT] = RightArc.label_cost
|
||||
label_cost_funcs[BREAK] = Break.label_cost
|
||||
label_cost_funcs[CONSTITUENT] = Constituent.label_cost
|
||||
label_cost_funcs[ADJUST] = Adjust.label_cost
|
||||
|
||||
cdef int* labels = gold.c.labels
|
||||
cdef int* heads = gold.c.heads
|
||||
for i in range(self.n_moves):
|
||||
move = self.c[i].move
|
||||
label = self.c[i].label
|
||||
if move_costs[move] == -1:
|
||||
move_costs[move] = move_cost_funcs[move](s, &gold.c)
|
||||
output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
n_gold = 0
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||
move = self.c[i].move
|
||||
label = self.c[i].label
|
||||
if move_costs[move] == -1:
|
||||
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
||||
output[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
||||
n_gold += output[i] == 0
|
||||
else:
|
||||
output[i] = 9000
|
||||
assert n_gold >= 1
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = Shift.is_valid(s, -1)
|
||||
is_valid[REDUCE] = Reduce.is_valid(s, -1)
|
||||
is_valid[LEFT] = LeftArc.is_valid(s, -1)
|
||||
is_valid[RIGHT] = RightArc.is_valid(s, -1)
|
||||
is_valid[BREAK] = Break.is_valid(s, -1)
|
||||
is_valid[CONSTITUENT] = Constituent.is_valid(s, -1)
|
||||
is_valid[ADJUST] = Adjust.is_valid(s, -1)
|
||||
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
|
||||
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
|
||||
is_valid[LEFT] = LeftArc.is_valid(stcls, -1)
|
||||
is_valid[RIGHT] = RightArc.is_valid(stcls, -1)
|
||||
is_valid[BREAK] = Break.is_valid(stcls, -1)
|
||||
cdef Transition best
|
||||
cdef weight_t score = MIN_SCORE
|
||||
cdef int i
|
||||
|
@ -600,15 +448,5 @@ cdef class ArcEager(TransitionSystem):
|
|||
best = self.c[i]
|
||||
score = scores[i]
|
||||
assert best.clas < self.n_moves
|
||||
assert score > MIN_SCORE
|
||||
# Label Shift moves with the best Right-Arc label, for non-monotonic
|
||||
# actions
|
||||
if best.move == SHIFT:
|
||||
score = MIN_SCORE
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].move == RIGHT and scores[i] > score:
|
||||
best.label = self.c[i].label
|
||||
score = scores[i]
|
||||
assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length)
|
||||
return best
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from .transition_system cimport TransitionSystem
|
||||
from .transition_system cimport Transition
|
||||
from ._state cimport State
|
||||
from ..gold cimport GoldParseC
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from ._state cimport State
|
||||
|
||||
from .transition_system cimport Transition
|
||||
from .transition_system cimport do_func_t
|
||||
|
||||
|
@ -11,6 +9,8 @@ from thinc.typedefs cimport weight_t
|
|||
from ..gold cimport GoldParseC
|
||||
from ..gold cimport GoldParse
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
|
||||
cdef enum:
|
||||
MISSING
|
||||
|
@ -34,18 +34,14 @@ MOVE_NAMES[OUT] = 'O'
|
|||
cdef do_func_t[N_MOVES] do_funcs
|
||||
|
||||
|
||||
cdef bint entity_is_open(const State *s) except -1:
|
||||
return s.ents_len >= 1 and s.ent.end == 0
|
||||
|
||||
|
||||
cdef bint _entity_is_sunk(const State *s, Transition* golds) except -1:
|
||||
if not entity_is_open(s):
|
||||
cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
|
||||
if not st.entity_is_open():
|
||||
return False
|
||||
|
||||
cdef const Transition* gold = &golds[s.ent.start]
|
||||
cdef const Transition* gold = &golds[st.E(0)]
|
||||
if gold.move != BEGIN and gold.move != UNIT:
|
||||
return True
|
||||
elif gold.label != s.ent.label:
|
||||
elif gold.label != st.E_(0).ent_type:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
@ -132,14 +128,14 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
raise Exception(move)
|
||||
return t
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *:
|
||||
cdef int best = -1
|
||||
cdef weight_t score = -90000
|
||||
cdef const Transition* m
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
m = &self.c[i]
|
||||
if m.is_valid(s, m.label) and scores[i] > score:
|
||||
if m.is_valid(stcls, m.label) and scores[i] > score:
|
||||
best = i
|
||||
score = scores[i]
|
||||
assert best >= 0
|
||||
|
@ -147,49 +143,43 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
t.score = score
|
||||
return t
|
||||
|
||||
cdef int set_valid(self, bint* output, const State* s) except -1:
|
||||
cdef int set_valid(self, bint* output, StateClass stcls) except -1:
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
m = &self.c[i]
|
||||
output[i] = m.is_valid(s, m.label)
|
||||
output[i] = m.is_valid(stcls, m.label)
|
||||
|
||||
|
||||
cdef class Missing:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
raise NotImplementedError
|
||||
cdef int transition(StateClass s, int label) nogil:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
return 9000
|
||||
|
||||
|
||||
cdef class Begin:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
return label != 0 and not entity_is_open(s)
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return label != 0 and not st.entity_is_open()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
s.ent += 1
|
||||
s.ents_len += 1
|
||||
s.ent.start = s.i
|
||||
s.ent.label = label
|
||||
s.ent.end = 0
|
||||
s.sent[s.i].ent_iob = 3
|
||||
s.sent[s.i].ent_type = label
|
||||
s.i += 1
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.open_ent(label)
|
||||
st.set_ent_tag(st.B(0), 3, label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Begin.is_valid(s, label):
|
||||
return 9000
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef int g_tag = gold.ner[s.B(0)].label
|
||||
|
||||
if g_act == MISSING:
|
||||
return 0
|
||||
|
@ -203,25 +193,24 @@ cdef class Begin:
|
|||
# B, Gold U --> False (P)
|
||||
return 1
|
||||
|
||||
|
||||
cdef class In:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
return entity_is_open(s) and label != 0 and s.ent.label == label
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
s.sent[s.i].ent_iob = 1
|
||||
s.sent[s.i].ent_type = label
|
||||
s.i += 1
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.set_ent_tag(st.B(0), 1, label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not In.is_valid(s, label):
|
||||
return 9000
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
move = IN
|
||||
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
cdef int next_act = gold.ner[s.B(1)].move if s.B(0) < s.length else OUT
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef int g_tag = gold.ner[s.B(0)].label
|
||||
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
||||
|
||||
if g_act == MISSING:
|
||||
|
@ -245,27 +234,23 @@ cdef class In:
|
|||
return 1
|
||||
|
||||
|
||||
|
||||
cdef class Last:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
return entity_is_open(s) and label != 0 and s.ent.label == label
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
s.ent.end = s.i+1
|
||||
s.sent[s.i].ent_iob = 1
|
||||
s.sent[s.i].ent_type = label
|
||||
s.i += 1
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.close_ent()
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Last.is_valid(s, label):
|
||||
return 9000
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
move = LAST
|
||||
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef int g_tag = gold.ner[s.B(0)].label
|
||||
|
||||
if g_act == MISSING:
|
||||
return 0
|
||||
|
@ -290,26 +275,21 @@ cdef class Last:
|
|||
|
||||
cdef class Unit:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
return label != 0 and not entity_is_open(s)
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return label != 0 and not st.entity_is_open()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
s.ent += 1
|
||||
s.ents_len += 1
|
||||
s.ent.start = s.i
|
||||
s.ent.label = label
|
||||
s.ent.end = s.i+1
|
||||
s.sent[s.i].ent_iob = 3
|
||||
s.sent[s.i].ent_type = label
|
||||
s.i += 1
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.open_ent(label)
|
||||
st.close_ent()
|
||||
st.set_ent_tag(st.B(0), 3, label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Unit.is_valid(s, label):
|
||||
return 9000
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef int g_tag = gold.ner[s.B(0)].label
|
||||
|
||||
if g_act == MISSING:
|
||||
return 0
|
||||
|
@ -326,22 +306,19 @@ cdef class Unit:
|
|||
|
||||
cdef class Out:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const State* s, int label) except -1:
|
||||
return not entity_is_open(s)
|
||||
cdef bint is_valid(StateClass st, int label) nogil:
|
||||
return not st.entity_is_open()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
s.sent[s.i].ent_iob = 2
|
||||
s.i += 1
|
||||
cdef int transition(StateClass st, int label) nogil:
|
||||
st.set_ent_tag(st.B(0), 2, 0)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||
if not Out.is_valid(s, label):
|
||||
return 9000
|
||||
|
||||
cdef int g_act = gold.ner[s.i].move
|
||||
cdef int g_tag = gold.ner[s.i].label
|
||||
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef int g_tag = gold.ner[s.B(0)].label
|
||||
|
||||
if g_act == MISSING:
|
||||
return 0
|
||||
|
|
|
@ -5,8 +5,6 @@ from .._ml cimport Model
|
|||
from .arc_eager cimport TransitionSystem
|
||||
|
||||
from ..tokens cimport Tokens, TokenC
|
||||
from ._state cimport State
|
||||
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
# cython: profile=True
|
||||
# cython: experimental_cpp_class_def=True
|
||||
"""
|
||||
MALT-style dependency parser
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
cimport cython
|
||||
|
||||
from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
|
||||
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
from libc.string cimport memset, memcpy
|
||||
import random
|
||||
|
@ -14,7 +18,7 @@ import json
|
|||
|
||||
from cymem.cymem cimport Pool, Address
|
||||
from murmurhash.mrmr cimport hash64
|
||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
|
||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
|
||||
|
||||
|
||||
from util import Config
|
||||
|
@ -31,14 +35,16 @@ from thinc.search cimport MaxViolation
|
|||
from ..tokens cimport Tokens, TokenC
|
||||
from ..strings cimport StringStore
|
||||
|
||||
from .arc_eager cimport TransitionSystem, Transition
|
||||
from .transition_system import OracleError
|
||||
|
||||
from ._state cimport State, new_state, copy_state, is_final, push_stack
|
||||
from .transition_system import OracleError
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
|
||||
from ..gold cimport GoldParse
|
||||
|
||||
from . import _parse_features
|
||||
from ._parse_features cimport fill_context, CONTEXT_SIZE
|
||||
from ._parse_features cimport CONTEXT_SIZE
|
||||
from ._parse_features cimport fill_context
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
|
||||
DEBUG = False
|
||||
|
@ -47,20 +53,6 @@ def set_debug(val):
|
|||
DEBUG = val
|
||||
|
||||
|
||||
cdef unicode print_state(State* s, list words):
|
||||
words = list(words) + ['EOL']
|
||||
top = words[s.stack[0]] + '_%d' % s.sent[s.stack[0]].head
|
||||
second = words[s.stack[-1]] + '_%d' % s.sent[s.stack[-1]].head
|
||||
third = words[s.stack[-2]] + '_%d' % s.sent[s.stack[-2]].head
|
||||
n0 = words[s.i] if s.i < len(words) else 'EOL'
|
||||
n1 = words[s.i + 1] if s.i+1 < len(words) else 'EOL'
|
||||
if s.ents_len:
|
||||
ent = '%s %d-%d' % (s.ent.label, s.ent.start, s.ent.end)
|
||||
else:
|
||||
ent = '-'
|
||||
return ' '.join((ent, str(s.stack_len), third, second, top, '|', n0, n1))
|
||||
|
||||
|
||||
def get_templates(name):
|
||||
pf = _parse_features
|
||||
if name == 'ner':
|
||||
|
@ -68,7 +60,7 @@ def get_templates(name):
|
|||
elif name == 'debug':
|
||||
return pf.unigrams
|
||||
else:
|
||||
return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \
|
||||
return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \
|
||||
pf.tree_shape + pf.trigrams)
|
||||
|
||||
|
||||
|
@ -81,16 +73,14 @@ cdef class Parser:
|
|||
self.model = Model(self.moves.n_moves, templates, model_dir)
|
||||
|
||||
def __call__(self, Tokens tokens):
|
||||
if tokens.length == 0:
|
||||
return 0
|
||||
if self.cfg.get('beam_width', 1) <= 1:
|
||||
if self.cfg.get('beam_width', 1) < 1:
|
||||
self._greedy_parse(tokens)
|
||||
else:
|
||||
self._beam_parse(tokens)
|
||||
|
||||
def train(self, Tokens tokens, GoldParse gold):
|
||||
self.moves.preprocess_gold(gold)
|
||||
if self.cfg.beam_width <= 1:
|
||||
if self.cfg.beam_width < 1:
|
||||
return self._greedy_train(tokens, gold)
|
||||
else:
|
||||
return self._beam_train(tokens, gold)
|
||||
|
@ -99,31 +89,36 @@ cdef class Parser:
|
|||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef int n_feats
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||
self.moves.initialize_state(state)
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
cdef Transition guess
|
||||
while not is_final(state):
|
||||
fill_context(context, state)
|
||||
words = [w.orth_ for w in tokens]
|
||||
while not stcls.is_final():
|
||||
fill_context(context, stcls)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
guess.do(state, guess.label)
|
||||
self.moves.finalize_state(state)
|
||||
tokens.set_parse(state.sent)
|
||||
guess = self.moves.best_valid(scores, stcls)
|
||||
#print self.moves.move_name(guess.move, guess.label), stcls.print_state(words)
|
||||
guess.do(stcls, guess.label)
|
||||
assert stcls._s_i >= 0
|
||||
self.moves.finalize_state(stcls)
|
||||
tokens.set_parse(stcls._sent)
|
||||
|
||||
cdef int _beam_parse(self, Tokens tokens) except -1:
|
||||
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||
words = [w.orth_ for w in tokens]
|
||||
beam.initialize(_init_state, tokens.length, tokens.data)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
while not beam.is_done:
|
||||
self._advance_beam(beam, None, False)
|
||||
state = <State*>beam.at(0)
|
||||
self._advance_beam(beam, None, False, words)
|
||||
state = <StateClass>beam.at(0)
|
||||
self.moves.finalize_state(state)
|
||||
tokens.set_parse(state.sent)
|
||||
tokens.set_parse(state._sent)
|
||||
_cleanup(beam)
|
||||
|
||||
def _greedy_train(self, Tokens tokens, GoldParse gold):
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||
self.moves.initialize_state(state)
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
|
||||
cdef int cost
|
||||
cdef const Feature* feats
|
||||
|
@ -132,14 +127,16 @@ cdef class Parser:
|
|||
cdef Transition best
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
loss = 0
|
||||
while not is_final(state):
|
||||
fill_context(context, state)
|
||||
words = [w.orth_ for w in tokens]
|
||||
history = []
|
||||
while not stcls.is_final():
|
||||
fill_context(context, stcls)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
best = self.moves.best_gold(scores, state, gold)
|
||||
cost = guess.get_cost(state, &gold.c, guess.label)
|
||||
guess = self.moves.best_valid(scores, stcls)
|
||||
best = self.moves.best_gold(scores, stcls, gold)
|
||||
cost = guess.get_cost(stcls, &gold.c, guess.label)
|
||||
self.model.update(context, guess.clas, best.clas, cost)
|
||||
guess.do(state, guess.label)
|
||||
guess.do(stcls, guess.label)
|
||||
loss += cost
|
||||
return loss
|
||||
|
||||
|
@ -152,9 +149,10 @@ cdef class Parser:
|
|||
gold.check_done(_check_final_state, NULL)
|
||||
|
||||
violn = MaxViolation()
|
||||
words = [w.orth_ for w in tokens]
|
||||
while not pred.is_done and not gold.is_done:
|
||||
self._advance_beam(pred, gold_parse, False)
|
||||
self._advance_beam(gold, gold_parse, True)
|
||||
self._advance_beam(pred, gold_parse, False, words)
|
||||
self._advance_beam(gold, gold_parse, True, words)
|
||||
violn.check(pred, gold)
|
||||
if pred.loss >= 1:
|
||||
counts = {clas: {} for clas in range(self.model.n_classes)}
|
||||
|
@ -163,62 +161,90 @@ cdef class Parser:
|
|||
else:
|
||||
counts = {}
|
||||
self.model._model.update(counts)
|
||||
_cleanup(pred)
|
||||
_cleanup(gold)
|
||||
return pred.loss
|
||||
|
||||
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
||||
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold, words):
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef State* state
|
||||
cdef int i, j, cost
|
||||
cdef bint is_valid
|
||||
cdef const Transition* move
|
||||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
if not is_final(state):
|
||||
fill_context(context, state)
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.is_final():
|
||||
fill_context(context, stcls)
|
||||
self.model.set_scores(beam.scores[i], context)
|
||||
self.moves.set_valid(beam.is_valid[i], state)
|
||||
|
||||
self.moves.set_valid(beam.is_valid[i], stcls)
|
||||
if gold is not None:
|
||||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
self.moves.set_costs(beam.costs[i], state, gold)
|
||||
if follow_gold:
|
||||
for j in range(self.moves.n_moves):
|
||||
beam.is_valid[i][j] *= beam.costs[i][j] == 0
|
||||
beam.advance(_transition_state, <void*>self.moves.c)
|
||||
state = <State*>beam.at(0)
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.is_final():
|
||||
self.moves.set_costs(beam.costs[i], stcls, gold)
|
||||
if follow_gold:
|
||||
for j in range(self.moves.n_moves):
|
||||
beam.is_valid[i][j] *= beam.costs[i][j] == 0
|
||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
|
||||
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||
self.moves.initialize_state(state)
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
|
||||
cdef class_t clas
|
||||
cdef int n_feats
|
||||
for clas in hist:
|
||||
fill_context(context, state)
|
||||
fill_context(context, stcls)
|
||||
feats = self.model._extractor.get_feats(context, &n_feats)
|
||||
count_feats(counts[clas], feats, n_feats, inc)
|
||||
self.moves.c[clas].do(state, self.moves.c[clas].label)
|
||||
self.moves.c[clas].do(stcls, self.moves.c[clas].label)
|
||||
|
||||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
|
||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||
dest = <State*>_dest
|
||||
src = <const State*>_src
|
||||
dest = <StateClass>_dest
|
||||
src = <StateClass>_src
|
||||
moves = <const Transition*>_moves
|
||||
copy_state(dest, src)
|
||||
dest.clone(src)
|
||||
moves[clas].do(dest, moves[clas].label)
|
||||
|
||||
|
||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||
state = new_state(mem, <const TokenC*>tokens, length)
|
||||
push_stack(state)
|
||||
return state
|
||||
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
|
||||
st.fast_forward()
|
||||
Py_INCREF(st)
|
||||
return <void*>st
|
||||
|
||||
|
||||
cdef int _check_final_state(void* state, void* extra_args) except -1:
|
||||
return is_final(<State*>state)
|
||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||
return (<StateClass>_state).is_final()
|
||||
|
||||
|
||||
def _cleanup(Beam beam):
|
||||
for i in range(beam.width):
|
||||
Py_XDECREF(<PyObject*>beam._states[i].content)
|
||||
Py_XDECREF(<PyObject*>beam._parents[i].content)
|
||||
|
||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||
return <hash_t>_state
|
||||
|
||||
#state = <const State*>_state
|
||||
#cdef atom_t[10] rep
|
||||
|
||||
#rep[0] = state.stack[0] if state.stack_len >= 1 else 0
|
||||
#rep[1] = state.stack[-1] if state.stack_len >= 2 else 0
|
||||
#rep[2] = state.stack[-2] if state.stack_len >= 3 else 0
|
||||
#rep[3] = state.i
|
||||
#rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0
|
||||
#rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0
|
||||
#rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0
|
||||
#rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0
|
||||
#if get_left(state, get_n0(state), 1) != NULL:
|
||||
# rep[8] = get_left(state, get_n0(state), 1).dep
|
||||
#else:
|
||||
# rep[8] = 0
|
||||
#rep[9] = state.sent[state.i].l_kids
|
||||
#return hash64(rep, sizeof(atom_t) * 10, 0)
|
||||
|
|
104
spacy/syntax/stateclass.pxd
Normal file
104
spacy/syntax/stateclass.pxd
Normal file
|
@ -0,0 +1,104 @@
|
|||
from libc.string cimport memcpy, memset
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from ..structs cimport TokenC, Entity
|
||||
|
||||
from ..vocab cimport EMPTY_LEXEME
|
||||
|
||||
|
||||
cdef class StateClass:
|
||||
cdef Pool mem
|
||||
cdef int* _stack
|
||||
cdef int* _buffer
|
||||
cdef bint* shifted
|
||||
cdef TokenC* _sent
|
||||
cdef Entity* _ents
|
||||
cdef TokenC _empty_token
|
||||
cdef int length
|
||||
cdef int _s_i
|
||||
cdef int _b_i
|
||||
cdef int _e_i
|
||||
cdef int _break
|
||||
|
||||
@staticmethod
|
||||
cdef inline StateClass init(const TokenC* sent, int length):
|
||||
cdef StateClass self = StateClass(length)
|
||||
cdef int i
|
||||
for i in range(length):
|
||||
self._sent[i] = sent[i]
|
||||
self._buffer[i] = i
|
||||
for i in range(length, length + 5):
|
||||
self._sent[i].lex = &EMPTY_LEXEME
|
||||
return self
|
||||
|
||||
cdef inline int S(self, int i) nogil:
|
||||
if i >= self._s_i:
|
||||
return -1
|
||||
return self._stack[self._s_i - (i+1)]
|
||||
|
||||
cdef inline int B(self, int i) nogil:
|
||||
if (i + self._b_i) >= self.length:
|
||||
return -1
|
||||
return self._buffer[self._b_i + i]
|
||||
|
||||
cdef int H(self, int i) nogil
|
||||
cdef int E(self, int i) nogil
|
||||
|
||||
cdef int L(self, int i, int idx) nogil
|
||||
cdef int R(self, int i, int idx) nogil
|
||||
|
||||
cdef const TokenC* S_(self, int i) nogil
|
||||
cdef const TokenC* B_(self, int i) nogil
|
||||
|
||||
cdef const TokenC* H_(self, int i) nogil
|
||||
cdef const TokenC* E_(self, int i) nogil
|
||||
|
||||
cdef const TokenC* L_(self, int i, int idx) nogil
|
||||
cdef const TokenC* R_(self, int i, int idx) nogil
|
||||
|
||||
cdef const TokenC* safe_get(self, int i) nogil
|
||||
|
||||
cdef bint empty(self) nogil
|
||||
|
||||
cdef bint entity_is_open(self) nogil
|
||||
|
||||
cdef bint eol(self) nogil
|
||||
|
||||
cdef bint at_break(self) nogil
|
||||
|
||||
cdef bint is_final(self) nogil
|
||||
|
||||
cdef bint has_head(self, int i) nogil
|
||||
|
||||
cdef int n_L(self, int i) nogil
|
||||
|
||||
cdef int n_R(self, int i) nogil
|
||||
|
||||
cdef bint stack_is_connected(self) nogil
|
||||
|
||||
cdef int stack_depth(self) nogil
|
||||
|
||||
cdef int buffer_length(self) nogil
|
||||
|
||||
cdef void push(self) nogil
|
||||
|
||||
cdef void pop(self) nogil
|
||||
|
||||
cdef void unshift(self) nogil
|
||||
|
||||
cdef void add_arc(self, int head, int child, int label) nogil
|
||||
|
||||
cdef void del_arc(self, int head, int child) nogil
|
||||
|
||||
cdef void open_ent(self, int label) nogil
|
||||
|
||||
cdef void close_ent(self) nogil
|
||||
|
||||
cdef void set_ent_tag(self, int i, int ent_iob, int ent_type) nogil
|
||||
|
||||
cdef void set_break(self, int i) nogil
|
||||
|
||||
cdef void clone(self, StateClass src) nogil
|
||||
|
||||
cdef void fast_forward(self) nogil
|
253
spacy/syntax/stateclass.pyx
Normal file
253
spacy/syntax/stateclass.pyx
Normal file
|
@ -0,0 +1,253 @@
|
|||
from libc.string cimport memcpy, memset
|
||||
from libc.stdint cimport uint32_t
|
||||
from ..vocab cimport EMPTY_LEXEME
|
||||
from ..structs cimport Entity
|
||||
|
||||
|
||||
cdef class StateClass:
|
||||
def __init__(self, int length):
|
||||
cdef Pool mem = Pool()
|
||||
PADDING = 5
|
||||
self._buffer = <int*>mem.alloc(length + PADDING, sizeof(int))
|
||||
self._stack = <int*>mem.alloc(length + PADDING, sizeof(int))
|
||||
self.shifted = <bint*>mem.alloc(length + PADDING, sizeof(bint))
|
||||
self._sent = <TokenC*>mem.alloc(length + PADDING, sizeof(TokenC))
|
||||
self._ents = <Entity*>mem.alloc(length + PADDING, sizeof(Entity))
|
||||
cdef int i
|
||||
for i in range(length):
|
||||
self._ents[i].end = -1
|
||||
for i in range(length, length + PADDING):
|
||||
self._sent[i].lex = &EMPTY_LEXEME
|
||||
self.mem = mem
|
||||
self.length = length
|
||||
self._break = -1
|
||||
self._s_i = 0
|
||||
self._b_i = 0
|
||||
self._e_i = 0
|
||||
for i in range(length):
|
||||
self._buffer[i] = i
|
||||
self._empty_token.lex = &EMPTY_LEXEME
|
||||
|
||||
cdef int H(self, int i) nogil:
|
||||
if i < 0 or i >= self.length:
|
||||
return -1
|
||||
return self._sent[i].head + i
|
||||
|
||||
cdef int E(self, int i) nogil:
|
||||
if self._e_i <= 0 or self._e_i >= self.length:
|
||||
return 0
|
||||
if i < 0 or i >= self.length:
|
||||
return 0
|
||||
return self._ents[self._e_i-1].start
|
||||
|
||||
cdef int L(self, int i, int idx) nogil:
|
||||
if idx < 1:
|
||||
return -1
|
||||
if i < 0 or i >= self.length:
|
||||
return -1
|
||||
cdef const TokenC* target = &self._sent[i]
|
||||
cdef const TokenC* ptr = self._sent
|
||||
|
||||
while ptr < target:
|
||||
# If this head is still to the right of us, we can skip to it
|
||||
# No token that's between this token and this head could be our
|
||||
# child.
|
||||
if (ptr.head >= 1) and (ptr + ptr.head) < target:
|
||||
ptr += ptr.head
|
||||
|
||||
elif ptr + ptr.head == target:
|
||||
idx -= 1
|
||||
if idx == 0:
|
||||
return ptr - self._sent
|
||||
ptr += 1
|
||||
else:
|
||||
ptr += 1
|
||||
return -1
|
||||
|
||||
cdef int R(self, int i, int idx) nogil:
|
||||
if idx < 1:
|
||||
return -1
|
||||
if i < 0 or i >= self.length:
|
||||
return -1
|
||||
cdef const TokenC* ptr = self._sent + (self.length - 1)
|
||||
cdef const TokenC* target = &self._sent[i]
|
||||
while ptr > target:
|
||||
# If this head is still to the right of us, we can skip to it
|
||||
# No token that's between this token and this head could be our
|
||||
# child.
|
||||
if (ptr.head < 0) and ((ptr + ptr.head) > target):
|
||||
ptr += ptr.head
|
||||
elif ptr + ptr.head == target:
|
||||
idx -= 1
|
||||
if idx == 0:
|
||||
return ptr - self._sent
|
||||
ptr -= 1
|
||||
else:
|
||||
ptr -= 1
|
||||
return -1
|
||||
|
||||
cdef const TokenC* S_(self, int i) nogil:
|
||||
return self.safe_get(self.S(i))
|
||||
|
||||
cdef const TokenC* B_(self, int i) nogil:
|
||||
return self.safe_get(self.B(i))
|
||||
|
||||
cdef const TokenC* H_(self, int i) nogil:
|
||||
return self.safe_get(self.H(i))
|
||||
|
||||
cdef const TokenC* E_(self, int i) nogil:
|
||||
return self.safe_get(self.E(i))
|
||||
|
||||
cdef const TokenC* L_(self, int i, int idx) nogil:
|
||||
return self.safe_get(self.L(i, idx))
|
||||
|
||||
cdef const TokenC* R_(self, int i, int idx) nogil:
|
||||
return self.safe_get(self.R(i, idx))
|
||||
|
||||
cdef const TokenC* safe_get(self, int i) nogil:
|
||||
if i < 0 or i >= self.length:
|
||||
return &self._empty_token
|
||||
else:
|
||||
return &self._sent[i]
|
||||
|
||||
cdef bint empty(self) nogil:
|
||||
return self._s_i <= 0
|
||||
|
||||
cdef bint eol(self) nogil:
|
||||
return self.buffer_length() == 0
|
||||
|
||||
cdef bint at_break(self) nogil:
|
||||
return self._break != -1
|
||||
|
||||
cdef bint is_final(self) nogil:
|
||||
return self.stack_depth() <= 0 and self._b_i >= self.length
|
||||
|
||||
cdef bint has_head(self, int i) nogil:
|
||||
return self.safe_get(i).head != 0
|
||||
|
||||
cdef int n_L(self, int i) nogil:
|
||||
return self.safe_get(i).l_kids
|
||||
|
||||
cdef int n_R(self, int i) nogil:
|
||||
return self.safe_get(i).r_kids
|
||||
|
||||
cdef bint stack_is_connected(self) nogil:
|
||||
return False
|
||||
|
||||
cdef bint entity_is_open(self) nogil:
|
||||
if self._e_i < 1:
|
||||
return False
|
||||
return self._ents[self._e_i-1].end == -1
|
||||
|
||||
cdef int stack_depth(self) nogil:
|
||||
return self._s_i
|
||||
|
||||
cdef int buffer_length(self) nogil:
|
||||
if self._break != -1:
|
||||
return self._break - self._b_i
|
||||
else:
|
||||
return self.length - self._b_i
|
||||
|
||||
cdef void push(self) nogil:
|
||||
if self.B(0) != -1:
|
||||
self._stack[self._s_i] = self.B(0)
|
||||
self._s_i += 1
|
||||
self._b_i += 1
|
||||
if self._b_i > self._break:
|
||||
self._break = -1
|
||||
|
||||
cdef void pop(self) nogil:
|
||||
if self._s_i >= 1:
|
||||
self._s_i -= 1
|
||||
|
||||
cdef void unshift(self) nogil:
|
||||
self._b_i -= 1
|
||||
self._buffer[self._b_i] = self.S(0)
|
||||
self._s_i -= 1
|
||||
self.shifted[self.B(0)] = True
|
||||
|
||||
cdef void fast_forward(self) nogil:
|
||||
while self.buffer_length() == 0 or self.stack_depth() == 0:
|
||||
if self.buffer_length() == 1 and self.stack_depth() == 0:
|
||||
self.push()
|
||||
self.pop()
|
||||
elif self.buffer_length() == 0 and self.stack_depth() == 1:
|
||||
self.pop()
|
||||
elif self.buffer_length() == 0 and self.stack_depth() >= 2:
|
||||
if self.has_head(self.S(0)):
|
||||
self.pop()
|
||||
else:
|
||||
self.unshift()
|
||||
elif (self.length - self._b_i) >= 1 and self.stack_depth() == 0:
|
||||
self.push()
|
||||
else:
|
||||
break
|
||||
|
||||
cdef void add_arc(self, int head, int child, int label) nogil:
|
||||
if self.has_head(child):
|
||||
self.del_arc(self.H(child), child)
|
||||
|
||||
cdef int dist = head - child
|
||||
self._sent[child].head = dist
|
||||
self._sent[child].dep = label
|
||||
cdef int i
|
||||
if child > head:
|
||||
self._sent[head].r_kids += 1
|
||||
self._sent[head].r_edge = child
|
||||
i = 0
|
||||
while self.has_head(head) and i < self.length:
|
||||
self._sent[head].r_edge = child
|
||||
head = self.H(head)
|
||||
i += 1 # Guard against infinite loops
|
||||
else:
|
||||
self._sent[head].l_kids += 1
|
||||
self._sent[head].l_edge = self._sent[child].l_edge
|
||||
|
||||
cdef void del_arc(self, int h_i, int c_i) nogil:
|
||||
cdef int dist = h_i - c_i
|
||||
cdef TokenC* h = &self._sent[h_i]
|
||||
if c_i > h_i:
|
||||
h.r_kids -= 1
|
||||
h.r_edge = self.R_(h_i, h.r_kids-1).r_edge if h.r_kids >= 1 else h_i
|
||||
else:
|
||||
h.l_kids -= 1
|
||||
h.l_edge = self.L_(h_i, h.l_kids-1).l_edge if h.l_kids >= 1 else h_i
|
||||
|
||||
cdef void open_ent(self, int label) nogil:
|
||||
self._ents[self._e_i].start = self.B(0)
|
||||
self._ents[self._e_i].label = label
|
||||
self._ents[self._e_i].end = -1
|
||||
self._e_i += 1
|
||||
|
||||
cdef void close_ent(self) nogil:
|
||||
self._ents[self._e_i-1].end = self.B(0)+1
|
||||
self._sent[self.B(0)].ent_iob = 1
|
||||
|
||||
cdef void set_ent_tag(self, int i, int ent_iob, int ent_type) nogil:
|
||||
if 0 <= i < self.length:
|
||||
self._sent[i].ent_iob = ent_iob
|
||||
self._sent[i].ent_type = ent_type
|
||||
|
||||
cdef void set_break(self, int _) nogil:
|
||||
if 0 <= self.B(0) < self.length:
|
||||
self._sent[self.B(0)].sent_start = True
|
||||
self._break = self._b_i
|
||||
|
||||
cdef void clone(self, StateClass src) nogil:
|
||||
memcpy(self._sent, src._sent, self.length * sizeof(TokenC))
|
||||
memcpy(self._stack, src._stack, self.length * sizeof(int))
|
||||
memcpy(self._buffer, src._buffer, self.length * sizeof(int))
|
||||
memcpy(self._ents, src._ents, self.length * sizeof(Entity))
|
||||
self._b_i = src._b_i
|
||||
self._s_i = src._s_i
|
||||
self._e_i = src._e_i
|
||||
self._break = src._break
|
||||
|
||||
def print_state(self, words):
|
||||
words = list(words) + ['_']
|
||||
top = words[self.S(0)] + '_%d' % self.S_(0).head
|
||||
second = words[self.S(1)] + '_%d' % self.S_(1).head
|
||||
third = words[self.S(2)] + '_%d' % self.S_(2).head
|
||||
n0 = words[self.B(0)]
|
||||
n1 = words[self.B(1)]
|
||||
return ' '.join((third, second, top, '|', n0, n1))
|
|
@ -2,11 +2,12 @@ from cymem.cymem cimport Pool
|
|||
from thinc.typedefs cimport weight_t
|
||||
|
||||
from ..structs cimport TokenC
|
||||
from ._state cimport State
|
||||
from ..gold cimport GoldParse
|
||||
from ..gold cimport GoldParseC
|
||||
from ..strings cimport StringStore
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
|
||||
cdef struct Transition:
|
||||
int clas
|
||||
|
@ -15,16 +16,16 @@ cdef struct Transition:
|
|||
|
||||
weight_t score
|
||||
|
||||
bint (*is_valid)(const State* state, int label) except -1
|
||||
int (*get_cost)(const State* state, const GoldParseC* gold, int label) except -1
|
||||
int (*do)(State* state, int label) except -1
|
||||
bint (*is_valid)(StateClass state, int label) nogil
|
||||
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||
int (*do)(StateClass state, int label) nogil
|
||||
|
||||
|
||||
ctypedef int (*get_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
|
||||
ctypedef int (*move_cost_func_t)(const State* state, const GoldParseC* gold) except -1
|
||||
ctypedef int (*label_cost_func_t)(const State* state, const GoldParseC* gold, int label) except -1
|
||||
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil
|
||||
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
|
||||
|
||||
ctypedef int (*do_func_t)(State* state, int label) except -1
|
||||
ctypedef int (*do_func_t)(StateClass state, int label) nogil
|
||||
|
||||
|
||||
cdef class TransitionSystem:
|
||||
|
@ -34,8 +35,8 @@ cdef class TransitionSystem:
|
|||
cdef bint* _is_valid
|
||||
cdef readonly int n_moves
|
||||
|
||||
cdef int initialize_state(self, State* state) except -1
|
||||
cdef int finalize_state(self, State* state) except -1
|
||||
cdef int initialize_state(self, StateClass state) except -1
|
||||
cdef int finalize_state(self, StateClass state) except -1
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1
|
||||
|
||||
|
@ -43,18 +44,11 @@ cdef class TransitionSystem:
|
|||
|
||||
cdef Transition init_transition(self, int clas, int move, int label) except *
|
||||
|
||||
cdef int set_valid(self, bint* output, const State* state) except -1
|
||||
cdef int set_valid(self, bint* output, StateClass state) except -1
|
||||
|
||||
cdef int set_costs(self, int* output, const State* state, GoldParse gold) except -1
|
||||
cdef int set_costs(self, int* output, StateClass state, GoldParse gold) except -1
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* state) except *
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass stcls) except *
|
||||
|
||||
cdef Transition best_gold(self, const weight_t* scores, const State* state,
|
||||
cdef Transition best_gold(self, const weight_t* scores, StateClass state,
|
||||
GoldParse gold) except *
|
||||
|
||||
|
||||
#cdef class PyState:
|
||||
# """Provide a Python class for testing purposes."""
|
||||
# cdef Pool mem
|
||||
# cdef TransitionSystem system
|
||||
# cdef State* _state
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from cymem.cymem cimport Pool
|
||||
from ._state cimport State
|
||||
from ..structs cimport TokenC
|
||||
from thinc.typedefs cimport weight_t
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
|
||||
cdef weight_t MIN_SCORE = -90000
|
||||
|
||||
|
@ -27,10 +28,10 @@ cdef class TransitionSystem:
|
|||
i += 1
|
||||
self.c = moves
|
||||
|
||||
cdef int initialize_state(self, State* state) except -1:
|
||||
cdef int initialize_state(self, StateClass state) except -1:
|
||||
pass
|
||||
|
||||
cdef int finalize_state(self, State* state) except -1:
|
||||
cdef int finalize_state(self, StateClass state) except -1:
|
||||
pass
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
|
@ -42,62 +43,30 @@ cdef class TransitionSystem:
|
|||
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
||||
raise NotImplementedError
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
cdef Transition best_valid(self, const weight_t* scores, StateClass s) except *:
|
||||
raise NotImplementedError
|
||||
|
||||
cdef int set_valid(self, bint* output, const State* state) except -1:
|
||||
cdef int set_valid(self, bint* output, StateClass state) except -1:
|
||||
raise NotImplementedError
|
||||
|
||||
cdef int set_costs(self, int* output, const State* s, GoldParse gold) except -1:
|
||||
cdef int set_costs(self, int* output, StateClass stcls, GoldParse gold) except -1:
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
output[i] = self.c[i].get_cost(s, &gold.c, self.c[i].label)
|
||||
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||
output[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
|
||||
else:
|
||||
output[i] = 9000
|
||||
|
||||
cdef Transition best_gold(self, const weight_t* scores, const State* s,
|
||||
cdef Transition best_gold(self, const weight_t* scores, StateClass stcls,
|
||||
GoldParse gold) except *:
|
||||
cdef Transition best
|
||||
cdef weight_t score = MIN_SCORE
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
cost = self.c[i].get_cost(s, &gold.c, self.c[i].label)
|
||||
if scores[i] > score and cost == 0:
|
||||
best = self.c[i]
|
||||
score = scores[i]
|
||||
if self.c[i].is_valid(stcls, self.c[i].label):
|
||||
cost = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
|
||||
if scores[i] > score and cost == 0:
|
||||
best = self.c[i]
|
||||
score = scores[i]
|
||||
assert score > MIN_SCORE
|
||||
return best
|
||||
|
||||
|
||||
#cdef class PyState:
|
||||
# """Provide a Python class for testing purposes."""
|
||||
# def __init__(self, GoldParse gold):
|
||||
# self.mem = Pool()
|
||||
# self.system = EntityRecognition(labels)
|
||||
# self._state = init_state(self.mem, tokens, gold.length)
|
||||
#
|
||||
# def transition(self, name):
|
||||
# cdef const Transition* trans = self._transition_by_name(name)
|
||||
# trans.do(trans, self._state)
|
||||
#
|
||||
# def is_valid(self, name):
|
||||
# cdef const Transition* trans = self._transition_by_name(name)
|
||||
# return _is_valid(trans.move, trans.label, self._state)
|
||||
#
|
||||
# def is_gold(self, name):
|
||||
# cdef const Transition* trans = self._transition_by_name(name)
|
||||
# return _get_const(trans, self._state, self._gold)
|
||||
#
|
||||
# property ent:
|
||||
# def __get__(self):
|
||||
# pass
|
||||
#
|
||||
# property n_ents:
|
||||
# def __get__(self):
|
||||
# pass
|
||||
#
|
||||
# property i:
|
||||
# def __get__(self):
|
||||
# pass
|
||||
#
|
||||
# property open_entity:
|
||||
# def __get__(self):
|
||||
# return entity_is_open(self._s)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from libc.stdint cimport uint32_t
|
||||
|
||||
from numpy cimport ndarray
|
||||
cimport numpy
|
||||
cimport numpy as np
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
from thinc.typedefs cimport atom_t
|
||||
|
@ -47,7 +47,7 @@ cdef class Tokens:
|
|||
|
||||
cdef int push_back(self, int i, LexemeOrToken lex_or_tok) except -1
|
||||
|
||||
cpdef long[:,:] to_array(self, object features)
|
||||
cpdef np.ndarray to_array(self, object features)
|
||||
|
||||
cdef int set_parse(self, const TokenC* parsed) except -1
|
||||
|
||||
|
|
|
@ -17,8 +17,11 @@ from .spans import Span
|
|||
from .structs cimport UniStr
|
||||
|
||||
from unidecode import unidecode
|
||||
# Compiler crashes on memory view coercion without this. Should report bug.
|
||||
from cython.view cimport array as cvarray
|
||||
cimport numpy as np
|
||||
np.import_array()
|
||||
|
||||
cimport numpy
|
||||
import numpy
|
||||
|
||||
cimport cython
|
||||
|
@ -183,15 +186,12 @@ cdef class Tokens:
|
|||
"""
|
||||
cdef int i
|
||||
cdef Tokens sent = Tokens(self.vocab, self._string[self.data[0].idx:])
|
||||
start = None
|
||||
for i in range(self.length):
|
||||
if start is None:
|
||||
start = 0
|
||||
for i in range(1, self.length):
|
||||
if self.data[i].sent_start:
|
||||
yield Span(self, start, i)
|
||||
start = i
|
||||
if self.data[i].sent_end:
|
||||
yield Span(self, start, i+1)
|
||||
start = None
|
||||
if start is not None:
|
||||
yield Span(self, start, self.length)
|
||||
yield Span(self, start, self.length)
|
||||
|
||||
cdef int push_back(self, int idx, LexemeOrToken lex_or_tok) except -1:
|
||||
if self.length == self.max_length:
|
||||
|
@ -207,7 +207,7 @@ cdef class Tokens:
|
|||
return idx + t.lex.length
|
||||
|
||||
@cython.boundscheck(False)
|
||||
cpdef long[:,:] to_array(self, object py_attr_ids):
|
||||
cpdef np.ndarray to_array(self, object py_attr_ids):
|
||||
"""Given a list of M attribute IDs, export the tokens to a numpy ndarray
|
||||
of shape N*M, where N is the length of the sentence.
|
||||
|
||||
|
@ -221,10 +221,10 @@ cdef class Tokens:
|
|||
"""
|
||||
cdef int i, j
|
||||
cdef attr_id_t feature
|
||||
cdef numpy.ndarray[long, ndim=2] output
|
||||
cdef np.ndarray[long, ndim=2] output
|
||||
# Make an array from the attributes --- otherwise our inner loop is Python
|
||||
# dict iteration.
|
||||
cdef numpy.ndarray[long, ndim=1] attr_ids = numpy.asarray(py_attr_ids)
|
||||
cdef np.ndarray[long, ndim=1] attr_ids = numpy.asarray(py_attr_ids)
|
||||
output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int)
|
||||
for i in range(self.length):
|
||||
for j, feature in enumerate(attr_ids):
|
||||
|
@ -464,7 +464,9 @@ cdef class Token:
|
|||
|
||||
property repvec:
|
||||
def __get__(self):
|
||||
return numpy.asarray(<float[:self.vocab.repvec_length,]> self.c.lex.repvec)
|
||||
cdef int length = self.vocab.repvec_length
|
||||
repvec_view = <float[:length,]>self.c.lex.repvec
|
||||
return numpy.asarray(repvec_view)
|
||||
|
||||
property n_lefts:
|
||||
def __get__(self):
|
||||
|
@ -546,13 +548,13 @@ cdef class Token:
|
|||
property left_edge:
|
||||
def __get__(self):
|
||||
return Token.cinit(self.vocab, self._string,
|
||||
self.c + self.c.l_edge, self.i + self.c.l_edge,
|
||||
(self.c - self.i) + self.c.l_edge, self.c.l_edge,
|
||||
self.array_len, self._seq)
|
||||
|
||||
property right_edge:
|
||||
def __get__(self):
|
||||
return Token.cinit(self.vocab, self._string,
|
||||
self.c + self.c.r_edge, self.i + self.c.r_edge,
|
||||
(self.c - self.i) + self.c.r_edge, self.c.r_edge,
|
||||
self.array_len, self._seq)
|
||||
|
||||
property head:
|
||||
|
|
59
tests/munge/test_bad_periods.py
Normal file
59
tests/munge/test_bad_periods.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import spacy.munge.read_conll
|
||||
|
||||
hongbin_example = """
|
||||
1 2. 0. LS _ 24 meta _ _ _
|
||||
2 . . . _ 1 punct _ _ _
|
||||
3 Wang wang NNP _ 4 compound _ _ _
|
||||
4 Hongbin hongbin NNP _ 16 nsubj _ _ _
|
||||
5 , , , _ 4 punct _ _ _
|
||||
6 the the DT _ 11 det _ _ _
|
||||
7 " " `` _ 11 punct _ _ _
|
||||
8 communist communist JJ _ 11 amod _ _ _
|
||||
9 trail trail NN _ 11 compound _ _ _
|
||||
10 - - HYPH _ 11 punct _ _ _
|
||||
11 blazer blazer NN _ 4 appos _ _ _
|
||||
12 , , , _ 16 punct _ _ _
|
||||
13 " " '' _ 16 punct _ _ _
|
||||
14 has have VBZ _ 16 aux _ _ _
|
||||
15 not not RB _ 16 neg _ _ _
|
||||
16 turned turn VBN _ 24 ccomp _ _ _
|
||||
17 into into IN syn=CLR 16 prep _ _ _
|
||||
18 a a DT _ 19 det _ _ _
|
||||
19 capitalist capitalist NN _ 17 pobj _ _ _
|
||||
20 ( ( -LRB- _ 24 punct _ _ _
|
||||
21 he he PRP _ 24 nsubj _ _ _
|
||||
22 does do VBZ _ 24 aux _ _ _
|
||||
23 n't not RB _ 24 neg _ _ _
|
||||
24 have have VB _ 0 root _ _ _
|
||||
25 any any DT _ 26 det _ _ _
|
||||
26 shares share NNS _ 24 dobj _ _ _
|
||||
27 , , , _ 24 punct _ _ _
|
||||
28 does do VBZ _ 30 aux _ _ _
|
||||
29 n't not RB _ 30 neg _ _ _
|
||||
30 have have VB _ 24 conj _ _ _
|
||||
31 any any DT _ 32 det _ _ _
|
||||
32 savings saving NNS _ 30 dobj _ _ _
|
||||
33 , , , _ 30 punct _ _ _
|
||||
34 does do VBZ _ 36 aux _ _ _
|
||||
35 n't not RB _ 36 neg _ _ _
|
||||
36 have have VB _ 30 conj _ _ _
|
||||
37 his his PRP$ _ 39 poss _ _ _
|
||||
38 own own JJ _ 39 amod _ _ _
|
||||
39 car car NN _ 36 dobj _ _ _
|
||||
40 , , , _ 36 punct _ _ _
|
||||
41 and and CC _ 36 cc _ _ _
|
||||
42 does do VBZ _ 44 aux _ _ _
|
||||
43 n't not RB _ 44 neg _ _ _
|
||||
44 have have VB _ 36 conj _ _ _
|
||||
45 a a DT _ 46 det _ _ _
|
||||
46 mansion mansion NN _ 44 dobj _ _ _
|
||||
47 ; ; . _ 24 punct _ _ _
|
||||
""".strip()
|
||||
|
||||
|
||||
def test_hongbin():
|
||||
words, annot = spacy.munge.read_conll.parse(hongbin_example, strip_bad_periods=True)
|
||||
assert words[annot[0]['head']] == 'have'
|
||||
assert words[annot[1]['head']] == 'Hongbin'
|
||||
|
||||
|
|
@ -103,10 +103,12 @@ def test_cnts5(en_tokenizer):
|
|||
tokens = en_tokenizer(text)
|
||||
assert len(tokens) == 11
|
||||
|
||||
def test_mr(en_tokenizer):
|
||||
text = """Mr. Smith"""
|
||||
tokens = en_tokenizer(text)
|
||||
assert len(tokens) == 2
|
||||
# TODO: This is currently difficult --- infix interferes here.
|
||||
#def test_mr(en_tokenizer):
|
||||
# text = """Today is Tuesday.Mr."""
|
||||
# tokens = en_tokenizer(text)
|
||||
# assert len(tokens) == 5
|
||||
# assert [w.orth_ for w in tokens] == ['Today', 'is', 'Tuesday', '.', 'Mr.']
|
||||
|
||||
|
||||
def test_cnts6(en_tokenizer):
|
||||
|
|
Loading…
Reference in New Issue
Block a user