mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 16:24:16 +03:00
Fix train_new_entity_type example
This commit is contained in:
parent
6a4221a6de
commit
2f84626417
|
@ -4,7 +4,6 @@ import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.pipeline import EntityRecognizer
|
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
from spacy.tagger import Tagger
|
from spacy.tagger import Tagger
|
||||||
|
|
||||||
|
@ -25,10 +24,13 @@ def train_ner(nlp, train_data, output_dir):
|
||||||
loss = nlp.entity.update(doc, gold)
|
loss = nlp.entity.update(doc, gold)
|
||||||
nlp.end_training()
|
nlp.end_training()
|
||||||
if output_dir:
|
if output_dir:
|
||||||
|
if not output_dir.exists():
|
||||||
|
output_dir.mkdir()
|
||||||
nlp.save_to_directory(output_dir)
|
nlp.save_to_directory(output_dir)
|
||||||
|
|
||||||
|
|
||||||
def main(model_name, output_directory=None):
|
def main(model_name, output_directory=None):
|
||||||
|
print("Loading initial model", model_name)
|
||||||
nlp = spacy.load(model_name)
|
nlp = spacy.load(model_name)
|
||||||
if output_directory is not None:
|
if output_directory is not None:
|
||||||
output_directory = Path(output_directory)
|
output_directory = Path(output_directory)
|
||||||
|
@ -52,13 +54,14 @@ def main(model_name, output_directory=None):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
nlp.entity.add_label('ANIMAL')
|
nlp.entity.add_label('ANIMAL')
|
||||||
ner = train_ner(nlp, train_data, output_directory)
|
train_ner(nlp, train_data, output_directory)
|
||||||
|
|
||||||
# Test that the entity is recognized
|
# Test that the entity is recognized
|
||||||
doc = nlp('Do you like horses?')
|
doc = nlp('Do you like horses?')
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
print(ent.label_, ent.text)
|
print(ent.label_, ent.text)
|
||||||
if output_directory:
|
if output_directory:
|
||||||
|
print("Loading from", output_directory)
|
||||||
nlp2 = spacy.load('en', path=output_directory)
|
nlp2 = spacy.load('en', path=output_directory)
|
||||||
nlp2.entity.add_label('ANIMAL')
|
nlp2.entity.add_label('ANIMAL')
|
||||||
doc2 = nlp2('Do you like horses?')
|
doc2 = nlp2('Do you like horses?')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user