mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Add -t2v argument to train_textcat script
This commit is contained in:
parent
764359c952
commit
4e3ed2ea88
|
@ -24,8 +24,9 @@ from spacy.util import minibatch, compounding
|
||||||
output_dir=("Optional output directory", "option", "o", Path),
|
output_dir=("Optional output directory", "option", "o", Path),
|
||||||
n_texts=("Number of texts to train from", "option", "t", int),
|
n_texts=("Number of texts to train from", "option", "t", int),
|
||||||
n_iter=("Number of training iterations", "option", "n", int),
|
n_iter=("Number of training iterations", "option", "n", int),
|
||||||
|
init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path)
|
||||||
)
|
)
|
||||||
def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
|
def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None):
|
||||||
if output_dir is not None:
|
if output_dir is not None:
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
|
@ -67,6 +68,9 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"]
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"]
|
||||||
with nlp.disable_pipes(*other_pipes): # only train textcat
|
with nlp.disable_pipes(*other_pipes): # only train textcat
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
|
if init_tok2vec is not None:
|
||||||
|
with init_tok2vec.open("rb") as file_:
|
||||||
|
textcat.model.tok2vec.from_bytes(file_.read())
|
||||||
print("Training the model...")
|
print("Training the model...")
|
||||||
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
|
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user