mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Update sentence recognizer (#5109)
* Update sentence recognizer * rename `sentrec` to `senter` * use `spacy.HashEmbedCNN.v1` by default * update to follow `Tagger` modifications * remove component methods that can be inherited from `Tagger` * add simple initialization and overfitting pipeline tests * Update serialization test for senter
This commit is contained in:
parent
6ac9fc0619
commit
c95ce96c44
|
@ -157,6 +157,8 @@ def train(
|
|||
config_loc = default_dir / "ner_defaults.cfg"
|
||||
elif pipe == "textcat":
|
||||
config_loc = default_dir / "textcat_defaults.cfg"
|
||||
elif pipe == "senter":
|
||||
config_loc = default_dir / "senter_defaults.cfg"
|
||||
else:
|
||||
raise ValueError(f"Component {pipe} currently not supported.")
|
||||
pipe_cfg = util.load_config(config_loc, create_objects=False)
|
||||
|
@ -221,6 +223,8 @@ def train(
|
|||
config_loc = default_dir / "ner_defaults.cfg"
|
||||
elif pipe == "textcat":
|
||||
config_loc = default_dir / "textcat_defaults.cfg"
|
||||
elif pipe == "senter":
|
||||
config_loc = default_dir / "senter_defaults.cfg"
|
||||
else:
|
||||
raise ValueError(f"Component {pipe} currently not supported.")
|
||||
pipe_cfg = util.load_config(config_loc, create_objects=False)
|
||||
|
@ -559,7 +563,7 @@ def _score_for_model(meta):
|
|||
mean_acc.append((acc["ents_p"] + acc["ents_r"] + acc["ents_f"]) / 3)
|
||||
if "textcat" in pipes:
|
||||
mean_acc.append(acc["textcat_score"])
|
||||
if "sentrec" in pipes:
|
||||
if "senter" in pipes:
|
||||
mean_acc.append((acc["sent_p"] + acc["sent_r"] + acc["sent_f"]) / 3)
|
||||
return sum(mean_acc) / len(mean_acc)
|
||||
|
||||
|
@ -638,7 +642,7 @@ def _get_metrics(component):
|
|||
return ("tags_acc",)
|
||||
elif component == "ner":
|
||||
return ("ents_f", "ents_p", "ents_r", "ents_per_type")
|
||||
elif component == "sentrec":
|
||||
elif component == "senter":
|
||||
return ("sent_f", "sent_p", "sent_r")
|
||||
elif component == "textcat":
|
||||
return ("textcat_score",)
|
||||
|
@ -665,9 +669,9 @@ def _configure_training_output(pipeline, use_gpu, has_beam_widths):
|
|||
elif pipe == "textcat":
|
||||
row_head.extend(["Textcat Loss", "Textcat"])
|
||||
output_stats.extend(["textcat_loss", "textcat_score"])
|
||||
elif pipe == "sentrec":
|
||||
row_head.extend(["Sentrec Loss", "Sent P", "Sent R", "Sent F"])
|
||||
output_stats.extend(["sentrec_loss", "sent_p", "sent_r", "sent_f"])
|
||||
elif pipe == "senter":
|
||||
row_head.extend(["Senter Loss", "Sent P", "Sent R", "Sent F"])
|
||||
output_stats.extend(["senter_loss", "sent_p", "sent_r", "sent_f"])
|
||||
row_head.extend(["Token %", "CPU WPS"])
|
||||
output_stats.extend(["token_acc", "cpu_wps"])
|
||||
|
||||
|
@ -693,7 +697,7 @@ def _get_progress(
|
|||
scores["ner_loss"] = losses.get("ner", 0.0)
|
||||
scores["tag_loss"] = losses.get("tagger", 0.0)
|
||||
scores["textcat_loss"] = losses.get("textcat", 0.0)
|
||||
scores["sentrec_loss"] = losses.get("sentrec", 0.0)
|
||||
scores["senter_loss"] = losses.get("senter", 0.0)
|
||||
scores["cpu_wps"] = cpu_wps
|
||||
scores["gpu_wps"] = gpu_wps or 0.0
|
||||
scores.update(dev_scores)
|
||||
|
|
|
@ -190,7 +190,7 @@ class Language(object):
|
|||
default_textcat_config,
|
||||
default_nel_config,
|
||||
default_morphologizer_config,
|
||||
default_sentrec_config,
|
||||
default_senter_config,
|
||||
default_tensorizer_config,
|
||||
default_tok2vec_config,
|
||||
)
|
||||
|
@ -202,7 +202,7 @@ class Language(object):
|
|||
"textcat": default_textcat_config(),
|
||||
"entity_linker": default_nel_config(),
|
||||
"morphologizer": default_morphologizer_config(),
|
||||
"sentrec": default_sentrec_config(),
|
||||
"senter": default_senter_config(),
|
||||
"tensorizer": default_tensorizer_config(),
|
||||
"tok2vec": default_tok2vec_config(),
|
||||
}
|
||||
|
@ -267,8 +267,8 @@ class Language(object):
|
|||
return self.get_pipe("entity_linker")
|
||||
|
||||
@property
|
||||
def sentrec(self):
|
||||
return self.get_pipe("sentrec")
|
||||
def senter(self):
|
||||
return self.get_pipe("senter")
|
||||
|
||||
@property
|
||||
def matcher(self):
|
||||
|
|
|
@ -43,13 +43,13 @@ def default_ner():
|
|||
return util.load_config(loc, create_objects=True)["model"]
|
||||
|
||||
|
||||
def default_sentrec_config():
|
||||
loc = Path(__file__).parent / "sentrec_defaults.cfg"
|
||||
def default_senter_config():
|
||||
loc = Path(__file__).parent / "senter_defaults.cfg"
|
||||
return util.load_config(loc, create_objects=False)
|
||||
|
||||
|
||||
def default_sentrec():
|
||||
loc = Path(__file__).parent / "sentrec_defaults.cfg"
|
||||
def default_senter():
|
||||
loc = Path(__file__).parent / "senter_defaults.cfg"
|
||||
return util.load_config(loc, create_objects=True)["model"]
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
@architectures = "spacy.Tagger.v1"
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.HashCharEmbedCNN.v1"
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = null
|
||||
width = 12
|
||||
depth = 1
|
||||
|
@ -10,5 +10,3 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 2
|
||||
subword_features = true
|
||||
nM = 64
|
||||
nC = 8
|
|
@ -650,7 +650,7 @@ class Tagger(Pipe):
|
|||
return self
|
||||
|
||||
|
||||
@component("sentrec", assigns=["token.is_sent_start"])
|
||||
@component("senter", assigns=["token.is_sent_start"])
|
||||
class SentenceRecognizer(Tagger):
|
||||
"""Pipeline component for sentence segmentation.
|
||||
|
||||
|
@ -670,7 +670,7 @@ class SentenceRecognizer(Tagger):
|
|||
# are 0
|
||||
return tuple(["I", "S"])
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids, **_):
|
||||
def set_annotations(self, docs, batch_tag_ids):
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
|
@ -686,24 +686,6 @@ class SentenceRecognizer(Tagger):
|
|||
else:
|
||||
doc.c[j].sent_start = -1
|
||||
|
||||
def update(self, examples, drop=0., sgd=None, losses=None):
|
||||
examples = Example.to_example_objects(examples)
|
||||
if losses is not None and self.name not in losses:
|
||||
losses[self.name] = 0.
|
||||
|
||||
if not any(len(ex.doc) if ex.doc else 0 for ex in examples):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return
|
||||
set_dropout_rate(self.model, drop)
|
||||
tag_scores, bp_tag_scores = self.model.begin_update([ex.doc for ex in examples])
|
||||
loss, d_tag_scores = self.get_loss(examples, tag_scores)
|
||||
bp_tag_scores(d_tag_scores)
|
||||
if sgd is not None:
|
||||
self.model.finish_update(sgd)
|
||||
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
scores = self.model.ops.flatten(scores)
|
||||
tag_index = range(len(self.labels))
|
||||
|
@ -732,9 +714,9 @@ class SentenceRecognizer(Tagger):
|
|||
|
||||
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None,
|
||||
**kwargs):
|
||||
cdef Vocab vocab = self.vocab
|
||||
self.set_output(len(self.labels))
|
||||
self.model.initialize()
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
@ -742,10 +724,6 @@ class SentenceRecognizer(Tagger):
|
|||
def add_label(self, label, values=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def use_params(self, params):
|
||||
with self.model.use_params(params):
|
||||
yield
|
||||
|
||||
def to_bytes(self, exclude=tuple(), **kwargs):
|
||||
serialize = {}
|
||||
serialize["model"] = self.model.to_bytes
|
||||
|
|
52
spacy/tests/pipeline/test_senter.py
Normal file
52
spacy/tests/pipeline/test_senter.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import pytest
|
||||
|
||||
from spacy import util
|
||||
from spacy.lang.en import English
|
||||
from spacy.language import Language
|
||||
from spacy.tests.util import make_tempdir
|
||||
|
||||
|
||||
def test_label_types():
|
||||
nlp = Language()
|
||||
nlp.add_pipe(nlp.create_pipe("senter"))
|
||||
with pytest.raises(NotImplementedError):
|
||||
nlp.get_pipe("senter").add_label("A")
|
||||
|
||||
SENT_STARTS = [0] * 14
|
||||
SENT_STARTS[0] = 1
|
||||
SENT_STARTS[5] = 1
|
||||
SENT_STARTS[9] = 1
|
||||
|
||||
TRAIN_DATA = [
|
||||
("I like green eggs. Eat blue ham. I like purple eggs.", {"sent_starts": SENT_STARTS}),
|
||||
("She likes purple eggs. They hate ham. You like yellow eggs.", {"sent_starts": SENT_STARTS}),
|
||||
]
|
||||
|
||||
|
||||
def test_overfitting_IO():
|
||||
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
||||
nlp = English()
|
||||
senter = nlp.create_pipe("senter")
|
||||
nlp.add_pipe(senter)
|
||||
optimizer = nlp.begin_training()
|
||||
|
||||
for i in range(200):
|
||||
losses = {}
|
||||
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
|
||||
assert losses["senter"] < 0.0001
|
||||
|
||||
# test the trained model
|
||||
test_text = "I like eggs. There is ham. She likes ham."
|
||||
doc = nlp(test_text)
|
||||
gold_sent_starts = [0] * 12
|
||||
gold_sent_starts[0] = 1
|
||||
gold_sent_starts[4] = 1
|
||||
gold_sent_starts[8] = 1
|
||||
assert gold_sent_starts == [int(t.is_sent_start) for t in doc]
|
||||
|
||||
# Also test the results are still the same after IO
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
doc2 = nlp2(test_text)
|
||||
assert gold_sent_starts == [int(t.is_sent_start) for t in doc2]
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
|
||||
from spacy.pipeline import Tensorizer, TextCategorizer, SentenceRecognizer
|
||||
from spacy.ml.models.defaults import default_parser, default_tensorizer, default_tagger
|
||||
from spacy.ml.models.defaults import default_textcat, default_sentrec
|
||||
from spacy.ml.models.defaults import default_textcat, default_senter
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -146,7 +146,7 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
|
|||
|
||||
|
||||
def test_serialize_sentencerecognizer(en_vocab):
|
||||
sr = SentenceRecognizer(en_vocab, default_sentrec())
|
||||
sr = SentenceRecognizer(en_vocab, default_senter())
|
||||
sr_b = sr.to_bytes()
|
||||
sr_d = SentenceRecognizer(en_vocab, default_sentrec()).from_bytes(sr_b)
|
||||
sr_d = SentenceRecognizer(en_vocab, default_senter()).from_bytes(sr_b)
|
||||
assert sr.to_bytes() == sr_d.to_bytes()
|
||||
|
|
Loading…
Reference in New Issue
Block a user