mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 08:12:24 +03:00
Add _n_labels property to SpanCategorizer
Instead of using len(self.labels) in initialize() I am using a private property self._n_labels. This achieves implementation parity and allows me to delete the whole initialize() method for spancat_exclusive (since it's now the same with spancat).
This commit is contained in:
parent
023a1a6c04
commit
bdf2a1d1fe
|
@ -260,6 +260,11 @@ class SpanCategorizer(TrainablePipe):
|
|||
"""
|
||||
return list(self.labels)
|
||||
|
||||
@property
|
||||
def _n_labels(self) -> int:
|
||||
"""RETURNS (int): Number of labels."""
|
||||
return len(self.labels)
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
||||
|
@ -432,7 +437,7 @@ class SpanCategorizer(TrainablePipe):
|
|||
if subbatch:
|
||||
docs = [eg.x for eg in subbatch]
|
||||
spans = build_ngram_suggester(sizes=[1])(docs)
|
||||
Y = self.model.ops.alloc2f(spans.dataXd.shape[0], len(self.labels))
|
||||
Y = self.model.ops.alloc2f(spans.dataXd.shape[0], self._n_labels)
|
||||
self.model.initialize(X=(docs, spans), Y=Y)
|
||||
else:
|
||||
self.model.initialize()
|
||||
|
|
|
@ -378,57 +378,6 @@ class SpanCategorizerExclusive(TrainablePipe):
|
|||
loss = float((d_scores**2).sum())
|
||||
return loss, d_scores
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
labels: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initialize the pipe for training, using a representative set
|
||||
of data examples.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||
returns a representative sample of gold-standard Example objects.
|
||||
nlp (Optional[Language]): The current nlp object the component is part of.
|
||||
labels (Optional[List[str]]): The labels to add to the component, typically generated by the
|
||||
`init labels` command. If no labels are provided, the get_examples
|
||||
callback is used to extract the labels from the data.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizerexclusive#initialize
|
||||
"""
|
||||
subbatch: List[Example] = []
|
||||
if labels is not None:
|
||||
for label in labels:
|
||||
self.add_label(label)
|
||||
for eg in get_examples():
|
||||
if labels is None:
|
||||
for span in eg.reference.spans.get(self.key, []):
|
||||
self.add_label(span.label_)
|
||||
if len(subbatch) < 10:
|
||||
subbatch.append(eg)
|
||||
self._require_labels()
|
||||
if subbatch:
|
||||
docs = [eg.x for eg in subbatch]
|
||||
spans = build_ngram_suggester(sizes=[1])(docs)
|
||||
# + 1 for the "no-label" category
|
||||
Y = self.model.ops.alloc2f(spans.dataXd.shape[0], self._n_labels)
|
||||
self.model.initialize(X=(docs, spans), Y=Y)
|
||||
else:
|
||||
# FIXME: Ideally we want to raise an error to avoid implicitly
|
||||
# raising it when initializing without examples. For now, we'll just
|
||||
# copy over what `spancat` did.
|
||||
self.model.initialize()
|
||||
|
||||
def _validate_categories(self, examples: Iterable[Example]):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def _get_aligned_spans(self, eg: Example):
|
||||
return eg.get_aligned_spans_y2x(
|
||||
eg.reference.spans.get(self.key, []), allow_overlap=True
|
||||
)
|
||||
|
||||
def _make_span_group(
|
||||
self,
|
||||
doc: Doc,
|
||||
|
|
Loading…
Reference in New Issue
Block a user