WIP on hash kernel

This commit is contained in:
Matthew Honnibal 2017-03-14 21:28:43 +01:00
parent 2ac166eacd
commit 755d7d486c
11 changed files with 383 additions and 179 deletions

View File

@ -56,6 +56,7 @@ MOD_NAMES = [
'spacy.lexeme', 'spacy.lexeme',
'spacy.vocab', 'spacy.vocab',
'spacy.attrs', 'spacy.attrs',
'spacy._ml',
'spacy.morphology', 'spacy.morphology',
'spacy.tagger', 'spacy.tagger',
'spacy.pipeline', 'spacy.pipeline',

31
spacy/_ml.pxd Normal file
View File

@ -0,0 +1,31 @@
from thinc.linear.features cimport ConjunctionExtracter
from thinc.typedefs cimport atom_t, weight_t
from thinc.structs cimport FeatureC
from libc.stdint cimport uint32_t
cimport numpy as np
from cymem.cymem cimport Pool
cdef class LinearModel:
cdef ConjunctionExtracter extracter
cdef readonly int nr_class
cdef readonly uint32_t nr_weight
cdef public weight_t learn_rate
cdef Pool mem
cdef weight_t* W
cdef weight_t* d_W
cdef void hinge_lossC(self, weight_t* d_scores,
const weight_t* scores, const weight_t* costs) nogil
cdef void log_lossC(self, weight_t* d_scores,
const weight_t* scores, const weight_t* costs) nogil
cdef void regression_lossC(self, weight_t* d_scores,
const weight_t* scores, const weight_t* costs) nogil
cdef void set_scoresC(self, weight_t* scores,
const FeatureC* features, int nr_feat) nogil
cdef void set_gradientC(self, const weight_t* d_scores, const FeatureC*
features, int nr_feat) nogil

151
spacy/_ml.pyx Normal file
View File

@ -0,0 +1,151 @@
# cython: infer_types=True
# cython: profile=True
# cython: cdivision=True
from libcpp.vector cimport vector
from libc.stdint cimport uint64_t, uint32_t, int32_t
from libc.string cimport memcpy, memset
cimport libcpp.algorithm
from libc.math cimport exp
from cymem.cymem cimport Pool
from thinc.linalg cimport Vec, VecVec
from murmurhash.mrmr cimport hash64
cimport numpy as np
import numpy
np.import_array()
cdef class LinearModel:
def __init__(self, int nr_class, templates, weight_t learn_rate=0.001,
size=2**18):
self.extracter = ConjunctionExtracter(templates)
self.nr_weight = size
self.nr_class = nr_class
self.learn_rate = learn_rate
self.mem = Pool()
self.W = <weight_t*>self.mem.alloc(self.nr_weight * self.nr_class,
sizeof(weight_t))
self.d_W = <weight_t*>self.mem.alloc(self.nr_weight * self.nr_class,
sizeof(weight_t))
cdef void hinge_lossC(self, weight_t* d_scores,
const weight_t* scores, const weight_t* costs) nogil:
guess = 0
best = -1
for i in range(1, self.nr_class):
if scores[i] > scores[guess]:
guess = i
if costs[i] == 0 and (best == -1 or scores[i] > scores[best]):
best = i
if best != -1 and scores[guess] >= scores[best]:
d_scores[guess] = 1.
d_scores[best] = -1.
cdef void log_lossC(self, weight_t* d_scores,
const weight_t* scores, const weight_t* costs) nogil:
for i in range(self.nr_class):
if costs[i] <= 0:
break
else:
return
cdef double Z = 1e-10
cdef double gZ = 1e-10
cdef double max_ = scores[0]
cdef double g_max = -9000
for i in range(self.nr_class):
max_ = max(max_, scores[i])
if costs[i] <= 0:
g_max = max(g_max, scores[i])
for i in range(self.nr_class):
Z += exp(scores[i]-max_)
if costs[i] <= 0:
gZ += exp(scores[i]-g_max)
for i in range(self.nr_class):
score = exp(scores[i]-max_)
if costs[i] >= 1:
d_scores[i] = score / Z
else:
g_score = exp(scores[i]-g_max)
d_scores[i] = (score / Z) - (g_score / gZ)
cdef void regression_lossC(self, weight_t* d_scores,
const weight_t* scores, const weight_t* costs) nogil:
best = -1
for i in range(self.nr_class):
if costs[i] <= 0:
if best == -1:
best = i
elif scores[i] > scores[best]:
best = i
if best == -1:
return
for i in range(self.nr_class):
if scores[i] < scores[best]:
d_scores[i] = 0
elif costs[i] <= 0 and scores[i] == best:
continue
else:
d_scores[i] = scores[i] - -costs[i]
cdef void set_scoresC(self, weight_t* scores,
const FeatureC* features, int nr_feat) nogil:
cdef uint64_t nr_weight = self.nr_weight
cdef int nr_class = self.nr_class
cdef vector[uint64_t] indices
# Collect all feature indices
cdef uint32_t[2] hashed
cdef FeatureC feat
cdef uint64_t hash2
for feat in features[:nr_feat]:
if feat.value == 0:
continue
memcpy(hashed, &feat.key, sizeof(hashed))
indices.push_back(hashed[0] % nr_weight)
indices.push_back(hashed[1] % nr_weight)
# Sort them, to improve memory access pattern
libcpp.algorithm.sort(indices.begin(), indices.end())
for idx in indices:
W = &self.W[idx * nr_class]
for clas in range(nr_class):
scores[clas] += W[clas]
cdef void set_gradientC(self, const weight_t* d_scores, const FeatureC*
features, int nr_feat) nogil:
cdef uint64_t nr_weight = self.nr_weight
cdef int nr_class = self.nr_class
cdef vector[uint64_t] indices
# Collect all feature indices
cdef uint32_t[2] hashed
cdef uint64_t hash2
for feat in features[:nr_feat]:
if feat.value == 0:
continue
memcpy(hashed, &feat.key, sizeof(hashed))
indices.push_back(hashed[0] % nr_weight)
indices.push_back(hashed[1] % nr_weight)
# Sort them, to improve memory access pattern
libcpp.algorithm.sort(indices.begin(), indices.end())
for idx in indices:
W = &self.W[idx * nr_class]
for clas in range(nr_class):
if d_scores[clas] < 0:
W[clas] -= self.learn_rate * max(-10., d_scores[clas])
else:
W[clas] -= self.learn_rate * min(10., d_scores[clas])
@property
def nr_active_feat(self):
return self.nr_weight
@property
def nr_feat(self):
return self.extracter.nr_templ
def end_training(self, *args, **kwargs):
pass
def dump(self, *args, **kwargs):
pass

View File

@ -4,13 +4,13 @@
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py # https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
__title__ = 'spacy' __title__ = 'spacy'
__version__ = '1.6.0' __version__ = '1.7.0'
__summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython' __summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython'
__uri__ = 'https://spacy.io' __uri__ = 'https://spacy.io'
__author__ = 'Matthew Honnibal' __author__ = 'Matthew Honnibal'
__email__ = 'matt@explosion.ai' __email__ = 'matt@explosion.ai'
__license__ = 'MIT' __license__ = 'MIT'
__models__ = { __models__ = {
'en': 'en>=1.1.0,<1.2.0', 'en': 'en>=1.2.0,<1.3.0',
'de': 'de>=1.0.0,<1.1.0', 'de': 'de>=1.2.0,<1.3.0',
} }

View File

@ -304,11 +304,13 @@ cdef cppclass StateC:
this._break = this._b_i this._break = this._b_i
void clone(const StateC* src) nogil: void clone(const StateC* src) nogil:
memcpy(this._sent, src._sent, this.length * sizeof(TokenC)) # This is still quadratic, but make it a it faster.
memcpy(this._stack, src._stack, this.length * sizeof(int)) # Not carefully reviewed for accuracy yet.
memcpy(this._buffer, src._buffer, this.length * sizeof(int)) memcpy(this._sent, src._sent, this.B(1) * sizeof(TokenC))
memcpy(this._ents, src._ents, this.length * sizeof(Entity)) memcpy(this._stack, src._stack, this._s_i * sizeof(int))
memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0])) memcpy(this._buffer, src._buffer, this._b_i * sizeof(int))
memcpy(this._ents, src._ents, this._e_i * sizeof(Entity))
memcpy(this.shifted, src.shifted, this.B(2) * sizeof(this.shifted[0]))
this.length = src.length this.length = src.length
this._b_i = src._b_i this._b_i = src._b_i
this._s_i = src._s_i this._s_i = src._s_i

View File

@ -70,7 +70,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef weight_t cost = 0 cdef weight_t cost = 0
cdef int i, B_i cdef int i, B_i
for i in range(stcls.buffer_length()): for i in range(min(30, stcls.buffer_length())):
B_i = stcls.B(i) B_i = stcls.B(i)
cost += gold.heads[B_i] == target cost += gold.heads[B_i] == target
cost += gold.heads[target] == B_i cost += gold.heads[target] == B_i
@ -268,10 +268,12 @@ cdef class Break:
cdef int i, j, S_i, B_i cdef int i, j, S_i, B_i
for i in range(s.stack_depth()): for i in range(s.stack_depth()):
S_i = s.S(i) S_i = s.S(i)
for j in range(s.buffer_length()): for j in range(min(30, s.buffer_length())):
B_i = s.B(j) B_i = s.B(j)
cost += gold.heads[S_i] == B_i cost += gold.heads[S_i] == B_i
cost += gold.heads[B_i] == S_i cost += gold.heads[B_i] == S_i
if cost != 0:
break
# Check for sentence boundary --- if it's here, we can't have any deps # Check for sentence boundary --- if it's here, we can't have any deps
# between stack and buffer, so rest of action is irrelevant. # between stack and buffer, so rest of action is irrelevant.
s0_root = _get_root(s.S(0), gold) s0_root = _get_root(s.S(0), gold)
@ -462,7 +464,7 @@ cdef class ArcEager(TransitionSystem):
cdef int* labels = gold.c.labels cdef int* labels = gold.c.labels
cdef int* heads = gold.c.heads cdef int* heads = gold.c.heads
n_gold = 0 cdef int n_gold = 0
for i in range(self.n_moves): for i in range(self.n_moves):
if self.c[i].is_valid(stcls.c, self.c[i].label): if self.c[i].is_valid(stcls.c, self.c[i].label):
is_valid[i] = True is_valid[i] = True

View File

@ -1,5 +1,6 @@
from thinc.linear.avgtron cimport AveragedPerceptron from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.typedefs cimport atom_t from thinc.linear.features cimport ConjunctionExtracter
from thinc.typedefs cimport atom_t, weight_t
from thinc.structs cimport FeatureC from thinc.structs cimport FeatureC
from .stateclass cimport StateClass from .stateclass cimport StateClass
@ -8,9 +9,10 @@ from ..vocab cimport Vocab
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..structs cimport TokenC from ..structs cimport TokenC
from ._state cimport StateC from ._state cimport StateC
from .._ml cimport LinearModel
cdef class ParserModel(AveragedPerceptron): cdef class ParserModel(LinearModel):
cdef int set_featuresC(self, atom_t* context, FeatureC* features, cdef int set_featuresC(self, atom_t* context, FeatureC* features,
const StateC* state) nogil const StateC* state) nogil
@ -20,5 +22,6 @@ cdef class Parser:
cdef readonly ParserModel model cdef readonly ParserModel model
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef readonly object cfg cdef readonly object cfg
cdef public object optimizer
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil

View File

@ -1,4 +1,6 @@
# cython: infer_types=True # cython: infer_types=True
# cython: cdivision=True
# cython: profile=True
""" """
MALT-style dependency parser MALT-style dependency parser
""" """
@ -20,15 +22,22 @@ import shutil
import json import json
import sys import sys
from .nonproj import PseudoProjectivity from .nonproj import PseudoProjectivity
import numpy
import random
cimport numpy as np
np.import_array()
from cymem.cymem cimport Pool, Address from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64, hash32
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from thinc.linear.avgtron cimport AveragedPerceptron from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.linalg cimport VecVec from thinc.linalg cimport VecVec
from thinc.structs cimport SparseArrayC from thinc.structs cimport SparseArrayC
from preshed.maps cimport MapStruct from preshed.maps cimport MapStruct
from preshed.maps cimport map_get from preshed.maps cimport map_get
from thinc.neural.ops import NumpyOps
from thinc.neural.optimizers import Adam
from thinc.neural.optimizers import SGD
from thinc.structs cimport FeatureC from thinc.structs cimport FeatureC
from thinc.structs cimport ExampleC from thinc.structs cimport ExampleC
@ -51,6 +60,7 @@ from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context from ._parse_features cimport fill_context
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
from .._ml cimport LinearModel
DEBUG = False DEBUG = False
@ -72,57 +82,65 @@ def get_templates(name):
pf.tree_shape + pf.trigrams) pf.tree_shape + pf.trigrams)
cdef class ParserModel(AveragedPerceptron): #cdef class ParserModel(AveragedPerceptron):
# cdef int set_featuresC(self, atom_t* context, FeatureC* features,
# const StateC* state) nogil:
# fill_context(context, state)
# nr_feat = self.extracter.set_features(features, context)
# return nr_feat
#
# def update(self, Example eg, itn=0):
# '''Does regression on negative cost. Sort of cute?'''
# self.time += 1
# best = arg_max_if_gold(eg.c.scores, eg.c.costs, eg.c.nr_class)
# guess = eg.guess
# cdef weight_t loss = 0.0
# if guess == best:
# return loss
# for clas in [guess, best]:
# loss += (-eg.c.costs[clas] - eg.c.scores[clas]) ** 2
# d_loss = eg.c.scores[clas] - -eg.c.costs[clas]
# for feat in eg.c.features[:eg.c.nr_feat]:
# self.update_weight_ftrl(feat.key, clas, feat.value * d_loss)
# return loss
#
# def update_from_histories(self, TransitionSystem moves, Doc doc, histories, weight_t min_grad=0.0):
# cdef Pool mem = Pool()
# features = <FeatureC*>mem.alloc(self.nr_feat, sizeof(FeatureC))
#
# cdef StateClass stcls
#
# cdef class_t clas
# self.time += 1
# cdef atom_t[CONTEXT_SIZE] atoms
# histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad and hist]
# if not histories:
# return None
# gradient = [Counter() for _ in range(max([max(h)+1 for _, h in histories]))]
# for d_loss, history in histories:
# stcls = StateClass.init(doc.c, doc.length)
# moves.initialize_state(stcls.c)
# for clas in history:
# nr_feat = self.set_featuresC(atoms, features, stcls.c)
# clas_grad = gradient[clas]
# for feat in features[:nr_feat]:
# clas_grad[feat.key] += d_loss * feat.value
# moves.c[clas].do(stcls.c, moves.c[clas].label)
# cdef feat_t key
# cdef weight_t d_feat
# for clas, clas_grad in enumerate(gradient):
# for key, d_feat in clas_grad.items():
# if d_feat != 0:
# self.update_weight_ftrl(key, clas, d_feat)
#
cdef class ParserModel(LinearModel):
cdef int set_featuresC(self, atom_t* context, FeatureC* features, cdef int set_featuresC(self, atom_t* context, FeatureC* features,
const StateC* state) nogil: const StateC* state) nogil:
fill_context(context, state) fill_context(context, state)
nr_feat = self.extracter.set_features(features, context) nr_feat = self.extracter.set_features(features, context)
return nr_feat return nr_feat
def update(self, Example eg, itn=0):
'''Does regression on negative cost. Sort of cute?'''
self.time += 1
best = arg_max_if_gold(eg.c.scores, eg.c.costs, eg.c.nr_class)
guess = eg.guess
cdef weight_t loss = 0.0
if guess == best:
return loss
for clas in [guess, best]:
loss += (-eg.c.costs[clas] - eg.c.scores[clas]) ** 2
d_loss = eg.c.scores[clas] - -eg.c.costs[clas]
for feat in eg.c.features[:eg.c.nr_feat]:
self.update_weight_ftrl(feat.key, clas, feat.value * d_loss)
return loss
def update_from_histories(self, TransitionSystem moves, Doc doc, histories, weight_t min_grad=0.0):
cdef Pool mem = Pool()
features = <FeatureC*>mem.alloc(self.nr_feat, sizeof(FeatureC))
cdef StateClass stcls
cdef class_t clas
self.time += 1
cdef atom_t[CONTEXT_SIZE] atoms
histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad and hist]
if not histories:
return None
gradient = [Counter() for _ in range(max([max(h)+1 for _, h in histories]))]
for d_loss, history in histories:
stcls = StateClass.init(doc.c, doc.length)
moves.initialize_state(stcls.c)
for clas in history:
nr_feat = self.set_featuresC(atoms, features, stcls.c)
clas_grad = gradient[clas]
for feat in features[:nr_feat]:
clas_grad[feat.key] += d_loss * feat.value
moves.c[clas].do(stcls.c, moves.c[clas].label)
cdef feat_t key
cdef weight_t d_feat
for clas, clas_grad in enumerate(gradient):
for key, d_feat in clas_grad.items():
if d_feat != 0:
self.update_weight_ftrl(key, clas, d_feat)
cdef class Parser: cdef class Parser:
"""Base class of the DependencyParser and EntityRecognizer.""" """Base class of the DependencyParser and EntityRecognizer."""
@ -174,9 +192,14 @@ cdef class Parser:
cfg['features'] = get_templates(cfg['features']) cfg['features'] = get_templates(cfg['features'])
elif 'features' not in cfg: elif 'features' not in cfg:
cfg['features'] = self.feature_templates cfg['features'] = self.feature_templates
self.model = ParserModel(cfg['features']) self.model = ParserModel(self.moves.n_moves, cfg['features'],
self.model.l1_penalty = cfg.get('L1', 1e-8) size=2**18,
self.model.learn_rate = cfg.get('learn_rate', 0.001) learn_rate=cfg.get('learn_rate', 0.001))
#self.model.l1_penalty = cfg.get('L1', 1e-8)
#self.model.learn_rate = cfg.get('learn_rate', 0.001)
self.optimizer = SGD(NumpyOps(), cfg.get('learn_rate', 0.001),
momentum=0.9)
self.cfg = cfg self.cfg = cfg
@ -300,27 +323,48 @@ cdef class Parser:
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length) cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
self.moves.initialize_state(stcls.c) self.moves.initialize_state(stcls.c)
cdef int nr_class = self.model.nr_class
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef Example eg = Example( d_scores = <weight_t*>mem.alloc(nr_class, sizeof(weight_t))
nr_class=self.moves.n_moves, scores = <weight_t*>mem.alloc(nr_class, sizeof(weight_t))
nr_atom=CONTEXT_SIZE, costs = <weight_t*>mem.alloc(nr_class, sizeof(weight_t))
nr_feat=self.model.nr_feat) features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
cdef atom_t[CONTEXT_SIZE] context
cdef weight_t loss = 0 cdef weight_t loss = 0
cdef Transition action cdef Transition action
words = [w.text for w in tokens]
while not stcls.is_final(): while not stcls.is_final():
eg.c.nr_feat = self.model.set_featuresC(eg.c.atoms, eg.c.features,
stcls.c) nr_feat = self.model.set_featuresC(context, features, stcls.c)
self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold) self.moves.set_costs(is_valid, costs, stcls, gold)
self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat) self.model.set_scoresC(scores, features, nr_feat)
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
self.model.update(eg) guess = VecVec.arg_max_if_true(scores, is_valid, nr_class)
best = arg_max_if_gold(scores, costs, nr_class)
self.model.regression_lossC(d_scores, scores, costs)
self.model.set_gradientC(d_scores, features, nr_feat)
action = self.moves.c[guess] action = self.moves.c[guess]
action.do(stcls.c, action.label) action.do(stcls.c, action.label)
loss += eg.costs[guess] #print(scores[guess], scores[best], d_scores[guess], costs[guess],
eg.fill_scores(0, eg.c.nr_class) # self.moves.move_name(action.move, action.label), stcls.print_state(words))
eg.fill_costs(0, eg.c.nr_class)
eg.fill_is_valid(1, eg.c.nr_class) loss += scores[guess]
memset(context, 0, sizeof(context))
memset(features, 0, sizeof(features[0]) * nr_feat)
memset(scores, 0, sizeof(scores[0]) * nr_class)
memset(d_scores, 0, sizeof(d_scores[0]) * nr_class)
memset(costs, 0, sizeof(costs[0]) * nr_class)
for i in range(nr_class):
is_valid[i] = 1
#if itn % 100 == 0:
# self.optimizer(self.model.model[0].ravel(),
# self.model.model[1].ravel(), key=1)
return loss return loss
def step_through(self, Doc doc): def step_through(self, Doc doc):

View File

@ -1,15 +1,14 @@
from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.extra.eg cimport Example
from thinc.structs cimport ExampleC
from thinc.linear.features cimport ConjunctionExtracter
from .structs cimport TokenC from .structs cimport TokenC
from .vocab cimport Vocab from .vocab cimport Vocab
from ._ml cimport LinearModel
from thinc.structs cimport FeatureC
from thinc.typedefs cimport atom_t
cdef class TaggerModel: cdef class TaggerModel(LinearModel):
cdef ConjunctionExtracter extracter cdef int set_featuresC(self, FeatureC* features, atom_t* context,
cdef object model const TokenC* tokens, int i) nogil
cdef class Tagger: cdef class Tagger:

View File

@ -16,9 +16,8 @@ from thinc.extra.eg cimport Example
from thinc.structs cimport ExampleC from thinc.structs cimport ExampleC
from thinc.linear.avgtron cimport AveragedPerceptron from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.linalg cimport Vec, VecVec from thinc.linalg cimport Vec, VecVec
from thinc.linear.linear import LinearModel
from thinc.structs cimport FeatureC from thinc.structs cimport FeatureC
from thinc.neural.optimizers import Adam from thinc.neural.optimizers import Adam, SGD
from thinc.neural.ops import NumpyOps from thinc.neural.ops import NumpyOps
from .typedefs cimport attr_t from .typedefs cimport attr_t
@ -80,69 +79,16 @@ cpdef enum:
N_CONTEXT_FIELDS N_CONTEXT_FIELDS
cdef class TaggerModel: cdef class TaggerModel(LinearModel):
def __init__(self, int nr_tag, templates): cdef int set_featuresC(self, FeatureC* features, atom_t* context,
self.extracter = ConjunctionExtracter(templates) const TokenC* tokens, int i) nogil:
self.model = LinearModel(nr_tag) _fill_from_token(&context[P2_orth], &tokens[i-2])
_fill_from_token(&context[P1_orth], &tokens[i-1])
def begin_update(self, atom_t[:, ::1] contexts, drop=0.): _fill_from_token(&context[W_orth], &tokens[i])
cdef vector[uint64_t]* keys = new vector[uint64_t]() _fill_from_token(&context[N1_orth], &tokens[i+1])
cdef vector[float]* values = new vector[float]() _fill_from_token(&context[N2_orth], &tokens[i+2])
cdef vector[int64_t]* lengths = new vector[int64_t]() nr_feat = self.extracter.set_features(features, context)
features = new vector[FeatureC](self.extracter.nr_templ) return nr_feat
features.resize(self.extracter.nr_templ)
cdef FeatureC feat
cdef int i, j
for i in range(contexts.shape[0]):
nr_feat = self.extracter.set_features(features.data(), &contexts[i, 0])
for j in range(nr_feat):
keys.push_back(features.at(j).key)
values.push_back(features.at(j).value)
lengths.push_back(nr_feat)
cdef np.ndarray[uint64_t, ndim=1] py_keys
cdef np.ndarray[float, ndim=1] py_values
cdef np.ndarray[long, ndim=1] py_lengths
py_keys = vector_uint64_2numpy(keys)
py_values = vector_float_2numpy(values)
py_lengths = vector_long_2numpy(lengths)
instance = (py_keys, py_values, py_lengths)
del keys
del values
del lengths
del features
return self.model.begin_update(instance, drop=drop)
def end_training(self, *args, **kwargs):
pass
def dump(self, *args, **kwargs):
pass
cdef np.ndarray[uint64_t, ndim=1] vector_uint64_2numpy(vector[uint64_t]* vec):
cdef np.ndarray[uint64_t, ndim=1, mode="c"] arr = np.zeros(vec.size(), dtype='uint64')
memcpy(arr.data, vec.data(), sizeof(uint64_t) * vec.size())
return arr
cdef np.ndarray[long, ndim=1] vector_long_2numpy(vector[int64_t]* vec):
cdef np.ndarray[long, ndim=1, mode="c"] arr = np.zeros(vec.size(), dtype='int64')
memcpy(arr.data, vec.data(), sizeof(int64_t) * vec.size())
return arr
cdef np.ndarray[float, ndim=1] vector_float_2numpy(vector[float]* vec):
cdef np.ndarray[float, ndim=1, mode="c"] arr = np.zeros(vec.size(), dtype='float32')
memcpy(arr.data, vec.data(), sizeof(float) * vec.size())
return arr
cdef void fill_context(atom_t* context, const TokenC* tokens, int i) nogil:
_fill_from_token(&context[P2_orth], &tokens[i-2])
_fill_from_token(&context[P1_orth], &tokens[i-1])
_fill_from_token(&context[W_orth], &tokens[i])
_fill_from_token(&context[N1_orth], &tokens[i+1])
_fill_from_token(&context[N2_orth], &tokens[i+2])
cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil: cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil:
@ -213,8 +159,10 @@ cdef class Tagger:
The newly constructed object. The newly constructed object.
""" """
if model is None: if model is None:
print("Create tagger")
model = TaggerModel(vocab.morphology.n_tags, model = TaggerModel(vocab.morphology.n_tags,
cfg.get('features', self.feature_templates)) cfg.get('features', self.feature_templates),
learn_rate=0.01, size=2**18)
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
# TODO: Move this to tag map # TODO: Move this to tag map
@ -223,7 +171,7 @@ cdef class Tagger:
self.freqs[TAG][self.vocab.strings[tag]] = 1 self.freqs[TAG][self.vocab.strings[tag]] = 1
self.freqs[TAG][0] = 1 self.freqs[TAG][0] = 1
self.cfg = cfg self.cfg = cfg
self.optimizer = Adam(NumpyOps(), 0.001) self.optimizer = SGD(NumpyOps(), 0.001, momentum=0.9)
@property @property
def tag_names(self): def tag_names(self):
@ -250,20 +198,22 @@ cdef class Tagger:
if tokens.length == 0: if tokens.length == 0:
return 0 return 0
cdef atom_t[1][N_CONTEXT_FIELDS] c_context cdef atom_t[N_CONTEXT_FIELDS] context
memset(c_context, 0, sizeof(c_context))
cdef atom_t[:, ::1] context = c_context
cdef float[:, ::1] scores
cdef int nr_class = self.vocab.morphology.n_tags cdef int nr_class = self.vocab.morphology.n_tags
cdef Pool mem = Pool()
scores = <weight_t*>mem.alloc(nr_class, sizeof(weight_t))
features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
for i in range(tokens.length): for i in range(tokens.length):
if tokens.c[i].pos == 0: if tokens.c[i].pos == 0:
fill_context(&context[0, 0], tokens.c, i) nr_feat = self.model.set_featuresC(features, context, tokens.c, i)
scores, _ = self.model.begin_update(context) self.model.set_scoresC(scores,
features, nr_feat)
guess = Vec.arg_max(&scores[0, 0], nr_class) guess = Vec.arg_max(scores, nr_class)
self.vocab.morphology.assign_tag_id(&tokens.c[i], guess) self.vocab.morphology.assign_tag_id(&tokens.c[i], guess)
memset(&scores[0, 0], 0, sizeof(float) * scores.size) memset(scores, 0, sizeof(weight_t) * nr_class)
memset(features, 0, sizeof(FeatureC) * nr_feat)
memset(context, 0, sizeof(N_CONTEXT_FIELDS))
tokens.is_tagged = True tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length tokens._py_tokens = [None] * tokens.length
@ -295,7 +245,6 @@ cdef class Tagger:
Returns (int): Returns (int):
Number of tags correct. Number of tags correct.
""" """
cdef int nr_class = self.vocab.morphology.n_tags
gold_tag_strs = gold.tags gold_tag_strs = gold.tags
assert len(tokens) == len(gold_tag_strs) assert len(tokens) == len(gold_tag_strs)
for tag in gold_tag_strs: for tag in gold_tag_strs:
@ -303,27 +252,47 @@ cdef class Tagger:
msg = ("Unrecognized gold tag: %s. tag_map.json must contain all " msg = ("Unrecognized gold tag: %s. tag_map.json must contain all "
"gold tags, to maintain coarse-grained mapping.") "gold tags, to maintain coarse-grained mapping.")
raise ValueError(msg % tag) raise ValueError(msg % tag)
golds = [self.tag_names.index(g) if g is not None else -1 for g in gold_tag_strs] cdef Pool mem = Pool()
golds = <int*>mem.alloc(sizeof(int), len(gold_tag_strs))
for i, g in enumerate(gold_tag_strs):
golds[i] = self.tag_names.index(g) if g is not None else -1
cdef atom_t[N_CONTEXT_FIELDS] context
cdef int nr_class = self.model.nr_class
costs = <weight_t*>mem.alloc(sizeof(weight_t), nr_class)
features = <FeatureC*>mem.alloc(sizeof(FeatureC), self.model.nr_feat)
scores = <weight_t*>mem.alloc(sizeof(weight_t), nr_class)
d_scores = <weight_t*>mem.alloc(sizeof(weight_t), nr_class)
cdef int correct = 0 cdef int correct = 0
cdef atom_t[:, ::1] context = np.zeros((1, N_CONTEXT_FIELDS), dtype='uint64')
cdef float[:, ::1] scores
for i in range(tokens.length): for i in range(tokens.length):
fill_context(&context[0, 0], tokens.c, i) nr_feat = self.model.set_featuresC(features, context, tokens.c, i)
scores, finish_update = self.model.begin_update(context) self.model.set_scoresC(scores,
guess = Vec.arg_max(&scores[0, 0], nr_class) features, nr_feat)
self.vocab.morphology.assign_tag_id(&tokens.c[i], guess)
if golds[i] != -1: if golds[i] != -1:
scores[0, golds[i]] -= 1 for j in range(nr_class):
finish_update(scores, lambda *args, **kwargs: None) costs[j] = 1
costs[golds[i]] = 0
self.model.log_lossC(d_scores, scores, costs)
self.model.set_gradientC(d_scores, features, nr_feat)
guess = Vec.arg_max(scores, nr_class)
#print(tokens[i].text, golds[i], guess, [features[i].key for i in range(nr_feat)])
self.vocab.morphology.assign_tag_id(&tokens.c[i], guess)
if (golds[i] in (guess, -1)):
correct += 1
self.freqs[TAG][tokens.c[i].tag] += 1 self.freqs[TAG][tokens.c[i].tag] += 1
self.optimizer(self.model.model.weights, self.model.model.d_weights, correct += costs[guess] == 0
key=self.model.model.id)
memset(features, 0, sizeof(FeatureC) * nr_feat)
memset(costs, 0, sizeof(weight_t) * nr_class)
memset(scores, 0, sizeof(weight_t) * nr_class)
memset(d_scores, 0, sizeof(weight_t) * nr_class)
#if itn % 10 == 0:
# self.optimizer(self.model.weights.ravel(), self.model.d_weights.ravel(),
# key=1)
tokens.is_tagged = True tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length tokens._py_tokens = [None] * tokens.length
return correct return correct

View File

@ -14,6 +14,7 @@ class Trainer(object):
self.nlp = nlp self.nlp = nlp
self.gold_tuples = gold_tuples self.gold_tuples = gold_tuples
self.nr_epoch = 0 self.nr_epoch = 0
self.nr_itn = 0
def epochs(self, nr_epoch, augment_data=None, gold_preproc=False): def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
cached_golds = {} cached_golds = {}
@ -36,6 +37,7 @@ class Trainer(object):
golds = self.make_golds(docs, paragraph_tuples) golds = self.make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
yield doc, gold yield doc, gold
self.nr_itn += 1
indices = list(range(len(self.gold_tuples))) indices = list(range(len(self.gold_tuples)))
for itn in range(nr_epoch): for itn in range(nr_epoch):
@ -46,7 +48,7 @@ class Trainer(object):
def update(self, doc, gold): def update(self, doc, gold):
for process in self.nlp.pipeline: for process in self.nlp.pipeline:
if hasattr(process, 'update'): if hasattr(process, 'update'):
loss = process.update(doc, gold, itn=self.nr_epoch) loss = process.update(doc, gold, itn=self.nr_itn)
process(doc) process(doc)
return doc return doc