From 9a35b24b488db2b78bb0200a4553e661063fff37 Mon Sep 17 00:00:00 2001 From: Lj Miranda Date: Fri, 18 Nov 2022 13:48:18 +0800 Subject: [PATCH] Implement _allow_extra_label to use _n_labels To ensure that spancat / spancat_exclusive cannot be resized after initialization, I inherited the _allow_extra_label() method from spacy/pipeline/trainable_pipe.pyx and used self._n_labels instead of len(self.labels) for checking. I think that changing it locally is a better solution rather than forcing each class that inherits TrainablePipe to use the self._n_labels attribute. Also note that I turned-off black formatting in this block of code because it reads better without the overhang. --- spacy/pipeline/spancat.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index eb5f0bbb3..6fb26fa41 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -227,6 +227,19 @@ class SpanCategorizer(TrainablePipe): """ return str(self.cfg["spans_key"]) + def _allow_extra_label(self) -> None: + """Raise an error if the component can not add any more labels.""" + nO = None + # fmt: off + if self.model.has_dim("nO"): + nO = self.model.get_dim("nO") + elif self.model.has_ref("output_layer") and self.model.get_ref("output_layer").has_dim("nO"): + nO = self.model.get_ref("output_layer").get_dim("nO") + if nO is not None and nO == self._n_labels: + if not self.is_resizable: + raise ValueError(Errors.E922.format(name=self.name, nO=self.model.get_dim("nO"))) + # fmt: on + def add_label(self, label: str) -> int: """Add a new label to the pipe.