mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
fix for Issue #4000
This commit is contained in:
parent
dae8a21282
commit
9f8c1e71a2
|
@ -1,6 +1,8 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
from os import path
|
||||
import random
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
@ -26,7 +28,8 @@ ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
|
|||
ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
|
||||
ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
|
||||
|
||||
KB_FILE = OUTPUT_DIR / "kb_1" / "kb"
|
||||
KB_DIR = OUTPUT_DIR / "kb_1"
|
||||
KB_FILE = "kb"
|
||||
NLP_1_DIR = OUTPUT_DIR / "nlp_1"
|
||||
NLP_2_DIR = OUTPUT_DIR / "nlp_2"
|
||||
|
||||
|
@ -118,7 +121,10 @@ def run_pipeline():
|
|||
print()
|
||||
|
||||
print("STEP 3b: write KB and NLP", now())
|
||||
kb_1.dump(KB_FILE)
|
||||
|
||||
if not path.exists(KB_DIR):
|
||||
os.makedirs(KB_DIR)
|
||||
kb_1.dump(KB_DIR / KB_FILE)
|
||||
nlp_1.to_disk(NLP_1_DIR)
|
||||
print()
|
||||
|
||||
|
@ -127,7 +133,7 @@ def run_pipeline():
|
|||
print("STEP 4: to_read_kb", now())
|
||||
nlp_2 = spacy.load(NLP_1_DIR)
|
||||
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
||||
kb_2.load_bulk(KB_FILE)
|
||||
kb_2.load_bulk(KB_DIR / KB_FILE)
|
||||
print("kb entities:", kb_2.get_size_entities())
|
||||
print("kb aliases:", kb_2.get_size_aliases())
|
||||
print()
|
||||
|
@ -327,7 +333,8 @@ def _measure_acc(data, el_pipe=None, error_analysis=False):
|
|||
# only evaluating on positive examples
|
||||
for gold_kb, value in kb_dict.items():
|
||||
if value:
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
offset = str(start) + "-" + str(end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
|
@ -385,7 +392,8 @@ def _measure_baselines(data, kb):
|
|||
for gold_kb, value in kb_dict.items():
|
||||
# only evaluating on positive examples
|
||||
if value:
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
offset = str(start) + "-" + str(end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
label = ent.label_
|
||||
|
|
|
@ -278,7 +278,7 @@ cdef class KnowledgeBase:
|
|||
cdef hash_t entity_hash
|
||||
cdef hash_t alias_hash
|
||||
cdef int64_t entry_index
|
||||
cdef float freq
|
||||
cdef float freq, prob
|
||||
cdef int32_t vector_index
|
||||
cdef KBEntryC entry
|
||||
cdef AliasC alias
|
||||
|
@ -373,7 +373,7 @@ cdef class Writer:
|
|||
loc = bytes(loc)
|
||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||
self._fp = fopen(<char*>bytes_loc, 'wb')
|
||||
assert self._fp != NULL
|
||||
assert self._fp != NULL, "Could not access %s" % loc
|
||||
fseek(self._fp, 0, 0)
|
||||
|
||||
def close(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user