mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-30 19:24:07 +03:00
* Add updated unsupervised_train script, from the wsd directory
This commit is contained in:
parent
1d21eebda4
commit
eb3057d806
|
@ -4,16 +4,16 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from os import path
|
from os import path
|
||||||
import shutil
|
|
||||||
import codecs
|
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
|
|
||||||
import plac
|
import plac
|
||||||
import cProfile
|
|
||||||
import pstats
|
from spacy.munge.corpus import DocsDB
|
||||||
import re
|
from spacy.munge.read_semcor import read_semcor
|
||||||
|
|
||||||
from spacy.en import English
|
from spacy.en import English
|
||||||
|
from spacy.syntax.util import Config
|
||||||
|
|
||||||
|
|
||||||
def score_model(nlp, semcor_docs):
|
def score_model(nlp, semcor_docs):
|
||||||
|
@ -24,8 +24,11 @@ def score_model(nlp, semcor_docs):
|
||||||
for pnum, para in paras:
|
for pnum, para in paras:
|
||||||
for snum, sent in para:
|
for snum, sent in para:
|
||||||
words = [t.orth for t in sent]
|
words = [t.orth for t in sent]
|
||||||
|
if len(words) < 2:
|
||||||
|
continue
|
||||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
tokens = nlp.tokenizer.tokens_from_list(words)
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
|
nlp.parser(tokens)
|
||||||
nlp.senser(tokens)
|
nlp.senser(tokens)
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
if '_' in sent[i].orth:
|
if '_' in sent[i].orth:
|
||||||
|
@ -33,40 +36,44 @@ def score_model(nlp, semcor_docs):
|
||||||
elif sent[i].supersense != 'NO_SENSE':
|
elif sent[i].supersense != 'NO_SENSE':
|
||||||
n_right += token.sense_ == sent[i].supersense
|
n_right += token.sense_ == sent[i].supersense
|
||||||
n_wrong += token.sense_ != sent[i].supersense
|
n_wrong += token.sense_ != sent[i].supersense
|
||||||
return n_multi, n_right, n_wrong
|
return n_right / (n_right + n_wrong)
|
||||||
|
|
||||||
|
|
||||||
def train(Language, model_dir, docs, annotations, report_every=1000, n_docs=1000):
|
def train(Language, model_dir, train_docs, dev_docs,
|
||||||
|
report_every=1000, n_docs=1000, seed=0):
|
||||||
wsd_model_dir = path.join(model_dir, 'wsd')
|
wsd_model_dir = path.join(model_dir, 'wsd')
|
||||||
if path.exists(pos_model_dir):
|
if path.exists(wsd_model_dir):
|
||||||
shutil.rmtree(pos_model_dir)
|
shutil.rmtree(wsd_model_dir)
|
||||||
os.mkdir(wsd_model_dir)
|
os.mkdir(wsd_model_dir)
|
||||||
|
|
||||||
Config.write(wsd_model_dir, 'config', seed=seed)
|
Config.write(wsd_model_dir, 'config', seed=seed)
|
||||||
|
|
||||||
nlp = Language(data_dir=model_dir)
|
nlp = Language(data_dir=model_dir, load_vectors=False)
|
||||||
|
|
||||||
for doc in corpus:
|
loss = 0
|
||||||
tokens = nlp(doc, senser=False)
|
n_tokens = 0
|
||||||
|
for i, doc in enumerate(train_docs):
|
||||||
|
tokens = nlp(doc, parse=True, entity=False)
|
||||||
loss += nlp.senser.train(tokens)
|
loss += nlp.senser.train(tokens)
|
||||||
if i and not i % report_every:
|
n_tokens += len(tokens)
|
||||||
|
if i and i % report_every == 0:
|
||||||
acc = score_model(nlp, dev_docs)
|
acc = score_model(nlp, dev_docs)
|
||||||
print loss, n_right / (n_right + n_wrong)
|
print i, loss / n_tokens, acc
|
||||||
nlp.senser.end_training()
|
nlp.senser.end_training()
|
||||||
nlp.vocab.strings.dump(path.join(model_dir, 'vocab', 'strings.txt'))
|
nlp.vocab.strings.dump(path.join(model_dir, 'vocab', 'strings.txt'))
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
docs_db_loc=("Location of the documents SQLite database"),
|
train_loc=("Location of the documents SQLite database"),
|
||||||
|
dev_loc=("Location of the SemCor corpus directory"),
|
||||||
model_dir=("Location of the models directory"),
|
model_dir=("Location of the models directory"),
|
||||||
n_docs=("Number of training documents", "option", "n", int),
|
n_docs=("Number of training documents", "option", "n", int),
|
||||||
verbose=("Verbose error reporting", "flag", "v", bool),
|
seed=("Random seed", "option", "s", int),
|
||||||
debug=("Debug mode", "flag", "d", bool),
|
|
||||||
)
|
)
|
||||||
def main(train_loc, dev_loc, model_dir, n_docs=0):
|
def main(train_loc, dev_loc, model_dir, n_docs=1000000, seed=0):
|
||||||
train_docs = DocsDB(train_loc)
|
train_docs = DocsDB(train_loc, limit=n_docs)
|
||||||
dev_docs = read_semcor(dev_loc)
|
dev_docs = read_semcor(dev_loc)
|
||||||
train(English, model_dir, train_docs, dev_docs, report_every=10, n_docs=1000):
|
train(English, model_dir, train_docs, dev_docs, report_every=100, seed=seed)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user