mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
Merge pull request #6827 from explosion/feature/add-labels-implicitly
This commit is contained in:
commit
abb24fdc0f
|
@ -614,10 +614,22 @@ cdef class ArcEager(TransitionSystem):
|
||||||
actions[LEFT].setdefault('dep', 0)
|
actions[LEFT].setdefault('dep', 0)
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def builtin_labels(self):
|
||||||
|
return ["ROOT", "dep"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_types(self):
|
def action_types(self):
|
||||||
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
|
||||||
|
|
||||||
|
def get_doc_labels(self, doc):
|
||||||
|
"""Get the labels required for a given Doc."""
|
||||||
|
labels = set(self.builtin_labels)
|
||||||
|
for token in doc:
|
||||||
|
if token.dep_:
|
||||||
|
labels.add(token.dep_)
|
||||||
|
return labels
|
||||||
|
|
||||||
def transition(self, StateClass state, action):
|
def transition(self, StateClass state, action):
|
||||||
cdef Transition t = self.lookup_transition(action)
|
cdef Transition t = self.lookup_transition(action)
|
||||||
t.do(state.c, t.label)
|
t.do(state.c, t.label)
|
||||||
|
|
|
@ -126,6 +126,13 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
def action_types(self):
|
def action_types(self):
|
||||||
return (BEGIN, IN, LAST, UNIT, OUT)
|
return (BEGIN, IN, LAST, UNIT, OUT)
|
||||||
|
|
||||||
|
def get_doc_labels(self, doc):
|
||||||
|
labels = set()
|
||||||
|
for token in doc:
|
||||||
|
if token.ent_type:
|
||||||
|
labels.add(token.ent_type_)
|
||||||
|
return labels
|
||||||
|
|
||||||
def move_name(self, int move, attr_t label):
|
def move_name(self, int move, attr_t label):
|
||||||
if move == OUT:
|
if move == OUT:
|
||||||
return 'O'
|
return 'O'
|
||||||
|
|
|
@ -277,3 +277,10 @@ cdef class DependencyParser(Parser):
|
||||||
head_scores.append(score_head_dict)
|
head_scores.append(score_head_dict)
|
||||||
label_scores.append(score_label_dict)
|
label_scores.append(score_label_dict)
|
||||||
return head_scores, label_scores
|
return head_scores, label_scores
|
||||||
|
|
||||||
|
def _ensure_labels_are_added(self, docs):
|
||||||
|
# This gives the parser a chance to add labels it's missing for a batch
|
||||||
|
# of documents. However, this isn't desirable for the dependency parser,
|
||||||
|
# because we instead have a label frequency cut-off and back off rare
|
||||||
|
# labels to 'dep'.
|
||||||
|
pass
|
||||||
|
|
|
@ -132,6 +132,23 @@ cdef class Parser(TrainablePipe):
|
||||||
return 1
|
return 1
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def _ensure_labels_are_added(self, docs):
|
||||||
|
"""Ensure that all labels for a batch of docs are added."""
|
||||||
|
resized = False
|
||||||
|
labels = set()
|
||||||
|
for doc in docs:
|
||||||
|
labels.update(self.moves.get_doc_labels(doc))
|
||||||
|
for label in labels:
|
||||||
|
for action in self.moves.action_types:
|
||||||
|
added = self.moves.add_action(action, label)
|
||||||
|
if added:
|
||||||
|
self.vocab.strings.add(label)
|
||||||
|
resized = True
|
||||||
|
if resized:
|
||||||
|
self._resize()
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
def _resize(self):
|
def _resize(self):
|
||||||
self.model.attrs["resize_output"](self.model, self.moves.n_moves)
|
self.model.attrs["resize_output"](self.model, self.moves.n_moves)
|
||||||
if self._rehearsal_model not in (True, False, None):
|
if self._rehearsal_model not in (True, False, None):
|
||||||
|
@ -188,9 +205,9 @@ cdef class Parser(TrainablePipe):
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
self._ensure_labels_are_added(docs)
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
result = self.moves.init_batch(docs)
|
result = self.moves.init_batch(docs)
|
||||||
self._resize()
|
|
||||||
return result
|
return result
|
||||||
if self.cfg["beam_width"] == 1:
|
if self.cfg["beam_width"] == 1:
|
||||||
return self.greedy_parse(docs, drop=0.0)
|
return self.greedy_parse(docs, drop=0.0)
|
||||||
|
@ -207,10 +224,6 @@ cdef class Parser(TrainablePipe):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
batch = self.moves.init_batch(docs)
|
batch = self.moves.init_batch(docs)
|
||||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
|
||||||
# if labels are missing. We therefore have to check whether we need to
|
|
||||||
# expand our model output.
|
|
||||||
self._resize()
|
|
||||||
model = self.model.predict(docs)
|
model = self.model.predict(docs)
|
||||||
weights = get_c_weights(model)
|
weights = get_c_weights(model)
|
||||||
for state in batch:
|
for state in batch:
|
||||||
|
@ -234,10 +247,6 @@ cdef class Parser(TrainablePipe):
|
||||||
beam_width,
|
beam_width,
|
||||||
density=beam_density
|
density=beam_density
|
||||||
)
|
)
|
||||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
|
||||||
# if labels are missing. We therefore have to check whether we need to
|
|
||||||
# expand our model output.
|
|
||||||
self._resize()
|
|
||||||
model = self.model.predict(docs)
|
model = self.model.predict(docs)
|
||||||
while not batch.is_done:
|
while not batch.is_done:
|
||||||
states = batch.get_unfinished_states()
|
states = batch.get_unfinished_states()
|
||||||
|
@ -314,6 +323,9 @@ cdef class Parser(TrainablePipe):
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.)
|
losses.setdefault(self.name, 0.)
|
||||||
validate_examples(examples, "Parser.update")
|
validate_examples(examples, "Parser.update")
|
||||||
|
self._ensure_labels_are_added(
|
||||||
|
[eg.x for eg in examples] + [eg.y for eg in examples]
|
||||||
|
)
|
||||||
for multitask in self._multitasks:
|
for multitask in self._multitasks:
|
||||||
multitask.update(examples, drop=drop, sgd=sgd)
|
multitask.update(examples, drop=drop, sgd=sgd)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
from thinc.api import Adam, fix_random_seed
|
from thinc.api import Adam, fix_random_seed
|
||||||
from spacy import registry
|
from spacy import registry
|
||||||
|
from spacy.language import Language
|
||||||
from spacy.attrs import NORM
|
from spacy.attrs import NORM
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
|
@ -123,3 +124,28 @@ def test_add_label_get_label(pipe_cls, n_moves, model_config):
|
||||||
assert len(pipe.move_names) == len(labels) * n_moves
|
assert len(pipe.move_names) == len(labels) * n_moves
|
||||||
pipe_labels = sorted(list(pipe.labels))
|
pipe_labels = sorted(list(pipe.labels))
|
||||||
assert pipe_labels == labels
|
assert pipe_labels == labels
|
||||||
|
|
||||||
|
|
||||||
|
def test_ner_labels_added_implicitly_on_predict():
|
||||||
|
nlp = Language()
|
||||||
|
ner = nlp.add_pipe("ner")
|
||||||
|
for label in ["A", "B", "C"]:
|
||||||
|
ner.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"])
|
||||||
|
ner(doc)
|
||||||
|
assert [t.ent_type_ for t in doc] == ["D", ""]
|
||||||
|
assert "D" in ner.labels
|
||||||
|
|
||||||
|
|
||||||
|
def test_ner_labels_added_implicitly_on_update():
|
||||||
|
nlp = Language()
|
||||||
|
ner = nlp.add_pipe("ner")
|
||||||
|
for label in ["A", "B", "C"]:
|
||||||
|
ner.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"])
|
||||||
|
example = Example(nlp.make_doc(doc.text), doc)
|
||||||
|
assert "D" not in ner.labels
|
||||||
|
nlp.update([example])
|
||||||
|
assert "D" in ner.labels
|
||||||
|
|
Loading…
Reference in New Issue
Block a user