mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
commit
0994dc50d8
4
Makefile
4
Makefile
|
@ -5,11 +5,11 @@ dist/spacy.pex : spacy/*.py* spacy/*/*.py*
|
||||||
python3.6 -m venv env3.6
|
python3.6 -m venv env3.6
|
||||||
source env3.6/bin/activate
|
source env3.6/bin/activate
|
||||||
env3.6/bin/pip install wheel
|
env3.6/bin/pip install wheel
|
||||||
env3.6/bin/pip install -r requirements.txt --no-cache-dir --no-binary :all:
|
env3.6/bin/pip install -r requirements.txt --no-cache-dir
|
||||||
env3.6/bin/python setup.py build_ext --inplace
|
env3.6/bin/python setup.py build_ext --inplace
|
||||||
env3.6/bin/python setup.py sdist
|
env3.6/bin/python setup.py sdist
|
||||||
env3.6/bin/python setup.py bdist_wheel
|
env3.6/bin/python setup.py bdist_wheel
|
||||||
env3.6/bin/python -m pip install pex
|
env3.6/bin/python -m pip install pex==1.5.3
|
||||||
env3.6/bin/pex pytest dist/*.whl -e spacy -o dist/spacy-$(sha).pex
|
env3.6/bin/pex pytest dist/*.whl -e spacy -o dist/spacy-$(sha).pex
|
||||||
cp dist/spacy-$(sha).pex dist/spacy.pex
|
cp dist/spacy-$(sha).pex dist/spacy.pex
|
||||||
chmod a+rx dist/spacy.pex
|
chmod a+rx dist/spacy.pex
|
||||||
|
|
|
@ -4,7 +4,7 @@ preshed>=2.0.1,<2.1.0
|
||||||
thinc==7.0.0.dev6
|
thinc==7.0.0.dev6
|
||||||
blis>=0.2.2,<0.3.0
|
blis>=0.2.2,<0.3.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
wasabi>=0.0.8,<1.1.0
|
wasabi>=0.0.12,<1.1.0
|
||||||
srsly>=0.0.5,<1.1.0
|
srsly>=0.0.5,<1.1.0
|
||||||
# Third party dependencies
|
# Third party dependencies
|
||||||
numpy>=1.15.0
|
numpy>=1.15.0
|
||||||
|
|
12
setup.py
12
setup.py
|
@ -13,12 +13,12 @@ from setuptools import Extension, setup, find_packages
|
||||||
|
|
||||||
|
|
||||||
def is_new_osx():
|
def is_new_osx():
|
||||||
'''Check whether we're on OSX >= 10.10'''
|
"""Check whether we're on OSX >= 10.10"""
|
||||||
name = distutils.util.get_platform()
|
name = distutils.util.get_platform()
|
||||||
if sys.platform != 'darwin':
|
if sys.platform != "darwin":
|
||||||
return False
|
return False
|
||||||
elif name.startswith('macosx-10'):
|
elif name.startswith("macosx-10"):
|
||||||
minor_version = int(name.split('-')[1].split('.')[1])
|
minor_version = int(name.split("-")[1].split(".")[1])
|
||||||
if minor_version >= 7:
|
if minor_version >= 7:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
@ -27,7 +27,6 @@ def is_new_osx():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PACKAGE_DATA = {"": ["*.pyx", "*.pxd", "*.txt", "*.tokens"]}
|
PACKAGE_DATA = {"": ["*.pyx", "*.pxd", "*.txt", "*.tokens"]}
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,7 +83,6 @@ if is_new_osx():
|
||||||
LINK_OPTIONS["other"].append("-nodefaultlibs")
|
LINK_OPTIONS["other"].append("-nodefaultlibs")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
USE_OPENMP_DEFAULT = "0" if sys.platform != "darwin" else None
|
USE_OPENMP_DEFAULT = "0" if sys.platform != "darwin" else None
|
||||||
if os.environ.get("USE_OPENMP", USE_OPENMP_DEFAULT) == "1":
|
if os.environ.get("USE_OPENMP", USE_OPENMP_DEFAULT) == "1":
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
|
@ -232,7 +230,7 @@ def setup_package():
|
||||||
"regex==2018.01.10",
|
"regex==2018.01.10",
|
||||||
"requests>=2.13.0,<3.0.0",
|
"requests>=2.13.0,<3.0.0",
|
||||||
"jsonschema>=2.6.0,<3.0.0",
|
"jsonschema>=2.6.0,<3.0.0",
|
||||||
"wasabi>=0.0.8,<1.1.0",
|
"wasabi>=0.0.12,<1.1.0",
|
||||||
"srsly>=0.0.5,<1.1.0",
|
"srsly>=0.0.5,<1.1.0",
|
||||||
'pathlib==1.0.1; python_version < "3.4"',
|
'pathlib==1.0.1; python_version < "3.4"',
|
||||||
],
|
],
|
||||||
|
|
|
@ -271,7 +271,7 @@ def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
|
||||||
|
|
||||||
def Tok2Vec(width, embed_size, **kwargs):
|
def Tok2Vec(width, embed_size, **kwargs):
|
||||||
pretrained_vectors = kwargs.get("pretrained_vectors", None)
|
pretrained_vectors = kwargs.get("pretrained_vectors", None)
|
||||||
cnn_maxout_pieces = kwargs.get("cnn_maxout_pieces", 2)
|
cnn_maxout_pieces = kwargs.get("cnn_maxout_pieces", 3)
|
||||||
subword_features = kwargs.get("subword_features", True)
|
subword_features = kwargs.get("subword_features", True)
|
||||||
conv_depth = kwargs.get("conv_depth", 4)
|
conv_depth = kwargs.get("conv_depth", 4)
|
||||||
bilstm_depth = kwargs.get("bilstm_depth", 0)
|
bilstm_depth = kwargs.get("bilstm_depth", 0)
|
||||||
|
|
|
@ -1,105 +0,0 @@
|
||||||
# coding: utf8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
|
|
||||||
class Messages(object):
|
|
||||||
M001 = ("Download successful but linking failed")
|
|
||||||
M002 = ("Creating a shortcut link for 'en' didn't work (maybe you "
|
|
||||||
"don't have admin permissions?), but you can still load the "
|
|
||||||
"model via its full package name: nlp = spacy.load('{name}')")
|
|
||||||
M003 = ("Server error ({code})")
|
|
||||||
M004 = ("Couldn't fetch {desc}. Please find a model for your spaCy "
|
|
||||||
"installation (v{version}), and download it manually. For more "
|
|
||||||
"details, see the documentation: https://spacy.io/usage/models")
|
|
||||||
M005 = ("Compatibility error")
|
|
||||||
M006 = ("No compatible models found for v{version} of spaCy.")
|
|
||||||
M007 = ("No compatible model found for '{name}' (spaCy v{version}).")
|
|
||||||
M008 = ("Can't locate model data")
|
|
||||||
M009 = ("The data should be located in {path}")
|
|
||||||
M010 = ("Can't find the spaCy data path to create model symlink")
|
|
||||||
M011 = ("Make sure a directory `/data` exists within your spaCy "
|
|
||||||
"installation and try again. The data directory should be "
|
|
||||||
"located here:")
|
|
||||||
M012 = ("Link '{name}' already exists")
|
|
||||||
M013 = ("To overwrite an existing link, use the --force flag.")
|
|
||||||
M014 = ("Can't overwrite symlink '{name}'")
|
|
||||||
M015 = ("This can happen if your data directory contains a directory or "
|
|
||||||
"file of the same name.")
|
|
||||||
M016 = ("Error: Couldn't link model to '{name}'")
|
|
||||||
M017 = ("Creating a symlink in spacy/data failed. Make sure you have the "
|
|
||||||
"required permissions and try re-running the command as admin, or "
|
|
||||||
"use a virtualenv. You can still import the model as a module and "
|
|
||||||
"call its load() method, or create the symlink manually.")
|
|
||||||
M018 = ("Linking successful")
|
|
||||||
M019 = ("You can now load the model via spacy.load('{name}')")
|
|
||||||
M020 = ("Can't find model meta.json")
|
|
||||||
M021 = ("Couldn't fetch compatibility table.")
|
|
||||||
M022 = ("Can't find spaCy v{version} in compatibility table")
|
|
||||||
M023 = ("Installed models (spaCy v{version})")
|
|
||||||
M024 = ("No models found in your current environment.")
|
|
||||||
M025 = ("Use the following commands to update the model packages:")
|
|
||||||
M026 = ("The following models are not available for spaCy "
|
|
||||||
"v{version}: {models}")
|
|
||||||
M027 = ("You may also want to overwrite the incompatible links using the "
|
|
||||||
"`python -m spacy link` command with `--force`, or remove them "
|
|
||||||
"from the data directory. Data path: {path}")
|
|
||||||
M028 = ("Input file not found")
|
|
||||||
M029 = ("Output directory not found")
|
|
||||||
M030 = ("Unknown format")
|
|
||||||
M031 = ("Can't find converter for {converter}")
|
|
||||||
M032 = ("Generated output file {name}")
|
|
||||||
M033 = ("Created {n_docs} documents")
|
|
||||||
M034 = ("Evaluation data not found")
|
|
||||||
M035 = ("Visualization output directory not found")
|
|
||||||
M036 = ("Generated {n} parses as HTML")
|
|
||||||
M037 = ("Can't find words frequencies file")
|
|
||||||
M038 = ("Sucessfully compiled vocab")
|
|
||||||
M039 = ("{entries} entries, {vectors} vectors")
|
|
||||||
M040 = ("Output directory not found")
|
|
||||||
M041 = ("Loaded meta.json from file")
|
|
||||||
M042 = ("Successfully created package '{name}'")
|
|
||||||
M043 = ("To build the package, run `python setup.py sdist` in this "
|
|
||||||
"directory.")
|
|
||||||
M044 = ("Package directory already exists")
|
|
||||||
M045 = ("Please delete the directory and try again, or use the `--force` "
|
|
||||||
"flag to overwrite existing directories.")
|
|
||||||
M046 = ("Generating meta.json")
|
|
||||||
M047 = ("Enter the package settings for your model. The following "
|
|
||||||
"information will be read from your model data: pipeline, vectors.")
|
|
||||||
M048 = ("No '{key}' setting found in meta.json")
|
|
||||||
M049 = ("This setting is required to build your package.")
|
|
||||||
M050 = ("Training data not found")
|
|
||||||
M051 = ("Development data not found")
|
|
||||||
M052 = ("Not a valid meta.json format")
|
|
||||||
M053 = ("Expected dict but got: {meta_type}")
|
|
||||||
M054 = ("No --lang specified, but tokenization required.")
|
|
||||||
M055 = ("Training pipeline: {pipeline}")
|
|
||||||
M056 = ("Starting with base model '{model}'")
|
|
||||||
M057 = ("Starting with blank model '{model}'")
|
|
||||||
M058 = ("Loading vector from model '{model}'")
|
|
||||||
M059 = ("Can't use multitask objective without '{pipe}' in the pipeline")
|
|
||||||
M060 = ("Counting training words (limit={limit})")
|
|
||||||
M061 = ("\nSaving model...")
|
|
||||||
M062 = ("Output directory is not empty.")
|
|
||||||
M063 = ("Incompatible arguments")
|
|
||||||
M064 = ("The -f and -c arguments are deprecated, and not compatible with "
|
|
||||||
"the -j argument, which should specify the same information. "
|
|
||||||
"Either merge the frequencies and clusters data into the "
|
|
||||||
"JSONL-formatted file (recommended), or use only the -f and -c "
|
|
||||||
"files, without the other lexical attributes.")
|
|
||||||
M065 = ("This can lead to unintended side effects when saving the model. "
|
|
||||||
"Please use an empty directory or a different path instead. If "
|
|
||||||
"the specified output path doesn't exist, the directory will be "
|
|
||||||
"created for you.")
|
|
||||||
M066 = ("Saved model to output directory")
|
|
||||||
M067 = ("Can't find lexical data")
|
|
||||||
M068 = ("Sucessfully compiled vocab and vectors, and saved model")
|
|
||||||
M069 = ("Unknown file type: '{name}'")
|
|
||||||
M070 = ("Supported file types: '{options}'")
|
|
||||||
M071 = ("Loaded pretrained tok2vec for: {components}")
|
|
||||||
M072 = ("Model language ('{model_lang}') doesn't match language specified "
|
|
||||||
"as `lang` argument ('{lang}') ")
|
|
||||||
|
|
||||||
# fmt: on
|
|
|
@ -6,10 +6,8 @@ from pathlib import Path
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
from ..compat import path2str
|
|
||||||
from .converters import conllu2json, conllubio2json, iob2json, conll_ner2json
|
from .converters import conllu2json, conllubio2json, iob2json, conll_ner2json
|
||||||
from .converters import ner_jsonl2json
|
from .converters import ner_jsonl2json
|
||||||
from ._messages import Messages
|
|
||||||
|
|
||||||
|
|
||||||
# Converters are matched by file extension. To add a converter, add a new
|
# Converters are matched by file extension. To add a converter, add a new
|
||||||
|
@ -56,18 +54,18 @@ def convert(
|
||||||
input_path = Path(input_file)
|
input_path = Path(input_file)
|
||||||
if file_type not in FILE_TYPES:
|
if file_type not in FILE_TYPES:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
Messages.M069.format(name=file_type),
|
"Unknown file type: '{}'".format(file_type),
|
||||||
Messages.M070.format(options=", ".join(FILE_TYPES)),
|
"Supported file types: '{}'".format(", ".join(FILE_TYPES)),
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
if not input_path.exists():
|
if not input_path.exists():
|
||||||
msg.fail(Messages.M028, input_path, exits=1)
|
msg.fail("Input file not found", input_path, exits=1)
|
||||||
if output_dir != "-" and not Path(output_dir).exists():
|
if output_dir != "-" and not Path(output_dir).exists():
|
||||||
msg.fail(Messages.M029, output_dir, exits=1)
|
msg.fail("Output directory not found", output_dir, exits=1)
|
||||||
if converter == "auto":
|
if converter == "auto":
|
||||||
converter = input_path.suffix[1:]
|
converter = input_path.suffix[1:]
|
||||||
if converter not in CONVERTERS:
|
if converter not in CONVERTERS:
|
||||||
msg.fail(Messages.M030, Messages.M031.format(converter=converter), exits=1)
|
msg.fail("Can't find converter for {}".format(converter), exits=1)
|
||||||
# Use converter function to convert data
|
# Use converter function to convert data
|
||||||
func = CONVERTERS[converter]
|
func = CONVERTERS[converter]
|
||||||
input_data = input_path.open("r", encoding="utf-8").read()
|
input_data = input_path.open("r", encoding="utf-8").read()
|
||||||
|
@ -80,10 +78,7 @@ def convert(
|
||||||
srsly.write_json(output_file, data)
|
srsly.write_json(output_file, data)
|
||||||
elif file_type == "jsonl":
|
elif file_type == "jsonl":
|
||||||
srsly.write_jsonl(output_file, data)
|
srsly.write_jsonl(output_file, data)
|
||||||
msg.good(
|
msg.good("Generated output file ({} documents)".format(len(data)), output_file)
|
||||||
Messages.M032.format(name=path2str(output_file)),
|
|
||||||
Messages.M033.format(n_docs=len(data)),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Print to stdout
|
# Print to stdout
|
||||||
if file_type == "json":
|
if file_type == "json":
|
||||||
|
|
|
@ -4,12 +4,11 @@ from __future__ import unicode_literals
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
from ...util import get_lang_class
|
from ...util import get_lang_class
|
||||||
from .._messages import Messages
|
|
||||||
|
|
||||||
|
|
||||||
def ner_jsonl2json(input_data, lang=None, n_sents=10, use_morphology=False):
|
def ner_jsonl2json(input_data, lang=None, n_sents=10, use_morphology=False):
|
||||||
if lang is None:
|
if lang is None:
|
||||||
raise ValueError(Messages.M054)
|
raise ValueError("No --lang specified, but tokenization required")
|
||||||
json_docs = []
|
json_docs = []
|
||||||
input_tuples = [srsly.json_loads(line) for line in input_data]
|
input_tuples = [srsly.json_loads(line) for line in input_data]
|
||||||
nlp = get_lang_class(lang)()
|
nlp = get_lang_class(lang)()
|
||||||
|
|
|
@ -12,7 +12,6 @@ from ..gold import GoldCorpus, read_json_object
|
||||||
from ..util import load_model, get_lang_class
|
from ..util import load_model, get_lang_class
|
||||||
|
|
||||||
# from .schemas import get_schema, validate_json
|
# from .schemas import get_schema, validate_json
|
||||||
from ._messages import Messages
|
|
||||||
|
|
||||||
|
|
||||||
# Minimum number of expected occurences of label in data to train new label
|
# Minimum number of expected occurences of label in data to train new label
|
||||||
|
@ -58,9 +57,9 @@ def debug_data(
|
||||||
|
|
||||||
# Make sure all files and paths exists if they are needed
|
# Make sure all files and paths exists if they are needed
|
||||||
if not train_path.exists():
|
if not train_path.exists():
|
||||||
msg.fail(Messages.M050, train_path, exits=1)
|
msg.fail("Training data not found", train_path, exits=1)
|
||||||
if not dev_path.exists():
|
if not dev_path.exists():
|
||||||
msg.fail(Messages.M051, dev_path, exits=1)
|
msg.fail("Development data not found", dev_path, exits=1)
|
||||||
|
|
||||||
# Initialize the model and pipeline
|
# Initialize the model and pipeline
|
||||||
pipeline = [p.strip() for p in pipeline.split(",")]
|
pipeline = [p.strip() for p in pipeline.split(",")]
|
||||||
|
@ -72,10 +71,8 @@ def debug_data(
|
||||||
|
|
||||||
msg.divider("Data format validation")
|
msg.divider("Data format validation")
|
||||||
# Load the data in one – might take a while but okay in this case
|
# Load the data in one – might take a while but okay in this case
|
||||||
with msg.loading("Loading {}...".format(train_path.parts[-1])):
|
train_data = _load_file(train_path, msg)
|
||||||
train_data = _load_file(train_path, msg)
|
dev_data = _load_file(dev_path, msg)
|
||||||
with msg.loading("Loading {}...".format(dev_path.parts[-1])):
|
|
||||||
dev_data = _load_file(dev_path, msg)
|
|
||||||
|
|
||||||
# Validate data format using the JSON schema
|
# Validate data format using the JSON schema
|
||||||
# TODO: update once the new format is ready
|
# TODO: update once the new format is ready
|
||||||
|
@ -172,6 +169,7 @@ def debug_data(
|
||||||
existing_labels = [l for l in labels if l in model_labels]
|
existing_labels = [l for l in labels if l in model_labels]
|
||||||
has_low_data_warning = False
|
has_low_data_warning = False
|
||||||
has_no_neg_warning = False
|
has_no_neg_warning = False
|
||||||
|
has_ws_ents_error = False
|
||||||
|
|
||||||
msg.divider("Named Entity Recognition")
|
msg.divider("Named Entity Recognition")
|
||||||
msg.info(
|
msg.info(
|
||||||
|
@ -201,6 +199,10 @@ def debug_data(
|
||||||
"Existing: {}".format(_format_labels(existing_labels)), show=verbose
|
"Existing: {}".format(_format_labels(existing_labels)), show=verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if gold_data["ws_ents"]:
|
||||||
|
msg.fail("{} invalid whitespace entity spans".format(gold_data["ws_ents"]))
|
||||||
|
has_ws_ents_error = True
|
||||||
|
|
||||||
for label in new_labels:
|
for label in new_labels:
|
||||||
if label_counts[label] <= NEW_LABEL_THRESHOLD:
|
if label_counts[label] <= NEW_LABEL_THRESHOLD:
|
||||||
msg.warn(
|
msg.warn(
|
||||||
|
@ -222,6 +224,8 @@ def debug_data(
|
||||||
msg.good("Good amount of examples for all labels")
|
msg.good("Good amount of examples for all labels")
|
||||||
if not has_no_neg_warning:
|
if not has_no_neg_warning:
|
||||||
msg.good("Examples without occurences available for all labels")
|
msg.good("Examples without occurences available for all labels")
|
||||||
|
if not has_ws_ents_error:
|
||||||
|
msg.good("No entities consisting of or starting/ending with whitespace")
|
||||||
|
|
||||||
if has_low_data_warning:
|
if has_low_data_warning:
|
||||||
msg.text(
|
msg.text(
|
||||||
|
@ -236,6 +240,11 @@ def debug_data(
|
||||||
"type.",
|
"type.",
|
||||||
show=verbose,
|
show=verbose,
|
||||||
)
|
)
|
||||||
|
if has_ws_ents_error:
|
||||||
|
msg.text(
|
||||||
|
"As of spaCy v2.1.0, entity spans consisting of or starting/ending "
|
||||||
|
"with whitespace characters are considered invalid."
|
||||||
|
)
|
||||||
|
|
||||||
if "textcat" in pipeline:
|
if "textcat" in pipeline:
|
||||||
msg.divider("Text Classification")
|
msg.divider("Text Classification")
|
||||||
|
@ -321,11 +330,13 @@ def debug_data(
|
||||||
def _load_file(file_path, msg):
|
def _load_file(file_path, msg):
|
||||||
file_name = file_path.parts[-1]
|
file_name = file_path.parts[-1]
|
||||||
if file_path.suffix == ".json":
|
if file_path.suffix == ".json":
|
||||||
data = srsly.read_json(file_path)
|
with msg.loading("Loading {}...".format(file_name)):
|
||||||
|
data = srsly.read_json(file_path)
|
||||||
msg.good("Loaded {}".format(file_name))
|
msg.good("Loaded {}".format(file_name))
|
||||||
return data
|
return data
|
||||||
elif file_path.suffix == ".jsonl":
|
elif file_path.suffix == ".jsonl":
|
||||||
data = srsly.read_jsonl(file_path)
|
with msg.loading("Loading {}...".format(file_name)):
|
||||||
|
data = srsly.read_jsonl(file_path)
|
||||||
msg.good("Loaded {}".format(file_name))
|
msg.good("Loaded {}".format(file_name))
|
||||||
return data
|
return data
|
||||||
msg.fail(
|
msg.fail(
|
||||||
|
@ -342,6 +353,7 @@ def _compile_gold(train_docs, pipeline):
|
||||||
"tags": Counter(),
|
"tags": Counter(),
|
||||||
"deps": Counter(),
|
"deps": Counter(),
|
||||||
"words": Counter(),
|
"words": Counter(),
|
||||||
|
"ws_ents": 0,
|
||||||
"n_words": 0,
|
"n_words": 0,
|
||||||
"texts": set(),
|
"texts": set(),
|
||||||
}
|
}
|
||||||
|
@ -350,7 +362,10 @@ def _compile_gold(train_docs, pipeline):
|
||||||
data["n_words"] += len(gold.words)
|
data["n_words"] += len(gold.words)
|
||||||
data["texts"].add(doc.text)
|
data["texts"].add(doc.text)
|
||||||
if "ner" in pipeline:
|
if "ner" in pipeline:
|
||||||
for label in gold.ner:
|
for i, label in enumerate(gold.ner):
|
||||||
|
if label.startswith(("B-", "U-", "L-")) and doc[i].is_space:
|
||||||
|
# "Illegal" whitespace entity
|
||||||
|
data["ws_ents"] += 1
|
||||||
if label.startswith(("B-", "U-")):
|
if label.startswith(("B-", "U-")):
|
||||||
combined_label = label.split("-")[1]
|
combined_label = label.split("-")[1]
|
||||||
data["ner"][combined_label] += 1
|
data["ner"][combined_label] += 1
|
||||||
|
@ -371,18 +386,6 @@ def _format_labels(labels, counts=False):
|
||||||
return ", ".join(["'{}'".format(l) for l in labels])
|
return ", ".join(["'{}'".format(l) for l in labels])
|
||||||
|
|
||||||
|
|
||||||
def _get_ner_counts(data):
|
|
||||||
counter = Counter()
|
|
||||||
for doc, gold in data:
|
|
||||||
for label in gold.ner:
|
|
||||||
if label.startswith(("B-", "U-")):
|
|
||||||
combined_label = label.split("-")[1]
|
|
||||||
counter[combined_label] += 1
|
|
||||||
elif label == "-":
|
|
||||||
counter["-"] += 1
|
|
||||||
return counter
|
|
||||||
|
|
||||||
|
|
||||||
def _get_examples_without_label(data, label):
|
def _get_examples_without_label(data, label):
|
||||||
count = 0
|
count = 0
|
||||||
for doc, gold in data:
|
for doc, gold in data:
|
||||||
|
|
|
@ -8,7 +8,6 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
|
|
||||||
from ._messages import Messages
|
|
||||||
from .link import link
|
from .link import link
|
||||||
from ..util import get_package_path
|
from ..util import get_package_path
|
||||||
from .. import about
|
from .. import about
|
||||||
|
@ -50,15 +49,24 @@ def download(model, direct=False, *pip_args):
|
||||||
# Dirty, but since spacy.download and the auto-linking is
|
# Dirty, but since spacy.download and the auto-linking is
|
||||||
# mostly a convenience wrapper, it's best to show a success
|
# mostly a convenience wrapper, it's best to show a success
|
||||||
# message and loading instructions, even if linking fails.
|
# message and loading instructions, even if linking fails.
|
||||||
msg.warn(Messages.M002.format(name=model_name), Messages.M001)
|
msg.warn(
|
||||||
|
"Download successful but linking failed",
|
||||||
|
"Creating a shortcut link for 'en' didn't work (maybe you "
|
||||||
|
"don't have admin permissions?), but you can still load the "
|
||||||
|
"model via its full package name: "
|
||||||
|
"nlp = spacy.load('{}')".format(model_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_json(url, desc):
|
def get_json(url, desc):
|
||||||
r = requests.get(url)
|
r = requests.get(url)
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
Messages.M003.format(code=r.status_code),
|
"Server error ({})".format(r.status_code),
|
||||||
Messages.M004.format(desc=desc, version=about.__version__),
|
"Couldn't fetch {}. Please find a model for your spaCy "
|
||||||
|
"installation (v{}), and download it manually. For more "
|
||||||
|
"details, see the documentation: "
|
||||||
|
"https://spacy.io/usage/models".format(desc, about.__version__),
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
return r.json()
|
return r.json()
|
||||||
|
@ -70,7 +78,7 @@ def get_compatibility():
|
||||||
comp_table = get_json(about.__compatibility__, "compatibility table")
|
comp_table = get_json(about.__compatibility__, "compatibility table")
|
||||||
comp = comp_table["spacy"]
|
comp = comp_table["spacy"]
|
||||||
if version not in comp:
|
if version not in comp:
|
||||||
msg.fail(Messages.M005, Messages.M006.format(version=version), exits=1)
|
msg.fail("No compatible models found for v{} of spaCy".format(version), exits=1)
|
||||||
return comp[version]
|
return comp[version]
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,8 +86,8 @@ def get_version(model, comp):
|
||||||
model = model.rsplit(".dev", 1)[0]
|
model = model.rsplit(".dev", 1)[0]
|
||||||
if model not in comp:
|
if model not in comp:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
Messages.M005,
|
"No compatible model found for '{}' "
|
||||||
Messages.M007.format(name=model, version=about.__version__),
|
"(spaCy v{}).".format(model, about.__version__),
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
return comp[model][0]
|
return comp[model][0]
|
||||||
|
|
|
@ -5,7 +5,6 @@ import plac
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
|
|
||||||
from ._messages import Messages
|
|
||||||
from ..gold import GoldCorpus
|
from ..gold import GoldCorpus
|
||||||
from .. import util
|
from .. import util
|
||||||
from .. import displacy
|
from .. import displacy
|
||||||
|
@ -39,9 +38,9 @@ def evaluate(
|
||||||
data_path = util.ensure_path(data_path)
|
data_path = util.ensure_path(data_path)
|
||||||
displacy_path = util.ensure_path(displacy_path)
|
displacy_path = util.ensure_path(displacy_path)
|
||||||
if not data_path.exists():
|
if not data_path.exists():
|
||||||
msg.fail(Messages.M034, data_path, exits=1)
|
msg.fail("Evaluation data not found", data_path, exits=1)
|
||||||
if displacy_path and not displacy_path.exists():
|
if displacy_path and not displacy_path.exists():
|
||||||
msg.fail(Messages.M035, displacy_path, exits=1)
|
msg.fail("Visualization output directory not found", displacy_path, exits=1)
|
||||||
corpus = GoldCorpus(data_path, data_path)
|
corpus = GoldCorpus(data_path, data_path)
|
||||||
nlp = util.load_model(model)
|
nlp = util.load_model(model)
|
||||||
dev_docs = list(corpus.dev_docs(nlp, gold_preproc=gold_preproc))
|
dev_docs = list(corpus.dev_docs(nlp, gold_preproc=gold_preproc))
|
||||||
|
@ -75,7 +74,7 @@ def evaluate(
|
||||||
deps=render_deps,
|
deps=render_deps,
|
||||||
ents=render_ents,
|
ents=render_ents,
|
||||||
)
|
)
|
||||||
msg.good(Messages.M036.format(n=displacy_limit), displacy_path)
|
msg.good("Generated {} parses as HTML".format(displacy_limit), displacy_path)
|
||||||
|
|
||||||
|
|
||||||
def render_parses(docs, output_path, model_name="", limit=250, deps=True, ents=True):
|
def render_parses(docs, output_path, model_name="", limit=250, deps=True, ents=True):
|
||||||
|
@ -90,39 +89,3 @@ def render_parses(docs, output_path, model_name="", limit=250, deps=True, ents=T
|
||||||
docs[:limit], style="dep", page=True, options={"compact": True}
|
docs[:limit], style="dep", page=True, options={"compact": True}
|
||||||
)
|
)
|
||||||
file_.write(html)
|
file_.write(html)
|
||||||
|
|
||||||
|
|
||||||
def print_progress(itn, losses, dev_scores, wps=0.0):
|
|
||||||
scores = {}
|
|
||||||
for col in [
|
|
||||||
"dep_loss",
|
|
||||||
"tag_loss",
|
|
||||||
"uas",
|
|
||||||
"tags_acc",
|
|
||||||
"token_acc",
|
|
||||||
"ents_p",
|
|
||||||
"ents_r",
|
|
||||||
"ents_f",
|
|
||||||
"wps",
|
|
||||||
]:
|
|
||||||
scores[col] = 0.0
|
|
||||||
scores["dep_loss"] = losses.get("parser", 0.0)
|
|
||||||
scores["ner_loss"] = losses.get("ner", 0.0)
|
|
||||||
scores["tag_loss"] = losses.get("tagger", 0.0)
|
|
||||||
scores.update(dev_scores)
|
|
||||||
scores["wps"] = wps
|
|
||||||
tpl = "\t".join(
|
|
||||||
(
|
|
||||||
"{:d}",
|
|
||||||
"{dep_loss:.3f}",
|
|
||||||
"{ner_loss:.3f}",
|
|
||||||
"{uas:.3f}",
|
|
||||||
"{ents_p:.3f}",
|
|
||||||
"{ents_r:.3f}",
|
|
||||||
"{ents_f:.3f}",
|
|
||||||
"{tags_acc:.3f}",
|
|
||||||
"{token_acc:.3f}",
|
|
||||||
"{wps:.1f}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(tpl.format(itn, **scores))
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ from pathlib import Path
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
from ._messages import Messages
|
|
||||||
from ..compat import path2str, basestring_, unicode_
|
from ..compat import path2str, basestring_, unicode_
|
||||||
from .. import util
|
from .. import util
|
||||||
from .. import about
|
from .. import about
|
||||||
|
@ -32,7 +31,7 @@ def info(model=None, markdown=False, silent=False):
|
||||||
model_path = util.get_data_path() / model
|
model_path = util.get_data_path() / model
|
||||||
meta_path = model_path / "meta.json"
|
meta_path = model_path / "meta.json"
|
||||||
if not meta_path.is_file():
|
if not meta_path.is_file():
|
||||||
msg.fail(Messages.M020, meta_path, exits=1)
|
msg.fail("Can't find model meta.json", meta_path, exits=1)
|
||||||
meta = srsly.read_json(meta_path)
|
meta = srsly.read_json(meta_path)
|
||||||
if model_path.resolve() != model_path:
|
if model_path.resolve() != model_path:
|
||||||
meta["link"] = path2str(model_path)
|
meta["link"] = path2str(model_path)
|
||||||
|
|
|
@ -14,7 +14,6 @@ import zipfile
|
||||||
import srsly
|
import srsly
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
|
|
||||||
from ._messages import Messages
|
|
||||||
from ..vectors import Vectors
|
from ..vectors import Vectors
|
||||||
from ..errors import Errors, Warnings, user_warning
|
from ..errors import Errors, Warnings, user_warning
|
||||||
from ..util import ensure_path, get_lang_class
|
from ..util import ensure_path, get_lang_class
|
||||||
|
@ -58,14 +57,21 @@ def init_model(
|
||||||
settings.append("-f")
|
settings.append("-f")
|
||||||
if clusters_loc:
|
if clusters_loc:
|
||||||
settings.append("-c")
|
settings.append("-c")
|
||||||
msg.warn(Messages.M063, Messages.M064)
|
msg.warn(
|
||||||
|
"Incompatible arguments",
|
||||||
|
"The -f and -c arguments are deprecated, and not compatible "
|
||||||
|
"with the -j argument, which should specify the same "
|
||||||
|
"information. Either merge the frequencies and clusters data "
|
||||||
|
"into the JSONL-formatted file (recommended), or use only the "
|
||||||
|
"-f and -c files, without the other lexical attributes.",
|
||||||
|
)
|
||||||
jsonl_loc = ensure_path(jsonl_loc)
|
jsonl_loc = ensure_path(jsonl_loc)
|
||||||
lex_attrs = srsly.read_jsonl(jsonl_loc)
|
lex_attrs = srsly.read_jsonl(jsonl_loc)
|
||||||
else:
|
else:
|
||||||
clusters_loc = ensure_path(clusters_loc)
|
clusters_loc = ensure_path(clusters_loc)
|
||||||
freqs_loc = ensure_path(freqs_loc)
|
freqs_loc = ensure_path(freqs_loc)
|
||||||
if freqs_loc is not None and not freqs_loc.exists():
|
if freqs_loc is not None and not freqs_loc.exists():
|
||||||
msg.fail(Messages.M037, freqs_loc, exits=1)
|
msg.fail("Can't find words frequencies file", freqs_loc, exits=1)
|
||||||
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
|
lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
|
||||||
|
|
||||||
with msg.loading("Creating model..."):
|
with msg.loading("Creating model..."):
|
||||||
|
@ -75,7 +81,10 @@ def init_model(
|
||||||
add_vectors(nlp, vectors_loc, prune_vectors)
|
add_vectors(nlp, vectors_loc, prune_vectors)
|
||||||
vec_added = len(nlp.vocab.vectors)
|
vec_added = len(nlp.vocab.vectors)
|
||||||
lex_added = len(nlp.vocab)
|
lex_added = len(nlp.vocab)
|
||||||
msg.good(Messages.M038, Messages.M039.format(entries=lex_added, vectors=vec_added))
|
msg.good(
|
||||||
|
"Sucessfully compiled vocab",
|
||||||
|
"{} entries, {} vectors".format(lex_added, vec_added),
|
||||||
|
)
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
nlp.to_disk(output_dir)
|
nlp.to_disk(output_dir)
|
||||||
|
|
|
@ -5,7 +5,6 @@ import plac
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
|
|
||||||
from ._messages import Messages
|
|
||||||
from ..compat import symlink_to, path2str
|
from ..compat import symlink_to, path2str
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
|
@ -28,29 +27,52 @@ def link(origin, link_name, force=False, model_path=None):
|
||||||
model_path = Path(origin) if model_path is None else Path(model_path)
|
model_path = Path(origin) if model_path is None else Path(model_path)
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
msg.fail(
|
msg.fail(
|
||||||
Messages.M008, Messages.M009.format(path=path2str(model_path)), exits=1
|
"Can't locate model data",
|
||||||
|
"The data should be located in {}".format(path2str(model_path)),
|
||||||
|
exits=1,
|
||||||
)
|
)
|
||||||
data_path = util.get_data_path()
|
data_path = util.get_data_path()
|
||||||
if not data_path or not data_path.exists():
|
if not data_path or not data_path.exists():
|
||||||
spacy_loc = Path(__file__).parent.parent
|
spacy_loc = Path(__file__).parent.parent
|
||||||
msg.fail(Messages.M010, Messages.M011.format(path=spacy_loc), exits=1)
|
msg.fail(
|
||||||
|
"Can't find the spaCy data path to create model symlink",
|
||||||
|
"Make sure a directory `/data` exists within your spaCy "
|
||||||
|
"installation and try again. The data directory should be located "
|
||||||
|
"here:".format(path=spacy_loc),
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
link_path = util.get_data_path() / link_name
|
link_path = util.get_data_path() / link_name
|
||||||
if link_path.is_symlink() and not force:
|
if link_path.is_symlink() and not force:
|
||||||
msg.fail(Messages.M012.format(name=link_name), Messages.M013, exits=1)
|
msg.fail(
|
||||||
|
"Link '{}' already exists".format(link_name),
|
||||||
|
"To overwrite an existing link, use the --force flag",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
elif link_path.is_symlink(): # does a symlink exist?
|
elif link_path.is_symlink(): # does a symlink exist?
|
||||||
# NB: It's important to check for is_symlink here and not for exists,
|
# NB: It's important to check for is_symlink here and not for exists,
|
||||||
# because invalid/outdated symlinks would return False otherwise.
|
# because invalid/outdated symlinks would return False otherwise.
|
||||||
link_path.unlink()
|
link_path.unlink()
|
||||||
elif link_path.exists(): # does it exist otherwise?
|
elif link_path.exists(): # does it exist otherwise?
|
||||||
# NB: Check this last because valid symlinks also "exist".
|
# NB: Check this last because valid symlinks also "exist".
|
||||||
msg.fail(Messages.M014.format(name=link_name), Messages.M015, exits=1)
|
msg.fail(
|
||||||
|
"Can't overwrite symlink '{}'".format(link_name),
|
||||||
|
"This can happen if your data directory contains a directory or "
|
||||||
|
"file of the same name.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
details = "%s --> %s" % (path2str(model_path), path2str(link_path))
|
details = "%s --> %s" % (path2str(model_path), path2str(link_path))
|
||||||
try:
|
try:
|
||||||
symlink_to(link_path, model_path)
|
symlink_to(link_path, model_path)
|
||||||
except: # noqa: E722
|
except: # noqa: E722
|
||||||
# This is quite dirty, but just making sure other errors are caught.
|
# This is quite dirty, but just making sure other errors are caught.
|
||||||
msg.fail(Messages.M016.format(name=link_name), Messages.M017)
|
msg.fail(
|
||||||
|
"Couldn't link model to '{}'".format(link_name),
|
||||||
|
"Creating a symlink in spacy/data failed. Make sure you have the "
|
||||||
|
"required permissions and try re-running the command as admin, or "
|
||||||
|
"use a virtualenv. You can still import the model as a module and "
|
||||||
|
"call its load() method, or create the symlink manually.",
|
||||||
|
)
|
||||||
msg.text(details)
|
msg.text(details)
|
||||||
raise
|
raise
|
||||||
msg.good(Messages.M018, details)
|
msg.good("Linking successful", details)
|
||||||
msg.text(Messages.M019.format(name=link_name))
|
msg.text("You can now load the model via spacy.load('{}')".format(link_name))
|
||||||
|
|
|
@ -7,7 +7,6 @@ from pathlib import Path
|
||||||
from wasabi import Printer, get_raw_input
|
from wasabi import Printer, get_raw_input
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
from ._messages import Messages
|
|
||||||
from ..compat import path2str
|
from ..compat import path2str
|
||||||
from .. import util
|
from .. import util
|
||||||
from .. import about
|
from .. import about
|
||||||
|
@ -33,22 +32,26 @@ def package(input_dir, output_dir, meta_path=None, create_meta=False, force=Fals
|
||||||
output_path = util.ensure_path(output_dir)
|
output_path = util.ensure_path(output_dir)
|
||||||
meta_path = util.ensure_path(meta_path)
|
meta_path = util.ensure_path(meta_path)
|
||||||
if not input_path or not input_path.exists():
|
if not input_path or not input_path.exists():
|
||||||
msg.fail(Messages.M008, input_path, exits=1)
|
msg.fail("Can't locate model data", input_path, exits=1)
|
||||||
if not output_path or not output_path.exists():
|
if not output_path or not output_path.exists():
|
||||||
msg.fail(Messages.M040, output_path, exits=1)
|
msg.fail("Output directory not found", output_path, exits=1)
|
||||||
if meta_path and not meta_path.exists():
|
if meta_path and not meta_path.exists():
|
||||||
msg.fail(Messages.M020, meta_path, exits=1)
|
msg.fail("Can't find model meta.json", meta_path, exits=1)
|
||||||
|
|
||||||
meta_path = meta_path or input_path / "meta.json"
|
meta_path = meta_path or input_path / "meta.json"
|
||||||
if meta_path.is_file():
|
if meta_path.is_file():
|
||||||
meta = srsly.read_json(meta_path)
|
meta = srsly.read_json(meta_path)
|
||||||
if not create_meta: # only print if user doesn't want to overwrite
|
if not create_meta: # only print if user doesn't want to overwrite
|
||||||
msg.good(Messages.M041, meta_path)
|
msg.good("Loaded meta.json from file", meta_path)
|
||||||
else:
|
else:
|
||||||
meta = generate_meta(input_dir, meta, msg)
|
meta = generate_meta(input_dir, meta, msg)
|
||||||
for key in ("lang", "name", "version"):
|
for key in ("lang", "name", "version"):
|
||||||
if key not in meta or meta[key] == "":
|
if key not in meta or meta[key] == "":
|
||||||
msg.fail(Messages.M048.format(key=key), Messages.M049, exits=1)
|
msg.fail(
|
||||||
|
"No '{}' setting found in meta.json".format(key),
|
||||||
|
"This setting is required to build your package.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
model_name = meta["lang"] + "_" + meta["name"]
|
model_name = meta["lang"] + "_" + meta["name"]
|
||||||
model_name_v = model_name + "-" + meta["version"]
|
model_name_v = model_name + "-" + meta["version"]
|
||||||
main_path = output_path / model_name_v
|
main_path = output_path / model_name_v
|
||||||
|
@ -59,8 +62,10 @@ def package(input_dir, output_dir, meta_path=None, create_meta=False, force=Fals
|
||||||
shutil.rmtree(path2str(package_path))
|
shutil.rmtree(path2str(package_path))
|
||||||
else:
|
else:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
Messages.M044,
|
"Package directory already exists",
|
||||||
Messages.M045.format(path=path2str(package_path)),
|
"Please delete the directory and try again, or use the "
|
||||||
|
"`--force` flag to overwrite existing "
|
||||||
|
"directories.".format(path=path2str(package_path)),
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
Path.mkdir(package_path, parents=True)
|
Path.mkdir(package_path, parents=True)
|
||||||
|
@ -69,8 +74,8 @@ def package(input_dir, output_dir, meta_path=None, create_meta=False, force=Fals
|
||||||
create_file(main_path / "setup.py", TEMPLATE_SETUP)
|
create_file(main_path / "setup.py", TEMPLATE_SETUP)
|
||||||
create_file(main_path / "MANIFEST.in", TEMPLATE_MANIFEST)
|
create_file(main_path / "MANIFEST.in", TEMPLATE_MANIFEST)
|
||||||
create_file(package_path / "__init__.py", TEMPLATE_INIT)
|
create_file(package_path / "__init__.py", TEMPLATE_INIT)
|
||||||
msg.good(Messages.M042.format(name=model_name_v), main_path)
|
msg.good("Successfully created package '{}'".format(model_name_v), main_path)
|
||||||
msg.text(Messages.M043)
|
msg.text("To build the package, run `python setup.py sdist` in this directory.")
|
||||||
|
|
||||||
|
|
||||||
def create_file(file_path, contents):
|
def create_file(file_path, contents):
|
||||||
|
@ -98,8 +103,11 @@ def generate_meta(model_path, existing_meta, msg):
|
||||||
"vectors": len(nlp.vocab.vectors),
|
"vectors": len(nlp.vocab.vectors),
|
||||||
"keys": nlp.vocab.vectors.n_keys,
|
"keys": nlp.vocab.vectors.n_keys,
|
||||||
}
|
}
|
||||||
msg.divider(Messages.M046)
|
msg.divider("Generating meta.json")
|
||||||
msg.text(Messages.M047)
|
msg.text(
|
||||||
|
"Enter the package settings for your model. The following information "
|
||||||
|
"will be read from your model data: pipeline, vectors."
|
||||||
|
)
|
||||||
for setting, desc, default in settings:
|
for setting, desc, default in settings:
|
||||||
response = get_raw_input(desc, default)
|
response = get_raw_input(desc, default)
|
||||||
meta[setting] = default if response == "" and default else response
|
meta[setting] = default if response == "" and default else response
|
||||||
|
|
|
@ -11,7 +11,6 @@ import srsly
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
from thinc.rates import slanted_triangular
|
from thinc.rates import slanted_triangular
|
||||||
|
|
||||||
from ._messages import Messages
|
|
||||||
from .._ml import create_default_optimizer
|
from .._ml import create_default_optimizer
|
||||||
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
|
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
|
||||||
from ..gold import GoldCorpus
|
from ..gold import GoldCorpus
|
||||||
|
@ -19,22 +18,6 @@ from .. import util
|
||||||
from .. import about
|
from .. import about
|
||||||
|
|
||||||
|
|
||||||
# Take dropout and batch size as generators of values -- dropout
|
|
||||||
# starts high and decays sharply, to force the optimizer to explore.
|
|
||||||
# Batch size starts at 1 and grows, so that we make updates quickly
|
|
||||||
# at the beginning of training.
|
|
||||||
dropout_rates = util.decaying(
|
|
||||||
util.env_opt("dropout_from", 0.2),
|
|
||||||
util.env_opt("dropout_to", 0.2),
|
|
||||||
util.env_opt("dropout_decay", 0.0),
|
|
||||||
)
|
|
||||||
batch_sizes = util.compounding(
|
|
||||||
util.env_opt("batch_from", 100),
|
|
||||||
util.env_opt("batch_to", 1000),
|
|
||||||
util.env_opt("batch_compound", 1.001),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
lang=("Model language", "positional", None, str),
|
lang=("Model language", "positional", None, str),
|
||||||
output_path=("Output directory to store model in", "positional", None, Path),
|
output_path=("Output directory to store model in", "positional", None, Path),
|
||||||
|
@ -108,36 +91,59 @@ def train(
|
||||||
dev_path = util.ensure_path(dev_path)
|
dev_path = util.ensure_path(dev_path)
|
||||||
meta_path = util.ensure_path(meta_path)
|
meta_path = util.ensure_path(meta_path)
|
||||||
if not train_path or not train_path.exists():
|
if not train_path or not train_path.exists():
|
||||||
msg.fail(Messages.M050, train_path, exits=1)
|
msg.fail("Training data not found", train_path, exits=1)
|
||||||
if not dev_path or not dev_path.exists():
|
if not dev_path or not dev_path.exists():
|
||||||
msg.fail(Messages.M051, dev_path, exits=1)
|
msg.fail("Development data not found", dev_path, exits=1)
|
||||||
if meta_path is not None and not meta_path.exists():
|
if meta_path is not None and not meta_path.exists():
|
||||||
msg.fail(Messages.M020, meta_path, exits=1)
|
msg.fail("Can't find model meta.json", meta_path, exits=1)
|
||||||
meta = srsly.read_json(meta_path) if meta_path else {}
|
meta = srsly.read_json(meta_path) if meta_path else {}
|
||||||
if not isinstance(meta, dict):
|
|
||||||
msg.fail(Messages.M052, Messages.M053.format(meta_type=type(meta)), exits=1)
|
|
||||||
if output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
|
if output_path.exists() and [p for p in output_path.iterdir() if p.is_dir()]:
|
||||||
msg.fail(Messages.M062, Messages.M065)
|
msg.warn(
|
||||||
|
"Output directory is not empty",
|
||||||
|
"This can lead to unintended side effects when saving the model. "
|
||||||
|
"Please use an empty directory or a different path instead. If "
|
||||||
|
"the specified output path doesn't exist, the directory will be "
|
||||||
|
"created for you.",
|
||||||
|
)
|
||||||
if not output_path.exists():
|
if not output_path.exists():
|
||||||
output_path.mkdir()
|
output_path.mkdir()
|
||||||
|
|
||||||
|
# Take dropout and batch size as generators of values -- dropout
|
||||||
|
# starts high and decays sharply, to force the optimizer to explore.
|
||||||
|
# Batch size starts at 1 and grows, so that we make updates quickly
|
||||||
|
# at the beginning of training.
|
||||||
|
dropout_rates = util.decaying(
|
||||||
|
util.env_opt("dropout_from", 0.2),
|
||||||
|
util.env_opt("dropout_to", 0.2),
|
||||||
|
util.env_opt("dropout_decay", 0.0),
|
||||||
|
)
|
||||||
|
batch_sizes = util.compounding(
|
||||||
|
util.env_opt("batch_from", 100.0),
|
||||||
|
util.env_opt("batch_to", 2000.0),
|
||||||
|
util.env_opt("batch_compound", 1.001),
|
||||||
|
)
|
||||||
|
|
||||||
# Set up the base model and pipeline. If a base model is specified, load
|
# Set up the base model and pipeline. If a base model is specified, load
|
||||||
# the model and make sure the pipeline matches the pipeline setting. If
|
# the model and make sure the pipeline matches the pipeline setting. If
|
||||||
# training starts from a blank model, intitalize the language class.
|
# training starts from a blank model, intitalize the language class.
|
||||||
pipeline = [p.strip() for p in pipeline.split(",")]
|
pipeline = [p.strip() for p in pipeline.split(",")]
|
||||||
msg.text(Messages.M055.format(pipeline=pipeline))
|
msg.text("Training pipeline: {}".format(pipeline))
|
||||||
if base_model:
|
if base_model:
|
||||||
msg.text(Messages.M056.format(model=base_model))
|
msg.text("Starting with base model '{}'".format(base_model))
|
||||||
nlp = util.load_model(base_model)
|
nlp = util.load_model(base_model)
|
||||||
if nlp.lang != lang:
|
if nlp.lang != lang:
|
||||||
msg.fail(Messages.M072.format(model_lang=nlp.lang, lang=lang), exits=1)
|
msg.fail(
|
||||||
|
"Model language ('{}') doesn't match language specified as "
|
||||||
|
"`lang` argument ('{}') ".format(nlp.lang, lang),
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipeline]
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipeline]
|
||||||
nlp.disable_pipes(*other_pipes)
|
nlp.disable_pipes(*other_pipes)
|
||||||
for pipe in pipeline:
|
for pipe in pipeline:
|
||||||
if pipe not in nlp.pipe_names:
|
if pipe not in nlp.pipe_names:
|
||||||
nlp.add_pipe(nlp.create_pipe(pipe))
|
nlp.add_pipe(nlp.create_pipe(pipe))
|
||||||
else:
|
else:
|
||||||
msg.text(Messages.M057.format(model=lang))
|
msg.text("Starting with blank model '{}'".format(lang))
|
||||||
lang_cls = util.get_lang_class(lang)
|
lang_cls = util.get_lang_class(lang)
|
||||||
nlp = lang_cls()
|
nlp = lang_cls()
|
||||||
for pipe in pipeline:
|
for pipe in pipeline:
|
||||||
|
@ -147,7 +153,7 @@ def train(
|
||||||
nlp.add_pipe(nlp.create_pipe("merge_subtokens"))
|
nlp.add_pipe(nlp.create_pipe("merge_subtokens"))
|
||||||
|
|
||||||
if vectors:
|
if vectors:
|
||||||
msg.text(Messages.M058.format(model=vectors))
|
msg.text("Loading vector from model '{}'".format(vectors))
|
||||||
_load_vectors(nlp, vectors)
|
_load_vectors(nlp, vectors)
|
||||||
|
|
||||||
# Multitask objectives
|
# Multitask objectives
|
||||||
|
@ -155,13 +161,16 @@ def train(
|
||||||
for pipe_name, multitasks in multitask_options:
|
for pipe_name, multitasks in multitask_options:
|
||||||
if multitasks:
|
if multitasks:
|
||||||
if pipe_name not in pipeline:
|
if pipe_name not in pipeline:
|
||||||
msg.fail(Messages.M059.format(pipe=pipe_name))
|
msg.fail(
|
||||||
|
"Can't use multitask objective without '{}' in the "
|
||||||
|
"pipeline".format(pipe_name)
|
||||||
|
)
|
||||||
pipe = nlp.get_pipe(pipe_name)
|
pipe = nlp.get_pipe(pipe_name)
|
||||||
for objective in multitasks.split(","):
|
for objective in multitasks.split(","):
|
||||||
pipe.add_multitask_objective(objective)
|
pipe.add_multitask_objective(objective)
|
||||||
|
|
||||||
# Prepare training corpus
|
# Prepare training corpus
|
||||||
msg.text(Messages.M060.format(limit=n_examples))
|
msg.text("Counting training words (limit={})".format(n_examples))
|
||||||
corpus = GoldCorpus(train_path, dev_path, limit=n_examples)
|
corpus = GoldCorpus(train_path, dev_path, limit=n_examples)
|
||||||
n_train_words = corpus.count_train()
|
n_train_words = corpus.count_train()
|
||||||
|
|
||||||
|
@ -179,11 +188,19 @@ def train(
|
||||||
# Load in pre-trained weights
|
# Load in pre-trained weights
|
||||||
if init_tok2vec is not None:
|
if init_tok2vec is not None:
|
||||||
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
|
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
|
||||||
msg.text(Messages.M071.format(components=components))
|
msg.text("Loaded pretrained tok2vec for: {}".format(components))
|
||||||
|
|
||||||
print(
|
# fmt: off
|
||||||
"\nItn. Dep Loss NER Loss UAS NER P. NER R. NER F. Tag % Token % CPU WPS GPU WPS"
|
row_head = ("Itn", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS")
|
||||||
)
|
row_settings = {
|
||||||
|
"widths": (3, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7),
|
||||||
|
"aligns": tuple(["r" for i in row_head]),
|
||||||
|
"spacing": 2
|
||||||
|
}
|
||||||
|
# fmt: on
|
||||||
|
print("")
|
||||||
|
msg.row(row_head, **row_settings)
|
||||||
|
msg.row(["-" * width for width in row_settings["widths"]], **row_settings)
|
||||||
try:
|
try:
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
train_docs = corpus.train_docs(
|
train_docs = corpus.train_docs(
|
||||||
|
@ -250,15 +267,18 @@ def train(
|
||||||
|
|
||||||
util.set_env_log(verbose)
|
util.set_env_log(verbose)
|
||||||
|
|
||||||
print_progress(i, losses, scorer.scores, cpu_wps=cpu_wps, gpu_wps=gpu_wps)
|
progress = _get_progress(
|
||||||
|
i, losses, scorer.scores, cpu_wps=cpu_wps, gpu_wps=gpu_wps
|
||||||
|
)
|
||||||
|
msg.row(progress, **row_settings)
|
||||||
finally:
|
finally:
|
||||||
with msg.loading(Messages.M061):
|
with nlp.use_params(optimizer.averages):
|
||||||
with nlp.use_params(optimizer.averages):
|
final_model_path = output_path / "model-final"
|
||||||
final_model_path = output_path / "model-final"
|
nlp.to_disk(final_model_path)
|
||||||
nlp.to_disk(final_model_path)
|
msg.good("Saved model to output directory", final_model_path)
|
||||||
msg.good(Messages.M066, util.path2str(final_model_path))
|
with msg.loading("Creating best model..."):
|
||||||
|
best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names)
|
||||||
_collate_best_model(meta, output_path, nlp.pipe_names)
|
msg.good("Created best model", best_model_path)
|
||||||
|
|
||||||
|
|
||||||
def _load_vectors(nlp, vectors):
|
def _load_vectors(nlp, vectors):
|
||||||
|
@ -301,6 +321,7 @@ def _collate_best_model(meta, output_path, components):
|
||||||
for metric in _get_metrics(component):
|
for metric in _get_metrics(component):
|
||||||
meta["accuracy"][metric] = accs[metric]
|
meta["accuracy"][metric] = accs[metric]
|
||||||
srsly.write_json(best_dest / "meta.json", meta)
|
srsly.write_json(best_dest / "meta.json", meta)
|
||||||
|
return best_dest
|
||||||
|
|
||||||
|
|
||||||
def _find_best(experiment_dir, component):
|
def _find_best(experiment_dir, component):
|
||||||
|
@ -326,7 +347,7 @@ def _get_metrics(component):
|
||||||
return ("token_acc",)
|
return ("token_acc",)
|
||||||
|
|
||||||
|
|
||||||
def print_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0):
|
def _get_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0):
|
||||||
scores = {}
|
scores = {}
|
||||||
for col in [
|
for col in [
|
||||||
"dep_loss",
|
"dep_loss",
|
||||||
|
@ -347,19 +368,16 @@ def print_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0):
|
||||||
scores.update(dev_scores)
|
scores.update(dev_scores)
|
||||||
scores["cpu_wps"] = cpu_wps
|
scores["cpu_wps"] = cpu_wps
|
||||||
scores["gpu_wps"] = gpu_wps or 0.0
|
scores["gpu_wps"] = gpu_wps or 0.0
|
||||||
tpl = "".join(
|
return [
|
||||||
(
|
itn,
|
||||||
"{:<6d}",
|
"{:.3f}".format(scores["dep_loss"]),
|
||||||
"{dep_loss:<10.3f}",
|
"{:.3f}".format(scores["ner_loss"]),
|
||||||
"{ner_loss:<10.3f}",
|
"{:.3f}".format(scores["uas"]),
|
||||||
"{uas:<8.3f}",
|
"{:.3f}".format(scores["ents_p"]),
|
||||||
"{ents_p:<8.3f}",
|
"{:.3f}".format(scores["ents_r"]),
|
||||||
"{ents_r:<8.3f}",
|
"{:.3f}".format(scores["ents_f"]),
|
||||||
"{ents_f:<8.3f}",
|
"{:.3f}".format(scores["tags_acc"]),
|
||||||
"{tags_acc:<8.3f}",
|
"{:.3f}".format(scores["token_acc"]),
|
||||||
"{token_acc:<9.3f}",
|
"{:.0f}".format(scores["cpu_wps"]),
|
||||||
"{cpu_wps:<9.1f}",
|
"{:.0f}".format(scores["gpu_wps"]),
|
||||||
"{gpu_wps:.1f}",
|
]
|
||||||
)
|
|
||||||
)
|
|
||||||
print(tpl.format(itn, **scores))
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import requests
|
||||||
import srsly
|
import srsly
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
|
|
||||||
from ._messages import Messages
|
|
||||||
from ..compat import path2str
|
from ..compat import path2str
|
||||||
from ..util import get_data_path
|
from ..util import get_data_path
|
||||||
from .. import about
|
from .. import about
|
||||||
|
@ -23,13 +22,17 @@ def validate():
|
||||||
with msg.loading("Loading compatibility table..."):
|
with msg.loading("Loading compatibility table..."):
|
||||||
r = requests.get(about.__compatibility__)
|
r = requests.get(about.__compatibility__)
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
msg.fail(Messages.M003.format(code=r.status_code), Messages.M021, exits=1)
|
msg.fail(
|
||||||
|
"Server error ({})".format(r.status_code),
|
||||||
|
"Couldn't fetch compatibility table.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
msg.good("Loaded compatibility table")
|
msg.good("Loaded compatibility table")
|
||||||
compat = r.json()["spacy"]
|
compat = r.json()["spacy"]
|
||||||
current_compat = compat.get(about.__version__)
|
current_compat = compat.get(about.__version__)
|
||||||
if not current_compat:
|
if not current_compat:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
Messages.M022.format(version=about.__version__),
|
"Can't find spaCy v{} in compatibility table".format(about.__version__),
|
||||||
about.__compatibility__,
|
about.__compatibility__,
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
|
@ -49,7 +52,7 @@ def validate():
|
||||||
update_models = [m for m in incompat_models if m in current_compat]
|
update_models = [m for m in incompat_models if m in current_compat]
|
||||||
spacy_dir = Path(__file__).parent.parent
|
spacy_dir = Path(__file__).parent.parent
|
||||||
|
|
||||||
msg.divider(Messages.M023.format(version=about.__version__))
|
msg.divider("Installed models (spaCy v{})".format(about.__version__))
|
||||||
msg.info("spaCy installation: {}".format(path2str(spacy_dir)))
|
msg.info("spaCy installation: {}".format(path2str(spacy_dir)))
|
||||||
|
|
||||||
if model_links or model_pkgs:
|
if model_links or model_pkgs:
|
||||||
|
@ -61,17 +64,24 @@ def validate():
|
||||||
rows.append(get_model_row(current_compat, name, data, msg, "link"))
|
rows.append(get_model_row(current_compat, name, data, msg, "link"))
|
||||||
msg.table(rows, header=header)
|
msg.table(rows, header=header)
|
||||||
else:
|
else:
|
||||||
msg.text(Messages.M024, exits=0)
|
msg.text("No models found in your current environment.", exits=0)
|
||||||
if update_models:
|
if update_models:
|
||||||
msg.divider("Install updates")
|
msg.divider("Install updates")
|
||||||
|
msg.text("Use the following commands to update the model packages:")
|
||||||
cmd = "python -m spacy download {}"
|
cmd = "python -m spacy download {}"
|
||||||
print("\n".join([cmd.format(pkg) for pkg in update_models]) + "\n")
|
print("\n".join([cmd.format(pkg) for pkg in update_models]) + "\n")
|
||||||
if na_models:
|
if na_models:
|
||||||
msg.text(
|
msg.text(
|
||||||
Messages.M025.format(version=about.__version__, models=", ".join(na_models))
|
"The following models are not available for spaCy "
|
||||||
|
"v{}: {}".format(about.__version__, ", ".join(na_models))
|
||||||
)
|
)
|
||||||
if incompat_links:
|
if incompat_links:
|
||||||
msg.text(Messages.M027.format(path=path2str(get_data_path())))
|
msg.text(
|
||||||
|
"You may also want to overwrite the incompatible links using the "
|
||||||
|
"`python -m spacy link` command with `--force`, or remove them "
|
||||||
|
"from the data directory. "
|
||||||
|
"Data path: {path}".format(path=path2str(get_data_path()))
|
||||||
|
)
|
||||||
if incompat_models or incompat_links:
|
if incompat_models or incompat_links:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
|
@ -346,12 +346,12 @@ def _json_iterate(loc):
|
||||||
cdef char close_curly = ord('}')
|
cdef char close_curly = ord('}')
|
||||||
for i in range(len(py_raw)):
|
for i in range(len(py_raw)):
|
||||||
c = raw[i]
|
c = raw[i]
|
||||||
if c == backslash:
|
|
||||||
escape = True
|
|
||||||
continue
|
|
||||||
if escape:
|
if escape:
|
||||||
escape = False
|
escape = False
|
||||||
continue
|
continue
|
||||||
|
if c == backslash:
|
||||||
|
escape = True
|
||||||
|
continue
|
||||||
if c == quote:
|
if c == quote:
|
||||||
inside_string = not inside_string
|
inside_string = not inside_string
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -58,7 +58,7 @@ cdef struct TokenC:
|
||||||
attr_t tag
|
attr_t tag
|
||||||
int idx
|
int idx
|
||||||
attr_t lemma
|
attr_t lemma
|
||||||
attr_t sense
|
attr_t norm
|
||||||
int head
|
int head
|
||||||
attr_t dep
|
attr_t dep
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from spacy.attrs import ORTH, LENGTH
|
from spacy.attrs import ORTH, LENGTH
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc, Span
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
from ..util import get_doc
|
from ..util import get_doc
|
||||||
|
@ -154,6 +154,17 @@ def test_span_as_doc(doc):
|
||||||
assert span.text == span_doc.text.strip()
|
assert span.text == span_doc.text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_span_string_label(doc):
|
||||||
|
span = Span(doc, 0, 1, label='hello')
|
||||||
|
assert span.label_ == 'hello'
|
||||||
|
assert span.label == doc.vocab.strings['hello']
|
||||||
|
|
||||||
|
def test_span_string_set_label(doc):
|
||||||
|
span = Span(doc, 0, 1)
|
||||||
|
span.label_ = 'hello'
|
||||||
|
assert span.label_ == 'hello'
|
||||||
|
assert span.label == doc.vocab.strings['hello']
|
||||||
|
|
||||||
def test_span_ents_property(doc):
|
def test_span_ents_property(doc):
|
||||||
"""Test span.ents for the """
|
"""Test span.ents for the """
|
||||||
doc.ents = [
|
doc.ents = [
|
||||||
|
|
14
spacy/tests/regression/test_issue2754.py
Normal file
14
spacy/tests/regression/test_issue2754.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from spacy.lang.en import English
|
||||||
|
|
||||||
|
def test_issue2754():
|
||||||
|
"""Test that words like 'a' and 'a.m.' don't get exceptional norm values."""
|
||||||
|
nlp = English()
|
||||||
|
a = nlp('a')
|
||||||
|
assert a[0].norm_ == 'a'
|
||||||
|
am = nlp('am')
|
||||||
|
assert am[0].norm_ == 'am'
|
||||||
|
|
|
@ -15,7 +15,7 @@ from ..parts_of_speech cimport univ_pos_t
|
||||||
from ..util import normalize_slice
|
from ..util import normalize_slice
|
||||||
from ..attrs cimport IS_PUNCT, IS_SPACE
|
from ..attrs cimport IS_PUNCT, IS_SPACE
|
||||||
from ..lexeme cimport Lexeme
|
from ..lexeme cimport Lexeme
|
||||||
from ..compat import is_config
|
from ..compat import is_config, basestring_
|
||||||
from ..errors import Errors, TempErrors, Warnings, user_warning, models_warning
|
from ..errors import Errors, TempErrors, Warnings, user_warning, models_warning
|
||||||
from .underscore import Underscore, get_ext_args
|
from .underscore import Underscore, get_ext_args
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ cdef class Span:
|
||||||
raise ValueError(Errors.E046.format(name=name))
|
raise ValueError(Errors.E046.format(name=name))
|
||||||
return Underscore.span_extensions.pop(name)
|
return Underscore.span_extensions.pop(name)
|
||||||
|
|
||||||
def __cinit__(self, Doc doc, int start, int end, attr_t label=0,
|
def __cinit__(self, Doc doc, int start, int end, label=0,
|
||||||
vector=None, vector_norm=None):
|
vector=None, vector_norm=None):
|
||||||
"""Create a `Span` object from the slice `doc[start : end]`.
|
"""Create a `Span` object from the slice `doc[start : end]`.
|
||||||
|
|
||||||
|
@ -64,6 +64,8 @@ cdef class Span:
|
||||||
self.end_char = self.doc[end - 1].idx + len(self.doc[end - 1])
|
self.end_char = self.doc[end - 1].idx + len(self.doc[end - 1])
|
||||||
else:
|
else:
|
||||||
self.end_char = 0
|
self.end_char = 0
|
||||||
|
if isinstance(label, basestring_):
|
||||||
|
label = doc.vocab.strings.add(label)
|
||||||
if label not in doc.vocab.strings:
|
if label not in doc.vocab.strings:
|
||||||
raise ValueError(Errors.E084.format(label=label))
|
raise ValueError(Errors.E084.format(label=label))
|
||||||
self.label = label
|
self.label = label
|
||||||
|
@ -601,6 +603,8 @@ cdef class Span:
|
||||||
"""RETURNS (unicode): The span's label."""
|
"""RETURNS (unicode): The span's label."""
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return self.doc.vocab.strings[self.label]
|
return self.doc.vocab.strings[self.label]
|
||||||
|
def __set__(self, unicode label_):
|
||||||
|
self.label = self.doc.vocab.strings.add(label_)
|
||||||
|
|
||||||
|
|
||||||
cdef int _count_words_to_root(const TokenC* token, int sent_length) except -1:
|
cdef int _count_words_to_root(const TokenC* token, int sent_length) except -1:
|
||||||
|
|
|
@ -34,6 +34,11 @@ cdef class Token:
|
||||||
return Lexeme.c_check_flag(token.lex, feat_name)
|
return Lexeme.c_check_flag(token.lex, feat_name)
|
||||||
elif feat_name == LEMMA:
|
elif feat_name == LEMMA:
|
||||||
return token.lemma
|
return token.lemma
|
||||||
|
elif feat_name == NORM:
|
||||||
|
if token.norm == 0:
|
||||||
|
return token.lex.norm
|
||||||
|
else:
|
||||||
|
return token.norm
|
||||||
elif feat_name == POS:
|
elif feat_name == POS:
|
||||||
return token.pos
|
return token.pos
|
||||||
elif feat_name == TAG:
|
elif feat_name == TAG:
|
||||||
|
@ -58,6 +63,8 @@ cdef class Token:
|
||||||
attr_t value) nogil:
|
attr_t value) nogil:
|
||||||
if feat_name == LEMMA:
|
if feat_name == LEMMA:
|
||||||
token.lemma = value
|
token.lemma = value
|
||||||
|
elif feat_name == NORM:
|
||||||
|
token.norm = value
|
||||||
elif feat_name == POS:
|
elif feat_name == POS:
|
||||||
token.pos = <univ_pos_t>value
|
token.pos = <univ_pos_t>value
|
||||||
elif feat_name == TAG:
|
elif feat_name == TAG:
|
||||||
|
|
|
@ -249,7 +249,10 @@ cdef class Token:
|
||||||
or norm exceptions.
|
or norm exceptions.
|
||||||
"""
|
"""
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return self.c.lex.norm
|
if self.c.norm == 0:
|
||||||
|
return self.c.lex.norm
|
||||||
|
else:
|
||||||
|
return self.c.norm
|
||||||
|
|
||||||
property shape:
|
property shape:
|
||||||
"""RETURNS (uint64): ID of the token's shape, a transform of the
|
"""RETURNS (uint64): ID of the token's shape, a transform of the
|
||||||
|
@ -711,7 +714,10 @@ cdef class Token:
|
||||||
norm exceptions.
|
norm exceptions.
|
||||||
"""
|
"""
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return self.vocab.strings[self.c.lex.norm]
|
return self.vocab.strings[self.norm]
|
||||||
|
|
||||||
|
def __set__(self, unicode norm_):
|
||||||
|
self.c.norm = self.vocab.strings.add(norm_)
|
||||||
|
|
||||||
property shape_:
|
property shape_:
|
||||||
"""RETURNS (unicode): Transform of the tokens's string, to show
|
"""RETURNS (unicode): Transform of the tokens's string, to show
|
||||||
|
|
|
@ -15,6 +15,11 @@ import itertools
|
||||||
import numpy.random
|
import numpy.random
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cupy.random
|
||||||
|
except ImportError:
|
||||||
|
cupy = None
|
||||||
|
|
||||||
from .symbols import ORTH
|
from .symbols import ORTH
|
||||||
from .compat import cupy, CudaStream, path2str, basestring_, unicode_
|
from .compat import cupy, CudaStream, path2str, basestring_, unicode_
|
||||||
from .compat import import_file
|
from .compat import import_file
|
||||||
|
@ -598,6 +603,8 @@ def use_gpu(gpu_id):
|
||||||
def fix_random_seed(seed=0):
|
def fix_random_seed(seed=0):
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
numpy.random.seed(seed)
|
numpy.random.seed(seed)
|
||||||
|
if cupy is not None:
|
||||||
|
cupy.random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
class SimpleFrozenDict(dict):
|
class SimpleFrozenDict(dict):
|
||||||
|
|
|
@ -17,7 +17,7 @@ from .structs cimport SerializedLexemeC
|
||||||
from .compat import copy_reg, basestring_
|
from .compat import copy_reg, basestring_
|
||||||
from .errors import Errors
|
from .errors import Errors
|
||||||
from .lemmatizer import Lemmatizer
|
from .lemmatizer import Lemmatizer
|
||||||
from .attrs import intify_attrs
|
from .attrs import intify_attrs, NORM
|
||||||
from .vectors import Vectors
|
from .vectors import Vectors
|
||||||
from ._ml import link_vectors_to_models
|
from ._ml import link_vectors_to_models
|
||||||
from . import util
|
from . import util
|
||||||
|
@ -234,7 +234,10 @@ cdef class Vocab:
|
||||||
self.morphology.assign_tag(token, props[TAG])
|
self.morphology.assign_tag(token, props[TAG])
|
||||||
for attr_id, value in props.items():
|
for attr_id, value in props.items():
|
||||||
Token.set_struct_attr(token, attr_id, value)
|
Token.set_struct_attr(token, attr_id, value)
|
||||||
Lexeme.set_struct_attr(lex, attr_id, value)
|
# NORM is the only one that overlaps between the two
|
||||||
|
# (which is maybe not great?)
|
||||||
|
if attr_id != NORM:
|
||||||
|
Lexeme.set_struct_attr(lex, attr_id, value)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
Loading…
Reference in New Issue
Block a user