# cython: profile=True """ MALT-style dependency parser """ from __future__ import unicode_literals cimport cython from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF from libc.stdint cimport uint32_t, uint64_t from libc.string cimport memset, memcpy import random import os.path from os import path import shutil import json from cymem.cymem cimport Pool, Address from murmurhash.mrmr cimport hash64 from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t from util import Config from thinc.api cimport Example from ..tokens cimport Tokens, TokenC from ..strings cimport StringStore from .transition_system import OracleError from .transition_system cimport TransitionSystem, Transition from ..gold cimport GoldParse from . import _parse_features from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context from .stateclass cimport StateClass DEBUG = False def set_debug(val): global DEBUG DEBUG = val def get_templates(name): pf = _parse_features if name == 'ner': return pf.ner elif name == 'debug': return pf.unigrams elif name.startswith('embed'): return (pf.words, pf.tags, pf.labels) else: return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \ pf.tree_shape + pf.trigrams) cdef class Parser: def __init__(self, StringStore strings, model_dir, transition_system, get_model=Model): assert os.path.exists(model_dir) and os.path.isdir(model_dir) self.cfg = Config.read(model_dir, 'config') self.moves = transition_system(strings, self.cfg.labels) templates = get_templates(self.cfg.features) self.model = get_model(self.moves.n_moves, templates, model_dir) def __call__(self, Tokens tokens): cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) self.moves.initialize_state(stcls) cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE, self.model.n_feats, self.model.n_feats) while not stcls.is_final(): memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t)) self.moves.set_valid(eg.c.is_valid, stcls) fill_context(eg.c.atoms, stcls) self.model.predict(eg) self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label) self.moves.finalize_state(stcls) tokens.set_parse(stcls._sent) def train(self, Tokens tokens, GoldParse gold): self.moves.preprocess_gold(gold) cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) self.moves.initialize_state(stcls) cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE, self.model.n_feats, self.model.n_feats) cdef int cost = 0 while not stcls.is_final(): memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t)) self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold) fill_context(eg.c.atoms, stcls) self.model.train(eg) self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label) cost += eg.c.cost return cost # These are passed as callbacks to thinc.search.Beam """ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: dest = _dest src = _src moves = _moves dest.clone(src) moves[clas].do(dest, moves[clas].label) cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef StateClass st = StateClass.init(tokens, length) st.fast_forward() Py_INCREF(st) return st cdef int _check_final_state(void* _state, void* extra_args) except -1: return (_state).is_final() def _cleanup(Beam beam): for i in range(beam.width): Py_XDECREF(beam._states[i].content) Py_XDECREF(beam._parents[i].content) cdef hash_t _hash_state(void* _state, void* _) except 0: return _state #state = _state #cdef atom_t[10] rep #rep[0] = state.stack[0] if state.stack_len >= 1 else 0 #rep[1] = state.stack[-1] if state.stack_len >= 2 else 0 #rep[2] = state.stack[-2] if state.stack_len >= 3 else 0 #rep[3] = state.i #rep[4] = state.sent[state.stack[0]].l_kids if state.stack_len >= 1 else 0 #rep[5] = state.sent[state.stack[0]].r_kids if state.stack_len >= 1 else 0 #rep[6] = state.sent[state.stack[0]].dep if state.stack_len >= 1 else 0 #rep[7] = state.sent[state.stack[-1]].dep if state.stack_len >= 2 else 0 #if get_left(state, get_n0(state), 1) != NULL: # rep[8] = get_left(state, get_n0(state), 1).dep #else: # rep[8] = 0 #rep[9] = state.sent[state.i].l_kids #return hash64(rep, sizeof(atom_t) * 10, 0) cdef int _beam_parse(self, Tokens tokens) except -1: cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width) words = [w.orth_ for w in tokens] beam.initialize(_init_state, tokens.length, tokens.data) beam.check_done(_check_final_state, NULL) while not beam.is_done: self._advance_beam(beam, None, False, words) state = beam.at(0) self.moves.finalize_state(state) tokens.set_parse(state._sent) _cleanup(beam) def _beam_train(self, Tokens tokens, GoldParse gold_parse): cdef Beam pred = Beam(self.moves.n_moves, self.cfg.beam_width) pred.initialize(_init_state, tokens.length, tokens.data) pred.check_done(_check_final_state, NULL) cdef Beam gold = Beam(self.moves.n_moves, self.cfg.beam_width) gold.initialize(_init_state, tokens.length, tokens.data) gold.check_done(_check_final_state, NULL) violn = MaxViolation() words = [w.orth_ for w in tokens] while not pred.is_done and not gold.is_done: self._advance_beam(pred, gold_parse, False, words) self._advance_beam(gold, gold_parse, True, words) violn.check(pred, gold) if pred.loss >= 1: counts = {clas: {} for clas in range(self.model.n_classes)} self._count_feats(counts, tokens, violn.g_hist, 1) self._count_feats(counts, tokens, violn.p_hist, -1) else: counts = {} self.model._model.update(counts) _cleanup(pred) _cleanup(gold) return pred.loss def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold, words): cdef atom_t[CONTEXT_SIZE] context cdef int i, j, cost cdef bint is_valid cdef const Transition* move for i in range(beam.size): stcls = beam.at(i) if not stcls.is_final(): fill_context(context, stcls) self.model.set_scores(beam.scores[i], context) self.moves.set_valid(beam.is_valid[i], stcls) if gold is not None: for i in range(beam.size): stcls = beam.at(i) if not stcls.is_final(): self.moves.set_costs(beam.costs[i], stcls, gold) if follow_gold: for j in range(self.moves.n_moves): beam.is_valid[i][j] *= beam.costs[i][j] == 0 beam.advance(_transition_state, _hash_state, self.moves.c) beam.check_done(_check_final_state, NULL) def _count_feats(self, dict counts, Tokens tokens, list hist, int inc): cdef atom_t[CONTEXT_SIZE] context cdef Pool mem = Pool() cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) self.moves.initialize_state(stcls) cdef class_t clas cdef int n_feats for clas in hist: fill_context(context, stcls) feats = self.model._extractor.get_feats(context, &n_feats) count_feats(counts[clas], feats, n_feats, inc) self.moves.c[clas].do(stcls, self.moves.c[clas].label) """