mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Pass option for pre-trained vectors in parser
This commit is contained in:
parent
8665a77f48
commit
5ff2491f24
|
@ -245,7 +245,7 @@ cdef class Parser:
|
|||
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2)
|
||||
embed_size = util.env_opt('embed_size', 4000)
|
||||
tensors = fine_tune(Tok2Vec(token_vector_width, embed_size,
|
||||
preprocess=doc2feats()))
|
||||
pretrained_dims=cfg.get('pretrained_dims')))
|
||||
if parser_maxout_pieces == 1:
|
||||
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class,
|
||||
nF=cls.nr_feature,
|
||||
|
@ -391,9 +391,10 @@ cdef class Parser:
|
|||
if isinstance(tokvecses, np.ndarray):
|
||||
tokvecses = [tokvecses]
|
||||
|
||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||
if USE_FINE_TUNE:
|
||||
tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
|
||||
else:
|
||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||
|
||||
nr_state = len(docs)
|
||||
nr_class = self.moves.n_moves
|
||||
|
@ -451,9 +452,10 @@ cdef class Parser:
|
|||
cdef Doc doc
|
||||
cdef int nr_class = self.moves.n_moves
|
||||
cdef StateClass stcls, output
|
||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||
if USE_FINE_TUNE:
|
||||
tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
|
||||
else:
|
||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||
cuda_stream = get_cuda_stream()
|
||||
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs,
|
||||
cuda_stream, 0.0)
|
||||
|
@ -533,6 +535,8 @@ cdef class Parser:
|
|||
if USE_FINE_TUNE:
|
||||
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
|
||||
tokvecs = self.model[0].ops.flatten(my_tokvecs)
|
||||
else:
|
||||
tokvecs = self.model[0].ops.flatten(docs_tokvecs[1])
|
||||
|
||||
cuda_stream = get_cuda_stream()
|
||||
|
||||
|
@ -603,11 +607,11 @@ cdef class Parser:
|
|||
docs, tokvecs = docs_tokvecs
|
||||
lengths = [len(d) for d in docs]
|
||||
assert min(lengths) >= 1
|
||||
tokvecs = self.model[0].ops.flatten(tokvecs)
|
||||
if USE_FINE_TUNE:
|
||||
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
|
||||
tokvecs += self.model[0].ops.flatten(my_tokvecs)
|
||||
|
||||
else:
|
||||
tokvecs = self.model[0].ops.flatten(tokvecs)
|
||||
states = self.moves.init_batch(docs)
|
||||
for gold in golds:
|
||||
self.moves.preprocess_gold(gold)
|
||||
|
@ -775,6 +779,7 @@ cdef class Parser:
|
|||
for label in labels:
|
||||
self.moves.add_action(action, label)
|
||||
if self.model is True:
|
||||
cfg['pretrained_dims'] = self.vocab.vectors_length
|
||||
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
|
||||
self.cfg.update(cfg)
|
||||
|
||||
|
@ -856,9 +861,11 @@ cdef class Parser:
|
|||
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
||||
if 'model' not in exclude:
|
||||
if self.model is True:
|
||||
self.model, cfg = self.Model(self.moves.n_moves)
|
||||
self.model, cfg = self.Model(self.moves.n_moves,
|
||||
pretrained_dims=self.vocab.vectors_length)
|
||||
else:
|
||||
cfg = {}
|
||||
cfg['pretrained_dims'] = self.vocab.vectors_length
|
||||
if 'tok2vec_model' in msg:
|
||||
self.model[0].from_bytes(msg['tok2vec_model'])
|
||||
if 'lower_model' in msg:
|
||||
|
|
Loading…
Reference in New Issue
Block a user