From 9f8c1e71a21457aa4110d9cced223665d70017d5 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Mon, 22 Jul 2019 13:34:12 +0200 Subject: [PATCH] fix for Issue #4000 --- examples/pipeline/wikidata_entity_linking.py | 18 +++++++++++++----- spacy/kb.pyx | 4 ++-- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/pipeline/wikidata_entity_linking.py b/examples/pipeline/wikidata_entity_linking.py index 478d35111..32f751cd7 100644 --- a/examples/pipeline/wikidata_entity_linking.py +++ b/examples/pipeline/wikidata_entity_linking.py @@ -1,6 +1,8 @@ # coding: utf-8 from __future__ import unicode_literals +import os +from os import path import random import datetime from pathlib import Path @@ -26,7 +28,8 @@ ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv" ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv" ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv" -KB_FILE = OUTPUT_DIR / "kb_1" / "kb" +KB_DIR = OUTPUT_DIR / "kb_1" +KB_FILE = "kb" NLP_1_DIR = OUTPUT_DIR / "nlp_1" NLP_2_DIR = OUTPUT_DIR / "nlp_2" @@ -118,7 +121,10 @@ def run_pipeline(): print() print("STEP 3b: write KB and NLP", now()) - kb_1.dump(KB_FILE) + + if not path.exists(KB_DIR): + os.makedirs(KB_DIR) + kb_1.dump(KB_DIR / KB_FILE) nlp_1.to_disk(NLP_1_DIR) print() @@ -127,7 +133,7 @@ def run_pipeline(): print("STEP 4: to_read_kb", now()) nlp_2 = spacy.load(NLP_1_DIR) kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH) - kb_2.load_bulk(KB_FILE) + kb_2.load_bulk(KB_DIR / KB_FILE) print("kb entities:", kb_2.get_size_entities()) print("kb aliases:", kb_2.get_size_aliases()) print() @@ -327,7 +333,8 @@ def _measure_acc(data, el_pipe=None, error_analysis=False): # only evaluating on positive examples for gold_kb, value in kb_dict.items(): if value: - correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb + offset = str(start) + "-" + str(end) + correct_entries_per_article[offset] = gold_kb for ent in doc.ents: ent_label = ent.label_ @@ -385,7 +392,8 @@ def _measure_baselines(data, kb): for gold_kb, value in kb_dict.items(): # only evaluating on positive examples if value: - correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb + offset = str(start) + "-" + str(end) + correct_entries_per_article[offset] = gold_kb for ent in doc.ents: label = ent.label_ diff --git a/spacy/kb.pyx b/spacy/kb.pyx index 9df0e4fc2..0f1c87de8 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -278,7 +278,7 @@ cdef class KnowledgeBase: cdef hash_t entity_hash cdef hash_t alias_hash cdef int64_t entry_index - cdef float freq + cdef float freq, prob cdef int32_t vector_index cdef KBEntryC entry cdef AliasC alias @@ -373,7 +373,7 @@ cdef class Writer: loc = bytes(loc) cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc self._fp = fopen(bytes_loc, 'wb') - assert self._fp != NULL + assert self._fp != NULL, "Could not access %s" % loc fseek(self._fp, 0, 0) def close(self):