mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
Gate parser fine-tuning behind feature flag
This commit is contained in:
parent
dbdd8afc4b
commit
bbace204be
|
@ -59,8 +59,9 @@ from ..structs cimport TokenC
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..strings cimport StringStore
|
from ..strings cimport StringStore
|
||||||
from ..gold cimport GoldParse
|
from ..gold cimport GoldParse
|
||||||
from ..attrs cimport TAG, DEP
|
from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
|
||||||
|
|
||||||
|
USE_FINE_TUNE = True
|
||||||
|
|
||||||
def get_templates(*args, **kwargs):
|
def get_templates(*args, **kwargs):
|
||||||
return []
|
return []
|
||||||
|
@ -237,7 +238,8 @@ cdef class Parser:
|
||||||
token_vector_width = util.env_opt('token_vector_width', token_vector_width)
|
token_vector_width = util.env_opt('token_vector_width', token_vector_width)
|
||||||
hidden_width = util.env_opt('hidden_width', hidden_width)
|
hidden_width = util.env_opt('hidden_width', hidden_width)
|
||||||
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2)
|
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2)
|
||||||
tensors = fine_tune(Tok2Vec(token_vector_width, 7500, preprocess=doc2feats()))
|
tensors = fine_tune(Tok2Vec(token_vector_width, 7500,
|
||||||
|
preprocess=doc2feats(cols=[ID, NORM, PREFIX, SUFFIX, TAG])))
|
||||||
if parser_maxout_pieces == 1:
|
if parser_maxout_pieces == 1:
|
||||||
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class,
|
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class,
|
||||||
nF=cls.nr_feature,
|
nF=cls.nr_feature,
|
||||||
|
@ -367,6 +369,7 @@ cdef class Parser:
|
||||||
tokvecses = [tokvecses]
|
tokvecses = [tokvecses]
|
||||||
|
|
||||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||||
|
if USE_FINE_TUNE:
|
||||||
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
|
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
|
||||||
|
|
||||||
nr_state = len(docs)
|
nr_state = len(docs)
|
||||||
|
@ -419,6 +422,7 @@ cdef class Parser:
|
||||||
cdef int nr_class = self.moves.n_moves
|
cdef int nr_class = self.moves.n_moves
|
||||||
cdef StateClass stcls, output
|
cdef StateClass stcls, output
|
||||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||||
|
if USE_FINE_TUNE:
|
||||||
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
|
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
|
||||||
cuda_stream = get_cuda_stream()
|
cuda_stream = get_cuda_stream()
|
||||||
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs,
|
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs,
|
||||||
|
@ -460,6 +464,7 @@ cdef class Parser:
|
||||||
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
golds = [golds]
|
golds = [golds]
|
||||||
|
if USE_FINE_TUNE:
|
||||||
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=0.)
|
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=0.)
|
||||||
my_tokvecs = self.model[0].ops.flatten(my_tokvecs)
|
my_tokvecs = self.model[0].ops.flatten(my_tokvecs)
|
||||||
tokvecs += my_tokvecs
|
tokvecs += my_tokvecs
|
||||||
|
@ -513,6 +518,7 @@ cdef class Parser:
|
||||||
self._make_updates(d_tokvecs,
|
self._make_updates(d_tokvecs,
|
||||||
backprops, sgd, cuda_stream)
|
backprops, sgd, cuda_stream)
|
||||||
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
|
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
|
||||||
|
if USE_FINE_TUNE:
|
||||||
bp_my_tokvecs(d_tokvecs, sgd=sgd)
|
bp_my_tokvecs(d_tokvecs, sgd=sgd)
|
||||||
return d_tokvecs
|
return d_tokvecs
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user