diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 1b7a9eecb..eb5f0bbb3 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -260,6 +260,11 @@ class SpanCategorizer(TrainablePipe): """ return list(self.labels) + @property + def _n_labels(self) -> int: + """RETURNS (int): Number of labels.""" + return len(self.labels) + def predict(self, docs: Iterable[Doc]): """Apply the pipeline's model to a batch of docs, without modifying them. @@ -432,7 +437,7 @@ class SpanCategorizer(TrainablePipe): if subbatch: docs = [eg.x for eg in subbatch] spans = build_ngram_suggester(sizes=[1])(docs) - Y = self.model.ops.alloc2f(spans.dataXd.shape[0], len(self.labels)) + Y = self.model.ops.alloc2f(spans.dataXd.shape[0], self._n_labels) self.model.initialize(X=(docs, spans), Y=Y) else: self.model.initialize() diff --git a/spacy/pipeline/spancat_exclusive.py b/spacy/pipeline/spancat_exclusive.py index 8d2748ea5..e4da41677 100644 --- a/spacy/pipeline/spancat_exclusive.py +++ b/spacy/pipeline/spancat_exclusive.py @@ -378,57 +378,6 @@ class SpanCategorizerExclusive(TrainablePipe): loss = float((d_scores**2).sum()) return loss, d_scores - def initialize( - self, - get_examples: Callable[[], Iterable[Example]], - *, - nlp: Optional[Language] = None, - labels: Optional[List[str]] = None, - ) -> None: - """Initialize the pipe for training, using a representative set - of data examples. - - get_examples (Callable[[], Iterable[Example]]): Function that - returns a representative sample of gold-standard Example objects. - nlp (Optional[Language]): The current nlp object the component is part of. - labels (Optional[List[str]]): The labels to add to the component, typically generated by the - `init labels` command. If no labels are provided, the get_examples - callback is used to extract the labels from the data. - - DOCS: https://spacy.io/api/spancategorizerexclusive#initialize - """ - subbatch: List[Example] = [] - if labels is not None: - for label in labels: - self.add_label(label) - for eg in get_examples(): - if labels is None: - for span in eg.reference.spans.get(self.key, []): - self.add_label(span.label_) - if len(subbatch) < 10: - subbatch.append(eg) - self._require_labels() - if subbatch: - docs = [eg.x for eg in subbatch] - spans = build_ngram_suggester(sizes=[1])(docs) - # + 1 for the "no-label" category - Y = self.model.ops.alloc2f(spans.dataXd.shape[0], self._n_labels) - self.model.initialize(X=(docs, spans), Y=Y) - else: - # FIXME: Ideally we want to raise an error to avoid implicitly - # raising it when initializing without examples. For now, we'll just - # copy over what `spancat` did. - self.model.initialize() - - def _validate_categories(self, examples: Iterable[Example]): - # TODO - pass - - def _get_aligned_spans(self, eg: Example): - return eg.get_aligned_spans_y2x( - eg.reference.spans.get(self.key, []), allow_overlap=True - ) - def _make_span_group( self, doc: Doc,