mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Gate parser fine-tuning behind feature flag
This commit is contained in:
parent
dbdd8afc4b
commit
bbace204be
|
@ -59,8 +59,9 @@ from ..structs cimport TokenC
|
|||
from ..tokens.doc cimport Doc
|
||||
from ..strings cimport StringStore
|
||||
from ..gold cimport GoldParse
|
||||
from ..attrs cimport TAG, DEP
|
||||
from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
|
||||
|
||||
USE_FINE_TUNE = True
|
||||
|
||||
def get_templates(*args, **kwargs):
|
||||
return []
|
||||
|
@ -237,7 +238,8 @@ cdef class Parser:
|
|||
token_vector_width = util.env_opt('token_vector_width', token_vector_width)
|
||||
hidden_width = util.env_opt('hidden_width', hidden_width)
|
||||
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2)
|
||||
tensors = fine_tune(Tok2Vec(token_vector_width, 7500, preprocess=doc2feats()))
|
||||
tensors = fine_tune(Tok2Vec(token_vector_width, 7500,
|
||||
preprocess=doc2feats(cols=[ID, NORM, PREFIX, SUFFIX, TAG])))
|
||||
if parser_maxout_pieces == 1:
|
||||
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class,
|
||||
nF=cls.nr_feature,
|
||||
|
@ -367,6 +369,7 @@ cdef class Parser:
|
|||
tokvecses = [tokvecses]
|
||||
|
||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||
if USE_FINE_TUNE:
|
||||
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
|
||||
|
||||
nr_state = len(docs)
|
||||
|
@ -419,6 +422,7 @@ cdef class Parser:
|
|||
cdef int nr_class = self.moves.n_moves
|
||||
cdef StateClass stcls, output
|
||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||
if USE_FINE_TUNE:
|
||||
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
|
||||
cuda_stream = get_cuda_stream()
|
||||
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs,
|
||||
|
@ -460,6 +464,7 @@ cdef class Parser:
|
|||
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
||||
docs = [docs]
|
||||
golds = [golds]
|
||||
if USE_FINE_TUNE:
|
||||
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=0.)
|
||||
my_tokvecs = self.model[0].ops.flatten(my_tokvecs)
|
||||
tokvecs += my_tokvecs
|
||||
|
@ -513,6 +518,7 @@ cdef class Parser:
|
|||
self._make_updates(d_tokvecs,
|
||||
backprops, sgd, cuda_stream)
|
||||
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
|
||||
if USE_FINE_TUNE:
|
||||
bp_my_tokvecs(d_tokvecs, sgd=sgd)
|
||||
return d_tokvecs
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user