mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-29 09:23:12 +03:00
Fix pipeline/.
This commit is contained in:
parent
e7cf6c7d9b
commit
1ac29fd8df
|
@ -46,11 +46,18 @@ cdef struct EditTreeC:
|
||||||
bint is_match_node
|
bint is_match_node
|
||||||
NodeC inner
|
NodeC inner
|
||||||
|
|
||||||
cdef inline EditTreeC edittree_new_match(len_t prefix_len, len_t suffix_len,
|
cdef inline EditTreeC edittree_new_match(
|
||||||
uint32_t prefix_tree, uint32_t suffix_tree):
|
len_t prefix_len,
|
||||||
cdef MatchNodeC match_node = MatchNodeC(prefix_len=prefix_len,
|
len_t suffix_len,
|
||||||
suffix_len=suffix_len, prefix_tree=prefix_tree,
|
uint32_t prefix_tree,
|
||||||
suffix_tree=suffix_tree)
|
uint32_t suffix_tree
|
||||||
|
):
|
||||||
|
cdef MatchNodeC match_node = MatchNodeC(
|
||||||
|
prefix_len=prefix_len,
|
||||||
|
suffix_len=suffix_len,
|
||||||
|
prefix_tree=prefix_tree,
|
||||||
|
suffix_tree=suffix_tree
|
||||||
|
)
|
||||||
cdef NodeC inner = NodeC(match_node=match_node)
|
cdef NodeC inner = NodeC(match_node=match_node)
|
||||||
return EditTreeC(is_match_node=True, inner=inner)
|
return EditTreeC(is_match_node=True, inner=inner)
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,6 @@ from libc.string cimport memset
|
||||||
from libcpp.pair cimport pair
|
from libcpp.pair cimport pair
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from ...typedefs cimport hash_t
|
from ...typedefs cimport hash_t
|
||||||
|
|
||||||
from ... import util
|
from ... import util
|
||||||
|
@ -25,17 +23,16 @@ cdef LCS find_lcs(str source, str target):
|
||||||
target (str): The second string.
|
target (str): The second string.
|
||||||
RETURNS (LCS): The spans of the longest common subsequences.
|
RETURNS (LCS): The spans of the longest common subsequences.
|
||||||
"""
|
"""
|
||||||
cdef Py_ssize_t source_len = len(source)
|
|
||||||
cdef Py_ssize_t target_len = len(target)
|
cdef Py_ssize_t target_len = len(target)
|
||||||
cdef size_t longest_align = 0;
|
cdef size_t longest_align = 0
|
||||||
cdef int source_idx, target_idx
|
cdef int source_idx, target_idx
|
||||||
cdef LCS lcs
|
cdef LCS lcs
|
||||||
cdef Py_UCS4 source_cp, target_cp
|
cdef Py_UCS4 source_cp, target_cp
|
||||||
|
|
||||||
memset(&lcs, 0, sizeof(lcs))
|
memset(&lcs, 0, sizeof(lcs))
|
||||||
|
|
||||||
cdef vector[size_t] prev_aligns = vector[size_t](target_len);
|
cdef vector[size_t] prev_aligns = vector[size_t](target_len)
|
||||||
cdef vector[size_t] cur_aligns = vector[size_t](target_len);
|
cdef vector[size_t] cur_aligns = vector[size_t](target_len)
|
||||||
|
|
||||||
for (source_idx, source_cp) in enumerate(source):
|
for (source_idx, source_cp) in enumerate(source):
|
||||||
for (target_idx, target_cp) in enumerate(target):
|
for (target_idx, target_cp) in enumerate(target):
|
||||||
|
@ -89,7 +86,7 @@ cdef class EditTrees:
|
||||||
cdef LCS lcs = find_lcs(form, lemma)
|
cdef LCS lcs = find_lcs(form, lemma)
|
||||||
|
|
||||||
cdef EditTreeC tree
|
cdef EditTreeC tree
|
||||||
cdef uint32_t tree_id, prefix_tree, suffix_tree
|
cdef uint32_t prefix_tree, suffix_tree
|
||||||
if lcs_is_empty(lcs):
|
if lcs_is_empty(lcs):
|
||||||
tree = edittree_new_subst(self.strings.add(form), self.strings.add(lemma))
|
tree = edittree_new_subst(self.strings.add(form), self.strings.add(lemma))
|
||||||
else:
|
else:
|
||||||
|
@ -289,6 +286,7 @@ def _tree2dict(tree):
|
||||||
tree = tree["inner"]["subst_node"]
|
tree = tree["inner"]["subst_node"]
|
||||||
return(dict(tree))
|
return(dict(tree))
|
||||||
|
|
||||||
|
|
||||||
def _dict2tree(tree):
|
def _dict2tree(tree):
|
||||||
errors = validate_edit_tree(tree)
|
errors = validate_edit_tree(tree)
|
||||||
if errors:
|
if errors:
|
||||||
|
|
|
@ -1,17 +1,14 @@
|
||||||
# cython: infer_types=True
|
# cython: infer_types=True
|
||||||
# cython: profile=True
|
# cython: profile=True
|
||||||
cimport numpy as np
|
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from cpython.ref cimport Py_XDECREF, PyObject
|
|
||||||
from thinc.extra.search cimport Beam
|
from thinc.extra.search cimport Beam
|
||||||
|
|
||||||
from thinc.extra.search import MaxViolation
|
from thinc.extra.search import MaxViolation
|
||||||
|
|
||||||
from thinc.extra.search cimport MaxViolation
|
from thinc.extra.search cimport MaxViolation
|
||||||
|
|
||||||
from ...typedefs cimport class_t, hash_t
|
from ...typedefs cimport class_t
|
||||||
from .transition_system cimport Transition, TransitionSystem
|
from .transition_system cimport Transition, TransitionSystem
|
||||||
|
|
||||||
from ...errors import Errors
|
from ...errors import Errors
|
||||||
|
@ -146,7 +143,6 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
|
||||||
cdef MaxViolation violn
|
cdef MaxViolation violn
|
||||||
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
|
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
|
||||||
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
|
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
|
||||||
cdef StateClass state
|
|
||||||
beam_maps = []
|
beam_maps = []
|
||||||
backprops = []
|
backprops = []
|
||||||
violns = [MaxViolation() for _ in range(len(states))]
|
violns = [MaxViolation() for _ in range(len(states))]
|
||||||
|
|
|
@ -277,7 +277,6 @@ cdef cppclass StateC:
|
||||||
|
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
|
||||||
int n_L(int head) nogil const:
|
int n_L(int head) nogil const:
|
||||||
return n_arcs(this._left_arcs, head)
|
return n_arcs(this._left_arcs, head)
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from ...strings cimport hash_string
|
||||||
from ...structs cimport TokenC
|
from ...structs cimport TokenC
|
||||||
from ...tokens.doc cimport Doc, set_children_from_heads
|
from ...tokens.doc cimport Doc, set_children_from_heads
|
||||||
from ...tokens.token cimport MISSING_DEP
|
from ...tokens.token cimport MISSING_DEP
|
||||||
from ...typedefs cimport attr_t, hash_t
|
from ...typedefs cimport attr_t
|
||||||
|
|
||||||
from ...training import split_bilu_label
|
from ...training import split_bilu_label
|
||||||
|
|
||||||
|
@ -68,8 +68,9 @@ cdef struct GoldParseStateC:
|
||||||
weight_t pop_cost
|
weight_t pop_cost
|
||||||
|
|
||||||
|
|
||||||
cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state,
|
cdef GoldParseStateC create_gold_state(
|
||||||
heads, labels, sent_starts) except *:
|
Pool mem, const StateC* state, heads, labels, sent_starts
|
||||||
|
) except *:
|
||||||
cdef GoldParseStateC gs
|
cdef GoldParseStateC gs
|
||||||
gs.length = len(heads)
|
gs.length = len(heads)
|
||||||
gs.stride = 1
|
gs.stride = 1
|
||||||
|
@ -82,7 +83,7 @@ cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state,
|
||||||
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
|
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
|
||||||
|
|
||||||
for i, is_sent_start in enumerate(sent_starts):
|
for i, is_sent_start in enumerate(sent_starts):
|
||||||
if is_sent_start == True:
|
if is_sent_start is True:
|
||||||
gs.state_bits[i] = set_state_flag(
|
gs.state_bits[i] = set_state_flag(
|
||||||
gs.state_bits[i],
|
gs.state_bits[i],
|
||||||
IS_SENT_START,
|
IS_SENT_START,
|
||||||
|
@ -210,6 +211,7 @@ cdef class ArcEagerGold:
|
||||||
def update(self, StateClass stcls):
|
def update(self, StateClass stcls):
|
||||||
update_gold_state(&self.c, stcls.c)
|
update_gold_state(&self.c, stcls.c)
|
||||||
|
|
||||||
|
|
||||||
def _get_aligned_sent_starts(example):
|
def _get_aligned_sent_starts(example):
|
||||||
"""Get list of SENT_START attributes aligned to the predicted tokenization.
|
"""Get list of SENT_START attributes aligned to the predicted tokenization.
|
||||||
If the reference has not sentence starts, return a list of None values.
|
If the reference has not sentence starts, return a list of None values.
|
||||||
|
@ -524,7 +526,6 @@ cdef class Break:
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||||
cdef int i
|
|
||||||
if st.buffer_length() < 2:
|
if st.buffer_length() < 2:
|
||||||
return False
|
return False
|
||||||
elif st.B(1) != st.B(0) + 1:
|
elif st.B(1) != st.B(0) + 1:
|
||||||
|
@ -556,8 +557,8 @@ cdef class Break:
|
||||||
cost -= 1
|
cost -= 1
|
||||||
if gold.heads[si] == b0:
|
if gold.heads[si] == b0:
|
||||||
cost -= 1
|
cost -= 1
|
||||||
if not is_sent_start(gold, state.B(1)) \
|
if not is_sent_start(gold, state.B(1)) and\
|
||||||
and not is_sent_start_unknown(gold, state.B(1)):
|
not is_sent_start_unknown(gold, state.B(1)):
|
||||||
cost += 1
|
cost += 1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
@ -803,7 +804,6 @@ cdef class ArcEager(TransitionSystem):
|
||||||
raise TypeError(Errors.E909.format(name="ArcEagerGold"))
|
raise TypeError(Errors.E909.format(name="ArcEagerGold"))
|
||||||
cdef ArcEagerGold gold_ = gold
|
cdef ArcEagerGold gold_ = gold
|
||||||
gold_state = gold_.c
|
gold_state = gold_.c
|
||||||
n_gold = 0
|
|
||||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||||
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
|
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
|
||||||
else:
|
else:
|
||||||
|
@ -875,7 +875,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
print("Gold")
|
print("Gold")
|
||||||
for token in example.y:
|
for token in example.y:
|
||||||
print(token.i, token.text, token.dep_, token.head.text)
|
print(token.i, token.text, token.dep_, token.head.text)
|
||||||
aligned_heads, aligned_labels = example.get_aligned_parse()
|
aligned_heads, _aligned_labels = example.get_aligned_parse()
|
||||||
print("Aligned heads")
|
print("Aligned heads")
|
||||||
for i, head in enumerate(aligned_heads):
|
for i, head in enumerate(aligned_heads):
|
||||||
print(example.x[i], example.x[head] if head is not None else "__")
|
print(example.x[i], example.x[head] if head is not None else "__")
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
from libc.stdint cimport int32_t
|
from libc.stdint cimport int32_t
|
||||||
|
|
||||||
|
@ -14,7 +11,7 @@ from ...tokens.span import Span
|
||||||
|
|
||||||
from ...attrs cimport IS_SPACE
|
from ...attrs cimport IS_SPACE
|
||||||
from ...lexeme cimport Lexeme
|
from ...lexeme cimport Lexeme
|
||||||
from ...structs cimport SpanC, TokenC
|
from ...structs cimport SpanC
|
||||||
from ...tokens.span cimport Span
|
from ...tokens.span cimport Span
|
||||||
from ...typedefs cimport attr_t, weight_t
|
from ...typedefs cimport attr_t, weight_t
|
||||||
|
|
||||||
|
@ -145,7 +142,6 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
for entity_type in kwargs.get('entity_types', []):
|
for entity_type in kwargs.get('entity_types', []):
|
||||||
for action in (BEGIN, IN, LAST, UNIT):
|
for action in (BEGIN, IN, LAST, UNIT):
|
||||||
actions[action][entity_type] = 1
|
actions[action][entity_type] = 1
|
||||||
moves = ('M', 'B', 'I', 'L', 'U')
|
|
||||||
for example in kwargs.get('examples', []):
|
for example in kwargs.get('examples', []):
|
||||||
for token in example.y:
|
for token in example.y:
|
||||||
ent_type = token.ent_type_
|
ent_type = token.ent_type_
|
||||||
|
@ -325,7 +321,6 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
raise TypeError(Errors.E909.format(name="BiluoGold"))
|
raise TypeError(Errors.E909.format(name="BiluoGold"))
|
||||||
cdef BiluoGold gold_ = gold
|
cdef BiluoGold gold_ = gold
|
||||||
gold_state = gold_.c
|
gold_state = gold_.c
|
||||||
n_gold = 0
|
|
||||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||||
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
|
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
|
||||||
else:
|
else:
|
||||||
|
@ -486,10 +481,8 @@ cdef class In:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||||
gold = <GoldNERStateC*>_gold
|
gold = <GoldNERStateC*>_gold
|
||||||
move = IN
|
|
||||||
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
|
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
|
||||||
cdef int g_act = gold.ner[s.B(0)].move
|
cdef int g_act = gold.ner[s.B(0)].move
|
||||||
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
|
||||||
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
||||||
|
|
||||||
if g_act == MISSING:
|
if g_act == MISSING:
|
||||||
|
@ -549,12 +542,10 @@ cdef class Last:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||||
gold = <GoldNERStateC*>_gold
|
gold = <GoldNERStateC*>_gold
|
||||||
move = LAST
|
|
||||||
b0 = s.B(0)
|
b0 = s.B(0)
|
||||||
ent_start = s.E(0)
|
ent_start = s.E(0)
|
||||||
|
|
||||||
cdef int g_act = gold.ner[b0].move
|
cdef int g_act = gold.ner[b0].move
|
||||||
cdef attr_t g_tag = gold.ner[b0].label
|
|
||||||
|
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
|
|
||||||
|
@ -652,7 +643,6 @@ cdef class Unit:
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef class Out:
|
cdef class Out:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||||
|
@ -675,7 +665,6 @@ cdef class Out:
|
||||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||||
gold = <GoldNERStateC*>_gold
|
gold = <GoldNERStateC*>_gold
|
||||||
cdef int g_act = gold.ner[s.B(0)].move
|
cdef int g_act = gold.ner[s.B(0)].move
|
||||||
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
|
||||||
cdef weight_t cost = 0
|
cdef weight_t cost = 0
|
||||||
if g_act == MISSING:
|
if g_act == MISSING:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -125,14 +125,17 @@ def decompose(label):
|
||||||
def is_decorated(label):
|
def is_decorated(label):
|
||||||
return DELIMITER in label
|
return DELIMITER in label
|
||||||
|
|
||||||
|
|
||||||
def count_decorated_labels(gold_data):
|
def count_decorated_labels(gold_data):
|
||||||
freqs = {}
|
freqs = {}
|
||||||
for example in gold_data:
|
for example in gold_data:
|
||||||
proj_heads, deco_deps = projectivize(example.get_aligned("HEAD"),
|
proj_heads, deco_deps = projectivize(example.get_aligned("HEAD"),
|
||||||
example.get_aligned("DEP"))
|
example.get_aligned("DEP"))
|
||||||
# set the label to ROOT for each root dependent
|
# set the label to ROOT for each root dependent
|
||||||
deco_deps = ['ROOT' if head == i else deco_deps[i]
|
deco_deps = [
|
||||||
for i, head in enumerate(proj_heads)]
|
'ROOT' if head == i else deco_deps[i]
|
||||||
|
for i, head in enumerate(proj_heads)
|
||||||
|
]
|
||||||
# count label frequencies
|
# count label frequencies
|
||||||
for label in deco_deps:
|
for label in deco_deps:
|
||||||
if is_decorated(label):
|
if is_decorated(label):
|
||||||
|
@ -160,9 +163,9 @@ def projectivize(heads, labels):
|
||||||
|
|
||||||
|
|
||||||
cdef vector[int] _heads_to_c(heads):
|
cdef vector[int] _heads_to_c(heads):
|
||||||
cdef vector[int] c_heads;
|
cdef vector[int] c_heads
|
||||||
for head in heads:
|
for head in heads:
|
||||||
if head == None:
|
if head is None:
|
||||||
c_heads.push_back(-1)
|
c_heads.push_back(-1)
|
||||||
else:
|
else:
|
||||||
assert head < len(heads)
|
assert head < len(heads)
|
||||||
|
@ -199,6 +202,7 @@ def _decorate(heads, proj_heads, labels):
|
||||||
deco_labels.append(labels[tokenid])
|
deco_labels.append(labels[tokenid])
|
||||||
return deco_labels
|
return deco_labels
|
||||||
|
|
||||||
|
|
||||||
def get_smallest_nonproj_arc_slow(heads):
|
def get_smallest_nonproj_arc_slow(heads):
|
||||||
cdef vector[int] c_heads = _heads_to_c(heads)
|
cdef vector[int] c_heads = _heads_to_c(heads)
|
||||||
return _get_smallest_nonproj_arc(c_heads)
|
return _get_smallest_nonproj_arc(c_heads)
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
# cython: infer_types=True
|
# cython: infer_types=True
|
||||||
import numpy
|
|
||||||
|
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
|
|
||||||
from ...tokens.doc cimport Doc
|
from ...tokens.doc cimport Doc
|
||||||
|
|
|
@ -20,11 +20,15 @@ cdef struct Transition:
|
||||||
int (*do)(StateC* state, attr_t label) nogil
|
int (*do)(StateC* state, attr_t label) nogil
|
||||||
|
|
||||||
|
|
||||||
ctypedef weight_t (*get_cost_func_t)(const StateC* state, const void* gold,
|
ctypedef weight_t (*get_cost_func_t)(
|
||||||
attr_tlabel) nogil
|
const StateC* state, const void* gold, attr_tlabel
|
||||||
ctypedef weight_t (*move_cost_func_t)(const StateC* state, const void* gold) nogil
|
) nogil
|
||||||
ctypedef weight_t (*label_cost_func_t)(const StateC* state, const void*
|
ctypedef weight_t (*move_cost_func_t)(
|
||||||
gold, attr_t label) nogil
|
const StateC* state, const void* gold
|
||||||
|
) nogil
|
||||||
|
ctypedef weight_t (*label_cost_func_t)(
|
||||||
|
const StateC* state, const void* gold, attr_t label
|
||||||
|
) nogil
|
||||||
|
|
||||||
ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil
|
ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil
|
||||||
|
|
||||||
|
|
|
@ -8,9 +8,7 @@ from collections import Counter
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
from ...structs cimport TokenC
|
from ...structs cimport TokenC
|
||||||
from ...tokens.doc cimport Doc
|
|
||||||
from ...typedefs cimport attr_t, weight_t
|
from ...typedefs cimport attr_t, weight_t
|
||||||
from . cimport _beam_utils
|
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
|
|
||||||
from ... import util
|
from ... import util
|
||||||
|
@ -231,7 +229,6 @@ cdef class TransitionSystem:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_bytes(self, exclude=tuple()):
|
def to_bytes(self, exclude=tuple()):
|
||||||
transitions = []
|
|
||||||
serializers = {
|
serializers = {
|
||||||
'moves': lambda: srsly.json_dumps(self.labels),
|
'moves': lambda: srsly.json_dumps(self.labels),
|
||||||
'strings': lambda: self.strings.to_bytes(),
|
'strings': lambda: self.strings.to_bytes(),
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# cython: infer_types=True, profile=True, binding=True
|
# cython: infer_types=True, profile=True, binding=True
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Callable, Iterable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from thinc.api import Config, Model
|
from thinc.api import Config, Model
|
||||||
|
|
||||||
|
@ -124,6 +124,7 @@ def make_parser(
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"beam_parser",
|
"beam_parser",
|
||||||
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
|
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import Callable, Dict, Optional, Union
|
from typing import Callable, Dict, Optional, Union
|
||||||
|
|
||||||
import srsly
|
|
||||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
|
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
|
||||||
|
|
||||||
from ..morphology cimport Morphology
|
from ..morphology cimport Morphology
|
||||||
|
@ -14,10 +13,8 @@ from ..errors import Errors
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..parts_of_speech import IDS as POS_IDS
|
from ..parts_of_speech import IDS as POS_IDS
|
||||||
from ..scorer import Scorer
|
from ..scorer import Scorer
|
||||||
from ..symbols import POS
|
|
||||||
from ..training import validate_examples, validate_get_examples
|
from ..training import validate_examples, validate_get_examples
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
from .pipe import deserialize_config
|
|
||||||
from .tagger import Tagger
|
from .tagger import Tagger
|
||||||
|
|
||||||
# See #9050
|
# See #9050
|
||||||
|
@ -76,8 +73,11 @@ def morphologizer_score(examples, **kwargs):
|
||||||
results = {}
|
results = {}
|
||||||
results.update(Scorer.score_token_attr(examples, "pos", **kwargs))
|
results.update(Scorer.score_token_attr(examples, "pos", **kwargs))
|
||||||
results.update(Scorer.score_token_attr(examples, "morph", getter=morph_key_getter, **kwargs))
|
results.update(Scorer.score_token_attr(examples, "morph", getter=morph_key_getter, **kwargs))
|
||||||
results.update(Scorer.score_token_attr_per_feat(examples,
|
results.update(
|
||||||
"morph", getter=morph_key_getter, **kwargs))
|
Scorer.score_token_attr_per_feat(
|
||||||
|
examples, "morph", getter=morph_key_getter, **kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@ -233,7 +233,6 @@ class Morphologizer(Tagger):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef Vocab vocab = self.vocab
|
|
||||||
cdef bint overwrite = self.cfg["overwrite"]
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
cdef bint extend = self.cfg["extend"]
|
cdef bint extend = self.cfg["extend"]
|
||||||
labels = self.labels
|
labels = self.labels
|
||||||
|
|
|
@ -4,13 +4,10 @@ from typing import Optional
|
||||||
import numpy
|
import numpy
|
||||||
from thinc.api import Config, CosineDistance, Model, set_dropout_rate, to_categorical
|
from thinc.api import Config, CosineDistance, Model, set_dropout_rate, to_categorical
|
||||||
|
|
||||||
from ..tokens.doc cimport Doc
|
from ..attrs import ID
|
||||||
|
|
||||||
from ..attrs import ID, POS
|
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..training import validate_examples
|
from ..training import validate_examples
|
||||||
from ._parser_internals import nonproj
|
|
||||||
from .tagger import Tagger
|
from .tagger import Tagger
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe import TrainablePipe
|
||||||
|
|
||||||
|
@ -103,10 +100,9 @@ class MultitaskObjective(Tagger):
|
||||||
cdef int idx = 0
|
cdef int idx = 0
|
||||||
correct = numpy.zeros((scores.shape[0],), dtype="i")
|
correct = numpy.zeros((scores.shape[0],), dtype="i")
|
||||||
guesses = scores.argmax(axis=1)
|
guesses = scores.argmax(axis=1)
|
||||||
docs = [eg.predicted for eg in examples]
|
|
||||||
for i, eg in enumerate(examples):
|
for i, eg in enumerate(examples):
|
||||||
# Handles alignment for tokenization differences
|
# Handles alignment for tokenization differences
|
||||||
doc_annots = eg.get_aligned() # TODO
|
_doc_annots = eg.get_aligned() # TODO
|
||||||
for j in range(len(eg.predicted)):
|
for j in range(len(eg.predicted)):
|
||||||
tok_annots = {key: values[j] for key, values in tok_annots.items()}
|
tok_annots = {key: values[j] for key, values in tok_annots.items()}
|
||||||
label = self.make_label(j, tok_annots)
|
label = self.make_label(j, tok_annots)
|
||||||
|
@ -206,7 +202,6 @@ class ClozeMultitask(TrainablePipe):
|
||||||
losses[self.name] = 0.
|
losses[self.name] = 0.
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
validate_examples(examples, "ClozeMultitask.rehearse")
|
validate_examples(examples, "ClozeMultitask.rehearse")
|
||||||
docs = [eg.predicted for eg in examples]
|
|
||||||
predictions, bp_predictions = self.model.begin_update()
|
predictions, bp_predictions = self.model.begin_update()
|
||||||
loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
|
loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
|
||||||
bp_predictions(d_predictions)
|
bp_predictions(d_predictions)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# cython: infer_types=True, profile=True, binding=True
|
# cython: infer_types=True, profile=True, binding=True
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Callable, Iterable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from thinc.api import Config, Model
|
from thinc.api import Config, Model
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ from ._parser_internals.ner cimport BiluoPushDown
|
||||||
from .transition_parser cimport Parser
|
from .transition_parser cimport Parser
|
||||||
|
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..scorer import PRFScore, get_ner_prf
|
from ..scorer import get_ner_prf
|
||||||
from ..training import remove_bilu_prefix
|
from ..training import remove_bilu_prefix
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
|
|
||||||
|
@ -100,6 +100,7 @@ def make_ner(
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"beam_ner",
|
"beam_ner",
|
||||||
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
|
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# cython: infer_types=True, profile=True, binding=True
|
# cython: infer_types=True, profile=True, binding=True
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Callable, Dict, Iterable, Iterator, Optional, Tuple, Union
|
from typing import Callable, Dict, Iterable, Iterator, Tuple, Union
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
|
|
|
@ -7,13 +7,13 @@ from ..tokens.doc cimport Doc
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..scorer import Scorer
|
|
||||||
from .pipe import Pipe
|
from .pipe import Pipe
|
||||||
from .senter import senter_score
|
from .senter import senter_score
|
||||||
|
|
||||||
# see #9050
|
# see #9050
|
||||||
BACKWARD_OVERWRITE = False
|
BACKWARD_OVERWRITE = False
|
||||||
|
|
||||||
|
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"sentencizer",
|
"sentencizer",
|
||||||
assigns=["token.is_sent_start", "doc.sents"],
|
assigns=["token.is_sent_start", "doc.sents"],
|
||||||
|
@ -36,7 +36,8 @@ class Sentencizer(Pipe):
|
||||||
DOCS: https://spacy.io/api/sentencizer
|
DOCS: https://spacy.io/api/sentencizer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
|
default_punct_chars = [
|
||||||
|
'!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
|
||||||
'।', '॥', '၊', '။', '።', '፧', '፨', '᙮', '᜵', '᜶', '᠃', '᠉', '᥄',
|
'।', '॥', '၊', '။', '።', '፧', '፨', '᙮', '᜵', '᜶', '᠃', '᠉', '᥄',
|
||||||
'᥅', '᪨', '᪩', '᪪', '᪫', '᭚', '᭛', '᭞', '᭟', '᰻', '᰼', '᱾', '᱿',
|
'᥅', '᪨', '᪩', '᪪', '᪫', '᭚', '᭛', '᭞', '᭟', '᰻', '᰼', '᱾', '᱿',
|
||||||
'‼', '‽', '⁇', '⁈', '⁉', '⸮', '⸼', '꓿', '꘎', '꘏', '꛳', '꛷', '꡶',
|
'‼', '‽', '⁇', '⁈', '⁉', '⸮', '⸼', '꓿', '꘎', '꘏', '꛳', '꛷', '꡶',
|
||||||
|
@ -46,7 +47,8 @@ class Sentencizer(Pipe):
|
||||||
'𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
|
'𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
|
||||||
'𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
|
'𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
|
||||||
'𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈',
|
'𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈',
|
||||||
'。', '。']
|
'。', '。'
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -128,7 +130,6 @@ class Sentencizer(Pipe):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef int idx = 0
|
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
for j, tag_id in enumerate(doc_tag_ids):
|
for j, tag_id in enumerate(doc_tag_ids):
|
||||||
|
@ -169,7 +170,6 @@ class Sentencizer(Pipe):
|
||||||
path = path.with_suffix(".json")
|
path = path.with_suffix(".json")
|
||||||
srsly.write_json(path, {"punct_chars": list(self.punct_chars), "overwrite": self.overwrite})
|
srsly.write_json(path, {"punct_chars": list(self.punct_chars), "overwrite": self.overwrite})
|
||||||
|
|
||||||
|
|
||||||
def from_disk(self, path, *, exclude=tuple()):
|
def from_disk(self, path, *, exclude=tuple()):
|
||||||
"""Load the sentencizer from disk.
|
"""Load the sentencizer from disk.
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import srsly
|
|
||||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
|
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
|
||||||
|
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
|
|
|
@ -1,26 +1,17 @@
|
||||||
# cython: infer_types=True, profile=True, binding=True
|
# cython: infer_types=True, profile=True, binding=True
|
||||||
import warnings
|
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import srsly
|
|
||||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy, set_dropout_rate
|
from thinc.api import Config, Model, SequenceCategoricalCrossentropy, set_dropout_rate
|
||||||
from thinc.types import Floats2d
|
|
||||||
|
|
||||||
from ..morphology cimport Morphology
|
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..vocab cimport Vocab
|
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..attrs import ID, POS
|
|
||||||
from ..errors import Errors, Warnings
|
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..parts_of_speech import X
|
|
||||||
from ..scorer import Scorer
|
from ..scorer import Scorer
|
||||||
from ..training import validate_examples, validate_get_examples
|
from ..training import validate_examples, validate_get_examples
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
from .pipe import deserialize_config
|
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe import TrainablePipe
|
||||||
|
|
||||||
# See #9050
|
# See #9050
|
||||||
|
@ -169,7 +160,6 @@ class Tagger(TrainablePipe):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef Vocab vocab = self.vocab
|
|
||||||
cdef bint overwrite = self.cfg["overwrite"]
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
labels = self.labels
|
labels = self.labels
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
|
|
|
@ -13,8 +13,18 @@ cdef class Parser(TrainablePipe):
|
||||||
cdef readonly TransitionSystem moves
|
cdef readonly TransitionSystem moves
|
||||||
cdef public object _multitasks
|
cdef public object _multitasks
|
||||||
|
|
||||||
cdef void _parseC(self, CBlas cblas, StateC** states,
|
cdef void _parseC(
|
||||||
WeightsC weights, SizesC sizes) nogil
|
self,
|
||||||
|
CBlas cblas,
|
||||||
|
StateC** states,
|
||||||
|
WeightsC weights,
|
||||||
|
SizesC sizes
|
||||||
|
) nogil
|
||||||
|
|
||||||
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
cdef void c_transition_batch(
|
||||||
int nr_class, int batch_size) nogil
|
self,
|
||||||
|
StateC** states,
|
||||||
|
const float* scores,
|
||||||
|
int nr_class,
|
||||||
|
int batch_size
|
||||||
|
) nogil
|
||||||
|
|
|
@ -7,17 +7,14 @@ from cymem.cymem cimport Pool
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
from libc.stdlib cimport calloc, free
|
from libc.stdlib cimport calloc, free
|
||||||
from libc.string cimport memcpy, memset
|
from libc.string cimport memset
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import CupyOps, NumpyOps, get_ops, set_dropout_rate
|
from thinc.api import CupyOps, NumpyOps, set_dropout_rate
|
||||||
|
|
||||||
from thinc.extra.search cimport Beam
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import numpy.random
|
import numpy.random
|
||||||
|
@ -42,7 +39,7 @@ from .trainable_pipe import TrainablePipe
|
||||||
from ._parser_internals cimport _beam_utils
|
from ._parser_internals cimport _beam_utils
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors
|
||||||
from ..training import validate_examples, validate_get_examples
|
from ..training import validate_examples, validate_get_examples
|
||||||
from ._parser_internals import _beam_utils
|
from ._parser_internals import _beam_utils
|
||||||
|
|
||||||
|
@ -258,7 +255,6 @@ cdef class Parser(TrainablePipe):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_handler(self.name, self, batch_in_order, e)
|
error_handler(self.name, self, batch_in_order, e)
|
||||||
|
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
@ -300,8 +296,6 @@ cdef class Parser(TrainablePipe):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||||
cdef Beam beam
|
|
||||||
cdef Doc doc
|
|
||||||
self._ensure_labels_are_added(docs)
|
self._ensure_labels_are_added(docs)
|
||||||
batch = _beam_utils.BeamBatch(
|
batch = _beam_utils.BeamBatch(
|
||||||
self.moves,
|
self.moves,
|
||||||
|
@ -321,16 +315,18 @@ cdef class Parser(TrainablePipe):
|
||||||
del model
|
del model
|
||||||
return list(batch)
|
return list(batch)
|
||||||
|
|
||||||
cdef void _parseC(self, CBlas cblas, StateC** states,
|
cdef void _parseC(
|
||||||
WeightsC weights, SizesC sizes) nogil:
|
self, CBlas cblas, StateC** states, WeightsC weights, SizesC sizes
|
||||||
cdef int i, j
|
) nogil:
|
||||||
|
cdef int i
|
||||||
cdef vector[StateC*] unfinished
|
cdef vector[StateC*] unfinished
|
||||||
cdef ActivationsC activations = alloc_activations(sizes)
|
cdef ActivationsC activations = alloc_activations(sizes)
|
||||||
while sizes.states >= 1:
|
while sizes.states >= 1:
|
||||||
predict_states(cblas, &activations, states, &weights, sizes)
|
predict_states(cblas, &activations, states, &weights, sizes)
|
||||||
# Validate actions, argmax, take action.
|
# Validate actions, argmax, take action.
|
||||||
self.c_transition_batch(states,
|
self.c_transition_batch(
|
||||||
activations.scores, sizes.classes, sizes.states)
|
states, activations.scores, sizes.classes, sizes.states
|
||||||
|
)
|
||||||
for i in range(sizes.states):
|
for i in range(sizes.states):
|
||||||
if not states[i].is_final():
|
if not states[i].is_final():
|
||||||
unfinished.push_back(states[i])
|
unfinished.push_back(states[i])
|
||||||
|
@ -342,7 +338,6 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
def set_annotations(self, docs, states_or_beams):
|
def set_annotations(self, docs, states_or_beams):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef Beam beam
|
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
states = _beam_utils.collect_states(states_or_beams, docs)
|
states = _beam_utils.collect_states(states_or_beams, docs)
|
||||||
for i, (state, doc) in enumerate(zip(states, docs)):
|
for i, (state, doc) in enumerate(zip(states, docs)):
|
||||||
|
@ -359,8 +354,13 @@ cdef class Parser(TrainablePipe):
|
||||||
self.c_transition_batch(&c_states[0], c_scores, scores.shape[1], scores.shape[0])
|
self.c_transition_batch(&c_states[0], c_scores, scores.shape[1], scores.shape[0])
|
||||||
return [state for state in states if not state.c.is_final()]
|
return [state for state in states if not state.c.is_final()]
|
||||||
|
|
||||||
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
cdef void c_transition_batch(
|
||||||
int nr_class, int batch_size) nogil:
|
self,
|
||||||
|
StateC** states,
|
||||||
|
const float* scores,
|
||||||
|
int nr_class,
|
||||||
|
int batch_size
|
||||||
|
) nogil:
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||||
with gil:
|
with gil:
|
||||||
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
|
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
|
||||||
|
@ -380,7 +380,6 @@ cdef class Parser(TrainablePipe):
|
||||||
free(is_valid)
|
free(is_valid)
|
||||||
|
|
||||||
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
||||||
cdef StateClass state
|
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.)
|
losses.setdefault(self.name, 0.)
|
||||||
|
@ -420,7 +419,6 @@ cdef class Parser(TrainablePipe):
|
||||||
return losses
|
return losses
|
||||||
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
|
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
|
||||||
|
|
||||||
all_states = list(states)
|
|
||||||
states_golds = list(zip(states, golds))
|
states_golds = list(zip(states, golds))
|
||||||
n_moves = 0
|
n_moves = 0
|
||||||
while states_golds:
|
while states_golds:
|
||||||
|
@ -500,8 +498,16 @@ cdef class Parser(TrainablePipe):
|
||||||
del tutor
|
del tutor
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def update_beam(self, examples, *, beam_width,
|
def update_beam(
|
||||||
drop=0., sgd=None, losses=None, beam_density=0.0):
|
self,
|
||||||
|
examples,
|
||||||
|
*,
|
||||||
|
beam_width,
|
||||||
|
drop=0.,
|
||||||
|
sgd=None,
|
||||||
|
losses=None,
|
||||||
|
beam_density=0.0
|
||||||
|
):
|
||||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||||
if not states:
|
if not states:
|
||||||
return losses
|
return losses
|
||||||
|
@ -531,8 +537,9 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
|
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
|
||||||
costs = <float*>mem.alloc(self.moves.n_moves, sizeof(float))
|
costs = <float*>mem.alloc(self.moves.n_moves, sizeof(float))
|
||||||
cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves),
|
cdef np.ndarray d_scores = numpy.zeros(
|
||||||
dtype='f', order='C')
|
(len(states), self.moves.n_moves), dtype='f', order='C'
|
||||||
|
)
|
||||||
c_d_scores = <float*>d_scores.data
|
c_d_scores = <float*>d_scores.data
|
||||||
unseen_classes = self.model.attrs["unseen_classes"]
|
unseen_classes = self.model.attrs["unseen_classes"]
|
||||||
for i, (state, gold) in enumerate(zip(states, golds)):
|
for i, (state, gold) in enumerate(zip(states, golds)):
|
||||||
|
@ -542,8 +549,9 @@ cdef class Parser(TrainablePipe):
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
if costs[j] <= 0.0 and j in unseen_classes:
|
if costs[j] <= 0.0 and j in unseen_classes:
|
||||||
unseen_classes.remove(j)
|
unseen_classes.remove(j)
|
||||||
cpu_log_loss(c_d_scores,
|
cpu_log_loss(
|
||||||
costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1]
|
||||||
|
)
|
||||||
c_d_scores += d_scores.shape[1]
|
c_d_scores += d_scores.shape[1]
|
||||||
# Note that we don't normalize this. See comment in update() for why.
|
# Note that we don't normalize this. See comment in update() for why.
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user