diff --git a/.gitignore b/.gitignore index ab97d9b6b..087163761 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,7 @@ env3.*/ .denv .pypyenv .pytest_cache/ +.mypy_cache/ # Distribution / packaging env/ diff --git a/setup.cfg b/setup.cfg index dbe2c25fd..0ff26ab77 100644 --- a/setup.cfg +++ b/setup.cfg @@ -108,3 +108,8 @@ exclude = [tool:pytest] markers = slow + +[mypy] +ignore_missing_imports = True +no_implicit_optional = True +plugins = pydantic.mypy, thinc.mypy diff --git a/spacy/cli/convert.py b/spacy/cli/convert.py index 059871edd..864051240 100644 --- a/spacy/cli/convert.py +++ b/spacy/cli/convert.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Any, List, Union from enum import Enum from pathlib import Path from wasabi import Printer @@ -66,10 +66,9 @@ def convert_cli( file_type = file_type.value input_path = Path(input_path) output_dir = "-" if output_dir == Path("-") else output_dir - cli_args = locals() silent = output_dir == "-" msg = Printer(no_print=silent) - verify_cli_args(msg, **cli_args) + verify_cli_args(msg, input_path, output_dir, file_type, converter, ner_map) converter = _get_converter(msg, converter, input_path) convert( input_path, @@ -89,8 +88,8 @@ def convert_cli( def convert( - input_path: Path, - output_dir: Path, + input_path: Union[str, Path], + output_dir: Union[str, Path], *, file_type: str = "json", n_sents: int = 1, @@ -102,13 +101,12 @@ def convert( ner_map: Optional[Path] = None, lang: Optional[str] = None, silent: bool = True, - msg: Optional[Path] = None, + msg: Optional[Printer], ) -> None: if not msg: msg = Printer(no_print=silent) ner_map = srsly.read_json(ner_map) if ner_map is not None else None - - for input_loc in walk_directory(input_path): + for input_loc in walk_directory(Path(input_path)): input_data = input_loc.open("r", encoding="utf-8").read() # Use converter function to convert data func = CONVERTERS[converter] @@ -140,14 +138,14 @@ def convert( msg.good(f"Generated output file ({len(docs)} documents): {output_file}") -def _print_docs_to_stdout(data, output_type): +def _print_docs_to_stdout(data: Any, output_type: str) -> None: if output_type == "json": srsly.write_json("-", data) else: sys.stdout.buffer.write(data) -def _write_docs_to_file(data, output_file, output_type): +def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None: if not output_file.parent.exists(): output_file.parent.mkdir(parents=True) if output_type == "json": @@ -157,7 +155,7 @@ def _write_docs_to_file(data, output_file, output_type): file_.write(data) -def autodetect_ner_format(input_data: str) -> str: +def autodetect_ner_format(input_data: str) -> Optional[str]: # guess format from the first 20 lines lines = input_data.split("\n")[:20] format_guesses = {"ner": 0, "iob": 0} @@ -176,7 +174,7 @@ def autodetect_ner_format(input_data: str) -> str: return None -def walk_directory(path): +def walk_directory(path: Path) -> List[Path]: if not path.is_dir(): return [path] paths = [path] @@ -196,31 +194,25 @@ def walk_directory(path): def verify_cli_args( - msg, - input_path, - output_dir, - file_type, - n_sents, - seg_sents, - model, - morphology, - merge_subtokens, - converter, - ner_map, - lang, + msg: Printer, + input_path: Union[str, Path], + output_dir: Union[str, Path], + file_type: FileTypes, + converter: str, + ner_map: Optional[Path], ): input_path = Path(input_path) if file_type not in FILE_TYPES_STDOUT and output_dir == "-": - # TODO: support msgpack via stdout in srsly? msg.fail( - f"Can't write .{file_type} data to stdout", - "Please specify an output directory.", + f"Can't write .{file_type} data to stdout. Please specify an output directory.", exits=1, ) if not input_path.exists(): msg.fail("Input file not found", input_path, exits=1) if output_dir != "-" and not Path(output_dir).exists(): msg.fail("Output directory not found", output_dir, exits=1) + if ner_map is not None and not Path(ner_map).exists(): + msg.fail("NER map not found", ner_map, exits=1) if input_path.is_dir(): input_locs = walk_directory(input_path) if len(input_locs) == 0: @@ -229,10 +221,8 @@ def verify_cli_args( if len(file_types) >= 2: file_types = ",".join(file_types) msg.fail("All input files must be same type", file_types, exits=1) - converter = _get_converter(msg, converter, input_path) - if converter not in CONVERTERS: + if converter != "auto" and converter not in CONVERTERS: msg.fail(f"Can't find converter for {converter}", exits=1) - return converter def _get_converter(msg, converter, input_path): diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py index 66b22b131..5cdbee065 100644 --- a/spacy/cli/evaluate.py +++ b/spacy/cli/evaluate.py @@ -82,11 +82,11 @@ def evaluate( "NER P": "ents_p", "NER R": "ents_r", "NER F": "ents_f", - "Textcat AUC": 'textcat_macro_auc', - "Textcat F": 'textcat_macro_f', - "Sent P": 'sents_p', - "Sent R": 'sents_r', - "Sent F": 'sents_f', + "Textcat AUC": "textcat_macro_auc", + "Textcat F": "textcat_macro_f", + "Sent P": "sents_p", + "Sent R": "sents_r", + "Sent F": "sents_f", } results = {} for metric, key in metrics.items(): diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 3e34176b4..21fd0eb72 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -14,7 +14,6 @@ import typer from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error from ._util import import_code from ..gold import Corpus, Example -from ..lookups import Lookups from ..language import Language from .. import util from ..errors import Errors diff --git a/spacy/compat.py b/spacy/compat.py index d8377633f..2d51ff0ae 100644 --- a/spacy/compat.py +++ b/spacy/compat.py @@ -1,12 +1,5 @@ -""" -Helpers for Python and platform compatibility. To distinguish them from -the builtin functions, replacement functions are suffixed with an underscore, -e.g. `unicode_`. - -DOCS: https://spacy.io/api/top-level#compat -""" +"""Helpers for Python and platform compatibility.""" import sys - from thinc.util import copy_array try: @@ -40,21 +33,3 @@ copy_array = copy_array is_windows = sys.platform.startswith("win") is_linux = sys.platform.startswith("linux") is_osx = sys.platform == "darwin" - - -def is_config(windows=None, linux=None, osx=None, **kwargs): - """Check if a specific configuration of Python version and operating system - matches the user's setup. Mostly used to display targeted error messages. - - windows (bool): spaCy is executed on Windows. - linux (bool): spaCy is executed on Linux. - osx (bool): spaCy is executed on OS X or macOS. - RETURNS (bool): Whether the configuration matches the user's platform. - - DOCS: https://spacy.io/api/top-level#compat.is_config - """ - return ( - windows in (None, is_windows) - and linux in (None, is_linux) - and osx in (None, is_osx) - ) diff --git a/spacy/displacy/__init__.py b/spacy/displacy/__init__.py index 2c377a043..3f885f09f 100644 --- a/spacy/displacy/__init__.py +++ b/spacy/displacy/__init__.py @@ -4,6 +4,7 @@ spaCy's built in visualization suite for dependencies and named entities. DOCS: https://spacy.io/api/top-level#displacy USAGE: https://spacy.io/usage/visualizers """ +from typing import Union, Iterable, Optional, Dict, Any, Callable import warnings from .render import DependencyRenderer, EntityRenderer @@ -17,11 +18,17 @@ RENDER_WRAPPER = None def render( - docs, style="dep", page=False, minify=False, jupyter=None, options={}, manual=False -): + docs: Union[Iterable[Doc], Doc], + style: str = "dep", + page: bool = False, + minify: bool = False, + jupyter: Optional[bool] = None, + options: Dict[str, Any] = {}, + manual: bool = False, +) -> str: """Render displaCy visualisation. - docs (list or Doc): Document(s) to visualise. + docs (Union[Iterable[Doc], Doc]): Document(s) to visualise. style (str): Visualisation style, 'dep' or 'ent'. page (bool): Render markup as full HTML page. minify (bool): Minify HTML markup. @@ -44,8 +51,8 @@ def render( docs = [obj if not isinstance(obj, Span) else obj.as_doc() for obj in docs] if not all(isinstance(obj, (Doc, Span, dict)) for obj in docs): raise ValueError(Errors.E096) - renderer, converter = factories[style] - renderer = renderer(options=options) + renderer_func, converter = factories[style] + renderer = renderer_func(options=options) parsed = [converter(doc, options) for doc in docs] if not manual else docs _html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip() html = _html["parsed"] @@ -61,15 +68,15 @@ def render( def serve( - docs, - style="dep", - page=True, - minify=False, - options={}, - manual=False, - port=5000, - host="0.0.0.0", -): + docs: Union[Iterable[Doc], Doc], + style: str = "dep", + page: bool = True, + minify: bool = False, + options: Dict[str, Any] = {}, + manual: bool = False, + port: int = 5000, + host: str = "0.0.0.0", +) -> None: """Serve displaCy visualisation. docs (list or Doc): Document(s) to visualise. @@ -88,7 +95,6 @@ def serve( if is_in_jupyter(): warnings.warn(Warnings.W011) - render(docs, style=style, page=page, minify=minify, options=options, manual=manual) httpd = simple_server.make_server(host, port, app) print(f"\nUsing the '{style}' visualizer") @@ -102,14 +108,13 @@ def serve( def app(environ, start_response): - # Headers and status need to be bytes in Python 2, see #1227 headers = [("Content-type", "text/html; charset=utf-8")] start_response("200 OK", headers) res = _html["parsed"].encode(encoding="utf-8") return [res] -def parse_deps(orig_doc, options={}): +def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]: """Generate dependency parse in {'words': [], 'arcs': []} format. doc (Doc): Document do parse. @@ -152,7 +157,6 @@ def parse_deps(orig_doc, options={}): } for w in doc ] - arcs = [] for word in doc: if word.i < word.head.i: @@ -171,7 +175,7 @@ def parse_deps(orig_doc, options={}): return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)} -def parse_ents(doc, options={}): +def parse_ents(doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]: """Generate named entities in [{start: i, end: i, label: 'label'}] format. doc (Doc): Document do parse. @@ -188,7 +192,7 @@ def parse_ents(doc, options={}): return {"text": doc.text, "ents": ents, "title": title, "settings": settings} -def set_render_wrapper(func): +def set_render_wrapper(func: Callable[[str], str]) -> None: """Set an optional wrapper function that is called around the generated HTML markup on displacy.render. This can be used to allow integration into other platforms, similar to Jupyter Notebooks that require functions to be @@ -205,7 +209,7 @@ def set_render_wrapper(func): RENDER_WRAPPER = func -def get_doc_settings(doc): +def get_doc_settings(doc: Doc) -> Dict[str, Any]: return { "lang": doc.lang_, "direction": doc.vocab.writing_system.get("direction", "ltr"), diff --git a/spacy/displacy/render.py b/spacy/displacy/render.py index fcf4ccaa6..69f6df8f0 100644 --- a/spacy/displacy/render.py +++ b/spacy/displacy/render.py @@ -1,19 +1,36 @@ +from typing import Dict, Any, List, Optional, Union import uuid -from .templates import ( - TPL_DEP_SVG, - TPL_DEP_WORDS, - TPL_DEP_WORDS_LEMMA, - TPL_DEP_ARCS, - TPL_ENTS, -) +from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_WORDS_LEMMA, TPL_DEP_ARCS from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE +from .templates import TPL_ENTS from ..util import minify_html, escape_html, registry from ..errors import Errors DEFAULT_LANG = "en" DEFAULT_DIR = "ltr" +DEFAULT_ENTITY_COLOR = "#ddd" +DEFAULT_LABEL_COLORS = { + "ORG": "#7aecec", + "PRODUCT": "#bfeeb7", + "GPE": "#feca74", + "LOC": "#ff9561", + "PERSON": "#aa9cfc", + "NORP": "#c887fb", + "FACILITY": "#9cc9cc", + "EVENT": "#ffeb80", + "LAW": "#ff8197", + "LANGUAGE": "#ff8197", + "WORK_OF_ART": "#f0d0ff", + "DATE": "#bfe1d9", + "TIME": "#bfe1d9", + "MONEY": "#e4e7d2", + "QUANTITY": "#e4e7d2", + "ORDINAL": "#e4e7d2", + "CARDINAL": "#e4e7d2", + "PERCENT": "#e4e7d2", +} class DependencyRenderer: @@ -21,7 +38,7 @@ class DependencyRenderer: style = "dep" - def __init__(self, options={}): + def __init__(self, options: Dict[str, Any] = {}) -> None: """Initialise dependency renderer. options (dict): Visualiser-specific options (compact, word_spacing, @@ -41,7 +58,9 @@ class DependencyRenderer: self.direction = DEFAULT_DIR self.lang = DEFAULT_LANG - def render(self, parsed, page=False, minify=False): + def render( + self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False + ) -> str: """Render complete markup. parsed (list): Dependency parses to render. @@ -72,10 +91,15 @@ class DependencyRenderer: return minify_html(markup) return markup - def render_svg(self, render_id, words, arcs): + def render_svg( + self, + render_id: Union[int, str], + words: List[Dict[str, Any]], + arcs: List[Dict[str, Any]], + ) -> str: """Render SVG. - render_id (int): Unique ID, typically index of document. + render_id (Union[int, str]): Unique ID, typically index of document. words (list): Individual words and their tags. arcs (list): Individual arcs and their start, end, direction and label. RETURNS (str): Rendered SVG markup. @@ -86,15 +110,15 @@ class DependencyRenderer: self.width = self.offset_x + len(words) * self.distance self.height = self.offset_y + 3 * self.word_spacing self.id = render_id - words = [ + words_svg = [ self.render_word(w["text"], w["tag"], w.get("lemma", None), i) for i, w in enumerate(words) ] - arcs = [ + arcs_svg = [ self.render_arrow(a["label"], a["start"], a["end"], a["dir"], i) for i, a in enumerate(arcs) ] - content = "".join(words) + "".join(arcs) + content = "".join(words_svg) + "".join(arcs_svg) return TPL_DEP_SVG.format( id=self.id, width=self.width, @@ -107,9 +131,7 @@ class DependencyRenderer: lang=self.lang, ) - def render_word( - self, text, tag, lemma, i, - ): + def render_word(self, text: str, tag: str, lemma: str, i: int) -> str: """Render individual word. text (str): Word text. @@ -128,7 +150,9 @@ class DependencyRenderer: ) return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y) - def render_arrow(self, label, start, end, direction, i): + def render_arrow( + self, label: str, start: int, end: int, direction: str, i: int + ) -> str: """Render individual arrow. label (str): Dependency label. @@ -172,7 +196,7 @@ class DependencyRenderer: arc=arc, ) - def get_arc(self, x_start, y, y_curve, x_end): + def get_arc(self, x_start: int, y: int, y_curve: int, x_end: int) -> str: """Render individual arc. x_start (int): X-coordinate of arrow start point. @@ -186,7 +210,7 @@ class DependencyRenderer: template = "M{x},{y} {x},{c} {e},{c} {e},{y}" return template.format(x=x_start, y=y, c=y_curve, e=x_end) - def get_arrowhead(self, direction, x, y, end): + def get_arrowhead(self, direction: str, x: int, y: int, end: int) -> str: """Render individual arrow head. direction (str): Arrow direction, 'left' or 'right'. @@ -196,24 +220,12 @@ class DependencyRenderer: RETURNS (str): Definition of the arrow head path ('d' attribute). """ if direction == "left": - pos1, pos2, pos3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2) + p1, p2, p3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2) else: - pos1, pos2, pos3 = ( - end, - end + self.arrow_width - 2, - end - self.arrow_width + 2, - ) - arrowhead = ( - pos1, - y + 2, - pos2, - y - self.arrow_width, - pos3, - y - self.arrow_width, - ) - return "M{},{} L{},{} {},{}".format(*arrowhead) + p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2) + return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}" - def get_levels(self, arcs): + def get_levels(self, arcs: List[Dict[str, Any]]) -> List[int]: """Calculate available arc height "levels". Used to calculate arrow heights dynamically and without wasting space. @@ -229,41 +241,21 @@ class EntityRenderer: style = "ent" - def __init__(self, options={}): + def __init__(self, options: Dict[str, Any] = {}) -> None: """Initialise dependency renderer. options (dict): Visualiser-specific options (colors, ents) """ - colors = { - "ORG": "#7aecec", - "PRODUCT": "#bfeeb7", - "GPE": "#feca74", - "LOC": "#ff9561", - "PERSON": "#aa9cfc", - "NORP": "#c887fb", - "FACILITY": "#9cc9cc", - "EVENT": "#ffeb80", - "LAW": "#ff8197", - "LANGUAGE": "#ff8197", - "WORK_OF_ART": "#f0d0ff", - "DATE": "#bfe1d9", - "TIME": "#bfe1d9", - "MONEY": "#e4e7d2", - "QUANTITY": "#e4e7d2", - "ORDINAL": "#e4e7d2", - "CARDINAL": "#e4e7d2", - "PERCENT": "#e4e7d2", - } + colors = dict(DEFAULT_LABEL_COLORS) user_colors = registry.displacy_colors.get_all() for user_color in user_colors.values(): colors.update(user_color) colors.update(options.get("colors", {})) - self.default_color = "#ddd" + self.default_color = DEFAULT_ENTITY_COLOR self.colors = colors self.ents = options.get("ents", None) self.direction = DEFAULT_DIR self.lang = DEFAULT_LANG - template = options.get("template") if template: self.ent_template = template @@ -273,7 +265,9 @@ class EntityRenderer: else: self.ent_template = TPL_ENT - def render(self, parsed, page=False, minify=False): + def render( + self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False + ) -> str: """Render complete markup. parsed (list): Dependency parses to render. @@ -297,7 +291,9 @@ class EntityRenderer: return minify_html(markup) return markup - def render_ents(self, text, spans, title): + def render_ents( + self, text: str, spans: List[Dict[str, Any]], title: Optional[str] + ) -> str: """Render entities in text. text (str): Original text. diff --git a/spacy/errors.py b/spacy/errors.py index fb50f913d..df6f82757 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -483,6 +483,8 @@ class Errors: E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") # TODO: fix numbering after merging develop into master + E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}") + E954 = ("The Tok2Vec listener did not receive a valid input.") E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.") E956 = ("Can't find component '{name}' in [components] block in the config. " "Available components: {opts}") diff --git a/spacy/gold/align.py b/spacy/gold/align.py index 0dd48d4cf..af70ee5b7 100644 --- a/spacy/gold/align.py +++ b/spacy/gold/align.py @@ -15,7 +15,7 @@ class Alignment: x2y = _make_ragged(x2y) y2x = _make_ragged(y2x) return Alignment(x2y=x2y, y2x=y2x) - + @classmethod def from_strings(cls, A: List[str], B: List[str]) -> "Alignment": x2y, y2x = tokenizations.get_alignments(A, B) diff --git a/spacy/language.py b/spacy/language.py index 1fec45a09..93d239356 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -163,7 +163,7 @@ class Language: else: if (self.lang and vocab.lang) and (self.lang != vocab.lang): raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang)) - self.vocab = vocab + self.vocab: Vocab = vocab if self.lang is None: self.lang = self.vocab.lang self.pipeline = [] diff --git a/spacy/ml/_precomputable_affine.py b/spacy/ml/_precomputable_affine.py index a3e2633e9..f5e5cd8ad 100644 --- a/spacy/ml/_precomputable_affine.py +++ b/spacy/ml/_precomputable_affine.py @@ -8,7 +8,7 @@ def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1): init=init, dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP}, params={"W": None, "b": None, "pad": None}, - attrs={"dropout_rate": dropout} + attrs={"dropout_rate": dropout}, ) return model diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index f5249ae24..1766fa80e 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -154,16 +154,30 @@ def LayerNormalizedMaxout(width, maxout_pieces): def MultiHashEmbed( columns, width, rows, use_subwords, pretrained_vectors, mix, dropout ): - norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6) + norm = HashEmbed( + nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6 + ) if use_subwords: prefix = HashEmbed( - nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout, seed=7 + nO=width, + nV=rows // 2, + column=columns.index("PREFIX"), + dropout=dropout, + seed=7, ) suffix = HashEmbed( - nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout, seed=8 + nO=width, + nV=rows // 2, + column=columns.index("SUFFIX"), + dropout=dropout, + seed=8, ) shape = HashEmbed( - nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout, seed=9 + nO=width, + nV=rows // 2, + column=columns.index("SHAPE"), + dropout=dropout, + seed=9, ) if pretrained_vectors: @@ -192,7 +206,9 @@ def MultiHashEmbed( @registry.architectures.register("spacy.CharacterEmbed.v1") def CharacterEmbed(columns, width, rows, nM, nC, features, dropout): - norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5) + norm = HashEmbed( + nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5 + ) chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) with Model.define_operators({">>": chain, "|": concatenate}): embed_layer = chr_embed | features >> with_array(norm) @@ -263,21 +279,29 @@ def build_Tok2Vec_model( cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): norm = HashEmbed( - nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, - seed=0 + nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, seed=0 ) if subword_features: prefix = HashEmbed( - nO=width, nV=embed_size // 2, column=cols.index(PREFIX), dropout=None, - seed=1 + nO=width, + nV=embed_size // 2, + column=cols.index(PREFIX), + dropout=None, + seed=1, ) suffix = HashEmbed( - nO=width, nV=embed_size // 2, column=cols.index(SUFFIX), dropout=None, - seed=2 + nO=width, + nV=embed_size // 2, + column=cols.index(SUFFIX), + dropout=None, + seed=2, ) shape = HashEmbed( - nO=width, nV=embed_size // 2, column=cols.index(SHAPE), dropout=None, - seed=3 + nO=width, + nV=embed_size // 2, + column=cols.index(SHAPE), + dropout=None, + seed=3, ) else: prefix, suffix, shape = (None, None, None) @@ -294,11 +318,7 @@ def build_Tok2Vec_model( embed = uniqued( (glove | norm | prefix | suffix | shape) >> Maxout( - nO=width, - nI=width * columns, - nP=3, - dropout=0.0, - normalize=True, + nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, ), column=cols.index(ORTH), ) @@ -307,11 +327,7 @@ def build_Tok2Vec_model( embed = uniqued( (glove | norm) >> Maxout( - nO=width, - nI=width * columns, - nP=3, - dropout=0.0, - normalize=True, + nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, ), column=cols.index(ORTH), ) @@ -320,11 +336,7 @@ def build_Tok2Vec_model( embed = uniqued( concatenate(norm, prefix, suffix, shape) >> Maxout( - nO=width, - nI=width * columns, - nP=3, - dropout=0.0, - normalize=True, + nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, ), column=cols.index(ORTH), ) @@ -333,11 +345,7 @@ def build_Tok2Vec_model( cols ) >> with_array(norm) reduce_dimensions = Maxout( - nO=width, - nI=nM * nC + width, - nP=3, - dropout=0.0, - normalize=True, + nO=width, nI=nM * nC + width, nP=3, dropout=0.0, normalize=True, ) else: embed = norm diff --git a/spacy/pipeline/__init__.py b/spacy/pipeline/__init__.py index 5075121bc..f8accd14f 100644 --- a/spacy/pipeline/__init__.py +++ b/spacy/pipeline/__init__.py @@ -10,7 +10,6 @@ from .simple_ner import SimpleNER from .tagger import Tagger from .textcat import TextCategorizer from .tok2vec import Tok2Vec -from .hooks import SentenceSegmenter, SimilarityHook from .functions import merge_entities, merge_noun_chunks, merge_subtokens __all__ = [ @@ -21,9 +20,7 @@ __all__ = [ "Morphologizer", "Pipe", "SentenceRecognizer", - "SentenceSegmenter", "Sentencizer", - "SimilarityHook", "SimpleNER", "Tagger", "TextCategorizer", diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 07863b8e9..869968136 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -63,7 +63,7 @@ class EntityRuler: overwrite_ents: bool = False, ent_id_sep: str = DEFAULT_ENT_ID_SEP, patterns: Optional[List[PatternType]] = None, - ): + ) -> None: """Initialize the entitiy ruler. If patterns are supplied here, they need to be a list of dictionaries with a `"label"` and `"pattern"` key. A pattern can either be a token pattern (list) or a phrase pattern @@ -239,7 +239,6 @@ class EntityRuler: phrase_pattern_labels = [] phrase_pattern_texts = [] phrase_pattern_ids = [] - for entry in patterns: if isinstance(entry["pattern"], str): phrase_pattern_labels.append(entry["label"]) @@ -247,7 +246,6 @@ class EntityRuler: phrase_pattern_ids.append(entry.get("id")) elif isinstance(entry["pattern"], list): token_patterns.append(entry) - phrase_patterns = [] for label, pattern, ent_id in zip( phrase_pattern_labels, @@ -258,7 +256,6 @@ class EntityRuler: if ent_id: phrase_pattern["id"] = ent_id phrase_patterns.append(phrase_pattern) - for entry in token_patterns + phrase_patterns: label = entry["label"] if "id" in entry: @@ -266,7 +263,6 @@ class EntityRuler: label = self._create_label(label, entry["id"]) key = self.matcher._normalize_key(label) self._ent_ids[key] = (ent_label, entry["id"]) - pattern = entry["pattern"] if isinstance(pattern, Doc): self.phrase_patterns[label].append(pattern) @@ -323,13 +319,13 @@ class EntityRuler: self.clear() if isinstance(cfg, dict): self.add_patterns(cfg.get("patterns", cfg)) - self.overwrite = cfg.get("overwrite") + self.overwrite = cfg.get("overwrite", False) self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None) if self.phrase_matcher_attr is not None: self.phrase_matcher = PhraseMatcher( self.nlp.vocab, attr=self.phrase_matcher_attr ) - self.ent_id_sep = cfg.get("ent_id_sep") + self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) else: self.add_patterns(cfg) return self @@ -375,9 +371,9 @@ class EntityRuler: } deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))} from_disk(path, deserializers_cfg, {}) - self.overwrite = cfg.get("overwrite") + self.overwrite = cfg.get("overwrite", False) self.phrase_matcher_attr = cfg.get("phrase_matcher_attr") - self.ent_id_sep = cfg.get("ent_id_sep") + self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) if self.phrase_matcher_attr is not None: self.phrase_matcher = PhraseMatcher( diff --git a/spacy/pipeline/hooks.py b/spacy/pipeline/hooks.py deleted file mode 100644 index 60654f6b7..000000000 --- a/spacy/pipeline/hooks.py +++ /dev/null @@ -1,95 +0,0 @@ -from thinc.api import concatenate, reduce_max, reduce_mean, siamese, CauchySimilarity - -from .pipe import Pipe -from ..util import link_vectors_to_models - - -# TODO: do we want to keep these? - - -class SentenceSegmenter: - """A simple spaCy hook, to allow custom sentence boundary detection logic - (that doesn't require the dependency parse). To change the sentence - boundary detection strategy, pass a generator function `strategy` on - initialization, or assign a new strategy to the .strategy attribute. - Sentence detection strategies should be generators that take `Doc` objects - and yield `Span` objects for each sentence. - """ - - def __init__(self, vocab, strategy=None): - self.vocab = vocab - if strategy is None or strategy == "on_punct": - strategy = self.split_on_punct - self.strategy = strategy - - def __call__(self, doc): - doc.user_hooks["sents"] = self.strategy - return doc - - @staticmethod - def split_on_punct(doc): - start = 0 - seen_period = False - for i, token in enumerate(doc): - if seen_period and not token.is_punct: - yield doc[start : token.i] - start = token.i - seen_period = False - elif token.text in [".", "!", "?"]: - seen_period = True - if start < len(doc): - yield doc[start : len(doc)] - - -class SimilarityHook(Pipe): - """ - Experimental: A pipeline component to install a hook for supervised - similarity into `Doc` objects. - The similarity model can be any object obeying the Thinc `Model` - interface. By default, the model concatenates the elementwise mean and - elementwise max of the two tensors, and compares them using the - Cauchy-like similarity function from Chen (2013): - - >>> similarity = 1. / (1. + (W * (vec1-vec2)**2).sum()) - - Where W is a vector of dimension weights, initialized to 1. - """ - - def __init__(self, vocab, model=True, **cfg): - self.vocab = vocab - self.model = model - self.cfg = dict(cfg) - - @classmethod - def Model(cls, length): - return siamese( - concatenate(reduce_max(), reduce_mean()), CauchySimilarity(length * 2) - ) - - def __call__(self, doc): - """Install similarity hook""" - doc.user_hooks["similarity"] = self.predict - return doc - - def pipe(self, docs, **kwargs): - for doc in docs: - yield self(doc) - - def predict(self, doc1, doc2): - return self.model.predict([(doc1, doc2)]) - - def update(self, doc1_doc2, golds, sgd=None, drop=0.0): - sims, bp_sims = self.model.begin_update(doc1_doc2) - - def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs): - """Allocate model, using nO from the first model in the pipeline. - - gold_tuples (iterable): Gold-standard training data. - pipeline (list): The pipeline the model is part of. - """ - if self.model is True: - self.model = self.Model(pipeline[0].model.get_dim("nO")) - link_vectors_to_models(self.vocab) - if sgd is None: - sgd = self.create_optimizer() - return sgd diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index bc68bb806..3c0808342 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -1,4 +1,4 @@ -from typing import Iterable, Tuple, Optional, Dict, List, Callable +from typing import Iterable, Tuple, Optional, Dict, List, Callable, Iterator, Any from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config import numpy @@ -97,7 +97,7 @@ class TextCategorizer(Pipe): def labels(self, value: Iterable[str]) -> None: self.cfg["labels"] = tuple(value) - def pipe(self, stream, batch_size=128): + def pipe(self, stream: Iterator[str], batch_size: int = 128) -> Iterator[Doc]: for docs in util.minibatch(stream, size=batch_size): scores = self.predict(docs) self.set_annotations(docs, scores) @@ -252,8 +252,17 @@ class TextCategorizer(Pipe): sgd = self.create_optimizer() return sgd - def score(self, examples, positive_label=None, **kwargs): - return Scorer.score_cats(examples, "cats", labels=self.labels, + def score( + self, + examples: Iterable[Example], + positive_label: Optional[str] = None, + **kwargs, + ) -> Dict[str, Any]: + return Scorer.score_cats( + examples, + "cats", + labels=self.labels, multi_label=self.model.attrs["multi_label"], - positive_label=positive_label, **kwargs + positive_label=positive_label, + **kwargs, ) diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index 0322ef26c..51a8b6a16 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -6,6 +6,7 @@ from ..gold import Example from ..tokens import Doc from ..vocab import Vocab from ..language import Language +from ..errors import Errors from ..util import link_vectors_to_models, minibatch @@ -150,7 +151,7 @@ class Tok2Vec(Pipe): self.set_annotations(docs, tokvecs) return losses - def get_loss(self, examples, scores): + def get_loss(self, examples, scores) -> None: pass def begin_training( @@ -184,26 +185,26 @@ class Tok2VecListener(Model): self._backprop = None @classmethod - def get_batch_id(cls, inputs): + def get_batch_id(cls, inputs) -> int: return sum(sum(token.orth for token in doc) for doc in inputs) - def receive(self, batch_id, outputs, backprop): + def receive(self, batch_id: int, outputs, backprop) -> None: self._batch_id = batch_id self._outputs = outputs self._backprop = backprop - def verify_inputs(self, inputs): + def verify_inputs(self, inputs) -> bool: if self._batch_id is None and self._outputs is None: - raise ValueError("The Tok2Vec listener did not receive valid input.") + raise ValueError(Errors.E954) else: batch_id = self.get_batch_id(inputs) if batch_id != self._batch_id: - raise ValueError(f"Mismatched IDs! {batch_id} vs {self._batch_id}") + raise ValueError(Errors.E953.format(id1=batch_id, id2=self._batch_id)) else: return True -def forward(model: Tok2VecListener, inputs, is_train): +def forward(model: Tok2VecListener, inputs, is_train: bool): if is_train: model.verify_inputs(inputs) return model._outputs, model._backprop diff --git a/spacy/schemas.py b/spacy/schemas.py index c6bdd6e9c..478755cf8 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union, Optional, Sequence, Any, Callable +from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type from enum import Enum from pydantic import BaseModel, Field, ValidationError, validator from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool @@ -9,12 +9,12 @@ from thinc.api import Optimizer from .attrs import NAMES -def validate(schema, obj): +def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]: """Validate data against a given pydantic schema. - obj (dict): JSON-serializable data to validate. + obj (Dict[str, Any]): JSON-serializable data to validate. schema (pydantic.BaseModel): The schema to validate against. - RETURNS (list): A list of error messages, if available. + RETURNS (List[str]): A list of error messages, if available. """ try: schema(**obj) @@ -31,7 +31,7 @@ def validate(schema, obj): # Matcher token patterns -def validate_token_pattern(obj): +def validate_token_pattern(obj: list) -> List[str]: # Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"}) get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k if isinstance(obj, list): diff --git a/spacy/tests/doc/test_add_entities.py b/spacy/tests/doc/test_add_entities.py index c4167f878..5fb5f0914 100644 --- a/spacy/tests/doc/test_add_entities.py +++ b/spacy/tests/doc/test_add_entities.py @@ -15,7 +15,8 @@ def test_doc_add_entities_set_ents_iob(en_vocab): "min_action_freq": 30, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_NER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] ner = EntityRecognizer(en_vocab, model, **config) ner.begin_training([]) ner(doc) @@ -37,7 +38,8 @@ def test_ents_reset(en_vocab): "min_action_freq": 30, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_NER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] ner = EntityRecognizer(en_vocab, model, **config) ner.begin_training([]) ner(doc) diff --git a/spacy/tests/lang/en/test_sbd.py b/spacy/tests/lang/en/test_sbd.py index 7c2e2e0bd..38c8d94d8 100644 --- a/spacy/tests/lang/en/test_sbd.py +++ b/spacy/tests/lang/en/test_sbd.py @@ -14,7 +14,9 @@ def test_en_sbd_single_punct(en_tokenizer, text, punct): assert sum(len(sent) for sent in doc.sents) == len(doc) -@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") +@pytest.mark.skip( + reason="The step_through API was removed (but should be brought back)" +) def test_en_sentence_breaks(en_tokenizer, en_parser): # fmt: off text = "This is a sentence . This is another one ." diff --git a/spacy/tests/lang/ja/test_tokenizer.py b/spacy/tests/lang/ja/test_tokenizer.py index 8f22cb24a..e52741b70 100644 --- a/spacy/tests/lang/ja/test_tokenizer.py +++ b/spacy/tests/lang/ja/test_tokenizer.py @@ -32,6 +32,25 @@ SENTENCE_TESTS = [ ("あれ。これ。", ["あれ。", "これ。"]), ("「伝染るんです。」という漫画があります。", ["「伝染るんです。」という漫画があります。"]), ] + +tokens1 = [ + DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None), + DetailedToken(surface="会", tag="名詞-普通名詞-一般", inf="", lemma="会", reading="カイ", sub_tokens=None), +] +tokens2 = [ + DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None), + DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None), + DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None), + DetailedToken(surface="会", tag="名詞-普通名詞-一般", inf="", lemma="会", reading="カイ", sub_tokens=None), +] +tokens3 = [ + DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None), + DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None), + DetailedToken(surface="委員会", tag="名詞-普通名詞-一般", inf="", lemma="委員会", reading="イインカイ", sub_tokens=None), +] +SUB_TOKEN_TESTS = [ + ("選挙管理委員会", [None, None, None, None], [None, None, [tokens1]], [[tokens2, tokens3]]) +] # fmt: on @@ -92,33 +111,12 @@ def test_ja_tokenizer_split_modes(ja_tokenizer, text, len_a, len_b, len_c): assert len(nlp_c(text)) == len_c -@pytest.mark.parametrize("text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c", - [ - ( - "選挙管理委員会", - [None, None, None, None], - [None, None, [ - [ - DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None), - DetailedToken(surface='会', tag='名詞-普通名詞-一般', inf='', lemma='会', reading='カイ', sub_tokens=None), - ] - ]], - [[ - [ - DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None), - DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None), - DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None), - DetailedToken(surface='会', tag='名詞-普通名詞-一般', inf='', lemma='会', reading='カイ', sub_tokens=None), - ], [ - DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None), - DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None), - DetailedToken(surface='委員会', tag='名詞-普通名詞-一般', inf='', lemma='委員会', reading='イインカイ', sub_tokens=None), - ] - ]] - ), - ] +@pytest.mark.parametrize( + "text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c", SUB_TOKEN_TESTS, ) -def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c): +def test_ja_tokenizer_sub_tokens( + ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c +): nlp_a = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "A"}}}) nlp_b = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "B"}}}) nlp_c = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "C"}}}) @@ -129,16 +127,19 @@ def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_toke assert nlp_c(text).user_data["sub_tokens"] == sub_tokens_list_c -@pytest.mark.parametrize("text,inflections,reading_forms", +@pytest.mark.parametrize( + "text,inflections,reading_forms", [ ( "取ってつけた", ("五段-ラ行,連用形-促音便", "", "下一段-カ行,連用形-一般", "助動詞-タ,終止形-一般"), ("トッ", "テ", "ツケ", "タ"), ), - ] + ], ) -def test_ja_tokenizer_inflections_reading_forms(ja_tokenizer, text, inflections, reading_forms): +def test_ja_tokenizer_inflections_reading_forms( + ja_tokenizer, text, inflections, reading_forms +): assert ja_tokenizer(text).user_data["inflections"] == inflections assert ja_tokenizer(text).user_data["reading_forms"] == reading_forms diff --git a/spacy/tests/lang/ne/test_text.py b/spacy/tests/lang/ne/test_text.py index 926a7de04..794f8fbdc 100644 --- a/spacy/tests/lang/ne/test_text.py +++ b/spacy/tests/lang/ne/test_text.py @@ -11,9 +11,8 @@ def test_ne_tokenizer_handlers_long_text(ne_tokenizer): @pytest.mark.parametrize( - "text,length", - [("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)], + "text,length", [("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)], ) def test_ne_tokenizer_handles_cnts(ne_tokenizer, text, length): tokens = ne_tokenizer(text) - assert len(tokens) == length \ No newline at end of file + assert len(tokens) == length diff --git a/spacy/tests/lang/zh/test_serialize.py b/spacy/tests/lang/zh/test_serialize.py index 544c4a7bc..015f92785 100644 --- a/spacy/tests/lang/zh/test_serialize.py +++ b/spacy/tests/lang/zh/test_serialize.py @@ -30,10 +30,7 @@ def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg): nlp = Chinese( meta={ "tokenizer": { - "config": { - "segmenter": "pkuseg", - "pkuseg_model": "medicine", - } + "config": {"segmenter": "pkuseg", "pkuseg_model": "medicine",} } } ) diff --git a/spacy/tests/parser/test_add_label.py b/spacy/tests/parser/test_add_label.py index a1e7dd388..88dfabdc8 100644 --- a/spacy/tests/parser/test_add_label.py +++ b/spacy/tests/parser/test_add_label.py @@ -22,7 +22,8 @@ def parser(vocab): "min_action_freq": 30, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_PARSER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] parser = DependencyParser(vocab, model, **config) return parser @@ -68,7 +69,8 @@ def test_add_label_deserializes_correctly(): "min_action_freq": 30, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_NER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] ner1 = EntityRecognizer(Vocab(), model, **config) ner1.add_label("C") ner1.add_label("B") @@ -86,7 +88,10 @@ def test_add_label_deserializes_correctly(): @pytest.mark.parametrize( "pipe_cls,n_moves,model_config", - [(DependencyParser, 5, DEFAULT_PARSER_MODEL), (EntityRecognizer, 4, DEFAULT_NER_MODEL)], + [ + (DependencyParser, 5, DEFAULT_PARSER_MODEL), + (EntityRecognizer, 4, DEFAULT_NER_MODEL), + ], ) def test_add_label_get_label(pipe_cls, n_moves, model_config): """Test that added labels are returned correctly. This test was added to diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index 9781f71ed..77e142215 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -126,7 +126,8 @@ def test_get_oracle_actions(): "min_action_freq": 0, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_PARSER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] parser = DependencyParser(doc.vocab, model, **config) parser.moves.add_action(0, "") parser.moves.add_action(1, "") diff --git a/spacy/tests/parser/test_neural_parser.py b/spacy/tests/parser/test_neural_parser.py index a53a0f37a..feae52f7f 100644 --- a/spacy/tests/parser/test_neural_parser.py +++ b/spacy/tests/parser/test_neural_parser.py @@ -24,7 +24,8 @@ def arc_eager(vocab): @pytest.fixture def tok2vec(): - tok2vec = registry.make_from_config({"model": DEFAULT_TOK2VEC_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_TOK2VEC_MODEL} + tok2vec = registry.make_from_config(cfg, validate=True)["model"] tok2vec.initialize() return tok2vec @@ -36,13 +37,15 @@ def parser(vocab, arc_eager): "min_action_freq": 30, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_PARSER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] return Parser(vocab, model, moves=arc_eager, **config) @pytest.fixture def model(arc_eager, tok2vec, vocab): - model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_PARSER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] model.attrs["resize_output"](model, arc_eager.n_moves) model.initialize() return model @@ -68,7 +71,8 @@ def test_build_model(parser, vocab): "min_action_freq": 0, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_PARSER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model assert parser.model is not None diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index 3f3fabbb8..45ae09702 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -33,7 +33,9 @@ def test_parser_root(en_tokenizer): assert t.dep != 0, t.text -@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") +@pytest.mark.skip( + reason="The step_through API was removed (but should be brought back)" +) @pytest.mark.parametrize("text", ["Hello"]) def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text): tokens = en_tokenizer(text) @@ -47,7 +49,9 @@ def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text): assert doc[0].dep != 0 -@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") +@pytest.mark.skip( + reason="The step_through API was removed (but should be brought back)" +) def test_parser_initial(en_tokenizer, en_parser): text = "I ate the pizza with anchovies." # heads = [1, 0, 1, -2, -3, -1, -5] @@ -92,7 +96,9 @@ def test_parser_merge_pp(en_tokenizer): assert doc[3].text == "occurs" -@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") +@pytest.mark.skip( + reason="The step_through API was removed (but should be brought back)" +) def test_parser_arc_eager_finalize_state(en_tokenizer, en_parser): text = "a b c d e" diff --git a/spacy/tests/parser/test_preset_sbd.py b/spacy/tests/parser/test_preset_sbd.py index 747203c2f..939181419 100644 --- a/spacy/tests/parser/test_preset_sbd.py +++ b/spacy/tests/parser/test_preset_sbd.py @@ -21,7 +21,8 @@ def parser(vocab): "min_action_freq": 30, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_PARSER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] parser = DependencyParser(vocab, model, **config) parser.cfg["token_vector_width"] = 4 parser.cfg["hidden_width"] = 32 diff --git a/spacy/tests/parser/test_space_attachment.py b/spacy/tests/parser/test_space_attachment.py index db25a25c0..3a0a6b943 100644 --- a/spacy/tests/parser/test_space_attachment.py +++ b/spacy/tests/parser/test_space_attachment.py @@ -28,7 +28,9 @@ def test_parser_sentence_space(en_tokenizer): assert len(list(doc.sents)) == 2 -@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") +@pytest.mark.skip( + reason="The step_through API was removed (but should be brought back)" +) def test_parser_space_attachment_leading(en_tokenizer, en_parser): text = "\t \n This is a sentence ." heads = [1, 1, 0, 1, -2, -3] @@ -44,7 +46,9 @@ def test_parser_space_attachment_leading(en_tokenizer, en_parser): assert stepwise.stack == set([2]) -@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") +@pytest.mark.skip( + reason="The step_through API was removed (but should be brought back)" +) def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser): text = "This is \t a \t\n \n sentence . \n\n \n" heads = [1, 0, -1, 2, -1, -4, -5, -1] @@ -64,7 +68,9 @@ def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser): @pytest.mark.parametrize("text,length", [(["\n"], 1), (["\n", "\t", "\n\n", "\t"], 4)]) -@pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") +@pytest.mark.skip( + reason="The step_through API was removed (but should be brought back)" +) def test_parser_space_attachment_space(en_tokenizer, en_parser, text, length): doc = Doc(en_parser.vocab, words=text) assert len(doc) == length diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index ff284873d..5e8dab0bd 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -117,7 +117,9 @@ def test_overfitting_IO(): assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1) # Test scoring - scores = nlp.evaluate(train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}}) + scores = nlp.evaluate( + train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}} + ) assert scores["cats_f"] == 1.0 diff --git a/spacy/tests/regression/test_issue3001-3500.py b/spacy/tests/regression/test_issue3001-3500.py index 6b4a9ad1d..e93c27a59 100644 --- a/spacy/tests/regression/test_issue3001-3500.py +++ b/spacy/tests/regression/test_issue3001-3500.py @@ -201,7 +201,8 @@ def test_issue3345(): "min_action_freq": 30, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_NER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] ner = EntityRecognizer(doc.vocab, model, **config) # Add the OUT action. I wouldn't have thought this would be necessary... ner.moves.add_action(5, "") diff --git a/spacy/tests/serialize/test_serialize_pipeline.py b/spacy/tests/serialize/test_serialize_pipeline.py index 14a4579be..17d5a3a1e 100644 --- a/spacy/tests/serialize/test_serialize_pipeline.py +++ b/spacy/tests/serialize/test_serialize_pipeline.py @@ -20,7 +20,8 @@ def parser(en_vocab): "min_action_freq": 30, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_PARSER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] parser = DependencyParser(en_vocab, model, **config) parser.add_label("nsubj") return parser @@ -33,14 +34,16 @@ def blank_parser(en_vocab): "min_action_freq": 30, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_PARSER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] parser = DependencyParser(en_vocab, model, **config) return parser @pytest.fixture def taggers(en_vocab): - model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_TAGGER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] tagger1 = Tagger(en_vocab, model, set_morphology=True) tagger2 = Tagger(en_vocab, model, set_morphology=True) return tagger1, tagger2 @@ -53,7 +56,8 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser): "min_action_freq": 0, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_PARSER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] parser = Parser(en_vocab, model, **config) new_parser = Parser(en_vocab, model, **config) new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"])) @@ -70,7 +74,8 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser): "min_action_freq": 0, "update_with_oracle_cut_size": 100, } - model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_PARSER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] parser = Parser(en_vocab, model, **config) with make_tempdir() as d: file_path = d / "parser" @@ -88,7 +93,6 @@ def test_to_from_bytes(parser, blank_parser): assert blank_parser.model is not True assert blank_parser.moves.n_moves != parser.moves.n_moves bytes_data = parser.to_bytes(exclude=["vocab"]) - # the blank parser needs to be resized before we can call from_bytes blank_parser.model.attrs["resize_output"](blank_parser.model, parser.moves.n_moves) blank_parser.from_bytes(bytes_data) @@ -104,7 +108,8 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers): tagger1_b = tagger1.to_bytes() tagger1 = tagger1.from_bytes(tagger1_b) assert tagger1.to_bytes() == tagger1_b - model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_TAGGER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b) new_tagger1_b = new_tagger1.to_bytes() assert len(new_tagger1_b) == len(tagger1_b) @@ -118,7 +123,8 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers): file_path2 = d / "tagger2" tagger1.to_disk(file_path1) tagger2.to_disk(file_path2) - model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_TAGGER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] tagger1_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path1) tagger2_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path2) assert tagger1_d.to_bytes() == tagger2_d.to_bytes() @@ -126,21 +132,22 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers): def test_serialize_textcat_empty(en_vocab): # See issue #1105 - model = registry.make_from_config({"model": DEFAULT_TEXTCAT_MODEL}, validate=True)["model"] - textcat = TextCategorizer( - en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"] - ) + cfg = {"model": DEFAULT_TEXTCAT_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] + textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"]) textcat.to_bytes(exclude=["vocab"]) @pytest.mark.parametrize("Parser", test_parsers) def test_serialize_pipe_exclude(en_vocab, Parser): - model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_PARSER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] config = { "learn_tokens": False, "min_action_freq": 0, "update_with_oracle_cut_size": 100, } + def get_new_parser(): new_parser = Parser(en_vocab, model, **config) return new_parser @@ -160,7 +167,8 @@ def test_serialize_pipe_exclude(en_vocab, Parser): def test_serialize_sentencerecognizer(en_vocab): - model = registry.make_from_config({"model": DEFAULT_SENTER_MODEL}, validate=True)["model"] + cfg = {"model": DEFAULT_SENTER_MODEL} + model = registry.make_from_config(cfg, validate=True)["model"] sr = SentenceRecognizer(en_vocab, model) sr_b = sr.to_bytes() sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b) diff --git a/spacy/tests/test_gold.py b/spacy/tests/test_gold.py index b03765857..c44daf630 100644 --- a/spacy/tests/test_gold.py +++ b/spacy/tests/test_gold.py @@ -466,7 +466,6 @@ def test_iob_to_biluo(): def test_roundtrip_docs_to_docbin(doc): - nlp = English() text = doc.text idx = [t.idx for t in doc] tags = [t.tag_ for t in doc] diff --git a/spacy/tests/test_new_example.py b/spacy/tests/test_new_example.py index f858b0759..886a24a8e 100644 --- a/spacy/tests/test_new_example.py +++ b/spacy/tests/test_new_example.py @@ -7,11 +7,11 @@ from spacy.vocab import Vocab def test_Example_init_requires_doc_objects(): vocab = Vocab() with pytest.raises(TypeError): - example = Example(None, None) + Example(None, None) with pytest.raises(TypeError): - example = Example(Doc(vocab, words=["hi"]), None) + Example(Doc(vocab, words=["hi"]), None) with pytest.raises(TypeError): - example = Example(None, Doc(vocab, words=["hi"])) + Example(None, Doc(vocab, words=["hi"])) def test_Example_from_dict_basic(): diff --git a/spacy/tests/test_scorer.py b/spacy/tests/test_scorer.py index fea263df5..422b4e328 100644 --- a/spacy/tests/test_scorer.py +++ b/spacy/tests/test_scorer.py @@ -105,7 +105,11 @@ def test_tokenization(sented_doc): assert scores["token_acc"] == 1.0 nlp = English() - example.predicted = Doc(nlp.vocab, words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."], spaces=[True, True, True, True, True, False]) + example.predicted = Doc( + nlp.vocab, + words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."], + spaces=[True, True, True, True, True, False], + ) example.predicted[1].is_sent_start = False scores = scorer.score([example]) assert scores["token_acc"] == approx(0.66666666) diff --git a/spacy/util.py b/spacy/util.py index 682d45bc9..c98ce2354 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -1,12 +1,13 @@ from typing import List, Union, Dict, Any, Optional, Iterable, Callable, Tuple -from typing import Iterator, TYPE_CHECKING +from typing import Iterator, Type, Pattern, Sequence, TYPE_CHECKING +from types import ModuleType import os import importlib import importlib.util import re from pathlib import Path import thinc -from thinc.api import NumpyOps, get_current_ops, Adam, Config +from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer import functools import itertools import numpy.random @@ -49,6 +50,8 @@ from . import about if TYPE_CHECKING: # This lets us add type hints for mypy etc. without causing circular imports from .language import Language # noqa: F401 + from .tokens import Doc, Span # noqa: F401 + from .vocab import Vocab # noqa: F401 _PRINT_ENV = False @@ -102,12 +105,12 @@ class SimpleFrozenDict(dict): raise NotImplementedError(self.error) -def set_env_log(value): +def set_env_log(value: bool) -> None: global _PRINT_ENV _PRINT_ENV = value -def lang_class_is_loaded(lang): +def lang_class_is_loaded(lang: str) -> bool: """Check whether a Language class is already loaded. Language classes are loaded lazily, to avoid expensive setup code associated with the language data. @@ -118,7 +121,7 @@ def lang_class_is_loaded(lang): return lang in registry.languages -def get_lang_class(lang): +def get_lang_class(lang: str) -> "Language": """Import and load a Language class. lang (str): Two-letter language code, e.g. 'en'. @@ -136,7 +139,7 @@ def get_lang_class(lang): return registry.languages.get(lang) -def set_lang_class(name, cls): +def set_lang_class(name: str, cls: Type["Language"]) -> None: """Set a custom Language class name that can be loaded via get_lang_class. name (str): Name of Language class. @@ -145,10 +148,10 @@ def set_lang_class(name, cls): registry.languages.register(name, func=cls) -def ensure_path(path): +def ensure_path(path: Any) -> Any: """Ensure string is converted to a Path. - path: Anything. If string, it's converted to Path. + path (Any): Anything. If string, it's converted to Path. RETURNS: Path or original argument. """ if isinstance(path, str): @@ -157,7 +160,7 @@ def ensure_path(path): return path -def load_language_data(path): +def load_language_data(path: Union[str, Path]) -> Union[dict, list]: """Load JSON language data using the given path as a base. If the provided path isn't present, will attempt to load a gzipped version before giving up. @@ -173,7 +176,12 @@ def load_language_data(path): raise ValueError(Errors.E160.format(path=path)) -def get_module_path(module): +def get_module_path(module: ModuleType) -> Path: + """Get the path of a Python module. + + module (ModuleType): The Python module. + RETURNS (Path): The path. + """ if not hasattr(module, "__module__"): raise ValueError(Errors.E169.format(module=repr(module))) return Path(sys.modules[module.__module__].__file__).parent @@ -183,7 +191,7 @@ def load_model( name: Union[str, Path], disable: Iterable[str] = tuple(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), -): +) -> "Language": """Load a model from a package or data path. name (str): Package name or model path. @@ -209,7 +217,7 @@ def load_model_from_package( name: str, disable: Iterable[str] = tuple(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), -): +) -> "Language": """Load a model from an installed package.""" cls = importlib.import_module(name) return cls.load(disable=disable, component_cfg=component_cfg) @@ -220,7 +228,7 @@ def load_model_from_path( meta: Optional[Dict[str, Any]] = None, disable: Iterable[str] = tuple(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), -): +) -> "Language": """Load a model from a data directory path. Creates Language class with pipeline from config.cfg and then calls from_disk() with path.""" if not model_path.exists(): @@ -269,7 +277,7 @@ def load_model_from_init_py( init_file: Union[Path, str], disable: Iterable[str] = tuple(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), -): +) -> "Language": """Helper function to use in the `load()` method of a model package's __init__.py. @@ -288,15 +296,15 @@ def load_model_from_init_py( ) -def get_installed_models(): +def get_installed_models() -> List[str]: """List all model packages currently installed in the environment. - RETURNS (list): The string names of the models. + RETURNS (List[str]): The string names of the models. """ return list(registry.models.get_all().keys()) -def get_package_version(name): +def get_package_version(name: str) -> Optional[str]: """Get the version of an installed package. Typically used to get model package versions. @@ -309,7 +317,9 @@ def get_package_version(name): return None -def is_compatible_version(version, constraint, prereleases=True): +def is_compatible_version( + version: str, constraint: str, prereleases: bool = True +) -> Optional[bool]: """Check if a version (e.g. "2.0.0") is compatible given a version constraint (e.g. ">=1.9.0,<2.2.1"). If the constraint is a specific version, it's interpreted as =={version}. @@ -333,7 +343,9 @@ def is_compatible_version(version, constraint, prereleases=True): return version in spec -def is_unconstrained_version(constraint, prereleases=True): +def is_unconstrained_version( + constraint: str, prereleases: bool = True +) -> Optional[bool]: # We have an exact version, this is the ultimate constrained version if constraint[0].isdigit(): return False @@ -358,7 +370,7 @@ def is_unconstrained_version(constraint, prereleases=True): return True -def get_model_version_range(spacy_version): +def get_model_version_range(spacy_version: str) -> str: """Generate a version range like >=1.2.3,<1.3.0 based on a given spaCy version. Models are always compatible across patch versions but not across minor or major versions. @@ -367,7 +379,7 @@ def get_model_version_range(spacy_version): return f">={spacy_version},<{release[0]}.{release[1] + 1}.0" -def get_base_version(version): +def get_base_version(version: str) -> str: """Generate the base version without any prerelease identifiers. version (str): The version, e.g. "3.0.0.dev1". @@ -376,11 +388,11 @@ def get_base_version(version): return Version(version).base_version -def get_model_meta(path): +def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]: """Get model meta.json from a directory path and validate its contents. path (str / Path): Path to model directory. - RETURNS (dict): The model's meta data. + RETURNS (Dict[str, Any]): The model's meta data. """ model_path = ensure_path(path) if not model_path.exists(): @@ -412,7 +424,7 @@ def get_model_meta(path): return meta -def is_package(name): +def is_package(name: str) -> bool: """Check if string maps to a package installed via pip. name (str): Name of package. @@ -425,7 +437,7 @@ def is_package(name): return False -def get_package_path(name): +def get_package_path(name: str) -> Path: """Get the path to an installed package. name (str): Package name. @@ -495,7 +507,7 @@ def working_dir(path: Union[str, Path]) -> None: @contextmanager -def make_tempdir(): +def make_tempdir() -> None: """Execute a block in a temporary directory and remove the directory and its contents at the end of the with block. @@ -518,7 +530,7 @@ def is_cwd(path: Union[Path, str]) -> bool: return str(Path(path).resolve()).lower() == str(Path.cwd().resolve()).lower() -def is_in_jupyter(): +def is_in_jupyter() -> bool: """Check if user is running spaCy from a Jupyter notebook by detecting the IPython kernel. Mainly used for the displaCy visualizer. RETURNS (bool): True if in Jupyter, False if not. @@ -548,7 +560,9 @@ def get_object_name(obj: Any) -> str: return repr(obj) -def get_cuda_stream(require=False, non_blocking=True): +def get_cuda_stream( + require: bool = False, non_blocking: bool = True +) -> Optional[CudaStream]: ops = get_current_ops() if CudaStream is None: return None @@ -567,7 +581,7 @@ def get_async(stream, numpy_array): return array -def env_opt(name, default=None): +def env_opt(name: str, default: Optional[Any] = None) -> Optional[Any]: if type(default) is float: type_convert = float else: @@ -588,7 +602,7 @@ def env_opt(name, default=None): return default -def read_regex(path): +def read_regex(path: Union[str, Path]) -> Pattern: path = ensure_path(path) with path.open(encoding="utf8") as file_: entries = file_.read().split("\n") @@ -598,37 +612,40 @@ def read_regex(path): return re.compile(expression) -def compile_prefix_regex(entries): +def compile_prefix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern: """Compile a sequence of prefix rules into a regex object. - entries (tuple): The prefix rules, e.g. spacy.lang.punctuation.TOKENIZER_PREFIXES. - RETURNS (regex object): The regex object. to be used for Tokenizer.prefix_search. + entries (Iterable[Union[str, Pattern]]): The prefix rules, e.g. + spacy.lang.punctuation.TOKENIZER_PREFIXES. + RETURNS (Pattern): The regex object. to be used for Tokenizer.prefix_search. """ expression = "|".join(["^" + piece for piece in entries if piece.strip()]) return re.compile(expression) -def compile_suffix_regex(entries): +def compile_suffix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern: """Compile a sequence of suffix rules into a regex object. - entries (tuple): The suffix rules, e.g. spacy.lang.punctuation.TOKENIZER_SUFFIXES. - RETURNS (regex object): The regex object. to be used for Tokenizer.suffix_search. + entries (Iterable[Union[str, Pattern]]): The suffix rules, e.g. + spacy.lang.punctuation.TOKENIZER_SUFFIXES. + RETURNS (Pattern): The regex object. to be used for Tokenizer.suffix_search. """ expression = "|".join([piece + "$" for piece in entries if piece.strip()]) return re.compile(expression) -def compile_infix_regex(entries): +def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern: """Compile a sequence of infix rules into a regex object. - entries (tuple): The infix rules, e.g. spacy.lang.punctuation.TOKENIZER_INFIXES. + entries (Iterable[Union[str, Pattern]]): The infix rules, e.g. + spacy.lang.punctuation.TOKENIZER_INFIXES. RETURNS (regex object): The regex object. to be used for Tokenizer.infix_finditer. """ expression = "|".join([piece for piece in entries if piece.strip()]) return re.compile(expression) -def add_lookups(default_func, *lookups): +def add_lookups(default_func: Callable[[str], Any], *lookups) -> Callable[[str], Any]: """Extend an attribute function with special cases. If a word is in the lookups, the value is returned. Otherwise the previous function is used. @@ -641,19 +658,23 @@ def add_lookups(default_func, *lookups): return functools.partial(_get_attr_unless_lookup, default_func, lookups) -def _get_attr_unless_lookup(default_func, lookups, string): +def _get_attr_unless_lookup( + default_func: Callable[[str], Any], lookups: Dict[str, Any], string: str +) -> Any: for lookup in lookups: if string in lookup: return lookup[string] return default_func(string) -def update_exc(base_exceptions, *addition_dicts): +def update_exc( + base_exceptions: Dict[str, List[dict]], *addition_dicts +) -> Dict[str, List[dict]]: """Update and validate tokenizer exceptions. Will overwrite exceptions. - base_exceptions (dict): Base exceptions. - *addition_dicts (dict): Exceptions to add to the base dict, in order. - RETURNS (dict): Combined tokenizer exceptions. + base_exceptions (Dict[str, List[dict]]): Base exceptions. + *addition_dicts (Dict[str, List[dict]]): Exceptions to add to the base dict, in order. + RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions. """ exc = dict(base_exceptions) for additions in addition_dicts: @@ -668,14 +689,16 @@ def update_exc(base_exceptions, *addition_dicts): return exc -def expand_exc(excs, search, replace): +def expand_exc( + excs: Dict[str, List[dict]], search: str, replace: str +) -> Dict[str, List[dict]]: """Find string in tokenizer exceptions, duplicate entry and replace string. For example, to add additional versions with typographic apostrophes. - excs (dict): Tokenizer exceptions. + excs (Dict[str, List[dict]]): Tokenizer exceptions. search (str): String to find and replace. replace (str): Replacement. - RETURNS (dict): Combined tokenizer exceptions. + RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions. """ def _fix_token(token, search, replace): @@ -692,7 +715,9 @@ def expand_exc(excs, search, replace): return new_excs -def normalize_slice(length, start, stop, step=None): +def normalize_slice( + length: int, start: int, stop: int, step: Optional[int] = None +) -> Tuple[int, int]: if not (step is None or step == 1): raise ValueError(Errors.E057) if start is None: @@ -708,7 +733,9 @@ def normalize_slice(length, start, stop, step=None): return start, stop -def minibatch(items, size=8): +def minibatch( + items: Iterable[Any], size: Union[Iterator[int], int] = 8 +) -> Iterator[Any]: """Iterate over batches of items. `size` may be an iterator, so that batch-size can vary on each step. """ @@ -725,7 +752,12 @@ def minibatch(items, size=8): yield list(batch) -def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False): +def minibatch_by_padded_size( + docs: Iterator["Doc"], + size: Union[Iterator[int], int], + buffer: int = 256, + discard_oversize: bool = False, +) -> Iterator[Iterator["Doc"]]: if isinstance(size, int): size_ = itertools.repeat(size) else: @@ -742,7 +774,7 @@ def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False): yield subbatch -def _batch_by_length(seqs, max_words): +def _batch_by_length(seqs: Sequence[Any], max_words: int) -> List[List[Any]]: """Given a list of sequences, return a batched list of indices into the list, where the batches are grouped by length, in descending order. @@ -783,14 +815,12 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): size_ = iter(size) else: size_ = size - target_size = next(size_) tol_size = target_size * tolerance batch = [] overflow = [] batch_size = 0 overflow_size = 0 - for doc in docs: if isinstance(doc, Example): n_words = len(doc.reference) @@ -803,17 +833,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): if n_words > target_size + tol_size: if not discard_oversize: yield [doc] - # add the example to the current batch if there's no overflow yet and it still fits elif overflow_size == 0 and (batch_size + n_words) <= target_size: batch.append(doc) batch_size += n_words - # add the example to the overflow buffer if it fits in the tolerance margin elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): overflow.append(doc) overflow_size += n_words - # yield the previous batch and start a new one. The new one gets the overflow examples. else: if batch: @@ -824,17 +851,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): batch_size = overflow_size overflow = [] overflow_size = 0 - # this example still fits if (batch_size + n_words) <= target_size: batch.append(doc) batch_size += n_words - # this example fits in overflow elif (batch_size + n_words) <= (target_size + tol_size): overflow.append(doc) overflow_size += n_words - # this example does not fit with the previous overflow: start another new batch else: if batch: @@ -843,20 +867,19 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): tol_size = target_size * tolerance batch = [doc] batch_size = n_words - batch.extend(overflow) if batch: yield batch -def filter_spans(spans): +def filter_spans(spans: Iterable["Span"]) -> List["Span"]: """Filter a sequence of spans and remove duplicates or overlaps. Useful for creating named entities (where one token can only be part of one entity) or when merging spans with `Retokenizer.merge`. When spans overlap, the (first) longest span is preferred over shorter spans. - spans (iterable): The spans to filter. - RETURNS (list): The filtered spans. + spans (Iterable[Span]): The spans to filter. + RETURNS (List[Span]): The filtered spans. """ get_sort_key = lambda span: (span.end - span.start, -span.start) sorted_spans = sorted(spans, key=get_sort_key, reverse=True) @@ -871,15 +894,21 @@ def filter_spans(spans): return result -def to_bytes(getters, exclude): +def to_bytes(getters: Dict[str, Callable[[], bytes]], exclude: Iterable[str]) -> bytes: return srsly.msgpack_dumps(to_dict(getters, exclude)) -def from_bytes(bytes_data, setters, exclude): +def from_bytes( + bytes_data: bytes, + setters: Dict[str, Callable[[bytes], Any]], + exclude: Iterable[str], +) -> None: return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude) -def to_dict(getters, exclude): +def to_dict( + getters: Dict[str, Callable[[], Any]], exclude: Iterable[str] +) -> Dict[str, Any]: serialized = {} for key, getter in getters.items(): # Split to support file names like meta.json @@ -888,7 +917,11 @@ def to_dict(getters, exclude): return serialized -def from_dict(msg, setters, exclude): +def from_dict( + msg: Dict[str, Any], + setters: Dict[str, Callable[[Any], Any]], + exclude: Iterable[str], +) -> Dict[str, Any]: for key, setter in setters.items(): # Split to support file names like meta.json if key.split(".")[0] not in exclude and key in msg: @@ -896,7 +929,11 @@ def from_dict(msg, setters, exclude): return msg -def to_disk(path, writers, exclude): +def to_disk( + path: Union[str, Path], + writers: Dict[str, Callable[[Path], None]], + exclude: Iterable[str], +) -> Path: path = ensure_path(path) if not path.exists(): path.mkdir() @@ -907,7 +944,11 @@ def to_disk(path, writers, exclude): return path -def from_disk(path, readers, exclude): +def from_disk( + path: Union[str, Path], + readers: Dict[str, Callable[[Path], None]], + exclude: Iterable[str], +) -> Path: path = ensure_path(path) for key, reader in readers.items(): # Split to support file names like meta.json @@ -916,7 +957,7 @@ def from_disk(path, readers, exclude): return path -def import_file(name, loc): +def import_file(name: str, loc: Union[str, Path]) -> ModuleType: """Import module from a file. Used to load models from a directory. name (str): Name of module to load. @@ -930,7 +971,7 @@ def import_file(name, loc): return module -def minify_html(html): +def minify_html(html: str) -> str: """Perform a template-specific, rudimentary HTML minification for displaCy. Disclaimer: NOT a general-purpose solution, only removes indentation and newlines. @@ -941,7 +982,7 @@ def minify_html(html): return html.strip().replace(" ", "").replace("\n", "") -def escape_html(text): +def escape_html(text: str) -> str: """Replace <, >, &, " with their HTML encoded representation. Intended to prevent HTML errors in rendered displaCy markup. @@ -955,7 +996,9 @@ def escape_html(text): return text -def get_words_and_spaces(words, text): +def get_words_and_spaces( + words: Iterable[str], text: str +) -> Tuple[List[str], List[bool]]: if "".join("".join(words).split()) != "".join(text.split()): raise ValueError(Errors.E194.format(text=text, words=words)) text_words = [] @@ -1103,7 +1146,7 @@ class DummyTokenizer: return self -def link_vectors_to_models(vocab): +def link_vectors_to_models(vocab: "Vocab") -> None: vectors = vocab.vectors if vectors.name is None: vectors.name = VECTORS_KEY @@ -1119,7 +1162,7 @@ def link_vectors_to_models(vocab): VECTORS_KEY = "spacy_pretrained_vectors" -def create_default_optimizer(): +def create_default_optimizer() -> Optimizer: learn_rate = env_opt("learn_rate", 0.001) beta1 = env_opt("optimizer_B1", 0.9) beta2 = env_opt("optimizer_B2", 0.999)