diff --git a/spacy/pipeline/span_predictor.py b/spacy/pipeline/span_predictor.py index 12ea6611c..d21a45edb 100644 --- a/spacy/pipeline/span_predictor.py +++ b/spacy/pipeline/span_predictor.py @@ -96,7 +96,7 @@ class SpanPredictor(TrainablePipe): self.input_prefix = input_prefix self.output_prefix = output_prefix - self.cfg = {} + self.cfg: Dict[str, Any] = {} def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: # for now pretend there's just one doc @@ -205,7 +205,7 @@ class SpanPredictor(TrainablePipe): ops = self.model.ops # NOTE This is doing fake batching, and should always get a list of one example - assert len(examples) == 1, "Only fake batching is supported." + assert len(list(examples)) == 1, "Only fake batching is supported." # starts and ends are gold starts and ends (Ints1d) # span_scores is a Floats3d. What are the axes? mention x token x start/end for eg in examples: