mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-05 06:00:36 +03:00
Tidy up, autoformat, add types
This commit is contained in:
parent
71242327b2
commit
e92df281ce
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -51,6 +51,7 @@ env3.*/
|
||||||
.denv
|
.denv
|
||||||
.pypyenv
|
.pypyenv
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
env/
|
env/
|
||||||
|
|
|
@ -108,3 +108,8 @@ exclude =
|
||||||
[tool:pytest]
|
[tool:pytest]
|
||||||
markers =
|
markers =
|
||||||
slow
|
slow
|
||||||
|
|
||||||
|
[mypy]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
no_implicit_optional = True
|
||||||
|
plugins = pydantic.mypy, thinc.mypy
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional
|
from typing import Optional, Any, List, Union
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
|
@ -66,10 +66,9 @@ def convert_cli(
|
||||||
file_type = file_type.value
|
file_type = file_type.value
|
||||||
input_path = Path(input_path)
|
input_path = Path(input_path)
|
||||||
output_dir = "-" if output_dir == Path("-") else output_dir
|
output_dir = "-" if output_dir == Path("-") else output_dir
|
||||||
cli_args = locals()
|
|
||||||
silent = output_dir == "-"
|
silent = output_dir == "-"
|
||||||
msg = Printer(no_print=silent)
|
msg = Printer(no_print=silent)
|
||||||
verify_cli_args(msg, **cli_args)
|
verify_cli_args(msg, input_path, output_dir, file_type, converter, ner_map)
|
||||||
converter = _get_converter(msg, converter, input_path)
|
converter = _get_converter(msg, converter, input_path)
|
||||||
convert(
|
convert(
|
||||||
input_path,
|
input_path,
|
||||||
|
@ -89,8 +88,8 @@ def convert_cli(
|
||||||
|
|
||||||
|
|
||||||
def convert(
|
def convert(
|
||||||
input_path: Path,
|
input_path: Union[str, Path],
|
||||||
output_dir: Path,
|
output_dir: Union[str, Path],
|
||||||
*,
|
*,
|
||||||
file_type: str = "json",
|
file_type: str = "json",
|
||||||
n_sents: int = 1,
|
n_sents: int = 1,
|
||||||
|
@ -102,13 +101,12 @@ def convert(
|
||||||
ner_map: Optional[Path] = None,
|
ner_map: Optional[Path] = None,
|
||||||
lang: Optional[str] = None,
|
lang: Optional[str] = None,
|
||||||
silent: bool = True,
|
silent: bool = True,
|
||||||
msg: Optional[Path] = None,
|
msg: Optional[Printer],
|
||||||
) -> None:
|
) -> None:
|
||||||
if not msg:
|
if not msg:
|
||||||
msg = Printer(no_print=silent)
|
msg = Printer(no_print=silent)
|
||||||
ner_map = srsly.read_json(ner_map) if ner_map is not None else None
|
ner_map = srsly.read_json(ner_map) if ner_map is not None else None
|
||||||
|
for input_loc in walk_directory(Path(input_path)):
|
||||||
for input_loc in walk_directory(input_path):
|
|
||||||
input_data = input_loc.open("r", encoding="utf-8").read()
|
input_data = input_loc.open("r", encoding="utf-8").read()
|
||||||
# Use converter function to convert data
|
# Use converter function to convert data
|
||||||
func = CONVERTERS[converter]
|
func = CONVERTERS[converter]
|
||||||
|
@ -140,14 +138,14 @@ def convert(
|
||||||
msg.good(f"Generated output file ({len(docs)} documents): {output_file}")
|
msg.good(f"Generated output file ({len(docs)} documents): {output_file}")
|
||||||
|
|
||||||
|
|
||||||
def _print_docs_to_stdout(data, output_type):
|
def _print_docs_to_stdout(data: Any, output_type: str) -> None:
|
||||||
if output_type == "json":
|
if output_type == "json":
|
||||||
srsly.write_json("-", data)
|
srsly.write_json("-", data)
|
||||||
else:
|
else:
|
||||||
sys.stdout.buffer.write(data)
|
sys.stdout.buffer.write(data)
|
||||||
|
|
||||||
|
|
||||||
def _write_docs_to_file(data, output_file, output_type):
|
def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None:
|
||||||
if not output_file.parent.exists():
|
if not output_file.parent.exists():
|
||||||
output_file.parent.mkdir(parents=True)
|
output_file.parent.mkdir(parents=True)
|
||||||
if output_type == "json":
|
if output_type == "json":
|
||||||
|
@ -157,7 +155,7 @@ def _write_docs_to_file(data, output_file, output_type):
|
||||||
file_.write(data)
|
file_.write(data)
|
||||||
|
|
||||||
|
|
||||||
def autodetect_ner_format(input_data: str) -> str:
|
def autodetect_ner_format(input_data: str) -> Optional[str]:
|
||||||
# guess format from the first 20 lines
|
# guess format from the first 20 lines
|
||||||
lines = input_data.split("\n")[:20]
|
lines = input_data.split("\n")[:20]
|
||||||
format_guesses = {"ner": 0, "iob": 0}
|
format_guesses = {"ner": 0, "iob": 0}
|
||||||
|
@ -176,7 +174,7 @@ def autodetect_ner_format(input_data: str) -> str:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def walk_directory(path):
|
def walk_directory(path: Path) -> List[Path]:
|
||||||
if not path.is_dir():
|
if not path.is_dir():
|
||||||
return [path]
|
return [path]
|
||||||
paths = [path]
|
paths = [path]
|
||||||
|
@ -196,31 +194,25 @@ def walk_directory(path):
|
||||||
|
|
||||||
|
|
||||||
def verify_cli_args(
|
def verify_cli_args(
|
||||||
msg,
|
msg: Printer,
|
||||||
input_path,
|
input_path: Union[str, Path],
|
||||||
output_dir,
|
output_dir: Union[str, Path],
|
||||||
file_type,
|
file_type: FileTypes,
|
||||||
n_sents,
|
converter: str,
|
||||||
seg_sents,
|
ner_map: Optional[Path],
|
||||||
model,
|
|
||||||
morphology,
|
|
||||||
merge_subtokens,
|
|
||||||
converter,
|
|
||||||
ner_map,
|
|
||||||
lang,
|
|
||||||
):
|
):
|
||||||
input_path = Path(input_path)
|
input_path = Path(input_path)
|
||||||
if file_type not in FILE_TYPES_STDOUT and output_dir == "-":
|
if file_type not in FILE_TYPES_STDOUT and output_dir == "-":
|
||||||
# TODO: support msgpack via stdout in srsly?
|
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"Can't write .{file_type} data to stdout",
|
f"Can't write .{file_type} data to stdout. Please specify an output directory.",
|
||||||
"Please specify an output directory.",
|
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
if not input_path.exists():
|
if not input_path.exists():
|
||||||
msg.fail("Input file not found", input_path, exits=1)
|
msg.fail("Input file not found", input_path, exits=1)
|
||||||
if output_dir != "-" and not Path(output_dir).exists():
|
if output_dir != "-" and not Path(output_dir).exists():
|
||||||
msg.fail("Output directory not found", output_dir, exits=1)
|
msg.fail("Output directory not found", output_dir, exits=1)
|
||||||
|
if ner_map is not None and not Path(ner_map).exists():
|
||||||
|
msg.fail("NER map not found", ner_map, exits=1)
|
||||||
if input_path.is_dir():
|
if input_path.is_dir():
|
||||||
input_locs = walk_directory(input_path)
|
input_locs = walk_directory(input_path)
|
||||||
if len(input_locs) == 0:
|
if len(input_locs) == 0:
|
||||||
|
@ -229,10 +221,8 @@ def verify_cli_args(
|
||||||
if len(file_types) >= 2:
|
if len(file_types) >= 2:
|
||||||
file_types = ",".join(file_types)
|
file_types = ",".join(file_types)
|
||||||
msg.fail("All input files must be same type", file_types, exits=1)
|
msg.fail("All input files must be same type", file_types, exits=1)
|
||||||
converter = _get_converter(msg, converter, input_path)
|
if converter != "auto" and converter not in CONVERTERS:
|
||||||
if converter not in CONVERTERS:
|
|
||||||
msg.fail(f"Can't find converter for {converter}", exits=1)
|
msg.fail(f"Can't find converter for {converter}", exits=1)
|
||||||
return converter
|
|
||||||
|
|
||||||
|
|
||||||
def _get_converter(msg, converter, input_path):
|
def _get_converter(msg, converter, input_path):
|
||||||
|
|
|
@ -82,11 +82,11 @@ def evaluate(
|
||||||
"NER P": "ents_p",
|
"NER P": "ents_p",
|
||||||
"NER R": "ents_r",
|
"NER R": "ents_r",
|
||||||
"NER F": "ents_f",
|
"NER F": "ents_f",
|
||||||
"Textcat AUC": 'textcat_macro_auc',
|
"Textcat AUC": "textcat_macro_auc",
|
||||||
"Textcat F": 'textcat_macro_f',
|
"Textcat F": "textcat_macro_f",
|
||||||
"Sent P": 'sents_p',
|
"Sent P": "sents_p",
|
||||||
"Sent R": 'sents_r',
|
"Sent R": "sents_r",
|
||||||
"Sent F": 'sents_f',
|
"Sent F": "sents_f",
|
||||||
}
|
}
|
||||||
results = {}
|
results = {}
|
||||||
for metric, key in metrics.items():
|
for metric, key in metrics.items():
|
||||||
|
|
|
@ -14,7 +14,6 @@ import typer
|
||||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
from ._util import import_code
|
from ._util import import_code
|
||||||
from ..gold import Corpus, Example
|
from ..gold import Corpus, Example
|
||||||
from ..lookups import Lookups
|
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
|
|
@ -1,12 +1,5 @@
|
||||||
"""
|
"""Helpers for Python and platform compatibility."""
|
||||||
Helpers for Python and platform compatibility. To distinguish them from
|
|
||||||
the builtin functions, replacement functions are suffixed with an underscore,
|
|
||||||
e.g. `unicode_`.
|
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/top-level#compat
|
|
||||||
"""
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from thinc.util import copy_array
|
from thinc.util import copy_array
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -40,21 +33,3 @@ copy_array = copy_array
|
||||||
is_windows = sys.platform.startswith("win")
|
is_windows = sys.platform.startswith("win")
|
||||||
is_linux = sys.platform.startswith("linux")
|
is_linux = sys.platform.startswith("linux")
|
||||||
is_osx = sys.platform == "darwin"
|
is_osx = sys.platform == "darwin"
|
||||||
|
|
||||||
|
|
||||||
def is_config(windows=None, linux=None, osx=None, **kwargs):
|
|
||||||
"""Check if a specific configuration of Python version and operating system
|
|
||||||
matches the user's setup. Mostly used to display targeted error messages.
|
|
||||||
|
|
||||||
windows (bool): spaCy is executed on Windows.
|
|
||||||
linux (bool): spaCy is executed on Linux.
|
|
||||||
osx (bool): spaCy is executed on OS X or macOS.
|
|
||||||
RETURNS (bool): Whether the configuration matches the user's platform.
|
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/top-level#compat.is_config
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
windows in (None, is_windows)
|
|
||||||
and linux in (None, is_linux)
|
|
||||||
and osx in (None, is_osx)
|
|
||||||
)
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ spaCy's built in visualization suite for dependencies and named entities.
|
||||||
DOCS: https://spacy.io/api/top-level#displacy
|
DOCS: https://spacy.io/api/top-level#displacy
|
||||||
USAGE: https://spacy.io/usage/visualizers
|
USAGE: https://spacy.io/usage/visualizers
|
||||||
"""
|
"""
|
||||||
|
from typing import Union, Iterable, Optional, Dict, Any, Callable
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .render import DependencyRenderer, EntityRenderer
|
from .render import DependencyRenderer, EntityRenderer
|
||||||
|
@ -17,11 +18,17 @@ RENDER_WRAPPER = None
|
||||||
|
|
||||||
|
|
||||||
def render(
|
def render(
|
||||||
docs, style="dep", page=False, minify=False, jupyter=None, options={}, manual=False
|
docs: Union[Iterable[Doc], Doc],
|
||||||
):
|
style: str = "dep",
|
||||||
|
page: bool = False,
|
||||||
|
minify: bool = False,
|
||||||
|
jupyter: Optional[bool] = None,
|
||||||
|
options: Dict[str, Any] = {},
|
||||||
|
manual: bool = False,
|
||||||
|
) -> str:
|
||||||
"""Render displaCy visualisation.
|
"""Render displaCy visualisation.
|
||||||
|
|
||||||
docs (list or Doc): Document(s) to visualise.
|
docs (Union[Iterable[Doc], Doc]): Document(s) to visualise.
|
||||||
style (str): Visualisation style, 'dep' or 'ent'.
|
style (str): Visualisation style, 'dep' or 'ent'.
|
||||||
page (bool): Render markup as full HTML page.
|
page (bool): Render markup as full HTML page.
|
||||||
minify (bool): Minify HTML markup.
|
minify (bool): Minify HTML markup.
|
||||||
|
@ -44,8 +51,8 @@ def render(
|
||||||
docs = [obj if not isinstance(obj, Span) else obj.as_doc() for obj in docs]
|
docs = [obj if not isinstance(obj, Span) else obj.as_doc() for obj in docs]
|
||||||
if not all(isinstance(obj, (Doc, Span, dict)) for obj in docs):
|
if not all(isinstance(obj, (Doc, Span, dict)) for obj in docs):
|
||||||
raise ValueError(Errors.E096)
|
raise ValueError(Errors.E096)
|
||||||
renderer, converter = factories[style]
|
renderer_func, converter = factories[style]
|
||||||
renderer = renderer(options=options)
|
renderer = renderer_func(options=options)
|
||||||
parsed = [converter(doc, options) for doc in docs] if not manual else docs
|
parsed = [converter(doc, options) for doc in docs] if not manual else docs
|
||||||
_html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip()
|
_html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip()
|
||||||
html = _html["parsed"]
|
html = _html["parsed"]
|
||||||
|
@ -61,15 +68,15 @@ def render(
|
||||||
|
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
docs,
|
docs: Union[Iterable[Doc], Doc],
|
||||||
style="dep",
|
style: str = "dep",
|
||||||
page=True,
|
page: bool = True,
|
||||||
minify=False,
|
minify: bool = False,
|
||||||
options={},
|
options: Dict[str, Any] = {},
|
||||||
manual=False,
|
manual: bool = False,
|
||||||
port=5000,
|
port: int = 5000,
|
||||||
host="0.0.0.0",
|
host: str = "0.0.0.0",
|
||||||
):
|
) -> None:
|
||||||
"""Serve displaCy visualisation.
|
"""Serve displaCy visualisation.
|
||||||
|
|
||||||
docs (list or Doc): Document(s) to visualise.
|
docs (list or Doc): Document(s) to visualise.
|
||||||
|
@ -88,7 +95,6 @@ def serve(
|
||||||
|
|
||||||
if is_in_jupyter():
|
if is_in_jupyter():
|
||||||
warnings.warn(Warnings.W011)
|
warnings.warn(Warnings.W011)
|
||||||
|
|
||||||
render(docs, style=style, page=page, minify=minify, options=options, manual=manual)
|
render(docs, style=style, page=page, minify=minify, options=options, manual=manual)
|
||||||
httpd = simple_server.make_server(host, port, app)
|
httpd = simple_server.make_server(host, port, app)
|
||||||
print(f"\nUsing the '{style}' visualizer")
|
print(f"\nUsing the '{style}' visualizer")
|
||||||
|
@ -102,14 +108,13 @@ def serve(
|
||||||
|
|
||||||
|
|
||||||
def app(environ, start_response):
|
def app(environ, start_response):
|
||||||
# Headers and status need to be bytes in Python 2, see #1227
|
|
||||||
headers = [("Content-type", "text/html; charset=utf-8")]
|
headers = [("Content-type", "text/html; charset=utf-8")]
|
||||||
start_response("200 OK", headers)
|
start_response("200 OK", headers)
|
||||||
res = _html["parsed"].encode(encoding="utf-8")
|
res = _html["parsed"].encode(encoding="utf-8")
|
||||||
return [res]
|
return [res]
|
||||||
|
|
||||||
|
|
||||||
def parse_deps(orig_doc, options={}):
|
def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
|
||||||
"""Generate dependency parse in {'words': [], 'arcs': []} format.
|
"""Generate dependency parse in {'words': [], 'arcs': []} format.
|
||||||
|
|
||||||
doc (Doc): Document do parse.
|
doc (Doc): Document do parse.
|
||||||
|
@ -152,7 +157,6 @@ def parse_deps(orig_doc, options={}):
|
||||||
}
|
}
|
||||||
for w in doc
|
for w in doc
|
||||||
]
|
]
|
||||||
|
|
||||||
arcs = []
|
arcs = []
|
||||||
for word in doc:
|
for word in doc:
|
||||||
if word.i < word.head.i:
|
if word.i < word.head.i:
|
||||||
|
@ -171,7 +175,7 @@ def parse_deps(orig_doc, options={}):
|
||||||
return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)}
|
return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)}
|
||||||
|
|
||||||
|
|
||||||
def parse_ents(doc, options={}):
|
def parse_ents(doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
|
||||||
"""Generate named entities in [{start: i, end: i, label: 'label'}] format.
|
"""Generate named entities in [{start: i, end: i, label: 'label'}] format.
|
||||||
|
|
||||||
doc (Doc): Document do parse.
|
doc (Doc): Document do parse.
|
||||||
|
@ -188,7 +192,7 @@ def parse_ents(doc, options={}):
|
||||||
return {"text": doc.text, "ents": ents, "title": title, "settings": settings}
|
return {"text": doc.text, "ents": ents, "title": title, "settings": settings}
|
||||||
|
|
||||||
|
|
||||||
def set_render_wrapper(func):
|
def set_render_wrapper(func: Callable[[str], str]) -> None:
|
||||||
"""Set an optional wrapper function that is called around the generated
|
"""Set an optional wrapper function that is called around the generated
|
||||||
HTML markup on displacy.render. This can be used to allow integration into
|
HTML markup on displacy.render. This can be used to allow integration into
|
||||||
other platforms, similar to Jupyter Notebooks that require functions to be
|
other platforms, similar to Jupyter Notebooks that require functions to be
|
||||||
|
@ -205,7 +209,7 @@ def set_render_wrapper(func):
|
||||||
RENDER_WRAPPER = func
|
RENDER_WRAPPER = func
|
||||||
|
|
||||||
|
|
||||||
def get_doc_settings(doc):
|
def get_doc_settings(doc: Doc) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"lang": doc.lang_,
|
"lang": doc.lang_,
|
||||||
"direction": doc.vocab.writing_system.get("direction", "ltr"),
|
"direction": doc.vocab.writing_system.get("direction", "ltr"),
|
||||||
|
|
|
@ -1,19 +1,36 @@
|
||||||
|
from typing import Dict, Any, List, Optional, Union
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from .templates import (
|
from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_WORDS_LEMMA, TPL_DEP_ARCS
|
||||||
TPL_DEP_SVG,
|
|
||||||
TPL_DEP_WORDS,
|
|
||||||
TPL_DEP_WORDS_LEMMA,
|
|
||||||
TPL_DEP_ARCS,
|
|
||||||
TPL_ENTS,
|
|
||||||
)
|
|
||||||
from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE
|
from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE
|
||||||
|
from .templates import TPL_ENTS
|
||||||
from ..util import minify_html, escape_html, registry
|
from ..util import minify_html, escape_html, registry
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_LANG = "en"
|
DEFAULT_LANG = "en"
|
||||||
DEFAULT_DIR = "ltr"
|
DEFAULT_DIR = "ltr"
|
||||||
|
DEFAULT_ENTITY_COLOR = "#ddd"
|
||||||
|
DEFAULT_LABEL_COLORS = {
|
||||||
|
"ORG": "#7aecec",
|
||||||
|
"PRODUCT": "#bfeeb7",
|
||||||
|
"GPE": "#feca74",
|
||||||
|
"LOC": "#ff9561",
|
||||||
|
"PERSON": "#aa9cfc",
|
||||||
|
"NORP": "#c887fb",
|
||||||
|
"FACILITY": "#9cc9cc",
|
||||||
|
"EVENT": "#ffeb80",
|
||||||
|
"LAW": "#ff8197",
|
||||||
|
"LANGUAGE": "#ff8197",
|
||||||
|
"WORK_OF_ART": "#f0d0ff",
|
||||||
|
"DATE": "#bfe1d9",
|
||||||
|
"TIME": "#bfe1d9",
|
||||||
|
"MONEY": "#e4e7d2",
|
||||||
|
"QUANTITY": "#e4e7d2",
|
||||||
|
"ORDINAL": "#e4e7d2",
|
||||||
|
"CARDINAL": "#e4e7d2",
|
||||||
|
"PERCENT": "#e4e7d2",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class DependencyRenderer:
|
class DependencyRenderer:
|
||||||
|
@ -21,7 +38,7 @@ class DependencyRenderer:
|
||||||
|
|
||||||
style = "dep"
|
style = "dep"
|
||||||
|
|
||||||
def __init__(self, options={}):
|
def __init__(self, options: Dict[str, Any] = {}) -> None:
|
||||||
"""Initialise dependency renderer.
|
"""Initialise dependency renderer.
|
||||||
|
|
||||||
options (dict): Visualiser-specific options (compact, word_spacing,
|
options (dict): Visualiser-specific options (compact, word_spacing,
|
||||||
|
@ -41,7 +58,9 @@ class DependencyRenderer:
|
||||||
self.direction = DEFAULT_DIR
|
self.direction = DEFAULT_DIR
|
||||||
self.lang = DEFAULT_LANG
|
self.lang = DEFAULT_LANG
|
||||||
|
|
||||||
def render(self, parsed, page=False, minify=False):
|
def render(
|
||||||
|
self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False
|
||||||
|
) -> str:
|
||||||
"""Render complete markup.
|
"""Render complete markup.
|
||||||
|
|
||||||
parsed (list): Dependency parses to render.
|
parsed (list): Dependency parses to render.
|
||||||
|
@ -72,10 +91,15 @@ class DependencyRenderer:
|
||||||
return minify_html(markup)
|
return minify_html(markup)
|
||||||
return markup
|
return markup
|
||||||
|
|
||||||
def render_svg(self, render_id, words, arcs):
|
def render_svg(
|
||||||
|
self,
|
||||||
|
render_id: Union[int, str],
|
||||||
|
words: List[Dict[str, Any]],
|
||||||
|
arcs: List[Dict[str, Any]],
|
||||||
|
) -> str:
|
||||||
"""Render SVG.
|
"""Render SVG.
|
||||||
|
|
||||||
render_id (int): Unique ID, typically index of document.
|
render_id (Union[int, str]): Unique ID, typically index of document.
|
||||||
words (list): Individual words and their tags.
|
words (list): Individual words and their tags.
|
||||||
arcs (list): Individual arcs and their start, end, direction and label.
|
arcs (list): Individual arcs and their start, end, direction and label.
|
||||||
RETURNS (str): Rendered SVG markup.
|
RETURNS (str): Rendered SVG markup.
|
||||||
|
@ -86,15 +110,15 @@ class DependencyRenderer:
|
||||||
self.width = self.offset_x + len(words) * self.distance
|
self.width = self.offset_x + len(words) * self.distance
|
||||||
self.height = self.offset_y + 3 * self.word_spacing
|
self.height = self.offset_y + 3 * self.word_spacing
|
||||||
self.id = render_id
|
self.id = render_id
|
||||||
words = [
|
words_svg = [
|
||||||
self.render_word(w["text"], w["tag"], w.get("lemma", None), i)
|
self.render_word(w["text"], w["tag"], w.get("lemma", None), i)
|
||||||
for i, w in enumerate(words)
|
for i, w in enumerate(words)
|
||||||
]
|
]
|
||||||
arcs = [
|
arcs_svg = [
|
||||||
self.render_arrow(a["label"], a["start"], a["end"], a["dir"], i)
|
self.render_arrow(a["label"], a["start"], a["end"], a["dir"], i)
|
||||||
for i, a in enumerate(arcs)
|
for i, a in enumerate(arcs)
|
||||||
]
|
]
|
||||||
content = "".join(words) + "".join(arcs)
|
content = "".join(words_svg) + "".join(arcs_svg)
|
||||||
return TPL_DEP_SVG.format(
|
return TPL_DEP_SVG.format(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
width=self.width,
|
width=self.width,
|
||||||
|
@ -107,9 +131,7 @@ class DependencyRenderer:
|
||||||
lang=self.lang,
|
lang=self.lang,
|
||||||
)
|
)
|
||||||
|
|
||||||
def render_word(
|
def render_word(self, text: str, tag: str, lemma: str, i: int) -> str:
|
||||||
self, text, tag, lemma, i,
|
|
||||||
):
|
|
||||||
"""Render individual word.
|
"""Render individual word.
|
||||||
|
|
||||||
text (str): Word text.
|
text (str): Word text.
|
||||||
|
@ -128,7 +150,9 @@ class DependencyRenderer:
|
||||||
)
|
)
|
||||||
return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y)
|
return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y)
|
||||||
|
|
||||||
def render_arrow(self, label, start, end, direction, i):
|
def render_arrow(
|
||||||
|
self, label: str, start: int, end: int, direction: str, i: int
|
||||||
|
) -> str:
|
||||||
"""Render individual arrow.
|
"""Render individual arrow.
|
||||||
|
|
||||||
label (str): Dependency label.
|
label (str): Dependency label.
|
||||||
|
@ -172,7 +196,7 @@ class DependencyRenderer:
|
||||||
arc=arc,
|
arc=arc,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_arc(self, x_start, y, y_curve, x_end):
|
def get_arc(self, x_start: int, y: int, y_curve: int, x_end: int) -> str:
|
||||||
"""Render individual arc.
|
"""Render individual arc.
|
||||||
|
|
||||||
x_start (int): X-coordinate of arrow start point.
|
x_start (int): X-coordinate of arrow start point.
|
||||||
|
@ -186,7 +210,7 @@ class DependencyRenderer:
|
||||||
template = "M{x},{y} {x},{c} {e},{c} {e},{y}"
|
template = "M{x},{y} {x},{c} {e},{c} {e},{y}"
|
||||||
return template.format(x=x_start, y=y, c=y_curve, e=x_end)
|
return template.format(x=x_start, y=y, c=y_curve, e=x_end)
|
||||||
|
|
||||||
def get_arrowhead(self, direction, x, y, end):
|
def get_arrowhead(self, direction: str, x: int, y: int, end: int) -> str:
|
||||||
"""Render individual arrow head.
|
"""Render individual arrow head.
|
||||||
|
|
||||||
direction (str): Arrow direction, 'left' or 'right'.
|
direction (str): Arrow direction, 'left' or 'right'.
|
||||||
|
@ -196,24 +220,12 @@ class DependencyRenderer:
|
||||||
RETURNS (str): Definition of the arrow head path ('d' attribute).
|
RETURNS (str): Definition of the arrow head path ('d' attribute).
|
||||||
"""
|
"""
|
||||||
if direction == "left":
|
if direction == "left":
|
||||||
pos1, pos2, pos3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2)
|
p1, p2, p3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2)
|
||||||
else:
|
else:
|
||||||
pos1, pos2, pos3 = (
|
p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2)
|
||||||
end,
|
return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}"
|
||||||
end + self.arrow_width - 2,
|
|
||||||
end - self.arrow_width + 2,
|
|
||||||
)
|
|
||||||
arrowhead = (
|
|
||||||
pos1,
|
|
||||||
y + 2,
|
|
||||||
pos2,
|
|
||||||
y - self.arrow_width,
|
|
||||||
pos3,
|
|
||||||
y - self.arrow_width,
|
|
||||||
)
|
|
||||||
return "M{},{} L{},{} {},{}".format(*arrowhead)
|
|
||||||
|
|
||||||
def get_levels(self, arcs):
|
def get_levels(self, arcs: List[Dict[str, Any]]) -> List[int]:
|
||||||
"""Calculate available arc height "levels".
|
"""Calculate available arc height "levels".
|
||||||
Used to calculate arrow heights dynamically and without wasting space.
|
Used to calculate arrow heights dynamically and without wasting space.
|
||||||
|
|
||||||
|
@ -229,41 +241,21 @@ class EntityRenderer:
|
||||||
|
|
||||||
style = "ent"
|
style = "ent"
|
||||||
|
|
||||||
def __init__(self, options={}):
|
def __init__(self, options: Dict[str, Any] = {}) -> None:
|
||||||
"""Initialise dependency renderer.
|
"""Initialise dependency renderer.
|
||||||
|
|
||||||
options (dict): Visualiser-specific options (colors, ents)
|
options (dict): Visualiser-specific options (colors, ents)
|
||||||
"""
|
"""
|
||||||
colors = {
|
colors = dict(DEFAULT_LABEL_COLORS)
|
||||||
"ORG": "#7aecec",
|
|
||||||
"PRODUCT": "#bfeeb7",
|
|
||||||
"GPE": "#feca74",
|
|
||||||
"LOC": "#ff9561",
|
|
||||||
"PERSON": "#aa9cfc",
|
|
||||||
"NORP": "#c887fb",
|
|
||||||
"FACILITY": "#9cc9cc",
|
|
||||||
"EVENT": "#ffeb80",
|
|
||||||
"LAW": "#ff8197",
|
|
||||||
"LANGUAGE": "#ff8197",
|
|
||||||
"WORK_OF_ART": "#f0d0ff",
|
|
||||||
"DATE": "#bfe1d9",
|
|
||||||
"TIME": "#bfe1d9",
|
|
||||||
"MONEY": "#e4e7d2",
|
|
||||||
"QUANTITY": "#e4e7d2",
|
|
||||||
"ORDINAL": "#e4e7d2",
|
|
||||||
"CARDINAL": "#e4e7d2",
|
|
||||||
"PERCENT": "#e4e7d2",
|
|
||||||
}
|
|
||||||
user_colors = registry.displacy_colors.get_all()
|
user_colors = registry.displacy_colors.get_all()
|
||||||
for user_color in user_colors.values():
|
for user_color in user_colors.values():
|
||||||
colors.update(user_color)
|
colors.update(user_color)
|
||||||
colors.update(options.get("colors", {}))
|
colors.update(options.get("colors", {}))
|
||||||
self.default_color = "#ddd"
|
self.default_color = DEFAULT_ENTITY_COLOR
|
||||||
self.colors = colors
|
self.colors = colors
|
||||||
self.ents = options.get("ents", None)
|
self.ents = options.get("ents", None)
|
||||||
self.direction = DEFAULT_DIR
|
self.direction = DEFAULT_DIR
|
||||||
self.lang = DEFAULT_LANG
|
self.lang = DEFAULT_LANG
|
||||||
|
|
||||||
template = options.get("template")
|
template = options.get("template")
|
||||||
if template:
|
if template:
|
||||||
self.ent_template = template
|
self.ent_template = template
|
||||||
|
@ -273,7 +265,9 @@ class EntityRenderer:
|
||||||
else:
|
else:
|
||||||
self.ent_template = TPL_ENT
|
self.ent_template = TPL_ENT
|
||||||
|
|
||||||
def render(self, parsed, page=False, minify=False):
|
def render(
|
||||||
|
self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False
|
||||||
|
) -> str:
|
||||||
"""Render complete markup.
|
"""Render complete markup.
|
||||||
|
|
||||||
parsed (list): Dependency parses to render.
|
parsed (list): Dependency parses to render.
|
||||||
|
@ -297,7 +291,9 @@ class EntityRenderer:
|
||||||
return minify_html(markup)
|
return minify_html(markup)
|
||||||
return markup
|
return markup
|
||||||
|
|
||||||
def render_ents(self, text, spans, title):
|
def render_ents(
|
||||||
|
self, text: str, spans: List[Dict[str, Any]], title: Optional[str]
|
||||||
|
) -> str:
|
||||||
"""Render entities in text.
|
"""Render entities in text.
|
||||||
|
|
||||||
text (str): Original text.
|
text (str): Original text.
|
||||||
|
|
|
@ -483,6 +483,8 @@ class Errors:
|
||||||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
|
||||||
|
E954 = ("The Tok2Vec listener did not receive a valid input.")
|
||||||
E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.")
|
E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.")
|
||||||
E956 = ("Can't find component '{name}' in [components] block in the config. "
|
E956 = ("Can't find component '{name}' in [components] block in the config. "
|
||||||
"Available components: {opts}")
|
"Available components: {opts}")
|
||||||
|
|
|
@ -15,7 +15,7 @@ class Alignment:
|
||||||
x2y = _make_ragged(x2y)
|
x2y = _make_ragged(x2y)
|
||||||
y2x = _make_ragged(y2x)
|
y2x = _make_ragged(y2x)
|
||||||
return Alignment(x2y=x2y, y2x=y2x)
|
return Alignment(x2y=x2y, y2x=y2x)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_strings(cls, A: List[str], B: List[str]) -> "Alignment":
|
def from_strings(cls, A: List[str], B: List[str]) -> "Alignment":
|
||||||
x2y, y2x = tokenizations.get_alignments(A, B)
|
x2y, y2x = tokenizations.get_alignments(A, B)
|
||||||
|
|
|
@ -163,7 +163,7 @@ class Language:
|
||||||
else:
|
else:
|
||||||
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
|
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
|
||||||
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
|
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
|
||||||
self.vocab = vocab
|
self.vocab: Vocab = vocab
|
||||||
if self.lang is None:
|
if self.lang is None:
|
||||||
self.lang = self.vocab.lang
|
self.lang = self.vocab.lang
|
||||||
self.pipeline = []
|
self.pipeline = []
|
||||||
|
|
|
@ -8,7 +8,7 @@ def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1):
|
||||||
init=init,
|
init=init,
|
||||||
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
|
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
|
||||||
params={"W": None, "b": None, "pad": None},
|
params={"W": None, "b": None, "pad": None},
|
||||||
attrs={"dropout_rate": dropout}
|
attrs={"dropout_rate": dropout},
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -154,16 +154,30 @@ def LayerNormalizedMaxout(width, maxout_pieces):
|
||||||
def MultiHashEmbed(
|
def MultiHashEmbed(
|
||||||
columns, width, rows, use_subwords, pretrained_vectors, mix, dropout
|
columns, width, rows, use_subwords, pretrained_vectors, mix, dropout
|
||||||
):
|
):
|
||||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6)
|
norm = HashEmbed(
|
||||||
|
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6
|
||||||
|
)
|
||||||
if use_subwords:
|
if use_subwords:
|
||||||
prefix = HashEmbed(
|
prefix = HashEmbed(
|
||||||
nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout, seed=7
|
nO=width,
|
||||||
|
nV=rows // 2,
|
||||||
|
column=columns.index("PREFIX"),
|
||||||
|
dropout=dropout,
|
||||||
|
seed=7,
|
||||||
)
|
)
|
||||||
suffix = HashEmbed(
|
suffix = HashEmbed(
|
||||||
nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout, seed=8
|
nO=width,
|
||||||
|
nV=rows // 2,
|
||||||
|
column=columns.index("SUFFIX"),
|
||||||
|
dropout=dropout,
|
||||||
|
seed=8,
|
||||||
)
|
)
|
||||||
shape = HashEmbed(
|
shape = HashEmbed(
|
||||||
nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout, seed=9
|
nO=width,
|
||||||
|
nV=rows // 2,
|
||||||
|
column=columns.index("SHAPE"),
|
||||||
|
dropout=dropout,
|
||||||
|
seed=9,
|
||||||
)
|
)
|
||||||
|
|
||||||
if pretrained_vectors:
|
if pretrained_vectors:
|
||||||
|
@ -192,7 +206,9 @@ def MultiHashEmbed(
|
||||||
|
|
||||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||||
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
|
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
|
||||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5)
|
norm = HashEmbed(
|
||||||
|
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5
|
||||||
|
)
|
||||||
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
|
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||||
embed_layer = chr_embed | features >> with_array(norm)
|
embed_layer = chr_embed | features >> with_array(norm)
|
||||||
|
@ -263,21 +279,29 @@ def build_Tok2Vec_model(
|
||||||
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
||||||
norm = HashEmbed(
|
norm = HashEmbed(
|
||||||
nO=width, nV=embed_size, column=cols.index(NORM), dropout=None,
|
nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, seed=0
|
||||||
seed=0
|
|
||||||
)
|
)
|
||||||
if subword_features:
|
if subword_features:
|
||||||
prefix = HashEmbed(
|
prefix = HashEmbed(
|
||||||
nO=width, nV=embed_size // 2, column=cols.index(PREFIX), dropout=None,
|
nO=width,
|
||||||
seed=1
|
nV=embed_size // 2,
|
||||||
|
column=cols.index(PREFIX),
|
||||||
|
dropout=None,
|
||||||
|
seed=1,
|
||||||
)
|
)
|
||||||
suffix = HashEmbed(
|
suffix = HashEmbed(
|
||||||
nO=width, nV=embed_size // 2, column=cols.index(SUFFIX), dropout=None,
|
nO=width,
|
||||||
seed=2
|
nV=embed_size // 2,
|
||||||
|
column=cols.index(SUFFIX),
|
||||||
|
dropout=None,
|
||||||
|
seed=2,
|
||||||
)
|
)
|
||||||
shape = HashEmbed(
|
shape = HashEmbed(
|
||||||
nO=width, nV=embed_size // 2, column=cols.index(SHAPE), dropout=None,
|
nO=width,
|
||||||
seed=3
|
nV=embed_size // 2,
|
||||||
|
column=cols.index(SHAPE),
|
||||||
|
dropout=None,
|
||||||
|
seed=3,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefix, suffix, shape = (None, None, None)
|
prefix, suffix, shape = (None, None, None)
|
||||||
|
@ -294,11 +318,7 @@ def build_Tok2Vec_model(
|
||||||
embed = uniqued(
|
embed = uniqued(
|
||||||
(glove | norm | prefix | suffix | shape)
|
(glove | norm | prefix | suffix | shape)
|
||||||
>> Maxout(
|
>> Maxout(
|
||||||
nO=width,
|
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
||||||
nI=width * columns,
|
|
||||||
nP=3,
|
|
||||||
dropout=0.0,
|
|
||||||
normalize=True,
|
|
||||||
),
|
),
|
||||||
column=cols.index(ORTH),
|
column=cols.index(ORTH),
|
||||||
)
|
)
|
||||||
|
@ -307,11 +327,7 @@ def build_Tok2Vec_model(
|
||||||
embed = uniqued(
|
embed = uniqued(
|
||||||
(glove | norm)
|
(glove | norm)
|
||||||
>> Maxout(
|
>> Maxout(
|
||||||
nO=width,
|
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
||||||
nI=width * columns,
|
|
||||||
nP=3,
|
|
||||||
dropout=0.0,
|
|
||||||
normalize=True,
|
|
||||||
),
|
),
|
||||||
column=cols.index(ORTH),
|
column=cols.index(ORTH),
|
||||||
)
|
)
|
||||||
|
@ -320,11 +336,7 @@ def build_Tok2Vec_model(
|
||||||
embed = uniqued(
|
embed = uniqued(
|
||||||
concatenate(norm, prefix, suffix, shape)
|
concatenate(norm, prefix, suffix, shape)
|
||||||
>> Maxout(
|
>> Maxout(
|
||||||
nO=width,
|
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
||||||
nI=width * columns,
|
|
||||||
nP=3,
|
|
||||||
dropout=0.0,
|
|
||||||
normalize=True,
|
|
||||||
),
|
),
|
||||||
column=cols.index(ORTH),
|
column=cols.index(ORTH),
|
||||||
)
|
)
|
||||||
|
@ -333,11 +345,7 @@ def build_Tok2Vec_model(
|
||||||
cols
|
cols
|
||||||
) >> with_array(norm)
|
) >> with_array(norm)
|
||||||
reduce_dimensions = Maxout(
|
reduce_dimensions = Maxout(
|
||||||
nO=width,
|
nO=width, nI=nM * nC + width, nP=3, dropout=0.0, normalize=True,
|
||||||
nI=nM * nC + width,
|
|
||||||
nP=3,
|
|
||||||
dropout=0.0,
|
|
||||||
normalize=True,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embed = norm
|
embed = norm
|
||||||
|
|
|
@ -10,7 +10,6 @@ from .simple_ner import SimpleNER
|
||||||
from .tagger import Tagger
|
from .tagger import Tagger
|
||||||
from .textcat import TextCategorizer
|
from .textcat import TextCategorizer
|
||||||
from .tok2vec import Tok2Vec
|
from .tok2vec import Tok2Vec
|
||||||
from .hooks import SentenceSegmenter, SimilarityHook
|
|
||||||
from .functions import merge_entities, merge_noun_chunks, merge_subtokens
|
from .functions import merge_entities, merge_noun_chunks, merge_subtokens
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -21,9 +20,7 @@ __all__ = [
|
||||||
"Morphologizer",
|
"Morphologizer",
|
||||||
"Pipe",
|
"Pipe",
|
||||||
"SentenceRecognizer",
|
"SentenceRecognizer",
|
||||||
"SentenceSegmenter",
|
|
||||||
"Sentencizer",
|
"Sentencizer",
|
||||||
"SimilarityHook",
|
|
||||||
"SimpleNER",
|
"SimpleNER",
|
||||||
"Tagger",
|
"Tagger",
|
||||||
"TextCategorizer",
|
"TextCategorizer",
|
||||||
|
|
|
@ -63,7 +63,7 @@ class EntityRuler:
|
||||||
overwrite_ents: bool = False,
|
overwrite_ents: bool = False,
|
||||||
ent_id_sep: str = DEFAULT_ENT_ID_SEP,
|
ent_id_sep: str = DEFAULT_ENT_ID_SEP,
|
||||||
patterns: Optional[List[PatternType]] = None,
|
patterns: Optional[List[PatternType]] = None,
|
||||||
):
|
) -> None:
|
||||||
"""Initialize the entitiy ruler. If patterns are supplied here, they
|
"""Initialize the entitiy ruler. If patterns are supplied here, they
|
||||||
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
||||||
key. A pattern can either be a token pattern (list) or a phrase pattern
|
key. A pattern can either be a token pattern (list) or a phrase pattern
|
||||||
|
@ -239,7 +239,6 @@ class EntityRuler:
|
||||||
phrase_pattern_labels = []
|
phrase_pattern_labels = []
|
||||||
phrase_pattern_texts = []
|
phrase_pattern_texts = []
|
||||||
phrase_pattern_ids = []
|
phrase_pattern_ids = []
|
||||||
|
|
||||||
for entry in patterns:
|
for entry in patterns:
|
||||||
if isinstance(entry["pattern"], str):
|
if isinstance(entry["pattern"], str):
|
||||||
phrase_pattern_labels.append(entry["label"])
|
phrase_pattern_labels.append(entry["label"])
|
||||||
|
@ -247,7 +246,6 @@ class EntityRuler:
|
||||||
phrase_pattern_ids.append(entry.get("id"))
|
phrase_pattern_ids.append(entry.get("id"))
|
||||||
elif isinstance(entry["pattern"], list):
|
elif isinstance(entry["pattern"], list):
|
||||||
token_patterns.append(entry)
|
token_patterns.append(entry)
|
||||||
|
|
||||||
phrase_patterns = []
|
phrase_patterns = []
|
||||||
for label, pattern, ent_id in zip(
|
for label, pattern, ent_id in zip(
|
||||||
phrase_pattern_labels,
|
phrase_pattern_labels,
|
||||||
|
@ -258,7 +256,6 @@ class EntityRuler:
|
||||||
if ent_id:
|
if ent_id:
|
||||||
phrase_pattern["id"] = ent_id
|
phrase_pattern["id"] = ent_id
|
||||||
phrase_patterns.append(phrase_pattern)
|
phrase_patterns.append(phrase_pattern)
|
||||||
|
|
||||||
for entry in token_patterns + phrase_patterns:
|
for entry in token_patterns + phrase_patterns:
|
||||||
label = entry["label"]
|
label = entry["label"]
|
||||||
if "id" in entry:
|
if "id" in entry:
|
||||||
|
@ -266,7 +263,6 @@ class EntityRuler:
|
||||||
label = self._create_label(label, entry["id"])
|
label = self._create_label(label, entry["id"])
|
||||||
key = self.matcher._normalize_key(label)
|
key = self.matcher._normalize_key(label)
|
||||||
self._ent_ids[key] = (ent_label, entry["id"])
|
self._ent_ids[key] = (ent_label, entry["id"])
|
||||||
|
|
||||||
pattern = entry["pattern"]
|
pattern = entry["pattern"]
|
||||||
if isinstance(pattern, Doc):
|
if isinstance(pattern, Doc):
|
||||||
self.phrase_patterns[label].append(pattern)
|
self.phrase_patterns[label].append(pattern)
|
||||||
|
@ -323,13 +319,13 @@ class EntityRuler:
|
||||||
self.clear()
|
self.clear()
|
||||||
if isinstance(cfg, dict):
|
if isinstance(cfg, dict):
|
||||||
self.add_patterns(cfg.get("patterns", cfg))
|
self.add_patterns(cfg.get("patterns", cfg))
|
||||||
self.overwrite = cfg.get("overwrite")
|
self.overwrite = cfg.get("overwrite", False)
|
||||||
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
|
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
|
||||||
if self.phrase_matcher_attr is not None:
|
if self.phrase_matcher_attr is not None:
|
||||||
self.phrase_matcher = PhraseMatcher(
|
self.phrase_matcher = PhraseMatcher(
|
||||||
self.nlp.vocab, attr=self.phrase_matcher_attr
|
self.nlp.vocab, attr=self.phrase_matcher_attr
|
||||||
)
|
)
|
||||||
self.ent_id_sep = cfg.get("ent_id_sep")
|
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
|
||||||
else:
|
else:
|
||||||
self.add_patterns(cfg)
|
self.add_patterns(cfg)
|
||||||
return self
|
return self
|
||||||
|
@ -375,9 +371,9 @@ class EntityRuler:
|
||||||
}
|
}
|
||||||
deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
|
deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
|
||||||
from_disk(path, deserializers_cfg, {})
|
from_disk(path, deserializers_cfg, {})
|
||||||
self.overwrite = cfg.get("overwrite")
|
self.overwrite = cfg.get("overwrite", False)
|
||||||
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr")
|
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr")
|
||||||
self.ent_id_sep = cfg.get("ent_id_sep")
|
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
|
||||||
|
|
||||||
if self.phrase_matcher_attr is not None:
|
if self.phrase_matcher_attr is not None:
|
||||||
self.phrase_matcher = PhraseMatcher(
|
self.phrase_matcher = PhraseMatcher(
|
||||||
|
|
|
@ -1,95 +0,0 @@
|
||||||
from thinc.api import concatenate, reduce_max, reduce_mean, siamese, CauchySimilarity
|
|
||||||
|
|
||||||
from .pipe import Pipe
|
|
||||||
from ..util import link_vectors_to_models
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: do we want to keep these?
|
|
||||||
|
|
||||||
|
|
||||||
class SentenceSegmenter:
|
|
||||||
"""A simple spaCy hook, to allow custom sentence boundary detection logic
|
|
||||||
(that doesn't require the dependency parse). To change the sentence
|
|
||||||
boundary detection strategy, pass a generator function `strategy` on
|
|
||||||
initialization, or assign a new strategy to the .strategy attribute.
|
|
||||||
Sentence detection strategies should be generators that take `Doc` objects
|
|
||||||
and yield `Span` objects for each sentence.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, vocab, strategy=None):
|
|
||||||
self.vocab = vocab
|
|
||||||
if strategy is None or strategy == "on_punct":
|
|
||||||
strategy = self.split_on_punct
|
|
||||||
self.strategy = strategy
|
|
||||||
|
|
||||||
def __call__(self, doc):
|
|
||||||
doc.user_hooks["sents"] = self.strategy
|
|
||||||
return doc
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def split_on_punct(doc):
|
|
||||||
start = 0
|
|
||||||
seen_period = False
|
|
||||||
for i, token in enumerate(doc):
|
|
||||||
if seen_period and not token.is_punct:
|
|
||||||
yield doc[start : token.i]
|
|
||||||
start = token.i
|
|
||||||
seen_period = False
|
|
||||||
elif token.text in [".", "!", "?"]:
|
|
||||||
seen_period = True
|
|
||||||
if start < len(doc):
|
|
||||||
yield doc[start : len(doc)]
|
|
||||||
|
|
||||||
|
|
||||||
class SimilarityHook(Pipe):
|
|
||||||
"""
|
|
||||||
Experimental: A pipeline component to install a hook for supervised
|
|
||||||
similarity into `Doc` objects.
|
|
||||||
The similarity model can be any object obeying the Thinc `Model`
|
|
||||||
interface. By default, the model concatenates the elementwise mean and
|
|
||||||
elementwise max of the two tensors, and compares them using the
|
|
||||||
Cauchy-like similarity function from Chen (2013):
|
|
||||||
|
|
||||||
>>> similarity = 1. / (1. + (W * (vec1-vec2)**2).sum())
|
|
||||||
|
|
||||||
Where W is a vector of dimension weights, initialized to 1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, vocab, model=True, **cfg):
|
|
||||||
self.vocab = vocab
|
|
||||||
self.model = model
|
|
||||||
self.cfg = dict(cfg)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def Model(cls, length):
|
|
||||||
return siamese(
|
|
||||||
concatenate(reduce_max(), reduce_mean()), CauchySimilarity(length * 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, doc):
|
|
||||||
"""Install similarity hook"""
|
|
||||||
doc.user_hooks["similarity"] = self.predict
|
|
||||||
return doc
|
|
||||||
|
|
||||||
def pipe(self, docs, **kwargs):
|
|
||||||
for doc in docs:
|
|
||||||
yield self(doc)
|
|
||||||
|
|
||||||
def predict(self, doc1, doc2):
|
|
||||||
return self.model.predict([(doc1, doc2)])
|
|
||||||
|
|
||||||
def update(self, doc1_doc2, golds, sgd=None, drop=0.0):
|
|
||||||
sims, bp_sims = self.model.begin_update(doc1_doc2)
|
|
||||||
|
|
||||||
def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs):
|
|
||||||
"""Allocate model, using nO from the first model in the pipeline.
|
|
||||||
|
|
||||||
gold_tuples (iterable): Gold-standard training data.
|
|
||||||
pipeline (list): The pipeline the model is part of.
|
|
||||||
"""
|
|
||||||
if self.model is True:
|
|
||||||
self.model = self.Model(pipeline[0].model.get_dim("nO"))
|
|
||||||
link_vectors_to_models(self.vocab)
|
|
||||||
if sgd is None:
|
|
||||||
sgd = self.create_optimizer()
|
|
||||||
return sgd
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Iterable, Tuple, Optional, Dict, List, Callable
|
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Iterator, Any
|
||||||
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
@ -97,7 +97,7 @@ class TextCategorizer(Pipe):
|
||||||
def labels(self, value: Iterable[str]) -> None:
|
def labels(self, value: Iterable[str]) -> None:
|
||||||
self.cfg["labels"] = tuple(value)
|
self.cfg["labels"] = tuple(value)
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128):
|
def pipe(self, stream: Iterator[str], batch_size: int = 128) -> Iterator[Doc]:
|
||||||
for docs in util.minibatch(stream, size=batch_size):
|
for docs in util.minibatch(stream, size=batch_size):
|
||||||
scores = self.predict(docs)
|
scores = self.predict(docs)
|
||||||
self.set_annotations(docs, scores)
|
self.set_annotations(docs, scores)
|
||||||
|
@ -252,8 +252,17 @@ class TextCategorizer(Pipe):
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
||||||
def score(self, examples, positive_label=None, **kwargs):
|
def score(
|
||||||
return Scorer.score_cats(examples, "cats", labels=self.labels,
|
self,
|
||||||
|
examples: Iterable[Example],
|
||||||
|
positive_label: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return Scorer.score_cats(
|
||||||
|
examples,
|
||||||
|
"cats",
|
||||||
|
labels=self.labels,
|
||||||
multi_label=self.model.attrs["multi_label"],
|
multi_label=self.model.attrs["multi_label"],
|
||||||
positive_label=positive_label, **kwargs
|
positive_label=positive_label,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,6 +6,7 @@ from ..gold import Example
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
|
from ..errors import Errors
|
||||||
from ..util import link_vectors_to_models, minibatch
|
from ..util import link_vectors_to_models, minibatch
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,7 +151,7 @@ class Tok2Vec(Pipe):
|
||||||
self.set_annotations(docs, tokvecs)
|
self.set_annotations(docs, tokvecs)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def get_loss(self, examples, scores):
|
def get_loss(self, examples, scores) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def begin_training(
|
def begin_training(
|
||||||
|
@ -184,26 +185,26 @@ class Tok2VecListener(Model):
|
||||||
self._backprop = None
|
self._backprop = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_batch_id(cls, inputs):
|
def get_batch_id(cls, inputs) -> int:
|
||||||
return sum(sum(token.orth for token in doc) for doc in inputs)
|
return sum(sum(token.orth for token in doc) for doc in inputs)
|
||||||
|
|
||||||
def receive(self, batch_id, outputs, backprop):
|
def receive(self, batch_id: int, outputs, backprop) -> None:
|
||||||
self._batch_id = batch_id
|
self._batch_id = batch_id
|
||||||
self._outputs = outputs
|
self._outputs = outputs
|
||||||
self._backprop = backprop
|
self._backprop = backprop
|
||||||
|
|
||||||
def verify_inputs(self, inputs):
|
def verify_inputs(self, inputs) -> bool:
|
||||||
if self._batch_id is None and self._outputs is None:
|
if self._batch_id is None and self._outputs is None:
|
||||||
raise ValueError("The Tok2Vec listener did not receive valid input.")
|
raise ValueError(Errors.E954)
|
||||||
else:
|
else:
|
||||||
batch_id = self.get_batch_id(inputs)
|
batch_id = self.get_batch_id(inputs)
|
||||||
if batch_id != self._batch_id:
|
if batch_id != self._batch_id:
|
||||||
raise ValueError(f"Mismatched IDs! {batch_id} vs {self._batch_id}")
|
raise ValueError(Errors.E953.format(id1=batch_id, id2=self._batch_id))
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def forward(model: Tok2VecListener, inputs, is_train):
|
def forward(model: Tok2VecListener, inputs, is_train: bool):
|
||||||
if is_train:
|
if is_train:
|
||||||
model.verify_inputs(inputs)
|
model.verify_inputs(inputs)
|
||||||
return model._outputs, model._backprop
|
return model._outputs, model._backprop
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Dict, List, Union, Optional, Sequence, Any, Callable
|
from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field, ValidationError, validator
|
from pydantic import BaseModel, Field, ValidationError, validator
|
||||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||||
|
@ -9,12 +9,12 @@ from thinc.api import Optimizer
|
||||||
from .attrs import NAMES
|
from .attrs import NAMES
|
||||||
|
|
||||||
|
|
||||||
def validate(schema, obj):
|
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
|
||||||
"""Validate data against a given pydantic schema.
|
"""Validate data against a given pydantic schema.
|
||||||
|
|
||||||
obj (dict): JSON-serializable data to validate.
|
obj (Dict[str, Any]): JSON-serializable data to validate.
|
||||||
schema (pydantic.BaseModel): The schema to validate against.
|
schema (pydantic.BaseModel): The schema to validate against.
|
||||||
RETURNS (list): A list of error messages, if available.
|
RETURNS (List[str]): A list of error messages, if available.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
schema(**obj)
|
schema(**obj)
|
||||||
|
@ -31,7 +31,7 @@ def validate(schema, obj):
|
||||||
# Matcher token patterns
|
# Matcher token patterns
|
||||||
|
|
||||||
|
|
||||||
def validate_token_pattern(obj):
|
def validate_token_pattern(obj: list) -> List[str]:
|
||||||
# Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"})
|
# Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"})
|
||||||
get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k
|
get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k
|
||||||
if isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
|
|
|
@ -15,7 +15,8 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_NER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
ner = EntityRecognizer(en_vocab, model, **config)
|
ner = EntityRecognizer(en_vocab, model, **config)
|
||||||
ner.begin_training([])
|
ner.begin_training([])
|
||||||
ner(doc)
|
ner(doc)
|
||||||
|
@ -37,7 +38,8 @@ def test_ents_reset(en_vocab):
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_NER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
ner = EntityRecognizer(en_vocab, model, **config)
|
ner = EntityRecognizer(en_vocab, model, **config)
|
||||||
ner.begin_training([])
|
ner.begin_training([])
|
||||||
ner(doc)
|
ner(doc)
|
||||||
|
|
|
@ -14,7 +14,9 @@ def test_en_sbd_single_punct(en_tokenizer, text, punct):
|
||||||
assert sum(len(sent) for sent in doc.sents) == len(doc)
|
assert sum(len(sent) for sent in doc.sents) == len(doc)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
@pytest.mark.skip(
|
||||||
|
reason="The step_through API was removed (but should be brought back)"
|
||||||
|
)
|
||||||
def test_en_sentence_breaks(en_tokenizer, en_parser):
|
def test_en_sentence_breaks(en_tokenizer, en_parser):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
text = "This is a sentence . This is another one ."
|
text = "This is a sentence . This is another one ."
|
||||||
|
|
|
@ -32,6 +32,25 @@ SENTENCE_TESTS = [
|
||||||
("あれ。これ。", ["あれ。", "これ。"]),
|
("あれ。これ。", ["あれ。", "これ。"]),
|
||||||
("「伝染るんです。」という漫画があります。", ["「伝染るんです。」という漫画があります。"]),
|
("「伝染るんです。」という漫画があります。", ["「伝染るんです。」という漫画があります。"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
tokens1 = [
|
||||||
|
DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None),
|
||||||
|
DetailedToken(surface="会", tag="名詞-普通名詞-一般", inf="", lemma="会", reading="カイ", sub_tokens=None),
|
||||||
|
]
|
||||||
|
tokens2 = [
|
||||||
|
DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None),
|
||||||
|
DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None),
|
||||||
|
DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None),
|
||||||
|
DetailedToken(surface="会", tag="名詞-普通名詞-一般", inf="", lemma="会", reading="カイ", sub_tokens=None),
|
||||||
|
]
|
||||||
|
tokens3 = [
|
||||||
|
DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None),
|
||||||
|
DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None),
|
||||||
|
DetailedToken(surface="委員会", tag="名詞-普通名詞-一般", inf="", lemma="委員会", reading="イインカイ", sub_tokens=None),
|
||||||
|
]
|
||||||
|
SUB_TOKEN_TESTS = [
|
||||||
|
("選挙管理委員会", [None, None, None, None], [None, None, [tokens1]], [[tokens2, tokens3]])
|
||||||
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,33 +111,12 @@ def test_ja_tokenizer_split_modes(ja_tokenizer, text, len_a, len_b, len_c):
|
||||||
assert len(nlp_c(text)) == len_c
|
assert len(nlp_c(text)) == len_c
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c",
|
@pytest.mark.parametrize(
|
||||||
[
|
"text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c", SUB_TOKEN_TESTS,
|
||||||
(
|
|
||||||
"選挙管理委員会",
|
|
||||||
[None, None, None, None],
|
|
||||||
[None, None, [
|
|
||||||
[
|
|
||||||
DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None),
|
|
||||||
DetailedToken(surface='会', tag='名詞-普通名詞-一般', inf='', lemma='会', reading='カイ', sub_tokens=None),
|
|
||||||
]
|
|
||||||
]],
|
|
||||||
[[
|
|
||||||
[
|
|
||||||
DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None),
|
|
||||||
DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None),
|
|
||||||
DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None),
|
|
||||||
DetailedToken(surface='会', tag='名詞-普通名詞-一般', inf='', lemma='会', reading='カイ', sub_tokens=None),
|
|
||||||
], [
|
|
||||||
DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None),
|
|
||||||
DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None),
|
|
||||||
DetailedToken(surface='委員会', tag='名詞-普通名詞-一般', inf='', lemma='委員会', reading='イインカイ', sub_tokens=None),
|
|
||||||
]
|
|
||||||
]]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c):
|
def test_ja_tokenizer_sub_tokens(
|
||||||
|
ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c
|
||||||
|
):
|
||||||
nlp_a = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "A"}}})
|
nlp_a = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "A"}}})
|
||||||
nlp_b = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "B"}}})
|
nlp_b = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "B"}}})
|
||||||
nlp_c = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "C"}}})
|
nlp_c = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "C"}}})
|
||||||
|
@ -129,16 +127,19 @@ def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_toke
|
||||||
assert nlp_c(text).user_data["sub_tokens"] == sub_tokens_list_c
|
assert nlp_c(text).user_data["sub_tokens"] == sub_tokens_list_c
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("text,inflections,reading_forms",
|
@pytest.mark.parametrize(
|
||||||
|
"text,inflections,reading_forms",
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"取ってつけた",
|
"取ってつけた",
|
||||||
("五段-ラ行,連用形-促音便", "", "下一段-カ行,連用形-一般", "助動詞-タ,終止形-一般"),
|
("五段-ラ行,連用形-促音便", "", "下一段-カ行,連用形-一般", "助動詞-タ,終止形-一般"),
|
||||||
("トッ", "テ", "ツケ", "タ"),
|
("トッ", "テ", "ツケ", "タ"),
|
||||||
),
|
),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
def test_ja_tokenizer_inflections_reading_forms(ja_tokenizer, text, inflections, reading_forms):
|
def test_ja_tokenizer_inflections_reading_forms(
|
||||||
|
ja_tokenizer, text, inflections, reading_forms
|
||||||
|
):
|
||||||
assert ja_tokenizer(text).user_data["inflections"] == inflections
|
assert ja_tokenizer(text).user_data["inflections"] == inflections
|
||||||
assert ja_tokenizer(text).user_data["reading_forms"] == reading_forms
|
assert ja_tokenizer(text).user_data["reading_forms"] == reading_forms
|
||||||
|
|
||||||
|
|
|
@ -11,9 +11,8 @@ def test_ne_tokenizer_handlers_long_text(ne_tokenizer):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"text,length",
|
"text,length", [("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)],
|
||||||
[("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)],
|
|
||||||
)
|
)
|
||||||
def test_ne_tokenizer_handles_cnts(ne_tokenizer, text, length):
|
def test_ne_tokenizer_handles_cnts(ne_tokenizer, text, length):
|
||||||
tokens = ne_tokenizer(text)
|
tokens = ne_tokenizer(text)
|
||||||
assert len(tokens) == length
|
assert len(tokens) == length
|
||||||
|
|
|
@ -30,10 +30,7 @@ def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
|
||||||
nlp = Chinese(
|
nlp = Chinese(
|
||||||
meta={
|
meta={
|
||||||
"tokenizer": {
|
"tokenizer": {
|
||||||
"config": {
|
"config": {"segmenter": "pkuseg", "pkuseg_model": "medicine",}
|
||||||
"segmenter": "pkuseg",
|
|
||||||
"pkuseg_model": "medicine",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,7 +22,8 @@ def parser(vocab):
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
parser = DependencyParser(vocab, model, **config)
|
parser = DependencyParser(vocab, model, **config)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -68,7 +69,8 @@ def test_add_label_deserializes_correctly():
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_NER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
ner1 = EntityRecognizer(Vocab(), model, **config)
|
ner1 = EntityRecognizer(Vocab(), model, **config)
|
||||||
ner1.add_label("C")
|
ner1.add_label("C")
|
||||||
ner1.add_label("B")
|
ner1.add_label("B")
|
||||||
|
@ -86,7 +88,10 @@ def test_add_label_deserializes_correctly():
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"pipe_cls,n_moves,model_config",
|
"pipe_cls,n_moves,model_config",
|
||||||
[(DependencyParser, 5, DEFAULT_PARSER_MODEL), (EntityRecognizer, 4, DEFAULT_NER_MODEL)],
|
[
|
||||||
|
(DependencyParser, 5, DEFAULT_PARSER_MODEL),
|
||||||
|
(EntityRecognizer, 4, DEFAULT_NER_MODEL),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def test_add_label_get_label(pipe_cls, n_moves, model_config):
|
def test_add_label_get_label(pipe_cls, n_moves, model_config):
|
||||||
"""Test that added labels are returned correctly. This test was added to
|
"""Test that added labels are returned correctly. This test was added to
|
||||||
|
|
|
@ -126,7 +126,8 @@ def test_get_oracle_actions():
|
||||||
"min_action_freq": 0,
|
"min_action_freq": 0,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
parser = DependencyParser(doc.vocab, model, **config)
|
parser = DependencyParser(doc.vocab, model, **config)
|
||||||
parser.moves.add_action(0, "")
|
parser.moves.add_action(0, "")
|
||||||
parser.moves.add_action(1, "")
|
parser.moves.add_action(1, "")
|
||||||
|
|
|
@ -24,7 +24,8 @@ def arc_eager(vocab):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tok2vec():
|
def tok2vec():
|
||||||
tok2vec = registry.make_from_config({"model": DEFAULT_TOK2VEC_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_TOK2VEC_MODEL}
|
||||||
|
tok2vec = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
tok2vec.initialize()
|
tok2vec.initialize()
|
||||||
return tok2vec
|
return tok2vec
|
||||||
|
|
||||||
|
@ -36,13 +37,15 @@ def parser(vocab, arc_eager):
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
return Parser(vocab, model, moves=arc_eager, **config)
|
return Parser(vocab, model, moves=arc_eager, **config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model(arc_eager, tok2vec, vocab):
|
def model(arc_eager, tok2vec, vocab):
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
model.attrs["resize_output"](model, arc_eager.n_moves)
|
model.attrs["resize_output"](model, arc_eager.n_moves)
|
||||||
model.initialize()
|
model.initialize()
|
||||||
return model
|
return model
|
||||||
|
@ -68,7 +71,8 @@ def test_build_model(parser, vocab):
|
||||||
"min_action_freq": 0,
|
"min_action_freq": 0,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model
|
parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model
|
||||||
assert parser.model is not None
|
assert parser.model is not None
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,9 @@ def test_parser_root(en_tokenizer):
|
||||||
assert t.dep != 0, t.text
|
assert t.dep != 0, t.text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
@pytest.mark.skip(
|
||||||
|
reason="The step_through API was removed (but should be brought back)"
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("text", ["Hello"])
|
@pytest.mark.parametrize("text", ["Hello"])
|
||||||
def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text):
|
def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text):
|
||||||
tokens = en_tokenizer(text)
|
tokens = en_tokenizer(text)
|
||||||
|
@ -47,7 +49,9 @@ def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text):
|
||||||
assert doc[0].dep != 0
|
assert doc[0].dep != 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
@pytest.mark.skip(
|
||||||
|
reason="The step_through API was removed (but should be brought back)"
|
||||||
|
)
|
||||||
def test_parser_initial(en_tokenizer, en_parser):
|
def test_parser_initial(en_tokenizer, en_parser):
|
||||||
text = "I ate the pizza with anchovies."
|
text = "I ate the pizza with anchovies."
|
||||||
# heads = [1, 0, 1, -2, -3, -1, -5]
|
# heads = [1, 0, 1, -2, -3, -1, -5]
|
||||||
|
@ -92,7 +96,9 @@ def test_parser_merge_pp(en_tokenizer):
|
||||||
assert doc[3].text == "occurs"
|
assert doc[3].text == "occurs"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
@pytest.mark.skip(
|
||||||
|
reason="The step_through API was removed (but should be brought back)"
|
||||||
|
)
|
||||||
def test_parser_arc_eager_finalize_state(en_tokenizer, en_parser):
|
def test_parser_arc_eager_finalize_state(en_tokenizer, en_parser):
|
||||||
text = "a b c d e"
|
text = "a b c d e"
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,8 @@ def parser(vocab):
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
parser = DependencyParser(vocab, model, **config)
|
parser = DependencyParser(vocab, model, **config)
|
||||||
parser.cfg["token_vector_width"] = 4
|
parser.cfg["token_vector_width"] = 4
|
||||||
parser.cfg["hidden_width"] = 32
|
parser.cfg["hidden_width"] = 32
|
||||||
|
|
|
@ -28,7 +28,9 @@ def test_parser_sentence_space(en_tokenizer):
|
||||||
assert len(list(doc.sents)) == 2
|
assert len(list(doc.sents)) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
@pytest.mark.skip(
|
||||||
|
reason="The step_through API was removed (but should be brought back)"
|
||||||
|
)
|
||||||
def test_parser_space_attachment_leading(en_tokenizer, en_parser):
|
def test_parser_space_attachment_leading(en_tokenizer, en_parser):
|
||||||
text = "\t \n This is a sentence ."
|
text = "\t \n This is a sentence ."
|
||||||
heads = [1, 1, 0, 1, -2, -3]
|
heads = [1, 1, 0, 1, -2, -3]
|
||||||
|
@ -44,7 +46,9 @@ def test_parser_space_attachment_leading(en_tokenizer, en_parser):
|
||||||
assert stepwise.stack == set([2])
|
assert stepwise.stack == set([2])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
@pytest.mark.skip(
|
||||||
|
reason="The step_through API was removed (but should be brought back)"
|
||||||
|
)
|
||||||
def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser):
|
def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser):
|
||||||
text = "This is \t a \t\n \n sentence . \n\n \n"
|
text = "This is \t a \t\n \n sentence . \n\n \n"
|
||||||
heads = [1, 0, -1, 2, -1, -4, -5, -1]
|
heads = [1, 0, -1, 2, -1, -4, -5, -1]
|
||||||
|
@ -64,7 +68,9 @@ def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("text,length", [(["\n"], 1), (["\n", "\t", "\n\n", "\t"], 4)])
|
@pytest.mark.parametrize("text,length", [(["\n"], 1), (["\n", "\t", "\n\n", "\t"], 4)])
|
||||||
@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)")
|
@pytest.mark.skip(
|
||||||
|
reason="The step_through API was removed (but should be brought back)"
|
||||||
|
)
|
||||||
def test_parser_space_attachment_space(en_tokenizer, en_parser, text, length):
|
def test_parser_space_attachment_space(en_tokenizer, en_parser, text, length):
|
||||||
doc = Doc(en_parser.vocab, words=text)
|
doc = Doc(en_parser.vocab, words=text)
|
||||||
assert len(doc) == length
|
assert len(doc) == length
|
||||||
|
|
|
@ -117,7 +117,9 @@ def test_overfitting_IO():
|
||||||
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
|
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
|
||||||
|
|
||||||
# Test scoring
|
# Test scoring
|
||||||
scores = nlp.evaluate(train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}})
|
scores = nlp.evaluate(
|
||||||
|
train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}}
|
||||||
|
)
|
||||||
assert scores["cats_f"] == 1.0
|
assert scores["cats_f"] == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -201,7 +201,8 @@ def test_issue3345():
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_NER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
ner = EntityRecognizer(doc.vocab, model, **config)
|
ner = EntityRecognizer(doc.vocab, model, **config)
|
||||||
# Add the OUT action. I wouldn't have thought this would be necessary...
|
# Add the OUT action. I wouldn't have thought this would be necessary...
|
||||||
ner.moves.add_action(5, "")
|
ner.moves.add_action(5, "")
|
||||||
|
|
|
@ -20,7 +20,8 @@ def parser(en_vocab):
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
parser = DependencyParser(en_vocab, model, **config)
|
parser = DependencyParser(en_vocab, model, **config)
|
||||||
parser.add_label("nsubj")
|
parser.add_label("nsubj")
|
||||||
return parser
|
return parser
|
||||||
|
@ -33,14 +34,16 @@ def blank_parser(en_vocab):
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
parser = DependencyParser(en_vocab, model, **config)
|
parser = DependencyParser(en_vocab, model, **config)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def taggers(en_vocab):
|
def taggers(en_vocab):
|
||||||
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
tagger1 = Tagger(en_vocab, model, set_morphology=True)
|
tagger1 = Tagger(en_vocab, model, set_morphology=True)
|
||||||
tagger2 = Tagger(en_vocab, model, set_morphology=True)
|
tagger2 = Tagger(en_vocab, model, set_morphology=True)
|
||||||
return tagger1, tagger2
|
return tagger1, tagger2
|
||||||
|
@ -53,7 +56,8 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
|
||||||
"min_action_freq": 0,
|
"min_action_freq": 0,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
parser = Parser(en_vocab, model, **config)
|
parser = Parser(en_vocab, model, **config)
|
||||||
new_parser = Parser(en_vocab, model, **config)
|
new_parser = Parser(en_vocab, model, **config)
|
||||||
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
|
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
|
||||||
|
@ -70,7 +74,8 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
|
||||||
"min_action_freq": 0,
|
"min_action_freq": 0,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
parser = Parser(en_vocab, model, **config)
|
parser = Parser(en_vocab, model, **config)
|
||||||
with make_tempdir() as d:
|
with make_tempdir() as d:
|
||||||
file_path = d / "parser"
|
file_path = d / "parser"
|
||||||
|
@ -88,7 +93,6 @@ def test_to_from_bytes(parser, blank_parser):
|
||||||
assert blank_parser.model is not True
|
assert blank_parser.model is not True
|
||||||
assert blank_parser.moves.n_moves != parser.moves.n_moves
|
assert blank_parser.moves.n_moves != parser.moves.n_moves
|
||||||
bytes_data = parser.to_bytes(exclude=["vocab"])
|
bytes_data = parser.to_bytes(exclude=["vocab"])
|
||||||
|
|
||||||
# the blank parser needs to be resized before we can call from_bytes
|
# the blank parser needs to be resized before we can call from_bytes
|
||||||
blank_parser.model.attrs["resize_output"](blank_parser.model, parser.moves.n_moves)
|
blank_parser.model.attrs["resize_output"](blank_parser.model, parser.moves.n_moves)
|
||||||
blank_parser.from_bytes(bytes_data)
|
blank_parser.from_bytes(bytes_data)
|
||||||
|
@ -104,7 +108,8 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
|
||||||
tagger1_b = tagger1.to_bytes()
|
tagger1_b = tagger1.to_bytes()
|
||||||
tagger1 = tagger1.from_bytes(tagger1_b)
|
tagger1 = tagger1.from_bytes(tagger1_b)
|
||||||
assert tagger1.to_bytes() == tagger1_b
|
assert tagger1.to_bytes() == tagger1_b
|
||||||
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b)
|
new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b)
|
||||||
new_tagger1_b = new_tagger1.to_bytes()
|
new_tagger1_b = new_tagger1.to_bytes()
|
||||||
assert len(new_tagger1_b) == len(tagger1_b)
|
assert len(new_tagger1_b) == len(tagger1_b)
|
||||||
|
@ -118,7 +123,8 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
|
||||||
file_path2 = d / "tagger2"
|
file_path2 = d / "tagger2"
|
||||||
tagger1.to_disk(file_path1)
|
tagger1.to_disk(file_path1)
|
||||||
tagger2.to_disk(file_path2)
|
tagger2.to_disk(file_path2)
|
||||||
model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
tagger1_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path1)
|
tagger1_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path1)
|
||||||
tagger2_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path2)
|
tagger2_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path2)
|
||||||
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
|
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
|
||||||
|
@ -126,21 +132,22 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
|
||||||
|
|
||||||
def test_serialize_textcat_empty(en_vocab):
|
def test_serialize_textcat_empty(en_vocab):
|
||||||
# See issue #1105
|
# See issue #1105
|
||||||
model = registry.make_from_config({"model": DEFAULT_TEXTCAT_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
||||||
textcat = TextCategorizer(
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"]
|
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"])
|
||||||
)
|
|
||||||
textcat.to_bytes(exclude=["vocab"])
|
textcat.to_bytes(exclude=["vocab"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("Parser", test_parsers)
|
@pytest.mark.parametrize("Parser", test_parsers)
|
||||||
def test_serialize_pipe_exclude(en_vocab, Parser):
|
def test_serialize_pipe_exclude(en_vocab, Parser):
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
config = {
|
config = {
|
||||||
"learn_tokens": False,
|
"learn_tokens": False,
|
||||||
"min_action_freq": 0,
|
"min_action_freq": 0,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_new_parser():
|
def get_new_parser():
|
||||||
new_parser = Parser(en_vocab, model, **config)
|
new_parser = Parser(en_vocab, model, **config)
|
||||||
return new_parser
|
return new_parser
|
||||||
|
@ -160,7 +167,8 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
|
||||||
|
|
||||||
|
|
||||||
def test_serialize_sentencerecognizer(en_vocab):
|
def test_serialize_sentencerecognizer(en_vocab):
|
||||||
model = registry.make_from_config({"model": DEFAULT_SENTER_MODEL}, validate=True)["model"]
|
cfg = {"model": DEFAULT_SENTER_MODEL}
|
||||||
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
sr = SentenceRecognizer(en_vocab, model)
|
sr = SentenceRecognizer(en_vocab, model)
|
||||||
sr_b = sr.to_bytes()
|
sr_b = sr.to_bytes()
|
||||||
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)
|
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)
|
||||||
|
|
|
@ -466,7 +466,6 @@ def test_iob_to_biluo():
|
||||||
|
|
||||||
|
|
||||||
def test_roundtrip_docs_to_docbin(doc):
|
def test_roundtrip_docs_to_docbin(doc):
|
||||||
nlp = English()
|
|
||||||
text = doc.text
|
text = doc.text
|
||||||
idx = [t.idx for t in doc]
|
idx = [t.idx for t in doc]
|
||||||
tags = [t.tag_ for t in doc]
|
tags = [t.tag_ for t in doc]
|
||||||
|
|
|
@ -7,11 +7,11 @@ from spacy.vocab import Vocab
|
||||||
def test_Example_init_requires_doc_objects():
|
def test_Example_init_requires_doc_objects():
|
||||||
vocab = Vocab()
|
vocab = Vocab()
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
example = Example(None, None)
|
Example(None, None)
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
example = Example(Doc(vocab, words=["hi"]), None)
|
Example(Doc(vocab, words=["hi"]), None)
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
example = Example(None, Doc(vocab, words=["hi"]))
|
Example(None, Doc(vocab, words=["hi"]))
|
||||||
|
|
||||||
|
|
||||||
def test_Example_from_dict_basic():
|
def test_Example_from_dict_basic():
|
||||||
|
|
|
@ -105,7 +105,11 @@ def test_tokenization(sented_doc):
|
||||||
assert scores["token_acc"] == 1.0
|
assert scores["token_acc"] == 1.0
|
||||||
|
|
||||||
nlp = English()
|
nlp = English()
|
||||||
example.predicted = Doc(nlp.vocab, words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."], spaces=[True, True, True, True, True, False])
|
example.predicted = Doc(
|
||||||
|
nlp.vocab,
|
||||||
|
words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."],
|
||||||
|
spaces=[True, True, True, True, True, False],
|
||||||
|
)
|
||||||
example.predicted[1].is_sent_start = False
|
example.predicted[1].is_sent_start = False
|
||||||
scores = scorer.score([example])
|
scores = scorer.score([example])
|
||||||
assert scores["token_acc"] == approx(0.66666666)
|
assert scores["token_acc"] == approx(0.66666666)
|
||||||
|
|
193
spacy/util.py
193
spacy/util.py
|
@ -1,12 +1,13 @@
|
||||||
from typing import List, Union, Dict, Any, Optional, Iterable, Callable, Tuple
|
from typing import List, Union, Dict, Any, Optional, Iterable, Callable, Tuple
|
||||||
from typing import Iterator, TYPE_CHECKING
|
from typing import Iterator, Type, Pattern, Sequence, TYPE_CHECKING
|
||||||
|
from types import ModuleType
|
||||||
import os
|
import os
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import thinc
|
import thinc
|
||||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config
|
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import numpy.random
|
import numpy.random
|
||||||
|
@ -49,6 +50,8 @@ from . import about
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# This lets us add type hints for mypy etc. without causing circular imports
|
# This lets us add type hints for mypy etc. without causing circular imports
|
||||||
from .language import Language # noqa: F401
|
from .language import Language # noqa: F401
|
||||||
|
from .tokens import Doc, Span # noqa: F401
|
||||||
|
from .vocab import Vocab # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
_PRINT_ENV = False
|
_PRINT_ENV = False
|
||||||
|
@ -102,12 +105,12 @@ class SimpleFrozenDict(dict):
|
||||||
raise NotImplementedError(self.error)
|
raise NotImplementedError(self.error)
|
||||||
|
|
||||||
|
|
||||||
def set_env_log(value):
|
def set_env_log(value: bool) -> None:
|
||||||
global _PRINT_ENV
|
global _PRINT_ENV
|
||||||
_PRINT_ENV = value
|
_PRINT_ENV = value
|
||||||
|
|
||||||
|
|
||||||
def lang_class_is_loaded(lang):
|
def lang_class_is_loaded(lang: str) -> bool:
|
||||||
"""Check whether a Language class is already loaded. Language classes are
|
"""Check whether a Language class is already loaded. Language classes are
|
||||||
loaded lazily, to avoid expensive setup code associated with the language
|
loaded lazily, to avoid expensive setup code associated with the language
|
||||||
data.
|
data.
|
||||||
|
@ -118,7 +121,7 @@ def lang_class_is_loaded(lang):
|
||||||
return lang in registry.languages
|
return lang in registry.languages
|
||||||
|
|
||||||
|
|
||||||
def get_lang_class(lang):
|
def get_lang_class(lang: str) -> "Language":
|
||||||
"""Import and load a Language class.
|
"""Import and load a Language class.
|
||||||
|
|
||||||
lang (str): Two-letter language code, e.g. 'en'.
|
lang (str): Two-letter language code, e.g. 'en'.
|
||||||
|
@ -136,7 +139,7 @@ def get_lang_class(lang):
|
||||||
return registry.languages.get(lang)
|
return registry.languages.get(lang)
|
||||||
|
|
||||||
|
|
||||||
def set_lang_class(name, cls):
|
def set_lang_class(name: str, cls: Type["Language"]) -> None:
|
||||||
"""Set a custom Language class name that can be loaded via get_lang_class.
|
"""Set a custom Language class name that can be loaded via get_lang_class.
|
||||||
|
|
||||||
name (str): Name of Language class.
|
name (str): Name of Language class.
|
||||||
|
@ -145,10 +148,10 @@ def set_lang_class(name, cls):
|
||||||
registry.languages.register(name, func=cls)
|
registry.languages.register(name, func=cls)
|
||||||
|
|
||||||
|
|
||||||
def ensure_path(path):
|
def ensure_path(path: Any) -> Any:
|
||||||
"""Ensure string is converted to a Path.
|
"""Ensure string is converted to a Path.
|
||||||
|
|
||||||
path: Anything. If string, it's converted to Path.
|
path (Any): Anything. If string, it's converted to Path.
|
||||||
RETURNS: Path or original argument.
|
RETURNS: Path or original argument.
|
||||||
"""
|
"""
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
|
@ -157,7 +160,7 @@ def ensure_path(path):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def load_language_data(path):
|
def load_language_data(path: Union[str, Path]) -> Union[dict, list]:
|
||||||
"""Load JSON language data using the given path as a base. If the provided
|
"""Load JSON language data using the given path as a base. If the provided
|
||||||
path isn't present, will attempt to load a gzipped version before giving up.
|
path isn't present, will attempt to load a gzipped version before giving up.
|
||||||
|
|
||||||
|
@ -173,7 +176,12 @@ def load_language_data(path):
|
||||||
raise ValueError(Errors.E160.format(path=path))
|
raise ValueError(Errors.E160.format(path=path))
|
||||||
|
|
||||||
|
|
||||||
def get_module_path(module):
|
def get_module_path(module: ModuleType) -> Path:
|
||||||
|
"""Get the path of a Python module.
|
||||||
|
|
||||||
|
module (ModuleType): The Python module.
|
||||||
|
RETURNS (Path): The path.
|
||||||
|
"""
|
||||||
if not hasattr(module, "__module__"):
|
if not hasattr(module, "__module__"):
|
||||||
raise ValueError(Errors.E169.format(module=repr(module)))
|
raise ValueError(Errors.E169.format(module=repr(module)))
|
||||||
return Path(sys.modules[module.__module__].__file__).parent
|
return Path(sys.modules[module.__module__].__file__).parent
|
||||||
|
@ -183,7 +191,7 @@ def load_model(
|
||||||
name: Union[str, Path],
|
name: Union[str, Path],
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
):
|
) -> "Language":
|
||||||
"""Load a model from a package or data path.
|
"""Load a model from a package or data path.
|
||||||
|
|
||||||
name (str): Package name or model path.
|
name (str): Package name or model path.
|
||||||
|
@ -209,7 +217,7 @@ def load_model_from_package(
|
||||||
name: str,
|
name: str,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
):
|
) -> "Language":
|
||||||
"""Load a model from an installed package."""
|
"""Load a model from an installed package."""
|
||||||
cls = importlib.import_module(name)
|
cls = importlib.import_module(name)
|
||||||
return cls.load(disable=disable, component_cfg=component_cfg)
|
return cls.load(disable=disable, component_cfg=component_cfg)
|
||||||
|
@ -220,7 +228,7 @@ def load_model_from_path(
|
||||||
meta: Optional[Dict[str, Any]] = None,
|
meta: Optional[Dict[str, Any]] = None,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
):
|
) -> "Language":
|
||||||
"""Load a model from a data directory path. Creates Language class with
|
"""Load a model from a data directory path. Creates Language class with
|
||||||
pipeline from config.cfg and then calls from_disk() with path."""
|
pipeline from config.cfg and then calls from_disk() with path."""
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
|
@ -269,7 +277,7 @@ def load_model_from_init_py(
|
||||||
init_file: Union[Path, str],
|
init_file: Union[Path, str],
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
):
|
) -> "Language":
|
||||||
"""Helper function to use in the `load()` method of a model package's
|
"""Helper function to use in the `load()` method of a model package's
|
||||||
__init__.py.
|
__init__.py.
|
||||||
|
|
||||||
|
@ -288,15 +296,15 @@ def load_model_from_init_py(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_installed_models():
|
def get_installed_models() -> List[str]:
|
||||||
"""List all model packages currently installed in the environment.
|
"""List all model packages currently installed in the environment.
|
||||||
|
|
||||||
RETURNS (list): The string names of the models.
|
RETURNS (List[str]): The string names of the models.
|
||||||
"""
|
"""
|
||||||
return list(registry.models.get_all().keys())
|
return list(registry.models.get_all().keys())
|
||||||
|
|
||||||
|
|
||||||
def get_package_version(name):
|
def get_package_version(name: str) -> Optional[str]:
|
||||||
"""Get the version of an installed package. Typically used to get model
|
"""Get the version of an installed package. Typically used to get model
|
||||||
package versions.
|
package versions.
|
||||||
|
|
||||||
|
@ -309,7 +317,9 @@ def get_package_version(name):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def is_compatible_version(version, constraint, prereleases=True):
|
def is_compatible_version(
|
||||||
|
version: str, constraint: str, prereleases: bool = True
|
||||||
|
) -> Optional[bool]:
|
||||||
"""Check if a version (e.g. "2.0.0") is compatible given a version
|
"""Check if a version (e.g. "2.0.0") is compatible given a version
|
||||||
constraint (e.g. ">=1.9.0,<2.2.1"). If the constraint is a specific version,
|
constraint (e.g. ">=1.9.0,<2.2.1"). If the constraint is a specific version,
|
||||||
it's interpreted as =={version}.
|
it's interpreted as =={version}.
|
||||||
|
@ -333,7 +343,9 @@ def is_compatible_version(version, constraint, prereleases=True):
|
||||||
return version in spec
|
return version in spec
|
||||||
|
|
||||||
|
|
||||||
def is_unconstrained_version(constraint, prereleases=True):
|
def is_unconstrained_version(
|
||||||
|
constraint: str, prereleases: bool = True
|
||||||
|
) -> Optional[bool]:
|
||||||
# We have an exact version, this is the ultimate constrained version
|
# We have an exact version, this is the ultimate constrained version
|
||||||
if constraint[0].isdigit():
|
if constraint[0].isdigit():
|
||||||
return False
|
return False
|
||||||
|
@ -358,7 +370,7 @@ def is_unconstrained_version(constraint, prereleases=True):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_model_version_range(spacy_version):
|
def get_model_version_range(spacy_version: str) -> str:
|
||||||
"""Generate a version range like >=1.2.3,<1.3.0 based on a given spaCy
|
"""Generate a version range like >=1.2.3,<1.3.0 based on a given spaCy
|
||||||
version. Models are always compatible across patch versions but not
|
version. Models are always compatible across patch versions but not
|
||||||
across minor or major versions.
|
across minor or major versions.
|
||||||
|
@ -367,7 +379,7 @@ def get_model_version_range(spacy_version):
|
||||||
return f">={spacy_version},<{release[0]}.{release[1] + 1}.0"
|
return f">={spacy_version},<{release[0]}.{release[1] + 1}.0"
|
||||||
|
|
||||||
|
|
||||||
def get_base_version(version):
|
def get_base_version(version: str) -> str:
|
||||||
"""Generate the base version without any prerelease identifiers.
|
"""Generate the base version without any prerelease identifiers.
|
||||||
|
|
||||||
version (str): The version, e.g. "3.0.0.dev1".
|
version (str): The version, e.g. "3.0.0.dev1".
|
||||||
|
@ -376,11 +388,11 @@ def get_base_version(version):
|
||||||
return Version(version).base_version
|
return Version(version).base_version
|
||||||
|
|
||||||
|
|
||||||
def get_model_meta(path):
|
def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]:
|
||||||
"""Get model meta.json from a directory path and validate its contents.
|
"""Get model meta.json from a directory path and validate its contents.
|
||||||
|
|
||||||
path (str / Path): Path to model directory.
|
path (str / Path): Path to model directory.
|
||||||
RETURNS (dict): The model's meta data.
|
RETURNS (Dict[str, Any]): The model's meta data.
|
||||||
"""
|
"""
|
||||||
model_path = ensure_path(path)
|
model_path = ensure_path(path)
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
|
@ -412,7 +424,7 @@ def get_model_meta(path):
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
|
|
||||||
def is_package(name):
|
def is_package(name: str) -> bool:
|
||||||
"""Check if string maps to a package installed via pip.
|
"""Check if string maps to a package installed via pip.
|
||||||
|
|
||||||
name (str): Name of package.
|
name (str): Name of package.
|
||||||
|
@ -425,7 +437,7 @@ def is_package(name):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_package_path(name):
|
def get_package_path(name: str) -> Path:
|
||||||
"""Get the path to an installed package.
|
"""Get the path to an installed package.
|
||||||
|
|
||||||
name (str): Package name.
|
name (str): Package name.
|
||||||
|
@ -495,7 +507,7 @@ def working_dir(path: Union[str, Path]) -> None:
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def make_tempdir():
|
def make_tempdir() -> None:
|
||||||
"""Execute a block in a temporary directory and remove the directory and
|
"""Execute a block in a temporary directory and remove the directory and
|
||||||
its contents at the end of the with block.
|
its contents at the end of the with block.
|
||||||
|
|
||||||
|
@ -518,7 +530,7 @@ def is_cwd(path: Union[Path, str]) -> bool:
|
||||||
return str(Path(path).resolve()).lower() == str(Path.cwd().resolve()).lower()
|
return str(Path(path).resolve()).lower() == str(Path.cwd().resolve()).lower()
|
||||||
|
|
||||||
|
|
||||||
def is_in_jupyter():
|
def is_in_jupyter() -> bool:
|
||||||
"""Check if user is running spaCy from a Jupyter notebook by detecting the
|
"""Check if user is running spaCy from a Jupyter notebook by detecting the
|
||||||
IPython kernel. Mainly used for the displaCy visualizer.
|
IPython kernel. Mainly used for the displaCy visualizer.
|
||||||
RETURNS (bool): True if in Jupyter, False if not.
|
RETURNS (bool): True if in Jupyter, False if not.
|
||||||
|
@ -548,7 +560,9 @@ def get_object_name(obj: Any) -> str:
|
||||||
return repr(obj)
|
return repr(obj)
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_stream(require=False, non_blocking=True):
|
def get_cuda_stream(
|
||||||
|
require: bool = False, non_blocking: bool = True
|
||||||
|
) -> Optional[CudaStream]:
|
||||||
ops = get_current_ops()
|
ops = get_current_ops()
|
||||||
if CudaStream is None:
|
if CudaStream is None:
|
||||||
return None
|
return None
|
||||||
|
@ -567,7 +581,7 @@ def get_async(stream, numpy_array):
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
def env_opt(name, default=None):
|
def env_opt(name: str, default: Optional[Any] = None) -> Optional[Any]:
|
||||||
if type(default) is float:
|
if type(default) is float:
|
||||||
type_convert = float
|
type_convert = float
|
||||||
else:
|
else:
|
||||||
|
@ -588,7 +602,7 @@ def env_opt(name, default=None):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
def read_regex(path):
|
def read_regex(path: Union[str, Path]) -> Pattern:
|
||||||
path = ensure_path(path)
|
path = ensure_path(path)
|
||||||
with path.open(encoding="utf8") as file_:
|
with path.open(encoding="utf8") as file_:
|
||||||
entries = file_.read().split("\n")
|
entries = file_.read().split("\n")
|
||||||
|
@ -598,37 +612,40 @@ def read_regex(path):
|
||||||
return re.compile(expression)
|
return re.compile(expression)
|
||||||
|
|
||||||
|
|
||||||
def compile_prefix_regex(entries):
|
def compile_prefix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
|
||||||
"""Compile a sequence of prefix rules into a regex object.
|
"""Compile a sequence of prefix rules into a regex object.
|
||||||
|
|
||||||
entries (tuple): The prefix rules, e.g. spacy.lang.punctuation.TOKENIZER_PREFIXES.
|
entries (Iterable[Union[str, Pattern]]): The prefix rules, e.g.
|
||||||
RETURNS (regex object): The regex object. to be used for Tokenizer.prefix_search.
|
spacy.lang.punctuation.TOKENIZER_PREFIXES.
|
||||||
|
RETURNS (Pattern): The regex object. to be used for Tokenizer.prefix_search.
|
||||||
"""
|
"""
|
||||||
expression = "|".join(["^" + piece for piece in entries if piece.strip()])
|
expression = "|".join(["^" + piece for piece in entries if piece.strip()])
|
||||||
return re.compile(expression)
|
return re.compile(expression)
|
||||||
|
|
||||||
|
|
||||||
def compile_suffix_regex(entries):
|
def compile_suffix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
|
||||||
"""Compile a sequence of suffix rules into a regex object.
|
"""Compile a sequence of suffix rules into a regex object.
|
||||||
|
|
||||||
entries (tuple): The suffix rules, e.g. spacy.lang.punctuation.TOKENIZER_SUFFIXES.
|
entries (Iterable[Union[str, Pattern]]): The suffix rules, e.g.
|
||||||
RETURNS (regex object): The regex object. to be used for Tokenizer.suffix_search.
|
spacy.lang.punctuation.TOKENIZER_SUFFIXES.
|
||||||
|
RETURNS (Pattern): The regex object. to be used for Tokenizer.suffix_search.
|
||||||
"""
|
"""
|
||||||
expression = "|".join([piece + "$" for piece in entries if piece.strip()])
|
expression = "|".join([piece + "$" for piece in entries if piece.strip()])
|
||||||
return re.compile(expression)
|
return re.compile(expression)
|
||||||
|
|
||||||
|
|
||||||
def compile_infix_regex(entries):
|
def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
|
||||||
"""Compile a sequence of infix rules into a regex object.
|
"""Compile a sequence of infix rules into a regex object.
|
||||||
|
|
||||||
entries (tuple): The infix rules, e.g. spacy.lang.punctuation.TOKENIZER_INFIXES.
|
entries (Iterable[Union[str, Pattern]]): The infix rules, e.g.
|
||||||
|
spacy.lang.punctuation.TOKENIZER_INFIXES.
|
||||||
RETURNS (regex object): The regex object. to be used for Tokenizer.infix_finditer.
|
RETURNS (regex object): The regex object. to be used for Tokenizer.infix_finditer.
|
||||||
"""
|
"""
|
||||||
expression = "|".join([piece for piece in entries if piece.strip()])
|
expression = "|".join([piece for piece in entries if piece.strip()])
|
||||||
return re.compile(expression)
|
return re.compile(expression)
|
||||||
|
|
||||||
|
|
||||||
def add_lookups(default_func, *lookups):
|
def add_lookups(default_func: Callable[[str], Any], *lookups) -> Callable[[str], Any]:
|
||||||
"""Extend an attribute function with special cases. If a word is in the
|
"""Extend an attribute function with special cases. If a word is in the
|
||||||
lookups, the value is returned. Otherwise the previous function is used.
|
lookups, the value is returned. Otherwise the previous function is used.
|
||||||
|
|
||||||
|
@ -641,19 +658,23 @@ def add_lookups(default_func, *lookups):
|
||||||
return functools.partial(_get_attr_unless_lookup, default_func, lookups)
|
return functools.partial(_get_attr_unless_lookup, default_func, lookups)
|
||||||
|
|
||||||
|
|
||||||
def _get_attr_unless_lookup(default_func, lookups, string):
|
def _get_attr_unless_lookup(
|
||||||
|
default_func: Callable[[str], Any], lookups: Dict[str, Any], string: str
|
||||||
|
) -> Any:
|
||||||
for lookup in lookups:
|
for lookup in lookups:
|
||||||
if string in lookup:
|
if string in lookup:
|
||||||
return lookup[string]
|
return lookup[string]
|
||||||
return default_func(string)
|
return default_func(string)
|
||||||
|
|
||||||
|
|
||||||
def update_exc(base_exceptions, *addition_dicts):
|
def update_exc(
|
||||||
|
base_exceptions: Dict[str, List[dict]], *addition_dicts
|
||||||
|
) -> Dict[str, List[dict]]:
|
||||||
"""Update and validate tokenizer exceptions. Will overwrite exceptions.
|
"""Update and validate tokenizer exceptions. Will overwrite exceptions.
|
||||||
|
|
||||||
base_exceptions (dict): Base exceptions.
|
base_exceptions (Dict[str, List[dict]]): Base exceptions.
|
||||||
*addition_dicts (dict): Exceptions to add to the base dict, in order.
|
*addition_dicts (Dict[str, List[dict]]): Exceptions to add to the base dict, in order.
|
||||||
RETURNS (dict): Combined tokenizer exceptions.
|
RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions.
|
||||||
"""
|
"""
|
||||||
exc = dict(base_exceptions)
|
exc = dict(base_exceptions)
|
||||||
for additions in addition_dicts:
|
for additions in addition_dicts:
|
||||||
|
@ -668,14 +689,16 @@ def update_exc(base_exceptions, *addition_dicts):
|
||||||
return exc
|
return exc
|
||||||
|
|
||||||
|
|
||||||
def expand_exc(excs, search, replace):
|
def expand_exc(
|
||||||
|
excs: Dict[str, List[dict]], search: str, replace: str
|
||||||
|
) -> Dict[str, List[dict]]:
|
||||||
"""Find string in tokenizer exceptions, duplicate entry and replace string.
|
"""Find string in tokenizer exceptions, duplicate entry and replace string.
|
||||||
For example, to add additional versions with typographic apostrophes.
|
For example, to add additional versions with typographic apostrophes.
|
||||||
|
|
||||||
excs (dict): Tokenizer exceptions.
|
excs (Dict[str, List[dict]]): Tokenizer exceptions.
|
||||||
search (str): String to find and replace.
|
search (str): String to find and replace.
|
||||||
replace (str): Replacement.
|
replace (str): Replacement.
|
||||||
RETURNS (dict): Combined tokenizer exceptions.
|
RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _fix_token(token, search, replace):
|
def _fix_token(token, search, replace):
|
||||||
|
@ -692,7 +715,9 @@ def expand_exc(excs, search, replace):
|
||||||
return new_excs
|
return new_excs
|
||||||
|
|
||||||
|
|
||||||
def normalize_slice(length, start, stop, step=None):
|
def normalize_slice(
|
||||||
|
length: int, start: int, stop: int, step: Optional[int] = None
|
||||||
|
) -> Tuple[int, int]:
|
||||||
if not (step is None or step == 1):
|
if not (step is None or step == 1):
|
||||||
raise ValueError(Errors.E057)
|
raise ValueError(Errors.E057)
|
||||||
if start is None:
|
if start is None:
|
||||||
|
@ -708,7 +733,9 @@ def normalize_slice(length, start, stop, step=None):
|
||||||
return start, stop
|
return start, stop
|
||||||
|
|
||||||
|
|
||||||
def minibatch(items, size=8):
|
def minibatch(
|
||||||
|
items: Iterable[Any], size: Union[Iterator[int], int] = 8
|
||||||
|
) -> Iterator[Any]:
|
||||||
"""Iterate over batches of items. `size` may be an iterator,
|
"""Iterate over batches of items. `size` may be an iterator,
|
||||||
so that batch-size can vary on each step.
|
so that batch-size can vary on each step.
|
||||||
"""
|
"""
|
||||||
|
@ -725,7 +752,12 @@ def minibatch(items, size=8):
|
||||||
yield list(batch)
|
yield list(batch)
|
||||||
|
|
||||||
|
|
||||||
def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False):
|
def minibatch_by_padded_size(
|
||||||
|
docs: Iterator["Doc"],
|
||||||
|
size: Union[Iterator[int], int],
|
||||||
|
buffer: int = 256,
|
||||||
|
discard_oversize: bool = False,
|
||||||
|
) -> Iterator[Iterator["Doc"]]:
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size_ = itertools.repeat(size)
|
size_ = itertools.repeat(size)
|
||||||
else:
|
else:
|
||||||
|
@ -742,7 +774,7 @@ def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False):
|
||||||
yield subbatch
|
yield subbatch
|
||||||
|
|
||||||
|
|
||||||
def _batch_by_length(seqs, max_words):
|
def _batch_by_length(seqs: Sequence[Any], max_words: int) -> List[List[Any]]:
|
||||||
"""Given a list of sequences, return a batched list of indices into the
|
"""Given a list of sequences, return a batched list of indices into the
|
||||||
list, where the batches are grouped by length, in descending order.
|
list, where the batches are grouped by length, in descending order.
|
||||||
|
|
||||||
|
@ -783,14 +815,12 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||||
size_ = iter(size)
|
size_ = iter(size)
|
||||||
else:
|
else:
|
||||||
size_ = size
|
size_ = size
|
||||||
|
|
||||||
target_size = next(size_)
|
target_size = next(size_)
|
||||||
tol_size = target_size * tolerance
|
tol_size = target_size * tolerance
|
||||||
batch = []
|
batch = []
|
||||||
overflow = []
|
overflow = []
|
||||||
batch_size = 0
|
batch_size = 0
|
||||||
overflow_size = 0
|
overflow_size = 0
|
||||||
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
if isinstance(doc, Example):
|
if isinstance(doc, Example):
|
||||||
n_words = len(doc.reference)
|
n_words = len(doc.reference)
|
||||||
|
@ -803,17 +833,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||||
if n_words > target_size + tol_size:
|
if n_words > target_size + tol_size:
|
||||||
if not discard_oversize:
|
if not discard_oversize:
|
||||||
yield [doc]
|
yield [doc]
|
||||||
|
|
||||||
# add the example to the current batch if there's no overflow yet and it still fits
|
# add the example to the current batch if there's no overflow yet and it still fits
|
||||||
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
||||||
batch.append(doc)
|
batch.append(doc)
|
||||||
batch_size += n_words
|
batch_size += n_words
|
||||||
|
|
||||||
# add the example to the overflow buffer if it fits in the tolerance margin
|
# add the example to the overflow buffer if it fits in the tolerance margin
|
||||||
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
||||||
overflow.append(doc)
|
overflow.append(doc)
|
||||||
overflow_size += n_words
|
overflow_size += n_words
|
||||||
|
|
||||||
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||||
else:
|
else:
|
||||||
if batch:
|
if batch:
|
||||||
|
@ -824,17 +851,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||||
batch_size = overflow_size
|
batch_size = overflow_size
|
||||||
overflow = []
|
overflow = []
|
||||||
overflow_size = 0
|
overflow_size = 0
|
||||||
|
|
||||||
# this example still fits
|
# this example still fits
|
||||||
if (batch_size + n_words) <= target_size:
|
if (batch_size + n_words) <= target_size:
|
||||||
batch.append(doc)
|
batch.append(doc)
|
||||||
batch_size += n_words
|
batch_size += n_words
|
||||||
|
|
||||||
# this example fits in overflow
|
# this example fits in overflow
|
||||||
elif (batch_size + n_words) <= (target_size + tol_size):
|
elif (batch_size + n_words) <= (target_size + tol_size):
|
||||||
overflow.append(doc)
|
overflow.append(doc)
|
||||||
overflow_size += n_words
|
overflow_size += n_words
|
||||||
|
|
||||||
# this example does not fit with the previous overflow: start another new batch
|
# this example does not fit with the previous overflow: start another new batch
|
||||||
else:
|
else:
|
||||||
if batch:
|
if batch:
|
||||||
|
@ -843,20 +867,19 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||||
tol_size = target_size * tolerance
|
tol_size = target_size * tolerance
|
||||||
batch = [doc]
|
batch = [doc]
|
||||||
batch_size = n_words
|
batch_size = n_words
|
||||||
|
|
||||||
batch.extend(overflow)
|
batch.extend(overflow)
|
||||||
if batch:
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
def filter_spans(spans):
|
def filter_spans(spans: Iterable["Span"]) -> List["Span"]:
|
||||||
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
|
"""Filter a sequence of spans and remove duplicates or overlaps. Useful for
|
||||||
creating named entities (where one token can only be part of one entity) or
|
creating named entities (where one token can only be part of one entity) or
|
||||||
when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
|
when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
|
||||||
longest span is preferred over shorter spans.
|
longest span is preferred over shorter spans.
|
||||||
|
|
||||||
spans (iterable): The spans to filter.
|
spans (Iterable[Span]): The spans to filter.
|
||||||
RETURNS (list): The filtered spans.
|
RETURNS (List[Span]): The filtered spans.
|
||||||
"""
|
"""
|
||||||
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
||||||
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||||
|
@ -871,15 +894,21 @@ def filter_spans(spans):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def to_bytes(getters, exclude):
|
def to_bytes(getters: Dict[str, Callable[[], bytes]], exclude: Iterable[str]) -> bytes:
|
||||||
return srsly.msgpack_dumps(to_dict(getters, exclude))
|
return srsly.msgpack_dumps(to_dict(getters, exclude))
|
||||||
|
|
||||||
|
|
||||||
def from_bytes(bytes_data, setters, exclude):
|
def from_bytes(
|
||||||
|
bytes_data: bytes,
|
||||||
|
setters: Dict[str, Callable[[bytes], Any]],
|
||||||
|
exclude: Iterable[str],
|
||||||
|
) -> None:
|
||||||
return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude)
|
return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude)
|
||||||
|
|
||||||
|
|
||||||
def to_dict(getters, exclude):
|
def to_dict(
|
||||||
|
getters: Dict[str, Callable[[], Any]], exclude: Iterable[str]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
serialized = {}
|
serialized = {}
|
||||||
for key, getter in getters.items():
|
for key, getter in getters.items():
|
||||||
# Split to support file names like meta.json
|
# Split to support file names like meta.json
|
||||||
|
@ -888,7 +917,11 @@ def to_dict(getters, exclude):
|
||||||
return serialized
|
return serialized
|
||||||
|
|
||||||
|
|
||||||
def from_dict(msg, setters, exclude):
|
def from_dict(
|
||||||
|
msg: Dict[str, Any],
|
||||||
|
setters: Dict[str, Callable[[Any], Any]],
|
||||||
|
exclude: Iterable[str],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
for key, setter in setters.items():
|
for key, setter in setters.items():
|
||||||
# Split to support file names like meta.json
|
# Split to support file names like meta.json
|
||||||
if key.split(".")[0] not in exclude and key in msg:
|
if key.split(".")[0] not in exclude and key in msg:
|
||||||
|
@ -896,7 +929,11 @@ def from_dict(msg, setters, exclude):
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def to_disk(path, writers, exclude):
|
def to_disk(
|
||||||
|
path: Union[str, Path],
|
||||||
|
writers: Dict[str, Callable[[Path], None]],
|
||||||
|
exclude: Iterable[str],
|
||||||
|
) -> Path:
|
||||||
path = ensure_path(path)
|
path = ensure_path(path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
path.mkdir()
|
path.mkdir()
|
||||||
|
@ -907,7 +944,11 @@ def to_disk(path, writers, exclude):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def from_disk(path, readers, exclude):
|
def from_disk(
|
||||||
|
path: Union[str, Path],
|
||||||
|
readers: Dict[str, Callable[[Path], None]],
|
||||||
|
exclude: Iterable[str],
|
||||||
|
) -> Path:
|
||||||
path = ensure_path(path)
|
path = ensure_path(path)
|
||||||
for key, reader in readers.items():
|
for key, reader in readers.items():
|
||||||
# Split to support file names like meta.json
|
# Split to support file names like meta.json
|
||||||
|
@ -916,7 +957,7 @@ def from_disk(path, readers, exclude):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def import_file(name, loc):
|
def import_file(name: str, loc: Union[str, Path]) -> ModuleType:
|
||||||
"""Import module from a file. Used to load models from a directory.
|
"""Import module from a file. Used to load models from a directory.
|
||||||
|
|
||||||
name (str): Name of module to load.
|
name (str): Name of module to load.
|
||||||
|
@ -930,7 +971,7 @@ def import_file(name, loc):
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def minify_html(html):
|
def minify_html(html: str) -> str:
|
||||||
"""Perform a template-specific, rudimentary HTML minification for displaCy.
|
"""Perform a template-specific, rudimentary HTML minification for displaCy.
|
||||||
Disclaimer: NOT a general-purpose solution, only removes indentation and
|
Disclaimer: NOT a general-purpose solution, only removes indentation and
|
||||||
newlines.
|
newlines.
|
||||||
|
@ -941,7 +982,7 @@ def minify_html(html):
|
||||||
return html.strip().replace(" ", "").replace("\n", "")
|
return html.strip().replace(" ", "").replace("\n", "")
|
||||||
|
|
||||||
|
|
||||||
def escape_html(text):
|
def escape_html(text: str) -> str:
|
||||||
"""Replace <, >, &, " with their HTML encoded representation. Intended to
|
"""Replace <, >, &, " with their HTML encoded representation. Intended to
|
||||||
prevent HTML errors in rendered displaCy markup.
|
prevent HTML errors in rendered displaCy markup.
|
||||||
|
|
||||||
|
@ -955,7 +996,9 @@ def escape_html(text):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def get_words_and_spaces(words, text):
|
def get_words_and_spaces(
|
||||||
|
words: Iterable[str], text: str
|
||||||
|
) -> Tuple[List[str], List[bool]]:
|
||||||
if "".join("".join(words).split()) != "".join(text.split()):
|
if "".join("".join(words).split()) != "".join(text.split()):
|
||||||
raise ValueError(Errors.E194.format(text=text, words=words))
|
raise ValueError(Errors.E194.format(text=text, words=words))
|
||||||
text_words = []
|
text_words = []
|
||||||
|
@ -1103,7 +1146,7 @@ class DummyTokenizer:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def link_vectors_to_models(vocab):
|
def link_vectors_to_models(vocab: "Vocab") -> None:
|
||||||
vectors = vocab.vectors
|
vectors = vocab.vectors
|
||||||
if vectors.name is None:
|
if vectors.name is None:
|
||||||
vectors.name = VECTORS_KEY
|
vectors.name = VECTORS_KEY
|
||||||
|
@ -1119,7 +1162,7 @@ def link_vectors_to_models(vocab):
|
||||||
VECTORS_KEY = "spacy_pretrained_vectors"
|
VECTORS_KEY = "spacy_pretrained_vectors"
|
||||||
|
|
||||||
|
|
||||||
def create_default_optimizer():
|
def create_default_optimizer() -> Optimizer:
|
||||||
learn_rate = env_opt("learn_rate", 0.001)
|
learn_rate = env_opt("learn_rate", 0.001)
|
||||||
beta1 = env_opt("optimizer_B1", 0.9)
|
beta1 = env_opt("optimizer_B1", 0.9)
|
||||||
beta2 = env_opt("optimizer_B2", 0.999)
|
beta2 = env_opt("optimizer_B2", 0.999)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user