mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-30 18:03:04 +03:00
train_ner should save vocab; add load_ner example
This commit is contained in:
parent
5ad5408242
commit
ad54a929f8
|
@ -10,6 +10,13 @@ from spacy.tagger import Tagger
|
||||||
|
|
||||||
|
|
||||||
def train_ner(nlp, train_data, entity_types):
|
def train_ner(nlp, train_data, entity_types):
|
||||||
|
# Add new words to vocab.
|
||||||
|
for raw_text, _ in train_data:
|
||||||
|
doc = nlp.make_doc(raw_text)
|
||||||
|
for word in doc:
|
||||||
|
_ = nlp.vocab[word.orth]
|
||||||
|
|
||||||
|
# Train NER.
|
||||||
ner = EntityRecognizer(nlp.vocab, entity_types=entity_types)
|
ner = EntityRecognizer(nlp.vocab, entity_types=entity_types)
|
||||||
for itn in range(5):
|
for itn in range(5):
|
||||||
random.shuffle(train_data)
|
random.shuffle(train_data)
|
||||||
|
@ -20,21 +27,30 @@ def train_ner(nlp, train_data, entity_types):
|
||||||
ner.model.end_training()
|
ner.model.end_training()
|
||||||
return ner
|
return ner
|
||||||
|
|
||||||
|
def save_model(ner, model_dir):
|
||||||
|
model_dir = pathlib.Path(model_dir)
|
||||||
|
if not model_dir.exists():
|
||||||
|
model_dir.mkdir()
|
||||||
|
assert model_dir.is_dir()
|
||||||
|
|
||||||
|
with (model_dir / 'config.json').open('w') as file_:
|
||||||
|
json.dump(ner.cfg, file_)
|
||||||
|
ner.model.dump(str(model_dir / 'model'))
|
||||||
|
if not (model_dir / 'vocab').exists():
|
||||||
|
(model_dir / 'vocab').mkdir()
|
||||||
|
ner.vocab.dump(str(model_dir / 'vocab' / 'lexemes.bin'))
|
||||||
|
with (model_dir / 'vocab' / 'strings.json').open('w', encoding='utf8') as file_:
|
||||||
|
ner.vocab.strings.dump(file_)
|
||||||
|
|
||||||
|
|
||||||
def main(model_dir=None):
|
def main(model_dir=None):
|
||||||
if model_dir is not None:
|
|
||||||
model_dir = pathlib.Path(model_dir)
|
|
||||||
if not model_dir.exists():
|
|
||||||
model_dir.mkdir()
|
|
||||||
assert model_dir.is_dir()
|
|
||||||
|
|
||||||
nlp = spacy.load('en', parser=False, entity=False, add_vectors=False)
|
nlp = spacy.load('en', parser=False, entity=False, add_vectors=False)
|
||||||
|
|
||||||
# v1.1.2 onwards
|
# v1.1.2 onwards
|
||||||
if nlp.tagger is None:
|
if nlp.tagger is None:
|
||||||
print('---- WARNING ----')
|
print('---- WARNING ----')
|
||||||
print('Data directory not found')
|
print('Data directory not found')
|
||||||
print('please run: `python -m spacy.en.download –force all` for better performance')
|
print('please run: `python -m spacy.en.download --force all` for better performance')
|
||||||
print('Using feature templates for tagging')
|
print('Using feature templates for tagging')
|
||||||
print('-----------------')
|
print('-----------------')
|
||||||
nlp.tagger = Tagger(nlp.vocab, features=Tagger.feature_templates)
|
nlp.tagger = Tagger(nlp.vocab, features=Tagger.feature_templates)
|
||||||
|
@ -56,16 +72,17 @@ def main(model_dir=None):
|
||||||
nlp.tagger(doc)
|
nlp.tagger(doc)
|
||||||
ner(doc)
|
ner(doc)
|
||||||
for word in doc:
|
for word in doc:
|
||||||
print(word.text, word.tag_, word.ent_type_, word.ent_iob)
|
print(word.text, word.orth, word.lower, word.tag_, word.ent_type_, word.ent_iob)
|
||||||
|
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
with (model_dir / 'config.json').open('w') as file_:
|
save_model(ner, model_dir)
|
||||||
json.dump(ner.cfg, file_)
|
|
||||||
ner.model.dump(str(model_dir / 'model'))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main('ner')
|
||||||
# Who "" 2
|
# Who "" 2
|
||||||
# is "" 2
|
# is "" 2
|
||||||
# Shaka "" PERSON 3
|
# Shaka "" PERSON 3
|
||||||
|
|
|
@ -69,7 +69,7 @@ def main(output_dir=None):
|
||||||
print(word.text, word.tag_, word.pos_)
|
print(word.text, word.tag_, word.pos_)
|
||||||
if output_dir is not None:
|
if output_dir is not None:
|
||||||
tagger.model.dump(str(output_dir / 'pos' / 'model'))
|
tagger.model.dump(str(output_dir / 'pos' / 'model'))
|
||||||
with (output_dir / 'vocab' / 'strings.json').open('wb') as file_:
|
with (output_dir / 'vocab' / 'strings.json').open('w') as file_:
|
||||||
tagger.vocab.strings.dump(file_)
|
tagger.vocab.strings.dump(file_)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user