mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Update NER training example
This commit is contained in:
parent
63adcb8141
commit
ab70f6e18d
|
@ -9,6 +9,12 @@ from spacy.gold import GoldParse
|
|||
from spacy.tagger import Tagger
|
||||
|
||||
|
||||
try:
|
||||
unicode
|
||||
except:
|
||||
unicode = str
|
||||
|
||||
|
||||
def train_ner(nlp, train_data, entity_types):
|
||||
# Add new words to vocab.
|
||||
for raw_text, _ in train_data:
|
||||
|
@ -24,7 +30,6 @@ def train_ner(nlp, train_data, entity_types):
|
|||
doc = nlp.make_doc(raw_text)
|
||||
gold = GoldParse(doc, entities=entity_offsets)
|
||||
ner.update(doc, gold)
|
||||
ner.model.end_training()
|
||||
return ner
|
||||
|
||||
def save_model(ner, model_dir):
|
||||
|
@ -33,8 +38,11 @@ def save_model(ner, model_dir):
|
|||
model_dir.mkdir()
|
||||
assert model_dir.is_dir()
|
||||
|
||||
with (model_dir / 'config.json').open('w') as file_:
|
||||
json.dump(ner.cfg, file_)
|
||||
with (model_dir / 'config.json').open('wb') as file_:
|
||||
data = json.dumps(ner.cfg)
|
||||
if isinstance(data, unicode):
|
||||
data = data.encode('utf8')
|
||||
file_.write(data)
|
||||
ner.model.dump(str(model_dir / 'model'))
|
||||
if not (model_dir / 'vocab').exists():
|
||||
(model_dir / 'vocab').mkdir()
|
||||
|
|
Loading…
Reference in New Issue
Block a user