mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-24 19:11:58 +03:00
Implement _allow_extra_label to use _n_labels
To ensure that spancat / spancat_exclusive cannot be resized after initialization, I inherited the _allow_extra_label() method from spacy/pipeline/trainable_pipe.pyx and used self._n_labels instead of len(self.labels) for checking. I think that changing it locally is a better solution rather than forcing each class that inherits TrainablePipe to use the self._n_labels attribute. Also note that I turned-off black formatting in this block of code because it reads better without the overhang.
This commit is contained in:
parent
c9036a6d79
commit
9a35b24b48
|
@ -227,6 +227,19 @@ class SpanCategorizer(TrainablePipe):
|
|||
"""
|
||||
return str(self.cfg["spans_key"])
|
||||
|
||||
def _allow_extra_label(self) -> None:
|
||||
"""Raise an error if the component can not add any more labels."""
|
||||
nO = None
|
||||
# fmt: off
|
||||
if self.model.has_dim("nO"):
|
||||
nO = self.model.get_dim("nO")
|
||||
elif self.model.has_ref("output_layer") and self.model.get_ref("output_layer").has_dim("nO"):
|
||||
nO = self.model.get_ref("output_layer").get_dim("nO")
|
||||
if nO is not None and nO == self._n_labels:
|
||||
if not self.is_resizable:
|
||||
raise ValueError(Errors.E922.format(name=self.name, nO=self.model.get_dim("nO")))
|
||||
# fmt: on
|
||||
|
||||
def add_label(self, label: str) -> int:
|
||||
"""Add a new label to the pipe.
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user