mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-14 21:57:15 +03:00
Merge branch 'master' into spacy.io
This commit is contained in:
commit
0b2df3b879
106
.github/contributors/PeterGilles.md
vendored
Normal file
106
.github/contributors/PeterGilles.md
vendored
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
# spaCy contributor agreement
|
||||||
|
|
||||||
|
This spaCy Contributor Agreement (**"SCA"**) is based on the
|
||||||
|
[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
|
||||||
|
The SCA applies to any contribution that you make to any product or project
|
||||||
|
managed by us (the **"project"**), and sets out the intellectual property rights
|
||||||
|
you grant to us in the contributed materials. The term **"us"** shall mean
|
||||||
|
[ExplosionAI GmbH](https://explosion.ai/legal). The term
|
||||||
|
**"you"** shall mean the person or entity identified below.
|
||||||
|
|
||||||
|
If you agree to be bound by these terms, fill in the information requested
|
||||||
|
below and include the filled-in version with your first pull request, under the
|
||||||
|
folder [`.github/contributors/`](/.github/contributors/). The name of the file
|
||||||
|
should be your GitHub username, with the extension `.md`. For example, the user
|
||||||
|
example_user would create the file `.github/contributors/example_user.md`.
|
||||||
|
|
||||||
|
Read this agreement carefully before signing. These terms and conditions
|
||||||
|
constitute a binding legal agreement.
|
||||||
|
|
||||||
|
## Contributor Agreement
|
||||||
|
|
||||||
|
1. The term "contribution" or "contributed materials" means any source code,
|
||||||
|
object code, patch, tool, sample, graphic, specification, manual,
|
||||||
|
documentation, or any other material posted or submitted by you to the project.
|
||||||
|
|
||||||
|
2. With respect to any worldwide copyrights, or copyright applications and
|
||||||
|
registrations, in your contribution:
|
||||||
|
|
||||||
|
* you hereby assign to us joint ownership, and to the extent that such
|
||||||
|
assignment is or becomes invalid, ineffective or unenforceable, you hereby
|
||||||
|
grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
|
||||||
|
royalty-free, unrestricted license to exercise all rights under those
|
||||||
|
copyrights. This includes, at our option, the right to sublicense these same
|
||||||
|
rights to third parties through multiple levels of sublicensees or other
|
||||||
|
licensing arrangements;
|
||||||
|
|
||||||
|
* you agree that each of us can do all things in relation to your
|
||||||
|
contribution as if each of us were the sole owners, and if one of us makes
|
||||||
|
a derivative work of your contribution, the one who makes the derivative
|
||||||
|
work (or has it made will be the sole owner of that derivative work;
|
||||||
|
|
||||||
|
* you agree that you will not assert any moral rights in your contribution
|
||||||
|
against us, our licensees or transferees;
|
||||||
|
|
||||||
|
* you agree that we may register a copyright in your contribution and
|
||||||
|
exercise all ownership rights associated with it; and
|
||||||
|
|
||||||
|
* you agree that neither of us has any duty to consult with, obtain the
|
||||||
|
consent of, pay or render an accounting to the other for any use or
|
||||||
|
distribution of your contribution.
|
||||||
|
|
||||||
|
3. With respect to any patents you own, or that you can license without payment
|
||||||
|
to any third party, you hereby grant to us a perpetual, irrevocable,
|
||||||
|
non-exclusive, worldwide, no-charge, royalty-free license to:
|
||||||
|
|
||||||
|
* make, have made, use, sell, offer to sell, import, and otherwise transfer
|
||||||
|
your contribution in whole or in part, alone or in combination with or
|
||||||
|
included in any product, work or materials arising out of the project to
|
||||||
|
which your contribution was submitted, and
|
||||||
|
|
||||||
|
* at our option, to sublicense these same rights to third parties through
|
||||||
|
multiple levels of sublicensees or other licensing arrangements.
|
||||||
|
|
||||||
|
4. Except as set out above, you keep all right, title, and interest in your
|
||||||
|
contribution. The rights that you grant to us under these terms are effective
|
||||||
|
on the date you first submitted a contribution to us, even if your submission
|
||||||
|
took place before the date you sign these terms.
|
||||||
|
|
||||||
|
5. You covenant, represent, warrant and agree that:
|
||||||
|
|
||||||
|
* Each contribution that you submit is and shall be an original work of
|
||||||
|
authorship and you can legally grant the rights set out in this SCA;
|
||||||
|
|
||||||
|
* to the best of your knowledge, each contribution will not violate any
|
||||||
|
third party's copyrights, trademarks, patents, or other intellectual
|
||||||
|
property rights; and
|
||||||
|
|
||||||
|
* each contribution shall be in compliance with U.S. export control laws and
|
||||||
|
other applicable export and import laws. You agree to notify us if you
|
||||||
|
become aware of any circumstance which would make any of the foregoing
|
||||||
|
representations inaccurate in any respect. We may publicly disclose your
|
||||||
|
participation in the project, including the fact that you have signed the SCA.
|
||||||
|
|
||||||
|
6. This SCA is governed by the laws of the State of California and applicable
|
||||||
|
U.S. Federal law. Any choice of law rules will not apply.
|
||||||
|
|
||||||
|
7. Please place an “x” on one of the applicable statement below. Please do NOT
|
||||||
|
mark both statements:
|
||||||
|
|
||||||
|
* [X] I am signing on behalf of myself as an individual and no other person
|
||||||
|
or entity, including my employer, has or will have rights with respect to my
|
||||||
|
contributions.
|
||||||
|
|
||||||
|
* [ ] I am signing on behalf of my employer or a legal entity and I have the
|
||||||
|
actual authority to contractually bind that entity.
|
||||||
|
|
||||||
|
## Contributor Details
|
||||||
|
|
||||||
|
| Field | Entry |
|
||||||
|
|------------------------------- | -------------------- |
|
||||||
|
| Name | Peter Gilles |
|
||||||
|
| Company name (if applicable) | |
|
||||||
|
| Title or role (if applicable) | |
|
||||||
|
| Date | 10.10. |
|
||||||
|
| GitHub username | Peter Gilles |
|
||||||
|
| Website (optional) | |
|
|
@ -197,7 +197,7 @@ path to the model data directory.
|
||||||
```python
|
```python
|
||||||
import spacy
|
import spacy
|
||||||
nlp = spacy.load("en_core_web_sm")
|
nlp = spacy.load("en_core_web_sm")
|
||||||
doc = nlp(u"This is a sentence.")
|
doc = nlp("This is a sentence.")
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also `import` a model directly via its full name and then call its
|
You can also `import` a model directly via its full name and then call its
|
||||||
|
@ -208,7 +208,7 @@ import spacy
|
||||||
import en_core_web_sm
|
import en_core_web_sm
|
||||||
|
|
||||||
nlp = en_core_web_sm.load()
|
nlp = en_core_web_sm.load()
|
||||||
doc = nlp(u"This is a sentence.")
|
doc = nlp("This is a sentence.")
|
||||||
```
|
```
|
||||||
|
|
||||||
📖 **For more info and examples, check out the
|
📖 **For more info and examples, check out the
|
||||||
|
|
34
bin/wiki_entity_linking/README.md
Normal file
34
bin/wiki_entity_linking/README.md
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
## Entity Linking with Wikipedia and Wikidata
|
||||||
|
|
||||||
|
### Step 1: Create a Knowledge Base (KB) and training data
|
||||||
|
|
||||||
|
Run `wikipedia_pretrain_kb.py`
|
||||||
|
* This takes as input the locations of a **Wikipedia and a Wikidata dump**, and produces a **KB directory** + **training file**
|
||||||
|
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||||
|
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
|
||||||
|
* You can set the filtering parameters for KB construction:
|
||||||
|
* `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
|
||||||
|
* `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
|
||||||
|
* `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
||||||
|
* Further parameters to set:
|
||||||
|
* `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
||||||
|
* `entity_vector_length`: length of the pre-trained entity description vectors
|
||||||
|
* `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
|
||||||
|
|
||||||
|
Quick testing and rerunning:
|
||||||
|
* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.
|
||||||
|
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
|
||||||
|
|
||||||
|
|
||||||
|
### Step 2: Train an Entity Linking model
|
||||||
|
|
||||||
|
Run `wikidata_train_entity_linker.py`
|
||||||
|
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
|
||||||
|
* You can set the learning parameters for the EL training:
|
||||||
|
* `epochs`: number of training iterations
|
||||||
|
* `dropout`: dropout rate
|
||||||
|
* `lr`: learning rate
|
||||||
|
* `l2`: L2 regularization
|
||||||
|
* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
|
||||||
|
* Further parameters to set:
|
||||||
|
* `labels_discard`: NER label types to discard during training
|
|
@ -6,6 +6,7 @@ OUTPUT_MODEL_DIR = "nlp"
|
||||||
PRIOR_PROB_PATH = "prior_prob.csv"
|
PRIOR_PROB_PATH = "prior_prob.csv"
|
||||||
ENTITY_DEFS_PATH = "entity_defs.csv"
|
ENTITY_DEFS_PATH = "entity_defs.csv"
|
||||||
ENTITY_FREQ_PATH = "entity_freq.csv"
|
ENTITY_FREQ_PATH = "entity_freq.csv"
|
||||||
|
ENTITY_ALIAS_PATH = "entity_alias.csv"
|
||||||
ENTITY_DESCR_PATH = "entity_descriptions.csv"
|
ENTITY_DESCR_PATH = "entity_descriptions.csv"
|
||||||
|
|
||||||
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
|
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
|
||||||
|
|
|
@ -15,10 +15,11 @@ class Metrics(object):
|
||||||
candidate_is_correct = true_entity == candidate
|
candidate_is_correct = true_entity == candidate
|
||||||
|
|
||||||
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
|
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
|
||||||
# Therefore, if candidate_is_correct then we have a true positive and never a true negative
|
# Therefore, if candidate_is_correct then we have a true positive and never a true negative.
|
||||||
self.true_pos += candidate_is_correct
|
self.true_pos += candidate_is_correct
|
||||||
self.false_neg += not candidate_is_correct
|
self.false_neg += not candidate_is_correct
|
||||||
if candidate not in {"", "NIL"}:
|
if candidate and candidate not in {"", "NIL"}:
|
||||||
|
# A wrong prediction (e.g. Q42 != Q3) counts both as a FP as well as a FN.
|
||||||
self.false_pos += not candidate_is_correct
|
self.false_pos += not candidate_is_correct
|
||||||
|
|
||||||
def calculate_precision(self):
|
def calculate_precision(self):
|
||||||
|
@ -33,6 +34,14 @@ class Metrics(object):
|
||||||
else:
|
else:
|
||||||
return self.true_pos / (self.true_pos + self.false_neg)
|
return self.true_pos / (self.true_pos + self.false_neg)
|
||||||
|
|
||||||
|
def calculate_fscore(self):
|
||||||
|
p = self.calculate_precision()
|
||||||
|
r = self.calculate_recall()
|
||||||
|
if p + r == 0:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
return 2 * p * r / (p + r)
|
||||||
|
|
||||||
|
|
||||||
class EvaluationResults(object):
|
class EvaluationResults(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -43,18 +52,20 @@ class EvaluationResults(object):
|
||||||
self.metrics.update_results(true_entity, candidate)
|
self.metrics.update_results(true_entity, candidate)
|
||||||
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
|
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
|
||||||
|
|
||||||
def increment_false_negatives(self):
|
|
||||||
self.metrics.false_neg += 1
|
|
||||||
|
|
||||||
def report_metrics(self, model_name):
|
def report_metrics(self, model_name):
|
||||||
model_str = model_name.title()
|
model_str = model_name.title()
|
||||||
recall = self.metrics.calculate_recall()
|
recall = self.metrics.calculate_recall()
|
||||||
precision = self.metrics.calculate_precision()
|
precision = self.metrics.calculate_precision()
|
||||||
return ("{}: ".format(model_str) +
|
fscore = self.metrics.calculate_fscore()
|
||||||
"Recall = {} | ".format(round(recall, 3)) +
|
return (
|
||||||
"Precision = {} | ".format(round(precision, 3)) +
|
"{}: ".format(model_str)
|
||||||
"Precision by label = {}".format({k: v.calculate_precision()
|
+ "F-score = {} | ".format(round(fscore, 3))
|
||||||
for k, v in self.metrics_by_label.items()}))
|
+ "Recall = {} | ".format(round(recall, 3))
|
||||||
|
+ "Precision = {} | ".format(round(precision, 3))
|
||||||
|
+ "F-score by label = {}".format(
|
||||||
|
{k: v.calculate_fscore() for k, v in sorted(self.metrics_by_label.items())}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaselineResults(object):
|
class BaselineResults(object):
|
||||||
|
@ -63,25 +74,32 @@ class BaselineResults(object):
|
||||||
self.prior = EvaluationResults()
|
self.prior = EvaluationResults()
|
||||||
self.oracle = EvaluationResults()
|
self.oracle = EvaluationResults()
|
||||||
|
|
||||||
def report_accuracy(self, model):
|
def report_performance(self, model):
|
||||||
results = getattr(self, model)
|
results = getattr(self, model)
|
||||||
return results.report_metrics(model)
|
return results.report_metrics(model)
|
||||||
|
|
||||||
def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
|
def update_baselines(
|
||||||
|
self,
|
||||||
|
true_entity,
|
||||||
|
ent_label,
|
||||||
|
random_candidate,
|
||||||
|
prior_candidate,
|
||||||
|
oracle_candidate,
|
||||||
|
):
|
||||||
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
|
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
|
||||||
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
|
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
|
||||||
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
||||||
|
|
||||||
|
|
||||||
def measure_performance(dev_data, kb, el_pipe):
|
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
|
||||||
baseline_accuracies = measure_baselines(
|
if baseline:
|
||||||
dev_data, kb
|
baseline_accuracies, counts = measure_baselines(dev_data, kb)
|
||||||
)
|
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
||||||
|
logger.info(baseline_accuracies.report_performance("random"))
|
||||||
logger.info(baseline_accuracies.report_accuracy("random"))
|
logger.info(baseline_accuracies.report_performance("prior"))
|
||||||
logger.info(baseline_accuracies.report_accuracy("prior"))
|
logger.info(baseline_accuracies.report_performance("oracle"))
|
||||||
logger.info(baseline_accuracies.report_accuracy("oracle"))
|
|
||||||
|
|
||||||
|
if context:
|
||||||
# using only context
|
# using only context
|
||||||
el_pipe.cfg["incl_context"] = True
|
el_pipe.cfg["incl_context"] = True
|
||||||
el_pipe.cfg["incl_prior"] = False
|
el_pipe.cfg["incl_prior"] = False
|
||||||
|
@ -96,7 +114,11 @@ def measure_performance(dev_data, kb, el_pipe):
|
||||||
|
|
||||||
|
|
||||||
def get_eval_results(data, el_pipe=None):
|
def get_eval_results(data, el_pipe=None):
|
||||||
# If the docs in the data require further processing with an entity linker, set el_pipe
|
"""
|
||||||
|
Evaluate the ent.kb_id_ annotations against the gold standard.
|
||||||
|
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||||
|
If the docs in the data require further processing with an entity linker, set el_pipe.
|
||||||
|
"""
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
|
@ -111,18 +133,15 @@ def get_eval_results(data, el_pipe=None):
|
||||||
|
|
||||||
results = EvaluationResults()
|
results = EvaluationResults()
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
|
|
||||||
try:
|
try:
|
||||||
correct_entries_per_article = dict()
|
correct_entries_per_article = dict()
|
||||||
for entity, kb_dict in gold.links.items():
|
for entity, kb_dict in gold.links.items():
|
||||||
start, end = entity
|
start, end = entity
|
||||||
# only evaluating on positive examples
|
|
||||||
for gold_kb, value in kb_dict.items():
|
for gold_kb, value in kb_dict.items():
|
||||||
if value:
|
if value:
|
||||||
|
# only evaluating on positive examples
|
||||||
offset = _offset(start, end)
|
offset = _offset(start, end)
|
||||||
correct_entries_per_article[offset] = gold_kb
|
correct_entries_per_article[offset] = gold_kb
|
||||||
if offset not in tagged_entries_per_article:
|
|
||||||
results.increment_false_negatives()
|
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
ent_label = ent.label_
|
ent_label = ent.label_
|
||||||
|
@ -142,7 +161,11 @@ def get_eval_results(data, el_pipe=None):
|
||||||
|
|
||||||
|
|
||||||
def measure_baselines(data, kb):
|
def measure_baselines(data, kb):
|
||||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
"""
|
||||||
|
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
|
||||||
|
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||||
|
Also return a dictionary of counts by entity label.
|
||||||
|
"""
|
||||||
counts_d = dict()
|
counts_d = dict()
|
||||||
|
|
||||||
baseline_results = BaselineResults()
|
baseline_results = BaselineResults()
|
||||||
|
@ -152,7 +175,6 @@ def measure_baselines(data, kb):
|
||||||
|
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
correct_entries_per_article = dict()
|
correct_entries_per_article = dict()
|
||||||
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
|
|
||||||
for entity, kb_dict in gold.links.items():
|
for entity, kb_dict in gold.links.items():
|
||||||
start, end = entity
|
start, end = entity
|
||||||
for gold_kb, value in kb_dict.items():
|
for gold_kb, value in kb_dict.items():
|
||||||
|
@ -160,10 +182,6 @@ def measure_baselines(data, kb):
|
||||||
if value:
|
if value:
|
||||||
offset = _offset(start, end)
|
offset = _offset(start, end)
|
||||||
correct_entries_per_article[offset] = gold_kb
|
correct_entries_per_article[offset] = gold_kb
|
||||||
if offset not in tagged_entries_per_article:
|
|
||||||
baseline_results.random.increment_false_negatives()
|
|
||||||
baseline_results.oracle.increment_false_negatives()
|
|
||||||
baseline_results.prior.increment_false_negatives()
|
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
ent_label = ent.label_
|
ent_label = ent.label_
|
||||||
|
@ -176,7 +194,7 @@ def measure_baselines(data, kb):
|
||||||
if gold_entity is not None:
|
if gold_entity is not None:
|
||||||
candidates = kb.get_candidates(ent.text)
|
candidates = kb.get_candidates(ent.text)
|
||||||
oracle_candidate = ""
|
oracle_candidate = ""
|
||||||
best_candidate = ""
|
prior_candidate = ""
|
||||||
random_candidate = ""
|
random_candidate = ""
|
||||||
if candidates:
|
if candidates:
|
||||||
scores = []
|
scores = []
|
||||||
|
@ -187,13 +205,21 @@ def measure_baselines(data, kb):
|
||||||
oracle_candidate = c.entity_
|
oracle_candidate = c.entity_
|
||||||
|
|
||||||
best_index = scores.index(max(scores))
|
best_index = scores.index(max(scores))
|
||||||
best_candidate = candidates[best_index].entity_
|
prior_candidate = candidates[best_index].entity_
|
||||||
random_candidate = random.choice(candidates).entity_
|
random_candidate = random.choice(candidates).entity_
|
||||||
|
|
||||||
baseline_results.update_baselines(gold_entity, ent_label,
|
current_count = counts_d.get(ent_label, 0)
|
||||||
random_candidate, best_candidate, oracle_candidate)
|
counts_d[ent_label] = current_count+1
|
||||||
|
|
||||||
return baseline_results
|
baseline_results.update_baselines(
|
||||||
|
gold_entity,
|
||||||
|
ent_label,
|
||||||
|
random_candidate,
|
||||||
|
prior_candidate,
|
||||||
|
oracle_candidate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return baseline_results, counts_d
|
||||||
|
|
||||||
|
|
||||||
def _offset(start, end):
|
def _offset(start, end):
|
||||||
|
|
|
@ -1,17 +1,12 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import csv
|
|
||||||
import logging
|
import logging
|
||||||
import spacy
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
from bin.wiki_entity_linking import wikipedia_processor as wp
|
|
||||||
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
||||||
|
from bin.wiki_entity_linking import wiki_io as io
|
||||||
csv.field_size_limit(sys.maxsize)
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -22,18 +17,24 @@ def create_kb(
|
||||||
max_entities_per_alias,
|
max_entities_per_alias,
|
||||||
min_entity_freq,
|
min_entity_freq,
|
||||||
min_occ,
|
min_occ,
|
||||||
entity_def_input,
|
entity_def_path,
|
||||||
entity_descr_path,
|
entity_descr_path,
|
||||||
count_input,
|
entity_alias_path,
|
||||||
prior_prob_input,
|
entity_freq_path,
|
||||||
|
prior_prob_path,
|
||||||
entity_vector_length,
|
entity_vector_length,
|
||||||
):
|
):
|
||||||
# Create the knowledge base from Wikidata entries
|
# Create the knowledge base from Wikidata entries
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
|
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
|
||||||
|
entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length)
|
||||||
|
_define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path)
|
||||||
|
return kb
|
||||||
|
|
||||||
|
|
||||||
|
def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length):
|
||||||
# read the mappings from file
|
# read the mappings from file
|
||||||
title_to_id = get_entity_to_id(entity_def_input)
|
title_to_id = io.read_title_to_id(entity_def_path)
|
||||||
id_to_descr = get_id_to_description(entity_descr_path)
|
id_to_descr = io.read_id_to_descr(entity_descr_path)
|
||||||
|
|
||||||
# check the length of the nlp vectors
|
# check the length of the nlp vectors
|
||||||
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
|
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
|
||||||
|
@ -45,10 +46,8 @@ def create_kb(
|
||||||
" cf. https://spacy.io/usage/models#languages."
|
" cf. https://spacy.io/usage/models#languages."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Get entity frequencies")
|
|
||||||
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
|
|
||||||
|
|
||||||
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
|
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
|
||||||
|
entity_frequencies = io.read_entity_to_count(entity_freq_path)
|
||||||
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
|
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
|
||||||
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
|
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
|
||||||
title_to_id,
|
title_to_id,
|
||||||
|
@ -56,36 +55,33 @@ def create_kb(
|
||||||
entity_frequencies,
|
entity_frequencies,
|
||||||
min_entity_freq
|
min_entity_freq
|
||||||
)
|
)
|
||||||
logger.info("Left with {} entities".format(len(description_list)))
|
logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys())))
|
||||||
|
|
||||||
logger.info("Train entity encoder")
|
logger.info("Training entity encoder")
|
||||||
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
|
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
|
||||||
encoder.train(description_list=description_list, to_print=True)
|
encoder.train(description_list=description_list, to_print=True)
|
||||||
|
|
||||||
logger.info("Get entity embeddings:")
|
logger.info("Getting entity embeddings")
|
||||||
embeddings = encoder.apply_encoder(description_list)
|
embeddings = encoder.apply_encoder(description_list)
|
||||||
|
|
||||||
logger.info("Adding {} entities".format(len(entity_list)))
|
logger.info("Adding {} entities".format(len(entity_list)))
|
||||||
kb.set_entities(
|
kb.set_entities(
|
||||||
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
||||||
)
|
)
|
||||||
|
return entity_list, filtered_title_to_id
|
||||||
|
|
||||||
logger.info("Adding aliases")
|
|
||||||
|
def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
|
||||||
|
logger.info("Adding aliases from Wikipedia and Wikidata")
|
||||||
_add_aliases(
|
_add_aliases(
|
||||||
kb,
|
kb,
|
||||||
|
entity_list=entity_list,
|
||||||
title_to_id=filtered_title_to_id,
|
title_to_id=filtered_title_to_id,
|
||||||
max_entities_per_alias=max_entities_per_alias,
|
max_entities_per_alias=max_entities_per_alias,
|
||||||
min_occ=min_occ,
|
min_occ=min_occ,
|
||||||
prior_prob_input=prior_prob_input,
|
prior_prob_path=prior_prob_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("KB size: {} entities, {} aliases".format(
|
|
||||||
kb.get_size_entities(),
|
|
||||||
kb.get_size_aliases()))
|
|
||||||
|
|
||||||
logger.info("Done with kb")
|
|
||||||
return kb
|
|
||||||
|
|
||||||
|
|
||||||
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
|
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
|
||||||
min_entity_freq: int = 10):
|
min_entity_freq: int = 10):
|
||||||
|
@ -104,34 +100,13 @@ def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
|
||||||
return filtered_title_to_id, entity_list, description_list, frequency_list
|
return filtered_title_to_id, entity_list, description_list, frequency_list
|
||||||
|
|
||||||
|
|
||||||
def get_entity_to_id(entity_def_output):
|
def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
|
||||||
entity_to_id = dict()
|
|
||||||
with entity_def_output.open("r", encoding="utf8") as csvfile:
|
|
||||||
csvreader = csv.reader(csvfile, delimiter="|")
|
|
||||||
# skip header
|
|
||||||
next(csvreader)
|
|
||||||
for row in csvreader:
|
|
||||||
entity_to_id[row[0]] = row[1]
|
|
||||||
return entity_to_id
|
|
||||||
|
|
||||||
|
|
||||||
def get_id_to_description(entity_descr_path):
|
|
||||||
id_to_desc = dict()
|
|
||||||
with entity_descr_path.open("r", encoding="utf8") as csvfile:
|
|
||||||
csvreader = csv.reader(csvfile, delimiter="|")
|
|
||||||
# skip header
|
|
||||||
next(csvreader)
|
|
||||||
for row in csvreader:
|
|
||||||
id_to_desc[row[0]] = row[1]
|
|
||||||
return id_to_desc
|
|
||||||
|
|
||||||
|
|
||||||
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
|
|
||||||
wp_titles = title_to_id.keys()
|
wp_titles = title_to_id.keys()
|
||||||
|
|
||||||
# adding aliases with prior probabilities
|
# adding aliases with prior probabilities
|
||||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
logger.info("Adding WP aliases")
|
||||||
|
with prior_prob_path.open("r", encoding="utf8") as prior_file:
|
||||||
# skip header
|
# skip header
|
||||||
prior_file.readline()
|
prior_file.readline()
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
@ -180,10 +155,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
|
||||||
|
|
||||||
def read_nlp_kb(model_dir, kb_file):
|
def read_kb(nlp, kb_file):
|
||||||
nlp = spacy.load(model_dir)
|
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab)
|
kb = KnowledgeBase(vocab=nlp.vocab)
|
||||||
kb.load_bulk(kb_file)
|
kb.load_bulk(kb_file)
|
||||||
logger.info("kb entities: {}".format(kb.get_size_entities()))
|
return kb
|
||||||
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
|
|
||||||
return nlp, kb
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ class EntityEncoder:
|
||||||
|
|
||||||
start = start + batch_size
|
start = start + batch_size
|
||||||
stop = min(stop + batch_size, len(description_list))
|
stop = min(stop + batch_size, len(description_list))
|
||||||
logger.info("encoded: {} entities".format(stop))
|
logger.info("Encoded: {} entities".format(stop))
|
||||||
|
|
||||||
return encodings
|
return encodings
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ class EntityEncoder:
|
||||||
if to_print:
|
if to_print:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Trained entity descriptions on {} ".format(processed) +
|
"Trained entity descriptions on {} ".format(processed) +
|
||||||
"(non-unique) entities across {} ".format(self.epochs) +
|
"(non-unique) descriptions across {} ".format(self.epochs) +
|
||||||
"epochs"
|
"epochs"
|
||||||
)
|
)
|
||||||
logger.info("Final loss: {}".format(loss))
|
logger.info("Final loss: {}".format(loss))
|
||||||
|
|
|
@ -1,395 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
import re
|
|
||||||
import bz2
|
|
||||||
import json
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from spacy.gold import GoldParse
|
|
||||||
from bin.wiki_entity_linking import kb_creator
|
|
||||||
|
|
||||||
"""
|
|
||||||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
|
|
||||||
Gold-standard entities are stored in one file in standoff format (by character offset).
|
|
||||||
"""
|
|
||||||
|
|
||||||
ENTITY_FILE = "gold_entities.csv"
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def create_training_examples_and_descriptions(wikipedia_input,
|
|
||||||
entity_def_input,
|
|
||||||
description_output,
|
|
||||||
training_output,
|
|
||||||
parse_descriptions,
|
|
||||||
limit=None):
|
|
||||||
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
|
||||||
_process_wikipedia_texts(wikipedia_input,
|
|
||||||
wp_to_id,
|
|
||||||
description_output,
|
|
||||||
training_output,
|
|
||||||
parse_descriptions,
|
|
||||||
limit)
|
|
||||||
|
|
||||||
|
|
||||||
def _process_wikipedia_texts(wikipedia_input,
|
|
||||||
wp_to_id,
|
|
||||||
output,
|
|
||||||
training_output,
|
|
||||||
parse_descriptions,
|
|
||||||
limit=None):
|
|
||||||
"""
|
|
||||||
Read the XML wikipedia data to parse out training data:
|
|
||||||
raw text data + positive instances
|
|
||||||
"""
|
|
||||||
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
|
|
||||||
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
|
||||||
|
|
||||||
read_ids = set()
|
|
||||||
|
|
||||||
with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
|
|
||||||
if parse_descriptions:
|
|
||||||
_write_training_description(descr_file, "WD_id", "description")
|
|
||||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
|
||||||
article_count = 0
|
|
||||||
article_text = ""
|
|
||||||
article_title = None
|
|
||||||
article_id = None
|
|
||||||
reading_text = False
|
|
||||||
reading_revision = False
|
|
||||||
|
|
||||||
logger.info("Processed {} articles".format(article_count))
|
|
||||||
|
|
||||||
for line in file:
|
|
||||||
clean_line = line.strip().decode("utf-8")
|
|
||||||
|
|
||||||
if clean_line == "<revision>":
|
|
||||||
reading_revision = True
|
|
||||||
elif clean_line == "</revision>":
|
|
||||||
reading_revision = False
|
|
||||||
|
|
||||||
# Start reading new page
|
|
||||||
if clean_line == "<page>":
|
|
||||||
article_text = ""
|
|
||||||
article_title = None
|
|
||||||
article_id = None
|
|
||||||
# finished reading this page
|
|
||||||
elif clean_line == "</page>":
|
|
||||||
if article_id:
|
|
||||||
clean_text, entities = _process_wp_text(
|
|
||||||
article_title,
|
|
||||||
article_text,
|
|
||||||
wp_to_id
|
|
||||||
)
|
|
||||||
if clean_text is not None and entities is not None:
|
|
||||||
_write_training_entities(entity_file,
|
|
||||||
article_id,
|
|
||||||
clean_text,
|
|
||||||
entities)
|
|
||||||
|
|
||||||
if article_title in wp_to_id and parse_descriptions:
|
|
||||||
description = " ".join(clean_text[:1000].split(" ")[:-1])
|
|
||||||
_write_training_description(
|
|
||||||
descr_file,
|
|
||||||
wp_to_id[article_title],
|
|
||||||
description
|
|
||||||
)
|
|
||||||
article_count += 1
|
|
||||||
if article_count % 10000 == 0:
|
|
||||||
logger.info("Processed {} articles".format(article_count))
|
|
||||||
if limit and article_count >= limit:
|
|
||||||
break
|
|
||||||
article_text = ""
|
|
||||||
article_title = None
|
|
||||||
article_id = None
|
|
||||||
reading_text = False
|
|
||||||
reading_revision = False
|
|
||||||
|
|
||||||
# start reading text within a page
|
|
||||||
if "<text" in clean_line:
|
|
||||||
reading_text = True
|
|
||||||
|
|
||||||
if reading_text:
|
|
||||||
article_text += " " + clean_line
|
|
||||||
|
|
||||||
# stop reading text within a page (we assume a new page doesn't start on the same line)
|
|
||||||
if "</text" in clean_line:
|
|
||||||
reading_text = False
|
|
||||||
|
|
||||||
# read the ID of this article (outside the revision portion of the document)
|
|
||||||
if not reading_revision:
|
|
||||||
ids = id_regex.search(clean_line)
|
|
||||||
if ids:
|
|
||||||
article_id = ids[0]
|
|
||||||
if article_id in read_ids:
|
|
||||||
logger.info(
|
|
||||||
"Found duplicate article ID", article_id, clean_line
|
|
||||||
) # This should never happen ...
|
|
||||||
read_ids.add(article_id)
|
|
||||||
|
|
||||||
# read the title of this article (outside the revision portion of the document)
|
|
||||||
if not reading_revision:
|
|
||||||
titles = title_regex.search(clean_line)
|
|
||||||
if titles:
|
|
||||||
article_title = titles[0].strip()
|
|
||||||
logger.info("Finished. Processed {} articles".format(article_count))
|
|
||||||
|
|
||||||
|
|
||||||
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
|
||||||
info_regex = re.compile(r"{[^{]*?}")
|
|
||||||
htlm_regex = re.compile(r"<!--[^-]*-->")
|
|
||||||
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
|
|
||||||
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
|
|
||||||
ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
|
||||||
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
|
||||||
|
|
||||||
|
|
||||||
def _process_wp_text(article_title, article_text, wp_to_id):
|
|
||||||
# ignore meta Wikipedia pages
|
|
||||||
if (
|
|
||||||
article_title.startswith("Wikipedia:") or
|
|
||||||
article_title.startswith("Kategori:")
|
|
||||||
):
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
# remove the text tags
|
|
||||||
text_search = text_regex.search(article_text)
|
|
||||||
if text_search is None:
|
|
||||||
return None, None
|
|
||||||
text = text_search.group(0)
|
|
||||||
|
|
||||||
# stop processing if this is a redirect page
|
|
||||||
if text.startswith("#REDIRECT"):
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
# get the raw text without markup etc, keeping only interwiki links
|
|
||||||
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
|
|
||||||
return clean_text, entities
|
|
||||||
|
|
||||||
|
|
||||||
def _get_clean_wp_text(article_text):
|
|
||||||
clean_text = article_text.strip()
|
|
||||||
|
|
||||||
# remove bolding & italic markup
|
|
||||||
clean_text = clean_text.replace("'''", "")
|
|
||||||
clean_text = clean_text.replace("''", "")
|
|
||||||
|
|
||||||
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
|
||||||
try_again = True
|
|
||||||
previous_length = len(clean_text)
|
|
||||||
while try_again:
|
|
||||||
clean_text = info_regex.sub(
|
|
||||||
"", clean_text
|
|
||||||
) # non-greedy match excluding a nested {
|
|
||||||
if len(clean_text) < previous_length:
|
|
||||||
try_again = True
|
|
||||||
else:
|
|
||||||
try_again = False
|
|
||||||
previous_length = len(clean_text)
|
|
||||||
|
|
||||||
# remove HTML comments
|
|
||||||
clean_text = htlm_regex.sub("", clean_text)
|
|
||||||
|
|
||||||
# remove Category and File statements
|
|
||||||
clean_text = category_regex.sub("", clean_text)
|
|
||||||
clean_text = file_regex.sub("", clean_text)
|
|
||||||
|
|
||||||
# remove multiple =
|
|
||||||
while "==" in clean_text:
|
|
||||||
clean_text = clean_text.replace("==", "=")
|
|
||||||
|
|
||||||
clean_text = clean_text.replace(". =", ".")
|
|
||||||
clean_text = clean_text.replace(" = ", ". ")
|
|
||||||
clean_text = clean_text.replace("= ", ".")
|
|
||||||
clean_text = clean_text.replace(" =", "")
|
|
||||||
|
|
||||||
# remove refs (non-greedy match)
|
|
||||||
clean_text = ref_regex.sub("", clean_text)
|
|
||||||
clean_text = ref_2_regex.sub("", clean_text)
|
|
||||||
|
|
||||||
# remove additional wikiformatting
|
|
||||||
clean_text = re.sub(r"<blockquote>", "", clean_text)
|
|
||||||
clean_text = re.sub(r"</blockquote>", "", clean_text)
|
|
||||||
|
|
||||||
# change special characters back to normal ones
|
|
||||||
clean_text = clean_text.replace(r"<", "<")
|
|
||||||
clean_text = clean_text.replace(r">", ">")
|
|
||||||
clean_text = clean_text.replace(r""", '"')
|
|
||||||
clean_text = clean_text.replace(r"&nbsp;", " ")
|
|
||||||
clean_text = clean_text.replace(r"&", "&")
|
|
||||||
|
|
||||||
# remove multiple spaces
|
|
||||||
while " " in clean_text:
|
|
||||||
clean_text = clean_text.replace(" ", " ")
|
|
||||||
|
|
||||||
return clean_text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_links(clean_text, wp_to_id):
|
|
||||||
# read the text char by char to get the right offsets for the interwiki links
|
|
||||||
entities = []
|
|
||||||
final_text = ""
|
|
||||||
open_read = 0
|
|
||||||
reading_text = True
|
|
||||||
reading_entity = False
|
|
||||||
reading_mention = False
|
|
||||||
reading_special_case = False
|
|
||||||
entity_buffer = ""
|
|
||||||
mention_buffer = ""
|
|
||||||
for index, letter in enumerate(clean_text):
|
|
||||||
if letter == "[":
|
|
||||||
open_read += 1
|
|
||||||
elif letter == "]":
|
|
||||||
open_read -= 1
|
|
||||||
elif letter == "|":
|
|
||||||
if reading_text:
|
|
||||||
final_text += letter
|
|
||||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
|
||||||
elif reading_entity:
|
|
||||||
reading_text = False
|
|
||||||
reading_entity = False
|
|
||||||
reading_mention = True
|
|
||||||
else:
|
|
||||||
reading_special_case = True
|
|
||||||
else:
|
|
||||||
if reading_entity:
|
|
||||||
entity_buffer += letter
|
|
||||||
elif reading_mention:
|
|
||||||
mention_buffer += letter
|
|
||||||
elif reading_text:
|
|
||||||
final_text += letter
|
|
||||||
else:
|
|
||||||
raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
|
|
||||||
|
|
||||||
if open_read > 2:
|
|
||||||
reading_special_case = True
|
|
||||||
|
|
||||||
if open_read == 2 and reading_text:
|
|
||||||
reading_text = False
|
|
||||||
reading_entity = True
|
|
||||||
reading_mention = False
|
|
||||||
|
|
||||||
# we just finished reading an entity
|
|
||||||
if open_read == 0 and not reading_text:
|
|
||||||
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
|
||||||
reading_special_case = True
|
|
||||||
# Ignore cases with nested structures like File: handles etc
|
|
||||||
if not reading_special_case:
|
|
||||||
if not mention_buffer:
|
|
||||||
mention_buffer = entity_buffer
|
|
||||||
start = len(final_text)
|
|
||||||
end = start + len(mention_buffer)
|
|
||||||
qid = wp_to_id.get(entity_buffer, None)
|
|
||||||
if qid:
|
|
||||||
entities.append((mention_buffer, qid, start, end))
|
|
||||||
final_text += mention_buffer
|
|
||||||
|
|
||||||
entity_buffer = ""
|
|
||||||
mention_buffer = ""
|
|
||||||
|
|
||||||
reading_text = True
|
|
||||||
reading_entity = False
|
|
||||||
reading_mention = False
|
|
||||||
reading_special_case = False
|
|
||||||
return final_text, entities
|
|
||||||
|
|
||||||
|
|
||||||
def _write_training_description(outputfile, qid, description):
|
|
||||||
if description is not None:
|
|
||||||
line = str(qid) + "|" + description + "\n"
|
|
||||||
outputfile.write(line)
|
|
||||||
|
|
||||||
|
|
||||||
def _write_training_entities(outputfile, article_id, clean_text, entities):
|
|
||||||
entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
|
|
||||||
line = json.dumps(
|
|
||||||
{
|
|
||||||
"article_id": article_id,
|
|
||||||
"clean_text": clean_text,
|
|
||||||
"entities": entities_data
|
|
||||||
},
|
|
||||||
ensure_ascii=False) + "\n"
|
|
||||||
outputfile.write(line)
|
|
||||||
|
|
||||||
|
|
||||||
def read_training(nlp, entity_file_path, dev, limit, kb):
|
|
||||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
|
||||||
For training,, it will include negative training examples by using the candidate generator,
|
|
||||||
and it will only keep positive training examples that can be found by using the candidate generator.
|
|
||||||
For testing, it will include all positive examples only."""
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
data = []
|
|
||||||
num_entities = 0
|
|
||||||
get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
|
|
||||||
|
|
||||||
logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
|
|
||||||
with entity_file_path.open("r", encoding="utf8") as file:
|
|
||||||
with tqdm(total=limit, leave=False) as pbar:
|
|
||||||
for i, line in enumerate(file):
|
|
||||||
example = json.loads(line)
|
|
||||||
article_id = example["article_id"]
|
|
||||||
clean_text = example["clean_text"]
|
|
||||||
entities = example["entities"]
|
|
||||||
|
|
||||||
if dev != is_dev(article_id) or len(clean_text) >= 30000:
|
|
||||||
continue
|
|
||||||
|
|
||||||
doc = nlp(clean_text)
|
|
||||||
gold = get_gold_parse(doc, entities)
|
|
||||||
if gold and len(gold.links) > 0:
|
|
||||||
data.append((doc, gold))
|
|
||||||
num_entities += len(gold.links)
|
|
||||||
pbar.update(len(gold.links))
|
|
||||||
if limit and num_entities >= limit:
|
|
||||||
break
|
|
||||||
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def _get_gold_parse(doc, entities, dev, kb):
|
|
||||||
gold_entities = {}
|
|
||||||
tagged_ent_positions = set(
|
|
||||||
[(ent.start_char, ent.end_char) for ent in doc.ents]
|
|
||||||
)
|
|
||||||
|
|
||||||
for entity in entities:
|
|
||||||
entity_id = entity["entity"]
|
|
||||||
alias = entity["alias"]
|
|
||||||
start = entity["start"]
|
|
||||||
end = entity["end"]
|
|
||||||
|
|
||||||
candidates = kb.get_candidates(alias)
|
|
||||||
candidate_ids = [
|
|
||||||
c.entity_ for c in candidates
|
|
||||||
]
|
|
||||||
|
|
||||||
should_add_ent = (
|
|
||||||
dev or
|
|
||||||
(
|
|
||||||
(start, end) in tagged_ent_positions and
|
|
||||||
entity_id in candidate_ids and
|
|
||||||
len(candidates) > 1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_add_ent:
|
|
||||||
value_by_id = {entity_id: 1.0}
|
|
||||||
if not dev:
|
|
||||||
random.shuffle(candidate_ids)
|
|
||||||
value_by_id.update({
|
|
||||||
kb_id: 0.0
|
|
||||||
for kb_id in candidate_ids
|
|
||||||
if kb_id != entity_id
|
|
||||||
})
|
|
||||||
gold_entities[(start, end)] = value_by_id
|
|
||||||
|
|
||||||
return GoldParse(doc, links=gold_entities)
|
|
||||||
|
|
||||||
|
|
||||||
def is_dev(article_id):
|
|
||||||
return article_id.endswith("3")
|
|
127
bin/wiki_entity_linking/wiki_io.py
Normal file
127
bin/wiki_entity_linking/wiki_io.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import csv
|
||||||
|
|
||||||
|
# min() needed to prevent error on windows, cf https://stackoverflow.com/questions/52404416/
|
||||||
|
csv.field_size_limit(min(sys.maxsize, 2147483646))
|
||||||
|
|
||||||
|
""" This class provides reading/writing methods for temp files """
|
||||||
|
|
||||||
|
|
||||||
|
# Entity definition: WP title -> WD ID #
|
||||||
|
def write_title_to_id(entity_def_output, title_to_id):
|
||||||
|
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||||
|
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||||
|
for title, qid in title_to_id.items():
|
||||||
|
id_file.write(title + "|" + str(qid) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def read_title_to_id(entity_def_output):
|
||||||
|
title_to_id = dict()
|
||||||
|
with entity_def_output.open("r", encoding="utf8") as id_file:
|
||||||
|
csvreader = csv.reader(id_file, delimiter="|")
|
||||||
|
# skip header
|
||||||
|
next(csvreader)
|
||||||
|
for row in csvreader:
|
||||||
|
title_to_id[row[0]] = row[1]
|
||||||
|
return title_to_id
|
||||||
|
|
||||||
|
|
||||||
|
# Entity aliases from WD: WD ID -> WD alias #
|
||||||
|
def write_id_to_alias(entity_alias_path, id_to_alias):
|
||||||
|
with entity_alias_path.open("w", encoding="utf8") as alias_file:
|
||||||
|
alias_file.write("WD_id" + "|" + "alias" + "\n")
|
||||||
|
for qid, alias_list in id_to_alias.items():
|
||||||
|
for alias in alias_list:
|
||||||
|
alias_file.write(str(qid) + "|" + alias + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def read_id_to_alias(entity_alias_path):
|
||||||
|
id_to_alias = dict()
|
||||||
|
with entity_alias_path.open("r", encoding="utf8") as alias_file:
|
||||||
|
csvreader = csv.reader(alias_file, delimiter="|")
|
||||||
|
# skip header
|
||||||
|
next(csvreader)
|
||||||
|
for row in csvreader:
|
||||||
|
qid = row[0]
|
||||||
|
alias = row[1]
|
||||||
|
alias_list = id_to_alias.get(qid, [])
|
||||||
|
alias_list.append(alias)
|
||||||
|
id_to_alias[qid] = alias_list
|
||||||
|
return id_to_alias
|
||||||
|
|
||||||
|
|
||||||
|
def read_alias_to_id_generator(entity_alias_path):
|
||||||
|
""" Read (aliases, qid) tuples """
|
||||||
|
|
||||||
|
with entity_alias_path.open("r", encoding="utf8") as alias_file:
|
||||||
|
csvreader = csv.reader(alias_file, delimiter="|")
|
||||||
|
# skip header
|
||||||
|
next(csvreader)
|
||||||
|
for row in csvreader:
|
||||||
|
qid = row[0]
|
||||||
|
alias = row[1]
|
||||||
|
yield alias, qid
|
||||||
|
|
||||||
|
|
||||||
|
# Entity descriptions from WD: WD ID -> WD alias #
|
||||||
|
def write_id_to_descr(entity_descr_output, id_to_descr):
|
||||||
|
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||||
|
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||||
|
for qid, descr in id_to_descr.items():
|
||||||
|
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def read_id_to_descr(entity_desc_path):
|
||||||
|
id_to_desc = dict()
|
||||||
|
with entity_desc_path.open("r", encoding="utf8") as descr_file:
|
||||||
|
csvreader = csv.reader(descr_file, delimiter="|")
|
||||||
|
# skip header
|
||||||
|
next(csvreader)
|
||||||
|
for row in csvreader:
|
||||||
|
id_to_desc[row[0]] = row[1]
|
||||||
|
return id_to_desc
|
||||||
|
|
||||||
|
|
||||||
|
# Entity counts from WP: WP title -> count #
|
||||||
|
def write_entity_to_count(prior_prob_input, count_output):
|
||||||
|
# Write entity counts for quick access later
|
||||||
|
entity_to_count = dict()
|
||||||
|
total_count = 0
|
||||||
|
|
||||||
|
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||||
|
# skip header
|
||||||
|
prior_file.readline()
|
||||||
|
line = prior_file.readline()
|
||||||
|
|
||||||
|
while line:
|
||||||
|
splits = line.replace("\n", "").split(sep="|")
|
||||||
|
# alias = splits[0]
|
||||||
|
count = int(splits[1])
|
||||||
|
entity = splits[2]
|
||||||
|
|
||||||
|
current_count = entity_to_count.get(entity, 0)
|
||||||
|
entity_to_count[entity] = current_count + count
|
||||||
|
|
||||||
|
total_count += count
|
||||||
|
|
||||||
|
line = prior_file.readline()
|
||||||
|
|
||||||
|
with count_output.open("w", encoding="utf8") as entity_file:
|
||||||
|
entity_file.write("entity" + "|" + "count" + "\n")
|
||||||
|
for entity, count in entity_to_count.items():
|
||||||
|
entity_file.write(entity + "|" + str(count) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def read_entity_to_count(count_input):
|
||||||
|
entity_to_count = dict()
|
||||||
|
with count_input.open("r", encoding="utf8") as csvfile:
|
||||||
|
csvreader = csv.reader(csvfile, delimiter="|")
|
||||||
|
# skip header
|
||||||
|
next(csvreader)
|
||||||
|
for row in csvreader:
|
||||||
|
entity_to_count[row[0]] = int(row[1])
|
||||||
|
|
||||||
|
return entity_to_count
|
128
bin/wiki_entity_linking/wiki_namespaces.py
Normal file
128
bin/wiki_entity_linking/wiki_namespaces.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
# List of meta pages in Wikidata, should be kept out of the Knowledge base
|
||||||
|
WD_META_ITEMS = [
|
||||||
|
"Q163875",
|
||||||
|
"Q191780",
|
||||||
|
"Q224414",
|
||||||
|
"Q4167836",
|
||||||
|
"Q4167410",
|
||||||
|
"Q4663903",
|
||||||
|
"Q11266439",
|
||||||
|
"Q13406463",
|
||||||
|
"Q15407973",
|
||||||
|
"Q18616576",
|
||||||
|
"Q19887878",
|
||||||
|
"Q22808320",
|
||||||
|
"Q23894233",
|
||||||
|
"Q33120876",
|
||||||
|
"Q42104522",
|
||||||
|
"Q47460393",
|
||||||
|
"Q64875536",
|
||||||
|
"Q66480449",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add more cases from non-English WP's
|
||||||
|
|
||||||
|
# List of prefixes that refer to Wikipedia "file" pages
|
||||||
|
WP_FILE_NAMESPACE = ["Bestand", "File"]
|
||||||
|
|
||||||
|
# List of prefixes that refer to Wikipedia "category" pages
|
||||||
|
WP_CATEGORY_NAMESPACE = ["Kategori", "Category", "Categorie"]
|
||||||
|
|
||||||
|
# List of prefixes that refer to Wikipedia "meta" pages
|
||||||
|
# these will/should be matched ignoring case
|
||||||
|
WP_META_NAMESPACE = (
|
||||||
|
WP_FILE_NAMESPACE
|
||||||
|
+ WP_CATEGORY_NAMESPACE
|
||||||
|
+ [
|
||||||
|
"b",
|
||||||
|
"betawikiversity",
|
||||||
|
"Book",
|
||||||
|
"c",
|
||||||
|
"Commons",
|
||||||
|
"d",
|
||||||
|
"dbdump",
|
||||||
|
"download",
|
||||||
|
"Draft",
|
||||||
|
"Education",
|
||||||
|
"Foundation",
|
||||||
|
"Gadget",
|
||||||
|
"Gadget definition",
|
||||||
|
"Gebruiker",
|
||||||
|
"gerrit",
|
||||||
|
"Help",
|
||||||
|
"Image",
|
||||||
|
"Incubator",
|
||||||
|
"m",
|
||||||
|
"mail",
|
||||||
|
"mailarchive",
|
||||||
|
"media",
|
||||||
|
"MediaWiki",
|
||||||
|
"MediaWiki talk",
|
||||||
|
"Mediawikiwiki",
|
||||||
|
"MediaZilla",
|
||||||
|
"Meta",
|
||||||
|
"Metawikipedia",
|
||||||
|
"Module",
|
||||||
|
"mw",
|
||||||
|
"n",
|
||||||
|
"nost",
|
||||||
|
"oldwikisource",
|
||||||
|
"otrs",
|
||||||
|
"OTRSwiki",
|
||||||
|
"Overleg gebruiker",
|
||||||
|
"outreach",
|
||||||
|
"outreachwiki",
|
||||||
|
"Portal",
|
||||||
|
"phab",
|
||||||
|
"Phabricator",
|
||||||
|
"Project",
|
||||||
|
"q",
|
||||||
|
"quality",
|
||||||
|
"rev",
|
||||||
|
"s",
|
||||||
|
"spcom",
|
||||||
|
"Special",
|
||||||
|
"species",
|
||||||
|
"Strategy",
|
||||||
|
"sulutil",
|
||||||
|
"svn",
|
||||||
|
"Talk",
|
||||||
|
"Template",
|
||||||
|
"Template talk",
|
||||||
|
"Testwiki",
|
||||||
|
"ticket",
|
||||||
|
"TimedText",
|
||||||
|
"Toollabs",
|
||||||
|
"tools",
|
||||||
|
"tswiki",
|
||||||
|
"User",
|
||||||
|
"User talk",
|
||||||
|
"v",
|
||||||
|
"voy",
|
||||||
|
"w",
|
||||||
|
"Wikibooks",
|
||||||
|
"Wikidata",
|
||||||
|
"wikiHow",
|
||||||
|
"Wikinvest",
|
||||||
|
"wikilivres",
|
||||||
|
"Wikimedia",
|
||||||
|
"Wikinews",
|
||||||
|
"Wikipedia",
|
||||||
|
"Wikipedia talk",
|
||||||
|
"Wikiquote",
|
||||||
|
"Wikisource",
|
||||||
|
"Wikispecies",
|
||||||
|
"Wikitech",
|
||||||
|
"Wikiversity",
|
||||||
|
"Wikivoyage",
|
||||||
|
"wikt",
|
||||||
|
"wiktionary",
|
||||||
|
"wmf",
|
||||||
|
"wmania",
|
||||||
|
"WP",
|
||||||
|
]
|
||||||
|
)
|
|
@ -18,11 +18,12 @@ from pathlib import Path
|
||||||
import plac
|
import plac
|
||||||
|
|
||||||
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
|
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
|
||||||
|
from bin.wiki_entity_linking import wiki_io as io
|
||||||
from bin.wiki_entity_linking import kb_creator
|
from bin.wiki_entity_linking import kb_creator
|
||||||
from bin.wiki_entity_linking import training_set_creator
|
|
||||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
|
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
|
||||||
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
|
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH, ENTITY_ALIAS_PATH
|
||||||
import spacy
|
import spacy
|
||||||
|
from bin.wiki_entity_linking.kb_creator import read_kb
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -39,9 +40,11 @@ logger = logging.getLogger(__name__)
|
||||||
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
|
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
|
||||||
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
|
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
|
||||||
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
|
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
|
||||||
descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
|
descr_from_wp=("Flag for using wp descriptions not wd", "flag", "wp"),
|
||||||
limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
|
limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int),
|
||||||
lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
|
limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int),
|
||||||
|
limit_wd=("Threshold to limit lines read from WD", "option", "lw", int),
|
||||||
|
lang=("Optional language for which to get Wikidata titles. Defaults to 'en'", "option", "la", str),
|
||||||
)
|
)
|
||||||
def main(
|
def main(
|
||||||
wd_json,
|
wd_json,
|
||||||
|
@ -54,13 +57,16 @@ def main(
|
||||||
entity_vector_length=64,
|
entity_vector_length=64,
|
||||||
loc_prior_prob=None,
|
loc_prior_prob=None,
|
||||||
loc_entity_defs=None,
|
loc_entity_defs=None,
|
||||||
|
loc_entity_alias=None,
|
||||||
loc_entity_desc=None,
|
loc_entity_desc=None,
|
||||||
descriptions_from_wikipedia=False,
|
descr_from_wp=False,
|
||||||
limit=None,
|
limit_prior=None,
|
||||||
|
limit_train=None,
|
||||||
|
limit_wd=None,
|
||||||
lang="en",
|
lang="en",
|
||||||
):
|
):
|
||||||
|
|
||||||
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
|
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
|
||||||
|
entity_alias_path = loc_entity_alias if loc_entity_alias else output_dir / ENTITY_ALIAS_PATH
|
||||||
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
|
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
|
||||||
entity_freq_path = output_dir / ENTITY_FREQ_PATH
|
entity_freq_path = output_dir / ENTITY_FREQ_PATH
|
||||||
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
|
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
|
||||||
|
@ -69,15 +75,12 @@ def main(
|
||||||
|
|
||||||
logger.info("Creating KB with Wikipedia and WikiData")
|
logger.info("Creating KB with Wikipedia and WikiData")
|
||||||
|
|
||||||
if limit is not None:
|
|
||||||
logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
|
|
||||||
|
|
||||||
# STEP 0: set up IO
|
# STEP 0: set up IO
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir(parents=True)
|
output_dir.mkdir(parents=True)
|
||||||
|
|
||||||
# STEP 1: create the NLP object
|
# STEP 1: Load the NLP object
|
||||||
logger.info("STEP 1: Loading model {}".format(model))
|
logger.info("STEP 1: Loading NLP model {}".format(model))
|
||||||
nlp = spacy.load(model)
|
nlp = spacy.load(model)
|
||||||
|
|
||||||
# check the length of the nlp vectors
|
# check the length of the nlp vectors
|
||||||
|
@ -90,62 +93,83 @@ def main(
|
||||||
# STEP 2: create prior probabilities from WP
|
# STEP 2: create prior probabilities from WP
|
||||||
if not prior_prob_path.exists():
|
if not prior_prob_path.exists():
|
||||||
# It takes about 2h to process 1000M lines of Wikipedia XML dump
|
# It takes about 2h to process 1000M lines of Wikipedia XML dump
|
||||||
logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
|
logger.info("STEP 2: Writing prior probabilities to {}".format(prior_prob_path))
|
||||||
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
|
if limit_prior is not None:
|
||||||
logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
|
logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_prior))
|
||||||
|
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit_prior)
|
||||||
|
else:
|
||||||
|
logger.info("STEP 2: Reading prior probabilities from {}".format(prior_prob_path))
|
||||||
|
|
||||||
# STEP 3: deduce entity frequencies from WP (takes only a few minutes)
|
# STEP 3: calculate entity frequencies
|
||||||
logger.info("STEP 3: calculating entity frequencies")
|
if not entity_freq_path.exists():
|
||||||
wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
|
logger.info("STEP 3: Calculating and writing entity frequencies to {}".format(entity_freq_path))
|
||||||
|
io.write_entity_to_count(prior_prob_path, entity_freq_path)
|
||||||
|
else:
|
||||||
|
logger.info("STEP 3: Reading entity frequencies from {}".format(entity_freq_path))
|
||||||
|
|
||||||
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
|
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
|
||||||
message = " and descriptions" if not descriptions_from_wikipedia else ""
|
if (not entity_defs_path.exists()) or (not descr_from_wp and not entity_descr_path.exists()):
|
||||||
if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
|
|
||||||
# It takes about 10h to process 55M lines of Wikidata JSON dump
|
# It takes about 10h to process 55M lines of Wikidata JSON dump
|
||||||
logger.info("STEP 4: parsing wikidata for entity definitions" + message)
|
logger.info("STEP 4: Parsing and writing Wikidata entity definitions to {}".format(entity_defs_path))
|
||||||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
|
if limit_wd is not None:
|
||||||
|
logger.warning("Warning: reading only {} lines of Wikidata dump".format(limit_wd))
|
||||||
|
title_to_id, id_to_descr, id_to_alias = wd.read_wikidata_entities_json(
|
||||||
wd_json,
|
wd_json,
|
||||||
limit,
|
limit_wd,
|
||||||
to_print=False,
|
to_print=False,
|
||||||
lang=lang,
|
lang=lang,
|
||||||
parse_descriptions=(not descriptions_from_wikipedia),
|
parse_descr=(not descr_from_wp),
|
||||||
)
|
)
|
||||||
wd.write_entity_files(entity_defs_path, title_to_id)
|
io.write_title_to_id(entity_defs_path, title_to_id)
|
||||||
if not descriptions_from_wikipedia:
|
|
||||||
wd.write_entity_description_files(entity_descr_path, id_to_descr)
|
|
||||||
logger.info("STEP 4: read entity definitions" + message)
|
|
||||||
|
|
||||||
# STEP 5: Getting gold entities from wikipedia
|
logger.info("STEP 4b: Writing Wikidata entity aliases to {}".format(entity_alias_path))
|
||||||
message = " and descriptions" if descriptions_from_wikipedia else ""
|
io.write_id_to_alias(entity_alias_path, id_to_alias)
|
||||||
if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
|
|
||||||
logger.info("STEP 5: parsing wikipedia for gold entities" + message)
|
if not descr_from_wp:
|
||||||
training_set_creator.create_training_examples_and_descriptions(
|
logger.info("STEP 4c: Writing Wikidata entity descriptions to {}".format(entity_descr_path))
|
||||||
wp_xml,
|
io.write_id_to_descr(entity_descr_path, id_to_descr)
|
||||||
entity_defs_path,
|
else:
|
||||||
entity_descr_path,
|
logger.info("STEP 4: Reading entity definitions from {}".format(entity_defs_path))
|
||||||
training_entities_path,
|
logger.info("STEP 4b: Reading entity aliases from {}".format(entity_alias_path))
|
||||||
parse_descriptions=descriptions_from_wikipedia,
|
if not descr_from_wp:
|
||||||
limit=limit,
|
logger.info("STEP 4c: Reading entity descriptions from {}".format(entity_descr_path))
|
||||||
)
|
|
||||||
logger.info("STEP 5: read gold entities" + message)
|
# STEP 5: Getting gold entities from Wikipedia
|
||||||
|
if (not training_entities_path.exists()) or (descr_from_wp and not entity_descr_path.exists()):
|
||||||
|
logger.info("STEP 5: Parsing and writing Wikipedia gold entities to {}".format(training_entities_path))
|
||||||
|
if limit_train is not None:
|
||||||
|
logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_train))
|
||||||
|
wp.create_training_and_desc(wp_xml, entity_defs_path, entity_descr_path,
|
||||||
|
training_entities_path, descr_from_wp, limit_train)
|
||||||
|
if descr_from_wp:
|
||||||
|
logger.info("STEP 5b: Parsing and writing Wikipedia descriptions to {}".format(entity_descr_path))
|
||||||
|
else:
|
||||||
|
logger.info("STEP 5: Reading gold entities from {}".format(training_entities_path))
|
||||||
|
if descr_from_wp:
|
||||||
|
logger.info("STEP 5b: Reading entity descriptions from {}".format(entity_descr_path))
|
||||||
|
|
||||||
# STEP 6: creating the actual KB
|
# STEP 6: creating the actual KB
|
||||||
# It takes ca. 30 minutes to pretrain the entity embeddings
|
# It takes ca. 30 minutes to pretrain the entity embeddings
|
||||||
logger.info("STEP 6: creating the KB at {}".format(kb_path))
|
if not kb_path.exists():
|
||||||
|
logger.info("STEP 6: Creating the KB at {}".format(kb_path))
|
||||||
kb = kb_creator.create_kb(
|
kb = kb_creator.create_kb(
|
||||||
nlp=nlp,
|
nlp=nlp,
|
||||||
max_entities_per_alias=max_per_alias,
|
max_entities_per_alias=max_per_alias,
|
||||||
min_entity_freq=min_freq,
|
min_entity_freq=min_freq,
|
||||||
min_occ=min_pair,
|
min_occ=min_pair,
|
||||||
entity_def_input=entity_defs_path,
|
entity_def_path=entity_defs_path,
|
||||||
entity_descr_path=entity_descr_path,
|
entity_descr_path=entity_descr_path,
|
||||||
count_input=entity_freq_path,
|
entity_alias_path=entity_alias_path,
|
||||||
prior_prob_input=prior_prob_path,
|
entity_freq_path=entity_freq_path,
|
||||||
|
prior_prob_path=prior_prob_path,
|
||||||
entity_vector_length=entity_vector_length,
|
entity_vector_length=entity_vector_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
kb.dump(kb_path)
|
kb.dump(kb_path)
|
||||||
|
logger.info("kb entities: {}".format(kb.get_size_entities()))
|
||||||
|
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
|
||||||
nlp.to_disk(output_dir / KB_MODEL_DIR)
|
nlp.to_disk(output_dir / KB_MODEL_DIR)
|
||||||
|
else:
|
||||||
|
logger.info("STEP 6: KB already exists at {}".format(kb_path))
|
||||||
|
|
||||||
logger.info("Done!")
|
logger.info("Done!")
|
||||||
|
|
||||||
|
|
|
@ -1,40 +1,52 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import gzip
|
import bz2
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import datetime
|
|
||||||
|
from bin.wiki_entity_linking.wiki_namespaces import WD_META_ITEMS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
|
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descr=True):
|
||||||
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
|
# Read the JSON wiki data and parse out the entities. Takes about 7-10h to parse 55M lines.
|
||||||
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||||
|
|
||||||
site_filter = '{}wiki'.format(lang)
|
site_filter = '{}wiki'.format(lang)
|
||||||
|
|
||||||
# properties filter (currently disabled to get ALL data)
|
# filter: currently defined as OR: one hit suffices to be removed from further processing
|
||||||
prop_filter = dict()
|
exclude_list = WD_META_ITEMS
|
||||||
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
|
|
||||||
|
# punctuation
|
||||||
|
exclude_list.extend(["Q1383557", "Q10617810"])
|
||||||
|
|
||||||
|
# letters etc
|
||||||
|
exclude_list.extend(["Q188725", "Q19776628", "Q3841820", "Q17907810", "Q9788", "Q9398093"])
|
||||||
|
|
||||||
|
neg_prop_filter = {
|
||||||
|
'P31': exclude_list, # instance of
|
||||||
|
'P279': exclude_list # subclass
|
||||||
|
}
|
||||||
|
|
||||||
title_to_id = dict()
|
title_to_id = dict()
|
||||||
id_to_descr = dict()
|
id_to_descr = dict()
|
||||||
|
id_to_alias = dict()
|
||||||
|
|
||||||
# parse appropriate fields - depending on what we need in the KB
|
# parse appropriate fields - depending on what we need in the KB
|
||||||
parse_properties = False
|
parse_properties = False
|
||||||
parse_sitelinks = True
|
parse_sitelinks = True
|
||||||
parse_labels = False
|
parse_labels = False
|
||||||
parse_aliases = False
|
parse_aliases = True
|
||||||
parse_claims = False
|
parse_claims = True
|
||||||
|
|
||||||
with gzip.open(wikidata_file, mode='rb') as file:
|
with bz2.open(wikidata_file, mode='rb') as file:
|
||||||
for cnt, line in enumerate(file):
|
for cnt, line in enumerate(file):
|
||||||
if limit and cnt >= limit:
|
if limit and cnt >= limit:
|
||||||
break
|
break
|
||||||
if cnt % 500000 == 0:
|
if cnt % 500000 == 0 and cnt > 0:
|
||||||
logger.info("processed {} lines of WikiData dump".format(cnt))
|
logger.info("processed {} lines of WikiData JSON dump".format(cnt))
|
||||||
clean_line = line.strip()
|
clean_line = line.strip()
|
||||||
if clean_line.endswith(b","):
|
if clean_line.endswith(b","):
|
||||||
clean_line = clean_line[:-1]
|
clean_line = clean_line[:-1]
|
||||||
|
@ -43,13 +55,11 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
||||||
entry_type = obj["type"]
|
entry_type = obj["type"]
|
||||||
|
|
||||||
if entry_type == "item":
|
if entry_type == "item":
|
||||||
# filtering records on their properties (currently disabled to get ALL data)
|
|
||||||
# keep = False
|
|
||||||
keep = True
|
keep = True
|
||||||
|
|
||||||
claims = obj["claims"]
|
claims = obj["claims"]
|
||||||
if parse_claims:
|
if parse_claims:
|
||||||
for prop, value_set in prop_filter.items():
|
for prop, value_set in neg_prop_filter.items():
|
||||||
claim_property = claims.get(prop, None)
|
claim_property = claims.get(prop, None)
|
||||||
if claim_property:
|
if claim_property:
|
||||||
for cp in claim_property:
|
for cp in claim_property:
|
||||||
|
@ -61,7 +71,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
||||||
)
|
)
|
||||||
cp_rank = cp["rank"]
|
cp_rank = cp["rank"]
|
||||||
if cp_rank != "deprecated" and cp_id in value_set:
|
if cp_rank != "deprecated" and cp_id in value_set:
|
||||||
keep = True
|
keep = False
|
||||||
|
|
||||||
if keep:
|
if keep:
|
||||||
unique_id = obj["id"]
|
unique_id = obj["id"]
|
||||||
|
@ -108,7 +118,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
||||||
"label (" + lang + "):", lang_label["value"]
|
"label (" + lang + "):", lang_label["value"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if found_link and parse_descriptions:
|
if found_link and parse_descr:
|
||||||
descriptions = obj["descriptions"]
|
descriptions = obj["descriptions"]
|
||||||
if descriptions:
|
if descriptions:
|
||||||
lang_descr = descriptions.get(lang, None)
|
lang_descr = descriptions.get(lang, None)
|
||||||
|
@ -130,22 +140,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
||||||
print(
|
print(
|
||||||
"alias (" + lang + "):", item["value"]
|
"alias (" + lang + "):", item["value"]
|
||||||
)
|
)
|
||||||
|
alias_list = id_to_alias.get(unique_id, [])
|
||||||
|
alias_list.append(item["value"])
|
||||||
|
id_to_alias[unique_id] = alias_list
|
||||||
|
|
||||||
if to_print:
|
if to_print:
|
||||||
print()
|
print()
|
||||||
|
|
||||||
return title_to_id, id_to_descr
|
# log final number of lines processed
|
||||||
|
logger.info("Finished. Processed {} lines of WikiData JSON dump".format(cnt))
|
||||||
|
return title_to_id, id_to_descr, id_to_alias
|
||||||
|
|
||||||
|
|
||||||
def write_entity_files(entity_def_output, title_to_id):
|
|
||||||
with entity_def_output.open("w", encoding="utf8") as id_file:
|
|
||||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
|
||||||
for title, qid in title_to_id.items():
|
|
||||||
id_file.write(title + "|" + str(qid) + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def write_entity_description_files(entity_descr_output, id_to_descr):
|
|
||||||
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
|
||||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
|
||||||
for qid, descr in id_to_descr.items():
|
|
||||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
|
||||||
|
|
|
@ -6,19 +6,19 @@ as created by the script `wikidata_create_kb`.
|
||||||
|
|
||||||
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
|
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
|
||||||
from https://dumps.wikimedia.org/enwiki/latest/
|
from https://dumps.wikimedia.org/enwiki/latest/
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
|
import spacy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import plac
|
import plac
|
||||||
|
|
||||||
from bin.wiki_entity_linking import training_set_creator
|
from bin.wiki_entity_linking import wikipedia_processor
|
||||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
||||||
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
|
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance
|
||||||
from bin.wiki_entity_linking.kb_creator import read_nlp_kb
|
from bin.wiki_entity_linking.kb_creator import read_kb
|
||||||
|
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
|
||||||
l2=("L2 regularization", "option", "r", float),
|
l2=("L2 regularization", "option", "r", float),
|
||||||
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
||||||
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
||||||
|
labels_discard=("NER labels to discard (default None)", "option", "l", str),
|
||||||
)
|
)
|
||||||
def main(
|
def main(
|
||||||
dir_kb,
|
dir_kb,
|
||||||
|
@ -46,13 +47,14 @@ def main(
|
||||||
l2=1e-6,
|
l2=1e-6,
|
||||||
train_inst=None,
|
train_inst=None,
|
||||||
dev_inst=None,
|
dev_inst=None,
|
||||||
|
labels_discard=None
|
||||||
):
|
):
|
||||||
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
||||||
|
|
||||||
output_dir = Path(output_dir) if output_dir else dir_kb
|
output_dir = Path(output_dir) if output_dir else dir_kb
|
||||||
training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
|
training_path = loc_training if loc_training else dir_kb / TRAINING_DATA_FILE
|
||||||
nlp_dir = dir_kb / KB_MODEL_DIR
|
nlp_dir = dir_kb / KB_MODEL_DIR
|
||||||
kb_path = output_dir / KB_FILE
|
kb_path = dir_kb / KB_FILE
|
||||||
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
|
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
|
||||||
|
|
||||||
# STEP 0: set up IO
|
# STEP 0: set up IO
|
||||||
|
@ -60,38 +62,47 @@ def main(
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
|
|
||||||
# STEP 1 : load the NLP object
|
# STEP 1 : load the NLP object
|
||||||
logger.info("STEP 1: loading model from {}".format(nlp_dir))
|
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
|
||||||
nlp, kb = read_nlp_kb(nlp_dir, kb_path)
|
nlp = spacy.load(nlp_dir)
|
||||||
|
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
||||||
|
kb = read_kb(nlp, kb_path)
|
||||||
|
|
||||||
# check that there is a NER component in the pipeline
|
# check that there is a NER component in the pipeline
|
||||||
if "ner" not in nlp.pipe_names:
|
if "ner" not in nlp.pipe_names:
|
||||||
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
|
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
|
||||||
|
|
||||||
# STEP 2: create a training dataset from WP
|
# STEP 2: read the training dataset previously created from WP
|
||||||
logger.info("STEP 2: reading training dataset from {}".format(training_path))
|
logger.info("STEP 2: Reading training dataset from {}".format(training_path))
|
||||||
|
|
||||||
train_data = training_set_creator.read_training(
|
if labels_discard:
|
||||||
|
labels_discard = [x.strip() for x in labels_discard.split(",")]
|
||||||
|
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
|
||||||
|
|
||||||
|
train_data = wikipedia_processor.read_training(
|
||||||
nlp=nlp,
|
nlp=nlp,
|
||||||
entity_file_path=training_path,
|
entity_file_path=training_path,
|
||||||
dev=False,
|
dev=False,
|
||||||
limit=train_inst,
|
limit=train_inst,
|
||||||
kb=kb,
|
kb=kb,
|
||||||
|
labels_discard=labels_discard
|
||||||
)
|
)
|
||||||
|
|
||||||
# for testing, get all pos instances, whether or not they are in the kb
|
# for testing, get all pos instances (independently of KB)
|
||||||
dev_data = training_set_creator.read_training(
|
dev_data = wikipedia_processor.read_training(
|
||||||
nlp=nlp,
|
nlp=nlp,
|
||||||
entity_file_path=training_path,
|
entity_file_path=training_path,
|
||||||
dev=True,
|
dev=True,
|
||||||
limit=dev_inst,
|
limit=dev_inst,
|
||||||
kb=kb,
|
kb=None,
|
||||||
|
labels_discard=labels_discard
|
||||||
)
|
)
|
||||||
|
|
||||||
# STEP 3: create and train the entity linking pipe
|
# STEP 3: create and train an entity linking pipe
|
||||||
logger.info("STEP 3: training Entity Linking pipe")
|
logger.info("STEP 3: Creating and training an Entity Linking pipe")
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(
|
el_pipe = nlp.create_pipe(
|
||||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
|
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
|
||||||
|
"labels_discard": labels_discard}
|
||||||
)
|
)
|
||||||
el_pipe.set_kb(kb)
|
el_pipe.set_kb(kb)
|
||||||
nlp.add_pipe(el_pipe, last=True)
|
nlp.add_pipe(el_pipe, last=True)
|
||||||
|
@ -105,14 +116,9 @@ def main(
|
||||||
logger.info("Training on {} articles".format(len(train_data)))
|
logger.info("Training on {} articles".format(len(train_data)))
|
||||||
logger.info("Dev testing on {} articles".format(len(dev_data)))
|
logger.info("Dev testing on {} articles".format(len(dev_data)))
|
||||||
|
|
||||||
dev_baseline_accuracies = measure_baselines(
|
# baseline performance on dev data
|
||||||
dev_data, kb
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Dev Baseline Accuracies:")
|
logger.info("Dev Baseline Accuracies:")
|
||||||
logger.info(dev_baseline_accuracies.report_accuracy("random"))
|
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
|
||||||
logger.info(dev_baseline_accuracies.report_accuracy("prior"))
|
|
||||||
logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
|
|
||||||
|
|
||||||
for itn in range(epochs):
|
for itn in range(epochs):
|
||||||
random.shuffle(train_data)
|
random.shuffle(train_data)
|
||||||
|
@ -136,18 +142,18 @@ def main(
|
||||||
logger.error("Error updating batch:" + str(e))
|
logger.error("Error updating batch:" + str(e))
|
||||||
if batchnr > 0:
|
if batchnr > 0:
|
||||||
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
|
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
|
||||||
measure_performance(dev_data, kb, el_pipe)
|
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
|
||||||
|
|
||||||
# STEP 4: measure the performance of our trained pipe on an independent dev set
|
# STEP 4: measure the performance of our trained pipe on an independent dev set
|
||||||
logger.info("STEP 4: performance measurement of Entity Linking pipe")
|
logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
|
||||||
measure_performance(dev_data, kb, el_pipe)
|
measure_performance(dev_data, kb, el_pipe)
|
||||||
|
|
||||||
# STEP 5: apply the EL pipe on a toy example
|
# STEP 5: apply the EL pipe on a toy example
|
||||||
logger.info("STEP 5: applying Entity Linking to toy example")
|
logger.info("STEP 5: Applying Entity Linking to toy example")
|
||||||
run_el_toy_example(nlp=nlp)
|
run_el_toy_example(nlp=nlp)
|
||||||
|
|
||||||
if output_dir:
|
if output_dir:
|
||||||
# STEP 6: write the NLP pipeline (including entity linker) to file
|
# STEP 6: write the NLP pipeline (now including an EL model) to file
|
||||||
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
|
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
|
||||||
nlp.to_disk(nlp_output_dir)
|
nlp.to_disk(nlp_output_dir)
|
||||||
|
|
||||||
|
|
|
@ -3,147 +3,104 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import bz2
|
import bz2
|
||||||
import csv
|
|
||||||
import datetime
|
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
|
||||||
from bin.wiki_entity_linking import LOG_FORMAT
|
from functools import partial
|
||||||
|
|
||||||
|
from spacy.gold import GoldParse
|
||||||
|
from bin.wiki_entity_linking import wiki_io as io
|
||||||
|
from bin.wiki_entity_linking.wiki_namespaces import (
|
||||||
|
WP_META_NAMESPACE,
|
||||||
|
WP_FILE_NAMESPACE,
|
||||||
|
WP_CATEGORY_NAMESPACE,
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
|
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
|
||||||
Write these results to file for downstream KB and training data generation.
|
Write these results to file for downstream KB and training data generation.
|
||||||
|
|
||||||
|
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
ENTITY_FILE = "gold_entities.csv"
|
||||||
|
|
||||||
map_alias_to_link = dict()
|
map_alias_to_link = dict()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
|
||||||
# these will/should be matched ignoring case
|
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
||||||
wiki_namespaces = [
|
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
||||||
"b",
|
info_regex = re.compile(r"{[^{]*?}")
|
||||||
"betawikiversity",
|
html_regex = re.compile(r"<!--[^-]*-->")
|
||||||
"Book",
|
ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
||||||
"c",
|
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
||||||
"Category",
|
|
||||||
"Commons",
|
|
||||||
"d",
|
|
||||||
"dbdump",
|
|
||||||
"download",
|
|
||||||
"Draft",
|
|
||||||
"Education",
|
|
||||||
"Foundation",
|
|
||||||
"Gadget",
|
|
||||||
"Gadget definition",
|
|
||||||
"gerrit",
|
|
||||||
"File",
|
|
||||||
"Help",
|
|
||||||
"Image",
|
|
||||||
"Incubator",
|
|
||||||
"m",
|
|
||||||
"mail",
|
|
||||||
"mailarchive",
|
|
||||||
"media",
|
|
||||||
"MediaWiki",
|
|
||||||
"MediaWiki talk",
|
|
||||||
"Mediawikiwiki",
|
|
||||||
"MediaZilla",
|
|
||||||
"Meta",
|
|
||||||
"Metawikipedia",
|
|
||||||
"Module",
|
|
||||||
"mw",
|
|
||||||
"n",
|
|
||||||
"nost",
|
|
||||||
"oldwikisource",
|
|
||||||
"outreach",
|
|
||||||
"outreachwiki",
|
|
||||||
"otrs",
|
|
||||||
"OTRSwiki",
|
|
||||||
"Portal",
|
|
||||||
"phab",
|
|
||||||
"Phabricator",
|
|
||||||
"Project",
|
|
||||||
"q",
|
|
||||||
"quality",
|
|
||||||
"rev",
|
|
||||||
"s",
|
|
||||||
"spcom",
|
|
||||||
"Special",
|
|
||||||
"species",
|
|
||||||
"Strategy",
|
|
||||||
"sulutil",
|
|
||||||
"svn",
|
|
||||||
"Talk",
|
|
||||||
"Template",
|
|
||||||
"Template talk",
|
|
||||||
"Testwiki",
|
|
||||||
"ticket",
|
|
||||||
"TimedText",
|
|
||||||
"Toollabs",
|
|
||||||
"tools",
|
|
||||||
"tswiki",
|
|
||||||
"User",
|
|
||||||
"User talk",
|
|
||||||
"v",
|
|
||||||
"voy",
|
|
||||||
"w",
|
|
||||||
"Wikibooks",
|
|
||||||
"Wikidata",
|
|
||||||
"wikiHow",
|
|
||||||
"Wikinvest",
|
|
||||||
"wikilivres",
|
|
||||||
"Wikimedia",
|
|
||||||
"Wikinews",
|
|
||||||
"Wikipedia",
|
|
||||||
"Wikipedia talk",
|
|
||||||
"Wikiquote",
|
|
||||||
"Wikisource",
|
|
||||||
"Wikispecies",
|
|
||||||
"Wikitech",
|
|
||||||
"Wikiversity",
|
|
||||||
"Wikivoyage",
|
|
||||||
"wikt",
|
|
||||||
"wiktionary",
|
|
||||||
"wmf",
|
|
||||||
"wmania",
|
|
||||||
"WP",
|
|
||||||
]
|
|
||||||
|
|
||||||
# find the links
|
# find the links
|
||||||
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
|
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
|
||||||
|
|
||||||
# match on interwiki links, e.g. `en:` or `:fr:`
|
# match on interwiki links, e.g. `en:` or `:fr:`
|
||||||
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
||||||
|
|
||||||
# match on Namespace: optionally preceded by a :
|
# match on Namespace: optionally preceded by a :
|
||||||
for ns in wiki_namespaces:
|
for ns in WP_META_NAMESPACE:
|
||||||
ns_regex += "|" + ":?" + ns + ":"
|
ns_regex += "|" + ":?" + ns + ":"
|
||||||
|
|
||||||
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
||||||
|
|
||||||
|
files = r""
|
||||||
|
for f in WP_FILE_NAMESPACE:
|
||||||
|
files += "\[\[" + f + ":[^[\]]+]]" + "|"
|
||||||
|
files = files[0 : len(files) - 1]
|
||||||
|
file_regex = re.compile(files)
|
||||||
|
|
||||||
|
cats = r""
|
||||||
|
for c in WP_CATEGORY_NAMESPACE:
|
||||||
|
cats += "\[\[" + c + ":[^\[]*]]" + "|"
|
||||||
|
cats = cats[0 : len(cats) - 1]
|
||||||
|
category_regex = re.compile(cats)
|
||||||
|
|
||||||
|
|
||||||
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
||||||
"""
|
"""
|
||||||
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
||||||
The full file takes about 2h to parse 1100M lines.
|
The full file takes about 2-3h to parse 1100M lines.
|
||||||
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
|
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from,
|
||||||
|
though dev test articles are excluded in order not to get an artificially strong baseline.
|
||||||
"""
|
"""
|
||||||
|
cnt = 0
|
||||||
|
read_id = False
|
||||||
|
current_article_id = None
|
||||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||||
line = file.readline()
|
line = file.readline()
|
||||||
cnt = 0
|
|
||||||
while line and (not limit or cnt < limit):
|
while line and (not limit or cnt < limit):
|
||||||
if cnt % 25000000 == 0:
|
if cnt % 25000000 == 0 and cnt > 0:
|
||||||
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
||||||
clean_line = line.strip().decode("utf-8")
|
clean_line = line.strip().decode("utf-8")
|
||||||
|
|
||||||
|
# we attempt at reading the article's ID (but not the revision or contributor ID)
|
||||||
|
if "<revision>" in clean_line or "<contributor>" in clean_line:
|
||||||
|
read_id = False
|
||||||
|
if "<page>" in clean_line:
|
||||||
|
read_id = True
|
||||||
|
|
||||||
|
if read_id:
|
||||||
|
ids = id_regex.search(clean_line)
|
||||||
|
if ids:
|
||||||
|
current_article_id = ids[0]
|
||||||
|
|
||||||
|
# only processing prior probabilities from true training (non-dev) articles
|
||||||
|
if not is_dev(current_article_id):
|
||||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||||
for alias, entity, norm in zip(aliases, entities, normalizations):
|
for alias, entity, norm in zip(aliases, entities, normalizations):
|
||||||
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
|
_store_alias(
|
||||||
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
|
alias, entity, normalize_alias=norm, normalize_entity=True
|
||||||
|
)
|
||||||
|
|
||||||
line = file.readline()
|
line = file.readline()
|
||||||
cnt += 1
|
cnt += 1
|
||||||
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
||||||
|
logger.info("Finished. processed {} lines of Wikipedia XML dump".format(cnt))
|
||||||
|
|
||||||
# write all aliases and their entities and count occurrences to file
|
# write all aliases and their entities and count occurrences to file
|
||||||
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
||||||
|
@ -182,7 +139,7 @@ def get_wp_links(text):
|
||||||
match = match[2:][:-2].replace("_", " ").strip()
|
match = match[2:][:-2].replace("_", " ").strip()
|
||||||
|
|
||||||
if ns_regex.match(match):
|
if ns_regex.match(match):
|
||||||
pass # ignore namespaces at the beginning of the string
|
pass # ignore the entity if it points to a "meta" page
|
||||||
|
|
||||||
# this is a simple [[link]], with the alias the same as the mention
|
# this is a simple [[link]], with the alias the same as the mention
|
||||||
elif "|" not in match:
|
elif "|" not in match:
|
||||||
|
@ -218,47 +175,382 @@ def _capitalize_first(text):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
def create_training_and_desc(
|
||||||
# Write entity counts for quick access later
|
wp_input, def_input, desc_output, training_output, parse_desc, limit=None
|
||||||
entity_to_count = dict()
|
):
|
||||||
total_count = 0
|
wp_to_id = io.read_title_to_id(def_input)
|
||||||
|
_process_wikipedia_texts(
|
||||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
wp_input, wp_to_id, desc_output, training_output, parse_desc, limit
|
||||||
# skip header
|
)
|
||||||
prior_file.readline()
|
|
||||||
line = prior_file.readline()
|
|
||||||
|
|
||||||
while line:
|
|
||||||
splits = line.replace("\n", "").split(sep="|")
|
|
||||||
# alias = splits[0]
|
|
||||||
count = int(splits[1])
|
|
||||||
entity = splits[2]
|
|
||||||
|
|
||||||
current_count = entity_to_count.get(entity, 0)
|
|
||||||
entity_to_count[entity] = current_count + count
|
|
||||||
|
|
||||||
total_count += count
|
|
||||||
|
|
||||||
line = prior_file.readline()
|
|
||||||
|
|
||||||
with count_output.open("w", encoding="utf8") as entity_file:
|
|
||||||
entity_file.write("entity" + "|" + "count" + "\n")
|
|
||||||
for entity, count in entity_to_count.items():
|
|
||||||
entity_file.write(entity + "|" + str(count) + "\n")
|
|
||||||
|
|
||||||
if to_print:
|
|
||||||
for entity, count in entity_to_count.items():
|
|
||||||
print("Entity count:", entity, count)
|
|
||||||
print("Total count:", total_count)
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_frequencies(count_input):
|
def _process_wikipedia_texts(
|
||||||
entity_to_count = dict()
|
wikipedia_input, wp_to_id, output, training_output, parse_descriptions, limit=None
|
||||||
with count_input.open("r", encoding="utf8") as csvfile:
|
):
|
||||||
csvreader = csv.reader(csvfile, delimiter="|")
|
"""
|
||||||
# skip header
|
Read the XML wikipedia data to parse out training data:
|
||||||
next(csvreader)
|
raw text data + positive instances
|
||||||
for row in csvreader:
|
"""
|
||||||
entity_to_count[row[0]] = int(row[1])
|
|
||||||
|
|
||||||
return entity_to_count
|
read_ids = set()
|
||||||
|
|
||||||
|
with output.open("a", encoding="utf8") as descr_file, training_output.open(
|
||||||
|
"w", encoding="utf8"
|
||||||
|
) as entity_file:
|
||||||
|
if parse_descriptions:
|
||||||
|
_write_training_description(descr_file, "WD_id", "description")
|
||||||
|
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||||
|
article_count = 0
|
||||||
|
article_text = ""
|
||||||
|
article_title = None
|
||||||
|
article_id = None
|
||||||
|
reading_text = False
|
||||||
|
reading_revision = False
|
||||||
|
|
||||||
|
for line in file:
|
||||||
|
clean_line = line.strip().decode("utf-8")
|
||||||
|
|
||||||
|
if clean_line == "<revision>":
|
||||||
|
reading_revision = True
|
||||||
|
elif clean_line == "</revision>":
|
||||||
|
reading_revision = False
|
||||||
|
|
||||||
|
# Start reading new page
|
||||||
|
if clean_line == "<page>":
|
||||||
|
article_text = ""
|
||||||
|
article_title = None
|
||||||
|
article_id = None
|
||||||
|
# finished reading this page
|
||||||
|
elif clean_line == "</page>":
|
||||||
|
if article_id:
|
||||||
|
clean_text, entities = _process_wp_text(
|
||||||
|
article_title, article_text, wp_to_id
|
||||||
|
)
|
||||||
|
if clean_text is not None and entities is not None:
|
||||||
|
_write_training_entities(
|
||||||
|
entity_file, article_id, clean_text, entities
|
||||||
|
)
|
||||||
|
|
||||||
|
if article_title in wp_to_id and parse_descriptions:
|
||||||
|
description = " ".join(
|
||||||
|
clean_text[:1000].split(" ")[:-1]
|
||||||
|
)
|
||||||
|
_write_training_description(
|
||||||
|
descr_file, wp_to_id[article_title], description
|
||||||
|
)
|
||||||
|
article_count += 1
|
||||||
|
if article_count % 10000 == 0 and article_count > 0:
|
||||||
|
logger.info(
|
||||||
|
"Processed {} articles".format(article_count)
|
||||||
|
)
|
||||||
|
if limit and article_count >= limit:
|
||||||
|
break
|
||||||
|
article_text = ""
|
||||||
|
article_title = None
|
||||||
|
article_id = None
|
||||||
|
reading_text = False
|
||||||
|
reading_revision = False
|
||||||
|
|
||||||
|
# start reading text within a page
|
||||||
|
if "<text" in clean_line:
|
||||||
|
reading_text = True
|
||||||
|
|
||||||
|
if reading_text:
|
||||||
|
article_text += " " + clean_line
|
||||||
|
|
||||||
|
# stop reading text within a page (we assume a new page doesn't start on the same line)
|
||||||
|
if "</text" in clean_line:
|
||||||
|
reading_text = False
|
||||||
|
|
||||||
|
# read the ID of this article (outside the revision portion of the document)
|
||||||
|
if not reading_revision:
|
||||||
|
ids = id_regex.search(clean_line)
|
||||||
|
if ids:
|
||||||
|
article_id = ids[0]
|
||||||
|
if article_id in read_ids:
|
||||||
|
logger.info(
|
||||||
|
"Found duplicate article ID", article_id, clean_line
|
||||||
|
) # This should never happen ...
|
||||||
|
read_ids.add(article_id)
|
||||||
|
|
||||||
|
# read the title of this article (outside the revision portion of the document)
|
||||||
|
if not reading_revision:
|
||||||
|
titles = title_regex.search(clean_line)
|
||||||
|
if titles:
|
||||||
|
article_title = titles[0].strip()
|
||||||
|
logger.info("Finished. Processed {} articles".format(article_count))
|
||||||
|
|
||||||
|
|
||||||
|
def _process_wp_text(article_title, article_text, wp_to_id):
|
||||||
|
# ignore meta Wikipedia pages
|
||||||
|
if ns_regex.match(article_title):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# remove the text tags
|
||||||
|
text_search = text_regex.search(article_text)
|
||||||
|
if text_search is None:
|
||||||
|
return None, None
|
||||||
|
text = text_search.group(0)
|
||||||
|
|
||||||
|
# stop processing if this is a redirect page
|
||||||
|
if text.startswith("#REDIRECT"):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# get the raw text without markup etc, keeping only interwiki links
|
||||||
|
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
|
||||||
|
return clean_text, entities
|
||||||
|
|
||||||
|
|
||||||
|
def _get_clean_wp_text(article_text):
|
||||||
|
clean_text = article_text.strip()
|
||||||
|
|
||||||
|
# remove bolding & italic markup
|
||||||
|
clean_text = clean_text.replace("'''", "")
|
||||||
|
clean_text = clean_text.replace("''", "")
|
||||||
|
|
||||||
|
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
||||||
|
try_again = True
|
||||||
|
previous_length = len(clean_text)
|
||||||
|
while try_again:
|
||||||
|
clean_text = info_regex.sub(
|
||||||
|
"", clean_text
|
||||||
|
) # non-greedy match excluding a nested {
|
||||||
|
if len(clean_text) < previous_length:
|
||||||
|
try_again = True
|
||||||
|
else:
|
||||||
|
try_again = False
|
||||||
|
previous_length = len(clean_text)
|
||||||
|
|
||||||
|
# remove HTML comments
|
||||||
|
clean_text = html_regex.sub("", clean_text)
|
||||||
|
|
||||||
|
# remove Category and File statements
|
||||||
|
clean_text = category_regex.sub("", clean_text)
|
||||||
|
clean_text = file_regex.sub("", clean_text)
|
||||||
|
|
||||||
|
# remove multiple =
|
||||||
|
while "==" in clean_text:
|
||||||
|
clean_text = clean_text.replace("==", "=")
|
||||||
|
|
||||||
|
clean_text = clean_text.replace(". =", ".")
|
||||||
|
clean_text = clean_text.replace(" = ", ". ")
|
||||||
|
clean_text = clean_text.replace("= ", ".")
|
||||||
|
clean_text = clean_text.replace(" =", "")
|
||||||
|
|
||||||
|
# remove refs (non-greedy match)
|
||||||
|
clean_text = ref_regex.sub("", clean_text)
|
||||||
|
clean_text = ref_2_regex.sub("", clean_text)
|
||||||
|
|
||||||
|
# remove additional wikiformatting
|
||||||
|
clean_text = re.sub(r"<blockquote>", "", clean_text)
|
||||||
|
clean_text = re.sub(r"</blockquote>", "", clean_text)
|
||||||
|
|
||||||
|
# change special characters back to normal ones
|
||||||
|
clean_text = clean_text.replace(r"<", "<")
|
||||||
|
clean_text = clean_text.replace(r">", ">")
|
||||||
|
clean_text = clean_text.replace(r""", '"')
|
||||||
|
clean_text = clean_text.replace(r"&nbsp;", " ")
|
||||||
|
clean_text = clean_text.replace(r"&", "&")
|
||||||
|
|
||||||
|
# remove multiple spaces
|
||||||
|
while " " in clean_text:
|
||||||
|
clean_text = clean_text.replace(" ", " ")
|
||||||
|
|
||||||
|
return clean_text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_links(clean_text, wp_to_id):
|
||||||
|
# read the text char by char to get the right offsets for the interwiki links
|
||||||
|
entities = []
|
||||||
|
final_text = ""
|
||||||
|
open_read = 0
|
||||||
|
reading_text = True
|
||||||
|
reading_entity = False
|
||||||
|
reading_mention = False
|
||||||
|
reading_special_case = False
|
||||||
|
entity_buffer = ""
|
||||||
|
mention_buffer = ""
|
||||||
|
for index, letter in enumerate(clean_text):
|
||||||
|
if letter == "[":
|
||||||
|
open_read += 1
|
||||||
|
elif letter == "]":
|
||||||
|
open_read -= 1
|
||||||
|
elif letter == "|":
|
||||||
|
if reading_text:
|
||||||
|
final_text += letter
|
||||||
|
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||||
|
elif reading_entity:
|
||||||
|
reading_text = False
|
||||||
|
reading_entity = False
|
||||||
|
reading_mention = True
|
||||||
|
else:
|
||||||
|
reading_special_case = True
|
||||||
|
else:
|
||||||
|
if reading_entity:
|
||||||
|
entity_buffer += letter
|
||||||
|
elif reading_mention:
|
||||||
|
mention_buffer += letter
|
||||||
|
elif reading_text:
|
||||||
|
final_text += letter
|
||||||
|
else:
|
||||||
|
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
|
||||||
|
|
||||||
|
if open_read > 2:
|
||||||
|
reading_special_case = True
|
||||||
|
|
||||||
|
if open_read == 2 and reading_text:
|
||||||
|
reading_text = False
|
||||||
|
reading_entity = True
|
||||||
|
reading_mention = False
|
||||||
|
|
||||||
|
# we just finished reading an entity
|
||||||
|
if open_read == 0 and not reading_text:
|
||||||
|
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||||
|
reading_special_case = True
|
||||||
|
# Ignore cases with nested structures like File: handles etc
|
||||||
|
if not reading_special_case:
|
||||||
|
if not mention_buffer:
|
||||||
|
mention_buffer = entity_buffer
|
||||||
|
start = len(final_text)
|
||||||
|
end = start + len(mention_buffer)
|
||||||
|
qid = wp_to_id.get(entity_buffer, None)
|
||||||
|
if qid:
|
||||||
|
entities.append((mention_buffer, qid, start, end))
|
||||||
|
final_text += mention_buffer
|
||||||
|
|
||||||
|
entity_buffer = ""
|
||||||
|
mention_buffer = ""
|
||||||
|
|
||||||
|
reading_text = True
|
||||||
|
reading_entity = False
|
||||||
|
reading_mention = False
|
||||||
|
reading_special_case = False
|
||||||
|
return final_text, entities
|
||||||
|
|
||||||
|
|
||||||
|
def _write_training_description(outputfile, qid, description):
|
||||||
|
if description is not None:
|
||||||
|
line = str(qid) + "|" + description + "\n"
|
||||||
|
outputfile.write(line)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_training_entities(outputfile, article_id, clean_text, entities):
|
||||||
|
entities_data = [
|
||||||
|
{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]}
|
||||||
|
for ent in entities
|
||||||
|
]
|
||||||
|
line = (
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"article_id": article_id,
|
||||||
|
"clean_text": clean_text,
|
||||||
|
"entities": entities_data,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
outputfile.write(line)
|
||||||
|
|
||||||
|
|
||||||
|
def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
|
||||||
|
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||||
|
For training, it will include both positive and negative examples by using the candidate generator from the kb.
|
||||||
|
For testing (kb=None), it will include all positive examples only."""
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
if not labels_discard:
|
||||||
|
labels_discard = []
|
||||||
|
|
||||||
|
data = []
|
||||||
|
num_entities = 0
|
||||||
|
get_gold_parse = partial(
|
||||||
|
_get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Reading {} data with limit {}".format("dev" if dev else "train", limit)
|
||||||
|
)
|
||||||
|
with entity_file_path.open("r", encoding="utf8") as file:
|
||||||
|
with tqdm(total=limit, leave=False) as pbar:
|
||||||
|
for i, line in enumerate(file):
|
||||||
|
example = json.loads(line)
|
||||||
|
article_id = example["article_id"]
|
||||||
|
clean_text = example["clean_text"]
|
||||||
|
entities = example["entities"]
|
||||||
|
|
||||||
|
if dev != is_dev(article_id) or not is_valid_article(clean_text):
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc = nlp(clean_text)
|
||||||
|
gold = get_gold_parse(doc, entities)
|
||||||
|
if gold and len(gold.links) > 0:
|
||||||
|
data.append((doc, gold))
|
||||||
|
num_entities += len(gold.links)
|
||||||
|
pbar.update(len(gold.links))
|
||||||
|
if limit and num_entities >= limit:
|
||||||
|
break
|
||||||
|
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
|
||||||
|
gold_entities = {}
|
||||||
|
tagged_ent_positions = {
|
||||||
|
(ent.start_char, ent.end_char): ent
|
||||||
|
for ent in doc.ents
|
||||||
|
if ent.label_ not in labels_discard
|
||||||
|
}
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
entity_id = entity["entity"]
|
||||||
|
alias = entity["alias"]
|
||||||
|
start = entity["start"]
|
||||||
|
end = entity["end"]
|
||||||
|
|
||||||
|
candidate_ids = []
|
||||||
|
if kb and not dev:
|
||||||
|
candidates = kb.get_candidates(alias)
|
||||||
|
candidate_ids = [cand.entity_ for cand in candidates]
|
||||||
|
|
||||||
|
tagged_ent = tagged_ent_positions.get((start, end), None)
|
||||||
|
if tagged_ent:
|
||||||
|
# TODO: check that alias == doc.text[start:end]
|
||||||
|
should_add_ent = (dev or entity_id in candidate_ids) and is_valid_sentence(
|
||||||
|
tagged_ent.sent.text
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_add_ent:
|
||||||
|
value_by_id = {entity_id: 1.0}
|
||||||
|
if not dev:
|
||||||
|
random.shuffle(candidate_ids)
|
||||||
|
value_by_id.update(
|
||||||
|
{kb_id: 0.0 for kb_id in candidate_ids if kb_id != entity_id}
|
||||||
|
)
|
||||||
|
gold_entities[(start, end)] = value_by_id
|
||||||
|
|
||||||
|
return GoldParse(doc, links=gold_entities)
|
||||||
|
|
||||||
|
|
||||||
|
def is_dev(article_id):
|
||||||
|
if not article_id:
|
||||||
|
return False
|
||||||
|
return article_id.endswith("3")
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_article(doc_text):
|
||||||
|
# custom length cut-off
|
||||||
|
return 10 < len(doc_text) < 30000
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_sentence(sent_text):
|
||||||
|
if not 10 < len(sent_text) < 3000:
|
||||||
|
# custom length cut-off
|
||||||
|
return False
|
||||||
|
|
||||||
|
if sent_text.strip().startswith("*") or sent_text.strip().startswith("#"):
|
||||||
|
# remove 'enumeration' sentences (occurs often on Wikipedia)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
|
@ -7,7 +7,7 @@ dependency tree to find the noun phrase they are referring to – for example:
|
||||||
$9.4 million --> Net income.
|
$9.4 million --> Net income.
|
||||||
|
|
||||||
Compatible with: spaCy v2.0.0+
|
Compatible with: spaCy v2.0.0+
|
||||||
Last tested with: v2.1.0
|
Last tested with: v2.2.1
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals, print_function
|
from __future__ import unicode_literals, print_function
|
||||||
|
|
||||||
|
@ -38,14 +38,17 @@ def main(model="en_core_web_sm"):
|
||||||
|
|
||||||
def filter_spans(spans):
|
def filter_spans(spans):
|
||||||
# Filter a sequence of spans so they don't contain overlaps
|
# Filter a sequence of spans so they don't contain overlaps
|
||||||
get_sort_key = lambda span: (span.end - span.start, span.start)
|
# For spaCy 2.1.4+: this function is available as spacy.util.filter_spans()
|
||||||
|
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
||||||
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||||
result = []
|
result = []
|
||||||
seen_tokens = set()
|
seen_tokens = set()
|
||||||
for span in sorted_spans:
|
for span in sorted_spans:
|
||||||
|
# Check for end - 1 here because boundaries are inclusive
|
||||||
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
|
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
|
||||||
result.append(span)
|
result.append(span)
|
||||||
seen_tokens.update(range(span.start, span.end))
|
seen_tokens.update(range(span.start, span.end))
|
||||||
|
result = sorted(result, key=lambda span: span.start)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -91,8 +91,8 @@ def demo(shape):
|
||||||
nlp = spacy.load("en_vectors_web_lg")
|
nlp = spacy.load("en_vectors_web_lg")
|
||||||
nlp.add_pipe(KerasSimilarityShim.load(nlp.path / "similarity", nlp, shape[0]))
|
nlp.add_pipe(KerasSimilarityShim.load(nlp.path / "similarity", nlp, shape[0]))
|
||||||
|
|
||||||
doc1 = nlp(u"The king of France is bald.")
|
doc1 = nlp("The king of France is bald.")
|
||||||
doc2 = nlp(u"France has no king.")
|
doc2 = nlp("France has no king.")
|
||||||
|
|
||||||
print("Sentence 1:", doc1)
|
print("Sentence 1:", doc1)
|
||||||
print("Sentence 2:", doc2)
|
print("Sentence 2:", doc2)
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
{
|
{
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"head": 4,
|
"head": 44,
|
||||||
"dep": "prep",
|
"dep": "prep",
|
||||||
"tag": "IN",
|
"tag": "IN",
|
||||||
"orth": "In",
|
"orth": "In",
|
||||||
|
|
|
@ -11,7 +11,7 @@ numpy>=1.15.0
|
||||||
requests>=2.13.0,<3.0.0
|
requests>=2.13.0,<3.0.0
|
||||||
plac<1.0.0,>=0.9.6
|
plac<1.0.0,>=0.9.6
|
||||||
pathlib==1.0.1; python_version < "3.4"
|
pathlib==1.0.1; python_version < "3.4"
|
||||||
importlib_metadata>=0.23; python_version < "3.8"
|
importlib_metadata>=0.20; python_version < "3.8"
|
||||||
# Optional dependencies
|
# Optional dependencies
|
||||||
jsonschema>=2.6.0,<3.1.0
|
jsonschema>=2.6.0,<3.1.0
|
||||||
# Development dependencies
|
# Development dependencies
|
||||||
|
|
|
@ -51,7 +51,7 @@ install_requires =
|
||||||
wasabi>=0.2.0,<1.1.0
|
wasabi>=0.2.0,<1.1.0
|
||||||
srsly>=0.1.0,<1.1.0
|
srsly>=0.1.0,<1.1.0
|
||||||
pathlib==1.0.1; python_version < "3.4"
|
pathlib==1.0.1; python_version < "3.4"
|
||||||
importlib_metadata>=0.23; python_version < "3.8"
|
importlib_metadata>=0.20; python_version < "3.8"
|
||||||
|
|
||||||
[options.extras_require]
|
[options.extras_require]
|
||||||
lookups =
|
lookups =
|
||||||
|
|
|
@ -57,7 +57,8 @@ def convert(
|
||||||
is written to stdout, so you can pipe them forward to a JSON file:
|
is written to stdout, so you can pipe them forward to a JSON file:
|
||||||
$ spacy convert some_file.conllu > some_file.json
|
$ spacy convert some_file.conllu > some_file.json
|
||||||
"""
|
"""
|
||||||
msg = Printer()
|
no_print = (output_dir == "-")
|
||||||
|
msg = Printer(no_print=no_print)
|
||||||
input_path = Path(input_file)
|
input_path = Path(input_file)
|
||||||
if file_type not in FILE_TYPES:
|
if file_type not in FILE_TYPES:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
|
@ -102,6 +103,7 @@ def convert(
|
||||||
use_morphology=morphology,
|
use_morphology=morphology,
|
||||||
lang=lang,
|
lang=lang,
|
||||||
model=model,
|
model=model,
|
||||||
|
no_print=no_print,
|
||||||
)
|
)
|
||||||
if output_dir != "-":
|
if output_dir != "-":
|
||||||
# Export data to a file
|
# Export data to a file
|
||||||
|
|
|
@ -9,7 +9,7 @@ from ...tokens.doc import Doc
|
||||||
from ...util import load_model
|
from ...util import load_model
|
||||||
|
|
||||||
|
|
||||||
def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, **kwargs):
|
def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, no_print=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Convert files in the CoNLL-2003 NER format and similar
|
Convert files in the CoNLL-2003 NER format and similar
|
||||||
whitespace-separated columns into JSON format for use with train cli.
|
whitespace-separated columns into JSON format for use with train cli.
|
||||||
|
@ -34,7 +34,7 @@ def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, **kwargs
|
||||||
. O
|
. O
|
||||||
|
|
||||||
"""
|
"""
|
||||||
msg = Printer()
|
msg = Printer(no_print=no_print)
|
||||||
doc_delimiter = "-DOCSTART- -X- O O"
|
doc_delimiter = "-DOCSTART- -X- O O"
|
||||||
# check for existing delimiters, which should be preserved
|
# check for existing delimiters, which should be preserved
|
||||||
if "\n\n" in input_data and seg_sents:
|
if "\n\n" in input_data and seg_sents:
|
||||||
|
|
|
@ -8,7 +8,7 @@ from ...util import minibatch
|
||||||
from .conll_ner2json import n_sents_info
|
from .conll_ner2json import n_sents_info
|
||||||
|
|
||||||
|
|
||||||
def iob2json(input_data, n_sents=10, *args, **kwargs):
|
def iob2json(input_data, n_sents=10, no_print=False, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Convert IOB files with one sentence per line and tags separated with '|'
|
Convert IOB files with one sentence per line and tags separated with '|'
|
||||||
into JSON format for use with train cli. IOB and IOB2 are accepted.
|
into JSON format for use with train cli. IOB and IOB2 are accepted.
|
||||||
|
@ -20,7 +20,7 @@ def iob2json(input_data, n_sents=10, *args, **kwargs):
|
||||||
I|PRP|O like|VBP|O London|NNP|I-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O
|
I|PRP|O like|VBP|O London|NNP|I-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O
|
||||||
I|PRP|O like|VBP|O London|NNP|B-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O
|
I|PRP|O like|VBP|O London|NNP|B-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O
|
||||||
"""
|
"""
|
||||||
msg = Printer()
|
msg = Printer(no_print=no_print)
|
||||||
docs = read_iob(input_data.split("\n"))
|
docs = read_iob(input_data.split("\n"))
|
||||||
if n_sents > 0:
|
if n_sents > 0:
|
||||||
n_sents_info(msg, n_sents)
|
n_sents_info(msg, n_sents)
|
||||||
|
|
|
@ -360,6 +360,16 @@ def debug_data(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# check for documents with multiple sentences
|
||||||
|
sents_per_doc = gold_train_data["n_sents"] / len(gold_train_data["texts"])
|
||||||
|
if sents_per_doc < 1.1:
|
||||||
|
msg.warn(
|
||||||
|
"The training data contains {:.2f} sentences per "
|
||||||
|
"document. When there are very few documents containing more "
|
||||||
|
"than one sentence, the parser will not learn how to segment "
|
||||||
|
"longer texts into sentences.".format(sents_per_doc)
|
||||||
|
)
|
||||||
|
|
||||||
# profile labels
|
# profile labels
|
||||||
labels_train = [label for label in gold_train_data["deps"]]
|
labels_train = [label for label in gold_train_data["deps"]]
|
||||||
labels_train_unpreprocessed = [
|
labels_train_unpreprocessed = [
|
||||||
|
|
|
@ -246,7 +246,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
|
||||||
"""Perform an update over a single batch of documents.
|
"""Perform an update over a single batch of documents.
|
||||||
|
|
||||||
docs (iterable): A batch of `Doc` objects.
|
docs (iterable): A batch of `Doc` objects.
|
||||||
drop (float): The droput rate.
|
drop (float): The dropout rate.
|
||||||
optimizer (callable): An optimizer.
|
optimizer (callable): An optimizer.
|
||||||
RETURNS loss: A float for the loss.
|
RETURNS loss: A float for the loss.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -80,8 +80,8 @@ class Warnings(object):
|
||||||
"the v2.x models cannot release the global interpreter lock. "
|
"the v2.x models cannot release the global interpreter lock. "
|
||||||
"Future versions may introduce a `n_process` argument for "
|
"Future versions may introduce a `n_process` argument for "
|
||||||
"parallel inference via multiprocessing.")
|
"parallel inference via multiprocessing.")
|
||||||
W017 = ("Alias '{alias}' already exists in the Knowledge base.")
|
W017 = ("Alias '{alias}' already exists in the Knowledge Base.")
|
||||||
W018 = ("Entity '{entity}' already exists in the Knowledge base.")
|
W018 = ("Entity '{entity}' already exists in the Knowledge Base.")
|
||||||
W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
|
W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
|
||||||
"previously loaded vectors. See Issue #3853.")
|
"previously loaded vectors. See Issue #3853.")
|
||||||
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
||||||
|
@ -95,7 +95,10 @@ class Warnings(object):
|
||||||
"you can ignore this warning by setting SPACY_WARNING_IGNORE=W022. "
|
"you can ignore this warning by setting SPACY_WARNING_IGNORE=W022. "
|
||||||
"If this is surprising, make sure you have the spacy-lookups-data "
|
"If this is surprising, make sure you have the spacy-lookups-data "
|
||||||
"package installed.")
|
"package installed.")
|
||||||
W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.")
|
W023 = ("Multiprocessing of Language.pipe is not supported in Python 2. "
|
||||||
|
"'n_process' will be set to 1.")
|
||||||
|
W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
|
||||||
|
"the Knowledge Base.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
@ -408,7 +411,7 @@ class Errors(object):
|
||||||
"{probabilities_length} respectively.")
|
"{probabilities_length} respectively.")
|
||||||
E133 = ("The sum of prior probabilities for alias '{alias}' should not "
|
E133 = ("The sum of prior probabilities for alias '{alias}' should not "
|
||||||
"exceed 1, but found {sum}.")
|
"exceed 1, but found {sum}.")
|
||||||
E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.")
|
E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
|
||||||
E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
|
E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
|
||||||
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
|
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
|
||||||
E136 = ("This additional feature requires the jsonschema library to be "
|
E136 = ("This additional feature requires the jsonschema library to be "
|
||||||
|
@ -420,7 +423,7 @@ class Errors(object):
|
||||||
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
|
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
|
||||||
"includes either the `text` or `tokens` key. For more info, see "
|
"includes either the `text` or `tokens` key. For more info, see "
|
||||||
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
|
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
|
||||||
E139 = ("Knowledge base for component '{name}' not initialized. Did you "
|
E139 = ("Knowledge Base for component '{name}' not initialized. Did you "
|
||||||
"forget to call set_kb()?")
|
"forget to call set_kb()?")
|
||||||
E140 = ("The list of entities, prior probabilities and entity vectors "
|
E140 = ("The list of entities, prior probabilities and entity vectors "
|
||||||
"should be of equal length.")
|
"should be of equal length.")
|
||||||
|
@ -498,6 +501,8 @@ class Errors(object):
|
||||||
"details: https://spacy.io/api/lemmatizer#init")
|
"details: https://spacy.io/api/lemmatizer#init")
|
||||||
E174 = ("Architecture '{name}' not found in registry. Available "
|
E174 = ("Architecture '{name}' not found in registry. Available "
|
||||||
"names: {names}")
|
"names: {names}")
|
||||||
|
E175 = ("Can't remove rule for unknown match pattern ID: {key}")
|
||||||
|
E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -743,7 +743,8 @@ def docs_to_json(docs, id=0):
|
||||||
|
|
||||||
docs (iterable / Doc): The Doc object(s) to convert.
|
docs (iterable / Doc): The Doc object(s) to convert.
|
||||||
id (int): Id for the JSON.
|
id (int): Id for the JSON.
|
||||||
RETURNS (list): The data in spaCy's JSON format.
|
RETURNS (dict): The data in spaCy's JSON format
|
||||||
|
- each input doc will be treated as a paragraph in the output doc
|
||||||
"""
|
"""
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
|
69
spacy/kb.pyx
69
spacy/kb.pyx
|
@ -142,6 +142,7 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
cdef KBEntryC entry
|
cdef KBEntryC entry
|
||||||
|
cdef hash_t entity_hash
|
||||||
while i < nr_entities:
|
while i < nr_entities:
|
||||||
entity_vector = vector_list[i]
|
entity_vector = vector_list[i]
|
||||||
if len(entity_vector) != self.entity_vector_length:
|
if len(entity_vector) != self.entity_vector_length:
|
||||||
|
@ -161,6 +162,14 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
def contains_entity(self, unicode entity):
|
||||||
|
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||||
|
return entity_hash in self._entry_index
|
||||||
|
|
||||||
|
def contains_alias(self, unicode alias):
|
||||||
|
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
||||||
|
return alias_hash in self._alias_index
|
||||||
|
|
||||||
def add_alias(self, unicode alias, entities, probabilities):
|
def add_alias(self, unicode alias, entities, probabilities):
|
||||||
"""
|
"""
|
||||||
For a given alias, add its potential entities and prior probabilies to the KB.
|
For a given alias, add its potential entities and prior probabilies to the KB.
|
||||||
|
@ -190,7 +199,7 @@ cdef class KnowledgeBase:
|
||||||
for entity, prob in zip(entities, probabilities):
|
for entity, prob in zip(entities, probabilities):
|
||||||
entity_hash = self.vocab.strings[entity]
|
entity_hash = self.vocab.strings[entity]
|
||||||
if not entity_hash in self._entry_index:
|
if not entity_hash in self._entry_index:
|
||||||
raise ValueError(Errors.E134.format(alias=alias, entity=entity))
|
raise ValueError(Errors.E134.format(entity=entity))
|
||||||
|
|
||||||
entry_index = <int64_t>self._entry_index.get(entity_hash)
|
entry_index = <int64_t>self._entry_index.get(entity_hash)
|
||||||
entry_indices.push_back(int(entry_index))
|
entry_indices.push_back(int(entry_index))
|
||||||
|
@ -201,8 +210,63 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
return alias_hash
|
return alias_hash
|
||||||
|
|
||||||
def get_candidates(self, unicode alias):
|
def append_alias(self, unicode alias, unicode entity, float prior_prob, ignore_warnings=False):
|
||||||
|
"""
|
||||||
|
For an alias already existing in the KB, extend its potential entities with one more.
|
||||||
|
Throw a warning if either the alias or the entity is unknown,
|
||||||
|
or when the combination is already previously recorded.
|
||||||
|
Throw an error if this entity+prior prob would exceed the sum of 1.
|
||||||
|
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
|
||||||
|
"""
|
||||||
|
# Check if the alias exists in the KB
|
||||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||||
|
if not alias_hash in self._alias_index:
|
||||||
|
raise ValueError(Errors.E176.format(alias=alias))
|
||||||
|
|
||||||
|
# Check if the entity exists in the KB
|
||||||
|
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||||
|
if not entity_hash in self._entry_index:
|
||||||
|
raise ValueError(Errors.E134.format(entity=entity))
|
||||||
|
entry_index = <int64_t>self._entry_index.get(entity_hash)
|
||||||
|
|
||||||
|
# Throw an error if the prior probabilities (including the new one) sum up to more than 1
|
||||||
|
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||||
|
alias_entry = self._aliases_table[alias_index]
|
||||||
|
current_sum = sum([p for p in alias_entry.probs])
|
||||||
|
new_sum = current_sum + prior_prob
|
||||||
|
|
||||||
|
if new_sum > 1.00001:
|
||||||
|
raise ValueError(Errors.E133.format(alias=alias, sum=new_sum))
|
||||||
|
|
||||||
|
entry_indices = alias_entry.entry_indices
|
||||||
|
|
||||||
|
is_present = False
|
||||||
|
for i in range(entry_indices.size()):
|
||||||
|
if entry_indices[i] == int(entry_index):
|
||||||
|
is_present = True
|
||||||
|
|
||||||
|
if is_present:
|
||||||
|
if not ignore_warnings:
|
||||||
|
user_warning(Warnings.W024.format(entity=entity, alias=alias))
|
||||||
|
else:
|
||||||
|
entry_indices.push_back(int(entry_index))
|
||||||
|
alias_entry.entry_indices = entry_indices
|
||||||
|
|
||||||
|
probs = alias_entry.probs
|
||||||
|
probs.push_back(float(prior_prob))
|
||||||
|
alias_entry.probs = probs
|
||||||
|
self._aliases_table[alias_index] = alias_entry
|
||||||
|
|
||||||
|
|
||||||
|
def get_candidates(self, unicode alias):
|
||||||
|
"""
|
||||||
|
Return candidate entities for an alias. Each candidate defines the entity, the original alias,
|
||||||
|
and the prior probability of that alias resolving to that entity.
|
||||||
|
If the alias is not known in the KB, and empty list is returned.
|
||||||
|
"""
|
||||||
|
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||||
|
if not alias_hash in self._alias_index:
|
||||||
|
return []
|
||||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||||
alias_entry = self._aliases_table[alias_index]
|
alias_entry = self._aliases_table[alias_index]
|
||||||
|
|
||||||
|
@ -341,7 +405,6 @@ cdef class KnowledgeBase:
|
||||||
assert nr_entities == self.get_size_entities()
|
assert nr_entities == self.get_size_entities()
|
||||||
|
|
||||||
# STEP 3: load aliases
|
# STEP 3: load aliases
|
||||||
|
|
||||||
cdef int64_t nr_aliases
|
cdef int64_t nr_aliases
|
||||||
reader.read_alias_length(&nr_aliases)
|
reader.read_alias_length(&nr_aliases)
|
||||||
self._alias_index = PreshMap(nr_aliases+1)
|
self._alias_index = PreshMap(nr_aliases+1)
|
||||||
|
|
34
spacy/lang/lb/__init__.py
Normal file
34
spacy/lang/lb/__init__.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
|
||||||
|
from .norm_exceptions import NORM_EXCEPTIONS
|
||||||
|
from .lex_attrs import LEX_ATTRS
|
||||||
|
from .tag_map import TAG_MAP
|
||||||
|
from .stop_words import STOP_WORDS
|
||||||
|
|
||||||
|
from ..tokenizer_exceptions import BASE_EXCEPTIONS
|
||||||
|
from ..norm_exceptions import BASE_NORMS
|
||||||
|
from ...language import Language
|
||||||
|
from ...attrs import LANG, NORM
|
||||||
|
from ...util import update_exc, add_lookups
|
||||||
|
|
||||||
|
|
||||||
|
class LuxembourgishDefaults(Language.Defaults):
|
||||||
|
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
|
||||||
|
lex_attr_getters.update(LEX_ATTRS)
|
||||||
|
lex_attr_getters[LANG] = lambda text: "lb"
|
||||||
|
lex_attr_getters[NORM] = add_lookups(
|
||||||
|
Language.Defaults.lex_attr_getters[NORM], NORM_EXCEPTIONS, BASE_NORMS
|
||||||
|
)
|
||||||
|
tokenizer_exceptions = update_exc(BASE_EXCEPTIONS, TOKENIZER_EXCEPTIONS)
|
||||||
|
stop_words = STOP_WORDS
|
||||||
|
tag_map = TAG_MAP
|
||||||
|
|
||||||
|
|
||||||
|
class Luxembourgish(Language):
|
||||||
|
lang = "lb"
|
||||||
|
Defaults = LuxembourgishDefaults
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Luxembourgish"]
|
18
spacy/lang/lb/examples.py
Normal file
18
spacy/lang/lb/examples.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
"""
|
||||||
|
Example sentences to test spaCy and its language models.
|
||||||
|
|
||||||
|
>>> from spacy.lang.lb.examples import sentences
|
||||||
|
>>> docs = nlp.pipe(sentences)
|
||||||
|
"""
|
||||||
|
|
||||||
|
sentences = [
|
||||||
|
"An der Zäit hunn sech den Nordwand an d’Sonn gestridden, wie vun hinnen zwee wuel méi staark wier, wéi e Wanderer, deen an ee waarme Mantel agepak war, iwwert de Wee koum.",
|
||||||
|
"Si goufen sech eens, dass deejéinege fir de Stäerkste gëlle sollt, deen de Wanderer forcéiere géif, säi Mantel auszedoen.",
|
||||||
|
"Den Nordwand huet mat aller Force geblosen, awer wat e méi geblosen huet, wat de Wanderer sech méi a säi Mantel agewéckelt huet.",
|
||||||
|
"Um Enn huet den Nordwand säi Kampf opginn.",
|
||||||
|
"Dunn huet d’Sonn d’Loft mat hire frëndleche Strale gewiermt, a schonn no kuerzer Zäit huet de Wanderer säi Mantel ausgedoen.",
|
||||||
|
"Do huet den Nordwand missen zouginn, dass d’Sonn vun hinnen zwee de Stäerkste wier.",
|
||||||
|
]
|
44
spacy/lang/lb/lex_attrs.py
Normal file
44
spacy/lang/lb/lex_attrs.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from ...attrs import LIKE_NUM
|
||||||
|
|
||||||
|
|
||||||
|
_num_words = set(
|
||||||
|
"""
|
||||||
|
null eent zwee dräi véier fënnef sechs ziwen aacht néng zéng eelef zwielef dräizéng
|
||||||
|
véierzéng foffzéng siechzéng siwwenzéng uechtzeng uechzeng nonnzéng nongzéng zwanzeg drësseg véierzeg foffzeg sechzeg siechzeg siwenzeg achtzeg achzeg uechtzeg uechzeg nonnzeg
|
||||||
|
honnert dausend millioun milliard billioun billiard trillioun triliard
|
||||||
|
""".split()
|
||||||
|
)
|
||||||
|
|
||||||
|
_ordinal_words = set(
|
||||||
|
"""
|
||||||
|
éischten zweeten drëtten véierten fënneften sechsten siwenten aachten néngten zéngten eeleften
|
||||||
|
zwieleften dräizéngten véierzéngten foffzéngten siechzéngten uechtzéngen uechzéngten nonnzéngten nongzéngten zwanzegsten
|
||||||
|
drëssegsten véierzegsten foffzegsten siechzegsten siwenzegsten uechzegsten nonnzegsten
|
||||||
|
honnertsten dausendsten milliounsten
|
||||||
|
milliardsten billiounsten billiardsten trilliounsten trilliardsten
|
||||||
|
""".split()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def like_num(text):
|
||||||
|
"""
|
||||||
|
check if text resembles a number
|
||||||
|
"""
|
||||||
|
text = text.replace(",", "").replace(".", "")
|
||||||
|
if text.isdigit():
|
||||||
|
return True
|
||||||
|
if text.count("/") == 1:
|
||||||
|
num, denom = text.split("/")
|
||||||
|
if num.isdigit() and denom.isdigit():
|
||||||
|
return True
|
||||||
|
if text in _num_words:
|
||||||
|
return True
|
||||||
|
if text in _ordinal_words:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
LEX_ATTRS = {LIKE_NUM: like_num}
|
16
spacy/lang/lb/norm_exceptions.py
Normal file
16
spacy/lang/lb/norm_exceptions.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# norm execptions: find a possibility to deal with the zillions of spelling
|
||||||
|
# variants (vläicht = vlaicht, vleicht, viläicht, viläischt, etc. etc.)
|
||||||
|
# here one could include the most common spelling mistakes
|
||||||
|
|
||||||
|
_exc = {"datt": "dass", "wgl.": "weg.", "vläicht": "viläicht"}
|
||||||
|
|
||||||
|
|
||||||
|
NORM_EXCEPTIONS = {}
|
||||||
|
|
||||||
|
for string, norm in _exc.items():
|
||||||
|
NORM_EXCEPTIONS[string] = norm
|
||||||
|
NORM_EXCEPTIONS[string.title()] = norm
|
214
spacy/lang/lb/stop_words.py
Normal file
214
spacy/lang/lb/stop_words.py
Normal file
|
@ -0,0 +1,214 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
STOP_WORDS = set(
|
||||||
|
"""
|
||||||
|
a
|
||||||
|
à
|
||||||
|
äis
|
||||||
|
är
|
||||||
|
ärt
|
||||||
|
äert
|
||||||
|
ären
|
||||||
|
all
|
||||||
|
allem
|
||||||
|
alles
|
||||||
|
alleguer
|
||||||
|
als
|
||||||
|
also
|
||||||
|
am
|
||||||
|
an
|
||||||
|
anerefalls
|
||||||
|
ass
|
||||||
|
aus
|
||||||
|
awer
|
||||||
|
bei
|
||||||
|
beim
|
||||||
|
bis
|
||||||
|
bis
|
||||||
|
d'
|
||||||
|
dach
|
||||||
|
datt
|
||||||
|
däin
|
||||||
|
där
|
||||||
|
dat
|
||||||
|
de
|
||||||
|
dee
|
||||||
|
den
|
||||||
|
deel
|
||||||
|
deem
|
||||||
|
deen
|
||||||
|
deene
|
||||||
|
déi
|
||||||
|
den
|
||||||
|
deng
|
||||||
|
denger
|
||||||
|
dem
|
||||||
|
der
|
||||||
|
dësem
|
||||||
|
di
|
||||||
|
dir
|
||||||
|
do
|
||||||
|
da
|
||||||
|
dann
|
||||||
|
domat
|
||||||
|
dozou
|
||||||
|
drop
|
||||||
|
du
|
||||||
|
duerch
|
||||||
|
duerno
|
||||||
|
e
|
||||||
|
ee
|
||||||
|
em
|
||||||
|
een
|
||||||
|
eent
|
||||||
|
ë
|
||||||
|
en
|
||||||
|
ënner
|
||||||
|
ëm
|
||||||
|
ech
|
||||||
|
eis
|
||||||
|
eise
|
||||||
|
eisen
|
||||||
|
eiser
|
||||||
|
eises
|
||||||
|
eisereen
|
||||||
|
esou
|
||||||
|
een
|
||||||
|
eng
|
||||||
|
enger
|
||||||
|
engem
|
||||||
|
entweder
|
||||||
|
et
|
||||||
|
eréischt
|
||||||
|
falls
|
||||||
|
fir
|
||||||
|
géint
|
||||||
|
géif
|
||||||
|
gëtt
|
||||||
|
gët
|
||||||
|
geet
|
||||||
|
gi
|
||||||
|
ginn
|
||||||
|
gouf
|
||||||
|
gouff
|
||||||
|
goung
|
||||||
|
hat
|
||||||
|
haten
|
||||||
|
hatt
|
||||||
|
hätt
|
||||||
|
hei
|
||||||
|
hu
|
||||||
|
huet
|
||||||
|
hun
|
||||||
|
hunn
|
||||||
|
hiren
|
||||||
|
hien
|
||||||
|
hin
|
||||||
|
hier
|
||||||
|
hir
|
||||||
|
jidderen
|
||||||
|
jiddereen
|
||||||
|
jiddwereen
|
||||||
|
jiddereng
|
||||||
|
jiddwerengen
|
||||||
|
jo
|
||||||
|
ins
|
||||||
|
iech
|
||||||
|
iwwer
|
||||||
|
kann
|
||||||
|
kee
|
||||||
|
keen
|
||||||
|
kënne
|
||||||
|
kënnt
|
||||||
|
kéng
|
||||||
|
kéngen
|
||||||
|
kéngem
|
||||||
|
koum
|
||||||
|
kuckt
|
||||||
|
mam
|
||||||
|
mat
|
||||||
|
ma
|
||||||
|
mä
|
||||||
|
mech
|
||||||
|
méi
|
||||||
|
mécht
|
||||||
|
meng
|
||||||
|
menger
|
||||||
|
mer
|
||||||
|
mir
|
||||||
|
muss
|
||||||
|
nach
|
||||||
|
nämmlech
|
||||||
|
nämmelech
|
||||||
|
näischt
|
||||||
|
nawell
|
||||||
|
nëmme
|
||||||
|
nëmmen
|
||||||
|
net
|
||||||
|
nees
|
||||||
|
nee
|
||||||
|
no
|
||||||
|
nu
|
||||||
|
nom
|
||||||
|
och
|
||||||
|
oder
|
||||||
|
ons
|
||||||
|
onsen
|
||||||
|
onser
|
||||||
|
onsereen
|
||||||
|
onst
|
||||||
|
om
|
||||||
|
op
|
||||||
|
ouni
|
||||||
|
säi
|
||||||
|
säin
|
||||||
|
schonn
|
||||||
|
schonns
|
||||||
|
si
|
||||||
|
sid
|
||||||
|
sie
|
||||||
|
se
|
||||||
|
sech
|
||||||
|
seng
|
||||||
|
senge
|
||||||
|
sengem
|
||||||
|
senger
|
||||||
|
selwecht
|
||||||
|
selwer
|
||||||
|
sinn
|
||||||
|
sollten
|
||||||
|
souguer
|
||||||
|
sou
|
||||||
|
soss
|
||||||
|
sot
|
||||||
|
't
|
||||||
|
tëscht
|
||||||
|
u
|
||||||
|
un
|
||||||
|
um
|
||||||
|
virdrun
|
||||||
|
vu
|
||||||
|
vum
|
||||||
|
vun
|
||||||
|
wann
|
||||||
|
war
|
||||||
|
waren
|
||||||
|
was
|
||||||
|
wat
|
||||||
|
wëllt
|
||||||
|
weider
|
||||||
|
wéi
|
||||||
|
wéini
|
||||||
|
wéinst
|
||||||
|
wi
|
||||||
|
wollt
|
||||||
|
wou
|
||||||
|
wouhin
|
||||||
|
zanter
|
||||||
|
ze
|
||||||
|
zu
|
||||||
|
zum
|
||||||
|
zwar
|
||||||
|
""".split()
|
||||||
|
)
|
28
spacy/lang/lb/tag_map.py
Normal file
28
spacy/lang/lb/tag_map.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from ...symbols import POS, PUNCT, ADJ, CONJ, NUM, DET, ADV, ADP, X, VERB
|
||||||
|
from ...symbols import NOUN, PART, SPACE, AUX
|
||||||
|
|
||||||
|
# TODO: tag map is still using POS tags from an internal training set.
|
||||||
|
# These POS tags have to be modified to match those from Universal Dependencies
|
||||||
|
|
||||||
|
TAG_MAP = {
|
||||||
|
"$": {POS: PUNCT},
|
||||||
|
"ADJ": {POS: ADJ},
|
||||||
|
"AV": {POS: ADV},
|
||||||
|
"APPR": {POS: ADP, "AdpType": "prep"},
|
||||||
|
"APPRART": {POS: ADP, "AdpType": "prep", "PronType": "art"},
|
||||||
|
"D": {POS: DET, "PronType": "art"},
|
||||||
|
"KO": {POS: CONJ},
|
||||||
|
"N": {POS: NOUN},
|
||||||
|
"P": {POS: ADV},
|
||||||
|
"TRUNC": {POS: X, "Hyph": "yes"},
|
||||||
|
"AUX": {POS: AUX},
|
||||||
|
"V": {POS: VERB},
|
||||||
|
"MV": {POS: VERB, "VerbType": "mod"},
|
||||||
|
"PTK": {POS: PART},
|
||||||
|
"INTER": {POS: PART},
|
||||||
|
"NUM": {POS: NUM},
|
||||||
|
"_SP": {POS: SPACE},
|
||||||
|
}
|
69
spacy/lang/lb/tokenizer_exceptions.py
Normal file
69
spacy/lang/lb/tokenizer_exceptions.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from ...symbols import ORTH, LEMMA, NORM
|
||||||
|
from ..punctuation import TOKENIZER_PREFIXES
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# tokenize cliticised definite article "d'" as token of its own: d'Kanner > [d'] [Kanner]
|
||||||
|
# treat other apostrophes within words as part of the word: [op d'mannst], [fir d'éischt] (= exceptions)
|
||||||
|
|
||||||
|
# how to write the tokenisation exeption for the articles d' / D' ? This one is not working.
|
||||||
|
_prefixes = [
|
||||||
|
prefix for prefix in TOKENIZER_PREFIXES if prefix not in ["d'", "D'", "d’", "D’"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_exc = {
|
||||||
|
"d'mannst": [
|
||||||
|
{ORTH: "d'", LEMMA: "d'"},
|
||||||
|
{ORTH: "mannst", LEMMA: "mann", NORM: "mann"},
|
||||||
|
],
|
||||||
|
"d'éischt": [
|
||||||
|
{ORTH: "d'", LEMMA: "d'"},
|
||||||
|
{ORTH: "éischt", LEMMA: "éischt", NORM: "éischt"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# translate / delete what is not necessary
|
||||||
|
# what does PRON_LEMMA mean?
|
||||||
|
for exc_data in [
|
||||||
|
{ORTH: "wgl.", LEMMA: "wann ech gelift", NORM: "wann ech gelieft"},
|
||||||
|
{ORTH: "M.", LEMMA: "Monsieur", NORM: "Monsieur"},
|
||||||
|
{ORTH: "Mme.", LEMMA: "Madame", NORM: "Madame"},
|
||||||
|
{ORTH: "Dr.", LEMMA: "Dokter", NORM: "Dokter"},
|
||||||
|
{ORTH: "Tel.", LEMMA: "Telefon", NORM: "Telefon"},
|
||||||
|
{ORTH: "asw.", LEMMA: "an sou weider", NORM: "an sou weider"},
|
||||||
|
{ORTH: "etc.", LEMMA: "et cetera", NORM: "et cetera"},
|
||||||
|
{ORTH: "bzw.", LEMMA: "bezéiungsweis", NORM: "bezéiungsweis"},
|
||||||
|
{ORTH: "Jan.", LEMMA: "Januar", NORM: "Januar"},
|
||||||
|
]:
|
||||||
|
_exc[exc_data[ORTH]] = [exc_data]
|
||||||
|
|
||||||
|
|
||||||
|
# to be extended
|
||||||
|
for orth in [
|
||||||
|
"z.B.",
|
||||||
|
"Dipl.",
|
||||||
|
"Dr.",
|
||||||
|
"etc.",
|
||||||
|
"i.e.",
|
||||||
|
"o.k.",
|
||||||
|
"O.K.",
|
||||||
|
"p.a.",
|
||||||
|
"p.s.",
|
||||||
|
"P.S.",
|
||||||
|
"phil.",
|
||||||
|
"q.e.d.",
|
||||||
|
"R.I.P.",
|
||||||
|
"rer.",
|
||||||
|
"sen.",
|
||||||
|
"ë.a.",
|
||||||
|
"U.S.",
|
||||||
|
"U.S.A.",
|
||||||
|
]:
|
||||||
|
_exc[orth] = [{ORTH: orth}]
|
||||||
|
|
||||||
|
|
||||||
|
TOKENIZER_PREFIXES = _prefixes
|
||||||
|
TOKENIZER_EXCEPTIONS = _exc
|
|
@ -1,10 +1,8 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import absolute_import, unicode_literals
|
from __future__ import absolute_import, unicode_literals
|
||||||
|
|
||||||
import atexit
|
|
||||||
import random
|
import random
|
||||||
import itertools
|
import itertools
|
||||||
from warnings import warn
|
|
||||||
from spacy.util import minibatch
|
from spacy.util import minibatch
|
||||||
import weakref
|
import weakref
|
||||||
import functools
|
import functools
|
||||||
|
@ -483,7 +481,7 @@ class Language(object):
|
||||||
|
|
||||||
docs (iterable): A batch of `Doc` objects.
|
docs (iterable): A batch of `Doc` objects.
|
||||||
golds (iterable): A batch of `GoldParse` objects.
|
golds (iterable): A batch of `GoldParse` objects.
|
||||||
drop (float): The droput rate.
|
drop (float): The dropout rate.
|
||||||
sgd (callable): An optimizer.
|
sgd (callable): An optimizer.
|
||||||
losses (dict): Dictionary to update with the loss, keyed by component.
|
losses (dict): Dictionary to update with the loss, keyed by component.
|
||||||
component_cfg (dict): Config parameters for specific pipeline
|
component_cfg (dict): Config parameters for specific pipeline
|
||||||
|
@ -531,7 +529,7 @@ class Language(object):
|
||||||
even if you're updating it with a smaller set of examples.
|
even if you're updating it with a smaller set of examples.
|
||||||
|
|
||||||
docs (iterable): A batch of `Doc` objects.
|
docs (iterable): A batch of `Doc` objects.
|
||||||
drop (float): The droput rate.
|
drop (float): The dropout rate.
|
||||||
sgd (callable): An optimizer.
|
sgd (callable): An optimizer.
|
||||||
RETURNS (dict): Results from the update.
|
RETURNS (dict): Results from the update.
|
||||||
|
|
||||||
|
@ -753,7 +751,8 @@ class Language(object):
|
||||||
use. Experimental.
|
use. Experimental.
|
||||||
component_cfg (dict): An optional dictionary with extra keyword
|
component_cfg (dict): An optional dictionary with extra keyword
|
||||||
arguments for specific components.
|
arguments for specific components.
|
||||||
n_process (int): Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`.
|
n_process (int): Number of processors to process texts, only supported
|
||||||
|
in Python3. If -1, set `multiprocessing.cpu_count()`.
|
||||||
YIELDS (Doc): Documents in the order of the original text.
|
YIELDS (Doc): Documents in the order of the original text.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#pipe
|
DOCS: https://spacy.io/api/language#pipe
|
||||||
|
@ -1069,9 +1068,10 @@ def _pipe(docs, proc, kwargs):
|
||||||
def _apply_pipes(make_doc, pipes, reciever, sender):
|
def _apply_pipes(make_doc, pipes, reciever, sender):
|
||||||
"""Worker for Language.pipe
|
"""Worker for Language.pipe
|
||||||
|
|
||||||
Args:
|
receiver (multiprocessing.Connection): Pipe to receive text. Usually
|
||||||
receiver (multiprocessing.Connection): Pipe to receive text. Usually created by `multiprocessing.Pipe()`
|
created by `multiprocessing.Pipe()`
|
||||||
sender (multiprocessing.Connection): Pipe to send doc. Usually created by `multiprocessing.Pipe()`
|
sender (multiprocessing.Connection): Pipe to send doc. Usually created by
|
||||||
|
`multiprocessing.Pipe()`
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
texts = reciever.get()
|
texts = reciever.get()
|
||||||
|
|
|
@ -133,13 +133,15 @@ cdef class Matcher:
|
||||||
|
|
||||||
key (unicode): The ID of the match rule.
|
key (unicode): The ID of the match rule.
|
||||||
"""
|
"""
|
||||||
key = self._normalize_key(key)
|
norm_key = self._normalize_key(key)
|
||||||
self._patterns.pop(key)
|
if not norm_key in self._patterns:
|
||||||
self._callbacks.pop(key)
|
raise ValueError(Errors.E175.format(key=key))
|
||||||
|
self._patterns.pop(norm_key)
|
||||||
|
self._callbacks.pop(norm_key)
|
||||||
cdef int i = 0
|
cdef int i = 0
|
||||||
while i < self.patterns.size():
|
while i < self.patterns.size():
|
||||||
pattern_key = get_pattern_key(self.patterns.at(i))
|
pattern_key = get_ent_id(self.patterns.at(i))
|
||||||
if pattern_key == key:
|
if pattern_key == norm_key:
|
||||||
self.patterns.erase(self.patterns.begin()+i)
|
self.patterns.erase(self.patterns.begin()+i)
|
||||||
else:
|
else:
|
||||||
i += 1
|
i += 1
|
||||||
|
@ -293,18 +295,6 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
|
||||||
# There have been a few bugs here.
|
|
||||||
# The code was originally designed to always have pattern[1].attrs.value
|
|
||||||
# be the ent_id when we get to the end of a pattern. However, Issue #2671
|
|
||||||
# showed this wasn't the case when we had a reject-and-continue before a
|
|
||||||
# match.
|
|
||||||
# The patch to #2671 was wrong though, which came up in #3839.
|
|
||||||
while pattern.attrs.attr != ID:
|
|
||||||
pattern += 1
|
|
||||||
return pattern.attrs.value
|
|
||||||
|
|
||||||
|
|
||||||
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
|
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
|
||||||
char* cached_py_predicates,
|
char* cached_py_predicates,
|
||||||
Token token, const attr_t* extra_attrs, py_predicates) except *:
|
Token token, const attr_t* extra_attrs, py_predicates) except *:
|
||||||
|
@ -533,6 +523,7 @@ cdef char get_is_match(PatternStateC state,
|
||||||
if predicate_matches[state.pattern.py_predicates[i]] == -1:
|
if predicate_matches[state.pattern.py_predicates[i]] == -1:
|
||||||
return 0
|
return 0
|
||||||
spec = state.pattern
|
spec = state.pattern
|
||||||
|
if spec.nr_attr > 0:
|
||||||
for attr in spec.attrs[:spec.nr_attr]:
|
for attr in spec.attrs[:spec.nr_attr]:
|
||||||
if get_token_attr(token, attr.attr) != attr.value:
|
if get_token_attr(token, attr.attr) != attr.value:
|
||||||
return 0
|
return 0
|
||||||
|
@ -543,7 +534,11 @@ cdef char get_is_match(PatternStateC state,
|
||||||
|
|
||||||
|
|
||||||
cdef char get_is_final(PatternStateC state) nogil:
|
cdef char get_is_final(PatternStateC state) nogil:
|
||||||
if state.pattern[1].attrs[0].attr == ID and state.pattern[1].nr_attr == 0:
|
if state.pattern[1].nr_attr == 0 and state.pattern[1].attrs != NULL:
|
||||||
|
id_attr = state.pattern[1].attrs[0]
|
||||||
|
if id_attr.attr != ID:
|
||||||
|
with gil:
|
||||||
|
raise ValueError(Errors.E074.format(attr=ID, bad_attr=id_attr.attr))
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
@ -558,6 +553,8 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
|
||||||
cdef int i, index
|
cdef int i, index
|
||||||
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
|
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
|
||||||
pattern[i].quantifier = quantifier
|
pattern[i].quantifier = quantifier
|
||||||
|
# Ensure attrs refers to a null pointer if nr_attr == 0
|
||||||
|
if len(spec) > 0:
|
||||||
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
|
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
|
||||||
pattern[i].nr_attr = len(spec)
|
pattern[i].nr_attr = len(spec)
|
||||||
for j, (attr, value) in enumerate(spec):
|
for j, (attr, value) in enumerate(spec):
|
||||||
|
@ -574,6 +571,7 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
|
||||||
pattern[i].nr_py = len(predicates)
|
pattern[i].nr_py = len(predicates)
|
||||||
pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0)
|
pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0)
|
||||||
i = len(token_specs)
|
i = len(token_specs)
|
||||||
|
# Even though here, nr_attr == 0, we're storing the ID value in attrs[0] (bug-prone, thread carefully!)
|
||||||
pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
|
pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
|
||||||
pattern[i].attrs[0].attr = ID
|
pattern[i].attrs[0].attr = ID
|
||||||
pattern[i].attrs[0].value = entity_id
|
pattern[i].attrs[0].value = entity_id
|
||||||
|
@ -583,8 +581,26 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
|
||||||
return pattern
|
return pattern
|
||||||
|
|
||||||
|
|
||||||
cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil:
|
cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
||||||
while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0:
|
# There have been a few bugs here. We used to have two functions,
|
||||||
|
# get_ent_id and get_pattern_key that tried to do the same thing. These
|
||||||
|
# are now unified to try to solve the "ghost match" problem.
|
||||||
|
# Below is the previous implementation of get_ent_id and the comment on it,
|
||||||
|
# preserved for reference while we figure out whether the heisenbug in the
|
||||||
|
# matcher is resolved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
||||||
|
# # The code was originally designed to always have pattern[1].attrs.value
|
||||||
|
# # be the ent_id when we get to the end of a pattern. However, Issue #2671
|
||||||
|
# # showed this wasn't the case when we had a reject-and-continue before a
|
||||||
|
# # match.
|
||||||
|
# # The patch to #2671 was wrong though, which came up in #3839.
|
||||||
|
# while pattern.attrs.attr != ID:
|
||||||
|
# pattern += 1
|
||||||
|
# return pattern.attrs.value
|
||||||
|
while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0 \
|
||||||
|
or pattern.quantifier != ZERO:
|
||||||
pattern += 1
|
pattern += 1
|
||||||
id_attr = pattern[0].attrs[0]
|
id_attr = pattern[0].attrs[0]
|
||||||
if id_attr.attr != ID:
|
if id_attr.attr != ID:
|
||||||
|
@ -642,7 +658,7 @@ def _get_attr_values(spec, string_store):
|
||||||
value = string_store.add(value)
|
value = string_store.add(value)
|
||||||
elif isinstance(value, bool):
|
elif isinstance(value, bool):
|
||||||
value = int(value)
|
value = int(value)
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, (dict, int)):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise ValueError(Errors.E153.format(vtype=type(value).__name__))
|
raise ValueError(Errors.E153.format(vtype=type(value).__name__))
|
||||||
|
|
|
@ -4,6 +4,7 @@ from cymem.cymem cimport Pool
|
||||||
from preshed.maps cimport key_t, MapStruct
|
from preshed.maps cimport key_t, MapStruct
|
||||||
|
|
||||||
from ..attrs cimport attr_id_t
|
from ..attrs cimport attr_id_t
|
||||||
|
from ..structs cimport SpanC
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..vocab cimport Vocab
|
from ..vocab cimport Vocab
|
||||||
|
|
||||||
|
@ -18,10 +19,4 @@ cdef class PhraseMatcher:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cdef key_t _terminal_hash
|
cdef key_t _terminal_hash
|
||||||
|
|
||||||
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil
|
cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil
|
||||||
|
|
||||||
|
|
||||||
cdef struct MatchStruct:
|
|
||||||
key_t match_id
|
|
||||||
int start
|
|
||||||
int end
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter
|
||||||
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
|
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
|
||||||
from ..structs cimport TokenC
|
from ..structs cimport TokenC
|
||||||
from ..tokens.token cimport Token
|
from ..tokens.token cimport Token
|
||||||
|
from ..typedefs cimport attr_t
|
||||||
|
|
||||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
from ._schemas import TOKEN_PATTERN_SCHEMA
|
||||||
from ..errors import Errors, Warnings, deprecation_warning, user_warning
|
from ..errors import Errors, Warnings, deprecation_warning, user_warning
|
||||||
|
@ -102,8 +103,10 @@ cdef class PhraseMatcher:
|
||||||
cdef vector[MapStruct*] path_nodes
|
cdef vector[MapStruct*] path_nodes
|
||||||
cdef vector[key_t] path_keys
|
cdef vector[key_t] path_keys
|
||||||
cdef key_t key_to_remove
|
cdef key_t key_to_remove
|
||||||
for keyword in self._docs[key]:
|
for keyword in sorted(self._docs[key], key=lambda x: len(x), reverse=True):
|
||||||
current_node = self.c_map
|
current_node = self.c_map
|
||||||
|
path_nodes.clear()
|
||||||
|
path_keys.clear()
|
||||||
for token in keyword:
|
for token in keyword:
|
||||||
result = map_get(current_node, token)
|
result = map_get(current_node, token)
|
||||||
if result:
|
if result:
|
||||||
|
@ -220,17 +223,17 @@ cdef class PhraseMatcher:
|
||||||
# if doc is empty or None just return empty list
|
# if doc is empty or None just return empty list
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
cdef vector[MatchStruct] c_matches
|
cdef vector[SpanC] c_matches
|
||||||
self.find_matches(doc, &c_matches)
|
self.find_matches(doc, &c_matches)
|
||||||
for i in range(c_matches.size()):
|
for i in range(c_matches.size()):
|
||||||
matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end))
|
matches.append((c_matches[i].label, c_matches[i].start, c_matches[i].end))
|
||||||
for i, (ent_id, start, end) in enumerate(matches):
|
for i, (ent_id, start, end) in enumerate(matches):
|
||||||
on_match = self._callbacks.get(self.vocab.strings[ent_id])
|
on_match = self._callbacks.get(self.vocab.strings[ent_id])
|
||||||
if on_match is not None:
|
if on_match is not None:
|
||||||
on_match(self, doc, i, matches)
|
on_match(self, doc, i, matches)
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil:
|
cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil:
|
||||||
cdef MapStruct* current_node = self.c_map
|
cdef MapStruct* current_node = self.c_map
|
||||||
cdef int start = 0
|
cdef int start = 0
|
||||||
cdef int idx = 0
|
cdef int idx = 0
|
||||||
|
@ -238,7 +241,7 @@ cdef class PhraseMatcher:
|
||||||
cdef key_t key
|
cdef key_t key
|
||||||
cdef void* value
|
cdef void* value
|
||||||
cdef int i = 0
|
cdef int i = 0
|
||||||
cdef MatchStruct ms
|
cdef SpanC ms
|
||||||
cdef void* result
|
cdef void* result
|
||||||
while idx < doc.length:
|
while idx < doc.length:
|
||||||
start = idx
|
start = idx
|
||||||
|
@ -253,7 +256,7 @@ cdef class PhraseMatcher:
|
||||||
if result:
|
if result:
|
||||||
i = 0
|
i = 0
|
||||||
while map_iter(<MapStruct*>result, &i, &key, &value):
|
while map_iter(<MapStruct*>result, &i, &key, &value):
|
||||||
ms = make_matchstruct(key, start, idy)
|
ms = make_spanstruct(key, start, idy)
|
||||||
matches.push_back(ms)
|
matches.push_back(ms)
|
||||||
inner_token = Token.get_struct_attr(&doc.c[idy], self.attr)
|
inner_token = Token.get_struct_attr(&doc.c[idy], self.attr)
|
||||||
result = map_get(current_node, inner_token)
|
result = map_get(current_node, inner_token)
|
||||||
|
@ -268,7 +271,7 @@ cdef class PhraseMatcher:
|
||||||
if result:
|
if result:
|
||||||
i = 0
|
i = 0
|
||||||
while map_iter(<MapStruct*>result, &i, &key, &value):
|
while map_iter(<MapStruct*>result, &i, &key, &value):
|
||||||
ms = make_matchstruct(key, start, idy)
|
ms = make_spanstruct(key, start, idy)
|
||||||
matches.push_back(ms)
|
matches.push_back(ms)
|
||||||
current_node = self.c_map
|
current_node = self.c_map
|
||||||
idx += 1
|
idx += 1
|
||||||
|
@ -318,9 +321,9 @@ def unpickle_matcher(vocab, docs, callbacks, attr):
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
cdef MatchStruct make_matchstruct(key_t match_id, int start, int end) nogil:
|
cdef SpanC make_spanstruct(attr_t label, int start, int end) nogil:
|
||||||
cdef MatchStruct ms
|
cdef SpanC spanc
|
||||||
ms.match_id = match_id
|
spanc.label = label
|
||||||
ms.start = start
|
spanc.start = start
|
||||||
ms.end = end
|
spanc.end = end
|
||||||
return ms
|
return spanc
|
||||||
|
|
|
@ -183,7 +183,9 @@ class EntityRuler(object):
|
||||||
# disable the nlp components after this one in case they hadn't been initialized / deserialised yet
|
# disable the nlp components after this one in case they hadn't been initialized / deserialised yet
|
||||||
try:
|
try:
|
||||||
current_index = self.nlp.pipe_names.index(self.name)
|
current_index = self.nlp.pipe_names.index(self.name)
|
||||||
subsequent_pipes = [pipe for pipe in self.nlp.pipe_names[current_index + 1:]]
|
subsequent_pipes = [
|
||||||
|
pipe for pipe in self.nlp.pipe_names[current_index + 1 :]
|
||||||
|
]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
subsequent_pipes = []
|
subsequent_pipes = []
|
||||||
with self.nlp.disable_pipes(*subsequent_pipes):
|
with self.nlp.disable_pipes(*subsequent_pipes):
|
||||||
|
|
|
@ -1195,23 +1195,26 @@ class EntityLinker(Pipe):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
golds = [golds]
|
golds = [golds]
|
||||||
|
|
||||||
context_docs = []
|
sentence_docs = []
|
||||||
|
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
ents_by_offset = dict()
|
ents_by_offset = dict()
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent
|
ents_by_offset[(ent.start_char, ent.end_char)] = ent
|
||||||
|
|
||||||
for entity, kb_dict in gold.links.items():
|
for entity, kb_dict in gold.links.items():
|
||||||
start, end = entity
|
start, end = entity
|
||||||
mention = doc.text[start:end]
|
mention = doc.text[start:end]
|
||||||
|
# the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt
|
||||||
|
ent = ents_by_offset[(start, end)]
|
||||||
|
|
||||||
for kb_id, value in kb_dict.items():
|
for kb_id, value in kb_dict.items():
|
||||||
# Currently only training on the positive instances
|
# Currently only training on the positive instances
|
||||||
if value:
|
if value:
|
||||||
context_docs.append(doc)
|
sentence_docs.append(ent.sent.as_doc())
|
||||||
|
|
||||||
context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
|
sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop)
|
||||||
loss, d_scores = self.get_similarity_loss(scores=context_encodings, golds=golds, docs=None)
|
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None)
|
||||||
bp_context(d_scores, sgd=sgd)
|
bp_context(d_scores, sgd=sgd)
|
||||||
|
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
|
@ -1280,22 +1283,40 @@ class EntityLinker(Pipe):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
|
||||||
context_encodings = self.model(docs)
|
|
||||||
xp = get_array_module(context_encodings)
|
|
||||||
|
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
if len(doc) > 0:
|
if len(doc) > 0:
|
||||||
|
# Looping through each sentence and each entity
|
||||||
|
# This may go wrong if there are entities across sentences - because they might not get a KB ID
|
||||||
|
for sent in doc.ents:
|
||||||
|
sent_doc = sent.as_doc()
|
||||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
context_encoding = context_encodings[i]
|
sentence_encoding = self.model([sent_doc])[0]
|
||||||
context_enc_t = context_encoding.T
|
xp = get_array_module(sentence_encoding)
|
||||||
norm_1 = xp.linalg.norm(context_enc_t)
|
sentence_encoding_t = sentence_encoding.T
|
||||||
for ent in doc.ents:
|
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||||
|
|
||||||
|
for ent in sent_doc.ents:
|
||||||
entity_count += 1
|
entity_count += 1
|
||||||
|
|
||||||
|
if ent.label_ in self.cfg.get("labels_discard", []):
|
||||||
|
# ignoring this entity - setting to NIL
|
||||||
|
final_kb_ids.append(self.NIL)
|
||||||
|
final_tensors.append(sentence_encoding)
|
||||||
|
|
||||||
|
else:
|
||||||
candidates = self.kb.get_candidates(ent.text)
|
candidates = self.kb.get_candidates(ent.text)
|
||||||
if not candidates:
|
if not candidates:
|
||||||
final_kb_ids.append(self.NIL) # no prediction possible for this entity
|
# no prediction possible for this entity - setting to NIL
|
||||||
final_tensors.append(context_encoding)
|
final_kb_ids.append(self.NIL)
|
||||||
|
final_tensors.append(sentence_encoding)
|
||||||
|
|
||||||
|
elif len(candidates) == 1:
|
||||||
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
|
|
||||||
|
# TODO: thresholding
|
||||||
|
final_kb_ids.append(candidates[0].entity_)
|
||||||
|
final_tensors.append(sentence_encoding)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
|
|
||||||
|
@ -1308,13 +1329,13 @@ class EntityLinker(Pipe):
|
||||||
# add in similarity from the context
|
# add in similarity from the context
|
||||||
if self.cfg.get("incl_context", True):
|
if self.cfg.get("incl_context", True):
|
||||||
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||||
norm_2 = xp.linalg.norm(entity_encodings, axis=1)
|
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
|
||||||
|
|
||||||
if len(entity_encodings) != len(prior_probs):
|
if len(entity_encodings) != len(prior_probs):
|
||||||
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
|
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
|
||||||
|
|
||||||
# cosine similarity
|
# cosine similarity
|
||||||
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
|
sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm)
|
||||||
if sims.shape != prior_probs.shape:
|
if sims.shape != prior_probs.shape:
|
||||||
raise ValueError(Errors.E161)
|
raise ValueError(Errors.E161)
|
||||||
scores = prior_probs + sims - (prior_probs*sims)
|
scores = prior_probs + sims - (prior_probs*sims)
|
||||||
|
@ -1323,7 +1344,7 @@ class EntityLinker(Pipe):
|
||||||
best_index = scores.argmax()
|
best_index = scores.argmax()
|
||||||
best_candidate = candidates[best_index]
|
best_candidate = candidates[best_index]
|
||||||
final_kb_ids.append(best_candidate.entity_)
|
final_kb_ids.append(best_candidate.entity_)
|
||||||
final_tensors.append(context_encoding)
|
final_tensors.append(sentence_encoding)
|
||||||
|
|
||||||
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
|
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
|
||||||
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
|
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
|
||||||
|
|
|
@ -219,7 +219,9 @@ class Scorer(object):
|
||||||
DOCS: https://spacy.io/api/scorer#score
|
DOCS: https://spacy.io/api/scorer#score
|
||||||
"""
|
"""
|
||||||
if len(doc) != len(gold):
|
if len(doc) != len(gold):
|
||||||
gold = GoldParse.from_annot_tuples(doc, zip(*gold.orig_annot))
|
gold = GoldParse.from_annot_tuples(
|
||||||
|
doc, tuple(zip(*gold.orig_annot)) + (gold.cats,)
|
||||||
|
)
|
||||||
gold_deps = set()
|
gold_deps = set()
|
||||||
gold_tags = set()
|
gold_tags = set()
|
||||||
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))
|
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))
|
||||||
|
|
|
@ -47,11 +47,14 @@ cdef struct SerializedLexemeC:
|
||||||
# + sizeof(float) # l2_norm
|
# + sizeof(float) # l2_norm
|
||||||
|
|
||||||
|
|
||||||
cdef struct Entity:
|
cdef struct SpanC:
|
||||||
hash_t id
|
hash_t id
|
||||||
int start
|
int start
|
||||||
int end
|
int end
|
||||||
|
int start_char
|
||||||
|
int end_char
|
||||||
attr_t label
|
attr_t label
|
||||||
|
attr_t kb_id
|
||||||
|
|
||||||
|
|
||||||
cdef struct TokenC:
|
cdef struct TokenC:
|
||||||
|
|
|
@ -7,7 +7,7 @@ from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
|
||||||
from murmurhash.mrmr cimport hash64
|
from murmurhash.mrmr cimport hash64
|
||||||
|
|
||||||
from ..vocab cimport EMPTY_LEXEME
|
from ..vocab cimport EMPTY_LEXEME
|
||||||
from ..structs cimport TokenC, Entity
|
from ..structs cimport TokenC, SpanC
|
||||||
from ..lexeme cimport Lexeme
|
from ..lexeme cimport Lexeme
|
||||||
from ..symbols cimport punct
|
from ..symbols cimport punct
|
||||||
from ..attrs cimport IS_SPACE
|
from ..attrs cimport IS_SPACE
|
||||||
|
@ -40,7 +40,7 @@ cdef cppclass StateC:
|
||||||
int* _buffer
|
int* _buffer
|
||||||
bint* shifted
|
bint* shifted
|
||||||
TokenC* _sent
|
TokenC* _sent
|
||||||
Entity* _ents
|
SpanC* _ents
|
||||||
TokenC _empty_token
|
TokenC _empty_token
|
||||||
RingBufferC _hist
|
RingBufferC _hist
|
||||||
int length
|
int length
|
||||||
|
@ -56,7 +56,7 @@ cdef cppclass StateC:
|
||||||
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||||
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
|
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
|
||||||
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
|
this._ents = <SpanC*>calloc(length + (PADDING * 2), sizeof(SpanC))
|
||||||
if not (this._buffer and this._stack and this.shifted
|
if not (this._buffer and this._stack and this.shifted
|
||||||
and this._sent and this._ents):
|
and this._sent and this._ents):
|
||||||
with gil:
|
with gil:
|
||||||
|
@ -406,7 +406,7 @@ cdef cppclass StateC:
|
||||||
memcpy(this._sent, src._sent, this.length * sizeof(TokenC))
|
memcpy(this._sent, src._sent, this.length * sizeof(TokenC))
|
||||||
memcpy(this._stack, src._stack, this.length * sizeof(int))
|
memcpy(this._stack, src._stack, this.length * sizeof(int))
|
||||||
memcpy(this._buffer, src._buffer, this.length * sizeof(int))
|
memcpy(this._buffer, src._buffer, this.length * sizeof(int))
|
||||||
memcpy(this._ents, src._ents, this.length * sizeof(Entity))
|
memcpy(this._ents, src._ents, this.length * sizeof(SpanC))
|
||||||
memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0]))
|
memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0]))
|
||||||
this._b_i = src._b_i
|
this._b_i = src._b_i
|
||||||
this._s_i = src._s_i
|
this._s_i = src._s_i
|
||||||
|
|
|
@ -3,7 +3,7 @@ from libc.string cimport memcpy, memset
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
cimport cython
|
cimport cython
|
||||||
|
|
||||||
from ..structs cimport TokenC, Entity
|
from ..structs cimport TokenC, SpanC
|
||||||
from ..typedefs cimport attr_t
|
from ..typedefs cimport attr_t
|
||||||
|
|
||||||
from ..vocab cimport EMPTY_LEXEME
|
from ..vocab cimport EMPTY_LEXEME
|
||||||
|
|
|
@ -135,6 +135,11 @@ def ko_tokenizer():
|
||||||
return get_lang_class("ko").Defaults.create_tokenizer()
|
return get_lang_class("ko").Defaults.create_tokenizer()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def lb_tokenizer():
|
||||||
|
return get_lang_class("lb").Defaults.create_tokenizer()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def lt_tokenizer():
|
def lt_tokenizer():
|
||||||
return get_lang_class("lt").Defaults.create_tokenizer()
|
return get_lang_class("lt").Defaults.create_tokenizer()
|
||||||
|
|
|
@ -253,3 +253,11 @@ def test_filter_spans(doc):
|
||||||
assert len(filtered[1]) == 5
|
assert len(filtered[1]) == 5
|
||||||
assert filtered[0].start == 1 and filtered[0].end == 4
|
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||||
assert filtered[1].start == 5 and filtered[1].end == 10
|
assert filtered[1].start == 5 and filtered[1].end == 10
|
||||||
|
# Test filtering overlaps with earlier preference for identical length
|
||||||
|
spans = [doc[1:4], doc[2:5], doc[5:10], doc[7:9], doc[1:4]]
|
||||||
|
filtered = filter_spans(spans)
|
||||||
|
assert len(filtered) == 2
|
||||||
|
assert len(filtered[0]) == 3
|
||||||
|
assert len(filtered[1]) == 5
|
||||||
|
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||||
|
assert filtered[1].start == 5 and filtered[1].end == 10
|
||||||
|
|
10
spacy/tests/lang/lb/test_exceptions.py
Normal file
10
spacy/tests/lang/lb/test_exceptions.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("text", ["z.B.", "Jan."])
|
||||||
|
def test_lb_tokenizer_handles_abbr(lb_tokenizer, text):
|
||||||
|
tokens = lb_tokenizer(text)
|
||||||
|
assert len(tokens) == 1
|
22
spacy/tests/lang/lb/test_prefix_suffix_infix.py
Normal file
22
spacy/tests/lang/lb/test_prefix_suffix_infix.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("text,length", [("z.B.", 1), ("zb.", 2), ("(z.B.", 2)])
|
||||||
|
def test_lb_tokenizer_splits_prefix_interact(lb_tokenizer, text, length):
|
||||||
|
tokens = lb_tokenizer(text)
|
||||||
|
assert len(tokens) == length
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("text", ["z.B.)"])
|
||||||
|
def test_lb_tokenizer_splits_suffix_interact(lb_tokenizer, text):
|
||||||
|
tokens = lb_tokenizer(text)
|
||||||
|
assert len(tokens) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("text", ["(z.B.)"])
|
||||||
|
def test_lb_tokenizer_splits_even_wrap_interact(lb_tokenizer, text):
|
||||||
|
tokens = lb_tokenizer(text)
|
||||||
|
assert len(tokens) == 3
|
31
spacy/tests/lang/lb/test_text.py
Normal file
31
spacy/tests/lang/lb/test_text.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_lb_tokenizer_handles_long_text(lb_tokenizer):
|
||||||
|
text = """Den Nordwand an d'Sonn
|
||||||
|
|
||||||
|
An der Zäit hunn sech den Nordwand an d’Sonn gestridden, wie vun hinnen zwee wuel méi staark wier, wéi e Wanderer, deen an ee waarme Mantel agepak war, iwwert de Wee koum. Si goufen sech eens, dass deejéinege fir de Stäerkste gëlle sollt, deen de Wanderer forcéiere géif, säi Mantel auszedoen.",
|
||||||
|
|
||||||
|
Den Nordwand huet mat aller Force geblosen, awer wat e méi geblosen huet, wat de Wanderer sech méi a säi Mantel agewéckelt huet. Um Enn huet den Nordwand säi Kampf opginn.
|
||||||
|
|
||||||
|
Dunn huet d’Sonn d’Loft mat hire frëndleche Strale gewiermt, a schonn no kuerzer Zäit huet de Wanderer säi Mantel ausgedoen.
|
||||||
|
|
||||||
|
Do huet den Nordwand missen zouginn, dass d’Sonn vun hinnen zwee de Stäerkste wier."""
|
||||||
|
|
||||||
|
tokens = lb_tokenizer(text)
|
||||||
|
assert len(tokens) == 143
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"text,length",
|
||||||
|
[
|
||||||
|
("»Wat ass mat mir geschitt?«, huet hie geduecht.", 13),
|
||||||
|
("“Dëst fréi Opstoen”, denkt hien, “mécht ee ganz duercherneen. ", 15),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_lb_tokenizer_handles_examples(lb_tokenizer, text, length):
|
||||||
|
tokens = lb_tokenizer(text)
|
||||||
|
assert len(tokens) == length
|
|
@ -3,6 +3,8 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from spacy.lang.en import English
|
||||||
from spacy.matcher import Matcher
|
from spacy.matcher import Matcher
|
||||||
from spacy.tokens import Doc, Span
|
from spacy.tokens import Doc, Span
|
||||||
|
|
||||||
|
@ -143,3 +145,29 @@ def test_matcher_sets_return_correct_tokens(en_vocab):
|
||||||
matches = matcher(doc)
|
matches = matcher(doc)
|
||||||
texts = [Span(doc, s, e, label=L).text for L, s, e in matches]
|
texts = [Span(doc, s, e, label=L).text for L, s, e in matches]
|
||||||
assert texts == ["zero", "one", "two"]
|
assert texts == ["zero", "one", "two"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_remove():
|
||||||
|
nlp = English()
|
||||||
|
matcher = Matcher(nlp.vocab)
|
||||||
|
text = "This is a test case."
|
||||||
|
|
||||||
|
pattern = [{"ORTH": "test"}, {"OP": "?"}]
|
||||||
|
assert len(matcher) == 0
|
||||||
|
matcher.add("Rule", None, pattern)
|
||||||
|
assert "Rule" in matcher
|
||||||
|
|
||||||
|
# should give two matches
|
||||||
|
results1 = matcher(nlp(text))
|
||||||
|
assert len(results1) == 2
|
||||||
|
|
||||||
|
# removing once should work
|
||||||
|
matcher.remove("Rule")
|
||||||
|
|
||||||
|
# should not return any maches anymore
|
||||||
|
results2 = matcher(nlp(text))
|
||||||
|
assert len(results2) == 0
|
||||||
|
|
||||||
|
# removing again should throw an error
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
matcher.remove("Rule")
|
||||||
|
|
|
@ -12,24 +12,25 @@ from spacy.util import get_json_validator, validate_json
|
||||||
TEST_PATTERNS = [
|
TEST_PATTERNS = [
|
||||||
# Bad patterns flagged in all cases
|
# Bad patterns flagged in all cases
|
||||||
([{"XX": "foo"}], 1, 1),
|
([{"XX": "foo"}], 1, 1),
|
||||||
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 1),
|
|
||||||
([{"IS_ALPHA": {"==": True}}, {"LIKE_NUM": None}], 2, 1),
|
([{"IS_ALPHA": {"==": True}}, {"LIKE_NUM": None}], 2, 1),
|
||||||
([{"IS_PUNCT": True, "OP": "$"}], 1, 1),
|
([{"IS_PUNCT": True, "OP": "$"}], 1, 1),
|
||||||
([{"IS_DIGIT": -1}], 1, 1),
|
|
||||||
([{"ORTH": -1}], 1, 1),
|
|
||||||
([{"_": "foo"}], 1, 1),
|
([{"_": "foo"}], 1, 1),
|
||||||
('[{"TEXT": "foo"}, {"LOWER": "bar"}]', 1, 1),
|
('[{"TEXT": "foo"}, {"LOWER": "bar"}]', 1, 1),
|
||||||
([1, 2, 3], 3, 1),
|
([1, 2, 3], 3, 1),
|
||||||
# Bad patterns flagged outside of Matcher
|
# Bad patterns flagged outside of Matcher
|
||||||
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 1, 0),
|
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 1, 0),
|
||||||
# Bad patterns not flagged with minimal checks
|
# Bad patterns not flagged with minimal checks
|
||||||
|
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 0),
|
||||||
([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 2, 0),
|
([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 2, 0),
|
||||||
([{"LENGTH": {"VALUE": 5}}], 1, 0),
|
([{"LENGTH": {"VALUE": 5}}], 1, 0),
|
||||||
([{"TEXT": {"VALUE": "foo"}}], 1, 0),
|
([{"TEXT": {"VALUE": "foo"}}], 1, 0),
|
||||||
|
([{"IS_DIGIT": -1}], 1, 0),
|
||||||
|
([{"ORTH": -1}], 1, 0),
|
||||||
# Good patterns
|
# Good patterns
|
||||||
([{"TEXT": "foo"}, {"LOWER": "bar"}], 0, 0),
|
([{"TEXT": "foo"}, {"LOWER": "bar"}], 0, 0),
|
||||||
([{"LEMMA": {"IN": ["love", "like"]}}, {"POS": "DET", "OP": "?"}], 0, 0),
|
([{"LEMMA": {"IN": ["love", "like"]}}, {"POS": "DET", "OP": "?"}], 0, 0),
|
||||||
([{"LIKE_NUM": True, "LENGTH": {">=": 5}}], 0, 0),
|
([{"LIKE_NUM": True, "LENGTH": {">=": 5}}], 0, 0),
|
||||||
|
([{"LENGTH": 2}], 0, 0),
|
||||||
([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0, 0),
|
([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0, 0),
|
||||||
([{"NORM": "a"}, {"POS": {"IN": ["NOUN"]}}], 0, 0),
|
([{"NORM": "a"}, {"POS": {"IN": ["NOUN"]}}], 0, 0),
|
||||||
([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0, 0),
|
([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0, 0),
|
||||||
|
|
|
@ -226,3 +226,13 @@ def test_phrase_matcher_callback(en_vocab):
|
||||||
matcher.add("COMPANY", mock, pattern)
|
matcher.add("COMPANY", mock, pattern)
|
||||||
matches = matcher(doc)
|
matches = matcher(doc)
|
||||||
mock.assert_called_once_with(matcher, doc, 0, matches)
|
mock.assert_called_once_with(matcher, doc, 0, matches)
|
||||||
|
|
||||||
|
|
||||||
|
def test_phrase_matcher_remove_overlapping_patterns(en_vocab):
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
pattern1 = Doc(en_vocab, words=["this"])
|
||||||
|
pattern2 = Doc(en_vocab, words=["this", "is"])
|
||||||
|
pattern3 = Doc(en_vocab, words=["this", "is", "a"])
|
||||||
|
pattern4 = Doc(en_vocab, words=["this", "is", "a", "word"])
|
||||||
|
matcher.add("THIS", None, pattern1, pattern2, pattern3, pattern4)
|
||||||
|
matcher.remove("THIS")
|
||||||
|
|
|
@ -103,7 +103,7 @@ def test_oracle_moves_missing_B(en_vocab):
|
||||||
moves.add_action(move_types.index("L"), label)
|
moves.add_action(move_types.index("L"), label)
|
||||||
moves.add_action(move_types.index("U"), label)
|
moves.add_action(move_types.index("U"), label)
|
||||||
moves.preprocess_gold(gold)
|
moves.preprocess_gold(gold)
|
||||||
seq = moves.get_oracle_sequence(doc, gold)
|
moves.get_oracle_sequence(doc, gold)
|
||||||
|
|
||||||
|
|
||||||
def test_oracle_moves_whitespace(en_vocab):
|
def test_oracle_moves_whitespace(en_vocab):
|
||||||
|
|
|
@ -131,6 +131,53 @@ def test_candidate_generation(nlp):
|
||||||
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
|
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
|
||||||
|
|
||||||
|
|
||||||
|
def test_append_alias(nlp):
|
||||||
|
"""Test that we can append additional alias-entity pairs"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
|
# adding entities
|
||||||
|
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||||
|
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
||||||
|
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
||||||
|
|
||||||
|
# adding aliases
|
||||||
|
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.4, 0.1])
|
||||||
|
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||||
|
|
||||||
|
# test the size of the relevant candidates
|
||||||
|
assert len(mykb.get_candidates("douglas")) == 2
|
||||||
|
|
||||||
|
# append an alias
|
||||||
|
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
|
||||||
|
|
||||||
|
# test the size of the relevant candidates has been incremented
|
||||||
|
assert len(mykb.get_candidates("douglas")) == 3
|
||||||
|
|
||||||
|
# append the same alias-entity pair again should not work (will throw a warning)
|
||||||
|
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3)
|
||||||
|
|
||||||
|
# test the size of the relevant candidates remained unchanged
|
||||||
|
assert len(mykb.get_candidates("douglas")) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_append_invalid_alias(nlp):
|
||||||
|
"""Test that append an alias will throw an error if prior probs are exceeding 1"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
|
# adding entities
|
||||||
|
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||||
|
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
||||||
|
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
||||||
|
|
||||||
|
# adding aliases
|
||||||
|
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
|
||||||
|
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||||
|
|
||||||
|
# append an alias - should fail because the entities and probabilities vectors are not of equal length
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
|
||||||
|
|
||||||
|
|
||||||
def test_preserving_links_asdoc(nlp):
|
def test_preserving_links_asdoc(nlp):
|
||||||
"""Test that Span.as_doc preserves the existing entity links"""
|
"""Test that Span.as_doc preserves the existing entity links"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
|
@ -430,7 +430,7 @@ def test_issue957(en_tokenizer):
|
||||||
def test_issue999(train_data):
|
def test_issue999(train_data):
|
||||||
"""Test that adding entities and resuming training works passably OK.
|
"""Test that adding entities and resuming training works passably OK.
|
||||||
There are two issues here:
|
There are two issues here:
|
||||||
1) We have to read labels. This isn't very nice.
|
1) We have to re-add labels. This isn't very nice.
|
||||||
2) There's no way to set the learning rate for the weight update, so we
|
2) There's no way to set the learning rate for the weight update, so we
|
||||||
end up out-of-scale, causing it to learn too fast.
|
end up out-of-scale, causing it to learn too fast.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -323,7 +323,7 @@ def test_issue3456():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
nlp.add_pipe(nlp.create_pipe("tagger"))
|
nlp.add_pipe(nlp.create_pipe("tagger"))
|
||||||
nlp.begin_training()
|
nlp.begin_training()
|
||||||
list(nlp.pipe(['hi', '']))
|
list(nlp.pipe(["hi", ""]))
|
||||||
|
|
||||||
|
|
||||||
def test_issue3468():
|
def test_issue3468():
|
||||||
|
|
|
@ -76,7 +76,6 @@ def test_issue4042_bug2():
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
ner1.to_disk(output_dir)
|
ner1.to_disk(output_dir)
|
||||||
|
|
||||||
nlp2 = English(vocab)
|
|
||||||
ner2 = EntityRecognizer(vocab)
|
ner2 = EntityRecognizer(vocab)
|
||||||
ner2.from_disk(output_dir)
|
ner2.from_disk(output_dir)
|
||||||
assert len(ner2.labels) == 2
|
assert len(ner2.labels) == 2
|
||||||
|
|
|
@ -1,13 +1,8 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import spacy
|
|
||||||
|
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.pipeline import EntityRuler
|
from spacy.pipeline import EntityRuler
|
||||||
from spacy.tokens import Span
|
|
||||||
|
|
||||||
|
|
||||||
def test_issue4267():
|
def test_issue4267():
|
||||||
|
|
|
@ -6,6 +6,6 @@ from spacy.tokens import DocBin
|
||||||
|
|
||||||
def test_issue4367():
|
def test_issue4367():
|
||||||
"""Test that docbin init goes well"""
|
"""Test that docbin init goes well"""
|
||||||
doc_bin_1 = DocBin()
|
DocBin()
|
||||||
doc_bin_2 = DocBin(attrs=["LEMMA"])
|
DocBin(attrs=["LEMMA"])
|
||||||
doc_bin_3 = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"])
|
DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"])
|
||||||
|
|
|
@ -74,4 +74,4 @@ def test_serialize_doc_bin():
|
||||||
# Deserialize later, e.g. in a new process
|
# Deserialize later, e.g. in a new process
|
||||||
nlp = spacy.blank("en")
|
nlp = spacy.blank("en")
|
||||||
doc_bin = DocBin().from_bytes(bytes_data)
|
doc_bin = DocBin().from_bytes(bytes_data)
|
||||||
docs = list(doc_bin.get_docs(nlp.vocab))
|
list(doc_bin.get_docs(nlp.vocab))
|
||||||
|
|
|
@ -48,8 +48,13 @@ URLS_SHOULD_MATCH = [
|
||||||
"http://a.b--c.de/", # this is a legit domain name see: https://gist.github.com/dperini/729294 comment on 9/9/2014
|
"http://a.b--c.de/", # this is a legit domain name see: https://gist.github.com/dperini/729294 comment on 9/9/2014
|
||||||
"ssh://login@server.com:12345/repository.git",
|
"ssh://login@server.com:12345/repository.git",
|
||||||
"svn+ssh://user@ssh.yourdomain.com/path",
|
"svn+ssh://user@ssh.yourdomain.com/path",
|
||||||
pytest.param("chrome://extensions/?id=mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail()),
|
pytest.param(
|
||||||
pytest.param("chrome-extension://mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail()),
|
"chrome://extensions/?id=mhjfbmdgcfjbbpaeojofohoefgiehjai",
|
||||||
|
marks=pytest.mark.xfail(),
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"chrome-extension://mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail()
|
||||||
|
),
|
||||||
pytest.param("http://foo.com/blah_blah_(wikipedia)", marks=pytest.mark.xfail()),
|
pytest.param("http://foo.com/blah_blah_(wikipedia)", marks=pytest.mark.xfail()),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
"http://foo.com/blah_blah_(wikipedia)_(again)", marks=pytest.mark.xfail()
|
"http://foo.com/blah_blah_(wikipedia)_(again)", marks=pytest.mark.xfail()
|
||||||
|
|
|
@ -51,6 +51,14 @@ def data():
|
||||||
return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype="f")
|
return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype="f")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def most_similar_vectors_data():
|
||||||
|
return numpy.asarray(
|
||||||
|
[[0.0, 1.0, 2.0], [1.0, -2.0, 4.0], [1.0, 1.0, -1.0], [2.0, 3.0, 1.0]],
|
||||||
|
dtype="f",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def resize_data():
|
def resize_data():
|
||||||
return numpy.asarray([[0.0, 1.0], [2.0, 3.0]], dtype="f")
|
return numpy.asarray([[0.0, 1.0], [2.0, 3.0]], dtype="f")
|
||||||
|
@ -127,6 +135,12 @@ def test_set_vector(strings, data):
|
||||||
assert list(v[strings[0]]) != list(orig[0])
|
assert list(v[strings[0]]) != list(orig[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_vectors_most_similar(most_similar_vectors_data):
|
||||||
|
v = Vectors(data=most_similar_vectors_data)
|
||||||
|
_, best_rows, _ = v.most_similar(v.data, batch_size=2, n=2, sort=True)
|
||||||
|
assert all(row[0] == i for i, row in enumerate(best_rows))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("text", ["apple and orange"])
|
@pytest.mark.parametrize("text", ["apple and orange"])
|
||||||
def test_vectors_token_vector(tokenizer_v, vectors, text):
|
def test_vectors_token_vector(tokenizer_v, vectors, text):
|
||||||
doc = tokenizer_v(text)
|
doc = tokenizer_v(text)
|
||||||
|
@ -284,7 +298,7 @@ def test_vocab_prune_vectors():
|
||||||
vocab.set_vector("dog", data[1])
|
vocab.set_vector("dog", data[1])
|
||||||
vocab.set_vector("kitten", data[2])
|
vocab.set_vector("kitten", data[2])
|
||||||
|
|
||||||
remap = vocab.prune_vectors(2)
|
remap = vocab.prune_vectors(2, batch_size=2)
|
||||||
assert list(remap.keys()) == ["kitten"]
|
assert list(remap.keys()) == ["kitten"]
|
||||||
neighbour, similarity = list(remap.values())[0]
|
neighbour, similarity = list(remap.values())[0]
|
||||||
assert neighbour == "cat", remap
|
assert neighbour == "cat", remap
|
||||||
|
|
|
@ -666,7 +666,7 @@ def filter_spans(spans):
|
||||||
spans (iterable): The spans to filter.
|
spans (iterable): The spans to filter.
|
||||||
RETURNS (list): The filtered spans.
|
RETURNS (list): The filtered spans.
|
||||||
"""
|
"""
|
||||||
get_sort_key = lambda span: (span.end - span.start, span.start)
|
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
||||||
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||||
result = []
|
result = []
|
||||||
seen_tokens = set()
|
seen_tokens = set()
|
||||||
|
|
|
@ -336,8 +336,8 @@ cdef class Vectors:
|
||||||
best_rows[i:i+batch_size] = xp.argpartition(sims, -n, axis=1)[:,-n:]
|
best_rows[i:i+batch_size] = xp.argpartition(sims, -n, axis=1)[:,-n:]
|
||||||
scores[i:i+batch_size] = xp.partition(sims, -n, axis=1)[:,-n:]
|
scores[i:i+batch_size] = xp.partition(sims, -n, axis=1)[:,-n:]
|
||||||
|
|
||||||
if sort:
|
if sort and n >= 2:
|
||||||
sorted_index = xp.arange(scores.shape[0])[:,None],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
|
sorted_index = xp.arange(scores.shape[0])[:,None][i:i+batch_size],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
|
||||||
scores[i:i+batch_size] = scores[sorted_index]
|
scores[i:i+batch_size] = scores[sorted_index]
|
||||||
best_rows[i:i+batch_size] = best_rows[sorted_index]
|
best_rows[i:i+batch_size] = best_rows[sorted_index]
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ Whether the provided syntactic annotations form a projective dependency tree.
|
||||||
|
|
||||||
Convert a list of Doc objects into the
|
Convert a list of Doc objects into the
|
||||||
[JSON-serializable format](/api/annotation#json-input) used by the
|
[JSON-serializable format](/api/annotation#json-input) used by the
|
||||||
[`spacy train`](/api/cli#train) command.
|
[`spacy train`](/api/cli#train) command. Each input doc will be treated as a 'paragraph' in the output doc.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
@ -77,7 +77,7 @@ Convert a list of Doc objects into the
|
||||||
| ----------- | ---------------- | ------------------------------------------ |
|
| ----------- | ---------------- | ------------------------------------------ |
|
||||||
| `docs` | iterable / `Doc` | The `Doc` object(s) to convert. |
|
| `docs` | iterable / `Doc` | The `Doc` object(s) to convert. |
|
||||||
| `id` | int | ID to assign to the JSON. Defaults to `0`. |
|
| `id` | int | ID to assign to the JSON. Defaults to `0`. |
|
||||||
| **RETURNS** | list | The data in spaCy's JSON format. |
|
| **RETURNS** | dict | The data in spaCy's JSON format. |
|
||||||
|
|
||||||
### gold.align {#align tag="function"}
|
### gold.align {#align tag="function"}
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ Lemmatize a string.
|
||||||
> ```python
|
> ```python
|
||||||
> from spacy.lemmatizer import Lemmatizer
|
> from spacy.lemmatizer import Lemmatizer
|
||||||
> from spacy.lookups import Lookups
|
> from spacy.lookups import Lookups
|
||||||
> lookups = Loookups()
|
> lookups = Lookups()
|
||||||
> lookups.add_table("lemma_rules", {"noun": [["s", ""]]})
|
> lookups.add_table("lemma_rules", {"noun": [["s", ""]]})
|
||||||
> lemmatizer = Lemmatizer(lookups)
|
> lemmatizer = Lemmatizer(lookups)
|
||||||
> lemmas = lemmatizer("ducks", "NOUN")
|
> lemmas = lemmatizer("ducks", "NOUN")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user