mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-03 22:06:37 +03:00
Merge branch 'master' into spacy.io
This commit is contained in:
commit
fc88337cfa
|
@ -7,16 +7,17 @@ Run `wikipedia_pretrain_kb.py`
|
|||
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
|
||||
* You can set the filtering parameters for KB construction:
|
||||
* `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
|
||||
* `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
|
||||
* `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
||||
* `max_per_alias` (`-a`): (max) number of candidate entities in the KB per alias/synonym
|
||||
* `min_freq` (`-f`): threshold of number of times an entity should occur in the corpus to be included in the KB
|
||||
* `min_pair` (`-c`): threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
||||
* Further parameters to set:
|
||||
* `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
||||
* `entity_vector_length`: length of the pre-trained entity description vectors
|
||||
* `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
|
||||
* `descriptions_from_wikipedia` (`-wp`): whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
||||
* `entity_vector_length` (`-v`): length of the pre-trained entity description vectors
|
||||
* `lang` (`-la`): language for which to fetch Wikidata information (as the dump contains all languages)
|
||||
|
||||
Quick testing and rerunning:
|
||||
* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.
|
||||
* When trying out the pipeline for a quick test, set `limit_prior` (`-lp`), `limit_train` (`-lt`) and/or `limit_wd` (`-lw`) to read only parts of the dumps instead of everything.
|
||||
* e.g. set `-lt 20000 -lp 2000 -lw 3000 -f 1`
|
||||
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
|
||||
|
||||
|
||||
|
@ -24,11 +25,13 @@ Quick testing and rerunning:
|
|||
|
||||
Run `wikidata_train_entity_linker.py`
|
||||
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
|
||||
* Specify the output directory (`-o`) in which the final, trained model will be saved
|
||||
* You can set the learning parameters for the EL training:
|
||||
* `epochs`: number of training iterations
|
||||
* `dropout`: dropout rate
|
||||
* `lr`: learning rate
|
||||
* `l2`: L2 regularization
|
||||
* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
|
||||
* `epochs` (`-e`): number of training iterations
|
||||
* `dropout` (`-p`): dropout rate
|
||||
* `lr` (`-n`): learning rate
|
||||
* `l2` (`-r`): L2 regularization
|
||||
* Specify the number of training and dev testing articles with `train_articles` (`-t`) and `dev_articles` (`-d`) respectively
|
||||
* If not specified, the full dataset will be processed - this may take a LONG time !
|
||||
* Further parameters to set:
|
||||
* `labels_discard`: NER label types to discard during training
|
||||
* `labels_discard` (`-l`): NER label types to discard during training
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from tqdm import tqdm
|
||||
from collections import defaultdict
|
||||
|
||||
|
@ -92,102 +94,81 @@ class BaselineResults(object):
|
|||
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
||||
|
||||
|
||||
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
|
||||
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True, dev_limit=None):
|
||||
counts = dict()
|
||||
baseline_results = BaselineResults()
|
||||
context_results = EvaluationResults()
|
||||
combo_results = EvaluationResults()
|
||||
|
||||
for doc, gold in tqdm(dev_data, total=dev_limit, leave=False, desc='Processing dev data'):
|
||||
if len(doc) > 0:
|
||||
correct_ents = dict()
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
for gold_kb, value in kb_dict.items():
|
||||
if value:
|
||||
# only evaluating on positive examples
|
||||
offset = _offset(start, end)
|
||||
correct_ents[offset] = gold_kb
|
||||
|
||||
if baseline:
|
||||
baseline_accuracies, counts = measure_baselines(dev_data, kb)
|
||||
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
||||
logger.info(baseline_accuracies.report_performance("random"))
|
||||
logger.info(baseline_accuracies.report_performance("prior"))
|
||||
logger.info(baseline_accuracies.report_performance("oracle"))
|
||||
_add_baseline(baseline_results, counts, doc, correct_ents, kb)
|
||||
|
||||
if context:
|
||||
# using only context
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = False
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context only"))
|
||||
_add_eval_result(context_results, doc, correct_ents, el_pipe)
|
||||
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = True
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context and prior"))
|
||||
_add_eval_result(combo_results, doc, correct_ents, el_pipe)
|
||||
|
||||
if baseline:
|
||||
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
||||
logger.info(baseline_results.report_performance("random"))
|
||||
logger.info(baseline_results.report_performance("prior"))
|
||||
logger.info(baseline_results.report_performance("oracle"))
|
||||
|
||||
if context:
|
||||
logger.info(context_results.report_metrics("context only"))
|
||||
logger.info(combo_results.report_metrics("context and prior"))
|
||||
|
||||
|
||||
def get_eval_results(data, el_pipe=None):
|
||||
def _add_eval_result(results, doc, correct_ents, el_pipe):
|
||||
"""
|
||||
Evaluate the ent.kb_id_ annotations against the gold standard.
|
||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||
If the docs in the data require further processing with an entity linker, set el_pipe.
|
||||
"""
|
||||
docs = []
|
||||
golds = []
|
||||
for d, g in tqdm(data, leave=False):
|
||||
if len(d) > 0:
|
||||
golds.append(g)
|
||||
if el_pipe is not None:
|
||||
docs.append(el_pipe(d))
|
||||
else:
|
||||
docs.append(d)
|
||||
|
||||
results = EvaluationResults()
|
||||
for doc, gold in zip(docs, golds):
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
for gold_kb, value in kb_dict.items():
|
||||
if value:
|
||||
# only evaluating on positive examples
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
|
||||
doc = el_pipe(doc)
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
pred_entity = ent.kb_id_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
offset = _offset(start, end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
gold_entity = correct_ents.get(offset, None)
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
pred_entity = ent.kb_id_
|
||||
results.update_metrics(ent_label, gold_entity, pred_entity)
|
||||
|
||||
except Exception as e:
|
||||
logging.error("Error assessing accuracy " + str(e))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def measure_baselines(data, kb):
|
||||
def _add_baseline(baseline_results, counts, doc, correct_ents, kb):
|
||||
"""
|
||||
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
|
||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||
Also return a dictionary of counts by entity label.
|
||||
"""
|
||||
counts_d = dict()
|
||||
|
||||
baseline_results = BaselineResults()
|
||||
|
||||
docs = [d for d, g in data if len(d) > 0]
|
||||
golds = [g for d, g in data if len(d) > 0]
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
correct_entries_per_article = dict()
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
for gold_kb, value in kb_dict.items():
|
||||
# only evaluating on positive examples
|
||||
if value:
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
offset = _offset(start, end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
gold_entity = correct_ents.get(offset, None)
|
||||
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
|
@ -207,8 +188,8 @@ def measure_baselines(data, kb):
|
|||
prior_candidate = candidates[best_index].entity_
|
||||
random_candidate = random.choice(candidates).entity_
|
||||
|
||||
current_count = counts_d.get(ent_label, 0)
|
||||
counts_d[ent_label] = current_count+1
|
||||
current_count = counts.get(ent_label, 0)
|
||||
counts[ent_label] = current_count+1
|
||||
|
||||
baseline_results.update_baselines(
|
||||
gold_entity,
|
||||
|
@ -218,8 +199,6 @@ def measure_baselines(data, kb):
|
|||
oracle_candidate,
|
||||
)
|
||||
|
||||
return baseline_results, counts_d
|
||||
|
||||
|
||||
def _offset(start, end):
|
||||
return "{}_{}".format(start, end)
|
||||
|
|
|
@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
|
|||
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
|
||||
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
|
||||
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
|
||||
descr_from_wp=("Flag for using wp descriptions not wd", "flag", "wp"),
|
||||
descr_from_wp=("Flag for using descriptions from WP instead of WD (default False)", "flag", "wp"),
|
||||
limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int),
|
||||
limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int),
|
||||
limit_wd=("Threshold to limit lines read from WD", "option", "lw", int),
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding: utf-8
|
||||
"""Script to take a previously created Knowledge Base and train an entity linking
|
||||
"""Script that takes a previously created Knowledge Base and trains an entity linking
|
||||
pipeline. The provided KB directory should hold the kb, the original nlp object and
|
||||
its vocab used to create the KB, and a few auxiliary files such as the entity definitions,
|
||||
as created by the script `wikidata_create_kb`.
|
||||
|
@ -14,6 +14,7 @@ import logging
|
|||
import spacy
|
||||
from pathlib import Path
|
||||
import plac
|
||||
from tqdm import tqdm
|
||||
|
||||
from bin.wiki_entity_linking import wikipedia_processor
|
||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
||||
|
@ -33,8 +34,8 @@ logger = logging.getLogger(__name__)
|
|||
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
||||
lr=("Learning rate (default 0.005)", "option", "n", float),
|
||||
l2=("L2 regularization", "option", "r", float),
|
||||
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
||||
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
||||
train_articles=("# training articles (default 90% of all)", "option", "t", int),
|
||||
dev_articles=("# dev test articles (default 10% of all)", "option", "d", int),
|
||||
labels_discard=("NER labels to discard (default None)", "option", "l", str),
|
||||
)
|
||||
def main(
|
||||
|
@ -45,10 +46,13 @@ def main(
|
|||
dropout=0.5,
|
||||
lr=0.005,
|
||||
l2=1e-6,
|
||||
train_inst=None,
|
||||
dev_inst=None,
|
||||
train_articles=None,
|
||||
dev_articles=None,
|
||||
labels_discard=None
|
||||
):
|
||||
if not output_dir:
|
||||
logger.warning("No output dir specified so no results will be written, are you sure about this ?")
|
||||
|
||||
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
||||
|
||||
output_dir = Path(output_dir) if output_dir else dir_kb
|
||||
|
@ -64,44 +68,33 @@ def main(
|
|||
# STEP 1 : load the NLP object
|
||||
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
|
||||
nlp = spacy.load(nlp_dir)
|
||||
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
||||
kb = read_kb(nlp, kb_path)
|
||||
logger.info("Original NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
|
||||
|
||||
# check that there is a NER component in the pipeline
|
||||
if "ner" not in nlp.pipe_names:
|
||||
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
|
||||
|
||||
# STEP 2: read the training dataset previously created from WP
|
||||
logger.info("STEP 2: Reading training dataset from {}".format(training_path))
|
||||
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
||||
kb = read_kb(nlp, kb_path)
|
||||
|
||||
# STEP 2: read the training dataset previously created from WP
|
||||
logger.info("STEP 2: Reading training & dev dataset from {}".format(training_path))
|
||||
train_indices, dev_indices = wikipedia_processor.read_training_indices(training_path)
|
||||
logger.info("Training set has {} articles, limit set to roughly {} articles per epoch"
|
||||
.format(len(train_indices), train_articles if train_articles else "all"))
|
||||
logger.info("Dev set has {} articles, limit set to rougly {} articles for evaluation"
|
||||
.format(len(dev_indices), dev_articles if dev_articles else "all"))
|
||||
if dev_articles:
|
||||
dev_indices = dev_indices[0:dev_articles]
|
||||
|
||||
# STEP 3: create and train an entity linking pipe
|
||||
logger.info("STEP 3: Creating and training an Entity Linking pipe for {} epochs".format(epochs))
|
||||
if labels_discard:
|
||||
labels_discard = [x.strip() for x in labels_discard.split(",")]
|
||||
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
|
||||
else:
|
||||
labels_discard = []
|
||||
|
||||
train_data = wikipedia_processor.read_training(
|
||||
nlp=nlp,
|
||||
entity_file_path=training_path,
|
||||
dev=False,
|
||||
limit=train_inst,
|
||||
kb=kb,
|
||||
labels_discard=labels_discard
|
||||
)
|
||||
|
||||
# for testing, get all pos instances (independently of KB)
|
||||
dev_data = wikipedia_processor.read_training(
|
||||
nlp=nlp,
|
||||
entity_file_path=training_path,
|
||||
dev=True,
|
||||
limit=dev_inst,
|
||||
kb=None,
|
||||
labels_discard=labels_discard
|
||||
)
|
||||
|
||||
# STEP 3: create and train an entity linking pipe
|
||||
logger.info("STEP 3: Creating and training an Entity Linking pipe")
|
||||
|
||||
el_pipe = nlp.create_pipe(
|
||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
|
||||
"labels_discard": labels_discard}
|
||||
|
@ -115,23 +108,35 @@ def main(
|
|||
optimizer.learn_rate = lr
|
||||
optimizer.L2 = l2
|
||||
|
||||
logger.info("Training on {} articles".format(len(train_data)))
|
||||
logger.info("Dev testing on {} articles".format(len(dev_data)))
|
||||
|
||||
# baseline performance on dev data
|
||||
logger.info("Dev Baseline Accuracies:")
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
|
||||
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||
dev=True, line_ids=dev_indices,
|
||||
kb=kb, labels_discard=labels_discard)
|
||||
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False, dev_limit=len(dev_indices))
|
||||
|
||||
for itn in range(epochs):
|
||||
random.shuffle(train_data)
|
||||
random.shuffle(train_indices)
|
||||
losses = {}
|
||||
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
||||
batches = minibatch(train_indices, size=compounding(8.0, 128.0, 1.001))
|
||||
batchnr = 0
|
||||
articles_processed = 0
|
||||
|
||||
with nlp.disable_pipes(*other_pipes):
|
||||
# we either process the whole training file, or just a part each epoch
|
||||
bar_total = len(train_indices)
|
||||
if train_articles:
|
||||
bar_total = train_articles
|
||||
|
||||
with tqdm(total=bar_total, leave=False, desc='Epoch ' + str(itn)) as pbar:
|
||||
for batch in batches:
|
||||
if not train_articles or articles_processed < train_articles:
|
||||
with nlp.disable_pipes("entity_linker"):
|
||||
train_batch = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||
dev=False, line_ids=batch,
|
||||
kb=kb, labels_discard=labels_discard)
|
||||
docs, golds = zip(*train_batch)
|
||||
try:
|
||||
docs, golds = zip(*batch)
|
||||
with nlp.disable_pipes(*other_pipes):
|
||||
nlp.update(
|
||||
docs=docs,
|
||||
golds=golds,
|
||||
|
@ -140,55 +145,28 @@ def main(
|
|||
losses=losses,
|
||||
)
|
||||
batchnr += 1
|
||||
articles_processed += len(docs)
|
||||
pbar.update(len(docs))
|
||||
except Exception as e:
|
||||
logger.error("Error updating batch:" + str(e))
|
||||
if batchnr > 0:
|
||||
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
|
||||
|
||||
# STEP 4: measure the performance of our trained pipe on an independent dev set
|
||||
logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
|
||||
measure_performance(dev_data, kb, el_pipe)
|
||||
|
||||
# STEP 5: apply the EL pipe on a toy example
|
||||
logger.info("STEP 5: Applying Entity Linking to toy example")
|
||||
run_el_toy_example(nlp=nlp)
|
||||
logging.info("Epoch {} trained on {} articles, train loss {}"
|
||||
.format(itn, articles_processed, round(losses["entity_linker"] / batchnr, 2)))
|
||||
# re-read the dev_data (data is returned as a generator)
|
||||
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||
dev=True, line_ids=dev_indices,
|
||||
kb=kb, labels_discard=labels_discard)
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True, dev_limit=len(dev_indices))
|
||||
|
||||
if output_dir:
|
||||
# STEP 6: write the NLP pipeline (now including an EL model) to file
|
||||
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
|
||||
# STEP 4: write the NLP pipeline (now including an EL model) to file
|
||||
logger.info("Final NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
|
||||
logger.info("STEP 4: Writing trained NLP to {}".format(nlp_output_dir))
|
||||
nlp.to_disk(nlp_output_dir)
|
||||
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
def check_kb(kb):
|
||||
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
||||
candidates = kb.get_candidates(mention)
|
||||
|
||||
logger.info("generating candidates for " + mention + " :")
|
||||
for c in candidates:
|
||||
logger.info(" ".join[
|
||||
str(c.prior_prob),
|
||||
c.alias_,
|
||||
"-->",
|
||||
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
|
||||
])
|
||||
|
||||
|
||||
def run_el_toy_example(nlp):
|
||||
text = (
|
||||
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
|
||||
"Douglas reminds us to always bring our towel, even in China or Brazil. "
|
||||
"The main character in Doug's novel is the man Arthur Dent, "
|
||||
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
||||
)
|
||||
doc = nlp(text)
|
||||
logger.info(text)
|
||||
for ent in doc.ents:
|
||||
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
||||
plac.call(main)
|
||||
|
|
|
@ -6,9 +6,6 @@ import bz2
|
|||
import logging
|
||||
import random
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
from functools import partial
|
||||
|
||||
from spacy.gold import GoldParse
|
||||
from bin.wiki_entity_linking import wiki_io as io
|
||||
|
@ -454,25 +451,40 @@ def _write_training_entities(outputfile, article_id, clean_text, entities):
|
|||
outputfile.write(line)
|
||||
|
||||
|
||||
def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
|
||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||
def read_training_indices(entity_file_path):
|
||||
""" This method creates two lists of indices into the training file: one with indices for the
|
||||
training examples, and one for the dev examples."""
|
||||
train_indices = []
|
||||
dev_indices = []
|
||||
|
||||
with entity_file_path.open("r", encoding="utf8") as file:
|
||||
for i, line in enumerate(file):
|
||||
example = json.loads(line)
|
||||
article_id = example["article_id"]
|
||||
clean_text = example["clean_text"]
|
||||
|
||||
if is_valid_article(clean_text):
|
||||
if is_dev(article_id):
|
||||
dev_indices.append(i)
|
||||
else:
|
||||
train_indices.append(i)
|
||||
|
||||
return train_indices, dev_indices
|
||||
|
||||
|
||||
def read_el_docs_golds(nlp, entity_file_path, dev, line_ids, kb, labels_discard=None):
|
||||
""" This method provides training/dev examples that correspond to the entity annotations found by the nlp object.
|
||||
For training, it will include both positive and negative examples by using the candidate generator from the kb.
|
||||
For testing (kb=None), it will include all positive examples only."""
|
||||
if not labels_discard:
|
||||
labels_discard = []
|
||||
|
||||
data = []
|
||||
num_entities = 0
|
||||
get_gold_parse = partial(
|
||||
_get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
|
||||
)
|
||||
texts = []
|
||||
entities_list = []
|
||||
|
||||
logger.info(
|
||||
"Reading {} data with limit {}".format("dev" if dev else "train", limit)
|
||||
)
|
||||
with entity_file_path.open("r", encoding="utf8") as file:
|
||||
with tqdm(total=limit, leave=False) as pbar:
|
||||
for i, line in enumerate(file):
|
||||
if i in line_ids:
|
||||
example = json.loads(line)
|
||||
article_id = example["article_id"]
|
||||
clean_text = example["clean_text"]
|
||||
|
@ -481,16 +493,15 @@ def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
|
|||
if dev != is_dev(article_id) or not is_valid_article(clean_text):
|
||||
continue
|
||||
|
||||
doc = nlp(clean_text)
|
||||
gold = get_gold_parse(doc, entities)
|
||||
texts.append(clean_text)
|
||||
entities_list.append(entities)
|
||||
|
||||
docs = nlp.pipe(texts, batch_size=50)
|
||||
|
||||
for doc, entities in zip(docs, entities_list):
|
||||
gold = _get_gold_parse(doc, entities, dev=dev, kb=kb, labels_discard=labels_discard)
|
||||
if gold and len(gold.links) > 0:
|
||||
data.append((doc, gold))
|
||||
num_entities += len(gold.links)
|
||||
pbar.update(len(gold.links))
|
||||
if limit and num_entities >= limit:
|
||||
break
|
||||
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
||||
return data
|
||||
yield doc, gold
|
||||
|
||||
|
||||
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
|
||||
|
|
|
@ -32,27 +32,24 @@ DESC_WIDTH = 64 # dimension of output entity vectors
|
|||
|
||||
|
||||
@plac.annotations(
|
||||
vocab_path=("Path to the vocab for the kb", "option", "v", Path),
|
||||
model=("Model name, should have pretrained word embeddings", "option", "m", str),
|
||||
model=("Model name, should have pretrained word embeddings", "positional", None, str),
|
||||
output_dir=("Optional output directory", "option", "o", Path),
|
||||
n_iter=("Number of training iterations", "option", "n", int),
|
||||
)
|
||||
def main(vocab_path=None, model=None, output_dir=None, n_iter=50):
|
||||
def main(model=None, output_dir=None, n_iter=50):
|
||||
"""Load the model, create the KB and pretrain the entity encodings.
|
||||
Either an nlp model or a vocab is needed to provide access to pretrained word embeddings.
|
||||
If an output_dir is provided, the KB will be stored there in a file 'kb'.
|
||||
When providing an nlp model, the updated vocab will also be written to a directory in the output_dir."""
|
||||
if model is None and vocab_path is None:
|
||||
raise ValueError("Either the `nlp` model or the `vocab` should be specified.")
|
||||
The updated vocab will also be written to a directory in the output_dir."""
|
||||
|
||||
if model is not None:
|
||||
nlp = spacy.load(model) # load existing spaCy model
|
||||
print("Loaded model '%s'" % model)
|
||||
else:
|
||||
vocab = Vocab().from_disk(vocab_path)
|
||||
# create blank Language class with specified vocab
|
||||
nlp = spacy.blank("en", vocab=vocab)
|
||||
print("Created blank 'en' model with vocab from '%s'" % vocab_path)
|
||||
|
||||
# check the length of the nlp vectors
|
||||
if "vectors" not in nlp.meta or not nlp.vocab.vectors.size:
|
||||
raise ValueError(
|
||||
"The `nlp` object should have access to pretrained word vectors, "
|
||||
" cf. https://spacy.io/usage/models#languages."
|
||||
)
|
||||
|
||||
kb = KnowledgeBase(vocab=nlp.vocab)
|
||||
|
||||
|
@ -103,8 +100,6 @@ def main(vocab_path=None, model=None, output_dir=None, n_iter=50):
|
|||
print()
|
||||
print("Saved KB to", kb_path)
|
||||
|
||||
# only storing the vocab if we weren't already reading it from file
|
||||
if not vocab_path:
|
||||
vocab_path = output_dir / "vocab"
|
||||
kb.vocab.to_disk(vocab_path)
|
||||
print("Saved vocab to", vocab_path)
|
||||
|
|
|
@ -131,7 +131,8 @@ def train_textcat(nlp, n_texts, n_iter=10):
|
|||
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
|
||||
|
||||
# get names of other pipes to disable them during training
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"]
|
||||
pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"]
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||
with nlp.disable_pipes(*other_pipes): # only train textcat
|
||||
optimizer = nlp.begin_training()
|
||||
textcat.model.tok2vec.from_bytes(tok2vec_weights)
|
||||
|
|
|
@ -63,7 +63,8 @@ def main(model_name, unlabelled_loc):
|
|||
optimizer.b2 = 0.0
|
||||
|
||||
# get names of other pipes to disable them during training
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
|
||||
pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"]
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||
sizes = compounding(1.0, 4.0, 1.001)
|
||||
with nlp.disable_pipes(*other_pipes):
|
||||
for itn in range(n_iter):
|
||||
|
|
|
@ -113,7 +113,8 @@ def main(kb_path, vocab_path=None, output_dir=None, n_iter=50):
|
|||
TRAIN_DOCS.append((doc, annotation_clean))
|
||||
|
||||
# get names of other pipes to disable them during training
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"]
|
||||
pipe_exceptions = ["entity_linker", "trf_wordpiecer", "trf_tok2vec"]
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||
with nlp.disable_pipes(*other_pipes): # only train entity linker
|
||||
# reset and initialize the weights randomly
|
||||
optimizer = nlp.begin_training()
|
||||
|
|
|
@ -124,7 +124,8 @@ def main(model=None, output_dir=None, n_iter=15):
|
|||
for dep in annotations.get("deps", []):
|
||||
parser.add_label(dep)
|
||||
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "parser"]
|
||||
pipe_exceptions = ["parser", "trf_wordpiecer", "trf_tok2vec"]
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||
with nlp.disable_pipes(*other_pipes): # only train parser
|
||||
optimizer = nlp.begin_training()
|
||||
for itn in range(n_iter):
|
||||
|
|
|
@ -55,7 +55,8 @@ def main(model=None, output_dir=None, n_iter=100):
|
|||
ner.add_label(ent[2])
|
||||
|
||||
# get names of other pipes to disable them during training
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
|
||||
pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"]
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||
with nlp.disable_pipes(*other_pipes): # only train NER
|
||||
# reset and initialize the weights randomly – but only if we're
|
||||
# training a new model
|
||||
|
|
|
@ -95,7 +95,8 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=30):
|
|||
optimizer = nlp.resume_training()
|
||||
move_names = list(ner.move_names)
|
||||
# get names of other pipes to disable them during training
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
|
||||
pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"]
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||
with nlp.disable_pipes(*other_pipes): # only train NER
|
||||
sizes = compounding(1.0, 4.0, 1.001)
|
||||
# batch up the examples using spaCy's minibatch
|
||||
|
|
|
@ -65,7 +65,8 @@ def main(model=None, output_dir=None, n_iter=15):
|
|||
parser.add_label(dep)
|
||||
|
||||
# get names of other pipes to disable them during training
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "parser"]
|
||||
pipe_exceptions = ["parser", "trf_wordpiecer", "trf_tok2vec"]
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||
with nlp.disable_pipes(*other_pipes): # only train parser
|
||||
optimizer = nlp.begin_training()
|
||||
for itn in range(n_iter):
|
||||
|
|
|
@ -67,7 +67,8 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
|
|||
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
|
||||
|
||||
# get names of other pipes to disable them during training
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"]
|
||||
pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"]
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||
with nlp.disable_pipes(*other_pipes): # only train textcat
|
||||
optimizer = nlp.begin_training()
|
||||
if init_tok2vec is not None:
|
||||
|
|
|
@ -91,3 +91,4 @@ cdef enum attr_id_t:
|
|||
|
||||
LANG
|
||||
ENT_KB_ID = symbols.ENT_KB_ID
|
||||
ENT_ID = symbols.ENT_ID
|
||||
|
|
|
@ -84,6 +84,7 @@ IDS = {
|
|||
"DEP": DEP,
|
||||
"ENT_IOB": ENT_IOB,
|
||||
"ENT_TYPE": ENT_TYPE,
|
||||
"ENT_ID": ENT_ID,
|
||||
"ENT_KB_ID": ENT_KB_ID,
|
||||
"HEAD": HEAD,
|
||||
"SENT_START": SENT_START,
|
||||
|
|
|
@ -192,6 +192,7 @@ def debug_data(
|
|||
has_low_data_warning = False
|
||||
has_no_neg_warning = False
|
||||
has_ws_ents_error = False
|
||||
has_punct_ents_warning = False
|
||||
|
||||
msg.divider("Named Entity Recognition")
|
||||
msg.info(
|
||||
|
@ -226,10 +227,16 @@ def debug_data(
|
|||
|
||||
if gold_train_data["ws_ents"]:
|
||||
msg.fail(
|
||||
"{} invalid whitespace entity spans".format(gold_train_data["ws_ents"])
|
||||
"{} invalid whitespace entity span(s)".format(gold_train_data["ws_ents"])
|
||||
)
|
||||
has_ws_ents_error = True
|
||||
|
||||
if gold_train_data["punct_ents"]:
|
||||
msg.warn(
|
||||
"{} entity span(s) with punctuation".format(gold_train_data["punct_ents"])
|
||||
)
|
||||
has_punct_ents_warning = True
|
||||
|
||||
for label in new_labels:
|
||||
if label_counts[label] <= NEW_LABEL_THRESHOLD:
|
||||
msg.warn(
|
||||
|
@ -253,6 +260,8 @@ def debug_data(
|
|||
msg.good("Examples without occurrences available for all labels")
|
||||
if not has_ws_ents_error:
|
||||
msg.good("No entities consisting of or starting/ending with whitespace")
|
||||
if not has_punct_ents_warning:
|
||||
msg.good("No entities consisting of or starting/ending with punctuation")
|
||||
|
||||
if has_low_data_warning:
|
||||
msg.text(
|
||||
|
@ -273,6 +282,12 @@ def debug_data(
|
|||
"with whitespace characters are considered invalid."
|
||||
)
|
||||
|
||||
if has_punct_ents_warning:
|
||||
msg.text(
|
||||
"Entity spans consisting of or starting/ending "
|
||||
"with punctuation can not be trained with a noise level > 0."
|
||||
)
|
||||
|
||||
if "textcat" in pipeline:
|
||||
msg.divider("Text Classification")
|
||||
labels = [label for label in gold_train_data["cats"]]
|
||||
|
@ -547,6 +562,7 @@ def _compile_gold(train_docs, pipeline):
|
|||
"words": Counter(),
|
||||
"roots": Counter(),
|
||||
"ws_ents": 0,
|
||||
"punct_ents": 0,
|
||||
"n_words": 0,
|
||||
"n_misaligned_words": 0,
|
||||
"n_sents": 0,
|
||||
|
@ -568,6 +584,10 @@ def _compile_gold(train_docs, pipeline):
|
|||
if label.startswith(("B-", "U-", "L-")) and doc[i].is_space:
|
||||
# "Illegal" whitespace entity
|
||||
data["ws_ents"] += 1
|
||||
if label.startswith(("B-", "U-", "L-")) and doc[i].text in [".", "'", "!", "?", ","]:
|
||||
# punctuation entity: could be replaced by whitespace when training with noise,
|
||||
# so add a warning to alert the user to this unexpected side effect.
|
||||
data["punct_ents"] += 1
|
||||
if label.startswith(("B-", "U-")):
|
||||
combined_label = label.split("-")[1]
|
||||
data["ner"][combined_label] += 1
|
||||
|
|
|
@ -30,6 +30,7 @@ from .. import about
|
|||
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
|
||||
base_model=("Name of model to update (optional)", "option", "b", str),
|
||||
pipeline=("Comma-separated names of pipeline components", "option", "p", str),
|
||||
replace_components=("Replace components from base model", "flag", "R", bool),
|
||||
vectors=("Model to load vectors from", "option", "v", str),
|
||||
n_iter=("Number of iterations", "option", "n", int),
|
||||
n_early_stopping=("Maximum number of training epochs without dev accuracy improvement", "option", "ne", int),
|
||||
|
@ -60,6 +61,7 @@ def train(
|
|||
raw_text=None,
|
||||
base_model=None,
|
||||
pipeline="tagger,parser,ner",
|
||||
replace_components=False,
|
||||
vectors=None,
|
||||
n_iter=30,
|
||||
n_early_stopping=None,
|
||||
|
@ -142,6 +144,8 @@ def train(
|
|||
# the model and make sure the pipeline matches the pipeline setting. If
|
||||
# training starts from a blank model, intitalize the language class.
|
||||
pipeline = [p.strip() for p in pipeline.split(",")]
|
||||
disabled_pipes = None
|
||||
pipes_added = False
|
||||
msg.text("Training pipeline: {}".format(pipeline))
|
||||
if base_model:
|
||||
msg.text("Starting with base model '{}'".format(base_model))
|
||||
|
@ -152,9 +156,8 @@ def train(
|
|||
"`lang` argument ('{}') ".format(nlp.lang, lang),
|
||||
exits=1,
|
||||
)
|
||||
nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
|
||||
for pipe in pipeline:
|
||||
if pipe not in nlp.pipe_names:
|
||||
pipe_cfg = {}
|
||||
if pipe == "parser":
|
||||
pipe_cfg = {"learn_tokens": learn_tokens}
|
||||
elif pipe == "textcat":
|
||||
|
@ -163,9 +166,14 @@ def train(
|
|||
"architecture": textcat_arch,
|
||||
"positive_label": textcat_positive_label,
|
||||
}
|
||||
else:
|
||||
pipe_cfg = {}
|
||||
if pipe not in nlp.pipe_names:
|
||||
msg.text("Adding component to base model '{}'".format(pipe))
|
||||
nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
|
||||
pipes_added = True
|
||||
elif replace_components:
|
||||
msg.text("Replacing component from base model '{}'".format(pipe))
|
||||
nlp.replace_pipe(pipe, nlp.create_pipe(pipe, config=pipe_cfg))
|
||||
pipes_added = True
|
||||
else:
|
||||
if pipe == "textcat":
|
||||
textcat_cfg = nlp.get_pipe("textcat").cfg
|
||||
|
@ -174,11 +182,6 @@ def train(
|
|||
"architecture": textcat_cfg["architecture"],
|
||||
"positive_label": textcat_cfg["positive_label"],
|
||||
}
|
||||
pipe_cfg = {
|
||||
"exclusive_classes": not textcat_multilabel,
|
||||
"architecture": textcat_arch,
|
||||
"positive_label": textcat_positive_label,
|
||||
}
|
||||
if base_cfg != pipe_cfg:
|
||||
msg.fail(
|
||||
"The base textcat model configuration does"
|
||||
|
@ -188,6 +191,8 @@ def train(
|
|||
),
|
||||
exits=1,
|
||||
)
|
||||
msg.text("Extending component from base model '{}'".format(pipe))
|
||||
disabled_pipes = nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
|
||||
else:
|
||||
msg.text("Starting with blank model '{}'".format(lang))
|
||||
lang_cls = util.get_lang_class(lang)
|
||||
|
@ -227,7 +232,7 @@ def train(
|
|||
corpus = GoldCorpus(train_path, dev_path, limit=n_examples)
|
||||
n_train_words = corpus.count_train()
|
||||
|
||||
if base_model:
|
||||
if base_model and not pipes_added:
|
||||
# Start with an existing model, use default optimizer
|
||||
optimizer = create_default_optimizer(Model.ops)
|
||||
else:
|
||||
|
@ -243,7 +248,7 @@ def train(
|
|||
|
||||
# Verify textcat config
|
||||
if "textcat" in pipeline:
|
||||
textcat_labels = nlp.get_pipe("textcat").cfg["labels"]
|
||||
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
|
||||
if textcat_positive_label and textcat_positive_label not in textcat_labels:
|
||||
msg.fail(
|
||||
"The textcat_positive_label (tpl) '{}' does not match any "
|
||||
|
@ -426,11 +431,16 @@ def train(
|
|||
"cpu": cpu_wps,
|
||||
"gpu": gpu_wps,
|
||||
}
|
||||
meta["accuracy"] = scorer.scores
|
||||
meta.setdefault("accuracy", {})
|
||||
for component in nlp.pipe_names:
|
||||
for metric in _get_metrics(component):
|
||||
meta["accuracy"][metric] = scorer.scores[metric]
|
||||
else:
|
||||
meta.setdefault("beam_accuracy", {})
|
||||
meta.setdefault("beam_speed", {})
|
||||
meta["beam_accuracy"][beam_width] = scorer.scores
|
||||
for component in nlp.pipe_names:
|
||||
for metric in _get_metrics(component):
|
||||
meta["beam_accuracy"][metric] = scorer.scores[metric]
|
||||
meta["beam_speed"][beam_width] = {
|
||||
"nwords": nwords,
|
||||
"cpu": cpu_wps,
|
||||
|
@ -486,12 +496,16 @@ def train(
|
|||
)
|
||||
break
|
||||
finally:
|
||||
best_pipes = nlp.pipe_names
|
||||
if disabled_pipes:
|
||||
disabled_pipes.restore()
|
||||
with nlp.use_params(optimizer.averages):
|
||||
final_model_path = output_path / "model-final"
|
||||
nlp.to_disk(final_model_path)
|
||||
final_meta = srsly.read_json(output_path / "model-final" / "meta.json")
|
||||
msg.good("Saved model to output directory", final_model_path)
|
||||
with msg.loading("Creating best model..."):
|
||||
best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names)
|
||||
best_model_path = _collate_best_model(final_meta, output_path, best_pipes)
|
||||
msg.good("Created best model", best_model_path)
|
||||
|
||||
|
||||
|
@ -549,6 +563,7 @@ def _load_pretrained_tok2vec(nlp, loc):
|
|||
|
||||
def _collate_best_model(meta, output_path, components):
|
||||
bests = {}
|
||||
meta.setdefault("accuracy", {})
|
||||
for component in components:
|
||||
bests[component] = _find_best(output_path, component)
|
||||
best_dest = output_path / "model-best"
|
||||
|
@ -580,11 +595,13 @@ def _find_best(experiment_dir, component):
|
|||
|
||||
def _get_metrics(component):
|
||||
if component == "parser":
|
||||
return ("las", "uas", "token_acc")
|
||||
return ("las", "uas", "las_per_type", "token_acc")
|
||||
elif component == "tagger":
|
||||
return ("tags_acc",)
|
||||
elif component == "ner":
|
||||
return ("ents_f", "ents_p", "ents_r")
|
||||
return ("ents_f", "ents_p", "ents_r", "ents_per_type")
|
||||
elif component == "textcat":
|
||||
return ("textcat_score",)
|
||||
return ("token_acc",)
|
||||
|
||||
|
||||
|
|
|
@ -172,7 +172,8 @@ class Errors(object):
|
|||
"and satisfies the correct annotations specified in the GoldParse. "
|
||||
"For example, are all labels added to the model? If you're "
|
||||
"training a named entity recognizer, also make sure that none of "
|
||||
"your annotated entity spans have leading or trailing whitespace. "
|
||||
"your annotated entity spans have leading or trailing whitespace "
|
||||
"or punctuation. "
|
||||
"You can also use the experimental `debug-data` command to "
|
||||
"validate your JSON-formatted training data. For details, run:\n"
|
||||
"python -m spacy debug-data --help")
|
||||
|
|
|
@ -17,6 +17,17 @@ _tamil = r"\u0B80-\u0BFF"
|
|||
|
||||
_telugu = r"\u0C00-\u0C7F"
|
||||
|
||||
# from the final table in: https://en.wikipedia.org/wiki/CJK_Unified_Ideographs
|
||||
_cjk = (
|
||||
r"\u4E00-\u62FF\u6300-\u77FF\u7800-\u8CFF\u8D00-\u9FFF\u3400-\u4DBF"
|
||||
r"\U00020000-\U000215FF\U00021600-\U000230FF\U00023100-\U000245FF"
|
||||
r"\U00024600-\U000260FF\U00026100-\U000275FF\U00027600-\U000290FF"
|
||||
r"\U00029100-\U0002A6DF\U0002A700-\U0002B73F\U0002B740-\U0002B81F"
|
||||
r"\U0002B820-\U0002CEAF\U0002CEB0-\U0002EBEF\u2E80-\u2EFF\u2F00-\u2FDF"
|
||||
r"\u2FF0-\u2FFF\u3000-\u303F\u31C0-\u31EF\u3200-\u32FF\u3300-\u33FF"
|
||||
r"\uF900-\uFAFF\uFE30-\uFE4F\U0001F200-\U0001F2FF\U0002F800-\U0002FA1F"
|
||||
)
|
||||
|
||||
# Latin standard
|
||||
_latin_u_standard = r"A-Z"
|
||||
_latin_l_standard = r"a-z"
|
||||
|
@ -215,6 +226,7 @@ _uncased = (
|
|||
+ _tamil
|
||||
+ _telugu
|
||||
+ _hangul
|
||||
+ _cjk
|
||||
)
|
||||
|
||||
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
|
||||
|
|
|
@ -3,14 +3,18 @@ from __future__ import unicode_literals
|
|||
|
||||
import re
|
||||
|
||||
from .char_classes import ALPHA_LOWER
|
||||
from ..symbols import ORTH, POS, TAG, LEMMA, SPACE
|
||||
|
||||
|
||||
# URL validation regex courtesy of: https://mathiasbynens.be/demo/url-regex
|
||||
# A few minor mods to this regex to account for use cases represented in test_urls
|
||||
# and https://gist.github.com/dperini/729294 (Diego Perini, MIT License)
|
||||
# A few mods to this regex to account for use cases represented in test_urls
|
||||
URL_PATTERN = (
|
||||
# fmt: off
|
||||
r"^"
|
||||
# protocol identifier (see: https://www.iana.org/assignments/uri-schemes/uri-schemes.xhtml)
|
||||
# protocol identifier (mods: make optional and expand schemes)
|
||||
# (see: https://www.iana.org/assignments/uri-schemes/uri-schemes.xhtml)
|
||||
r"(?:(?:[\w\+\-\.]{2,})://)?"
|
||||
# mailto:user or user:pass authentication
|
||||
r"(?:\S+(?::\S*)?@)?"
|
||||
|
@ -31,18 +35,27 @@ URL_PATTERN = (
|
|||
r"(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5])){2}"
|
||||
r"(?:\.(?:[1-9]\d?|1\d\d|2[0-4]\d|25[0-4]))"
|
||||
r"|"
|
||||
# host name
|
||||
r"(?:(?:[a-z0-9\-]*)?[a-z0-9]+)"
|
||||
# domain name
|
||||
r"(?:\.(?:[a-z0-9])(?:[a-z0-9\-])*[a-z0-9])?"
|
||||
# host & domain names
|
||||
# mods: match is case-sensitive, so include [A-Z]
|
||||
"(?:"
|
||||
"(?:"
|
||||
"[A-Za-z0-9\u00a1-\uffff]"
|
||||
"[A-Za-z0-9\u00a1-\uffff_-]{0,62}"
|
||||
")?"
|
||||
"[A-Za-z0-9\u00a1-\uffff]\."
|
||||
")+"
|
||||
# TLD identifier
|
||||
r"(?:\.(?:[a-z]{2,}))"
|
||||
# mods: use ALPHA_LOWER instead of a wider range so that this doesn't match
|
||||
# strings like "lower.Upper", which can be split on "." by infixes in some
|
||||
# languages
|
||||
r"(?:[" + ALPHA_LOWER + "]{2,63})"
|
||||
r")"
|
||||
# port number
|
||||
r"(?::\d{2,5})?"
|
||||
# resource path
|
||||
r"(?:[/?#]\S*)?"
|
||||
r"$"
|
||||
# fmt: on
|
||||
).strip()
|
||||
|
||||
TOKEN_MATCH = re.compile(URL_PATTERN, re.UNICODE).match
|
||||
|
|
|
@ -780,7 +780,7 @@ class Language(object):
|
|||
|
||||
pipes = (
|
||||
[]
|
||||
) # contains functools.partial objects so that easily create multiprocess worker.
|
||||
) # contains functools.partial objects to easily create multiprocess worker.
|
||||
for name, proc in self.pipeline:
|
||||
if name in disable:
|
||||
continue
|
||||
|
@ -837,7 +837,7 @@ class Language(object):
|
|||
texts, raw_texts = itertools.tee(texts)
|
||||
# for sending texts to worker
|
||||
texts_q = [mp.Queue() for _ in range(n_process)]
|
||||
# for receiving byte encoded docs from worker
|
||||
# for receiving byte-encoded docs from worker
|
||||
bytedocs_recv_ch, bytedocs_send_ch = zip(
|
||||
*[mp.Pipe(False) for _ in range(n_process)]
|
||||
)
|
||||
|
@ -847,7 +847,7 @@ class Language(object):
|
|||
# This is necessary to properly handle infinite length of texts.
|
||||
# (In this case, all data cannot be sent to the workers at once)
|
||||
sender = _Sender(batch_texts, texts_q, chunk_size=n_process)
|
||||
# send twice so that make process busy
|
||||
# send twice to make process busy
|
||||
sender.send()
|
||||
sender.send()
|
||||
|
||||
|
@ -859,7 +859,7 @@ class Language(object):
|
|||
proc.start()
|
||||
|
||||
# Cycle channels not to break the order of docs.
|
||||
# The received object is batch of byte encoded docs, so flatten them with chain.from_iterable.
|
||||
# The received object is a batch of byte-encoded docs, so flatten them with chain.from_iterable.
|
||||
byte_docs = chain.from_iterable(recv.recv() for recv in cycle(bytedocs_recv_ch))
|
||||
docs = (Doc(self.vocab).from_bytes(byte_doc) for byte_doc in byte_docs)
|
||||
try:
|
||||
|
|
|
@ -129,20 +129,31 @@ class EntityRuler(object):
|
|||
|
||||
DOCS: https://spacy.io/api/entityruler#labels
|
||||
"""
|
||||
all_labels = set(self.token_patterns.keys())
|
||||
all_labels.update(self.phrase_patterns.keys())
|
||||
keys = set(self.token_patterns.keys())
|
||||
keys.update(self.phrase_patterns.keys())
|
||||
all_labels = set()
|
||||
|
||||
for l in keys:
|
||||
if self.ent_id_sep in l:
|
||||
label, _ = self._split_label(l)
|
||||
all_labels.add(label)
|
||||
else:
|
||||
all_labels.add(l)
|
||||
return tuple(all_labels)
|
||||
|
||||
@property
|
||||
def ent_ids(self):
|
||||
"""All entity ids present in the match patterns `id` properties.
|
||||
"""All entity ids present in the match patterns `id` properties
|
||||
|
||||
RETURNS (set): The string entity ids.
|
||||
|
||||
DOCS: https://spacy.io/api/entityruler#ent_ids
|
||||
"""
|
||||
keys = set(self.token_patterns.keys())
|
||||
keys.update(self.phrase_patterns.keys())
|
||||
all_ent_ids = set()
|
||||
for l in self.labels:
|
||||
|
||||
for l in keys:
|
||||
if self.ent_id_sep in l:
|
||||
_, ent_id = self._split_label(l)
|
||||
all_ent_ids.add(ent_id)
|
||||
|
|
|
@ -1308,7 +1308,7 @@ class EntityLinker(Pipe):
|
|||
for i, doc in enumerate(docs):
|
||||
if len(doc) > 0:
|
||||
# Looping through each sentence and each entity
|
||||
# This may go wrong if there are entities across sentences - because they might not get a KB ID
|
||||
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
|
||||
for sent in doc.sents:
|
||||
sent_doc = sent.as_doc()
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
|
|
|
@ -462,3 +462,4 @@ cdef enum symbol_t:
|
|||
acl
|
||||
|
||||
ENT_KB_ID
|
||||
ENT_ID
|
||||
|
|
|
@ -86,6 +86,7 @@ IDS = {
|
|||
"DEP": DEP,
|
||||
"ENT_IOB": ENT_IOB,
|
||||
"ENT_TYPE": ENT_TYPE,
|
||||
"ENT_ID": ENT_ID,
|
||||
"ENT_KB_ID": ENT_KB_ID,
|
||||
"HEAD": HEAD,
|
||||
"SENT_START": SENT_START,
|
||||
|
|
|
@ -57,7 +57,7 @@ cdef class Parser:
|
|||
subword_features = util.env_opt('subword_features',
|
||||
cfg.get('subword_features', True))
|
||||
conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4))
|
||||
conv_window = util.env_opt('conv_window', cfg.get('conv_depth', 1))
|
||||
conv_window = util.env_opt('conv_window', cfg.get('conv_window', 1))
|
||||
t2v_pieces = util.env_opt('cnn_maxout_pieces', cfg.get('cnn_maxout_pieces', 3))
|
||||
bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0))
|
||||
self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0))
|
||||
|
|
|
@ -296,9 +296,8 @@ WIKI_TESTS = [
|
|||
("cérium(IV)-oxid", ["cérium", "(", "IV", ")", "-oxid"]),
|
||||
]
|
||||
|
||||
TESTCASES = (
|
||||
DEFAULT_TESTS
|
||||
+ DOT_TESTS
|
||||
EXTRA_TESTS = (
|
||||
DOT_TESTS
|
||||
+ QUOTE_TESTS
|
||||
+ NUMBER_TESTS
|
||||
+ HYPHEN_TESTS
|
||||
|
@ -306,8 +305,16 @@ TESTCASES = (
|
|||
+ TYPO_TESTS
|
||||
)
|
||||
|
||||
# normal: default tests + 10% of extra tests
|
||||
TESTS = DEFAULT_TESTS
|
||||
TESTS.extend([x for i, x in enumerate(EXTRA_TESTS) if i % 10 == 0])
|
||||
|
||||
@pytest.mark.parametrize("text,expected_tokens", TESTCASES)
|
||||
# slow: remaining 90% of extra tests
|
||||
SLOW_TESTS = [x for i, x in enumerate(EXTRA_TESTS) if i % 10 != 0]
|
||||
TESTS.extend([pytest.param(x[0], x[1], marks=pytest.mark.slow()) if not isinstance(x[0], tuple) else x for x in SLOW_TESTS])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text,expected_tokens", TESTS)
|
||||
def test_hu_tokenizer_handles_testcases(hu_tokenizer, text, expected_tokens):
|
||||
tokens = hu_tokenizer(text)
|
||||
token_list = [token.text for token in tokens if not token.is_space]
|
||||
|
|
|
@ -21,6 +21,7 @@ def patterns():
|
|||
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
|
||||
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
|
||||
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
|
||||
{"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"},
|
||||
]
|
||||
|
||||
|
||||
|
@ -147,3 +148,14 @@ def test_entity_ruler_validate(nlp):
|
|||
# invalid pattern raises error with validate
|
||||
with pytest.raises(MatchPatternError):
|
||||
validated_ruler.add_patterns([invalid_pattern])
|
||||
|
||||
|
||||
def test_entity_ruler_properties(nlp, patterns):
|
||||
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
||||
assert sorted(ruler.labels) == sorted([
|
||||
"HELLO",
|
||||
"BYE",
|
||||
"COMPLEX",
|
||||
"TECH_ORG"
|
||||
])
|
||||
assert sorted(ruler.ent_ids) == ["a1", "a2"]
|
||||
|
|
36
spacy/tests/regression/test_issue4849.py
Normal file
36
spacy/tests/regression/test_issue4849.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.pipeline import EntityRuler
|
||||
|
||||
|
||||
def test_issue4849():
|
||||
nlp = English()
|
||||
|
||||
ruler = EntityRuler(
|
||||
nlp, patterns=[
|
||||
{"label": "PERSON", "pattern": 'joe biden', "id": 'joe-biden'},
|
||||
{"label": "PERSON", "pattern": 'bernie sanders', "id": 'bernie-sanders'},
|
||||
],
|
||||
phrase_matcher_attr="LOWER"
|
||||
)
|
||||
|
||||
nlp.add_pipe(ruler)
|
||||
|
||||
text = """
|
||||
The left is starting to take aim at Democratic front-runner Joe Biden.
|
||||
Sen. Bernie Sanders joined in her criticism: "There is no 'middle ground' when it comes to climate policy."
|
||||
"""
|
||||
|
||||
# USING 1 PROCESS
|
||||
count_ents = 0
|
||||
for doc in nlp.pipe([text], n_process=1):
|
||||
count_ents += len([ent for ent in doc.ents if ent.ent_id > 0])
|
||||
assert(count_ents == 2)
|
||||
|
||||
# USING 2 PROCESSES
|
||||
count_ents = 0
|
||||
for doc in nlp.pipe([text], n_process=2):
|
||||
count_ents += len([ent for ent in doc.ents if ent.ent_id > 0])
|
||||
assert (count_ents == 2)
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy.tokens import Doc
|
||||
from spacy.tokens import Doc, Token
|
||||
from spacy.vocab import Vocab
|
||||
|
||||
|
||||
|
@ -15,6 +15,10 @@ def doc_w_attrs(en_tokenizer):
|
|||
)
|
||||
doc = en_tokenizer("This is a test.")
|
||||
doc._._test_attr = "test"
|
||||
|
||||
Token.set_extension("_test_token", default="t0")
|
||||
doc[1]._._test_token = "t1"
|
||||
|
||||
return doc
|
||||
|
||||
|
||||
|
@ -25,3 +29,7 @@ def test_serialize_ext_attrs_from_bytes(doc_w_attrs):
|
|||
assert doc._._test_attr == "test"
|
||||
assert doc._._test_prop == len(doc.text)
|
||||
assert doc._._test_method("test") == "{}{}".format(len(doc.text), "test")
|
||||
|
||||
assert doc[0]._._test_token == "t0"
|
||||
assert doc[1]._._test_token == "t1"
|
||||
assert doc[2]._._test_token == "t0"
|
||||
|
|
|
@ -20,6 +20,7 @@ URLS_FULL = URLS_BASIC + [
|
|||
# URL SHOULD_MATCH and SHOULD_NOT_MATCH patterns courtesy of https://mathiasbynens.be/demo/url-regex
|
||||
URLS_SHOULD_MATCH = [
|
||||
"http://foo.com/blah_blah",
|
||||
"http://BlahBlah.com/Blah_Blah",
|
||||
"http://foo.com/blah_blah/",
|
||||
"http://www.example.com/wpstyle/?p=364",
|
||||
"https://www.example.com/foo/?bar=baz&inga=42&quux",
|
||||
|
@ -57,14 +58,17 @@ URLS_SHOULD_MATCH = [
|
|||
),
|
||||
"http://foo.com/blah_blah_(wikipedia)",
|
||||
"http://foo.com/blah_blah_(wikipedia)_(again)",
|
||||
pytest.param("http://⌘.ws", marks=pytest.mark.xfail()),
|
||||
pytest.param("http://⌘.ws/", marks=pytest.mark.xfail()),
|
||||
pytest.param("http://☺.damowmow.com/", marks=pytest.mark.xfail()),
|
||||
pytest.param("http://✪df.ws/123", marks=pytest.mark.xfail()),
|
||||
pytest.param("http://➡.ws/䨹", marks=pytest.mark.xfail()),
|
||||
pytest.param("http://مثال.إختبار", marks=pytest.mark.xfail()),
|
||||
pytest.param("http://例子.测试", marks=pytest.mark.xfail()),
|
||||
pytest.param("http://उदाहरण.परीक्षा", marks=pytest.mark.xfail()),
|
||||
"http://www.foo.co.uk",
|
||||
"http://www.foo.co.uk/",
|
||||
"http://www.foo.co.uk/blah/blah",
|
||||
"http://⌘.ws",
|
||||
"http://⌘.ws/",
|
||||
"http://☺.damowmow.com/",
|
||||
"http://✪df.ws/123",
|
||||
"http://➡.ws/䨹",
|
||||
"http://مثال.إختبار",
|
||||
"http://例子.测试",
|
||||
"http://उदाहरण.परीक्षा",
|
||||
]
|
||||
|
||||
URLS_SHOULD_NOT_MATCH = [
|
||||
|
|
|
@ -23,7 +23,7 @@ from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
|||
from ..typedefs cimport attr_t, flags_t
|
||||
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER
|
||||
from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB
|
||||
from ..attrs cimport ENT_TYPE, ENT_KB_ID, SENT_START, attr_id_t
|
||||
from ..attrs cimport ENT_TYPE, ENT_ID, ENT_KB_ID, SENT_START, attr_id_t
|
||||
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
|
||||
|
||||
from ..attrs import intify_attrs, IDS
|
||||
|
@ -69,6 +69,8 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
|
|||
return token.ent_iob
|
||||
elif feat_name == ENT_TYPE:
|
||||
return token.ent_type
|
||||
elif feat_name == ENT_ID:
|
||||
return token.ent_id
|
||||
elif feat_name == ENT_KB_ID:
|
||||
return token.ent_kb_id
|
||||
else:
|
||||
|
@ -868,7 +870,7 @@ cdef class Doc:
|
|||
|
||||
DOCS: https://spacy.io/api/doc#to_bytes
|
||||
"""
|
||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] # TODO: ENT_KB_ID ?
|
||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_ID] # TODO: ENT_KB_ID ?
|
||||
if self.is_tagged:
|
||||
array_head.extend([TAG, POS])
|
||||
# If doc parsed add head and dep attribute
|
||||
|
|
|
@ -212,7 +212,7 @@ cdef class Span:
|
|||
words = [t.text for t in self]
|
||||
spaces = [bool(t.whitespace_) for t in self]
|
||||
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
|
||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_KB_ID]
|
||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_ID, ENT_KB_ID]
|
||||
if self.doc.is_tagged:
|
||||
array_head.append(TAG)
|
||||
# If doc parsed add head and dep attribute
|
||||
|
|
|
@ -53,6 +53,8 @@ cdef class Token:
|
|||
return token.ent_iob
|
||||
elif feat_name == ENT_TYPE:
|
||||
return token.ent_type
|
||||
elif feat_name == ENT_ID:
|
||||
return token.ent_id
|
||||
elif feat_name == ENT_KB_ID:
|
||||
return token.ent_kb_id
|
||||
elif feat_name == SENT_START:
|
||||
|
@ -81,6 +83,8 @@ cdef class Token:
|
|||
token.ent_iob = value
|
||||
elif feat_name == ENT_TYPE:
|
||||
token.ent_type = value
|
||||
elif feat_name == ENT_ID:
|
||||
token.ent_id = value
|
||||
elif feat_name == ENT_KB_ID:
|
||||
token.ent_kb_id = value
|
||||
elif feat_name == SENT_START:
|
||||
|
|
|
@ -9,7 +9,7 @@ menu:
|
|||
---
|
||||
|
||||
Compared to using regular expressions on raw text, spaCy's rule-based matcher
|
||||
engines and components not only let you find you the words and phrases you're
|
||||
engines and components not only let you find the words and phrases you're
|
||||
looking for – they also give you access to the tokens within the document and
|
||||
their relationships. This means you can easily access and analyze the
|
||||
surrounding tokens, merge spans into single tokens or add entries to the named
|
||||
|
|
|
@ -229,10 +229,10 @@ For more details on **adding hooks** and **overwriting** the built-in `Doc`,
|
|||
If you're using a GPU, it's much more efficient to keep the word vectors on the
|
||||
device. You can do that by setting the [`Vectors.data`](/api/vectors#attributes)
|
||||
attribute to a `cupy.ndarray` object if you're using spaCy or
|
||||
[Chainer]("https://chainer.org"), or a `torch.Tensor` object if you're using
|
||||
[PyTorch]("http://pytorch.org"). The `data` object just needs to support
|
||||
[Chainer](https://chainer.org), or a `torch.Tensor` object if you're using
|
||||
[PyTorch](http://pytorch.org). The `data` object just needs to support
|
||||
`__iter__` and `__getitem__`, so if you're using another library such as
|
||||
[TensorFlow]("https://www.tensorflow.org"), you could also create a wrapper for
|
||||
[TensorFlow](https://www.tensorflow.org), you could also create a wrapper for
|
||||
your vectors data.
|
||||
|
||||
```python
|
||||
|
|
|
@ -1509,28 +1509,30 @@
|
|||
{
|
||||
"id": "spacy-conll",
|
||||
"title": "spacy_conll",
|
||||
"slogan": "Parse text with spaCy and print the output in CoNLL-U format",
|
||||
"description": "This module allows you to parse a text to CoNLL-U format. You can use it as a command line tool, or embed it in your own scripts.",
|
||||
"slogan": "Parse text with spaCy and gets its output in CoNLL-U format",
|
||||
"description": "This module allows you to parse a text to CoNLL-U format. It contains a pipeline component for spaCy that adds CoNLL-U properties to a Doc and its sentences. It can also be used as a command-line tool.",
|
||||
"code_example": [
|
||||
"from spacy_conll import Spacy2ConllParser",
|
||||
"spacyconll = Spacy2ConllParser()",
|
||||
"import spacy",
|
||||
"from spacy_conll import ConllFormatter",
|
||||
"",
|
||||
"# `parse` returns a generator of the parsed sentences",
|
||||
"for parsed_sent in spacyconll.parse(input_str='I like cookies.\nWhat about you?\nI don't like 'em!'):",
|
||||
" do_something_(parsed_sent)",
|
||||
"",
|
||||
"# `parseprint` prints output to stdout (default) or a file (use `output_file` parameter)",
|
||||
"# This method is called when using the command line",
|
||||
"spacyconll.parseprint(input_str='I like cookies.')"
|
||||
"nlp = spacy.load('en')",
|
||||
"conllformatter = ConllFormatter(nlp)",
|
||||
"nlp.add_pipe(conllformatter, after='parser')",
|
||||
"doc = nlp('I like cookies. Do you?')",
|
||||
"conll = doc._.conll",
|
||||
"print(doc._.conll_str_headers)",
|
||||
"print(doc._.conll_str)"
|
||||
],
|
||||
"code_language": "python",
|
||||
"author": "Bram Vanroy",
|
||||
"author_links": {
|
||||
"github": "BramVanroy",
|
||||
"twitter": "BramVanroy",
|
||||
"website": "https://bramvanroy.be"
|
||||
},
|
||||
"github": "BramVanroy/spacy_conll",
|
||||
"category": ["standalone"]
|
||||
"category": ["standalone", "pipeline"],
|
||||
"tags": ["linguistics", "computational linguistics", "conll"]
|
||||
},
|
||||
{
|
||||
"id": "spacy-langdetect",
|
||||
|
|
Loading…
Reference in New Issue
Block a user