Fixed training examples

Changes:
1. train_ner won't crash if no data directory is not found
2. Fixed train_tagger expected spacy.gold.GoldParse, got list
This commit is contained in:
kendricktan 2016-10-24 16:09:23 +10:00
parent a9289f6261
commit ba8841234a
2 changed files with 25 additions and 12 deletions

View File

@ -6,6 +6,7 @@ import random
import spacy import spacy
from spacy.pipeline import EntityRecognizer from spacy.pipeline import EntityRecognizer
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.tagger import Tagger
def train_ner(nlp, train_data, entity_types): def train_ner(nlp, train_data, entity_types):
@ -29,6 +30,15 @@ def main(model_dir=None):
nlp = spacy.load('en', parser=False, entity=False, add_vectors=False) nlp = spacy.load('en', parser=False, entity=False, add_vectors=False)
# v1.1.2 onwards
if nlp.tagger is None:
print('---- WARNING ----')
print('Data directory not found')
print('please run: `python -m spacy.en.download force all` for better performance')
print('Using feature templates for tagging')
print('-----------------')
nlp.tagger = Tagger(nlp.vocab, features=Tagger.feature_templates)
train_data = [ train_data = [
( (
'Who is Shaka Khan?', 'Who is Shaka Khan?',

View File

@ -10,8 +10,9 @@ from pathlib import Path
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tagger import Tagger from spacy.tagger import Tagger
from spacy.tokens import Doc from spacy.tokens import Doc
import random from spacy.gold import GoldParse
import random
# You need to define a mapping from your data's part-of-speech tag names to the # You need to define a mapping from your data's part-of-speech tag names to the
# Universal Part-of-Speech tag set, as spaCy includes an enum of these tags. # Universal Part-of-Speech tag set, as spaCy includes an enum of these tags.
@ -23,7 +24,7 @@ TAG_MAP = {
'N': {"pos": "NOUN"}, 'N': {"pos": "NOUN"},
'V': {"pos": "VERB"}, 'V': {"pos": "VERB"},
'J': {"pos": "ADJ"} 'J': {"pos": "ADJ"}
} }
# Usually you'll read this in, of course. Data formats vary. # Usually you'll read this in, of course. Data formats vary.
# Ensure your strings are unicode. # Ensure your strings are unicode.
@ -38,6 +39,7 @@ DATA = [
) )
] ]
def ensure_dir(path): def ensure_dir(path):
if not path.exists(): if not path.exists():
path.mkdir() path.mkdir()
@ -54,13 +56,14 @@ def main(output_dir=None):
# The default_templates argument is where features are specified. See # The default_templates argument is where features are specified. See
# spacy/tagger.pyx for the defaults. # spacy/tagger.pyx for the defaults.
tagger = Tagger(vocab) tagger = Tagger(vocab)
for i in range(5): for i in range(25):
for words, tags in DATA: for words, tags in DATA:
doc = Doc(vocab, words=words) doc = Doc(vocab, words=words)
tagger.update(doc, tags) gold = GoldParse(doc, tags=tags)
tagger.update(doc, gold)
random.shuffle(DATA) random.shuffle(DATA)
tagger.model.end_training() tagger.model.end_training()
doc = Doc(vocab, orths_and_spaces=zip(["I", "like", "blue", "eggs"], [True]*4)) doc = Doc(vocab, orths_and_spaces=zip(["I", "like", "blue", "eggs"], [True] * 4))
tagger(doc) tagger(doc)
for word in doc: for word in doc:
print(word.text, word.tag_, word.pos_) print(word.text, word.tag_, word.pos_)