mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
Update NER training example
This commit is contained in:
parent
fe28602f2e
commit
5c30466c95
|
@ -3,66 +3,26 @@ import json
|
||||||
import pathlib
|
import pathlib
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import spacy
|
import spacy.lang.en
|
||||||
from spacy.pipeline import EntityRecognizer
|
from spacy.gold import GoldParse, biluo_tags_from_offsets
|
||||||
from spacy.gold import GoldParse
|
|
||||||
from spacy.tagger import Tagger
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
def reformat_train_data(tokenizer, examples):
|
||||||
unicode
|
"""Reformat data to match JSON format"""
|
||||||
except:
|
output = []
|
||||||
unicode = str
|
for i, (text, entity_offsets) in enumerate(examples):
|
||||||
|
doc = tokenizer(text)
|
||||||
|
ner_tags = biluo_tags_from_offsets(tokenizer(text), entity_offsets)
|
||||||
def train_ner(nlp, train_data, entity_types):
|
words = [w.text for w in doc]
|
||||||
# Add new words to vocab.
|
tags = ['-'] * len(doc)
|
||||||
for raw_text, _ in train_data:
|
heads = [0] * len(doc)
|
||||||
doc = nlp.make_doc(raw_text)
|
deps = [''] * len(doc)
|
||||||
for word in doc:
|
sentence = (range(len(doc)), words, tags, heads, deps, ner_tags)
|
||||||
_ = nlp.vocab[word.orth]
|
output.append((text, [(sentence, [])]))
|
||||||
|
return output
|
||||||
# Train NER.
|
|
||||||
ner = EntityRecognizer(nlp.vocab, entity_types=entity_types)
|
|
||||||
for itn in range(5):
|
|
||||||
random.shuffle(train_data)
|
|
||||||
for raw_text, entity_offsets in train_data:
|
|
||||||
doc = nlp.make_doc(raw_text)
|
|
||||||
gold = GoldParse(doc, entities=entity_offsets)
|
|
||||||
ner.update(doc, gold)
|
|
||||||
return ner
|
|
||||||
|
|
||||||
def save_model(ner, model_dir):
|
|
||||||
model_dir = pathlib.Path(model_dir)
|
|
||||||
if not model_dir.exists():
|
|
||||||
model_dir.mkdir()
|
|
||||||
assert model_dir.is_dir()
|
|
||||||
|
|
||||||
with (model_dir / 'config.json').open('wb') as file_:
|
|
||||||
data = json.dumps(ner.cfg)
|
|
||||||
if isinstance(data, unicode):
|
|
||||||
data = data.encode('utf8')
|
|
||||||
file_.write(data)
|
|
||||||
ner.model.dump(str(model_dir / 'model'))
|
|
||||||
if not (model_dir / 'vocab').exists():
|
|
||||||
(model_dir / 'vocab').mkdir()
|
|
||||||
ner.vocab.dump(str(model_dir / 'vocab' / 'lexemes.bin'))
|
|
||||||
with (model_dir / 'vocab' / 'strings.json').open('w', encoding='utf8') as file_:
|
|
||||||
ner.vocab.strings.dump(file_)
|
|
||||||
|
|
||||||
|
|
||||||
def main(model_dir=None):
|
def main(model_dir=None):
|
||||||
nlp = spacy.load('en', parser=False, entity=False, add_vectors=False)
|
|
||||||
|
|
||||||
# v1.1.2 onwards
|
|
||||||
if nlp.tagger is None:
|
|
||||||
print('---- WARNING ----')
|
|
||||||
print('Data directory not found')
|
|
||||||
print('please run: `python -m spacy.en.download --force all` for better performance')
|
|
||||||
print('Using feature templates for tagging')
|
|
||||||
print('-----------------')
|
|
||||||
nlp.tagger = Tagger(nlp.vocab, features=Tagger.feature_templates)
|
|
||||||
|
|
||||||
train_data = [
|
train_data = [
|
||||||
(
|
(
|
||||||
'Who is Shaka Khan?',
|
'Who is Shaka Khan?',
|
||||||
|
@ -74,23 +34,35 @@ def main(model_dir=None):
|
||||||
(len('I like London and '), len('I like London and Berlin'), 'LOC')]
|
(len('I like London and '), len('I like London and Berlin'), 'LOC')]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ner = train_ner(nlp, train_data, ['PERSON', 'LOC'])
|
nlp = spacy.lang.en.English(pipeline=['tensorizer', 'ner'])
|
||||||
|
get_data = lambda: reformat_train_data(nlp.tokenizer, train_data)
|
||||||
doc = nlp.make_doc('Who is Shaka Khan?')
|
optimizer = nlp.begin_training(get_data)
|
||||||
nlp.tagger(doc)
|
for itn in range(100):
|
||||||
ner(doc)
|
random.shuffle(train_data)
|
||||||
for word in doc:
|
losses = {}
|
||||||
print(word.text, word.orth, word.lower, word.tag_, word.ent_type_, word.ent_iob)
|
for raw_text, entity_offsets in train_data:
|
||||||
|
doc = nlp.make_doc(raw_text)
|
||||||
if model_dir is not None:
|
gold = GoldParse(doc, entities=entity_offsets)
|
||||||
save_model(ner, model_dir)
|
nlp.update(
|
||||||
|
[doc], # Batch of Doc objects
|
||||||
|
[gold], # Batch of GoldParse objects
|
||||||
|
drop=0.5, # Dropout -- make it harder to memorise data
|
||||||
|
sgd=optimizer, # Callable to update weights
|
||||||
|
losses=losses)
|
||||||
|
print(losses)
|
||||||
|
print("Save to", model_dir)
|
||||||
|
nlp.to_disk(model_dir)
|
||||||
|
print("Load from", model_dir)
|
||||||
|
nlp = spacy.lang.en.English(pipeline=['tensorizer', 'ner'])
|
||||||
|
nlp.from_disk(model_dir)
|
||||||
|
for raw_text, _ in train_data:
|
||||||
|
doc = nlp(raw_text)
|
||||||
|
for word in doc:
|
||||||
|
print(word.text, word.ent_type_, word.ent_iob_)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main('ner')
|
import plac
|
||||||
|
plac.call(main)
|
||||||
# Who "" 2
|
# Who "" 2
|
||||||
# is "" 2
|
# is "" 2
|
||||||
# Shaka "" PERSON 3
|
# Shaka "" PERSON 3
|
||||||
|
|
Loading…
Reference in New Issue
Block a user