Refactor training, with new spacy.train module. Defaults still a little awkward.

This commit is contained in:
Matthew Honnibal 2016-10-09 12:24:24 +02:00
parent 53fbd3dd1c
commit ea23b64cc8
12 changed files with 212 additions and 134 deletions

View File

@ -79,82 +79,23 @@ def _merge_sents(sents):
return [(m_deps, m_brackets)] return [(m_deps, m_brackets)]
def train(Language, gold_tuples, model_dir, tagger_cfg, parser_cfg, entity_cfg, def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, entity_cfg,
n_iter=15, seed=0, gold_preproc=False, n_sents=0, corruption_level=0): n_iter=15, seed=0, gold_preproc=False, n_sents=0, corruption_level=0):
dep_model_dir = path.join(model_dir, 'deps')
ner_model_dir = path.join(model_dir, 'ner')
pos_model_dir = path.join(model_dir, 'pos')
if path.exists(dep_model_dir):
shutil.rmtree(dep_model_dir)
if path.exists(ner_model_dir):
shutil.rmtree(ner_model_dir)
if path.exists(pos_model_dir):
shutil.rmtree(pos_model_dir)
os.mkdir(dep_model_dir)
os.mkdir(ner_model_dir)
os.mkdir(pos_model_dir)
if parser_cfg['pseudoprojective']:
# preprocess training data here before ArcEager.get_labels() is called
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
parser_cfg['labels'] = ArcEager.get_labels(gold_tuples)
entity_cfg['labels'] = BiluoPushDown.get_labels(gold_tuples)
with (dep_model_dir / 'config.json').open('w') as file_:
json.dump(file_, parser_config)
with (ner_model_dir / 'config.json').open('w') as file_:
json.dump(file_, entity_config)
with (pos_model_dir / 'config.json').open('w') as file_:
json.dump(file_, tagger_config)
if n_sents > 0:
gold_tuples = gold_tuples[:n_sents]
nlp = Language(
data_dir=model_dir,
tagger=Tagger.blank(nlp.vocab, **tagger_cfg),
parser=Parser.blank(nlp.vocab, ArcEager, **parser_cfg),
entity=Parser.blank(nlp.vocab, BiluoPushDown, **entity_cfg))
print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %") print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %")
for itn in range(n_iter): format_str = '{:d}\t{:d}\t{uas:.3f}\t{ents_f:.3f}\t{tags_acc:.3f}\t{token_acc:.3f}'
scorer = Scorer() with Language.train(model_dir, train_data,
tagger_cfg, parser_cfg, entity_cfg) as trainer:
loss = 0 loss = 0
for raw_text, sents in gold_tuples: for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
if gold_preproc: for doc, gold in epoch:
raw_text = None trainer.update(doc, gold)
else: dev_scores = trainer.evaluate(dev_data)
sents = _merge_sents(sents) print(format_str.format(itn, loss, **dev_scores.scores))
for annot_tuples, ctnt in sents:
if len(annot_tuples[1]) == 1:
continue
score_model(scorer, nlp, raw_text, annot_tuples,
verbose=verbose if itn >= 2 else False)
if raw_text is None:
words = add_noise(annot_tuples[1], corruption_level)
tokens = nlp.tokenizer.tokens_from_list(words)
else:
raw_text = add_noise(raw_text, corruption_level)
tokens = nlp.tokenizer(raw_text)
nlp.tagger(tokens)
gold = GoldParse(tokens, annot_tuples)
if not gold.is_projective:
raise Exception("Non-projective sentence in training: %s" % annot_tuples[1])
loss += nlp.parser.train(tokens, gold)
nlp.entity.train(tokens, gold)
nlp.tagger.train(tokens, gold.tags)
random.shuffle(gold_tuples)
print('%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.ents_f,
scorer.tags_acc,
scorer.token_acc))
print('end training')
nlp.end_training(model_dir)
print('done')
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False, def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
beam_width=None, cand_preproc=None): beam_width=None, cand_preproc=None):
nlp = Language(data_dir=model_dir) nlp = Language(path=model_dir)
if nlp.lang == 'de': if nlp.lang == 'de':
nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string]) nlp.vocab.morphology.lemmatizer = lambda string,pos: set([string])
if beam_width is not None: if beam_width is not None:
@ -227,9 +168,13 @@ def main(language, train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc=
lang = spacy.util.get_lang_class(language) lang = spacy.util.get_lang_class(language)
parser_cfg['features'] = lang.Defaults.parser_features
entity_cfg['features'] = lang.Defaults.entity_features
if not eval_only: if not eval_only:
gold_train = list(read_json_file(train_loc)) gold_train = list(read_json_file(train_loc))
train(lang, gold_train, model_dir, tagger_cfg, parser_cfg, entity_cfg, gold_dev = list(read_json_file(dev_loc))
train(lang, gold_train, gold_dev, model_dir, tagger_cfg, parser_cfg, entity_cfg,
n_sents=n_sents, gold_preproc=gold_preproc, corruption_level=corruption_level, n_sents=n_sents, gold_preproc=gold_preproc, corruption_level=corruption_level,
n_iter=n_iter) n_iter=n_iter)
if out_loc: if out_loc:

View File

@ -27,4 +27,3 @@ class English(Language):
tag_map = dict(language_data.TAG_MAP) tag_map = dict(language_data.TAG_MAP)
stop_words = set(language_data.STOP_WORDS) stop_words = set(language_data.STOP_WORDS)

View File

@ -1,3 +1,5 @@
from __future__ import unicode_literals, print_function
import numpy import numpy
import io import io
import json import json
@ -128,7 +130,6 @@ def _min_edit_path(cand_words, gold_words):
def read_json_file(loc, docs_filter=None): def read_json_file(loc, docs_filter=None):
print loc
if path.isdir(loc): if path.isdir(loc):
for filename in os.listdir(loc): for filename in os.listdir(loc):
yield from read_json_file(path.join(loc, filename)) yield from read_json_file(path.join(loc, filename))
@ -199,7 +200,7 @@ def _consume_ent(tags):
cdef class GoldParse: cdef class GoldParse:
def __init__(self, tokens, annot_tuples, brackets=tuple(), make_projective=False): def __init__(self, tokens, annot_tuples, make_projective=False):
self.mem = Pool() self.mem = Pool()
self.loss = 0 self.loss = 0
self.length = len(tokens) self.length = len(tokens)
@ -209,9 +210,6 @@ cdef class GoldParse:
self.c.heads = <int*>self.mem.alloc(len(tokens), sizeof(int)) self.c.heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c.labels = <int*>self.mem.alloc(len(tokens), sizeof(int)) self.c.labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c.ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition)) self.c.ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
self.c.brackets = <int**>self.mem.alloc(len(tokens), sizeof(int*))
for i in range(len(tokens)):
self.c.brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.tags = [None] * len(tokens) self.tags = [None] * len(tokens)
self.heads = [None] * len(tokens) self.heads = [None] * len(tokens)
@ -246,14 +244,6 @@ cdef class GoldParse:
proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads,self.labels) proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads,self.labels)
self.heads = proj_heads self.heads = proj_heads
self.brackets = {}
for (gold_start, gold_end, label_str) in brackets:
start = self.gold_to_cand[gold_start]
end = self.gold_to_cand[gold_end]
if start is not None and end is not None:
self.brackets.setdefault(start, {}).setdefault(end, set())
self.brackets[end][start].add(label_str)
def __len__(self): def __len__(self):
return self.length return self.length

View File

@ -2,6 +2,8 @@ from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
from warnings import warn from warnings import warn
import pathlib import pathlib
from contextlib import contextmanager
import shutil
try: try:
import ujson as json import ujson as json
@ -15,7 +17,6 @@ except NameError:
basestring = str basestring = str
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .vocab import Vocab from .vocab import Vocab
from .syntax.parser import Parser from .syntax.parser import Parser
@ -27,9 +28,12 @@ from .syntax.ner import BiluoPushDown
from .syntax.arc_eager import ArcEager from .syntax.arc_eager import ArcEager
from . import util from . import util
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
from .train import Trainer
from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP
from .syntax.parser import get_templates from .syntax.parser import get_templates
from .syntax.nonproj import PseudoProjectivity
class BaseDefaults(object): class BaseDefaults(object):
@ -84,46 +88,62 @@ class BaseDefaults(object):
suffix_search=suffix_search, suffix_search=suffix_search,
infix_finditer=infix_finditer) infix_finditer=infix_finditer)
else: else:
return Tokenizer(vocab, rules=rules, tokenizer = Tokenizer(vocab, rules=rules,
prefix_search=prefix_search, suffix_search=suffix_search, prefix_search=prefix_search, suffix_search=suffix_search,
infix_finditer=infix_finditer) infix_finditer=infix_finditer)
return tokenizer
def Tagger(self, vocab): def Tagger(self, vocab, **cfg):
if self.path: if self.path:
return Tagger.load(self.path / 'pos', vocab) return Tagger.load(self.path / 'pos', vocab)
else: else:
return Tagger.blank(vocab, Tagger.default_templates()) return Tagger.blank(vocab, Tagger.default_templates())
def Parser(self, vocab, blank=False): def Parser(self, vocab, **cfg):
if blank: if self.path and (self.path / 'dep').exists():
return Parser.blank(vocab, ArcEager, return Parser.load(self.path / 'dep', vocab, ArcEager)
features=self.parser_features, labels=self.parser_labels)
elif self.path and (self.path / 'deps').exists():
return Parser.load(self.path / 'deps', vocab, ArcEager)
else: else:
return None if 'features' not in cfg:
cfg['features'] = self.parser_features
if 'labels' not in cfg:
cfg['labels'] = self.parser_labels
return Parser.blank(vocab, ArcEager, **cfg)
def Entity(self, vocab, blank=False): def Entity(self, vocab, **cfg):
if blank: if self.path and (self.path / 'ner').exists():
return Parser.blank(vocab, BiluoPushDown,
features=self.entity_features, labels=self.entity_labels)
elif self.path and (self.path / 'ner').exists():
return Parser.load(self.path / 'ner', vocab, BiluoPushDown) return Parser.load(self.path / 'ner', vocab, BiluoPushDown)
else: else:
return None if 'features' not in cfg:
cfg['features'] = self.entity_features
if 'labels' not in cfg:
cfg['labels'] = self.entity_labels
return Parser.blank(vocab, BiluoPushDown, **cfg)
def Matcher(self, vocab): def Matcher(self, vocab, **cfg):
if self.path: if self.path:
return Matcher.load(self.path, vocab) return Matcher.load(self.path, vocab)
else: else:
return Matcher(vocab) return Matcher(vocab)
def Pipeline(self, nlp): def Pipeline(self, nlp, **cfg):
return [ pipeline = [nlp.tokenizer]
nlp.tokenizer, if nlp.tagger:
nlp.tagger, pipeline.append(nlp.tagger)
nlp.parser, if nlp.parser:
nlp.entity] pipeline.append(nlp.parser)
if nlp.entity:
pipeline.append(nlp.entity)
return pipeline
prefixes = tuple()
suffixes = tuple()
infixes = tuple()
tag_map = {}
tokenizer_exceptions = {}
parser_labels = {0: {'ROOT': True}} parser_labels = {0: {'ROOT': True}}
@ -169,6 +189,58 @@ class Language(object):
Defaults = BaseDefaults Defaults = BaseDefaults
lang = None lang = None
@classmethod
def blank(cls):
return cls(path=False, vocab=False, tokenizer=False, tagger=False,
parser=False, entity=False, matcher=False, serializer=False,
vectors=False, pipeline=False)
@classmethod
@contextmanager
def train(cls, path, gold_tuples, *configs):
if isinstance(path, basestring):
path = pathlib.Path(path)
tagger_cfg, parser_cfg, entity_cfg = configs
dep_model_dir = path / 'dep'
ner_model_dir = path / 'ner'
pos_model_dir = path / 'pos'
if dep_model_dir.exists():
shutil.rmtree(str(dep_model_dir))
if ner_model_dir.exists():
shutil.rmtree(str(ner_model_dir))
if pos_model_dir.exists():
shutil.rmtree(str(pos_model_dir))
dep_model_dir.mkdir()
ner_model_dir.mkdir()
pos_model_dir.mkdir()
if parser_cfg['pseudoprojective']:
# preprocess training data here before ArcEager.get_labels() is called
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
parser_cfg['labels'] = ArcEager.get_labels(gold_tuples)
entity_cfg['labels'] = BiluoPushDown.get_labels(gold_tuples)
with (dep_model_dir / 'config.json').open('wb') as file_:
json.dump(parser_cfg, file_)
with (ner_model_dir / 'config.json').open('wb') as file_:
json.dump(entity_cfg, file_)
with (pos_model_dir / 'config.json').open('wb') as file_:
json.dump(tagger_cfg, file_)
self = cls.blank()
self.path = path
self.vocab = self.defaults.Vocab()
self.defaults.parser_labels = parser_cfg['labels']
self.defaults.entity_labels = entity_cfg['labels']
self.tokenizer = self.defaults.Tokenizer(self.vocab)
self.tagger = self.defaults.Tagger(self.vocab, **tagger_cfg)
self.parser = self.defaults.Parser(self.vocab, **parser_cfg)
self.entity = self.defaults.Entity(self.vocab, **entity_cfg)
self.pipeline = self.defaults.Pipeline(self)
yield Trainer(self, gold_tuples)
self.end_training()
def __init__(self, def __init__(self,
path=None, path=None,
vocab=True, vocab=True,
@ -210,13 +282,19 @@ class Language(object):
self.path = path self.path = path
defaults = defaults if defaults is not True else self.get_defaults(self.path) defaults = defaults if defaults is not True else self.get_defaults(self.path)
self.defaults = defaults
self.vocab = vocab if vocab is not True else defaults.Vocab(vectors=vectors) self.vocab = vocab if vocab is not True else defaults.Vocab(vectors=vectors)
self.tokenizer = tokenizer if tokenizer is not True else defaults.Tokenizer(self.vocab) self.tokenizer = tokenizer if tokenizer is not True else defaults.Tokenizer(self.vocab)
self.tagger = tagger if tagger is not True else defaults.Tagger(self.vocab) self.tagger = tagger if tagger is not True else defaults.Tagger(self.vocab)
self.entity = entity if entity is not True else defaults.Entity(self.vocab) self.entity = entity if entity is not True else defaults.Entity(self.vocab)
self.parser = parser if parser is not True else defaults.Parser(self.vocab) self.parser = parser if parser is not True else defaults.Parser(self.vocab)
self.matcher = matcher if matcher is not True else defaults.Matcher(self.vocab) self.matcher = matcher if matcher is not True else defaults.Matcher(self.vocab)
self.pipeline = pipeline(self) if pipeline is not True else defaults.Pipeline(self) if pipeline in (None, False):
self.pipeline = []
elif pipeline is True:
self.pipeline = defaults.Pipeline(self)
else:
self.pipeline = pipeline(self)
def __reduce__(self): def __reduce__(self):
args = ( args = (
@ -276,15 +354,18 @@ class Language(object):
def end_training(self, path=None): def end_training(self, path=None):
if path is None: if path is None:
path = self.path path = self.path
if self.parser: elif isinstance(path, basestring):
self.parser.model.end_training() path = pathlib.Path(path)
self.parser.model.dump(path / 'deps' / 'model')
if self.entity:
self.entity.model.end_training()
self.entity.model.dump(path / 'ner' / 'model')
if self.tagger: if self.tagger:
self.tagger.model.end_training() self.tagger.model.end_training()
self.tagger.model.dump(path / 'pos' / 'model') self.tagger.model.dump(str(path / 'pos' / 'model'))
if self.parser:
self.parser.model.end_training()
self.parser.model.dump(str(path / 'dep' / 'model'))
if self.entity:
self.entity.model.end_training()
self.entity.model.dump(str(path / 'ner' / 'model'))
strings_loc = path / 'vocab' / 'strings.json' strings_loc = path / 'vocab' / 'strings.json'
with strings_loc.open('w', encoding='utf8') as file_: with strings_loc.open('w', encoding='utf8') as file_:
@ -307,7 +388,7 @@ class Language(object):
else: else:
entity_iob_freqs = [] entity_iob_freqs = []
entity_type_freqs = [] entity_type_freqs = []
with (path / 'vocab' / 'serializer.json').open('w') as file_: with (path / 'vocab' / 'serializer.json').open('wb') as file_:
file_.write( file_.write(
json.dumps([ json.dumps([
(TAG, tagger_freqs), (TAG, tagger_freqs),

View File

@ -70,6 +70,15 @@ class Scorer(object):
def ents_f(self): def ents_f(self):
return self.ner.fscore * 100 return self.ner.fscore * 100
@property
def scores(self):
return {
'uas': self.uas, 'las': self.las,
'ents_p': self.ents_p, 'ents_r': self.ents_r, 'ents_f': self.ents_f,
'tags_acc': self.tags_acc,
'token_acc': self.token_acc
}
def score(self, tokens, gold, verbose=False, punct_labels=('p', 'punct')): def score(self, tokens, gold, verbose=False, punct_labels=('p', 'punct')):
assert len(tokens) == len(gold) assert len(tokens) == len(gold)

View File

@ -1,11 +1,11 @@
from libc.stdint cimport int64_t
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64
from .typedefs cimport attr_t
from libc.stdint cimport int64_t from .typedefs cimport attr_t, hash_t
from .typedefs cimport hash_t
cpdef hash_t hash_string(unicode string) except 0 cpdef hash_t hash_string(unicode string) except 0

View File

@ -312,12 +312,6 @@ cdef class ArcEager(TransitionSystem):
# Count frequencies, for use in encoder # Count frequencies, for use in encoder
self.freqs[HEAD][gold.c.heads[i] - i] += 1 self.freqs[HEAD][gold.c.heads[i] - i] += 1
self.freqs[DEP][gold.c.labels[i]] += 1 self.freqs[DEP][gold.c.labels[i]] += 1
for end, brackets in gold.brackets.items():
for start, label_strs in brackets.items():
gold.c.brackets[start][end] = 1
for label_str in label_strs:
# Add the encoded label to the set
gold.brackets[end][start].add(self.strings[label_str])
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
if '-' in name: if '-' in name:

View File

@ -83,8 +83,7 @@ cdef class Parser:
with (path / 'config.json').open() as file_: with (path / 'config.json').open() as file_:
cfg = json.load(file_) cfg = json.load(file_)
moves = moves_class(vocab.strings, cfg['labels']) moves = moves_class(vocab.strings, cfg['labels'])
templates = get_templates(cfg['features']) model = ParserModel(cfg['features'])
model = ParserModel(templates)
if (path / 'model').exists(): if (path / 'model').exists():
model.load(str(path / 'model')) model.load(str(path / 'model'))
return cls(vocab, moves, model, **cfg) return cls(vocab, moves, model, **cfg)
@ -96,7 +95,6 @@ cdef class Parser:
model = ParserModel(templates) model = ParserModel(templates)
return cls(vocab, moves, model, **cfg) return cls(vocab, moves, model, **cfg)
def __init__(self, Vocab vocab, transition_system, ParserModel model, **cfg): def __init__(self, Vocab vocab, transition_system, ParserModel model, **cfg):
self.moves = transition_system self.moves = transition_system
self.model = model self.model = model
@ -191,7 +189,7 @@ cdef class Parser:
free(eg.is_valid) free(eg.is_valid)
return 0 return 0
def train(self, Doc tokens, GoldParse gold): def update(self, Doc tokens, GoldParse gold):
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length) cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
self.moves.initialize_state(stcls.c) self.moves.initialize_state(stcls.c)

View File

@ -154,7 +154,7 @@ cdef class Tagger:
model.load(str(path / 'model')) model.load(str(path / 'model'))
return cls(vocab, model) return cls(vocab, model)
def __init__(self, Vocab vocab, TaggerModel model): def __init__(self, Vocab vocab, TaggerModel model, **cfg):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
# TODO: Move this to tag map # TODO: Move this to tag map
@ -208,7 +208,9 @@ cdef class Tagger:
self(doc) self(doc)
yield doc yield doc
def train(self, Doc tokens, object gold_tag_strs): def update(self, Doc tokens, object gold):
if hasattr(gold, 'tags'):
gold_tag_strs = list(gold.tags)
assert len(tokens) == len(gold_tag_strs) assert len(tokens) == len(gold_tag_strs)
for tag in gold_tag_strs: for tag in gold_tag_strs:
if tag != None and tag not in self.tag_names: if tag != None and tag not in self.tag_names:

59
spacy/train.py Normal file
View File

@ -0,0 +1,59 @@
from __future__ import absolute_import
from __future__ import unicode_literals
import random
from .gold import GoldParse
from .scorer import Scorer
class Trainer(object):
def __init__(self, nlp, gold_tuples):
self.nlp = nlp
self.gold_tuples = gold_tuples
def epochs(self, nr_epoch, augment_data=None):
def _epoch():
for raw_text, paragraph_tuples in self.gold_tuples:
if augment_data is not None:
raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples)
docs = self.make_docs(raw_text, paragraph_tuples)
golds = self.make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds):
yield doc, gold
for itn in range(nr_epoch):
random.shuffle(self.gold_tuples)
yield _epoch()
def update(self, doc, gold):
for process in self.nlp.pipeline[1:]:
if hasattr(process, 'update'):
process.update(doc, gold)
process(doc)
return doc
def evaluate(self, dev_sents):
scorer = Scorer()
for raw_text, paragraph_tuples in dev_sents:
docs = self.make_docs(raw_text, paragraph_tuples)
golds = self.make_golds(docs, paragraph_tuples)
for doc, gold in zip(docs, golds):
for process in self.nlp.pipeline[1:]:
process(doc)
scorer.score(doc, gold)
return scorer
def make_docs(self, raw_text, paragraph_tuples):
if raw_text is not None:
return [self.nlp.tokenizer(raw_text)]
else:
return [self.nlp.tokenizer.tokens_from_list(sent_tuples[0][1])
for sent_tuples in paragraph_tuples]
def make_golds(self, docs, paragraph_tuples):
if len(docs) == 1:
return [GoldParse(docs[0], sent_tuples[0])
for sent_tuples in paragraph_tuples]
else:
return [GoldParse(doc, sent_tuples[0])
for doc, sent_tuples in zip(docs, paragraph_tuples)]

View File

@ -13,6 +13,7 @@ try:
except NameError: except NameError:
basestring = str basestring = str
LANGUAGES = {} LANGUAGES = {}
_data_path = pathlib.Path(__file__).parent / 'data' _data_path = pathlib.Path(__file__).parent / 'data'

View File

@ -177,7 +177,7 @@ cdef class Vocab:
value = self.strings[value] value = self.strings[value]
if attr == PROB: if attr == PROB:
lex.prob = value lex.prob = value
else: elif value is not None:
Lexeme.set_struct_attr(lex, attr, value) Lexeme.set_struct_attr(lex, attr, value)
if is_oov: if is_oov:
lex.id = 0 lex.id = 0