Set architecture in textcat example

This commit is contained in:
Matthew Honnibal 2019-02-23 11:57:59 +01:00
parent e9dd5943b9
commit 5063d999e5

View File

@ -41,7 +41,9 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
# add the text classifier to the pipeline if it doesn't exist # add the text classifier to the pipeline if it doesn't exist
# nlp.create_pipe works for built-ins that are registered with spaCy # nlp.create_pipe works for built-ins that are registered with spaCy
if "textcat" not in nlp.pipe_names: if "textcat" not in nlp.pipe_names:
textcat = nlp.create_pipe("textcat") textcat = nlp.create_pipe("textcat", config={
"architecture": "simple_cnn",
"exclusive_classes": True})
nlp.add_pipe(textcat, last=True) nlp.add_pipe(textcat, last=True)
# otherwise, get it, so we can add labels to it # otherwise, get it, so we can add labels to it
else: else:
@ -70,7 +72,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
for i in range(n_iter): for i in range(n_iter):
losses = {} losses = {}
# batch up the examples using spaCy's minibatch # batch up the examples using spaCy's minibatch
batches = minibatch(train_data, size=compounding(4.0, 16.0, 1.001)) batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
for batch in batches: for batch in batches:
texts, annotations = zip(*batch) texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses) nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)