mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Merge branch 'develop' into nightly.spacy.io
This commit is contained in:
commit
797ca6f3dd
|
@ -5,7 +5,7 @@
|
||||||
# data is passed in sentence-by-sentence via some prior preprocessing.
|
# data is passed in sentence-by-sentence via some prior preprocessing.
|
||||||
gold_preproc = false
|
gold_preproc = false
|
||||||
# Limitations on training document length or number of examples.
|
# Limitations on training document length or number of examples.
|
||||||
max_length = 5000
|
max_length = 3000
|
||||||
limit = 0
|
limit = 0
|
||||||
# Data augmentation
|
# Data augmentation
|
||||||
orth_variant_level = 0.0
|
orth_variant_level = 0.0
|
||||||
|
@ -17,20 +17,20 @@ max_steps = 0
|
||||||
eval_frequency = 1000
|
eval_frequency = 1000
|
||||||
# Other settings
|
# Other settings
|
||||||
seed = 0
|
seed = 0
|
||||||
accumulate_gradient = 2
|
accumulate_gradient = 1
|
||||||
use_pytorch_for_gpu_memory = false
|
use_pytorch_for_gpu_memory = false
|
||||||
# Control how scores are printed and checkpoints are evaluated.
|
# Control how scores are printed and checkpoints are evaluated.
|
||||||
scores = ["speed", "ents_p", "ents_r", "ents_f"]
|
scores = ["speed", "ents_p", "ents_r", "ents_f"]
|
||||||
score_weights = {"ents_f": 1.0}
|
score_weights = {"ents_f": 1.0}
|
||||||
# These settings are invalid for the transformer models.
|
# These settings are invalid for the transformer models.
|
||||||
init_tok2vec = null
|
init_tok2vec = null
|
||||||
discard_oversize = true
|
discard_oversize = false
|
||||||
omit_extra_lookups = false
|
omit_extra_lookups = false
|
||||||
batch_by_words = true
|
batch_by = "words"
|
||||||
|
|
||||||
[training.batch_size]
|
[training.batch_size]
|
||||||
@schedules = "compounding.v1"
|
@schedules = "compounding.v1"
|
||||||
start = 1000
|
start = 100
|
||||||
stop = 1000
|
stop = 1000
|
||||||
compound = 1.001
|
compound = 1.001
|
||||||
|
|
||||||
|
@ -45,12 +45,6 @@ use_averages = true
|
||||||
eps = 1e-8
|
eps = 1e-8
|
||||||
learn_rate = 0.001
|
learn_rate = 0.001
|
||||||
|
|
||||||
#[training.optimizer.learn_rate]
|
|
||||||
#@schedules = "warmup_linear.v1"
|
|
||||||
#warmup_steps = 1000
|
|
||||||
#total_steps = 50000
|
|
||||||
#initial_rate = 0.003
|
|
||||||
|
|
||||||
[nlp]
|
[nlp]
|
||||||
lang = "en"
|
lang = "en"
|
||||||
vectors = null
|
vectors = null
|
||||||
|
@ -74,6 +68,6 @@ width = 96
|
||||||
depth = 4
|
depth = 4
|
||||||
window_size = 1
|
window_size = 1
|
||||||
embed_size = 2000
|
embed_size = 2000
|
||||||
maxout_pieces = 1
|
maxout_pieces = 3
|
||||||
subword_features = true
|
subword_features = true
|
||||||
dropout = ${training:dropout}
|
dropout = ${training:dropout}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# fmt: off
|
# fmt: off
|
||||||
__title__ = "spacy-nightly"
|
__title__ = "spacy-nightly"
|
||||||
__version__ = "3.0.0a2"
|
__version__ = "3.0.0a3"
|
||||||
__release__ = True
|
__release__ = True
|
||||||
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
||||||
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
||||||
|
|
|
@ -120,8 +120,12 @@ def convert(
|
||||||
no_print=silent,
|
no_print=silent,
|
||||||
ner_map=ner_map,
|
ner_map=ner_map,
|
||||||
)
|
)
|
||||||
|
if file_type == "json":
|
||||||
|
data = [docs_to_json(docs)]
|
||||||
|
else:
|
||||||
|
data = DocBin(docs=docs, store_user_data=True).to_bytes()
|
||||||
if output_dir == "-":
|
if output_dir == "-":
|
||||||
_print_docs_to_stdout(docs, file_type)
|
_print_docs_to_stdout(data, file_type)
|
||||||
else:
|
else:
|
||||||
if input_loc != input_path:
|
if input_loc != input_path:
|
||||||
subpath = input_loc.relative_to(input_path)
|
subpath = input_loc.relative_to(input_path)
|
||||||
|
@ -129,24 +133,23 @@ def convert(
|
||||||
else:
|
else:
|
||||||
output_file = Path(output_dir) / input_loc.parts[-1]
|
output_file = Path(output_dir) / input_loc.parts[-1]
|
||||||
output_file = output_file.with_suffix(f".{file_type}")
|
output_file = output_file.with_suffix(f".{file_type}")
|
||||||
_write_docs_to_file(docs, output_file, file_type)
|
_write_docs_to_file(data, output_file, file_type)
|
||||||
msg.good(f"Generated output file ({len(docs)} documents): {output_file}")
|
msg.good(f"Generated output file ({len(docs)} documents): {output_file}")
|
||||||
|
|
||||||
|
|
||||||
def _print_docs_to_stdout(docs, output_type):
|
def _print_docs_to_stdout(data, output_type):
|
||||||
if output_type == "json":
|
if output_type == "json":
|
||||||
srsly.write_json("-", [docs_to_json(docs)])
|
srsly.write_json("-", data)
|
||||||
else:
|
else:
|
||||||
sys.stdout.buffer.write(DocBin(docs=docs, store_user_data=True).to_bytes())
|
sys.stdout.buffer.write(data)
|
||||||
|
|
||||||
|
|
||||||
def _write_docs_to_file(docs, output_file, output_type):
|
def _write_docs_to_file(data, output_file, output_type):
|
||||||
if not output_file.parent.exists():
|
if not output_file.parent.exists():
|
||||||
output_file.parent.mkdir(parents=True)
|
output_file.parent.mkdir(parents=True)
|
||||||
if output_type == "json":
|
if output_type == "json":
|
||||||
srsly.write_json(output_file, [docs_to_json(docs)])
|
srsly.write_json(output_file, data)
|
||||||
else:
|
else:
|
||||||
data = DocBin(docs=docs, store_user_data=True).to_bytes()
|
|
||||||
with output_file.open("wb") as file_:
|
with output_file.open("wb") as file_:
|
||||||
file_.write(data)
|
file_.write(data)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -24,22 +25,18 @@ DIRS = [
|
||||||
@project_cli.command("clone")
|
@project_cli.command("clone")
|
||||||
def project_clone_cli(
|
def project_clone_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
name: str = Arg(..., help="The name of the template to fetch"),
|
name: str = Arg(..., help="The name of the template to clone"),
|
||||||
dest: Path = Arg(Path.cwd(), help="Where to download and work. Defaults to current working directory.", exists=False),
|
dest: Optional[Path] = Arg(None, help="Where to clone the project. Defaults to current working directory", exists=False),
|
||||||
repo: str = Opt(about.__projects__, "--repo", "-r", help="The repository to look in."),
|
repo: str = Opt(about.__projects__, "--repo", "-r", help="The repository to clone from"),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
):
|
):
|
||||||
"""Clone a project template from a repository. Calls into "git" and will
|
"""Clone a project template from a repository. Calls into "git" and will
|
||||||
only download the files from the given subdirectory. The GitHub repo
|
only download the files from the given subdirectory. The GitHub repo
|
||||||
defaults to the official spaCy template repo, but can be customized
|
defaults to the official spaCy template repo, but can be customized
|
||||||
(including using a private repo). Setting the --git flag will also
|
(including using a private repo).
|
||||||
initialize the project directory as a Git repo. If the project is intended
|
|
||||||
to be a Git repo, it should be initialized with Git first, before
|
|
||||||
initializing DVC (Data Version Control). This allows DVC to integrate with
|
|
||||||
Git.
|
|
||||||
"""
|
"""
|
||||||
if dest == Path.cwd():
|
if dest is None:
|
||||||
dest = dest / name
|
dest = Path.cwd() / name
|
||||||
project_clone(name, dest, repo=repo)
|
project_clone(name, dest, repo=repo)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ def project_update_dvc_cli(
|
||||||
"""Auto-generate Data Version Control (DVC) config. A DVC
|
"""Auto-generate Data Version Control (DVC) config. A DVC
|
||||||
project can only define one pipeline, so you need to specify one workflow
|
project can only define one pipeline, so you need to specify one workflow
|
||||||
defined in the project.yml. If no workflow is specified, the first defined
|
defined in the project.yml. If no workflow is specified, the first defined
|
||||||
workflow is used. The DVC config will only be updated if
|
workflow is used. The DVC config will only be updated if the project.yml changed.
|
||||||
"""
|
"""
|
||||||
project_update_dvc(project_dir, workflow, verbose=verbose, force=force)
|
project_update_dvc(project_dir, workflow, verbose=verbose, force=force)
|
||||||
|
|
||||||
|
|
|
@ -20,14 +20,14 @@ def project_run_cli(
|
||||||
subcommand: str = Arg(None, help=f"Name of command defined in the {PROJECT_FILE}"),
|
subcommand: str = Arg(None, help=f"Name of command defined in the {PROJECT_FILE}"),
|
||||||
project_dir: Path = Arg(Path.cwd(), help="Location of project directory. Defaults to current working directory.", exists=True, file_okay=False),
|
project_dir: Path = Arg(Path.cwd(), help="Location of project directory. Defaults to current working directory.", exists=True, file_okay=False),
|
||||||
force: bool = Opt(False, "--force", "-F", help="Force re-running steps, even if nothing changed"),
|
force: bool = Opt(False, "--force", "-F", help="Force re-running steps, even if nothing changed"),
|
||||||
dry: bool = Opt(False, "--dry", "-D", help="Perform a dry run and don't execute commands"),
|
dry: bool = Opt(False, "--dry", "-D", help="Perform a dry run and don't execute scripts"),
|
||||||
show_help: bool = Opt(False, "--help", help="Show help message and available subcommands")
|
show_help: bool = Opt(False, "--help", help="Show help message and available subcommands")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
):
|
):
|
||||||
"""Run a named script or workflow defined in the project.yml. If a workflow
|
"""Run a named command or workflow defined in the project.yml. If a workflow
|
||||||
name is specified, all commands in the workflow are run, in order. If
|
name is specified, all commands in the workflow are run, in order. If
|
||||||
commands define inputs and/or outputs, they will only be re-run if state
|
commands define dependencies and/or outputs, they will only be re-run if
|
||||||
has changed.
|
state has changed.
|
||||||
"""
|
"""
|
||||||
if show_help or not subcommand:
|
if show_help or not subcommand:
|
||||||
print_run_help(project_dir, subcommand)
|
print_run_help(project_dir, subcommand)
|
||||||
|
|
|
@ -121,14 +121,14 @@ class ConfigSchema(BaseModel):
|
||||||
@app.command("train")
|
@app.command("train")
|
||||||
def train_cli(
|
def train_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
train_path: Path = Arg(..., help="Location of JSON-formatted training data", exists=True),
|
train_path: Path = Arg(..., help="Location of training data", exists=True),
|
||||||
dev_path: Path = Arg(..., help="Location of JSON-formatted development data", exists=True),
|
dev_path: Path = Arg(..., help="Location of development data", exists=True),
|
||||||
config_path: Path = Arg(..., help="Path to config file", exists=True),
|
config_path: Path = Arg(..., help="Path to config file", exists=True),
|
||||||
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store model in"),
|
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store model in"),
|
||||||
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||||
init_tok2vec: Optional[Path] = Opt(None, "--init-tok2vec", "-t2v", help="Path to pretrained weights for the tok2vec components. See 'spacy pretrain'. Experimental."),
|
init_tok2vec: Optional[Path] = Opt(None, "--init-tok2vec", "-t2v", help="Path to pretrained weights for the tok2vec components. See 'spacy pretrain'. Experimental."),
|
||||||
raw_text: Optional[Path] = Opt(None, "--raw-text", "-rt", help="Path to jsonl file with unlabelled text documents."),
|
raw_text: Optional[Path] = Opt(None, "--raw-text", "-rt", help="Path to jsonl file with unlabelled text documents."),
|
||||||
verbose: bool = Opt(False, "--verbose", "-VV", help="Display more information for debugging purposes"),
|
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||||
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
||||||
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map"),
|
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map"),
|
||||||
omit_extra_lookups: bool = Opt(False, "--omit-extra-lookups", "-OEL", help="Don't include extra lookups in model"),
|
omit_extra_lookups: bool = Opt(False, "--omit-extra-lookups", "-OEL", help="Don't include extra lookups in model"),
|
||||||
|
@ -203,8 +203,10 @@ def train(
|
||||||
msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}")
|
msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}")
|
||||||
train_examples = list(
|
train_examples = list(
|
||||||
corpus.train_dataset(
|
corpus.train_dataset(
|
||||||
nlp, shuffle=False, gold_preproc=training["gold_preproc"],
|
nlp,
|
||||||
max_length=training["max_length"]
|
shuffle=False,
|
||||||
|
gold_preproc=training["gold_preproc"],
|
||||||
|
max_length=training["max_length"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
nlp.begin_training(lambda: train_examples)
|
nlp.begin_training(lambda: train_examples)
|
||||||
|
@ -303,21 +305,26 @@ def create_train_batches(nlp, corpus, cfg):
|
||||||
)
|
)
|
||||||
|
|
||||||
epoch = 0
|
epoch = 0
|
||||||
|
batch_strategy = cfg.get("batch_by", "sequences")
|
||||||
while True:
|
while True:
|
||||||
if len(train_examples) == 0:
|
if len(train_examples) == 0:
|
||||||
raise ValueError(Errors.E988)
|
raise ValueError(Errors.E988)
|
||||||
epoch += 1
|
epoch += 1
|
||||||
if cfg.get("batch_by_words", True):
|
if batch_strategy == "padded":
|
||||||
|
batches = util.minibatch_by_padded_size(
|
||||||
|
train_examples,
|
||||||
|
size=cfg["batch_size"],
|
||||||
|
buffer=256,
|
||||||
|
discard_oversize=cfg["discard_oversize"],
|
||||||
|
)
|
||||||
|
elif batch_strategy == "words":
|
||||||
batches = util.minibatch_by_words(
|
batches = util.minibatch_by_words(
|
||||||
train_examples,
|
train_examples,
|
||||||
size=cfg["batch_size"],
|
size=cfg["batch_size"],
|
||||||
discard_oversize=cfg["discard_oversize"],
|
discard_oversize=cfg["discard_oversize"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batches = util.minibatch(
|
batches = util.minibatch(train_examples, size=cfg["batch_size"])
|
||||||
train_examples,
|
|
||||||
size=cfg["batch_size"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
|
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
|
||||||
try:
|
try:
|
||||||
|
@ -430,7 +437,9 @@ def train_while_improving(
|
||||||
|
|
||||||
if raw_text:
|
if raw_text:
|
||||||
random.shuffle(raw_text)
|
random.shuffle(raw_text)
|
||||||
raw_examples = [Example.from_dict(nlp.make_doc(rt["text"]), {}) for rt in raw_text]
|
raw_examples = [
|
||||||
|
Example.from_dict(nlp.make_doc(rt["text"]), {}) for rt in raw_text
|
||||||
|
]
|
||||||
raw_batches = util.minibatch(raw_examples, size=8)
|
raw_batches = util.minibatch(raw_examples, size=8)
|
||||||
|
|
||||||
for step, (epoch, batch) in enumerate(train_data):
|
for step, (epoch, batch) in enumerate(train_data):
|
||||||
|
|
|
@ -69,6 +69,9 @@ class Warnings(object):
|
||||||
W027 = ("Found a large training file of {size} bytes. Note that it may "
|
W027 = ("Found a large training file of {size} bytes. Note that it may "
|
||||||
"be more efficient to split your training data into multiple "
|
"be more efficient to split your training data into multiple "
|
||||||
"smaller JSON files instead.")
|
"smaller JSON files instead.")
|
||||||
|
W028 = ("Doc.from_array was called with a vector of type '{type}', "
|
||||||
|
"but is expecting one of type 'uint64' instead. This may result "
|
||||||
|
"in problems with the vocab further on in the pipeline.")
|
||||||
W030 = ("Some entities could not be aligned in the text \"{text}\" with "
|
W030 = ("Some entities could not be aligned in the text \"{text}\" with "
|
||||||
"entities \"{entities}\". Use "
|
"entities \"{entities}\". Use "
|
||||||
"`spacy.gold.biluo_tags_from_offsets(nlp.make_doc(text), entities)`"
|
"`spacy.gold.biluo_tags_from_offsets(nlp.make_doc(text), entities)`"
|
||||||
|
|
|
@ -36,6 +36,9 @@ cdef class Example:
|
||||||
self.y = reference
|
self.y = reference
|
||||||
self._alignment = alignment
|
self._alignment = alignment
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.predicted)
|
||||||
|
|
||||||
property predicted:
|
property predicted:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return self.x
|
return self.x
|
||||||
|
@ -326,8 +329,8 @@ def _fix_legacy_dict_data(example_dict):
|
||||||
for key, value in old_token_dict.items():
|
for key, value in old_token_dict.items():
|
||||||
if key in ("text", "ids", "brackets"):
|
if key in ("text", "ids", "brackets"):
|
||||||
pass
|
pass
|
||||||
elif key in remapping:
|
elif key.lower() in remapping:
|
||||||
token_dict[remapping[key]] = value
|
token_dict[remapping[key.lower()]] = value
|
||||||
else:
|
else:
|
||||||
raise KeyError(Errors.E983.format(key=key, dict="token_annotation", keys=remapping.keys()))
|
raise KeyError(Errors.E983.format(key=key, dict="token_annotation", keys=remapping.keys()))
|
||||||
text = example_dict.get("text", example_dict.get("raw"))
|
text = example_dict.get("text", example_dict.get("raw"))
|
||||||
|
|
|
@ -513,20 +513,23 @@ class Language(object):
|
||||||
):
|
):
|
||||||
"""Update the models in the pipeline.
|
"""Update the models in the pipeline.
|
||||||
|
|
||||||
examples (iterable): A batch of `Example` objects.
|
examples (Iterable[Example]): A batch of examples
|
||||||
dummy: Should not be set - serves to catch backwards-incompatible scripts.
|
dummy: Should not be set - serves to catch backwards-incompatible scripts.
|
||||||
drop (float): The dropout rate.
|
drop (float): The dropout rate.
|
||||||
sgd (callable): An optimizer.
|
sgd (Optimizer): An optimizer.
|
||||||
losses (dict): Dictionary to update with the loss, keyed by component.
|
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
|
||||||
component_cfg (dict): Config parameters for specific pipeline
|
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
|
||||||
components, keyed by component name.
|
components, keyed by component name.
|
||||||
|
RETURNS (Dict[str, float]): The updated losses dictionary
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#update
|
DOCS: https://spacy.io/api/language#update
|
||||||
"""
|
"""
|
||||||
if dummy is not None:
|
if dummy is not None:
|
||||||
raise ValueError(Errors.E989)
|
raise ValueError(Errors.E989)
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
if len(examples) == 0:
|
if len(examples) == 0:
|
||||||
return
|
return losses
|
||||||
if not isinstance(examples, Iterable):
|
if not isinstance(examples, Iterable):
|
||||||
raise TypeError(Errors.E978.format(name="language", method="update", types=type(examples)))
|
raise TypeError(Errors.E978.format(name="language", method="update", types=type(examples)))
|
||||||
wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)])
|
wrong_types = set([type(eg) for eg in examples if not isinstance(eg, Example)])
|
||||||
|
@ -552,6 +555,7 @@ class Language(object):
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if hasattr(proc, "model"):
|
if hasattr(proc, "model"):
|
||||||
proc.model.finish_update(sgd)
|
proc.model.finish_update(sgd)
|
||||||
|
return losses
|
||||||
|
|
||||||
def rehearse(self, examples, sgd=None, losses=None, config=None):
|
def rehearse(self, examples, sgd=None, losses=None, config=None):
|
||||||
"""Make a "rehearsal" update to the models in the pipeline, to prevent
|
"""Make a "rehearsal" update to the models in the pipeline, to prevent
|
||||||
|
@ -757,18 +761,17 @@ class Language(object):
|
||||||
):
|
):
|
||||||
"""Process texts as a stream, and yield `Doc` objects in order.
|
"""Process texts as a stream, and yield `Doc` objects in order.
|
||||||
|
|
||||||
texts (iterator): A sequence of texts to process.
|
texts (Iterable[str]): A sequence of texts to process.
|
||||||
as_tuples (bool): If set to True, inputs should be a sequence of
|
as_tuples (bool): If set to True, inputs should be a sequence of
|
||||||
(text, context) tuples. Output will then be a sequence of
|
(text, context) tuples. Output will then be a sequence of
|
||||||
(doc, context) tuples. Defaults to False.
|
(doc, context) tuples. Defaults to False.
|
||||||
batch_size (int): The number of texts to buffer.
|
batch_size (int): The number of texts to buffer.
|
||||||
disable (list): Names of the pipeline components to disable.
|
disable (List[str]): Names of the pipeline components to disable.
|
||||||
cleanup (bool): If True, unneeded strings are freed to control memory
|
cleanup (bool): If True, unneeded strings are freed to control memory
|
||||||
use. Experimental.
|
use. Experimental.
|
||||||
component_cfg (dict): An optional dictionary with extra keyword
|
component_cfg (Dict[str, Dict]): An optional dictionary with extra keyword
|
||||||
arguments for specific components.
|
arguments for specific components.
|
||||||
n_process (int): Number of processors to process texts, only supported
|
n_process (int): Number of processors to process texts. If -1, set `multiprocessing.cpu_count()`.
|
||||||
in Python3. If -1, set `multiprocessing.cpu_count()`.
|
|
||||||
YIELDS (Doc): Documents in the order of the original text.
|
YIELDS (Doc): Documents in the order of the original text.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#pipe
|
DOCS: https://spacy.io/api/language#pipe
|
||||||
|
|
|
@ -87,16 +87,16 @@ def build_text_classifier(
|
||||||
cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
|
cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
|
||||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
||||||
lower = HashEmbed(
|
lower = HashEmbed(
|
||||||
nO=width, nV=embed_size, column=cols.index(LOWER), dropout=dropout
|
nO=width, nV=embed_size, column=cols.index(LOWER), dropout=dropout, seed=10
|
||||||
)
|
)
|
||||||
prefix = HashEmbed(
|
prefix = HashEmbed(
|
||||||
nO=width // 2, nV=embed_size, column=cols.index(PREFIX), dropout=dropout
|
nO=width // 2, nV=embed_size, column=cols.index(PREFIX), dropout=dropout, seed=11
|
||||||
)
|
)
|
||||||
suffix = HashEmbed(
|
suffix = HashEmbed(
|
||||||
nO=width // 2, nV=embed_size, column=cols.index(SUFFIX), dropout=dropout
|
nO=width // 2, nV=embed_size, column=cols.index(SUFFIX), dropout=dropout, seed=12
|
||||||
)
|
)
|
||||||
shape = HashEmbed(
|
shape = HashEmbed(
|
||||||
nO=width // 2, nV=embed_size, column=cols.index(SHAPE), dropout=dropout
|
nO=width // 2, nV=embed_size, column=cols.index(SHAPE), dropout=dropout, seed=13
|
||||||
)
|
)
|
||||||
|
|
||||||
width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape])
|
width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape])
|
||||||
|
|
|
@ -154,16 +154,16 @@ def LayerNormalizedMaxout(width, maxout_pieces):
|
||||||
def MultiHashEmbed(
|
def MultiHashEmbed(
|
||||||
columns, width, rows, use_subwords, pretrained_vectors, mix, dropout
|
columns, width, rows, use_subwords, pretrained_vectors, mix, dropout
|
||||||
):
|
):
|
||||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout)
|
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6)
|
||||||
if use_subwords:
|
if use_subwords:
|
||||||
prefix = HashEmbed(
|
prefix = HashEmbed(
|
||||||
nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout
|
nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout, seed=7
|
||||||
)
|
)
|
||||||
suffix = HashEmbed(
|
suffix = HashEmbed(
|
||||||
nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout
|
nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout, seed=8
|
||||||
)
|
)
|
||||||
shape = HashEmbed(
|
shape = HashEmbed(
|
||||||
nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout
|
nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout, seed=9
|
||||||
)
|
)
|
||||||
|
|
||||||
if pretrained_vectors:
|
if pretrained_vectors:
|
||||||
|
@ -192,7 +192,7 @@ def MultiHashEmbed(
|
||||||
|
|
||||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||||
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
|
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
|
||||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout)
|
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5)
|
||||||
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
|
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||||
embed_layer = chr_embed | features >> with_array(norm)
|
embed_layer = chr_embed | features >> with_array(norm)
|
||||||
|
|
|
@ -58,12 +58,8 @@ class Pipe(object):
|
||||||
Both __call__ and pipe should delegate to the `predict()`
|
Both __call__ and pipe should delegate to the `predict()`
|
||||||
and `set_annotations()` methods.
|
and `set_annotations()` methods.
|
||||||
"""
|
"""
|
||||||
predictions = self.predict([doc])
|
scores = self.predict([doc])
|
||||||
if isinstance(predictions, tuple) and len(predictions) == 2:
|
self.set_annotations([doc], scores)
|
||||||
scores, tensors = predictions
|
|
||||||
self.set_annotations([doc], scores, tensors=tensors)
|
|
||||||
else:
|
|
||||||
self.set_annotations([doc], predictions)
|
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128):
|
def pipe(self, stream, batch_size=128):
|
||||||
|
@ -73,12 +69,8 @@ class Pipe(object):
|
||||||
and `set_annotations()` methods.
|
and `set_annotations()` methods.
|
||||||
"""
|
"""
|
||||||
for docs in util.minibatch(stream, size=batch_size):
|
for docs in util.minibatch(stream, size=batch_size):
|
||||||
predictions = self.predict(docs)
|
scores = self.predict(docs)
|
||||||
if isinstance(predictions, tuple) and len(tuple) == 2:
|
self.set_annotations(docs, scores)
|
||||||
scores, tensors = predictions
|
|
||||||
self.set_annotations(docs, scores, tensors=tensors)
|
|
||||||
else:
|
|
||||||
self.set_annotations(docs, predictions)
|
|
||||||
yield from docs
|
yield from docs
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
|
@ -87,7 +79,7 @@ class Pipe(object):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def set_annotations(self, docs, scores, tensors=None):
|
def set_annotations(self, docs, scores):
|
||||||
"""Modify a batch of documents, using pre-computed scores."""
|
"""Modify a batch of documents, using pre-computed scores."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -281,9 +273,10 @@ class Tagger(Pipe):
|
||||||
idx += 1
|
idx += 1
|
||||||
doc.is_tagged = True
|
doc.is_tagged = True
|
||||||
|
|
||||||
def update(self, examples, drop=0., sgd=None, losses=None, set_annotations=False):
|
def update(self, examples, *, drop=0., sgd=None, losses=None, set_annotations=False):
|
||||||
if losses is not None and self.name not in losses:
|
if losses is None:
|
||||||
losses[self.name] = 0.
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||||
|
@ -303,11 +296,11 @@ class Tagger(Pipe):
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
self.model.finish_update(sgd)
|
self.model.finish_update(sgd)
|
||||||
|
|
||||||
if losses is not None:
|
|
||||||
losses[self.name] += loss
|
losses[self.name] += loss
|
||||||
if set_annotations:
|
if set_annotations:
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
self.set_annotations(docs, self._scores2guesses(tag_scores))
|
self.set_annotations(docs, self._scores2guesses(tag_scores))
|
||||||
|
return losses
|
||||||
|
|
||||||
def rehearse(self, examples, drop=0., sgd=None, losses=None):
|
def rehearse(self, examples, drop=0., sgd=None, losses=None):
|
||||||
"""Perform a 'rehearsal' update, where we try to match the output of
|
"""Perform a 'rehearsal' update, where we try to match the output of
|
||||||
|
@ -635,7 +628,7 @@ class MultitaskObjective(Tagger):
|
||||||
def labels(self, value):
|
def labels(self, value):
|
||||||
self.cfg["labels"] = value
|
self.cfg["labels"] = value
|
||||||
|
|
||||||
def set_annotations(self, docs, dep_ids, tensors=None):
|
def set_annotations(self, docs, dep_ids):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def begin_training(self, get_examples=lambda: [], pipeline=None,
|
def begin_training(self, get_examples=lambda: [], pipeline=None,
|
||||||
|
@ -732,7 +725,7 @@ class ClozeMultitask(Pipe):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.distance = CosineDistance(ignore_zeros=True, normalize=False) # TODO: in config
|
self.distance = CosineDistance(ignore_zeros=True, normalize=False) # TODO: in config
|
||||||
|
|
||||||
def set_annotations(self, docs, dep_ids, tensors=None):
|
def set_annotations(self, docs, dep_ids):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def begin_training(self, get_examples=lambda: [], pipeline=None,
|
def begin_training(self, get_examples=lambda: [], pipeline=None,
|
||||||
|
@ -761,7 +754,7 @@ class ClozeMultitask(Pipe):
|
||||||
loss = self.distance.get_loss(prediction, target)
|
loss = self.distance.get_loss(prediction, target)
|
||||||
return loss, gradient
|
return loss, gradient
|
||||||
|
|
||||||
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
|
def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def rehearse(self, examples, drop=0., sgd=None, losses=None):
|
def rehearse(self, examples, drop=0., sgd=None, losses=None):
|
||||||
|
@ -809,8 +802,8 @@ class TextCategorizer(Pipe):
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128):
|
def pipe(self, stream, batch_size=128):
|
||||||
for docs in util.minibatch(stream, size=batch_size):
|
for docs in util.minibatch(stream, size=batch_size):
|
||||||
scores, tensors = self.predict(docs)
|
scores = self.predict(docs)
|
||||||
self.set_annotations(docs, scores, tensors=tensors)
|
self.set_annotations(docs, scores)
|
||||||
yield from docs
|
yield from docs
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
|
@ -820,22 +813,25 @@ class TextCategorizer(Pipe):
|
||||||
# Handle cases where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
xp = get_array_module(tensors)
|
xp = get_array_module(tensors)
|
||||||
scores = xp.zeros((len(docs), len(self.labels)))
|
scores = xp.zeros((len(docs), len(self.labels)))
|
||||||
return scores, tensors
|
return scores
|
||||||
|
|
||||||
scores = self.model.predict(docs)
|
scores = self.model.predict(docs)
|
||||||
scores = self.model.ops.asarray(scores)
|
scores = self.model.ops.asarray(scores)
|
||||||
return scores, tensors
|
return scores
|
||||||
|
|
||||||
def set_annotations(self, docs, scores, tensors=None):
|
def set_annotations(self, docs, scores):
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
for j, label in enumerate(self.labels):
|
for j, label in enumerate(self.labels):
|
||||||
doc.cats[label] = float(scores[i, j])
|
doc.cats[label] = float(scores[i, j])
|
||||||
|
|
||||||
def update(self, examples, state=None, drop=0., set_annotations=False, sgd=None, losses=None):
|
def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
try:
|
try:
|
||||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||||
# Handle cases where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
return
|
return losses
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
types = set([type(eg) for eg in examples])
|
types = set([type(eg) for eg in examples])
|
||||||
raise TypeError(Errors.E978.format(name="TextCategorizer", method="update", types=types))
|
raise TypeError(Errors.E978.format(name="TextCategorizer", method="update", types=types))
|
||||||
|
@ -847,12 +843,11 @@ class TextCategorizer(Pipe):
|
||||||
bp_scores(d_scores)
|
bp_scores(d_scores)
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.model.finish_update(sgd)
|
self.model.finish_update(sgd)
|
||||||
if losses is not None:
|
|
||||||
losses.setdefault(self.name, 0.0)
|
|
||||||
losses[self.name] += loss
|
losses[self.name] += loss
|
||||||
if set_annotations:
|
if set_annotations:
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
self.set_annotations(docs, scores=scores)
|
self.set_annotations(docs, scores=scores)
|
||||||
|
return losses
|
||||||
|
|
||||||
def rehearse(self, examples, drop=0., sgd=None, losses=None):
|
def rehearse(self, examples, drop=0., sgd=None, losses=None):
|
||||||
if self._rehearsal_model is None:
|
if self._rehearsal_model is None:
|
||||||
|
@ -1076,12 +1071,13 @@ class EntityLinker(Pipe):
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
||||||
def update(self, examples, state=None, set_annotations=False, drop=0.0, sgd=None, losses=None):
|
def update(self, examples, *, set_annotations=False, drop=0.0, sgd=None, losses=None):
|
||||||
self.require_kb()
|
self.require_kb()
|
||||||
if losses is not None:
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.0)
|
losses.setdefault(self.name, 0.0)
|
||||||
if not examples:
|
if not examples:
|
||||||
return 0
|
return losses
|
||||||
sentence_docs = []
|
sentence_docs = []
|
||||||
try:
|
try:
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
|
@ -1124,20 +1120,19 @@ class EntityLinker(Pipe):
|
||||||
return 0.0
|
return 0.0
|
||||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
||||||
loss, d_scores = self.get_similarity_loss(
|
loss, d_scores = self.get_similarity_loss(
|
||||||
scores=sentence_encodings,
|
sentence_encodings=sentence_encodings,
|
||||||
examples=examples
|
examples=examples
|
||||||
)
|
)
|
||||||
bp_context(d_scores)
|
bp_context(d_scores)
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.model.finish_update(sgd)
|
self.model.finish_update(sgd)
|
||||||
|
|
||||||
if losses is not None:
|
|
||||||
losses[self.name] += loss
|
losses[self.name] += loss
|
||||||
if set_annotations:
|
if set_annotations:
|
||||||
self.set_annotations(docs, predictions)
|
self.set_annotations(docs, predictions)
|
||||||
return loss
|
return losses
|
||||||
|
|
||||||
def get_similarity_loss(self, examples, scores):
|
def get_similarity_loss(self, examples, sentence_encodings):
|
||||||
entity_encodings = []
|
entity_encodings = []
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||||
|
@ -1149,41 +1144,23 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||||
|
|
||||||
if scores.shape != entity_encodings.shape:
|
if sentence_encodings.shape != entity_encodings.shape:
|
||||||
raise RuntimeError(Errors.E147.format(method="get_similarity_loss", msg="gold entities do not match up"))
|
raise RuntimeError(Errors.E147.format(method="get_similarity_loss", msg="gold entities do not match up"))
|
||||||
|
|
||||||
gradients = self.distance.get_grad(scores, entity_encodings)
|
gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
|
||||||
loss = self.distance.get_loss(scores, entity_encodings)
|
loss = self.distance.get_loss(sentence_encodings, entity_encodings)
|
||||||
loss = loss / len(entity_encodings)
|
loss = loss / len(entity_encodings)
|
||||||
return loss, gradients
|
return loss, gradients
|
||||||
|
|
||||||
def get_loss(self, examples, scores):
|
|
||||||
cats = []
|
|
||||||
for eg in examples:
|
|
||||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
|
||||||
for ent in eg.predicted.ents:
|
|
||||||
kb_id = kb_ids[ent.start]
|
|
||||||
if kb_id:
|
|
||||||
cats.append([1.0])
|
|
||||||
|
|
||||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
|
||||||
if len(scores) != len(cats):
|
|
||||||
raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up"))
|
|
||||||
|
|
||||||
d_scores = (scores - cats)
|
|
||||||
loss = (d_scores ** 2).sum()
|
|
||||||
loss = loss / len(cats)
|
|
||||||
return loss, d_scores
|
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
kb_ids, tensors = self.predict([doc])
|
kb_ids = self.predict([doc])
|
||||||
self.set_annotations([doc], kb_ids, tensors=tensors)
|
self.set_annotations([doc], kb_ids)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128):
|
def pipe(self, stream, batch_size=128):
|
||||||
for docs in util.minibatch(stream, size=batch_size):
|
for docs in util.minibatch(stream, size=batch_size):
|
||||||
kb_ids, tensors = self.predict(docs)
|
kb_ids = self.predict(docs)
|
||||||
self.set_annotations(docs, kb_ids, tensors=tensors)
|
self.set_annotations(docs, kb_ids)
|
||||||
yield from docs
|
yield from docs
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
|
@ -1191,10 +1168,9 @@ class EntityLinker(Pipe):
|
||||||
self.require_kb()
|
self.require_kb()
|
||||||
entity_count = 0
|
entity_count = 0
|
||||||
final_kb_ids = []
|
final_kb_ids = []
|
||||||
final_tensors = []
|
|
||||||
|
|
||||||
if not docs:
|
if not docs:
|
||||||
return final_kb_ids, final_tensors
|
return final_kb_ids
|
||||||
|
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
@ -1228,21 +1204,18 @@ class EntityLinker(Pipe):
|
||||||
if to_discard and ent.label_ in to_discard:
|
if to_discard and ent.label_ in to_discard:
|
||||||
# ignoring this entity - setting to NIL
|
# ignoring this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
final_tensors.append(sentence_encoding)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
candidates = self.kb.get_candidates(ent.text)
|
candidates = self.kb.get_candidates(ent.text)
|
||||||
if not candidates:
|
if not candidates:
|
||||||
# no prediction possible for this entity - setting to NIL
|
# no prediction possible for this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
final_tensors.append(sentence_encoding)
|
|
||||||
|
|
||||||
elif len(candidates) == 1:
|
elif len(candidates) == 1:
|
||||||
# shortcut for efficiency reasons: take the 1 candidate
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
|
|
||||||
# TODO: thresholding
|
# TODO: thresholding
|
||||||
final_kb_ids.append(candidates[0].entity_)
|
final_kb_ids.append(candidates[0].entity_)
|
||||||
final_tensors.append(sentence_encoding)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
|
@ -1271,14 +1244,13 @@ class EntityLinker(Pipe):
|
||||||
best_index = scores.argmax().item()
|
best_index = scores.argmax().item()
|
||||||
best_candidate = candidates[best_index]
|
best_candidate = candidates[best_index]
|
||||||
final_kb_ids.append(best_candidate.entity_)
|
final_kb_ids.append(best_candidate.entity_)
|
||||||
final_tensors.append(sentence_encoding)
|
|
||||||
|
|
||||||
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
|
if not (len(final_kb_ids) == entity_count):
|
||||||
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
|
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
|
||||||
|
|
||||||
return final_kb_ids, final_tensors
|
return final_kb_ids
|
||||||
|
|
||||||
def set_annotations(self, docs, kb_ids, tensors=None):
|
def set_annotations(self, docs, kb_ids):
|
||||||
count_ents = len([ent for doc in docs for ent in doc.ents])
|
count_ents = len([ent for doc in docs for ent in doc.ents])
|
||||||
if count_ents != len(kb_ids):
|
if count_ents != len(kb_ids):
|
||||||
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
||||||
|
@ -1394,10 +1366,6 @@ class Sentencizer(Pipe):
|
||||||
def pipe(self, stream, batch_size=128):
|
def pipe(self, stream, batch_size=128):
|
||||||
for docs in util.minibatch(stream, size=batch_size):
|
for docs in util.minibatch(stream, size=batch_size):
|
||||||
predictions = self.predict(docs)
|
predictions = self.predict(docs)
|
||||||
if isinstance(predictions, tuple) and len(tuple) == 2:
|
|
||||||
scores, tensors = predictions
|
|
||||||
self.set_annotations(docs, scores, tensors=tensors)
|
|
||||||
else:
|
|
||||||
self.set_annotations(docs, predictions)
|
self.set_annotations(docs, predictions)
|
||||||
yield from docs
|
yield from docs
|
||||||
|
|
||||||
|
@ -1429,7 +1397,7 @@ class Sentencizer(Pipe):
|
||||||
guesses.append(doc_guesses)
|
guesses.append(doc_guesses)
|
||||||
return guesses
|
return guesses
|
||||||
|
|
||||||
def set_annotations(self, docs, batch_tag_ids, tensors=None):
|
def set_annotations(self, docs, batch_tag_ids):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
|
|
|
@ -57,7 +57,7 @@ class SimpleNER(Pipe):
|
||||||
scores = self.model.predict(docs)
|
scores = self.model.predict(docs)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def set_annotations(self, docs: List[Doc], scores: List[Floats2d], tensors=None):
|
def set_annotations(self, docs: List[Doc], scores: List[Floats2d]):
|
||||||
"""Set entities on a batch of documents from a batch of scores."""
|
"""Set entities on a batch of documents from a batch of scores."""
|
||||||
tag_names = self.get_tag_names()
|
tag_names = self.get_tag_names()
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
|
@ -67,9 +67,12 @@ class SimpleNER(Pipe):
|
||||||
tags = iob_to_biluo(tags)
|
tags = iob_to_biluo(tags)
|
||||||
doc.ents = spans_from_biluo_tags(doc, tags)
|
doc.ents = spans_from_biluo_tags(doc, tags)
|
||||||
|
|
||||||
def update(self, examples, set_annotations=False, drop=0.0, sgd=None, losses=None):
|
def update(self, examples, *, set_annotations=False, drop=0.0, sgd=None, losses=None):
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault("ner", 0.0)
|
||||||
if not any(_has_ner(eg) for eg in examples):
|
if not any(_has_ner(eg) for eg in examples):
|
||||||
return 0
|
return losses
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
scores, bp_scores = self.model.begin_update(docs)
|
scores, bp_scores = self.model.begin_update(docs)
|
||||||
|
@ -79,10 +82,8 @@ class SimpleNER(Pipe):
|
||||||
self.set_annotations(docs, scores)
|
self.set_annotations(docs, scores)
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.model.finish_update(sgd)
|
self.model.finish_update(sgd)
|
||||||
if losses is not None:
|
|
||||||
losses.setdefault("ner", 0.0)
|
|
||||||
losses["ner"] += loss
|
losses["ner"] += loss
|
||||||
return loss
|
return losses
|
||||||
|
|
||||||
def get_loss(self, examples, scores):
|
def get_loss(self, examples, scores):
|
||||||
loss = 0
|
loss = 0
|
||||||
|
|
|
@ -83,12 +83,14 @@ class Tok2Vec(Pipe):
|
||||||
assert tokvecs.shape[0] == len(doc)
|
assert tokvecs.shape[0] == len(doc)
|
||||||
doc.tensor = tokvecs
|
doc.tensor = tokvecs
|
||||||
|
|
||||||
def update(self, examples, drop=0.0, sgd=None, losses=None, set_annotations=False):
|
def update(self, examples, *, drop=0.0, sgd=None, losses=None, set_annotations=False):
|
||||||
"""Update the model.
|
"""Update the model.
|
||||||
examples (iterable): A batch of examples
|
examples (Iterable[Example]): A batch of examples
|
||||||
drop (float): The droput rate.
|
drop (float): The droput rate.
|
||||||
sgd (callable): An optimizer.
|
sgd (Optimizer): An optimizer.
|
||||||
RETURNS (dict): Results from the update.
|
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
|
||||||
|
set_annotations (bool): whether or not to update the examples with the predictions
|
||||||
|
RETURNS (Dict[str, float]): The updated losses dictionary
|
||||||
"""
|
"""
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -124,6 +126,7 @@ class Tok2Vec(Pipe):
|
||||||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||||
if set_annotations:
|
if set_annotations:
|
||||||
self.set_annotations(docs, tokvecs)
|
self.set_annotations(docs, tokvecs)
|
||||||
|
return losses
|
||||||
|
|
||||||
def get_loss(self, docs, golds, scores):
|
def get_loss(self, docs, golds, scores):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -153,7 +153,7 @@ cdef class Parser:
|
||||||
doc (Doc): The document to be processed.
|
doc (Doc): The document to be processed.
|
||||||
"""
|
"""
|
||||||
states = self.predict([doc])
|
states = self.predict([doc])
|
||||||
self.set_annotations([doc], states, tensors=None)
|
self.set_annotations([doc], states)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def pipe(self, docs, int batch_size=256):
|
def pipe(self, docs, int batch_size=256):
|
||||||
|
@ -170,7 +170,7 @@ cdef class Parser:
|
||||||
for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)):
|
for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)):
|
||||||
subbatch = list(subbatch)
|
subbatch = list(subbatch)
|
||||||
parse_states = self.predict(subbatch)
|
parse_states = self.predict(subbatch)
|
||||||
self.set_annotations(subbatch, parse_states, tensors=None)
|
self.set_annotations(subbatch, parse_states)
|
||||||
yield from batch_in_order
|
yield from batch_in_order
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
|
@ -222,7 +222,7 @@ cdef class Parser:
|
||||||
unfinished.clear()
|
unfinished.clear()
|
||||||
free_activations(&activations)
|
free_activations(&activations)
|
||||||
|
|
||||||
def set_annotations(self, docs, states, tensors=None):
|
def set_annotations(self, docs, states):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
for i, (state, doc) in enumerate(zip(states, docs)):
|
for i, (state, doc) in enumerate(zip(states, docs)):
|
||||||
|
@ -263,7 +263,7 @@ cdef class Parser:
|
||||||
states[i].push_hist(guess)
|
states[i].push_hist(guess)
|
||||||
free(is_valid)
|
free(is_valid)
|
||||||
|
|
||||||
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
|
def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
|
|
|
@ -302,7 +302,7 @@ def test_multiple_predictions():
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
return ([1, 2, 3], [4, 5, 6])
|
return ([1, 2, 3], [4, 5, 6])
|
||||||
|
|
||||||
def set_annotations(self, docs, scores, tensors=None):
|
def set_annotations(self, docs, scores):
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
|
31
spacy/tests/regression/test_issue5551.py
Normal file
31
spacy/tests/regression/test_issue5551.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.util import fix_random_seed
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue5551():
|
||||||
|
"""Test that after fixing the random seed, the results of the pipeline are truly identical"""
|
||||||
|
component = "textcat"
|
||||||
|
pipe_cfg = {"exclusive_classes": False}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(3):
|
||||||
|
fix_random_seed(0)
|
||||||
|
nlp = English()
|
||||||
|
example = (
|
||||||
|
"Once hot, form ping-pong-ball-sized balls of the mixture, each weighing roughly 25 g.",
|
||||||
|
{"cats": {"Labe1": 1.0, "Label2": 0.0, "Label3": 0.0}},
|
||||||
|
)
|
||||||
|
nlp.add_pipe(nlp.create_pipe(component, config=pipe_cfg), last=True)
|
||||||
|
pipe = nlp.get_pipe(component)
|
||||||
|
for label in set(example[1]["cats"]):
|
||||||
|
pipe.add_label(label)
|
||||||
|
nlp.begin_training(component_cfg={component: pipe_cfg})
|
||||||
|
|
||||||
|
# Store the result of each iteration
|
||||||
|
result = pipe.model.predict([nlp.make_doc(example[0])])
|
||||||
|
results.append(list(result[0]))
|
||||||
|
|
||||||
|
# All results should be the same because of the fixed seed
|
||||||
|
assert len(results) == 3
|
||||||
|
assert results[0] == results[1]
|
||||||
|
assert results[0] == results[2]
|
|
@ -1,3 +1,4 @@
|
||||||
|
import numpy
|
||||||
from spacy.errors import AlignmentError
|
from spacy.errors import AlignmentError
|
||||||
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
|
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
|
||||||
from spacy.gold import spans_from_biluo_tags, iob_to_biluo
|
from spacy.gold import spans_from_biluo_tags, iob_to_biluo
|
||||||
|
@ -154,6 +155,27 @@ def test_gold_biluo_misalign(en_vocab):
|
||||||
assert tags == ["O", "O", "O", "-", "-", "-"]
|
assert tags == ["O", "O", "O", "-", "-", "-"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_example_constructor(en_vocab):
|
||||||
|
words = ["I", "like", "stuff"]
|
||||||
|
tags = ["NOUN", "VERB", "NOUN"]
|
||||||
|
tag_ids = [en_vocab.strings.add(tag) for tag in tags]
|
||||||
|
predicted = Doc(en_vocab, words=words)
|
||||||
|
reference = Doc(en_vocab, words=words)
|
||||||
|
reference = reference.from_array("TAG", numpy.array(tag_ids, dtype="uint64"))
|
||||||
|
example = Example(predicted, reference)
|
||||||
|
tags = example.get_aligned("TAG", as_string=True)
|
||||||
|
assert tags == ["NOUN", "VERB", "NOUN"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_example_from_dict_tags(en_vocab):
|
||||||
|
words = ["I", "like", "stuff"]
|
||||||
|
tags = ["NOUN", "VERB", "NOUN"]
|
||||||
|
predicted = Doc(en_vocab, words=words)
|
||||||
|
example = Example.from_dict(predicted, {"TAGS": tags})
|
||||||
|
tags = example.get_aligned("TAG", as_string=True)
|
||||||
|
assert tags == ["NOUN", "VERB", "NOUN"]
|
||||||
|
|
||||||
|
|
||||||
def test_example_from_dict_no_ner(en_vocab):
|
def test_example_from_dict_no_ner(en_vocab):
|
||||||
words = ["a", "b", "c", "d"]
|
words = ["a", "b", "c", "d"]
|
||||||
spaces = [True, True, False, True]
|
spaces = [True, True, False, True]
|
||||||
|
|
156
spacy/tests/test_models.py
Normal file
156
spacy/tests/test_models.py
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from thinc.api import fix_random_seed, Adam, set_dropout_rate
|
||||||
|
from numpy.testing import assert_array_equal
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
from spacy.ml.models import build_Tok2Vec_model
|
||||||
|
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_params(model):
|
||||||
|
params = []
|
||||||
|
for node in model.walk():
|
||||||
|
for name in node.param_names:
|
||||||
|
params.append(node.get_param(name).ravel())
|
||||||
|
return node.ops.xp.concatenate(params)
|
||||||
|
|
||||||
|
|
||||||
|
def get_docs():
|
||||||
|
nlp = English()
|
||||||
|
return list(nlp.pipe(EN_SENTENCES + [" ".join(EN_SENTENCES)]))
|
||||||
|
|
||||||
|
|
||||||
|
def get_gradient(model, Y):
|
||||||
|
if isinstance(Y, model.ops.xp.ndarray):
|
||||||
|
dY = model.ops.alloc(Y.shape, dtype=Y.dtype)
|
||||||
|
dY += model.ops.xp.random.uniform(-1.0, 1.0, Y.shape)
|
||||||
|
return dY
|
||||||
|
elif isinstance(Y, List):
|
||||||
|
return [get_gradient(model, y) for y in Y]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Could not compare type {type(Y)}")
|
||||||
|
|
||||||
|
|
||||||
|
def default_tok2vec():
|
||||||
|
return build_Tok2Vec_model(**TOK2VEC_KWARGS)
|
||||||
|
|
||||||
|
|
||||||
|
TOK2VEC_KWARGS = {
|
||||||
|
"width": 96,
|
||||||
|
"embed_size": 2000,
|
||||||
|
"subword_features": True,
|
||||||
|
"char_embed": False,
|
||||||
|
"conv_depth": 4,
|
||||||
|
"bilstm_depth": 0,
|
||||||
|
"maxout_pieces": 4,
|
||||||
|
"window_size": 1,
|
||||||
|
"dropout": 0.1,
|
||||||
|
"nM": 0,
|
||||||
|
"nC": 0,
|
||||||
|
"pretrained_vectors": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
TEXTCAT_KWARGS = {
|
||||||
|
"width": 64,
|
||||||
|
"embed_size": 2000,
|
||||||
|
"pretrained_vectors": None,
|
||||||
|
"exclusive_classes": False,
|
||||||
|
"ngram_size": 1,
|
||||||
|
"window_size": 1,
|
||||||
|
"conv_depth": 2,
|
||||||
|
"dropout": None,
|
||||||
|
"nO": 7
|
||||||
|
}
|
||||||
|
|
||||||
|
TEXTCAT_CNN_KWARGS = {
|
||||||
|
"tok2vec": default_tok2vec(),
|
||||||
|
"exclusive_classes": False,
|
||||||
|
"nO": 13,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"seed,model_func,kwargs",
|
||||||
|
[
|
||||||
|
(0, build_Tok2Vec_model, TOK2VEC_KWARGS),
|
||||||
|
(0, build_text_classifier, TEXTCAT_KWARGS),
|
||||||
|
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_models_initialize_consistently(seed, model_func, kwargs):
|
||||||
|
fix_random_seed(seed)
|
||||||
|
model1 = model_func(**kwargs)
|
||||||
|
model1.initialize()
|
||||||
|
fix_random_seed(seed)
|
||||||
|
model2 = model_func(**kwargs)
|
||||||
|
model2.initialize()
|
||||||
|
params1 = get_all_params(model1)
|
||||||
|
params2 = get_all_params(model2)
|
||||||
|
assert_array_equal(params1, params2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"seed,model_func,kwargs,get_X",
|
||||||
|
[
|
||||||
|
(0, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
|
||||||
|
(0, build_text_classifier, TEXTCAT_KWARGS, get_docs),
|
||||||
|
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_models_predict_consistently(seed, model_func, kwargs, get_X):
|
||||||
|
fix_random_seed(seed)
|
||||||
|
model1 = model_func(**kwargs).initialize()
|
||||||
|
Y1 = model1.predict(get_X())
|
||||||
|
fix_random_seed(seed)
|
||||||
|
model2 = model_func(**kwargs).initialize()
|
||||||
|
Y2 = model2.predict(get_X())
|
||||||
|
|
||||||
|
if model1.has_ref("tok2vec"):
|
||||||
|
tok2vec1 = model1.get_ref("tok2vec").predict(get_X())
|
||||||
|
tok2vec2 = model2.get_ref("tok2vec").predict(get_X())
|
||||||
|
for i in range(len(tok2vec1)):
|
||||||
|
for j in range(len(tok2vec1[i])):
|
||||||
|
assert_array_equal(numpy.asarray(tok2vec1[i][j]), numpy.asarray(tok2vec2[i][j]))
|
||||||
|
|
||||||
|
if isinstance(Y1, numpy.ndarray):
|
||||||
|
assert_array_equal(Y1, Y2)
|
||||||
|
elif isinstance(Y1, List):
|
||||||
|
assert len(Y1) == len(Y2)
|
||||||
|
for y1, y2 in zip(Y1, Y2):
|
||||||
|
assert_array_equal(y1, y2)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Could not compare type {type(Y1)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"seed,dropout,model_func,kwargs,get_X",
|
||||||
|
[
|
||||||
|
(0, 0.2, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
|
||||||
|
(0, 0.2, build_text_classifier, TEXTCAT_KWARGS, get_docs),
|
||||||
|
(0, 0.2, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
|
||||||
|
def get_updated_model():
|
||||||
|
fix_random_seed(seed)
|
||||||
|
optimizer = Adam(0.001)
|
||||||
|
model = model_func(**kwargs).initialize()
|
||||||
|
initial_params = get_all_params(model)
|
||||||
|
set_dropout_rate(model, dropout)
|
||||||
|
for _ in range(5):
|
||||||
|
Y, get_dX = model.begin_update(get_X())
|
||||||
|
dY = get_gradient(model, Y)
|
||||||
|
_ = get_dX(dY)
|
||||||
|
model.finish_update(optimizer)
|
||||||
|
updated_params = get_all_params(model)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
assert_array_equal(initial_params, updated_params)
|
||||||
|
return model
|
||||||
|
|
||||||
|
model1 = get_updated_model()
|
||||||
|
model2 = get_updated_model()
|
||||||
|
assert_array_equal(get_all_params(model1), get_all_params(model2))
|
|
@ -803,7 +803,7 @@ cdef class Doc:
|
||||||
attrs = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
|
attrs = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
|
||||||
for id_ in attrs]
|
for id_ in attrs]
|
||||||
if array.dtype != numpy.uint64:
|
if array.dtype != numpy.uint64:
|
||||||
warnings.warn(Warnings.W101.format(type=array.dtype))
|
warnings.warn(Warnings.W028.format(type=array.dtype))
|
||||||
|
|
||||||
if SENT_START in attrs and HEAD in attrs:
|
if SENT_START in attrs and HEAD in attrs:
|
||||||
raise ValueError(Errors.E032)
|
raise ValueError(Errors.E032)
|
||||||
|
|
|
@ -741,6 +741,50 @@ def minibatch(items, size=8):
|
||||||
yield list(batch)
|
yield list(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False):
|
||||||
|
if isinstance(size, int):
|
||||||
|
size_ = itertools.repeat(size)
|
||||||
|
else:
|
||||||
|
size_ = size
|
||||||
|
for outer_batch in minibatch(docs, buffer):
|
||||||
|
outer_batch = list(outer_batch)
|
||||||
|
target_size = next(size_)
|
||||||
|
for indices in _batch_by_length(outer_batch, target_size):
|
||||||
|
subbatch = [outer_batch[i] for i in indices]
|
||||||
|
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
|
||||||
|
if discard_oversize and padded_size >= target_size:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
yield subbatch
|
||||||
|
|
||||||
|
|
||||||
|
def _batch_by_length(seqs, max_words):
|
||||||
|
"""Given a list of sequences, return a batched list of indices into the
|
||||||
|
list, where the batches are grouped by length, in descending order.
|
||||||
|
|
||||||
|
Batches may be at most max_words in size, defined as max sequence length * size.
|
||||||
|
"""
|
||||||
|
# Use negative index so we can get sort by position ascending.
|
||||||
|
lengths_indices = [(len(seq), i) for i, seq in enumerate(seqs)]
|
||||||
|
lengths_indices.sort()
|
||||||
|
batches = []
|
||||||
|
batch = []
|
||||||
|
for length, i in lengths_indices:
|
||||||
|
if not batch:
|
||||||
|
batch.append(i)
|
||||||
|
elif length * (len(batch) + 1) <= max_words:
|
||||||
|
batch.append(i)
|
||||||
|
else:
|
||||||
|
batches.append(batch)
|
||||||
|
batch = [i]
|
||||||
|
if batch:
|
||||||
|
batches.append(batch)
|
||||||
|
# Check lengths match
|
||||||
|
assert sum(len(b) for b in batches) == len(seqs)
|
||||||
|
batches = [list(sorted(batch)) for batch in batches]
|
||||||
|
batches.reverse()
|
||||||
|
return batches
|
||||||
|
|
||||||
def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||||
"""Create minibatches of roughly a given number of words. If any examples
|
"""Create minibatches of roughly a given number of words. If any examples
|
||||||
are longer than the specified batch length, they will appear in a batch by
|
are longer than the specified batch length, they will appear in a batch by
|
||||||
|
@ -787,6 +831,7 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||||
|
|
||||||
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||||
else:
|
else:
|
||||||
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
target_size = next(size_)
|
target_size = next(size_)
|
||||||
tol_size = target_size * tolerance
|
tol_size = target_size * tolerance
|
||||||
|
@ -807,15 +852,15 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||||
|
|
||||||
# this example does not fit with the previous overflow: start another new batch
|
# this example does not fit with the previous overflow: start another new batch
|
||||||
else:
|
else:
|
||||||
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
target_size = next(size_)
|
target_size = next(size_)
|
||||||
tol_size = target_size * tolerance
|
tol_size = target_size * tolerance
|
||||||
batch = [doc]
|
batch = [doc]
|
||||||
batch_size = n_words
|
batch_size = n_words
|
||||||
|
|
||||||
# yield the final batch
|
|
||||||
if batch:
|
|
||||||
batch.extend(overflow)
|
batch.extend(overflow)
|
||||||
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user