mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-19 05:54:11 +03:00
Merge branch 'master' into spacy.io
This commit is contained in:
commit
fc88337cfa
|
@ -7,16 +7,17 @@ Run `wikipedia_pretrain_kb.py`
|
||||||
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
|
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||||
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
|
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
|
||||||
* You can set the filtering parameters for KB construction:
|
* You can set the filtering parameters for KB construction:
|
||||||
* `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
|
* `max_per_alias` (`-a`): (max) number of candidate entities in the KB per alias/synonym
|
||||||
* `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
|
* `min_freq` (`-f`): threshold of number of times an entity should occur in the corpus to be included in the KB
|
||||||
* `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
* `min_pair` (`-c`): threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
||||||
* Further parameters to set:
|
* Further parameters to set:
|
||||||
* `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
* `descriptions_from_wikipedia` (`-wp`): whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
||||||
* `entity_vector_length`: length of the pre-trained entity description vectors
|
* `entity_vector_length` (`-v`): length of the pre-trained entity description vectors
|
||||||
* `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
|
* `lang` (`-la`): language for which to fetch Wikidata information (as the dump contains all languages)
|
||||||
|
|
||||||
Quick testing and rerunning:
|
Quick testing and rerunning:
|
||||||
* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.
|
* When trying out the pipeline for a quick test, set `limit_prior` (`-lp`), `limit_train` (`-lt`) and/or `limit_wd` (`-lw`) to read only parts of the dumps instead of everything.
|
||||||
|
* e.g. set `-lt 20000 -lp 2000 -lw 3000 -f 1`
|
||||||
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
|
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,11 +25,13 @@ Quick testing and rerunning:
|
||||||
|
|
||||||
Run `wikidata_train_entity_linker.py`
|
Run `wikidata_train_entity_linker.py`
|
||||||
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
|
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
|
||||||
|
* Specify the output directory (`-o`) in which the final, trained model will be saved
|
||||||
* You can set the learning parameters for the EL training:
|
* You can set the learning parameters for the EL training:
|
||||||
* `epochs`: number of training iterations
|
* `epochs` (`-e`): number of training iterations
|
||||||
* `dropout`: dropout rate
|
* `dropout` (`-p`): dropout rate
|
||||||
* `lr`: learning rate
|
* `lr` (`-n`): learning rate
|
||||||
* `l2`: L2 regularization
|
* `l2` (`-r`): L2 regularization
|
||||||
* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
|
* Specify the number of training and dev testing articles with `train_articles` (`-t`) and `dev_articles` (`-d`) respectively
|
||||||
|
* If not specified, the full dataset will be processed - this may take a LONG time !
|
||||||
* Further parameters to set:
|
* Further parameters to set:
|
||||||
* `labels_discard`: NER label types to discard during training
|
* `labels_discard` (`-l`): NER label types to discard during training
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
@ -92,133 +94,110 @@ class BaselineResults(object):
|
||||||
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
||||||
|
|
||||||
|
|
||||||
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
|
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True, dev_limit=None):
|
||||||
if baseline:
|
counts = dict()
|
||||||
baseline_accuracies, counts = measure_baselines(dev_data, kb)
|
baseline_results = BaselineResults()
|
||||||
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
context_results = EvaluationResults()
|
||||||
logger.info(baseline_accuracies.report_performance("random"))
|
combo_results = EvaluationResults()
|
||||||
logger.info(baseline_accuracies.report_performance("prior"))
|
|
||||||
logger.info(baseline_accuracies.report_performance("oracle"))
|
|
||||||
|
|
||||||
if context:
|
for doc, gold in tqdm(dev_data, total=dev_limit, leave=False, desc='Processing dev data'):
|
||||||
# using only context
|
if len(doc) > 0:
|
||||||
el_pipe.cfg["incl_context"] = True
|
correct_ents = dict()
|
||||||
el_pipe.cfg["incl_prior"] = False
|
|
||||||
results = get_eval_results(dev_data, el_pipe)
|
|
||||||
logger.info(results.report_metrics("context only"))
|
|
||||||
|
|
||||||
# measuring combined accuracy (prior + context)
|
|
||||||
el_pipe.cfg["incl_context"] = True
|
|
||||||
el_pipe.cfg["incl_prior"] = True
|
|
||||||
results = get_eval_results(dev_data, el_pipe)
|
|
||||||
logger.info(results.report_metrics("context and prior"))
|
|
||||||
|
|
||||||
|
|
||||||
def get_eval_results(data, el_pipe=None):
|
|
||||||
"""
|
|
||||||
Evaluate the ent.kb_id_ annotations against the gold standard.
|
|
||||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
|
||||||
If the docs in the data require further processing with an entity linker, set el_pipe.
|
|
||||||
"""
|
|
||||||
docs = []
|
|
||||||
golds = []
|
|
||||||
for d, g in tqdm(data, leave=False):
|
|
||||||
if len(d) > 0:
|
|
||||||
golds.append(g)
|
|
||||||
if el_pipe is not None:
|
|
||||||
docs.append(el_pipe(d))
|
|
||||||
else:
|
|
||||||
docs.append(d)
|
|
||||||
|
|
||||||
results = EvaluationResults()
|
|
||||||
for doc, gold in zip(docs, golds):
|
|
||||||
try:
|
|
||||||
correct_entries_per_article = dict()
|
|
||||||
for entity, kb_dict in gold.links.items():
|
for entity, kb_dict in gold.links.items():
|
||||||
start, end = entity
|
start, end = entity
|
||||||
for gold_kb, value in kb_dict.items():
|
for gold_kb, value in kb_dict.items():
|
||||||
if value:
|
if value:
|
||||||
# only evaluating on positive examples
|
# only evaluating on positive examples
|
||||||
offset = _offset(start, end)
|
offset = _offset(start, end)
|
||||||
correct_entries_per_article[offset] = gold_kb
|
correct_ents[offset] = gold_kb
|
||||||
|
|
||||||
for ent in doc.ents:
|
if baseline:
|
||||||
ent_label = ent.label_
|
_add_baseline(baseline_results, counts, doc, correct_ents, kb)
|
||||||
pred_entity = ent.kb_id_
|
|
||||||
start = ent.start_char
|
|
||||||
end = ent.end_char
|
|
||||||
offset = _offset(start, end)
|
|
||||||
gold_entity = correct_entries_per_article.get(offset, None)
|
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
|
||||||
if gold_entity is not None:
|
|
||||||
results.update_metrics(ent_label, gold_entity, pred_entity)
|
|
||||||
|
|
||||||
except Exception as e:
|
if context:
|
||||||
logging.error("Error assessing accuracy " + str(e))
|
# using only context
|
||||||
|
el_pipe.cfg["incl_context"] = True
|
||||||
|
el_pipe.cfg["incl_prior"] = False
|
||||||
|
_add_eval_result(context_results, doc, correct_ents, el_pipe)
|
||||||
|
|
||||||
return results
|
# measuring combined accuracy (prior + context)
|
||||||
|
el_pipe.cfg["incl_context"] = True
|
||||||
|
el_pipe.cfg["incl_prior"] = True
|
||||||
|
_add_eval_result(combo_results, doc, correct_ents, el_pipe)
|
||||||
|
|
||||||
|
if baseline:
|
||||||
|
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
||||||
|
logger.info(baseline_results.report_performance("random"))
|
||||||
|
logger.info(baseline_results.report_performance("prior"))
|
||||||
|
logger.info(baseline_results.report_performance("oracle"))
|
||||||
|
|
||||||
|
if context:
|
||||||
|
logger.info(context_results.report_metrics("context only"))
|
||||||
|
logger.info(combo_results.report_metrics("context and prior"))
|
||||||
|
|
||||||
|
|
||||||
def measure_baselines(data, kb):
|
def _add_eval_result(results, doc, correct_ents, el_pipe):
|
||||||
"""
|
"""
|
||||||
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
|
Evaluate the ent.kb_id_ annotations against the gold standard.
|
||||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||||
Also return a dictionary of counts by entity label.
|
|
||||||
"""
|
"""
|
||||||
counts_d = dict()
|
try:
|
||||||
|
doc = el_pipe(doc)
|
||||||
baseline_results = BaselineResults()
|
|
||||||
|
|
||||||
docs = [d for d, g in data if len(d) > 0]
|
|
||||||
golds = [g for d, g in data if len(d) > 0]
|
|
||||||
|
|
||||||
for doc, gold in zip(docs, golds):
|
|
||||||
correct_entries_per_article = dict()
|
|
||||||
for entity, kb_dict in gold.links.items():
|
|
||||||
start, end = entity
|
|
||||||
for gold_kb, value in kb_dict.items():
|
|
||||||
# only evaluating on positive examples
|
|
||||||
if value:
|
|
||||||
offset = _offset(start, end)
|
|
||||||
correct_entries_per_article[offset] = gold_kb
|
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
ent_label = ent.label_
|
ent_label = ent.label_
|
||||||
start = ent.start_char
|
start = ent.start_char
|
||||||
end = ent.end_char
|
end = ent.end_char
|
||||||
offset = _offset(start, end)
|
offset = _offset(start, end)
|
||||||
gold_entity = correct_entries_per_article.get(offset, None)
|
gold_entity = correct_ents.get(offset, None)
|
||||||
|
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
if gold_entity is not None:
|
if gold_entity is not None:
|
||||||
candidates = kb.get_candidates(ent.text)
|
pred_entity = ent.kb_id_
|
||||||
oracle_candidate = ""
|
results.update_metrics(ent_label, gold_entity, pred_entity)
|
||||||
prior_candidate = ""
|
|
||||||
random_candidate = ""
|
|
||||||
if candidates:
|
|
||||||
scores = []
|
|
||||||
|
|
||||||
for c in candidates:
|
except Exception as e:
|
||||||
scores.append(c.prior_prob)
|
logging.error("Error assessing accuracy " + str(e))
|
||||||
if c.entity_ == gold_entity:
|
|
||||||
oracle_candidate = c.entity_
|
|
||||||
|
|
||||||
best_index = scores.index(max(scores))
|
|
||||||
prior_candidate = candidates[best_index].entity_
|
|
||||||
random_candidate = random.choice(candidates).entity_
|
|
||||||
|
|
||||||
current_count = counts_d.get(ent_label, 0)
|
def _add_baseline(baseline_results, counts, doc, correct_ents, kb):
|
||||||
counts_d[ent_label] = current_count+1
|
"""
|
||||||
|
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
|
||||||
|
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||||
|
"""
|
||||||
|
for ent in doc.ents:
|
||||||
|
ent_label = ent.label_
|
||||||
|
start = ent.start_char
|
||||||
|
end = ent.end_char
|
||||||
|
offset = _offset(start, end)
|
||||||
|
gold_entity = correct_ents.get(offset, None)
|
||||||
|
|
||||||
baseline_results.update_baselines(
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
gold_entity,
|
if gold_entity is not None:
|
||||||
ent_label,
|
candidates = kb.get_candidates(ent.text)
|
||||||
random_candidate,
|
oracle_candidate = ""
|
||||||
prior_candidate,
|
prior_candidate = ""
|
||||||
oracle_candidate,
|
random_candidate = ""
|
||||||
)
|
if candidates:
|
||||||
|
scores = []
|
||||||
|
|
||||||
return baseline_results, counts_d
|
for c in candidates:
|
||||||
|
scores.append(c.prior_prob)
|
||||||
|
if c.entity_ == gold_entity:
|
||||||
|
oracle_candidate = c.entity_
|
||||||
|
|
||||||
|
best_index = scores.index(max(scores))
|
||||||
|
prior_candidate = candidates[best_index].entity_
|
||||||
|
random_candidate = random.choice(candidates).entity_
|
||||||
|
|
||||||
|
current_count = counts.get(ent_label, 0)
|
||||||
|
counts[ent_label] = current_count+1
|
||||||
|
|
||||||
|
baseline_results.update_baselines(
|
||||||
|
gold_entity,
|
||||||
|
ent_label,
|
||||||
|
random_candidate,
|
||||||
|
prior_candidate,
|
||||||
|
oracle_candidate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _offset(start, end):
|
def _offset(start, end):
|
||||||
|
|
|
@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
|
||||||
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
|
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
|
||||||
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
|
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
|
||||||
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
|
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
|
||||||
descr_from_wp=("Flag for using wp descriptions not wd", "flag", "wp"),
|
descr_from_wp=("Flag for using descriptions from WP instead of WD (default False)", "flag", "wp"),
|
||||||
limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int),
|
limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int),
|
||||||
limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int),
|
limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int),
|
||||||
limit_wd=("Threshold to limit lines read from WD", "option", "lw", int),
|
limit_wd=("Threshold to limit lines read from WD", "option", "lw", int),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
"""Script to take a previously created Knowledge Base and train an entity linking
|
"""Script that takes a previously created Knowledge Base and trains an entity linking
|
||||||
pipeline. The provided KB directory should hold the kb, the original nlp object and
|
pipeline. The provided KB directory should hold the kb, the original nlp object and
|
||||||
its vocab used to create the KB, and a few auxiliary files such as the entity definitions,
|
its vocab used to create the KB, and a few auxiliary files such as the entity definitions,
|
||||||
as created by the script `wikidata_create_kb`.
|
as created by the script `wikidata_create_kb`.
|
||||||
|
@ -14,6 +14,7 @@ import logging
|
||||||
import spacy
|
import spacy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import plac
|
import plac
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from bin.wiki_entity_linking import wikipedia_processor
|
from bin.wiki_entity_linking import wikipedia_processor
|
||||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
||||||
|
@ -33,8 +34,8 @@ logger = logging.getLogger(__name__)
|
||||||
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
||||||
lr=("Learning rate (default 0.005)", "option", "n", float),
|
lr=("Learning rate (default 0.005)", "option", "n", float),
|
||||||
l2=("L2 regularization", "option", "r", float),
|
l2=("L2 regularization", "option", "r", float),
|
||||||
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
train_articles=("# training articles (default 90% of all)", "option", "t", int),
|
||||||
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
dev_articles=("# dev test articles (default 10% of all)", "option", "d", int),
|
||||||
labels_discard=("NER labels to discard (default None)", "option", "l", str),
|
labels_discard=("NER labels to discard (default None)", "option", "l", str),
|
||||||
)
|
)
|
||||||
def main(
|
def main(
|
||||||
|
@ -45,10 +46,13 @@ def main(
|
||||||
dropout=0.5,
|
dropout=0.5,
|
||||||
lr=0.005,
|
lr=0.005,
|
||||||
l2=1e-6,
|
l2=1e-6,
|
||||||
train_inst=None,
|
train_articles=None,
|
||||||
dev_inst=None,
|
dev_articles=None,
|
||||||
labels_discard=None
|
labels_discard=None
|
||||||
):
|
):
|
||||||
|
if not output_dir:
|
||||||
|
logger.warning("No output dir specified so no results will be written, are you sure about this ?")
|
||||||
|
|
||||||
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
||||||
|
|
||||||
output_dir = Path(output_dir) if output_dir else dir_kb
|
output_dir = Path(output_dir) if output_dir else dir_kb
|
||||||
|
@ -64,44 +68,33 @@ def main(
|
||||||
# STEP 1 : load the NLP object
|
# STEP 1 : load the NLP object
|
||||||
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
|
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
|
||||||
nlp = spacy.load(nlp_dir)
|
nlp = spacy.load(nlp_dir)
|
||||||
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
logger.info("Original NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
|
||||||
kb = read_kb(nlp, kb_path)
|
|
||||||
|
|
||||||
# check that there is a NER component in the pipeline
|
# check that there is a NER component in the pipeline
|
||||||
if "ner" not in nlp.pipe_names:
|
if "ner" not in nlp.pipe_names:
|
||||||
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
|
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
|
||||||
|
|
||||||
# STEP 2: read the training dataset previously created from WP
|
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
||||||
logger.info("STEP 2: Reading training dataset from {}".format(training_path))
|
kb = read_kb(nlp, kb_path)
|
||||||
|
|
||||||
|
# STEP 2: read the training dataset previously created from WP
|
||||||
|
logger.info("STEP 2: Reading training & dev dataset from {}".format(training_path))
|
||||||
|
train_indices, dev_indices = wikipedia_processor.read_training_indices(training_path)
|
||||||
|
logger.info("Training set has {} articles, limit set to roughly {} articles per epoch"
|
||||||
|
.format(len(train_indices), train_articles if train_articles else "all"))
|
||||||
|
logger.info("Dev set has {} articles, limit set to rougly {} articles for evaluation"
|
||||||
|
.format(len(dev_indices), dev_articles if dev_articles else "all"))
|
||||||
|
if dev_articles:
|
||||||
|
dev_indices = dev_indices[0:dev_articles]
|
||||||
|
|
||||||
|
# STEP 3: create and train an entity linking pipe
|
||||||
|
logger.info("STEP 3: Creating and training an Entity Linking pipe for {} epochs".format(epochs))
|
||||||
if labels_discard:
|
if labels_discard:
|
||||||
labels_discard = [x.strip() for x in labels_discard.split(",")]
|
labels_discard = [x.strip() for x in labels_discard.split(",")]
|
||||||
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
|
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
|
||||||
else:
|
else:
|
||||||
labels_discard = []
|
labels_discard = []
|
||||||
|
|
||||||
train_data = wikipedia_processor.read_training(
|
|
||||||
nlp=nlp,
|
|
||||||
entity_file_path=training_path,
|
|
||||||
dev=False,
|
|
||||||
limit=train_inst,
|
|
||||||
kb=kb,
|
|
||||||
labels_discard=labels_discard
|
|
||||||
)
|
|
||||||
|
|
||||||
# for testing, get all pos instances (independently of KB)
|
|
||||||
dev_data = wikipedia_processor.read_training(
|
|
||||||
nlp=nlp,
|
|
||||||
entity_file_path=training_path,
|
|
||||||
dev=True,
|
|
||||||
limit=dev_inst,
|
|
||||||
kb=None,
|
|
||||||
labels_discard=labels_discard
|
|
||||||
)
|
|
||||||
|
|
||||||
# STEP 3: create and train an entity linking pipe
|
|
||||||
logger.info("STEP 3: Creating and training an Entity Linking pipe")
|
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(
|
el_pipe = nlp.create_pipe(
|
||||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
|
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
|
||||||
"labels_discard": labels_discard}
|
"labels_discard": labels_discard}
|
||||||
|
@ -115,80 +108,65 @@ def main(
|
||||||
optimizer.learn_rate = lr
|
optimizer.learn_rate = lr
|
||||||
optimizer.L2 = l2
|
optimizer.L2 = l2
|
||||||
|
|
||||||
logger.info("Training on {} articles".format(len(train_data)))
|
|
||||||
logger.info("Dev testing on {} articles".format(len(dev_data)))
|
|
||||||
|
|
||||||
# baseline performance on dev data
|
|
||||||
logger.info("Dev Baseline Accuracies:")
|
logger.info("Dev Baseline Accuracies:")
|
||||||
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
|
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||||
|
dev=True, line_ids=dev_indices,
|
||||||
|
kb=kb, labels_discard=labels_discard)
|
||||||
|
|
||||||
|
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False, dev_limit=len(dev_indices))
|
||||||
|
|
||||||
for itn in range(epochs):
|
for itn in range(epochs):
|
||||||
random.shuffle(train_data)
|
random.shuffle(train_indices)
|
||||||
losses = {}
|
losses = {}
|
||||||
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
batches = minibatch(train_indices, size=compounding(8.0, 128.0, 1.001))
|
||||||
batchnr = 0
|
batchnr = 0
|
||||||
|
articles_processed = 0
|
||||||
|
|
||||||
with nlp.disable_pipes(*other_pipes):
|
# we either process the whole training file, or just a part each epoch
|
||||||
|
bar_total = len(train_indices)
|
||||||
|
if train_articles:
|
||||||
|
bar_total = train_articles
|
||||||
|
|
||||||
|
with tqdm(total=bar_total, leave=False, desc='Epoch ' + str(itn)) as pbar:
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
try:
|
if not train_articles or articles_processed < train_articles:
|
||||||
docs, golds = zip(*batch)
|
with nlp.disable_pipes("entity_linker"):
|
||||||
nlp.update(
|
train_batch = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||||
docs=docs,
|
dev=False, line_ids=batch,
|
||||||
golds=golds,
|
kb=kb, labels_discard=labels_discard)
|
||||||
sgd=optimizer,
|
docs, golds = zip(*train_batch)
|
||||||
drop=dropout,
|
try:
|
||||||
losses=losses,
|
with nlp.disable_pipes(*other_pipes):
|
||||||
)
|
nlp.update(
|
||||||
batchnr += 1
|
docs=docs,
|
||||||
except Exception as e:
|
golds=golds,
|
||||||
logger.error("Error updating batch:" + str(e))
|
sgd=optimizer,
|
||||||
|
drop=dropout,
|
||||||
|
losses=losses,
|
||||||
|
)
|
||||||
|
batchnr += 1
|
||||||
|
articles_processed += len(docs)
|
||||||
|
pbar.update(len(docs))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error updating batch:" + str(e))
|
||||||
if batchnr > 0:
|
if batchnr > 0:
|
||||||
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
|
logging.info("Epoch {} trained on {} articles, train loss {}"
|
||||||
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
|
.format(itn, articles_processed, round(losses["entity_linker"] / batchnr, 2)))
|
||||||
|
# re-read the dev_data (data is returned as a generator)
|
||||||
# STEP 4: measure the performance of our trained pipe on an independent dev set
|
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||||
logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
|
dev=True, line_ids=dev_indices,
|
||||||
measure_performance(dev_data, kb, el_pipe)
|
kb=kb, labels_discard=labels_discard)
|
||||||
|
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True, dev_limit=len(dev_indices))
|
||||||
# STEP 5: apply the EL pipe on a toy example
|
|
||||||
logger.info("STEP 5: Applying Entity Linking to toy example")
|
|
||||||
run_el_toy_example(nlp=nlp)
|
|
||||||
|
|
||||||
if output_dir:
|
if output_dir:
|
||||||
# STEP 6: write the NLP pipeline (now including an EL model) to file
|
# STEP 4: write the NLP pipeline (now including an EL model) to file
|
||||||
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
|
logger.info("Final NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
|
||||||
|
logger.info("STEP 4: Writing trained NLP to {}".format(nlp_output_dir))
|
||||||
nlp.to_disk(nlp_output_dir)
|
nlp.to_disk(nlp_output_dir)
|
||||||
|
|
||||||
logger.info("Done!")
|
logger.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
def check_kb(kb):
|
|
||||||
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
|
||||||
candidates = kb.get_candidates(mention)
|
|
||||||
|
|
||||||
logger.info("generating candidates for " + mention + " :")
|
|
||||||
for c in candidates:
|
|
||||||
logger.info(" ".join[
|
|
||||||
str(c.prior_prob),
|
|
||||||
c.alias_,
|
|
||||||
"-->",
|
|
||||||
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def run_el_toy_example(nlp):
|
|
||||||
text = (
|
|
||||||
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
|
|
||||||
"Douglas reminds us to always bring our towel, even in China or Brazil. "
|
|
||||||
"The main character in Doug's novel is the man Arthur Dent, "
|
|
||||||
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
|
||||||
)
|
|
||||||
doc = nlp(text)
|
|
||||||
logger.info(text)
|
|
||||||
for ent in doc.ents:
|
|
||||||
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
|
@ -6,9 +6,6 @@ import bz2
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import json
|
import json
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
from bin.wiki_entity_linking import wiki_io as io
|
from bin.wiki_entity_linking import wiki_io as io
|
||||||
|
@ -454,25 +451,40 @@ def _write_training_entities(outputfile, article_id, clean_text, entities):
|
||||||
outputfile.write(line)
|
outputfile.write(line)
|
||||||
|
|
||||||
|
|
||||||
def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
|
def read_training_indices(entity_file_path):
|
||||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
""" This method creates two lists of indices into the training file: one with indices for the
|
||||||
|
training examples, and one for the dev examples."""
|
||||||
|
train_indices = []
|
||||||
|
dev_indices = []
|
||||||
|
|
||||||
|
with entity_file_path.open("r", encoding="utf8") as file:
|
||||||
|
for i, line in enumerate(file):
|
||||||
|
example = json.loads(line)
|
||||||
|
article_id = example["article_id"]
|
||||||
|
clean_text = example["clean_text"]
|
||||||
|
|
||||||
|
if is_valid_article(clean_text):
|
||||||
|
if is_dev(article_id):
|
||||||
|
dev_indices.append(i)
|
||||||
|
else:
|
||||||
|
train_indices.append(i)
|
||||||
|
|
||||||
|
return train_indices, dev_indices
|
||||||
|
|
||||||
|
|
||||||
|
def read_el_docs_golds(nlp, entity_file_path, dev, line_ids, kb, labels_discard=None):
|
||||||
|
""" This method provides training/dev examples that correspond to the entity annotations found by the nlp object.
|
||||||
For training, it will include both positive and negative examples by using the candidate generator from the kb.
|
For training, it will include both positive and negative examples by using the candidate generator from the kb.
|
||||||
For testing (kb=None), it will include all positive examples only."""
|
For testing (kb=None), it will include all positive examples only."""
|
||||||
if not labels_discard:
|
if not labels_discard:
|
||||||
labels_discard = []
|
labels_discard = []
|
||||||
|
|
||||||
data = []
|
texts = []
|
||||||
num_entities = 0
|
entities_list = []
|
||||||
get_gold_parse = partial(
|
|
||||||
_get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Reading {} data with limit {}".format("dev" if dev else "train", limit)
|
|
||||||
)
|
|
||||||
with entity_file_path.open("r", encoding="utf8") as file:
|
with entity_file_path.open("r", encoding="utf8") as file:
|
||||||
with tqdm(total=limit, leave=False) as pbar:
|
for i, line in enumerate(file):
|
||||||
for i, line in enumerate(file):
|
if i in line_ids:
|
||||||
example = json.loads(line)
|
example = json.loads(line)
|
||||||
article_id = example["article_id"]
|
article_id = example["article_id"]
|
||||||
clean_text = example["clean_text"]
|
clean_text = example["clean_text"]
|
||||||
|
@ -481,16 +493,15 @@ def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
|
||||||
if dev != is_dev(article_id) or not is_valid_article(clean_text):
|
if dev != is_dev(article_id) or not is_valid_article(clean_text):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
doc = nlp(clean_text)
|
texts.append(clean_text)
|
||||||
gold = get_gold_parse(doc, entities)
|
entities_list.append(entities)
|
||||||
if gold and len(gold.links) > 0:
|
|
||||||
data.append((doc, gold))
|
docs = nlp.pipe(texts, batch_size=50)
|
||||||
num_entities += len(gold.links)
|
|
||||||
pbar.update(len(gold.links))
|
for doc, entities in zip(docs, entities_list):
|
||||||
if limit and num_entities >= limit:
|
gold = _get_gold_parse(doc, entities, dev=dev, kb=kb, labels_discard=labels_discard)
|
||||||
break
|
if gold and len(gold.links) > 0:
|
||||||
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
yield doc, gold
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
|
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
|
||||||
|
|
|
@ -32,27 +32,24 @@ DESC_WIDTH = 64 # dimension of output entity vectors
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
vocab_path=("Path to the vocab for the kb", "option", "v", Path),
|
model=("Model name, should have pretrained word embeddings", "positional", None, str),
|
||||||
model=("Model name, should have pretrained word embeddings", "option", "m", str),
|
|
||||||
output_dir=("Optional output directory", "option", "o", Path),
|
output_dir=("Optional output directory", "option", "o", Path),
|
||||||
n_iter=("Number of training iterations", "option", "n", int),
|
n_iter=("Number of training iterations", "option", "n", int),
|
||||||
)
|
)
|
||||||
def main(vocab_path=None, model=None, output_dir=None, n_iter=50):
|
def main(model=None, output_dir=None, n_iter=50):
|
||||||
"""Load the model, create the KB and pretrain the entity encodings.
|
"""Load the model, create the KB and pretrain the entity encodings.
|
||||||
Either an nlp model or a vocab is needed to provide access to pretrained word embeddings.
|
|
||||||
If an output_dir is provided, the KB will be stored there in a file 'kb'.
|
If an output_dir is provided, the KB will be stored there in a file 'kb'.
|
||||||
When providing an nlp model, the updated vocab will also be written to a directory in the output_dir."""
|
The updated vocab will also be written to a directory in the output_dir."""
|
||||||
if model is None and vocab_path is None:
|
|
||||||
raise ValueError("Either the `nlp` model or the `vocab` should be specified.")
|
|
||||||
|
|
||||||
if model is not None:
|
nlp = spacy.load(model) # load existing spaCy model
|
||||||
nlp = spacy.load(model) # load existing spaCy model
|
print("Loaded model '%s'" % model)
|
||||||
print("Loaded model '%s'" % model)
|
|
||||||
else:
|
# check the length of the nlp vectors
|
||||||
vocab = Vocab().from_disk(vocab_path)
|
if "vectors" not in nlp.meta or not nlp.vocab.vectors.size:
|
||||||
# create blank Language class with specified vocab
|
raise ValueError(
|
||||||
nlp = spacy.blank("en", vocab=vocab)
|
"The `nlp` object should have access to pretrained word vectors, "
|
||||||
print("Created blank 'en' model with vocab from '%s'" % vocab_path)
|
" cf. https://spacy.io/usage/models#languages."
|
||||||
|
)
|
||||||
|
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab)
|
kb = KnowledgeBase(vocab=nlp.vocab)
|
||||||
|
|
||||||
|
@ -103,11 +100,9 @@ def main(vocab_path=None, model=None, output_dir=None, n_iter=50):
|
||||||
print()
|
print()
|
||||||
print("Saved KB to", kb_path)
|
print("Saved KB to", kb_path)
|
||||||
|
|
||||||
# only storing the vocab if we weren't already reading it from file
|
vocab_path = output_dir / "vocab"
|
||||||
if not vocab_path:
|
kb.vocab.to_disk(vocab_path)
|
||||||
vocab_path = output_dir / "vocab"
|
print("Saved vocab to", vocab_path)
|
||||||
kb.vocab.to_disk(vocab_path)
|
|
||||||
print("Saved vocab to", vocab_path)
|
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
|
@ -131,7 +131,8 @@ def train_textcat(nlp, n_texts, n_iter=10):
|
||||||
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
|
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
|
||||||
|
|
||||||
# get names of other pipes to disable them during training
|
# get names of other pipes to disable them during training
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"]
|
pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"]
|
||||||
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||||
with nlp.disable_pipes(*other_pipes): # only train textcat
|
with nlp.disable_pipes(*other_pipes): # only train textcat
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
textcat.model.tok2vec.from_bytes(tok2vec_weights)
|
textcat.model.tok2vec.from_bytes(tok2vec_weights)
|
||||||
|
|
|
@ -63,7 +63,8 @@ def main(model_name, unlabelled_loc):
|
||||||
optimizer.b2 = 0.0
|
optimizer.b2 = 0.0
|
||||||
|
|
||||||
# get names of other pipes to disable them during training
|
# get names of other pipes to disable them during training
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
|
pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"]
|
||||||
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||||
sizes = compounding(1.0, 4.0, 1.001)
|
sizes = compounding(1.0, 4.0, 1.001)
|
||||||
with nlp.disable_pipes(*other_pipes):
|
with nlp.disable_pipes(*other_pipes):
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
|
|
|
@ -113,7 +113,8 @@ def main(kb_path, vocab_path=None, output_dir=None, n_iter=50):
|
||||||
TRAIN_DOCS.append((doc, annotation_clean))
|
TRAIN_DOCS.append((doc, annotation_clean))
|
||||||
|
|
||||||
# get names of other pipes to disable them during training
|
# get names of other pipes to disable them during training
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"]
|
pipe_exceptions = ["entity_linker", "trf_wordpiecer", "trf_tok2vec"]
|
||||||
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||||
with nlp.disable_pipes(*other_pipes): # only train entity linker
|
with nlp.disable_pipes(*other_pipes): # only train entity linker
|
||||||
# reset and initialize the weights randomly
|
# reset and initialize the weights randomly
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
|
|
|
@ -124,7 +124,8 @@ def main(model=None, output_dir=None, n_iter=15):
|
||||||
for dep in annotations.get("deps", []):
|
for dep in annotations.get("deps", []):
|
||||||
parser.add_label(dep)
|
parser.add_label(dep)
|
||||||
|
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "parser"]
|
pipe_exceptions = ["parser", "trf_wordpiecer", "trf_tok2vec"]
|
||||||
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||||
with nlp.disable_pipes(*other_pipes): # only train parser
|
with nlp.disable_pipes(*other_pipes): # only train parser
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
|
|
|
@ -55,7 +55,8 @@ def main(model=None, output_dir=None, n_iter=100):
|
||||||
ner.add_label(ent[2])
|
ner.add_label(ent[2])
|
||||||
|
|
||||||
# get names of other pipes to disable them during training
|
# get names of other pipes to disable them during training
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
|
pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"]
|
||||||
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||||
with nlp.disable_pipes(*other_pipes): # only train NER
|
with nlp.disable_pipes(*other_pipes): # only train NER
|
||||||
# reset and initialize the weights randomly – but only if we're
|
# reset and initialize the weights randomly – but only if we're
|
||||||
# training a new model
|
# training a new model
|
||||||
|
|
|
@ -95,7 +95,8 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=30):
|
||||||
optimizer = nlp.resume_training()
|
optimizer = nlp.resume_training()
|
||||||
move_names = list(ner.move_names)
|
move_names = list(ner.move_names)
|
||||||
# get names of other pipes to disable them during training
|
# get names of other pipes to disable them during training
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
|
pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"]
|
||||||
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||||
with nlp.disable_pipes(*other_pipes): # only train NER
|
with nlp.disable_pipes(*other_pipes): # only train NER
|
||||||
sizes = compounding(1.0, 4.0, 1.001)
|
sizes = compounding(1.0, 4.0, 1.001)
|
||||||
# batch up the examples using spaCy's minibatch
|
# batch up the examples using spaCy's minibatch
|
||||||
|
|
|
@ -65,7 +65,8 @@ def main(model=None, output_dir=None, n_iter=15):
|
||||||
parser.add_label(dep)
|
parser.add_label(dep)
|
||||||
|
|
||||||
# get names of other pipes to disable them during training
|
# get names of other pipes to disable them during training
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "parser"]
|
pipe_exceptions = ["parser", "trf_wordpiecer", "trf_tok2vec"]
|
||||||
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||||
with nlp.disable_pipes(*other_pipes): # only train parser
|
with nlp.disable_pipes(*other_pipes): # only train parser
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
|
|
|
@ -67,7 +67,8 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
|
||||||
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
|
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
|
||||||
|
|
||||||
# get names of other pipes to disable them during training
|
# get names of other pipes to disable them during training
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"]
|
pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"]
|
||||||
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||||
with nlp.disable_pipes(*other_pipes): # only train textcat
|
with nlp.disable_pipes(*other_pipes): # only train textcat
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training()
|
||||||
if init_tok2vec is not None:
|
if init_tok2vec is not None:
|
||||||
|
|
|
@ -91,3 +91,4 @@ cdef enum attr_id_t:
|
||||||
|
|
||||||
LANG
|
LANG
|
||||||
ENT_KB_ID = symbols.ENT_KB_ID
|
ENT_KB_ID = symbols.ENT_KB_ID
|
||||||
|
ENT_ID = symbols.ENT_ID
|
||||||
|
|
|
@ -84,6 +84,7 @@ IDS = {
|
||||||
"DEP": DEP,
|
"DEP": DEP,
|
||||||
"ENT_IOB": ENT_IOB,
|
"ENT_IOB": ENT_IOB,
|
||||||
"ENT_TYPE": ENT_TYPE,
|
"ENT_TYPE": ENT_TYPE,
|
||||||
|
"ENT_ID": ENT_ID,
|
||||||
"ENT_KB_ID": ENT_KB_ID,
|
"ENT_KB_ID": ENT_KB_ID,
|
||||||
"HEAD": HEAD,
|
"HEAD": HEAD,
|
||||||
"SENT_START": SENT_START,
|
"SENT_START": SENT_START,
|
||||||
|
|
|
@ -192,6 +192,7 @@ def debug_data(
|
||||||
has_low_data_warning = False
|
has_low_data_warning = False
|
||||||
has_no_neg_warning = False
|
has_no_neg_warning = False
|
||||||
has_ws_ents_error = False
|
has_ws_ents_error = False
|
||||||
|
has_punct_ents_warning = False
|
||||||
|
|
||||||
msg.divider("Named Entity Recognition")
|
msg.divider("Named Entity Recognition")
|
||||||
msg.info(
|
msg.info(
|
||||||
|
@ -226,10 +227,16 @@ def debug_data(
|
||||||
|
|
||||||
if gold_train_data["ws_ents"]:
|
if gold_train_data["ws_ents"]:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
"{} invalid whitespace entity spans".format(gold_train_data["ws_ents"])
|
"{} invalid whitespace entity span(s)".format(gold_train_data["ws_ents"])
|
||||||
)
|
)
|
||||||
has_ws_ents_error = True
|
has_ws_ents_error = True
|
||||||
|
|
||||||
|
if gold_train_data["punct_ents"]:
|
||||||
|
msg.warn(
|
||||||
|
"{} entity span(s) with punctuation".format(gold_train_data["punct_ents"])
|
||||||
|
)
|
||||||
|
has_punct_ents_warning = True
|
||||||
|
|
||||||
for label in new_labels:
|
for label in new_labels:
|
||||||
if label_counts[label] <= NEW_LABEL_THRESHOLD:
|
if label_counts[label] <= NEW_LABEL_THRESHOLD:
|
||||||
msg.warn(
|
msg.warn(
|
||||||
|
@ -253,6 +260,8 @@ def debug_data(
|
||||||
msg.good("Examples without occurrences available for all labels")
|
msg.good("Examples without occurrences available for all labels")
|
||||||
if not has_ws_ents_error:
|
if not has_ws_ents_error:
|
||||||
msg.good("No entities consisting of or starting/ending with whitespace")
|
msg.good("No entities consisting of or starting/ending with whitespace")
|
||||||
|
if not has_punct_ents_warning:
|
||||||
|
msg.good("No entities consisting of or starting/ending with punctuation")
|
||||||
|
|
||||||
if has_low_data_warning:
|
if has_low_data_warning:
|
||||||
msg.text(
|
msg.text(
|
||||||
|
@ -273,6 +282,12 @@ def debug_data(
|
||||||
"with whitespace characters are considered invalid."
|
"with whitespace characters are considered invalid."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if has_punct_ents_warning:
|
||||||
|
msg.text(
|
||||||
|
"Entity spans consisting of or starting/ending "
|
||||||
|
"with punctuation can not be trained with a noise level > 0."
|
||||||
|
)
|
||||||
|
|
||||||
if "textcat" in pipeline:
|
if "textcat" in pipeline:
|
||||||
msg.divider("Text Classification")
|
msg.divider("Text Classification")
|
||||||
labels = [label for label in gold_train_data["cats"]]
|
labels = [label for label in gold_train_data["cats"]]
|
||||||
|
@ -547,6 +562,7 @@ def _compile_gold(train_docs, pipeline):
|
||||||
"words": Counter(),
|
"words": Counter(),
|
||||||
"roots": Counter(),
|
"roots": Counter(),
|
||||||
"ws_ents": 0,
|
"ws_ents": 0,
|
||||||
|
"punct_ents": 0,
|
||||||
"n_words": 0,
|
"n_words": 0,
|
||||||
"n_misaligned_words": 0,
|
"n_misaligned_words": 0,
|
||||||
"n_sents": 0,
|
"n_sents": 0,
|
||||||
|
@ -568,6 +584,10 @@ def _compile_gold(train_docs, pipeline):
|
||||||
if label.startswith(("B-", "U-", "L-")) and doc[i].is_space:
|
if label.startswith(("B-", "U-", "L-")) and doc[i].is_space:
|
||||||
# "Illegal" whitespace entity
|
# "Illegal" whitespace entity
|
||||||
data["ws_ents"] += 1
|
data["ws_ents"] += 1
|
||||||
|
if label.startswith(("B-", "U-", "L-")) and doc[i].text in [".", "'", "!", "?", ","]:
|
||||||
|
# punctuation entity: could be replaced by whitespace when training with noise,
|
||||||
|
# so add a warning to alert the user to this unexpected side effect.
|
||||||
|
data["punct_ents"] += 1
|
||||||
if label.startswith(("B-", "U-")):
|
if label.startswith(("B-", "U-")):
|
||||||
combined_label = label.split("-")[1]
|
combined_label = label.split("-")[1]
|
||||||
data["ner"][combined_label] += 1
|
data["ner"][combined_label] += 1
|
||||||
|
|
|
@ -30,6 +30,7 @@ from .. import about
|
||||||
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
|
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
|
||||||
base_model=("Name of model to update (optional)", "option", "b", str),
|
base_model=("Name of model to update (optional)", "option", "b", str),
|
||||||
pipeline=("Comma-separated names of pipeline components", "option", "p", str),
|
pipeline=("Comma-separated names of pipeline components", "option", "p", str),
|
||||||
|
replace_components=("Replace components from base model", "flag", "R", bool),
|
||||||
vectors=("Model to load vectors from", "option", "v", str),
|
vectors=("Model to load vectors from", "option", "v", str),
|
||||||
n_iter=("Number of iterations", "option", "n", int),
|
n_iter=("Number of iterations", "option", "n", int),
|
||||||
n_early_stopping=("Maximum number of training epochs without dev accuracy improvement", "option", "ne", int),
|
n_early_stopping=("Maximum number of training epochs without dev accuracy improvement", "option", "ne", int),
|
||||||
|
@ -60,6 +61,7 @@ def train(
|
||||||
raw_text=None,
|
raw_text=None,
|
||||||
base_model=None,
|
base_model=None,
|
||||||
pipeline="tagger,parser,ner",
|
pipeline="tagger,parser,ner",
|
||||||
|
replace_components=False,
|
||||||
vectors=None,
|
vectors=None,
|
||||||
n_iter=30,
|
n_iter=30,
|
||||||
n_early_stopping=None,
|
n_early_stopping=None,
|
||||||
|
@ -142,6 +144,8 @@ def train(
|
||||||
# the model and make sure the pipeline matches the pipeline setting. If
|
# the model and make sure the pipeline matches the pipeline setting. If
|
||||||
# training starts from a blank model, intitalize the language class.
|
# training starts from a blank model, intitalize the language class.
|
||||||
pipeline = [p.strip() for p in pipeline.split(",")]
|
pipeline = [p.strip() for p in pipeline.split(",")]
|
||||||
|
disabled_pipes = None
|
||||||
|
pipes_added = False
|
||||||
msg.text("Training pipeline: {}".format(pipeline))
|
msg.text("Training pipeline: {}".format(pipeline))
|
||||||
if base_model:
|
if base_model:
|
||||||
msg.text("Starting with base model '{}'".format(base_model))
|
msg.text("Starting with base model '{}'".format(base_model))
|
||||||
|
@ -152,20 +156,24 @@ def train(
|
||||||
"`lang` argument ('{}') ".format(nlp.lang, lang),
|
"`lang` argument ('{}') ".format(nlp.lang, lang),
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
|
|
||||||
for pipe in pipeline:
|
for pipe in pipeline:
|
||||||
|
pipe_cfg = {}
|
||||||
|
if pipe == "parser":
|
||||||
|
pipe_cfg = {"learn_tokens": learn_tokens}
|
||||||
|
elif pipe == "textcat":
|
||||||
|
pipe_cfg = {
|
||||||
|
"exclusive_classes": not textcat_multilabel,
|
||||||
|
"architecture": textcat_arch,
|
||||||
|
"positive_label": textcat_positive_label,
|
||||||
|
}
|
||||||
if pipe not in nlp.pipe_names:
|
if pipe not in nlp.pipe_names:
|
||||||
if pipe == "parser":
|
msg.text("Adding component to base model '{}'".format(pipe))
|
||||||
pipe_cfg = {"learn_tokens": learn_tokens}
|
|
||||||
elif pipe == "textcat":
|
|
||||||
pipe_cfg = {
|
|
||||||
"exclusive_classes": not textcat_multilabel,
|
|
||||||
"architecture": textcat_arch,
|
|
||||||
"positive_label": textcat_positive_label,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
pipe_cfg = {}
|
|
||||||
nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
|
nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
|
||||||
|
pipes_added = True
|
||||||
|
elif replace_components:
|
||||||
|
msg.text("Replacing component from base model '{}'".format(pipe))
|
||||||
|
nlp.replace_pipe(pipe, nlp.create_pipe(pipe, config=pipe_cfg))
|
||||||
|
pipes_added = True
|
||||||
else:
|
else:
|
||||||
if pipe == "textcat":
|
if pipe == "textcat":
|
||||||
textcat_cfg = nlp.get_pipe("textcat").cfg
|
textcat_cfg = nlp.get_pipe("textcat").cfg
|
||||||
|
@ -174,11 +182,6 @@ def train(
|
||||||
"architecture": textcat_cfg["architecture"],
|
"architecture": textcat_cfg["architecture"],
|
||||||
"positive_label": textcat_cfg["positive_label"],
|
"positive_label": textcat_cfg["positive_label"],
|
||||||
}
|
}
|
||||||
pipe_cfg = {
|
|
||||||
"exclusive_classes": not textcat_multilabel,
|
|
||||||
"architecture": textcat_arch,
|
|
||||||
"positive_label": textcat_positive_label,
|
|
||||||
}
|
|
||||||
if base_cfg != pipe_cfg:
|
if base_cfg != pipe_cfg:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
"The base textcat model configuration does"
|
"The base textcat model configuration does"
|
||||||
|
@ -188,6 +191,8 @@ def train(
|
||||||
),
|
),
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
|
msg.text("Extending component from base model '{}'".format(pipe))
|
||||||
|
disabled_pipes = nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
|
||||||
else:
|
else:
|
||||||
msg.text("Starting with blank model '{}'".format(lang))
|
msg.text("Starting with blank model '{}'".format(lang))
|
||||||
lang_cls = util.get_lang_class(lang)
|
lang_cls = util.get_lang_class(lang)
|
||||||
|
@ -227,7 +232,7 @@ def train(
|
||||||
corpus = GoldCorpus(train_path, dev_path, limit=n_examples)
|
corpus = GoldCorpus(train_path, dev_path, limit=n_examples)
|
||||||
n_train_words = corpus.count_train()
|
n_train_words = corpus.count_train()
|
||||||
|
|
||||||
if base_model:
|
if base_model and not pipes_added:
|
||||||
# Start with an existing model, use default optimizer
|
# Start with an existing model, use default optimizer
|
||||||
optimizer = create_default_optimizer(Model.ops)
|
optimizer = create_default_optimizer(Model.ops)
|
||||||
else:
|
else:
|
||||||
|
@ -243,7 +248,7 @@ def train(
|
||||||
|
|
||||||
# Verify textcat config
|
# Verify textcat config
|
||||||
if "textcat" in pipeline:
|
if "textcat" in pipeline:
|
||||||
textcat_labels = nlp.get_pipe("textcat").cfg["labels"]
|
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
|
||||||
if textcat_positive_label and textcat_positive_label not in textcat_labels:
|
if textcat_positive_label and textcat_positive_label not in textcat_labels:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
"The textcat_positive_label (tpl) '{}' does not match any "
|
"The textcat_positive_label (tpl) '{}' does not match any "
|
||||||
|
@ -426,11 +431,16 @@ def train(
|
||||||
"cpu": cpu_wps,
|
"cpu": cpu_wps,
|
||||||
"gpu": gpu_wps,
|
"gpu": gpu_wps,
|
||||||
}
|
}
|
||||||
meta["accuracy"] = scorer.scores
|
meta.setdefault("accuracy", {})
|
||||||
|
for component in nlp.pipe_names:
|
||||||
|
for metric in _get_metrics(component):
|
||||||
|
meta["accuracy"][metric] = scorer.scores[metric]
|
||||||
else:
|
else:
|
||||||
meta.setdefault("beam_accuracy", {})
|
meta.setdefault("beam_accuracy", {})
|
||||||
meta.setdefault("beam_speed", {})
|
meta.setdefault("beam_speed", {})
|
||||||
meta["beam_accuracy"][beam_width] = scorer.scores
|
for component in nlp.pipe_names:
|
||||||
|
for metric in _get_metrics(component):
|
||||||
|
meta["beam_accuracy"][metric] = scorer.scores[metric]
|
||||||
meta["beam_speed"][beam_width] = {
|
meta["beam_speed"][beam_width] = {
|
||||||
"nwords": nwords,
|
"nwords": nwords,
|
||||||
"cpu": cpu_wps,
|
"cpu": cpu_wps,
|
||||||
|
@ -486,12 +496,16 @@ def train(
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
|
best_pipes = nlp.pipe_names
|
||||||
|
if disabled_pipes:
|
||||||
|
disabled_pipes.restore()
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
final_model_path = output_path / "model-final"
|
final_model_path = output_path / "model-final"
|
||||||
nlp.to_disk(final_model_path)
|
nlp.to_disk(final_model_path)
|
||||||
|
final_meta = srsly.read_json(output_path / "model-final" / "meta.json")
|
||||||
msg.good("Saved model to output directory", final_model_path)
|
msg.good("Saved model to output directory", final_model_path)
|
||||||
with msg.loading("Creating best model..."):
|
with msg.loading("Creating best model..."):
|
||||||
best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names)
|
best_model_path = _collate_best_model(final_meta, output_path, best_pipes)
|
||||||
msg.good("Created best model", best_model_path)
|
msg.good("Created best model", best_model_path)
|
||||||
|
|
||||||
|
|
||||||
|
@ -549,6 +563,7 @@ def _load_pretrained_tok2vec(nlp, loc):
|
||||||
|
|
||||||
def _collate_best_model(meta, output_path, components):
|
def _collate_best_model(meta, output_path, components):
|
||||||
bests = {}
|
bests = {}
|
||||||
|
meta.setdefault("accuracy", {})
|
||||||
for component in components:
|
for component in components:
|
||||||
bests[component] = _find_best(output_path, component)
|
bests[component] = _find_best(output_path, component)
|
||||||
best_dest = output_path / "model-best"
|
best_dest = output_path / "model-best"
|
||||||
|
@ -580,11 +595,13 @@ def _find_best(experiment_dir, component):
|
||||||
|
|
||||||
def _get_metrics(component):
|
def _get_metrics(component):
|
||||||
if component == "parser":
|
if component == "parser":
|
||||||
return ("las", "uas", "token_acc")
|
return ("las", "uas", "las_per_type", "token_acc")
|
||||||
elif component == "tagger":
|
elif component == "tagger":
|
||||||
return ("tags_acc",)
|
return ("tags_acc",)
|
||||||
elif component == "ner":
|
elif component == "ner":
|
||||||
return ("ents_f", "ents_p", "ents_r")
|
return ("ents_f", "ents_p", "ents_r", "ents_per_type")
|
||||||
|
elif component == "textcat":
|
||||||
|
return ("textcat_score",)
|
||||||
return ("token_acc",)
|
return ("token_acc",)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -172,7 +172,8 @@ class Errors(object):
|
||||||
"and satisfies the correct annotations specified in the GoldParse. "
|
"and satisfies the correct annotations specified in the GoldParse. "
|
||||||
"For example, are all labels added to the model? If you're "
|
"For example, are all labels added to the model? If you're "
|
||||||
"training a named entity recognizer, also make sure that none of "
|
"training a named entity recognizer, also make sure that none of "
|
||||||
"your annotated entity spans have leading or trailing whitespace. "
|
"your annotated entity spans have leading or trailing whitespace "
|
||||||
|
"or punctuation. "
|
||||||
"You can also use the experimental `debug-data` command to "
|
"You can also use the experimental `debug-data` command to "
|
||||||
"validate your JSON-formatted training data. For details, run:\n"
|
"validate your JSON-formatted training data. For details, run:\n"
|
||||||
"python -m spacy debug-data --help")
|
"python -m spacy debug-data --help")
|
||||||
|
|
|
@ -17,6 +17,17 @@ _tamil = r"\u0B80-\u0BFF"
|
||||||
|
|
||||||
_telugu = r"\u0C00-\u0C7F"
|
_telugu = r"\u0C00-\u0C7F"
|
||||||
|
|
||||||
|
# from the final table in: https://en.wikipedia.org/wiki/CJK_Unified_Ideographs
|
||||||
|
_cjk = (
|
||||||
|
r"\u4E00-\u62FF\u6300-\u77FF\u7800-\u8CFF\u8D00-\u9FFF\u3400-\u4DBF"
|
||||||
|
r"\U00020000-\U000215FF\U00021600-\U000230FF\U00023100-\U000245FF"
|
||||||
|
r"\U00024600-\U000260FF\U00026100-\U000275FF\U00027600-\U000290FF"
|
||||||
|
r"\U00029100-\U0002A6DF\U0002A700-\U0002B73F\U0002B740-\U0002B81F"
|
||||||
|
r"\U0002B820-\U0002CEAF\U0002CEB0-\U0002EBEF\u2E80-\u2EFF\u2F00-\u2FDF"
|
||||||
|
r"\u2FF0-\u2FFF\u3000-\u303F\u31C0-\u31EF\u3200-\u32FF\u3300-\u33FF"
|
||||||
|
r"\uF900-\uFAFF\uFE30-\uFE4F\U0001F200-\U0001F2FF\U0002F800-\U0002FA1F"
|
||||||
|
)
|
||||||
|
|
||||||
# Latin standard
|
# Latin standard
|
||||||
_latin_u_standard = r"A-Z"
|
_latin_u_standard = r"A-Z"
|
||||||
_latin_l_standard = r"a-z"
|
_latin_l_standard = r"a-z"
|
||||||
|
@ -215,6 +226,7 @@ _uncased = (
|
||||||
+ _tamil
|
+ _tamil
|
||||||
+ _telugu
|
+ _telugu
|
||||||
+ _hangul
|
+ _hangul
|
||||||
|
+ _cjk
|
||||||
)
|
)
|
||||||
|
|
||||||
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
|
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
|
||||||
|
|
|
@ -3,14 +3,18 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from .char_classes import ALPHA_LOWER
|
||||||
from ..symbols import ORTH, POS, TAG, LEMMA, SPACE
|
from ..symbols import ORTH, POS, TAG, LEMMA, SPACE
|
||||||
|
|
||||||
|
|
||||||
# URL validation regex courtesy of: https://mathiasbynens.be/demo/url-regex
|
# URL validation regex courtesy of: https://mathiasbynens.be/demo/url-regex
|
||||||
# A few minor mods to this regex to account for use cases represented in test_urls
|
# and https://gist.github.com/dperini/729294 (Diego Perini, MIT License)
|
||||||
|
# A few mods to this regex to account for use cases represented in test_urls
|
||||||
URL_PATTERN = (
|
URL_PATTERN = (
|
||||||
|
# fmt: off
|
||||||
r"^"
|
r"^"
|
||||||
# protocol identifier (see: https://www.iana.org/assignments/uri-schemes/uri-schemes.xhtml)
|
# protocol identifier (mods: make optional and expand schemes)
|
||||||
|
# (see: https://www.iana.org/assignments/uri-schemes/uri-schemes.xhtml)
|
||||||
r"(?:(?:[\w\+\-\.]{2,})://)?"
|
r"(?:(?:[\w\+\-\.]{2,})://)?"
|
||||||
# mailto:user or user:pass authentication
|
# mailto:user or user:pass authentication
|
||||||
r"(?:\S+(?::\S*)?@)?"
|
r"(?:\S+(?::\S*)?@)?"
|
||||||
|
@ -31,18 +35,27 @@ URL_PATTERN = (
|
||||||
r"(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5])){2}"
|
r"(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5])){2}"
|
||||||
r"(?:\.(?:[1-9]\d?|1\d\d|2[0-4]\d|25[0-4]))"
|
r"(?:\.(?:[1-9]\d?|1\d\d|2[0-4]\d|25[0-4]))"
|
||||||
r"|"
|
r"|"
|
||||||
# host name
|
# host & domain names
|
||||||
r"(?:(?:[a-z0-9\-]*)?[a-z0-9]+)"
|
# mods: match is case-sensitive, so include [A-Z]
|
||||||
# domain name
|
"(?:"
|
||||||
r"(?:\.(?:[a-z0-9])(?:[a-z0-9\-])*[a-z0-9])?"
|
"(?:"
|
||||||
|
"[A-Za-z0-9\u00a1-\uffff]"
|
||||||
|
"[A-Za-z0-9\u00a1-\uffff_-]{0,62}"
|
||||||
|
")?"
|
||||||
|
"[A-Za-z0-9\u00a1-\uffff]\."
|
||||||
|
")+"
|
||||||
# TLD identifier
|
# TLD identifier
|
||||||
r"(?:\.(?:[a-z]{2,}))"
|
# mods: use ALPHA_LOWER instead of a wider range so that this doesn't match
|
||||||
|
# strings like "lower.Upper", which can be split on "." by infixes in some
|
||||||
|
# languages
|
||||||
|
r"(?:[" + ALPHA_LOWER + "]{2,63})"
|
||||||
r")"
|
r")"
|
||||||
# port number
|
# port number
|
||||||
r"(?::\d{2,5})?"
|
r"(?::\d{2,5})?"
|
||||||
# resource path
|
# resource path
|
||||||
r"(?:[/?#]\S*)?"
|
r"(?:[/?#]\S*)?"
|
||||||
r"$"
|
r"$"
|
||||||
|
# fmt: on
|
||||||
).strip()
|
).strip()
|
||||||
|
|
||||||
TOKEN_MATCH = re.compile(URL_PATTERN, re.UNICODE).match
|
TOKEN_MATCH = re.compile(URL_PATTERN, re.UNICODE).match
|
||||||
|
|
|
@ -780,7 +780,7 @@ class Language(object):
|
||||||
|
|
||||||
pipes = (
|
pipes = (
|
||||||
[]
|
[]
|
||||||
) # contains functools.partial objects so that easily create multiprocess worker.
|
) # contains functools.partial objects to easily create multiprocess worker.
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if name in disable:
|
if name in disable:
|
||||||
continue
|
continue
|
||||||
|
@ -837,7 +837,7 @@ class Language(object):
|
||||||
texts, raw_texts = itertools.tee(texts)
|
texts, raw_texts = itertools.tee(texts)
|
||||||
# for sending texts to worker
|
# for sending texts to worker
|
||||||
texts_q = [mp.Queue() for _ in range(n_process)]
|
texts_q = [mp.Queue() for _ in range(n_process)]
|
||||||
# for receiving byte encoded docs from worker
|
# for receiving byte-encoded docs from worker
|
||||||
bytedocs_recv_ch, bytedocs_send_ch = zip(
|
bytedocs_recv_ch, bytedocs_send_ch = zip(
|
||||||
*[mp.Pipe(False) for _ in range(n_process)]
|
*[mp.Pipe(False) for _ in range(n_process)]
|
||||||
)
|
)
|
||||||
|
@ -847,7 +847,7 @@ class Language(object):
|
||||||
# This is necessary to properly handle infinite length of texts.
|
# This is necessary to properly handle infinite length of texts.
|
||||||
# (In this case, all data cannot be sent to the workers at once)
|
# (In this case, all data cannot be sent to the workers at once)
|
||||||
sender = _Sender(batch_texts, texts_q, chunk_size=n_process)
|
sender = _Sender(batch_texts, texts_q, chunk_size=n_process)
|
||||||
# send twice so that make process busy
|
# send twice to make process busy
|
||||||
sender.send()
|
sender.send()
|
||||||
sender.send()
|
sender.send()
|
||||||
|
|
||||||
|
@ -859,7 +859,7 @@ class Language(object):
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|
||||||
# Cycle channels not to break the order of docs.
|
# Cycle channels not to break the order of docs.
|
||||||
# The received object is batch of byte encoded docs, so flatten them with chain.from_iterable.
|
# The received object is a batch of byte-encoded docs, so flatten them with chain.from_iterable.
|
||||||
byte_docs = chain.from_iterable(recv.recv() for recv in cycle(bytedocs_recv_ch))
|
byte_docs = chain.from_iterable(recv.recv() for recv in cycle(bytedocs_recv_ch))
|
||||||
docs = (Doc(self.vocab).from_bytes(byte_doc) for byte_doc in byte_docs)
|
docs = (Doc(self.vocab).from_bytes(byte_doc) for byte_doc in byte_docs)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -129,20 +129,31 @@ class EntityRuler(object):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entityruler#labels
|
DOCS: https://spacy.io/api/entityruler#labels
|
||||||
"""
|
"""
|
||||||
all_labels = set(self.token_patterns.keys())
|
keys = set(self.token_patterns.keys())
|
||||||
all_labels.update(self.phrase_patterns.keys())
|
keys.update(self.phrase_patterns.keys())
|
||||||
|
all_labels = set()
|
||||||
|
|
||||||
|
for l in keys:
|
||||||
|
if self.ent_id_sep in l:
|
||||||
|
label, _ = self._split_label(l)
|
||||||
|
all_labels.add(label)
|
||||||
|
else:
|
||||||
|
all_labels.add(l)
|
||||||
return tuple(all_labels)
|
return tuple(all_labels)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ent_ids(self):
|
def ent_ids(self):
|
||||||
"""All entity ids present in the match patterns `id` properties.
|
"""All entity ids present in the match patterns `id` properties
|
||||||
|
|
||||||
RETURNS (set): The string entity ids.
|
RETURNS (set): The string entity ids.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entityruler#ent_ids
|
DOCS: https://spacy.io/api/entityruler#ent_ids
|
||||||
"""
|
"""
|
||||||
|
keys = set(self.token_patterns.keys())
|
||||||
|
keys.update(self.phrase_patterns.keys())
|
||||||
all_ent_ids = set()
|
all_ent_ids = set()
|
||||||
for l in self.labels:
|
|
||||||
|
for l in keys:
|
||||||
if self.ent_id_sep in l:
|
if self.ent_id_sep in l:
|
||||||
_, ent_id = self._split_label(l)
|
_, ent_id = self._split_label(l)
|
||||||
all_ent_ids.add(ent_id)
|
all_ent_ids.add(ent_id)
|
||||||
|
|
|
@ -1308,7 +1308,7 @@ class EntityLinker(Pipe):
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
if len(doc) > 0:
|
if len(doc) > 0:
|
||||||
# Looping through each sentence and each entity
|
# Looping through each sentence and each entity
|
||||||
# This may go wrong if there are entities across sentences - because they might not get a KB ID
|
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
|
||||||
for sent in doc.sents:
|
for sent in doc.sents:
|
||||||
sent_doc = sent.as_doc()
|
sent_doc = sent.as_doc()
|
||||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
|
|
|
@ -462,3 +462,4 @@ cdef enum symbol_t:
|
||||||
acl
|
acl
|
||||||
|
|
||||||
ENT_KB_ID
|
ENT_KB_ID
|
||||||
|
ENT_ID
|
||||||
|
|
|
@ -86,6 +86,7 @@ IDS = {
|
||||||
"DEP": DEP,
|
"DEP": DEP,
|
||||||
"ENT_IOB": ENT_IOB,
|
"ENT_IOB": ENT_IOB,
|
||||||
"ENT_TYPE": ENT_TYPE,
|
"ENT_TYPE": ENT_TYPE,
|
||||||
|
"ENT_ID": ENT_ID,
|
||||||
"ENT_KB_ID": ENT_KB_ID,
|
"ENT_KB_ID": ENT_KB_ID,
|
||||||
"HEAD": HEAD,
|
"HEAD": HEAD,
|
||||||
"SENT_START": SENT_START,
|
"SENT_START": SENT_START,
|
||||||
|
|
|
@ -57,7 +57,7 @@ cdef class Parser:
|
||||||
subword_features = util.env_opt('subword_features',
|
subword_features = util.env_opt('subword_features',
|
||||||
cfg.get('subword_features', True))
|
cfg.get('subword_features', True))
|
||||||
conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4))
|
conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4))
|
||||||
conv_window = util.env_opt('conv_window', cfg.get('conv_depth', 1))
|
conv_window = util.env_opt('conv_window', cfg.get('conv_window', 1))
|
||||||
t2v_pieces = util.env_opt('cnn_maxout_pieces', cfg.get('cnn_maxout_pieces', 3))
|
t2v_pieces = util.env_opt('cnn_maxout_pieces', cfg.get('cnn_maxout_pieces', 3))
|
||||||
bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0))
|
bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0))
|
||||||
self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0))
|
self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0))
|
||||||
|
|
|
@ -296,9 +296,8 @@ WIKI_TESTS = [
|
||||||
("cérium(IV)-oxid", ["cérium", "(", "IV", ")", "-oxid"]),
|
("cérium(IV)-oxid", ["cérium", "(", "IV", ")", "-oxid"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
TESTCASES = (
|
EXTRA_TESTS = (
|
||||||
DEFAULT_TESTS
|
DOT_TESTS
|
||||||
+ DOT_TESTS
|
|
||||||
+ QUOTE_TESTS
|
+ QUOTE_TESTS
|
||||||
+ NUMBER_TESTS
|
+ NUMBER_TESTS
|
||||||
+ HYPHEN_TESTS
|
+ HYPHEN_TESTS
|
||||||
|
@ -306,8 +305,16 @@ TESTCASES = (
|
||||||
+ TYPO_TESTS
|
+ TYPO_TESTS
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# normal: default tests + 10% of extra tests
|
||||||
|
TESTS = DEFAULT_TESTS
|
||||||
|
TESTS.extend([x for i, x in enumerate(EXTRA_TESTS) if i % 10 == 0])
|
||||||
|
|
||||||
@pytest.mark.parametrize("text,expected_tokens", TESTCASES)
|
# slow: remaining 90% of extra tests
|
||||||
|
SLOW_TESTS = [x for i, x in enumerate(EXTRA_TESTS) if i % 10 != 0]
|
||||||
|
TESTS.extend([pytest.param(x[0], x[1], marks=pytest.mark.slow()) if not isinstance(x[0], tuple) else x for x in SLOW_TESTS])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("text,expected_tokens", TESTS)
|
||||||
def test_hu_tokenizer_handles_testcases(hu_tokenizer, text, expected_tokens):
|
def test_hu_tokenizer_handles_testcases(hu_tokenizer, text, expected_tokens):
|
||||||
tokens = hu_tokenizer(text)
|
tokens = hu_tokenizer(text)
|
||||||
token_list = [token.text for token in tokens if not token.is_space]
|
token_list = [token.text for token in tokens if not token.is_space]
|
||||||
|
|
|
@ -44,15 +44,15 @@ TYPOS_IN_PUNC_TESTS = [
|
||||||
|
|
||||||
LONG_TEXTS_TESTS = [
|
LONG_TEXTS_TESTS = [
|
||||||
(
|
(
|
||||||
"Иң борынгы кешеләр суыклар һәм салкын кышлар булмый торган җылы"
|
"Иң борынгы кешеләр суыклар һәм салкын кышлар булмый торган җылы "
|
||||||
"якларда яшәгәннәр, шуңа күрә аларга кием кирәк булмаган.Йөз"
|
"якларда яшәгәннәр, шуңа күрә аларга кием кирәк булмаган.Йөз "
|
||||||
"меңнәрчә еллар үткән, борынгы кешеләр акрынлап Европа һәм Азиянең"
|
"меңнәрчә еллар үткән, борынгы кешеләр акрынлап Европа һәм Азиянең "
|
||||||
"салкын илләрендә дә яши башлаганнар. Алар кырыс һәм салкын"
|
"салкын илләрендә дә яши башлаганнар. Алар кырыс һәм салкын "
|
||||||
"кышлардан саклану өчен кием-салым уйлап тапканнар - итәк.",
|
"кышлардан саклану өчен кием-салым уйлап тапканнар - итәк.",
|
||||||
"Иң борынгы кешеләр суыклар һәм салкын кышлар булмый торган җылы"
|
"Иң борынгы кешеләр суыклар һәм салкын кышлар булмый торган җылы "
|
||||||
"якларда яшәгәннәр , шуңа күрә аларга кием кирәк булмаган . Йөз"
|
"якларда яшәгәннәр , шуңа күрә аларга кием кирәк булмаган . Йөз "
|
||||||
"меңнәрчә еллар үткән , борынгы кешеләр акрынлап Европа һәм Азиянең"
|
"меңнәрчә еллар үткән , борынгы кешеләр акрынлап Европа һәм Азиянең "
|
||||||
"салкын илләрендә дә яши башлаганнар . Алар кырыс һәм салкын"
|
"салкын илләрендә дә яши башлаганнар . Алар кырыс һәм салкын "
|
||||||
"кышлардан саклану өчен кием-салым уйлап тапканнар - итәк .".split(),
|
"кышлардан саклану өчен кием-салым уйлап тапканнар - итәк .".split(),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
|
@ -21,6 +21,7 @@ def patterns():
|
||||||
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
|
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
|
||||||
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
|
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
|
||||||
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
|
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
|
||||||
|
{"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,3 +148,14 @@ def test_entity_ruler_validate(nlp):
|
||||||
# invalid pattern raises error with validate
|
# invalid pattern raises error with validate
|
||||||
with pytest.raises(MatchPatternError):
|
with pytest.raises(MatchPatternError):
|
||||||
validated_ruler.add_patterns([invalid_pattern])
|
validated_ruler.add_patterns([invalid_pattern])
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_ruler_properties(nlp, patterns):
|
||||||
|
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
||||||
|
assert sorted(ruler.labels) == sorted([
|
||||||
|
"HELLO",
|
||||||
|
"BYE",
|
||||||
|
"COMPLEX",
|
||||||
|
"TECH_ORG"
|
||||||
|
])
|
||||||
|
assert sorted(ruler.ent_ids) == ["a1", "a2"]
|
||||||
|
|
36
spacy/tests/regression/test_issue4849.py
Normal file
36
spacy/tests/regression/test_issue4849.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.pipeline import EntityRuler
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue4849():
|
||||||
|
nlp = English()
|
||||||
|
|
||||||
|
ruler = EntityRuler(
|
||||||
|
nlp, patterns=[
|
||||||
|
{"label": "PERSON", "pattern": 'joe biden', "id": 'joe-biden'},
|
||||||
|
{"label": "PERSON", "pattern": 'bernie sanders', "id": 'bernie-sanders'},
|
||||||
|
],
|
||||||
|
phrase_matcher_attr="LOWER"
|
||||||
|
)
|
||||||
|
|
||||||
|
nlp.add_pipe(ruler)
|
||||||
|
|
||||||
|
text = """
|
||||||
|
The left is starting to take aim at Democratic front-runner Joe Biden.
|
||||||
|
Sen. Bernie Sanders joined in her criticism: "There is no 'middle ground' when it comes to climate policy."
|
||||||
|
"""
|
||||||
|
|
||||||
|
# USING 1 PROCESS
|
||||||
|
count_ents = 0
|
||||||
|
for doc in nlp.pipe([text], n_process=1):
|
||||||
|
count_ents += len([ent for ent in doc.ents if ent.ent_id > 0])
|
||||||
|
assert(count_ents == 2)
|
||||||
|
|
||||||
|
# USING 2 PROCESSES
|
||||||
|
count_ents = 0
|
||||||
|
for doc in nlp.pipe([text], n_process=2):
|
||||||
|
count_ents += len([ent for ent in doc.ents if ent.ent_id > 0])
|
||||||
|
assert (count_ents == 2)
|
|
@ -2,7 +2,7 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc, Token
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,6 +15,10 @@ def doc_w_attrs(en_tokenizer):
|
||||||
)
|
)
|
||||||
doc = en_tokenizer("This is a test.")
|
doc = en_tokenizer("This is a test.")
|
||||||
doc._._test_attr = "test"
|
doc._._test_attr = "test"
|
||||||
|
|
||||||
|
Token.set_extension("_test_token", default="t0")
|
||||||
|
doc[1]._._test_token = "t1"
|
||||||
|
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,3 +29,7 @@ def test_serialize_ext_attrs_from_bytes(doc_w_attrs):
|
||||||
assert doc._._test_attr == "test"
|
assert doc._._test_attr == "test"
|
||||||
assert doc._._test_prop == len(doc.text)
|
assert doc._._test_prop == len(doc.text)
|
||||||
assert doc._._test_method("test") == "{}{}".format(len(doc.text), "test")
|
assert doc._._test_method("test") == "{}{}".format(len(doc.text), "test")
|
||||||
|
|
||||||
|
assert doc[0]._._test_token == "t0"
|
||||||
|
assert doc[1]._._test_token == "t1"
|
||||||
|
assert doc[2]._._test_token == "t0"
|
||||||
|
|
|
@ -20,6 +20,7 @@ URLS_FULL = URLS_BASIC + [
|
||||||
# URL SHOULD_MATCH and SHOULD_NOT_MATCH patterns courtesy of https://mathiasbynens.be/demo/url-regex
|
# URL SHOULD_MATCH and SHOULD_NOT_MATCH patterns courtesy of https://mathiasbynens.be/demo/url-regex
|
||||||
URLS_SHOULD_MATCH = [
|
URLS_SHOULD_MATCH = [
|
||||||
"http://foo.com/blah_blah",
|
"http://foo.com/blah_blah",
|
||||||
|
"http://BlahBlah.com/Blah_Blah",
|
||||||
"http://foo.com/blah_blah/",
|
"http://foo.com/blah_blah/",
|
||||||
"http://www.example.com/wpstyle/?p=364",
|
"http://www.example.com/wpstyle/?p=364",
|
||||||
"https://www.example.com/foo/?bar=baz&inga=42&quux",
|
"https://www.example.com/foo/?bar=baz&inga=42&quux",
|
||||||
|
@ -57,14 +58,17 @@ URLS_SHOULD_MATCH = [
|
||||||
),
|
),
|
||||||
"http://foo.com/blah_blah_(wikipedia)",
|
"http://foo.com/blah_blah_(wikipedia)",
|
||||||
"http://foo.com/blah_blah_(wikipedia)_(again)",
|
"http://foo.com/blah_blah_(wikipedia)_(again)",
|
||||||
pytest.param("http://⌘.ws", marks=pytest.mark.xfail()),
|
"http://www.foo.co.uk",
|
||||||
pytest.param("http://⌘.ws/", marks=pytest.mark.xfail()),
|
"http://www.foo.co.uk/",
|
||||||
pytest.param("http://☺.damowmow.com/", marks=pytest.mark.xfail()),
|
"http://www.foo.co.uk/blah/blah",
|
||||||
pytest.param("http://✪df.ws/123", marks=pytest.mark.xfail()),
|
"http://⌘.ws",
|
||||||
pytest.param("http://➡.ws/䨹", marks=pytest.mark.xfail()),
|
"http://⌘.ws/",
|
||||||
pytest.param("http://مثال.إختبار", marks=pytest.mark.xfail()),
|
"http://☺.damowmow.com/",
|
||||||
pytest.param("http://例子.测试", marks=pytest.mark.xfail()),
|
"http://✪df.ws/123",
|
||||||
pytest.param("http://उदाहरण.परीक्षा", marks=pytest.mark.xfail()),
|
"http://➡.ws/䨹",
|
||||||
|
"http://مثال.إختبار",
|
||||||
|
"http://例子.测试",
|
||||||
|
"http://उदाहरण.परीक्षा",
|
||||||
]
|
]
|
||||||
|
|
||||||
URLS_SHOULD_NOT_MATCH = [
|
URLS_SHOULD_NOT_MATCH = [
|
||||||
|
|
|
@ -23,7 +23,7 @@ from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
||||||
from ..typedefs cimport attr_t, flags_t
|
from ..typedefs cimport attr_t, flags_t
|
||||||
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER
|
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER
|
||||||
from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB
|
from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB
|
||||||
from ..attrs cimport ENT_TYPE, ENT_KB_ID, SENT_START, attr_id_t
|
from ..attrs cimport ENT_TYPE, ENT_ID, ENT_KB_ID, SENT_START, attr_id_t
|
||||||
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
|
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
|
||||||
|
|
||||||
from ..attrs import intify_attrs, IDS
|
from ..attrs import intify_attrs, IDS
|
||||||
|
@ -69,6 +69,8 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
|
||||||
return token.ent_iob
|
return token.ent_iob
|
||||||
elif feat_name == ENT_TYPE:
|
elif feat_name == ENT_TYPE:
|
||||||
return token.ent_type
|
return token.ent_type
|
||||||
|
elif feat_name == ENT_ID:
|
||||||
|
return token.ent_id
|
||||||
elif feat_name == ENT_KB_ID:
|
elif feat_name == ENT_KB_ID:
|
||||||
return token.ent_kb_id
|
return token.ent_kb_id
|
||||||
else:
|
else:
|
||||||
|
@ -868,7 +870,7 @@ cdef class Doc:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/doc#to_bytes
|
DOCS: https://spacy.io/api/doc#to_bytes
|
||||||
"""
|
"""
|
||||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] # TODO: ENT_KB_ID ?
|
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_ID] # TODO: ENT_KB_ID ?
|
||||||
if self.is_tagged:
|
if self.is_tagged:
|
||||||
array_head.extend([TAG, POS])
|
array_head.extend([TAG, POS])
|
||||||
# If doc parsed add head and dep attribute
|
# If doc parsed add head and dep attribute
|
||||||
|
|
|
@ -212,7 +212,7 @@ cdef class Span:
|
||||||
words = [t.text for t in self]
|
words = [t.text for t in self]
|
||||||
spaces = [bool(t.whitespace_) for t in self]
|
spaces = [bool(t.whitespace_) for t in self]
|
||||||
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
|
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
|
||||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_KB_ID]
|
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_ID, ENT_KB_ID]
|
||||||
if self.doc.is_tagged:
|
if self.doc.is_tagged:
|
||||||
array_head.append(TAG)
|
array_head.append(TAG)
|
||||||
# If doc parsed add head and dep attribute
|
# If doc parsed add head and dep attribute
|
||||||
|
|
|
@ -53,6 +53,8 @@ cdef class Token:
|
||||||
return token.ent_iob
|
return token.ent_iob
|
||||||
elif feat_name == ENT_TYPE:
|
elif feat_name == ENT_TYPE:
|
||||||
return token.ent_type
|
return token.ent_type
|
||||||
|
elif feat_name == ENT_ID:
|
||||||
|
return token.ent_id
|
||||||
elif feat_name == ENT_KB_ID:
|
elif feat_name == ENT_KB_ID:
|
||||||
return token.ent_kb_id
|
return token.ent_kb_id
|
||||||
elif feat_name == SENT_START:
|
elif feat_name == SENT_START:
|
||||||
|
@ -81,6 +83,8 @@ cdef class Token:
|
||||||
token.ent_iob = value
|
token.ent_iob = value
|
||||||
elif feat_name == ENT_TYPE:
|
elif feat_name == ENT_TYPE:
|
||||||
token.ent_type = value
|
token.ent_type = value
|
||||||
|
elif feat_name == ENT_ID:
|
||||||
|
token.ent_id = value
|
||||||
elif feat_name == ENT_KB_ID:
|
elif feat_name == ENT_KB_ID:
|
||||||
token.ent_kb_id = value
|
token.ent_kb_id = value
|
||||||
elif feat_name == SENT_START:
|
elif feat_name == SENT_START:
|
||||||
|
|
|
@ -9,7 +9,7 @@ menu:
|
||||||
---
|
---
|
||||||
|
|
||||||
Compared to using regular expressions on raw text, spaCy's rule-based matcher
|
Compared to using regular expressions on raw text, spaCy's rule-based matcher
|
||||||
engines and components not only let you find you the words and phrases you're
|
engines and components not only let you find the words and phrases you're
|
||||||
looking for – they also give you access to the tokens within the document and
|
looking for – they also give you access to the tokens within the document and
|
||||||
their relationships. This means you can easily access and analyze the
|
their relationships. This means you can easily access and analyze the
|
||||||
surrounding tokens, merge spans into single tokens or add entries to the named
|
surrounding tokens, merge spans into single tokens or add entries to the named
|
||||||
|
|
|
@ -229,10 +229,10 @@ For more details on **adding hooks** and **overwriting** the built-in `Doc`,
|
||||||
If you're using a GPU, it's much more efficient to keep the word vectors on the
|
If you're using a GPU, it's much more efficient to keep the word vectors on the
|
||||||
device. You can do that by setting the [`Vectors.data`](/api/vectors#attributes)
|
device. You can do that by setting the [`Vectors.data`](/api/vectors#attributes)
|
||||||
attribute to a `cupy.ndarray` object if you're using spaCy or
|
attribute to a `cupy.ndarray` object if you're using spaCy or
|
||||||
[Chainer]("https://chainer.org"), or a `torch.Tensor` object if you're using
|
[Chainer](https://chainer.org), or a `torch.Tensor` object if you're using
|
||||||
[PyTorch]("http://pytorch.org"). The `data` object just needs to support
|
[PyTorch](http://pytorch.org). The `data` object just needs to support
|
||||||
`__iter__` and `__getitem__`, so if you're using another library such as
|
`__iter__` and `__getitem__`, so if you're using another library such as
|
||||||
[TensorFlow]("https://www.tensorflow.org"), you could also create a wrapper for
|
[TensorFlow](https://www.tensorflow.org), you could also create a wrapper for
|
||||||
your vectors data.
|
your vectors data.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -1509,28 +1509,30 @@
|
||||||
{
|
{
|
||||||
"id": "spacy-conll",
|
"id": "spacy-conll",
|
||||||
"title": "spacy_conll",
|
"title": "spacy_conll",
|
||||||
"slogan": "Parse text with spaCy and print the output in CoNLL-U format",
|
"slogan": "Parse text with spaCy and gets its output in CoNLL-U format",
|
||||||
"description": "This module allows you to parse a text to CoNLL-U format. You can use it as a command line tool, or embed it in your own scripts.",
|
"description": "This module allows you to parse a text to CoNLL-U format. It contains a pipeline component for spaCy that adds CoNLL-U properties to a Doc and its sentences. It can also be used as a command-line tool.",
|
||||||
"code_example": [
|
"code_example": [
|
||||||
"from spacy_conll import Spacy2ConllParser",
|
"import spacy",
|
||||||
"spacyconll = Spacy2ConllParser()",
|
"from spacy_conll import ConllFormatter",
|
||||||
"",
|
"",
|
||||||
"# `parse` returns a generator of the parsed sentences",
|
"nlp = spacy.load('en')",
|
||||||
"for parsed_sent in spacyconll.parse(input_str='I like cookies.\nWhat about you?\nI don't like 'em!'):",
|
"conllformatter = ConllFormatter(nlp)",
|
||||||
" do_something_(parsed_sent)",
|
"nlp.add_pipe(conllformatter, after='parser')",
|
||||||
"",
|
"doc = nlp('I like cookies. Do you?')",
|
||||||
"# `parseprint` prints output to stdout (default) or a file (use `output_file` parameter)",
|
"conll = doc._.conll",
|
||||||
"# This method is called when using the command line",
|
"print(doc._.conll_str_headers)",
|
||||||
"spacyconll.parseprint(input_str='I like cookies.')"
|
"print(doc._.conll_str)"
|
||||||
],
|
],
|
||||||
"code_language": "python",
|
"code_language": "python",
|
||||||
"author": "Bram Vanroy",
|
"author": "Bram Vanroy",
|
||||||
"author_links": {
|
"author_links": {
|
||||||
"github": "BramVanroy",
|
"github": "BramVanroy",
|
||||||
|
"twitter": "BramVanroy",
|
||||||
"website": "https://bramvanroy.be"
|
"website": "https://bramvanroy.be"
|
||||||
},
|
},
|
||||||
"github": "BramVanroy/spacy_conll",
|
"github": "BramVanroy/spacy_conll",
|
||||||
"category": ["standalone"]
|
"category": ["standalone", "pipeline"],
|
||||||
|
"tags": ["linguistics", "computational linguistics", "conll"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "spacy-langdetect",
|
"id": "spacy-langdetect",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user