fix for Issue #4000

This commit is contained in:
svlandeg 2019-07-22 13:34:12 +02:00
parent dae8a21282
commit 9f8c1e71a2
2 changed files with 15 additions and 7 deletions

View File

@ -1,6 +1,8 @@
# coding: utf-8
from __future__ import unicode_literals
import os
from os import path
import random
import datetime
from pathlib import Path
@ -26,7 +28,8 @@ ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
KB_FILE = OUTPUT_DIR / "kb_1" / "kb"
KB_DIR = OUTPUT_DIR / "kb_1"
KB_FILE = "kb"
NLP_1_DIR = OUTPUT_DIR / "nlp_1"
NLP_2_DIR = OUTPUT_DIR / "nlp_2"
@ -118,7 +121,10 @@ def run_pipeline():
print()
print("STEP 3b: write KB and NLP", now())
kb_1.dump(KB_FILE)
if not path.exists(KB_DIR):
os.makedirs(KB_DIR)
kb_1.dump(KB_DIR / KB_FILE)
nlp_1.to_disk(NLP_1_DIR)
print()
@ -127,7 +133,7 @@ def run_pipeline():
print("STEP 4: to_read_kb", now())
nlp_2 = spacy.load(NLP_1_DIR)
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
kb_2.load_bulk(KB_FILE)
kb_2.load_bulk(KB_DIR / KB_FILE)
print("kb entities:", kb_2.get_size_entities())
print("kb aliases:", kb_2.get_size_aliases())
print()
@ -327,7 +333,8 @@ def _measure_acc(data, el_pipe=None, error_analysis=False):
# only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
offset = str(start) + "-" + str(end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
ent_label = ent.label_
@ -385,7 +392,8 @@ def _measure_baselines(data, kb):
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
offset = str(start) + "-" + str(end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
label = ent.label_

View File

@ -278,7 +278,7 @@ cdef class KnowledgeBase:
cdef hash_t entity_hash
cdef hash_t alias_hash
cdef int64_t entry_index
cdef float freq
cdef float freq, prob
cdef int32_t vector_index
cdef KBEntryC entry
cdef AliasC alias
@ -373,7 +373,7 @@ cdef class Writer:
loc = bytes(loc)
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'wb')
assert self._fp != NULL
assert self._fp != NULL, "Could not access %s" % loc
fseek(self._fp, 0, 0)
def close(self):