diff --git a/spacy/ml/tb_framework.pyx b/spacy/ml/tb_framework.pyx index e5bc99e60..d08ff9c6e 100644 --- a/spacy/ml/tb_framework.pyx +++ b/spacy/ml/tb_framework.pyx @@ -171,15 +171,18 @@ class TransitionModelInputs: "states", ] - def __init__(self, docs: List[Doc], moves: TransitionSystem, - actions: Optional[List[Ints1d]]=None, max_moves: int=0, + def __init__( + self, + docs: List[Doc], + moves: TransitionSystem, + actions: Optional[List[Ints1d]]=None, + max_moves: int=0, states: Optional[List[State]]=None): """ actions (Optional[List[Ints1d]]): actions to apply for each Doc. docs (List[Doc]): Docs to predict transition sequences for. - max_moves: (Optional[int]): the maximum number of moves to apply, - values less than 1 will apply moves to states until they are - final states. + max_moves: (int): the maximum number of moves to apply, values less + than 1 will apply moves to states until they are final states. moves (TransitionSystem): the transition system to use when predicting the transition sequences. states (Optional[List[States]]): the initial states to predict the @@ -208,7 +211,7 @@ def forward(model, inputs: TransitionModelInputs, is_train: bool): feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train) seen_mask = _get_seen_mask(model) - if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps): + if not is_train and beam_width == 1 and isinstance(model.ops, NumpyOps): # Note: max_moves is only used during training, so we don't need to # pass it to the greedy inference path. return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions) @@ -271,8 +274,17 @@ cdef list _parse_batch(CBlas cblas, TransitionSystem moves, StateC** states, return scores -def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool, - actions: Optional[List[Ints1d]]=None, max_moves=0): +def _forward_fallback( + model: Model, + moves: TransitionSystem, + states: List[StateClass], + tokvecs, backprop_tok2vec, + feats, + backprop_feats, + seen_mask, + is_train: bool, + actions: Optional[List[Ints1d]]=None, + max_moves: int=0): nF = model.get_dim("nF") output = model.get_ref("output") hidden_b = model.get_param("hidden_b") diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index b0f24cd73..9a6543453 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -528,10 +528,10 @@ class Parser(TrainablePipe): def _init_gold_batch(self, examples, max_length): """Make a square batch, of length equal to the shortest transition - sequence or a cap. A long - doc will get multiple states. Let's say we have a doc of length 2*N, - where N is the shortest doc. We'll make two states, one representing - long_doc[:N], and another representing long_doc[N:].""" + sequence or a cap. A long doc will get multiple states. Let's say we + have a doc of length 2*N, where N is the shortest doc. We'll make + two states, one representing long_doc[:N], and another representing + long_doc[N:].""" cdef: StateClass start_state StateClass state