This commit is contained in:
Matthew Honnibal 2014-12-24 17:42:00 +11:00
parent 75a6930ad9
commit b8b65903fc
10 changed files with 161 additions and 222 deletions

View File

@ -23,6 +23,7 @@ class English(object):
self.tokenizer = Tokenizer.from_dir(self.vocab, data_dir)
self.tagger = EnPosTagger(self.vocab.strings, data_dir) if tag else None
self.parser = GreedyParser(path.join(data_dir, 'deps')) if parse else None
self.strings = self.vocab.strings
def __call__(self, text, tag=True, parse=True):
tokens = self.tokenizer.tokenize(text)
@ -31,3 +32,10 @@ class English(object):
if self.parser and parse:
self.parser.parse(tokens)
return tokens
@property
def tags(self):
if self.tagger is None:
return []
else:
return self.tagger.tag_names

View File

@ -4,10 +4,12 @@ from ..typedefs cimport ID as _ID
from ..typedefs cimport SIC as _SIC
from ..typedefs cimport SHAPE as _SHAPE
from ..typedefs cimport DENSE as _DENSE
from ..typedefs cimport CLUSTER as _CLUSTER
from ..typedefs cimport SHAPE as _SHAPE
from ..typedefs cimport PREFIX as _PREFIX
from ..typedefs cimport SUFFIX as _SUFFIX
from ..typedefs cimport LEMMA as _LEMMA
from ..typedefs cimport POS as _POS
# Work around the lack of global cpdef variables
@ -25,8 +27,10 @@ cpdef enum:
ID = _ID
SIC = _SIC
SHAPE = _DENSE
DENSE = _SHAPE
SHAPE = _SHAPE
DENSE = _DENSE
PREFIX = _PREFIX
SUFFIX = _SUFFIX
CLUSTER = _CLUSTER
LEMMA = _LEMMA
POS = _POS

View File

@ -1,3 +1,4 @@
# cython: profile=True
from os import path
import json
@ -228,7 +229,7 @@ cdef class EnPosTagger(Tagger):
cdef TokenC* t = tokens.data
for i in range(tokens.length):
fill_context(context, i, t)
t[i].pos = self.predict(context)
t[i].fine_pos = self.predict(context)
self.set_morph(i, t)
def train(self, Tokens tokens, golds):
@ -238,13 +239,14 @@ cdef class EnPosTagger(Tagger):
cdef TokenC* t = tokens.data
for i in range(tokens.length):
fill_context(context, i, t)
t[i].pos = self.predict(context, [golds[i]])
t[i].fine_pos = self.predict(context, [golds[i]])
self.set_morph(i, t)
c += t[i].pos == golds[i]
c += t[i].fine_pos == golds[i]
return c
cdef int set_morph(self, const int i, TokenC* tokens) except -1:
cdef const PosTag* tag = &self.tags[tokens[i].pos]
cdef const PosTag* tag = &self.tags[tokens[i].fine_pos]
tokens[i].pos = tag.pos
cached = <_CachedMorph*>self._morph_cache.get(tag.id, tokens[i].lex.sic)
if cached is NULL:
cached = <_CachedMorph*>self.mem.alloc(1, sizeof(_CachedMorph))

View File

@ -40,8 +40,9 @@ cdef struct PosTag:
cdef struct TokenC:
const Lexeme* lex
Morphology morph
univ_tag_t pos
int fine_pos
int idx
int pos
int lemma
int sense
int head

View File

@ -13,8 +13,7 @@ from itertools import combinations
from ..tokens cimport TokenC
from ._state cimport State
from ._state cimport get_s2, get_s1, get_s0, get_n0, get_n1, get_n2
from ._state cimport has_head, get_left, get_right
from ._state cimport count_left_kids, count_right_kids
from ._state cimport get_left, get_right
cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
@ -25,12 +24,10 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
context[3] = 0
context[4] = 0
context[5] = 0
context[6] = 0
else:
context[0] = token.lex.sic
context[1] = token.lemma
context[2] = token.pos
context[3] = token.lex.cluster
context[1] = token.pos
context[2] = token.lex.cluster
# We've read in the string little-endian, so now we can take & (2**n)-1
# to get the first n bits of the cluster.
# e.g. s = "1110010101"
@ -43,9 +40,9 @@ cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
# What we're doing here is picking a number where all bits are 1, e.g.
# 15 is 1111, 63 is 111111 and doing bitwise AND, so getting all bits in
# the source that are set to 1.
context[4] = token.lex.cluster & 63
context[5] = token.lex.cluster & 15
context[6] = token.dep_tag if has_head(token) else 0
context[3] = token.lex.cluster & 63
context[4] = token.lex.cluster & 15
context[5] = token.dep_tag
cdef int fill_context(atom_t* context, State* state) except -1:
@ -69,148 +66,12 @@ cdef int fill_context(atom_t* context, State* state) except -1:
context[dist] = state.stack[0] - state.i
else:
context[dist] = 0
context[N0lv] = max(count_left_kids(get_n0(state)), 5)
context[S0lv] = max(count_left_kids(get_s0(state)), 5)
context[S0rv] = max(count_right_kids(get_s0(state)), 5)
context[S1lv] = max(count_left_kids(get_s1(state)), 5)
context[S1rv] = max(count_right_kids(get_s1(state)), 5)
context[N0lv] = 0
context[S0lv] = 0
context[S0rv] = 0
context[S1lv] = 0
context[S1rv] = 0
context[S0_has_head] = 0
context[S1_has_head] = 0
context[S2_has_head] = 0
if state.stack_len >= 1:
context[S0_has_head] = has_head(get_s0(state)) + 1
if state.stack_len >= 2:
context[S1_has_head] = has_head(get_s1(state)) + 1
if state.stack_len >= 3:
context[S2_has_head] = has_head(get_s2(state))
unigrams = (
(S2W, S2p),
(S2c6, S2p),
(S1W, S1p),
(S1c6, S1p),
(S0W, S0p),
(S0c6, S0p),
(N0W, N0p),
(N0p,),
(N0c,),
(N0c6, N0p),
(N0L,),
(N1W, N1p),
(N1c6, N1p),
(N2W, N2p),
(N2c6, N2p),
(S0r2W, S0r2p),
(S0r2c6, S0r2p),
(S0r2L,),
(S0rW, S0rp),
(S0rc6, S0rp),
(S0rL,),
(S0l2W, S0l2p),
(S0l2c6, S0l2p),
(S0l2L,),
(S0lW, S0lp),
(S0lc6, S0lp),
(S0lL,),
(N0l2W, N0l2p),
(N0l2c6, N0l2p),
(N0l2L,),
(N0lW, N0lp),
(N0lc6, N0lp),
(N0lL,),
)
s0_n0 = (
(S0W, S0p, N0W, N0p),
(S0c, S0p, N0c, N0p),
(S0c6, S0p, N0c6, N0p),
(S0c4, S0p, N0c4, N0p),
(S0p, N0p),
(S0W, N0p),
(S0p, N0W),
(S0W, N0c),
(S0c, N0W),
(S0p, N0c),
(S0c, N0p),
(S0W, S0rp, N0p),
(S0p, S0rp, N0p),
(S0p, N0lp, N0W),
(S0p, N0lp, N0p),
)
s1_n0 = (
(S1p, N0p),
(S1c, N0c),
(S1c, N0p),
(S1p, N0c),
(S1W, S1p, N0p),
(S1p, N0W, N0p),
(S1c6, S1p, N0c6, N0p),
)
s0_n1 = (
(S0p, N1p),
(S0c, N1c),
(S0c, N1p),
(S0p, N1c),
(S0W, S0p, N1p),
(S0p, N1W, N1p),
(S0c6, S0p, N1c6, N1p),
)
n0_n1 = (
(N0W, N0p, N1W, N1p),
(N0W, N0p, N1p),
(N0p, N1W, N1p),
(N0c, N0p, N1c, N1p),
(N0c6, N0p, N1c6, N1p),
(N0c, N1c),
(N0p, N1c),
)
tree_shape = (
(dist,),
(S0p, S0_has_head, S1_has_head, S2_has_head),
(S0p, S0lv, S0rv),
(N0p, N0lv),
)
trigrams = (
(N0p, N1p, N2p),
(S0p, S0lp, S0l2p),
(S0p, S0rp, S0r2p),
(S0p, S1p, S2p),
(S1p, S0p, N0p),
(S0p, S0lp, N0p),
(S0p, N0p, N0lp),
(N0p, N0lp, N0l2p),
(S0W, S0p, S0rL, S0r2L),
(S0p, S0rL, S0r2L),
(S0W, S0p, S0lL, S0l2L),
(S0p, S0lL, S0l2L),
(N0W, N0p, N0lL, N0l2L),
(N0p, N0lL, N0l2L),
)
arc_eager = (
(S0w, S0p),
@ -225,6 +86,7 @@ arc_eager = (
(N2w, N2p),
(N2w,),
(N2p,),
(S0w, S0p, N0w, N0p),
(S0w, S0p, N0w),
(S0w, N0w, N0p),

View File

@ -6,7 +6,7 @@ cimport numpy as np
from cymem.cymem cimport Pool
from thinc.typedefs cimport atom_t
from .typedefs cimport flags_t
from .typedefs cimport flags_t, attr_id_t, attr_t
from .structs cimport Morphology, TokenC, Lexeme
from .vocab cimport Vocab
from .strings cimport StringStore
@ -20,6 +20,13 @@ ctypedef fused LexemeOrToken:
TokenC_ptr
cdef attr_t get_lex_attr(const Lexeme* lex, attr_id_t feat_name) nogil
cdef attr_t get_token_attr(const TokenC* lex, attr_id_t feat_name) nogil
cdef inline bint check_flag(const Lexeme* lexeme, attr_id_t flag_id) nogil:
return lexeme.flags & (1 << flag_id)
cdef class Tokens:
cdef Pool mem
cdef Vocab vocab
@ -36,28 +43,5 @@ cdef class Tokens:
cdef class Token:
cdef readonly StringStore string_store
cdef public int i
cdef public int idx
cdef readonly int pos_id
cdef readonly int dep_id
cdef int lemma
cdef public int head
cdef public int dep_tag
cdef public atom_t id
cdef public atom_t cluster
cdef public atom_t length
cdef public atom_t postype
cdef public atom_t sensetype
cdef public atom_t sic
cdef public atom_t norm
cdef public atom_t shape
cdef public atom_t asciied
cdef public atom_t prefix
cdef public atom_t suffix
cdef public float prob
cdef public flags_t flags
cdef Tokens _seq
cdef readonly int i

View File

@ -2,16 +2,17 @@
from preshed.maps cimport PreshMap
from preshed.counter cimport PreshCounter
from .lexeme cimport get_attr, EMPTY_LEXEME
from .vocab cimport EMPTY_LEXEME
from .typedefs cimport attr_id_t, attr_t
from .typedefs cimport LEMMA
from .typedefs cimport ID, SIC, DENSE, SHAPE, PREFIX, SUFFIX, LENGTH, CLUSTER, POS_TYPE
from .typedefs cimport POS, LEMMA
cimport cython
import numpy as np
cimport numpy as np
POS = 0
ENTITY = 0
DEF PADDING = 5
@ -23,6 +24,40 @@ cdef int bounds_check(int i, int length, int padding) except -1:
raise IndexError
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
if feat_name == LEMMA:
return token.lemma
elif feat_name == POS:
return token.pos
else:
return get_lex_attr(token.lex, feat_name)
cdef attr_t get_lex_attr(const Lexeme* lex, attr_id_t feat_name) nogil:
if feat_name < (sizeof(flags_t) * 8):
return check_flag(lex, feat_name)
elif feat_name == ID:
return lex.id
elif feat_name == SIC:
return lex.sic
elif feat_name == DENSE:
return lex.dense
elif feat_name == SHAPE:
return lex.shape
elif feat_name == PREFIX:
return lex.prefix
elif feat_name == SUFFIX:
return lex.suffix
elif feat_name == LENGTH:
return lex.length
elif feat_name == CLUSTER:
return lex.cluster
elif feat_name == POS_TYPE:
return lex.pos_type
else:
return 0
cdef class Tokens:
"""A sequence of references to Lexeme objects.
@ -52,9 +87,7 @@ cdef class Tokens:
def __getitem__(self, i):
bounds_check(i, self.length, PADDING)
return Token(self.vocab.strings, i, self.data[i].idx, self.data[i].pos,
self.data[i].lemma, self.data[i].head, self.data[i].dep_tag,
self.data[i].lex[0])
return Token(self, i)
def __iter__(self):
for i in range(self.length):
@ -71,6 +104,7 @@ cdef class Tokens:
t[0] = lex_or_tok[0]
else:
t.lex = lex_or_tok
t.idx = idx
self.length += 1
return idx + t.lex.length
@ -82,7 +116,7 @@ cdef class Tokens:
output = np.ndarray(shape=(self.length, len(attr_ids)), dtype=int)
for i in range(self.length):
for j, feature in enumerate(attr_ids):
output[i, j] = get_attr(self.data[i].lex, feature)
output[i, j] = get_token_attr(&self.data[i], feature)
return output
def count_by(self, attr_id_t attr_id):
@ -92,10 +126,7 @@ cdef class Tokens:
cdef PreshCounter counts = PreshCounter(2 ** 8)
for i in range(self.length):
if attr_id == LEMMA:
attr = self.data[i].lemma
else:
attr = get_attr(self.data[i].lex, attr_id)
attr = get_token_attr(&self.data[i], attr_id)
counts.inc(attr, 1)
return dict(counts)
@ -117,49 +148,68 @@ cdef class Tokens:
@cython.freelist(64)
cdef class Token:
def __init__(self, StringStore string_store, int i, int idx,
int pos, int lemma, int head, int dep_tag, dict lex):
self.string_store = string_store
self.idx = idx
self.pos_id = pos
def __init__(self, Tokens tokens, int i):
self._seq = tokens
self.i = i
self.head = head
self.dep_id = dep_tag
self.id = lex['id']
self.lemma = lemma
self.cluster = lex['cluster']
self.length = lex['length']
self.postype = lex['pos_type']
self.sensetype = 0
self.sic = lex['sic']
self.norm = lex['dense']
self.shape = lex['shape']
self.suffix = lex['suffix']
self.prefix = lex['prefix']
def __unicode__(self):
cdef const TokenC* t = &self._seq.data[self.i]
cdef int end_idx = t.idx + t.lex.length
if self.i + 1 == self._seq.length:
return self.string
if end_idx == t[1].idx:
return self.string
else:
return self.string + ' '
self.prob = lex['prob']
self.flags = lex['flags']
def __len__(self):
return self._seq.data[self.i].lex.length
property idx:
def __get__(self):
return self._seq.data[self.i].idx
property length:
def __get__(self):
return self._seq.data[self.i].lex.length
property cluster:
def __get__(self):
return self._seq.data[self.i].lex.cluster
property string:
def __get__(self):
if self.sic == 0:
cdef const TokenC* t = &self._seq.data[self.i]
if t.lex.sic == 0:
return ''
cdef bytes utf8string = self.string_store[self.sic]
cdef bytes utf8string = self._seq.vocab.strings[t.lex.sic]
return utf8string.decode('utf8')
property lemma:
def __get__(self):
if self.lemma == 0:
cdef const TokenC* t = &self._seq.data[self.i]
if t.lemma == 0:
return self.string
cdef bytes utf8string = self.string_store[self.lemma]
cdef bytes utf8string = self._seq.vocab.strings[t.lemma]
return utf8string.decode('utf8')
property dep:
def __get__(self):
return self.string_store.dep_tags[self.dep_id]
return self._seq.data[self.i].dep_tag
property pos:
def __get__(self):
return self.pos_id
return self._seq.data[self.i].pos
property fine_pos:
def __get__(self):
return self._seq.data[self.i].fine_pos
property sic:
def __get__(self):
return self._seq.data[self.i].lex.sic
property head:
def __get__(self):
cdef const TokenC* t = &self._seq.data[self.i]
return Token(self._seq, self.i + t.head)

View File

@ -99,6 +99,7 @@ cpdef enum attr_id_t:
CLUSTER
POS_TYPE
LEMMA
POS

View File

@ -9,6 +9,9 @@ from .typedefs cimport utf8_t, id_t, hash_t
from .strings cimport StringStore
cdef Lexeme EMPTY_LEXEME
cdef union LexemesOrTokens:
const Lexeme* const* lexemes
TokenC* tokens

View File

@ -1,10 +1,34 @@
from libc.stdio cimport fopen, fclose, fread, fwrite, FILE
from libc.string cimport memset
from os import path
from .lexeme cimport EMPTY_LEXEME
from .lexeme cimport init as lexeme_init
from .strings cimport slice_unicode
from . import orth
memset(&EMPTY_LEXEME, 0, sizeof(Lexeme))
cpdef Lexeme init_lexeme(id_t i, unicode string, hash_t hashed,
StringStore string_store, dict props) except *:
cdef Lexeme lex
lex.id = i
lex.length = len(string)
lex.sic = string_store[string]
lex.cluster = props.get('cluster', 0)
lex.pos_type = props.get('pos_type', 0)
lex.prob = props.get('prob', 0)
lex.prefix = string_store[string[:1]]
lex.suffix = string_store[string[-3:]]
lex.shape = string_store[orth.word_shape(string)]
lex.flags = props.get('flags', 0)
return lex
cdef class Vocab:
@ -43,7 +67,7 @@ cdef class Vocab:
mem = self.mem
cdef unicode py_string = string.chars[:string.n]
lex = <Lexeme*>mem.alloc(sizeof(Lexeme), 1)
lex[0] = lexeme_init(self.lexemes.size(), py_string, string.key, self.strings,
lex[0] = init_lexeme(self.lexemes.size(), py_string, string.key, self.strings,
self.get_lex_props(py_string))
if mem is self.mem:
self._map.set(string.key, lex)