This commit is contained in:
Matthew Honnibal 2025-05-21 20:46:41 +02:00
parent a4bbd0ee08
commit 0630d62264
11 changed files with 47 additions and 43 deletions

View File

@ -24,8 +24,6 @@ TagMapType = Dict[str, Dict[Union[int, str], Union[int, str]]]
MorphRulesType = Dict[str, Dict[str, Dict[Union[int, str], Union[int, str]]]]
def attribute_ruler_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
def morph_key_getter(token, attr):
return getattr(token, attr).key

View File

@ -41,8 +41,6 @@ subword_features = true
DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["model"]
class EditTreeLemmatizer(TrainablePipe):
"""
Lemmatizer that lemmatizes each word using a predicted edit tree.

View File

@ -42,8 +42,6 @@ subword_features = true
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
def entity_linker_score(examples, **kwargs):
return Scorer.score_links(examples, negative_labels=[EntityLinker.NIL], **kwargs)

View File

@ -21,8 +21,6 @@ DEFAULT_ENT_ID_SEP = "||"
PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
def entity_ruler_score(examples, **kwargs):
return get_ner_prf(examples)

View File

@ -18,8 +18,6 @@ from ..vocab import Vocab
from .pipe import Pipe
def lemmatizer_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
return Scorer.score_token_attr(examples, "lemma", **kwargs)

View File

@ -34,10 +34,6 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
DEFAULT_SPANS_KEY = "ruler"
def prioritize_new_ents_filter(
entities: Iterable[Span], spans: Iterable[Span]
) -> List[Span]:

View File

@ -159,10 +159,6 @@ def build_preset_spans_suggester(spans_key: str) -> Suggester:
return partial(preset_spans_suggester, spans_key=spans_key)
def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
kwargs = dict(kwargs)
attr_prefix = "spans_"

View File

@ -26,8 +26,6 @@ subword_features = true
DEFAULT_TOK2VEC_MODEL = Config().from_str(default_model_config)["model"]
class Tok2Vec(TrainablePipe):
"""Apply a "token-to-vector" model and set its outputs in the doc.tensor
attribute. This is mostly useful to share a single subnetwork between multiple

View File

@ -14,6 +14,7 @@ from spacy import util
original_is_same_func = util.is_same_func
def patched_is_same_func(func1, func2):
# Handle Cython functions
try:
@ -22,17 +23,22 @@ def patched_is_same_func(func1, func2):
# For Cython functions, just compare the string representation
return str(func1) == str(func2)
util.is_same_func = patched_is_same_func
@pytest.fixture
def reference_factory_registrations():
"""Load reference factory registrations from JSON file"""
if not REFERENCE_FILE.exists():
pytest.fail(f"Reference file {REFERENCE_FILE} not found. Run export_factory_registrations.py first.")
pytest.fail(
f"Reference file {REFERENCE_FILE} not found. Run export_factory_registrations.py first."
)
with REFERENCE_FILE.open("r") as f:
return json.load(f)
def test_factory_registrations_preserved(reference_factory_registrations):
"""Test that all factory registrations from the reference file are still present."""
# Ensure the registry is populated
@ -51,13 +57,17 @@ def test_factory_registrations_preserved(reference_factory_registrations):
module_name = func.__module__
except (AttributeError, TypeError):
# For Cython functions, just use a placeholder
module_name = str(func).split()[1].split('.')[0]
module_name = str(func).split()[1].split(".")[0]
try:
func_name = func.__qualname__
except (AttributeError, TypeError):
# For Cython functions, use the function's name
func_name = func.__name__ if hasattr(func, "__name__") else str(func).split()[1].split('.')[-1]
func_name = (
func.__name__
if hasattr(func, "__name__")
else str(func).split()[1].split(".")[-1]
)
current_registrations[name] = {
"name": name,
@ -66,11 +76,19 @@ def test_factory_registrations_preserved(reference_factory_registrations):
}
# Check for missing registrations
missing_registrations = set(reference_factory_registrations.keys()) - set(current_registrations.keys())
assert not missing_registrations, f"Missing factory registrations: {', '.join(sorted(missing_registrations))}"
missing_registrations = set(reference_factory_registrations.keys()) - set(
current_registrations.keys()
)
assert (
not missing_registrations
), f"Missing factory registrations: {', '.join(sorted(missing_registrations))}"
# Check for new registrations (not an error, but informative)
new_registrations = set(current_registrations.keys()) - set(reference_factory_registrations.keys())
new_registrations = set(current_registrations.keys()) - set(
reference_factory_registrations.keys()
)
if new_registrations:
# This is not an error, just informative
print(f"New factory registrations found: {', '.join(sorted(new_registrations))}")
print(
f"New factory registrations found: {', '.join(sorted(new_registrations))}"
)

View File

@ -7,6 +7,7 @@ from spacy.util import registry
# Path to the reference registry contents, relative to this file
REFERENCE_FILE = Path(__file__).parent / "registry_contents.json"
@pytest.fixture
def reference_registry():
"""Load reference registry contents from JSON file"""
@ -16,6 +17,7 @@ def reference_registry():
with REFERENCE_FILE.open("r") as f:
return json.load(f)
def test_registry_types(reference_registry):
"""Test that all registry types match the reference"""
# Get current registry types
@ -26,6 +28,7 @@ def test_registry_types(reference_registry):
missing_types = expected_registry_types - current_registry_types
assert not missing_types, f"Missing registry types: {', '.join(missing_types)}"
def test_registry_entries(reference_registry):
"""Test that all registry entries are present"""
# Check each registry's entries
@ -45,4 +48,6 @@ def test_registry_entries(reference_registry):
# Check for missing entries - these would indicate our new registry population
# mechanism is missing something
missing_entries = expected_set - current_set
assert not missing_entries, f"Registry '{registry_name}' missing entries: {', '.join(missing_entries)}"
assert (
not missing_entries
), f"Registry '{registry_name}' missing entries: {', '.join(missing_entries)}"

View File

@ -136,6 +136,7 @@ class registry(thinc.registry):
def ensure_populated(cls) -> None:
"""Ensure the registry is populated with all necessary components."""
from .registrations import populate_registry, REGISTRY_POPULATED
if not REGISTRY_POPULATED:
populate_registry()