mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Add support for pretrained tok2vec to ud-train
This commit is contained in:
parent
93be3ad038
commit
681258e29b
|
@ -305,10 +305,28 @@ def initialize_pipeline(nlp, docs, golds, config, device):
|
|||
nlp.tagger.add_label(tag)
|
||||
if torch is not None and device != -1:
|
||||
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
||||
return nlp.begin_training(
|
||||
optimizer = nlp.begin_training(
|
||||
lambda: golds_to_gold_tuples(docs, golds), device=device,
|
||||
subword_features=config.subword_features, conv_depth=config.conv_depth,
|
||||
bilstm_depth=config.bilstm_depth)
|
||||
if config.pretrained_tok2vec:
|
||||
_load_pretrained_tok2vec(nlp, config.pretrained_tok2vec)
|
||||
return optimizer
|
||||
|
||||
|
||||
def _load_pretrained_tok2vec(nlp, loc):
|
||||
"""Load pre-trained weights for the 'token-to-vector' part of the component
|
||||
models, which is typically a CNN. See 'spacy pretrain'. Experimental.
|
||||
"""
|
||||
with Path(loc).open('rb') as file_:
|
||||
weights_data = file_.read()
|
||||
loaded = []
|
||||
for name, component in nlp.pipeline:
|
||||
if hasattr(component, 'model') and hasattr(component.model, 'tok2vec'):
|
||||
component.tok2vec.from_bytes(weights_data)
|
||||
loaded.append(name)
|
||||
return loaded
|
||||
|
||||
|
||||
|
||||
########################
|
||||
|
@ -318,9 +336,9 @@ def initialize_pipeline(nlp, docs, golds, config, device):
|
|||
class Config(object):
|
||||
def __init__(self, vectors=None, max_doc_length=10, multitask_tag=False,
|
||||
multitask_sent=False, multitask_dep=False, multitask_vectors=None,
|
||||
bilstm_depth=0, nr_epoch=30, min_batch_size=100, max_batch_size=1000,
|
||||
batch_by_words=True, dropout=0.2, conv_depth=4, subword_features=True,
|
||||
vectors_dir=None):
|
||||
bilstm_depth=0, nr_epoch=30, min_batch_size=750, max_batch_size=750,
|
||||
batch_by_words=True, dropout=0.1, conv_depth=4, subword_features=True,
|
||||
vectors_dir=None, pretrained_tok2vec=None):
|
||||
if vectors_dir is not None:
|
||||
if vectors is None:
|
||||
vectors = True
|
||||
|
|
Loading…
Reference in New Issue
Block a user