add entity_linker to jinja template

This commit is contained in:
svlandeg 2020-09-22 10:40:05 +02:00
parent 135de82a2d
commit 396b33257f
2 changed files with 34 additions and 2 deletions

View File

@ -36,7 +36,7 @@ def init_config_cli(
""" """
Generate a starter config.cfg for training. Based on your requirements Generate a starter config.cfg for training. Based on your requirements
specified via the CLI arguments, this command generates a config with the specified via the CLI arguments, this command generates a config with the
optimal settings for you use case. This includes the choice of architecture, optimal settings for your use case. This includes the choice of architecture,
pretrained weights and related hyperparameters. pretrained weights and related hyperparameters.
DOCS: https://nightly.spacy.io/api/cli#init-config DOCS: https://nightly.spacy.io/api/cli#init-config

View File

@ -93,6 +93,22 @@ grad_factor = 1.0
@layers = "reduce_mean.v1" @layers = "reduce_mean.v1"
{% endif -%} {% endif -%}
{% if "entity_linker" in components -%}
[components.entity_linker]
factory = "entity_linker"
get_candidates = {"@misc":"spacy.CandidateGenerator.v1"}
incl_context = true
incl_prior = true
[components.entity_linker.model]
@architectures = "spacy.EntityLinker.v1"
nO = null
[components.entity_linker.model.tok2vec]
@architectures = "spacy-transformers.TransformerListener.v1"
grad_factor = 1.0
{% endif -%}
{% if "textcat" in components %} {% if "textcat" in components %}
[components.textcat] [components.textcat]
factory = "textcat" factory = "textcat"
@ -191,6 +207,22 @@ nO = null
width = ${components.tok2vec.model.encode.width} width = ${components.tok2vec.model.encode.width}
{% endif %} {% endif %}
{% if "entity_linker" in components -%}
[components.entity_linker]
factory = "entity_linker"
get_candidates = {"@misc":"spacy.CandidateGenerator.v1"}
incl_context = true
incl_prior = true
[components.entity_linker.model]
@architectures = "spacy.EntityLinker.v1"
nO = null
[components.entity_linker.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
{% endif %}
{% if "textcat" in components %} {% if "textcat" in components %}
[components.textcat] [components.textcat]
factory = "textcat" factory = "textcat"
@ -216,7 +248,7 @@ ngram_size = 1
{% endif %} {% endif %}
{% for pipe in components %} {% for pipe in components %}
{% if pipe not in ["tagger", "parser", "ner", "textcat"] %} {% if pipe not in ["tagger", "parser", "ner", "textcat", "entity_linker"] %}
{# Other components defined by the user: we just assume they're factories #} {# Other components defined by the user: we just assume they're factories #}
[components.{{ pipe }}] [components.{{ pipe }}]
factory = "{{ pipe }}" factory = "{{ pipe }}"