Fix pipeline/.

This commit is contained in:
Raphael Mitsch 2023-07-03 14:50:11 +02:00
parent e7cf6c7d9b
commit 1ac29fd8df
21 changed files with 143 additions and 148 deletions

View File

@ -46,11 +46,18 @@ cdef struct EditTreeC:
bint is_match_node
NodeC inner
cdef inline EditTreeC edittree_new_match(len_t prefix_len, len_t suffix_len,
uint32_t prefix_tree, uint32_t suffix_tree):
cdef MatchNodeC match_node = MatchNodeC(prefix_len=prefix_len,
suffix_len=suffix_len, prefix_tree=prefix_tree,
suffix_tree=suffix_tree)
cdef inline EditTreeC edittree_new_match(
len_t prefix_len,
len_t suffix_len,
uint32_t prefix_tree,
uint32_t suffix_tree
):
cdef MatchNodeC match_node = MatchNodeC(
prefix_len=prefix_len,
suffix_len=suffix_len,
prefix_tree=prefix_tree,
suffix_tree=suffix_tree
)
cdef NodeC inner = NodeC(match_node=match_node)
return EditTreeC(is_match_node=True, inner=inner)

View File

@ -5,8 +5,6 @@ from libc.string cimport memset
from libcpp.pair cimport pair
from libcpp.vector cimport vector
from pathlib import Path
from ...typedefs cimport hash_t
from ... import util
@ -25,17 +23,16 @@ cdef LCS find_lcs(str source, str target):
target (str): The second string.
RETURNS (LCS): The spans of the longest common subsequences.
"""
cdef Py_ssize_t source_len = len(source)
cdef Py_ssize_t target_len = len(target)
cdef size_t longest_align = 0;
cdef size_t longest_align = 0
cdef int source_idx, target_idx
cdef LCS lcs
cdef Py_UCS4 source_cp, target_cp
memset(&lcs, 0, sizeof(lcs))
cdef vector[size_t] prev_aligns = vector[size_t](target_len);
cdef vector[size_t] cur_aligns = vector[size_t](target_len);
cdef vector[size_t] prev_aligns = vector[size_t](target_len)
cdef vector[size_t] cur_aligns = vector[size_t](target_len)
for (source_idx, source_cp) in enumerate(source):
for (target_idx, target_cp) in enumerate(target):
@ -89,7 +86,7 @@ cdef class EditTrees:
cdef LCS lcs = find_lcs(form, lemma)
cdef EditTreeC tree
cdef uint32_t tree_id, prefix_tree, suffix_tree
cdef uint32_t prefix_tree, suffix_tree
if lcs_is_empty(lcs):
tree = edittree_new_subst(self.strings.add(form), self.strings.add(lemma))
else:
@ -108,7 +105,7 @@ cdef class EditTrees:
return self._tree_id(tree)
cdef uint32_t _tree_id(self, EditTreeC tree):
# If this tree has been constructed before, return its identifier.
# If this tree has been constructed before, return its identifier.
cdef hash_t hash = edittree_hash(tree)
cdef unordered_map[hash_t, uint32_t].iterator iter = self.map.find(hash)
if iter != self.map.end():
@ -289,6 +286,7 @@ def _tree2dict(tree):
tree = tree["inner"]["subst_node"]
return(dict(tree))
def _dict2tree(tree):
errors = validate_edit_tree(tree)
if errors:

View File

@ -1,17 +1,14 @@
# cython: infer_types=True
# cython: profile=True
cimport numpy as np
import numpy
from cpython.ref cimport Py_XDECREF, PyObject
from thinc.extra.search cimport Beam
from thinc.extra.search import MaxViolation
from thinc.extra.search cimport MaxViolation
from ...typedefs cimport class_t, hash_t
from ...typedefs cimport class_t
from .transition_system cimport Transition, TransitionSystem
from ...errors import Errors
@ -146,7 +143,6 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
cdef MaxViolation violn
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
cdef StateClass state
beam_maps = []
backprops = []
violns = [MaxViolation() for _ in range(len(states))]

View File

@ -277,7 +277,6 @@ cdef cppclass StateC:
return n
int n_L(int head) nogil const:
return n_arcs(this._left_arcs, head)

View File

@ -9,7 +9,7 @@ from ...strings cimport hash_string
from ...structs cimport TokenC
from ...tokens.doc cimport Doc, set_children_from_heads
from ...tokens.token cimport MISSING_DEP
from ...typedefs cimport attr_t, hash_t
from ...typedefs cimport attr_t
from ...training import split_bilu_label
@ -68,8 +68,9 @@ cdef struct GoldParseStateC:
weight_t pop_cost
cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state,
heads, labels, sent_starts) except *:
cdef GoldParseStateC create_gold_state(
Pool mem, const StateC* state, heads, labels, sent_starts
) except *:
cdef GoldParseStateC gs
gs.length = len(heads)
gs.stride = 1
@ -82,7 +83,7 @@ cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state,
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
for i, is_sent_start in enumerate(sent_starts):
if is_sent_start == True:
if is_sent_start is True:
gs.state_bits[i] = set_state_flag(
gs.state_bits[i],
IS_SENT_START,
@ -210,6 +211,7 @@ cdef class ArcEagerGold:
def update(self, StateClass stcls):
update_gold_state(&self.c, stcls.c)
def _get_aligned_sent_starts(example):
"""Get list of SENT_START attributes aligned to the predicted tokenization.
If the reference has not sentence starts, return a list of None values.
@ -524,7 +526,6 @@ cdef class Break:
"""
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int i
if st.buffer_length() < 2:
return False
elif st.B(1) != st.B(0) + 1:
@ -556,8 +557,8 @@ cdef class Break:
cost -= 1
if gold.heads[si] == b0:
cost -= 1
if not is_sent_start(gold, state.B(1)) \
and not is_sent_start_unknown(gold, state.B(1)):
if not is_sent_start(gold, state.B(1)) and\
not is_sent_start_unknown(gold, state.B(1)):
cost += 1
return cost
@ -803,7 +804,6 @@ cdef class ArcEager(TransitionSystem):
raise TypeError(Errors.E909.format(name="ArcEagerGold"))
cdef ArcEagerGold gold_ = gold
gold_state = gold_.c
n_gold = 0
if self.c[i].is_valid(stcls.c, self.c[i].label):
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
else:
@ -875,7 +875,7 @@ cdef class ArcEager(TransitionSystem):
print("Gold")
for token in example.y:
print(token.i, token.text, token.dep_, token.head.text)
aligned_heads, aligned_labels = example.get_aligned_parse()
aligned_heads, _aligned_labels = example.get_aligned_parse()
print("Aligned heads")
for i, head in enumerate(aligned_heads):
print(example.x[i], example.x[head] if head is not None else "__")

View File

@ -1,6 +1,3 @@
import os
import random
from cymem.cymem cimport Pool
from libc.stdint cimport int32_t
@ -14,7 +11,7 @@ from ...tokens.span import Span
from ...attrs cimport IS_SPACE
from ...lexeme cimport Lexeme
from ...structs cimport SpanC, TokenC
from ...structs cimport SpanC
from ...tokens.span cimport Span
from ...typedefs cimport attr_t, weight_t
@ -141,11 +138,10 @@ cdef class BiluoPushDown(TransitionSystem):
OUT: Counter()
}
actions[OUT][''] = 1 # Represents a token predicted to be outside of any entity
actions[UNIT][''] = 1 # Represents a token prohibited to be in an entity
actions[UNIT][''] = 1 # Represents a token prohibited to be in an entity
for entity_type in kwargs.get('entity_types', []):
for action in (BEGIN, IN, LAST, UNIT):
actions[action][entity_type] = 1
moves = ('M', 'B', 'I', 'L', 'U')
for example in kwargs.get('examples', []):
for token in example.y:
ent_type = token.ent_type_
@ -164,7 +160,7 @@ cdef class BiluoPushDown(TransitionSystem):
if token.ent_type:
labels.add(token.ent_type_)
return labels
def move_name(self, int move, attr_t label):
if move == OUT:
return 'O'
@ -325,7 +321,6 @@ cdef class BiluoPushDown(TransitionSystem):
raise TypeError(Errors.E909.format(name="BiluoGold"))
cdef BiluoGold gold_ = gold
gold_state = gold_.c
n_gold = 0
if self.c[i].is_valid(stcls.c, self.c[i].label):
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
else:
@ -486,10 +481,8 @@ cdef class In:
@staticmethod
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold
move = IN
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
if g_act == MISSING:
@ -549,12 +542,10 @@ cdef class Last:
@staticmethod
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold
move = LAST
b0 = s.B(0)
ent_start = s.E(0)
cdef int g_act = gold.ner[b0].move
cdef attr_t g_tag = gold.ner[b0].label
cdef int cost = 0
@ -650,7 +641,6 @@ cdef class Unit:
cost += 1
break
return cost
cdef class Out:
@ -675,7 +665,6 @@ cdef class Out:
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold
cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label
cdef weight_t cost = 0
if g_act == MISSING:
pass

View File

@ -125,14 +125,17 @@ def decompose(label):
def is_decorated(label):
return DELIMITER in label
def count_decorated_labels(gold_data):
freqs = {}
for example in gold_data:
proj_heads, deco_deps = projectivize(example.get_aligned("HEAD"),
example.get_aligned("DEP"))
# set the label to ROOT for each root dependent
deco_deps = ['ROOT' if head == i else deco_deps[i]
for i, head in enumerate(proj_heads)]
deco_deps = [
'ROOT' if head == i else deco_deps[i]
for i, head in enumerate(proj_heads)
]
# count label frequencies
for label in deco_deps:
if is_decorated(label):
@ -160,9 +163,9 @@ def projectivize(heads, labels):
cdef vector[int] _heads_to_c(heads):
cdef vector[int] c_heads;
cdef vector[int] c_heads
for head in heads:
if head == None:
if head is None:
c_heads.push_back(-1)
else:
assert head < len(heads)
@ -199,6 +202,7 @@ def _decorate(heads, proj_heads, labels):
deco_labels.append(labels[tokenid])
return deco_labels
def get_smallest_nonproj_arc_slow(heads):
cdef vector[int] c_heads = _heads_to_c(heads)
return _get_smallest_nonproj_arc(c_heads)

View File

@ -1,6 +1,4 @@
# cython: infer_types=True
import numpy
from libcpp.vector cimport vector
from ...tokens.doc cimport Doc
@ -38,11 +36,11 @@ cdef class StateClass:
cdef vector[ArcC] arcs
self.c.get_arcs(&arcs)
return list(arcs)
#py_arcs = []
#for arc in arcs:
# if arc.head != -1 and arc.child != -1:
# py_arcs.append((arc.head, arc.child, arc.label))
#return arcs
# py_arcs = []
# for arc in arcs:
# if arc.head != -1 and arc.child != -1:
# py_arcs.append((arc.head, arc.child, arc.label))
# return arcs
def add_arc(self, int head, int child, int label):
self.c.add_arc(head, child, label)
@ -52,10 +50,10 @@ cdef class StateClass:
def H(self, int child):
return self.c.H(child)
def L(self, int head, int idx):
return self.c.L(head, idx)
def R(self, int head, int idx):
return self.c.R(head, idx)
@ -98,7 +96,7 @@ cdef class StateClass:
def H(self, int i):
return self.c.H(i)
def E(self, int i):
return self.c.E(i)
@ -116,7 +114,7 @@ cdef class StateClass:
def H_(self, int i):
return self.doc[self.c.H(i)]
def E_(self, int i):
return self.doc[self.c.E(i)]
@ -125,7 +123,7 @@ cdef class StateClass:
def R_(self, int i, int idx):
return self.doc[self.c.R(i, idx)]
def empty(self):
return self.c.empty()
@ -134,7 +132,7 @@ cdef class StateClass:
def at_break(self):
return False
#return self.c.at_break()
# return self.c.at_break()
def has_head(self, int i):
return self.c.has_head(i)

View File

@ -20,11 +20,15 @@ cdef struct Transition:
int (*do)(StateC* state, attr_t label) nogil
ctypedef weight_t (*get_cost_func_t)(const StateC* state, const void* gold,
attr_tlabel) nogil
ctypedef weight_t (*move_cost_func_t)(const StateC* state, const void* gold) nogil
ctypedef weight_t (*label_cost_func_t)(const StateC* state, const void*
gold, attr_t label) nogil
ctypedef weight_t (*get_cost_func_t)(
const StateC* state, const void* gold, attr_tlabel
) nogil
ctypedef weight_t (*move_cost_func_t)(
const StateC* state, const void* gold
) nogil
ctypedef weight_t (*label_cost_func_t)(
const StateC* state, const void* gold, attr_t label
) nogil
ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil

View File

@ -8,9 +8,7 @@ from collections import Counter
import srsly
from ...structs cimport TokenC
from ...tokens.doc cimport Doc
from ...typedefs cimport attr_t, weight_t
from . cimport _beam_utils
from .stateclass cimport StateClass
from ... import util
@ -231,7 +229,6 @@ cdef class TransitionSystem:
return self
def to_bytes(self, exclude=tuple()):
transitions = []
serializers = {
'moves': lambda: srsly.json_dumps(self.labels),
'strings': lambda: self.strings.to_bytes(),

View File

@ -1,6 +1,6 @@
# cython: infer_types=True, profile=True, binding=True
from collections import defaultdict
from typing import Callable, Iterable, Optional
from typing import Callable, Optional
from thinc.api import Config, Model
@ -124,6 +124,7 @@ def make_parser(
scorer=scorer,
)
@Language.factory(
"beam_parser",
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],

View File

@ -2,7 +2,6 @@
from itertools import islice
from typing import Callable, Dict, Optional, Union
import srsly
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
from ..morphology cimport Morphology
@ -14,10 +13,8 @@ from ..errors import Errors
from ..language import Language
from ..parts_of_speech import IDS as POS_IDS
from ..scorer import Scorer
from ..symbols import POS
from ..training import validate_examples, validate_get_examples
from ..util import registry
from .pipe import deserialize_config
from .tagger import Tagger
# See #9050
@ -76,8 +73,11 @@ def morphologizer_score(examples, **kwargs):
results = {}
results.update(Scorer.score_token_attr(examples, "pos", **kwargs))
results.update(Scorer.score_token_attr(examples, "morph", getter=morph_key_getter, **kwargs))
results.update(Scorer.score_token_attr_per_feat(examples,
"morph", getter=morph_key_getter, **kwargs))
results.update(
Scorer.score_token_attr_per_feat(
examples, "morph", getter=morph_key_getter, **kwargs
)
)
return results
@ -233,7 +233,6 @@ class Morphologizer(Tagger):
if isinstance(docs, Doc):
docs = [docs]
cdef Doc doc
cdef Vocab vocab = self.vocab
cdef bint overwrite = self.cfg["overwrite"]
cdef bint extend = self.cfg["extend"]
labels = self.labels

View File

@ -4,13 +4,10 @@ from typing import Optional
import numpy
from thinc.api import Config, CosineDistance, Model, set_dropout_rate, to_categorical
from ..tokens.doc cimport Doc
from ..attrs import ID, POS
from ..attrs import ID
from ..errors import Errors
from ..language import Language
from ..training import validate_examples
from ._parser_internals import nonproj
from .tagger import Tagger
from .trainable_pipe import TrainablePipe
@ -103,10 +100,9 @@ class MultitaskObjective(Tagger):
cdef int idx = 0
correct = numpy.zeros((scores.shape[0],), dtype="i")
guesses = scores.argmax(axis=1)
docs = [eg.predicted for eg in examples]
for i, eg in enumerate(examples):
# Handles alignment for tokenization differences
doc_annots = eg.get_aligned() # TODO
_doc_annots = eg.get_aligned() # TODO
for j in range(len(eg.predicted)):
tok_annots = {key: values[j] for key, values in tok_annots.items()}
label = self.make_label(j, tok_annots)
@ -206,7 +202,6 @@ class ClozeMultitask(TrainablePipe):
losses[self.name] = 0.
set_dropout_rate(self.model, drop)
validate_examples(examples, "ClozeMultitask.rehearse")
docs = [eg.predicted for eg in examples]
predictions, bp_predictions = self.model.begin_update()
loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
bp_predictions(d_predictions)

View File

@ -1,6 +1,6 @@
# cython: infer_types=True, profile=True, binding=True
from collections import defaultdict
from typing import Callable, Iterable, Optional
from typing import Callable, Optional
from thinc.api import Config, Model
@ -10,7 +10,7 @@ from ._parser_internals.ner cimport BiluoPushDown
from .transition_parser cimport Parser
from ..language import Language
from ..scorer import PRFScore, get_ner_prf
from ..scorer import get_ner_prf
from ..training import remove_bilu_prefix
from ..util import registry
@ -100,6 +100,7 @@ def make_ner(
scorer=scorer,
)
@Language.factory(
"beam_ner",
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],

View File

@ -1,6 +1,6 @@
# cython: infer_types=True, profile=True, binding=True
import warnings
from typing import Callable, Dict, Iterable, Iterator, Optional, Tuple, Union
from typing import Callable, Dict, Iterable, Iterator, Tuple, Union
import srsly
@ -40,7 +40,7 @@ cdef class Pipe:
"""
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="__call__", name=self.name))
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
@ -59,7 +59,7 @@ cdef class Pipe:
except Exception as e:
error_handler(self.name, self, [doc], e)
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language=None):
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language = None):
"""Initialize the pipe. For non-trainable components, this method
is optional. For trainable components, which should inherit
from the subclass TrainablePipe, the provided data examples

View File

@ -7,13 +7,13 @@ from ..tokens.doc cimport Doc
from .. import util
from ..language import Language
from ..scorer import Scorer
from .pipe import Pipe
from .senter import senter_score
# see #9050
BACKWARD_OVERWRITE = False
@Language.factory(
"sentencizer",
assigns=["token.is_sent_start", "doc.sents"],
@ -36,17 +36,19 @@ class Sentencizer(Pipe):
DOCS: https://spacy.io/api/sentencizer
"""
default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '᱿',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀',
'𑃁', '𑅁', '𑅂', '𑅃', '𑇅', '𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼',
'𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
'𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
'𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈',
'', '']
default_punct_chars = [
'!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '᱿',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀',
'𑃁', '𑅁', '𑅂', '𑅃', '𑇅', '𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼',
'𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
'𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
'𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈',
'', ''
]
def __init__(
self,
@ -128,7 +130,6 @@ class Sentencizer(Pipe):
if isinstance(docs, Doc):
docs = [docs]
cdef Doc doc
cdef int idx = 0
for i, doc in enumerate(docs):
doc_tag_ids = batch_tag_ids[i]
for j, tag_id in enumerate(doc_tag_ids):
@ -169,7 +170,6 @@ class Sentencizer(Pipe):
path = path.with_suffix(".json")
srsly.write_json(path, {"punct_chars": list(self.punct_chars), "overwrite": self.overwrite})
def from_disk(self, path, *, exclude=tuple()):
"""Load the sentencizer from disk.

View File

@ -2,7 +2,6 @@
from itertools import islice
from typing import Callable, Optional
import srsly
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
from ..tokens.doc cimport Doc

View File

@ -1,26 +1,17 @@
# cython: infer_types=True, profile=True, binding=True
import warnings
from itertools import islice
from typing import Callable, Optional
import numpy
import srsly
from thinc.api import Config, Model, SequenceCategoricalCrossentropy, set_dropout_rate
from thinc.types import Floats2d
from ..morphology cimport Morphology
from ..tokens.doc cimport Doc
from ..vocab cimport Vocab
from .. import util
from ..attrs import ID, POS
from ..errors import Errors, Warnings
from ..language import Language
from ..parts_of_speech import X
from ..scorer import Scorer
from ..training import validate_examples, validate_get_examples
from ..util import registry
from .pipe import deserialize_config
from .trainable_pipe import TrainablePipe
# See #9050
@ -169,7 +160,6 @@ class Tagger(TrainablePipe):
if isinstance(docs, Doc):
docs = [docs]
cdef Doc doc
cdef Vocab vocab = self.vocab
cdef bint overwrite = self.cfg["overwrite"]
labels = self.labels
for i, doc in enumerate(docs):

View File

@ -55,7 +55,7 @@ cdef class TrainablePipe(Pipe):
except Exception as e:
error_handler(self.name, self, [doc], e)
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
@ -102,9 +102,9 @@ cdef class TrainablePipe(Pipe):
def update(self,
examples: Iterable["Example"],
*,
drop: float=0.0,
sgd: Optimizer=None,
losses: Optional[Dict[str, float]]=None) -> Dict[str, float]:
drop: float = 0.0,
sgd: Optimizer = None,
losses: Optional[Dict[str, float]] = None) -> Dict[str, float]:
"""Learn from a batch of documents and gold-standard information,
updating the pipe's model. Delegates to predict and get_loss.
@ -138,8 +138,8 @@ cdef class TrainablePipe(Pipe):
def rehearse(self,
examples: Iterable[Example],
*,
sgd: Optimizer=None,
losses: Dict[str, float]=None,
sgd: Optimizer = None,
losses: Dict[str, float] = None,
**config) -> Dict[str, float]:
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
teach the current model to make predictions similar to an initial model,
@ -177,7 +177,7 @@ cdef class TrainablePipe(Pipe):
"""
return util.create_default_optimizer()
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language=None):
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language = None):
"""Initialize the pipe for training, using data examples if available.
This method needs to be implemented by each TrainablePipe component,
ensuring the internal model (if available) is initialized properly

View File

@ -13,8 +13,18 @@ cdef class Parser(TrainablePipe):
cdef readonly TransitionSystem moves
cdef public object _multitasks
cdef void _parseC(self, CBlas cblas, StateC** states,
WeightsC weights, SizesC sizes) nogil
cdef void _parseC(
self,
CBlas cblas,
StateC** states,
WeightsC weights,
SizesC sizes
) nogil
cdef void c_transition_batch(self, StateC** states, const float* scores,
int nr_class, int batch_size) nogil
cdef void c_transition_batch(
self,
StateC** states,
const float* scores,
int nr_class,
int batch_size
) nogil

View File

@ -7,17 +7,14 @@ from cymem.cymem cimport Pool
from itertools import islice
from libc.stdlib cimport calloc, free
from libc.string cimport memcpy, memset
from libc.string cimport memset
from libcpp.vector cimport vector
import random
import srsly
from thinc.api import CupyOps, NumpyOps, get_ops, set_dropout_rate
from thinc.api import CupyOps, NumpyOps, set_dropout_rate
from thinc.extra.search cimport Beam
import warnings
import numpy
import numpy.random
@ -42,7 +39,7 @@ from .trainable_pipe import TrainablePipe
from ._parser_internals cimport _beam_utils
from .. import util
from ..errors import Errors, Warnings
from ..errors import Errors
from ..training import validate_examples, validate_get_examples
from ._parser_internals import _beam_utils
@ -258,7 +255,6 @@ cdef class Parser(TrainablePipe):
except Exception as e:
error_handler(self.name, self, batch_in_order, e)
def predict(self, docs):
if isinstance(docs, Doc):
docs = [docs]
@ -300,8 +296,6 @@ cdef class Parser(TrainablePipe):
return batch
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
cdef Beam beam
cdef Doc doc
self._ensure_labels_are_added(docs)
batch = _beam_utils.BeamBatch(
self.moves,
@ -321,16 +315,18 @@ cdef class Parser(TrainablePipe):
del model
return list(batch)
cdef void _parseC(self, CBlas cblas, StateC** states,
WeightsC weights, SizesC sizes) nogil:
cdef int i, j
cdef void _parseC(
self, CBlas cblas, StateC** states, WeightsC weights, SizesC sizes
) nogil:
cdef int i
cdef vector[StateC*] unfinished
cdef ActivationsC activations = alloc_activations(sizes)
while sizes.states >= 1:
predict_states(cblas, &activations, states, &weights, sizes)
# Validate actions, argmax, take action.
self.c_transition_batch(states,
activations.scores, sizes.classes, sizes.states)
self.c_transition_batch(
states, activations.scores, sizes.classes, sizes.states
)
for i in range(sizes.states):
if not states[i].is_final():
unfinished.push_back(states[i])
@ -342,7 +338,6 @@ cdef class Parser(TrainablePipe):
def set_annotations(self, docs, states_or_beams):
cdef StateClass state
cdef Beam beam
cdef Doc doc
states = _beam_utils.collect_states(states_or_beams, docs)
for i, (state, doc) in enumerate(zip(states, docs)):
@ -359,8 +354,13 @@ cdef class Parser(TrainablePipe):
self.c_transition_batch(&c_states[0], c_scores, scores.shape[1], scores.shape[0])
return [state for state in states if not state.c.is_final()]
cdef void c_transition_batch(self, StateC** states, const float* scores,
int nr_class, int batch_size) nogil:
cdef void c_transition_batch(
self,
StateC** states,
const float* scores,
int nr_class,
int batch_size
) nogil:
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
with gil:
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
@ -380,7 +380,6 @@ cdef class Parser(TrainablePipe):
free(is_valid)
def update(self, examples, *, drop=0., sgd=None, losses=None):
cdef StateClass state
if losses is None:
losses = {}
losses.setdefault(self.name, 0.)
@ -419,8 +418,7 @@ cdef class Parser(TrainablePipe):
if not states:
return losses
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
all_states = list(states)
states_golds = list(zip(states, golds))
n_moves = 0
while states_golds:
@ -500,8 +498,16 @@ cdef class Parser(TrainablePipe):
del tutor
return losses
def update_beam(self, examples, *, beam_width,
drop=0., sgd=None, losses=None, beam_density=0.0):
def update_beam(
self,
examples,
*,
beam_width,
drop=0.,
sgd=None,
losses=None,
beam_density=0.0
):
states, golds, _ = self.moves.init_gold_batch(examples)
if not states:
return losses
@ -531,8 +537,9 @@ cdef class Parser(TrainablePipe):
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
costs = <float*>mem.alloc(self.moves.n_moves, sizeof(float))
cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves),
dtype='f', order='C')
cdef np.ndarray d_scores = numpy.zeros(
(len(states), self.moves.n_moves), dtype='f', order='C'
)
c_d_scores = <float*>d_scores.data
unseen_classes = self.model.attrs["unseen_classes"]
for i, (state, gold) in enumerate(zip(states, golds)):
@ -542,8 +549,9 @@ cdef class Parser(TrainablePipe):
for j in range(self.moves.n_moves):
if costs[j] <= 0.0 and j in unseen_classes:
unseen_classes.remove(j)
cpu_log_loss(c_d_scores,
costs, is_valid, &scores[i, 0], d_scores.shape[1])
cpu_log_loss(
c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1]
)
c_d_scores += d_scores.shape[1]
# Note that we don't normalize this. See comment in update() for why.
if losses is not None: