diff --git a/examples/training/train_ner.py b/examples/training/train_ner.py index bcc087d07..e50e36756 100644 --- a/examples/training/train_ner.py +++ b/examples/training/train_ner.py @@ -3,66 +3,26 @@ import json import pathlib import random -import spacy -from spacy.pipeline import EntityRecognizer -from spacy.gold import GoldParse -from spacy.tagger import Tagger - - -try: - unicode -except: - unicode = str +import spacy.lang.en +from spacy.gold import GoldParse, biluo_tags_from_offsets -def train_ner(nlp, train_data, entity_types): - # Add new words to vocab. - for raw_text, _ in train_data: - doc = nlp.make_doc(raw_text) - for word in doc: - _ = nlp.vocab[word.orth] - - # Train NER. - ner = EntityRecognizer(nlp.vocab, entity_types=entity_types) - for itn in range(5): - random.shuffle(train_data) - for raw_text, entity_offsets in train_data: - doc = nlp.make_doc(raw_text) - gold = GoldParse(doc, entities=entity_offsets) - ner.update(doc, gold) - return ner - -def save_model(ner, model_dir): - model_dir = pathlib.Path(model_dir) - if not model_dir.exists(): - model_dir.mkdir() - assert model_dir.is_dir() - - with (model_dir / 'config.json').open('wb') as file_: - data = json.dumps(ner.cfg) - if isinstance(data, unicode): - data = data.encode('utf8') - file_.write(data) - ner.model.dump(str(model_dir / 'model')) - if not (model_dir / 'vocab').exists(): - (model_dir / 'vocab').mkdir() - ner.vocab.dump(str(model_dir / 'vocab' / 'lexemes.bin')) - with (model_dir / 'vocab' / 'strings.json').open('w', encoding='utf8') as file_: - ner.vocab.strings.dump(file_) +def reformat_train_data(tokenizer, examples): + """Reformat data to match JSON format""" + output = [] + for i, (text, entity_offsets) in enumerate(examples): + doc = tokenizer(text) + ner_tags = biluo_tags_from_offsets(tokenizer(text), entity_offsets) + words = [w.text for w in doc] + tags = ['-'] * len(doc) + heads = [0] * len(doc) + deps = [''] * len(doc) + sentence = (range(len(doc)), words, tags, heads, deps, ner_tags) + output.append((text, [(sentence, [])])) + return output def main(model_dir=None): - nlp = spacy.load('en', parser=False, entity=False, add_vectors=False) - - # v1.1.2 onwards - if nlp.tagger is None: - print('---- WARNING ----') - print('Data directory not found') - print('please run: `python -m spacy.en.download --force all` for better performance') - print('Using feature templates for tagging') - print('-----------------') - nlp.tagger = Tagger(nlp.vocab, features=Tagger.feature_templates) - train_data = [ ( 'Who is Shaka Khan?', @@ -74,23 +34,35 @@ def main(model_dir=None): (len('I like London and '), len('I like London and Berlin'), 'LOC')] ) ] - ner = train_ner(nlp, train_data, ['PERSON', 'LOC']) - - doc = nlp.make_doc('Who is Shaka Khan?') - nlp.tagger(doc) - ner(doc) - for word in doc: - print(word.text, word.orth, word.lower, word.tag_, word.ent_type_, word.ent_iob) - - if model_dir is not None: - save_model(ner, model_dir) - - - - + nlp = spacy.lang.en.English(pipeline=['tensorizer', 'ner']) + get_data = lambda: reformat_train_data(nlp.tokenizer, train_data) + optimizer = nlp.begin_training(get_data) + for itn in range(100): + random.shuffle(train_data) + losses = {} + for raw_text, entity_offsets in train_data: + doc = nlp.make_doc(raw_text) + gold = GoldParse(doc, entities=entity_offsets) + nlp.update( + [doc], # Batch of Doc objects + [gold], # Batch of GoldParse objects + drop=0.5, # Dropout -- make it harder to memorise data + sgd=optimizer, # Callable to update weights + losses=losses) + print(losses) + print("Save to", model_dir) + nlp.to_disk(model_dir) + print("Load from", model_dir) + nlp = spacy.lang.en.English(pipeline=['tensorizer', 'ner']) + nlp.from_disk(model_dir) + for raw_text, _ in train_data: + doc = nlp(raw_text) + for word in doc: + print(word.text, word.ent_type_, word.ent_iob_) if __name__ == '__main__': - main('ner') + import plac + plac.call(main) # Who "" 2 # is "" 2 # Shaka "" PERSON 3