Fix pickle for ngram suggester (#12486)

This commit is contained in:
Adriane Boyd 2023-03-31 13:43:51 +02:00
parent b228875600
commit 57ee1212de
2 changed files with 50 additions and 28 deletions

View File

@ -1,5 +1,6 @@
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer from thinc.api import Optimizer
from thinc.types import Ragged, Ints2d, Floats2d from thinc.types import Ragged, Ints2d, Floats2d
@ -82,13 +83,9 @@ class Suggester(Protocol):
... ...
@registry.misc("spacy.ngram_suggester.v1") def ngram_suggester(
def build_ngram_suggester(sizes: List[int]) -> Suggester: docs: Iterable[Doc], sizes: List[int], *, ops: Optional[Ops] = None
"""Suggest all spans of the given lengths. Spans are returned as a ragged ) -> Ragged:
array of integers. The array has two columns, indicating the start and end
position."""
def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
if ops is None: if ops is None:
ops = get_current_ops() ops = get_current_ops()
spans = [] spans = []
@ -114,7 +111,14 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester:
assert output.dataXd.ndim == 2 assert output.dataXd.ndim == 2
return output return output
return ngram_suggester
@registry.misc("spacy.ngram_suggester.v1")
def build_ngram_suggester(sizes: List[int]) -> Suggester:
"""Suggest all spans of the given lengths. Spans are returned as a ragged
array of integers. The array has two columns, indicating the start and end
position."""
return partial(ngram_suggester, sizes=sizes)
@registry.misc("spacy.ngram_range_suggester.v1") @registry.misc("spacy.ngram_range_suggester.v1")

View File

@ -1,7 +1,7 @@
import pytest import pytest
import numpy import numpy
from numpy.testing import assert_array_equal, assert_almost_equal from numpy.testing import assert_array_equal, assert_almost_equal
from thinc.api import get_current_ops, Ragged from thinc.api import get_current_ops, NumpyOps, Ragged
from spacy import util from spacy import util
from spacy.lang.en import English from spacy.lang.en import English
@ -577,3 +577,21 @@ def test_set_candidates(name):
assert len(docs[0].spans["candidates"]) == 9 assert len(docs[0].spans["candidates"]) == 9
assert docs[0].spans["candidates"][0].text == "Just" assert docs[0].spans["candidates"][0].text == "Just"
assert docs[0].spans["candidates"][4].text == "Just a" assert docs[0].spans["candidates"][4].text == "Just a"
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
@pytest.mark.parametrize("n_process", [1, 2])
def test_spancat_multiprocessing(name, n_process):
if isinstance(get_current_ops, NumpyOps) or n_process < 2:
nlp = Language()
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
train_examples = make_examples(nlp)
nlp.initialize(get_examples=lambda: train_examples)
texts = [
"Just a sentence.",
"I like London and Berlin",
"I like Berlin",
"I eat ham.",
]
docs = list(nlp.pipe(texts, n_process=n_process))
assert len(docs) == len(texts)