Merge pull request #6106 from svlandeg/feature/textcat-quickstart

This commit is contained in:
Ines Montani 2020-09-23 10:11:45 +02:00 committed by GitHub
commit 888f936a73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 2 deletions

View File

@ -36,7 +36,7 @@ def init_config_cli(
""" """
Generate a starter config.cfg for training. Based on your requirements Generate a starter config.cfg for training. Based on your requirements
specified via the CLI arguments, this command generates a config with the specified via the CLI arguments, this command generates a config with the
optimal settings for you use case. This includes the choice of architecture, optimal settings for your use case. This includes the choice of architecture,
pretrained weights and related hyperparameters. pretrained weights and related hyperparameters.
DOCS: https://nightly.spacy.io/api/cli#init-config DOCS: https://nightly.spacy.io/api/cli#init-config

View File

@ -93,6 +93,49 @@ grad_factor = 1.0
@layers = "reduce_mean.v1" @layers = "reduce_mean.v1"
{% endif -%} {% endif -%}
{% if "entity_linker" in components -%}
[components.entity_linker]
factory = "entity_linker"
get_candidates = {"@misc":"spacy.CandidateGenerator.v1"}
incl_context = true
incl_prior = true
[components.entity_linker.model]
@architectures = "spacy.EntityLinker.v1"
nO = null
[components.entity_linker.model.tok2vec]
@architectures = "spacy-transformers.TransformerListener.v1"
grad_factor = 1.0
[components.entity_linker.model.tok2vec.pooling]
@layers = "reduce_mean.v1"
{% endif -%}
{% if "textcat" in components %}
[components.textcat]
factory = "textcat"
{% if optimize == "accuracy" %}
[components.textcat.model]
@architectures = "spacy.TextCatEnsemble.v1"
exclusive_classes = false
width = 64
conv_depth = 2
embed_size = 2000
window_size = 1
ngram_size = 1
nO = null
{% else -%}
[components.textcat.model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
{%- endif %}
{%- endif %}
{# NON-TRANSFORMER PIPELINE #} {# NON-TRANSFORMER PIPELINE #}
{% else -%} {% else -%}
@ -167,10 +210,50 @@ nO = null
@architectures = "spacy.Tok2VecListener.v1" @architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width} width = ${components.tok2vec.model.encode.width}
{% endif %} {% endif %}
{% if "entity_linker" in components -%}
[components.entity_linker]
factory = "entity_linker"
get_candidates = {"@misc":"spacy.CandidateGenerator.v1"}
incl_context = true
incl_prior = true
[components.entity_linker.model]
@architectures = "spacy.EntityLinker.v1"
nO = null
[components.entity_linker.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
{% endif %}
{% if "textcat" in components %}
[components.textcat]
factory = "textcat"
{% if optimize == "accuracy" %}
[components.textcat.model]
@architectures = "spacy.TextCatEnsemble.v1"
exclusive_classes = false
width = 64
conv_depth = 2
embed_size = 2000
window_size = 1
ngram_size = 1
nO = null
{% else -%}
[components.textcat.model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
{%- endif %}
{%- endif %}
{% endif %} {% endif %}
{% for pipe in components %} {% for pipe in components %}
{% if pipe not in ["tagger", "parser", "ner"] %} {% if pipe not in ["tagger", "parser", "ner", "textcat", "entity_linker"] %}
{# Other components defined by the user: we just assume they're factories #} {# Other components defined by the user: we just assume they're factories #}
[components.{{ pipe }}] [components.{{ pipe }}]
factory = "{{ pipe }}" factory = "{{ pipe }}"
@ -245,3 +328,6 @@ ents_f = {{ (1.0 / components|length)|round(2) }}
ents_p = 0.0 ents_p = 0.0
ents_r = 0.0 ents_r = 0.0
{%- endif -%} {%- endif -%}
{%- if "textcat" in components %}
cats_score = {{ (1.0 / components|length)|round(2) }}
{%- endif -%}