Tidy up, autoformat, add types

This commit is contained in:
Ines Montani 2020-07-25 15:01:15 +02:00
parent 71242327b2
commit e92df281ce
37 changed files with 423 additions and 454 deletions

1
.gitignore vendored
View File

@ -51,6 +51,7 @@ env3.*/
.denv .denv
.pypyenv .pypyenv
.pytest_cache/ .pytest_cache/
.mypy_cache/
# Distribution / packaging # Distribution / packaging
env/ env/

View File

@ -108,3 +108,8 @@ exclude =
[tool:pytest] [tool:pytest]
markers = markers =
slow slow
[mypy]
ignore_missing_imports = True
no_implicit_optional = True
plugins = pydantic.mypy, thinc.mypy

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, Any, List, Union
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from wasabi import Printer from wasabi import Printer
@ -66,10 +66,9 @@ def convert_cli(
file_type = file_type.value file_type = file_type.value
input_path = Path(input_path) input_path = Path(input_path)
output_dir = "-" if output_dir == Path("-") else output_dir output_dir = "-" if output_dir == Path("-") else output_dir
cli_args = locals()
silent = output_dir == "-" silent = output_dir == "-"
msg = Printer(no_print=silent) msg = Printer(no_print=silent)
verify_cli_args(msg, **cli_args) verify_cli_args(msg, input_path, output_dir, file_type, converter, ner_map)
converter = _get_converter(msg, converter, input_path) converter = _get_converter(msg, converter, input_path)
convert( convert(
input_path, input_path,
@ -89,8 +88,8 @@ def convert_cli(
def convert( def convert(
input_path: Path, input_path: Union[str, Path],
output_dir: Path, output_dir: Union[str, Path],
*, *,
file_type: str = "json", file_type: str = "json",
n_sents: int = 1, n_sents: int = 1,
@ -102,13 +101,12 @@ def convert(
ner_map: Optional[Path] = None, ner_map: Optional[Path] = None,
lang: Optional[str] = None, lang: Optional[str] = None,
silent: bool = True, silent: bool = True,
msg: Optional[Path] = None, msg: Optional[Printer],
) -> None: ) -> None:
if not msg: if not msg:
msg = Printer(no_print=silent) msg = Printer(no_print=silent)
ner_map = srsly.read_json(ner_map) if ner_map is not None else None ner_map = srsly.read_json(ner_map) if ner_map is not None else None
for input_loc in walk_directory(Path(input_path)):
for input_loc in walk_directory(input_path):
input_data = input_loc.open("r", encoding="utf-8").read() input_data = input_loc.open("r", encoding="utf-8").read()
# Use converter function to convert data # Use converter function to convert data
func = CONVERTERS[converter] func = CONVERTERS[converter]
@ -140,14 +138,14 @@ def convert(
msg.good(f"Generated output file ({len(docs)} documents): {output_file}") msg.good(f"Generated output file ({len(docs)} documents): {output_file}")
def _print_docs_to_stdout(data, output_type): def _print_docs_to_stdout(data: Any, output_type: str) -> None:
if output_type == "json": if output_type == "json":
srsly.write_json("-", data) srsly.write_json("-", data)
else: else:
sys.stdout.buffer.write(data) sys.stdout.buffer.write(data)
def _write_docs_to_file(data, output_file, output_type): def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None:
if not output_file.parent.exists(): if not output_file.parent.exists():
output_file.parent.mkdir(parents=True) output_file.parent.mkdir(parents=True)
if output_type == "json": if output_type == "json":
@ -157,7 +155,7 @@ def _write_docs_to_file(data, output_file, output_type):
file_.write(data) file_.write(data)
def autodetect_ner_format(input_data: str) -> str: def autodetect_ner_format(input_data: str) -> Optional[str]:
# guess format from the first 20 lines # guess format from the first 20 lines
lines = input_data.split("\n")[:20] lines = input_data.split("\n")[:20]
format_guesses = {"ner": 0, "iob": 0} format_guesses = {"ner": 0, "iob": 0}
@ -176,7 +174,7 @@ def autodetect_ner_format(input_data: str) -> str:
return None return None
def walk_directory(path): def walk_directory(path: Path) -> List[Path]:
if not path.is_dir(): if not path.is_dir():
return [path] return [path]
paths = [path] paths = [path]
@ -196,31 +194,25 @@ def walk_directory(path):
def verify_cli_args( def verify_cli_args(
msg, msg: Printer,
input_path, input_path: Union[str, Path],
output_dir, output_dir: Union[str, Path],
file_type, file_type: FileTypes,
n_sents, converter: str,
seg_sents, ner_map: Optional[Path],
model,
morphology,
merge_subtokens,
converter,
ner_map,
lang,
): ):
input_path = Path(input_path) input_path = Path(input_path)
if file_type not in FILE_TYPES_STDOUT and output_dir == "-": if file_type not in FILE_TYPES_STDOUT and output_dir == "-":
# TODO: support msgpack via stdout in srsly?
msg.fail( msg.fail(
f"Can't write .{file_type} data to stdout", f"Can't write .{file_type} data to stdout. Please specify an output directory.",
"Please specify an output directory.",
exits=1, exits=1,
) )
if not input_path.exists(): if not input_path.exists():
msg.fail("Input file not found", input_path, exits=1) msg.fail("Input file not found", input_path, exits=1)
if output_dir != "-" and not Path(output_dir).exists(): if output_dir != "-" and not Path(output_dir).exists():
msg.fail("Output directory not found", output_dir, exits=1) msg.fail("Output directory not found", output_dir, exits=1)
if ner_map is not None and not Path(ner_map).exists():
msg.fail("NER map not found", ner_map, exits=1)
if input_path.is_dir(): if input_path.is_dir():
input_locs = walk_directory(input_path) input_locs = walk_directory(input_path)
if len(input_locs) == 0: if len(input_locs) == 0:
@ -229,10 +221,8 @@ def verify_cli_args(
if len(file_types) >= 2: if len(file_types) >= 2:
file_types = ",".join(file_types) file_types = ",".join(file_types)
msg.fail("All input files must be same type", file_types, exits=1) msg.fail("All input files must be same type", file_types, exits=1)
converter = _get_converter(msg, converter, input_path) if converter != "auto" and converter not in CONVERTERS:
if converter not in CONVERTERS:
msg.fail(f"Can't find converter for {converter}", exits=1) msg.fail(f"Can't find converter for {converter}", exits=1)
return converter
def _get_converter(msg, converter, input_path): def _get_converter(msg, converter, input_path):

View File

@ -82,11 +82,11 @@ def evaluate(
"NER P": "ents_p", "NER P": "ents_p",
"NER R": "ents_r", "NER R": "ents_r",
"NER F": "ents_f", "NER F": "ents_f",
"Textcat AUC": 'textcat_macro_auc', "Textcat AUC": "textcat_macro_auc",
"Textcat F": 'textcat_macro_f', "Textcat F": "textcat_macro_f",
"Sent P": 'sents_p', "Sent P": "sents_p",
"Sent R": 'sents_r', "Sent R": "sents_r",
"Sent F": 'sents_f', "Sent F": "sents_f",
} }
results = {} results = {}
for metric, key in metrics.items(): for metric, key in metrics.items():

View File

@ -14,7 +14,6 @@ import typer
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code from ._util import import_code
from ..gold import Corpus, Example from ..gold import Corpus, Example
from ..lookups import Lookups
from ..language import Language from ..language import Language
from .. import util from .. import util
from ..errors import Errors from ..errors import Errors

View File

@ -1,12 +1,5 @@
""" """Helpers for Python and platform compatibility."""
Helpers for Python and platform compatibility. To distinguish them from
the builtin functions, replacement functions are suffixed with an underscore,
e.g. `unicode_`.
DOCS: https://spacy.io/api/top-level#compat
"""
import sys import sys
from thinc.util import copy_array from thinc.util import copy_array
try: try:
@ -40,21 +33,3 @@ copy_array = copy_array
is_windows = sys.platform.startswith("win") is_windows = sys.platform.startswith("win")
is_linux = sys.platform.startswith("linux") is_linux = sys.platform.startswith("linux")
is_osx = sys.platform == "darwin" is_osx = sys.platform == "darwin"
def is_config(windows=None, linux=None, osx=None, **kwargs):
"""Check if a specific configuration of Python version and operating system
matches the user's setup. Mostly used to display targeted error messages.
windows (bool): spaCy is executed on Windows.
linux (bool): spaCy is executed on Linux.
osx (bool): spaCy is executed on OS X or macOS.
RETURNS (bool): Whether the configuration matches the user's platform.
DOCS: https://spacy.io/api/top-level#compat.is_config
"""
return (
windows in (None, is_windows)
and linux in (None, is_linux)
and osx in (None, is_osx)
)

View File

@ -4,6 +4,7 @@ spaCy's built in visualization suite for dependencies and named entities.
DOCS: https://spacy.io/api/top-level#displacy DOCS: https://spacy.io/api/top-level#displacy
USAGE: https://spacy.io/usage/visualizers USAGE: https://spacy.io/usage/visualizers
""" """
from typing import Union, Iterable, Optional, Dict, Any, Callable
import warnings import warnings
from .render import DependencyRenderer, EntityRenderer from .render import DependencyRenderer, EntityRenderer
@ -17,11 +18,17 @@ RENDER_WRAPPER = None
def render( def render(
docs, style="dep", page=False, minify=False, jupyter=None, options={}, manual=False docs: Union[Iterable[Doc], Doc],
): style: str = "dep",
page: bool = False,
minify: bool = False,
jupyter: Optional[bool] = None,
options: Dict[str, Any] = {},
manual: bool = False,
) -> str:
"""Render displaCy visualisation. """Render displaCy visualisation.
docs (list or Doc): Document(s) to visualise. docs (Union[Iterable[Doc], Doc]): Document(s) to visualise.
style (str): Visualisation style, 'dep' or 'ent'. style (str): Visualisation style, 'dep' or 'ent'.
page (bool): Render markup as full HTML page. page (bool): Render markup as full HTML page.
minify (bool): Minify HTML markup. minify (bool): Minify HTML markup.
@ -44,8 +51,8 @@ def render(
docs = [obj if not isinstance(obj, Span) else obj.as_doc() for obj in docs] docs = [obj if not isinstance(obj, Span) else obj.as_doc() for obj in docs]
if not all(isinstance(obj, (Doc, Span, dict)) for obj in docs): if not all(isinstance(obj, (Doc, Span, dict)) for obj in docs):
raise ValueError(Errors.E096) raise ValueError(Errors.E096)
renderer, converter = factories[style] renderer_func, converter = factories[style]
renderer = renderer(options=options) renderer = renderer_func(options=options)
parsed = [converter(doc, options) for doc in docs] if not manual else docs parsed = [converter(doc, options) for doc in docs] if not manual else docs
_html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip() _html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip()
html = _html["parsed"] html = _html["parsed"]
@ -61,15 +68,15 @@ def render(
def serve( def serve(
docs, docs: Union[Iterable[Doc], Doc],
style="dep", style: str = "dep",
page=True, page: bool = True,
minify=False, minify: bool = False,
options={}, options: Dict[str, Any] = {},
manual=False, manual: bool = False,
port=5000, port: int = 5000,
host="0.0.0.0", host: str = "0.0.0.0",
): ) -> None:
"""Serve displaCy visualisation. """Serve displaCy visualisation.
docs (list or Doc): Document(s) to visualise. docs (list or Doc): Document(s) to visualise.
@ -88,7 +95,6 @@ def serve(
if is_in_jupyter(): if is_in_jupyter():
warnings.warn(Warnings.W011) warnings.warn(Warnings.W011)
render(docs, style=style, page=page, minify=minify, options=options, manual=manual) render(docs, style=style, page=page, minify=minify, options=options, manual=manual)
httpd = simple_server.make_server(host, port, app) httpd = simple_server.make_server(host, port, app)
print(f"\nUsing the '{style}' visualizer") print(f"\nUsing the '{style}' visualizer")
@ -102,14 +108,13 @@ def serve(
def app(environ, start_response): def app(environ, start_response):
# Headers and status need to be bytes in Python 2, see #1227
headers = [("Content-type", "text/html; charset=utf-8")] headers = [("Content-type", "text/html; charset=utf-8")]
start_response("200 OK", headers) start_response("200 OK", headers)
res = _html["parsed"].encode(encoding="utf-8") res = _html["parsed"].encode(encoding="utf-8")
return [res] return [res]
def parse_deps(orig_doc, options={}): def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
"""Generate dependency parse in {'words': [], 'arcs': []} format. """Generate dependency parse in {'words': [], 'arcs': []} format.
doc (Doc): Document do parse. doc (Doc): Document do parse.
@ -152,7 +157,6 @@ def parse_deps(orig_doc, options={}):
} }
for w in doc for w in doc
] ]
arcs = [] arcs = []
for word in doc: for word in doc:
if word.i < word.head.i: if word.i < word.head.i:
@ -171,7 +175,7 @@ def parse_deps(orig_doc, options={}):
return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)} return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)}
def parse_ents(doc, options={}): def parse_ents(doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
"""Generate named entities in [{start: i, end: i, label: 'label'}] format. """Generate named entities in [{start: i, end: i, label: 'label'}] format.
doc (Doc): Document do parse. doc (Doc): Document do parse.
@ -188,7 +192,7 @@ def parse_ents(doc, options={}):
return {"text": doc.text, "ents": ents, "title": title, "settings": settings} return {"text": doc.text, "ents": ents, "title": title, "settings": settings}
def set_render_wrapper(func): def set_render_wrapper(func: Callable[[str], str]) -> None:
"""Set an optional wrapper function that is called around the generated """Set an optional wrapper function that is called around the generated
HTML markup on displacy.render. This can be used to allow integration into HTML markup on displacy.render. This can be used to allow integration into
other platforms, similar to Jupyter Notebooks that require functions to be other platforms, similar to Jupyter Notebooks that require functions to be
@ -205,7 +209,7 @@ def set_render_wrapper(func):
RENDER_WRAPPER = func RENDER_WRAPPER = func
def get_doc_settings(doc): def get_doc_settings(doc: Doc) -> Dict[str, Any]:
return { return {
"lang": doc.lang_, "lang": doc.lang_,
"direction": doc.vocab.writing_system.get("direction", "ltr"), "direction": doc.vocab.writing_system.get("direction", "ltr"),

View File

@ -1,19 +1,36 @@
from typing import Dict, Any, List, Optional, Union
import uuid import uuid
from .templates import ( from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_WORDS_LEMMA, TPL_DEP_ARCS
TPL_DEP_SVG,
TPL_DEP_WORDS,
TPL_DEP_WORDS_LEMMA,
TPL_DEP_ARCS,
TPL_ENTS,
)
from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE
from .templates import TPL_ENTS
from ..util import minify_html, escape_html, registry from ..util import minify_html, escape_html, registry
from ..errors import Errors from ..errors import Errors
DEFAULT_LANG = "en" DEFAULT_LANG = "en"
DEFAULT_DIR = "ltr" DEFAULT_DIR = "ltr"
DEFAULT_ENTITY_COLOR = "#ddd"
DEFAULT_LABEL_COLORS = {
"ORG": "#7aecec",
"PRODUCT": "#bfeeb7",
"GPE": "#feca74",
"LOC": "#ff9561",
"PERSON": "#aa9cfc",
"NORP": "#c887fb",
"FACILITY": "#9cc9cc",
"EVENT": "#ffeb80",
"LAW": "#ff8197",
"LANGUAGE": "#ff8197",
"WORK_OF_ART": "#f0d0ff",
"DATE": "#bfe1d9",
"TIME": "#bfe1d9",
"MONEY": "#e4e7d2",
"QUANTITY": "#e4e7d2",
"ORDINAL": "#e4e7d2",
"CARDINAL": "#e4e7d2",
"PERCENT": "#e4e7d2",
}
class DependencyRenderer: class DependencyRenderer:
@ -21,7 +38,7 @@ class DependencyRenderer:
style = "dep" style = "dep"
def __init__(self, options={}): def __init__(self, options: Dict[str, Any] = {}) -> None:
"""Initialise dependency renderer. """Initialise dependency renderer.
options (dict): Visualiser-specific options (compact, word_spacing, options (dict): Visualiser-specific options (compact, word_spacing,
@ -41,7 +58,9 @@ class DependencyRenderer:
self.direction = DEFAULT_DIR self.direction = DEFAULT_DIR
self.lang = DEFAULT_LANG self.lang = DEFAULT_LANG
def render(self, parsed, page=False, minify=False): def render(
self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False
) -> str:
"""Render complete markup. """Render complete markup.
parsed (list): Dependency parses to render. parsed (list): Dependency parses to render.
@ -72,10 +91,15 @@ class DependencyRenderer:
return minify_html(markup) return minify_html(markup)
return markup return markup
def render_svg(self, render_id, words, arcs): def render_svg(
self,
render_id: Union[int, str],
words: List[Dict[str, Any]],
arcs: List[Dict[str, Any]],
) -> str:
"""Render SVG. """Render SVG.
render_id (int): Unique ID, typically index of document. render_id (Union[int, str]): Unique ID, typically index of document.
words (list): Individual words and their tags. words (list): Individual words and their tags.
arcs (list): Individual arcs and their start, end, direction and label. arcs (list): Individual arcs and their start, end, direction and label.
RETURNS (str): Rendered SVG markup. RETURNS (str): Rendered SVG markup.
@ -86,15 +110,15 @@ class DependencyRenderer:
self.width = self.offset_x + len(words) * self.distance self.width = self.offset_x + len(words) * self.distance
self.height = self.offset_y + 3 * self.word_spacing self.height = self.offset_y + 3 * self.word_spacing
self.id = render_id self.id = render_id
words = [ words_svg = [
self.render_word(w["text"], w["tag"], w.get("lemma", None), i) self.render_word(w["text"], w["tag"], w.get("lemma", None), i)
for i, w in enumerate(words) for i, w in enumerate(words)
] ]
arcs = [ arcs_svg = [
self.render_arrow(a["label"], a["start"], a["end"], a["dir"], i) self.render_arrow(a["label"], a["start"], a["end"], a["dir"], i)
for i, a in enumerate(arcs) for i, a in enumerate(arcs)
] ]
content = "".join(words) + "".join(arcs) content = "".join(words_svg) + "".join(arcs_svg)
return TPL_DEP_SVG.format( return TPL_DEP_SVG.format(
id=self.id, id=self.id,
width=self.width, width=self.width,
@ -107,9 +131,7 @@ class DependencyRenderer:
lang=self.lang, lang=self.lang,
) )
def render_word( def render_word(self, text: str, tag: str, lemma: str, i: int) -> str:
self, text, tag, lemma, i,
):
"""Render individual word. """Render individual word.
text (str): Word text. text (str): Word text.
@ -128,7 +150,9 @@ class DependencyRenderer:
) )
return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y) return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y)
def render_arrow(self, label, start, end, direction, i): def render_arrow(
self, label: str, start: int, end: int, direction: str, i: int
) -> str:
"""Render individual arrow. """Render individual arrow.
label (str): Dependency label. label (str): Dependency label.
@ -172,7 +196,7 @@ class DependencyRenderer:
arc=arc, arc=arc,
) )
def get_arc(self, x_start, y, y_curve, x_end): def get_arc(self, x_start: int, y: int, y_curve: int, x_end: int) -> str:
"""Render individual arc. """Render individual arc.
x_start (int): X-coordinate of arrow start point. x_start (int): X-coordinate of arrow start point.
@ -186,7 +210,7 @@ class DependencyRenderer:
template = "M{x},{y} {x},{c} {e},{c} {e},{y}" template = "M{x},{y} {x},{c} {e},{c} {e},{y}"
return template.format(x=x_start, y=y, c=y_curve, e=x_end) return template.format(x=x_start, y=y, c=y_curve, e=x_end)
def get_arrowhead(self, direction, x, y, end): def get_arrowhead(self, direction: str, x: int, y: int, end: int) -> str:
"""Render individual arrow head. """Render individual arrow head.
direction (str): Arrow direction, 'left' or 'right'. direction (str): Arrow direction, 'left' or 'right'.
@ -196,24 +220,12 @@ class DependencyRenderer:
RETURNS (str): Definition of the arrow head path ('d' attribute). RETURNS (str): Definition of the arrow head path ('d' attribute).
""" """
if direction == "left": if direction == "left":
pos1, pos2, pos3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2) p1, p2, p3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2)
else: else:
pos1, pos2, pos3 = ( p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2)
end, return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}"
end + self.arrow_width - 2,
end - self.arrow_width + 2,
)
arrowhead = (
pos1,
y + 2,
pos2,
y - self.arrow_width,
pos3,
y - self.arrow_width,
)
return "M{},{} L{},{} {},{}".format(*arrowhead)
def get_levels(self, arcs): def get_levels(self, arcs: List[Dict[str, Any]]) -> List[int]:
"""Calculate available arc height "levels". """Calculate available arc height "levels".
Used to calculate arrow heights dynamically and without wasting space. Used to calculate arrow heights dynamically and without wasting space.
@ -229,41 +241,21 @@ class EntityRenderer:
style = "ent" style = "ent"
def __init__(self, options={}): def __init__(self, options: Dict[str, Any] = {}) -> None:
"""Initialise dependency renderer. """Initialise dependency renderer.
options (dict): Visualiser-specific options (colors, ents) options (dict): Visualiser-specific options (colors, ents)
""" """
colors = { colors = dict(DEFAULT_LABEL_COLORS)
"ORG": "#7aecec",
"PRODUCT": "#bfeeb7",
"GPE": "#feca74",
"LOC": "#ff9561",
"PERSON": "#aa9cfc",
"NORP": "#c887fb",
"FACILITY": "#9cc9cc",
"EVENT": "#ffeb80",
"LAW": "#ff8197",
"LANGUAGE": "#ff8197",
"WORK_OF_ART": "#f0d0ff",
"DATE": "#bfe1d9",
"TIME": "#bfe1d9",
"MONEY": "#e4e7d2",
"QUANTITY": "#e4e7d2",
"ORDINAL": "#e4e7d2",
"CARDINAL": "#e4e7d2",
"PERCENT": "#e4e7d2",
}
user_colors = registry.displacy_colors.get_all() user_colors = registry.displacy_colors.get_all()
for user_color in user_colors.values(): for user_color in user_colors.values():
colors.update(user_color) colors.update(user_color)
colors.update(options.get("colors", {})) colors.update(options.get("colors", {}))
self.default_color = "#ddd" self.default_color = DEFAULT_ENTITY_COLOR
self.colors = colors self.colors = colors
self.ents = options.get("ents", None) self.ents = options.get("ents", None)
self.direction = DEFAULT_DIR self.direction = DEFAULT_DIR
self.lang = DEFAULT_LANG self.lang = DEFAULT_LANG
template = options.get("template") template = options.get("template")
if template: if template:
self.ent_template = template self.ent_template = template
@ -273,7 +265,9 @@ class EntityRenderer:
else: else:
self.ent_template = TPL_ENT self.ent_template = TPL_ENT
def render(self, parsed, page=False, minify=False): def render(
self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False
) -> str:
"""Render complete markup. """Render complete markup.
parsed (list): Dependency parses to render. parsed (list): Dependency parses to render.
@ -297,7 +291,9 @@ class EntityRenderer:
return minify_html(markup) return minify_html(markup)
return markup return markup
def render_ents(self, text, spans, title): def render_ents(
self, text: str, spans: List[Dict[str, Any]], title: Optional[str]
) -> str:
"""Render entities in text. """Render entities in text.
text (str): Original text. text (str): Original text.

View File

@ -483,6 +483,8 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
E954 = ("The Tok2Vec listener did not receive a valid input.")
E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.") E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.")
E956 = ("Can't find component '{name}' in [components] block in the config. " E956 = ("Can't find component '{name}' in [components] block in the config. "
"Available components: {opts}") "Available components: {opts}")

View File

@ -163,7 +163,7 @@ class Language:
else: else:
if (self.lang and vocab.lang) and (self.lang != vocab.lang): if (self.lang and vocab.lang) and (self.lang != vocab.lang):
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang)) raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
self.vocab = vocab self.vocab: Vocab = vocab
if self.lang is None: if self.lang is None:
self.lang = self.vocab.lang self.lang = self.vocab.lang
self.pipeline = [] self.pipeline = []

View File

@ -8,7 +8,7 @@ def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1):
init=init, init=init,
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP}, dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
params={"W": None, "b": None, "pad": None}, params={"W": None, "b": None, "pad": None},
attrs={"dropout_rate": dropout} attrs={"dropout_rate": dropout},
) )
return model return model

View File

@ -154,16 +154,30 @@ def LayerNormalizedMaxout(width, maxout_pieces):
def MultiHashEmbed( def MultiHashEmbed(
columns, width, rows, use_subwords, pretrained_vectors, mix, dropout columns, width, rows, use_subwords, pretrained_vectors, mix, dropout
): ):
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6) norm = HashEmbed(
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6
)
if use_subwords: if use_subwords:
prefix = HashEmbed( prefix = HashEmbed(
nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout, seed=7 nO=width,
nV=rows // 2,
column=columns.index("PREFIX"),
dropout=dropout,
seed=7,
) )
suffix = HashEmbed( suffix = HashEmbed(
nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout, seed=8 nO=width,
nV=rows // 2,
column=columns.index("SUFFIX"),
dropout=dropout,
seed=8,
) )
shape = HashEmbed( shape = HashEmbed(
nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout, seed=9 nO=width,
nV=rows // 2,
column=columns.index("SHAPE"),
dropout=dropout,
seed=9,
) )
if pretrained_vectors: if pretrained_vectors:
@ -192,7 +206,9 @@ def MultiHashEmbed(
@registry.architectures.register("spacy.CharacterEmbed.v1") @registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout): def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5) norm = HashEmbed(
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5
)
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
with Model.define_operators({">>": chain, "|": concatenate}): with Model.define_operators({">>": chain, "|": concatenate}):
embed_layer = chr_embed | features >> with_array(norm) embed_layer = chr_embed | features >> with_array(norm)
@ -263,21 +279,29 @@ def build_Tok2Vec_model(
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
norm = HashEmbed( norm = HashEmbed(
nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, seed=0
seed=0
) )
if subword_features: if subword_features:
prefix = HashEmbed( prefix = HashEmbed(
nO=width, nV=embed_size // 2, column=cols.index(PREFIX), dropout=None, nO=width,
seed=1 nV=embed_size // 2,
column=cols.index(PREFIX),
dropout=None,
seed=1,
) )
suffix = HashEmbed( suffix = HashEmbed(
nO=width, nV=embed_size // 2, column=cols.index(SUFFIX), dropout=None, nO=width,
seed=2 nV=embed_size // 2,
column=cols.index(SUFFIX),
dropout=None,
seed=2,
) )
shape = HashEmbed( shape = HashEmbed(
nO=width, nV=embed_size // 2, column=cols.index(SHAPE), dropout=None, nO=width,
seed=3 nV=embed_size // 2,
column=cols.index(SHAPE),
dropout=None,
seed=3,
) )
else: else:
prefix, suffix, shape = (None, None, None) prefix, suffix, shape = (None, None, None)
@ -294,11 +318,7 @@ def build_Tok2Vec_model(
embed = uniqued( embed = uniqued(
(glove | norm | prefix | suffix | shape) (glove | norm | prefix | suffix | shape)
>> Maxout( >> Maxout(
nO=width, nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
nI=width * columns,
nP=3,
dropout=0.0,
normalize=True,
), ),
column=cols.index(ORTH), column=cols.index(ORTH),
) )
@ -307,11 +327,7 @@ def build_Tok2Vec_model(
embed = uniqued( embed = uniqued(
(glove | norm) (glove | norm)
>> Maxout( >> Maxout(
nO=width, nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
nI=width * columns,
nP=3,
dropout=0.0,
normalize=True,
), ),
column=cols.index(ORTH), column=cols.index(ORTH),
) )
@ -320,11 +336,7 @@ def build_Tok2Vec_model(
embed = uniqued( embed = uniqued(
concatenate(norm, prefix, suffix, shape) concatenate(norm, prefix, suffix, shape)
>> Maxout( >> Maxout(
nO=width, nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
nI=width * columns,
nP=3,
dropout=0.0,
normalize=True,
), ),
column=cols.index(ORTH), column=cols.index(ORTH),
) )
@ -333,11 +345,7 @@ def build_Tok2Vec_model(
cols cols
) >> with_array(norm) ) >> with_array(norm)
reduce_dimensions = Maxout( reduce_dimensions = Maxout(
nO=width, nO=width, nI=nM * nC + width, nP=3, dropout=0.0, normalize=True,
nI=nM * nC + width,
nP=3,
dropout=0.0,
normalize=True,
) )
else: else:
embed = norm embed = norm

View File

@ -10,7 +10,6 @@ from .simple_ner import SimpleNER
from .tagger import Tagger from .tagger import Tagger
from .textcat import TextCategorizer from .textcat import TextCategorizer
from .tok2vec import Tok2Vec from .tok2vec import Tok2Vec
from .hooks import SentenceSegmenter, SimilarityHook
from .functions import merge_entities, merge_noun_chunks, merge_subtokens from .functions import merge_entities, merge_noun_chunks, merge_subtokens
__all__ = [ __all__ = [
@ -21,9 +20,7 @@ __all__ = [
"Morphologizer", "Morphologizer",
"Pipe", "Pipe",
"SentenceRecognizer", "SentenceRecognizer",
"SentenceSegmenter",
"Sentencizer", "Sentencizer",
"SimilarityHook",
"SimpleNER", "SimpleNER",
"Tagger", "Tagger",
"TextCategorizer", "TextCategorizer",

View File

@ -63,7 +63,7 @@ class EntityRuler:
overwrite_ents: bool = False, overwrite_ents: bool = False,
ent_id_sep: str = DEFAULT_ENT_ID_SEP, ent_id_sep: str = DEFAULT_ENT_ID_SEP,
patterns: Optional[List[PatternType]] = None, patterns: Optional[List[PatternType]] = None,
): ) -> None:
"""Initialize the entitiy ruler. If patterns are supplied here, they """Initialize the entitiy ruler. If patterns are supplied here, they
need to be a list of dictionaries with a `"label"` and `"pattern"` need to be a list of dictionaries with a `"label"` and `"pattern"`
key. A pattern can either be a token pattern (list) or a phrase pattern key. A pattern can either be a token pattern (list) or a phrase pattern
@ -239,7 +239,6 @@ class EntityRuler:
phrase_pattern_labels = [] phrase_pattern_labels = []
phrase_pattern_texts = [] phrase_pattern_texts = []
phrase_pattern_ids = [] phrase_pattern_ids = []
for entry in patterns: for entry in patterns:
if isinstance(entry["pattern"], str): if isinstance(entry["pattern"], str):
phrase_pattern_labels.append(entry["label"]) phrase_pattern_labels.append(entry["label"])
@ -247,7 +246,6 @@ class EntityRuler:
phrase_pattern_ids.append(entry.get("id")) phrase_pattern_ids.append(entry.get("id"))
elif isinstance(entry["pattern"], list): elif isinstance(entry["pattern"], list):
token_patterns.append(entry) token_patterns.append(entry)
phrase_patterns = [] phrase_patterns = []
for label, pattern, ent_id in zip( for label, pattern, ent_id in zip(
phrase_pattern_labels, phrase_pattern_labels,
@ -258,7 +256,6 @@ class EntityRuler:
if ent_id: if ent_id:
phrase_pattern["id"] = ent_id phrase_pattern["id"] = ent_id
phrase_patterns.append(phrase_pattern) phrase_patterns.append(phrase_pattern)
for entry in token_patterns + phrase_patterns: for entry in token_patterns + phrase_patterns:
label = entry["label"] label = entry["label"]
if "id" in entry: if "id" in entry:
@ -266,7 +263,6 @@ class EntityRuler:
label = self._create_label(label, entry["id"]) label = self._create_label(label, entry["id"])
key = self.matcher._normalize_key(label) key = self.matcher._normalize_key(label)
self._ent_ids[key] = (ent_label, entry["id"]) self._ent_ids[key] = (ent_label, entry["id"])
pattern = entry["pattern"] pattern = entry["pattern"]
if isinstance(pattern, Doc): if isinstance(pattern, Doc):
self.phrase_patterns[label].append(pattern) self.phrase_patterns[label].append(pattern)
@ -323,13 +319,13 @@ class EntityRuler:
self.clear() self.clear()
if isinstance(cfg, dict): if isinstance(cfg, dict):
self.add_patterns(cfg.get("patterns", cfg)) self.add_patterns(cfg.get("patterns", cfg))
self.overwrite = cfg.get("overwrite") self.overwrite = cfg.get("overwrite", False)
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None) self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
if self.phrase_matcher_attr is not None: if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher( self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr self.nlp.vocab, attr=self.phrase_matcher_attr
) )
self.ent_id_sep = cfg.get("ent_id_sep") self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
else: else:
self.add_patterns(cfg) self.add_patterns(cfg)
return self return self
@ -375,9 +371,9 @@ class EntityRuler:
} }
deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))} deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
from_disk(path, deserializers_cfg, {}) from_disk(path, deserializers_cfg, {})
self.overwrite = cfg.get("overwrite") self.overwrite = cfg.get("overwrite", False)
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr") self.phrase_matcher_attr = cfg.get("phrase_matcher_attr")
self.ent_id_sep = cfg.get("ent_id_sep") self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
if self.phrase_matcher_attr is not None: if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher( self.phrase_matcher = PhraseMatcher(

View File

@ -1,95 +0,0 @@
from thinc.api import concatenate, reduce_max, reduce_mean, siamese, CauchySimilarity
from .pipe import Pipe
from ..util import link_vectors_to_models
# TODO: do we want to keep these?
class SentenceSegmenter:
"""A simple spaCy hook, to allow custom sentence boundary detection logic
(that doesn't require the dependency parse). To change the sentence
boundary detection strategy, pass a generator function `strategy` on
initialization, or assign a new strategy to the .strategy attribute.
Sentence detection strategies should be generators that take `Doc` objects
and yield `Span` objects for each sentence.
"""
def __init__(self, vocab, strategy=None):
self.vocab = vocab
if strategy is None or strategy == "on_punct":
strategy = self.split_on_punct
self.strategy = strategy
def __call__(self, doc):
doc.user_hooks["sents"] = self.strategy
return doc
@staticmethod
def split_on_punct(doc):
start = 0
seen_period = False
for i, token in enumerate(doc):
if seen_period and not token.is_punct:
yield doc[start : token.i]
start = token.i
seen_period = False
elif token.text in [".", "!", "?"]:
seen_period = True
if start < len(doc):
yield doc[start : len(doc)]
class SimilarityHook(Pipe):
"""
Experimental: A pipeline component to install a hook for supervised
similarity into `Doc` objects.
The similarity model can be any object obeying the Thinc `Model`
interface. By default, the model concatenates the elementwise mean and
elementwise max of the two tensors, and compares them using the
Cauchy-like similarity function from Chen (2013):
>>> similarity = 1. / (1. + (W * (vec1-vec2)**2).sum())
Where W is a vector of dimension weights, initialized to 1.
"""
def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab
self.model = model
self.cfg = dict(cfg)
@classmethod
def Model(cls, length):
return siamese(
concatenate(reduce_max(), reduce_mean()), CauchySimilarity(length * 2)
)
def __call__(self, doc):
"""Install similarity hook"""
doc.user_hooks["similarity"] = self.predict
return doc
def pipe(self, docs, **kwargs):
for doc in docs:
yield self(doc)
def predict(self, doc1, doc2):
return self.model.predict([(doc1, doc2)])
def update(self, doc1_doc2, golds, sgd=None, drop=0.0):
sims, bp_sims = self.model.begin_update(doc1_doc2)
def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs):
"""Allocate model, using nO from the first model in the pipeline.
gold_tuples (iterable): Gold-standard training data.
pipeline (list): The pipeline the model is part of.
"""
if self.model is True:
self.model = self.Model(pipeline[0].model.get_dim("nO"))
link_vectors_to_models(self.vocab)
if sgd is None:
sgd = self.create_optimizer()
return sgd

View File

@ -1,4 +1,4 @@
from typing import Iterable, Tuple, Optional, Dict, List, Callable from typing import Iterable, Tuple, Optional, Dict, List, Callable, Iterator, Any
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
import numpy import numpy
@ -97,7 +97,7 @@ class TextCategorizer(Pipe):
def labels(self, value: Iterable[str]) -> None: def labels(self, value: Iterable[str]) -> None:
self.cfg["labels"] = tuple(value) self.cfg["labels"] = tuple(value)
def pipe(self, stream, batch_size=128): def pipe(self, stream: Iterator[str], batch_size: int = 128) -> Iterator[Doc]:
for docs in util.minibatch(stream, size=batch_size): for docs in util.minibatch(stream, size=batch_size):
scores = self.predict(docs) scores = self.predict(docs)
self.set_annotations(docs, scores) self.set_annotations(docs, scores)
@ -252,8 +252,17 @@ class TextCategorizer(Pipe):
sgd = self.create_optimizer() sgd = self.create_optimizer()
return sgd return sgd
def score(self, examples, positive_label=None, **kwargs): def score(
return Scorer.score_cats(examples, "cats", labels=self.labels, self,
examples: Iterable[Example],
positive_label: Optional[str] = None,
**kwargs,
) -> Dict[str, Any]:
return Scorer.score_cats(
examples,
"cats",
labels=self.labels,
multi_label=self.model.attrs["multi_label"], multi_label=self.model.attrs["multi_label"],
positive_label=positive_label, **kwargs positive_label=positive_label,
**kwargs,
) )

View File

@ -6,6 +6,7 @@ from ..gold import Example
from ..tokens import Doc from ..tokens import Doc
from ..vocab import Vocab from ..vocab import Vocab
from ..language import Language from ..language import Language
from ..errors import Errors
from ..util import link_vectors_to_models, minibatch from ..util import link_vectors_to_models, minibatch
@ -150,7 +151,7 @@ class Tok2Vec(Pipe):
self.set_annotations(docs, tokvecs) self.set_annotations(docs, tokvecs)
return losses return losses
def get_loss(self, examples, scores): def get_loss(self, examples, scores) -> None:
pass pass
def begin_training( def begin_training(
@ -184,26 +185,26 @@ class Tok2VecListener(Model):
self._backprop = None self._backprop = None
@classmethod @classmethod
def get_batch_id(cls, inputs): def get_batch_id(cls, inputs) -> int:
return sum(sum(token.orth for token in doc) for doc in inputs) return sum(sum(token.orth for token in doc) for doc in inputs)
def receive(self, batch_id, outputs, backprop): def receive(self, batch_id: int, outputs, backprop) -> None:
self._batch_id = batch_id self._batch_id = batch_id
self._outputs = outputs self._outputs = outputs
self._backprop = backprop self._backprop = backprop
def verify_inputs(self, inputs): def verify_inputs(self, inputs) -> bool:
if self._batch_id is None and self._outputs is None: if self._batch_id is None and self._outputs is None:
raise ValueError("The Tok2Vec listener did not receive valid input.") raise ValueError(Errors.E954)
else: else:
batch_id = self.get_batch_id(inputs) batch_id = self.get_batch_id(inputs)
if batch_id != self._batch_id: if batch_id != self._batch_id:
raise ValueError(f"Mismatched IDs! {batch_id} vs {self._batch_id}") raise ValueError(Errors.E953.format(id1=batch_id, id2=self._batch_id))
else: else:
return True return True
def forward(model: Tok2VecListener, inputs, is_train): def forward(model: Tok2VecListener, inputs, is_train: bool):
if is_train: if is_train:
model.verify_inputs(inputs) model.verify_inputs(inputs)
return model._outputs, model._backprop return model._outputs, model._backprop

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Union, Optional, Sequence, Any, Callable from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator from pydantic import BaseModel, Field, ValidationError, validator
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
@ -9,12 +9,12 @@ from thinc.api import Optimizer
from .attrs import NAMES from .attrs import NAMES
def validate(schema, obj): def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
"""Validate data against a given pydantic schema. """Validate data against a given pydantic schema.
obj (dict): JSON-serializable data to validate. obj (Dict[str, Any]): JSON-serializable data to validate.
schema (pydantic.BaseModel): The schema to validate against. schema (pydantic.BaseModel): The schema to validate against.
RETURNS (list): A list of error messages, if available. RETURNS (List[str]): A list of error messages, if available.
""" """
try: try:
schema(**obj) schema(**obj)
@ -31,7 +31,7 @@ def validate(schema, obj):
# Matcher token patterns # Matcher token patterns
def validate_token_pattern(obj): def validate_token_pattern(obj: list) -> List[str]:
# Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"}) # Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"})
get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k
if isinstance(obj, list): if isinstance(obj, list):

View File

@ -15,7 +15,8 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config) ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training([]) ner.begin_training([])
ner(doc) ner(doc)
@ -37,7 +38,8 @@ def test_ents_reset(en_vocab):
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config) ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training([]) ner.begin_training([])
ner(doc) ner(doc)

View File

@ -14,7 +14,9 @@ def test_en_sbd_single_punct(en_tokenizer, text, punct):
assert sum(len(sent) for sent in doc.sents) == len(doc) assert sum(len(sent) for sent in doc.sents) == len(doc)
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") @pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_en_sentence_breaks(en_tokenizer, en_parser): def test_en_sentence_breaks(en_tokenizer, en_parser):
# fmt: off # fmt: off
text = "This is a sentence . This is another one ." text = "This is a sentence . This is another one ."

View File

@ -32,6 +32,25 @@ SENTENCE_TESTS = [
("あれ。これ。", ["あれ。", "これ。"]), ("あれ。これ。", ["あれ。", "これ。"]),
("「伝染るんです。」という漫画があります。", ["「伝染るんです。」という漫画があります。"]), ("「伝染るんです。」という漫画があります。", ["「伝染るんです。」という漫画があります。"]),
] ]
tokens1 = [
DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None),
DetailedToken(surface="", tag="名詞-普通名詞-一般", inf="", lemma="", reading="カイ", sub_tokens=None),
]
tokens2 = [
DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None),
DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None),
DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None),
DetailedToken(surface="", tag="名詞-普通名詞-一般", inf="", lemma="", reading="カイ", sub_tokens=None),
]
tokens3 = [
DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None),
DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None),
DetailedToken(surface="委員会", tag="名詞-普通名詞-一般", inf="", lemma="委員会", reading="イインカイ", sub_tokens=None),
]
SUB_TOKEN_TESTS = [
("選挙管理委員会", [None, None, None, None], [None, None, [tokens1]], [[tokens2, tokens3]])
]
# fmt: on # fmt: on
@ -92,33 +111,12 @@ def test_ja_tokenizer_split_modes(ja_tokenizer, text, len_a, len_b, len_c):
assert len(nlp_c(text)) == len_c assert len(nlp_c(text)) == len_c
@pytest.mark.parametrize("text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c", @pytest.mark.parametrize(
[ "text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c", SUB_TOKEN_TESTS,
(
"選挙管理委員会",
[None, None, None, None],
[None, None, [
[
DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None),
DetailedToken(surface='', tag='名詞-普通名詞-一般', inf='', lemma='', reading='カイ', sub_tokens=None),
]
]],
[[
[
DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None),
DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None),
DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None),
DetailedToken(surface='', tag='名詞-普通名詞-一般', inf='', lemma='', reading='カイ', sub_tokens=None),
], [
DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None),
DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None),
DetailedToken(surface='委員会', tag='名詞-普通名詞-一般', inf='', lemma='委員会', reading='イインカイ', sub_tokens=None),
]
]]
),
]
) )
def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c): def test_ja_tokenizer_sub_tokens(
ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c
):
nlp_a = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "A"}}}) nlp_a = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "A"}}})
nlp_b = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "B"}}}) nlp_b = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "B"}}})
nlp_c = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "C"}}}) nlp_c = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "C"}}})
@ -129,16 +127,19 @@ def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_toke
assert nlp_c(text).user_data["sub_tokens"] == sub_tokens_list_c assert nlp_c(text).user_data["sub_tokens"] == sub_tokens_list_c
@pytest.mark.parametrize("text,inflections,reading_forms", @pytest.mark.parametrize(
"text,inflections,reading_forms",
[ [
( (
"取ってつけた", "取ってつけた",
("五段-ラ行,連用形-促音便", "", "下一段-カ行,連用形-一般", "助動詞-タ,終止形-一般"), ("五段-ラ行,連用形-促音便", "", "下一段-カ行,連用形-一般", "助動詞-タ,終止形-一般"),
("トッ", "", "ツケ", ""), ("トッ", "", "ツケ", ""),
), ),
] ],
) )
def test_ja_tokenizer_inflections_reading_forms(ja_tokenizer, text, inflections, reading_forms): def test_ja_tokenizer_inflections_reading_forms(
ja_tokenizer, text, inflections, reading_forms
):
assert ja_tokenizer(text).user_data["inflections"] == inflections assert ja_tokenizer(text).user_data["inflections"] == inflections
assert ja_tokenizer(text).user_data["reading_forms"] == reading_forms assert ja_tokenizer(text).user_data["reading_forms"] == reading_forms

View File

@ -11,8 +11,7 @@ def test_ne_tokenizer_handlers_long_text(ne_tokenizer):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"text,length", "text,length", [("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)],
[("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)],
) )
def test_ne_tokenizer_handles_cnts(ne_tokenizer, text, length): def test_ne_tokenizer_handles_cnts(ne_tokenizer, text, length):
tokens = ne_tokenizer(text) tokens = ne_tokenizer(text)

View File

@ -30,10 +30,7 @@ def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
nlp = Chinese( nlp = Chinese(
meta={ meta={
"tokenizer": { "tokenizer": {
"config": { "config": {"segmenter": "pkuseg", "pkuseg_model": "medicine",}
"segmenter": "pkuseg",
"pkuseg_model": "medicine",
}
} }
} }
) )

View File

@ -22,7 +22,8 @@ def parser(vocab):
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = DependencyParser(vocab, model, **config) parser = DependencyParser(vocab, model, **config)
return parser return parser
@ -68,7 +69,8 @@ def test_add_label_deserializes_correctly():
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
ner1 = EntityRecognizer(Vocab(), model, **config) ner1 = EntityRecognizer(Vocab(), model, **config)
ner1.add_label("C") ner1.add_label("C")
ner1.add_label("B") ner1.add_label("B")
@ -86,7 +88,10 @@ def test_add_label_deserializes_correctly():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pipe_cls,n_moves,model_config", "pipe_cls,n_moves,model_config",
[(DependencyParser, 5, DEFAULT_PARSER_MODEL), (EntityRecognizer, 4, DEFAULT_NER_MODEL)], [
(DependencyParser, 5, DEFAULT_PARSER_MODEL),
(EntityRecognizer, 4, DEFAULT_NER_MODEL),
],
) )
def test_add_label_get_label(pipe_cls, n_moves, model_config): def test_add_label_get_label(pipe_cls, n_moves, model_config):
"""Test that added labels are returned correctly. This test was added to """Test that added labels are returned correctly. This test was added to

View File

@ -126,7 +126,8 @@ def test_get_oracle_actions():
"min_action_freq": 0, "min_action_freq": 0,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = DependencyParser(doc.vocab, model, **config) parser = DependencyParser(doc.vocab, model, **config)
parser.moves.add_action(0, "") parser.moves.add_action(0, "")
parser.moves.add_action(1, "") parser.moves.add_action(1, "")

View File

@ -24,7 +24,8 @@ def arc_eager(vocab):
@pytest.fixture @pytest.fixture
def tok2vec(): def tok2vec():
tok2vec = registry.make_from_config({"model": DEFAULT_TOK2VEC_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_TOK2VEC_MODEL}
tok2vec = registry.make_from_config(cfg, validate=True)["model"]
tok2vec.initialize() tok2vec.initialize()
return tok2vec return tok2vec
@ -36,13 +37,15 @@ def parser(vocab, arc_eager):
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
return Parser(vocab, model, moves=arc_eager, **config) return Parser(vocab, model, moves=arc_eager, **config)
@pytest.fixture @pytest.fixture
def model(arc_eager, tok2vec, vocab): def model(arc_eager, tok2vec, vocab):
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model.attrs["resize_output"](model, arc_eager.n_moves) model.attrs["resize_output"](model, arc_eager.n_moves)
model.initialize() model.initialize()
return model return model
@ -68,7 +71,8 @@ def test_build_model(parser, vocab):
"min_action_freq": 0, "min_action_freq": 0,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model
assert parser.model is not None assert parser.model is not None

View File

@ -33,7 +33,9 @@ def test_parser_root(en_tokenizer):
assert t.dep != 0, t.text assert t.dep != 0, t.text
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") @pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
@pytest.mark.parametrize("text", ["Hello"]) @pytest.mark.parametrize("text", ["Hello"])
def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text): def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text):
tokens = en_tokenizer(text) tokens = en_tokenizer(text)
@ -47,7 +49,9 @@ def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text):
assert doc[0].dep != 0 assert doc[0].dep != 0
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") @pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_parser_initial(en_tokenizer, en_parser): def test_parser_initial(en_tokenizer, en_parser):
text = "I ate the pizza with anchovies." text = "I ate the pizza with anchovies."
# heads = [1, 0, 1, -2, -3, -1, -5] # heads = [1, 0, 1, -2, -3, -1, -5]
@ -92,7 +96,9 @@ def test_parser_merge_pp(en_tokenizer):
assert doc[3].text == "occurs" assert doc[3].text == "occurs"
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") @pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_parser_arc_eager_finalize_state(en_tokenizer, en_parser): def test_parser_arc_eager_finalize_state(en_tokenizer, en_parser):
text = "a b c d e" text = "a b c d e"

View File

@ -21,7 +21,8 @@ def parser(vocab):
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = DependencyParser(vocab, model, **config) parser = DependencyParser(vocab, model, **config)
parser.cfg["token_vector_width"] = 4 parser.cfg["token_vector_width"] = 4
parser.cfg["hidden_width"] = 32 parser.cfg["hidden_width"] = 32

View File

@ -28,7 +28,9 @@ def test_parser_sentence_space(en_tokenizer):
assert len(list(doc.sents)) == 2 assert len(list(doc.sents)) == 2
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") @pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_parser_space_attachment_leading(en_tokenizer, en_parser): def test_parser_space_attachment_leading(en_tokenizer, en_parser):
text = "\t \n This is a sentence ." text = "\t \n This is a sentence ."
heads = [1, 1, 0, 1, -2, -3] heads = [1, 1, 0, 1, -2, -3]
@ -44,7 +46,9 @@ def test_parser_space_attachment_leading(en_tokenizer, en_parser):
assert stepwise.stack == set([2]) assert stepwise.stack == set([2])
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") @pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser): def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser):
text = "This is \t a \t\n \n sentence . \n\n \n" text = "This is \t a \t\n \n sentence . \n\n \n"
heads = [1, 0, -1, 2, -1, -4, -5, -1] heads = [1, 0, -1, 2, -1, -4, -5, -1]
@ -64,7 +68,9 @@ def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser):
@pytest.mark.parametrize("text,length", [(["\n"], 1), (["\n", "\t", "\n\n", "\t"], 4)]) @pytest.mark.parametrize("text,length", [(["\n"], 1), (["\n", "\t", "\n\n", "\t"], 4)])
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") @pytest.mark.skip(
reason="The step_through API was removed (but should be brought back)"
)
def test_parser_space_attachment_space(en_tokenizer, en_parser, text, length): def test_parser_space_attachment_space(en_tokenizer, en_parser, text, length):
doc = Doc(en_parser.vocab, words=text) doc = Doc(en_parser.vocab, words=text)
assert len(doc) == length assert len(doc) == length

View File

@ -117,7 +117,9 @@ def test_overfitting_IO():
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1) assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
# Test scoring # Test scoring
scores = nlp.evaluate(train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}}) scores = nlp.evaluate(
train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}}
)
assert scores["cats_f"] == 1.0 assert scores["cats_f"] == 1.0

View File

@ -201,7 +201,8 @@ def test_issue3345():
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
ner = EntityRecognizer(doc.vocab, model, **config) ner = EntityRecognizer(doc.vocab, model, **config)
# Add the OUT action. I wouldn't have thought this would be necessary... # Add the OUT action. I wouldn't have thought this would be necessary...
ner.moves.add_action(5, "") ner.moves.add_action(5, "")

View File

@ -20,7 +20,8 @@ def parser(en_vocab):
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = DependencyParser(en_vocab, model, **config) parser = DependencyParser(en_vocab, model, **config)
parser.add_label("nsubj") parser.add_label("nsubj")
return parser return parser
@ -33,14 +34,16 @@ def blank_parser(en_vocab):
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = DependencyParser(en_vocab, model, **config) parser = DependencyParser(en_vocab, model, **config)
return parser return parser
@pytest.fixture @pytest.fixture
def taggers(en_vocab): def taggers(en_vocab):
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
tagger1 = Tagger(en_vocab, model, set_morphology=True) tagger1 = Tagger(en_vocab, model, set_morphology=True)
tagger2 = Tagger(en_vocab, model, set_morphology=True) tagger2 = Tagger(en_vocab, model, set_morphology=True)
return tagger1, tagger2 return tagger1, tagger2
@ -53,7 +56,8 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
"min_action_freq": 0, "min_action_freq": 0,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = Parser(en_vocab, model, **config) parser = Parser(en_vocab, model, **config)
new_parser = Parser(en_vocab, model, **config) new_parser = Parser(en_vocab, model, **config)
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"])) new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
@ -70,7 +74,8 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
"min_action_freq": 0, "min_action_freq": 0,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
parser = Parser(en_vocab, model, **config) parser = Parser(en_vocab, model, **config)
with make_tempdir() as d: with make_tempdir() as d:
file_path = d / "parser" file_path = d / "parser"
@ -88,7 +93,6 @@ def test_to_from_bytes(parser, blank_parser):
assert blank_parser.model is not True assert blank_parser.model is not True
assert blank_parser.moves.n_moves != parser.moves.n_moves assert blank_parser.moves.n_moves != parser.moves.n_moves
bytes_data = parser.to_bytes(exclude=["vocab"]) bytes_data = parser.to_bytes(exclude=["vocab"])
# the blank parser needs to be resized before we can call from_bytes # the blank parser needs to be resized before we can call from_bytes
blank_parser.model.attrs["resize_output"](blank_parser.model, parser.moves.n_moves) blank_parser.model.attrs["resize_output"](blank_parser.model, parser.moves.n_moves)
blank_parser.from_bytes(bytes_data) blank_parser.from_bytes(bytes_data)
@ -104,7 +108,8 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
tagger1_b = tagger1.to_bytes() tagger1_b = tagger1.to_bytes()
tagger1 = tagger1.from_bytes(tagger1_b) tagger1 = tagger1.from_bytes(tagger1_b)
assert tagger1.to_bytes() == tagger1_b assert tagger1.to_bytes() == tagger1_b
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b) new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b)
new_tagger1_b = new_tagger1.to_bytes() new_tagger1_b = new_tagger1.to_bytes()
assert len(new_tagger1_b) == len(tagger1_b) assert len(new_tagger1_b) == len(tagger1_b)
@ -118,7 +123,8 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
file_path2 = d / "tagger2" file_path2 = d / "tagger2"
tagger1.to_disk(file_path1) tagger1.to_disk(file_path1)
tagger2.to_disk(file_path2) tagger2.to_disk(file_path2)
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
tagger1_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path1) tagger1_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path1)
tagger2_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path2) tagger2_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path2)
assert tagger1_d.to_bytes() == tagger2_d.to_bytes() assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
@ -126,21 +132,22 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
def test_serialize_textcat_empty(en_vocab): def test_serialize_textcat_empty(en_vocab):
# See issue #1105 # See issue #1105
model = registry.make_from_config({"model": DEFAULT_TEXTCAT_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_TEXTCAT_MODEL}
textcat = TextCategorizer( model = registry.make_from_config(cfg, validate=True)["model"]
en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"] textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"])
)
textcat.to_bytes(exclude=["vocab"]) textcat.to_bytes(exclude=["vocab"])
@pytest.mark.parametrize("Parser", test_parsers) @pytest.mark.parametrize("Parser", test_parsers)
def test_serialize_pipe_exclude(en_vocab, Parser): def test_serialize_pipe_exclude(en_vocab, Parser):
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
config = { config = {
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 0, "min_action_freq": 0,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
def get_new_parser(): def get_new_parser():
new_parser = Parser(en_vocab, model, **config) new_parser = Parser(en_vocab, model, **config)
return new_parser return new_parser
@ -160,7 +167,8 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
def test_serialize_sentencerecognizer(en_vocab): def test_serialize_sentencerecognizer(en_vocab):
model = registry.make_from_config({"model": DEFAULT_SENTER_MODEL}, validate=True)["model"] cfg = {"model": DEFAULT_SENTER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
sr = SentenceRecognizer(en_vocab, model) sr = SentenceRecognizer(en_vocab, model)
sr_b = sr.to_bytes() sr_b = sr.to_bytes()
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b) sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)

View File

@ -466,7 +466,6 @@ def test_iob_to_biluo():
def test_roundtrip_docs_to_docbin(doc): def test_roundtrip_docs_to_docbin(doc):
nlp = English()
text = doc.text text = doc.text
idx = [t.idx for t in doc] idx = [t.idx for t in doc]
tags = [t.tag_ for t in doc] tags = [t.tag_ for t in doc]

View File

@ -7,11 +7,11 @@ from spacy.vocab import Vocab
def test_Example_init_requires_doc_objects(): def test_Example_init_requires_doc_objects():
vocab = Vocab() vocab = Vocab()
with pytest.raises(TypeError): with pytest.raises(TypeError):
example = Example(None, None) Example(None, None)
with pytest.raises(TypeError): with pytest.raises(TypeError):
example = Example(Doc(vocab, words=["hi"]), None) Example(Doc(vocab, words=["hi"]), None)
with pytest.raises(TypeError): with pytest.raises(TypeError):
example = Example(None, Doc(vocab, words=["hi"])) Example(None, Doc(vocab, words=["hi"]))
def test_Example_from_dict_basic(): def test_Example_from_dict_basic():

View File

@ -105,7 +105,11 @@ def test_tokenization(sented_doc):
assert scores["token_acc"] == 1.0 assert scores["token_acc"] == 1.0
nlp = English() nlp = English()
example.predicted = Doc(nlp.vocab, words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."], spaces=[True, True, True, True, True, False]) example.predicted = Doc(
nlp.vocab,
words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."],
spaces=[True, True, True, True, True, False],
)
example.predicted[1].is_sent_start = False example.predicted[1].is_sent_start = False
scores = scorer.score([example]) scores = scorer.score([example])
assert scores["token_acc"] == approx(0.66666666) assert scores["token_acc"] == approx(0.66666666)

View File

@ -1,12 +1,13 @@
from typing import List, Union, Dict, Any, Optional, Iterable, Callable, Tuple from typing import List, Union, Dict, Any, Optional, Iterable, Callable, Tuple
from typing import Iterator, TYPE_CHECKING from typing import Iterator, Type, Pattern, Sequence, TYPE_CHECKING
from types import ModuleType
import os import os
import importlib import importlib
import importlib.util import importlib.util
import re import re
from pathlib import Path from pathlib import Path
import thinc import thinc
from thinc.api import NumpyOps, get_current_ops, Adam, Config from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
import functools import functools
import itertools import itertools
import numpy.random import numpy.random
@ -49,6 +50,8 @@ from . import about
if TYPE_CHECKING: if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports # This lets us add type hints for mypy etc. without causing circular imports
from .language import Language # noqa: F401 from .language import Language # noqa: F401
from .tokens import Doc, Span # noqa: F401
from .vocab import Vocab # noqa: F401
_PRINT_ENV = False _PRINT_ENV = False
@ -102,12 +105,12 @@ class SimpleFrozenDict(dict):
raise NotImplementedError(self.error) raise NotImplementedError(self.error)
def set_env_log(value): def set_env_log(value: bool) -> None:
global _PRINT_ENV global _PRINT_ENV
_PRINT_ENV = value _PRINT_ENV = value
def lang_class_is_loaded(lang): def lang_class_is_loaded(lang: str) -> bool:
"""Check whether a Language class is already loaded. Language classes are """Check whether a Language class is already loaded. Language classes are
loaded lazily, to avoid expensive setup code associated with the language loaded lazily, to avoid expensive setup code associated with the language
data. data.
@ -118,7 +121,7 @@ def lang_class_is_loaded(lang):
return lang in registry.languages return lang in registry.languages
def get_lang_class(lang): def get_lang_class(lang: str) -> "Language":
"""Import and load a Language class. """Import and load a Language class.
lang (str): Two-letter language code, e.g. 'en'. lang (str): Two-letter language code, e.g. 'en'.
@ -136,7 +139,7 @@ def get_lang_class(lang):
return registry.languages.get(lang) return registry.languages.get(lang)
def set_lang_class(name, cls): def set_lang_class(name: str, cls: Type["Language"]) -> None:
"""Set a custom Language class name that can be loaded via get_lang_class. """Set a custom Language class name that can be loaded via get_lang_class.
name (str): Name of Language class. name (str): Name of Language class.
@ -145,10 +148,10 @@ def set_lang_class(name, cls):
registry.languages.register(name, func=cls) registry.languages.register(name, func=cls)
def ensure_path(path): def ensure_path(path: Any) -> Any:
"""Ensure string is converted to a Path. """Ensure string is converted to a Path.
path: Anything. If string, it's converted to Path. path (Any): Anything. If string, it's converted to Path.
RETURNS: Path or original argument. RETURNS: Path or original argument.
""" """
if isinstance(path, str): if isinstance(path, str):
@ -157,7 +160,7 @@ def ensure_path(path):
return path return path
def load_language_data(path): def load_language_data(path: Union[str, Path]) -> Union[dict, list]:
"""Load JSON language data using the given path as a base. If the provided """Load JSON language data using the given path as a base. If the provided
path isn't present, will attempt to load a gzipped version before giving up. path isn't present, will attempt to load a gzipped version before giving up.
@ -173,7 +176,12 @@ def load_language_data(path):
raise ValueError(Errors.E160.format(path=path)) raise ValueError(Errors.E160.format(path=path))
def get_module_path(module): def get_module_path(module: ModuleType) -> Path:
"""Get the path of a Python module.
module (ModuleType): The Python module.
RETURNS (Path): The path.
"""
if not hasattr(module, "__module__"): if not hasattr(module, "__module__"):
raise ValueError(Errors.E169.format(module=repr(module))) raise ValueError(Errors.E169.format(module=repr(module)))
return Path(sys.modules[module.__module__].__file__).parent return Path(sys.modules[module.__module__].__file__).parent
@ -183,7 +191,7 @@ def load_model(
name: Union[str, Path], name: Union[str, Path],
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
): ) -> "Language":
"""Load a model from a package or data path. """Load a model from a package or data path.
name (str): Package name or model path. name (str): Package name or model path.
@ -209,7 +217,7 @@ def load_model_from_package(
name: str, name: str,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
): ) -> "Language":
"""Load a model from an installed package.""" """Load a model from an installed package."""
cls = importlib.import_module(name) cls = importlib.import_module(name)
return cls.load(disable=disable, component_cfg=component_cfg) return cls.load(disable=disable, component_cfg=component_cfg)
@ -220,7 +228,7 @@ def load_model_from_path(
meta: Optional[Dict[str, Any]] = None, meta: Optional[Dict[str, Any]] = None,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
): ) -> "Language":
"""Load a model from a data directory path. Creates Language class with """Load a model from a data directory path. Creates Language class with
pipeline from config.cfg and then calls from_disk() with path.""" pipeline from config.cfg and then calls from_disk() with path."""
if not model_path.exists(): if not model_path.exists():
@ -269,7 +277,7 @@ def load_model_from_init_py(
init_file: Union[Path, str], init_file: Union[Path, str],
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
): ) -> "Language":
"""Helper function to use in the `load()` method of a model package's """Helper function to use in the `load()` method of a model package's
__init__.py. __init__.py.
@ -288,15 +296,15 @@ def load_model_from_init_py(
) )
def get_installed_models(): def get_installed_models() -> List[str]:
"""List all model packages currently installed in the environment. """List all model packages currently installed in the environment.
RETURNS (list): The string names of the models. RETURNS (List[str]): The string names of the models.
""" """
return list(registry.models.get_all().keys()) return list(registry.models.get_all().keys())
def get_package_version(name): def get_package_version(name: str) -> Optional[str]:
"""Get the version of an installed package. Typically used to get model """Get the version of an installed package. Typically used to get model
package versions. package versions.
@ -309,7 +317,9 @@ def get_package_version(name):
return None return None
def is_compatible_version(version, constraint, prereleases=True): def is_compatible_version(
version: str, constraint: str, prereleases: bool = True
) -> Optional[bool]:
"""Check if a version (e.g. "2.0.0") is compatible given a version """Check if a version (e.g. "2.0.0") is compatible given a version
constraint (e.g. ">=1.9.0,<2.2.1"). If the constraint is a specific version, constraint (e.g. ">=1.9.0,<2.2.1"). If the constraint is a specific version,
it's interpreted as =={version}. it's interpreted as =={version}.
@ -333,7 +343,9 @@ def is_compatible_version(version, constraint, prereleases=True):
return version in spec return version in spec
def is_unconstrained_version(constraint, prereleases=True): def is_unconstrained_version(
constraint: str, prereleases: bool = True
) -> Optional[bool]:
# We have an exact version, this is the ultimate constrained version # We have an exact version, this is the ultimate constrained version
if constraint[0].isdigit(): if constraint[0].isdigit():
return False return False
@ -358,7 +370,7 @@ def is_unconstrained_version(constraint, prereleases=True):
return True return True
def get_model_version_range(spacy_version): def get_model_version_range(spacy_version: str) -> str:
"""Generate a version range like >=1.2.3,<1.3.0 based on a given spaCy """Generate a version range like >=1.2.3,<1.3.0 based on a given spaCy
version. Models are always compatible across patch versions but not version. Models are always compatible across patch versions but not
across minor or major versions. across minor or major versions.
@ -367,7 +379,7 @@ def get_model_version_range(spacy_version):
return f">={spacy_version},<{release[0]}.{release[1] + 1}.0" return f">={spacy_version},<{release[0]}.{release[1] + 1}.0"
def get_base_version(version): def get_base_version(version: str) -> str:
"""Generate the base version without any prerelease identifiers. """Generate the base version without any prerelease identifiers.
version (str): The version, e.g. "3.0.0.dev1". version (str): The version, e.g. "3.0.0.dev1".
@ -376,11 +388,11 @@ def get_base_version(version):
return Version(version).base_version return Version(version).base_version
def get_model_meta(path): def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]:
"""Get model meta.json from a directory path and validate its contents. """Get model meta.json from a directory path and validate its contents.
path (str / Path): Path to model directory. path (str / Path): Path to model directory.
RETURNS (dict): The model's meta data. RETURNS (Dict[str, Any]): The model's meta data.
""" """
model_path = ensure_path(path) model_path = ensure_path(path)
if not model_path.exists(): if not model_path.exists():
@ -412,7 +424,7 @@ def get_model_meta(path):
return meta return meta
def is_package(name): def is_package(name: str) -> bool:
"""Check if string maps to a package installed via pip. """Check if string maps to a package installed via pip.
name (str): Name of package. name (str): Name of package.
@ -425,7 +437,7 @@ def is_package(name):
return False return False
def get_package_path(name): def get_package_path(name: str) -> Path:
"""Get the path to an installed package. """Get the path to an installed package.
name (str): Package name. name (str): Package name.
@ -495,7 +507,7 @@ def working_dir(path: Union[str, Path]) -> None:
@contextmanager @contextmanager
def make_tempdir(): def make_tempdir() -> None:
"""Execute a block in a temporary directory and remove the directory and """Execute a block in a temporary directory and remove the directory and
its contents at the end of the with block. its contents at the end of the with block.
@ -518,7 +530,7 @@ def is_cwd(path: Union[Path, str]) -> bool:
return str(Path(path).resolve()).lower() == str(Path.cwd().resolve()).lower() return str(Path(path).resolve()).lower() == str(Path.cwd().resolve()).lower()
def is_in_jupyter(): def is_in_jupyter() -> bool:
"""Check if user is running spaCy from a Jupyter notebook by detecting the """Check if user is running spaCy from a Jupyter notebook by detecting the
IPython kernel. Mainly used for the displaCy visualizer. IPython kernel. Mainly used for the displaCy visualizer.
RETURNS (bool): True if in Jupyter, False if not. RETURNS (bool): True if in Jupyter, False if not.
@ -548,7 +560,9 @@ def get_object_name(obj: Any) -> str:
return repr(obj) return repr(obj)
def get_cuda_stream(require=False, non_blocking=True): def get_cuda_stream(
require: bool = False, non_blocking: bool = True
) -> Optional[CudaStream]:
ops = get_current_ops() ops = get_current_ops()
if CudaStream is None: if CudaStream is None:
return None return None
@ -567,7 +581,7 @@ def get_async(stream, numpy_array):
return array return array
def env_opt(name, default=None): def env_opt(name: str, default: Optional[Any] = None) -> Optional[Any]:
if type(default) is float: if type(default) is float:
type_convert = float type_convert = float
else: else:
@ -588,7 +602,7 @@ def env_opt(name, default=None):
return default return default
def read_regex(path): def read_regex(path: Union[str, Path]) -> Pattern:
path = ensure_path(path) path = ensure_path(path)
with path.open(encoding="utf8") as file_: with path.open(encoding="utf8") as file_:
entries = file_.read().split("\n") entries = file_.read().split("\n")
@ -598,37 +612,40 @@ def read_regex(path):
return re.compile(expression) return re.compile(expression)
def compile_prefix_regex(entries): def compile_prefix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
"""Compile a sequence of prefix rules into a regex object. """Compile a sequence of prefix rules into a regex object.
entries (tuple): The prefix rules, e.g. spacy.lang.punctuation.TOKENIZER_PREFIXES. entries (Iterable[Union[str, Pattern]]): The prefix rules, e.g.
RETURNS (regex object): The regex object. to be used for Tokenizer.prefix_search. spacy.lang.punctuation.TOKENIZER_PREFIXES.
RETURNS (Pattern): The regex object. to be used for Tokenizer.prefix_search.
""" """
expression = "|".join(["^" + piece for piece in entries if piece.strip()]) expression = "|".join(["^" + piece for piece in entries if piece.strip()])
return re.compile(expression) return re.compile(expression)
def compile_suffix_regex(entries): def compile_suffix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
"""Compile a sequence of suffix rules into a regex object. """Compile a sequence of suffix rules into a regex object.
entries (tuple): The suffix rules, e.g. spacy.lang.punctuation.TOKENIZER_SUFFIXES. entries (Iterable[Union[str, Pattern]]): The suffix rules, e.g.
RETURNS (regex object): The regex object. to be used for Tokenizer.suffix_search. spacy.lang.punctuation.TOKENIZER_SUFFIXES.
RETURNS (Pattern): The regex object. to be used for Tokenizer.suffix_search.
""" """
expression = "|".join([piece + "$" for piece in entries if piece.strip()]) expression = "|".join([piece + "$" for piece in entries if piece.strip()])
return re.compile(expression) return re.compile(expression)
def compile_infix_regex(entries): def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
"""Compile a sequence of infix rules into a regex object. """Compile a sequence of infix rules into a regex object.
entries (tuple): The infix rules, e.g. spacy.lang.punctuation.TOKENIZER_INFIXES. entries (Iterable[Union[str, Pattern]]): The infix rules, e.g.
spacy.lang.punctuation.TOKENIZER_INFIXES.
RETURNS (regex object): The regex object. to be used for Tokenizer.infix_finditer. RETURNS (regex object): The regex object. to be used for Tokenizer.infix_finditer.
""" """
expression = "|".join([piece for piece in entries if piece.strip()]) expression = "|".join([piece for piece in entries if piece.strip()])
return re.compile(expression) return re.compile(expression)
def add_lookups(default_func, *lookups): def add_lookups(default_func: Callable[[str], Any], *lookups) -> Callable[[str], Any]:
"""Extend an attribute function with special cases. If a word is in the """Extend an attribute function with special cases. If a word is in the
lookups, the value is returned. Otherwise the previous function is used. lookups, the value is returned. Otherwise the previous function is used.
@ -641,19 +658,23 @@ def add_lookups(default_func, *lookups):
return functools.partial(_get_attr_unless_lookup, default_func, lookups) return functools.partial(_get_attr_unless_lookup, default_func, lookups)
def _get_attr_unless_lookup(default_func, lookups, string): def _get_attr_unless_lookup(
default_func: Callable[[str], Any], lookups: Dict[str, Any], string: str
) -> Any:
for lookup in lookups: for lookup in lookups:
if string in lookup: if string in lookup:
return lookup[string] return lookup[string]
return default_func(string) return default_func(string)
def update_exc(base_exceptions, *addition_dicts): def update_exc(
base_exceptions: Dict[str, List[dict]], *addition_dicts
) -> Dict[str, List[dict]]:
"""Update and validate tokenizer exceptions. Will overwrite exceptions. """Update and validate tokenizer exceptions. Will overwrite exceptions.
base_exceptions (dict): Base exceptions. base_exceptions (Dict[str, List[dict]]): Base exceptions.
*addition_dicts (dict): Exceptions to add to the base dict, in order. *addition_dicts (Dict[str, List[dict]]): Exceptions to add to the base dict, in order.
RETURNS (dict): Combined tokenizer exceptions. RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions.
""" """
exc = dict(base_exceptions) exc = dict(base_exceptions)
for additions in addition_dicts: for additions in addition_dicts:
@ -668,14 +689,16 @@ def update_exc(base_exceptions, *addition_dicts):
return exc return exc
def expand_exc(excs, search, replace): def expand_exc(
excs: Dict[str, List[dict]], search: str, replace: str
) -> Dict[str, List[dict]]:
"""Find string in tokenizer exceptions, duplicate entry and replace string. """Find string in tokenizer exceptions, duplicate entry and replace string.
For example, to add additional versions with typographic apostrophes. For example, to add additional versions with typographic apostrophes.
excs (dict): Tokenizer exceptions. excs (Dict[str, List[dict]]): Tokenizer exceptions.
search (str): String to find and replace. search (str): String to find and replace.
replace (str): Replacement. replace (str): Replacement.
RETURNS (dict): Combined tokenizer exceptions. RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions.
""" """
def _fix_token(token, search, replace): def _fix_token(token, search, replace):
@ -692,7 +715,9 @@ def expand_exc(excs, search, replace):
return new_excs return new_excs
def normalize_slice(length, start, stop, step=None): def normalize_slice(
length: int, start: int, stop: int, step: Optional[int] = None
) -> Tuple[int, int]:
if not (step is None or step == 1): if not (step is None or step == 1):
raise ValueError(Errors.E057) raise ValueError(Errors.E057)
if start is None: if start is None:
@ -708,7 +733,9 @@ def normalize_slice(length, start, stop, step=None):
return start, stop return start, stop
def minibatch(items, size=8): def minibatch(
items: Iterable[Any], size: Union[Iterator[int], int] = 8
) -> Iterator[Any]:
"""Iterate over batches of items. `size` may be an iterator, """Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step. so that batch-size can vary on each step.
""" """
@ -725,7 +752,12 @@ def minibatch(items, size=8):
yield list(batch) yield list(batch)
def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False): def minibatch_by_padded_size(
docs: Iterator["Doc"],
size: Union[Iterator[int], int],
buffer: int = 256,
discard_oversize: bool = False,
) -> Iterator[Iterator["Doc"]]:
if isinstance(size, int): if isinstance(size, int):
size_ = itertools.repeat(size) size_ = itertools.repeat(size)
else: else:
@ -742,7 +774,7 @@ def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False):
yield subbatch yield subbatch
def _batch_by_length(seqs, max_words): def _batch_by_length(seqs: Sequence[Any], max_words: int) -> List[List[Any]]:
"""Given a list of sequences, return a batched list of indices into the """Given a list of sequences, return a batched list of indices into the
list, where the batches are grouped by length, in descending order. list, where the batches are grouped by length, in descending order.
@ -783,14 +815,12 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
size_ = iter(size) size_ = iter(size)
else: else:
size_ = size size_ = size
target_size = next(size_) target_size = next(size_)
tol_size = target_size * tolerance tol_size = target_size * tolerance
batch = [] batch = []
overflow = [] overflow = []
batch_size = 0 batch_size = 0
overflow_size = 0 overflow_size = 0
for doc in docs: for doc in docs:
if isinstance(doc, Example): if isinstance(doc, Example):
n_words = len(doc.reference) n_words = len(doc.reference)
@ -803,17 +833,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
if n_words > target_size + tol_size: if n_words > target_size + tol_size:
if not discard_oversize: if not discard_oversize:
yield [doc] yield [doc]
# add the example to the current batch if there's no overflow yet and it still fits # add the example to the current batch if there's no overflow yet and it still fits
elif overflow_size == 0 and (batch_size + n_words) <= target_size: elif overflow_size == 0 and (batch_size + n_words) <= target_size:
batch.append(doc) batch.append(doc)
batch_size += n_words batch_size += n_words
# add the example to the overflow buffer if it fits in the tolerance margin # add the example to the overflow buffer if it fits in the tolerance margin
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
overflow.append(doc) overflow.append(doc)
overflow_size += n_words overflow_size += n_words
# yield the previous batch and start a new one. The new one gets the overflow examples. # yield the previous batch and start a new one. The new one gets the overflow examples.
else: else:
if batch: if batch:
@ -824,17 +851,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
batch_size = overflow_size batch_size = overflow_size
overflow = [] overflow = []
overflow_size = 0 overflow_size = 0
# this example still fits # this example still fits
if (batch_size + n_words) <= target_size: if (batch_size + n_words) <= target_size:
batch.append(doc) batch.append(doc)
batch_size += n_words batch_size += n_words
# this example fits in overflow # this example fits in overflow
elif (batch_size + n_words) <= (target_size + tol_size): elif (batch_size + n_words) <= (target_size + tol_size):
overflow.append(doc) overflow.append(doc)
overflow_size += n_words overflow_size += n_words
# this example does not fit with the previous overflow: start another new batch # this example does not fit with the previous overflow: start another new batch
else: else:
if batch: if batch:
@ -843,20 +867,19 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
tol_size = target_size * tolerance tol_size = target_size * tolerance
batch = [doc] batch = [doc]
batch_size = n_words batch_size = n_words
batch.extend(overflow) batch.extend(overflow)
if batch: if batch:
yield batch yield batch
def filter_spans(spans): def filter_spans(spans: Iterable["Span"]) -> List["Span"]:
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for """Filter a sequence of spans and remove duplicates or overlaps. Useful for
creating named entities (where one token can only be part of one entity) or creating named entities (where one token can only be part of one entity) or
when merging spans with `Retokenizer.merge`. When spans overlap, the (first) when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
longest span is preferred over shorter spans. longest span is preferred over shorter spans.
spans (iterable): The spans to filter. spans (Iterable[Span]): The spans to filter.
RETURNS (list): The filtered spans. RETURNS (List[Span]): The filtered spans.
""" """
get_sort_key = lambda span: (span.end - span.start, -span.start) get_sort_key = lambda span: (span.end - span.start, -span.start)
sorted_spans = sorted(spans, key=get_sort_key, reverse=True) sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
@ -871,15 +894,21 @@ def filter_spans(spans):
return result return result
def to_bytes(getters, exclude): def to_bytes(getters: Dict[str, Callable[[], bytes]], exclude: Iterable[str]) -> bytes:
return srsly.msgpack_dumps(to_dict(getters, exclude)) return srsly.msgpack_dumps(to_dict(getters, exclude))
def from_bytes(bytes_data, setters, exclude): def from_bytes(
bytes_data: bytes,
setters: Dict[str, Callable[[bytes], Any]],
exclude: Iterable[str],
) -> None:
return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude) return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude)
def to_dict(getters, exclude): def to_dict(
getters: Dict[str, Callable[[], Any]], exclude: Iterable[str]
) -> Dict[str, Any]:
serialized = {} serialized = {}
for key, getter in getters.items(): for key, getter in getters.items():
# Split to support file names like meta.json # Split to support file names like meta.json
@ -888,7 +917,11 @@ def to_dict(getters, exclude):
return serialized return serialized
def from_dict(msg, setters, exclude): def from_dict(
msg: Dict[str, Any],
setters: Dict[str, Callable[[Any], Any]],
exclude: Iterable[str],
) -> Dict[str, Any]:
for key, setter in setters.items(): for key, setter in setters.items():
# Split to support file names like meta.json # Split to support file names like meta.json
if key.split(".")[0] not in exclude and key in msg: if key.split(".")[0] not in exclude and key in msg:
@ -896,7 +929,11 @@ def from_dict(msg, setters, exclude):
return msg return msg
def to_disk(path, writers, exclude): def to_disk(
path: Union[str, Path],
writers: Dict[str, Callable[[Path], None]],
exclude: Iterable[str],
) -> Path:
path = ensure_path(path) path = ensure_path(path)
if not path.exists(): if not path.exists():
path.mkdir() path.mkdir()
@ -907,7 +944,11 @@ def to_disk(path, writers, exclude):
return path return path
def from_disk(path, readers, exclude): def from_disk(
path: Union[str, Path],
readers: Dict[str, Callable[[Path], None]],
exclude: Iterable[str],
) -> Path:
path = ensure_path(path) path = ensure_path(path)
for key, reader in readers.items(): for key, reader in readers.items():
# Split to support file names like meta.json # Split to support file names like meta.json
@ -916,7 +957,7 @@ def from_disk(path, readers, exclude):
return path return path
def import_file(name, loc): def import_file(name: str, loc: Union[str, Path]) -> ModuleType:
"""Import module from a file. Used to load models from a directory. """Import module from a file. Used to load models from a directory.
name (str): Name of module to load. name (str): Name of module to load.
@ -930,7 +971,7 @@ def import_file(name, loc):
return module return module
def minify_html(html): def minify_html(html: str) -> str:
"""Perform a template-specific, rudimentary HTML minification for displaCy. """Perform a template-specific, rudimentary HTML minification for displaCy.
Disclaimer: NOT a general-purpose solution, only removes indentation and Disclaimer: NOT a general-purpose solution, only removes indentation and
newlines. newlines.
@ -941,7 +982,7 @@ def minify_html(html):
return html.strip().replace(" ", "").replace("\n", "") return html.strip().replace(" ", "").replace("\n", "")
def escape_html(text): def escape_html(text: str) -> str:
"""Replace <, >, &, " with their HTML encoded representation. Intended to """Replace <, >, &, " with their HTML encoded representation. Intended to
prevent HTML errors in rendered displaCy markup. prevent HTML errors in rendered displaCy markup.
@ -955,7 +996,9 @@ def escape_html(text):
return text return text
def get_words_and_spaces(words, text): def get_words_and_spaces(
words: Iterable[str], text: str
) -> Tuple[List[str], List[bool]]:
if "".join("".join(words).split()) != "".join(text.split()): if "".join("".join(words).split()) != "".join(text.split()):
raise ValueError(Errors.E194.format(text=text, words=words)) raise ValueError(Errors.E194.format(text=text, words=words))
text_words = [] text_words = []
@ -1103,7 +1146,7 @@ class DummyTokenizer:
return self return self
def link_vectors_to_models(vocab): def link_vectors_to_models(vocab: "Vocab") -> None:
vectors = vocab.vectors vectors = vocab.vectors
if vectors.name is None: if vectors.name is None:
vectors.name = VECTORS_KEY vectors.name = VECTORS_KEY
@ -1119,7 +1162,7 @@ def link_vectors_to_models(vocab):
VECTORS_KEY = "spacy_pretrained_vectors" VECTORS_KEY = "spacy_pretrained_vectors"
def create_default_optimizer(): def create_default_optimizer() -> Optimizer:
learn_rate = env_opt("learn_rate", 0.001) learn_rate = env_opt("learn_rate", 0.001)
beta1 = env_opt("optimizer_B1", 0.9) beta1 = env_opt("optimizer_B1", 0.9)
beta2 = env_opt("optimizer_B2", 0.999) beta2 = env_opt("optimizer_B2", 0.999)