Merge pull request #3 from explosion/master

Update
This commit is contained in:
Shen Qin 2022-06-30 23:05:20 +08:00 committed by GitHub
commit 9accdbdbad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 229 additions and 43 deletions

View File

@ -103,6 +103,10 @@ cuda114 =
cupy-cuda114>=5.0.0b4,<11.0.0
cuda115 =
cupy-cuda115>=5.0.0b4,<11.0.0
cuda116 =
cupy-cuda116>=5.0.0b4,<11.0.0
cuda117 =
cupy-cuda117>=5.0.0b4,<11.0.0
apple =
thinc-apple-ops>=0.1.0.dev0,<1.0.0
# Language tokenizers with external dependencies

View File

@ -209,6 +209,9 @@ class Warnings(metaclass=ErrorsWithCodes):
"Only the last span group will be loaded under "
"Doc.spans['{group_name}']. Skipping span group with values: "
"{group_values}")
W121 = ("Attempting to trace non-existent method '{method}' in pipe '{pipe}'")
W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class "
"is a Cython extension type.")
class Errors(metaclass=ErrorsWithCodes):

View File

@ -90,6 +90,10 @@ cdef class Matcher:
'?': Make the pattern optional, by allowing it to match 0 or 1 times.
'+': Require the pattern to match 1 or more times.
'*': Allow the pattern to zero or more times.
'{n}': Require the pattern to match exactly _n_ times.
'{n,m}': Require the pattern to match at least _n_ but not more than _m_ times.
'{n,}': Require the pattern to match at least _n_ times.
'{,m}': Require the pattern to match at most _m_ times.
The + and * operators return all possible matches (not just the greedy
ones). However, the "greedy" argument can filter the final matches
@ -1004,8 +1008,29 @@ def _get_operators(spec):
return (ONE,)
elif spec["OP"] in lookup:
return lookup[spec["OP"]]
#Min_max {n,m}
elif spec["OP"].startswith("{") and spec["OP"].endswith("}"):
# {n} --> {n,n} exactly n ONE,(n)
# {n,m}--> {n,m} min of n, max of m ONE,(n),ZERO_ONE,(m)
# {,m} --> {0,m} min of zero, max of m ZERO_ONE,(m)
# {n,} --> {n,∞} min of n, max of inf ONE,(n),ZERO_PLUS
min_max = spec["OP"][1:-1]
min_max = min_max if "," in min_max else f"{min_max},{min_max}"
n, m = min_max.split(",")
#1. Either n or m is a blank string and the other is numeric -->isdigit
#2. Both are numeric and n <= m
if (not n.isdecimal() and not m.isdecimal()) or (n.isdecimal() and m.isdecimal() and int(n) > int(m)):
keys = ", ".join(lookup.keys()) + ", {n}, {n,m}, {n,}, {,m} where n and m are integers and n <= m "
raise ValueError(Errors.E011.format(op=spec["OP"], opts=keys))
# if n is empty string, zero would be used
head = tuple(ONE for __ in range(int(n or 0)))
tail = tuple(ZERO_ONE for __ in range(int(m) - int(n or 0))) if m else (ZERO_PLUS,)
return head + tail
else:
keys = ", ".join(lookup.keys())
keys = ", ".join(lookup.keys()) + ", {n}, {n,m}, {n,}, {,m} where n and m are integers and n <= m "
raise ValueError(Errors.E011.format(op=spec["OP"], opts=keys))

View File

@ -1,9 +1,14 @@
from functools import partial
from typing import Type, Callable, TYPE_CHECKING
from typing import Type, Callable, Dict, TYPE_CHECKING, List, Optional, Set
import functools
import inspect
import types
import warnings
from thinc.layers import with_nvtx_range
from thinc.model import Model, wrap_model_recursive
from thinc.util import use_nvtx_range
from ..errors import Warnings
from ..util import registry
if TYPE_CHECKING:
@ -11,29 +16,106 @@ if TYPE_CHECKING:
from ..language import Language # noqa: F401
@registry.callbacks("spacy.models_with_nvtx_range.v1")
def create_models_with_nvtx_range(
forward_color: int = -1, backprop_color: int = -1
) -> Callable[["Language"], "Language"]:
def models_with_nvtx_range(nlp):
DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS = [
"pipe",
"predict",
"set_annotations",
"update",
"rehearse",
"get_loss",
"initialize",
"begin_update",
"finish_update",
"update",
]
def models_with_nvtx_range(nlp, forward_color: int, backprop_color: int):
pipes = [
pipe
for _, pipe in nlp.components
if hasattr(pipe, "is_trainable") and pipe.is_trainable
]
# We need process all models jointly to avoid wrapping callbacks twice.
models = Model(
"wrap_with_nvtx_range",
forward=lambda model, X, is_train: ...,
layers=[pipe.model for pipe in pipes],
)
for node in models.walk():
seen_models: Set[int] = set()
for pipe in pipes:
for node in pipe.model.walk():
if id(node) in seen_models:
continue
seen_models.add(id(node))
with_nvtx_range(
node, forward_color=forward_color, backprop_color=backprop_color
)
return nlp
return models_with_nvtx_range
@registry.callbacks("spacy.models_with_nvtx_range.v1")
def create_models_with_nvtx_range(
forward_color: int = -1, backprop_color: int = -1
) -> Callable[["Language"], "Language"]:
return functools.partial(
models_with_nvtx_range,
forward_color=forward_color,
backprop_color=backprop_color,
)
def nvtx_range_wrapper_for_pipe_method(self, func, *args, **kwargs):
if isinstance(func, functools.partial):
return func(*args, **kwargs)
else:
with use_nvtx_range(f"{self.name} {func.__name__}"):
return func(*args, **kwargs)
def pipes_with_nvtx_range(
nlp, additional_pipe_functions: Optional[Dict[str, List[str]]]
):
for _, pipe in nlp.components:
if additional_pipe_functions:
extra_funcs = additional_pipe_functions.get(pipe.name, [])
else:
extra_funcs = []
for name in DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS + extra_funcs:
func = getattr(pipe, name, None)
if func is None:
if name in extra_funcs:
warnings.warn(Warnings.W121.format(method=name, pipe=pipe.name))
continue
wrapped_func = functools.partial(
types.MethodType(nvtx_range_wrapper_for_pipe_method, pipe), func
)
# Try to preserve the original function signature.
try:
wrapped_func.__signature__ = inspect.signature(func) # type: ignore
except:
pass
try:
setattr(
pipe,
name,
wrapped_func,
)
except AttributeError:
warnings.warn(Warnings.W122.format(method=name, pipe=pipe.name))
return nlp
@registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")
def create_models_and_pipes_with_nvtx_range(
forward_color: int = -1,
backprop_color: int = -1,
additional_pipe_functions: Optional[Dict[str, List[str]]] = None,
) -> Callable[["Language"], "Language"]:
def inner(nlp):
nlp = models_with_nvtx_range(nlp, forward_color, backprop_color)
nlp = pipes_with_nvtx_range(nlp, additional_pipe_functions)
return nlp
return inner

View File

@ -3,12 +3,13 @@ from typing import Iterable, TypeVar, TYPE_CHECKING
from .compat import Literal
from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator, create_model
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool, ConstrainedStr
from pydantic.main import ModelMetaclass
from thinc.api import Optimizer, ConfigValidationError, Model
from thinc.config import Promise
from collections import defaultdict
import inspect
import re
from .attrs import NAMES
from .lookups import Lookups
@ -198,13 +199,18 @@ class TokenPatternNumber(BaseModel):
return v
class TokenPatternOperator(str, Enum):
class TokenPatternOperatorSimple(str, Enum):
plus: StrictStr = StrictStr("+")
start: StrictStr = StrictStr("*")
star: StrictStr = StrictStr("*")
question: StrictStr = StrictStr("?")
exclamation: StrictStr = StrictStr("!")
class TokenPatternOperatorMinMax(ConstrainedStr):
regex = re.compile("^({\d+}|{\d+,\d*}|{\d*,\d+})$")
TokenPatternOperator = Union[TokenPatternOperatorSimple, TokenPatternOperatorMinMax]
StringValue = Union[TokenPatternString, StrictStr]
NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
UnderscoreValue = Union[

View File

@ -680,3 +680,38 @@ def test_matcher_ent_iob_key(en_vocab):
assert matches[0] == "Maria"
assert matches[1] == "Maria Esperanza"
assert matches[2] == "Esperanza"
def test_matcher_min_max_operator(en_vocab):
# Exactly n matches {n}
doc = Doc(
en_vocab, words=["foo", "bar", "foo", "foo", "bar",
"foo", "foo", "foo", "bar", "bar"]
)
matcher = Matcher(en_vocab)
pattern = [{"ORTH": "foo", "OP": "{3}"}]
matcher.add("TEST", [pattern])
matches1 = [doc[start:end].text for _, start, end in matcher(doc)]
assert len(matches1) == 1
# At least n matches {n,}
matcher = Matcher(en_vocab)
pattern = [{"ORTH": "foo", "OP": "{2,}"}]
matcher.add("TEST", [pattern])
matches2 = [doc[start:end].text for _, start, end in matcher(doc)]
assert len(matches2) == 4
# At most m matches {,m}
matcher = Matcher(en_vocab)
pattern = [{"ORTH": "foo", "OP": "{,2}"}]
matcher.add("TEST", [pattern])
matches3 = [doc[start:end].text for _, start, end in matcher(doc)]
assert len(matches3) == 9
# At least n matches and most m matches {n,m}
matcher = Matcher(en_vocab)
pattern = [{"ORTH": "foo", "OP": "{2,3}"}]
matcher.add("TEST", [pattern])
matches4 = [doc[start:end].text for _, start, end in matcher(doc)]
assert len(matches4) == 4

View File

@ -699,6 +699,10 @@ def test_matcher_with_alignments_greedy_longest(en_vocab):
("aaaa", "a a a a a?", [0, 1, 2, 3]),
("aaab", "a+ a b", [0, 0, 1, 2]),
("aaab", "a+ a+ b", [0, 0, 1, 2]),
("aaab", "a{2,} b", [0, 0, 0, 1]),
("aaab", "a{,3} b", [0, 0, 0, 1]),
("aaab", "a{2} b", [0, 0, 1]),
("aaab", "a{2,3} b", [0, 0, 0, 1]),
]
for string, pattern_str, result in cases:
matcher = Matcher(en_vocab)
@ -711,6 +715,8 @@ def test_matcher_with_alignments_greedy_longest(en_vocab):
pattern.append({"ORTH": part[0], "OP": "*"})
elif part.endswith("?"):
pattern.append({"ORTH": part[0], "OP": "?"})
elif part.endswith("}"):
pattern.append({"ORTH": part[0], "OP": part[1:]})
else:
pattern.append({"ORTH": part})
matcher.add("PATTERN", [pattern], greedy="LONGEST")
@ -722,7 +728,7 @@ def test_matcher_with_alignments_greedy_longest(en_vocab):
assert expected == result, (string, pattern_str, s, e, n_matches)
def test_matcher_with_alignments_nongreedy(en_vocab):
def test_matcher_with_alignments_non_greedy(en_vocab):
cases = [
(0, "aaab", "a* b", [[0, 1], [0, 0, 1], [0, 0, 0, 1], [1]]),
(1, "baab", "b a* b", [[0, 1, 1, 2]]),
@ -752,6 +758,10 @@ def test_matcher_with_alignments_nongreedy(en_vocab):
(15, "aaaa", "a a a a a?", [[0, 1, 2, 3]]),
(16, "aaab", "a+ a b", [[0, 1, 2], [0, 0, 1, 2]]),
(17, "aaab", "a+ a+ b", [[0, 1, 2], [0, 0, 1, 2]]),
(18, "aaab", "a{2,} b", [[0, 0, 1], [0, 0, 0, 1]]),
(19, "aaab", "a{3} b", [[0, 0, 0, 1]]),
(20, "aaab", "a{2} b", [[0, 0, 1]]),
(21, "aaab", "a{2,3} b", [[0, 0, 1], [0, 0, 0, 1]]),
]
for case_id, string, pattern_str, results in cases:
matcher = Matcher(en_vocab)
@ -764,6 +774,8 @@ def test_matcher_with_alignments_nongreedy(en_vocab):
pattern.append({"ORTH": part[0], "OP": "*"})
elif part.endswith("?"):
pattern.append({"ORTH": part[0], "OP": "?"})
elif part.endswith("}"):
pattern.append({"ORTH": part[0], "OP": part[1:]})
else:
pattern.append({"ORTH": part})

View File

@ -14,6 +14,14 @@ TEST_PATTERNS = [
('[{"TEXT": "foo"}, {"LOWER": "bar"}]', 1, 1),
([{"ENT_IOB": "foo"}], 1, 1),
([1, 2, 3], 3, 1),
([{"TEXT": "foo", "OP": "{,}"}], 1, 1),
([{"TEXT": "foo", "OP": "{,4}4"}], 1, 1),
([{"TEXT": "foo", "OP": "{a,3}"}], 1, 1),
([{"TEXT": "foo", "OP": "{a}"}], 1, 1),
([{"TEXT": "foo", "OP": "{,a}"}], 1, 1),
([{"TEXT": "foo", "OP": "{1,2,3}"}], 1, 1),
([{"TEXT": "foo", "OP": "{1, 3}"}], 1, 1),
([{"TEXT": "foo", "OP": "{-2}"}], 1, 1),
# Bad patterns flagged outside of Matcher
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 2, 0), # prev: (1, 0)
# Bad patterns not flagged with minimal checks
@ -38,6 +46,7 @@ TEST_PATTERNS = [
([{"SENT_START": True}], 0, 0),
([{"ENT_ID": "STRING"}], 0, 0),
([{"ENT_KB_ID": "STRING"}], 0, 0),
([{"TEXT": "ha", "OP": "{3}"}], 0, 0),
]

View File

@ -60,12 +60,11 @@ def test_readers():
assert isinstance(extra_corpus, Callable)
# TODO: enable IMDB test once Stanford servers are back up and running
@pytest.mark.slow
@pytest.mark.parametrize(
"reader,additional_config",
[
# ("ml_datasets.imdb_sentiment.v1", {"train_limit": 10, "dev_limit": 10}),
("ml_datasets.imdb_sentiment.v1", {"train_limit": 10, "dev_limit": 10}),
("ml_datasets.dbpedia.v1", {"train_limit": 10, "dev_limit": 10}),
("ml_datasets.cmu_movies.v1", {"limit": 10, "freq_cutoff": 200, "split": 0.8}),
],

View File

@ -59,15 +59,20 @@ matched:
> [
> {"POS": "ADJ", "OP": "*"},
> {"POS": "NOUN", "OP": "+"}
> {"POS": "PROPN", "OP": "{2}"}
> ]
> ```
| OP | Description |
| --- | ---------------------------------------------------------------- |
|---------|------------------------------------------------------------------------|
| `!` | Negate the pattern, by requiring it to match exactly 0 times. |
| `?` | Make the pattern optional, by allowing it to match 0 or 1 times. |
| `+` | Require the pattern to match 1 or more times. |
| `*` | Allow the pattern to match 0 or more times. |
| `{n}` | Require the pattern to match exactly _n_ times. |
| `{n,m}` | Require the pattern to match at least _n_ but not more than _m_ times. |
| `{n,}` | Require the pattern to match at least _n_ times. |
| `{,m}` | Require the pattern to match at most _m_ times. |
Token patterns can also map to a **dictionary of properties** instead of a
single value to indicate whether the expected value is a member of a list or how

View File

@ -375,11 +375,15 @@ scoped quantifiers instead, you can build those behaviors with `on_match`
callbacks.
| OP | Description |
| --- | ---------------------------------------------------------------- |
|---------|------------------------------------------------------------------------|
| `!` | Negate the pattern, by requiring it to match exactly 0 times. |
| `?` | Make the pattern optional, by allowing it to match 0 or 1 times. |
| `+` | Require the pattern to match 1 or more times. |
| `*` | Allow the pattern to match zero or more times. |
| `{n}` | Require the pattern to match exactly _n_ times. |
| `{n,m}` | Require the pattern to match at least _n_ but not more than _m_ times. |
| `{n,}` | Require the pattern to match at least _n_ times. |
| `{,m}` | Require the pattern to match at most _m_ times. |
> #### Example
>

View File

@ -24,6 +24,8 @@ const CUDA = {
'11.3': 'cuda113',
'11.4': 'cuda114',
'11.5': 'cuda115',
'11.6': 'cuda116',
'11.7': 'cuda117',
}
const LANG_EXTRAS = ['ja'] // only for languages with models