mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
train_ner should save vocab; add load_ner example
This commit is contained in:
parent
5ad5408242
commit
ad54a929f8
|
@ -10,6 +10,13 @@ from spacy.tagger import Tagger
|
|||
|
||||
|
||||
def train_ner(nlp, train_data, entity_types):
|
||||
# Add new words to vocab.
|
||||
for raw_text, _ in train_data:
|
||||
doc = nlp.make_doc(raw_text)
|
||||
for word in doc:
|
||||
_ = nlp.vocab[word.orth]
|
||||
|
||||
# Train NER.
|
||||
ner = EntityRecognizer(nlp.vocab, entity_types=entity_types)
|
||||
for itn in range(5):
|
||||
random.shuffle(train_data)
|
||||
|
@ -20,21 +27,30 @@ def train_ner(nlp, train_data, entity_types):
|
|||
ner.model.end_training()
|
||||
return ner
|
||||
|
||||
def save_model(ner, model_dir):
|
||||
model_dir = pathlib.Path(model_dir)
|
||||
if not model_dir.exists():
|
||||
model_dir.mkdir()
|
||||
assert model_dir.is_dir()
|
||||
|
||||
with (model_dir / 'config.json').open('w') as file_:
|
||||
json.dump(ner.cfg, file_)
|
||||
ner.model.dump(str(model_dir / 'model'))
|
||||
if not (model_dir / 'vocab').exists():
|
||||
(model_dir / 'vocab').mkdir()
|
||||
ner.vocab.dump(str(model_dir / 'vocab' / 'lexemes.bin'))
|
||||
with (model_dir / 'vocab' / 'strings.json').open('w', encoding='utf8') as file_:
|
||||
ner.vocab.strings.dump(file_)
|
||||
|
||||
|
||||
def main(model_dir=None):
|
||||
if model_dir is not None:
|
||||
model_dir = pathlib.Path(model_dir)
|
||||
if not model_dir.exists():
|
||||
model_dir.mkdir()
|
||||
assert model_dir.is_dir()
|
||||
|
||||
nlp = spacy.load('en', parser=False, entity=False, add_vectors=False)
|
||||
|
||||
# v1.1.2 onwards
|
||||
if nlp.tagger is None:
|
||||
print('---- WARNING ----')
|
||||
print('Data directory not found')
|
||||
print('please run: `python -m spacy.en.download –force all` for better performance')
|
||||
print('please run: `python -m spacy.en.download --force all` for better performance')
|
||||
print('Using feature templates for tagging')
|
||||
print('-----------------')
|
||||
nlp.tagger = Tagger(nlp.vocab, features=Tagger.feature_templates)
|
||||
|
@ -56,16 +72,17 @@ def main(model_dir=None):
|
|||
nlp.tagger(doc)
|
||||
ner(doc)
|
||||
for word in doc:
|
||||
print(word.text, word.tag_, word.ent_type_, word.ent_iob)
|
||||
print(word.text, word.orth, word.lower, word.tag_, word.ent_type_, word.ent_iob)
|
||||
|
||||
if model_dir is not None:
|
||||
with (model_dir / 'config.json').open('w') as file_:
|
||||
json.dump(ner.cfg, file_)
|
||||
ner.model.dump(str(model_dir / 'model'))
|
||||
save_model(ner, model_dir)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main('ner')
|
||||
# Who "" 2
|
||||
# is "" 2
|
||||
# Shaka "" PERSON 3
|
||||
|
|
|
@ -69,7 +69,7 @@ def main(output_dir=None):
|
|||
print(word.text, word.tag_, word.pos_)
|
||||
if output_dir is not None:
|
||||
tagger.model.dump(str(output_dir / 'pos' / 'model'))
|
||||
with (output_dir / 'vocab' / 'strings.json').open('wb') as file_:
|
||||
with (output_dir / 'vocab' / 'strings.json').open('w') as file_:
|
||||
tagger.vocab.strings.dump(file_)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user