Apply suggestions from @shadeMe

This commit is contained in:
Daniël de Kok 2023-01-11 08:33:43 +01:00
parent 74b9ddd03a
commit 672cad7161
2 changed files with 24 additions and 12 deletions

View File

@ -171,15 +171,18 @@ class TransitionModelInputs:
"states",
]
def __init__(self, docs: List[Doc], moves: TransitionSystem,
actions: Optional[List[Ints1d]]=None, max_moves: int=0,
def __init__(
self,
docs: List[Doc],
moves: TransitionSystem,
actions: Optional[List[Ints1d]]=None,
max_moves: int=0,
states: Optional[List[State]]=None):
"""
actions (Optional[List[Ints1d]]): actions to apply for each Doc.
docs (List[Doc]): Docs to predict transition sequences for.
max_moves: (Optional[int]): the maximum number of moves to apply,
values less than 1 will apply moves to states until they are
final states.
max_moves: (int): the maximum number of moves to apply, values less
than 1 will apply moves to states until they are final states.
moves (TransitionSystem): the transition system to use when predicting
the transition sequences.
states (Optional[List[States]]): the initial states to predict the
@ -208,7 +211,7 @@ def forward(model, inputs: TransitionModelInputs, is_train: bool):
feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train)
seen_mask = _get_seen_mask(model)
if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps):
if not is_train and beam_width == 1 and isinstance(model.ops, NumpyOps):
# Note: max_moves is only used during training, so we don't need to
# pass it to the greedy inference path.
return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions)
@ -271,8 +274,17 @@ cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states,
return scores
def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool,
actions: Optional[List[Ints1d]]=None, max_moves=0):
def _forward_fallback(
model: Model,
moves: TransitionSystem,
states: List[StateClass],
tokvecs, backprop_tok2vec,
feats,
backprop_feats,
seen_mask,
is_train: bool,
actions: Optional[List[Ints1d]]=None,
max_moves: int=0):
nF = model.get_dim("nF")
output = model.get_ref("output")
hidden_b = model.get_param("hidden_b")

View File

@ -528,10 +528,10 @@ class Parser(TrainablePipe):
def _init_gold_batch(self, examples, max_length):
"""Make a square batch, of length equal to the shortest transition
sequence or a cap. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
where N is the shortest doc. We'll make two states, one representing
long_doc[:N], and another representing long_doc[N:]."""
sequence or a cap. A long doc will get multiple states. Let's say we
have a doc of length 2*N, where N is the shortest doc. We'll make
two states, one representing long_doc[:N], and another representing
long_doc[N:]."""
cdef:
StateClass start_state
StateClass state