diff --git a/examples/training/train_ner.py b/examples/training/train_ner.py index 220244b93..bcc087d07 100644 --- a/examples/training/train_ner.py +++ b/examples/training/train_ner.py @@ -8,6 +8,12 @@ from spacy.pipeline import EntityRecognizer from spacy.gold import GoldParse from spacy.tagger import Tagger + +try: + unicode +except: + unicode = str + def train_ner(nlp, train_data, entity_types): # Add new words to vocab. @@ -24,7 +30,6 @@ def train_ner(nlp, train_data, entity_types): doc = nlp.make_doc(raw_text) gold = GoldParse(doc, entities=entity_offsets) ner.update(doc, gold) - ner.model.end_training() return ner def save_model(ner, model_dir): @@ -33,8 +38,11 @@ def save_model(ner, model_dir): model_dir.mkdir() assert model_dir.is_dir() - with (model_dir / 'config.json').open('w') as file_: - json.dump(ner.cfg, file_) + with (model_dir / 'config.json').open('wb') as file_: + data = json.dumps(ner.cfg) + if isinstance(data, unicode): + data = data.encode('utf8') + file_.write(data) ner.model.dump(str(model_dir / 'model')) if not (model_dir / 'vocab').exists(): (model_dir / 'vocab').mkdir()