Tidy up, autoformat, add types

This commit is contained in:
Ines Montani 2020-07-25 15:01:15 +02:00
parent 71242327b2
commit e92df281ce
37 changed files with 423 additions and 454 deletions

1
.gitignore vendored
View File

@ -51,6 +51,7 @@ env3.*/
.denv
.pypyenv
.pytest_cache/
.mypy_cache/
# Distribution / packaging
env/

View File

@ -108,3 +108,8 @@ exclude =
[tool:pytest]
markers =
slow
[mypy]
ignore_missing_imports = True
no_implicit_optional = True
plugins = pydantic.mypy, thinc.mypy

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Any, List, Union
from enum import Enum
from pathlib import Path
from wasabi import Printer
@ -66,10 +66,9 @@ def convert_cli(
file_type = file_type.value
input_path = Path(input_path)
output_dir = "-" if output_dir == Path("-") else output_dir
cli_args = locals()
silent = output_dir == "-"
msg = Printer(no_print=silent)
verify_cli_args(msg, **cli_args)
verify_cli_args(msg, input_path, output_dir, file_type, converter, ner_map)
converter = _get_converter(msg, converter, input_path)
convert(
input_path,
@ -89,8 +88,8 @@ def convert_cli(
def convert(
input_path: Path,
output_dir: Path,
input_path: Union[str, Path],
output_dir: Union[str, Path],
*,
file_type: str = "json",
n_sents: int = 1,
@ -102,13 +101,12 @@ def convert(
ner_map: Optional[Path] = None,
lang: Optional[str] = None,
silent: bool = True,
msg: Optional[Path] = None,
msg: Optional[Printer],
) -> None:
if not msg:
msg = Printer(no_print=silent)
ner_map = srsly.read_json(ner_map) if ner_map is not None else None
for input_loc in walk_directory(input_path):
for input_loc in walk_directory(Path(input_path)):
input_data = input_loc.open("r", encoding="utf-8").read()
# Use converter function to convert data
func = CONVERTERS[converter]
@ -140,14 +138,14 @@ def convert(
msg.good(f"Generated output file ({len(docs)} documents): {output_file}")
def _print_docs_to_stdout(data, output_type):
def _print_docs_to_stdout(data: Any, output_type: str) -> None:
if output_type == "json":
srsly.write_json("-", data)
else:
sys.stdout.buffer.write(data)
def _write_docs_to_file(data, output_file, output_type):
def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None:
if not output_file.parent.exists():
output_file.parent.mkdir(parents=True)
if output_type == "json":
@ -157,7 +155,7 @@ def _write_docs_to_file(data, output_file, output_type):
file_.write(data)
def autodetect_ner_format(input_data: str) -> str:
def autodetect_ner_format(input_data: str) -> Optional[str]:
# guess format from the first 20 lines
lines = input_data.split("\n")[:20]
format_guesses = {"ner": 0, "iob": 0}
@ -176,7 +174,7 @@ def autodetect_ner_format(input_data: str) -> str:
return None
def walk_directory(path):
def walk_directory(path: Path) -> List[Path]:
if not path.is_dir():
return [path]
paths = [path]
@ -196,31 +194,25 @@ def walk_directory(path):
def verify_cli_args(
msg,
input_path,
output_dir,
file_type,
n_sents,
seg_sents,
model,
morphology,
merge_subtokens,
converter,
ner_map,
lang,
msg: Printer,
input_path: Union[str, Path],
output_dir: Union[str, Path],
file_type: FileTypes,
converter: str,
ner_map: Optional[Path],
):
input_path = Path(input_path)
if file_type not in FILE_TYPES_STDOUT and output_dir == "-":
# TODO: support msgpack via stdout in srsly?
msg.fail(
f"Can't write .{file_type} data to stdout",
"Please specify an output directory.",
f"Can't write .{file_type} data to stdout. Please specify an output directory.",
exits=1,
)
if not input_path.exists():
msg.fail("Input file not found", input_path, exits=1)
if output_dir != "-" and not Path(output_dir).exists():
msg.fail("Output directory not found", output_dir, exits=1)
if ner_map is not None and not Path(ner_map).exists():
msg.fail("NER map not found", ner_map, exits=1)
if input_path.is_dir():
input_locs = walk_directory(input_path)
if len(input_locs) == 0:
@ -229,10 +221,8 @@ def verify_cli_args(
if len(file_types) >= 2:
file_types = ",".join(file_types)
msg.fail("All input files must be same type", file_types, exits=1)
converter = _get_converter(msg, converter, input_path)
if converter not in CONVERTERS:
if converter != "auto" and converter not in CONVERTERS:
msg.fail(f"Can't find converter for {converter}", exits=1)
return converter
def _get_converter(msg, converter, input_path):

View File

@ -82,11 +82,11 @@ def evaluate(
"NER P": "ents_p",
"NER R": "ents_r",
"NER F": "ents_f",
"Textcat AUC": 'textcat_macro_auc',
"Textcat F": 'textcat_macro_f',
"Sent P": 'sents_p',
"Sent R": 'sents_r',
"Sent F": 'sents_f',
"Textcat AUC": "textcat_macro_auc",
"Textcat F": "textcat_macro_f",
"Sent P": "sents_p",
"Sent R": "sents_r",
"Sent F": "sents_f",
}
results = {}
for metric, key in metrics.items():

View File

@ -14,7 +14,6 @@ import typer
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code
from ..gold import Corpus, Example
from ..lookups import Lookups
from ..language import Language
from .. import util
from ..errors import Errors

View File

@ -1,12 +1,5 @@
"""
Helpers for Python and platform compatibility. To distinguish them from
the builtin functions, replacement functions are suffixed with an underscore,
e.g. `unicode_`.
DOCS: https://spacy.io/api/top-level#compat
"""
"""Helpers for Python and platform compatibility."""
import sys
from thinc.util import copy_array
try:
@ -40,21 +33,3 @@ copy_array = copy_array
is_windows = sys.platform.startswith("win")
is_linux = sys.platform.startswith("linux")
is_osx = sys.platform == "darwin"
def is_config(windows=None, linux=None, osx=None, **kwargs):
"""Check if a specific configuration of Python version and operating system
matches the user's setup. Mostly used to display targeted error messages.
windows (bool): spaCy is executed on Windows.
linux (bool): spaCy is executed on Linux.
osx (bool): spaCy is executed on OS X or macOS.
RETURNS (bool): Whether the configuration matches the user's platform.
DOCS: https://spacy.io/api/top-level#compat.is_config
"""
return (
windows in (None, is_windows)
and linux in (None, is_linux)
and osx in (None, is_osx)
)

View File

@ -4,6 +4,7 @@ spaCy's built in visualization suite for dependencies and named entities.
DOCS: https://spacy.io/api/top-level#displacy
USAGE: https://spacy.io/usage/visualizers
"""
from typing import Union, Iterable, Optional, Dict, Any, Callable
import warnings
from .render import DependencyRenderer, EntityRenderer
@ -17,11 +18,17 @@ RENDER_WRAPPER = None
def render(
docs, style="dep", page=False, minify=False, jupyter=None, options={}, manual=False
):
docs: Union[Iterable[Doc], Doc],
style: str = "dep",
page: bool = False,
minify: bool = False,
jupyter: Optional[bool] = None,
options: Dict[str, Any] = {},
manual: bool = False,
) -> str:
"""Render displaCy visualisation.
docs (list or Doc): Document(s) to visualise.
docs (Union[Iterable[Doc], Doc]): Document(s) to visualise.
style (str): Visualisation style, 'dep' or 'ent'.
page (bool): Render markup as full HTML page.
minify (bool): Minify HTML markup.
@ -44,8 +51,8 @@ def render(
docs = [obj if not isinstance(obj, Span) else obj.as_doc() for obj in docs]
if not all(isinstance(obj, (Doc, Span, dict)) for obj in docs):
raise ValueError(Errors.E096)
renderer, converter = factories[style]
renderer = renderer(options=options)
renderer_func, converter = factories[style]
renderer = renderer_func(options=options)
parsed = [converter(doc, options) for doc in docs] if not manual else docs
_html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip()
html = _html["parsed"]
@ -61,15 +68,15 @@ def render(
def serve(
docs,
style="dep",
page=True,
minify=False,
options={},
manual=False,
port=5000,
host="0.0.0.0",
):
docs: Union[Iterable[Doc], Doc],
style: str = "dep",
page: bool = True,
minify: bool = False,
options: Dict[str, Any] = {},
manual: bool = False,
port: int = 5000,
host: str = "0.0.0.0",
) -> None:
"""Serve displaCy visualisation.
docs (list or Doc): Document(s) to visualise.
@ -88,7 +95,6 @@ def serve(
if is_in_jupyter():
warnings.warn(Warnings.W011)
render(docs, style=style, page=page, minify=minify, options=options, manual=manual)
httpd = simple_server.make_server(host, port, app)
print(f"\nUsing the '{style}' visualizer")
@ -102,14 +108,13 @@ def serve(
def app(environ, start_response):
# Headers and status need to be bytes in Python 2, see #1227
headers = [("Content-type", "text/html; charset=utf-8")]
start_response("200 OK", headers)
res = _html["parsed"].encode(encoding="utf-8")
return [res]
def parse_deps(orig_doc, options={}):
def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
"""Generate dependency parse in {'words': [], 'arcs': []} format.
doc (Doc): Document do parse.
@ -152,7 +157,6 @@ def parse_deps(orig_doc, options={}):
}
for w in doc
]
arcs = []
for word in doc:
if word.i < word.head.i:
@ -171,7 +175,7 @@ def parse_deps(orig_doc, options={}):
return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)}
def parse_ents(doc, options={}):
def parse_ents(doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
"""Generate named entities in [{start: i, end: i, label: 'label'}] format.
doc (Doc): Document do parse.
@ -188,7 +192,7 @@ def parse_ents(doc, options={}):
return {"text": doc.text, "ents": ents, "title": title, "settings": settings}
def set_render_wrapper(func):
def set_render_wrapper(func: Callable[[str], str]) -> None:
"""Set an optional wrapper function that is called around the generated
HTML markup on displacy.render. This can be used to allow integration into
other platforms, similar to Jupyter Notebooks that require functions to be
@ -205,7 +209,7 @@ def set_render_wrapper(func):
RENDER_WRAPPER = func
def get_doc_settings(doc):
def get_doc_settings(doc: Doc) -> Dict[str, Any]:
return {
"lang": doc.lang_,
"direction": doc.vocab.writing_system.get("direction", "ltr"),

View File

@ -1,19 +1,36 @@
from typing import Dict, Any, List, Optional, Union
import uuid
from .templates import (
TPL_DEP_SVG,
TPL_DEP_WORDS,
TPL_DEP_WORDS_LEMMA,
TPL_DEP_ARCS,
TPL_ENTS,
)
from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_WORDS_LEMMA, TPL_DEP_ARCS
from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE
from .templates import TPL_ENTS
from ..util import minify_html, escape_html, registry
from ..errors import Errors
DEFAULT_LANG = "en"
DEFAULT_DIR = "ltr"
DEFAULT_ENTITY_COLOR = "#ddd"
DEFAULT_LABEL_COLORS = {
"ORG": "#7aecec",
"PRODUCT": "#bfeeb7",
"GPE": "#feca74",
"LOC": "#ff9561",
"PERSON": "#aa9cfc",
"NORP": "#c887fb",
"FACILITY": "#9cc9cc",
"EVENT": "#ffeb80",
"LAW": "#ff8197",
"LANGUAGE": "#ff8197",
"WORK_OF_ART": "#f0d0ff",
"DATE": "#bfe1d9",
"TIME": "#bfe1d9",
"MONEY": "#e4e7d2",
"QUANTITY": "#e4e7d2",
"ORDINAL": "#e4e7d2",
"CARDINAL": "#e4e7d2",
"PERCENT": "#e4e7d2",
}
class DependencyRenderer:
@ -21,7 +38,7 @@ class DependencyRenderer:
style = "dep"
def __init__(self, options={}):
def __init__(self, options: Dict[str, Any] = {}) -> None:
"""Initialise dependency renderer.
options (dict): Visualiser-specific options (compact, word_spacing,
@ -41,7 +58,9 @@ class DependencyRenderer:
self.direction = DEFAULT_DIR
self.lang = DEFAULT_LANG
def render(self, parsed, page=False, minify=False):
def render(
self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False
) -> str:
"""Render complete markup.
parsed (list): Dependency parses to render.
@ -72,10 +91,15 @@ class DependencyRenderer:
return minify_html(markup)
return markup
def render_svg(self, render_id, words, arcs):
def render_svg(
self,
render_id: Union[int, str],
words: List[Dict[str, Any]],
arcs: List[Dict[str, Any]],
) -> str:
"""Render SVG.
render_id (int): Unique ID, typically index of document.
render_id (Union[int, str]): Unique ID, typically index of document.
words (list): Individual words and their tags.
arcs (list): Individual arcs and their start, end, direction and label.
RETURNS (str): Rendered SVG markup.
@ -86,15 +110,15 @@ class DependencyRenderer:
self.width = self.offset_x + len(words) * self.distance
self.height = self.offset_y + 3 * self.word_spacing
self.id = render_id
words = [
words_svg = [
self.render_word(w["text"], w["tag"], w.get("lemma", None), i)
for i, w in enumerate(words)
]
arcs = [
arcs_svg = [
self.render_arrow(a["label"], a["start"], a["end"], a["dir"], i)
for i, a in enumerate(arcs)
]
content = "".join(words) + "".join(arcs)
content = "".join(words_svg) + "".join(arcs_svg)
return TPL_DEP_SVG.format(
id=self.id,
width=self.width,
@ -107,9 +131,7 @@ class DependencyRenderer:
lang=self.lang,
)
def render_word(
self, text, tag, lemma, i,
):
def render_word(self, text: str, tag: str, lemma: str, i: int) -> str:
"""Render individual word.
text (str): Word text.
@ -128,7 +150,9 @@ class DependencyRenderer:
)
return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y)
def render_arrow(self, label, start, end, direction, i):
def render_arrow(
self, label: str, start: int, end: int, direction: str, i: int
) -> str:
"""Render individual arrow.
label (str): Dependency label.
@ -172,7 +196,7 @@ class DependencyRenderer:
arc=arc,
)
def get_arc(self, x_start, y, y_curve, x_end):
def get_arc(self, x_start: int, y: int, y_curve: int, x_end: int) -> str:
"""Render individual arc.
x_start (int): X-coordinate of arrow start point.
@ -186,7 +210,7 @@ class DependencyRenderer:
template = "M{x},{y} {x},{c} {e},{c} {e},{y}"
return template.format(x=x_start, y=y, c=y_curve, e=x_end)
def get_arrowhead(self, direction, x, y, end):
def get_arrowhead(self, direction: str, x: int, y: int, end: int) -> str:
"""Render individual arrow head.
direction (str): Arrow direction, 'left' or 'right'.
@ -196,24 +220,12 @@ class DependencyRenderer:
RETURNS (str): Definition of the arrow head path ('d' attribute).
"""
if direction == "left":
pos1, pos2, pos3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2)
p1, p2, p3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2)
else:
pos1, pos2, pos3 = (
end,
end + self.arrow_width - 2,
end - self.arrow_width + 2,
)
arrowhead = (
pos1,
y + 2,
pos2,
y - self.arrow_width,
pos3,
y - self.arrow_width,
)
return "M{},{} L{},{} {},{}".format(*arrowhead)
p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2)
return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}"
def get_levels(self, arcs):
def get_levels(self, arcs: List[Dict[str, Any]]) -> List[int]:
"""Calculate available arc height "levels".
Used to calculate arrow heights dynamically and without wasting space.
@ -229,41 +241,21 @@ class EntityRenderer:
style = "ent"
def __init__(self, options={}):
def __init__(self, options: Dict[str, Any] = {}) -> None:
"""Initialise dependency renderer.
options (dict): Visualiser-specific options (colors, ents)
"""
colors = {
"ORG": "#7aecec",
"PRODUCT": "#bfeeb7",
"GPE": "#feca74",
"LOC": "#ff9561",
"PERSON": "#aa9cfc",
"NORP": "#c887fb",
"FACILITY": "#9cc9cc",
"EVENT": "#ffeb80",
"LAW": "#ff8197",
"LANGUAGE": "#ff8197",
"WORK_OF_ART": "#f0d0ff",
"DATE": "#bfe1d9",
"TIME": "#bfe1d9",
"MONEY": "#e4e7d2",
"QUANTITY": "#e4e7d2",
"ORDINAL": "#e4e7d2",
"CARDINAL": "#e4e7d2",
"PERCENT": "#e4e7d2",
}
colors = dict(DEFAULT_LABEL_COLORS)
user_colors = registry.displacy_colors.get_all()
for user_color in user_colors.values():
colors.update(user_color)
colors.update(options.get("colors", {}))
self.default_color = "#ddd"
self.default_color = DEFAULT_ENTITY_COLOR
self.colors = colors
self.ents = options.get("ents", None)
self.direction = DEFAULT_DIR
self.lang = DEFAULT_LANG
template = options.get("template")
if template:
self.ent_template = template
@ -273,7 +265,9 @@ class EntityRenderer:
else:
self.ent_template = TPL_ENT
def render(self, parsed, page=False, minify=False):
def render(
self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False
) -> str:
"""Render complete markup.
parsed (list): Dependency parses to render.
@ -297,7 +291,9 @@ class EntityRenderer:
return minify_html(markup)
return markup
def render_ents(self, text, spans, title):
def render_ents(
self, text: str, spans: List[Dict[str, Any]], title: Optional[str]
) -> str:
"""Render entities in text.
text (str): Original text.

View File

@ -483,6 +483,8 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
E954 = ("The Tok2Vec listener did not receive a valid input.")
E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.")
E956 = ("Can't find component '{name}' in [components] block in the config. "
"Available components: {opts}")

View File

@ -163,7 +163,7 @@ class Language:
else:
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
self.vocab = vocab
self.vocab: Vocab = vocab
if self.lang is None:
self.lang = self.vocab.lang
self.pipeline = []

View File

@ -8,7 +8,7 @@ def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1):
init=init,
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
params={"W": None, "b": None, "pad": None},
attrs={"dropout_rate": dropout}
attrs={"dropout_rate": dropout},
)
return model

View File

@ -154,16 +154,30 @@ def LayerNormalizedMaxout(width, maxout_pieces):
def MultiHashEmbed(
columns, width, rows, use_subwords, pretrained_vectors, mix, dropout
):
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6)
norm = HashEmbed(
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6
)
if use_subwords:
prefix = HashEmbed(
nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout, seed=7
nO=width,
nV=rows // 2,
column=columns.index("PREFIX"),
dropout=dropout,
seed=7,
)
suffix = HashEmbed(
nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout, seed=8
nO=width,
nV=rows // 2,
column=columns.index("SUFFIX"),
dropout=dropout,
seed=8,
)
shape = HashEmbed(
nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout, seed=9
nO=width,
nV=rows // 2,
column=columns.index("SHAPE"),
dropout=dropout,
seed=9,
)
if pretrained_vectors:
@ -192,7 +206,9 @@ def MultiHashEmbed(
@registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5)
norm = HashEmbed(
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5
)
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
with Model.define_operators({">>": chain, "|": concatenate}):
embed_layer = chr_embed | features >> with_array(norm)
@ -263,21 +279,29 @@ def build_Tok2Vec_model(
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
norm = HashEmbed(
nO=width, nV=embed_size, column=cols.index(NORM), dropout=None,
seed=0
nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, seed=0
)
if subword_features:
prefix = HashEmbed(
nO=width, nV=embed_size // 2, column=cols.index(PREFIX), dropout=None,
seed=1
nO=width,
nV=embed_size // 2,
column=cols.index(PREFIX),
dropout=None,
seed=1,
)
suffix = HashEmbed(
nO=width, nV=embed_size // 2, column=cols.index(SUFFIX), dropout=None,
seed=2
nO=width,
nV=embed_size // 2,
column=cols.index(SUFFIX),
dropout=None,
seed=2,
)
shape = HashEmbed(
nO=width, nV=embed_size // 2, column=cols.index(SHAPE), dropout=None,
seed=3
nO=width,
nV=embed_size // 2,
column=cols.index(SHAPE),
dropout=None,
seed=3,
)
else:
prefix, suffix, shape = (None, None, None)
@ -294,11 +318,7 @@ def build_Tok2Vec_model(
embed = uniqued(
(glove | norm | prefix | suffix | shape)
>> Maxout(
nO=width,
nI=width * columns,
nP=3,
dropout=0.0,
normalize=True,
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
),
column=cols.index(ORTH),
)
@ -307,11 +327,7 @@ def build_Tok2Vec_model(
embed = uniqued(
(glove | norm)
>> Maxout(
nO=width,
nI=width * columns,
nP=3,
dropout=0.0,
normalize=True,
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
),
column=cols.index(ORTH),
)
@ -320,11 +336,7 @@ def build_Tok2Vec_model(
embed = uniqued(
concatenate(norm, prefix, suffix, shape)
>> Maxout(
nO=width,
nI=width * columns,
nP=3,
dropout=0.0,
normalize=True,
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
),
column=cols.index(ORTH),
)
@ -333,11 +345,7 @@ def build_Tok2Vec_model(
cols
) >> with_array(norm)
reduce_dimensions = Maxout(
nO=width,
nI=nM * nC + width,
nP=3,
dropout=0.0,
normalize=True,
nO=width, nI=nM * nC + width, nP=3, dropout=0.0, normalize=True,
)
else:
embed = norm

View File

@ -10,7 +10,6 @@ from .simple_ner import SimpleNER
from .tagger import Tagger
from .textcat import TextCategorizer
from .tok2vec import Tok2Vec
from .hooks import SentenceSegmenter, SimilarityHook
from .functions import merge_entities, merge_noun_chunks, merge_subtokens
__all__ = [
@ -21,9 +20,7 @@ __all__ = [
"Morphologizer",
"Pipe",
"SentenceRecognizer",
"SentenceSegmenter",
"Sentencizer",
"SimilarityHook",
"SimpleNER",
"Tagger",
"TextCategorizer",

View File

@ -63,7 +63,7 @@ class EntityRuler:
overwrite_ents: bool = False,
ent_id_sep: str = DEFAULT_ENT_ID_SEP,
patterns: Optional[List[PatternType]] = None,
):
) -> None:
"""Initialize the entitiy ruler. If patterns are supplied here, they
need to be a list of dictionaries with a `"label"` and `"pattern"`
key. A pattern can either be a token pattern (list) or a phrase pattern
@ -239,7 +239,6 @@ class EntityRuler:
phrase_pattern_labels = []
phrase_pattern_texts = []
phrase_pattern_ids = []
for entry in patterns:
if isinstance(entry["pattern"], str):
phrase_pattern_labels.append(entry["label"])
@ -247,7 +246,6 @@ class EntityRuler:
phrase_pattern_ids.append(entry.get("id"))
elif isinstance(entry["pattern"], list):
token_patterns.append(entry)
phrase_patterns = []
for label, pattern, ent_id in zip(
phrase_pattern_labels,
@ -258,7 +256,6 @@ class EntityRuler:
if ent_id:
phrase_pattern["id"] = ent_id
phrase_patterns.append(phrase_pattern)
for entry in token_patterns + phrase_patterns:
label = entry["label"]
if "id" in entry:
@ -266,7 +263,6 @@ class EntityRuler:
label = self._create_label(label, entry["id"])
key = self.matcher._normalize_key(label)
self._ent_ids[key] = (ent_label, entry["id"])
pattern = entry["pattern"]
if isinstance(pattern, Doc):
self.phrase_patterns[label].append(pattern)
@ -323,13 +319,13 @@ class EntityRuler:
self.clear()
if isinstance(cfg, dict):
self.add_patterns(cfg.get("patterns", cfg))
self.overwrite = cfg.get("overwrite")
self.overwrite = cfg.get("overwrite", False)
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr
)
self.ent_id_sep = cfg.get("ent_id_sep")
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
else:
self.add_patterns(cfg)
return self
@ -375,9 +371,9 @@ class EntityRuler:
}
deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
from_disk(path, deserializers_cfg, {})
self.overwrite = cfg.get("overwrite")
self.overwrite = cfg.get("overwrite", False)
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr")
self.ent_id_sep = cfg.get("ent_id_sep")
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher(

View File

@ -1,95 +0,0 @@
from thinc.api import concatenate, reduce_max, reduce_mean, siamese, CauchySimilarity
from .pipe import Pipe
from ..util import link_vectors_to_models
# TODO: do we want to keep these?
class SentenceSegmenter:
"""A simple spaCy hook, to allow custom sentence boundary detection logic
(that doesn't require the dependency parse). To change the sentence
boundary detection strategy, pass a generator function `strategy` on
initialization, or assign a new strategy to the .strategy attribute.
Sentence detection strategies should be generators that take `Doc` objects
and yield `Span` objects for each sentence.
"""
def __init__(self, vocab, strategy=None):
self.vocab = vocab
if strategy is None or strategy == "on_punct":
strategy = self.split_on_punct
self.strategy = strategy
def __call__(self, doc):
doc.user_hooks["sents"] = self.strategy
return doc
@staticmethod
def split_on_punct(doc):
start = 0
seen_period = False
for i, token in enumerate(doc):
if seen_period and not token.is_punct:
yield doc[start : token.i]
start = token.i
seen_period = False
elif token.text in [".", "!", "?"]:
seen_period = True
if start < len(doc):
yield doc[start : len(doc)]
class SimilarityHook(Pipe):
"""
Experimental: A pipeline component to install a hook for supervised
similarity into `Doc` objects.
The similarity model can be any object obeying the Thinc `Model`
interface. By default, the model concatenates the elementwise mean and
elementwise max of the two tensors, and compares them using the
Cauchy-like similarity function from Chen (2013):
>>> similarity = 1. / (1. + (W * (vec1-vec2)**2).sum())
Where W is a vector of dimension weights, initialized to 1.
"""
def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab
self.model = model
self.cfg = dict(cfg)
@classmethod
def Model(cls, length):
return siamese(
concatenate(reduce_max(), reduce_mean()), CauchySimilarity(length * 2)
)
def __call__(self, doc):
"""Install similarity hook"""
doc.user_hooks["similarity"] = self.predict
return doc
def pipe(self, docs, **kwargs):
for doc in docs:
yield self(doc)
def predict(self, doc1, doc2):
return self.model.predict([(doc1, doc2)])
def update(self, doc1_doc2, golds, sgd=None, drop=0.0):
sims, bp_sims = self.model.begin_update(doc1_doc2)
def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs):
"""Allocate model, using nO from the first model in the pipeline.
gold_tuples (iterable): Gold-standard training data.
pipeline (list): The pipeline the model is part of.
"""
if self.model is True:
self.model = self.Model(pipeline[0].model.get_dim("nO"))
link_vectors_to_models(self.vocab)
if sgd is None:
sgd = self.create_optimizer()
return sgd

View File

@ -1,4 +1,4 @@
from typing import Iterable, Tuple, Optional, Dict, List, Callable
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Iterator, Any
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
import numpy
@ -97,7 +97,7 @@ class TextCategorizer(Pipe):
def labels(self, value: Iterable[str]) -> None:
self.cfg["labels"] = tuple(value)
def pipe(self, stream, batch_size=128):
def pipe(self, stream: Iterator[str], batch_size: int = 128) -> Iterator[Doc]:
for docs in util.minibatch(stream, size=batch_size):
scores = self.predict(docs)
self.set_annotations(docs, scores)
@ -252,8 +252,17 @@ class TextCategorizer(Pipe):
sgd = self.create_optimizer()
return sgd
def score(self, examples, positive_label=None, **kwargs):
return Scorer.score_cats(examples, "cats", labels=self.labels,
def score(
self,
examples: Iterable[Example],
positive_label: Optional[str] = None,
**kwargs,
) -> Dict[str, Any]:
return Scorer.score_cats(
examples,
"cats",
labels=self.labels,
multi_label=self.model.attrs["multi_label"],
positive_label=positive_label, **kwargs
positive_label=positive_label,
**kwargs,
)

View File

@ -6,6 +6,7 @@ from ..gold import Example
from ..tokens import Doc
from ..vocab import Vocab
from ..language import Language
from ..errors import Errors
from ..util import link_vectors_to_models, minibatch
@ -150,7 +151,7 @@ class Tok2Vec(Pipe):
self.set_annotations(docs, tokvecs)
return losses
def get_loss(self, examples, scores):
def get_loss(self, examples, scores) -> None:
pass
def begin_training(
@ -184,26 +185,26 @@ class Tok2VecListener(Model):
self._backprop = None
@classmethod
def get_batch_id(cls, inputs):
def get_batch_id(cls, inputs) -> int:
return sum(sum(token.orth for token in doc) for doc in inputs)
def receive(self, batch_id, outputs, backprop):
def receive(self, batch_id: int, outputs, backprop) -> None:
self._batch_id = batch_id
self._outputs = outputs
self._backprop = backprop
def verify_inputs(self, inputs):
def verify_inputs(self, inputs) -> bool:
if self._batch_id is None and self._outputs is None:
raise ValueError("The Tok2Vec listener did not receive valid input.")
raise ValueError(Errors.E954)
else:
batch_id = self.get_batch_id(inputs)
if batch_id != self._batch_id:
raise ValueError(f"Mismatched IDs! {batch_id} vs {self._batch_id}")
raise ValueError(Errors.E953.format(id1=batch_id, id2=self._batch_id))
else:
return True
def forward(model: Tok2VecListener, inputs, is_train):
def forward(model: Tok2VecListener, inputs, is_train: bool):
if is_train:
model.verify_inputs(inputs)
return model._outputs, model._backprop

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Union, Optional, Sequence, Any, Callable
from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type
from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
@ -9,12 +9,12 @@ from thinc.api import Optimizer
from .attrs import NAMES
def validate(schema, obj):
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
"""Validate data against a given pydantic schema.
obj (dict): JSON-serializable data to validate.
obj (Dict[str, Any]): JSON-serializable data to validate.
schema (pydantic.BaseModel): The schema to validate against.
RETURNS (list): A list of error messages, if available.
RETURNS (List[str]): A list of error messages, if available.
"""
try:
schema(**obj)
@ -31,7 +31,7 @@ def validate(schema, obj):
# Matcher token patterns
def validate_token_pattern(obj):
def validate_token_pattern(obj: list) -> List[str]:
# Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"})
get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k
if isinstance(obj, list):

View File

@ -15,7 +15,8 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
"min_action_freq": 30,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training([])
ner(doc)
@ -37,7 +38,8 @@ def test_ents_reset(en_vocab):
"min_action_freq": 30,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training([])
ner(doc)

View File

@ -14,7 +14,9 @@ def test_en_sbd_single_punct(en_tokenizer, text, punct):
assert sum(len(sent) for sent in doc.sents) == len(doc)
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
@pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_en_sentence_breaks(en_tokenizer, en_parser):
# fmt: off
text = "This is a sentence . This is another one ."

View File

@ -32,6 +32,25 @@ SENTENCE_TESTS = [
("あれ。これ。", ["あれ。", "これ。"]),
("「伝染るんです。」という漫画があります。", ["「伝染るんです。」という漫画があります。"]),
]
tokens1 = [
DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None),
DetailedToken(surface="", tag="名詞-普通名詞-一般", inf="", lemma="", reading="カイ", sub_tokens=None),
]
tokens2 = [
DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None),
DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None),
DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None),
DetailedToken(surface="", tag="名詞-普通名詞-一般", inf="", lemma="", reading="カイ", sub_tokens=None),
]
tokens3 = [
DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None),
DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None),
DetailedToken(surface="委員会", tag="名詞-普通名詞-一般", inf="", lemma="委員会", reading="イインカイ", sub_tokens=None),
]
SUB_TOKEN_TESTS = [
("選挙管理委員会", [None, None, None, None], [None, None, [tokens1]], [[tokens2, tokens3]])
]
# fmt: on
@ -92,33 +111,12 @@ def test_ja_tokenizer_split_modes(ja_tokenizer, text, len_a, len_b, len_c):
assert len(nlp_c(text)) == len_c
@pytest.mark.parametrize("text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c",
[
(
"選挙管理委員会",
[None, None, None, None],
[None, None, [
[
DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None),
DetailedToken(surface='', tag='名詞-普通名詞-一般', inf='', lemma='', reading='カイ', sub_tokens=None),
]
]],
[[
[
DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None),
DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None),
DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None),
DetailedToken(surface='', tag='名詞-普通名詞-一般', inf='', lemma='', reading='カイ', sub_tokens=None),
], [
DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None),
DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None),
DetailedToken(surface='委員会', tag='名詞-普通名詞-一般', inf='', lemma='委員会', reading='イインカイ', sub_tokens=None),
]
]]
),
]
@pytest.mark.parametrize(
"text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c", SUB_TOKEN_TESTS,
)
def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c):
def test_ja_tokenizer_sub_tokens(
ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c
):
nlp_a = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "A"}}})
nlp_b = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "B"}}})
nlp_c = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "C"}}})
@ -129,16 +127,19 @@ def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_toke
assert nlp_c(text).user_data["sub_tokens"] == sub_tokens_list_c
@pytest.mark.parametrize("text,inflections,reading_forms",
@pytest.mark.parametrize(
"text,inflections,reading_forms",
[
(
"取ってつけた",
("五段-ラ行,連用形-促音便", "", "下一段-カ行,連用形-一般", "助動詞-タ,終止形-一般"),
("トッ", "", "ツケ", ""),
),
]
],
)
def test_ja_tokenizer_inflections_reading_forms(ja_tokenizer, text, inflections, reading_forms):
def test_ja_tokenizer_inflections_reading_forms(
ja_tokenizer, text, inflections, reading_forms
):
assert ja_tokenizer(text).user_data["inflections"] == inflections
assert ja_tokenizer(text).user_data["reading_forms"] == reading_forms

View File

@ -11,8 +11,7 @@ def test_ne_tokenizer_handlers_long_text(ne_tokenizer):
@pytest.mark.parametrize(
"text,length",
[("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)],
"text,length", [("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)],
)
def test_ne_tokenizer_handles_cnts(ne_tokenizer, text, length):
tokens = ne_tokenizer(text)

View File

@ -30,10 +30,7 @@ def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
nlp = Chinese(
meta={
"tokenizer": {
"config": {
"segmenter": "pkuseg",
"pkuseg_model": "medicine",
}
"config": {"segmenter": "pkuseg", "pkuseg_model": "medicine",}
}
}
)

View File

@ -22,7 +22,8 @@ def parser(vocab):
"min_action_freq": 30,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = DependencyParser(vocab, model, **config)
return parser
@ -68,7 +69,8 @@ def test_add_label_deserializes_correctly():
"min_action_freq": 30,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
ner1 = EntityRecognizer(Vocab(), model, **config)
ner1.add_label("C")
ner1.add_label("B")
@ -86,7 +88,10 @@ def test_add_label_deserializes_correctly():
@pytest.mark.parametrize(
"pipe_cls,n_moves,model_config",
[(DependencyParser, 5, DEFAULT_PARSER_MODEL), (EntityRecognizer, 4, DEFAULT_NER_MODEL)],
[
(DependencyParser, 5, DEFAULT_PARSER_MODEL),
(EntityRecognizer, 4, DEFAULT_NER_MODEL),
],
)
def test_add_label_get_label(pipe_cls, n_moves, model_config):
"""Test that added labels are returned correctly. This test was added to

View File

@ -126,7 +126,8 @@ def test_get_oracle_actions():
"min_action_freq": 0,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = DependencyParser(doc.vocab, model, **config)
parser.moves.add_action(0, "")
parser.moves.add_action(1, "")

View File

@ -24,7 +24,8 @@ def arc_eager(vocab):
@pytest.fixture
def tok2vec():
tok2vec = registry.make_from_config({"model": DEFAULT_TOK2VEC_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_TOK2VEC_MODEL}
tok2vec = registry.make_from_config(cfg, validate=True)["model"]
tok2vec.initialize()
return tok2vec
@ -36,13 +37,15 @@ def parser(vocab, arc_eager):
"min_action_freq": 30,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
return Parser(vocab, model, moves=arc_eager, **config)
@pytest.fixture
def model(arc_eager, tok2vec, vocab):
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model.attrs["resize_output"](model, arc_eager.n_moves)
model.initialize()
return model
@ -68,7 +71,8 @@ def test_build_model(parser, vocab):
"min_action_freq": 0,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model
assert parser.model is not None

View File

@ -33,7 +33,9 @@ def test_parser_root(en_tokenizer):
assert t.dep != 0, t.text
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
@pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
@pytest.mark.parametrize("text", ["Hello"])
def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text):
tokens = en_tokenizer(text)
@ -47,7 +49,9 @@ def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text):
assert doc[0].dep != 0
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
@pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_parser_initial(en_tokenizer, en_parser):
text = "I ate the pizza with anchovies."
# heads = [1, 0, 1, -2, -3, -1, -5]
@ -92,7 +96,9 @@ def test_parser_merge_pp(en_tokenizer):
assert doc[3].text == "occurs"
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
@pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_parser_arc_eager_finalize_state(en_tokenizer, en_parser):
text = "a b c d e"

View File

@ -21,7 +21,8 @@ def parser(vocab):
"min_action_freq": 30,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = DependencyParser(vocab, model, **config)
parser.cfg["token_vector_width"] = 4
parser.cfg["hidden_width"] = 32

View File

@ -28,7 +28,9 @@ def test_parser_sentence_space(en_tokenizer):
assert len(list(doc.sents)) == 2
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
@pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_parser_space_attachment_leading(en_tokenizer, en_parser):
text = "\t \n This is a sentence ."
heads = [1, 1, 0, 1, -2, -3]
@ -44,7 +46,9 @@ def test_parser_space_attachment_leading(en_tokenizer, en_parser):
assert stepwise.stack == set([2])
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
@pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser):
text = "This is \t a \t\n \n sentence . \n\n \n"
heads = [1, 0, -1, 2, -1, -4, -5, -1]
@ -64,7 +68,9 @@ def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser):
@pytest.mark.parametrize("text,length", [(["\n"], 1), (["\n", "\t", "\n\n", "\t"], 4)])
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
@pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_parser_space_attachment_space(en_tokenizer, en_parser, text, length):
doc = Doc(en_parser.vocab, words=text)
assert len(doc) == length

View File

@ -117,7 +117,9 @@ def test_overfitting_IO():
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
# Test scoring
scores = nlp.evaluate(train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}})
scores = nlp.evaluate(
train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}}
)
assert scores["cats_f"] == 1.0

View File

@ -201,7 +201,8 @@ def test_issue3345():
"min_action_freq": 30,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
ner = EntityRecognizer(doc.vocab, model, **config)
# Add the OUT action. I wouldn't have thought this would be necessary...
ner.moves.add_action(5, "")

View File

@ -20,7 +20,8 @@ def parser(en_vocab):
"min_action_freq": 30,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = DependencyParser(en_vocab, model, **config)
parser.add_label("nsubj")
return parser
@ -33,14 +34,16 @@ def blank_parser(en_vocab):
"min_action_freq": 30,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = DependencyParser(en_vocab, model, **config)
return parser
@pytest.fixture
def taggers(en_vocab):
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
tagger1 = Tagger(en_vocab, model, set_morphology=True)
tagger2 = Tagger(en_vocab, model, set_morphology=True)
return tagger1, tagger2
@ -53,7 +56,8 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
"min_action_freq": 0,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = Parser(en_vocab, model, **config)
new_parser = Parser(en_vocab, model, **config)
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
@ -70,7 +74,8 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
"min_action_freq": 0,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = Parser(en_vocab, model, **config)
with make_tempdir() as d:
file_path = d / "parser"
@ -88,7 +93,6 @@ def test_to_from_bytes(parser, blank_parser):
assert blank_parser.model is not True
assert blank_parser.moves.n_moves != parser.moves.n_moves
bytes_data = parser.to_bytes(exclude=["vocab"])
# the blank parser needs to be resized before we can call from_bytes
blank_parser.model.attrs["resize_output"](blank_parser.model, parser.moves.n_moves)
blank_parser.from_bytes(bytes_data)
@ -104,7 +108,8 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
tagger1_b = tagger1.to_bytes()
tagger1 = tagger1.from_bytes(tagger1_b)
assert tagger1.to_bytes() == tagger1_b
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b)
new_tagger1_b = new_tagger1.to_bytes()
assert len(new_tagger1_b) == len(tagger1_b)
@ -118,7 +123,8 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
file_path2 = d / "tagger2"
tagger1.to_disk(file_path1)
tagger2.to_disk(file_path2)
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
tagger1_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path1)
tagger2_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path2)
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
@ -126,21 +132,22 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
def test_serialize_textcat_empty(en_vocab):
# See issue #1105
model = registry.make_from_config({"model": DEFAULT_TEXTCAT_MODEL}, validate=True)["model"]
textcat = TextCategorizer(
en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"]
)
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"])
textcat.to_bytes(exclude=["vocab"])
@pytest.mark.parametrize("Parser", test_parsers)
def test_serialize_pipe_exclude(en_vocab, Parser):
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
config = {
"learn_tokens": False,
"min_action_freq": 0,
"update_with_oracle_cut_size": 100,
}
def get_new_parser():
new_parser = Parser(en_vocab, model, **config)
return new_parser
@ -160,7 +167,8 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
def test_serialize_sentencerecognizer(en_vocab):
model = registry.make_from_config({"model": DEFAULT_SENTER_MODEL}, validate=True)["model"]
cfg = {"model": DEFAULT_SENTER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
sr = SentenceRecognizer(en_vocab, model)
sr_b = sr.to_bytes()
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)

View File

@ -466,7 +466,6 @@ def test_iob_to_biluo():
def test_roundtrip_docs_to_docbin(doc):
nlp = English()
text = doc.text
idx = [t.idx for t in doc]
tags = [t.tag_ for t in doc]

View File

@ -7,11 +7,11 @@ from spacy.vocab import Vocab
def test_Example_init_requires_doc_objects():
vocab = Vocab()
with pytest.raises(TypeError):
example = Example(None, None)
Example(None, None)
with pytest.raises(TypeError):
example = Example(Doc(vocab, words=["hi"]), None)
Example(Doc(vocab, words=["hi"]), None)
with pytest.raises(TypeError):
example = Example(None, Doc(vocab, words=["hi"]))
Example(None, Doc(vocab, words=["hi"]))
def test_Example_from_dict_basic():

View File

@ -105,7 +105,11 @@ def test_tokenization(sented_doc):
assert scores["token_acc"] == 1.0
nlp = English()
example.predicted = Doc(nlp.vocab, words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."], spaces=[True, True, True, True, True, False])
example.predicted = Doc(
nlp.vocab,
words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."],
spaces=[True, True, True, True, True, False],
)
example.predicted[1].is_sent_start = False
scores = scorer.score([example])
assert scores["token_acc"] == approx(0.66666666)

View File

@ -1,12 +1,13 @@
from typing import List, Union, Dict, Any, Optional, Iterable, Callable, Tuple
from typing import Iterator, TYPE_CHECKING
from typing import Iterator, Type, Pattern, Sequence, TYPE_CHECKING
from types import ModuleType
import os
import importlib
import importlib.util
import re
from pathlib import Path
import thinc
from thinc.api import NumpyOps, get_current_ops, Adam, Config
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
import functools
import itertools
import numpy.random
@ -49,6 +50,8 @@ from . import about
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from .language import Language # noqa: F401
from .tokens import Doc, Span # noqa: F401
from .vocab import Vocab # noqa: F401
_PRINT_ENV = False
@ -102,12 +105,12 @@ class SimpleFrozenDict(dict):
raise NotImplementedError(self.error)
def set_env_log(value):
def set_env_log(value: bool) -> None:
global _PRINT_ENV
_PRINT_ENV = value
def lang_class_is_loaded(lang):
def lang_class_is_loaded(lang: str) -> bool:
"""Check whether a Language class is already loaded. Language classes are
loaded lazily, to avoid expensive setup code associated with the language
data.
@ -118,7 +121,7 @@ def lang_class_is_loaded(lang):
return lang in registry.languages
def get_lang_class(lang):
def get_lang_class(lang: str) -> "Language":
"""Import and load a Language class.
lang (str): Two-letter language code, e.g. 'en'.
@ -136,7 +139,7 @@ def get_lang_class(lang):
return registry.languages.get(lang)
def set_lang_class(name, cls):
def set_lang_class(name: str, cls: Type["Language"]) -> None:
"""Set a custom Language class name that can be loaded via get_lang_class.
name (str): Name of Language class.
@ -145,10 +148,10 @@ def set_lang_class(name, cls):
registry.languages.register(name, func=cls)
def ensure_path(path):
def ensure_path(path: Any) -> Any:
"""Ensure string is converted to a Path.
path: Anything. If string, it's converted to Path.
path (Any): Anything. If string, it's converted to Path.
RETURNS: Path or original argument.
"""
if isinstance(path, str):
@ -157,7 +160,7 @@ def ensure_path(path):
return path
def load_language_data(path):
def load_language_data(path: Union[str, Path]) -> Union[dict, list]:
"""Load JSON language data using the given path as a base. If the provided
path isn't present, will attempt to load a gzipped version before giving up.
@ -173,7 +176,12 @@ def load_language_data(path):
raise ValueError(Errors.E160.format(path=path))
def get_module_path(module):
def get_module_path(module: ModuleType) -> Path:
"""Get the path of a Python module.
module (ModuleType): The Python module.
RETURNS (Path): The path.
"""
if not hasattr(module, "__module__"):
raise ValueError(Errors.E169.format(module=repr(module)))
return Path(sys.modules[module.__module__].__file__).parent
@ -183,7 +191,7 @@ def load_model(
name: Union[str, Path],
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
):
) -> "Language":
"""Load a model from a package or data path.
name (str): Package name or model path.
@ -209,7 +217,7 @@ def load_model_from_package(
name: str,
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
):
) -> "Language":
"""Load a model from an installed package."""
cls = importlib.import_module(name)
return cls.load(disable=disable, component_cfg=component_cfg)
@ -220,7 +228,7 @@ def load_model_from_path(
meta: Optional[Dict[str, Any]] = None,
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
):
) -> "Language":
"""Load a model from a data directory path. Creates Language class with
pipeline from config.cfg and then calls from_disk() with path."""
if not model_path.exists():
@ -269,7 +277,7 @@ def load_model_from_init_py(
init_file: Union[Path, str],
disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
):
) -> "Language":
"""Helper function to use in the `load()` method of a model package's
__init__.py.
@ -288,15 +296,15 @@ def load_model_from_init_py(
)
def get_installed_models():
def get_installed_models() -> List[str]:
"""List all model packages currently installed in the environment.
RETURNS (list): The string names of the models.
RETURNS (List[str]): The string names of the models.
"""
return list(registry.models.get_all().keys())
def get_package_version(name):
def get_package_version(name: str) -> Optional[str]:
"""Get the version of an installed package. Typically used to get model
package versions.
@ -309,7 +317,9 @@ def get_package_version(name):
return None
def is_compatible_version(version, constraint, prereleases=True):
def is_compatible_version(
version: str, constraint: str, prereleases: bool = True
) -> Optional[bool]:
"""Check if a version (e.g. "2.0.0") is compatible given a version
constraint (e.g. ">=1.9.0,<2.2.1"). If the constraint is a specific version,
it's interpreted as =={version}.
@ -333,7 +343,9 @@ def is_compatible_version(version, constraint, prereleases=True):
return version in spec
def is_unconstrained_version(constraint, prereleases=True):
def is_unconstrained_version(
constraint: str, prereleases: bool = True
) -> Optional[bool]:
# We have an exact version, this is the ultimate constrained version
if constraint[0].isdigit():
return False
@ -358,7 +370,7 @@ def is_unconstrained_version(constraint, prereleases=True):
return True
def get_model_version_range(spacy_version):
def get_model_version_range(spacy_version: str) -> str:
"""Generate a version range like >=1.2.3,<1.3.0 based on a given spaCy
version. Models are always compatible across patch versions but not
across minor or major versions.
@ -367,7 +379,7 @@ def get_model_version_range(spacy_version):
return f">={spacy_version},<{release[0]}.{release[1] + 1}.0"
def get_base_version(version):
def get_base_version(version: str) -> str:
"""Generate the base version without any prerelease identifiers.
version (str): The version, e.g. "3.0.0.dev1".
@ -376,11 +388,11 @@ def get_base_version(version):
return Version(version).base_version
def get_model_meta(path):
def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]:
"""Get model meta.json from a directory path and validate its contents.
path (str / Path): Path to model directory.
RETURNS (dict): The model's meta data.
RETURNS (Dict[str, Any]): The model's meta data.
"""
model_path = ensure_path(path)
if not model_path.exists():
@ -412,7 +424,7 @@ def get_model_meta(path):
return meta
def is_package(name):
def is_package(name: str) -> bool:
"""Check if string maps to a package installed via pip.
name (str): Name of package.
@ -425,7 +437,7 @@ def is_package(name):
return False
def get_package_path(name):
def get_package_path(name: str) -> Path:
"""Get the path to an installed package.
name (str): Package name.
@ -495,7 +507,7 @@ def working_dir(path: Union[str, Path]) -> None:
@contextmanager
def make_tempdir():
def make_tempdir() -> None:
"""Execute a block in a temporary directory and remove the directory and
its contents at the end of the with block.
@ -518,7 +530,7 @@ def is_cwd(path: Union[Path, str]) -> bool:
return str(Path(path).resolve()).lower() == str(Path.cwd().resolve()).lower()
def is_in_jupyter():
def is_in_jupyter() -> bool:
"""Check if user is running spaCy from a Jupyter notebook by detecting the
IPython kernel. Mainly used for the displaCy visualizer.
RETURNS (bool): True if in Jupyter, False if not.
@ -548,7 +560,9 @@ def get_object_name(obj: Any) -> str:
return repr(obj)
def get_cuda_stream(require=False, non_blocking=True):
def get_cuda_stream(
require: bool = False, non_blocking: bool = True
) -> Optional[CudaStream]:
ops = get_current_ops()
if CudaStream is None:
return None
@ -567,7 +581,7 @@ def get_async(stream, numpy_array):
return array
def env_opt(name, default=None):
def env_opt(name: str, default: Optional[Any] = None) -> Optional[Any]:
if type(default) is float:
type_convert = float
else:
@ -588,7 +602,7 @@ def env_opt(name, default=None):
return default
def read_regex(path):
def read_regex(path: Union[str, Path]) -> Pattern:
path = ensure_path(path)
with path.open(encoding="utf8") as file_:
entries = file_.read().split("\n")
@ -598,37 +612,40 @@ def read_regex(path):
return re.compile(expression)
def compile_prefix_regex(entries):
def compile_prefix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
"""Compile a sequence of prefix rules into a regex object.
entries (tuple): The prefix rules, e.g. spacy.lang.punctuation.TOKENIZER_PREFIXES.
RETURNS (regex object): The regex object. to be used for Tokenizer.prefix_search.
entries (Iterable[Union[str, Pattern]]): The prefix rules, e.g.
spacy.lang.punctuation.TOKENIZER_PREFIXES.
RETURNS (Pattern): The regex object. to be used for Tokenizer.prefix_search.
"""
expression = "|".join(["^" + piece for piece in entries if piece.strip()])
return re.compile(expression)
def compile_suffix_regex(entries):
def compile_suffix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
"""Compile a sequence of suffix rules into a regex object.
entries (tuple): The suffix rules, e.g. spacy.lang.punctuation.TOKENIZER_SUFFIXES.
RETURNS (regex object): The regex object. to be used for Tokenizer.suffix_search.
entries (Iterable[Union[str, Pattern]]): The suffix rules, e.g.
spacy.lang.punctuation.TOKENIZER_SUFFIXES.
RETURNS (Pattern): The regex object. to be used for Tokenizer.suffix_search.
"""
expression = "|".join([piece + "$" for piece in entries if piece.strip()])
return re.compile(expression)
def compile_infix_regex(entries):
def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
"""Compile a sequence of infix rules into a regex object.
entries (tuple): The infix rules, e.g. spacy.lang.punctuation.TOKENIZER_INFIXES.
entries (Iterable[Union[str, Pattern]]): The infix rules, e.g.
spacy.lang.punctuation.TOKENIZER_INFIXES.
RETURNS (regex object): The regex object. to be used for Tokenizer.infix_finditer.
"""
expression = "|".join([piece for piece in entries if piece.strip()])
return re.compile(expression)
def add_lookups(default_func, *lookups):
def add_lookups(default_func: Callable[[str], Any], *lookups) -> Callable[[str], Any]:
"""Extend an attribute function with special cases. If a word is in the
lookups, the value is returned. Otherwise the previous function is used.
@ -641,19 +658,23 @@ def add_lookups(default_func, *lookups):
return functools.partial(_get_attr_unless_lookup, default_func, lookups)
def _get_attr_unless_lookup(default_func, lookups, string):
def _get_attr_unless_lookup(
default_func: Callable[[str], Any], lookups: Dict[str, Any], string: str
) -> Any:
for lookup in lookups:
if string in lookup:
return lookup[string]
return default_func(string)
def update_exc(base_exceptions, *addition_dicts):
def update_exc(
base_exceptions: Dict[str, List[dict]], *addition_dicts
) -> Dict[str, List[dict]]:
"""Update and validate tokenizer exceptions. Will overwrite exceptions.
base_exceptions (dict): Base exceptions.
*addition_dicts (dict): Exceptions to add to the base dict, in order.
RETURNS (dict): Combined tokenizer exceptions.
base_exceptions (Dict[str, List[dict]]): Base exceptions.
*addition_dicts (Dict[str, List[dict]]): Exceptions to add to the base dict, in order.
RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions.
"""
exc = dict(base_exceptions)
for additions in addition_dicts:
@ -668,14 +689,16 @@ def update_exc(base_exceptions, *addition_dicts):
return exc
def expand_exc(excs, search, replace):
def expand_exc(
excs: Dict[str, List[dict]], search: str, replace: str
) -> Dict[str, List[dict]]:
"""Find string in tokenizer exceptions, duplicate entry and replace string.
For example, to add additional versions with typographic apostrophes.
excs (dict): Tokenizer exceptions.
excs (Dict[str, List[dict]]): Tokenizer exceptions.
search (str): String to find and replace.
replace (str): Replacement.
RETURNS (dict): Combined tokenizer exceptions.
RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions.
"""
def _fix_token(token, search, replace):
@ -692,7 +715,9 @@ def expand_exc(excs, search, replace):
return new_excs
def normalize_slice(length, start, stop, step=None):
def normalize_slice(
length: int, start: int, stop: int, step: Optional[int] = None
) -> Tuple[int, int]:
if not (step is None or step == 1):
raise ValueError(Errors.E057)
if start is None:
@ -708,7 +733,9 @@ def normalize_slice(length, start, stop, step=None):
return start, stop
def minibatch(items, size=8):
def minibatch(
items: Iterable[Any], size: Union[Iterator[int], int] = 8
) -> Iterator[Any]:
"""Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step.
"""
@ -725,7 +752,12 @@ def minibatch(items, size=8):
yield list(batch)
def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False):
def minibatch_by_padded_size(
docs: Iterator["Doc"],
size: Union[Iterator[int], int],
buffer: int = 256,
discard_oversize: bool = False,
) -> Iterator[Iterator["Doc"]]:
if isinstance(size, int):
size_ = itertools.repeat(size)
else:
@ -742,7 +774,7 @@ def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False):
yield subbatch
def _batch_by_length(seqs, max_words):
def _batch_by_length(seqs: Sequence[Any], max_words: int) -> List[List[Any]]:
"""Given a list of sequences, return a batched list of indices into the
list, where the batches are grouped by length, in descending order.
@ -783,14 +815,12 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
size_ = iter(size)
else:
size_ = size
target_size = next(size_)
tol_size = target_size * tolerance
batch = []
overflow = []
batch_size = 0
overflow_size = 0
for doc in docs:
if isinstance(doc, Example):
n_words = len(doc.reference)
@ -803,17 +833,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
if n_words > target_size + tol_size:
if not discard_oversize:
yield [doc]
# add the example to the current batch if there's no overflow yet and it still fits
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
batch.append(doc)
batch_size += n_words
# add the example to the overflow buffer if it fits in the tolerance margin
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
overflow.append(doc)
overflow_size += n_words
# yield the previous batch and start a new one. The new one gets the overflow examples.
else:
if batch:
@ -824,17 +851,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
batch_size = overflow_size
overflow = []
overflow_size = 0
# this example still fits
if (batch_size + n_words) <= target_size:
batch.append(doc)
batch_size += n_words
# this example fits in overflow
elif (batch_size + n_words) <= (target_size + tol_size):
overflow.append(doc)
overflow_size += n_words
# this example does not fit with the previous overflow: start another new batch
else:
if batch:
@ -843,20 +867,19 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
tol_size = target_size * tolerance
batch = [doc]
batch_size = n_words
batch.extend(overflow)
if batch:
yield batch
def filter_spans(spans):
def filter_spans(spans: Iterable["Span"]) -> List["Span"]:
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
creating named entities (where one token can only be part of one entity) or
when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
longest span is preferred over shorter spans.
spans (iterable): The spans to filter.
RETURNS (list): The filtered spans.
spans (Iterable[Span]): The spans to filter.
RETURNS (List[Span]): The filtered spans.
"""
get_sort_key = lambda span: (span.end - span.start, -span.start)
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
@ -871,15 +894,21 @@ def filter_spans(spans):
return result
def to_bytes(getters, exclude):
def to_bytes(getters: Dict[str, Callable[[], bytes]], exclude: Iterable[str]) -> bytes:
return srsly.msgpack_dumps(to_dict(getters, exclude))
def from_bytes(bytes_data, setters, exclude):
def from_bytes(
bytes_data: bytes,
setters: Dict[str, Callable[[bytes], Any]],
exclude: Iterable[str],
) -> None:
return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude)
def to_dict(getters, exclude):
def to_dict(
getters: Dict[str, Callable[[], Any]], exclude: Iterable[str]
) -> Dict[str, Any]:
serialized = {}
for key, getter in getters.items():
# Split to support file names like meta.json
@ -888,7 +917,11 @@ def to_dict(getters, exclude):
return serialized
def from_dict(msg, setters, exclude):
def from_dict(
msg: Dict[str, Any],
setters: Dict[str, Callable[[Any], Any]],
exclude: Iterable[str],
) -> Dict[str, Any]:
for key, setter in setters.items():
# Split to support file names like meta.json
if key.split(".")[0] not in exclude and key in msg:
@ -896,7 +929,11 @@ def from_dict(msg, setters, exclude):
return msg
def to_disk(path, writers, exclude):
def to_disk(
path: Union[str, Path],
writers: Dict[str, Callable[[Path], None]],
exclude: Iterable[str],
) -> Path:
path = ensure_path(path)
if not path.exists():
path.mkdir()
@ -907,7 +944,11 @@ def to_disk(path, writers, exclude):
return path
def from_disk(path, readers, exclude):
def from_disk(
path: Union[str, Path],
readers: Dict[str, Callable[[Path], None]],
exclude: Iterable[str],
) -> Path:
path = ensure_path(path)
for key, reader in readers.items():
# Split to support file names like meta.json
@ -916,7 +957,7 @@ def from_disk(path, readers, exclude):
return path
def import_file(name, loc):
def import_file(name: str, loc: Union[str, Path]) -> ModuleType:
"""Import module from a file. Used to load models from a directory.
name (str): Name of module to load.
@ -930,7 +971,7 @@ def import_file(name, loc):
return module
def minify_html(html):
def minify_html(html: str) -> str:
"""Perform a template-specific, rudimentary HTML minification for displaCy.
Disclaimer: NOT a general-purpose solution, only removes indentation and
newlines.
@ -941,7 +982,7 @@ def minify_html(html):
return html.strip().replace(" ", "").replace("\n", "")
def escape_html(text):
def escape_html(text: str) -> str:
"""Replace <, >, &, " with their HTML encoded representation. Intended to
prevent HTML errors in rendered displaCy markup.
@ -955,7 +996,9 @@ def escape_html(text):
return text
def get_words_and_spaces(words, text):
def get_words_and_spaces(
words: Iterable[str], text: str
) -> Tuple[List[str], List[bool]]:
if "".join("".join(words).split()) != "".join(text.split()):
raise ValueError(Errors.E194.format(text=text, words=words))
text_words = []
@ -1103,7 +1146,7 @@ class DummyTokenizer:
return self
def link_vectors_to_models(vocab):
def link_vectors_to_models(vocab: "Vocab") -> None:
vectors = vocab.vectors
if vectors.name is None:
vectors.name = VECTORS_KEY
@ -1119,7 +1162,7 @@ def link_vectors_to_models(vocab):
VECTORS_KEY = "spacy_pretrained_vectors"
def create_default_optimizer():
def create_default_optimizer() -> Optimizer:
learn_rate = env_opt("learn_rate", 0.001)
beta1 = env_opt("optimizer_B1", 0.9)
beta2 = env_opt("optimizer_B2", 0.999)