mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-24 07:30:52 +03:00
Merge pull request #5062 from svlandeg/bugfix/merge-conflicts
Fix sync between master and develop
This commit is contained in:
commit
f39ddda193
|
@ -175,12 +175,10 @@ def main(
|
||||||
kb=kb,
|
kb=kb,
|
||||||
labels_discard=labels_discard,
|
labels_discard=labels_discard,
|
||||||
)
|
)
|
||||||
docs, golds = zip(*train_batch)
|
|
||||||
try:
|
try:
|
||||||
with nlp.disable_pipes(*other_pipes):
|
with nlp.disable_pipes(*other_pipes):
|
||||||
nlp.update(
|
nlp.update(
|
||||||
docs=docs,
|
examples=train_batch,
|
||||||
golds=golds,
|
|
||||||
sgd=optimizer,
|
sgd=optimizer,
|
||||||
drop=dropout,
|
drop=dropout,
|
||||||
losses=losses,
|
losses=losses,
|
||||||
|
|
|
@ -28,13 +28,6 @@ def train(
|
||||||
pipeline: ("Comma-separated names of pipeline components", "option", "p", str) = "tagger,parser,ner",
|
pipeline: ("Comma-separated names of pipeline components", "option", "p", str) = "tagger,parser,ner",
|
||||||
vectors: ("Model to load vectors from", "option", "v", str) = None,
|
vectors: ("Model to load vectors from", "option", "v", str) = None,
|
||||||
replace_components: ("Replace components from base model", "flag", "R", bool) = False,
|
replace_components: ("Replace components from base model", "flag", "R", bool) = False,
|
||||||
width: ("Width of CNN layers of Tok2Vec component", "option", "cw", int) = 96,
|
|
||||||
conv_depth: ("Depth of CNN layers of Tok2Vec component", "option", "cd", int) = 4,
|
|
||||||
cnn_window: ("Window size for CNN layers of Tok2Vec component", "option", "cW", int) = 1,
|
|
||||||
cnn_pieces: ("Maxout size for CNN layers of Tok2Vec component. 1 for Mish", "option", "cP", int) = 3,
|
|
||||||
use_chars: ("Whether to use character-based embedding of Tok2Vec component", "flag", "chr", bool) = False,
|
|
||||||
bilstm_depth: ("Depth of BiLSTM layers of Tok2Vec component (requires PyTorch)", "option", "lstm", int) = 0,
|
|
||||||
embed_rows: ("Number of embedding rows of Tok2Vec component", "option", "er", int) = 2000,
|
|
||||||
n_iter: ("Number of iterations", "option", "n", int) = 30,
|
n_iter: ("Number of iterations", "option", "n", int) = 30,
|
||||||
n_early_stopping: ("Maximum number of training epochs without dev accuracy improvement", "option", "ne", int) = None,
|
n_early_stopping: ("Maximum number of training epochs without dev accuracy improvement", "option", "ne", int) = None,
|
||||||
n_examples: ("Number of examples", "option", "ns", int) = 0,
|
n_examples: ("Number of examples", "option", "ns", int) = 0,
|
||||||
|
@ -232,14 +225,7 @@ def train(
|
||||||
else:
|
else:
|
||||||
# Start with a blank model, call begin_training
|
# Start with a blank model, call begin_training
|
||||||
cfg = {"device": use_gpu}
|
cfg = {"device": use_gpu}
|
||||||
cfg["conv_depth"] = conv_depth
|
optimizer = nlp.begin_training(lambda: corpus.train_examples, **cfg)
|
||||||
cfg["token_vector_width"] = width
|
|
||||||
cfg["bilstm_depth"] = bilstm_depth
|
|
||||||
cfg["cnn_maxout_pieces"] = cnn_pieces
|
|
||||||
cfg["embed_size"] = embed_rows
|
|
||||||
cfg["conv_window"] = cnn_window
|
|
||||||
cfg["subword_features"] = not use_chars
|
|
||||||
optimizer = nlp.begin_training(lambda: corpus.train_tuples, **cfg)
|
|
||||||
nlp._optimizer = None
|
nlp._optimizer = None
|
||||||
|
|
||||||
# Load in pretrained weights
|
# Load in pretrained weights
|
||||||
|
@ -362,11 +348,9 @@ def train(
|
||||||
for batch in util.minibatch_by_words(train_data, size=batch_sizes):
|
for batch in util.minibatch_by_words(train_data, size=batch_sizes):
|
||||||
if not batch:
|
if not batch:
|
||||||
continue
|
continue
|
||||||
docs, golds = zip(*batch)
|
|
||||||
try:
|
try:
|
||||||
nlp.update(
|
nlp.update(
|
||||||
docs,
|
batch,
|
||||||
golds,
|
|
||||||
sgd=optimizer,
|
sgd=optimizer,
|
||||||
drop=next(dropout_rates),
|
drop=next(dropout_rates),
|
||||||
losses=losses,
|
losses=losses,
|
||||||
|
@ -609,7 +593,7 @@ def _get_metrics(component):
|
||||||
elif component == "tagger":
|
elif component == "tagger":
|
||||||
return ("tags_acc",)
|
return ("tags_acc",)
|
||||||
elif component == "ner":
|
elif component == "ner":
|
||||||
return ("ents_f", "ents_p", "ents_r", "enty_per_type")
|
return ("ents_f", "ents_p", "ents_r", "ents_per_type")
|
||||||
elif component == "sentrec":
|
elif component == "sentrec":
|
||||||
return ("sent_f", "sent_p", "sent_r")
|
return ("sent_f", "sent_p", "sent_r")
|
||||||
elif component == "textcat":
|
elif component == "textcat":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user