Merge branch 'master' into spacy.io

This commit is contained in:
Ines Montani 2019-10-19 14:03:29 +02:00
commit 0b2df3b879
68 changed files with 1960 additions and 912 deletions

106
.github/contributors/PeterGilles.md vendored Normal file
View File

@ -0,0 +1,106 @@
# spaCy contributor agreement
This spaCy Contributor Agreement (**"SCA"**) is based on the
[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
The SCA applies to any contribution that you make to any product or project
managed by us (the **"project"**), and sets out the intellectual property rights
you grant to us in the contributed materials. The term **"us"** shall mean
[ExplosionAI GmbH](https://explosion.ai/legal). The term
**"you"** shall mean the person or entity identified below.
If you agree to be bound by these terms, fill in the information requested
below and include the filled-in version with your first pull request, under the
folder [`.github/contributors/`](/.github/contributors/). The name of the file
should be your GitHub username, with the extension `.md`. For example, the user
example_user would create the file `.github/contributors/example_user.md`.
Read this agreement carefully before signing. These terms and conditions
constitute a binding legal agreement.
## Contributor Agreement
1. The term "contribution" or "contributed materials" means any source code,
object code, patch, tool, sample, graphic, specification, manual,
documentation, or any other material posted or submitted by you to the project.
2. With respect to any worldwide copyrights, or copyright applications and
registrations, in your contribution:
* you hereby assign to us joint ownership, and to the extent that such
assignment is or becomes invalid, ineffective or unenforceable, you hereby
grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
royalty-free, unrestricted license to exercise all rights under those
copyrights. This includes, at our option, the right to sublicense these same
rights to third parties through multiple levels of sublicensees or other
licensing arrangements;
* you agree that each of us can do all things in relation to your
contribution as if each of us were the sole owners, and if one of us makes
a derivative work of your contribution, the one who makes the derivative
work (or has it made will be the sole owner of that derivative work;
* you agree that you will not assert any moral rights in your contribution
against us, our licensees or transferees;
* you agree that we may register a copyright in your contribution and
exercise all ownership rights associated with it; and
* you agree that neither of us has any duty to consult with, obtain the
consent of, pay or render an accounting to the other for any use or
distribution of your contribution.
3. With respect to any patents you own, or that you can license without payment
to any third party, you hereby grant to us a perpetual, irrevocable,
non-exclusive, worldwide, no-charge, royalty-free license to:
* make, have made, use, sell, offer to sell, import, and otherwise transfer
your contribution in whole or in part, alone or in combination with or
included in any product, work or materials arising out of the project to
which your contribution was submitted, and
* at our option, to sublicense these same rights to third parties through
multiple levels of sublicensees or other licensing arrangements.
4. Except as set out above, you keep all right, title, and interest in your
contribution. The rights that you grant to us under these terms are effective
on the date you first submitted a contribution to us, even if your submission
took place before the date you sign these terms.
5. You covenant, represent, warrant and agree that:
* Each contribution that you submit is and shall be an original work of
authorship and you can legally grant the rights set out in this SCA;
* to the best of your knowledge, each contribution will not violate any
third party's copyrights, trademarks, patents, or other intellectual
property rights; and
* each contribution shall be in compliance with U.S. export control laws and
other applicable export and import laws. You agree to notify us if you
become aware of any circumstance which would make any of the foregoing
representations inaccurate in any respect. We may publicly disclose your
participation in the project, including the fact that you have signed the SCA.
6. This SCA is governed by the laws of the State of California and applicable
U.S. Federal law. Any choice of law rules will not apply.
7. Please place an “x” on one of the applicable statement below. Please do NOT
mark both statements:
* [X] I am signing on behalf of myself as an individual and no other person
or entity, including my employer, has or will have rights with respect to my
contributions.
* [ ] I am signing on behalf of my employer or a legal entity and I have the
actual authority to contractually bind that entity.
## Contributor Details
| Field | Entry |
|------------------------------- | -------------------- |
| Name | Peter Gilles |
| Company name (if applicable) | |
| Title or role (if applicable) | |
| Date | 10.10. |
| GitHub username | Peter Gilles |
| Website (optional) | |

View File

@ -197,7 +197,7 @@ path to the model data directory.
```python
import spacy
nlp = spacy.load("en_core_web_sm")
doc = nlp(u"This is a sentence.")
doc = nlp("This is a sentence.")
```
You can also `import` a model directly via its full name and then call its
@ -208,7 +208,7 @@ import spacy
import en_core_web_sm
nlp = en_core_web_sm.load()
doc = nlp(u"This is a sentence.")
doc = nlp("This is a sentence.")
```
📖 **For more info and examples, check out the

View File

@ -0,0 +1,34 @@
## Entity Linking with Wikipedia and Wikidata
### Step 1: Create a Knowledge Base (KB) and training data
Run `wikipedia_pretrain_kb.py`
* This takes as input the locations of a **Wikipedia and a Wikidata dump**, and produces a **KB directory** + **training file**
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
* You can set the filtering parameters for KB construction:
* `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
* `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
* `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
* Further parameters to set:
* `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
* `entity_vector_length`: length of the pre-trained entity description vectors
* `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
Quick testing and rerunning:
* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
### Step 2: Train an Entity Linking model
Run `wikidata_train_entity_linker.py`
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
* You can set the learning parameters for the EL training:
* `epochs`: number of training iterations
* `dropout`: dropout rate
* `lr`: learning rate
* `l2`: L2 regularization
* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
* Further parameters to set:
* `labels_discard`: NER label types to discard during training

View File

@ -6,6 +6,7 @@ OUTPUT_MODEL_DIR = "nlp"
PRIOR_PROB_PATH = "prior_prob.csv"
ENTITY_DEFS_PATH = "entity_defs.csv"
ENTITY_FREQ_PATH = "entity_freq.csv"
ENTITY_ALIAS_PATH = "entity_alias.csv"
ENTITY_DESCR_PATH = "entity_descriptions.csv"
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'

View File

@ -15,10 +15,11 @@ class Metrics(object):
candidate_is_correct = true_entity == candidate
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
# Therefore, if candidate_is_correct then we have a true positive and never a true negative
# Therefore, if candidate_is_correct then we have a true positive and never a true negative.
self.true_pos += candidate_is_correct
self.false_neg += not candidate_is_correct
if candidate not in {"", "NIL"}:
if candidate and candidate not in {"", "NIL"}:
# A wrong prediction (e.g. Q42 != Q3) counts both as a FP as well as a FN.
self.false_pos += not candidate_is_correct
def calculate_precision(self):
@ -33,6 +34,14 @@ class Metrics(object):
else:
return self.true_pos / (self.true_pos + self.false_neg)
def calculate_fscore(self):
p = self.calculate_precision()
r = self.calculate_recall()
if p + r == 0:
return 0.0
else:
return 2 * p * r / (p + r)
class EvaluationResults(object):
def __init__(self):
@ -43,18 +52,20 @@ class EvaluationResults(object):
self.metrics.update_results(true_entity, candidate)
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
def increment_false_negatives(self):
self.metrics.false_neg += 1
def report_metrics(self, model_name):
model_str = model_name.title()
recall = self.metrics.calculate_recall()
precision = self.metrics.calculate_precision()
return ("{}: ".format(model_str) +
"Recall = {} | ".format(round(recall, 3)) +
"Precision = {} | ".format(round(precision, 3)) +
"Precision by label = {}".format({k: v.calculate_precision()
for k, v in self.metrics_by_label.items()}))
fscore = self.metrics.calculate_fscore()
return (
"{}: ".format(model_str)
+ "F-score = {} | ".format(round(fscore, 3))
+ "Recall = {} | ".format(round(recall, 3))
+ "Precision = {} | ".format(round(precision, 3))
+ "F-score by label = {}".format(
{k: v.calculate_fscore() for k, v in sorted(self.metrics_by_label.items())}
)
)
class BaselineResults(object):
@ -63,40 +74,51 @@ class BaselineResults(object):
self.prior = EvaluationResults()
self.oracle = EvaluationResults()
def report_accuracy(self, model):
def report_performance(self, model):
results = getattr(self, model)
return results.report_metrics(model)
def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
def update_baselines(
self,
true_entity,
ent_label,
random_candidate,
prior_candidate,
oracle_candidate,
):
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
self.random.update_metrics(ent_label, true_entity, random_candidate)
def measure_performance(dev_data, kb, el_pipe):
baseline_accuracies = measure_baselines(
dev_data, kb
)
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
if baseline:
baseline_accuracies, counts = measure_baselines(dev_data, kb)
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
logger.info(baseline_accuracies.report_performance("random"))
logger.info(baseline_accuracies.report_performance("prior"))
logger.info(baseline_accuracies.report_performance("oracle"))
logger.info(baseline_accuracies.report_accuracy("random"))
logger.info(baseline_accuracies.report_accuracy("prior"))
logger.info(baseline_accuracies.report_accuracy("oracle"))
if context:
# using only context
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context only"))
# using only context
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context only"))
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context and prior"))
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context and prior"))
def get_eval_results(data, el_pipe=None):
# If the docs in the data require further processing with an entity linker, set el_pipe
"""
Evaluate the ent.kb_id_ annotations against the gold standard.
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
If the docs in the data require further processing with an entity linker, set el_pipe.
"""
from tqdm import tqdm
docs = []
@ -111,18 +133,15 @@ def get_eval_results(data, el_pipe=None):
results = EvaluationResults()
for doc, gold in zip(docs, golds):
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
# only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
# only evaluating on positive examples
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
if offset not in tagged_entries_per_article:
results.increment_false_negatives()
for ent in doc.ents:
ent_label = ent.label_
@ -142,7 +161,11 @@ def get_eval_results(data, el_pipe=None):
def measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
"""
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
Also return a dictionary of counts by entity label.
"""
counts_d = dict()
baseline_results = BaselineResults()
@ -152,7 +175,6 @@ def measure_baselines(data, kb):
for doc, gold in zip(docs, golds):
correct_entries_per_article = dict()
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
@ -160,10 +182,6 @@ def measure_baselines(data, kb):
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
if offset not in tagged_entries_per_article:
baseline_results.random.increment_false_negatives()
baseline_results.oracle.increment_false_negatives()
baseline_results.prior.increment_false_negatives()
for ent in doc.ents:
ent_label = ent.label_
@ -176,7 +194,7 @@ def measure_baselines(data, kb):
if gold_entity is not None:
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
prior_candidate = ""
random_candidate = ""
if candidates:
scores = []
@ -187,13 +205,21 @@ def measure_baselines(data, kb):
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
best_candidate = candidates[best_index].entity_
prior_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
baseline_results.update_baselines(gold_entity, ent_label,
random_candidate, best_candidate, oracle_candidate)
current_count = counts_d.get(ent_label, 0)
counts_d[ent_label] = current_count+1
return baseline_results
baseline_results.update_baselines(
gold_entity,
ent_label,
random_candidate,
prior_candidate,
oracle_candidate,
)
return baseline_results, counts_d
def _offset(start, end):

View File

@ -1,17 +1,12 @@
# coding: utf-8
from __future__ import unicode_literals
import csv
import logging
import spacy
import sys
from spacy.kb import KnowledgeBase
from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
csv.field_size_limit(sys.maxsize)
from bin.wiki_entity_linking import wiki_io as io
logger = logging.getLogger(__name__)
@ -22,18 +17,24 @@ def create_kb(
max_entities_per_alias,
min_entity_freq,
min_occ,
entity_def_input,
entity_def_path,
entity_descr_path,
count_input,
prior_prob_input,
entity_alias_path,
entity_freq_path,
prior_prob_path,
entity_vector_length,
):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length)
_define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path)
return kb
def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length):
# read the mappings from file
title_to_id = get_entity_to_id(entity_def_input)
id_to_descr = get_id_to_description(entity_descr_path)
title_to_id = io.read_title_to_id(entity_def_path)
id_to_descr = io.read_id_to_descr(entity_descr_path)
# check the length of the nlp vectors
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
@ -45,10 +46,8 @@ def create_kb(
" cf. https://spacy.io/usage/models#languages."
)
logger.info("Get entity frequencies")
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
entity_frequencies = io.read_entity_to_count(entity_freq_path)
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
title_to_id,
@ -56,36 +55,33 @@ def create_kb(
entity_frequencies,
min_entity_freq
)
logger.info("Left with {} entities".format(len(description_list)))
logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys())))
logger.info("Train entity encoder")
logger.info("Training entity encoder")
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
encoder.train(description_list=description_list, to_print=True)
logger.info("Get entity embeddings:")
logger.info("Getting entity embeddings")
embeddings = encoder.apply_encoder(description_list)
logger.info("Adding {} entities".format(len(entity_list)))
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
return entity_list, filtered_title_to_id
logger.info("Adding aliases")
def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
logger.info("Adding aliases from Wikipedia and Wikidata")
_add_aliases(
kb,
entity_list=entity_list,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_input=prior_prob_input,
prior_prob_path=prior_prob_path,
)
logger.info("KB size: {} entities, {} aliases".format(
kb.get_size_entities(),
kb.get_size_aliases()))
logger.info("Done with kb")
return kb
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
min_entity_freq: int = 10):
@ -104,34 +100,13 @@ def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
return filtered_title_to_id, entity_list, description_list, frequency_list
def get_entity_to_id(entity_def_output):
entity_to_id = dict()
with entity_def_output.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
entity_to_id[row[0]] = row[1]
return entity_to_id
def get_id_to_description(entity_descr_path):
id_to_desc = dict()
with entity_descr_path.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
id_to_desc[row[0]] = row[1]
return id_to_desc
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
wp_titles = title_to_id.keys()
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
with prior_prob_input.open("r", encoding="utf8") as prior_file:
logger.info("Adding WP aliases")
with prior_prob_path.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
@ -180,10 +155,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
line = prior_file.readline()
def read_nlp_kb(model_dir, kb_file):
nlp = spacy.load(model_dir)
def read_kb(nlp, kb_file):
kb = KnowledgeBase(vocab=nlp.vocab)
kb.load_bulk(kb_file)
logger.info("kb entities: {}".format(kb.get_size_entities()))
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
return nlp, kb
return kb

View File

@ -53,7 +53,7 @@ class EntityEncoder:
start = start + batch_size
stop = min(stop + batch_size, len(description_list))
logger.info("encoded: {} entities".format(stop))
logger.info("Encoded: {} entities".format(stop))
return encodings
@ -62,7 +62,7 @@ class EntityEncoder:
if to_print:
logger.info(
"Trained entity descriptions on {} ".format(processed) +
"(non-unique) entities across {} ".format(self.epochs) +
"(non-unique) descriptions across {} ".format(self.epochs) +
"epochs"
)
logger.info("Final loss: {}".format(loss))

View File

@ -1,395 +0,0 @@
# coding: utf-8
from __future__ import unicode_literals
import logging
import random
import re
import bz2
import json
from functools import partial
from spacy.gold import GoldParse
from bin.wiki_entity_linking import kb_creator
"""
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
Gold-standard entities are stored in one file in standoff format (by character offset).
"""
ENTITY_FILE = "gold_entities.csv"
logger = logging.getLogger(__name__)
def create_training_examples_and_descriptions(wikipedia_input,
entity_def_input,
description_output,
training_output,
parse_descriptions,
limit=None):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wikipedia_input,
wp_to_id,
description_output,
training_output,
parse_descriptions,
limit)
def _process_wikipedia_texts(wikipedia_input,
wp_to_id,
output,
training_output,
parse_descriptions,
limit=None):
"""
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
"""
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
read_ids = set()
with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
if parse_descriptions:
_write_training_description(descr_file, "WD_id", "description")
with bz2.open(wikipedia_input, mode="rb") as file:
article_count = 0
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
logger.info("Processed {} articles".format(article_count))
for line in file:
clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>":
reading_revision = True
elif clean_line == "</revision>":
reading_revision = False
# Start reading new page
if clean_line == "<page>":
article_text = ""
article_title = None
article_id = None
# finished reading this page
elif clean_line == "</page>":
if article_id:
clean_text, entities = _process_wp_text(
article_title,
article_text,
wp_to_id
)
if clean_text is not None and entities is not None:
_write_training_entities(entity_file,
article_id,
clean_text,
entities)
if article_title in wp_to_id and parse_descriptions:
description = " ".join(clean_text[:1000].split(" ")[:-1])
_write_training_description(
descr_file,
wp_to_id[article_title],
description
)
article_count += 1
if article_count % 10000 == 0:
logger.info("Processed {} articles".format(article_count))
if limit and article_count >= limit:
break
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
# start reading text within a page
if "<text" in clean_line:
reading_text = True
if reading_text:
article_text += " " + clean_line
# stop reading text within a page (we assume a new page doesn't start on the same line)
if "</text" in clean_line:
reading_text = False
# read the ID of this article (outside the revision portion of the document)
if not reading_revision:
ids = id_regex.search(clean_line)
if ids:
article_id = ids[0]
if article_id in read_ids:
logger.info(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id)
# read the title of this article (outside the revision portion of the document)
if not reading_revision:
titles = title_regex.search(clean_line)
if titles:
article_title = titles[0].strip()
logger.info("Finished. Processed {} articles".format(article_count))
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
info_regex = re.compile(r"{[^{]*?}")
htlm_regex = re.compile(r"&lt;!--[^-]*--&gt;")
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
ref_regex = re.compile(r"&lt;ref.*?&gt;") # non-greedy
ref_2_regex = re.compile(r"&lt;/ref.*?&gt;") # non-greedy
def _process_wp_text(article_title, article_text, wp_to_id):
# ignore meta Wikipedia pages
if (
article_title.startswith("Wikipedia:") or
article_title.startswith("Kategori:")
):
return None, None
# remove the text tags
text_search = text_regex.search(article_text)
if text_search is None:
return None, None
text = text_search.group(0)
# stop processing if this is a redirect page
if text.startswith("#REDIRECT"):
return None, None
# get the raw text without markup etc, keeping only interwiki links
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
return clean_text, entities
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
# remove bolding & italic markup
clean_text = clean_text.replace("'''", "")
clean_text = clean_text.replace("''", "")
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
try_again = True
previous_length = len(clean_text)
while try_again:
clean_text = info_regex.sub(
"", clean_text
) # non-greedy match excluding a nested {
if len(clean_text) < previous_length:
try_again = True
else:
try_again = False
previous_length = len(clean_text)
# remove HTML comments
clean_text = htlm_regex.sub("", clean_text)
# remove Category and File statements
clean_text = category_regex.sub("", clean_text)
clean_text = file_regex.sub("", clean_text)
# remove multiple =
while "==" in clean_text:
clean_text = clean_text.replace("==", "=")
clean_text = clean_text.replace(". =", ".")
clean_text = clean_text.replace(" = ", ". ")
clean_text = clean_text.replace("= ", ".")
clean_text = clean_text.replace(" =", "")
# remove refs (non-greedy match)
clean_text = ref_regex.sub("", clean_text)
clean_text = ref_2_regex.sub("", clean_text)
# remove additional wikiformatting
clean_text = re.sub(r"&lt;blockquote&gt;", "", clean_text)
clean_text = re.sub(r"&lt;/blockquote&gt;", "", clean_text)
# change special characters back to normal ones
clean_text = clean_text.replace(r"&lt;", "<")
clean_text = clean_text.replace(r"&gt;", ">")
clean_text = clean_text.replace(r"&quot;", '"')
clean_text = clean_text.replace(r"&amp;nbsp;", " ")
clean_text = clean_text.replace(r"&amp;", "&")
# remove multiple spaces
while " " in clean_text:
clean_text = clean_text.replace(" ", " ")
return clean_text.strip()
def _remove_links(clean_text, wp_to_id):
# read the text char by char to get the right offsets for the interwiki links
entities = []
final_text = ""
open_read = 0
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == "[":
open_read += 1
elif letter == "]":
open_read -= 1
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
elif reading_entity:
reading_text = False
reading_entity = False
reading_mention = True
else:
reading_special_case = True
else:
if reading_entity:
entity_buffer += letter
elif reading_mention:
mention_buffer += letter
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
if open_read > 2:
reading_special_case = True
if open_read == 2 and reading_text:
reading_text = False
reading_entity = True
reading_mention = False
# we just finished reading an entity
if open_read == 0 and not reading_text:
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
if not mention_buffer:
mention_buffer = entity_buffer
start = len(final_text)
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
entities.append((mention_buffer, qid, start, end))
final_text += mention_buffer
entity_buffer = ""
mention_buffer = ""
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
return final_text, entities
def _write_training_description(outputfile, qid, description):
if description is not None:
line = str(qid) + "|" + description + "\n"
outputfile.write(line)
def _write_training_entities(outputfile, article_id, clean_text, entities):
entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
line = json.dumps(
{
"article_id": article_id,
"clean_text": clean_text,
"entities": entities_data
},
ensure_ascii=False) + "\n"
outputfile.write(line)
def read_training(nlp, entity_file_path, dev, limit, kb):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
For training,, it will include negative training examples by using the candidate generator,
and it will only keep positive training examples that can be found by using the candidate generator.
For testing, it will include all positive examples only."""
from tqdm import tqdm
data = []
num_entities = 0
get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
with entity_file_path.open("r", encoding="utf8") as file:
with tqdm(total=limit, leave=False) as pbar:
for i, line in enumerate(file):
example = json.loads(line)
article_id = example["article_id"]
clean_text = example["clean_text"]
entities = example["entities"]
if dev != is_dev(article_id) or len(clean_text) >= 30000:
continue
doc = nlp(clean_text)
gold = get_gold_parse(doc, entities)
if gold and len(gold.links) > 0:
data.append((doc, gold))
num_entities += len(gold.links)
pbar.update(len(gold.links))
if limit and num_entities >= limit:
break
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
return data
def _get_gold_parse(doc, entities, dev, kb):
gold_entities = {}
tagged_ent_positions = set(
[(ent.start_char, ent.end_char) for ent in doc.ents]
)
for entity in entities:
entity_id = entity["entity"]
alias = entity["alias"]
start = entity["start"]
end = entity["end"]
candidates = kb.get_candidates(alias)
candidate_ids = [
c.entity_ for c in candidates
]
should_add_ent = (
dev or
(
(start, end) in tagged_ent_positions and
entity_id in candidate_ids and
len(candidates) > 1
)
)
if should_add_ent:
value_by_id = {entity_id: 1.0}
if not dev:
random.shuffle(candidate_ids)
value_by_id.update({
kb_id: 0.0
for kb_id in candidate_ids
if kb_id != entity_id
})
gold_entities[(start, end)] = value_by_id
return GoldParse(doc, links=gold_entities)
def is_dev(article_id):
return article_id.endswith("3")

View File

@ -0,0 +1,127 @@
# coding: utf-8
from __future__ import unicode_literals
import sys
import csv
# min() needed to prevent error on windows, cf https://stackoverflow.com/questions/52404416/
csv.field_size_limit(min(sys.maxsize, 2147483646))
""" This class provides reading/writing methods for temp files """
# Entity definition: WP title -> WD ID #
def write_title_to_id(entity_def_output, title_to_id):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
def read_title_to_id(entity_def_output):
title_to_id = dict()
with entity_def_output.open("r", encoding="utf8") as id_file:
csvreader = csv.reader(id_file, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
title_to_id[row[0]] = row[1]
return title_to_id
# Entity aliases from WD: WD ID -> WD alias #
def write_id_to_alias(entity_alias_path, id_to_alias):
with entity_alias_path.open("w", encoding="utf8") as alias_file:
alias_file.write("WD_id" + "|" + "alias" + "\n")
for qid, alias_list in id_to_alias.items():
for alias in alias_list:
alias_file.write(str(qid) + "|" + alias + "\n")
def read_id_to_alias(entity_alias_path):
id_to_alias = dict()
with entity_alias_path.open("r", encoding="utf8") as alias_file:
csvreader = csv.reader(alias_file, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
qid = row[0]
alias = row[1]
alias_list = id_to_alias.get(qid, [])
alias_list.append(alias)
id_to_alias[qid] = alias_list
return id_to_alias
def read_alias_to_id_generator(entity_alias_path):
""" Read (aliases, qid) tuples """
with entity_alias_path.open("r", encoding="utf8") as alias_file:
csvreader = csv.reader(alias_file, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
qid = row[0]
alias = row[1]
yield alias, qid
# Entity descriptions from WD: WD ID -> WD alias #
def write_id_to_descr(entity_descr_output, id_to_descr):
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")
def read_id_to_descr(entity_desc_path):
id_to_desc = dict()
with entity_desc_path.open("r", encoding="utf8") as descr_file:
csvreader = csv.reader(descr_file, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
id_to_desc[row[0]] = row[1]
return id_to_desc
# Entity counts from WP: WP title -> count #
def write_entity_to_count(prior_prob_input, count_output):
# Write entity counts for quick access later
entity_to_count = dict()
total_count = 0
with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
while line:
splits = line.replace("\n", "").split(sep="|")
# alias = splits[0]
count = int(splits[1])
entity = splits[2]
current_count = entity_to_count.get(entity, 0)
entity_to_count[entity] = current_count + count
total_count += count
line = prior_file.readline()
with count_output.open("w", encoding="utf8") as entity_file:
entity_file.write("entity" + "|" + "count" + "\n")
for entity, count in entity_to_count.items():
entity_file.write(entity + "|" + str(count) + "\n")
def read_entity_to_count(count_input):
entity_to_count = dict()
with count_input.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
entity_to_count[row[0]] = int(row[1])
return entity_to_count

View File

@ -0,0 +1,128 @@
# coding: utf8
from __future__ import unicode_literals
# List of meta pages in Wikidata, should be kept out of the Knowledge base
WD_META_ITEMS = [
"Q163875",
"Q191780",
"Q224414",
"Q4167836",
"Q4167410",
"Q4663903",
"Q11266439",
"Q13406463",
"Q15407973",
"Q18616576",
"Q19887878",
"Q22808320",
"Q23894233",
"Q33120876",
"Q42104522",
"Q47460393",
"Q64875536",
"Q66480449",
]
# TODO: add more cases from non-English WP's
# List of prefixes that refer to Wikipedia "file" pages
WP_FILE_NAMESPACE = ["Bestand", "File"]
# List of prefixes that refer to Wikipedia "category" pages
WP_CATEGORY_NAMESPACE = ["Kategori", "Category", "Categorie"]
# List of prefixes that refer to Wikipedia "meta" pages
# these will/should be matched ignoring case
WP_META_NAMESPACE = (
WP_FILE_NAMESPACE
+ WP_CATEGORY_NAMESPACE
+ [
"b",
"betawikiversity",
"Book",
"c",
"Commons",
"d",
"dbdump",
"download",
"Draft",
"Education",
"Foundation",
"Gadget",
"Gadget definition",
"Gebruiker",
"gerrit",
"Help",
"Image",
"Incubator",
"m",
"mail",
"mailarchive",
"media",
"MediaWiki",
"MediaWiki talk",
"Mediawikiwiki",
"MediaZilla",
"Meta",
"Metawikipedia",
"Module",
"mw",
"n",
"nost",
"oldwikisource",
"otrs",
"OTRSwiki",
"Overleg gebruiker",
"outreach",
"outreachwiki",
"Portal",
"phab",
"Phabricator",
"Project",
"q",
"quality",
"rev",
"s",
"spcom",
"Special",
"species",
"Strategy",
"sulutil",
"svn",
"Talk",
"Template",
"Template talk",
"Testwiki",
"ticket",
"TimedText",
"Toollabs",
"tools",
"tswiki",
"User",
"User talk",
"v",
"voy",
"w",
"Wikibooks",
"Wikidata",
"wikiHow",
"Wikinvest",
"wikilivres",
"Wikimedia",
"Wikinews",
"Wikipedia",
"Wikipedia talk",
"Wikiquote",
"Wikisource",
"Wikispecies",
"Wikitech",
"Wikiversity",
"Wikivoyage",
"wikt",
"wiktionary",
"wmf",
"wmania",
"WP",
]
)

View File

@ -18,11 +18,12 @@ from pathlib import Path
import plac
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
from bin.wiki_entity_linking import wiki_io as io
from bin.wiki_entity_linking import kb_creator
from bin.wiki_entity_linking import training_set_creator
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH, ENTITY_ALIAS_PATH
import spacy
from bin.wiki_entity_linking.kb_creator import read_kb
logger = logging.getLogger(__name__)
@ -39,9 +40,11 @@ logger = logging.getLogger(__name__)
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
descr_from_wp=("Flag for using wp descriptions not wd", "flag", "wp"),
limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int),
limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int),
limit_wd=("Threshold to limit lines read from WD", "option", "lw", int),
lang=("Optional language for which to get Wikidata titles. Defaults to 'en'", "option", "la", str),
)
def main(
wd_json,
@ -54,13 +57,16 @@ def main(
entity_vector_length=64,
loc_prior_prob=None,
loc_entity_defs=None,
loc_entity_alias=None,
loc_entity_desc=None,
descriptions_from_wikipedia=False,
limit=None,
descr_from_wp=False,
limit_prior=None,
limit_train=None,
limit_wd=None,
lang="en",
):
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
entity_alias_path = loc_entity_alias if loc_entity_alias else output_dir / ENTITY_ALIAS_PATH
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
entity_freq_path = output_dir / ENTITY_FREQ_PATH
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
@ -69,15 +75,12 @@ def main(
logger.info("Creating KB with Wikipedia and WikiData")
if limit is not None:
logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
# STEP 0: set up IO
if not output_dir.exists():
output_dir.mkdir(parents=True)
# STEP 1: create the NLP object
logger.info("STEP 1: Loading model {}".format(model))
# STEP 1: Load the NLP object
logger.info("STEP 1: Loading NLP model {}".format(model))
nlp = spacy.load(model)
# check the length of the nlp vectors
@ -90,62 +93,83 @@ def main(
# STEP 2: create prior probabilities from WP
if not prior_prob_path.exists():
# It takes about 2h to process 1000M lines of Wikipedia XML dump
logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
logger.info("STEP 2: Writing prior probabilities to {}".format(prior_prob_path))
if limit_prior is not None:
logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_prior))
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit_prior)
else:
logger.info("STEP 2: Reading prior probabilities from {}".format(prior_prob_path))
# STEP 3: deduce entity frequencies from WP (takes only a few minutes)
logger.info("STEP 3: calculating entity frequencies")
wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
# STEP 3: calculate entity frequencies
if not entity_freq_path.exists():
logger.info("STEP 3: Calculating and writing entity frequencies to {}".format(entity_freq_path))
io.write_entity_to_count(prior_prob_path, entity_freq_path)
else:
logger.info("STEP 3: Reading entity frequencies from {}".format(entity_freq_path))
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
message = " and descriptions" if not descriptions_from_wikipedia else ""
if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
if (not entity_defs_path.exists()) or (not descr_from_wp and not entity_descr_path.exists()):
# It takes about 10h to process 55M lines of Wikidata JSON dump
logger.info("STEP 4: parsing wikidata for entity definitions" + message)
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
logger.info("STEP 4: Parsing and writing Wikidata entity definitions to {}".format(entity_defs_path))
if limit_wd is not None:
logger.warning("Warning: reading only {} lines of Wikidata dump".format(limit_wd))
title_to_id, id_to_descr, id_to_alias = wd.read_wikidata_entities_json(
wd_json,
limit,
limit_wd,
to_print=False,
lang=lang,
parse_descriptions=(not descriptions_from_wikipedia),
parse_descr=(not descr_from_wp),
)
wd.write_entity_files(entity_defs_path, title_to_id)
if not descriptions_from_wikipedia:
wd.write_entity_description_files(entity_descr_path, id_to_descr)
logger.info("STEP 4: read entity definitions" + message)
io.write_title_to_id(entity_defs_path, title_to_id)
# STEP 5: Getting gold entities from wikipedia
message = " and descriptions" if descriptions_from_wikipedia else ""
if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
logger.info("STEP 5: parsing wikipedia for gold entities" + message)
training_set_creator.create_training_examples_and_descriptions(
wp_xml,
entity_defs_path,
entity_descr_path,
training_entities_path,
parse_descriptions=descriptions_from_wikipedia,
limit=limit,
)
logger.info("STEP 5: read gold entities" + message)
logger.info("STEP 4b: Writing Wikidata entity aliases to {}".format(entity_alias_path))
io.write_id_to_alias(entity_alias_path, id_to_alias)
if not descr_from_wp:
logger.info("STEP 4c: Writing Wikidata entity descriptions to {}".format(entity_descr_path))
io.write_id_to_descr(entity_descr_path, id_to_descr)
else:
logger.info("STEP 4: Reading entity definitions from {}".format(entity_defs_path))
logger.info("STEP 4b: Reading entity aliases from {}".format(entity_alias_path))
if not descr_from_wp:
logger.info("STEP 4c: Reading entity descriptions from {}".format(entity_descr_path))
# STEP 5: Getting gold entities from Wikipedia
if (not training_entities_path.exists()) or (descr_from_wp and not entity_descr_path.exists()):
logger.info("STEP 5: Parsing and writing Wikipedia gold entities to {}".format(training_entities_path))
if limit_train is not None:
logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_train))
wp.create_training_and_desc(wp_xml, entity_defs_path, entity_descr_path,
training_entities_path, descr_from_wp, limit_train)
if descr_from_wp:
logger.info("STEP 5b: Parsing and writing Wikipedia descriptions to {}".format(entity_descr_path))
else:
logger.info("STEP 5: Reading gold entities from {}".format(training_entities_path))
if descr_from_wp:
logger.info("STEP 5b: Reading entity descriptions from {}".format(entity_descr_path))
# STEP 6: creating the actual KB
# It takes ca. 30 minutes to pretrain the entity embeddings
logger.info("STEP 6: creating the KB at {}".format(kb_path))
kb = kb_creator.create_kb(
nlp=nlp,
max_entities_per_alias=max_per_alias,
min_entity_freq=min_freq,
min_occ=min_pair,
entity_def_input=entity_defs_path,
entity_descr_path=entity_descr_path,
count_input=entity_freq_path,
prior_prob_input=prior_prob_path,
entity_vector_length=entity_vector_length,
)
kb.dump(kb_path)
nlp.to_disk(output_dir / KB_MODEL_DIR)
if not kb_path.exists():
logger.info("STEP 6: Creating the KB at {}".format(kb_path))
kb = kb_creator.create_kb(
nlp=nlp,
max_entities_per_alias=max_per_alias,
min_entity_freq=min_freq,
min_occ=min_pair,
entity_def_path=entity_defs_path,
entity_descr_path=entity_descr_path,
entity_alias_path=entity_alias_path,
entity_freq_path=entity_freq_path,
prior_prob_path=prior_prob_path,
entity_vector_length=entity_vector_length,
)
kb.dump(kb_path)
logger.info("kb entities: {}".format(kb.get_size_entities()))
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
nlp.to_disk(output_dir / KB_MODEL_DIR)
else:
logger.info("STEP 6: KB already exists at {}".format(kb_path))
logger.info("Done!")

View File

@ -1,40 +1,52 @@
# coding: utf-8
from __future__ import unicode_literals
import gzip
import bz2
import json
import logging
import datetime
from bin.wiki_entity_linking.wiki_namespaces import WD_META_ITEMS
logger = logging.getLogger(__name__)
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descr=True):
# Read the JSON wiki data and parse out the entities. Takes about 7-10h to parse 55M lines.
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
site_filter = '{}wiki'.format(lang)
# properties filter (currently disabled to get ALL data)
prop_filter = dict()
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
# filter: currently defined as OR: one hit suffices to be removed from further processing
exclude_list = WD_META_ITEMS
# punctuation
exclude_list.extend(["Q1383557", "Q10617810"])
# letters etc
exclude_list.extend(["Q188725", "Q19776628", "Q3841820", "Q17907810", "Q9788", "Q9398093"])
neg_prop_filter = {
'P31': exclude_list, # instance of
'P279': exclude_list # subclass
}
title_to_id = dict()
id_to_descr = dict()
id_to_alias = dict()
# parse appropriate fields - depending on what we need in the KB
parse_properties = False
parse_sitelinks = True
parse_labels = False
parse_aliases = False
parse_claims = False
parse_aliases = True
parse_claims = True
with gzip.open(wikidata_file, mode='rb') as file:
with bz2.open(wikidata_file, mode='rb') as file:
for cnt, line in enumerate(file):
if limit and cnt >= limit:
break
if cnt % 500000 == 0:
logger.info("processed {} lines of WikiData dump".format(cnt))
if cnt % 500000 == 0 and cnt > 0:
logger.info("processed {} lines of WikiData JSON dump".format(cnt))
clean_line = line.strip()
if clean_line.endswith(b","):
clean_line = clean_line[:-1]
@ -43,13 +55,11 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
entry_type = obj["type"]
if entry_type == "item":
# filtering records on their properties (currently disabled to get ALL data)
# keep = False
keep = True
claims = obj["claims"]
if parse_claims:
for prop, value_set in prop_filter.items():
for prop, value_set in neg_prop_filter.items():
claim_property = claims.get(prop, None)
if claim_property:
for cp in claim_property:
@ -61,7 +71,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
)
cp_rank = cp["rank"]
if cp_rank != "deprecated" and cp_id in value_set:
keep = True
keep = False
if keep:
unique_id = obj["id"]
@ -108,7 +118,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
"label (" + lang + "):", lang_label["value"]
)
if found_link and parse_descriptions:
if found_link and parse_descr:
descriptions = obj["descriptions"]
if descriptions:
lang_descr = descriptions.get(lang, None)
@ -130,22 +140,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
print(
"alias (" + lang + "):", item["value"]
)
alias_list = id_to_alias.get(unique_id, [])
alias_list.append(item["value"])
id_to_alias[unique_id] = alias_list
if to_print:
print()
return title_to_id, id_to_descr
# log final number of lines processed
logger.info("Finished. Processed {} lines of WikiData JSON dump".format(cnt))
return title_to_id, id_to_descr, id_to_alias
def write_entity_files(entity_def_output, title_to_id):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
def write_entity_description_files(entity_descr_output, id_to_descr):
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")

View File

@ -6,19 +6,19 @@ as created by the script `wikidata_create_kb`.
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
from https://dumps.wikimedia.org/enwiki/latest/
"""
from __future__ import unicode_literals
import random
import logging
import spacy
from pathlib import Path
import plac
from bin.wiki_entity_linking import training_set_creator
from bin.wiki_entity_linking import wikipedia_processor
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
from bin.wiki_entity_linking.kb_creator import read_nlp_kb
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance
from bin.wiki_entity_linking.kb_creator import read_kb
from spacy.util import minibatch, compounding
@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
l2=("L2 regularization", "option", "r", float),
train_inst=("# training instances (default 90% of all)", "option", "t", int),
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
labels_discard=("NER labels to discard (default None)", "option", "l", str),
)
def main(
dir_kb,
@ -46,13 +47,14 @@ def main(
l2=1e-6,
train_inst=None,
dev_inst=None,
labels_discard=None
):
logger.info("Creating Entity Linker with Wikipedia and WikiData")
output_dir = Path(output_dir) if output_dir else dir_kb
training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
training_path = loc_training if loc_training else dir_kb / TRAINING_DATA_FILE
nlp_dir = dir_kb / KB_MODEL_DIR
kb_path = output_dir / KB_FILE
kb_path = dir_kb / KB_FILE
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
# STEP 0: set up IO
@ -60,38 +62,47 @@ def main(
output_dir.mkdir()
# STEP 1 : load the NLP object
logger.info("STEP 1: loading model from {}".format(nlp_dir))
nlp, kb = read_nlp_kb(nlp_dir, kb_path)
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
nlp = spacy.load(nlp_dir)
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
kb = read_kb(nlp, kb_path)
# check that there is a NER component in the pipeline
if "ner" not in nlp.pipe_names:
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
# STEP 2: create a training dataset from WP
logger.info("STEP 2: reading training dataset from {}".format(training_path))
# STEP 2: read the training dataset previously created from WP
logger.info("STEP 2: Reading training dataset from {}".format(training_path))
train_data = training_set_creator.read_training(
if labels_discard:
labels_discard = [x.strip() for x in labels_discard.split(",")]
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
train_data = wikipedia_processor.read_training(
nlp=nlp,
entity_file_path=training_path,
dev=False,
limit=train_inst,
kb=kb,
labels_discard=labels_discard
)
# for testing, get all pos instances, whether or not they are in the kb
dev_data = training_set_creator.read_training(
# for testing, get all pos instances (independently of KB)
dev_data = wikipedia_processor.read_training(
nlp=nlp,
entity_file_path=training_path,
dev=True,
limit=dev_inst,
kb=kb,
kb=None,
labels_discard=labels_discard
)
# STEP 3: create and train the entity linking pipe
logger.info("STEP 3: training Entity Linking pipe")
# STEP 3: create and train an entity linking pipe
logger.info("STEP 3: Creating and training an Entity Linking pipe")
el_pipe = nlp.create_pipe(
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
"labels_discard": labels_discard}
)
el_pipe.set_kb(kb)
nlp.add_pipe(el_pipe, last=True)
@ -105,14 +116,9 @@ def main(
logger.info("Training on {} articles".format(len(train_data)))
logger.info("Dev testing on {} articles".format(len(dev_data)))
dev_baseline_accuracies = measure_baselines(
dev_data, kb
)
# baseline performance on dev data
logger.info("Dev Baseline Accuracies:")
logger.info(dev_baseline_accuracies.report_accuracy("random"))
logger.info(dev_baseline_accuracies.report_accuracy("prior"))
logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
for itn in range(epochs):
random.shuffle(train_data)
@ -136,18 +142,18 @@ def main(
logger.error("Error updating batch:" + str(e))
if batchnr > 0:
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
measure_performance(dev_data, kb, el_pipe)
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
# STEP 4: measure the performance of our trained pipe on an independent dev set
logger.info("STEP 4: performance measurement of Entity Linking pipe")
logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
measure_performance(dev_data, kb, el_pipe)
# STEP 5: apply the EL pipe on a toy example
logger.info("STEP 5: applying Entity Linking to toy example")
logger.info("STEP 5: Applying Entity Linking to toy example")
run_el_toy_example(nlp=nlp)
if output_dir:
# STEP 6: write the NLP pipeline (including entity linker) to file
# STEP 6: write the NLP pipeline (now including an EL model) to file
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
nlp.to_disk(nlp_output_dir)

View File

@ -3,147 +3,104 @@ from __future__ import unicode_literals
import re
import bz2
import csv
import datetime
import logging
import random
import json
from bin.wiki_entity_linking import LOG_FORMAT
from functools import partial
from spacy.gold import GoldParse
from bin.wiki_entity_linking import wiki_io as io
from bin.wiki_entity_linking.wiki_namespaces import (
WP_META_NAMESPACE,
WP_FILE_NAMESPACE,
WP_CATEGORY_NAMESPACE,
)
"""
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
Write these results to file for downstream KB and training data generation.
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
"""
ENTITY_FILE = "gold_entities.csv"
map_alias_to_link = dict()
logger = logging.getLogger(__name__)
# these will/should be matched ignoring case
wiki_namespaces = [
"b",
"betawikiversity",
"Book",
"c",
"Category",
"Commons",
"d",
"dbdump",
"download",
"Draft",
"Education",
"Foundation",
"Gadget",
"Gadget definition",
"gerrit",
"File",
"Help",
"Image",
"Incubator",
"m",
"mail",
"mailarchive",
"media",
"MediaWiki",
"MediaWiki talk",
"Mediawikiwiki",
"MediaZilla",
"Meta",
"Metawikipedia",
"Module",
"mw",
"n",
"nost",
"oldwikisource",
"outreach",
"outreachwiki",
"otrs",
"OTRSwiki",
"Portal",
"phab",
"Phabricator",
"Project",
"q",
"quality",
"rev",
"s",
"spcom",
"Special",
"species",
"Strategy",
"sulutil",
"svn",
"Talk",
"Template",
"Template talk",
"Testwiki",
"ticket",
"TimedText",
"Toollabs",
"tools",
"tswiki",
"User",
"User talk",
"v",
"voy",
"w",
"Wikibooks",
"Wikidata",
"wikiHow",
"Wikinvest",
"wikilivres",
"Wikimedia",
"Wikinews",
"Wikipedia",
"Wikipedia talk",
"Wikiquote",
"Wikisource",
"Wikispecies",
"Wikitech",
"Wikiversity",
"Wikivoyage",
"wikt",
"wiktionary",
"wmf",
"wmania",
"WP",
]
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
info_regex = re.compile(r"{[^{]*?}")
html_regex = re.compile(r"&lt;!--[^-]*--&gt;")
ref_regex = re.compile(r"&lt;ref.*?&gt;") # non-greedy
ref_2_regex = re.compile(r"&lt;/ref.*?&gt;") # non-greedy
# find the links
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
# match on interwiki links, e.g. `en:` or `:fr:`
ns_regex = r":?" + "[a-z][a-z]" + ":"
# match on Namespace: optionally preceded by a :
for ns in wiki_namespaces:
for ns in WP_META_NAMESPACE:
ns_regex += "|" + ":?" + ns + ":"
ns_regex = re.compile(ns_regex, re.IGNORECASE)
files = r""
for f in WP_FILE_NAMESPACE:
files += "\[\[" + f + ":[^[\]]+]]" + "|"
files = files[0 : len(files) - 1]
file_regex = re.compile(files)
cats = r""
for c in WP_CATEGORY_NAMESPACE:
cats += "\[\[" + c + ":[^\[]*]]" + "|"
cats = cats[0 : len(cats) - 1]
category_regex = re.compile(cats)
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
"""
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
The full file takes about 2h to parse 1100M lines.
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
The full file takes about 2-3h to parse 1100M lines.
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from,
though dev test articles are excluded in order not to get an artificially strong baseline.
"""
cnt = 0
read_id = False
current_article_id = None
with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline()
cnt = 0
while line and (not limit or cnt < limit):
if cnt % 25000000 == 0:
if cnt % 25000000 == 0 and cnt > 0:
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
clean_line = line.strip().decode("utf-8")
aliases, entities, normalizations = get_wp_links(clean_line)
for alias, entity, norm in zip(aliases, entities, normalizations):
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
# we attempt at reading the article's ID (but not the revision or contributor ID)
if "<revision>" in clean_line or "<contributor>" in clean_line:
read_id = False
if "<page>" in clean_line:
read_id = True
if read_id:
ids = id_regex.search(clean_line)
if ids:
current_article_id = ids[0]
# only processing prior probabilities from true training (non-dev) articles
if not is_dev(current_article_id):
aliases, entities, normalizations = get_wp_links(clean_line)
for alias, entity, norm in zip(aliases, entities, normalizations):
_store_alias(
alias, entity, normalize_alias=norm, normalize_entity=True
)
line = file.readline()
cnt += 1
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
logger.info("Finished. processed {} lines of Wikipedia XML dump".format(cnt))
# write all aliases and their entities and count occurrences to file
with prior_prob_output.open("w", encoding="utf8") as outputfile:
@ -182,7 +139,7 @@ def get_wp_links(text):
match = match[2:][:-2].replace("_", " ").strip()
if ns_regex.match(match):
pass # ignore namespaces at the beginning of the string
pass # ignore the entity if it points to a "meta" page
# this is a simple [[link]], with the alias the same as the mention
elif "|" not in match:
@ -218,47 +175,382 @@ def _capitalize_first(text):
return result
def write_entity_counts(prior_prob_input, count_output, to_print=False):
# Write entity counts for quick access later
entity_to_count = dict()
total_count = 0
with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
while line:
splits = line.replace("\n", "").split(sep="|")
# alias = splits[0]
count = int(splits[1])
entity = splits[2]
current_count = entity_to_count.get(entity, 0)
entity_to_count[entity] = current_count + count
total_count += count
line = prior_file.readline()
with count_output.open("w", encoding="utf8") as entity_file:
entity_file.write("entity" + "|" + "count" + "\n")
for entity, count in entity_to_count.items():
entity_file.write(entity + "|" + str(count) + "\n")
if to_print:
for entity, count in entity_to_count.items():
print("Entity count:", entity, count)
print("Total count:", total_count)
def create_training_and_desc(
wp_input, def_input, desc_output, training_output, parse_desc, limit=None
):
wp_to_id = io.read_title_to_id(def_input)
_process_wikipedia_texts(
wp_input, wp_to_id, desc_output, training_output, parse_desc, limit
)
def get_all_frequencies(count_input):
entity_to_count = dict()
with count_input.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
entity_to_count[row[0]] = int(row[1])
def _process_wikipedia_texts(
wikipedia_input, wp_to_id, output, training_output, parse_descriptions, limit=None
):
"""
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
"""
return entity_to_count
read_ids = set()
with output.open("a", encoding="utf8") as descr_file, training_output.open(
"w", encoding="utf8"
) as entity_file:
if parse_descriptions:
_write_training_description(descr_file, "WD_id", "description")
with bz2.open(wikipedia_input, mode="rb") as file:
article_count = 0
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
for line in file:
clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>":
reading_revision = True
elif clean_line == "</revision>":
reading_revision = False
# Start reading new page
if clean_line == "<page>":
article_text = ""
article_title = None
article_id = None
# finished reading this page
elif clean_line == "</page>":
if article_id:
clean_text, entities = _process_wp_text(
article_title, article_text, wp_to_id
)
if clean_text is not None and entities is not None:
_write_training_entities(
entity_file, article_id, clean_text, entities
)
if article_title in wp_to_id and parse_descriptions:
description = " ".join(
clean_text[:1000].split(" ")[:-1]
)
_write_training_description(
descr_file, wp_to_id[article_title], description
)
article_count += 1
if article_count % 10000 == 0 and article_count > 0:
logger.info(
"Processed {} articles".format(article_count)
)
if limit and article_count >= limit:
break
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
# start reading text within a page
if "<text" in clean_line:
reading_text = True
if reading_text:
article_text += " " + clean_line
# stop reading text within a page (we assume a new page doesn't start on the same line)
if "</text" in clean_line:
reading_text = False
# read the ID of this article (outside the revision portion of the document)
if not reading_revision:
ids = id_regex.search(clean_line)
if ids:
article_id = ids[0]
if article_id in read_ids:
logger.info(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id)
# read the title of this article (outside the revision portion of the document)
if not reading_revision:
titles = title_regex.search(clean_line)
if titles:
article_title = titles[0].strip()
logger.info("Finished. Processed {} articles".format(article_count))
def _process_wp_text(article_title, article_text, wp_to_id):
# ignore meta Wikipedia pages
if ns_regex.match(article_title):
return None, None
# remove the text tags
text_search = text_regex.search(article_text)
if text_search is None:
return None, None
text = text_search.group(0)
# stop processing if this is a redirect page
if text.startswith("#REDIRECT"):
return None, None
# get the raw text without markup etc, keeping only interwiki links
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
return clean_text, entities
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
# remove bolding & italic markup
clean_text = clean_text.replace("'''", "")
clean_text = clean_text.replace("''", "")
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
try_again = True
previous_length = len(clean_text)
while try_again:
clean_text = info_regex.sub(
"", clean_text
) # non-greedy match excluding a nested {
if len(clean_text) < previous_length:
try_again = True
else:
try_again = False
previous_length = len(clean_text)
# remove HTML comments
clean_text = html_regex.sub("", clean_text)
# remove Category and File statements
clean_text = category_regex.sub("", clean_text)
clean_text = file_regex.sub("", clean_text)
# remove multiple =
while "==" in clean_text:
clean_text = clean_text.replace("==", "=")
clean_text = clean_text.replace(". =", ".")
clean_text = clean_text.replace(" = ", ". ")
clean_text = clean_text.replace("= ", ".")
clean_text = clean_text.replace(" =", "")
# remove refs (non-greedy match)
clean_text = ref_regex.sub("", clean_text)
clean_text = ref_2_regex.sub("", clean_text)
# remove additional wikiformatting
clean_text = re.sub(r"&lt;blockquote&gt;", "", clean_text)
clean_text = re.sub(r"&lt;/blockquote&gt;", "", clean_text)
# change special characters back to normal ones
clean_text = clean_text.replace(r"&lt;", "<")
clean_text = clean_text.replace(r"&gt;", ">")
clean_text = clean_text.replace(r"&quot;", '"')
clean_text = clean_text.replace(r"&amp;nbsp;", " ")
clean_text = clean_text.replace(r"&amp;", "&")
# remove multiple spaces
while " " in clean_text:
clean_text = clean_text.replace(" ", " ")
return clean_text.strip()
def _remove_links(clean_text, wp_to_id):
# read the text char by char to get the right offsets for the interwiki links
entities = []
final_text = ""
open_read = 0
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == "[":
open_read += 1
elif letter == "]":
open_read -= 1
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
elif reading_entity:
reading_text = False
reading_entity = False
reading_mention = True
else:
reading_special_case = True
else:
if reading_entity:
entity_buffer += letter
elif reading_mention:
mention_buffer += letter
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
if open_read > 2:
reading_special_case = True
if open_read == 2 and reading_text:
reading_text = False
reading_entity = True
reading_mention = False
# we just finished reading an entity
if open_read == 0 and not reading_text:
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
if not mention_buffer:
mention_buffer = entity_buffer
start = len(final_text)
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
entities.append((mention_buffer, qid, start, end))
final_text += mention_buffer
entity_buffer = ""
mention_buffer = ""
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
return final_text, entities
def _write_training_description(outputfile, qid, description):
if description is not None:
line = str(qid) + "|" + description + "\n"
outputfile.write(line)
def _write_training_entities(outputfile, article_id, clean_text, entities):
entities_data = [
{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]}
for ent in entities
]
line = (
json.dumps(
{
"article_id": article_id,
"clean_text": clean_text,
"entities": entities_data,
},
ensure_ascii=False,
)
+ "\n"
)
outputfile.write(line)
def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
For training, it will include both positive and negative examples by using the candidate generator from the kb.
For testing (kb=None), it will include all positive examples only."""
from tqdm import tqdm
if not labels_discard:
labels_discard = []
data = []
num_entities = 0
get_gold_parse = partial(
_get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
)
logger.info(
"Reading {} data with limit {}".format("dev" if dev else "train", limit)
)
with entity_file_path.open("r", encoding="utf8") as file:
with tqdm(total=limit, leave=False) as pbar:
for i, line in enumerate(file):
example = json.loads(line)
article_id = example["article_id"]
clean_text = example["clean_text"]
entities = example["entities"]
if dev != is_dev(article_id) or not is_valid_article(clean_text):
continue
doc = nlp(clean_text)
gold = get_gold_parse(doc, entities)
if gold and len(gold.links) > 0:
data.append((doc, gold))
num_entities += len(gold.links)
pbar.update(len(gold.links))
if limit and num_entities >= limit:
break
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
return data
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
gold_entities = {}
tagged_ent_positions = {
(ent.start_char, ent.end_char): ent
for ent in doc.ents
if ent.label_ not in labels_discard
}
for entity in entities:
entity_id = entity["entity"]
alias = entity["alias"]
start = entity["start"]
end = entity["end"]
candidate_ids = []
if kb and not dev:
candidates = kb.get_candidates(alias)
candidate_ids = [cand.entity_ for cand in candidates]
tagged_ent = tagged_ent_positions.get((start, end), None)
if tagged_ent:
# TODO: check that alias == doc.text[start:end]
should_add_ent = (dev or entity_id in candidate_ids) and is_valid_sentence(
tagged_ent.sent.text
)
if should_add_ent:
value_by_id = {entity_id: 1.0}
if not dev:
random.shuffle(candidate_ids)
value_by_id.update(
{kb_id: 0.0 for kb_id in candidate_ids if kb_id != entity_id}
)
gold_entities[(start, end)] = value_by_id
return GoldParse(doc, links=gold_entities)
def is_dev(article_id):
if not article_id:
return False
return article_id.endswith("3")
def is_valid_article(doc_text):
# custom length cut-off
return 10 < len(doc_text) < 30000
def is_valid_sentence(sent_text):
if not 10 < len(sent_text) < 3000:
# custom length cut-off
return False
if sent_text.strip().startswith("*") or sent_text.strip().startswith("#"):
# remove 'enumeration' sentences (occurs often on Wikipedia)
return False
return True

View File

@ -7,7 +7,7 @@ dependency tree to find the noun phrase they are referring to for example:
$9.4 million --> Net income.
Compatible with: spaCy v2.0.0+
Last tested with: v2.1.0
Last tested with: v2.2.1
"""
from __future__ import unicode_literals, print_function
@ -38,14 +38,17 @@ def main(model="en_core_web_sm"):
def filter_spans(spans):
# Filter a sequence of spans so they don't contain overlaps
get_sort_key = lambda span: (span.end - span.start, span.start)
# For spaCy 2.1.4+: this function is available as spacy.util.filter_spans()
get_sort_key = lambda span: (span.end - span.start, -span.start)
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
result = []
seen_tokens = set()
for span in sorted_spans:
# Check for end - 1 here because boundaries are inclusive
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
result.append(span)
seen_tokens.update(range(span.start, span.end))
seen_tokens.update(range(span.start, span.end))
result = sorted(result, key=lambda span: span.start)
return result

View File

@ -91,8 +91,8 @@ def demo(shape):
nlp = spacy.load("en_vectors_web_lg")
nlp.add_pipe(KerasSimilarityShim.load(nlp.path / "similarity", nlp, shape[0]))
doc1 = nlp(u"The king of France is bald.")
doc2 = nlp(u"France has no king.")
doc1 = nlp("The king of France is bald.")
doc2 = nlp("France has no king.")
print("Sentence 1:", doc1)
print("Sentence 2:", doc2)

View File

@ -8,7 +8,7 @@
{
"tokens": [
{
"head": 4,
"head": 44,
"dep": "prep",
"tag": "IN",
"orth": "In",

View File

@ -11,7 +11,7 @@ numpy>=1.15.0
requests>=2.13.0,<3.0.0
plac<1.0.0,>=0.9.6
pathlib==1.0.1; python_version < "3.4"
importlib_metadata>=0.23; python_version < "3.8"
importlib_metadata>=0.20; python_version < "3.8"
# Optional dependencies
jsonschema>=2.6.0,<3.1.0
# Development dependencies

View File

@ -51,7 +51,7 @@ install_requires =
wasabi>=0.2.0,<1.1.0
srsly>=0.1.0,<1.1.0
pathlib==1.0.1; python_version < "3.4"
importlib_metadata>=0.23; python_version < "3.8"
importlib_metadata>=0.20; python_version < "3.8"
[options.extras_require]
lookups =

View File

@ -57,7 +57,8 @@ def convert(
is written to stdout, so you can pipe them forward to a JSON file:
$ spacy convert some_file.conllu > some_file.json
"""
msg = Printer()
no_print = (output_dir == "-")
msg = Printer(no_print=no_print)
input_path = Path(input_file)
if file_type not in FILE_TYPES:
msg.fail(
@ -102,6 +103,7 @@ def convert(
use_morphology=morphology,
lang=lang,
model=model,
no_print=no_print,
)
if output_dir != "-":
# Export data to a file

View File

@ -9,7 +9,7 @@ from ...tokens.doc import Doc
from ...util import load_model
def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, **kwargs):
def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, no_print=False, **kwargs):
"""
Convert files in the CoNLL-2003 NER format and similar
whitespace-separated columns into JSON format for use with train cli.
@ -34,7 +34,7 @@ def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, **kwargs
. O
"""
msg = Printer()
msg = Printer(no_print=no_print)
doc_delimiter = "-DOCSTART- -X- O O"
# check for existing delimiters, which should be preserved
if "\n\n" in input_data and seg_sents:

View File

@ -8,7 +8,7 @@ from ...util import minibatch
from .conll_ner2json import n_sents_info
def iob2json(input_data, n_sents=10, *args, **kwargs):
def iob2json(input_data, n_sents=10, no_print=False, *args, **kwargs):
"""
Convert IOB files with one sentence per line and tags separated with '|'
into JSON format for use with train cli. IOB and IOB2 are accepted.
@ -20,7 +20,7 @@ def iob2json(input_data, n_sents=10, *args, **kwargs):
I|PRP|O like|VBP|O London|NNP|I-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O
I|PRP|O like|VBP|O London|NNP|B-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O
"""
msg = Printer()
msg = Printer(no_print=no_print)
docs = read_iob(input_data.split("\n"))
if n_sents > 0:
n_sents_info(msg, n_sents)

View File

@ -360,6 +360,16 @@ def debug_data(
)
)
# check for documents with multiple sentences
sents_per_doc = gold_train_data["n_sents"] / len(gold_train_data["texts"])
if sents_per_doc < 1.1:
msg.warn(
"The training data contains {:.2f} sentences per "
"document. When there are very few documents containing more "
"than one sentence, the parser will not learn how to segment "
"longer texts into sentences.".format(sents_per_doc)
)
# profile labels
labels_train = [label for label in gold_train_data["deps"]]
labels_train_unpreprocessed = [

View File

@ -246,7 +246,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
"""Perform an update over a single batch of documents.
docs (iterable): A batch of `Doc` objects.
drop (float): The droput rate.
drop (float): The dropout rate.
optimizer (callable): An optimizer.
RETURNS loss: A float for the loss.
"""

View File

@ -80,8 +80,8 @@ class Warnings(object):
"the v2.x models cannot release the global interpreter lock. "
"Future versions may introduce a `n_process` argument for "
"parallel inference via multiprocessing.")
W017 = ("Alias '{alias}' already exists in the Knowledge base.")
W018 = ("Entity '{entity}' already exists in the Knowledge base.")
W017 = ("Alias '{alias}' already exists in the Knowledge Base.")
W018 = ("Entity '{entity}' already exists in the Knowledge Base.")
W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
"previously loaded vectors. See Issue #3853.")
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
@ -95,7 +95,10 @@ class Warnings(object):
"you can ignore this warning by setting SPACY_WARNING_IGNORE=W022. "
"If this is surprising, make sure you have the spacy-lookups-data "
"package installed.")
W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.")
W023 = ("Multiprocessing of Language.pipe is not supported in Python 2. "
"'n_process' will be set to 1.")
W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
"the Knowledge Base.")
@add_codes
@ -408,7 +411,7 @@ class Errors(object):
"{probabilities_length} respectively.")
E133 = ("The sum of prior probabilities for alias '{alias}' should not "
"exceed 1, but found {sum}.")
E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.")
E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
E136 = ("This additional feature requires the jsonschema library to be "
@ -420,7 +423,7 @@ class Errors(object):
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
"includes either the `text` or `tokens` key. For more info, see "
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
E139 = ("Knowledge base for component '{name}' not initialized. Did you "
E139 = ("Knowledge Base for component '{name}' not initialized. Did you "
"forget to call set_kb()?")
E140 = ("The list of entities, prior probabilities and entity vectors "
"should be of equal length.")
@ -498,6 +501,8 @@ class Errors(object):
"details: https://spacy.io/api/lemmatizer#init")
E174 = ("Architecture '{name}' not found in registry. Available "
"names: {names}")
E175 = ("Can't remove rule for unknown match pattern ID: {key}")
E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
@add_codes

View File

@ -743,7 +743,8 @@ def docs_to_json(docs, id=0):
docs (iterable / Doc): The Doc object(s) to convert.
id (int): Id for the JSON.
RETURNS (list): The data in spaCy's JSON format.
RETURNS (dict): The data in spaCy's JSON format
- each input doc will be treated as a paragraph in the output doc
"""
if isinstance(docs, Doc):
docs = [docs]

View File

@ -142,6 +142,7 @@ cdef class KnowledgeBase:
i = 0
cdef KBEntryC entry
cdef hash_t entity_hash
while i < nr_entities:
entity_vector = vector_list[i]
if len(entity_vector) != self.entity_vector_length:
@ -161,6 +162,14 @@ cdef class KnowledgeBase:
i += 1
def contains_entity(self, unicode entity):
cdef hash_t entity_hash = self.vocab.strings.add(entity)
return entity_hash in self._entry_index
def contains_alias(self, unicode alias):
cdef hash_t alias_hash = self.vocab.strings.add(alias)
return alias_hash in self._alias_index
def add_alias(self, unicode alias, entities, probabilities):
"""
For a given alias, add its potential entities and prior probabilies to the KB.
@ -190,7 +199,7 @@ cdef class KnowledgeBase:
for entity, prob in zip(entities, probabilities):
entity_hash = self.vocab.strings[entity]
if not entity_hash in self._entry_index:
raise ValueError(Errors.E134.format(alias=alias, entity=entity))
raise ValueError(Errors.E134.format(entity=entity))
entry_index = <int64_t>self._entry_index.get(entity_hash)
entry_indices.push_back(int(entry_index))
@ -201,8 +210,63 @@ cdef class KnowledgeBase:
return alias_hash
def get_candidates(self, unicode alias):
def append_alias(self, unicode alias, unicode entity, float prior_prob, ignore_warnings=False):
"""
For an alias already existing in the KB, extend its potential entities with one more.
Throw a warning if either the alias or the entity is unknown,
or when the combination is already previously recorded.
Throw an error if this entity+prior prob would exceed the sum of 1.
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
"""
# Check if the alias exists in the KB
cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index:
raise ValueError(Errors.E176.format(alias=alias))
# Check if the entity exists in the KB
cdef hash_t entity_hash = self.vocab.strings[entity]
if not entity_hash in self._entry_index:
raise ValueError(Errors.E134.format(entity=entity))
entry_index = <int64_t>self._entry_index.get(entity_hash)
# Throw an error if the prior probabilities (including the new one) sum up to more than 1
alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index]
current_sum = sum([p for p in alias_entry.probs])
new_sum = current_sum + prior_prob
if new_sum > 1.00001:
raise ValueError(Errors.E133.format(alias=alias, sum=new_sum))
entry_indices = alias_entry.entry_indices
is_present = False
for i in range(entry_indices.size()):
if entry_indices[i] == int(entry_index):
is_present = True
if is_present:
if not ignore_warnings:
user_warning(Warnings.W024.format(entity=entity, alias=alias))
else:
entry_indices.push_back(int(entry_index))
alias_entry.entry_indices = entry_indices
probs = alias_entry.probs
probs.push_back(float(prior_prob))
alias_entry.probs = probs
self._aliases_table[alias_index] = alias_entry
def get_candidates(self, unicode alias):
"""
Return candidate entities for an alias. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity.
If the alias is not known in the KB, and empty list is returned.
"""
cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index:
return []
alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index]
@ -341,7 +405,6 @@ cdef class KnowledgeBase:
assert nr_entities == self.get_size_entities()
# STEP 3: load aliases
cdef int64_t nr_aliases
reader.read_alias_length(&nr_aliases)
self._alias_index = PreshMap(nr_aliases+1)

34
spacy/lang/lb/__init__.py Normal file
View File

@ -0,0 +1,34 @@
# coding: utf8
from __future__ import unicode_literals
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .norm_exceptions import NORM_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
from .tag_map import TAG_MAP
from .stop_words import STOP_WORDS
from ..tokenizer_exceptions import BASE_EXCEPTIONS
from ..norm_exceptions import BASE_NORMS
from ...language import Language
from ...attrs import LANG, NORM
from ...util import update_exc, add_lookups
class LuxembourgishDefaults(Language.Defaults):
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
lex_attr_getters.update(LEX_ATTRS)
lex_attr_getters[LANG] = lambda text: "lb"
lex_attr_getters[NORM] = add_lookups(
Language.Defaults.lex_attr_getters[NORM], NORM_EXCEPTIONS, BASE_NORMS
)
tokenizer_exceptions = update_exc(BASE_EXCEPTIONS, TOKENIZER_EXCEPTIONS)
stop_words = STOP_WORDS
tag_map = TAG_MAP
class Luxembourgish(Language):
lang = "lb"
Defaults = LuxembourgishDefaults
__all__ = ["Luxembourgish"]

18
spacy/lang/lb/examples.py Normal file
View File

@ -0,0 +1,18 @@
# coding: utf8
from __future__ import unicode_literals
"""
Example sentences to test spaCy and its language models.
>>> from spacy.lang.lb.examples import sentences
>>> docs = nlp.pipe(sentences)
"""
sentences = [
"An der Zäit hunn sech den Nordwand an dSonn gestridden, wie vun hinnen zwee wuel méi staark wier, wéi e Wanderer, deen an ee waarme Mantel agepak war, iwwert de Wee koum.",
"Si goufen sech eens, dass deejéinege fir de Stäerkste gëlle sollt, deen de Wanderer forcéiere géif, säi Mantel auszedoen.",
"Den Nordwand huet mat aller Force geblosen, awer wat e méi geblosen huet, wat de Wanderer sech méi a säi Mantel agewéckelt huet.",
"Um Enn huet den Nordwand säi Kampf opginn.",
"Dunn huet dSonn dLoft mat hire frëndleche Strale gewiermt, a schonn no kuerzer Zäit huet de Wanderer säi Mantel ausgedoen.",
"Do huet den Nordwand missen zouginn, dass dSonn vun hinnen zwee de Stäerkste wier.",
]

View File

@ -0,0 +1,44 @@
# coding: utf8
from __future__ import unicode_literals
from ...attrs import LIKE_NUM
_num_words = set(
"""
null eent zwee dräi véier fënnef sechs ziwen aacht néng zéng eelef zwielef dräizéng
véierzéng foffzéng siechzéng siwwenzéng uechtzeng uechzeng nonnzéng nongzéng zwanzeg drësseg véierzeg foffzeg sechzeg siechzeg siwenzeg achtzeg achzeg uechtzeg uechzeg nonnzeg
honnert dausend millioun milliard billioun billiard trillioun triliard
""".split()
)
_ordinal_words = set(
"""
éischten zweeten drëtten véierten fënneften sechsten siwenten aachten néngten zéngten eeleften
zwieleften dräizéngten véierzéngten foffzéngten siechzéngten uechtzéngen uechzéngten nonnzéngten nongzéngten zwanzegsten
drëssegsten véierzegsten foffzegsten siechzegsten siwenzegsten uechzegsten nonnzegsten
honnertsten dausendsten milliounsten
milliardsten billiounsten billiardsten trilliounsten trilliardsten
""".split()
)
def like_num(text):
"""
check if text resembles a number
"""
text = text.replace(",", "").replace(".", "")
if text.isdigit():
return True
if text.count("/") == 1:
num, denom = text.split("/")
if num.isdigit() and denom.isdigit():
return True
if text in _num_words:
return True
if text in _ordinal_words:
return True
return False
LEX_ATTRS = {LIKE_NUM: like_num}

View File

@ -0,0 +1,16 @@
# coding: utf8
from __future__ import unicode_literals
# TODO
# norm execptions: find a possibility to deal with the zillions of spelling
# variants (vläicht = vlaicht, vleicht, viläicht, viläischt, etc. etc.)
# here one could include the most common spelling mistakes
_exc = {"datt": "dass", "wgl.": "weg.", "vläicht": "viläicht"}
NORM_EXCEPTIONS = {}
for string, norm in _exc.items():
NORM_EXCEPTIONS[string] = norm
NORM_EXCEPTIONS[string.title()] = norm

214
spacy/lang/lb/stop_words.py Normal file
View File

@ -0,0 +1,214 @@
# coding: utf8
from __future__ import unicode_literals
STOP_WORDS = set(
"""
a
à
äis
är
ärt
äert
ären
all
allem
alles
alleguer
als
also
am
an
anerefalls
ass
aus
awer
bei
beim
bis
bis
d'
dach
datt
däin
där
dat
de
dee
den
deel
deem
deen
deene
déi
den
deng
denger
dem
der
dësem
di
dir
do
da
dann
domat
dozou
drop
du
duerch
duerno
e
ee
em
een
eent
ë
en
ënner
ëm
ech
eis
eise
eisen
eiser
eises
eisereen
esou
een
eng
enger
engem
entweder
et
eréischt
falls
fir
géint
géif
gëtt
gët
geet
gi
ginn
gouf
gouff
goung
hat
haten
hatt
hätt
hei
hu
huet
hun
hunn
hiren
hien
hin
hier
hir
jidderen
jiddereen
jiddwereen
jiddereng
jiddwerengen
jo
ins
iech
iwwer
kann
kee
keen
kënne
kënnt
kéng
kéngen
kéngem
koum
kuckt
mam
mat
ma
mech
méi
mécht
meng
menger
mer
mir
muss
nach
nämmlech
nämmelech
näischt
nawell
nëmme
nëmmen
net
nees
nee
no
nu
nom
och
oder
ons
onsen
onser
onsereen
onst
om
op
ouni
säi
säin
schonn
schonns
si
sid
sie
se
sech
seng
senge
sengem
senger
selwecht
selwer
sinn
sollten
souguer
sou
soss
sot
't
tëscht
u
un
um
virdrun
vu
vum
vun
wann
war
waren
was
wat
wëllt
weider
wéi
wéini
wéinst
wi
wollt
wou
wouhin
zanter
ze
zu
zum
zwar
""".split()
)

28
spacy/lang/lb/tag_map.py Normal file
View File

@ -0,0 +1,28 @@
# coding: utf8
from __future__ import unicode_literals
from ...symbols import POS, PUNCT, ADJ, CONJ, NUM, DET, ADV, ADP, X, VERB
from ...symbols import NOUN, PART, SPACE, AUX
# TODO: tag map is still using POS tags from an internal training set.
# These POS tags have to be modified to match those from Universal Dependencies
TAG_MAP = {
"$": {POS: PUNCT},
"ADJ": {POS: ADJ},
"AV": {POS: ADV},
"APPR": {POS: ADP, "AdpType": "prep"},
"APPRART": {POS: ADP, "AdpType": "prep", "PronType": "art"},
"D": {POS: DET, "PronType": "art"},
"KO": {POS: CONJ},
"N": {POS: NOUN},
"P": {POS: ADV},
"TRUNC": {POS: X, "Hyph": "yes"},
"AUX": {POS: AUX},
"V": {POS: VERB},
"MV": {POS: VERB, "VerbType": "mod"},
"PTK": {POS: PART},
"INTER": {POS: PART},
"NUM": {POS: NUM},
"_SP": {POS: SPACE},
}

View File

@ -0,0 +1,69 @@
# coding: utf8
from __future__ import unicode_literals
from ...symbols import ORTH, LEMMA, NORM
from ..punctuation import TOKENIZER_PREFIXES
# TODO
# tokenize cliticised definite article "d'" as token of its own: d'Kanner > [d'] [Kanner]
# treat other apostrophes within words as part of the word: [op d'mannst], [fir d'éischt] (= exceptions)
# how to write the tokenisation exeption for the articles d' / D' ? This one is not working.
_prefixes = [
prefix for prefix in TOKENIZER_PREFIXES if prefix not in ["d'", "D'", "d", "D"]
]
_exc = {
"d'mannst": [
{ORTH: "d'", LEMMA: "d'"},
{ORTH: "mannst", LEMMA: "mann", NORM: "mann"},
],
"d'éischt": [
{ORTH: "d'", LEMMA: "d'"},
{ORTH: "éischt", LEMMA: "éischt", NORM: "éischt"},
],
}
# translate / delete what is not necessary
# what does PRON_LEMMA mean?
for exc_data in [
{ORTH: "wgl.", LEMMA: "wann ech gelift", NORM: "wann ech gelieft"},
{ORTH: "M.", LEMMA: "Monsieur", NORM: "Monsieur"},
{ORTH: "Mme.", LEMMA: "Madame", NORM: "Madame"},
{ORTH: "Dr.", LEMMA: "Dokter", NORM: "Dokter"},
{ORTH: "Tel.", LEMMA: "Telefon", NORM: "Telefon"},
{ORTH: "asw.", LEMMA: "an sou weider", NORM: "an sou weider"},
{ORTH: "etc.", LEMMA: "et cetera", NORM: "et cetera"},
{ORTH: "bzw.", LEMMA: "bezéiungsweis", NORM: "bezéiungsweis"},
{ORTH: "Jan.", LEMMA: "Januar", NORM: "Januar"},
]:
_exc[exc_data[ORTH]] = [exc_data]
# to be extended
for orth in [
"z.B.",
"Dipl.",
"Dr.",
"etc.",
"i.e.",
"o.k.",
"O.K.",
"p.a.",
"p.s.",
"P.S.",
"phil.",
"q.e.d.",
"R.I.P.",
"rer.",
"sen.",
"ë.a.",
"U.S.",
"U.S.A.",
]:
_exc[orth] = [{ORTH: orth}]
TOKENIZER_PREFIXES = _prefixes
TOKENIZER_EXCEPTIONS = _exc

View File

@ -1,10 +1,8 @@
# coding: utf8
from __future__ import absolute_import, unicode_literals
import atexit
import random
import itertools
from warnings import warn
from spacy.util import minibatch
import weakref
import functools
@ -483,7 +481,7 @@ class Language(object):
docs (iterable): A batch of `Doc` objects.
golds (iterable): A batch of `GoldParse` objects.
drop (float): The droput rate.
drop (float): The dropout rate.
sgd (callable): An optimizer.
losses (dict): Dictionary to update with the loss, keyed by component.
component_cfg (dict): Config parameters for specific pipeline
@ -531,7 +529,7 @@ class Language(object):
even if you're updating it with a smaller set of examples.
docs (iterable): A batch of `Doc` objects.
drop (float): The droput rate.
drop (float): The dropout rate.
sgd (callable): An optimizer.
RETURNS (dict): Results from the update.
@ -753,7 +751,8 @@ class Language(object):
use. Experimental.
component_cfg (dict): An optional dictionary with extra keyword
arguments for specific components.
n_process (int): Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`.
n_process (int): Number of processors to process texts, only supported
in Python3. If -1, set `multiprocessing.cpu_count()`.
YIELDS (Doc): Documents in the order of the original text.
DOCS: https://spacy.io/api/language#pipe
@ -1069,9 +1068,10 @@ def _pipe(docs, proc, kwargs):
def _apply_pipes(make_doc, pipes, reciever, sender):
"""Worker for Language.pipe
Args:
receiver (multiprocessing.Connection): Pipe to receive text. Usually created by `multiprocessing.Pipe()`
sender (multiprocessing.Connection): Pipe to send doc. Usually created by `multiprocessing.Pipe()`
receiver (multiprocessing.Connection): Pipe to receive text. Usually
created by `multiprocessing.Pipe()`
sender (multiprocessing.Connection): Pipe to send doc. Usually created by
`multiprocessing.Pipe()`
"""
while True:
texts = reciever.get()
@ -1100,7 +1100,7 @@ class _Sender:
q.put(item)
def step(self):
"""Tell sender that comsumed one item.
"""Tell sender that comsumed one item.
Data is sent to the workers after every chunk_size calls."""
self.count += 1

View File

@ -133,13 +133,15 @@ cdef class Matcher:
key (unicode): The ID of the match rule.
"""
key = self._normalize_key(key)
self._patterns.pop(key)
self._callbacks.pop(key)
norm_key = self._normalize_key(key)
if not norm_key in self._patterns:
raise ValueError(Errors.E175.format(key=key))
self._patterns.pop(norm_key)
self._callbacks.pop(norm_key)
cdef int i = 0
while i < self.patterns.size():
pattern_key = get_pattern_key(self.patterns.at(i))
if pattern_key == key:
pattern_key = get_ent_id(self.patterns.at(i))
if pattern_key == norm_key:
self.patterns.erase(self.patterns.begin()+i)
else:
i += 1
@ -293,18 +295,6 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
return output
cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
# There have been a few bugs here.
# The code was originally designed to always have pattern[1].attrs.value
# be the ent_id when we get to the end of a pattern. However, Issue #2671
# showed this wasn't the case when we had a reject-and-continue before a
# match.
# The patch to #2671 was wrong though, which came up in #3839.
while pattern.attrs.attr != ID:
pattern += 1
return pattern.attrs.value
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
char* cached_py_predicates,
Token token, const attr_t* extra_attrs, py_predicates) except *:
@ -533,9 +523,10 @@ cdef char get_is_match(PatternStateC state,
if predicate_matches[state.pattern.py_predicates[i]] == -1:
return 0
spec = state.pattern
for attr in spec.attrs[:spec.nr_attr]:
if get_token_attr(token, attr.attr) != attr.value:
return 0
if spec.nr_attr > 0:
for attr in spec.attrs[:spec.nr_attr]:
if get_token_attr(token, attr.attr) != attr.value:
return 0
for i in range(spec.nr_extra_attr):
if spec.extra_attrs[i].value != extra_attrs[spec.extra_attrs[i].index]:
return 0
@ -543,7 +534,11 @@ cdef char get_is_match(PatternStateC state,
cdef char get_is_final(PatternStateC state) nogil:
if state.pattern[1].attrs[0].attr == ID and state.pattern[1].nr_attr == 0:
if state.pattern[1].nr_attr == 0 and state.pattern[1].attrs != NULL:
id_attr = state.pattern[1].attrs[0]
if id_attr.attr != ID:
with gil:
raise ValueError(Errors.E074.format(attr=ID, bad_attr=id_attr.attr))
return 1
else:
return 0
@ -558,7 +553,9 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
cdef int i, index
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
pattern[i].quantifier = quantifier
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
# Ensure attrs refers to a null pointer if nr_attr == 0
if len(spec) > 0:
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
pattern[i].nr_attr = len(spec)
for j, (attr, value) in enumerate(spec):
pattern[i].attrs[j].attr = attr
@ -574,6 +571,7 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
pattern[i].nr_py = len(predicates)
pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0)
i = len(token_specs)
# Even though here, nr_attr == 0, we're storing the ID value in attrs[0] (bug-prone, thread carefully!)
pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
pattern[i].attrs[0].attr = ID
pattern[i].attrs[0].value = entity_id
@ -583,8 +581,26 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
return pattern
cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil:
while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0:
cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
# There have been a few bugs here. We used to have two functions,
# get_ent_id and get_pattern_key that tried to do the same thing. These
# are now unified to try to solve the "ghost match" problem.
# Below is the previous implementation of get_ent_id and the comment on it,
# preserved for reference while we figure out whether the heisenbug in the
# matcher is resolved.
#
#
# cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
# # The code was originally designed to always have pattern[1].attrs.value
# # be the ent_id when we get to the end of a pattern. However, Issue #2671
# # showed this wasn't the case when we had a reject-and-continue before a
# # match.
# # The patch to #2671 was wrong though, which came up in #3839.
# while pattern.attrs.attr != ID:
# pattern += 1
# return pattern.attrs.value
while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0 \
or pattern.quantifier != ZERO:
pattern += 1
id_attr = pattern[0].attrs[0]
if id_attr.attr != ID:
@ -642,7 +658,7 @@ def _get_attr_values(spec, string_store):
value = string_store.add(value)
elif isinstance(value, bool):
value = int(value)
elif isinstance(value, dict):
elif isinstance(value, (dict, int)):
continue
else:
raise ValueError(Errors.E153.format(vtype=type(value).__name__))

View File

@ -4,6 +4,7 @@ from cymem.cymem cimport Pool
from preshed.maps cimport key_t, MapStruct
from ..attrs cimport attr_id_t
from ..structs cimport SpanC
from ..tokens.doc cimport Doc
from ..vocab cimport Vocab
@ -18,10 +19,4 @@ cdef class PhraseMatcher:
cdef Pool mem
cdef key_t _terminal_hash
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil
cdef struct MatchStruct:
key_t match_id
int start
int end
cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil

View File

@ -9,6 +9,7 @@ from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
from ..structs cimport TokenC
from ..tokens.token cimport Token
from ..typedefs cimport attr_t
from ._schemas import TOKEN_PATTERN_SCHEMA
from ..errors import Errors, Warnings, deprecation_warning, user_warning
@ -102,8 +103,10 @@ cdef class PhraseMatcher:
cdef vector[MapStruct*] path_nodes
cdef vector[key_t] path_keys
cdef key_t key_to_remove
for keyword in self._docs[key]:
for keyword in sorted(self._docs[key], key=lambda x: len(x), reverse=True):
current_node = self.c_map
path_nodes.clear()
path_keys.clear()
for token in keyword:
result = map_get(current_node, token)
if result:
@ -220,17 +223,17 @@ cdef class PhraseMatcher:
# if doc is empty or None just return empty list
return matches
cdef vector[MatchStruct] c_matches
cdef vector[SpanC] c_matches
self.find_matches(doc, &c_matches)
for i in range(c_matches.size()):
matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end))
matches.append((c_matches[i].label, c_matches[i].start, c_matches[i].end))
for i, (ent_id, start, end) in enumerate(matches):
on_match = self._callbacks.get(self.vocab.strings[ent_id])
if on_match is not None:
on_match(self, doc, i, matches)
return matches
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil:
cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil:
cdef MapStruct* current_node = self.c_map
cdef int start = 0
cdef int idx = 0
@ -238,7 +241,7 @@ cdef class PhraseMatcher:
cdef key_t key
cdef void* value
cdef int i = 0
cdef MatchStruct ms
cdef SpanC ms
cdef void* result
while idx < doc.length:
start = idx
@ -253,7 +256,7 @@ cdef class PhraseMatcher:
if result:
i = 0
while map_iter(<MapStruct*>result, &i, &key, &value):
ms = make_matchstruct(key, start, idy)
ms = make_spanstruct(key, start, idy)
matches.push_back(ms)
inner_token = Token.get_struct_attr(&doc.c[idy], self.attr)
result = map_get(current_node, inner_token)
@ -268,7 +271,7 @@ cdef class PhraseMatcher:
if result:
i = 0
while map_iter(<MapStruct*>result, &i, &key, &value):
ms = make_matchstruct(key, start, idy)
ms = make_spanstruct(key, start, idy)
matches.push_back(ms)
current_node = self.c_map
idx += 1
@ -318,9 +321,9 @@ def unpickle_matcher(vocab, docs, callbacks, attr):
return matcher
cdef MatchStruct make_matchstruct(key_t match_id, int start, int end) nogil:
cdef MatchStruct ms
ms.match_id = match_id
ms.start = start
ms.end = end
return ms
cdef SpanC make_spanstruct(attr_t label, int start, int end) nogil:
cdef SpanC spanc
spanc.label = label
spanc.start = start
spanc.end = end
return spanc

View File

@ -183,7 +183,9 @@ class EntityRuler(object):
# disable the nlp components after this one in case they hadn't been initialized / deserialised yet
try:
current_index = self.nlp.pipe_names.index(self.name)
subsequent_pipes = [pipe for pipe in self.nlp.pipe_names[current_index + 1:]]
subsequent_pipes = [
pipe for pipe in self.nlp.pipe_names[current_index + 1 :]
]
except ValueError:
subsequent_pipes = []
with self.nlp.disable_pipes(*subsequent_pipes):

View File

@ -1195,23 +1195,26 @@ class EntityLinker(Pipe):
docs = [docs]
golds = [golds]
context_docs = []
sentence_docs = []
for doc, gold in zip(docs, golds):
ents_by_offset = dict()
for ent in doc.ents:
ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent
ents_by_offset[(ent.start_char, ent.end_char)] = ent
for entity, kb_dict in gold.links.items():
start, end = entity
mention = doc.text[start:end]
# the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt
ent = ents_by_offset[(start, end)]
for kb_id, value in kb_dict.items():
# Currently only training on the positive instances
if value:
context_docs.append(doc)
sentence_docs.append(ent.sent.as_doc())
context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
loss, d_scores = self.get_similarity_loss(scores=context_encodings, golds=golds, docs=None)
sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop)
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None)
bp_context(d_scores, sgd=sgd)
if losses is not None:
@ -1280,50 +1283,68 @@ class EntityLinker(Pipe):
if isinstance(docs, Doc):
docs = [docs]
context_encodings = self.model(docs)
xp = get_array_module(context_encodings)
for i, doc in enumerate(docs):
if len(doc) > 0:
# currently, the context is the same for each entity in a sentence (should be refined)
context_encoding = context_encodings[i]
context_enc_t = context_encoding.T
norm_1 = xp.linalg.norm(context_enc_t)
for ent in doc.ents:
entity_count += 1
# Looping through each sentence and each entity
# This may go wrong if there are entities across sentences - because they might not get a KB ID
for sent in doc.ents:
sent_doc = sent.as_doc()
# currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model([sent_doc])[0]
xp = get_array_module(sentence_encoding)
sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t)
candidates = self.kb.get_candidates(ent.text)
if not candidates:
final_kb_ids.append(self.NIL) # no prediction possible for this entity
final_tensors.append(context_encoding)
else:
random.shuffle(candidates)
for ent in sent_doc.ents:
entity_count += 1
# this will set all prior probabilities to 0 if they should be excluded from the model
prior_probs = xp.asarray([c.prior_prob for c in candidates])
if not self.cfg.get("incl_prior", True):
prior_probs = xp.asarray([0.0 for c in candidates])
scores = prior_probs
if ent.label_ in self.cfg.get("labels_discard", []):
# ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL)
final_tensors.append(sentence_encoding)
# add in similarity from the context
if self.cfg.get("incl_context", True):
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
norm_2 = xp.linalg.norm(entity_encodings, axis=1)
else:
candidates = self.kb.get_candidates(ent.text)
if not candidates:
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL)
final_tensors.append(sentence_encoding)
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
elif len(candidates) == 1:
# shortcut for efficiency reasons: take the 1 candidate
# cosine similarity
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
if sims.shape != prior_probs.shape:
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs*sims)
# TODO: thresholding
final_kb_ids.append(candidates[0].entity_)
final_tensors.append(sentence_encoding)
# TODO: thresholding
best_index = scores.argmax()
best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_)
final_tensors.append(context_encoding)
else:
random.shuffle(candidates)
# this will set all prior probabilities to 0 if they should be excluded from the model
prior_probs = xp.asarray([c.prior_prob for c in candidates])
if not self.cfg.get("incl_prior", True):
prior_probs = xp.asarray([0.0 for c in candidates])
scores = prior_probs
# add in similarity from the context
if self.cfg.get("incl_context", True):
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
# cosine similarity
sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm)
if sims.shape != prior_probs.shape:
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs*sims)
# TODO: thresholding
best_index = scores.argmax()
best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_)
final_tensors.append(sentence_encoding)
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))

View File

@ -219,7 +219,9 @@ class Scorer(object):
DOCS: https://spacy.io/api/scorer#score
"""
if len(doc) != len(gold):
gold = GoldParse.from_annot_tuples(doc, zip(*gold.orig_annot))
gold = GoldParse.from_annot_tuples(
doc, tuple(zip(*gold.orig_annot)) + (gold.cats,)
)
gold_deps = set()
gold_tags = set()
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))

View File

@ -47,11 +47,14 @@ cdef struct SerializedLexemeC:
# + sizeof(float) # l2_norm
cdef struct Entity:
cdef struct SpanC:
hash_t id
int start
int end
int start_char
int end_char
attr_t label
attr_t kb_id
cdef struct TokenC:

View File

@ -7,7 +7,7 @@ from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
from murmurhash.mrmr cimport hash64
from ..vocab cimport EMPTY_LEXEME
from ..structs cimport TokenC, Entity
from ..structs cimport TokenC, SpanC
from ..lexeme cimport Lexeme
from ..symbols cimport punct
from ..attrs cimport IS_SPACE
@ -40,7 +40,7 @@ cdef cppclass StateC:
int* _buffer
bint* shifted
TokenC* _sent
Entity* _ents
SpanC* _ents
TokenC _empty_token
RingBufferC _hist
int length
@ -56,7 +56,7 @@ cdef cppclass StateC:
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
this._ents = <SpanC*>calloc(length + (PADDING * 2), sizeof(SpanC))
if not (this._buffer and this._stack and this.shifted
and this._sent and this._ents):
with gil:
@ -406,7 +406,7 @@ cdef cppclass StateC:
memcpy(this._sent, src._sent, this.length * sizeof(TokenC))
memcpy(this._stack, src._stack, this.length * sizeof(int))
memcpy(this._buffer, src._buffer, this.length * sizeof(int))
memcpy(this._ents, src._ents, this.length * sizeof(Entity))
memcpy(this._ents, src._ents, this.length * sizeof(SpanC))
memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0]))
this._b_i = src._b_i
this._s_i = src._s_i

View File

@ -3,7 +3,7 @@ from libc.string cimport memcpy, memset
from cymem.cymem cimport Pool
cimport cython
from ..structs cimport TokenC, Entity
from ..structs cimport TokenC, SpanC
from ..typedefs cimport attr_t
from ..vocab cimport EMPTY_LEXEME

View File

@ -135,6 +135,11 @@ def ko_tokenizer():
return get_lang_class("ko").Defaults.create_tokenizer()
@pytest.fixture(scope="session")
def lb_tokenizer():
return get_lang_class("lb").Defaults.create_tokenizer()
@pytest.fixture(scope="session")
def lt_tokenizer():
return get_lang_class("lt").Defaults.create_tokenizer()

View File

@ -253,3 +253,11 @@ def test_filter_spans(doc):
assert len(filtered[1]) == 5
assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 5 and filtered[1].end == 10
# Test filtering overlaps with earlier preference for identical length
spans = [doc[1:4], doc[2:5], doc[5:10], doc[7:9], doc[1:4]]
filtered = filter_spans(spans)
assert len(filtered) == 2
assert len(filtered[0]) == 3
assert len(filtered[1]) == 5
assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 5 and filtered[1].end == 10

View File

@ -0,0 +1,10 @@
# coding: utf-8
from __future__ import unicode_literals
import pytest
@pytest.mark.parametrize("text", ["z.B.", "Jan."])
def test_lb_tokenizer_handles_abbr(lb_tokenizer, text):
tokens = lb_tokenizer(text)
assert len(tokens) == 1

View File

@ -0,0 +1,22 @@
# coding: utf-8
from __future__ import unicode_literals
import pytest
@pytest.mark.parametrize("text,length", [("z.B.", 1), ("zb.", 2), ("(z.B.", 2)])
def test_lb_tokenizer_splits_prefix_interact(lb_tokenizer, text, length):
tokens = lb_tokenizer(text)
assert len(tokens) == length
@pytest.mark.parametrize("text", ["z.B.)"])
def test_lb_tokenizer_splits_suffix_interact(lb_tokenizer, text):
tokens = lb_tokenizer(text)
assert len(tokens) == 2
@pytest.mark.parametrize("text", ["(z.B.)"])
def test_lb_tokenizer_splits_even_wrap_interact(lb_tokenizer, text):
tokens = lb_tokenizer(text)
assert len(tokens) == 3

View File

@ -0,0 +1,31 @@
# coding: utf-8
from __future__ import unicode_literals
import pytest
def test_lb_tokenizer_handles_long_text(lb_tokenizer):
text = """Den Nordwand an d'Sonn
An der Zäit hunn sech den Nordwand an dSonn gestridden, wie vun hinnen zwee wuel méi staark wier, wéi e Wanderer, deen an ee waarme Mantel agepak war, iwwert de Wee koum. Si goufen sech eens, dass deejéinege fir de Stäerkste gëlle sollt, deen de Wanderer forcéiere géif, säi Mantel auszedoen.",
Den Nordwand huet mat aller Force geblosen, awer wat e méi geblosen huet, wat de Wanderer sech méi a säi Mantel agewéckelt huet. Um Enn huet den Nordwand säi Kampf opginn.
Dunn huet dSonn dLoft mat hire frëndleche Strale gewiermt, a schonn no kuerzer Zäit huet de Wanderer säi Mantel ausgedoen.
Do huet den Nordwand missen zouginn, dass dSonn vun hinnen zwee de Stäerkste wier."""
tokens = lb_tokenizer(text)
assert len(tokens) == 143
@pytest.mark.parametrize(
"text,length",
[
("»Wat ass mat mir geschitt?«, huet hie geduecht.", 13),
("“Dëst fréi Opstoen”, denkt hien, “mécht ee ganz duercherneen. ", 15),
],
)
def test_lb_tokenizer_handles_examples(lb_tokenizer, text, length):
tokens = lb_tokenizer(text)
assert len(tokens) == length

View File

@ -3,6 +3,8 @@ from __future__ import unicode_literals
import pytest
import re
from spacy.lang.en import English
from spacy.matcher import Matcher
from spacy.tokens import Doc, Span
@ -143,3 +145,29 @@ def test_matcher_sets_return_correct_tokens(en_vocab):
matches = matcher(doc)
texts = [Span(doc, s, e, label=L).text for L, s, e in matches]
assert texts == ["zero", "one", "two"]
def test_matcher_remove():
nlp = English()
matcher = Matcher(nlp.vocab)
text = "This is a test case."
pattern = [{"ORTH": "test"}, {"OP": "?"}]
assert len(matcher) == 0
matcher.add("Rule", None, pattern)
assert "Rule" in matcher
# should give two matches
results1 = matcher(nlp(text))
assert len(results1) == 2
# removing once should work
matcher.remove("Rule")
# should not return any maches anymore
results2 = matcher(nlp(text))
assert len(results2) == 0
# removing again should throw an error
with pytest.raises(ValueError):
matcher.remove("Rule")

View File

@ -12,24 +12,25 @@ from spacy.util import get_json_validator, validate_json
TEST_PATTERNS = [
# Bad patterns flagged in all cases
([{"XX": "foo"}], 1, 1),
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 1),
([{"IS_ALPHA": {"==": True}}, {"LIKE_NUM": None}], 2, 1),
([{"IS_PUNCT": True, "OP": "$"}], 1, 1),
([{"IS_DIGIT": -1}], 1, 1),
([{"ORTH": -1}], 1, 1),
([{"_": "foo"}], 1, 1),
('[{"TEXT": "foo"}, {"LOWER": "bar"}]', 1, 1),
([1, 2, 3], 3, 1),
# Bad patterns flagged outside of Matcher
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 1, 0),
# Bad patterns not flagged with minimal checks
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 0),
([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 2, 0),
([{"LENGTH": {"VALUE": 5}}], 1, 0),
([{"TEXT": {"VALUE": "foo"}}], 1, 0),
([{"IS_DIGIT": -1}], 1, 0),
([{"ORTH": -1}], 1, 0),
# Good patterns
([{"TEXT": "foo"}, {"LOWER": "bar"}], 0, 0),
([{"LEMMA": {"IN": ["love", "like"]}}, {"POS": "DET", "OP": "?"}], 0, 0),
([{"LIKE_NUM": True, "LENGTH": {">=": 5}}], 0, 0),
([{"LENGTH": 2}], 0, 0),
([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0, 0),
([{"NORM": "a"}, {"POS": {"IN": ["NOUN"]}}], 0, 0),
([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0, 0),

View File

@ -226,3 +226,13 @@ def test_phrase_matcher_callback(en_vocab):
matcher.add("COMPANY", mock, pattern)
matches = matcher(doc)
mock.assert_called_once_with(matcher, doc, 0, matches)
def test_phrase_matcher_remove_overlapping_patterns(en_vocab):
matcher = PhraseMatcher(en_vocab)
pattern1 = Doc(en_vocab, words=["this"])
pattern2 = Doc(en_vocab, words=["this", "is"])
pattern3 = Doc(en_vocab, words=["this", "is", "a"])
pattern4 = Doc(en_vocab, words=["this", "is", "a", "word"])
matcher.add("THIS", None, pattern1, pattern2, pattern3, pattern4)
matcher.remove("THIS")

View File

@ -103,7 +103,7 @@ def test_oracle_moves_missing_B(en_vocab):
moves.add_action(move_types.index("L"), label)
moves.add_action(move_types.index("U"), label)
moves.preprocess_gold(gold)
seq = moves.get_oracle_sequence(doc, gold)
moves.get_oracle_sequence(doc, gold)
def test_oracle_moves_whitespace(en_vocab):

View File

@ -131,6 +131,53 @@ def test_candidate_generation(nlp):
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
def test_append_alias(nlp):
"""Test that we can append additional alias-entity pairs"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.4, 0.1])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
assert len(mykb.get_candidates("douglas")) == 2
# append an alias
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
# test the size of the relevant candidates has been incremented
assert len(mykb.get_candidates("douglas")) == 3
# append the same alias-entity pair again should not work (will throw a warning)
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3)
# test the size of the relevant candidates remained unchanged
assert len(mykb.get_candidates("douglas")) == 3
def test_append_invalid_alias(nlp):
"""Test that append an alias will throw an error if prior probs are exceeding 1"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# append an alias - should fail because the entities and probabilities vectors are not of equal length
with pytest.raises(ValueError):
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)

View File

@ -430,7 +430,7 @@ def test_issue957(en_tokenizer):
def test_issue999(train_data):
"""Test that adding entities and resuming training works passably OK.
There are two issues here:
1) We have to read labels. This isn't very nice.
1) We have to re-add labels. This isn't very nice.
2) There's no way to set the learning rate for the weight update, so we
end up out-of-scale, causing it to learn too fast.
"""

View File

@ -323,7 +323,7 @@ def test_issue3456():
nlp = English()
nlp.add_pipe(nlp.create_pipe("tagger"))
nlp.begin_training()
list(nlp.pipe(['hi', '']))
list(nlp.pipe(["hi", ""]))
def test_issue3468():

View File

@ -76,7 +76,6 @@ def test_issue4042_bug2():
output_dir.mkdir()
ner1.to_disk(output_dir)
nlp2 = English(vocab)
ner2 = EntityRecognizer(vocab)
ner2.from_disk(output_dir)
assert len(ner2.labels) == 2

View File

@ -1,13 +1,8 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
import spacy
from spacy.lang.en import English
from spacy.pipeline import EntityRuler
from spacy.tokens import Span
def test_issue4267():

View File

@ -6,6 +6,6 @@ from spacy.tokens import DocBin
def test_issue4367():
"""Test that docbin init goes well"""
doc_bin_1 = DocBin()
doc_bin_2 = DocBin(attrs=["LEMMA"])
doc_bin_3 = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"])
DocBin()
DocBin(attrs=["LEMMA"])
DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"])

View File

@ -74,4 +74,4 @@ def test_serialize_doc_bin():
# Deserialize later, e.g. in a new process
nlp = spacy.blank("en")
doc_bin = DocBin().from_bytes(bytes_data)
docs = list(doc_bin.get_docs(nlp.vocab))
list(doc_bin.get_docs(nlp.vocab))

View File

@ -48,8 +48,13 @@ URLS_SHOULD_MATCH = [
"http://a.b--c.de/", # this is a legit domain name see: https://gist.github.com/dperini/729294 comment on 9/9/2014
"ssh://login@server.com:12345/repository.git",
"svn+ssh://user@ssh.yourdomain.com/path",
pytest.param("chrome://extensions/?id=mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail()),
pytest.param("chrome-extension://mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail()),
pytest.param(
"chrome://extensions/?id=mhjfbmdgcfjbbpaeojofohoefgiehjai",
marks=pytest.mark.xfail(),
),
pytest.param(
"chrome-extension://mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail()
),
pytest.param("http://foo.com/blah_blah_(wikipedia)", marks=pytest.mark.xfail()),
pytest.param(
"http://foo.com/blah_blah_(wikipedia)_(again)", marks=pytest.mark.xfail()

View File

@ -51,6 +51,14 @@ def data():
return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype="f")
@pytest.fixture
def most_similar_vectors_data():
return numpy.asarray(
[[0.0, 1.0, 2.0], [1.0, -2.0, 4.0], [1.0, 1.0, -1.0], [2.0, 3.0, 1.0]],
dtype="f",
)
@pytest.fixture
def resize_data():
return numpy.asarray([[0.0, 1.0], [2.0, 3.0]], dtype="f")
@ -127,6 +135,12 @@ def test_set_vector(strings, data):
assert list(v[strings[0]]) != list(orig[0])
def test_vectors_most_similar(most_similar_vectors_data):
v = Vectors(data=most_similar_vectors_data)
_, best_rows, _ = v.most_similar(v.data, batch_size=2, n=2, sort=True)
assert all(row[0] == i for i, row in enumerate(best_rows))
@pytest.mark.parametrize("text", ["apple and orange"])
def test_vectors_token_vector(tokenizer_v, vectors, text):
doc = tokenizer_v(text)
@ -284,7 +298,7 @@ def test_vocab_prune_vectors():
vocab.set_vector("dog", data[1])
vocab.set_vector("kitten", data[2])
remap = vocab.prune_vectors(2)
remap = vocab.prune_vectors(2, batch_size=2)
assert list(remap.keys()) == ["kitten"]
neighbour, similarity = list(remap.values())[0]
assert neighbour == "cat", remap

View File

@ -666,7 +666,7 @@ def filter_spans(spans):
spans (iterable): The spans to filter.
RETURNS (list): The filtered spans.
"""
get_sort_key = lambda span: (span.end - span.start, span.start)
get_sort_key = lambda span: (span.end - span.start, -span.start)
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
result = []
seen_tokens = set()

View File

@ -336,8 +336,8 @@ cdef class Vectors:
best_rows[i:i+batch_size] = xp.argpartition(sims, -n, axis=1)[:,-n:]
scores[i:i+batch_size] = xp.partition(sims, -n, axis=1)[:,-n:]
if sort:
sorted_index = xp.arange(scores.shape[0])[:,None],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
if sort and n >= 2:
sorted_index = xp.arange(scores.shape[0])[:,None][i:i+batch_size],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
scores[i:i+batch_size] = scores[sorted_index]
best_rows[i:i+batch_size] = best_rows[sorted_index]

View File

@ -62,7 +62,7 @@ Whether the provided syntactic annotations form a projective dependency tree.
Convert a list of Doc objects into the
[JSON-serializable format](/api/annotation#json-input) used by the
[`spacy train`](/api/cli#train) command.
[`spacy train`](/api/cli#train) command. Each input doc will be treated as a 'paragraph' in the output doc.
> #### Example
>
@ -77,7 +77,7 @@ Convert a list of Doc objects into the
| ----------- | ---------------- | ------------------------------------------ |
| `docs` | iterable / `Doc` | The `Doc` object(s) to convert. |
| `id` | int | ID to assign to the JSON. Defaults to `0`. |
| **RETURNS** | list | The data in spaCy's JSON format. |
| **RETURNS** | dict | The data in spaCy's JSON format. |
### gold.align {#align tag="function"}

View File

@ -54,7 +54,7 @@ Lemmatize a string.
> ```python
> from spacy.lemmatizer import Lemmatizer
> from spacy.lookups import Lookups
> lookups = Loookups()
> lookups = Lookups()
> lookups.add_table("lemma_rules", {"noun": [["s", ""]]})
> lemmatizer = Lemmatizer(lookups)
> lemmas = lemmatizer("ducks", "NOUN")