This commit is contained in:
Matthew Honnibal 2020-08-23 21:15:12 +02:00
parent 89f5b8abb3
commit 160a855246
5 changed files with 20 additions and 6 deletions

View File

@ -21,8 +21,8 @@ from .project.clone import project_clone # noqa: F401
from .project.assets import project_assets # noqa: F401 from .project.assets import project_assets # noqa: F401
from .project.run import project_run # noqa: F401 from .project.run import project_run # noqa: F401
from .project.dvc import project_update_dvc # noqa: F401 from .project.dvc import project_update_dvc # noqa: F401
from .project.push import project_push # noqa: F401 from .project.push import project_push # noqa: F401
from .project.pull import project_pull # noqa: F401 from .project.pull import project_pull # noqa: F401
@app.command("link", no_args_is_help=True, deprecated=True, hidden=True) @app.command("link", no_args_is_help=True, deprecated=True, hidden=True)

View File

@ -75,7 +75,9 @@ def train(
msg.info("Using CPU") msg.info("Using CPU")
msg.info(f"Loading config and nlp from: {config_path}") msg.info(f"Loading config and nlp from: {config_path}")
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config(config_path, overrides=config_overrides, interpolate=True) config = util.load_config(
config_path, overrides=config_overrides, interpolate=True
)
if config.get("training", {}).get("seed") is not None: if config.get("training", {}).get("seed") is not None:
fix_random_seed(config["training"]["seed"]) fix_random_seed(config["training"]["seed"])
# Use original config here before it's resolved to functions # Use original config here before it's resolved to functions
@ -208,7 +210,9 @@ def create_evaluation_callback(
scores = nlp.evaluate(dev_examples) scores = nlp.evaluate(dev_examples)
# Calculate a weighted sum based on score_weights for the main score # Calculate a weighted sum based on score_weights for the main score
try: try:
weighted_score = sum(scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights) weighted_score = sum(
scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights
)
except KeyError as e: except KeyError as e:
keys = list(scores.keys()) keys = list(scores.keys())
err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys) err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys)
@ -378,7 +382,8 @@ def setup_printer(
try: try:
scores = [ scores = [
"{0:.2f}".format(float(info["other_scores"].get(col, 0.0))) for col in score_cols "{0:.2f}".format(float(info["other_scores"].get(col, 0.0)))
for col in score_cols
] ]
except KeyError as e: except KeyError as e:
raise KeyError( raise KeyError(

View File

@ -774,7 +774,13 @@ class Language:
# we have no components to insert before/after, or we're replacing the last component # we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name, config=config, validate=validate) self.add_pipe(factory_name, name=name, config=config, validate=validate)
else: else:
self.add_pipe(factory_name, name=name, before=pipe_index, config=config, validate=validate) self.add_pipe(
factory_name,
name=name,
before=pipe_index,
config=config,
validate=validate,
)
def rename_pipe(self, old_name: str, new_name: str) -> None: def rename_pipe(self, old_name: str, new_name: str) -> None:
"""Rename a pipeline component. """Rename a pipeline component.

View File

@ -30,6 +30,7 @@ def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]:
kb = KnowledgeBase(vocab, entity_vector_length=1) kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.from_disk(kb_path) kb.from_disk(kb_path)
return kb return kb
return kb_from_file return kb_from_file
@ -37,6 +38,7 @@ def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]:
def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]: def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
def empty_kb_factory(vocab): def empty_kb_factory(vocab):
return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length) return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
return empty_kb_factory return empty_kb_factory

View File

@ -76,6 +76,7 @@ def entity_linker():
kb = KnowledgeBase(vocab, entity_vector_length=1) kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
return kb return kb
return create_kb return create_kb
config = {"kb_loader": {"@assets": "TestIssue5230KB.v1"}} config = {"kb_loader": {"@assets": "TestIssue5230KB.v1"}}