mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 05:01:02 +03:00 
			
		
		
		
	Merge pull request #5813 from explosion/chore/tidy-autoformat-types
Tidy up, autoformat, add types
This commit is contained in:
		
						commit
						1346ee06d4
					
				
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							|  | @ -51,6 +51,7 @@ env3.*/ | ||||||
| .denv | .denv | ||||||
| .pypyenv | .pypyenv | ||||||
| .pytest_cache/ | .pytest_cache/ | ||||||
|  | .mypy_cache/ | ||||||
| 
 | 
 | ||||||
| # Distribution / packaging | # Distribution / packaging | ||||||
| env/ | env/ | ||||||
|  |  | ||||||
|  | @ -108,3 +108,8 @@ exclude = | ||||||
| [tool:pytest] | [tool:pytest] | ||||||
| markers = | markers = | ||||||
|     slow |     slow | ||||||
|  | 
 | ||||||
|  | [mypy] | ||||||
|  | ignore_missing_imports = True | ||||||
|  | no_implicit_optional = True | ||||||
|  | plugins = pydantic.mypy, thinc.mypy | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| from typing import Optional | from typing import Optional, Any, List, Union | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from wasabi import Printer | from wasabi import Printer | ||||||
|  | @ -66,10 +66,9 @@ def convert_cli( | ||||||
|         file_type = file_type.value |         file_type = file_type.value | ||||||
|     input_path = Path(input_path) |     input_path = Path(input_path) | ||||||
|     output_dir = "-" if output_dir == Path("-") else output_dir |     output_dir = "-" if output_dir == Path("-") else output_dir | ||||||
|     cli_args = locals() |  | ||||||
|     silent = output_dir == "-" |     silent = output_dir == "-" | ||||||
|     msg = Printer(no_print=silent) |     msg = Printer(no_print=silent) | ||||||
|     verify_cli_args(msg, **cli_args) |     verify_cli_args(msg, input_path, output_dir, file_type, converter, ner_map) | ||||||
|     converter = _get_converter(msg, converter, input_path) |     converter = _get_converter(msg, converter, input_path) | ||||||
|     convert( |     convert( | ||||||
|         input_path, |         input_path, | ||||||
|  | @ -89,8 +88,8 @@ def convert_cli( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def convert( | def convert( | ||||||
|     input_path: Path, |     input_path: Union[str, Path], | ||||||
|     output_dir: Path, |     output_dir: Union[str, Path], | ||||||
|     *, |     *, | ||||||
|     file_type: str = "json", |     file_type: str = "json", | ||||||
|     n_sents: int = 1, |     n_sents: int = 1, | ||||||
|  | @ -102,13 +101,12 @@ def convert( | ||||||
|     ner_map: Optional[Path] = None, |     ner_map: Optional[Path] = None, | ||||||
|     lang: Optional[str] = None, |     lang: Optional[str] = None, | ||||||
|     silent: bool = True, |     silent: bool = True, | ||||||
|     msg: Optional[Path] = None, |     msg: Optional[Printer], | ||||||
| ) -> None: | ) -> None: | ||||||
|     if not msg: |     if not msg: | ||||||
|         msg = Printer(no_print=silent) |         msg = Printer(no_print=silent) | ||||||
|     ner_map = srsly.read_json(ner_map) if ner_map is not None else None |     ner_map = srsly.read_json(ner_map) if ner_map is not None else None | ||||||
| 
 |     for input_loc in walk_directory(Path(input_path)): | ||||||
|     for input_loc in walk_directory(input_path): |  | ||||||
|         input_data = input_loc.open("r", encoding="utf-8").read() |         input_data = input_loc.open("r", encoding="utf-8").read() | ||||||
|         # Use converter function to convert data |         # Use converter function to convert data | ||||||
|         func = CONVERTERS[converter] |         func = CONVERTERS[converter] | ||||||
|  | @ -140,14 +138,14 @@ def convert( | ||||||
|             msg.good(f"Generated output file ({len(docs)} documents): {output_file}") |             msg.good(f"Generated output file ({len(docs)} documents): {output_file}") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _print_docs_to_stdout(data, output_type): | def _print_docs_to_stdout(data: Any, output_type: str) -> None: | ||||||
|     if output_type == "json": |     if output_type == "json": | ||||||
|         srsly.write_json("-", data) |         srsly.write_json("-", data) | ||||||
|     else: |     else: | ||||||
|         sys.stdout.buffer.write(data) |         sys.stdout.buffer.write(data) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _write_docs_to_file(data, output_file, output_type): | def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None: | ||||||
|     if not output_file.parent.exists(): |     if not output_file.parent.exists(): | ||||||
|         output_file.parent.mkdir(parents=True) |         output_file.parent.mkdir(parents=True) | ||||||
|     if output_type == "json": |     if output_type == "json": | ||||||
|  | @ -157,7 +155,7 @@ def _write_docs_to_file(data, output_file, output_type): | ||||||
|             file_.write(data) |             file_.write(data) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def autodetect_ner_format(input_data: str) -> str: | def autodetect_ner_format(input_data: str) -> Optional[str]: | ||||||
|     # guess format from the first 20 lines |     # guess format from the first 20 lines | ||||||
|     lines = input_data.split("\n")[:20] |     lines = input_data.split("\n")[:20] | ||||||
|     format_guesses = {"ner": 0, "iob": 0} |     format_guesses = {"ner": 0, "iob": 0} | ||||||
|  | @ -176,7 +174,7 @@ def autodetect_ner_format(input_data: str) -> str: | ||||||
|     return None |     return None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def walk_directory(path): | def walk_directory(path: Path) -> List[Path]: | ||||||
|     if not path.is_dir(): |     if not path.is_dir(): | ||||||
|         return [path] |         return [path] | ||||||
|     paths = [path] |     paths = [path] | ||||||
|  | @ -196,31 +194,25 @@ def walk_directory(path): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def verify_cli_args( | def verify_cli_args( | ||||||
|     msg, |     msg: Printer, | ||||||
|     input_path, |     input_path: Union[str, Path], | ||||||
|     output_dir, |     output_dir: Union[str, Path], | ||||||
|     file_type, |     file_type: FileTypes, | ||||||
|     n_sents, |     converter: str, | ||||||
|     seg_sents, |     ner_map: Optional[Path], | ||||||
|     model, |  | ||||||
|     morphology, |  | ||||||
|     merge_subtokens, |  | ||||||
|     converter, |  | ||||||
|     ner_map, |  | ||||||
|     lang, |  | ||||||
| ): | ): | ||||||
|     input_path = Path(input_path) |     input_path = Path(input_path) | ||||||
|     if file_type not in FILE_TYPES_STDOUT and output_dir == "-": |     if file_type not in FILE_TYPES_STDOUT and output_dir == "-": | ||||||
|         # TODO: support msgpack via stdout in srsly? |  | ||||||
|         msg.fail( |         msg.fail( | ||||||
|             f"Can't write .{file_type} data to stdout", |             f"Can't write .{file_type} data to stdout. Please specify an output directory.", | ||||||
|             "Please specify an output directory.", |  | ||||||
|             exits=1, |             exits=1, | ||||||
|         ) |         ) | ||||||
|     if not input_path.exists(): |     if not input_path.exists(): | ||||||
|         msg.fail("Input file not found", input_path, exits=1) |         msg.fail("Input file not found", input_path, exits=1) | ||||||
|     if output_dir != "-" and not Path(output_dir).exists(): |     if output_dir != "-" and not Path(output_dir).exists(): | ||||||
|         msg.fail("Output directory not found", output_dir, exits=1) |         msg.fail("Output directory not found", output_dir, exits=1) | ||||||
|  |     if ner_map is not None and not Path(ner_map).exists(): | ||||||
|  |         msg.fail("NER map not found", ner_map, exits=1) | ||||||
|     if input_path.is_dir(): |     if input_path.is_dir(): | ||||||
|         input_locs = walk_directory(input_path) |         input_locs = walk_directory(input_path) | ||||||
|         if len(input_locs) == 0: |         if len(input_locs) == 0: | ||||||
|  | @ -229,10 +221,8 @@ def verify_cli_args( | ||||||
|         if len(file_types) >= 2: |         if len(file_types) >= 2: | ||||||
|             file_types = ",".join(file_types) |             file_types = ",".join(file_types) | ||||||
|             msg.fail("All input files must be same type", file_types, exits=1) |             msg.fail("All input files must be same type", file_types, exits=1) | ||||||
|     converter = _get_converter(msg, converter, input_path) |     if converter != "auto" and converter not in CONVERTERS: | ||||||
|     if converter not in CONVERTERS: |  | ||||||
|         msg.fail(f"Can't find converter for {converter}", exits=1) |         msg.fail(f"Can't find converter for {converter}", exits=1) | ||||||
|     return converter |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _get_converter(msg, converter, input_path): | def _get_converter(msg, converter, input_path): | ||||||
|  |  | ||||||
|  | @ -82,11 +82,11 @@ def evaluate( | ||||||
|         "NER P": "ents_p", |         "NER P": "ents_p", | ||||||
|         "NER R": "ents_r", |         "NER R": "ents_r", | ||||||
|         "NER F": "ents_f", |         "NER F": "ents_f", | ||||||
|         "Textcat AUC": 'textcat_macro_auc', |         "Textcat AUC": "textcat_macro_auc", | ||||||
|         "Textcat F": 'textcat_macro_f', |         "Textcat F": "textcat_macro_f", | ||||||
|         "Sent P": 'sents_p', |         "Sent P": "sents_p", | ||||||
|         "Sent R": 'sents_r', |         "Sent R": "sents_r", | ||||||
|         "Sent F": 'sents_f', |         "Sent F": "sents_f", | ||||||
|     } |     } | ||||||
|     results = {} |     results = {} | ||||||
|     for metric, key in metrics.items(): |     for metric, key in metrics.items(): | ||||||
|  |  | ||||||
|  | @ -14,7 +14,6 @@ import typer | ||||||
| from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error | from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error | ||||||
| from ._util import import_code | from ._util import import_code | ||||||
| from ..gold import Corpus, Example | from ..gold import Corpus, Example | ||||||
| from ..lookups import Lookups |  | ||||||
| from ..language import Language | from ..language import Language | ||||||
| from .. import util | from .. import util | ||||||
| from ..errors import Errors | from ..errors import Errors | ||||||
|  |  | ||||||
|  | @ -1,12 +1,5 @@ | ||||||
| """ | """Helpers for Python and platform compatibility.""" | ||||||
| Helpers for Python and platform compatibility. To distinguish them from |  | ||||||
| the builtin functions, replacement functions are suffixed with an underscore, |  | ||||||
| e.g. `unicode_`. |  | ||||||
| 
 |  | ||||||
| DOCS: https://spacy.io/api/top-level#compat |  | ||||||
| """ |  | ||||||
| import sys | import sys | ||||||
| 
 |  | ||||||
| from thinc.util import copy_array | from thinc.util import copy_array | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|  | @ -40,21 +33,3 @@ copy_array = copy_array | ||||||
| is_windows = sys.platform.startswith("win") | is_windows = sys.platform.startswith("win") | ||||||
| is_linux = sys.platform.startswith("linux") | is_linux = sys.platform.startswith("linux") | ||||||
| is_osx = sys.platform == "darwin" | is_osx = sys.platform == "darwin" | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def is_config(windows=None, linux=None, osx=None, **kwargs): |  | ||||||
|     """Check if a specific configuration of Python version and operating system |  | ||||||
|     matches the user's setup. Mostly used to display targeted error messages. |  | ||||||
| 
 |  | ||||||
|     windows (bool): spaCy is executed on Windows. |  | ||||||
|     linux (bool): spaCy is executed on Linux. |  | ||||||
|     osx (bool): spaCy is executed on OS X or macOS. |  | ||||||
|     RETURNS (bool): Whether the configuration matches the user's platform. |  | ||||||
| 
 |  | ||||||
|     DOCS: https://spacy.io/api/top-level#compat.is_config |  | ||||||
|     """ |  | ||||||
|     return ( |  | ||||||
|         windows in (None, is_windows) |  | ||||||
|         and linux in (None, is_linux) |  | ||||||
|         and osx in (None, is_osx) |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  | @ -4,6 +4,7 @@ spaCy's built in visualization suite for dependencies and named entities. | ||||||
| DOCS: https://spacy.io/api/top-level#displacy | DOCS: https://spacy.io/api/top-level#displacy | ||||||
| USAGE: https://spacy.io/usage/visualizers | USAGE: https://spacy.io/usage/visualizers | ||||||
| """ | """ | ||||||
|  | from typing import Union, Iterable, Optional, Dict, Any, Callable | ||||||
| import warnings | import warnings | ||||||
| 
 | 
 | ||||||
| from .render import DependencyRenderer, EntityRenderer | from .render import DependencyRenderer, EntityRenderer | ||||||
|  | @ -17,11 +18,17 @@ RENDER_WRAPPER = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def render( | def render( | ||||||
|     docs, style="dep", page=False, minify=False, jupyter=None, options={}, manual=False |     docs: Union[Iterable[Doc], Doc], | ||||||
| ): |     style: str = "dep", | ||||||
|  |     page: bool = False, | ||||||
|  |     minify: bool = False, | ||||||
|  |     jupyter: Optional[bool] = None, | ||||||
|  |     options: Dict[str, Any] = {}, | ||||||
|  |     manual: bool = False, | ||||||
|  | ) -> str: | ||||||
|     """Render displaCy visualisation. |     """Render displaCy visualisation. | ||||||
| 
 | 
 | ||||||
|     docs (list or Doc): Document(s) to visualise. |     docs (Union[Iterable[Doc], Doc]): Document(s) to visualise. | ||||||
|     style (str): Visualisation style, 'dep' or 'ent'. |     style (str): Visualisation style, 'dep' or 'ent'. | ||||||
|     page (bool): Render markup as full HTML page. |     page (bool): Render markup as full HTML page. | ||||||
|     minify (bool): Minify HTML markup. |     minify (bool): Minify HTML markup. | ||||||
|  | @ -44,8 +51,8 @@ def render( | ||||||
|     docs = [obj if not isinstance(obj, Span) else obj.as_doc() for obj in docs] |     docs = [obj if not isinstance(obj, Span) else obj.as_doc() for obj in docs] | ||||||
|     if not all(isinstance(obj, (Doc, Span, dict)) for obj in docs): |     if not all(isinstance(obj, (Doc, Span, dict)) for obj in docs): | ||||||
|         raise ValueError(Errors.E096) |         raise ValueError(Errors.E096) | ||||||
|     renderer, converter = factories[style] |     renderer_func, converter = factories[style] | ||||||
|     renderer = renderer(options=options) |     renderer = renderer_func(options=options) | ||||||
|     parsed = [converter(doc, options) for doc in docs] if not manual else docs |     parsed = [converter(doc, options) for doc in docs] if not manual else docs | ||||||
|     _html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip() |     _html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip() | ||||||
|     html = _html["parsed"] |     html = _html["parsed"] | ||||||
|  | @ -61,15 +68,15 @@ def render( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def serve( | def serve( | ||||||
|     docs, |     docs: Union[Iterable[Doc], Doc], | ||||||
|     style="dep", |     style: str = "dep", | ||||||
|     page=True, |     page: bool = True, | ||||||
|     minify=False, |     minify: bool = False, | ||||||
|     options={}, |     options: Dict[str, Any] = {}, | ||||||
|     manual=False, |     manual: bool = False, | ||||||
|     port=5000, |     port: int = 5000, | ||||||
|     host="0.0.0.0", |     host: str = "0.0.0.0", | ||||||
| ): | ) -> None: | ||||||
|     """Serve displaCy visualisation. |     """Serve displaCy visualisation. | ||||||
| 
 | 
 | ||||||
|     docs (list or Doc): Document(s) to visualise. |     docs (list or Doc): Document(s) to visualise. | ||||||
|  | @ -88,7 +95,6 @@ def serve( | ||||||
| 
 | 
 | ||||||
|     if is_in_jupyter(): |     if is_in_jupyter(): | ||||||
|         warnings.warn(Warnings.W011) |         warnings.warn(Warnings.W011) | ||||||
| 
 |  | ||||||
|     render(docs, style=style, page=page, minify=minify, options=options, manual=manual) |     render(docs, style=style, page=page, minify=minify, options=options, manual=manual) | ||||||
|     httpd = simple_server.make_server(host, port, app) |     httpd = simple_server.make_server(host, port, app) | ||||||
|     print(f"\nUsing the '{style}' visualizer") |     print(f"\nUsing the '{style}' visualizer") | ||||||
|  | @ -102,14 +108,13 @@ def serve( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def app(environ, start_response): | def app(environ, start_response): | ||||||
|     # Headers and status need to be bytes in Python 2, see #1227 |  | ||||||
|     headers = [("Content-type", "text/html; charset=utf-8")] |     headers = [("Content-type", "text/html; charset=utf-8")] | ||||||
|     start_response("200 OK", headers) |     start_response("200 OK", headers) | ||||||
|     res = _html["parsed"].encode(encoding="utf-8") |     res = _html["parsed"].encode(encoding="utf-8") | ||||||
|     return [res] |     return [res] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def parse_deps(orig_doc, options={}): | def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]: | ||||||
|     """Generate dependency parse in {'words': [], 'arcs': []} format. |     """Generate dependency parse in {'words': [], 'arcs': []} format. | ||||||
| 
 | 
 | ||||||
|     doc (Doc): Document do parse. |     doc (Doc): Document do parse. | ||||||
|  | @ -152,7 +157,6 @@ def parse_deps(orig_doc, options={}): | ||||||
|         } |         } | ||||||
|         for w in doc |         for w in doc | ||||||
|     ] |     ] | ||||||
| 
 |  | ||||||
|     arcs = [] |     arcs = [] | ||||||
|     for word in doc: |     for word in doc: | ||||||
|         if word.i < word.head.i: |         if word.i < word.head.i: | ||||||
|  | @ -171,7 +175,7 @@ def parse_deps(orig_doc, options={}): | ||||||
|     return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)} |     return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)} | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def parse_ents(doc, options={}): | def parse_ents(doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]: | ||||||
|     """Generate named entities in [{start: i, end: i, label: 'label'}] format. |     """Generate named entities in [{start: i, end: i, label: 'label'}] format. | ||||||
| 
 | 
 | ||||||
|     doc (Doc): Document do parse. |     doc (Doc): Document do parse. | ||||||
|  | @ -188,7 +192,7 @@ def parse_ents(doc, options={}): | ||||||
|     return {"text": doc.text, "ents": ents, "title": title, "settings": settings} |     return {"text": doc.text, "ents": ents, "title": title, "settings": settings} | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def set_render_wrapper(func): | def set_render_wrapper(func: Callable[[str], str]) -> None: | ||||||
|     """Set an optional wrapper function that is called around the generated |     """Set an optional wrapper function that is called around the generated | ||||||
|     HTML markup on displacy.render. This can be used to allow integration into |     HTML markup on displacy.render. This can be used to allow integration into | ||||||
|     other platforms, similar to Jupyter Notebooks that require functions to be |     other platforms, similar to Jupyter Notebooks that require functions to be | ||||||
|  | @ -205,7 +209,7 @@ def set_render_wrapper(func): | ||||||
|     RENDER_WRAPPER = func |     RENDER_WRAPPER = func | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_doc_settings(doc): | def get_doc_settings(doc: Doc) -> Dict[str, Any]: | ||||||
|     return { |     return { | ||||||
|         "lang": doc.lang_, |         "lang": doc.lang_, | ||||||
|         "direction": doc.vocab.writing_system.get("direction", "ltr"), |         "direction": doc.vocab.writing_system.get("direction", "ltr"), | ||||||
|  |  | ||||||
|  | @ -1,19 +1,36 @@ | ||||||
|  | from typing import Dict, Any, List, Optional, Union | ||||||
| import uuid | import uuid | ||||||
| 
 | 
 | ||||||
| from .templates import ( | from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_WORDS_LEMMA, TPL_DEP_ARCS | ||||||
|     TPL_DEP_SVG, |  | ||||||
|     TPL_DEP_WORDS, |  | ||||||
|     TPL_DEP_WORDS_LEMMA, |  | ||||||
|     TPL_DEP_ARCS, |  | ||||||
|     TPL_ENTS, |  | ||||||
| ) |  | ||||||
| from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE | from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE | ||||||
|  | from .templates import TPL_ENTS | ||||||
| from ..util import minify_html, escape_html, registry | from ..util import minify_html, escape_html, registry | ||||||
| from ..errors import Errors | from ..errors import Errors | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| DEFAULT_LANG = "en" | DEFAULT_LANG = "en" | ||||||
| DEFAULT_DIR = "ltr" | DEFAULT_DIR = "ltr" | ||||||
|  | DEFAULT_ENTITY_COLOR = "#ddd" | ||||||
|  | DEFAULT_LABEL_COLORS = { | ||||||
|  |     "ORG": "#7aecec", | ||||||
|  |     "PRODUCT": "#bfeeb7", | ||||||
|  |     "GPE": "#feca74", | ||||||
|  |     "LOC": "#ff9561", | ||||||
|  |     "PERSON": "#aa9cfc", | ||||||
|  |     "NORP": "#c887fb", | ||||||
|  |     "FACILITY": "#9cc9cc", | ||||||
|  |     "EVENT": "#ffeb80", | ||||||
|  |     "LAW": "#ff8197", | ||||||
|  |     "LANGUAGE": "#ff8197", | ||||||
|  |     "WORK_OF_ART": "#f0d0ff", | ||||||
|  |     "DATE": "#bfe1d9", | ||||||
|  |     "TIME": "#bfe1d9", | ||||||
|  |     "MONEY": "#e4e7d2", | ||||||
|  |     "QUANTITY": "#e4e7d2", | ||||||
|  |     "ORDINAL": "#e4e7d2", | ||||||
|  |     "CARDINAL": "#e4e7d2", | ||||||
|  |     "PERCENT": "#e4e7d2", | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class DependencyRenderer: | class DependencyRenderer: | ||||||
|  | @ -21,7 +38,7 @@ class DependencyRenderer: | ||||||
| 
 | 
 | ||||||
|     style = "dep" |     style = "dep" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, options={}): |     def __init__(self, options: Dict[str, Any] = {}) -> None: | ||||||
|         """Initialise dependency renderer. |         """Initialise dependency renderer. | ||||||
| 
 | 
 | ||||||
|         options (dict): Visualiser-specific options (compact, word_spacing, |         options (dict): Visualiser-specific options (compact, word_spacing, | ||||||
|  | @ -41,7 +58,9 @@ class DependencyRenderer: | ||||||
|         self.direction = DEFAULT_DIR |         self.direction = DEFAULT_DIR | ||||||
|         self.lang = DEFAULT_LANG |         self.lang = DEFAULT_LANG | ||||||
| 
 | 
 | ||||||
|     def render(self, parsed, page=False, minify=False): |     def render( | ||||||
|  |         self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False | ||||||
|  |     ) -> str: | ||||||
|         """Render complete markup. |         """Render complete markup. | ||||||
| 
 | 
 | ||||||
|         parsed (list): Dependency parses to render. |         parsed (list): Dependency parses to render. | ||||||
|  | @ -72,10 +91,15 @@ class DependencyRenderer: | ||||||
|             return minify_html(markup) |             return minify_html(markup) | ||||||
|         return markup |         return markup | ||||||
| 
 | 
 | ||||||
|     def render_svg(self, render_id, words, arcs): |     def render_svg( | ||||||
|  |         self, | ||||||
|  |         render_id: Union[int, str], | ||||||
|  |         words: List[Dict[str, Any]], | ||||||
|  |         arcs: List[Dict[str, Any]], | ||||||
|  |     ) -> str: | ||||||
|         """Render SVG. |         """Render SVG. | ||||||
| 
 | 
 | ||||||
|         render_id (int): Unique ID, typically index of document. |         render_id (Union[int, str]): Unique ID, typically index of document. | ||||||
|         words (list): Individual words and their tags. |         words (list): Individual words and their tags. | ||||||
|         arcs (list): Individual arcs and their start, end, direction and label. |         arcs (list): Individual arcs and their start, end, direction and label. | ||||||
|         RETURNS (str): Rendered SVG markup. |         RETURNS (str): Rendered SVG markup. | ||||||
|  | @ -86,15 +110,15 @@ class DependencyRenderer: | ||||||
|         self.width = self.offset_x + len(words) * self.distance |         self.width = self.offset_x + len(words) * self.distance | ||||||
|         self.height = self.offset_y + 3 * self.word_spacing |         self.height = self.offset_y + 3 * self.word_spacing | ||||||
|         self.id = render_id |         self.id = render_id | ||||||
|         words = [ |         words_svg = [ | ||||||
|             self.render_word(w["text"], w["tag"], w.get("lemma", None), i) |             self.render_word(w["text"], w["tag"], w.get("lemma", None), i) | ||||||
|             for i, w in enumerate(words) |             for i, w in enumerate(words) | ||||||
|         ] |         ] | ||||||
|         arcs = [ |         arcs_svg = [ | ||||||
|             self.render_arrow(a["label"], a["start"], a["end"], a["dir"], i) |             self.render_arrow(a["label"], a["start"], a["end"], a["dir"], i) | ||||||
|             for i, a in enumerate(arcs) |             for i, a in enumerate(arcs) | ||||||
|         ] |         ] | ||||||
|         content = "".join(words) + "".join(arcs) |         content = "".join(words_svg) + "".join(arcs_svg) | ||||||
|         return TPL_DEP_SVG.format( |         return TPL_DEP_SVG.format( | ||||||
|             id=self.id, |             id=self.id, | ||||||
|             width=self.width, |             width=self.width, | ||||||
|  | @ -107,9 +131,7 @@ class DependencyRenderer: | ||||||
|             lang=self.lang, |             lang=self.lang, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def render_word( |     def render_word(self, text: str, tag: str, lemma: str, i: int) -> str: | ||||||
|         self, text, tag, lemma, i, |  | ||||||
|     ): |  | ||||||
|         """Render individual word. |         """Render individual word. | ||||||
| 
 | 
 | ||||||
|         text (str): Word text. |         text (str): Word text. | ||||||
|  | @ -128,7 +150,9 @@ class DependencyRenderer: | ||||||
|             ) |             ) | ||||||
|         return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y) |         return TPL_DEP_WORDS.format(text=html_text, tag=tag, x=x, y=y) | ||||||
| 
 | 
 | ||||||
|     def render_arrow(self, label, start, end, direction, i): |     def render_arrow( | ||||||
|  |         self, label: str, start: int, end: int, direction: str, i: int | ||||||
|  |     ) -> str: | ||||||
|         """Render individual arrow. |         """Render individual arrow. | ||||||
| 
 | 
 | ||||||
|         label (str): Dependency label. |         label (str): Dependency label. | ||||||
|  | @ -172,7 +196,7 @@ class DependencyRenderer: | ||||||
|             arc=arc, |             arc=arc, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def get_arc(self, x_start, y, y_curve, x_end): |     def get_arc(self, x_start: int, y: int, y_curve: int, x_end: int) -> str: | ||||||
|         """Render individual arc. |         """Render individual arc. | ||||||
| 
 | 
 | ||||||
|         x_start (int): X-coordinate of arrow start point. |         x_start (int): X-coordinate of arrow start point. | ||||||
|  | @ -186,7 +210,7 @@ class DependencyRenderer: | ||||||
|             template = "M{x},{y} {x},{c} {e},{c} {e},{y}" |             template = "M{x},{y} {x},{c} {e},{c} {e},{y}" | ||||||
|         return template.format(x=x_start, y=y, c=y_curve, e=x_end) |         return template.format(x=x_start, y=y, c=y_curve, e=x_end) | ||||||
| 
 | 
 | ||||||
|     def get_arrowhead(self, direction, x, y, end): |     def get_arrowhead(self, direction: str, x: int, y: int, end: int) -> str: | ||||||
|         """Render individual arrow head. |         """Render individual arrow head. | ||||||
| 
 | 
 | ||||||
|         direction (str): Arrow direction, 'left' or 'right'. |         direction (str): Arrow direction, 'left' or 'right'. | ||||||
|  | @ -196,24 +220,12 @@ class DependencyRenderer: | ||||||
|         RETURNS (str): Definition of the arrow head path ('d' attribute). |         RETURNS (str): Definition of the arrow head path ('d' attribute). | ||||||
|         """ |         """ | ||||||
|         if direction == "left": |         if direction == "left": | ||||||
|             pos1, pos2, pos3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2) |             p1, p2, p3 = (x, x - self.arrow_width + 2, x + self.arrow_width - 2) | ||||||
|         else: |         else: | ||||||
|             pos1, pos2, pos3 = ( |             p1, p2, p3 = (end, end + self.arrow_width - 2, end - self.arrow_width + 2) | ||||||
|                 end, |         return f"M{p1},{y + 2} L{p2},{y - self.arrow_width} {p3},{y - self.arrow_width}" | ||||||
|                 end + self.arrow_width - 2, |  | ||||||
|                 end - self.arrow_width + 2, |  | ||||||
|             ) |  | ||||||
|         arrowhead = ( |  | ||||||
|             pos1, |  | ||||||
|             y + 2, |  | ||||||
|             pos2, |  | ||||||
|             y - self.arrow_width, |  | ||||||
|             pos3, |  | ||||||
|             y - self.arrow_width, |  | ||||||
|         ) |  | ||||||
|         return "M{},{} L{},{} {},{}".format(*arrowhead) |  | ||||||
| 
 | 
 | ||||||
|     def get_levels(self, arcs): |     def get_levels(self, arcs: List[Dict[str, Any]]) -> List[int]: | ||||||
|         """Calculate available arc height "levels". |         """Calculate available arc height "levels". | ||||||
|         Used to calculate arrow heights dynamically and without wasting space. |         Used to calculate arrow heights dynamically and without wasting space. | ||||||
| 
 | 
 | ||||||
|  | @ -229,41 +241,21 @@ class EntityRenderer: | ||||||
| 
 | 
 | ||||||
|     style = "ent" |     style = "ent" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, options={}): |     def __init__(self, options: Dict[str, Any] = {}) -> None: | ||||||
|         """Initialise dependency renderer. |         """Initialise dependency renderer. | ||||||
| 
 | 
 | ||||||
|         options (dict): Visualiser-specific options (colors, ents) |         options (dict): Visualiser-specific options (colors, ents) | ||||||
|         """ |         """ | ||||||
|         colors = { |         colors = dict(DEFAULT_LABEL_COLORS) | ||||||
|             "ORG": "#7aecec", |  | ||||||
|             "PRODUCT": "#bfeeb7", |  | ||||||
|             "GPE": "#feca74", |  | ||||||
|             "LOC": "#ff9561", |  | ||||||
|             "PERSON": "#aa9cfc", |  | ||||||
|             "NORP": "#c887fb", |  | ||||||
|             "FACILITY": "#9cc9cc", |  | ||||||
|             "EVENT": "#ffeb80", |  | ||||||
|             "LAW": "#ff8197", |  | ||||||
|             "LANGUAGE": "#ff8197", |  | ||||||
|             "WORK_OF_ART": "#f0d0ff", |  | ||||||
|             "DATE": "#bfe1d9", |  | ||||||
|             "TIME": "#bfe1d9", |  | ||||||
|             "MONEY": "#e4e7d2", |  | ||||||
|             "QUANTITY": "#e4e7d2", |  | ||||||
|             "ORDINAL": "#e4e7d2", |  | ||||||
|             "CARDINAL": "#e4e7d2", |  | ||||||
|             "PERCENT": "#e4e7d2", |  | ||||||
|         } |  | ||||||
|         user_colors = registry.displacy_colors.get_all() |         user_colors = registry.displacy_colors.get_all() | ||||||
|         for user_color in user_colors.values(): |         for user_color in user_colors.values(): | ||||||
|             colors.update(user_color) |             colors.update(user_color) | ||||||
|         colors.update(options.get("colors", {})) |         colors.update(options.get("colors", {})) | ||||||
|         self.default_color = "#ddd" |         self.default_color = DEFAULT_ENTITY_COLOR | ||||||
|         self.colors = colors |         self.colors = colors | ||||||
|         self.ents = options.get("ents", None) |         self.ents = options.get("ents", None) | ||||||
|         self.direction = DEFAULT_DIR |         self.direction = DEFAULT_DIR | ||||||
|         self.lang = DEFAULT_LANG |         self.lang = DEFAULT_LANG | ||||||
| 
 |  | ||||||
|         template = options.get("template") |         template = options.get("template") | ||||||
|         if template: |         if template: | ||||||
|             self.ent_template = template |             self.ent_template = template | ||||||
|  | @ -273,7 +265,9 @@ class EntityRenderer: | ||||||
|             else: |             else: | ||||||
|                 self.ent_template = TPL_ENT |                 self.ent_template = TPL_ENT | ||||||
| 
 | 
 | ||||||
|     def render(self, parsed, page=False, minify=False): |     def render( | ||||||
|  |         self, parsed: List[Dict[str, Any]], page: bool = False, minify: bool = False | ||||||
|  |     ) -> str: | ||||||
|         """Render complete markup. |         """Render complete markup. | ||||||
| 
 | 
 | ||||||
|         parsed (list): Dependency parses to render. |         parsed (list): Dependency parses to render. | ||||||
|  | @ -297,7 +291,9 @@ class EntityRenderer: | ||||||
|             return minify_html(markup) |             return minify_html(markup) | ||||||
|         return markup |         return markup | ||||||
| 
 | 
 | ||||||
|     def render_ents(self, text, spans, title): |     def render_ents( | ||||||
|  |         self, text: str, spans: List[Dict[str, Any]], title: Optional[str] | ||||||
|  |     ) -> str: | ||||||
|         """Render entities in text. |         """Render entities in text. | ||||||
| 
 | 
 | ||||||
|         text (str): Original text. |         text (str): Original text. | ||||||
|  |  | ||||||
|  | @ -483,6 +483,8 @@ class Errors: | ||||||
|     E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") |     E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") | ||||||
| 
 | 
 | ||||||
|     # TODO: fix numbering after merging develop into master |     # TODO: fix numbering after merging develop into master | ||||||
|  |     E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}") | ||||||
|  |     E954 = ("The Tok2Vec listener did not receive a valid input.") | ||||||
|     E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.") |     E955 = ("Can't find table '{table}' for language '{lang}' in spacy-lookups-data.") | ||||||
|     E956 = ("Can't find component '{name}' in [components] block in the config. " |     E956 = ("Can't find component '{name}' in [components] block in the config. " | ||||||
|             "Available components: {opts}") |             "Available components: {opts}") | ||||||
|  |  | ||||||
|  | @ -15,7 +15,7 @@ class Alignment: | ||||||
|         x2y = _make_ragged(x2y) |         x2y = _make_ragged(x2y) | ||||||
|         y2x = _make_ragged(y2x) |         y2x = _make_ragged(y2x) | ||||||
|         return Alignment(x2y=x2y, y2x=y2x) |         return Alignment(x2y=x2y, y2x=y2x) | ||||||
|      | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def from_strings(cls, A: List[str], B: List[str]) -> "Alignment": |     def from_strings(cls, A: List[str], B: List[str]) -> "Alignment": | ||||||
|         x2y, y2x = tokenizations.get_alignments(A, B) |         x2y, y2x = tokenizations.get_alignments(A, B) | ||||||
|  |  | ||||||
|  | @ -163,7 +163,7 @@ class Language: | ||||||
|         else: |         else: | ||||||
|             if (self.lang and vocab.lang) and (self.lang != vocab.lang): |             if (self.lang and vocab.lang) and (self.lang != vocab.lang): | ||||||
|                 raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang)) |                 raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang)) | ||||||
|         self.vocab = vocab |         self.vocab: Vocab = vocab | ||||||
|         if self.lang is None: |         if self.lang is None: | ||||||
|             self.lang = self.vocab.lang |             self.lang = self.vocab.lang | ||||||
|         self.pipeline = [] |         self.pipeline = [] | ||||||
|  |  | ||||||
|  | @ -8,7 +8,7 @@ def PrecomputableAffine(nO, nI, nF, nP, dropout=0.1): | ||||||
|         init=init, |         init=init, | ||||||
|         dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP}, |         dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP}, | ||||||
|         params={"W": None, "b": None, "pad": None}, |         params={"W": None, "b": None, "pad": None}, | ||||||
|         attrs={"dropout_rate": dropout} |         attrs={"dropout_rate": dropout}, | ||||||
|     ) |     ) | ||||||
|     return model |     return model | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -154,16 +154,30 @@ def LayerNormalizedMaxout(width, maxout_pieces): | ||||||
| def MultiHashEmbed( | def MultiHashEmbed( | ||||||
|     columns, width, rows, use_subwords, pretrained_vectors, mix, dropout |     columns, width, rows, use_subwords, pretrained_vectors, mix, dropout | ||||||
| ): | ): | ||||||
|     norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6) |     norm = HashEmbed( | ||||||
|  |         nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6 | ||||||
|  |     ) | ||||||
|     if use_subwords: |     if use_subwords: | ||||||
|         prefix = HashEmbed( |         prefix = HashEmbed( | ||||||
|             nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout, seed=7 |             nO=width, | ||||||
|  |             nV=rows // 2, | ||||||
|  |             column=columns.index("PREFIX"), | ||||||
|  |             dropout=dropout, | ||||||
|  |             seed=7, | ||||||
|         ) |         ) | ||||||
|         suffix = HashEmbed( |         suffix = HashEmbed( | ||||||
|             nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout, seed=8 |             nO=width, | ||||||
|  |             nV=rows // 2, | ||||||
|  |             column=columns.index("SUFFIX"), | ||||||
|  |             dropout=dropout, | ||||||
|  |             seed=8, | ||||||
|         ) |         ) | ||||||
|         shape = HashEmbed( |         shape = HashEmbed( | ||||||
|             nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout, seed=9 |             nO=width, | ||||||
|  |             nV=rows // 2, | ||||||
|  |             column=columns.index("SHAPE"), | ||||||
|  |             dropout=dropout, | ||||||
|  |             seed=9, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     if pretrained_vectors: |     if pretrained_vectors: | ||||||
|  | @ -192,7 +206,9 @@ def MultiHashEmbed( | ||||||
| 
 | 
 | ||||||
| @registry.architectures.register("spacy.CharacterEmbed.v1") | @registry.architectures.register("spacy.CharacterEmbed.v1") | ||||||
| def CharacterEmbed(columns, width, rows, nM, nC, features, dropout): | def CharacterEmbed(columns, width, rows, nM, nC, features, dropout): | ||||||
|     norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5) |     norm = HashEmbed( | ||||||
|  |         nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5 | ||||||
|  |     ) | ||||||
|     chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) |     chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) | ||||||
|     with Model.define_operators({">>": chain, "|": concatenate}): |     with Model.define_operators({">>": chain, "|": concatenate}): | ||||||
|         embed_layer = chr_embed | features >> with_array(norm) |         embed_layer = chr_embed | features >> with_array(norm) | ||||||
|  | @ -263,21 +279,29 @@ def build_Tok2Vec_model( | ||||||
|     cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] |     cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] | ||||||
|     with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): |     with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): | ||||||
|         norm = HashEmbed( |         norm = HashEmbed( | ||||||
|             nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, |             nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, seed=0 | ||||||
|             seed=0 |  | ||||||
|         ) |         ) | ||||||
|         if subword_features: |         if subword_features: | ||||||
|             prefix = HashEmbed( |             prefix = HashEmbed( | ||||||
|                 nO=width, nV=embed_size // 2, column=cols.index(PREFIX), dropout=None, |                 nO=width, | ||||||
|                 seed=1 |                 nV=embed_size // 2, | ||||||
|  |                 column=cols.index(PREFIX), | ||||||
|  |                 dropout=None, | ||||||
|  |                 seed=1, | ||||||
|             ) |             ) | ||||||
|             suffix = HashEmbed( |             suffix = HashEmbed( | ||||||
|                 nO=width, nV=embed_size // 2, column=cols.index(SUFFIX), dropout=None, |                 nO=width, | ||||||
|                 seed=2 |                 nV=embed_size // 2, | ||||||
|  |                 column=cols.index(SUFFIX), | ||||||
|  |                 dropout=None, | ||||||
|  |                 seed=2, | ||||||
|             ) |             ) | ||||||
|             shape = HashEmbed( |             shape = HashEmbed( | ||||||
|                 nO=width, nV=embed_size // 2, column=cols.index(SHAPE), dropout=None, |                 nO=width, | ||||||
|                 seed=3 |                 nV=embed_size // 2, | ||||||
|  |                 column=cols.index(SHAPE), | ||||||
|  |                 dropout=None, | ||||||
|  |                 seed=3, | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             prefix, suffix, shape = (None, None, None) |             prefix, suffix, shape = (None, None, None) | ||||||
|  | @ -294,11 +318,7 @@ def build_Tok2Vec_model( | ||||||
|                 embed = uniqued( |                 embed = uniqued( | ||||||
|                     (glove | norm | prefix | suffix | shape) |                     (glove | norm | prefix | suffix | shape) | ||||||
|                     >> Maxout( |                     >> Maxout( | ||||||
|                         nO=width, |                         nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, | ||||||
|                         nI=width * columns, |  | ||||||
|                         nP=3, |  | ||||||
|                         dropout=0.0, |  | ||||||
|                         normalize=True, |  | ||||||
|                     ), |                     ), | ||||||
|                     column=cols.index(ORTH), |                     column=cols.index(ORTH), | ||||||
|                 ) |                 ) | ||||||
|  | @ -307,11 +327,7 @@ def build_Tok2Vec_model( | ||||||
|                 embed = uniqued( |                 embed = uniqued( | ||||||
|                     (glove | norm) |                     (glove | norm) | ||||||
|                     >> Maxout( |                     >> Maxout( | ||||||
|                         nO=width, |                         nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, | ||||||
|                         nI=width * columns, |  | ||||||
|                         nP=3, |  | ||||||
|                         dropout=0.0, |  | ||||||
|                         normalize=True, |  | ||||||
|                     ), |                     ), | ||||||
|                     column=cols.index(ORTH), |                     column=cols.index(ORTH), | ||||||
|                 ) |                 ) | ||||||
|  | @ -320,11 +336,7 @@ def build_Tok2Vec_model( | ||||||
|             embed = uniqued( |             embed = uniqued( | ||||||
|                 concatenate(norm, prefix, suffix, shape) |                 concatenate(norm, prefix, suffix, shape) | ||||||
|                 >> Maxout( |                 >> Maxout( | ||||||
|                     nO=width, |                     nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, | ||||||
|                     nI=width * columns, |  | ||||||
|                     nP=3, |  | ||||||
|                     dropout=0.0, |  | ||||||
|                     normalize=True, |  | ||||||
|                 ), |                 ), | ||||||
|                 column=cols.index(ORTH), |                 column=cols.index(ORTH), | ||||||
|             ) |             ) | ||||||
|  | @ -333,11 +345,7 @@ def build_Tok2Vec_model( | ||||||
|                 cols |                 cols | ||||||
|             ) >> with_array(norm) |             ) >> with_array(norm) | ||||||
|             reduce_dimensions = Maxout( |             reduce_dimensions = Maxout( | ||||||
|                 nO=width, |                 nO=width, nI=nM * nC + width, nP=3, dropout=0.0, normalize=True, | ||||||
|                 nI=nM * nC + width, |  | ||||||
|                 nP=3, |  | ||||||
|                 dropout=0.0, |  | ||||||
|                 normalize=True, |  | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             embed = norm |             embed = norm | ||||||
|  |  | ||||||
|  | @ -10,7 +10,6 @@ from .simple_ner import SimpleNER | ||||||
| from .tagger import Tagger | from .tagger import Tagger | ||||||
| from .textcat import TextCategorizer | from .textcat import TextCategorizer | ||||||
| from .tok2vec import Tok2Vec | from .tok2vec import Tok2Vec | ||||||
| from .hooks import SentenceSegmenter, SimilarityHook |  | ||||||
| from .functions import merge_entities, merge_noun_chunks, merge_subtokens | from .functions import merge_entities, merge_noun_chunks, merge_subtokens | ||||||
| 
 | 
 | ||||||
| __all__ = [ | __all__ = [ | ||||||
|  | @ -21,9 +20,7 @@ __all__ = [ | ||||||
|     "Morphologizer", |     "Morphologizer", | ||||||
|     "Pipe", |     "Pipe", | ||||||
|     "SentenceRecognizer", |     "SentenceRecognizer", | ||||||
|     "SentenceSegmenter", |  | ||||||
|     "Sentencizer", |     "Sentencizer", | ||||||
|     "SimilarityHook", |  | ||||||
|     "SimpleNER", |     "SimpleNER", | ||||||
|     "Tagger", |     "Tagger", | ||||||
|     "TextCategorizer", |     "TextCategorizer", | ||||||
|  |  | ||||||
|  | @ -63,7 +63,7 @@ class EntityRuler: | ||||||
|         overwrite_ents: bool = False, |         overwrite_ents: bool = False, | ||||||
|         ent_id_sep: str = DEFAULT_ENT_ID_SEP, |         ent_id_sep: str = DEFAULT_ENT_ID_SEP, | ||||||
|         patterns: Optional[List[PatternType]] = None, |         patterns: Optional[List[PatternType]] = None, | ||||||
|     ): |     ) -> None: | ||||||
|         """Initialize the entitiy ruler. If patterns are supplied here, they |         """Initialize the entitiy ruler. If patterns are supplied here, they | ||||||
|         need to be a list of dictionaries with a `"label"` and `"pattern"` |         need to be a list of dictionaries with a `"label"` and `"pattern"` | ||||||
|         key. A pattern can either be a token pattern (list) or a phrase pattern |         key. A pattern can either be a token pattern (list) or a phrase pattern | ||||||
|  | @ -239,7 +239,6 @@ class EntityRuler: | ||||||
|             phrase_pattern_labels = [] |             phrase_pattern_labels = [] | ||||||
|             phrase_pattern_texts = [] |             phrase_pattern_texts = [] | ||||||
|             phrase_pattern_ids = [] |             phrase_pattern_ids = [] | ||||||
| 
 |  | ||||||
|             for entry in patterns: |             for entry in patterns: | ||||||
|                 if isinstance(entry["pattern"], str): |                 if isinstance(entry["pattern"], str): | ||||||
|                     phrase_pattern_labels.append(entry["label"]) |                     phrase_pattern_labels.append(entry["label"]) | ||||||
|  | @ -247,7 +246,6 @@ class EntityRuler: | ||||||
|                     phrase_pattern_ids.append(entry.get("id")) |                     phrase_pattern_ids.append(entry.get("id")) | ||||||
|                 elif isinstance(entry["pattern"], list): |                 elif isinstance(entry["pattern"], list): | ||||||
|                     token_patterns.append(entry) |                     token_patterns.append(entry) | ||||||
| 
 |  | ||||||
|             phrase_patterns = [] |             phrase_patterns = [] | ||||||
|             for label, pattern, ent_id in zip( |             for label, pattern, ent_id in zip( | ||||||
|                 phrase_pattern_labels, |                 phrase_pattern_labels, | ||||||
|  | @ -258,7 +256,6 @@ class EntityRuler: | ||||||
|                 if ent_id: |                 if ent_id: | ||||||
|                     phrase_pattern["id"] = ent_id |                     phrase_pattern["id"] = ent_id | ||||||
|                 phrase_patterns.append(phrase_pattern) |                 phrase_patterns.append(phrase_pattern) | ||||||
| 
 |  | ||||||
|             for entry in token_patterns + phrase_patterns: |             for entry in token_patterns + phrase_patterns: | ||||||
|                 label = entry["label"] |                 label = entry["label"] | ||||||
|                 if "id" in entry: |                 if "id" in entry: | ||||||
|  | @ -266,7 +263,6 @@ class EntityRuler: | ||||||
|                     label = self._create_label(label, entry["id"]) |                     label = self._create_label(label, entry["id"]) | ||||||
|                     key = self.matcher._normalize_key(label) |                     key = self.matcher._normalize_key(label) | ||||||
|                     self._ent_ids[key] = (ent_label, entry["id"]) |                     self._ent_ids[key] = (ent_label, entry["id"]) | ||||||
| 
 |  | ||||||
|                 pattern = entry["pattern"] |                 pattern = entry["pattern"] | ||||||
|                 if isinstance(pattern, Doc): |                 if isinstance(pattern, Doc): | ||||||
|                     self.phrase_patterns[label].append(pattern) |                     self.phrase_patterns[label].append(pattern) | ||||||
|  | @ -323,13 +319,13 @@ class EntityRuler: | ||||||
|         self.clear() |         self.clear() | ||||||
|         if isinstance(cfg, dict): |         if isinstance(cfg, dict): | ||||||
|             self.add_patterns(cfg.get("patterns", cfg)) |             self.add_patterns(cfg.get("patterns", cfg)) | ||||||
|             self.overwrite = cfg.get("overwrite") |             self.overwrite = cfg.get("overwrite", False) | ||||||
|             self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None) |             self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None) | ||||||
|             if self.phrase_matcher_attr is not None: |             if self.phrase_matcher_attr is not None: | ||||||
|                 self.phrase_matcher = PhraseMatcher( |                 self.phrase_matcher = PhraseMatcher( | ||||||
|                     self.nlp.vocab, attr=self.phrase_matcher_attr |                     self.nlp.vocab, attr=self.phrase_matcher_attr | ||||||
|                 ) |                 ) | ||||||
|             self.ent_id_sep = cfg.get("ent_id_sep") |             self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) | ||||||
|         else: |         else: | ||||||
|             self.add_patterns(cfg) |             self.add_patterns(cfg) | ||||||
|         return self |         return self | ||||||
|  | @ -375,9 +371,9 @@ class EntityRuler: | ||||||
|             } |             } | ||||||
|             deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))} |             deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))} | ||||||
|             from_disk(path, deserializers_cfg, {}) |             from_disk(path, deserializers_cfg, {}) | ||||||
|             self.overwrite = cfg.get("overwrite") |             self.overwrite = cfg.get("overwrite", False) | ||||||
|             self.phrase_matcher_attr = cfg.get("phrase_matcher_attr") |             self.phrase_matcher_attr = cfg.get("phrase_matcher_attr") | ||||||
|             self.ent_id_sep = cfg.get("ent_id_sep") |             self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) | ||||||
| 
 | 
 | ||||||
|             if self.phrase_matcher_attr is not None: |             if self.phrase_matcher_attr is not None: | ||||||
|                 self.phrase_matcher = PhraseMatcher( |                 self.phrase_matcher = PhraseMatcher( | ||||||
|  |  | ||||||
|  | @ -1,95 +0,0 @@ | ||||||
| from thinc.api import concatenate, reduce_max, reduce_mean, siamese, CauchySimilarity |  | ||||||
| 
 |  | ||||||
| from .pipe import Pipe |  | ||||||
| from ..util import link_vectors_to_models |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # TODO: do we want to keep these? |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class SentenceSegmenter: |  | ||||||
|     """A simple spaCy hook, to allow custom sentence boundary detection logic |  | ||||||
|     (that doesn't require the dependency parse). To change the sentence |  | ||||||
|     boundary detection strategy, pass a generator function `strategy` on |  | ||||||
|     initialization, or assign a new strategy to the .strategy attribute. |  | ||||||
|     Sentence detection strategies should be generators that take `Doc` objects |  | ||||||
|     and yield `Span` objects for each sentence. |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     def __init__(self, vocab, strategy=None): |  | ||||||
|         self.vocab = vocab |  | ||||||
|         if strategy is None or strategy == "on_punct": |  | ||||||
|             strategy = self.split_on_punct |  | ||||||
|         self.strategy = strategy |  | ||||||
| 
 |  | ||||||
|     def __call__(self, doc): |  | ||||||
|         doc.user_hooks["sents"] = self.strategy |  | ||||||
|         return doc |  | ||||||
| 
 |  | ||||||
|     @staticmethod |  | ||||||
|     def split_on_punct(doc): |  | ||||||
|         start = 0 |  | ||||||
|         seen_period = False |  | ||||||
|         for i, token in enumerate(doc): |  | ||||||
|             if seen_period and not token.is_punct: |  | ||||||
|                 yield doc[start : token.i] |  | ||||||
|                 start = token.i |  | ||||||
|                 seen_period = False |  | ||||||
|             elif token.text in [".", "!", "?"]: |  | ||||||
|                 seen_period = True |  | ||||||
|         if start < len(doc): |  | ||||||
|             yield doc[start : len(doc)] |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class SimilarityHook(Pipe): |  | ||||||
|     """ |  | ||||||
|     Experimental: A pipeline component to install a hook for supervised |  | ||||||
|     similarity into `Doc` objects. |  | ||||||
|     The similarity model can be any object obeying the Thinc `Model` |  | ||||||
|     interface. By default, the model concatenates the elementwise mean and |  | ||||||
|     elementwise max of the two tensors, and compares them using the |  | ||||||
|     Cauchy-like similarity function from Chen (2013): |  | ||||||
| 
 |  | ||||||
|         >>> similarity = 1. / (1. + (W * (vec1-vec2)**2).sum()) |  | ||||||
| 
 |  | ||||||
|     Where W is a vector of dimension weights, initialized to 1. |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     def __init__(self, vocab, model=True, **cfg): |  | ||||||
|         self.vocab = vocab |  | ||||||
|         self.model = model |  | ||||||
|         self.cfg = dict(cfg) |  | ||||||
| 
 |  | ||||||
|     @classmethod |  | ||||||
|     def Model(cls, length): |  | ||||||
|         return siamese( |  | ||||||
|             concatenate(reduce_max(), reduce_mean()), CauchySimilarity(length * 2) |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     def __call__(self, doc): |  | ||||||
|         """Install similarity hook""" |  | ||||||
|         doc.user_hooks["similarity"] = self.predict |  | ||||||
|         return doc |  | ||||||
| 
 |  | ||||||
|     def pipe(self, docs, **kwargs): |  | ||||||
|         for doc in docs: |  | ||||||
|             yield self(doc) |  | ||||||
| 
 |  | ||||||
|     def predict(self, doc1, doc2): |  | ||||||
|         return self.model.predict([(doc1, doc2)]) |  | ||||||
| 
 |  | ||||||
|     def update(self, doc1_doc2, golds, sgd=None, drop=0.0): |  | ||||||
|         sims, bp_sims = self.model.begin_update(doc1_doc2) |  | ||||||
| 
 |  | ||||||
|     def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs): |  | ||||||
|         """Allocate model, using nO from the first model in the pipeline. |  | ||||||
| 
 |  | ||||||
|         gold_tuples (iterable): Gold-standard training data. |  | ||||||
|         pipeline (list): The pipeline the model is part of. |  | ||||||
|         """ |  | ||||||
|         if self.model is True: |  | ||||||
|             self.model = self.Model(pipeline[0].model.get_dim("nO")) |  | ||||||
|             link_vectors_to_models(self.vocab) |  | ||||||
|         if sgd is None: |  | ||||||
|             sgd = self.create_optimizer() |  | ||||||
|         return sgd |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| from typing import Iterable, Tuple, Optional, Dict, List, Callable | from typing import Iterable, Tuple, Optional, Dict, List, Callable, Iterator, Any | ||||||
| from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config | from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config | ||||||
| import numpy | import numpy | ||||||
| 
 | 
 | ||||||
|  | @ -97,7 +97,7 @@ class TextCategorizer(Pipe): | ||||||
|     def labels(self, value: Iterable[str]) -> None: |     def labels(self, value: Iterable[str]) -> None: | ||||||
|         self.cfg["labels"] = tuple(value) |         self.cfg["labels"] = tuple(value) | ||||||
| 
 | 
 | ||||||
|     def pipe(self, stream, batch_size=128): |     def pipe(self, stream: Iterator[str], batch_size: int = 128) -> Iterator[Doc]: | ||||||
|         for docs in util.minibatch(stream, size=batch_size): |         for docs in util.minibatch(stream, size=batch_size): | ||||||
|             scores = self.predict(docs) |             scores = self.predict(docs) | ||||||
|             self.set_annotations(docs, scores) |             self.set_annotations(docs, scores) | ||||||
|  | @ -252,8 +252,17 @@ class TextCategorizer(Pipe): | ||||||
|             sgd = self.create_optimizer() |             sgd = self.create_optimizer() | ||||||
|         return sgd |         return sgd | ||||||
| 
 | 
 | ||||||
|     def score(self, examples, positive_label=None, **kwargs): |     def score( | ||||||
|         return Scorer.score_cats(examples, "cats", labels=self.labels, |         self, | ||||||
|  |         examples: Iterable[Example], | ||||||
|  |         positive_label: Optional[str] = None, | ||||||
|  |         **kwargs, | ||||||
|  |     ) -> Dict[str, Any]: | ||||||
|  |         return Scorer.score_cats( | ||||||
|  |             examples, | ||||||
|  |             "cats", | ||||||
|  |             labels=self.labels, | ||||||
|             multi_label=self.model.attrs["multi_label"], |             multi_label=self.model.attrs["multi_label"], | ||||||
|             positive_label=positive_label, **kwargs |             positive_label=positive_label, | ||||||
|  |             **kwargs, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ from ..gold import Example | ||||||
| from ..tokens import Doc | from ..tokens import Doc | ||||||
| from ..vocab import Vocab | from ..vocab import Vocab | ||||||
| from ..language import Language | from ..language import Language | ||||||
|  | from ..errors import Errors | ||||||
| from ..util import link_vectors_to_models, minibatch | from ..util import link_vectors_to_models, minibatch | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -150,7 +151,7 @@ class Tok2Vec(Pipe): | ||||||
|             self.set_annotations(docs, tokvecs) |             self.set_annotations(docs, tokvecs) | ||||||
|         return losses |         return losses | ||||||
| 
 | 
 | ||||||
|     def get_loss(self, examples, scores): |     def get_loss(self, examples, scores) -> None: | ||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
|     def begin_training( |     def begin_training( | ||||||
|  | @ -184,26 +185,26 @@ class Tok2VecListener(Model): | ||||||
|         self._backprop = None |         self._backprop = None | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_batch_id(cls, inputs): |     def get_batch_id(cls, inputs) -> int: | ||||||
|         return sum(sum(token.orth for token in doc) for doc in inputs) |         return sum(sum(token.orth for token in doc) for doc in inputs) | ||||||
| 
 | 
 | ||||||
|     def receive(self, batch_id, outputs, backprop): |     def receive(self, batch_id: int, outputs, backprop) -> None: | ||||||
|         self._batch_id = batch_id |         self._batch_id = batch_id | ||||||
|         self._outputs = outputs |         self._outputs = outputs | ||||||
|         self._backprop = backprop |         self._backprop = backprop | ||||||
| 
 | 
 | ||||||
|     def verify_inputs(self, inputs): |     def verify_inputs(self, inputs) -> bool: | ||||||
|         if self._batch_id is None and self._outputs is None: |         if self._batch_id is None and self._outputs is None: | ||||||
|             raise ValueError("The Tok2Vec listener did not receive valid input.") |             raise ValueError(Errors.E954) | ||||||
|         else: |         else: | ||||||
|             batch_id = self.get_batch_id(inputs) |             batch_id = self.get_batch_id(inputs) | ||||||
|             if batch_id != self._batch_id: |             if batch_id != self._batch_id: | ||||||
|                 raise ValueError(f"Mismatched IDs! {batch_id} vs {self._batch_id}") |                 raise ValueError(Errors.E953.format(id1=batch_id, id2=self._batch_id)) | ||||||
|             else: |             else: | ||||||
|                 return True |                 return True | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def forward(model: Tok2VecListener, inputs, is_train): | def forward(model: Tok2VecListener, inputs, is_train: bool): | ||||||
|     if is_train: |     if is_train: | ||||||
|         model.verify_inputs(inputs) |         model.verify_inputs(inputs) | ||||||
|         return model._outputs, model._backprop |         return model._outputs, model._backprop | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| from typing import Dict, List, Union, Optional, Sequence, Any, Callable | from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from pydantic import BaseModel, Field, ValidationError, validator | from pydantic import BaseModel, Field, ValidationError, validator | ||||||
| from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool | from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool | ||||||
|  | @ -9,12 +9,12 @@ from thinc.api import Optimizer | ||||||
| from .attrs import NAMES | from .attrs import NAMES | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def validate(schema, obj): | def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]: | ||||||
|     """Validate data against a given pydantic schema. |     """Validate data against a given pydantic schema. | ||||||
| 
 | 
 | ||||||
|     obj (dict): JSON-serializable data to validate. |     obj (Dict[str, Any]): JSON-serializable data to validate. | ||||||
|     schema (pydantic.BaseModel): The schema to validate against. |     schema (pydantic.BaseModel): The schema to validate against. | ||||||
|     RETURNS (list): A list of error messages, if available. |     RETURNS (List[str]): A list of error messages, if available. | ||||||
|     """ |     """ | ||||||
|     try: |     try: | ||||||
|         schema(**obj) |         schema(**obj) | ||||||
|  | @ -31,7 +31,7 @@ def validate(schema, obj): | ||||||
| # Matcher token patterns | # Matcher token patterns | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def validate_token_pattern(obj): | def validate_token_pattern(obj: list) -> List[str]: | ||||||
|     # Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"}) |     # Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"}) | ||||||
|     get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k |     get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k | ||||||
|     if isinstance(obj, list): |     if isinstance(obj, list): | ||||||
|  |  | ||||||
|  | @ -15,7 +15,8 @@ def test_doc_add_entities_set_ents_iob(en_vocab): | ||||||
|         "min_action_freq": 30, |         "min_action_freq": 30, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_NER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     ner = EntityRecognizer(en_vocab, model, **config) |     ner = EntityRecognizer(en_vocab, model, **config) | ||||||
|     ner.begin_training([]) |     ner.begin_training([]) | ||||||
|     ner(doc) |     ner(doc) | ||||||
|  | @ -37,7 +38,8 @@ def test_ents_reset(en_vocab): | ||||||
|         "min_action_freq": 30, |         "min_action_freq": 30, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_NER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     ner = EntityRecognizer(en_vocab, model, **config) |     ner = EntityRecognizer(en_vocab, model, **config) | ||||||
|     ner.begin_training([]) |     ner.begin_training([]) | ||||||
|     ner(doc) |     ner(doc) | ||||||
|  |  | ||||||
|  | @ -14,7 +14,9 @@ def test_en_sbd_single_punct(en_tokenizer, text, punct): | ||||||
|     assert sum(len(sent) for sent in doc.sents) == len(doc) |     assert sum(len(sent) for sent in doc.sents) == len(doc) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") | @pytest.mark.skip( | ||||||
|  |     reason="The step_through API was removed (but should be brought back)" | ||||||
|  | ) | ||||||
| def test_en_sentence_breaks(en_tokenizer, en_parser): | def test_en_sentence_breaks(en_tokenizer, en_parser): | ||||||
|     # fmt: off |     # fmt: off | ||||||
|     text = "This is a sentence . This is another one ." |     text = "This is a sentence . This is another one ." | ||||||
|  |  | ||||||
|  | @ -32,6 +32,25 @@ SENTENCE_TESTS = [ | ||||||
|     ("あれ。これ。", ["あれ。", "これ。"]), |     ("あれ。これ。", ["あれ。", "これ。"]), | ||||||
|     ("「伝染るんです。」という漫画があります。", ["「伝染るんです。」という漫画があります。"]), |     ("「伝染るんです。」という漫画があります。", ["「伝染るんです。」という漫画があります。"]), | ||||||
| ] | ] | ||||||
|  | 
 | ||||||
|  | tokens1 = [ | ||||||
|  |     DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None), | ||||||
|  |     DetailedToken(surface="会", tag="名詞-普通名詞-一般", inf="", lemma="会", reading="カイ", sub_tokens=None), | ||||||
|  | ] | ||||||
|  | tokens2 = [ | ||||||
|  |     DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None), | ||||||
|  |     DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None), | ||||||
|  |     DetailedToken(surface="委員", tag="名詞-普通名詞-一般", inf="", lemma="委員", reading="イイン", sub_tokens=None), | ||||||
|  |     DetailedToken(surface="会", tag="名詞-普通名詞-一般", inf="", lemma="会", reading="カイ", sub_tokens=None), | ||||||
|  | ] | ||||||
|  | tokens3 = [ | ||||||
|  |     DetailedToken(surface="選挙", tag="名詞-普通名詞-サ変可能", inf="", lemma="選挙", reading="センキョ", sub_tokens=None), | ||||||
|  |     DetailedToken(surface="管理", tag="名詞-普通名詞-サ変可能", inf="", lemma="管理", reading="カンリ", sub_tokens=None), | ||||||
|  |     DetailedToken(surface="委員会", tag="名詞-普通名詞-一般", inf="", lemma="委員会", reading="イインカイ", sub_tokens=None), | ||||||
|  | ] | ||||||
|  | SUB_TOKEN_TESTS = [ | ||||||
|  |     ("選挙管理委員会", [None, None, None, None], [None, None, [tokens1]], [[tokens2, tokens3]]) | ||||||
|  | ] | ||||||
| # fmt: on | # fmt: on | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -92,33 +111,12 @@ def test_ja_tokenizer_split_modes(ja_tokenizer, text, len_a, len_b, len_c): | ||||||
|     assert len(nlp_c(text)) == len_c |     assert len(nlp_c(text)) == len_c | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c", | @pytest.mark.parametrize( | ||||||
|     [ |     "text,sub_tokens_list_a,sub_tokens_list_b,sub_tokens_list_c", SUB_TOKEN_TESTS, | ||||||
|         ( |  | ||||||
|             "選挙管理委員会", |  | ||||||
|             [None, None, None, None], |  | ||||||
|             [None, None, [ |  | ||||||
|                 [ |  | ||||||
|                     DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None), |  | ||||||
|                     DetailedToken(surface='会', tag='名詞-普通名詞-一般', inf='', lemma='会', reading='カイ', sub_tokens=None), |  | ||||||
|                 ] |  | ||||||
|             ]], |  | ||||||
|             [[ |  | ||||||
|                 [ |  | ||||||
|                     DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None), |  | ||||||
|                     DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None), |  | ||||||
|                     DetailedToken(surface='委員', tag='名詞-普通名詞-一般', inf='', lemma='委員', reading='イイン', sub_tokens=None), |  | ||||||
|                     DetailedToken(surface='会', tag='名詞-普通名詞-一般', inf='', lemma='会', reading='カイ', sub_tokens=None), |  | ||||||
|                 ], [ |  | ||||||
|                     DetailedToken(surface='選挙', tag='名詞-普通名詞-サ変可能', inf='', lemma='選挙', reading='センキョ', sub_tokens=None), |  | ||||||
|                     DetailedToken(surface='管理', tag='名詞-普通名詞-サ変可能', inf='', lemma='管理', reading='カンリ', sub_tokens=None), |  | ||||||
|                     DetailedToken(surface='委員会', tag='名詞-普通名詞-一般', inf='', lemma='委員会', reading='イインカイ', sub_tokens=None), |  | ||||||
|                 ] |  | ||||||
|             ]] |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| ) | ) | ||||||
| def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c): | def test_ja_tokenizer_sub_tokens( | ||||||
|  |     ja_tokenizer, text, sub_tokens_list_a, sub_tokens_list_b, sub_tokens_list_c | ||||||
|  | ): | ||||||
|     nlp_a = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "A"}}}) |     nlp_a = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "A"}}}) | ||||||
|     nlp_b = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "B"}}}) |     nlp_b = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "B"}}}) | ||||||
|     nlp_c = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "C"}}}) |     nlp_c = Japanese.from_config({"nlp": {"tokenizer": {"split_mode": "C"}}}) | ||||||
|  | @ -129,16 +127,19 @@ def test_ja_tokenizer_sub_tokens(ja_tokenizer, text, sub_tokens_list_a, sub_toke | ||||||
|     assert nlp_c(text).user_data["sub_tokens"] == sub_tokens_list_c |     assert nlp_c(text).user_data["sub_tokens"] == sub_tokens_list_c | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("text,inflections,reading_forms", | @pytest.mark.parametrize( | ||||||
|  |     "text,inflections,reading_forms", | ||||||
|     [ |     [ | ||||||
|         ( |         ( | ||||||
|             "取ってつけた", |             "取ってつけた", | ||||||
|             ("五段-ラ行,連用形-促音便", "", "下一段-カ行,連用形-一般", "助動詞-タ,終止形-一般"), |             ("五段-ラ行,連用形-促音便", "", "下一段-カ行,連用形-一般", "助動詞-タ,終止形-一般"), | ||||||
|             ("トッ", "テ", "ツケ", "タ"), |             ("トッ", "テ", "ツケ", "タ"), | ||||||
|         ), |         ), | ||||||
|     ] |     ], | ||||||
| ) | ) | ||||||
| def test_ja_tokenizer_inflections_reading_forms(ja_tokenizer, text, inflections, reading_forms): | def test_ja_tokenizer_inflections_reading_forms( | ||||||
|  |     ja_tokenizer, text, inflections, reading_forms | ||||||
|  | ): | ||||||
|     assert ja_tokenizer(text).user_data["inflections"] == inflections |     assert ja_tokenizer(text).user_data["inflections"] == inflections | ||||||
|     assert ja_tokenizer(text).user_data["reading_forms"] == reading_forms |     assert ja_tokenizer(text).user_data["reading_forms"] == reading_forms | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -11,9 +11,8 @@ def test_ne_tokenizer_handlers_long_text(ne_tokenizer): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     "text,length", |     "text,length", [("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)], | ||||||
|     [("समय जान कति पनि बेर लाग्दैन ।", 7), ("म ठूलो हुँदै थिएँ ।", 5)], |  | ||||||
| ) | ) | ||||||
| def test_ne_tokenizer_handles_cnts(ne_tokenizer, text, length): | def test_ne_tokenizer_handles_cnts(ne_tokenizer, text, length): | ||||||
|     tokens = ne_tokenizer(text) |     tokens = ne_tokenizer(text) | ||||||
|     assert len(tokens) == length |     assert len(tokens) == length | ||||||
|  |  | ||||||
|  | @ -30,10 +30,7 @@ def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg): | ||||||
|     nlp = Chinese( |     nlp = Chinese( | ||||||
|         meta={ |         meta={ | ||||||
|             "tokenizer": { |             "tokenizer": { | ||||||
|                 "config": { |                 "config": {"segmenter": "pkuseg", "pkuseg_model": "medicine",} | ||||||
|                     "segmenter": "pkuseg", |  | ||||||
|                     "pkuseg_model": "medicine", |  | ||||||
|                 } |  | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  | @ -22,7 +22,8 @@ def parser(vocab): | ||||||
|         "min_action_freq": 30, |         "min_action_freq": 30, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_PARSER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     parser = DependencyParser(vocab, model, **config) |     parser = DependencyParser(vocab, model, **config) | ||||||
|     return parser |     return parser | ||||||
| 
 | 
 | ||||||
|  | @ -68,7 +69,8 @@ def test_add_label_deserializes_correctly(): | ||||||
|         "min_action_freq": 30, |         "min_action_freq": 30, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_NER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     ner1 = EntityRecognizer(Vocab(), model, **config) |     ner1 = EntityRecognizer(Vocab(), model, **config) | ||||||
|     ner1.add_label("C") |     ner1.add_label("C") | ||||||
|     ner1.add_label("B") |     ner1.add_label("B") | ||||||
|  | @ -86,7 +88,10 @@ def test_add_label_deserializes_correctly(): | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     "pipe_cls,n_moves,model_config", |     "pipe_cls,n_moves,model_config", | ||||||
|     [(DependencyParser, 5, DEFAULT_PARSER_MODEL), (EntityRecognizer, 4, DEFAULT_NER_MODEL)], |     [ | ||||||
|  |         (DependencyParser, 5, DEFAULT_PARSER_MODEL), | ||||||
|  |         (EntityRecognizer, 4, DEFAULT_NER_MODEL), | ||||||
|  |     ], | ||||||
| ) | ) | ||||||
| def test_add_label_get_label(pipe_cls, n_moves, model_config): | def test_add_label_get_label(pipe_cls, n_moves, model_config): | ||||||
|     """Test that added labels are returned correctly. This test was added to |     """Test that added labels are returned correctly. This test was added to | ||||||
|  |  | ||||||
|  | @ -126,7 +126,8 @@ def test_get_oracle_actions(): | ||||||
|         "min_action_freq": 0, |         "min_action_freq": 0, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_PARSER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     parser = DependencyParser(doc.vocab, model, **config) |     parser = DependencyParser(doc.vocab, model, **config) | ||||||
|     parser.moves.add_action(0, "") |     parser.moves.add_action(0, "") | ||||||
|     parser.moves.add_action(1, "") |     parser.moves.add_action(1, "") | ||||||
|  |  | ||||||
|  | @ -24,7 +24,8 @@ def arc_eager(vocab): | ||||||
| 
 | 
 | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def tok2vec(): | def tok2vec(): | ||||||
|     tok2vec = registry.make_from_config({"model": DEFAULT_TOK2VEC_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_TOK2VEC_MODEL} | ||||||
|  |     tok2vec = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     tok2vec.initialize() |     tok2vec.initialize() | ||||||
|     return tok2vec |     return tok2vec | ||||||
| 
 | 
 | ||||||
|  | @ -36,13 +37,15 @@ def parser(vocab, arc_eager): | ||||||
|         "min_action_freq": 30, |         "min_action_freq": 30, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_PARSER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     return Parser(vocab, model, moves=arc_eager, **config) |     return Parser(vocab, model, moves=arc_eager, **config) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def model(arc_eager, tok2vec, vocab): | def model(arc_eager, tok2vec, vocab): | ||||||
|     model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_PARSER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     model.attrs["resize_output"](model, arc_eager.n_moves) |     model.attrs["resize_output"](model, arc_eager.n_moves) | ||||||
|     model.initialize() |     model.initialize() | ||||||
|     return model |     return model | ||||||
|  | @ -68,7 +71,8 @@ def test_build_model(parser, vocab): | ||||||
|         "min_action_freq": 0, |         "min_action_freq": 0, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_PARSER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model |     parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model | ||||||
|     assert parser.model is not None |     assert parser.model is not None | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -33,7 +33,9 @@ def test_parser_root(en_tokenizer): | ||||||
|         assert t.dep != 0, t.text |         assert t.dep != 0, t.text | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") | @pytest.mark.skip( | ||||||
|  |     reason="The step_through API was removed (but should be brought back)" | ||||||
|  | ) | ||||||
| @pytest.mark.parametrize("text", ["Hello"]) | @pytest.mark.parametrize("text", ["Hello"]) | ||||||
| def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text): | def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text): | ||||||
|     tokens = en_tokenizer(text) |     tokens = en_tokenizer(text) | ||||||
|  | @ -47,7 +49,9 @@ def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text): | ||||||
|     assert doc[0].dep != 0 |     assert doc[0].dep != 0 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") | @pytest.mark.skip( | ||||||
|  |     reason="The step_through API was removed (but should be brought back)" | ||||||
|  | ) | ||||||
| def test_parser_initial(en_tokenizer, en_parser): | def test_parser_initial(en_tokenizer, en_parser): | ||||||
|     text = "I ate the pizza with anchovies." |     text = "I ate the pizza with anchovies." | ||||||
|     # heads = [1, 0, 1, -2, -3, -1, -5] |     # heads = [1, 0, 1, -2, -3, -1, -5] | ||||||
|  | @ -92,7 +96,9 @@ def test_parser_merge_pp(en_tokenizer): | ||||||
|     assert doc[3].text == "occurs" |     assert doc[3].text == "occurs" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") | @pytest.mark.skip( | ||||||
|  |     reason="The step_through API was removed (but should be brought back)" | ||||||
|  | ) | ||||||
| def test_parser_arc_eager_finalize_state(en_tokenizer, en_parser): | def test_parser_arc_eager_finalize_state(en_tokenizer, en_parser): | ||||||
|     text = "a b c d e" |     text = "a b c d e" | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -21,7 +21,8 @@ def parser(vocab): | ||||||
|         "min_action_freq": 30, |         "min_action_freq": 30, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_PARSER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     parser = DependencyParser(vocab, model, **config) |     parser = DependencyParser(vocab, model, **config) | ||||||
|     parser.cfg["token_vector_width"] = 4 |     parser.cfg["token_vector_width"] = 4 | ||||||
|     parser.cfg["hidden_width"] = 32 |     parser.cfg["hidden_width"] = 32 | ||||||
|  |  | ||||||
|  | @ -28,7 +28,9 @@ def test_parser_sentence_space(en_tokenizer): | ||||||
|     assert len(list(doc.sents)) == 2 |     assert len(list(doc.sents)) == 2 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") | @pytest.mark.skip( | ||||||
|  |     reason="The step_through API was removed (but should be brought back)" | ||||||
|  | ) | ||||||
| def test_parser_space_attachment_leading(en_tokenizer, en_parser): | def test_parser_space_attachment_leading(en_tokenizer, en_parser): | ||||||
|     text = "\t \n This is a sentence ." |     text = "\t \n This is a sentence ." | ||||||
|     heads = [1, 1, 0, 1, -2, -3] |     heads = [1, 1, 0, 1, -2, -3] | ||||||
|  | @ -44,7 +46,9 @@ def test_parser_space_attachment_leading(en_tokenizer, en_parser): | ||||||
|     assert stepwise.stack == set([2]) |     assert stepwise.stack == set([2]) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") | @pytest.mark.skip( | ||||||
|  |     reason="The step_through API was removed (but should be brought back)" | ||||||
|  | ) | ||||||
| def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser): | def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser): | ||||||
|     text = "This is \t a \t\n \n sentence . \n\n \n" |     text = "This is \t a \t\n \n sentence . \n\n \n" | ||||||
|     heads = [1, 0, -1, 2, -1, -4, -5, -1] |     heads = [1, 0, -1, 2, -1, -4, -5, -1] | ||||||
|  | @ -64,7 +68,9 @@ def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("text,length", [(["\n"], 1), (["\n", "\t", "\n\n", "\t"], 4)]) | @pytest.mark.parametrize("text,length", [(["\n"], 1), (["\n", "\t", "\n\n", "\t"], 4)]) | ||||||
| @pytest.mark.skip(reason="The step_through API was removed (but should be brought back)") | @pytest.mark.skip( | ||||||
|  |     reason="The step_through API was removed (but should be brought back)" | ||||||
|  | ) | ||||||
| def test_parser_space_attachment_space(en_tokenizer, en_parser, text, length): | def test_parser_space_attachment_space(en_tokenizer, en_parser, text, length): | ||||||
|     doc = Doc(en_parser.vocab, words=text) |     doc = Doc(en_parser.vocab, words=text) | ||||||
|     assert len(doc) == length |     assert len(doc) == length | ||||||
|  |  | ||||||
|  | @ -117,7 +117,9 @@ def test_overfitting_IO(): | ||||||
|         assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1) |         assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1) | ||||||
| 
 | 
 | ||||||
|     # Test scoring |     # Test scoring | ||||||
|     scores = nlp.evaluate(train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}}) |     scores = nlp.evaluate( | ||||||
|  |         train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}} | ||||||
|  |     ) | ||||||
|     assert scores["cats_f"] == 1.0 |     assert scores["cats_f"] == 1.0 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -201,7 +201,8 @@ def test_issue3345(): | ||||||
|         "min_action_freq": 30, |         "min_action_freq": 30, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_NER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_NER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     ner = EntityRecognizer(doc.vocab, model, **config) |     ner = EntityRecognizer(doc.vocab, model, **config) | ||||||
|     # Add the OUT action. I wouldn't have thought this would be necessary... |     # Add the OUT action. I wouldn't have thought this would be necessary... | ||||||
|     ner.moves.add_action(5, "") |     ner.moves.add_action(5, "") | ||||||
|  |  | ||||||
|  | @ -20,7 +20,8 @@ def parser(en_vocab): | ||||||
|         "min_action_freq": 30, |         "min_action_freq": 30, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_PARSER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     parser = DependencyParser(en_vocab, model, **config) |     parser = DependencyParser(en_vocab, model, **config) | ||||||
|     parser.add_label("nsubj") |     parser.add_label("nsubj") | ||||||
|     return parser |     return parser | ||||||
|  | @ -33,14 +34,16 @@ def blank_parser(en_vocab): | ||||||
|         "min_action_freq": 30, |         "min_action_freq": 30, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_PARSER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     parser = DependencyParser(en_vocab, model, **config) |     parser = DependencyParser(en_vocab, model, **config) | ||||||
|     return parser |     return parser | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def taggers(en_vocab): | def taggers(en_vocab): | ||||||
|     model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_TAGGER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     tagger1 = Tagger(en_vocab, model, set_morphology=True) |     tagger1 = Tagger(en_vocab, model, set_morphology=True) | ||||||
|     tagger2 = Tagger(en_vocab, model, set_morphology=True) |     tagger2 = Tagger(en_vocab, model, set_morphology=True) | ||||||
|     return tagger1, tagger2 |     return tagger1, tagger2 | ||||||
|  | @ -53,7 +56,8 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser): | ||||||
|         "min_action_freq": 0, |         "min_action_freq": 0, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_PARSER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     parser = Parser(en_vocab, model, **config) |     parser = Parser(en_vocab, model, **config) | ||||||
|     new_parser = Parser(en_vocab, model, **config) |     new_parser = Parser(en_vocab, model, **config) | ||||||
|     new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"])) |     new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"])) | ||||||
|  | @ -70,7 +74,8 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser): | ||||||
|         "min_action_freq": 0, |         "min_action_freq": 0, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|     model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_PARSER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     parser = Parser(en_vocab, model, **config) |     parser = Parser(en_vocab, model, **config) | ||||||
|     with make_tempdir() as d: |     with make_tempdir() as d: | ||||||
|         file_path = d / "parser" |         file_path = d / "parser" | ||||||
|  | @ -88,7 +93,6 @@ def test_to_from_bytes(parser, blank_parser): | ||||||
|     assert blank_parser.model is not True |     assert blank_parser.model is not True | ||||||
|     assert blank_parser.moves.n_moves != parser.moves.n_moves |     assert blank_parser.moves.n_moves != parser.moves.n_moves | ||||||
|     bytes_data = parser.to_bytes(exclude=["vocab"]) |     bytes_data = parser.to_bytes(exclude=["vocab"]) | ||||||
| 
 |  | ||||||
|     # the blank parser needs to be resized before we can call from_bytes |     # the blank parser needs to be resized before we can call from_bytes | ||||||
|     blank_parser.model.attrs["resize_output"](blank_parser.model, parser.moves.n_moves) |     blank_parser.model.attrs["resize_output"](blank_parser.model, parser.moves.n_moves) | ||||||
|     blank_parser.from_bytes(bytes_data) |     blank_parser.from_bytes(bytes_data) | ||||||
|  | @ -104,7 +108,8 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers): | ||||||
|     tagger1_b = tagger1.to_bytes() |     tagger1_b = tagger1.to_bytes() | ||||||
|     tagger1 = tagger1.from_bytes(tagger1_b) |     tagger1 = tagger1.from_bytes(tagger1_b) | ||||||
|     assert tagger1.to_bytes() == tagger1_b |     assert tagger1.to_bytes() == tagger1_b | ||||||
|     model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_TAGGER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b) |     new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b) | ||||||
|     new_tagger1_b = new_tagger1.to_bytes() |     new_tagger1_b = new_tagger1.to_bytes() | ||||||
|     assert len(new_tagger1_b) == len(tagger1_b) |     assert len(new_tagger1_b) == len(tagger1_b) | ||||||
|  | @ -118,7 +123,8 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers): | ||||||
|         file_path2 = d / "tagger2" |         file_path2 = d / "tagger2" | ||||||
|         tagger1.to_disk(file_path1) |         tagger1.to_disk(file_path1) | ||||||
|         tagger2.to_disk(file_path2) |         tagger2.to_disk(file_path2) | ||||||
|         model = registry.make_from_config({"model": DEFAULT_TAGGER_MODEL}, validate=True)["model"] |         cfg = {"model": DEFAULT_TAGGER_MODEL} | ||||||
|  |         model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|         tagger1_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path1) |         tagger1_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path1) | ||||||
|         tagger2_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path2) |         tagger2_d = Tagger(en_vocab, model, set_morphology=True).from_disk(file_path2) | ||||||
|         assert tagger1_d.to_bytes() == tagger2_d.to_bytes() |         assert tagger1_d.to_bytes() == tagger2_d.to_bytes() | ||||||
|  | @ -126,21 +132,22 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers): | ||||||
| 
 | 
 | ||||||
| def test_serialize_textcat_empty(en_vocab): | def test_serialize_textcat_empty(en_vocab): | ||||||
|     # See issue #1105 |     # See issue #1105 | ||||||
|     model = registry.make_from_config({"model": DEFAULT_TEXTCAT_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_TEXTCAT_MODEL} | ||||||
|     textcat = TextCategorizer( |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|         en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"] |     textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"]) | ||||||
|     ) |  | ||||||
|     textcat.to_bytes(exclude=["vocab"]) |     textcat.to_bytes(exclude=["vocab"]) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("Parser", test_parsers) | @pytest.mark.parametrize("Parser", test_parsers) | ||||||
| def test_serialize_pipe_exclude(en_vocab, Parser): | def test_serialize_pipe_exclude(en_vocab, Parser): | ||||||
|     model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_PARSER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     config = { |     config = { | ||||||
|         "learn_tokens": False, |         "learn_tokens": False, | ||||||
|         "min_action_freq": 0, |         "min_action_freq": 0, | ||||||
|         "update_with_oracle_cut_size": 100, |         "update_with_oracle_cut_size": 100, | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|     def get_new_parser(): |     def get_new_parser(): | ||||||
|         new_parser = Parser(en_vocab, model, **config) |         new_parser = Parser(en_vocab, model, **config) | ||||||
|         return new_parser |         return new_parser | ||||||
|  | @ -160,7 +167,8 @@ def test_serialize_pipe_exclude(en_vocab, Parser): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_serialize_sentencerecognizer(en_vocab): | def test_serialize_sentencerecognizer(en_vocab): | ||||||
|     model = registry.make_from_config({"model": DEFAULT_SENTER_MODEL}, validate=True)["model"] |     cfg = {"model": DEFAULT_SENTER_MODEL} | ||||||
|  |     model = registry.make_from_config(cfg, validate=True)["model"] | ||||||
|     sr = SentenceRecognizer(en_vocab, model) |     sr = SentenceRecognizer(en_vocab, model) | ||||||
|     sr_b = sr.to_bytes() |     sr_b = sr.to_bytes() | ||||||
|     sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b) |     sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b) | ||||||
|  |  | ||||||
|  | @ -466,7 +466,6 @@ def test_iob_to_biluo(): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_roundtrip_docs_to_docbin(doc): | def test_roundtrip_docs_to_docbin(doc): | ||||||
|     nlp = English() |  | ||||||
|     text = doc.text |     text = doc.text | ||||||
|     idx = [t.idx for t in doc] |     idx = [t.idx for t in doc] | ||||||
|     tags = [t.tag_ for t in doc] |     tags = [t.tag_ for t in doc] | ||||||
|  |  | ||||||
|  | @ -7,11 +7,11 @@ from spacy.vocab import Vocab | ||||||
| def test_Example_init_requires_doc_objects(): | def test_Example_init_requires_doc_objects(): | ||||||
|     vocab = Vocab() |     vocab = Vocab() | ||||||
|     with pytest.raises(TypeError): |     with pytest.raises(TypeError): | ||||||
|         example = Example(None, None) |         Example(None, None) | ||||||
|     with pytest.raises(TypeError): |     with pytest.raises(TypeError): | ||||||
|         example = Example(Doc(vocab, words=["hi"]), None) |         Example(Doc(vocab, words=["hi"]), None) | ||||||
|     with pytest.raises(TypeError): |     with pytest.raises(TypeError): | ||||||
|         example = Example(None, Doc(vocab, words=["hi"])) |         Example(None, Doc(vocab, words=["hi"])) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_Example_from_dict_basic(): | def test_Example_from_dict_basic(): | ||||||
|  |  | ||||||
|  | @ -105,7 +105,11 @@ def test_tokenization(sented_doc): | ||||||
|     assert scores["token_acc"] == 1.0 |     assert scores["token_acc"] == 1.0 | ||||||
| 
 | 
 | ||||||
|     nlp = English() |     nlp = English() | ||||||
|     example.predicted = Doc(nlp.vocab, words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."], spaces=[True, True, True, True, True, False]) |     example.predicted = Doc( | ||||||
|  |         nlp.vocab, | ||||||
|  |         words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."], | ||||||
|  |         spaces=[True, True, True, True, True, False], | ||||||
|  |     ) | ||||||
|     example.predicted[1].is_sent_start = False |     example.predicted[1].is_sent_start = False | ||||||
|     scores = scorer.score([example]) |     scores = scorer.score([example]) | ||||||
|     assert scores["token_acc"] == approx(0.66666666) |     assert scores["token_acc"] == approx(0.66666666) | ||||||
|  |  | ||||||
							
								
								
									
										193
									
								
								spacy/util.py
									
									
									
									
									
								
							
							
						
						
									
										193
									
								
								spacy/util.py
									
									
									
									
									
								
							|  | @ -1,12 +1,13 @@ | ||||||
| from typing import List, Union, Dict, Any, Optional, Iterable, Callable, Tuple | from typing import List, Union, Dict, Any, Optional, Iterable, Callable, Tuple | ||||||
| from typing import Iterator, TYPE_CHECKING | from typing import Iterator, Type, Pattern, Sequence, TYPE_CHECKING | ||||||
|  | from types import ModuleType | ||||||
| import os | import os | ||||||
| import importlib | import importlib | ||||||
| import importlib.util | import importlib.util | ||||||
| import re | import re | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| import thinc | import thinc | ||||||
| from thinc.api import NumpyOps, get_current_ops, Adam, Config | from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer | ||||||
| import functools | import functools | ||||||
| import itertools | import itertools | ||||||
| import numpy.random | import numpy.random | ||||||
|  | @ -49,6 +50,8 @@ from . import about | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     # This lets us add type hints for mypy etc. without causing circular imports |     # This lets us add type hints for mypy etc. without causing circular imports | ||||||
|     from .language import Language  # noqa: F401 |     from .language import Language  # noqa: F401 | ||||||
|  |     from .tokens import Doc, Span  # noqa: F401 | ||||||
|  |     from .vocab import Vocab  # noqa: F401 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| _PRINT_ENV = False | _PRINT_ENV = False | ||||||
|  | @ -102,12 +105,12 @@ class SimpleFrozenDict(dict): | ||||||
|         raise NotImplementedError(self.error) |         raise NotImplementedError(self.error) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def set_env_log(value): | def set_env_log(value: bool) -> None: | ||||||
|     global _PRINT_ENV |     global _PRINT_ENV | ||||||
|     _PRINT_ENV = value |     _PRINT_ENV = value | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def lang_class_is_loaded(lang): | def lang_class_is_loaded(lang: str) -> bool: | ||||||
|     """Check whether a Language class is already loaded. Language classes are |     """Check whether a Language class is already loaded. Language classes are | ||||||
|     loaded lazily, to avoid expensive setup code associated with the language |     loaded lazily, to avoid expensive setup code associated with the language | ||||||
|     data. |     data. | ||||||
|  | @ -118,7 +121,7 @@ def lang_class_is_loaded(lang): | ||||||
|     return lang in registry.languages |     return lang in registry.languages | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_lang_class(lang): | def get_lang_class(lang: str) -> "Language": | ||||||
|     """Import and load a Language class. |     """Import and load a Language class. | ||||||
| 
 | 
 | ||||||
|     lang (str): Two-letter language code, e.g. 'en'. |     lang (str): Two-letter language code, e.g. 'en'. | ||||||
|  | @ -136,7 +139,7 @@ def get_lang_class(lang): | ||||||
|     return registry.languages.get(lang) |     return registry.languages.get(lang) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def set_lang_class(name, cls): | def set_lang_class(name: str, cls: Type["Language"]) -> None: | ||||||
|     """Set a custom Language class name that can be loaded via get_lang_class. |     """Set a custom Language class name that can be loaded via get_lang_class. | ||||||
| 
 | 
 | ||||||
|     name (str): Name of Language class. |     name (str): Name of Language class. | ||||||
|  | @ -145,10 +148,10 @@ def set_lang_class(name, cls): | ||||||
|     registry.languages.register(name, func=cls) |     registry.languages.register(name, func=cls) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def ensure_path(path): | def ensure_path(path: Any) -> Any: | ||||||
|     """Ensure string is converted to a Path. |     """Ensure string is converted to a Path. | ||||||
| 
 | 
 | ||||||
|     path: Anything. If string, it's converted to Path. |     path (Any): Anything. If string, it's converted to Path. | ||||||
|     RETURNS: Path or original argument. |     RETURNS: Path or original argument. | ||||||
|     """ |     """ | ||||||
|     if isinstance(path, str): |     if isinstance(path, str): | ||||||
|  | @ -157,7 +160,7 @@ def ensure_path(path): | ||||||
|         return path |         return path | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def load_language_data(path): | def load_language_data(path: Union[str, Path]) -> Union[dict, list]: | ||||||
|     """Load JSON language data using the given path as a base. If the provided |     """Load JSON language data using the given path as a base. If the provided | ||||||
|     path isn't present, will attempt to load a gzipped version before giving up. |     path isn't present, will attempt to load a gzipped version before giving up. | ||||||
| 
 | 
 | ||||||
|  | @ -173,7 +176,12 @@ def load_language_data(path): | ||||||
|     raise ValueError(Errors.E160.format(path=path)) |     raise ValueError(Errors.E160.format(path=path)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_module_path(module): | def get_module_path(module: ModuleType) -> Path: | ||||||
|  |     """Get the path of a Python module. | ||||||
|  | 
 | ||||||
|  |     module (ModuleType): The Python module. | ||||||
|  |     RETURNS (Path): The path. | ||||||
|  |     """ | ||||||
|     if not hasattr(module, "__module__"): |     if not hasattr(module, "__module__"): | ||||||
|         raise ValueError(Errors.E169.format(module=repr(module))) |         raise ValueError(Errors.E169.format(module=repr(module))) | ||||||
|     return Path(sys.modules[module.__module__].__file__).parent |     return Path(sys.modules[module.__module__].__file__).parent | ||||||
|  | @ -183,7 +191,7 @@ def load_model( | ||||||
|     name: Union[str, Path], |     name: Union[str, Path], | ||||||
|     disable: Iterable[str] = tuple(), |     disable: Iterable[str] = tuple(), | ||||||
|     component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), |     component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), | ||||||
| ): | ) -> "Language": | ||||||
|     """Load a model from a package or data path. |     """Load a model from a package or data path. | ||||||
| 
 | 
 | ||||||
|     name (str): Package name or model path. |     name (str): Package name or model path. | ||||||
|  | @ -209,7 +217,7 @@ def load_model_from_package( | ||||||
|     name: str, |     name: str, | ||||||
|     disable: Iterable[str] = tuple(), |     disable: Iterable[str] = tuple(), | ||||||
|     component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), |     component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), | ||||||
| ): | ) -> "Language": | ||||||
|     """Load a model from an installed package.""" |     """Load a model from an installed package.""" | ||||||
|     cls = importlib.import_module(name) |     cls = importlib.import_module(name) | ||||||
|     return cls.load(disable=disable, component_cfg=component_cfg) |     return cls.load(disable=disable, component_cfg=component_cfg) | ||||||
|  | @ -220,7 +228,7 @@ def load_model_from_path( | ||||||
|     meta: Optional[Dict[str, Any]] = None, |     meta: Optional[Dict[str, Any]] = None, | ||||||
|     disable: Iterable[str] = tuple(), |     disable: Iterable[str] = tuple(), | ||||||
|     component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), |     component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), | ||||||
| ): | ) -> "Language": | ||||||
|     """Load a model from a data directory path. Creates Language class with |     """Load a model from a data directory path. Creates Language class with | ||||||
|     pipeline from config.cfg and then calls from_disk() with path.""" |     pipeline from config.cfg and then calls from_disk() with path.""" | ||||||
|     if not model_path.exists(): |     if not model_path.exists(): | ||||||
|  | @ -269,7 +277,7 @@ def load_model_from_init_py( | ||||||
|     init_file: Union[Path, str], |     init_file: Union[Path, str], | ||||||
|     disable: Iterable[str] = tuple(), |     disable: Iterable[str] = tuple(), | ||||||
|     component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), |     component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), | ||||||
| ): | ) -> "Language": | ||||||
|     """Helper function to use in the `load()` method of a model package's |     """Helper function to use in the `load()` method of a model package's | ||||||
|     __init__.py. |     __init__.py. | ||||||
| 
 | 
 | ||||||
|  | @ -288,15 +296,15 @@ def load_model_from_init_py( | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_installed_models(): | def get_installed_models() -> List[str]: | ||||||
|     """List all model packages currently installed in the environment. |     """List all model packages currently installed in the environment. | ||||||
| 
 | 
 | ||||||
|     RETURNS (list): The string names of the models. |     RETURNS (List[str]): The string names of the models. | ||||||
|     """ |     """ | ||||||
|     return list(registry.models.get_all().keys()) |     return list(registry.models.get_all().keys()) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_package_version(name): | def get_package_version(name: str) -> Optional[str]: | ||||||
|     """Get the version of an installed package. Typically used to get model |     """Get the version of an installed package. Typically used to get model | ||||||
|     package versions. |     package versions. | ||||||
| 
 | 
 | ||||||
|  | @ -309,7 +317,9 @@ def get_package_version(name): | ||||||
|         return None |         return None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def is_compatible_version(version, constraint, prereleases=True): | def is_compatible_version( | ||||||
|  |     version: str, constraint: str, prereleases: bool = True | ||||||
|  | ) -> Optional[bool]: | ||||||
|     """Check if a version (e.g. "2.0.0") is compatible given a version |     """Check if a version (e.g. "2.0.0") is compatible given a version | ||||||
|     constraint (e.g. ">=1.9.0,<2.2.1"). If the constraint is a specific version, |     constraint (e.g. ">=1.9.0,<2.2.1"). If the constraint is a specific version, | ||||||
|     it's interpreted as =={version}. |     it's interpreted as =={version}. | ||||||
|  | @ -333,7 +343,9 @@ def is_compatible_version(version, constraint, prereleases=True): | ||||||
|     return version in spec |     return version in spec | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def is_unconstrained_version(constraint, prereleases=True): | def is_unconstrained_version( | ||||||
|  |     constraint: str, prereleases: bool = True | ||||||
|  | ) -> Optional[bool]: | ||||||
|     # We have an exact version, this is the ultimate constrained version |     # We have an exact version, this is the ultimate constrained version | ||||||
|     if constraint[0].isdigit(): |     if constraint[0].isdigit(): | ||||||
|         return False |         return False | ||||||
|  | @ -358,7 +370,7 @@ def is_unconstrained_version(constraint, prereleases=True): | ||||||
|     return True |     return True | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_model_version_range(spacy_version): | def get_model_version_range(spacy_version: str) -> str: | ||||||
|     """Generate a version range like >=1.2.3,<1.3.0 based on a given spaCy |     """Generate a version range like >=1.2.3,<1.3.0 based on a given spaCy | ||||||
|     version. Models are always compatible across patch versions but not |     version. Models are always compatible across patch versions but not | ||||||
|     across minor or major versions. |     across minor or major versions. | ||||||
|  | @ -367,7 +379,7 @@ def get_model_version_range(spacy_version): | ||||||
|     return f">={spacy_version},<{release[0]}.{release[1] + 1}.0" |     return f">={spacy_version},<{release[0]}.{release[1] + 1}.0" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_base_version(version): | def get_base_version(version: str) -> str: | ||||||
|     """Generate the base version without any prerelease identifiers. |     """Generate the base version without any prerelease identifiers. | ||||||
| 
 | 
 | ||||||
|     version (str): The version, e.g. "3.0.0.dev1". |     version (str): The version, e.g. "3.0.0.dev1". | ||||||
|  | @ -376,11 +388,11 @@ def get_base_version(version): | ||||||
|     return Version(version).base_version |     return Version(version).base_version | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_model_meta(path): | def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]: | ||||||
|     """Get model meta.json from a directory path and validate its contents. |     """Get model meta.json from a directory path and validate its contents. | ||||||
| 
 | 
 | ||||||
|     path (str / Path): Path to model directory. |     path (str / Path): Path to model directory. | ||||||
|     RETURNS (dict): The model's meta data. |     RETURNS (Dict[str, Any]): The model's meta data. | ||||||
|     """ |     """ | ||||||
|     model_path = ensure_path(path) |     model_path = ensure_path(path) | ||||||
|     if not model_path.exists(): |     if not model_path.exists(): | ||||||
|  | @ -412,7 +424,7 @@ def get_model_meta(path): | ||||||
|     return meta |     return meta | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def is_package(name): | def is_package(name: str) -> bool: | ||||||
|     """Check if string maps to a package installed via pip. |     """Check if string maps to a package installed via pip. | ||||||
| 
 | 
 | ||||||
|     name (str): Name of package. |     name (str): Name of package. | ||||||
|  | @ -425,7 +437,7 @@ def is_package(name): | ||||||
|         return False |         return False | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_package_path(name): | def get_package_path(name: str) -> Path: | ||||||
|     """Get the path to an installed package. |     """Get the path to an installed package. | ||||||
| 
 | 
 | ||||||
|     name (str): Package name. |     name (str): Package name. | ||||||
|  | @ -495,7 +507,7 @@ def working_dir(path: Union[str, Path]) -> None: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @contextmanager | @contextmanager | ||||||
| def make_tempdir(): | def make_tempdir() -> None: | ||||||
|     """Execute a block in a temporary directory and remove the directory and |     """Execute a block in a temporary directory and remove the directory and | ||||||
|     its contents at the end of the with block. |     its contents at the end of the with block. | ||||||
| 
 | 
 | ||||||
|  | @ -518,7 +530,7 @@ def is_cwd(path: Union[Path, str]) -> bool: | ||||||
|     return str(Path(path).resolve()).lower() == str(Path.cwd().resolve()).lower() |     return str(Path(path).resolve()).lower() == str(Path.cwd().resolve()).lower() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def is_in_jupyter(): | def is_in_jupyter() -> bool: | ||||||
|     """Check if user is running spaCy from a Jupyter notebook by detecting the |     """Check if user is running spaCy from a Jupyter notebook by detecting the | ||||||
|     IPython kernel. Mainly used for the displaCy visualizer. |     IPython kernel. Mainly used for the displaCy visualizer. | ||||||
|     RETURNS (bool): True if in Jupyter, False if not. |     RETURNS (bool): True if in Jupyter, False if not. | ||||||
|  | @ -548,7 +560,9 @@ def get_object_name(obj: Any) -> str: | ||||||
|     return repr(obj) |     return repr(obj) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_cuda_stream(require=False, non_blocking=True): | def get_cuda_stream( | ||||||
|  |     require: bool = False, non_blocking: bool = True | ||||||
|  | ) -> Optional[CudaStream]: | ||||||
|     ops = get_current_ops() |     ops = get_current_ops() | ||||||
|     if CudaStream is None: |     if CudaStream is None: | ||||||
|         return None |         return None | ||||||
|  | @ -567,7 +581,7 @@ def get_async(stream, numpy_array): | ||||||
|         return array |         return array | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def env_opt(name, default=None): | def env_opt(name: str, default: Optional[Any] = None) -> Optional[Any]: | ||||||
|     if type(default) is float: |     if type(default) is float: | ||||||
|         type_convert = float |         type_convert = float | ||||||
|     else: |     else: | ||||||
|  | @ -588,7 +602,7 @@ def env_opt(name, default=None): | ||||||
|         return default |         return default | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def read_regex(path): | def read_regex(path: Union[str, Path]) -> Pattern: | ||||||
|     path = ensure_path(path) |     path = ensure_path(path) | ||||||
|     with path.open(encoding="utf8") as file_: |     with path.open(encoding="utf8") as file_: | ||||||
|         entries = file_.read().split("\n") |         entries = file_.read().split("\n") | ||||||
|  | @ -598,37 +612,40 @@ def read_regex(path): | ||||||
|     return re.compile(expression) |     return re.compile(expression) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def compile_prefix_regex(entries): | def compile_prefix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern: | ||||||
|     """Compile a sequence of prefix rules into a regex object. |     """Compile a sequence of prefix rules into a regex object. | ||||||
| 
 | 
 | ||||||
|     entries (tuple): The prefix rules, e.g. spacy.lang.punctuation.TOKENIZER_PREFIXES. |     entries (Iterable[Union[str, Pattern]]): The prefix rules, e.g. | ||||||
|     RETURNS (regex object): The regex object. to be used for Tokenizer.prefix_search. |         spacy.lang.punctuation.TOKENIZER_PREFIXES. | ||||||
|  |     RETURNS (Pattern): The regex object. to be used for Tokenizer.prefix_search. | ||||||
|     """ |     """ | ||||||
|     expression = "|".join(["^" + piece for piece in entries if piece.strip()]) |     expression = "|".join(["^" + piece for piece in entries if piece.strip()]) | ||||||
|     return re.compile(expression) |     return re.compile(expression) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def compile_suffix_regex(entries): | def compile_suffix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern: | ||||||
|     """Compile a sequence of suffix rules into a regex object. |     """Compile a sequence of suffix rules into a regex object. | ||||||
| 
 | 
 | ||||||
|     entries (tuple): The suffix rules, e.g. spacy.lang.punctuation.TOKENIZER_SUFFIXES. |     entries (Iterable[Union[str, Pattern]]): The suffix rules, e.g. | ||||||
|     RETURNS (regex object): The regex object. to be used for Tokenizer.suffix_search. |         spacy.lang.punctuation.TOKENIZER_SUFFIXES. | ||||||
|  |     RETURNS (Pattern): The regex object. to be used for Tokenizer.suffix_search. | ||||||
|     """ |     """ | ||||||
|     expression = "|".join([piece + "$" for piece in entries if piece.strip()]) |     expression = "|".join([piece + "$" for piece in entries if piece.strip()]) | ||||||
|     return re.compile(expression) |     return re.compile(expression) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def compile_infix_regex(entries): | def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern: | ||||||
|     """Compile a sequence of infix rules into a regex object. |     """Compile a sequence of infix rules into a regex object. | ||||||
| 
 | 
 | ||||||
|     entries (tuple): The infix rules, e.g. spacy.lang.punctuation.TOKENIZER_INFIXES. |     entries (Iterable[Union[str, Pattern]]): The infix rules, e.g. | ||||||
|  |         spacy.lang.punctuation.TOKENIZER_INFIXES. | ||||||
|     RETURNS (regex object): The regex object. to be used for Tokenizer.infix_finditer. |     RETURNS (regex object): The regex object. to be used for Tokenizer.infix_finditer. | ||||||
|     """ |     """ | ||||||
|     expression = "|".join([piece for piece in entries if piece.strip()]) |     expression = "|".join([piece for piece in entries if piece.strip()]) | ||||||
|     return re.compile(expression) |     return re.compile(expression) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def add_lookups(default_func, *lookups): | def add_lookups(default_func: Callable[[str], Any], *lookups) -> Callable[[str], Any]: | ||||||
|     """Extend an attribute function with special cases. If a word is in the |     """Extend an attribute function with special cases. If a word is in the | ||||||
|     lookups, the value is returned. Otherwise the previous function is used. |     lookups, the value is returned. Otherwise the previous function is used. | ||||||
| 
 | 
 | ||||||
|  | @ -641,19 +658,23 @@ def add_lookups(default_func, *lookups): | ||||||
|     return functools.partial(_get_attr_unless_lookup, default_func, lookups) |     return functools.partial(_get_attr_unless_lookup, default_func, lookups) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _get_attr_unless_lookup(default_func, lookups, string): | def _get_attr_unless_lookup( | ||||||
|  |     default_func: Callable[[str], Any], lookups: Dict[str, Any], string: str | ||||||
|  | ) -> Any: | ||||||
|     for lookup in lookups: |     for lookup in lookups: | ||||||
|         if string in lookup: |         if string in lookup: | ||||||
|             return lookup[string] |             return lookup[string] | ||||||
|     return default_func(string) |     return default_func(string) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def update_exc(base_exceptions, *addition_dicts): | def update_exc( | ||||||
|  |     base_exceptions: Dict[str, List[dict]], *addition_dicts | ||||||
|  | ) -> Dict[str, List[dict]]: | ||||||
|     """Update and validate tokenizer exceptions. Will overwrite exceptions. |     """Update and validate tokenizer exceptions. Will overwrite exceptions. | ||||||
| 
 | 
 | ||||||
|     base_exceptions (dict): Base exceptions. |     base_exceptions (Dict[str, List[dict]]): Base exceptions. | ||||||
|     *addition_dicts (dict): Exceptions to add to the base dict, in order. |     *addition_dicts (Dict[str, List[dict]]): Exceptions to add to the base dict, in order. | ||||||
|     RETURNS (dict): Combined tokenizer exceptions. |     RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions. | ||||||
|     """ |     """ | ||||||
|     exc = dict(base_exceptions) |     exc = dict(base_exceptions) | ||||||
|     for additions in addition_dicts: |     for additions in addition_dicts: | ||||||
|  | @ -668,14 +689,16 @@ def update_exc(base_exceptions, *addition_dicts): | ||||||
|     return exc |     return exc | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def expand_exc(excs, search, replace): | def expand_exc( | ||||||
|  |     excs: Dict[str, List[dict]], search: str, replace: str | ||||||
|  | ) -> Dict[str, List[dict]]: | ||||||
|     """Find string in tokenizer exceptions, duplicate entry and replace string. |     """Find string in tokenizer exceptions, duplicate entry and replace string. | ||||||
|     For example, to add additional versions with typographic apostrophes. |     For example, to add additional versions with typographic apostrophes. | ||||||
| 
 | 
 | ||||||
|     excs (dict): Tokenizer exceptions. |     excs (Dict[str, List[dict]]): Tokenizer exceptions. | ||||||
|     search (str): String to find and replace. |     search (str): String to find and replace. | ||||||
|     replace (str): Replacement. |     replace (str): Replacement. | ||||||
|     RETURNS (dict): Combined tokenizer exceptions. |     RETURNS (Dict[str, List[dict]]): Combined tokenizer exceptions. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def _fix_token(token, search, replace): |     def _fix_token(token, search, replace): | ||||||
|  | @ -692,7 +715,9 @@ def expand_exc(excs, search, replace): | ||||||
|     return new_excs |     return new_excs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def normalize_slice(length, start, stop, step=None): | def normalize_slice( | ||||||
|  |     length: int, start: int, stop: int, step: Optional[int] = None | ||||||
|  | ) -> Tuple[int, int]: | ||||||
|     if not (step is None or step == 1): |     if not (step is None or step == 1): | ||||||
|         raise ValueError(Errors.E057) |         raise ValueError(Errors.E057) | ||||||
|     if start is None: |     if start is None: | ||||||
|  | @ -708,7 +733,9 @@ def normalize_slice(length, start, stop, step=None): | ||||||
|     return start, stop |     return start, stop | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def minibatch(items, size=8): | def minibatch( | ||||||
|  |     items: Iterable[Any], size: Union[Iterator[int], int] = 8 | ||||||
|  | ) -> Iterator[Any]: | ||||||
|     """Iterate over batches of items. `size` may be an iterator, |     """Iterate over batches of items. `size` may be an iterator, | ||||||
|     so that batch-size can vary on each step. |     so that batch-size can vary on each step. | ||||||
|     """ |     """ | ||||||
|  | @ -725,7 +752,12 @@ def minibatch(items, size=8): | ||||||
|         yield list(batch) |         yield list(batch) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False): | def minibatch_by_padded_size( | ||||||
|  |     docs: Iterator["Doc"], | ||||||
|  |     size: Union[Iterator[int], int], | ||||||
|  |     buffer: int = 256, | ||||||
|  |     discard_oversize: bool = False, | ||||||
|  | ) -> Iterator[Iterator["Doc"]]: | ||||||
|     if isinstance(size, int): |     if isinstance(size, int): | ||||||
|         size_ = itertools.repeat(size) |         size_ = itertools.repeat(size) | ||||||
|     else: |     else: | ||||||
|  | @ -742,7 +774,7 @@ def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False): | ||||||
|                 yield subbatch |                 yield subbatch | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _batch_by_length(seqs, max_words): | def _batch_by_length(seqs: Sequence[Any], max_words: int) -> List[List[Any]]: | ||||||
|     """Given a list of sequences, return a batched list of indices into the |     """Given a list of sequences, return a batched list of indices into the | ||||||
|     list, where the batches are grouped by length, in descending order. |     list, where the batches are grouped by length, in descending order. | ||||||
| 
 | 
 | ||||||
|  | @ -783,14 +815,12 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): | ||||||
|         size_ = iter(size) |         size_ = iter(size) | ||||||
|     else: |     else: | ||||||
|         size_ = size |         size_ = size | ||||||
| 
 |  | ||||||
|     target_size = next(size_) |     target_size = next(size_) | ||||||
|     tol_size = target_size * tolerance |     tol_size = target_size * tolerance | ||||||
|     batch = [] |     batch = [] | ||||||
|     overflow = [] |     overflow = [] | ||||||
|     batch_size = 0 |     batch_size = 0 | ||||||
|     overflow_size = 0 |     overflow_size = 0 | ||||||
| 
 |  | ||||||
|     for doc in docs: |     for doc in docs: | ||||||
|         if isinstance(doc, Example): |         if isinstance(doc, Example): | ||||||
|             n_words = len(doc.reference) |             n_words = len(doc.reference) | ||||||
|  | @ -803,17 +833,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): | ||||||
|         if n_words > target_size + tol_size: |         if n_words > target_size + tol_size: | ||||||
|             if not discard_oversize: |             if not discard_oversize: | ||||||
|                 yield [doc] |                 yield [doc] | ||||||
| 
 |  | ||||||
|         # add the example to the current batch if there's no overflow yet and it still fits |         # add the example to the current batch if there's no overflow yet and it still fits | ||||||
|         elif overflow_size == 0 and (batch_size + n_words) <= target_size: |         elif overflow_size == 0 and (batch_size + n_words) <= target_size: | ||||||
|             batch.append(doc) |             batch.append(doc) | ||||||
|             batch_size += n_words |             batch_size += n_words | ||||||
| 
 |  | ||||||
|         # add the example to the overflow buffer if it fits in the tolerance margin |         # add the example to the overflow buffer if it fits in the tolerance margin | ||||||
|         elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): |         elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): | ||||||
|             overflow.append(doc) |             overflow.append(doc) | ||||||
|             overflow_size += n_words |             overflow_size += n_words | ||||||
| 
 |  | ||||||
|         # yield the previous batch and start a new one. The new one gets the overflow examples. |         # yield the previous batch and start a new one. The new one gets the overflow examples. | ||||||
|         else: |         else: | ||||||
|             if batch: |             if batch: | ||||||
|  | @ -824,17 +851,14 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): | ||||||
|             batch_size = overflow_size |             batch_size = overflow_size | ||||||
|             overflow = [] |             overflow = [] | ||||||
|             overflow_size = 0 |             overflow_size = 0 | ||||||
| 
 |  | ||||||
|             # this example still fits |             # this example still fits | ||||||
|             if (batch_size + n_words) <= target_size: |             if (batch_size + n_words) <= target_size: | ||||||
|                 batch.append(doc) |                 batch.append(doc) | ||||||
|                 batch_size += n_words |                 batch_size += n_words | ||||||
| 
 |  | ||||||
|             # this example fits in overflow |             # this example fits in overflow | ||||||
|             elif (batch_size + n_words) <= (target_size + tol_size): |             elif (batch_size + n_words) <= (target_size + tol_size): | ||||||
|                 overflow.append(doc) |                 overflow.append(doc) | ||||||
|                 overflow_size += n_words |                 overflow_size += n_words | ||||||
| 
 |  | ||||||
|             # this example does not fit with the previous overflow: start another new batch |             # this example does not fit with the previous overflow: start another new batch | ||||||
|             else: |             else: | ||||||
|                 if batch: |                 if batch: | ||||||
|  | @ -843,20 +867,19 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): | ||||||
|                 tol_size = target_size * tolerance |                 tol_size = target_size * tolerance | ||||||
|                 batch = [doc] |                 batch = [doc] | ||||||
|                 batch_size = n_words |                 batch_size = n_words | ||||||
| 
 |  | ||||||
|     batch.extend(overflow) |     batch.extend(overflow) | ||||||
|     if batch: |     if batch: | ||||||
|         yield batch |         yield batch | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def filter_spans(spans): | def filter_spans(spans: Iterable["Span"]) -> List["Span"]: | ||||||
|     """Filter a sequence of spans and remove duplicates or overlaps. Useful for |     """Filter a sequence of spans and remove duplicates or overlaps. Useful for | ||||||
|     creating named entities (where one token can only be part of one entity) or |     creating named entities (where one token can only be part of one entity) or | ||||||
|     when merging spans with `Retokenizer.merge`. When spans overlap, the (first) |     when merging spans with `Retokenizer.merge`. When spans overlap, the (first) | ||||||
|     longest span is preferred over shorter spans. |     longest span is preferred over shorter spans. | ||||||
| 
 | 
 | ||||||
|     spans (iterable): The spans to filter. |     spans (Iterable[Span]): The spans to filter. | ||||||
|     RETURNS (list): The filtered spans. |     RETURNS (List[Span]): The filtered spans. | ||||||
|     """ |     """ | ||||||
|     get_sort_key = lambda span: (span.end - span.start, -span.start) |     get_sort_key = lambda span: (span.end - span.start, -span.start) | ||||||
|     sorted_spans = sorted(spans, key=get_sort_key, reverse=True) |     sorted_spans = sorted(spans, key=get_sort_key, reverse=True) | ||||||
|  | @ -871,15 +894,21 @@ def filter_spans(spans): | ||||||
|     return result |     return result | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def to_bytes(getters, exclude): | def to_bytes(getters: Dict[str, Callable[[], bytes]], exclude: Iterable[str]) -> bytes: | ||||||
|     return srsly.msgpack_dumps(to_dict(getters, exclude)) |     return srsly.msgpack_dumps(to_dict(getters, exclude)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def from_bytes(bytes_data, setters, exclude): | def from_bytes( | ||||||
|  |     bytes_data: bytes, | ||||||
|  |     setters: Dict[str, Callable[[bytes], Any]], | ||||||
|  |     exclude: Iterable[str], | ||||||
|  | ) -> None: | ||||||
|     return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude) |     return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def to_dict(getters, exclude): | def to_dict( | ||||||
|  |     getters: Dict[str, Callable[[], Any]], exclude: Iterable[str] | ||||||
|  | ) -> Dict[str, Any]: | ||||||
|     serialized = {} |     serialized = {} | ||||||
|     for key, getter in getters.items(): |     for key, getter in getters.items(): | ||||||
|         # Split to support file names like meta.json |         # Split to support file names like meta.json | ||||||
|  | @ -888,7 +917,11 @@ def to_dict(getters, exclude): | ||||||
|     return serialized |     return serialized | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def from_dict(msg, setters, exclude): | def from_dict( | ||||||
|  |     msg: Dict[str, Any], | ||||||
|  |     setters: Dict[str, Callable[[Any], Any]], | ||||||
|  |     exclude: Iterable[str], | ||||||
|  | ) -> Dict[str, Any]: | ||||||
|     for key, setter in setters.items(): |     for key, setter in setters.items(): | ||||||
|         # Split to support file names like meta.json |         # Split to support file names like meta.json | ||||||
|         if key.split(".")[0] not in exclude and key in msg: |         if key.split(".")[0] not in exclude and key in msg: | ||||||
|  | @ -896,7 +929,11 @@ def from_dict(msg, setters, exclude): | ||||||
|     return msg |     return msg | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def to_disk(path, writers, exclude): | def to_disk( | ||||||
|  |     path: Union[str, Path], | ||||||
|  |     writers: Dict[str, Callable[[Path], None]], | ||||||
|  |     exclude: Iterable[str], | ||||||
|  | ) -> Path: | ||||||
|     path = ensure_path(path) |     path = ensure_path(path) | ||||||
|     if not path.exists(): |     if not path.exists(): | ||||||
|         path.mkdir() |         path.mkdir() | ||||||
|  | @ -907,7 +944,11 @@ def to_disk(path, writers, exclude): | ||||||
|     return path |     return path | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def from_disk(path, readers, exclude): | def from_disk( | ||||||
|  |     path: Union[str, Path], | ||||||
|  |     readers: Dict[str, Callable[[Path], None]], | ||||||
|  |     exclude: Iterable[str], | ||||||
|  | ) -> Path: | ||||||
|     path = ensure_path(path) |     path = ensure_path(path) | ||||||
|     for key, reader in readers.items(): |     for key, reader in readers.items(): | ||||||
|         # Split to support file names like meta.json |         # Split to support file names like meta.json | ||||||
|  | @ -916,7 +957,7 @@ def from_disk(path, readers, exclude): | ||||||
|     return path |     return path | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def import_file(name, loc): | def import_file(name: str, loc: Union[str, Path]) -> ModuleType: | ||||||
|     """Import module from a file. Used to load models from a directory. |     """Import module from a file. Used to load models from a directory. | ||||||
| 
 | 
 | ||||||
|     name (str): Name of module to load. |     name (str): Name of module to load. | ||||||
|  | @ -930,7 +971,7 @@ def import_file(name, loc): | ||||||
|     return module |     return module | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def minify_html(html): | def minify_html(html: str) -> str: | ||||||
|     """Perform a template-specific, rudimentary HTML minification for displaCy. |     """Perform a template-specific, rudimentary HTML minification for displaCy. | ||||||
|     Disclaimer: NOT a general-purpose solution, only removes indentation and |     Disclaimer: NOT a general-purpose solution, only removes indentation and | ||||||
|     newlines. |     newlines. | ||||||
|  | @ -941,7 +982,7 @@ def minify_html(html): | ||||||
|     return html.strip().replace("    ", "").replace("\n", "") |     return html.strip().replace("    ", "").replace("\n", "") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def escape_html(text): | def escape_html(text: str) -> str: | ||||||
|     """Replace <, >, &, " with their HTML encoded representation. Intended to |     """Replace <, >, &, " with their HTML encoded representation. Intended to | ||||||
|     prevent HTML errors in rendered displaCy markup. |     prevent HTML errors in rendered displaCy markup. | ||||||
| 
 | 
 | ||||||
|  | @ -955,7 +996,9 @@ def escape_html(text): | ||||||
|     return text |     return text | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_words_and_spaces(words, text): | def get_words_and_spaces( | ||||||
|  |     words: Iterable[str], text: str | ||||||
|  | ) -> Tuple[List[str], List[bool]]: | ||||||
|     if "".join("".join(words).split()) != "".join(text.split()): |     if "".join("".join(words).split()) != "".join(text.split()): | ||||||
|         raise ValueError(Errors.E194.format(text=text, words=words)) |         raise ValueError(Errors.E194.format(text=text, words=words)) | ||||||
|     text_words = [] |     text_words = [] | ||||||
|  | @ -1103,7 +1146,7 @@ class DummyTokenizer: | ||||||
|         return self |         return self | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def link_vectors_to_models(vocab): | def link_vectors_to_models(vocab: "Vocab") -> None: | ||||||
|     vectors = vocab.vectors |     vectors = vocab.vectors | ||||||
|     if vectors.name is None: |     if vectors.name is None: | ||||||
|         vectors.name = VECTORS_KEY |         vectors.name = VECTORS_KEY | ||||||
|  | @ -1119,7 +1162,7 @@ def link_vectors_to_models(vocab): | ||||||
| VECTORS_KEY = "spacy_pretrained_vectors" | VECTORS_KEY = "spacy_pretrained_vectors" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_default_optimizer(): | def create_default_optimizer() -> Optimizer: | ||||||
|     learn_rate = env_opt("learn_rate", 0.001) |     learn_rate = env_opt("learn_rate", 0.001) | ||||||
|     beta1 = env_opt("optimizer_B1", 0.9) |     beta1 = env_opt("optimizer_B1", 0.9) | ||||||
|     beta2 = env_opt("optimizer_B2", 0.999) |     beta2 = env_opt("optimizer_B2", 0.999) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user