Merge pull request #6827 from explosion/feature/add-labels-implicitly

This commit is contained in:
Ines Montani 2021-01-27 21:34:58 +11:00 committed by GitHub
commit abb24fdc0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 9 deletions

View File

@ -614,10 +614,22 @@ cdef class ArcEager(TransitionSystem):
actions[LEFT].setdefault('dep', 0) actions[LEFT].setdefault('dep', 0)
return actions return actions
@property
def builtin_labels(self):
return ["ROOT", "dep"]
@property @property
def action_types(self): def action_types(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
def get_doc_labels(self, doc):
"""Get the labels required for a given Doc."""
labels = set(self.builtin_labels)
for token in doc:
if token.dep_:
labels.add(token.dep_)
return labels
def transition(self, StateClass state, action): def transition(self, StateClass state, action):
cdef Transition t = self.lookup_transition(action) cdef Transition t = self.lookup_transition(action)
t.do(state.c, t.label) t.do(state.c, t.label)

View File

@ -126,6 +126,13 @@ cdef class BiluoPushDown(TransitionSystem):
def action_types(self): def action_types(self):
return (BEGIN, IN, LAST, UNIT, OUT) return (BEGIN, IN, LAST, UNIT, OUT)
def get_doc_labels(self, doc):
labels = set()
for token in doc:
if token.ent_type:
labels.add(token.ent_type_)
return labels
def move_name(self, int move, attr_t label): def move_name(self, int move, attr_t label):
if move == OUT: if move == OUT:
return 'O' return 'O'

View File

@ -277,3 +277,10 @@ cdef class DependencyParser(Parser):
head_scores.append(score_head_dict) head_scores.append(score_head_dict)
label_scores.append(score_label_dict) label_scores.append(score_label_dict)
return head_scores, label_scores return head_scores, label_scores
def _ensure_labels_are_added(self, docs):
# This gives the parser a chance to add labels it's missing for a batch
# of documents. However, this isn't desirable for the dependency parser,
# because we instead have a label frequency cut-off and back off rare
# labels to 'dep'.
pass

View File

@ -132,6 +132,23 @@ cdef class Parser(TrainablePipe):
return 1 return 1
return 0 return 0
def _ensure_labels_are_added(self, docs):
"""Ensure that all labels for a batch of docs are added."""
resized = False
labels = set()
for doc in docs:
labels.update(self.moves.get_doc_labels(doc))
for label in labels:
for action in self.moves.action_types:
added = self.moves.add_action(action, label)
if added:
self.vocab.strings.add(label)
resized = True
if resized:
self._resize()
return 1
return 0
def _resize(self): def _resize(self):
self.model.attrs["resize_output"](self.model, self.moves.n_moves) self.model.attrs["resize_output"](self.model, self.moves.n_moves)
if self._rehearsal_model not in (True, False, None): if self._rehearsal_model not in (True, False, None):
@ -188,9 +205,9 @@ cdef class Parser(TrainablePipe):
def predict(self, docs): def predict(self, docs):
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
self._ensure_labels_are_added(docs)
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
result = self.moves.init_batch(docs) result = self.moves.init_batch(docs)
self._resize()
return result return result
if self.cfg["beam_width"] == 1: if self.cfg["beam_width"] == 1:
return self.greedy_parse(docs, drop=0.0) return self.greedy_parse(docs, drop=0.0)
@ -207,10 +224,6 @@ cdef class Parser(TrainablePipe):
cdef StateClass state cdef StateClass state
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
batch = self.moves.init_batch(docs) batch = self.moves.init_batch(docs)
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize()
model = self.model.predict(docs) model = self.model.predict(docs)
weights = get_c_weights(model) weights = get_c_weights(model)
for state in batch: for state in batch:
@ -234,10 +247,6 @@ cdef class Parser(TrainablePipe):
beam_width, beam_width,
density=beam_density density=beam_density
) )
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize()
model = self.model.predict(docs) model = self.model.predict(docs)
while not batch.is_done: while not batch.is_done:
states = batch.get_unfinished_states() states = batch.get_unfinished_states()
@ -314,6 +323,9 @@ cdef class Parser(TrainablePipe):
losses = {} losses = {}
losses.setdefault(self.name, 0.) losses.setdefault(self.name, 0.)
validate_examples(examples, "Parser.update") validate_examples(examples, "Parser.update")
self._ensure_labels_are_added(
[eg.x for eg in examples] + [eg.y for eg in examples]
)
for multitask in self._multitasks: for multitask in self._multitasks:
multitask.update(examples, drop=drop, sgd=sgd) multitask.update(examples, drop=drop, sgd=sgd)

View File

@ -1,6 +1,7 @@
import pytest import pytest
from thinc.api import Adam, fix_random_seed from thinc.api import Adam, fix_random_seed
from spacy import registry from spacy import registry
from spacy.language import Language
from spacy.attrs import NORM from spacy.attrs import NORM
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.training import Example from spacy.training import Example
@ -123,3 +124,28 @@ def test_add_label_get_label(pipe_cls, n_moves, model_config):
assert len(pipe.move_names) == len(labels) * n_moves assert len(pipe.move_names) == len(labels) * n_moves
pipe_labels = sorted(list(pipe.labels)) pipe_labels = sorted(list(pipe.labels))
assert pipe_labels == labels assert pipe_labels == labels
def test_ner_labels_added_implicitly_on_predict():
nlp = Language()
ner = nlp.add_pipe("ner")
for label in ["A", "B", "C"]:
ner.add_label(label)
nlp.initialize()
doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"])
ner(doc)
assert [t.ent_type_ for t in doc] == ["D", ""]
assert "D" in ner.labels
def test_ner_labels_added_implicitly_on_update():
nlp = Language()
ner = nlp.add_pipe("ner")
for label in ["A", "B", "C"]:
ner.add_label(label)
nlp.initialize()
doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"])
example = Example(nlp.make_doc(doc.text), doc)
assert "D" not in ner.labels
nlp.update([example])
assert "D" in ner.labels