mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Format
This commit is contained in:
parent
89f5b8abb3
commit
160a855246
|
@ -21,8 +21,8 @@ from .project.clone import project_clone # noqa: F401
|
|||
from .project.assets import project_assets # noqa: F401
|
||||
from .project.run import project_run # noqa: F401
|
||||
from .project.dvc import project_update_dvc # noqa: F401
|
||||
from .project.push import project_push # noqa: F401
|
||||
from .project.pull import project_pull # noqa: F401
|
||||
from .project.push import project_push # noqa: F401
|
||||
from .project.pull import project_pull # noqa: F401
|
||||
|
||||
|
||||
@app.command("link", no_args_is_help=True, deprecated=True, hidden=True)
|
||||
|
|
|
@ -75,7 +75,9 @@ def train(
|
|||
msg.info("Using CPU")
|
||||
msg.info(f"Loading config and nlp from: {config_path}")
|
||||
with show_validation_error(config_path):
|
||||
config = util.load_config(config_path, overrides=config_overrides, interpolate=True)
|
||||
config = util.load_config(
|
||||
config_path, overrides=config_overrides, interpolate=True
|
||||
)
|
||||
if config.get("training", {}).get("seed") is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
# Use original config here before it's resolved to functions
|
||||
|
@ -208,7 +210,9 @@ def create_evaluation_callback(
|
|||
scores = nlp.evaluate(dev_examples)
|
||||
# Calculate a weighted sum based on score_weights for the main score
|
||||
try:
|
||||
weighted_score = sum(scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights)
|
||||
weighted_score = sum(
|
||||
scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights
|
||||
)
|
||||
except KeyError as e:
|
||||
keys = list(scores.keys())
|
||||
err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys)
|
||||
|
@ -378,7 +382,8 @@ def setup_printer(
|
|||
|
||||
try:
|
||||
scores = [
|
||||
"{0:.2f}".format(float(info["other_scores"].get(col, 0.0))) for col in score_cols
|
||||
"{0:.2f}".format(float(info["other_scores"].get(col, 0.0)))
|
||||
for col in score_cols
|
||||
]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
|
|
|
@ -774,7 +774,13 @@ class Language:
|
|||
# we have no components to insert before/after, or we're replacing the last component
|
||||
self.add_pipe(factory_name, name=name, config=config, validate=validate)
|
||||
else:
|
||||
self.add_pipe(factory_name, name=name, before=pipe_index, config=config, validate=validate)
|
||||
self.add_pipe(
|
||||
factory_name,
|
||||
name=name,
|
||||
before=pipe_index,
|
||||
config=config,
|
||||
validate=validate,
|
||||
)
|
||||
|
||||
def rename_pipe(self, old_name: str, new_name: str) -> None:
|
||||
"""Rename a pipeline component.
|
||||
|
|
|
@ -30,6 +30,7 @@ def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]:
|
|||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||
kb.from_disk(kb_path)
|
||||
return kb
|
||||
|
||||
return kb_from_file
|
||||
|
||||
|
||||
|
@ -37,6 +38,7 @@ def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]:
|
|||
def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
|
||||
def empty_kb_factory(vocab):
|
||||
return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
|
||||
|
||||
return empty_kb_factory
|
||||
|
||||
|
||||
|
|
|
@ -76,6 +76,7 @@ def entity_linker():
|
|||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
|
||||
return kb
|
||||
|
||||
return create_kb
|
||||
|
||||
config = {"kb_loader": {"@assets": "TestIssue5230KB.v1"}}
|
||||
|
|
Loading…
Reference in New Issue
Block a user