mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 12:18:04 +03:00
Merge branch 'master' into spacy.io
This commit is contained in:
commit
0b2df3b879
106
.github/contributors/PeterGilles.md
vendored
Normal file
106
.github/contributors/PeterGilles.md
vendored
Normal file
|
@ -0,0 +1,106 @@
|
|||
# spaCy contributor agreement
|
||||
|
||||
This spaCy Contributor Agreement (**"SCA"**) is based on the
|
||||
[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
|
||||
The SCA applies to any contribution that you make to any product or project
|
||||
managed by us (the **"project"**), and sets out the intellectual property rights
|
||||
you grant to us in the contributed materials. The term **"us"** shall mean
|
||||
[ExplosionAI GmbH](https://explosion.ai/legal). The term
|
||||
**"you"** shall mean the person or entity identified below.
|
||||
|
||||
If you agree to be bound by these terms, fill in the information requested
|
||||
below and include the filled-in version with your first pull request, under the
|
||||
folder [`.github/contributors/`](/.github/contributors/). The name of the file
|
||||
should be your GitHub username, with the extension `.md`. For example, the user
|
||||
example_user would create the file `.github/contributors/example_user.md`.
|
||||
|
||||
Read this agreement carefully before signing. These terms and conditions
|
||||
constitute a binding legal agreement.
|
||||
|
||||
## Contributor Agreement
|
||||
|
||||
1. The term "contribution" or "contributed materials" means any source code,
|
||||
object code, patch, tool, sample, graphic, specification, manual,
|
||||
documentation, or any other material posted or submitted by you to the project.
|
||||
|
||||
2. With respect to any worldwide copyrights, or copyright applications and
|
||||
registrations, in your contribution:
|
||||
|
||||
* you hereby assign to us joint ownership, and to the extent that such
|
||||
assignment is or becomes invalid, ineffective or unenforceable, you hereby
|
||||
grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
|
||||
royalty-free, unrestricted license to exercise all rights under those
|
||||
copyrights. This includes, at our option, the right to sublicense these same
|
||||
rights to third parties through multiple levels of sublicensees or other
|
||||
licensing arrangements;
|
||||
|
||||
* you agree that each of us can do all things in relation to your
|
||||
contribution as if each of us were the sole owners, and if one of us makes
|
||||
a derivative work of your contribution, the one who makes the derivative
|
||||
work (or has it made will be the sole owner of that derivative work;
|
||||
|
||||
* you agree that you will not assert any moral rights in your contribution
|
||||
against us, our licensees or transferees;
|
||||
|
||||
* you agree that we may register a copyright in your contribution and
|
||||
exercise all ownership rights associated with it; and
|
||||
|
||||
* you agree that neither of us has any duty to consult with, obtain the
|
||||
consent of, pay or render an accounting to the other for any use or
|
||||
distribution of your contribution.
|
||||
|
||||
3. With respect to any patents you own, or that you can license without payment
|
||||
to any third party, you hereby grant to us a perpetual, irrevocable,
|
||||
non-exclusive, worldwide, no-charge, royalty-free license to:
|
||||
|
||||
* make, have made, use, sell, offer to sell, import, and otherwise transfer
|
||||
your contribution in whole or in part, alone or in combination with or
|
||||
included in any product, work or materials arising out of the project to
|
||||
which your contribution was submitted, and
|
||||
|
||||
* at our option, to sublicense these same rights to third parties through
|
||||
multiple levels of sublicensees or other licensing arrangements.
|
||||
|
||||
4. Except as set out above, you keep all right, title, and interest in your
|
||||
contribution. The rights that you grant to us under these terms are effective
|
||||
on the date you first submitted a contribution to us, even if your submission
|
||||
took place before the date you sign these terms.
|
||||
|
||||
5. You covenant, represent, warrant and agree that:
|
||||
|
||||
* Each contribution that you submit is and shall be an original work of
|
||||
authorship and you can legally grant the rights set out in this SCA;
|
||||
|
||||
* to the best of your knowledge, each contribution will not violate any
|
||||
third party's copyrights, trademarks, patents, or other intellectual
|
||||
property rights; and
|
||||
|
||||
* each contribution shall be in compliance with U.S. export control laws and
|
||||
other applicable export and import laws. You agree to notify us if you
|
||||
become aware of any circumstance which would make any of the foregoing
|
||||
representations inaccurate in any respect. We may publicly disclose your
|
||||
participation in the project, including the fact that you have signed the SCA.
|
||||
|
||||
6. This SCA is governed by the laws of the State of California and applicable
|
||||
U.S. Federal law. Any choice of law rules will not apply.
|
||||
|
||||
7. Please place an “x” on one of the applicable statement below. Please do NOT
|
||||
mark both statements:
|
||||
|
||||
* [X] I am signing on behalf of myself as an individual and no other person
|
||||
or entity, including my employer, has or will have rights with respect to my
|
||||
contributions.
|
||||
|
||||
* [ ] I am signing on behalf of my employer or a legal entity and I have the
|
||||
actual authority to contractually bind that entity.
|
||||
|
||||
## Contributor Details
|
||||
|
||||
| Field | Entry |
|
||||
|------------------------------- | -------------------- |
|
||||
| Name | Peter Gilles |
|
||||
| Company name (if applicable) | |
|
||||
| Title or role (if applicable) | |
|
||||
| Date | 10.10. |
|
||||
| GitHub username | Peter Gilles |
|
||||
| Website (optional) | |
|
|
@ -197,7 +197,7 @@ path to the model data directory.
|
|||
```python
|
||||
import spacy
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
doc = nlp(u"This is a sentence.")
|
||||
doc = nlp("This is a sentence.")
|
||||
```
|
||||
|
||||
You can also `import` a model directly via its full name and then call its
|
||||
|
@ -208,7 +208,7 @@ import spacy
|
|||
import en_core_web_sm
|
||||
|
||||
nlp = en_core_web_sm.load()
|
||||
doc = nlp(u"This is a sentence.")
|
||||
doc = nlp("This is a sentence.")
|
||||
```
|
||||
|
||||
📖 **For more info and examples, check out the
|
||||
|
|
34
bin/wiki_entity_linking/README.md
Normal file
34
bin/wiki_entity_linking/README.md
Normal file
|
@ -0,0 +1,34 @@
|
|||
## Entity Linking with Wikipedia and Wikidata
|
||||
|
||||
### Step 1: Create a Knowledge Base (KB) and training data
|
||||
|
||||
Run `wikipedia_pretrain_kb.py`
|
||||
* This takes as input the locations of a **Wikipedia and a Wikidata dump**, and produces a **KB directory** + **training file**
|
||||
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
|
||||
* You can set the filtering parameters for KB construction:
|
||||
* `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
|
||||
* `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
|
||||
* `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
||||
* Further parameters to set:
|
||||
* `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
||||
* `entity_vector_length`: length of the pre-trained entity description vectors
|
||||
* `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
|
||||
|
||||
Quick testing and rerunning:
|
||||
* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.
|
||||
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
|
||||
|
||||
|
||||
### Step 2: Train an Entity Linking model
|
||||
|
||||
Run `wikidata_train_entity_linker.py`
|
||||
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
|
||||
* You can set the learning parameters for the EL training:
|
||||
* `epochs`: number of training iterations
|
||||
* `dropout`: dropout rate
|
||||
* `lr`: learning rate
|
||||
* `l2`: L2 regularization
|
||||
* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
|
||||
* Further parameters to set:
|
||||
* `labels_discard`: NER label types to discard during training
|
|
@ -6,6 +6,7 @@ OUTPUT_MODEL_DIR = "nlp"
|
|||
PRIOR_PROB_PATH = "prior_prob.csv"
|
||||
ENTITY_DEFS_PATH = "entity_defs.csv"
|
||||
ENTITY_FREQ_PATH = "entity_freq.csv"
|
||||
ENTITY_ALIAS_PATH = "entity_alias.csv"
|
||||
ENTITY_DESCR_PATH = "entity_descriptions.csv"
|
||||
|
||||
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
|
||||
|
|
|
@ -15,10 +15,11 @@ class Metrics(object):
|
|||
candidate_is_correct = true_entity == candidate
|
||||
|
||||
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
|
||||
# Therefore, if candidate_is_correct then we have a true positive and never a true negative
|
||||
# Therefore, if candidate_is_correct then we have a true positive and never a true negative.
|
||||
self.true_pos += candidate_is_correct
|
||||
self.false_neg += not candidate_is_correct
|
||||
if candidate not in {"", "NIL"}:
|
||||
if candidate and candidate not in {"", "NIL"}:
|
||||
# A wrong prediction (e.g. Q42 != Q3) counts both as a FP as well as a FN.
|
||||
self.false_pos += not candidate_is_correct
|
||||
|
||||
def calculate_precision(self):
|
||||
|
@ -33,6 +34,14 @@ class Metrics(object):
|
|||
else:
|
||||
return self.true_pos / (self.true_pos + self.false_neg)
|
||||
|
||||
def calculate_fscore(self):
|
||||
p = self.calculate_precision()
|
||||
r = self.calculate_recall()
|
||||
if p + r == 0:
|
||||
return 0.0
|
||||
else:
|
||||
return 2 * p * r / (p + r)
|
||||
|
||||
|
||||
class EvaluationResults(object):
|
||||
def __init__(self):
|
||||
|
@ -43,18 +52,20 @@ class EvaluationResults(object):
|
|||
self.metrics.update_results(true_entity, candidate)
|
||||
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
|
||||
|
||||
def increment_false_negatives(self):
|
||||
self.metrics.false_neg += 1
|
||||
|
||||
def report_metrics(self, model_name):
|
||||
model_str = model_name.title()
|
||||
recall = self.metrics.calculate_recall()
|
||||
precision = self.metrics.calculate_precision()
|
||||
return ("{}: ".format(model_str) +
|
||||
"Recall = {} | ".format(round(recall, 3)) +
|
||||
"Precision = {} | ".format(round(precision, 3)) +
|
||||
"Precision by label = {}".format({k: v.calculate_precision()
|
||||
for k, v in self.metrics_by_label.items()}))
|
||||
fscore = self.metrics.calculate_fscore()
|
||||
return (
|
||||
"{}: ".format(model_str)
|
||||
+ "F-score = {} | ".format(round(fscore, 3))
|
||||
+ "Recall = {} | ".format(round(recall, 3))
|
||||
+ "Precision = {} | ".format(round(precision, 3))
|
||||
+ "F-score by label = {}".format(
|
||||
{k: v.calculate_fscore() for k, v in sorted(self.metrics_by_label.items())}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BaselineResults(object):
|
||||
|
@ -63,40 +74,51 @@ class BaselineResults(object):
|
|||
self.prior = EvaluationResults()
|
||||
self.oracle = EvaluationResults()
|
||||
|
||||
def report_accuracy(self, model):
|
||||
def report_performance(self, model):
|
||||
results = getattr(self, model)
|
||||
return results.report_metrics(model)
|
||||
|
||||
def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
|
||||
def update_baselines(
|
||||
self,
|
||||
true_entity,
|
||||
ent_label,
|
||||
random_candidate,
|
||||
prior_candidate,
|
||||
oracle_candidate,
|
||||
):
|
||||
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
|
||||
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
|
||||
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
||||
|
||||
|
||||
def measure_performance(dev_data, kb, el_pipe):
|
||||
baseline_accuracies = measure_baselines(
|
||||
dev_data, kb
|
||||
)
|
||||
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
|
||||
if baseline:
|
||||
baseline_accuracies, counts = measure_baselines(dev_data, kb)
|
||||
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
||||
logger.info(baseline_accuracies.report_performance("random"))
|
||||
logger.info(baseline_accuracies.report_performance("prior"))
|
||||
logger.info(baseline_accuracies.report_performance("oracle"))
|
||||
|
||||
logger.info(baseline_accuracies.report_accuracy("random"))
|
||||
logger.info(baseline_accuracies.report_accuracy("prior"))
|
||||
logger.info(baseline_accuracies.report_accuracy("oracle"))
|
||||
if context:
|
||||
# using only context
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = False
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context only"))
|
||||
|
||||
# using only context
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = False
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context only"))
|
||||
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = True
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context and prior"))
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = True
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context and prior"))
|
||||
|
||||
|
||||
def get_eval_results(data, el_pipe=None):
|
||||
# If the docs in the data require further processing with an entity linker, set el_pipe
|
||||
"""
|
||||
Evaluate the ent.kb_id_ annotations against the gold standard.
|
||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||
If the docs in the data require further processing with an entity linker, set el_pipe.
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
|
||||
docs = []
|
||||
|
@ -111,18 +133,15 @@ def get_eval_results(data, el_pipe=None):
|
|||
|
||||
results = EvaluationResults()
|
||||
for doc, gold in zip(docs, golds):
|
||||
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
# only evaluating on positive examples
|
||||
for gold_kb, value in kb_dict.items():
|
||||
if value:
|
||||
# only evaluating on positive examples
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
if offset not in tagged_entries_per_article:
|
||||
results.increment_false_negatives()
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
|
@ -142,7 +161,11 @@ def get_eval_results(data, el_pipe=None):
|
|||
|
||||
|
||||
def measure_baselines(data, kb):
|
||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||
"""
|
||||
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
|
||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||
Also return a dictionary of counts by entity label.
|
||||
"""
|
||||
counts_d = dict()
|
||||
|
||||
baseline_results = BaselineResults()
|
||||
|
@ -152,7 +175,6 @@ def measure_baselines(data, kb):
|
|||
|
||||
for doc, gold in zip(docs, golds):
|
||||
correct_entries_per_article = dict()
|
||||
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
for gold_kb, value in kb_dict.items():
|
||||
|
@ -160,10 +182,6 @@ def measure_baselines(data, kb):
|
|||
if value:
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
if offset not in tagged_entries_per_article:
|
||||
baseline_results.random.increment_false_negatives()
|
||||
baseline_results.oracle.increment_false_negatives()
|
||||
baseline_results.prior.increment_false_negatives()
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
|
@ -176,7 +194,7 @@ def measure_baselines(data, kb):
|
|||
if gold_entity is not None:
|
||||
candidates = kb.get_candidates(ent.text)
|
||||
oracle_candidate = ""
|
||||
best_candidate = ""
|
||||
prior_candidate = ""
|
||||
random_candidate = ""
|
||||
if candidates:
|
||||
scores = []
|
||||
|
@ -187,13 +205,21 @@ def measure_baselines(data, kb):
|
|||
oracle_candidate = c.entity_
|
||||
|
||||
best_index = scores.index(max(scores))
|
||||
best_candidate = candidates[best_index].entity_
|
||||
prior_candidate = candidates[best_index].entity_
|
||||
random_candidate = random.choice(candidates).entity_
|
||||
|
||||
baseline_results.update_baselines(gold_entity, ent_label,
|
||||
random_candidate, best_candidate, oracle_candidate)
|
||||
current_count = counts_d.get(ent_label, 0)
|
||||
counts_d[ent_label] = current_count+1
|
||||
|
||||
return baseline_results
|
||||
baseline_results.update_baselines(
|
||||
gold_entity,
|
||||
ent_label,
|
||||
random_candidate,
|
||||
prior_candidate,
|
||||
oracle_candidate,
|
||||
)
|
||||
|
||||
return baseline_results, counts_d
|
||||
|
||||
|
||||
def _offset(start, end):
|
||||
|
|
|
@ -1,17 +1,12 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import spacy
|
||||
import sys
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
|
||||
from bin.wiki_entity_linking import wikipedia_processor as wp
|
||||
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
||||
|
||||
csv.field_size_limit(sys.maxsize)
|
||||
from bin.wiki_entity_linking import wiki_io as io
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -22,18 +17,24 @@ def create_kb(
|
|||
max_entities_per_alias,
|
||||
min_entity_freq,
|
||||
min_occ,
|
||||
entity_def_input,
|
||||
entity_def_path,
|
||||
entity_descr_path,
|
||||
count_input,
|
||||
prior_prob_input,
|
||||
entity_alias_path,
|
||||
entity_freq_path,
|
||||
prior_prob_path,
|
||||
entity_vector_length,
|
||||
):
|
||||
# Create the knowledge base from Wikidata entries
|
||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
|
||||
entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length)
|
||||
_define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path)
|
||||
return kb
|
||||
|
||||
|
||||
def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length):
|
||||
# read the mappings from file
|
||||
title_to_id = get_entity_to_id(entity_def_input)
|
||||
id_to_descr = get_id_to_description(entity_descr_path)
|
||||
title_to_id = io.read_title_to_id(entity_def_path)
|
||||
id_to_descr = io.read_id_to_descr(entity_descr_path)
|
||||
|
||||
# check the length of the nlp vectors
|
||||
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
|
||||
|
@ -45,10 +46,8 @@ def create_kb(
|
|||
" cf. https://spacy.io/usage/models#languages."
|
||||
)
|
||||
|
||||
logger.info("Get entity frequencies")
|
||||
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
|
||||
|
||||
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
|
||||
entity_frequencies = io.read_entity_to_count(entity_freq_path)
|
||||
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
|
||||
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
|
||||
title_to_id,
|
||||
|
@ -56,36 +55,33 @@ def create_kb(
|
|||
entity_frequencies,
|
||||
min_entity_freq
|
||||
)
|
||||
logger.info("Left with {} entities".format(len(description_list)))
|
||||
logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys())))
|
||||
|
||||
logger.info("Train entity encoder")
|
||||
logger.info("Training entity encoder")
|
||||
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
|
||||
encoder.train(description_list=description_list, to_print=True)
|
||||
|
||||
logger.info("Get entity embeddings:")
|
||||
logger.info("Getting entity embeddings")
|
||||
embeddings = encoder.apply_encoder(description_list)
|
||||
|
||||
logger.info("Adding {} entities".format(len(entity_list)))
|
||||
kb.set_entities(
|
||||
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
||||
)
|
||||
return entity_list, filtered_title_to_id
|
||||
|
||||
logger.info("Adding aliases")
|
||||
|
||||
def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
|
||||
logger.info("Adding aliases from Wikipedia and Wikidata")
|
||||
_add_aliases(
|
||||
kb,
|
||||
entity_list=entity_list,
|
||||
title_to_id=filtered_title_to_id,
|
||||
max_entities_per_alias=max_entities_per_alias,
|
||||
min_occ=min_occ,
|
||||
prior_prob_input=prior_prob_input,
|
||||
prior_prob_path=prior_prob_path,
|
||||
)
|
||||
|
||||
logger.info("KB size: {} entities, {} aliases".format(
|
||||
kb.get_size_entities(),
|
||||
kb.get_size_aliases()))
|
||||
|
||||
logger.info("Done with kb")
|
||||
return kb
|
||||
|
||||
|
||||
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
|
||||
min_entity_freq: int = 10):
|
||||
|
@ -104,34 +100,13 @@ def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
|
|||
return filtered_title_to_id, entity_list, description_list, frequency_list
|
||||
|
||||
|
||||
def get_entity_to_id(entity_def_output):
|
||||
entity_to_id = dict()
|
||||
with entity_def_output.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
entity_to_id[row[0]] = row[1]
|
||||
return entity_to_id
|
||||
|
||||
|
||||
def get_id_to_description(entity_descr_path):
|
||||
id_to_desc = dict()
|
||||
with entity_descr_path.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
id_to_desc[row[0]] = row[1]
|
||||
return id_to_desc
|
||||
|
||||
|
||||
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
|
||||
def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
|
||||
wp_titles = title_to_id.keys()
|
||||
|
||||
# adding aliases with prior probabilities
|
||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||
logger.info("Adding WP aliases")
|
||||
with prior_prob_path.open("r", encoding="utf8") as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
|
@ -180,10 +155,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
line = prior_file.readline()
|
||||
|
||||
|
||||
def read_nlp_kb(model_dir, kb_file):
|
||||
nlp = spacy.load(model_dir)
|
||||
def read_kb(nlp, kb_file):
|
||||
kb = KnowledgeBase(vocab=nlp.vocab)
|
||||
kb.load_bulk(kb_file)
|
||||
logger.info("kb entities: {}".format(kb.get_size_entities()))
|
||||
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
|
||||
return nlp, kb
|
||||
return kb
|
||||
|
|
|
@ -53,7 +53,7 @@ class EntityEncoder:
|
|||
|
||||
start = start + batch_size
|
||||
stop = min(stop + batch_size, len(description_list))
|
||||
logger.info("encoded: {} entities".format(stop))
|
||||
logger.info("Encoded: {} entities".format(stop))
|
||||
|
||||
return encodings
|
||||
|
||||
|
@ -62,7 +62,7 @@ class EntityEncoder:
|
|||
if to_print:
|
||||
logger.info(
|
||||
"Trained entity descriptions on {} ".format(processed) +
|
||||
"(non-unique) entities across {} ".format(self.epochs) +
|
||||
"(non-unique) descriptions across {} ".format(self.epochs) +
|
||||
"epochs"
|
||||
)
|
||||
logger.info("Final loss: {}".format(loss))
|
||||
|
|
|
@ -1,395 +0,0 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import bz2
|
||||
import json
|
||||
|
||||
from functools import partial
|
||||
|
||||
from spacy.gold import GoldParse
|
||||
from bin.wiki_entity_linking import kb_creator
|
||||
|
||||
"""
|
||||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
|
||||
Gold-standard entities are stored in one file in standoff format (by character offset).
|
||||
"""
|
||||
|
||||
ENTITY_FILE = "gold_entities.csv"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_training_examples_and_descriptions(wikipedia_input,
|
||||
entity_def_input,
|
||||
description_output,
|
||||
training_output,
|
||||
parse_descriptions,
|
||||
limit=None):
|
||||
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
||||
_process_wikipedia_texts(wikipedia_input,
|
||||
wp_to_id,
|
||||
description_output,
|
||||
training_output,
|
||||
parse_descriptions,
|
||||
limit)
|
||||
|
||||
|
||||
def _process_wikipedia_texts(wikipedia_input,
|
||||
wp_to_id,
|
||||
output,
|
||||
training_output,
|
||||
parse_descriptions,
|
||||
limit=None):
|
||||
"""
|
||||
Read the XML wikipedia data to parse out training data:
|
||||
raw text data + positive instances
|
||||
"""
|
||||
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
|
||||
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
||||
|
||||
read_ids = set()
|
||||
|
||||
with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
|
||||
if parse_descriptions:
|
||||
_write_training_description(descr_file, "WD_id", "description")
|
||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||
article_count = 0
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
reading_text = False
|
||||
reading_revision = False
|
||||
|
||||
logger.info("Processed {} articles".format(article_count))
|
||||
|
||||
for line in file:
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
if clean_line == "<revision>":
|
||||
reading_revision = True
|
||||
elif clean_line == "</revision>":
|
||||
reading_revision = False
|
||||
|
||||
# Start reading new page
|
||||
if clean_line == "<page>":
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
# finished reading this page
|
||||
elif clean_line == "</page>":
|
||||
if article_id:
|
||||
clean_text, entities = _process_wp_text(
|
||||
article_title,
|
||||
article_text,
|
||||
wp_to_id
|
||||
)
|
||||
if clean_text is not None and entities is not None:
|
||||
_write_training_entities(entity_file,
|
||||
article_id,
|
||||
clean_text,
|
||||
entities)
|
||||
|
||||
if article_title in wp_to_id and parse_descriptions:
|
||||
description = " ".join(clean_text[:1000].split(" ")[:-1])
|
||||
_write_training_description(
|
||||
descr_file,
|
||||
wp_to_id[article_title],
|
||||
description
|
||||
)
|
||||
article_count += 1
|
||||
if article_count % 10000 == 0:
|
||||
logger.info("Processed {} articles".format(article_count))
|
||||
if limit and article_count >= limit:
|
||||
break
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
reading_text = False
|
||||
reading_revision = False
|
||||
|
||||
# start reading text within a page
|
||||
if "<text" in clean_line:
|
||||
reading_text = True
|
||||
|
||||
if reading_text:
|
||||
article_text += " " + clean_line
|
||||
|
||||
# stop reading text within a page (we assume a new page doesn't start on the same line)
|
||||
if "</text" in clean_line:
|
||||
reading_text = False
|
||||
|
||||
# read the ID of this article (outside the revision portion of the document)
|
||||
if not reading_revision:
|
||||
ids = id_regex.search(clean_line)
|
||||
if ids:
|
||||
article_id = ids[0]
|
||||
if article_id in read_ids:
|
||||
logger.info(
|
||||
"Found duplicate article ID", article_id, clean_line
|
||||
) # This should never happen ...
|
||||
read_ids.add(article_id)
|
||||
|
||||
# read the title of this article (outside the revision portion of the document)
|
||||
if not reading_revision:
|
||||
titles = title_regex.search(clean_line)
|
||||
if titles:
|
||||
article_title = titles[0].strip()
|
||||
logger.info("Finished. Processed {} articles".format(article_count))
|
||||
|
||||
|
||||
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
||||
info_regex = re.compile(r"{[^{]*?}")
|
||||
htlm_regex = re.compile(r"<!--[^-]*-->")
|
||||
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
|
||||
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
|
||||
ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
||||
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
||||
|
||||
|
||||
def _process_wp_text(article_title, article_text, wp_to_id):
|
||||
# ignore meta Wikipedia pages
|
||||
if (
|
||||
article_title.startswith("Wikipedia:") or
|
||||
article_title.startswith("Kategori:")
|
||||
):
|
||||
return None, None
|
||||
|
||||
# remove the text tags
|
||||
text_search = text_regex.search(article_text)
|
||||
if text_search is None:
|
||||
return None, None
|
||||
text = text_search.group(0)
|
||||
|
||||
# stop processing if this is a redirect page
|
||||
if text.startswith("#REDIRECT"):
|
||||
return None, None
|
||||
|
||||
# get the raw text without markup etc, keeping only interwiki links
|
||||
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
|
||||
return clean_text, entities
|
||||
|
||||
|
||||
def _get_clean_wp_text(article_text):
|
||||
clean_text = article_text.strip()
|
||||
|
||||
# remove bolding & italic markup
|
||||
clean_text = clean_text.replace("'''", "")
|
||||
clean_text = clean_text.replace("''", "")
|
||||
|
||||
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
||||
try_again = True
|
||||
previous_length = len(clean_text)
|
||||
while try_again:
|
||||
clean_text = info_regex.sub(
|
||||
"", clean_text
|
||||
) # non-greedy match excluding a nested {
|
||||
if len(clean_text) < previous_length:
|
||||
try_again = True
|
||||
else:
|
||||
try_again = False
|
||||
previous_length = len(clean_text)
|
||||
|
||||
# remove HTML comments
|
||||
clean_text = htlm_regex.sub("", clean_text)
|
||||
|
||||
# remove Category and File statements
|
||||
clean_text = category_regex.sub("", clean_text)
|
||||
clean_text = file_regex.sub("", clean_text)
|
||||
|
||||
# remove multiple =
|
||||
while "==" in clean_text:
|
||||
clean_text = clean_text.replace("==", "=")
|
||||
|
||||
clean_text = clean_text.replace(". =", ".")
|
||||
clean_text = clean_text.replace(" = ", ". ")
|
||||
clean_text = clean_text.replace("= ", ".")
|
||||
clean_text = clean_text.replace(" =", "")
|
||||
|
||||
# remove refs (non-greedy match)
|
||||
clean_text = ref_regex.sub("", clean_text)
|
||||
clean_text = ref_2_regex.sub("", clean_text)
|
||||
|
||||
# remove additional wikiformatting
|
||||
clean_text = re.sub(r"<blockquote>", "", clean_text)
|
||||
clean_text = re.sub(r"</blockquote>", "", clean_text)
|
||||
|
||||
# change special characters back to normal ones
|
||||
clean_text = clean_text.replace(r"<", "<")
|
||||
clean_text = clean_text.replace(r">", ">")
|
||||
clean_text = clean_text.replace(r""", '"')
|
||||
clean_text = clean_text.replace(r"&nbsp;", " ")
|
||||
clean_text = clean_text.replace(r"&", "&")
|
||||
|
||||
# remove multiple spaces
|
||||
while " " in clean_text:
|
||||
clean_text = clean_text.replace(" ", " ")
|
||||
|
||||
return clean_text.strip()
|
||||
|
||||
|
||||
def _remove_links(clean_text, wp_to_id):
|
||||
# read the text char by char to get the right offsets for the interwiki links
|
||||
entities = []
|
||||
final_text = ""
|
||||
open_read = 0
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
for index, letter in enumerate(clean_text):
|
||||
if letter == "[":
|
||||
open_read += 1
|
||||
elif letter == "]":
|
||||
open_read -= 1
|
||||
elif letter == "|":
|
||||
if reading_text:
|
||||
final_text += letter
|
||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||
elif reading_entity:
|
||||
reading_text = False
|
||||
reading_entity = False
|
||||
reading_mention = True
|
||||
else:
|
||||
reading_special_case = True
|
||||
else:
|
||||
if reading_entity:
|
||||
entity_buffer += letter
|
||||
elif reading_mention:
|
||||
mention_buffer += letter
|
||||
elif reading_text:
|
||||
final_text += letter
|
||||
else:
|
||||
raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
|
||||
|
||||
if open_read > 2:
|
||||
reading_special_case = True
|
||||
|
||||
if open_read == 2 and reading_text:
|
||||
reading_text = False
|
||||
reading_entity = True
|
||||
reading_mention = False
|
||||
|
||||
# we just finished reading an entity
|
||||
if open_read == 0 and not reading_text:
|
||||
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||
reading_special_case = True
|
||||
# Ignore cases with nested structures like File: handles etc
|
||||
if not reading_special_case:
|
||||
if not mention_buffer:
|
||||
mention_buffer = entity_buffer
|
||||
start = len(final_text)
|
||||
end = start + len(mention_buffer)
|
||||
qid = wp_to_id.get(entity_buffer, None)
|
||||
if qid:
|
||||
entities.append((mention_buffer, qid, start, end))
|
||||
final_text += mention_buffer
|
||||
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
return final_text, entities
|
||||
|
||||
|
||||
def _write_training_description(outputfile, qid, description):
|
||||
if description is not None:
|
||||
line = str(qid) + "|" + description + "\n"
|
||||
outputfile.write(line)
|
||||
|
||||
|
||||
def _write_training_entities(outputfile, article_id, clean_text, entities):
|
||||
entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
|
||||
line = json.dumps(
|
||||
{
|
||||
"article_id": article_id,
|
||||
"clean_text": clean_text,
|
||||
"entities": entities_data
|
||||
},
|
||||
ensure_ascii=False) + "\n"
|
||||
outputfile.write(line)
|
||||
|
||||
|
||||
def read_training(nlp, entity_file_path, dev, limit, kb):
|
||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||
For training,, it will include negative training examples by using the candidate generator,
|
||||
and it will only keep positive training examples that can be found by using the candidate generator.
|
||||
For testing, it will include all positive examples only."""
|
||||
|
||||
from tqdm import tqdm
|
||||
data = []
|
||||
num_entities = 0
|
||||
get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
|
||||
|
||||
logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
|
||||
with entity_file_path.open("r", encoding="utf8") as file:
|
||||
with tqdm(total=limit, leave=False) as pbar:
|
||||
for i, line in enumerate(file):
|
||||
example = json.loads(line)
|
||||
article_id = example["article_id"]
|
||||
clean_text = example["clean_text"]
|
||||
entities = example["entities"]
|
||||
|
||||
if dev != is_dev(article_id) or len(clean_text) >= 30000:
|
||||
continue
|
||||
|
||||
doc = nlp(clean_text)
|
||||
gold = get_gold_parse(doc, entities)
|
||||
if gold and len(gold.links) > 0:
|
||||
data.append((doc, gold))
|
||||
num_entities += len(gold.links)
|
||||
pbar.update(len(gold.links))
|
||||
if limit and num_entities >= limit:
|
||||
break
|
||||
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
||||
return data
|
||||
|
||||
|
||||
def _get_gold_parse(doc, entities, dev, kb):
|
||||
gold_entities = {}
|
||||
tagged_ent_positions = set(
|
||||
[(ent.start_char, ent.end_char) for ent in doc.ents]
|
||||
)
|
||||
|
||||
for entity in entities:
|
||||
entity_id = entity["entity"]
|
||||
alias = entity["alias"]
|
||||
start = entity["start"]
|
||||
end = entity["end"]
|
||||
|
||||
candidates = kb.get_candidates(alias)
|
||||
candidate_ids = [
|
||||
c.entity_ for c in candidates
|
||||
]
|
||||
|
||||
should_add_ent = (
|
||||
dev or
|
||||
(
|
||||
(start, end) in tagged_ent_positions and
|
||||
entity_id in candidate_ids and
|
||||
len(candidates) > 1
|
||||
)
|
||||
)
|
||||
|
||||
if should_add_ent:
|
||||
value_by_id = {entity_id: 1.0}
|
||||
if not dev:
|
||||
random.shuffle(candidate_ids)
|
||||
value_by_id.update({
|
||||
kb_id: 0.0
|
||||
for kb_id in candidate_ids
|
||||
if kb_id != entity_id
|
||||
})
|
||||
gold_entities[(start, end)] = value_by_id
|
||||
|
||||
return GoldParse(doc, links=gold_entities)
|
||||
|
||||
|
||||
def is_dev(article_id):
|
||||
return article_id.endswith("3")
|
127
bin/wiki_entity_linking/wiki_io.py
Normal file
127
bin/wiki_entity_linking/wiki_io.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
import csv
|
||||
|
||||
# min() needed to prevent error on windows, cf https://stackoverflow.com/questions/52404416/
|
||||
csv.field_size_limit(min(sys.maxsize, 2147483646))
|
||||
|
||||
""" This class provides reading/writing methods for temp files """
|
||||
|
||||
|
||||
# Entity definition: WP title -> WD ID #
|
||||
def write_title_to_id(entity_def_output, title_to_id):
|
||||
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||
for title, qid in title_to_id.items():
|
||||
id_file.write(title + "|" + str(qid) + "\n")
|
||||
|
||||
|
||||
def read_title_to_id(entity_def_output):
|
||||
title_to_id = dict()
|
||||
with entity_def_output.open("r", encoding="utf8") as id_file:
|
||||
csvreader = csv.reader(id_file, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
title_to_id[row[0]] = row[1]
|
||||
return title_to_id
|
||||
|
||||
|
||||
# Entity aliases from WD: WD ID -> WD alias #
|
||||
def write_id_to_alias(entity_alias_path, id_to_alias):
|
||||
with entity_alias_path.open("w", encoding="utf8") as alias_file:
|
||||
alias_file.write("WD_id" + "|" + "alias" + "\n")
|
||||
for qid, alias_list in id_to_alias.items():
|
||||
for alias in alias_list:
|
||||
alias_file.write(str(qid) + "|" + alias + "\n")
|
||||
|
||||
|
||||
def read_id_to_alias(entity_alias_path):
|
||||
id_to_alias = dict()
|
||||
with entity_alias_path.open("r", encoding="utf8") as alias_file:
|
||||
csvreader = csv.reader(alias_file, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
qid = row[0]
|
||||
alias = row[1]
|
||||
alias_list = id_to_alias.get(qid, [])
|
||||
alias_list.append(alias)
|
||||
id_to_alias[qid] = alias_list
|
||||
return id_to_alias
|
||||
|
||||
|
||||
def read_alias_to_id_generator(entity_alias_path):
|
||||
""" Read (aliases, qid) tuples """
|
||||
|
||||
with entity_alias_path.open("r", encoding="utf8") as alias_file:
|
||||
csvreader = csv.reader(alias_file, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
qid = row[0]
|
||||
alias = row[1]
|
||||
yield alias, qid
|
||||
|
||||
|
||||
# Entity descriptions from WD: WD ID -> WD alias #
|
||||
def write_id_to_descr(entity_descr_output, id_to_descr):
|
||||
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||
for qid, descr in id_to_descr.items():
|
||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||
|
||||
|
||||
def read_id_to_descr(entity_desc_path):
|
||||
id_to_desc = dict()
|
||||
with entity_desc_path.open("r", encoding="utf8") as descr_file:
|
||||
csvreader = csv.reader(descr_file, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
id_to_desc[row[0]] = row[1]
|
||||
return id_to_desc
|
||||
|
||||
|
||||
# Entity counts from WP: WP title -> count #
|
||||
def write_entity_to_count(prior_prob_input, count_output):
|
||||
# Write entity counts for quick access later
|
||||
entity_to_count = dict()
|
||||
total_count = 0
|
||||
|
||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
|
||||
while line:
|
||||
splits = line.replace("\n", "").split(sep="|")
|
||||
# alias = splits[0]
|
||||
count = int(splits[1])
|
||||
entity = splits[2]
|
||||
|
||||
current_count = entity_to_count.get(entity, 0)
|
||||
entity_to_count[entity] = current_count + count
|
||||
|
||||
total_count += count
|
||||
|
||||
line = prior_file.readline()
|
||||
|
||||
with count_output.open("w", encoding="utf8") as entity_file:
|
||||
entity_file.write("entity" + "|" + "count" + "\n")
|
||||
for entity, count in entity_to_count.items():
|
||||
entity_file.write(entity + "|" + str(count) + "\n")
|
||||
|
||||
|
||||
def read_entity_to_count(count_input):
|
||||
entity_to_count = dict()
|
||||
with count_input.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
entity_to_count[row[0]] = int(row[1])
|
||||
|
||||
return entity_to_count
|
128
bin/wiki_entity_linking/wiki_namespaces.py
Normal file
128
bin/wiki_entity_linking/wiki_namespaces.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
# List of meta pages in Wikidata, should be kept out of the Knowledge base
|
||||
WD_META_ITEMS = [
|
||||
"Q163875",
|
||||
"Q191780",
|
||||
"Q224414",
|
||||
"Q4167836",
|
||||
"Q4167410",
|
||||
"Q4663903",
|
||||
"Q11266439",
|
||||
"Q13406463",
|
||||
"Q15407973",
|
||||
"Q18616576",
|
||||
"Q19887878",
|
||||
"Q22808320",
|
||||
"Q23894233",
|
||||
"Q33120876",
|
||||
"Q42104522",
|
||||
"Q47460393",
|
||||
"Q64875536",
|
||||
"Q66480449",
|
||||
]
|
||||
|
||||
|
||||
# TODO: add more cases from non-English WP's
|
||||
|
||||
# List of prefixes that refer to Wikipedia "file" pages
|
||||
WP_FILE_NAMESPACE = ["Bestand", "File"]
|
||||
|
||||
# List of prefixes that refer to Wikipedia "category" pages
|
||||
WP_CATEGORY_NAMESPACE = ["Kategori", "Category", "Categorie"]
|
||||
|
||||
# List of prefixes that refer to Wikipedia "meta" pages
|
||||
# these will/should be matched ignoring case
|
||||
WP_META_NAMESPACE = (
|
||||
WP_FILE_NAMESPACE
|
||||
+ WP_CATEGORY_NAMESPACE
|
||||
+ [
|
||||
"b",
|
||||
"betawikiversity",
|
||||
"Book",
|
||||
"c",
|
||||
"Commons",
|
||||
"d",
|
||||
"dbdump",
|
||||
"download",
|
||||
"Draft",
|
||||
"Education",
|
||||
"Foundation",
|
||||
"Gadget",
|
||||
"Gadget definition",
|
||||
"Gebruiker",
|
||||
"gerrit",
|
||||
"Help",
|
||||
"Image",
|
||||
"Incubator",
|
||||
"m",
|
||||
"mail",
|
||||
"mailarchive",
|
||||
"media",
|
||||
"MediaWiki",
|
||||
"MediaWiki talk",
|
||||
"Mediawikiwiki",
|
||||
"MediaZilla",
|
||||
"Meta",
|
||||
"Metawikipedia",
|
||||
"Module",
|
||||
"mw",
|
||||
"n",
|
||||
"nost",
|
||||
"oldwikisource",
|
||||
"otrs",
|
||||
"OTRSwiki",
|
||||
"Overleg gebruiker",
|
||||
"outreach",
|
||||
"outreachwiki",
|
||||
"Portal",
|
||||
"phab",
|
||||
"Phabricator",
|
||||
"Project",
|
||||
"q",
|
||||
"quality",
|
||||
"rev",
|
||||
"s",
|
||||
"spcom",
|
||||
"Special",
|
||||
"species",
|
||||
"Strategy",
|
||||
"sulutil",
|
||||
"svn",
|
||||
"Talk",
|
||||
"Template",
|
||||
"Template talk",
|
||||
"Testwiki",
|
||||
"ticket",
|
||||
"TimedText",
|
||||
"Toollabs",
|
||||
"tools",
|
||||
"tswiki",
|
||||
"User",
|
||||
"User talk",
|
||||
"v",
|
||||
"voy",
|
||||
"w",
|
||||
"Wikibooks",
|
||||
"Wikidata",
|
||||
"wikiHow",
|
||||
"Wikinvest",
|
||||
"wikilivres",
|
||||
"Wikimedia",
|
||||
"Wikinews",
|
||||
"Wikipedia",
|
||||
"Wikipedia talk",
|
||||
"Wikiquote",
|
||||
"Wikisource",
|
||||
"Wikispecies",
|
||||
"Wikitech",
|
||||
"Wikiversity",
|
||||
"Wikivoyage",
|
||||
"wikt",
|
||||
"wiktionary",
|
||||
"wmf",
|
||||
"wmania",
|
||||
"WP",
|
||||
]
|
||||
)
|
|
@ -18,11 +18,12 @@ from pathlib import Path
|
|||
import plac
|
||||
|
||||
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
|
||||
from bin.wiki_entity_linking import wiki_io as io
|
||||
from bin.wiki_entity_linking import kb_creator
|
||||
from bin.wiki_entity_linking import training_set_creator
|
||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
|
||||
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
|
||||
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH, ENTITY_ALIAS_PATH
|
||||
import spacy
|
||||
from bin.wiki_entity_linking.kb_creator import read_kb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -39,9 +40,11 @@ logger = logging.getLogger(__name__)
|
|||
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
|
||||
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
|
||||
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
|
||||
descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
|
||||
limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
|
||||
lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
|
||||
descr_from_wp=("Flag for using wp descriptions not wd", "flag", "wp"),
|
||||
limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int),
|
||||
limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int),
|
||||
limit_wd=("Threshold to limit lines read from WD", "option", "lw", int),
|
||||
lang=("Optional language for which to get Wikidata titles. Defaults to 'en'", "option", "la", str),
|
||||
)
|
||||
def main(
|
||||
wd_json,
|
||||
|
@ -54,13 +57,16 @@ def main(
|
|||
entity_vector_length=64,
|
||||
loc_prior_prob=None,
|
||||
loc_entity_defs=None,
|
||||
loc_entity_alias=None,
|
||||
loc_entity_desc=None,
|
||||
descriptions_from_wikipedia=False,
|
||||
limit=None,
|
||||
descr_from_wp=False,
|
||||
limit_prior=None,
|
||||
limit_train=None,
|
||||
limit_wd=None,
|
||||
lang="en",
|
||||
):
|
||||
|
||||
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
|
||||
entity_alias_path = loc_entity_alias if loc_entity_alias else output_dir / ENTITY_ALIAS_PATH
|
||||
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
|
||||
entity_freq_path = output_dir / ENTITY_FREQ_PATH
|
||||
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
|
||||
|
@ -69,15 +75,12 @@ def main(
|
|||
|
||||
logger.info("Creating KB with Wikipedia and WikiData")
|
||||
|
||||
if limit is not None:
|
||||
logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
|
||||
|
||||
# STEP 0: set up IO
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
# STEP 1: create the NLP object
|
||||
logger.info("STEP 1: Loading model {}".format(model))
|
||||
# STEP 1: Load the NLP object
|
||||
logger.info("STEP 1: Loading NLP model {}".format(model))
|
||||
nlp = spacy.load(model)
|
||||
|
||||
# check the length of the nlp vectors
|
||||
|
@ -90,62 +93,83 @@ def main(
|
|||
# STEP 2: create prior probabilities from WP
|
||||
if not prior_prob_path.exists():
|
||||
# It takes about 2h to process 1000M lines of Wikipedia XML dump
|
||||
logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
|
||||
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
|
||||
logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
|
||||
logger.info("STEP 2: Writing prior probabilities to {}".format(prior_prob_path))
|
||||
if limit_prior is not None:
|
||||
logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_prior))
|
||||
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit_prior)
|
||||
else:
|
||||
logger.info("STEP 2: Reading prior probabilities from {}".format(prior_prob_path))
|
||||
|
||||
# STEP 3: deduce entity frequencies from WP (takes only a few minutes)
|
||||
logger.info("STEP 3: calculating entity frequencies")
|
||||
wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
|
||||
# STEP 3: calculate entity frequencies
|
||||
if not entity_freq_path.exists():
|
||||
logger.info("STEP 3: Calculating and writing entity frequencies to {}".format(entity_freq_path))
|
||||
io.write_entity_to_count(prior_prob_path, entity_freq_path)
|
||||
else:
|
||||
logger.info("STEP 3: Reading entity frequencies from {}".format(entity_freq_path))
|
||||
|
||||
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
|
||||
message = " and descriptions" if not descriptions_from_wikipedia else ""
|
||||
if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
|
||||
if (not entity_defs_path.exists()) or (not descr_from_wp and not entity_descr_path.exists()):
|
||||
# It takes about 10h to process 55M lines of Wikidata JSON dump
|
||||
logger.info("STEP 4: parsing wikidata for entity definitions" + message)
|
||||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
|
||||
logger.info("STEP 4: Parsing and writing Wikidata entity definitions to {}".format(entity_defs_path))
|
||||
if limit_wd is not None:
|
||||
logger.warning("Warning: reading only {} lines of Wikidata dump".format(limit_wd))
|
||||
title_to_id, id_to_descr, id_to_alias = wd.read_wikidata_entities_json(
|
||||
wd_json,
|
||||
limit,
|
||||
limit_wd,
|
||||
to_print=False,
|
||||
lang=lang,
|
||||
parse_descriptions=(not descriptions_from_wikipedia),
|
||||
parse_descr=(not descr_from_wp),
|
||||
)
|
||||
wd.write_entity_files(entity_defs_path, title_to_id)
|
||||
if not descriptions_from_wikipedia:
|
||||
wd.write_entity_description_files(entity_descr_path, id_to_descr)
|
||||
logger.info("STEP 4: read entity definitions" + message)
|
||||
io.write_title_to_id(entity_defs_path, title_to_id)
|
||||
|
||||
# STEP 5: Getting gold entities from wikipedia
|
||||
message = " and descriptions" if descriptions_from_wikipedia else ""
|
||||
if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
|
||||
logger.info("STEP 5: parsing wikipedia for gold entities" + message)
|
||||
training_set_creator.create_training_examples_and_descriptions(
|
||||
wp_xml,
|
||||
entity_defs_path,
|
||||
entity_descr_path,
|
||||
training_entities_path,
|
||||
parse_descriptions=descriptions_from_wikipedia,
|
||||
limit=limit,
|
||||
)
|
||||
logger.info("STEP 5: read gold entities" + message)
|
||||
logger.info("STEP 4b: Writing Wikidata entity aliases to {}".format(entity_alias_path))
|
||||
io.write_id_to_alias(entity_alias_path, id_to_alias)
|
||||
|
||||
if not descr_from_wp:
|
||||
logger.info("STEP 4c: Writing Wikidata entity descriptions to {}".format(entity_descr_path))
|
||||
io.write_id_to_descr(entity_descr_path, id_to_descr)
|
||||
else:
|
||||
logger.info("STEP 4: Reading entity definitions from {}".format(entity_defs_path))
|
||||
logger.info("STEP 4b: Reading entity aliases from {}".format(entity_alias_path))
|
||||
if not descr_from_wp:
|
||||
logger.info("STEP 4c: Reading entity descriptions from {}".format(entity_descr_path))
|
||||
|
||||
# STEP 5: Getting gold entities from Wikipedia
|
||||
if (not training_entities_path.exists()) or (descr_from_wp and not entity_descr_path.exists()):
|
||||
logger.info("STEP 5: Parsing and writing Wikipedia gold entities to {}".format(training_entities_path))
|
||||
if limit_train is not None:
|
||||
logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_train))
|
||||
wp.create_training_and_desc(wp_xml, entity_defs_path, entity_descr_path,
|
||||
training_entities_path, descr_from_wp, limit_train)
|
||||
if descr_from_wp:
|
||||
logger.info("STEP 5b: Parsing and writing Wikipedia descriptions to {}".format(entity_descr_path))
|
||||
else:
|
||||
logger.info("STEP 5: Reading gold entities from {}".format(training_entities_path))
|
||||
if descr_from_wp:
|
||||
logger.info("STEP 5b: Reading entity descriptions from {}".format(entity_descr_path))
|
||||
|
||||
# STEP 6: creating the actual KB
|
||||
# It takes ca. 30 minutes to pretrain the entity embeddings
|
||||
logger.info("STEP 6: creating the KB at {}".format(kb_path))
|
||||
kb = kb_creator.create_kb(
|
||||
nlp=nlp,
|
||||
max_entities_per_alias=max_per_alias,
|
||||
min_entity_freq=min_freq,
|
||||
min_occ=min_pair,
|
||||
entity_def_input=entity_defs_path,
|
||||
entity_descr_path=entity_descr_path,
|
||||
count_input=entity_freq_path,
|
||||
prior_prob_input=prior_prob_path,
|
||||
entity_vector_length=entity_vector_length,
|
||||
)
|
||||
|
||||
kb.dump(kb_path)
|
||||
nlp.to_disk(output_dir / KB_MODEL_DIR)
|
||||
if not kb_path.exists():
|
||||
logger.info("STEP 6: Creating the KB at {}".format(kb_path))
|
||||
kb = kb_creator.create_kb(
|
||||
nlp=nlp,
|
||||
max_entities_per_alias=max_per_alias,
|
||||
min_entity_freq=min_freq,
|
||||
min_occ=min_pair,
|
||||
entity_def_path=entity_defs_path,
|
||||
entity_descr_path=entity_descr_path,
|
||||
entity_alias_path=entity_alias_path,
|
||||
entity_freq_path=entity_freq_path,
|
||||
prior_prob_path=prior_prob_path,
|
||||
entity_vector_length=entity_vector_length,
|
||||
)
|
||||
kb.dump(kb_path)
|
||||
logger.info("kb entities: {}".format(kb.get_size_entities()))
|
||||
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
|
||||
nlp.to_disk(output_dir / KB_MODEL_DIR)
|
||||
else:
|
||||
logger.info("STEP 6: KB already exists at {}".format(kb_path))
|
||||
|
||||
logger.info("Done!")
|
||||
|
||||
|
|
|
@ -1,40 +1,52 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import gzip
|
||||
import bz2
|
||||
import json
|
||||
import logging
|
||||
import datetime
|
||||
|
||||
from bin.wiki_entity_linking.wiki_namespaces import WD_META_ITEMS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
|
||||
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
|
||||
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descr=True):
|
||||
# Read the JSON wiki data and parse out the entities. Takes about 7-10h to parse 55M lines.
|
||||
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||
|
||||
site_filter = '{}wiki'.format(lang)
|
||||
|
||||
# properties filter (currently disabled to get ALL data)
|
||||
prop_filter = dict()
|
||||
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
|
||||
# filter: currently defined as OR: one hit suffices to be removed from further processing
|
||||
exclude_list = WD_META_ITEMS
|
||||
|
||||
# punctuation
|
||||
exclude_list.extend(["Q1383557", "Q10617810"])
|
||||
|
||||
# letters etc
|
||||
exclude_list.extend(["Q188725", "Q19776628", "Q3841820", "Q17907810", "Q9788", "Q9398093"])
|
||||
|
||||
neg_prop_filter = {
|
||||
'P31': exclude_list, # instance of
|
||||
'P279': exclude_list # subclass
|
||||
}
|
||||
|
||||
title_to_id = dict()
|
||||
id_to_descr = dict()
|
||||
id_to_alias = dict()
|
||||
|
||||
# parse appropriate fields - depending on what we need in the KB
|
||||
parse_properties = False
|
||||
parse_sitelinks = True
|
||||
parse_labels = False
|
||||
parse_aliases = False
|
||||
parse_claims = False
|
||||
parse_aliases = True
|
||||
parse_claims = True
|
||||
|
||||
with gzip.open(wikidata_file, mode='rb') as file:
|
||||
with bz2.open(wikidata_file, mode='rb') as file:
|
||||
for cnt, line in enumerate(file):
|
||||
if limit and cnt >= limit:
|
||||
break
|
||||
if cnt % 500000 == 0:
|
||||
logger.info("processed {} lines of WikiData dump".format(cnt))
|
||||
if cnt % 500000 == 0 and cnt > 0:
|
||||
logger.info("processed {} lines of WikiData JSON dump".format(cnt))
|
||||
clean_line = line.strip()
|
||||
if clean_line.endswith(b","):
|
||||
clean_line = clean_line[:-1]
|
||||
|
@ -43,13 +55,11 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
|||
entry_type = obj["type"]
|
||||
|
||||
if entry_type == "item":
|
||||
# filtering records on their properties (currently disabled to get ALL data)
|
||||
# keep = False
|
||||
keep = True
|
||||
|
||||
claims = obj["claims"]
|
||||
if parse_claims:
|
||||
for prop, value_set in prop_filter.items():
|
||||
for prop, value_set in neg_prop_filter.items():
|
||||
claim_property = claims.get(prop, None)
|
||||
if claim_property:
|
||||
for cp in claim_property:
|
||||
|
@ -61,7 +71,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
|||
)
|
||||
cp_rank = cp["rank"]
|
||||
if cp_rank != "deprecated" and cp_id in value_set:
|
||||
keep = True
|
||||
keep = False
|
||||
|
||||
if keep:
|
||||
unique_id = obj["id"]
|
||||
|
@ -108,7 +118,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
|||
"label (" + lang + "):", lang_label["value"]
|
||||
)
|
||||
|
||||
if found_link and parse_descriptions:
|
||||
if found_link and parse_descr:
|
||||
descriptions = obj["descriptions"]
|
||||
if descriptions:
|
||||
lang_descr = descriptions.get(lang, None)
|
||||
|
@ -130,22 +140,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
|||
print(
|
||||
"alias (" + lang + "):", item["value"]
|
||||
)
|
||||
alias_list = id_to_alias.get(unique_id, [])
|
||||
alias_list.append(item["value"])
|
||||
id_to_alias[unique_id] = alias_list
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
|
||||
return title_to_id, id_to_descr
|
||||
# log final number of lines processed
|
||||
logger.info("Finished. Processed {} lines of WikiData JSON dump".format(cnt))
|
||||
return title_to_id, id_to_descr, id_to_alias
|
||||
|
||||
|
||||
def write_entity_files(entity_def_output, title_to_id):
|
||||
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||
for title, qid in title_to_id.items():
|
||||
id_file.write(title + "|" + str(qid) + "\n")
|
||||
|
||||
|
||||
def write_entity_description_files(entity_descr_output, id_to_descr):
|
||||
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||
for qid, descr in id_to_descr.items():
|
||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||
|
|
|
@ -6,19 +6,19 @@ as created by the script `wikidata_create_kb`.
|
|||
|
||||
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
|
||||
from https://dumps.wikimedia.org/enwiki/latest/
|
||||
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import random
|
||||
import logging
|
||||
import spacy
|
||||
from pathlib import Path
|
||||
import plac
|
||||
|
||||
from bin.wiki_entity_linking import training_set_creator
|
||||
from bin.wiki_entity_linking import wikipedia_processor
|
||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
||||
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
|
||||
from bin.wiki_entity_linking.kb_creator import read_nlp_kb
|
||||
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance
|
||||
from bin.wiki_entity_linking.kb_creator import read_kb
|
||||
|
||||
from spacy.util import minibatch, compounding
|
||||
|
||||
|
@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
|
|||
l2=("L2 regularization", "option", "r", float),
|
||||
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
||||
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
||||
labels_discard=("NER labels to discard (default None)", "option", "l", str),
|
||||
)
|
||||
def main(
|
||||
dir_kb,
|
||||
|
@ -46,13 +47,14 @@ def main(
|
|||
l2=1e-6,
|
||||
train_inst=None,
|
||||
dev_inst=None,
|
||||
labels_discard=None
|
||||
):
|
||||
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
||||
|
||||
output_dir = Path(output_dir) if output_dir else dir_kb
|
||||
training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
|
||||
training_path = loc_training if loc_training else dir_kb / TRAINING_DATA_FILE
|
||||
nlp_dir = dir_kb / KB_MODEL_DIR
|
||||
kb_path = output_dir / KB_FILE
|
||||
kb_path = dir_kb / KB_FILE
|
||||
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
|
||||
|
||||
# STEP 0: set up IO
|
||||
|
@ -60,38 +62,47 @@ def main(
|
|||
output_dir.mkdir()
|
||||
|
||||
# STEP 1 : load the NLP object
|
||||
logger.info("STEP 1: loading model from {}".format(nlp_dir))
|
||||
nlp, kb = read_nlp_kb(nlp_dir, kb_path)
|
||||
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
|
||||
nlp = spacy.load(nlp_dir)
|
||||
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
||||
kb = read_kb(nlp, kb_path)
|
||||
|
||||
# check that there is a NER component in the pipeline
|
||||
if "ner" not in nlp.pipe_names:
|
||||
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
|
||||
|
||||
# STEP 2: create a training dataset from WP
|
||||
logger.info("STEP 2: reading training dataset from {}".format(training_path))
|
||||
# STEP 2: read the training dataset previously created from WP
|
||||
logger.info("STEP 2: Reading training dataset from {}".format(training_path))
|
||||
|
||||
train_data = training_set_creator.read_training(
|
||||
if labels_discard:
|
||||
labels_discard = [x.strip() for x in labels_discard.split(",")]
|
||||
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
|
||||
|
||||
train_data = wikipedia_processor.read_training(
|
||||
nlp=nlp,
|
||||
entity_file_path=training_path,
|
||||
dev=False,
|
||||
limit=train_inst,
|
||||
kb=kb,
|
||||
labels_discard=labels_discard
|
||||
)
|
||||
|
||||
# for testing, get all pos instances, whether or not they are in the kb
|
||||
dev_data = training_set_creator.read_training(
|
||||
# for testing, get all pos instances (independently of KB)
|
||||
dev_data = wikipedia_processor.read_training(
|
||||
nlp=nlp,
|
||||
entity_file_path=training_path,
|
||||
dev=True,
|
||||
limit=dev_inst,
|
||||
kb=kb,
|
||||
kb=None,
|
||||
labels_discard=labels_discard
|
||||
)
|
||||
|
||||
# STEP 3: create and train the entity linking pipe
|
||||
logger.info("STEP 3: training Entity Linking pipe")
|
||||
# STEP 3: create and train an entity linking pipe
|
||||
logger.info("STEP 3: Creating and training an Entity Linking pipe")
|
||||
|
||||
el_pipe = nlp.create_pipe(
|
||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
|
||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
|
||||
"labels_discard": labels_discard}
|
||||
)
|
||||
el_pipe.set_kb(kb)
|
||||
nlp.add_pipe(el_pipe, last=True)
|
||||
|
@ -105,14 +116,9 @@ def main(
|
|||
logger.info("Training on {} articles".format(len(train_data)))
|
||||
logger.info("Dev testing on {} articles".format(len(dev_data)))
|
||||
|
||||
dev_baseline_accuracies = measure_baselines(
|
||||
dev_data, kb
|
||||
)
|
||||
|
||||
# baseline performance on dev data
|
||||
logger.info("Dev Baseline Accuracies:")
|
||||
logger.info(dev_baseline_accuracies.report_accuracy("random"))
|
||||
logger.info(dev_baseline_accuracies.report_accuracy("prior"))
|
||||
logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
|
||||
|
||||
for itn in range(epochs):
|
||||
random.shuffle(train_data)
|
||||
|
@ -136,18 +142,18 @@ def main(
|
|||
logger.error("Error updating batch:" + str(e))
|
||||
if batchnr > 0:
|
||||
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
|
||||
measure_performance(dev_data, kb, el_pipe)
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
|
||||
|
||||
# STEP 4: measure the performance of our trained pipe on an independent dev set
|
||||
logger.info("STEP 4: performance measurement of Entity Linking pipe")
|
||||
logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
|
||||
measure_performance(dev_data, kb, el_pipe)
|
||||
|
||||
# STEP 5: apply the EL pipe on a toy example
|
||||
logger.info("STEP 5: applying Entity Linking to toy example")
|
||||
logger.info("STEP 5: Applying Entity Linking to toy example")
|
||||
run_el_toy_example(nlp=nlp)
|
||||
|
||||
if output_dir:
|
||||
# STEP 6: write the NLP pipeline (including entity linker) to file
|
||||
# STEP 6: write the NLP pipeline (now including an EL model) to file
|
||||
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
|
||||
nlp.to_disk(nlp_output_dir)
|
||||
|
||||
|
|
|
@ -3,147 +3,104 @@ from __future__ import unicode_literals
|
|||
|
||||
import re
|
||||
import bz2
|
||||
import csv
|
||||
import datetime
|
||||
import logging
|
||||
import random
|
||||
import json
|
||||
|
||||
from bin.wiki_entity_linking import LOG_FORMAT
|
||||
from functools import partial
|
||||
|
||||
from spacy.gold import GoldParse
|
||||
from bin.wiki_entity_linking import wiki_io as io
|
||||
from bin.wiki_entity_linking.wiki_namespaces import (
|
||||
WP_META_NAMESPACE,
|
||||
WP_FILE_NAMESPACE,
|
||||
WP_CATEGORY_NAMESPACE,
|
||||
)
|
||||
|
||||
"""
|
||||
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
|
||||
Write these results to file for downstream KB and training data generation.
|
||||
|
||||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
|
||||
"""
|
||||
|
||||
ENTITY_FILE = "gold_entities.csv"
|
||||
|
||||
map_alias_to_link = dict()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# these will/should be matched ignoring case
|
||||
wiki_namespaces = [
|
||||
"b",
|
||||
"betawikiversity",
|
||||
"Book",
|
||||
"c",
|
||||
"Category",
|
||||
"Commons",
|
||||
"d",
|
||||
"dbdump",
|
||||
"download",
|
||||
"Draft",
|
||||
"Education",
|
||||
"Foundation",
|
||||
"Gadget",
|
||||
"Gadget definition",
|
||||
"gerrit",
|
||||
"File",
|
||||
"Help",
|
||||
"Image",
|
||||
"Incubator",
|
||||
"m",
|
||||
"mail",
|
||||
"mailarchive",
|
||||
"media",
|
||||
"MediaWiki",
|
||||
"MediaWiki talk",
|
||||
"Mediawikiwiki",
|
||||
"MediaZilla",
|
||||
"Meta",
|
||||
"Metawikipedia",
|
||||
"Module",
|
||||
"mw",
|
||||
"n",
|
||||
"nost",
|
||||
"oldwikisource",
|
||||
"outreach",
|
||||
"outreachwiki",
|
||||
"otrs",
|
||||
"OTRSwiki",
|
||||
"Portal",
|
||||
"phab",
|
||||
"Phabricator",
|
||||
"Project",
|
||||
"q",
|
||||
"quality",
|
||||
"rev",
|
||||
"s",
|
||||
"spcom",
|
||||
"Special",
|
||||
"species",
|
||||
"Strategy",
|
||||
"sulutil",
|
||||
"svn",
|
||||
"Talk",
|
||||
"Template",
|
||||
"Template talk",
|
||||
"Testwiki",
|
||||
"ticket",
|
||||
"TimedText",
|
||||
"Toollabs",
|
||||
"tools",
|
||||
"tswiki",
|
||||
"User",
|
||||
"User talk",
|
||||
"v",
|
||||
"voy",
|
||||
"w",
|
||||
"Wikibooks",
|
||||
"Wikidata",
|
||||
"wikiHow",
|
||||
"Wikinvest",
|
||||
"wikilivres",
|
||||
"Wikimedia",
|
||||
"Wikinews",
|
||||
"Wikipedia",
|
||||
"Wikipedia talk",
|
||||
"Wikiquote",
|
||||
"Wikisource",
|
||||
"Wikispecies",
|
||||
"Wikitech",
|
||||
"Wikiversity",
|
||||
"Wikivoyage",
|
||||
"wikt",
|
||||
"wiktionary",
|
||||
"wmf",
|
||||
"wmania",
|
||||
"WP",
|
||||
]
|
||||
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
|
||||
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
||||
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
||||
info_regex = re.compile(r"{[^{]*?}")
|
||||
html_regex = re.compile(r"<!--[^-]*-->")
|
||||
ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
||||
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
||||
|
||||
# find the links
|
||||
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
|
||||
|
||||
# match on interwiki links, e.g. `en:` or `:fr:`
|
||||
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
||||
|
||||
# match on Namespace: optionally preceded by a :
|
||||
for ns in wiki_namespaces:
|
||||
for ns in WP_META_NAMESPACE:
|
||||
ns_regex += "|" + ":?" + ns + ":"
|
||||
|
||||
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
||||
|
||||
files = r""
|
||||
for f in WP_FILE_NAMESPACE:
|
||||
files += "\[\[" + f + ":[^[\]]+]]" + "|"
|
||||
files = files[0 : len(files) - 1]
|
||||
file_regex = re.compile(files)
|
||||
|
||||
cats = r""
|
||||
for c in WP_CATEGORY_NAMESPACE:
|
||||
cats += "\[\[" + c + ":[^\[]*]]" + "|"
|
||||
cats = cats[0 : len(cats) - 1]
|
||||
category_regex = re.compile(cats)
|
||||
|
||||
|
||||
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
||||
"""
|
||||
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
||||
The full file takes about 2h to parse 1100M lines.
|
||||
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
|
||||
The full file takes about 2-3h to parse 1100M lines.
|
||||
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from,
|
||||
though dev test articles are excluded in order not to get an artificially strong baseline.
|
||||
"""
|
||||
cnt = 0
|
||||
read_id = False
|
||||
current_article_id = None
|
||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||
line = file.readline()
|
||||
cnt = 0
|
||||
while line and (not limit or cnt < limit):
|
||||
if cnt % 25000000 == 0:
|
||||
if cnt % 25000000 == 0 and cnt > 0:
|
||||
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||
for alias, entity, norm in zip(aliases, entities, normalizations):
|
||||
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
|
||||
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
|
||||
# we attempt at reading the article's ID (but not the revision or contributor ID)
|
||||
if "<revision>" in clean_line or "<contributor>" in clean_line:
|
||||
read_id = False
|
||||
if "<page>" in clean_line:
|
||||
read_id = True
|
||||
|
||||
if read_id:
|
||||
ids = id_regex.search(clean_line)
|
||||
if ids:
|
||||
current_article_id = ids[0]
|
||||
|
||||
# only processing prior probabilities from true training (non-dev) articles
|
||||
if not is_dev(current_article_id):
|
||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||
for alias, entity, norm in zip(aliases, entities, normalizations):
|
||||
_store_alias(
|
||||
alias, entity, normalize_alias=norm, normalize_entity=True
|
||||
)
|
||||
|
||||
line = file.readline()
|
||||
cnt += 1
|
||||
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
||||
logger.info("Finished. processed {} lines of Wikipedia XML dump".format(cnt))
|
||||
|
||||
# write all aliases and their entities and count occurrences to file
|
||||
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
||||
|
@ -182,7 +139,7 @@ def get_wp_links(text):
|
|||
match = match[2:][:-2].replace("_", " ").strip()
|
||||
|
||||
if ns_regex.match(match):
|
||||
pass # ignore namespaces at the beginning of the string
|
||||
pass # ignore the entity if it points to a "meta" page
|
||||
|
||||
# this is a simple [[link]], with the alias the same as the mention
|
||||
elif "|" not in match:
|
||||
|
@ -218,47 +175,382 @@ def _capitalize_first(text):
|
|||
return result
|
||||
|
||||
|
||||
def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||
# Write entity counts for quick access later
|
||||
entity_to_count = dict()
|
||||
total_count = 0
|
||||
|
||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
|
||||
while line:
|
||||
splits = line.replace("\n", "").split(sep="|")
|
||||
# alias = splits[0]
|
||||
count = int(splits[1])
|
||||
entity = splits[2]
|
||||
|
||||
current_count = entity_to_count.get(entity, 0)
|
||||
entity_to_count[entity] = current_count + count
|
||||
|
||||
total_count += count
|
||||
|
||||
line = prior_file.readline()
|
||||
|
||||
with count_output.open("w", encoding="utf8") as entity_file:
|
||||
entity_file.write("entity" + "|" + "count" + "\n")
|
||||
for entity, count in entity_to_count.items():
|
||||
entity_file.write(entity + "|" + str(count) + "\n")
|
||||
|
||||
if to_print:
|
||||
for entity, count in entity_to_count.items():
|
||||
print("Entity count:", entity, count)
|
||||
print("Total count:", total_count)
|
||||
def create_training_and_desc(
|
||||
wp_input, def_input, desc_output, training_output, parse_desc, limit=None
|
||||
):
|
||||
wp_to_id = io.read_title_to_id(def_input)
|
||||
_process_wikipedia_texts(
|
||||
wp_input, wp_to_id, desc_output, training_output, parse_desc, limit
|
||||
)
|
||||
|
||||
|
||||
def get_all_frequencies(count_input):
|
||||
entity_to_count = dict()
|
||||
with count_input.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
entity_to_count[row[0]] = int(row[1])
|
||||
def _process_wikipedia_texts(
|
||||
wikipedia_input, wp_to_id, output, training_output, parse_descriptions, limit=None
|
||||
):
|
||||
"""
|
||||
Read the XML wikipedia data to parse out training data:
|
||||
raw text data + positive instances
|
||||
"""
|
||||
|
||||
return entity_to_count
|
||||
read_ids = set()
|
||||
|
||||
with output.open("a", encoding="utf8") as descr_file, training_output.open(
|
||||
"w", encoding="utf8"
|
||||
) as entity_file:
|
||||
if parse_descriptions:
|
||||
_write_training_description(descr_file, "WD_id", "description")
|
||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||
article_count = 0
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
reading_text = False
|
||||
reading_revision = False
|
||||
|
||||
for line in file:
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
if clean_line == "<revision>":
|
||||
reading_revision = True
|
||||
elif clean_line == "</revision>":
|
||||
reading_revision = False
|
||||
|
||||
# Start reading new page
|
||||
if clean_line == "<page>":
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
# finished reading this page
|
||||
elif clean_line == "</page>":
|
||||
if article_id:
|
||||
clean_text, entities = _process_wp_text(
|
||||
article_title, article_text, wp_to_id
|
||||
)
|
||||
if clean_text is not None and entities is not None:
|
||||
_write_training_entities(
|
||||
entity_file, article_id, clean_text, entities
|
||||
)
|
||||
|
||||
if article_title in wp_to_id and parse_descriptions:
|
||||
description = " ".join(
|
||||
clean_text[:1000].split(" ")[:-1]
|
||||
)
|
||||
_write_training_description(
|
||||
descr_file, wp_to_id[article_title], description
|
||||
)
|
||||
article_count += 1
|
||||
if article_count % 10000 == 0 and article_count > 0:
|
||||
logger.info(
|
||||
"Processed {} articles".format(article_count)
|
||||
)
|
||||
if limit and article_count >= limit:
|
||||
break
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
reading_text = False
|
||||
reading_revision = False
|
||||
|
||||
# start reading text within a page
|
||||
if "<text" in clean_line:
|
||||
reading_text = True
|
||||
|
||||
if reading_text:
|
||||
article_text += " " + clean_line
|
||||
|
||||
# stop reading text within a page (we assume a new page doesn't start on the same line)
|
||||
if "</text" in clean_line:
|
||||
reading_text = False
|
||||
|
||||
# read the ID of this article (outside the revision portion of the document)
|
||||
if not reading_revision:
|
||||
ids = id_regex.search(clean_line)
|
||||
if ids:
|
||||
article_id = ids[0]
|
||||
if article_id in read_ids:
|
||||
logger.info(
|
||||
"Found duplicate article ID", article_id, clean_line
|
||||
) # This should never happen ...
|
||||
read_ids.add(article_id)
|
||||
|
||||
# read the title of this article (outside the revision portion of the document)
|
||||
if not reading_revision:
|
||||
titles = title_regex.search(clean_line)
|
||||
if titles:
|
||||
article_title = titles[0].strip()
|
||||
logger.info("Finished. Processed {} articles".format(article_count))
|
||||
|
||||
|
||||
def _process_wp_text(article_title, article_text, wp_to_id):
|
||||
# ignore meta Wikipedia pages
|
||||
if ns_regex.match(article_title):
|
||||
return None, None
|
||||
|
||||
# remove the text tags
|
||||
text_search = text_regex.search(article_text)
|
||||
if text_search is None:
|
||||
return None, None
|
||||
text = text_search.group(0)
|
||||
|
||||
# stop processing if this is a redirect page
|
||||
if text.startswith("#REDIRECT"):
|
||||
return None, None
|
||||
|
||||
# get the raw text without markup etc, keeping only interwiki links
|
||||
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
|
||||
return clean_text, entities
|
||||
|
||||
|
||||
def _get_clean_wp_text(article_text):
|
||||
clean_text = article_text.strip()
|
||||
|
||||
# remove bolding & italic markup
|
||||
clean_text = clean_text.replace("'''", "")
|
||||
clean_text = clean_text.replace("''", "")
|
||||
|
||||
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
||||
try_again = True
|
||||
previous_length = len(clean_text)
|
||||
while try_again:
|
||||
clean_text = info_regex.sub(
|
||||
"", clean_text
|
||||
) # non-greedy match excluding a nested {
|
||||
if len(clean_text) < previous_length:
|
||||
try_again = True
|
||||
else:
|
||||
try_again = False
|
||||
previous_length = len(clean_text)
|
||||
|
||||
# remove HTML comments
|
||||
clean_text = html_regex.sub("", clean_text)
|
||||
|
||||
# remove Category and File statements
|
||||
clean_text = category_regex.sub("", clean_text)
|
||||
clean_text = file_regex.sub("", clean_text)
|
||||
|
||||
# remove multiple =
|
||||
while "==" in clean_text:
|
||||
clean_text = clean_text.replace("==", "=")
|
||||
|
||||
clean_text = clean_text.replace(". =", ".")
|
||||
clean_text = clean_text.replace(" = ", ". ")
|
||||
clean_text = clean_text.replace("= ", ".")
|
||||
clean_text = clean_text.replace(" =", "")
|
||||
|
||||
# remove refs (non-greedy match)
|
||||
clean_text = ref_regex.sub("", clean_text)
|
||||
clean_text = ref_2_regex.sub("", clean_text)
|
||||
|
||||
# remove additional wikiformatting
|
||||
clean_text = re.sub(r"<blockquote>", "", clean_text)
|
||||
clean_text = re.sub(r"</blockquote>", "", clean_text)
|
||||
|
||||
# change special characters back to normal ones
|
||||
clean_text = clean_text.replace(r"<", "<")
|
||||
clean_text = clean_text.replace(r">", ">")
|
||||
clean_text = clean_text.replace(r""", '"')
|
||||
clean_text = clean_text.replace(r"&nbsp;", " ")
|
||||
clean_text = clean_text.replace(r"&", "&")
|
||||
|
||||
# remove multiple spaces
|
||||
while " " in clean_text:
|
||||
clean_text = clean_text.replace(" ", " ")
|
||||
|
||||
return clean_text.strip()
|
||||
|
||||
|
||||
def _remove_links(clean_text, wp_to_id):
|
||||
# read the text char by char to get the right offsets for the interwiki links
|
||||
entities = []
|
||||
final_text = ""
|
||||
open_read = 0
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
for index, letter in enumerate(clean_text):
|
||||
if letter == "[":
|
||||
open_read += 1
|
||||
elif letter == "]":
|
||||
open_read -= 1
|
||||
elif letter == "|":
|
||||
if reading_text:
|
||||
final_text += letter
|
||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||
elif reading_entity:
|
||||
reading_text = False
|
||||
reading_entity = False
|
||||
reading_mention = True
|
||||
else:
|
||||
reading_special_case = True
|
||||
else:
|
||||
if reading_entity:
|
||||
entity_buffer += letter
|
||||
elif reading_mention:
|
||||
mention_buffer += letter
|
||||
elif reading_text:
|
||||
final_text += letter
|
||||
else:
|
||||
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
|
||||
|
||||
if open_read > 2:
|
||||
reading_special_case = True
|
||||
|
||||
if open_read == 2 and reading_text:
|
||||
reading_text = False
|
||||
reading_entity = True
|
||||
reading_mention = False
|
||||
|
||||
# we just finished reading an entity
|
||||
if open_read == 0 and not reading_text:
|
||||
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||
reading_special_case = True
|
||||
# Ignore cases with nested structures like File: handles etc
|
||||
if not reading_special_case:
|
||||
if not mention_buffer:
|
||||
mention_buffer = entity_buffer
|
||||
start = len(final_text)
|
||||
end = start + len(mention_buffer)
|
||||
qid = wp_to_id.get(entity_buffer, None)
|
||||
if qid:
|
||||
entities.append((mention_buffer, qid, start, end))
|
||||
final_text += mention_buffer
|
||||
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
return final_text, entities
|
||||
|
||||
|
||||
def _write_training_description(outputfile, qid, description):
|
||||
if description is not None:
|
||||
line = str(qid) + "|" + description + "\n"
|
||||
outputfile.write(line)
|
||||
|
||||
|
||||
def _write_training_entities(outputfile, article_id, clean_text, entities):
|
||||
entities_data = [
|
||||
{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]}
|
||||
for ent in entities
|
||||
]
|
||||
line = (
|
||||
json.dumps(
|
||||
{
|
||||
"article_id": article_id,
|
||||
"clean_text": clean_text,
|
||||
"entities": entities_data,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
outputfile.write(line)
|
||||
|
||||
|
||||
def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
|
||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||
For training, it will include both positive and negative examples by using the candidate generator from the kb.
|
||||
For testing (kb=None), it will include all positive examples only."""
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
if not labels_discard:
|
||||
labels_discard = []
|
||||
|
||||
data = []
|
||||
num_entities = 0
|
||||
get_gold_parse = partial(
|
||||
_get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Reading {} data with limit {}".format("dev" if dev else "train", limit)
|
||||
)
|
||||
with entity_file_path.open("r", encoding="utf8") as file:
|
||||
with tqdm(total=limit, leave=False) as pbar:
|
||||
for i, line in enumerate(file):
|
||||
example = json.loads(line)
|
||||
article_id = example["article_id"]
|
||||
clean_text = example["clean_text"]
|
||||
entities = example["entities"]
|
||||
|
||||
if dev != is_dev(article_id) or not is_valid_article(clean_text):
|
||||
continue
|
||||
|
||||
doc = nlp(clean_text)
|
||||
gold = get_gold_parse(doc, entities)
|
||||
if gold and len(gold.links) > 0:
|
||||
data.append((doc, gold))
|
||||
num_entities += len(gold.links)
|
||||
pbar.update(len(gold.links))
|
||||
if limit and num_entities >= limit:
|
||||
break
|
||||
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
||||
return data
|
||||
|
||||
|
||||
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
|
||||
gold_entities = {}
|
||||
tagged_ent_positions = {
|
||||
(ent.start_char, ent.end_char): ent
|
||||
for ent in doc.ents
|
||||
if ent.label_ not in labels_discard
|
||||
}
|
||||
|
||||
for entity in entities:
|
||||
entity_id = entity["entity"]
|
||||
alias = entity["alias"]
|
||||
start = entity["start"]
|
||||
end = entity["end"]
|
||||
|
||||
candidate_ids = []
|
||||
if kb and not dev:
|
||||
candidates = kb.get_candidates(alias)
|
||||
candidate_ids = [cand.entity_ for cand in candidates]
|
||||
|
||||
tagged_ent = tagged_ent_positions.get((start, end), None)
|
||||
if tagged_ent:
|
||||
# TODO: check that alias == doc.text[start:end]
|
||||
should_add_ent = (dev or entity_id in candidate_ids) and is_valid_sentence(
|
||||
tagged_ent.sent.text
|
||||
)
|
||||
|
||||
if should_add_ent:
|
||||
value_by_id = {entity_id: 1.0}
|
||||
if not dev:
|
||||
random.shuffle(candidate_ids)
|
||||
value_by_id.update(
|
||||
{kb_id: 0.0 for kb_id in candidate_ids if kb_id != entity_id}
|
||||
)
|
||||
gold_entities[(start, end)] = value_by_id
|
||||
|
||||
return GoldParse(doc, links=gold_entities)
|
||||
|
||||
|
||||
def is_dev(article_id):
|
||||
if not article_id:
|
||||
return False
|
||||
return article_id.endswith("3")
|
||||
|
||||
|
||||
def is_valid_article(doc_text):
|
||||
# custom length cut-off
|
||||
return 10 < len(doc_text) < 30000
|
||||
|
||||
|
||||
def is_valid_sentence(sent_text):
|
||||
if not 10 < len(sent_text) < 3000:
|
||||
# custom length cut-off
|
||||
return False
|
||||
|
||||
if sent_text.strip().startswith("*") or sent_text.strip().startswith("#"):
|
||||
# remove 'enumeration' sentences (occurs often on Wikipedia)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
|
@ -7,7 +7,7 @@ dependency tree to find the noun phrase they are referring to – for example:
|
|||
$9.4 million --> Net income.
|
||||
|
||||
Compatible with: spaCy v2.0.0+
|
||||
Last tested with: v2.1.0
|
||||
Last tested with: v2.2.1
|
||||
"""
|
||||
from __future__ import unicode_literals, print_function
|
||||
|
||||
|
@ -38,14 +38,17 @@ def main(model="en_core_web_sm"):
|
|||
|
||||
def filter_spans(spans):
|
||||
# Filter a sequence of spans so they don't contain overlaps
|
||||
get_sort_key = lambda span: (span.end - span.start, span.start)
|
||||
# For spaCy 2.1.4+: this function is available as spacy.util.filter_spans()
|
||||
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
||||
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||
result = []
|
||||
seen_tokens = set()
|
||||
for span in sorted_spans:
|
||||
# Check for end - 1 here because boundaries are inclusive
|
||||
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
|
||||
result.append(span)
|
||||
seen_tokens.update(range(span.start, span.end))
|
||||
seen_tokens.update(range(span.start, span.end))
|
||||
result = sorted(result, key=lambda span: span.start)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -91,8 +91,8 @@ def demo(shape):
|
|||
nlp = spacy.load("en_vectors_web_lg")
|
||||
nlp.add_pipe(KerasSimilarityShim.load(nlp.path / "similarity", nlp, shape[0]))
|
||||
|
||||
doc1 = nlp(u"The king of France is bald.")
|
||||
doc2 = nlp(u"France has no king.")
|
||||
doc1 = nlp("The king of France is bald.")
|
||||
doc2 = nlp("France has no king.")
|
||||
|
||||
print("Sentence 1:", doc1)
|
||||
print("Sentence 2:", doc2)
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
{
|
||||
"tokens": [
|
||||
{
|
||||
"head": 4,
|
||||
"head": 44,
|
||||
"dep": "prep",
|
||||
"tag": "IN",
|
||||
"orth": "In",
|
||||
|
|
|
@ -11,7 +11,7 @@ numpy>=1.15.0
|
|||
requests>=2.13.0,<3.0.0
|
||||
plac<1.0.0,>=0.9.6
|
||||
pathlib==1.0.1; python_version < "3.4"
|
||||
importlib_metadata>=0.23; python_version < "3.8"
|
||||
importlib_metadata>=0.20; python_version < "3.8"
|
||||
# Optional dependencies
|
||||
jsonschema>=2.6.0,<3.1.0
|
||||
# Development dependencies
|
||||
|
|
|
@ -51,7 +51,7 @@ install_requires =
|
|||
wasabi>=0.2.0,<1.1.0
|
||||
srsly>=0.1.0,<1.1.0
|
||||
pathlib==1.0.1; python_version < "3.4"
|
||||
importlib_metadata>=0.23; python_version < "3.8"
|
||||
importlib_metadata>=0.20; python_version < "3.8"
|
||||
|
||||
[options.extras_require]
|
||||
lookups =
|
||||
|
|
|
@ -57,7 +57,8 @@ def convert(
|
|||
is written to stdout, so you can pipe them forward to a JSON file:
|
||||
$ spacy convert some_file.conllu > some_file.json
|
||||
"""
|
||||
msg = Printer()
|
||||
no_print = (output_dir == "-")
|
||||
msg = Printer(no_print=no_print)
|
||||
input_path = Path(input_file)
|
||||
if file_type not in FILE_TYPES:
|
||||
msg.fail(
|
||||
|
@ -102,6 +103,7 @@ def convert(
|
|||
use_morphology=morphology,
|
||||
lang=lang,
|
||||
model=model,
|
||||
no_print=no_print,
|
||||
)
|
||||
if output_dir != "-":
|
||||
# Export data to a file
|
||||
|
|
|
@ -9,7 +9,7 @@ from ...tokens.doc import Doc
|
|||
from ...util import load_model
|
||||
|
||||
|
||||
def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, **kwargs):
|
||||
def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, no_print=False, **kwargs):
|
||||
"""
|
||||
Convert files in the CoNLL-2003 NER format and similar
|
||||
whitespace-separated columns into JSON format for use with train cli.
|
||||
|
@ -34,7 +34,7 @@ def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, **kwargs
|
|||
. O
|
||||
|
||||
"""
|
||||
msg = Printer()
|
||||
msg = Printer(no_print=no_print)
|
||||
doc_delimiter = "-DOCSTART- -X- O O"
|
||||
# check for existing delimiters, which should be preserved
|
||||
if "\n\n" in input_data and seg_sents:
|
||||
|
|
|
@ -8,7 +8,7 @@ from ...util import minibatch
|
|||
from .conll_ner2json import n_sents_info
|
||||
|
||||
|
||||
def iob2json(input_data, n_sents=10, *args, **kwargs):
|
||||
def iob2json(input_data, n_sents=10, no_print=False, *args, **kwargs):
|
||||
"""
|
||||
Convert IOB files with one sentence per line and tags separated with '|'
|
||||
into JSON format for use with train cli. IOB and IOB2 are accepted.
|
||||
|
@ -20,7 +20,7 @@ def iob2json(input_data, n_sents=10, *args, **kwargs):
|
|||
I|PRP|O like|VBP|O London|NNP|I-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O
|
||||
I|PRP|O like|VBP|O London|NNP|B-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O
|
||||
"""
|
||||
msg = Printer()
|
||||
msg = Printer(no_print=no_print)
|
||||
docs = read_iob(input_data.split("\n"))
|
||||
if n_sents > 0:
|
||||
n_sents_info(msg, n_sents)
|
||||
|
|
|
@ -360,6 +360,16 @@ def debug_data(
|
|||
)
|
||||
)
|
||||
|
||||
# check for documents with multiple sentences
|
||||
sents_per_doc = gold_train_data["n_sents"] / len(gold_train_data["texts"])
|
||||
if sents_per_doc < 1.1:
|
||||
msg.warn(
|
||||
"The training data contains {:.2f} sentences per "
|
||||
"document. When there are very few documents containing more "
|
||||
"than one sentence, the parser will not learn how to segment "
|
||||
"longer texts into sentences.".format(sents_per_doc)
|
||||
)
|
||||
|
||||
# profile labels
|
||||
labels_train = [label for label in gold_train_data["deps"]]
|
||||
labels_train_unpreprocessed = [
|
||||
|
|
|
@ -246,7 +246,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
|
|||
"""Perform an update over a single batch of documents.
|
||||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
drop (float): The droput rate.
|
||||
drop (float): The dropout rate.
|
||||
optimizer (callable): An optimizer.
|
||||
RETURNS loss: A float for the loss.
|
||||
"""
|
||||
|
|
|
@ -80,8 +80,8 @@ class Warnings(object):
|
|||
"the v2.x models cannot release the global interpreter lock. "
|
||||
"Future versions may introduce a `n_process` argument for "
|
||||
"parallel inference via multiprocessing.")
|
||||
W017 = ("Alias '{alias}' already exists in the Knowledge base.")
|
||||
W018 = ("Entity '{entity}' already exists in the Knowledge base.")
|
||||
W017 = ("Alias '{alias}' already exists in the Knowledge Base.")
|
||||
W018 = ("Entity '{entity}' already exists in the Knowledge Base.")
|
||||
W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
|
||||
"previously loaded vectors. See Issue #3853.")
|
||||
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
||||
|
@ -95,7 +95,10 @@ class Warnings(object):
|
|||
"you can ignore this warning by setting SPACY_WARNING_IGNORE=W022. "
|
||||
"If this is surprising, make sure you have the spacy-lookups-data "
|
||||
"package installed.")
|
||||
W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.")
|
||||
W023 = ("Multiprocessing of Language.pipe is not supported in Python 2. "
|
||||
"'n_process' will be set to 1.")
|
||||
W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
|
||||
"the Knowledge Base.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
@ -408,7 +411,7 @@ class Errors(object):
|
|||
"{probabilities_length} respectively.")
|
||||
E133 = ("The sum of prior probabilities for alias '{alias}' should not "
|
||||
"exceed 1, but found {sum}.")
|
||||
E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.")
|
||||
E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
|
||||
E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
|
||||
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
|
||||
E136 = ("This additional feature requires the jsonschema library to be "
|
||||
|
@ -420,7 +423,7 @@ class Errors(object):
|
|||
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
|
||||
"includes either the `text` or `tokens` key. For more info, see "
|
||||
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
|
||||
E139 = ("Knowledge base for component '{name}' not initialized. Did you "
|
||||
E139 = ("Knowledge Base for component '{name}' not initialized. Did you "
|
||||
"forget to call set_kb()?")
|
||||
E140 = ("The list of entities, prior probabilities and entity vectors "
|
||||
"should be of equal length.")
|
||||
|
@ -498,6 +501,8 @@ class Errors(object):
|
|||
"details: https://spacy.io/api/lemmatizer#init")
|
||||
E174 = ("Architecture '{name}' not found in registry. Available "
|
||||
"names: {names}")
|
||||
E175 = ("Can't remove rule for unknown match pattern ID: {key}")
|
||||
E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -743,7 +743,8 @@ def docs_to_json(docs, id=0):
|
|||
|
||||
docs (iterable / Doc): The Doc object(s) to convert.
|
||||
id (int): Id for the JSON.
|
||||
RETURNS (list): The data in spaCy's JSON format.
|
||||
RETURNS (dict): The data in spaCy's JSON format
|
||||
- each input doc will be treated as a paragraph in the output doc
|
||||
"""
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
|
69
spacy/kb.pyx
69
spacy/kb.pyx
|
@ -142,6 +142,7 @@ cdef class KnowledgeBase:
|
|||
|
||||
i = 0
|
||||
cdef KBEntryC entry
|
||||
cdef hash_t entity_hash
|
||||
while i < nr_entities:
|
||||
entity_vector = vector_list[i]
|
||||
if len(entity_vector) != self.entity_vector_length:
|
||||
|
@ -161,6 +162,14 @@ cdef class KnowledgeBase:
|
|||
|
||||
i += 1
|
||||
|
||||
def contains_entity(self, unicode entity):
|
||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||
return entity_hash in self._entry_index
|
||||
|
||||
def contains_alias(self, unicode alias):
|
||||
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
||||
return alias_hash in self._alias_index
|
||||
|
||||
def add_alias(self, unicode alias, entities, probabilities):
|
||||
"""
|
||||
For a given alias, add its potential entities and prior probabilies to the KB.
|
||||
|
@ -190,7 +199,7 @@ cdef class KnowledgeBase:
|
|||
for entity, prob in zip(entities, probabilities):
|
||||
entity_hash = self.vocab.strings[entity]
|
||||
if not entity_hash in self._entry_index:
|
||||
raise ValueError(Errors.E134.format(alias=alias, entity=entity))
|
||||
raise ValueError(Errors.E134.format(entity=entity))
|
||||
|
||||
entry_index = <int64_t>self._entry_index.get(entity_hash)
|
||||
entry_indices.push_back(int(entry_index))
|
||||
|
@ -201,8 +210,63 @@ cdef class KnowledgeBase:
|
|||
|
||||
return alias_hash
|
||||
|
||||
def get_candidates(self, unicode alias):
|
||||
def append_alias(self, unicode alias, unicode entity, float prior_prob, ignore_warnings=False):
|
||||
"""
|
||||
For an alias already existing in the KB, extend its potential entities with one more.
|
||||
Throw a warning if either the alias or the entity is unknown,
|
||||
or when the combination is already previously recorded.
|
||||
Throw an error if this entity+prior prob would exceed the sum of 1.
|
||||
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
|
||||
"""
|
||||
# Check if the alias exists in the KB
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
if not alias_hash in self._alias_index:
|
||||
raise ValueError(Errors.E176.format(alias=alias))
|
||||
|
||||
# Check if the entity exists in the KB
|
||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||
if not entity_hash in self._entry_index:
|
||||
raise ValueError(Errors.E134.format(entity=entity))
|
||||
entry_index = <int64_t>self._entry_index.get(entity_hash)
|
||||
|
||||
# Throw an error if the prior probabilities (including the new one) sum up to more than 1
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
current_sum = sum([p for p in alias_entry.probs])
|
||||
new_sum = current_sum + prior_prob
|
||||
|
||||
if new_sum > 1.00001:
|
||||
raise ValueError(Errors.E133.format(alias=alias, sum=new_sum))
|
||||
|
||||
entry_indices = alias_entry.entry_indices
|
||||
|
||||
is_present = False
|
||||
for i in range(entry_indices.size()):
|
||||
if entry_indices[i] == int(entry_index):
|
||||
is_present = True
|
||||
|
||||
if is_present:
|
||||
if not ignore_warnings:
|
||||
user_warning(Warnings.W024.format(entity=entity, alias=alias))
|
||||
else:
|
||||
entry_indices.push_back(int(entry_index))
|
||||
alias_entry.entry_indices = entry_indices
|
||||
|
||||
probs = alias_entry.probs
|
||||
probs.push_back(float(prior_prob))
|
||||
alias_entry.probs = probs
|
||||
self._aliases_table[alias_index] = alias_entry
|
||||
|
||||
|
||||
def get_candidates(self, unicode alias):
|
||||
"""
|
||||
Return candidate entities for an alias. Each candidate defines the entity, the original alias,
|
||||
and the prior probability of that alias resolving to that entity.
|
||||
If the alias is not known in the KB, and empty list is returned.
|
||||
"""
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
if not alias_hash in self._alias_index:
|
||||
return []
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
|
||||
|
@ -341,7 +405,6 @@ cdef class KnowledgeBase:
|
|||
assert nr_entities == self.get_size_entities()
|
||||
|
||||
# STEP 3: load aliases
|
||||
|
||||
cdef int64_t nr_aliases
|
||||
reader.read_alias_length(&nr_aliases)
|
||||
self._alias_index = PreshMap(nr_aliases+1)
|
||||
|
|
34
spacy/lang/lb/__init__.py
Normal file
34
spacy/lang/lb/__init__.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
|
||||
from .norm_exceptions import NORM_EXCEPTIONS
|
||||
from .lex_attrs import LEX_ATTRS
|
||||
from .tag_map import TAG_MAP
|
||||
from .stop_words import STOP_WORDS
|
||||
|
||||
from ..tokenizer_exceptions import BASE_EXCEPTIONS
|
||||
from ..norm_exceptions import BASE_NORMS
|
||||
from ...language import Language
|
||||
from ...attrs import LANG, NORM
|
||||
from ...util import update_exc, add_lookups
|
||||
|
||||
|
||||
class LuxembourgishDefaults(Language.Defaults):
|
||||
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
|
||||
lex_attr_getters.update(LEX_ATTRS)
|
||||
lex_attr_getters[LANG] = lambda text: "lb"
|
||||
lex_attr_getters[NORM] = add_lookups(
|
||||
Language.Defaults.lex_attr_getters[NORM], NORM_EXCEPTIONS, BASE_NORMS
|
||||
)
|
||||
tokenizer_exceptions = update_exc(BASE_EXCEPTIONS, TOKENIZER_EXCEPTIONS)
|
||||
stop_words = STOP_WORDS
|
||||
tag_map = TAG_MAP
|
||||
|
||||
|
||||
class Luxembourgish(Language):
|
||||
lang = "lb"
|
||||
Defaults = LuxembourgishDefaults
|
||||
|
||||
|
||||
__all__ = ["Luxembourgish"]
|
18
spacy/lang/lb/examples.py
Normal file
18
spacy/lang/lb/examples.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
"""
|
||||
Example sentences to test spaCy and its language models.
|
||||
|
||||
>>> from spacy.lang.lb.examples import sentences
|
||||
>>> docs = nlp.pipe(sentences)
|
||||
"""
|
||||
|
||||
sentences = [
|
||||
"An der Zäit hunn sech den Nordwand an d’Sonn gestridden, wie vun hinnen zwee wuel méi staark wier, wéi e Wanderer, deen an ee waarme Mantel agepak war, iwwert de Wee koum.",
|
||||
"Si goufen sech eens, dass deejéinege fir de Stäerkste gëlle sollt, deen de Wanderer forcéiere géif, säi Mantel auszedoen.",
|
||||
"Den Nordwand huet mat aller Force geblosen, awer wat e méi geblosen huet, wat de Wanderer sech méi a säi Mantel agewéckelt huet.",
|
||||
"Um Enn huet den Nordwand säi Kampf opginn.",
|
||||
"Dunn huet d’Sonn d’Loft mat hire frëndleche Strale gewiermt, a schonn no kuerzer Zäit huet de Wanderer säi Mantel ausgedoen.",
|
||||
"Do huet den Nordwand missen zouginn, dass d’Sonn vun hinnen zwee de Stäerkste wier.",
|
||||
]
|
44
spacy/lang/lb/lex_attrs.py
Normal file
44
spacy/lang/lb/lex_attrs.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from ...attrs import LIKE_NUM
|
||||
|
||||
|
||||
_num_words = set(
|
||||
"""
|
||||
null eent zwee dräi véier fënnef sechs ziwen aacht néng zéng eelef zwielef dräizéng
|
||||
véierzéng foffzéng siechzéng siwwenzéng uechtzeng uechzeng nonnzéng nongzéng zwanzeg drësseg véierzeg foffzeg sechzeg siechzeg siwenzeg achtzeg achzeg uechtzeg uechzeg nonnzeg
|
||||
honnert dausend millioun milliard billioun billiard trillioun triliard
|
||||
""".split()
|
||||
)
|
||||
|
||||
_ordinal_words = set(
|
||||
"""
|
||||
éischten zweeten drëtten véierten fënneften sechsten siwenten aachten néngten zéngten eeleften
|
||||
zwieleften dräizéngten véierzéngten foffzéngten siechzéngten uechtzéngen uechzéngten nonnzéngten nongzéngten zwanzegsten
|
||||
drëssegsten véierzegsten foffzegsten siechzegsten siwenzegsten uechzegsten nonnzegsten
|
||||
honnertsten dausendsten milliounsten
|
||||
milliardsten billiounsten billiardsten trilliounsten trilliardsten
|
||||
""".split()
|
||||
)
|
||||
|
||||
|
||||
def like_num(text):
|
||||
"""
|
||||
check if text resembles a number
|
||||
"""
|
||||
text = text.replace(",", "").replace(".", "")
|
||||
if text.isdigit():
|
||||
return True
|
||||
if text.count("/") == 1:
|
||||
num, denom = text.split("/")
|
||||
if num.isdigit() and denom.isdigit():
|
||||
return True
|
||||
if text in _num_words:
|
||||
return True
|
||||
if text in _ordinal_words:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
LEX_ATTRS = {LIKE_NUM: like_num}
|
16
spacy/lang/lb/norm_exceptions.py
Normal file
16
spacy/lang/lb/norm_exceptions.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
# TODO
|
||||
# norm execptions: find a possibility to deal with the zillions of spelling
|
||||
# variants (vläicht = vlaicht, vleicht, viläicht, viläischt, etc. etc.)
|
||||
# here one could include the most common spelling mistakes
|
||||
|
||||
_exc = {"datt": "dass", "wgl.": "weg.", "vläicht": "viläicht"}
|
||||
|
||||
|
||||
NORM_EXCEPTIONS = {}
|
||||
|
||||
for string, norm in _exc.items():
|
||||
NORM_EXCEPTIONS[string] = norm
|
||||
NORM_EXCEPTIONS[string.title()] = norm
|
214
spacy/lang/lb/stop_words.py
Normal file
214
spacy/lang/lb/stop_words.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
STOP_WORDS = set(
|
||||
"""
|
||||
a
|
||||
à
|
||||
äis
|
||||
är
|
||||
ärt
|
||||
äert
|
||||
ären
|
||||
all
|
||||
allem
|
||||
alles
|
||||
alleguer
|
||||
als
|
||||
also
|
||||
am
|
||||
an
|
||||
anerefalls
|
||||
ass
|
||||
aus
|
||||
awer
|
||||
bei
|
||||
beim
|
||||
bis
|
||||
bis
|
||||
d'
|
||||
dach
|
||||
datt
|
||||
däin
|
||||
där
|
||||
dat
|
||||
de
|
||||
dee
|
||||
den
|
||||
deel
|
||||
deem
|
||||
deen
|
||||
deene
|
||||
déi
|
||||
den
|
||||
deng
|
||||
denger
|
||||
dem
|
||||
der
|
||||
dësem
|
||||
di
|
||||
dir
|
||||
do
|
||||
da
|
||||
dann
|
||||
domat
|
||||
dozou
|
||||
drop
|
||||
du
|
||||
duerch
|
||||
duerno
|
||||
e
|
||||
ee
|
||||
em
|
||||
een
|
||||
eent
|
||||
ë
|
||||
en
|
||||
ënner
|
||||
ëm
|
||||
ech
|
||||
eis
|
||||
eise
|
||||
eisen
|
||||
eiser
|
||||
eises
|
||||
eisereen
|
||||
esou
|
||||
een
|
||||
eng
|
||||
enger
|
||||
engem
|
||||
entweder
|
||||
et
|
||||
eréischt
|
||||
falls
|
||||
fir
|
||||
géint
|
||||
géif
|
||||
gëtt
|
||||
gët
|
||||
geet
|
||||
gi
|
||||
ginn
|
||||
gouf
|
||||
gouff
|
||||
goung
|
||||
hat
|
||||
haten
|
||||
hatt
|
||||
hätt
|
||||
hei
|
||||
hu
|
||||
huet
|
||||
hun
|
||||
hunn
|
||||
hiren
|
||||
hien
|
||||
hin
|
||||
hier
|
||||
hir
|
||||
jidderen
|
||||
jiddereen
|
||||
jiddwereen
|
||||
jiddereng
|
||||
jiddwerengen
|
||||
jo
|
||||
ins
|
||||
iech
|
||||
iwwer
|
||||
kann
|
||||
kee
|
||||
keen
|
||||
kënne
|
||||
kënnt
|
||||
kéng
|
||||
kéngen
|
||||
kéngem
|
||||
koum
|
||||
kuckt
|
||||
mam
|
||||
mat
|
||||
ma
|
||||
mä
|
||||
mech
|
||||
méi
|
||||
mécht
|
||||
meng
|
||||
menger
|
||||
mer
|
||||
mir
|
||||
muss
|
||||
nach
|
||||
nämmlech
|
||||
nämmelech
|
||||
näischt
|
||||
nawell
|
||||
nëmme
|
||||
nëmmen
|
||||
net
|
||||
nees
|
||||
nee
|
||||
no
|
||||
nu
|
||||
nom
|
||||
och
|
||||
oder
|
||||
ons
|
||||
onsen
|
||||
onser
|
||||
onsereen
|
||||
onst
|
||||
om
|
||||
op
|
||||
ouni
|
||||
säi
|
||||
säin
|
||||
schonn
|
||||
schonns
|
||||
si
|
||||
sid
|
||||
sie
|
||||
se
|
||||
sech
|
||||
seng
|
||||
senge
|
||||
sengem
|
||||
senger
|
||||
selwecht
|
||||
selwer
|
||||
sinn
|
||||
sollten
|
||||
souguer
|
||||
sou
|
||||
soss
|
||||
sot
|
||||
't
|
||||
tëscht
|
||||
u
|
||||
un
|
||||
um
|
||||
virdrun
|
||||
vu
|
||||
vum
|
||||
vun
|
||||
wann
|
||||
war
|
||||
waren
|
||||
was
|
||||
wat
|
||||
wëllt
|
||||
weider
|
||||
wéi
|
||||
wéini
|
||||
wéinst
|
||||
wi
|
||||
wollt
|
||||
wou
|
||||
wouhin
|
||||
zanter
|
||||
ze
|
||||
zu
|
||||
zum
|
||||
zwar
|
||||
""".split()
|
||||
)
|
28
spacy/lang/lb/tag_map.py
Normal file
28
spacy/lang/lb/tag_map.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from ...symbols import POS, PUNCT, ADJ, CONJ, NUM, DET, ADV, ADP, X, VERB
|
||||
from ...symbols import NOUN, PART, SPACE, AUX
|
||||
|
||||
# TODO: tag map is still using POS tags from an internal training set.
|
||||
# These POS tags have to be modified to match those from Universal Dependencies
|
||||
|
||||
TAG_MAP = {
|
||||
"$": {POS: PUNCT},
|
||||
"ADJ": {POS: ADJ},
|
||||
"AV": {POS: ADV},
|
||||
"APPR": {POS: ADP, "AdpType": "prep"},
|
||||
"APPRART": {POS: ADP, "AdpType": "prep", "PronType": "art"},
|
||||
"D": {POS: DET, "PronType": "art"},
|
||||
"KO": {POS: CONJ},
|
||||
"N": {POS: NOUN},
|
||||
"P": {POS: ADV},
|
||||
"TRUNC": {POS: X, "Hyph": "yes"},
|
||||
"AUX": {POS: AUX},
|
||||
"V": {POS: VERB},
|
||||
"MV": {POS: VERB, "VerbType": "mod"},
|
||||
"PTK": {POS: PART},
|
||||
"INTER": {POS: PART},
|
||||
"NUM": {POS: NUM},
|
||||
"_SP": {POS: SPACE},
|
||||
}
|
69
spacy/lang/lb/tokenizer_exceptions.py
Normal file
69
spacy/lang/lb/tokenizer_exceptions.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from ...symbols import ORTH, LEMMA, NORM
|
||||
from ..punctuation import TOKENIZER_PREFIXES
|
||||
|
||||
# TODO
|
||||
# tokenize cliticised definite article "d'" as token of its own: d'Kanner > [d'] [Kanner]
|
||||
# treat other apostrophes within words as part of the word: [op d'mannst], [fir d'éischt] (= exceptions)
|
||||
|
||||
# how to write the tokenisation exeption for the articles d' / D' ? This one is not working.
|
||||
_prefixes = [
|
||||
prefix for prefix in TOKENIZER_PREFIXES if prefix not in ["d'", "D'", "d’", "D’"]
|
||||
]
|
||||
|
||||
|
||||
_exc = {
|
||||
"d'mannst": [
|
||||
{ORTH: "d'", LEMMA: "d'"},
|
||||
{ORTH: "mannst", LEMMA: "mann", NORM: "mann"},
|
||||
],
|
||||
"d'éischt": [
|
||||
{ORTH: "d'", LEMMA: "d'"},
|
||||
{ORTH: "éischt", LEMMA: "éischt", NORM: "éischt"},
|
||||
],
|
||||
}
|
||||
|
||||
# translate / delete what is not necessary
|
||||
# what does PRON_LEMMA mean?
|
||||
for exc_data in [
|
||||
{ORTH: "wgl.", LEMMA: "wann ech gelift", NORM: "wann ech gelieft"},
|
||||
{ORTH: "M.", LEMMA: "Monsieur", NORM: "Monsieur"},
|
||||
{ORTH: "Mme.", LEMMA: "Madame", NORM: "Madame"},
|
||||
{ORTH: "Dr.", LEMMA: "Dokter", NORM: "Dokter"},
|
||||
{ORTH: "Tel.", LEMMA: "Telefon", NORM: "Telefon"},
|
||||
{ORTH: "asw.", LEMMA: "an sou weider", NORM: "an sou weider"},
|
||||
{ORTH: "etc.", LEMMA: "et cetera", NORM: "et cetera"},
|
||||
{ORTH: "bzw.", LEMMA: "bezéiungsweis", NORM: "bezéiungsweis"},
|
||||
{ORTH: "Jan.", LEMMA: "Januar", NORM: "Januar"},
|
||||
]:
|
||||
_exc[exc_data[ORTH]] = [exc_data]
|
||||
|
||||
|
||||
# to be extended
|
||||
for orth in [
|
||||
"z.B.",
|
||||
"Dipl.",
|
||||
"Dr.",
|
||||
"etc.",
|
||||
"i.e.",
|
||||
"o.k.",
|
||||
"O.K.",
|
||||
"p.a.",
|
||||
"p.s.",
|
||||
"P.S.",
|
||||
"phil.",
|
||||
"q.e.d.",
|
||||
"R.I.P.",
|
||||
"rer.",
|
||||
"sen.",
|
||||
"ë.a.",
|
||||
"U.S.",
|
||||
"U.S.A.",
|
||||
]:
|
||||
_exc[orth] = [{ORTH: orth}]
|
||||
|
||||
|
||||
TOKENIZER_PREFIXES = _prefixes
|
||||
TOKENIZER_EXCEPTIONS = _exc
|
|
@ -1,10 +1,8 @@
|
|||
# coding: utf8
|
||||
from __future__ import absolute_import, unicode_literals
|
||||
|
||||
import atexit
|
||||
import random
|
||||
import itertools
|
||||
from warnings import warn
|
||||
from spacy.util import minibatch
|
||||
import weakref
|
||||
import functools
|
||||
|
@ -483,7 +481,7 @@ class Language(object):
|
|||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
golds (iterable): A batch of `GoldParse` objects.
|
||||
drop (float): The droput rate.
|
||||
drop (float): The dropout rate.
|
||||
sgd (callable): An optimizer.
|
||||
losses (dict): Dictionary to update with the loss, keyed by component.
|
||||
component_cfg (dict): Config parameters for specific pipeline
|
||||
|
@ -531,7 +529,7 @@ class Language(object):
|
|||
even if you're updating it with a smaller set of examples.
|
||||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
drop (float): The droput rate.
|
||||
drop (float): The dropout rate.
|
||||
sgd (callable): An optimizer.
|
||||
RETURNS (dict): Results from the update.
|
||||
|
||||
|
@ -753,7 +751,8 @@ class Language(object):
|
|||
use. Experimental.
|
||||
component_cfg (dict): An optional dictionary with extra keyword
|
||||
arguments for specific components.
|
||||
n_process (int): Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`.
|
||||
n_process (int): Number of processors to process texts, only supported
|
||||
in Python3. If -1, set `multiprocessing.cpu_count()`.
|
||||
YIELDS (Doc): Documents in the order of the original text.
|
||||
|
||||
DOCS: https://spacy.io/api/language#pipe
|
||||
|
@ -1069,9 +1068,10 @@ def _pipe(docs, proc, kwargs):
|
|||
def _apply_pipes(make_doc, pipes, reciever, sender):
|
||||
"""Worker for Language.pipe
|
||||
|
||||
Args:
|
||||
receiver (multiprocessing.Connection): Pipe to receive text. Usually created by `multiprocessing.Pipe()`
|
||||
sender (multiprocessing.Connection): Pipe to send doc. Usually created by `multiprocessing.Pipe()`
|
||||
receiver (multiprocessing.Connection): Pipe to receive text. Usually
|
||||
created by `multiprocessing.Pipe()`
|
||||
sender (multiprocessing.Connection): Pipe to send doc. Usually created by
|
||||
`multiprocessing.Pipe()`
|
||||
"""
|
||||
while True:
|
||||
texts = reciever.get()
|
||||
|
@ -1100,7 +1100,7 @@ class _Sender:
|
|||
q.put(item)
|
||||
|
||||
def step(self):
|
||||
"""Tell sender that comsumed one item.
|
||||
"""Tell sender that comsumed one item.
|
||||
|
||||
Data is sent to the workers after every chunk_size calls."""
|
||||
self.count += 1
|
||||
|
|
|
@ -133,13 +133,15 @@ cdef class Matcher:
|
|||
|
||||
key (unicode): The ID of the match rule.
|
||||
"""
|
||||
key = self._normalize_key(key)
|
||||
self._patterns.pop(key)
|
||||
self._callbacks.pop(key)
|
||||
norm_key = self._normalize_key(key)
|
||||
if not norm_key in self._patterns:
|
||||
raise ValueError(Errors.E175.format(key=key))
|
||||
self._patterns.pop(norm_key)
|
||||
self._callbacks.pop(norm_key)
|
||||
cdef int i = 0
|
||||
while i < self.patterns.size():
|
||||
pattern_key = get_pattern_key(self.patterns.at(i))
|
||||
if pattern_key == key:
|
||||
pattern_key = get_ent_id(self.patterns.at(i))
|
||||
if pattern_key == norm_key:
|
||||
self.patterns.erase(self.patterns.begin()+i)
|
||||
else:
|
||||
i += 1
|
||||
|
@ -293,18 +295,6 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
|||
return output
|
||||
|
||||
|
||||
cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
||||
# There have been a few bugs here.
|
||||
# The code was originally designed to always have pattern[1].attrs.value
|
||||
# be the ent_id when we get to the end of a pattern. However, Issue #2671
|
||||
# showed this wasn't the case when we had a reject-and-continue before a
|
||||
# match.
|
||||
# The patch to #2671 was wrong though, which came up in #3839.
|
||||
while pattern.attrs.attr != ID:
|
||||
pattern += 1
|
||||
return pattern.attrs.value
|
||||
|
||||
|
||||
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
|
||||
char* cached_py_predicates,
|
||||
Token token, const attr_t* extra_attrs, py_predicates) except *:
|
||||
|
@ -533,9 +523,10 @@ cdef char get_is_match(PatternStateC state,
|
|||
if predicate_matches[state.pattern.py_predicates[i]] == -1:
|
||||
return 0
|
||||
spec = state.pattern
|
||||
for attr in spec.attrs[:spec.nr_attr]:
|
||||
if get_token_attr(token, attr.attr) != attr.value:
|
||||
return 0
|
||||
if spec.nr_attr > 0:
|
||||
for attr in spec.attrs[:spec.nr_attr]:
|
||||
if get_token_attr(token, attr.attr) != attr.value:
|
||||
return 0
|
||||
for i in range(spec.nr_extra_attr):
|
||||
if spec.extra_attrs[i].value != extra_attrs[spec.extra_attrs[i].index]:
|
||||
return 0
|
||||
|
@ -543,7 +534,11 @@ cdef char get_is_match(PatternStateC state,
|
|||
|
||||
|
||||
cdef char get_is_final(PatternStateC state) nogil:
|
||||
if state.pattern[1].attrs[0].attr == ID and state.pattern[1].nr_attr == 0:
|
||||
if state.pattern[1].nr_attr == 0 and state.pattern[1].attrs != NULL:
|
||||
id_attr = state.pattern[1].attrs[0]
|
||||
if id_attr.attr != ID:
|
||||
with gil:
|
||||
raise ValueError(Errors.E074.format(attr=ID, bad_attr=id_attr.attr))
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
@ -558,7 +553,9 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
|
|||
cdef int i, index
|
||||
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
|
||||
pattern[i].quantifier = quantifier
|
||||
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
|
||||
# Ensure attrs refers to a null pointer if nr_attr == 0
|
||||
if len(spec) > 0:
|
||||
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
|
||||
pattern[i].nr_attr = len(spec)
|
||||
for j, (attr, value) in enumerate(spec):
|
||||
pattern[i].attrs[j].attr = attr
|
||||
|
@ -574,6 +571,7 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
|
|||
pattern[i].nr_py = len(predicates)
|
||||
pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0)
|
||||
i = len(token_specs)
|
||||
# Even though here, nr_attr == 0, we're storing the ID value in attrs[0] (bug-prone, thread carefully!)
|
||||
pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
|
||||
pattern[i].attrs[0].attr = ID
|
||||
pattern[i].attrs[0].value = entity_id
|
||||
|
@ -583,8 +581,26 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
|
|||
return pattern
|
||||
|
||||
|
||||
cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil:
|
||||
while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0:
|
||||
cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
||||
# There have been a few bugs here. We used to have two functions,
|
||||
# get_ent_id and get_pattern_key that tried to do the same thing. These
|
||||
# are now unified to try to solve the "ghost match" problem.
|
||||
# Below is the previous implementation of get_ent_id and the comment on it,
|
||||
# preserved for reference while we figure out whether the heisenbug in the
|
||||
# matcher is resolved.
|
||||
#
|
||||
#
|
||||
# cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
||||
# # The code was originally designed to always have pattern[1].attrs.value
|
||||
# # be the ent_id when we get to the end of a pattern. However, Issue #2671
|
||||
# # showed this wasn't the case when we had a reject-and-continue before a
|
||||
# # match.
|
||||
# # The patch to #2671 was wrong though, which came up in #3839.
|
||||
# while pattern.attrs.attr != ID:
|
||||
# pattern += 1
|
||||
# return pattern.attrs.value
|
||||
while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0 \
|
||||
or pattern.quantifier != ZERO:
|
||||
pattern += 1
|
||||
id_attr = pattern[0].attrs[0]
|
||||
if id_attr.attr != ID:
|
||||
|
@ -642,7 +658,7 @@ def _get_attr_values(spec, string_store):
|
|||
value = string_store.add(value)
|
||||
elif isinstance(value, bool):
|
||||
value = int(value)
|
||||
elif isinstance(value, dict):
|
||||
elif isinstance(value, (dict, int)):
|
||||
continue
|
||||
else:
|
||||
raise ValueError(Errors.E153.format(vtype=type(value).__name__))
|
||||
|
|
|
@ -4,6 +4,7 @@ from cymem.cymem cimport Pool
|
|||
from preshed.maps cimport key_t, MapStruct
|
||||
|
||||
from ..attrs cimport attr_id_t
|
||||
from ..structs cimport SpanC
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..vocab cimport Vocab
|
||||
|
||||
|
@ -18,10 +19,4 @@ cdef class PhraseMatcher:
|
|||
cdef Pool mem
|
||||
cdef key_t _terminal_hash
|
||||
|
||||
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil
|
||||
|
||||
|
||||
cdef struct MatchStruct:
|
||||
key_t match_id
|
||||
int start
|
||||
int end
|
||||
cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil
|
||||
|
|
|
@ -9,6 +9,7 @@ from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter
|
|||
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
|
||||
from ..structs cimport TokenC
|
||||
from ..tokens.token cimport Token
|
||||
from ..typedefs cimport attr_t
|
||||
|
||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
||||
from ..errors import Errors, Warnings, deprecation_warning, user_warning
|
||||
|
@ -102,8 +103,10 @@ cdef class PhraseMatcher:
|
|||
cdef vector[MapStruct*] path_nodes
|
||||
cdef vector[key_t] path_keys
|
||||
cdef key_t key_to_remove
|
||||
for keyword in self._docs[key]:
|
||||
for keyword in sorted(self._docs[key], key=lambda x: len(x), reverse=True):
|
||||
current_node = self.c_map
|
||||
path_nodes.clear()
|
||||
path_keys.clear()
|
||||
for token in keyword:
|
||||
result = map_get(current_node, token)
|
||||
if result:
|
||||
|
@ -220,17 +223,17 @@ cdef class PhraseMatcher:
|
|||
# if doc is empty or None just return empty list
|
||||
return matches
|
||||
|
||||
cdef vector[MatchStruct] c_matches
|
||||
cdef vector[SpanC] c_matches
|
||||
self.find_matches(doc, &c_matches)
|
||||
for i in range(c_matches.size()):
|
||||
matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end))
|
||||
matches.append((c_matches[i].label, c_matches[i].start, c_matches[i].end))
|
||||
for i, (ent_id, start, end) in enumerate(matches):
|
||||
on_match = self._callbacks.get(self.vocab.strings[ent_id])
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matches)
|
||||
return matches
|
||||
|
||||
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil:
|
||||
cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil:
|
||||
cdef MapStruct* current_node = self.c_map
|
||||
cdef int start = 0
|
||||
cdef int idx = 0
|
||||
|
@ -238,7 +241,7 @@ cdef class PhraseMatcher:
|
|||
cdef key_t key
|
||||
cdef void* value
|
||||
cdef int i = 0
|
||||
cdef MatchStruct ms
|
||||
cdef SpanC ms
|
||||
cdef void* result
|
||||
while idx < doc.length:
|
||||
start = idx
|
||||
|
@ -253,7 +256,7 @@ cdef class PhraseMatcher:
|
|||
if result:
|
||||
i = 0
|
||||
while map_iter(<MapStruct*>result, &i, &key, &value):
|
||||
ms = make_matchstruct(key, start, idy)
|
||||
ms = make_spanstruct(key, start, idy)
|
||||
matches.push_back(ms)
|
||||
inner_token = Token.get_struct_attr(&doc.c[idy], self.attr)
|
||||
result = map_get(current_node, inner_token)
|
||||
|
@ -268,7 +271,7 @@ cdef class PhraseMatcher:
|
|||
if result:
|
||||
i = 0
|
||||
while map_iter(<MapStruct*>result, &i, &key, &value):
|
||||
ms = make_matchstruct(key, start, idy)
|
||||
ms = make_spanstruct(key, start, idy)
|
||||
matches.push_back(ms)
|
||||
current_node = self.c_map
|
||||
idx += 1
|
||||
|
@ -318,9 +321,9 @@ def unpickle_matcher(vocab, docs, callbacks, attr):
|
|||
return matcher
|
||||
|
||||
|
||||
cdef MatchStruct make_matchstruct(key_t match_id, int start, int end) nogil:
|
||||
cdef MatchStruct ms
|
||||
ms.match_id = match_id
|
||||
ms.start = start
|
||||
ms.end = end
|
||||
return ms
|
||||
cdef SpanC make_spanstruct(attr_t label, int start, int end) nogil:
|
||||
cdef SpanC spanc
|
||||
spanc.label = label
|
||||
spanc.start = start
|
||||
spanc.end = end
|
||||
return spanc
|
||||
|
|
|
@ -183,7 +183,9 @@ class EntityRuler(object):
|
|||
# disable the nlp components after this one in case they hadn't been initialized / deserialised yet
|
||||
try:
|
||||
current_index = self.nlp.pipe_names.index(self.name)
|
||||
subsequent_pipes = [pipe for pipe in self.nlp.pipe_names[current_index + 1:]]
|
||||
subsequent_pipes = [
|
||||
pipe for pipe in self.nlp.pipe_names[current_index + 1 :]
|
||||
]
|
||||
except ValueError:
|
||||
subsequent_pipes = []
|
||||
with self.nlp.disable_pipes(*subsequent_pipes):
|
||||
|
|
|
@ -1195,23 +1195,26 @@ class EntityLinker(Pipe):
|
|||
docs = [docs]
|
||||
golds = [golds]
|
||||
|
||||
context_docs = []
|
||||
sentence_docs = []
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
ents_by_offset = dict()
|
||||
for ent in doc.ents:
|
||||
ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent
|
||||
ents_by_offset[(ent.start_char, ent.end_char)] = ent
|
||||
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
mention = doc.text[start:end]
|
||||
# the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt
|
||||
ent = ents_by_offset[(start, end)]
|
||||
|
||||
for kb_id, value in kb_dict.items():
|
||||
# Currently only training on the positive instances
|
||||
if value:
|
||||
context_docs.append(doc)
|
||||
sentence_docs.append(ent.sent.as_doc())
|
||||
|
||||
context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
|
||||
loss, d_scores = self.get_similarity_loss(scores=context_encodings, golds=golds, docs=None)
|
||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop)
|
||||
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None)
|
||||
bp_context(d_scores, sgd=sgd)
|
||||
|
||||
if losses is not None:
|
||||
|
@ -1280,50 +1283,68 @@ class EntityLinker(Pipe):
|
|||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
||||
context_encodings = self.model(docs)
|
||||
xp = get_array_module(context_encodings)
|
||||
|
||||
for i, doc in enumerate(docs):
|
||||
if len(doc) > 0:
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
context_encoding = context_encodings[i]
|
||||
context_enc_t = context_encoding.T
|
||||
norm_1 = xp.linalg.norm(context_enc_t)
|
||||
for ent in doc.ents:
|
||||
entity_count += 1
|
||||
# Looping through each sentence and each entity
|
||||
# This may go wrong if there are entities across sentences - because they might not get a KB ID
|
||||
for sent in doc.ents:
|
||||
sent_doc = sent.as_doc()
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
sentence_encoding = self.model([sent_doc])[0]
|
||||
xp = get_array_module(sentence_encoding)
|
||||
sentence_encoding_t = sentence_encoding.T
|
||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
if not candidates:
|
||||
final_kb_ids.append(self.NIL) # no prediction possible for this entity
|
||||
final_tensors.append(context_encoding)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
for ent in sent_doc.ents:
|
||||
entity_count += 1
|
||||
|
||||
# this will set all prior probabilities to 0 if they should be excluded from the model
|
||||
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
||||
if not self.cfg.get("incl_prior", True):
|
||||
prior_probs = xp.asarray([0.0 for c in candidates])
|
||||
scores = prior_probs
|
||||
if ent.label_ in self.cfg.get("labels_discard", []):
|
||||
# ignoring this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
# add in similarity from the context
|
||||
if self.cfg.get("incl_context", True):
|
||||
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||
norm_2 = xp.linalg.norm(entity_encodings, axis=1)
|
||||
else:
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
if not candidates:
|
||||
# no prediction possible for this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
if len(entity_encodings) != len(prior_probs):
|
||||
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
|
||||
elif len(candidates) == 1:
|
||||
# shortcut for efficiency reasons: take the 1 candidate
|
||||
|
||||
# cosine similarity
|
||||
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
|
||||
if sims.shape != prior_probs.shape:
|
||||
raise ValueError(Errors.E161)
|
||||
scores = prior_probs + sims - (prior_probs*sims)
|
||||
# TODO: thresholding
|
||||
final_kb_ids.append(candidates[0].entity_)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
# TODO: thresholding
|
||||
best_index = scores.argmax()
|
||||
best_candidate = candidates[best_index]
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
final_tensors.append(context_encoding)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
|
||||
# this will set all prior probabilities to 0 if they should be excluded from the model
|
||||
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
||||
if not self.cfg.get("incl_prior", True):
|
||||
prior_probs = xp.asarray([0.0 for c in candidates])
|
||||
scores = prior_probs
|
||||
|
||||
# add in similarity from the context
|
||||
if self.cfg.get("incl_context", True):
|
||||
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
|
||||
|
||||
if len(entity_encodings) != len(prior_probs):
|
||||
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
|
||||
|
||||
# cosine similarity
|
||||
sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm)
|
||||
if sims.shape != prior_probs.shape:
|
||||
raise ValueError(Errors.E161)
|
||||
scores = prior_probs + sims - (prior_probs*sims)
|
||||
|
||||
# TODO: thresholding
|
||||
best_index = scores.argmax()
|
||||
best_candidate = candidates[best_index]
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
|
||||
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
|
||||
|
|
|
@ -219,7 +219,9 @@ class Scorer(object):
|
|||
DOCS: https://spacy.io/api/scorer#score
|
||||
"""
|
||||
if len(doc) != len(gold):
|
||||
gold = GoldParse.from_annot_tuples(doc, zip(*gold.orig_annot))
|
||||
gold = GoldParse.from_annot_tuples(
|
||||
doc, tuple(zip(*gold.orig_annot)) + (gold.cats,)
|
||||
)
|
||||
gold_deps = set()
|
||||
gold_tags = set()
|
||||
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))
|
||||
|
|
|
@ -47,11 +47,14 @@ cdef struct SerializedLexemeC:
|
|||
# + sizeof(float) # l2_norm
|
||||
|
||||
|
||||
cdef struct Entity:
|
||||
cdef struct SpanC:
|
||||
hash_t id
|
||||
int start
|
||||
int end
|
||||
int start_char
|
||||
int end_char
|
||||
attr_t label
|
||||
attr_t kb_id
|
||||
|
||||
|
||||
cdef struct TokenC:
|
||||
|
|
|
@ -7,7 +7,7 @@ from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
|
|||
from murmurhash.mrmr cimport hash64
|
||||
|
||||
from ..vocab cimport EMPTY_LEXEME
|
||||
from ..structs cimport TokenC, Entity
|
||||
from ..structs cimport TokenC, SpanC
|
||||
from ..lexeme cimport Lexeme
|
||||
from ..symbols cimport punct
|
||||
from ..attrs cimport IS_SPACE
|
||||
|
@ -40,7 +40,7 @@ cdef cppclass StateC:
|
|||
int* _buffer
|
||||
bint* shifted
|
||||
TokenC* _sent
|
||||
Entity* _ents
|
||||
SpanC* _ents
|
||||
TokenC _empty_token
|
||||
RingBufferC _hist
|
||||
int length
|
||||
|
@ -56,7 +56,7 @@ cdef cppclass StateC:
|
|||
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
|
||||
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
|
||||
this._ents = <SpanC*>calloc(length + (PADDING * 2), sizeof(SpanC))
|
||||
if not (this._buffer and this._stack and this.shifted
|
||||
and this._sent and this._ents):
|
||||
with gil:
|
||||
|
@ -406,7 +406,7 @@ cdef cppclass StateC:
|
|||
memcpy(this._sent, src._sent, this.length * sizeof(TokenC))
|
||||
memcpy(this._stack, src._stack, this.length * sizeof(int))
|
||||
memcpy(this._buffer, src._buffer, this.length * sizeof(int))
|
||||
memcpy(this._ents, src._ents, this.length * sizeof(Entity))
|
||||
memcpy(this._ents, src._ents, this.length * sizeof(SpanC))
|
||||
memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0]))
|
||||
this._b_i = src._b_i
|
||||
this._s_i = src._s_i
|
||||
|
|
|
@ -3,7 +3,7 @@ from libc.string cimport memcpy, memset
|
|||
from cymem.cymem cimport Pool
|
||||
cimport cython
|
||||
|
||||
from ..structs cimport TokenC, Entity
|
||||
from ..structs cimport TokenC, SpanC
|
||||
from ..typedefs cimport attr_t
|
||||
|
||||
from ..vocab cimport EMPTY_LEXEME
|
||||
|
|
|
@ -135,6 +135,11 @@ def ko_tokenizer():
|
|||
return get_lang_class("ko").Defaults.create_tokenizer()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lb_tokenizer():
|
||||
return get_lang_class("lb").Defaults.create_tokenizer()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lt_tokenizer():
|
||||
return get_lang_class("lt").Defaults.create_tokenizer()
|
||||
|
|
|
@ -253,3 +253,11 @@ def test_filter_spans(doc):
|
|||
assert len(filtered[1]) == 5
|
||||
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||
assert filtered[1].start == 5 and filtered[1].end == 10
|
||||
# Test filtering overlaps with earlier preference for identical length
|
||||
spans = [doc[1:4], doc[2:5], doc[5:10], doc[7:9], doc[1:4]]
|
||||
filtered = filter_spans(spans)
|
||||
assert len(filtered) == 2
|
||||
assert len(filtered[0]) == 3
|
||||
assert len(filtered[1]) == 5
|
||||
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||
assert filtered[1].start == 5 and filtered[1].end == 10
|
||||
|
|
10
spacy/tests/lang/lb/test_exceptions.py
Normal file
10
spacy/tests/lang/lb/test_exceptions.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", ["z.B.", "Jan."])
|
||||
def test_lb_tokenizer_handles_abbr(lb_tokenizer, text):
|
||||
tokens = lb_tokenizer(text)
|
||||
assert len(tokens) == 1
|
22
spacy/tests/lang/lb/test_prefix_suffix_infix.py
Normal file
22
spacy/tests/lang/lb/test_prefix_suffix_infix.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text,length", [("z.B.", 1), ("zb.", 2), ("(z.B.", 2)])
|
||||
def test_lb_tokenizer_splits_prefix_interact(lb_tokenizer, text, length):
|
||||
tokens = lb_tokenizer(text)
|
||||
assert len(tokens) == length
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", ["z.B.)"])
|
||||
def test_lb_tokenizer_splits_suffix_interact(lb_tokenizer, text):
|
||||
tokens = lb_tokenizer(text)
|
||||
assert len(tokens) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", ["(z.B.)"])
|
||||
def test_lb_tokenizer_splits_even_wrap_interact(lb_tokenizer, text):
|
||||
tokens = lb_tokenizer(text)
|
||||
assert len(tokens) == 3
|
31
spacy/tests/lang/lb/test_text.py
Normal file
31
spacy/tests/lang/lb/test_text.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_lb_tokenizer_handles_long_text(lb_tokenizer):
|
||||
text = """Den Nordwand an d'Sonn
|
||||
|
||||
An der Zäit hunn sech den Nordwand an d’Sonn gestridden, wie vun hinnen zwee wuel méi staark wier, wéi e Wanderer, deen an ee waarme Mantel agepak war, iwwert de Wee koum. Si goufen sech eens, dass deejéinege fir de Stäerkste gëlle sollt, deen de Wanderer forcéiere géif, säi Mantel auszedoen.",
|
||||
|
||||
Den Nordwand huet mat aller Force geblosen, awer wat e méi geblosen huet, wat de Wanderer sech méi a säi Mantel agewéckelt huet. Um Enn huet den Nordwand säi Kampf opginn.
|
||||
|
||||
Dunn huet d’Sonn d’Loft mat hire frëndleche Strale gewiermt, a schonn no kuerzer Zäit huet de Wanderer säi Mantel ausgedoen.
|
||||
|
||||
Do huet den Nordwand missen zouginn, dass d’Sonn vun hinnen zwee de Stäerkste wier."""
|
||||
|
||||
tokens = lb_tokenizer(text)
|
||||
assert len(tokens) == 143
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,length",
|
||||
[
|
||||
("»Wat ass mat mir geschitt?«, huet hie geduecht.", 13),
|
||||
("“Dëst fréi Opstoen”, denkt hien, “mécht ee ganz duercherneen. ", 15),
|
||||
],
|
||||
)
|
||||
def test_lb_tokenizer_handles_examples(lb_tokenizer, text, length):
|
||||
tokens = lb_tokenizer(text)
|
||||
assert len(tokens) == length
|
|
@ -3,6 +3,8 @@ from __future__ import unicode_literals
|
|||
|
||||
import pytest
|
||||
import re
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.matcher import Matcher
|
||||
from spacy.tokens import Doc, Span
|
||||
|
||||
|
@ -143,3 +145,29 @@ def test_matcher_sets_return_correct_tokens(en_vocab):
|
|||
matches = matcher(doc)
|
||||
texts = [Span(doc, s, e, label=L).text for L, s, e in matches]
|
||||
assert texts == ["zero", "one", "two"]
|
||||
|
||||
|
||||
def test_matcher_remove():
|
||||
nlp = English()
|
||||
matcher = Matcher(nlp.vocab)
|
||||
text = "This is a test case."
|
||||
|
||||
pattern = [{"ORTH": "test"}, {"OP": "?"}]
|
||||
assert len(matcher) == 0
|
||||
matcher.add("Rule", None, pattern)
|
||||
assert "Rule" in matcher
|
||||
|
||||
# should give two matches
|
||||
results1 = matcher(nlp(text))
|
||||
assert len(results1) == 2
|
||||
|
||||
# removing once should work
|
||||
matcher.remove("Rule")
|
||||
|
||||
# should not return any maches anymore
|
||||
results2 = matcher(nlp(text))
|
||||
assert len(results2) == 0
|
||||
|
||||
# removing again should throw an error
|
||||
with pytest.raises(ValueError):
|
||||
matcher.remove("Rule")
|
||||
|
|
|
@ -12,24 +12,25 @@ from spacy.util import get_json_validator, validate_json
|
|||
TEST_PATTERNS = [
|
||||
# Bad patterns flagged in all cases
|
||||
([{"XX": "foo"}], 1, 1),
|
||||
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 1),
|
||||
([{"IS_ALPHA": {"==": True}}, {"LIKE_NUM": None}], 2, 1),
|
||||
([{"IS_PUNCT": True, "OP": "$"}], 1, 1),
|
||||
([{"IS_DIGIT": -1}], 1, 1),
|
||||
([{"ORTH": -1}], 1, 1),
|
||||
([{"_": "foo"}], 1, 1),
|
||||
('[{"TEXT": "foo"}, {"LOWER": "bar"}]', 1, 1),
|
||||
([1, 2, 3], 3, 1),
|
||||
# Bad patterns flagged outside of Matcher
|
||||
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 1, 0),
|
||||
# Bad patterns not flagged with minimal checks
|
||||
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 0),
|
||||
([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 2, 0),
|
||||
([{"LENGTH": {"VALUE": 5}}], 1, 0),
|
||||
([{"TEXT": {"VALUE": "foo"}}], 1, 0),
|
||||
([{"IS_DIGIT": -1}], 1, 0),
|
||||
([{"ORTH": -1}], 1, 0),
|
||||
# Good patterns
|
||||
([{"TEXT": "foo"}, {"LOWER": "bar"}], 0, 0),
|
||||
([{"LEMMA": {"IN": ["love", "like"]}}, {"POS": "DET", "OP": "?"}], 0, 0),
|
||||
([{"LIKE_NUM": True, "LENGTH": {">=": 5}}], 0, 0),
|
||||
([{"LENGTH": 2}], 0, 0),
|
||||
([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0, 0),
|
||||
([{"NORM": "a"}, {"POS": {"IN": ["NOUN"]}}], 0, 0),
|
||||
([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0, 0),
|
||||
|
|
|
@ -226,3 +226,13 @@ def test_phrase_matcher_callback(en_vocab):
|
|||
matcher.add("COMPANY", mock, pattern)
|
||||
matches = matcher(doc)
|
||||
mock.assert_called_once_with(matcher, doc, 0, matches)
|
||||
|
||||
|
||||
def test_phrase_matcher_remove_overlapping_patterns(en_vocab):
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
pattern1 = Doc(en_vocab, words=["this"])
|
||||
pattern2 = Doc(en_vocab, words=["this", "is"])
|
||||
pattern3 = Doc(en_vocab, words=["this", "is", "a"])
|
||||
pattern4 = Doc(en_vocab, words=["this", "is", "a", "word"])
|
||||
matcher.add("THIS", None, pattern1, pattern2, pattern3, pattern4)
|
||||
matcher.remove("THIS")
|
||||
|
|
|
@ -103,7 +103,7 @@ def test_oracle_moves_missing_B(en_vocab):
|
|||
moves.add_action(move_types.index("L"), label)
|
||||
moves.add_action(move_types.index("U"), label)
|
||||
moves.preprocess_gold(gold)
|
||||
seq = moves.get_oracle_sequence(doc, gold)
|
||||
moves.get_oracle_sequence(doc, gold)
|
||||
|
||||
|
||||
def test_oracle_moves_whitespace(en_vocab):
|
||||
|
|
|
@ -131,6 +131,53 @@ def test_candidate_generation(nlp):
|
|||
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
|
||||
|
||||
|
||||
def test_append_alias(nlp):
|
||||
"""Test that we can append additional alias-entity pairs"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.4, 0.1])
|
||||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# test the size of the relevant candidates
|
||||
assert len(mykb.get_candidates("douglas")) == 2
|
||||
|
||||
# append an alias
|
||||
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
|
||||
|
||||
# test the size of the relevant candidates has been incremented
|
||||
assert len(mykb.get_candidates("douglas")) == 3
|
||||
|
||||
# append the same alias-entity pair again should not work (will throw a warning)
|
||||
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3)
|
||||
|
||||
# test the size of the relevant candidates remained unchanged
|
||||
assert len(mykb.get_candidates("douglas")) == 3
|
||||
|
||||
|
||||
def test_append_invalid_alias(nlp):
|
||||
"""Test that append an alias will throw an error if prior probs are exceeding 1"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
|
||||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# append an alias - should fail because the entities and probabilities vectors are not of equal length
|
||||
with pytest.raises(ValueError):
|
||||
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
|
||||
|
||||
|
||||
def test_preserving_links_asdoc(nlp):
|
||||
"""Test that Span.as_doc preserves the existing entity links"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
|
|
@ -430,7 +430,7 @@ def test_issue957(en_tokenizer):
|
|||
def test_issue999(train_data):
|
||||
"""Test that adding entities and resuming training works passably OK.
|
||||
There are two issues here:
|
||||
1) We have to read labels. This isn't very nice.
|
||||
1) We have to re-add labels. This isn't very nice.
|
||||
2) There's no way to set the learning rate for the weight update, so we
|
||||
end up out-of-scale, causing it to learn too fast.
|
||||
"""
|
||||
|
|
|
@ -323,7 +323,7 @@ def test_issue3456():
|
|||
nlp = English()
|
||||
nlp.add_pipe(nlp.create_pipe("tagger"))
|
||||
nlp.begin_training()
|
||||
list(nlp.pipe(['hi', '']))
|
||||
list(nlp.pipe(["hi", ""]))
|
||||
|
||||
|
||||
def test_issue3468():
|
||||
|
|
|
@ -76,7 +76,6 @@ def test_issue4042_bug2():
|
|||
output_dir.mkdir()
|
||||
ner1.to_disk(output_dir)
|
||||
|
||||
nlp2 = English(vocab)
|
||||
ner2 = EntityRecognizer(vocab)
|
||||
ner2.from_disk(output_dir)
|
||||
assert len(ner2.labels) == 2
|
||||
|
|
|
@ -1,13 +1,8 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
|
||||
import spacy
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.pipeline import EntityRuler
|
||||
from spacy.tokens import Span
|
||||
|
||||
|
||||
def test_issue4267():
|
||||
|
|
|
@ -6,6 +6,6 @@ from spacy.tokens import DocBin
|
|||
|
||||
def test_issue4367():
|
||||
"""Test that docbin init goes well"""
|
||||
doc_bin_1 = DocBin()
|
||||
doc_bin_2 = DocBin(attrs=["LEMMA"])
|
||||
doc_bin_3 = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"])
|
||||
DocBin()
|
||||
DocBin(attrs=["LEMMA"])
|
||||
DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"])
|
||||
|
|
|
@ -74,4 +74,4 @@ def test_serialize_doc_bin():
|
|||
# Deserialize later, e.g. in a new process
|
||||
nlp = spacy.blank("en")
|
||||
doc_bin = DocBin().from_bytes(bytes_data)
|
||||
docs = list(doc_bin.get_docs(nlp.vocab))
|
||||
list(doc_bin.get_docs(nlp.vocab))
|
||||
|
|
|
@ -48,8 +48,13 @@ URLS_SHOULD_MATCH = [
|
|||
"http://a.b--c.de/", # this is a legit domain name see: https://gist.github.com/dperini/729294 comment on 9/9/2014
|
||||
"ssh://login@server.com:12345/repository.git",
|
||||
"svn+ssh://user@ssh.yourdomain.com/path",
|
||||
pytest.param("chrome://extensions/?id=mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail()),
|
||||
pytest.param("chrome-extension://mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail()),
|
||||
pytest.param(
|
||||
"chrome://extensions/?id=mhjfbmdgcfjbbpaeojofohoefgiehjai",
|
||||
marks=pytest.mark.xfail(),
|
||||
),
|
||||
pytest.param(
|
||||
"chrome-extension://mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail()
|
||||
),
|
||||
pytest.param("http://foo.com/blah_blah_(wikipedia)", marks=pytest.mark.xfail()),
|
||||
pytest.param(
|
||||
"http://foo.com/blah_blah_(wikipedia)_(again)", marks=pytest.mark.xfail()
|
||||
|
|
|
@ -51,6 +51,14 @@ def data():
|
|||
return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype="f")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def most_similar_vectors_data():
|
||||
return numpy.asarray(
|
||||
[[0.0, 1.0, 2.0], [1.0, -2.0, 4.0], [1.0, 1.0, -1.0], [2.0, 3.0, 1.0]],
|
||||
dtype="f",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resize_data():
|
||||
return numpy.asarray([[0.0, 1.0], [2.0, 3.0]], dtype="f")
|
||||
|
@ -127,6 +135,12 @@ def test_set_vector(strings, data):
|
|||
assert list(v[strings[0]]) != list(orig[0])
|
||||
|
||||
|
||||
def test_vectors_most_similar(most_similar_vectors_data):
|
||||
v = Vectors(data=most_similar_vectors_data)
|
||||
_, best_rows, _ = v.most_similar(v.data, batch_size=2, n=2, sort=True)
|
||||
assert all(row[0] == i for i, row in enumerate(best_rows))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", ["apple and orange"])
|
||||
def test_vectors_token_vector(tokenizer_v, vectors, text):
|
||||
doc = tokenizer_v(text)
|
||||
|
@ -284,7 +298,7 @@ def test_vocab_prune_vectors():
|
|||
vocab.set_vector("dog", data[1])
|
||||
vocab.set_vector("kitten", data[2])
|
||||
|
||||
remap = vocab.prune_vectors(2)
|
||||
remap = vocab.prune_vectors(2, batch_size=2)
|
||||
assert list(remap.keys()) == ["kitten"]
|
||||
neighbour, similarity = list(remap.values())[0]
|
||||
assert neighbour == "cat", remap
|
||||
|
|
|
@ -666,7 +666,7 @@ def filter_spans(spans):
|
|||
spans (iterable): The spans to filter.
|
||||
RETURNS (list): The filtered spans.
|
||||
"""
|
||||
get_sort_key = lambda span: (span.end - span.start, span.start)
|
||||
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
||||
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||
result = []
|
||||
seen_tokens = set()
|
||||
|
|
|
@ -336,8 +336,8 @@ cdef class Vectors:
|
|||
best_rows[i:i+batch_size] = xp.argpartition(sims, -n, axis=1)[:,-n:]
|
||||
scores[i:i+batch_size] = xp.partition(sims, -n, axis=1)[:,-n:]
|
||||
|
||||
if sort:
|
||||
sorted_index = xp.arange(scores.shape[0])[:,None],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
|
||||
if sort and n >= 2:
|
||||
sorted_index = xp.arange(scores.shape[0])[:,None][i:i+batch_size],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
|
||||
scores[i:i+batch_size] = scores[sorted_index]
|
||||
best_rows[i:i+batch_size] = best_rows[sorted_index]
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ Whether the provided syntactic annotations form a projective dependency tree.
|
|||
|
||||
Convert a list of Doc objects into the
|
||||
[JSON-serializable format](/api/annotation#json-input) used by the
|
||||
[`spacy train`](/api/cli#train) command.
|
||||
[`spacy train`](/api/cli#train) command. Each input doc will be treated as a 'paragraph' in the output doc.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -77,7 +77,7 @@ Convert a list of Doc objects into the
|
|||
| ----------- | ---------------- | ------------------------------------------ |
|
||||
| `docs` | iterable / `Doc` | The `Doc` object(s) to convert. |
|
||||
| `id` | int | ID to assign to the JSON. Defaults to `0`. |
|
||||
| **RETURNS** | list | The data in spaCy's JSON format. |
|
||||
| **RETURNS** | dict | The data in spaCy's JSON format. |
|
||||
|
||||
### gold.align {#align tag="function"}
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ Lemmatize a string.
|
|||
> ```python
|
||||
> from spacy.lemmatizer import Lemmatizer
|
||||
> from spacy.lookups import Lookups
|
||||
> lookups = Loookups()
|
||||
> lookups = Lookups()
|
||||
> lookups.add_table("lemma_rules", {"noun": [["s", ""]]})
|
||||
> lemmatizer = Lemmatizer(lookups)
|
||||
> lemmas = lemmatizer("ducks", "NOUN")
|
||||
|
|
Loading…
Reference in New Issue
Block a user