mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 01:34:30 +03:00
Tidy up, autoformat, add types
This commit is contained in:
parent
71242327b2
commit
e92df281ce
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -51,6 +51,7 @@ env3.*/
|
|||
.denv
|
||||
.pypyenv
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
|
||||
# Distribution / packaging
|
||||
env/
|
||||
|
|
|
@ -108,3 +108,8 @@ exclude =
|
|||
[tool:pytest]
|
||||
markers =
|
||||
slow
|
||||
|
||||
[mypy]
|
||||
ignore_missing_imports = True
|
||||
no_implicit_optional = True
|
||||
plugins = pydantic.mypy, thinc.mypy
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Any, List, Union
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from wasabi import Printer
|
||||
|
@ -66,10 +66,9 @@ def convert_cli(
|
|||
file_type = file_type.value
|
||||
input_path = Path(input_path)
|
||||
output_dir = "-" if output_dir == Path("-") else output_dir
|
||||
cli_args = locals()
|
||||
silent = output_dir == "-"
|
||||
msg = Printer(no_print=silent)
|
||||
verify_cli_args(msg, **cli_args)
|
||||
verify_cli_args(msg, input_path, output_dir, file_type, converter, ner_map)
|
||||
converter = _get_converter(msg, converter, input_path)
|
||||
convert(
|
||||
input_path,
|
||||
|
@ -89,8 +88,8 @@ def convert_cli(
|
|||
|
||||
|
||||
def convert(
|
||||
input_path: Path,
|
||||
output_dir: Path,
|
||||
input_path: Union[str, Path],
|
||||
output_dir: Union[str, Path],
|
||||
*,
|
||||
file_type: str = "json",
|
||||
n_sents: int = 1,
|
||||
|
@ -102,13 +101,12 @@ def convert(
|
|||
ner_map: Optional[Path] = None,
|
||||
lang: Optional[str] = None,
|
||||
silent: bool = True,
|
||||
msg: Optional[Path] = None,
|
||||
msg: Optional[Printer],
|
||||
) -> None:
|
||||
if not msg:
|
||||
msg = Printer(no_print=silent)
|
||||
ner_map = srsly.read_json(ner_map) if ner_map is not None else None
|
||||
|
||||
for input_loc in walk_directory(input_path):
|
||||
for input_loc in walk_directory(Path(input_path)):
|
||||
input_data = input_loc.open("r", encoding="utf-8").read()
|
||||
# Use converter function to convert data
|
||||
func = CONVERTERS[converter]
|
||||
|
@ -140,14 +138,14 @@ def convert(
|
|||
msg.good(f"Generated output file ({len(docs)} documents): {output_file}")
|
||||
|
||||
|
||||
def _print_docs_to_stdout(data, output_type):
|
||||
def _print_docs_to_stdout(data: Any, output_type: str) -> None:
|
||||
if output_type == "json":
|
||||
srsly.write_json("-", data)
|
||||
else:
|
||||
sys.stdout.buffer.write(data)
|
||||
|
||||
|
||||
def _write_docs_to_file(data, output_file, output_type):
|
||||
def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None:
|
||||
if not output_file.parent.exists():
|
||||
output_file.parent.mkdir(parents=True)
|
||||
if output_type == "json":
|
||||
|
@ -157,7 +155,7 @@ def _write_docs_to_file(data, output_file, output_type):
|
|||
file_.write(data)
|
||||
|
||||
|
||||
def autodetect_ner_format(input_data: str) -> str:
|
||||
def autodetect_ner_format(input_data: str) -> Optional[str]:
|
||||
# guess format from the first 20 lines
|
||||
lines = input_data.split("\n")[:20]
|
||||
format_guesses = {"ner": 0, "iob": 0}
|
||||
|
@ -176,7 +174,7 @@ def autodetect_ner_format(input_data: str) -> str:
|
|||
return None
|
||||
|
||||
|
||||
def walk_directory(path):
|
||||
def walk_directory(path: Path) -> List[Path]:
|
||||
if not path.is_dir():
|
||||
return [path]
|
||||
paths = [path]
|
||||
|
@ -196,31 +194,25 @@ def walk_directory(path):
|
|||
|
||||
|
||||
def verify_cli_args(
|
||||
msg,
|
||||
input_path,
|
||||
output_dir,
|
||||
file_type,
|
||||
n_sents,
|
||||
seg_sents,
|
||||
model,
|
||||
morphology,
|
||||
merge_subtokens,
|
||||
converter,
|
||||
ner_map,
|
||||
lang,
|
||||
msg: Printer,
|
||||
input_path: Union[str, Path],
|
||||
output_dir: Union[str, Path],
|
||||
file_type: FileTypes,
|
||||
converter: str,
|
||||
ner_map: Optional[Path],
|
||||
):
|
||||
input_path = Path(input_path)
|
||||
if file_type not in FILE_TYPES_STDOUT and output_dir == "-":
|
||||
# TODO: support msgpack via stdout in srsly?
|
||||
msg.fail(
|
||||
f"Can't write .{file_type} data to stdout",
|
||||
"Please specify an output directory.",
|
||||
f"Can't write .{file_type} data to stdout. Please specify an output directory.",
|
||||
exits=1,
|
||||
)
|
||||
if not input_path.exists():
|
||||
msg.fail("Input file not found", input_path, exits=1)
|
||||
if output_dir != "-" and not Path(output_dir).exists():
|
||||
msg.fail("Output directory not found", output_dir, exits=1)
|
||||
if ner_map is not None and not Path(ner_map).exists():
|
||||
msg.fail("NER map not found", ner_map, exits=1)
|
||||
if input_path.is_dir():
|
||||
input_locs = walk_directory(input_path)
|
||||
if len(input_locs) == 0:
|
||||
|
@ -229,10 +221,8 @@ def verify_cli_args(
|
|||
if len(file_types) >= 2:
|
||||
file_types = ",".join(file_types)
|
||||
msg.fail("All input files must be same type", file_types, exits=1)
|
||||
converter = _get_converter(msg, converter, input_path)
|
||||
if converter not in CONVERTERS:
|
||||
if converter != "auto" and converter not in CONVERTERS:
|
||||
msg.fail(f"Can't find converter for {converter}", exits=1)
|
||||
return converter
|
||||
|
||||
|
||||
def _get_converter(msg, converter, input_path):
|
||||
|
|
|
@ -82,11 +82,11 @@ def evaluate(
|
|||
"NER P": "ents_p",
|
||||
"NER R": "ents_r",
|
||||
"NER F": "ents_f",
|
||||
"Textcat AUC": 'textcat_macro_auc',
|
||||
"Textcat F": 'textcat_macro_f',
|
||||
"Sent P": 'sents_p',
|
||||
"Sent R": 'sents_r',
|
||||
"Sent F": 'sents_f',
|
||||
"Textcat AUC": "textcat_macro_auc",
|
||||
"Textcat F": "textcat_macro_f",
|
||||
"Sent P": "sents_p",
|
||||
"Sent R": "sents_r",
|
||||
"Sent F": "sents_f",
|
||||
}
|
||||
results = {}
|
||||
for metric, key in metrics.items():
|
||||
|
|
|
@ -14,7 +14,6 @@ import typer
|
|||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||
from ._util import import_code
|
||||
from ..gold import Corpus, Example
|
||||
from ..lookups import Lookups
|
||||
from ..language import Language
|
||||
from .. import util
|
||||
from ..errors import Errors
|
||||
|
|
|
@ -1,12 +1,5 @@
|
|||
"""
|
||||
Helpers for Python and platform compatibility. To distinguish them from
|
||||
the builtin functions, replacement functions are suffixed with an underscore,
|
||||
e.g. `unicode_`.
|
||||
|
||||
DOCS: https://spacy.io/api/top-level#compat
|
||||
"""
|
||||
"""Helpers for Python and platform compatibility."""
|
||||
import sys
|
||||
|
||||
from thinc.util import copy_array
|
||||
|
||||
try:
|
||||
|
@ -40,21 +33,3 @@ copy_array = copy_array
|
|||
is_windows = sys.platform.startswith("win")
|
||||
is_linux = sys.platform.startswith("linux")
|
||||
is_osx = sys.platform == "darwin"
|
||||
|
||||
|
||||
def is_config(windows=None, linux=None, osx=None, **kwargs):
|
||||
"""Check if a specific configuration of Python version and operating system
|
||||
matches the user's setup. Mostly used to display targeted error messages.
|
||||
|
||||
windows (bool): spaCy is executed on Windows.
|
||||
linux (bool): spaCy is executed on Linux.
|
||||
osx (bool): spaCy is executed on OS X or macOS.
|
||||
RETURNS (bool): Whether the configuration matches the user's platform.
|
||||
|
||||
DOCS: https://spacy.io/api/top-level#compat.is_config
|
||||
"""
|
||||
return (
|
||||
windows in (None, is_windows)
|
||||
and linux in (None, is_linux)
|
||||
and osx in (None, is_osx)
|
||||
)
|
||||
|
|
|
@ -4,6 +4,7 @@ spaCy's built in visualization suite for dependencies and named entities.
|
|||
DOCS: https://spacy.io/api/top-level#displacy
|
||||
USAGE: https://spacy.io/usage/visualizers
|
||||
"""
|
||||
from typing import Union, Iterable, Optional, Dict, Any, Callable
|
||||
import warnings
|
||||
|
||||
from .render import DependencyRenderer, EntityRenderer
|
||||
|
@ -17,11 +18,17 @@ RENDER_WRAPPER = None
|
|||
|
||||
|
||||
def render(
|
||||
docs, style="dep", page=False, minify=False, jupyter=None, options={}, manual=False
|
||||
):
|
||||
docs: Union[Iterable[Doc], Doc],
|
||||
style: str = "dep",
|
||||
page: bool = False,
|
||||
minify: bool = False,
|
||||
jupyter: Optional[bool] = None,
|
||||
options: Dict[str, Any] = {},
|
||||
manual: bool = False,
|
||||
) -> str:
|
||||
"""Render displaCy visualisation.
|
||||
|
||||
docs (list or Doc): Document(s) to visualise.
|
||||
docs (Union[Iterable[Doc], Doc]): Document(s) to visualise.
|
||||
style (str): Visualisation style, 'dep' or 'ent'.
|
||||
page (bool): Render markup as full HTML page.
|
||||
minify (bool): Minify HTML markup.
|
||||
|
@ -44,8 +51,8 @@ def render(
|
|||
docs = [obj if not isinstance(obj, Span) else obj.as_doc() for obj in docs]
|
||||
if not all(isinstance(obj, (Doc, Span, dict)) for obj in docs):
|
||||
raise ValueError(Errors.E096)
|
||||
renderer, converter = factories[style]
|
||||
renderer = renderer(options=options)
|
||||
renderer_func, converter = factories[style]
|
||||
renderer = renderer_func(options=options)
|
||||
parsed = [converter(doc, options) for doc in docs] if not manual else docs
|
||||
_html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip()
|
||||
html = _html["parsed"]
|
||||
|
@ -61,15 +68,15 @@ def render(
|
|||
|
||||
|
||||
def serve(
|
||||
docs,
|
||||
style="dep",
|
||||
page=True,
|
||||
minify=False,
|
||||
options={},
|
||||
manual=False,
|
||||
port=5000,
|
||||
host="0.0.0.0",
|
||||
):
|
||||
docs: Union[Iterable[Doc], Doc],
|
||||
style: str = "dep",
|
||||
page: bool = True,
|
||||
minify: bool = False,
|
||||
options: Dict[str, Any] = {},
|
||||
manual: bool = False,
|
||||
port: int = 5000,
|
||||
host: str = "0.0.0.0",
|
||||
) -> None:
|
||||
"""Serve displaCy visualisation.
|
||||
|
||||
docs (list or Doc): Document(s) to visualise.
|
||||
|
@ -88,7 +95,6 @@ def serve(
|
|||
|
||||
if is_in_jupyter():
|
||||
warnings.warn(Warnings.W011)
|
||||
|
||||
render(docs, style=style, page=page, minify=minify, options=options, manual=manual)
|
||||
httpd = simple_server.make_server(host, port, app)
|
||||
print(f"\nUsing the '{style}' visualizer")
|
||||
|
@ -102,14 +108,13 @@ def serve(
|
|||
|
||||
|
||||
def app(environ, start_response):
|
||||
# Headers and status need to be bytes in Python 2, see #1227
|
||||
headers = [("Content-type", "text/html; charset=utf-8")]
|
||||
start_response("200 OK", headers)
|
||||
res = _html["parsed"].encode(encoding="utf-8")
|
||||
return [res]
|
||||
|
||||
|
||||
def parse_deps(orig_doc, options={}):
|
||||
def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
|
||||
"""Generate dependency parse in {'words': [], 'arcs': []} format.
|
||||
|
||||
doc (Doc): Document do parse.
|
||||
|
@ -152,7 +157,6 @@ def parse_deps(orig_doc, options={}):
|
|||
}
|
||||
for w in doc
|
||||
]
|
||||
|
||||
arcs = []
|
||||
for word in doc:
|
||||
if word.i < word.head.i:
|
||||
|
@ -171,7 +175,7 @@ def parse_deps(orig_doc, options={}):
|
|||
return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)}
|
||||
|
||||
|
||||
def parse_ents(doc, options={}):
|
||||
def parse_ents(doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
|
||||
"""Generate named entities in [{start: i, end: i, label: 'label'}] format.
|
||||
|
||||
doc (Doc): Document do parse.
|
||||
|
@ -188,7 +192,7 @@ def parse_ents(doc, options={}):
|
|||
return {"text": doc.text, "ents": ents, "title": title, "settings": settings}
|
||||
|
||||
|
||||
def set_render_wrapper(func):
|
||||
def set_render_wrapper(func: Callable[[str], str]) -> None:
|
||||
"""Set an optional wrapper function that is called around the generated
|
||||
HTML markup on displacy.render. This can be used to allow integration into
|
||||
other platforms, similar to Jupyter Notebooks that require functions to be
|
||||
|
@ -205,7 +209,7 @@ def set_render_wrapper(func):
|
|||
RENDER_WRAPPER = func
|
||||
|
||||
|
||||
def get_doc_settings(doc):
|
||||
def get_doc_settings(doc: Doc) -> Dict[str, Any]:
|
||||
return {
|
||||
"lang": doc.lang_,
|
||||
"direction": doc.vocab.writing_system.get("direction", "ltr"),
|
||||
|
|
|
@ -1,19 +1,36 @@
|
|||
from typing import Dict, Any, List, Optional, Union
|
||||
import uuid
|
||||
|
||||
from .templates import (
|
||||
TPL_DEP_SVG,
|
||||
TPL_DEP_WORDS,
|
||||
TPL_DEP_WORDS_LEMMA,
|
||||
TPL_DEP_ARCS,
|
||||
TPL_ENTS,
|
||||
)
|
||||
from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_WORDS_LEMMA, TPL_DEP_ARCS
|
||||
from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE
|
||||
from .templates import TPL_ENTS
|
||||
from ..util import minify_html, escape_html, registry
|
||||
from ..errors import Errors
|
||||
|
||||
|
||||
DEFAULT_LANG = "en"
|
||||
DEFAULT_DIR = "ltr"
|
||||
DEFAULT_ENTITY_COLOR = "#ddd"
|
||||
DEFAULT_LABEL_COLORS = {
|
||||
"ORG": "#7aecec",
|
||||
"PRODUCT": "#bfeeb7",
|
||||
"GPE": "#feca74",
|
||||
"LOC": "#ff9561",
|
||||
"PERSON": "#aa9cfc",
|
||||
"NORP": "#c887fb",
|
||||
"FACILITY": "#9cc9cc",
|
||||
"EVENT": "#ffeb80",
|
||||
"LAW": "#ff8197",
|
||||
"LANGUAGE": "#ff8197",
|
||||
"WORK_OF_ART": "#f0d0ff",
|
||||
"DATE": "#bfe1d9",
|
||||
"TIME": "#bfe1d9",
|
||||
"MONEY": "#e4e7d2",
|
||||
"QUANTITY": "#e4e7d2",
|
||||
"ORDINAL": "#e4e7d2",
|
||||
"CARDINAL": "#e4e7d2",
|
||||
"PERCENT": "#e4e7d2",
|
||||
}
|
||||
|
||||
|
||||
class DependencyRenderer:
|
||||
|
@ -21,7 +38,7 @@ class DependencyRenderer:
|
|||
|
||||
style = "dep"
|
||||
|
||||
def __init__(self, options={}):
|
||||
def __init__(self, options: Dict[str, Any] = {}) -> None:
|
||||
"""Initialise dependency renderer.
|
||||
|
||||
options (dict): Visualiser-specific options (compact, word_spacing,
|
||||
|
@ -41,7 +58,9 @@ class DependencyRenderer:
|
|||
self.direction = DEFAULT_DIR
|
||||
self.lang = DEFAULT_LANG
|
||||
|
||||
def render(self, parsed, page=False, minify=False):
|
||||
def render(
|
||||
self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False
|
||||
) -> str:
|
||||
"""Render complete markup.
|
||||
|
||||
parsed (list): Dependency parses to render.
|
||||
|
@ -72,10 +91,15 @@ class DependencyRenderer:
|
|||
return minify_html(markup)
|
||||
return markup
|
||||
|
||||
def render_svg(self, render_id, words, arcs):
|
||||
def render_svg(
|
||||
self,
|
||||
render_id: Union[int, str],
|
||||
words: List[Dict[str, Any]],
|
||||
arcs: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""Render SVG.
|
||||
|
||||
render_id (int): Unique ID, typically index of document.
|
||||
render_id (Union[int, str]): Unique ID, typically index of document.
|
||||
words (list): Individual words and their tags.
|
||||
arcs (list): Individual arcs and their start, end, direction and label.
|
||||
RETURNS (str): Rendered SVG markup.
|
||||
|
@ -86,15 +110,15 @@ class DependencyRenderer:
|
|||
self.width = self.offset_x + len(words) * self.distance
|
||||
self.height = self.offset_y + 3 * self.word_spacing
|
||||
self.id = render_id
|
||||
words = [
|
||||
words_svg = [
|
||||
self.render_word(w["text"], w["tag"], w.get("lemma", None), i)
|
||||
for i, w in enumerate(words)
|
||||
]
|
||||
arcs = [
|
||||
arcs_svg = [
|
||||
self.render_arrow(a["label"], a["start"], a["end"], a["dir"], i)
|
||||
for i, a in enumerate(arcs)
|
||||
]
|
||||
content = "".join(words) + "".join(arcs)
|
||||
content = "".join(words_svg) + "".join(arcs_svg)
|
||||
return TPL_DEP_SVG.format(
|
||||
id=self.id,
|
||||
width=self.width,
|
||||
|
@ -107,9 +131,7 @@ class DependencyRenderer:
|
|||
lang=self.lang,
|
||||
)
|
||||
|
||||
def render_word(
|
||||
self, text, tag, lemma, i,
|
||||
):
|
||||
def render_word(self, text: str, tag: str, lemma: str, i: int) -> str:
|
||||
"""Render individual word.
|
||||
|
||||
text (str): Word text.
|
||||
|
@ -128,7 +150,9 @@ class DependencyRenderer:
|
|||
)
|
||||
return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y)
|
||||
|
||||
def render_arrow(self, label, start, end, direction, i):
|
||||
def render_arrow(
|
||||
self, label: str, start: int, end: int, direction: str, i: int
|
||||
) -> str:
|
||||
"""Render individual arrow.
|
||||
|
||||
label (str): Dependency label.
|
||||
|
@ -172,7 +196,7 @@ class DependencyRenderer:
|
|||
arc=arc,
|
||||
)
|
||||
|
||||
def get_arc(self, x_start, y, y_curve, x_end):
|
||||
def get_arc(self, x_start: int, y: int, y_curve: int, x_end: int) -> str:
|
||||
"""Render individual arc.
|
||||
|
||||
x_start (int): X-coordinate of arrow start point.
|
||||
|
@ -186,7 +210,7 @@ class DependencyRenderer:
|
|||
template = "M{x},{y} {x},{c} {e},{c} {e},{y}"
|
||||
return template.format(x=x_start, y=y, c=y_curve, e=x_end)
|
||||
|
||||
def get_arrowhead(self, direction, x, y, end):
|
||||
def get_arrowhead(self, direction: str, x: int, y: int, end: int) -> str:
|
||||
"""Render individual arrow head.
|
||||
|
||||
direction (str): Arrow direction, 'left' or 'right'.
|
||||
|
@ -196,24 +220,12 @@ class DependencyRenderer:
|
|||
RETURNS (str): Definition of the arrow head path ('d' attribute).
|
||||
"""
|
||||
if direction == "left":
|
||||
pos1, pos2, pos3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2)
|
||||
p1, p2, p3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2)
|
||||
else:
|
||||
pos1, pos2, pos3 = (
|
||||
end,
|
||||
end + self.arrow_width - 2,
|
||||
end - self.arrow_width + 2,
|
||||
)
|
||||
arrowhead = (
|
||||
pos1,
|
||||
y + 2,
|
||||
pos2,
|
||||
y - self.arrow_width,
|
||||
pos3,
|
||||
y - self.arrow_width,
|
||||
)
|
||||
return "M{},{} L{},{} {},{}".format(*arrowhead)
|
||||
p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2)
|
||||
return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}"
|
||||
|
||||
def get_levels(self, arcs):
|
||||
def get_levels(self, arcs: List[Dict[str, Any]]) -> List[int]:
|
||||
"""Calculate available arc height "levels".
|
||||
Used to calculate arrow heights dynamically and without wasting space.
|
||||
|
||||
|
@ -229,41 +241,21 @@ class EntityRenderer:
|
|||
|
||||
style = "ent"
|
||||
|
||||
def __init__(self, options={}):
|
||||
def __init__(self, options: Dict[str, Any] = {}) -> None:
|
||||
"""Initialise dependency renderer.
|
||||
|
||||
options (dict): Visualiser-specific options (colors, ents)
|
||||
"""
|
||||
colors = {
|
||||
"ORG": "#7aecec",
|
||||
"PRODUCT": "#bfeeb7",
|
||||
"GPE": "#feca74",
|
||||
"LOC": "#ff9561",
|
||||
"PERSON": "#aa9cfc",
|
||||
"NORP": "#c887fb",
|
||||
"FACILITY": "#9cc9cc",
|
||||
"EVENT": "#ffeb80",
|
||||
"LAW": "#ff8197",
|
||||
"LANGUAGE": "#ff8197",
|
||||
"WORK_OF_ART": "#f0d0ff",
|
||||
"DATE": "#bfe1d9",
|
||||
"TIME": "#bfe1d9",
|
||||
"MONEY": "#e4e7d2",
|
||||
"QUANTITY": "#e4e7d2",
|
||||
"ORDINAL": "#e4e7d2",
|
||||
"CARDINAL": "#e4e7d2",
|
||||
"PERCENT": "#e4e7d2",
|
||||
}
|
||||
colors = dict(DEFAULT_LABEL_COLORS)
|
||||
user_colors = registry.displacy_colors.get_all()
|
||||
for user_color in user_colors.values():
|
||||
colors.update(user_color)
|
||||
colors.update(options.get("colors", {}))
|
||||
self.default_color = "#ddd"
|
||||
self.default_color = DEFAULT_ENTITY_COLOR
|
||||
self.colors = colors
|
||||
self.ents = options.get("ents", None)
|
||||
self.direction = DEFAULT_DIR
|
||||
self.lang = DEFAULT_LANG
|
||||
|
||||
template = options.get("template")
|
||||
if template:
|
||||
self.ent_template = template
|
||||
|
@ -273,7 +265,9 @@ class EntityRenderer:
|
|||
else:
|
||||
self.ent_template = TPL_ENT
|
||||
|
||||
def render(self, parsed, page=False, minify=False):
|
||||
def render(
|
||||
self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False
|
||||
) -> str:
|
||||
"""Render complete markup.
|
||||
|
||||
parsed (list): Dependency parses to render.
|
||||
|
@ -297,7 +291,9 @@ class EntityRenderer:
|
|||
return minify_html(markup)
|
||||
return markup
|
||||
|
||||
def render_ents(self, text, spans, title):
|
||||
def render_ents(
|
||||
self, text: str, spans: List[Dict[str, Any]], title: Optional[str]
|
||||
) -> str:
|
||||
"""Render entities in text.
|
||||
|
||||
text (str): Original text.
|
||||
|
|
|
@ -483,6 +483,8 @@ class Errors:
|
|||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
|
||||
E954 = ("The Tok2Vec listener did not receive a valid input.")
|
||||
E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.")
|
||||
E956 = ("Can't find component '{name}' in [components] block in the config. "
|
||||
"Available components: {opts}")
|
||||
|
|
|
@ -15,7 +15,7 @@ class Alignment:
|
|||
x2y = _make_ragged(x2y)
|
||||
y2x = _make_ragged(y2x)
|
||||
return Alignment(x2y=x2y, y2x=y2x)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_strings(cls, A: List[str], B: List[str]) -> "Alignment":
|
||||
x2y, y2x = tokenizations.get_alignments(A, B)
|
||||
|
|
|
@ -163,7 +163,7 @@ class Language:
|
|||
else:
|
||||
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
|
||||
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
|
||||
self.vocab = vocab
|
||||
self.vocab: Vocab = vocab
|
||||
if self.lang is None:
|
||||
self.lang = self.vocab.lang
|
||||
self.pipeline = []
|
||||
|
|
|
@ -8,7 +8,7 @@ def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1):
|
|||
init=init,
|
||||
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
|
||||
params={"W": None, "b": None, "pad": None},
|
||||
attrs={"dropout_rate": dropout}
|
||||
attrs={"dropout_rate": dropout},
|
||||
)
|
||||
return model
|
||||
|
||||
|
|
|
@ -154,16 +154,30 @@ def LayerNormalizedMaxout(width, maxout_pieces):
|
|||
def MultiHashEmbed(
|
||||
columns, width, rows, use_subwords, pretrained_vectors, mix, dropout
|
||||
):
|
||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6)
|
||||
norm = HashEmbed(
|
||||
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6
|
||||
)
|
||||
if use_subwords:
|
||||
prefix = HashEmbed(
|
||||
nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout, seed=7
|
||||
nO=width,
|
||||
nV=rows // 2,
|
||||
column=columns.index("PREFIX"),
|
||||
dropout=dropout,
|
||||
seed=7,
|
||||
)
|
||||
suffix = HashEmbed(
|
||||
nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout, seed=8
|
||||
nO=width,
|
||||
nV=rows // 2,
|
||||
column=columns.index("SUFFIX"),
|
||||
dropout=dropout,
|
||||
seed=8,
|
||||
)
|
||||
shape = HashEmbed(
|
||||
nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout, seed=9
|
||||
nO=width,
|
||||
nV=rows // 2,
|
||||
column=columns.index("SHAPE"),
|
||||
dropout=dropout,
|
||||
seed=9,
|
||||
)
|
||||
|
||||
if pretrained_vectors:
|
||||
|
@ -192,7 +206,9 @@ def MultiHashEmbed(
|
|||
|
||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
|
||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5)
|
||||
norm = HashEmbed(
|
||||
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5
|
||||
)
|
||||
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
|
||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||
embed_layer = chr_embed | features >> with_array(norm)
|
||||
|
@ -263,21 +279,29 @@ def build_Tok2Vec_model(
|
|||
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
||||
norm = HashEmbed(
|
||||
nO=width, nV=embed_size, column=cols.index(NORM), dropout=None,
|
||||
seed=0
|
||||
nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, seed=0
|
||||
)
|
||||
if subword_features:
|
||||
prefix = HashEmbed(
|
||||
nO=width, nV=embed_size // 2, column=cols.index(PREFIX), dropout=None,
|
||||
seed=1
|
||||
nO=width,
|
||||
nV=embed_size // 2,
|
||||
column=cols.index(PREFIX),
|
||||
dropout=None,
|
||||
seed=1,
|
||||
)
|
||||
suffix = HashEmbed(
|
||||
nO=width, nV=embed_size // 2, column=cols.index(SUFFIX), dropout=None,
|
||||
seed=2
|
||||
nO=width,
|
||||
nV=embed_size // 2,
|
||||
column=cols.index(SUFFIX),
|
||||
dropout=None,
|
||||
seed=2,
|
||||
)
|
||||
shape = HashEmbed(
|
||||
nO=width, nV=embed_size // 2, column=cols.index(SHAPE), dropout=None,
|
||||
seed=3
|
||||
nO=width,
|
||||
nV=embed_size // 2,
|
||||
column=cols.index(SHAPE),
|
||||
dropout=None,
|
||||
seed=3,
|
||||
)
|
||||
else:
|
||||
prefix, suffix, shape = (None, None, None)
|
||||
|
@ -294,11 +318,7 @@ def build_Tok2Vec_model(
|
|||
embed = uniqued(
|
||||
(glove | norm | prefix | suffix | shape)
|
||||
>> Maxout(
|
||||
nO=width,
|
||||
nI=width * columns,
|
||||
nP=3,
|
||||
dropout=0.0,
|
||||
normalize=True,
|
||||
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
||||
),
|
||||
column=cols.index(ORTH),
|
||||
)
|
||||
|
@ -307,11 +327,7 @@ def build_Tok2Vec_model(
|
|||
embed = uniqued(
|
||||
(glove | norm)
|
||||
>> Maxout(
|
||||
nO=width,
|
||||
nI=width * columns,
|
||||
nP=3,
|
||||
dropout=0.0,
|
||||
normalize=True,
|
||||
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
||||
),
|
||||
column=cols.index(ORTH),
|
||||
)
|
||||
|
@ -320,11 +336,7 @@ def build_Tok2Vec_model(
|
|||
embed = uniqued(
|
||||
concatenate(norm, prefix, suffix, shape)
|
||||
>> Maxout(
|
||||
nO=width,
|
||||
nI=width * columns,
|
||||
nP=3,
|
||||
dropout=0.0,
|
||||
normalize=True,
|
||||
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
||||
),
|
||||
column=cols.index(ORTH),
|
||||
)
|
||||
|
@ -333,11 +345,7 @@ def build_Tok2Vec_model(
|
|||
cols
|
||||
) >> with_array(norm)
|
||||
reduce_dimensions = Maxout(
|
||||
nO=width,
|
||||
nI=nM * nC + width,
|
||||
nP=3,
|
||||
dropout=0.0,
|
||||
normalize=True,
|
||||
nO=width, nI=nM * nC + width, nP=3, dropout=0.0, normalize=True,
|
||||
)
|
||||
else:
|
||||
embed = norm
|
||||
|
|
|
@ -10,7 +10,6 @@ from .simple_ner import SimpleNER
|
|||
from .tagger import Tagger
|
||||
from .textcat import TextCategorizer
|
||||
from .tok2vec import Tok2Vec
|
||||
from .hooks import SentenceSegmenter, SimilarityHook
|
||||
from .functions import merge_entities, merge_noun_chunks, merge_subtokens
|
||||
|
||||
__all__ = [
|
||||
|
@ -21,9 +20,7 @@ __all__ = [
|
|||
"Morphologizer",
|
||||
"Pipe",
|
||||
"SentenceRecognizer",
|
||||
"SentenceSegmenter",
|
||||
"Sentencizer",
|
||||
"SimilarityHook",
|
||||
"SimpleNER",
|
||||
"Tagger",
|
||||
"TextCategorizer",
|
||||
|
|
|
@ -63,7 +63,7 @@ class EntityRuler:
|
|||
overwrite_ents: bool = False,
|
||||
ent_id_sep: str = DEFAULT_ENT_ID_SEP,
|
||||
patterns: Optional[List[PatternType]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize the entitiy ruler. If patterns are supplied here, they
|
||||
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
||||
key. A pattern can either be a token pattern (list) or a phrase pattern
|
||||
|
@ -239,7 +239,6 @@ class EntityRuler:
|
|||
phrase_pattern_labels = []
|
||||
phrase_pattern_texts = []
|
||||
phrase_pattern_ids = []
|
||||
|
||||
for entry in patterns:
|
||||
if isinstance(entry["pattern"], str):
|
||||
phrase_pattern_labels.append(entry["label"])
|
||||
|
@ -247,7 +246,6 @@ class EntityRuler:
|
|||
phrase_pattern_ids.append(entry.get("id"))
|
||||
elif isinstance(entry["pattern"], list):
|
||||
token_patterns.append(entry)
|
||||
|
||||
phrase_patterns = []
|
||||
for label, pattern, ent_id in zip(
|
||||
phrase_pattern_labels,
|
||||
|
@ -258,7 +256,6 @@ class EntityRuler:
|
|||
if ent_id:
|
||||
phrase_pattern["id"] = ent_id
|
||||
phrase_patterns.append(phrase_pattern)
|
||||
|
||||
for entry in token_patterns + phrase_patterns:
|
||||
label = entry["label"]
|
||||
if "id" in entry:
|
||||
|
@ -266,7 +263,6 @@ class EntityRuler:
|
|||
label = self._create_label(label, entry["id"])
|
||||
key = self.matcher._normalize_key(label)
|
||||
self._ent_ids[key] = (ent_label, entry["id"])
|
||||
|
||||
pattern = entry["pattern"]
|
||||
if isinstance(pattern, Doc):
|
||||
self.phrase_patterns[label].append(pattern)
|
||||
|
@ -323,13 +319,13 @@ class EntityRuler:
|
|||
self.clear()
|
||||
if isinstance(cfg, dict):
|
||||
self.add_patterns(cfg.get("patterns", cfg))
|
||||
self.overwrite = cfg.get("overwrite")
|
||||
self.overwrite = cfg.get("overwrite", False)
|
||||
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
|
||||
if self.phrase_matcher_attr is not None:
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
self.nlp.vocab, attr=self.phrase_matcher_attr
|
||||
)
|
||||
self.ent_id_sep = cfg.get("ent_id_sep")
|
||||
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
|
||||
else:
|
||||
self.add_patterns(cfg)
|
||||
return self
|
||||
|
@ -375,9 +371,9 @@ class EntityRuler:
|
|||
}
|
||||
deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
|
||||
from_disk(path, deserializers_cfg, {})
|
||||
self.overwrite = cfg.get("overwrite")
|
||||
self.overwrite = cfg.get("overwrite", False)
|
||||
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr")
|
||||
self.ent_id_sep = cfg.get("ent_id_sep")
|
||||
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
|
||||
|
||||
if self.phrase_matcher_attr is not None:
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
|
|
|
@ -1,95 +0,0 @@
|
|||
from thinc.api import concatenate, reduce_max, reduce_mean, siamese, CauchySimilarity
|
||||
|
||||
from .pipe import Pipe
|
||||
from ..util import link_vectors_to_models
|
||||
|
||||
|
||||
# TODO: do we want to keep these?
|
||||
|
||||
|
||||
class SentenceSegmenter:
|
||||
"""A simple spaCy hook, to allow custom sentence boundary detection logic
|
||||
(that doesn't require the dependency parse). To change the sentence
|
||||
boundary detection strategy, pass a generator function `strategy` on
|
||||
initialization, or assign a new strategy to the .strategy attribute.
|
||||
Sentence detection strategies should be generators that take `Doc` objects
|
||||
and yield `Span` objects for each sentence.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab, strategy=None):
|
||||
self.vocab = vocab
|
||||
if strategy is None or strategy == "on_punct":
|
||||
strategy = self.split_on_punct
|
||||
self.strategy = strategy
|
||||
|
||||
def __call__(self, doc):
|
||||
doc.user_hooks["sents"] = self.strategy
|
||||
return doc
|
||||
|
||||
@staticmethod
|
||||
def split_on_punct(doc):
|
||||
start = 0
|
||||
seen_period = False
|
||||
for i, token in enumerate(doc):
|
||||
if seen_period and not token.is_punct:
|
||||
yield doc[start : token.i]
|
||||
start = token.i
|
||||
seen_period = False
|
||||
elif token.text in [".", "!", "?"]:
|
||||
seen_period = True
|
||||
if start < len(doc):
|
||||
yield doc[start : len(doc)]
|
||||
|
||||
|
||||
class SimilarityHook(Pipe):
|
||||
"""
|
||||
Experimental: A pipeline component to install a hook for supervised
|
||||
similarity into `Doc` objects.
|
||||
The similarity model can be any object obeying the Thinc `Model`
|
||||
interface. By default, the model concatenates the elementwise mean and
|
||||
elementwise max of the two tensors, and compares them using the
|
||||
Cauchy-like similarity function from Chen (2013):
|
||||
|
||||
>>> similarity = 1. / (1. + (W * (vec1-vec2)**2).sum())
|
||||
|
||||
Where W is a vector of dimension weights, initialized to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab, model=True, **cfg):
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.cfg = dict(cfg)
|
||||
|
||||
@classmethod
|
||||
def Model(cls, length):
|
||||
return siamese(
|
||||
concatenate(reduce_max(), reduce_mean()), CauchySimilarity(length * 2)
|
||||
)
|
||||
|
||||
def __call__(self, doc):
|
||||
"""Install similarity hook"""
|
||||
doc.user_hooks["similarity"] = self.predict
|
||||
return doc
|
||||
|
||||
def pipe(self, docs, **kwargs):
|
||||
for doc in docs:
|
||||
yield self(doc)
|
||||
|
||||
def predict(self, doc1, doc2):
|
||||
return self.model.predict([(doc1, doc2)])
|
||||
|
||||
def update(self, doc1_doc2, golds, sgd=None, drop=0.0):
|
||||
sims, bp_sims = self.model.begin_update(doc1_doc2)
|
||||
|
||||
def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs):
|
||||
"""Allocate model, using nO from the first model in the pipeline.
|
||||
|
||||
gold_tuples (iterable): Gold-standard training data.
|
||||
pipeline (list): The pipeline the model is part of.
|
||||
"""
|
||||
if self.model is True:
|
||||
self.model = self.Model(pipeline[0].model.get_dim("nO"))
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Iterable, Tuple, Optional, Dict, List, Callable
|
||||
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Iterator, Any
|
||||
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
||||
import numpy
|
||||
|
||||
|
@ -97,7 +97,7 @@ class TextCategorizer(Pipe):
|
|||
def labels(self, value: Iterable[str]) -> None:
|
||||
self.cfg["labels"] = tuple(value)
|
||||
|
||||
def pipe(self, stream, batch_size=128):
|
||||
def pipe(self, stream: Iterator[str], batch_size: int = 128) -> Iterator[Doc]:
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
scores = self.predict(docs)
|
||||
self.set_annotations(docs, scores)
|
||||
|
@ -252,8 +252,17 @@ class TextCategorizer(Pipe):
|
|||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
def score(self, examples, positive_label=None, **kwargs):
|
||||
return Scorer.score_cats(examples, "cats", labels=self.labels,
|
||||
def score(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
positive_label: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
return Scorer.score_cats(
|
||||
examples,
|
||||
"cats",
|
||||
labels=self.labels,
|
||||
multi_label=self.model.attrs["multi_label"],
|
||||
positive_label=positive_label, **kwargs
|
||||
positive_label=positive_label,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
@ -6,6 +6,7 @@ from ..gold import Example
|
|||
from ..tokens import Doc
|
||||
from ..vocab import Vocab
|
||||
from ..language import Language
|
||||
from ..errors import Errors
|
||||
from ..util import link_vectors_to_models, minibatch
|
||||
|
||||
|
||||
|
@ -150,7 +151,7 @@ class Tok2Vec(Pipe):
|
|||
self.set_annotations(docs, tokvecs)
|
||||
return losses
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
def get_loss(self, examples, scores) -> None:
|
||||
pass
|
||||
|
||||
def begin_training(
|
||||
|
@ -184,26 +185,26 @@ class Tok2VecListener(Model):
|
|||
self._backprop = None
|
||||
|
||||
@classmethod
|
||||
def get_batch_id(cls, inputs):
|
||||
def get_batch_id(cls, inputs) -> int:
|
||||
return sum(sum(token.orth for token in doc) for doc in inputs)
|
||||
|
||||
def receive(self, batch_id, outputs, backprop):
|
||||
def receive(self, batch_id: int, outputs, backprop) -> None:
|
||||
self._batch_id = batch_id
|
||||
self._outputs = outputs
|
||||
self._backprop = backprop
|
||||
|
||||
def verify_inputs(self, inputs):
|
||||
def verify_inputs(self, inputs) -> bool:
|
||||
if self._batch_id is None and self._outputs is None:
|
||||
raise ValueError("The Tok2Vec listener did not receive valid input.")
|
||||
raise ValueError(Errors.E954)
|
||||
else:
|
||||
batch_id = self.get_batch_id(inputs)
|
||||
if batch_id != self._batch_id:
|
||||
raise ValueError(f"Mismatched IDs! {batch_id} vs {self._batch_id}")
|
||||
raise ValueError(Errors.E953.format(id1=batch_id, id2=self._batch_id))
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def forward(model: Tok2VecListener, inputs, is_train):
|
||||
def forward(model: Tok2VecListener, inputs, is_train: bool):
|
||||
if is_train:
|
||||
model.verify_inputs(inputs)
|
||||
return model._outputs, model._backprop
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Union, Optional, Sequence, Any, Callable
|
||||
from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, ValidationError, validator
|
||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||
|
@ -9,12 +9,12 @@ from thinc.api import Optimizer
|
|||
from .attrs import NAMES
|
||||
|
||||
|
||||
def validate(schema, obj):
|
||||
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
|
||||
"""Validate data against a given pydantic schema.
|
||||
|
||||
obj (dict): JSON-serializable data to validate.
|
||||
obj (Dict[str, Any]): JSON-serializable data to validate.
|
||||
schema (pydantic.BaseModel): The schema to validate against.
|
||||
RETURNS (list): A list of error messages, if available.
|
||||
RETURNS (List[str]): A list of error messages, if available.
|
||||
"""
|
||||
try:
|
||||
schema(**obj)
|
||||
|
@ -31,7 +31,7 @@ def validate(schema, obj):
|
|||
# Matcher token patterns
|
||||
|
||||
|
||||
def validate_token_pattern(obj):
|
||||
def validate_token_pattern(obj: list) -> List[str]:
|
||||
# Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"})
|
||||
get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k
|
||||
if isinstance(obj, list):
|
||||
|
|
|
@ -15,7 +15,8 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
|
|||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_NER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
ner = EntityRecognizer(en_vocab, model, **config)
|
||||
ner.begin_training([])
|
||||
ner(doc)
|
||||
|
@ -37,7 +38,8 @@ def test_ents_reset(en_vocab):
|
|||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_NER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
ner = EntityRecognizer(en_vocab, model, **config)
|
||||
ner.begin_training([])
|
||||
ner(doc)
|
||||
|
|
|
@ -14,7 +14,9 @@ def test_en_sbd_single_punct(en_tokenizer, text, punct):
|
|||
assert sum(len(sent) for sent in doc.sents) == len(doc)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
||||
@pytest.mark.skip(
|
||||
reason="The step_through API was removed (but should be brought back)"
|
||||
)
|
||||
def test_en_sentence_breaks(en_tokenizer, en_parser):
|
||||
# fmt: off
|
||||
text = "This is a sentence . This is another one ."
|
||||
|
|
|
@ -32,6 +32,25 @@ SENTENCE_TESTS = [
|
|||
("あれ。これ。", ["あれ。", "これ。"]),
|
||||
("「伝染るんです。」という漫画があります。", ["「伝染るんです。」という漫画があります。"]),
|
||||
]
|
||||
|
||||
tokens1 = [
|
||||
DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None),
|
||||
DetailedToken(surface="会", tag="名詞-普通名詞-一般", inf="", lemma="会", reading="カイ", sub_tokens=None),
|
||||
]
|
||||
tokens2 = [
|
||||
DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None),
|
||||
DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None),
|
||||
DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None),
|
||||
DetailedToken(surface="会", tag="名詞-普通名詞-一般", inf="", lemma="会", reading="カイ", sub_tokens=None),
|
||||
]
|
||||
tokens3 = [
|
||||
DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None),
|
||||
DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None),
|
||||
DetailedToken(surface="委員会", tag="名詞-普通名詞-一般", inf="", lemma="委員会", reading="イインカイ", sub_tokens=None),
|
||||
]
|
||||
SUB_TOKEN_TESTS = [
|
||||
("選挙管理委員会", [None, None, None, None], [None, None, [tokens1]], [[tokens2, tokens3]])
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
|
@ -92,33 +111,12 @@ def test_ja_tokenizer_split_modes(ja_tokenizer, text, len_a, len_b, len_c):
|
|||
assert len(nlp_c(text)) == len_c
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c",
|
||||
[
|
||||
(
|
||||
"選挙管理委員会",
|
||||
[None, None, None, None],
|
||||
[None, None, [
|
||||
[
|
||||
DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None),
|
||||
DetailedToken(surface='会', tag='名詞-普通名詞-一般', inf='', lemma='会', reading='カイ', sub_tokens=None),
|
||||
]
|
||||
]],
|
||||
[[
|
||||
[
|
||||
DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None),
|
||||
DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None),
|
||||
DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None),
|
||||
DetailedToken(surface='会', tag='名詞-普通名詞-一般', inf='', lemma='会', reading='カイ', sub_tokens=None),
|
||||
], [
|
||||
DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None),
|
||||
DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None),
|
||||
DetailedToken(surface='委員会', tag='名詞-普通名詞-一般', inf='', lemma='委員会', reading='イインカイ', sub_tokens=None),
|
||||
]
|
||||
]]
|
||||
),
|
||||
]
|
||||
@pytest.mark.parametrize(
|
||||
"text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c", SUB_TOKEN_TESTS,
|
||||
)
|
||||
def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c):
|
||||
def test_ja_tokenizer_sub_tokens(
|
||||
ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c
|
||||
):
|
||||
nlp_a = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "A"}}})
|
||||
nlp_b = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "B"}}})
|
||||
nlp_c = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "C"}}})
|
||||
|
@ -129,16 +127,19 @@ def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_toke
|
|||
assert nlp_c(text).user_data["sub_tokens"] == sub_tokens_list_c
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text,inflections,reading_forms",
|
||||
@pytest.mark.parametrize(
|
||||
"text,inflections,reading_forms",
|
||||
[
|
||||
(
|
||||
"取ってつけた",
|
||||
("五段-ラ行,連用形-促音便", "", "下一段-カ行,連用形-一般", "助動詞-タ,終止形-一般"),
|
||||
("トッ", "テ", "ツケ", "タ"),
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_ja_tokenizer_inflections_reading_forms(ja_tokenizer, text, inflections, reading_forms):
|
||||
def test_ja_tokenizer_inflections_reading_forms(
|
||||
ja_tokenizer, text, inflections, reading_forms
|
||||
):
|
||||
assert ja_tokenizer(text).user_data["inflections"] == inflections
|
||||
assert ja_tokenizer(text).user_data["reading_forms"] == reading_forms
|
||||
|
||||
|
|
|
@ -11,9 +11,8 @@ def test_ne_tokenizer_handlers_long_text(ne_tokenizer):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,length",
|
||||
[("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)],
|
||||
"text,length", [("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)],
|
||||
)
|
||||
def test_ne_tokenizer_handles_cnts(ne_tokenizer, text, length):
|
||||
tokens = ne_tokenizer(text)
|
||||
assert len(tokens) == length
|
||||
assert len(tokens) == length
|
||||
|
|
|
@ -30,10 +30,7 @@ def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
|
|||
nlp = Chinese(
|
||||
meta={
|
||||
"tokenizer": {
|
||||
"config": {
|
||||
"segmenter": "pkuseg",
|
||||
"pkuseg_model": "medicine",
|
||||
}
|
||||
"config": {"segmenter": "pkuseg", "pkuseg_model": "medicine",}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
|
@ -22,7 +22,8 @@ def parser(vocab):
|
|||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
parser = DependencyParser(vocab, model, **config)
|
||||
return parser
|
||||
|
||||
|
@ -68,7 +69,8 @@ def test_add_label_deserializes_correctly():
|
|||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_NER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
ner1 = EntityRecognizer(Vocab(), model, **config)
|
||||
ner1.add_label("C")
|
||||
ner1.add_label("B")
|
||||
|
@ -86,7 +88,10 @@ def test_add_label_deserializes_correctly():
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
"pipe_cls,n_moves,model_config",
|
||||
[(DependencyParser, 5, DEFAULT_PARSER_MODEL), (EntityRecognizer, 4, DEFAULT_NER_MODEL)],
|
||||
[
|
||||
(DependencyParser, 5, DEFAULT_PARSER_MODEL),
|
||||
(EntityRecognizer, 4, DEFAULT_NER_MODEL),
|
||||
],
|
||||
)
|
||||
def test_add_label_get_label(pipe_cls, n_moves, model_config):
|
||||
"""Test that added labels are returned correctly. This test was added to
|
||||
|
|
|
@ -126,7 +126,8 @@ def test_get_oracle_actions():
|
|||
"min_action_freq": 0,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
parser = DependencyParser(doc.vocab, model, **config)
|
||||
parser.moves.add_action(0, "")
|
||||
parser.moves.add_action(1, "")
|
||||
|
|
|
@ -24,7 +24,8 @@ def arc_eager(vocab):
|
|||
|
||||
@pytest.fixture
|
||||
def tok2vec():
|
||||
tok2vec = registry.make_from_config({"model": DEFAULT_TOK2VEC_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_TOK2VEC_MODEL}
|
||||
tok2vec = registry.make_from_config(cfg, validate=True)["model"]
|
||||
tok2vec.initialize()
|
||||
return tok2vec
|
||||
|
||||
|
@ -36,13 +37,15 @@ def parser(vocab, arc_eager):
|
|||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
return Parser(vocab, model, moves=arc_eager, **config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model(arc_eager, tok2vec, vocab):
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model.attrs["resize_output"](model, arc_eager.n_moves)
|
||||
model.initialize()
|
||||
return model
|
||||
|
@ -68,7 +71,8 @@ def test_build_model(parser, vocab):
|
|||
"min_action_freq": 0,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model
|
||||
assert parser.model is not None
|
||||
|
||||
|
|
|
@ -33,7 +33,9 @@ def test_parser_root(en_tokenizer):
|
|||
assert t.dep != 0, t.text
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
||||
@pytest.mark.skip(
|
||||
reason="The step_through API was removed (but should be brought back)"
|
||||
)
|
||||
@pytest.mark.parametrize("text", ["Hello"])
|
||||
def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text):
|
||||
tokens = en_tokenizer(text)
|
||||
|
@ -47,7 +49,9 @@ def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text):
|
|||
assert doc[0].dep != 0
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
||||
@pytest.mark.skip(
|
||||
reason="The step_through API was removed (but should be brought back)"
|
||||
)
|
||||
def test_parser_initial(en_tokenizer, en_parser):
|
||||
text = "I ate the pizza with anchovies."
|
||||
# heads = [1, 0, 1, -2, -3, -1, -5]
|
||||
|
@ -92,7 +96,9 @@ def test_parser_merge_pp(en_tokenizer):
|
|||
assert doc[3].text == "occurs"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
||||
@pytest.mark.skip(
|
||||
reason="The step_through API was removed (but should be brought back)"
|
||||
)
|
||||
def test_parser_arc_eager_finalize_state(en_tokenizer, en_parser):
|
||||
text = "a b c d e"
|
||||
|
||||
|
|
|
@ -21,7 +21,8 @@ def parser(vocab):
|
|||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
parser = DependencyParser(vocab, model, **config)
|
||||
parser.cfg["token_vector_width"] = 4
|
||||
parser.cfg["hidden_width"] = 32
|
||||
|
|
|
@ -28,7 +28,9 @@ def test_parser_sentence_space(en_tokenizer):
|
|||
assert len(list(doc.sents)) == 2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
||||
@pytest.mark.skip(
|
||||
reason="The step_through API was removed (but should be brought back)"
|
||||
)
|
||||
def test_parser_space_attachment_leading(en_tokenizer, en_parser):
|
||||
text = "\t \n This is a sentence ."
|
||||
heads = [1, 1, 0, 1, -2, -3]
|
||||
|
@ -44,7 +46,9 @@ def test_parser_space_attachment_leading(en_tokenizer, en_parser):
|
|||
assert stepwise.stack == set([2])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
||||
@pytest.mark.skip(
|
||||
reason="The step_through API was removed (but should be brought back)"
|
||||
)
|
||||
def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser):
|
||||
text = "This is \t a \t\n \n sentence . \n\n \n"
|
||||
heads = [1, 0, -1, 2, -1, -4, -5, -1]
|
||||
|
@ -64,7 +68,9 @@ def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("text,length", [(["\n"], 1), (["\n", "\t", "\n\n", "\t"], 4)])
|
||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
||||
@pytest.mark.skip(
|
||||
reason="The step_through API was removed (but should be brought back)"
|
||||
)
|
||||
def test_parser_space_attachment_space(en_tokenizer, en_parser, text, length):
|
||||
doc = Doc(en_parser.vocab, words=text)
|
||||
assert len(doc) == length
|
||||
|
|
|
@ -117,7 +117,9 @@ def test_overfitting_IO():
|
|||
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
|
||||
|
||||
# Test scoring
|
||||
scores = nlp.evaluate(train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}})
|
||||
scores = nlp.evaluate(
|
||||
train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}}
|
||||
)
|
||||
assert scores["cats_f"] == 1.0
|
||||
|
||||
|
||||
|
|
|
@ -201,7 +201,8 @@ def test_issue3345():
|
|||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_NER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
ner = EntityRecognizer(doc.vocab, model, **config)
|
||||
# Add the OUT action. I wouldn't have thought this would be necessary...
|
||||
ner.moves.add_action(5, "")
|
||||
|
|
|
@ -20,7 +20,8 @@ def parser(en_vocab):
|
|||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
parser = DependencyParser(en_vocab, model, **config)
|
||||
parser.add_label("nsubj")
|
||||
return parser
|
||||
|
@ -33,14 +34,16 @@ def blank_parser(en_vocab):
|
|||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
parser = DependencyParser(en_vocab, model, **config)
|
||||
return parser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def taggers(en_vocab):
|
||||
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
tagger1 = Tagger(en_vocab, model, set_morphology=True)
|
||||
tagger2 = Tagger(en_vocab, model, set_morphology=True)
|
||||
return tagger1, tagger2
|
||||
|
@ -53,7 +56,8 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
|
|||
"min_action_freq": 0,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
parser = Parser(en_vocab, model, **config)
|
||||
new_parser = Parser(en_vocab, model, **config)
|
||||
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
|
||||
|
@ -70,7 +74,8 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
|
|||
"min_action_freq": 0,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
parser = Parser(en_vocab, model, **config)
|
||||
with make_tempdir() as d:
|
||||
file_path = d / "parser"
|
||||
|
@ -88,7 +93,6 @@ def test_to_from_bytes(parser, blank_parser):
|
|||
assert blank_parser.model is not True
|
||||
assert blank_parser.moves.n_moves != parser.moves.n_moves
|
||||
bytes_data = parser.to_bytes(exclude=["vocab"])
|
||||
|
||||
# the blank parser needs to be resized before we can call from_bytes
|
||||
blank_parser.model.attrs["resize_output"](blank_parser.model, parser.moves.n_moves)
|
||||
blank_parser.from_bytes(bytes_data)
|
||||
|
@ -104,7 +108,8 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
|
|||
tagger1_b = tagger1.to_bytes()
|
||||
tagger1 = tagger1.from_bytes(tagger1_b)
|
||||
assert tagger1.to_bytes() == tagger1_b
|
||||
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b)
|
||||
new_tagger1_b = new_tagger1.to_bytes()
|
||||
assert len(new_tagger1_b) == len(tagger1_b)
|
||||
|
@ -118,7 +123,8 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
|
|||
file_path2 = d / "tagger2"
|
||||
tagger1.to_disk(file_path1)
|
||||
tagger2.to_disk(file_path2)
|
||||
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
tagger1_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path1)
|
||||
tagger2_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path2)
|
||||
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
|
||||
|
@ -126,21 +132,22 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
|
|||
|
||||
def test_serialize_textcat_empty(en_vocab):
|
||||
# See issue #1105
|
||||
model = registry.make_from_config({"model": DEFAULT_TEXTCAT_MODEL}, validate=True)["model"]
|
||||
textcat = TextCategorizer(
|
||||
en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"]
|
||||
)
|
||||
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"])
|
||||
textcat.to_bytes(exclude=["vocab"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("Parser", test_parsers)
|
||||
def test_serialize_pipe_exclude(en_vocab, Parser):
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
config = {
|
||||
"learn_tokens": False,
|
||||
"min_action_freq": 0,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
|
||||
def get_new_parser():
|
||||
new_parser = Parser(en_vocab, model, **config)
|
||||
return new_parser
|
||||
|
@ -160,7 +167,8 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
|
|||
|
||||
|
||||
def test_serialize_sentencerecognizer(en_vocab):
|
||||
model = registry.make_from_config({"model": DEFAULT_SENTER_MODEL}, validate=True)["model"]
|
||||
cfg = {"model": DEFAULT_SENTER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
sr = SentenceRecognizer(en_vocab, model)
|
||||
sr_b = sr.to_bytes()
|
||||
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)
|
||||
|
|
|
@ -466,7 +466,6 @@ def test_iob_to_biluo():
|
|||
|
||||
|
||||
def test_roundtrip_docs_to_docbin(doc):
|
||||
nlp = English()
|
||||
text = doc.text
|
||||
idx = [t.idx for t in doc]
|
||||
tags = [t.tag_ for t in doc]
|
||||
|
|
|
@ -7,11 +7,11 @@ from spacy.vocab import Vocab
|
|||
def test_Example_init_requires_doc_objects():
|
||||
vocab = Vocab()
|
||||
with pytest.raises(TypeError):
|
||||
example = Example(None, None)
|
||||
Example(None, None)
|
||||
with pytest.raises(TypeError):
|
||||
example = Example(Doc(vocab, words=["hi"]), None)
|
||||
Example(Doc(vocab, words=["hi"]), None)
|
||||
with pytest.raises(TypeError):
|
||||
example = Example(None, Doc(vocab, words=["hi"]))
|
||||
Example(None, Doc(vocab, words=["hi"]))
|
||||
|
||||
|
||||
def test_Example_from_dict_basic():
|
||||
|
|
|
@ -105,7 +105,11 @@ def test_tokenization(sented_doc):
|
|||
assert scores["token_acc"] == 1.0
|
||||
|
||||
nlp = English()
|
||||
example.predicted = Doc(nlp.vocab, words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."], spaces=[True, True, True, True, True, False])
|
||||
example.predicted = Doc(
|
||||
nlp.vocab,
|
||||
words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."],
|
||||
spaces=[True, True, True, True, True, False],
|
||||
)
|
||||
example.predicted[1].is_sent_start = False
|
||||
scores = scorer.score([example])
|
||||
assert scores["token_acc"] == approx(0.66666666)
|
||||
|
|
193
spacy/util.py
193
spacy/util.py
|
@ -1,12 +1,13 @@
|
|||
from typing import List, Union, Dict, Any, Optional, Iterable, Callable, Tuple
|
||||
from typing import Iterator, TYPE_CHECKING
|
||||
from typing import Iterator, Type, Pattern, Sequence, TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
import os
|
||||
import importlib
|
||||
import importlib.util
|
||||
import re
|
||||
from pathlib import Path
|
||||
import thinc
|
||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config
|
||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
||||
import functools
|
||||
import itertools
|
||||
import numpy.random
|
||||
|
@ -49,6 +50,8 @@ from . import about
|
|||
if TYPE_CHECKING:
|
||||
# This lets us add type hints for mypy etc. without causing circular imports
|
||||
from .language import Language # noqa: F401
|
||||
from .tokens import Doc, Span # noqa: F401
|
||||
from .vocab import Vocab # noqa: F401
|
||||
|
||||
|
||||
_PRINT_ENV = False
|
||||
|
@ -102,12 +105,12 @@ class SimpleFrozenDict(dict):
|
|||
raise NotImplementedError(self.error)
|
||||
|
||||
|
||||
def set_env_log(value):
|
||||
def set_env_log(value: bool) -> None:
|
||||
global _PRINT_ENV
|
||||
_PRINT_ENV = value
|
||||
|
||||
|
||||
def lang_class_is_loaded(lang):
|
||||
def lang_class_is_loaded(lang: str) -> bool:
|
||||
"""Check whether a Language class is already loaded. Language classes are
|
||||
loaded lazily, to avoid expensive setup code associated with the language
|
||||
data.
|
||||
|
@ -118,7 +121,7 @@ def lang_class_is_loaded(lang):
|
|||
return lang in registry.languages
|
||||
|
||||
|
||||
def get_lang_class(lang):
|
||||
def get_lang_class(lang: str) -> "Language":
|
||||
"""Import and load a Language class.
|
||||
|
||||
lang (str): Two-letter language code, e.g. 'en'.
|
||||
|
@ -136,7 +139,7 @@ def get_lang_class(lang):
|
|||
return registry.languages.get(lang)
|
||||
|
||||
|
||||
def set_lang_class(name, cls):
|
||||
def set_lang_class(name: str, cls: Type["Language"]) -> None:
|
||||
"""Set a custom Language class name that can be loaded via get_lang_class.
|
||||
|
||||
name (str): Name of Language class.
|
||||
|
@ -145,10 +148,10 @@ def set_lang_class(name, cls):
|
|||
registry.languages.register(name, func=cls)
|
||||
|
||||
|
||||
def ensure_path(path):
|
||||
def ensure_path(path: Any) -> Any:
|
||||
"""Ensure string is converted to a Path.
|
||||
|
||||
path: Anything. If string, it's converted to Path.
|
||||
path (Any): Anything. If string, it's converted to Path.
|
||||
RETURNS: Path or original argument.
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
|
@ -157,7 +160,7 @@ def ensure_path(path):
|
|||
return path
|
||||
|
||||
|
||||
def load_language_data(path):
|
||||
def load_language_data(path: Union[str, Path]) -> Union[dict, list]:
|
||||
"""Load JSON language data using the given path as a base. If the provided
|
||||
path isn't present, will attempt to load a gzipped version before giving up.
|
||||
|
||||
|
@ -173,7 +176,12 @@ def load_language_data(path):
|
|||
raise ValueError(Errors.E160.format(path=path))
|
||||
|
||||
|
||||
def get_module_path(module):
|
||||
def get_module_path(module: ModuleType) -> Path:
|
||||
"""Get the path of a Python module.
|
||||
|
||||
module (ModuleType): The Python module.
|
||||
RETURNS (Path): The path.
|
||||
"""
|
||||
if not hasattr(module, "__module__"):
|
||||
raise ValueError(Errors.E169.format(module=repr(module)))
|
||||
return Path(sys.modules[module.__module__].__file__).parent
|
||||
|
@ -183,7 +191,7 @@ def load_model(
|
|||
name: Union[str, Path],
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
):
|
||||
) -> "Language":
|
||||
"""Load a model from a package or data path.
|
||||
|
||||
name (str): Package name or model path.
|
||||
|
@ -209,7 +217,7 @@ def load_model_from_package(
|
|||
name: str,
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
):
|
||||
) -> "Language":
|
||||
"""Load a model from an installed package."""
|
||||
cls = importlib.import_module(name)
|
||||
return cls.load(disable=disable, component_cfg=component_cfg)
|
||||
|
@ -220,7 +228,7 @@ def load_model_from_path(
|
|||
meta: Optional[Dict[str, Any]] = None,
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
):
|
||||
) -> "Language":
|
||||
"""Load a model from a data directory path. Creates Language class with
|
||||
pipeline from config.cfg and then calls from_disk() with path."""
|
||||
if not model_path.exists():
|
||||
|
@ -269,7 +277,7 @@ def load_model_from_init_py(
|
|||
init_file: Union[Path, str],
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
):
|
||||
) -> "Language":
|
||||
"""Helper function to use in the `load()` method of a model package's
|
||||
__init__.py.
|
||||
|
||||
|
@ -288,15 +296,15 @@ def load_model_from_init_py(
|
|||
)
|
||||
|
||||
|
||||
def get_installed_models():
|
||||
def get_installed_models() -> List[str]:
|
||||
"""List all model packages currently installed in the environment.
|
||||
|
||||
RETURNS (list): The string names of the models.
|
||||
RETURNS (List[str]): The string names of the models.
|
||||
"""
|
||||
return list(registry.models.get_all().keys())
|
||||
|
||||
|
||||
def get_package_version(name):
|
||||
def get_package_version(name: str) -> Optional[str]:
|
||||
"""Get the version of an installed package. Typically used to get model
|
||||
package versions.
|
||||
|
||||
|
@ -309,7 +317,9 @@ def get_package_version(name):
|
|||
return None
|
||||
|
||||
|
||||
def is_compatible_version(version, constraint, prereleases=True):
|
||||
def is_compatible_version(
|
||||
version: str, constraint: str, prereleases: bool = True
|
||||
) -> Optional[bool]:
|
||||
"""Check if a version (e.g. "2.0.0") is compatible given a version
|
||||
constraint (e.g. ">=1.9.0,<2.2.1"). If the constraint is a specific version,
|
||||
it's interpreted as =={version}.
|
||||
|
@ -333,7 +343,9 @@ def is_compatible_version(version, constraint, prereleases=True):
|
|||
return version in spec
|
||||
|
||||
|
||||
def is_unconstrained_version(constraint, prereleases=True):
|
||||
def is_unconstrained_version(
|
||||
constraint: str, prereleases: bool = True
|
||||
) -> Optional[bool]:
|
||||
# We have an exact version, this is the ultimate constrained version
|
||||
if constraint[0].isdigit():
|
||||
return False
|
||||
|
@ -358,7 +370,7 @@ def is_unconstrained_version(constraint, prereleases=True):
|
|||
return True
|
||||
|
||||
|
||||
def get_model_version_range(spacy_version):
|
||||
def get_model_version_range(spacy_version: str) -> str:
|
||||
"""Generate a version range like >=1.2.3,<1.3.0 based on a given spaCy
|
||||
version. Models are always compatible across patch versions but not
|
||||
across minor or major versions.
|
||||
|
@ -367,7 +379,7 @@ def get_model_version_range(spacy_version):
|
|||
return f">={spacy_version},<{release[0]}.{release[1] + 1}.0"
|
||||
|
||||
|
||||
def get_base_version(version):
|
||||
def get_base_version(version: str) -> str:
|
||||
"""Generate the base version without any prerelease identifiers.
|
||||
|
||||
version (str): The version, e.g. "3.0.0.dev1".
|
||||
|
@ -376,11 +388,11 @@ def get_base_version(version):
|
|||
return Version(version).base_version
|
||||
|
||||
|
||||
def get_model_meta(path):
|
||||
def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""Get model meta.json from a directory path and validate its contents.
|
||||
|
||||
path (str / Path): Path to model directory.
|
||||
RETURNS (dict): The model's meta data.
|
||||
RETURNS (Dict[str, Any]): The model's meta data.
|
||||
"""
|
||||
model_path = ensure_path(path)
|
||||
if not model_path.exists():
|
||||
|
@ -412,7 +424,7 @@ def get_model_meta(path):
|
|||
return meta
|
||||
|
||||
|
||||
def is_package(name):
|
||||
def is_package(name: str) -> bool:
|
||||
"""Check if string maps to a package installed via pip.
|
||||
|
||||
name (str): Name of package.
|
||||
|
@ -425,7 +437,7 @@ def is_package(name):
|
|||
return False
|
||||
|
||||
|
||||
def get_package_path(name):
|
||||
def get_package_path(name: str) -> Path:
|
||||
"""Get the path to an installed package.
|
||||
|
||||
name (str): Package name.
|
||||
|
@ -495,7 +507,7 @@ def working_dir(path: Union[str, Path]) -> None:
|
|||
|
||||
|
||||
@contextmanager
|
||||
def make_tempdir():
|
||||
def make_tempdir() -> None:
|
||||
"""Execute a block in a temporary directory and remove the directory and
|
||||
its contents at the end of the with block.
|
||||
|
||||
|
@ -518,7 +530,7 @@ def is_cwd(path: Union[Path, str]) -> bool:
|
|||
return str(Path(path).resolve()).lower() == str(Path.cwd().resolve()).lower()
|
||||
|
||||
|
||||
def is_in_jupyter():
|
||||
def is_in_jupyter() -> bool:
|
||||
"""Check if user is running spaCy from a Jupyter notebook by detecting the
|
||||
IPython kernel. Mainly used for the displaCy visualizer.
|
||||
RETURNS (bool): True if in Jupyter, False if not.
|
||||
|
@ -548,7 +560,9 @@ def get_object_name(obj: Any) -> str:
|
|||
return repr(obj)
|
||||
|
||||
|
||||
def get_cuda_stream(require=False, non_blocking=True):
|
||||
def get_cuda_stream(
|
||||
require: bool = False, non_blocking: bool = True
|
||||
) -> Optional[CudaStream]:
|
||||
ops = get_current_ops()
|
||||
if CudaStream is None:
|
||||
return None
|
||||
|
@ -567,7 +581,7 @@ def get_async(stream, numpy_array):
|
|||
return array
|
||||
|
||||
|
||||
def env_opt(name, default=None):
|
||||
def env_opt(name: str, default: Optional[Any] = None) -> Optional[Any]:
|
||||
if type(default) is float:
|
||||
type_convert = float
|
||||
else:
|
||||
|
@ -588,7 +602,7 @@ def env_opt(name, default=None):
|
|||
return default
|
||||
|
||||
|
||||
def read_regex(path):
|
||||
def read_regex(path: Union[str, Path]) -> Pattern:
|
||||
path = ensure_path(path)
|
||||
with path.open(encoding="utf8") as file_:
|
||||
entries = file_.read().split("\n")
|
||||
|
@ -598,37 +612,40 @@ def read_regex(path):
|
|||
return re.compile(expression)
|
||||
|
||||
|
||||
def compile_prefix_regex(entries):
|
||||
def compile_prefix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
|
||||
"""Compile a sequence of prefix rules into a regex object.
|
||||
|
||||
entries (tuple): The prefix rules, e.g. spacy.lang.punctuation.TOKENIZER_PREFIXES.
|
||||
RETURNS (regex object): The regex object. to be used for Tokenizer.prefix_search.
|
||||
entries (Iterable[Union[str, Pattern]]): The prefix rules, e.g.
|
||||
spacy.lang.punctuation.TOKENIZER_PREFIXES.
|
||||
RETURNS (Pattern): The regex object. to be used for Tokenizer.prefix_search.
|
||||
"""
|
||||
expression = "|".join(["^" + piece for piece in entries if piece.strip()])
|
||||
return re.compile(expression)
|
||||
|
||||
|
||||
def compile_suffix_regex(entries):
|
||||
def compile_suffix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
|
||||
"""Compile a sequence of suffix rules into a regex object.
|
||||
|
||||
entries (tuple): The suffix rules, e.g. spacy.lang.punctuation.TOKENIZER_SUFFIXES.
|
||||
RETURNS (regex object): The regex object. to be used for Tokenizer.suffix_search.
|
||||
entries (Iterable[Union[str, Pattern]]): The suffix rules, e.g.
|
||||
spacy.lang.punctuation.TOKENIZER_SUFFIXES.
|
||||
RETURNS (Pattern): The regex object. to be used for Tokenizer.suffix_search.
|
||||
"""
|
||||
expression = "|".join([piece + "$" for piece in entries if piece.strip()])
|
||||
return re.compile(expression)
|
||||
|
||||
|
||||
def compile_infix_regex(entries):
|
||||
def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
|
||||
"""Compile a sequence of infix rules into a regex object.
|
||||
|
||||
entries (tuple): The infix rules, e.g. spacy.lang.punctuation.TOKENIZER_INFIXES.
|
||||
entries (Iterable[Union[str, Pattern]]): The infix rules, e.g.
|
||||
spacy.lang.punctuation.TOKENIZER_INFIXES.
|
||||
RETURNS (regex object): The regex object. to be used for Tokenizer.infix_finditer.
|
||||
"""
|
||||
expression = "|".join([piece for piece in entries if piece.strip()])
|
||||
return re.compile(expression)
|
||||
|
||||
|
||||
def add_lookups(default_func, *lookups):
|
||||
def add_lookups(default_func: Callable[[str], Any], *lookups) -> Callable[[str], Any]:
|
||||
"""Extend an attribute function with special cases. If a word is in the
|
||||
lookups, the value is returned. Otherwise the previous function is used.
|
||||
|
||||
|
@ -641,19 +658,23 @@ def add_lookups(default_func, *lookups):
|
|||
return functools.partial(_get_attr_unless_lookup, default_func, lookups)
|
||||
|
||||
|
||||
def _get_attr_unless_lookup(default_func, lookups, string):
|
||||
def _get_attr_unless_lookup(
|
||||
default_func: Callable[[str], Any], lookups: Dict[str, Any], string: str
|
||||
) -> Any:
|
||||
for lookup in lookups:
|
||||
if string in lookup:
|
||||
return lookup[string]
|
||||
return default_func(string)
|
||||
|
||||
|
||||
def update_exc(base_exceptions, *addition_dicts):
|
||||
def update_exc(
|
||||
base_exceptions: Dict[str, List[dict]], *addition_dicts
|
||||
) -> Dict[str, List[dict]]:
|
||||
"""Update and validate tokenizer exceptions. Will overwrite exceptions.
|
||||
|
||||
base_exceptions (dict): Base exceptions.
|
||||
*addition_dicts (dict): Exceptions to add to the base dict, in order.
|
||||
RETURNS (dict): Combined tokenizer exceptions.
|
||||
base_exceptions (Dict[str, List[dict]]): Base exceptions.
|
||||
*addition_dicts (Dict[str, List[dict]]): Exceptions to add to the base dict, in order.
|
||||
RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions.
|
||||
"""
|
||||
exc = dict(base_exceptions)
|
||||
for additions in addition_dicts:
|
||||
|
@ -668,14 +689,16 @@ def update_exc(base_exceptions, *addition_dicts):
|
|||
return exc
|
||||
|
||||
|
||||
def expand_exc(excs, search, replace):
|
||||
def expand_exc(
|
||||
excs: Dict[str, List[dict]], search: str, replace: str
|
||||
) -> Dict[str, List[dict]]:
|
||||
"""Find string in tokenizer exceptions, duplicate entry and replace string.
|
||||
For example, to add additional versions with typographic apostrophes.
|
||||
|
||||
excs (dict): Tokenizer exceptions.
|
||||
excs (Dict[str, List[dict]]): Tokenizer exceptions.
|
||||
search (str): String to find and replace.
|
||||
replace (str): Replacement.
|
||||
RETURNS (dict): Combined tokenizer exceptions.
|
||||
RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions.
|
||||
"""
|
||||
|
||||
def _fix_token(token, search, replace):
|
||||
|
@ -692,7 +715,9 @@ def expand_exc(excs, search, replace):
|
|||
return new_excs
|
||||
|
||||
|
||||
def normalize_slice(length, start, stop, step=None):
|
||||
def normalize_slice(
|
||||
length: int, start: int, stop: int, step: Optional[int] = None
|
||||
) -> Tuple[int, int]:
|
||||
if not (step is None or step == 1):
|
||||
raise ValueError(Errors.E057)
|
||||
if start is None:
|
||||
|
@ -708,7 +733,9 @@ def normalize_slice(length, start, stop, step=None):
|
|||
return start, stop
|
||||
|
||||
|
||||
def minibatch(items, size=8):
|
||||
def minibatch(
|
||||
items: Iterable[Any], size: Union[Iterator[int], int] = 8
|
||||
) -> Iterator[Any]:
|
||||
"""Iterate over batches of items. `size` may be an iterator,
|
||||
so that batch-size can vary on each step.
|
||||
"""
|
||||
|
@ -725,7 +752,12 @@ def minibatch(items, size=8):
|
|||
yield list(batch)
|
||||
|
||||
|
||||
def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False):
|
||||
def minibatch_by_padded_size(
|
||||
docs: Iterator["Doc"],
|
||||
size: Union[Iterator[int], int],
|
||||
buffer: int = 256,
|
||||
discard_oversize: bool = False,
|
||||
) -> Iterator[Iterator["Doc"]]:
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
|
@ -742,7 +774,7 @@ def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False):
|
|||
yield subbatch
|
||||
|
||||
|
||||
def _batch_by_length(seqs, max_words):
|
||||
def _batch_by_length(seqs: Sequence[Any], max_words: int) -> List[List[Any]]:
|
||||
"""Given a list of sequences, return a batched list of indices into the
|
||||
list, where the batches are grouped by length, in descending order.
|
||||
|
||||
|
@ -783,14 +815,12 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
|||
size_ = iter(size)
|
||||
else:
|
||||
size_ = size
|
||||
|
||||
target_size = next(size_)
|
||||
tol_size = target_size * tolerance
|
||||
batch = []
|
||||
overflow = []
|
||||
batch_size = 0
|
||||
overflow_size = 0
|
||||
|
||||
for doc in docs:
|
||||
if isinstance(doc, Example):
|
||||
n_words = len(doc.reference)
|
||||
|
@ -803,17 +833,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
|||
if n_words > target_size + tol_size:
|
||||
if not discard_oversize:
|
||||
yield [doc]
|
||||
|
||||
# add the example to the current batch if there's no overflow yet and it still fits
|
||||
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
||||
batch.append(doc)
|
||||
batch_size += n_words
|
||||
|
||||
# add the example to the overflow buffer if it fits in the tolerance margin
|
||||
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
||||
overflow.append(doc)
|
||||
overflow_size += n_words
|
||||
|
||||
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||
else:
|
||||
if batch:
|
||||
|
@ -824,17 +851,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
|||
batch_size = overflow_size
|
||||
overflow = []
|
||||
overflow_size = 0
|
||||
|
||||
# this example still fits
|
||||
if (batch_size + n_words) <= target_size:
|
||||
batch.append(doc)
|
||||
batch_size += n_words
|
||||
|
||||
# this example fits in overflow
|
||||
elif (batch_size + n_words) <= (target_size + tol_size):
|
||||
overflow.append(doc)
|
||||
overflow_size += n_words
|
||||
|
||||
# this example does not fit with the previous overflow: start another new batch
|
||||
else:
|
||||
if batch:
|
||||
|
@ -843,20 +867,19 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
|||
tol_size = target_size * tolerance
|
||||
batch = [doc]
|
||||
batch_size = n_words
|
||||
|
||||
batch.extend(overflow)
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
|
||||
def filter_spans(spans):
|
||||
def filter_spans(spans: Iterable["Span"]) -> List["Span"]:
|
||||
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
|
||||
creating named entities (where one token can only be part of one entity) or
|
||||
when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
|
||||
longest span is preferred over shorter spans.
|
||||
|
||||
spans (iterable): The spans to filter.
|
||||
RETURNS (list): The filtered spans.
|
||||
spans (Iterable[Span]): The spans to filter.
|
||||
RETURNS (List[Span]): The filtered spans.
|
||||
"""
|
||||
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
||||
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||
|
@ -871,15 +894,21 @@ def filter_spans(spans):
|
|||
return result
|
||||
|
||||
|
||||
def to_bytes(getters, exclude):
|
||||
def to_bytes(getters: Dict[str, Callable[[], bytes]], exclude: Iterable[str]) -> bytes:
|
||||
return srsly.msgpack_dumps(to_dict(getters, exclude))
|
||||
|
||||
|
||||
def from_bytes(bytes_data, setters, exclude):
|
||||
def from_bytes(
|
||||
bytes_data: bytes,
|
||||
setters: Dict[str, Callable[[bytes], Any]],
|
||||
exclude: Iterable[str],
|
||||
) -> None:
|
||||
return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude)
|
||||
|
||||
|
||||
def to_dict(getters, exclude):
|
||||
def to_dict(
|
||||
getters: Dict[str, Callable[[], Any]], exclude: Iterable[str]
|
||||
) -> Dict[str, Any]:
|
||||
serialized = {}
|
||||
for key, getter in getters.items():
|
||||
# Split to support file names like meta.json
|
||||
|
@ -888,7 +917,11 @@ def to_dict(getters, exclude):
|
|||
return serialized
|
||||
|
||||
|
||||
def from_dict(msg, setters, exclude):
|
||||
def from_dict(
|
||||
msg: Dict[str, Any],
|
||||
setters: Dict[str, Callable[[Any], Any]],
|
||||
exclude: Iterable[str],
|
||||
) -> Dict[str, Any]:
|
||||
for key, setter in setters.items():
|
||||
# Split to support file names like meta.json
|
||||
if key.split(".")[0] not in exclude and key in msg:
|
||||
|
@ -896,7 +929,11 @@ def from_dict(msg, setters, exclude):
|
|||
return msg
|
||||
|
||||
|
||||
def to_disk(path, writers, exclude):
|
||||
def to_disk(
|
||||
path: Union[str, Path],
|
||||
writers: Dict[str, Callable[[Path], None]],
|
||||
exclude: Iterable[str],
|
||||
) -> Path:
|
||||
path = ensure_path(path)
|
||||
if not path.exists():
|
||||
path.mkdir()
|
||||
|
@ -907,7 +944,11 @@ def to_disk(path, writers, exclude):
|
|||
return path
|
||||
|
||||
|
||||
def from_disk(path, readers, exclude):
|
||||
def from_disk(
|
||||
path: Union[str, Path],
|
||||
readers: Dict[str, Callable[[Path], None]],
|
||||
exclude: Iterable[str],
|
||||
) -> Path:
|
||||
path = ensure_path(path)
|
||||
for key, reader in readers.items():
|
||||
# Split to support file names like meta.json
|
||||
|
@ -916,7 +957,7 @@ def from_disk(path, readers, exclude):
|
|||
return path
|
||||
|
||||
|
||||
def import_file(name, loc):
|
||||
def import_file(name: str, loc: Union[str, Path]) -> ModuleType:
|
||||
"""Import module from a file. Used to load models from a directory.
|
||||
|
||||
name (str): Name of module to load.
|
||||
|
@ -930,7 +971,7 @@ def import_file(name, loc):
|
|||
return module
|
||||
|
||||
|
||||
def minify_html(html):
|
||||
def minify_html(html: str) -> str:
|
||||
"""Perform a template-specific, rudimentary HTML minification for displaCy.
|
||||
Disclaimer: NOT a general-purpose solution, only removes indentation and
|
||||
newlines.
|
||||
|
@ -941,7 +982,7 @@ def minify_html(html):
|
|||
return html.strip().replace(" ", "").replace("\n", "")
|
||||
|
||||
|
||||
def escape_html(text):
|
||||
def escape_html(text: str) -> str:
|
||||
"""Replace <, >, &, " with their HTML encoded representation. Intended to
|
||||
prevent HTML errors in rendered displaCy markup.
|
||||
|
||||
|
@ -955,7 +996,9 @@ def escape_html(text):
|
|||
return text
|
||||
|
||||
|
||||
def get_words_and_spaces(words, text):
|
||||
def get_words_and_spaces(
|
||||
words: Iterable[str], text: str
|
||||
) -> Tuple[List[str], List[bool]]:
|
||||
if "".join("".join(words).split()) != "".join(text.split()):
|
||||
raise ValueError(Errors.E194.format(text=text, words=words))
|
||||
text_words = []
|
||||
|
@ -1103,7 +1146,7 @@ class DummyTokenizer:
|
|||
return self
|
||||
|
||||
|
||||
def link_vectors_to_models(vocab):
|
||||
def link_vectors_to_models(vocab: "Vocab") -> None:
|
||||
vectors = vocab.vectors
|
||||
if vectors.name is None:
|
||||
vectors.name = VECTORS_KEY
|
||||
|
@ -1119,7 +1162,7 @@ def link_vectors_to_models(vocab):
|
|||
VECTORS_KEY = "spacy_pretrained_vectors"
|
||||
|
||||
|
||||
def create_default_optimizer():
|
||||
def create_default_optimizer() -> Optimizer:
|
||||
learn_rate = env_opt("learn_rate", 0.001)
|
||||
beta1 = env_opt("optimizer_B1", 0.9)
|
||||
beta2 = env_opt("optimizer_B2", 0.999)
|
||||
|
|
Loading…
Reference in New Issue
Block a user