mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
fix bugs from porting master to develop
This commit is contained in:
parent
192b8d45a1
commit
fc6e34c3a1
|
@ -175,12 +175,10 @@ def main(
|
|||
kb=kb,
|
||||
labels_discard=labels_discard,
|
||||
)
|
||||
docs, golds = zip(*train_batch)
|
||||
try:
|
||||
with nlp.disable_pipes(*other_pipes):
|
||||
nlp.update(
|
||||
docs=docs,
|
||||
golds=golds,
|
||||
examples=train_batch,
|
||||
sgd=optimizer,
|
||||
drop=dropout,
|
||||
losses=losses,
|
||||
|
|
|
@ -28,13 +28,6 @@ def train(
|
|||
pipeline: ("Comma-separated names of pipeline components", "option", "p", str) = "tagger,parser,ner",
|
||||
vectors: ("Model to load vectors from", "option", "v", str) = None,
|
||||
replace_components: ("Replace components from base model", "flag", "R", bool) = False,
|
||||
width: ("Width of CNN layers of Tok2Vec component", "option", "cw", int) = 96,
|
||||
conv_depth: ("Depth of CNN layers of Tok2Vec component", "option", "cd", int) = 4,
|
||||
cnn_window: ("Window size for CNN layers of Tok2Vec component", "option", "cW", int) = 1,
|
||||
cnn_pieces: ("Maxout size for CNN layers of Tok2Vec component. 1 for Mish", "option", "cP", int) = 3,
|
||||
use_chars: ("Whether to use character-based embedding of Tok2Vec component", "flag", "chr", bool) = False,
|
||||
bilstm_depth: ("Depth of BiLSTM layers of Tok2Vec component (requires PyTorch)", "option", "lstm", int) = 0,
|
||||
embed_rows: ("Number of embedding rows of Tok2Vec component", "option", "er", int) = 2000,
|
||||
n_iter: ("Number of iterations", "option", "n", int) = 30,
|
||||
n_early_stopping: ("Maximum number of training epochs without dev accuracy improvement", "option", "ne", int) = None,
|
||||
n_examples: ("Number of examples", "option", "ns", int) = 0,
|
||||
|
@ -232,14 +225,7 @@ def train(
|
|||
else:
|
||||
# Start with a blank model, call begin_training
|
||||
cfg = {"device": use_gpu}
|
||||
cfg["conv_depth"] = conv_depth
|
||||
cfg["token_vector_width"] = width
|
||||
cfg["bilstm_depth"] = bilstm_depth
|
||||
cfg["cnn_maxout_pieces"] = cnn_pieces
|
||||
cfg["embed_size"] = embed_rows
|
||||
cfg["conv_window"] = cnn_window
|
||||
cfg["subword_features"] = not use_chars
|
||||
optimizer = nlp.begin_training(lambda: corpus.train_tuples, **cfg)
|
||||
optimizer = nlp.begin_training(lambda: corpus.train_examples, **cfg)
|
||||
nlp._optimizer = None
|
||||
|
||||
# Load in pretrained weights
|
||||
|
@ -362,11 +348,9 @@ def train(
|
|||
for batch in util.minibatch_by_words(train_data, size=batch_sizes):
|
||||
if not batch:
|
||||
continue
|
||||
docs, golds = zip(*batch)
|
||||
try:
|
||||
nlp.update(
|
||||
docs,
|
||||
golds,
|
||||
batch,
|
||||
sgd=optimizer,
|
||||
drop=next(dropout_rates),
|
||||
losses=losses,
|
||||
|
@ -609,7 +593,7 @@ def _get_metrics(component):
|
|||
elif component == "tagger":
|
||||
return ("tags_acc",)
|
||||
elif component == "ner":
|
||||
return ("ents_f", "ents_p", "ents_r", "enty_per_type")
|
||||
return ("ents_f", "ents_p", "ents_r", "ents_per_type")
|
||||
elif component == "sentrec":
|
||||
return ("sent_f", "sent_p", "sent_r")
|
||||
elif component == "textcat":
|
||||
|
|
Loading…
Reference in New Issue
Block a user