mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Update NER training example
This commit is contained in:
parent
fe28602f2e
commit
5c30466c95
|
@ -3,66 +3,26 @@ import json
|
|||
import pathlib
|
||||
import random
|
||||
|
||||
import spacy
|
||||
from spacy.pipeline import EntityRecognizer
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.tagger import Tagger
|
||||
import spacy.lang.en
|
||||
from spacy.gold import GoldParse, biluo_tags_from_offsets
|
||||
|
||||
|
||||
try:
|
||||
unicode
|
||||
except:
|
||||
unicode = str
|
||||
|
||||
|
||||
def train_ner(nlp, train_data, entity_types):
|
||||
# Add new words to vocab.
|
||||
for raw_text, _ in train_data:
|
||||
doc = nlp.make_doc(raw_text)
|
||||
for word in doc:
|
||||
_ = nlp.vocab[word.orth]
|
||||
|
||||
# Train NER.
|
||||
ner = EntityRecognizer(nlp.vocab, entity_types=entity_types)
|
||||
for itn in range(5):
|
||||
random.shuffle(train_data)
|
||||
for raw_text, entity_offsets in train_data:
|
||||
doc = nlp.make_doc(raw_text)
|
||||
gold = GoldParse(doc, entities=entity_offsets)
|
||||
ner.update(doc, gold)
|
||||
return ner
|
||||
|
||||
def save_model(ner, model_dir):
|
||||
model_dir = pathlib.Path(model_dir)
|
||||
if not model_dir.exists():
|
||||
model_dir.mkdir()
|
||||
assert model_dir.is_dir()
|
||||
|
||||
with (model_dir / 'config.json').open('wb') as file_:
|
||||
data = json.dumps(ner.cfg)
|
||||
if isinstance(data, unicode):
|
||||
data = data.encode('utf8')
|
||||
file_.write(data)
|
||||
ner.model.dump(str(model_dir / 'model'))
|
||||
if not (model_dir / 'vocab').exists():
|
||||
(model_dir / 'vocab').mkdir()
|
||||
ner.vocab.dump(str(model_dir / 'vocab' / 'lexemes.bin'))
|
||||
with (model_dir / 'vocab' / 'strings.json').open('w', encoding='utf8') as file_:
|
||||
ner.vocab.strings.dump(file_)
|
||||
def reformat_train_data(tokenizer, examples):
|
||||
"""Reformat data to match JSON format"""
|
||||
output = []
|
||||
for i, (text, entity_offsets) in enumerate(examples):
|
||||
doc = tokenizer(text)
|
||||
ner_tags = biluo_tags_from_offsets(tokenizer(text), entity_offsets)
|
||||
words = [w.text for w in doc]
|
||||
tags = ['-'] * len(doc)
|
||||
heads = [0] * len(doc)
|
||||
deps = [''] * len(doc)
|
||||
sentence = (range(len(doc)), words, tags, heads, deps, ner_tags)
|
||||
output.append((text, [(sentence, [])]))
|
||||
return output
|
||||
|
||||
|
||||
def main(model_dir=None):
|
||||
nlp = spacy.load('en', parser=False, entity=False, add_vectors=False)
|
||||
|
||||
# v1.1.2 onwards
|
||||
if nlp.tagger is None:
|
||||
print('---- WARNING ----')
|
||||
print('Data directory not found')
|
||||
print('please run: `python -m spacy.en.download --force all` for better performance')
|
||||
print('Using feature templates for tagging')
|
||||
print('-----------------')
|
||||
nlp.tagger = Tagger(nlp.vocab, features=Tagger.feature_templates)
|
||||
|
||||
train_data = [
|
||||
(
|
||||
'Who is Shaka Khan?',
|
||||
|
@ -74,23 +34,35 @@ def main(model_dir=None):
|
|||
(len('I like London and '), len('I like London and Berlin'), 'LOC')]
|
||||
)
|
||||
]
|
||||
ner = train_ner(nlp, train_data, ['PERSON', 'LOC'])
|
||||
|
||||
doc = nlp.make_doc('Who is Shaka Khan?')
|
||||
nlp.tagger(doc)
|
||||
ner(doc)
|
||||
for word in doc:
|
||||
print(word.text, word.orth, word.lower, word.tag_, word.ent_type_, word.ent_iob)
|
||||
|
||||
if model_dir is not None:
|
||||
save_model(ner, model_dir)
|
||||
|
||||
|
||||
|
||||
|
||||
nlp = spacy.lang.en.English(pipeline=['tensorizer', 'ner'])
|
||||
get_data = lambda: reformat_train_data(nlp.tokenizer, train_data)
|
||||
optimizer = nlp.begin_training(get_data)
|
||||
for itn in range(100):
|
||||
random.shuffle(train_data)
|
||||
losses = {}
|
||||
for raw_text, entity_offsets in train_data:
|
||||
doc = nlp.make_doc(raw_text)
|
||||
gold = GoldParse(doc, entities=entity_offsets)
|
||||
nlp.update(
|
||||
[doc], # Batch of Doc objects
|
||||
[gold], # Batch of GoldParse objects
|
||||
drop=0.5, # Dropout -- make it harder to memorise data
|
||||
sgd=optimizer, # Callable to update weights
|
||||
losses=losses)
|
||||
print(losses)
|
||||
print("Save to", model_dir)
|
||||
nlp.to_disk(model_dir)
|
||||
print("Load from", model_dir)
|
||||
nlp = spacy.lang.en.English(pipeline=['tensorizer', 'ner'])
|
||||
nlp.from_disk(model_dir)
|
||||
for raw_text, _ in train_data:
|
||||
doc = nlp(raw_text)
|
||||
for word in doc:
|
||||
print(word.text, word.ent_type_, word.ent_iob_)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main('ner')
|
||||
import plac
|
||||
plac.call(main)
|
||||
# Who "" 2
|
||||
# is "" 2
|
||||
# Shaka "" PERSON 3
|
||||
|
|
Loading…
Reference in New Issue
Block a user