Merge pull request #679 from savvopoulos/train-ner-update

train_ner should save vocab; add load_ner example
This commit is contained in:
Matthew Honnibal 2016-12-13 07:13:30 +11:00 committed by GitHub
commit c4d9ea1186
2 changed files with 30 additions and 13 deletions

View File

@ -10,6 +10,13 @@ from spacy.tagger import Tagger
def train_ner(nlp, train_data, entity_types):
# Add new words to vocab.
for raw_text, _ in train_data:
doc = nlp.make_doc(raw_text)
for word in doc:
_ = nlp.vocab[word.orth]
# Train NER.
ner = EntityRecognizer(nlp.vocab, entity_types=entity_types)
for itn in range(5):
random.shuffle(train_data)
@ -20,21 +27,30 @@ def train_ner(nlp, train_data, entity_types):
ner.model.end_training()
return ner
def main(model_dir=None):
if model_dir is not None:
def save_model(ner, model_dir):
model_dir = pathlib.Path(model_dir)
if not model_dir.exists():
model_dir.mkdir()
assert model_dir.is_dir()
with (model_dir / 'config.json').open('w') as file_:
json.dump(ner.cfg, file_)
ner.model.dump(str(model_dir / 'model'))
if not (model_dir / 'vocab').exists():
(model_dir / 'vocab').mkdir()
ner.vocab.dump(str(model_dir / 'vocab' / 'lexemes.bin'))
with (model_dir / 'vocab' / 'strings.json').open('w', encoding='utf8') as file_:
ner.vocab.strings.dump(file_)
def main(model_dir=None):
nlp = spacy.load('en', parser=False, entity=False, add_vectors=False)
# v1.1.2 onwards
if nlp.tagger is None:
print('---- WARNING ----')
print('Data directory not found')
print('please run: `python -m spacy.en.download force all` for better performance')
print('please run: `python -m spacy.en.download --force all` for better performance')
print('Using feature templates for tagging')
print('-----------------')
nlp.tagger = Tagger(nlp.vocab, features=Tagger.feature_templates)
@ -56,16 +72,17 @@ def main(model_dir=None):
nlp.tagger(doc)
ner(doc)
for word in doc:
print(word.text, word.tag_, word.ent_type_, word.ent_iob)
print(word.text, word.orth, word.lower, word.tag_, word.ent_type_, word.ent_iob)
if model_dir is not None:
with (model_dir / 'config.json').open('w') as file_:
json.dump(ner.cfg, file_)
ner.model.dump(str(model_dir / 'model'))
save_model(ner, model_dir)
if __name__ == '__main__':
main()
main('ner')
# Who "" 2
# is "" 2
# Shaka "" PERSON 3

View File

@ -69,7 +69,7 @@ def main(output_dir=None):
print(word.text, word.tag_, word.pos_)
if output_dir is not None:
tagger.model.dump(str(output_dir / 'pos' / 'model'))
with (output_dir / 'vocab' / 'strings.json').open('wb') as file_:
with (output_dir / 'vocab' / 'strings.json').open('w') as file_:
tagger.vocab.strings.dump(file_)