mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Remove unused model_dir option
As noted in #845, the `model_dir` argument was not being used. I've removed it for now, although it would be good to have this option restored and working.
This commit is contained in:
parent
724e51ed47
commit
c031c677cc
|
@ -18,7 +18,7 @@ except ImportError:
|
|||
import pickle
|
||||
|
||||
|
||||
def train(model_dir, train_loc, dev_loc, shape, settings):
|
||||
def train(train_loc, dev_loc, shape, settings):
|
||||
train_texts1, train_texts2, train_labels = read_snli(train_loc)
|
||||
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc)
|
||||
|
||||
|
@ -44,7 +44,7 @@ def train(model_dir, train_loc, dev_loc, shape, settings):
|
|||
batch_size=settings['batch_size'])
|
||||
if not (nlp.path / 'similarity').exists():
|
||||
(nlp.path / 'similarity').mkdir()
|
||||
print("Saving to", model_dir / 'similarity')
|
||||
print("Saving to", nlp.path / 'similarity')
|
||||
weights = model.get_weights()
|
||||
with (nlp.path / 'similarity' / 'model').open('wb') as file_:
|
||||
pickle.dump(weights[1:], file_)
|
||||
|
@ -68,8 +68,8 @@ def evaluate(model_dir, dev_loc):
|
|||
return correct, total
|
||||
|
||||
|
||||
def demo(model_dir):
|
||||
nlp = spacy.load('en', path=model_dir,
|
||||
def demo():
|
||||
nlp = spacy.load('en',
|
||||
create_pipeline=create_similarity_pipeline)
|
||||
doc1 = nlp(u'What were the best crime fiction books in 2016?')
|
||||
doc2 = nlp(
|
||||
|
@ -98,7 +98,6 @@ def read_snli(path):
|
|||
|
||||
@plac.annotations(
|
||||
mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]),
|
||||
model_dir=("Path to spaCy model directory", "positional", None, Path),
|
||||
train_loc=("Path to training data", "positional", None, Path),
|
||||
dev_loc=("Path to development data", "positional", None, Path),
|
||||
max_length=("Length to truncate sentences", "option", "L", int),
|
||||
|
@ -110,7 +109,7 @@ def read_snli(path):
|
|||
tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool),
|
||||
gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool),
|
||||
)
|
||||
def main(mode, model_dir, train_loc, dev_loc,
|
||||
def main(mode, train_loc, dev_loc,
|
||||
tree_truncate=False,
|
||||
gru_encode=False,
|
||||
max_length=100,
|
||||
|
@ -129,12 +128,12 @@ def main(mode, model_dir, train_loc, dev_loc,
|
|||
'gru_encode': gru_encode
|
||||
}
|
||||
if mode == 'train':
|
||||
train(model_dir, train_loc, dev_loc, shape, settings)
|
||||
train(train_loc, dev_loc, shape, settings)
|
||||
elif mode == 'evaluate':
|
||||
correct, total = evaluate(model_dir, dev_loc)
|
||||
correct, total = evaluate(dev_loc)
|
||||
print(correct, '/', total, correct / total)
|
||||
else:
|
||||
demo(model_dir)
|
||||
demo()
|
||||
|
||||
if __name__ == '__main__':
|
||||
plac.call(main)
|
||||
|
|
Loading…
Reference in New Issue
Block a user