This commit is contained in:
Matthew Honnibal 2025-05-21 20:46:41 +02:00
parent a4bbd0ee08
commit 0630d62264
11 changed files with 47 additions and 43 deletions

View File

@ -24,8 +24,6 @@ TagMapType = Dict[str, Dict[Union[int, str], Union[int, str]]]
MorphRulesType = Dict[str, Dict[str, Dict[Union[int, str], Union[int, str]]]] MorphRulesType = Dict[str, Dict[str, Dict[Union[int, str], Union[int, str]]]]
def attribute_ruler_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: def attribute_ruler_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
def morph_key_getter(token, attr): def morph_key_getter(token, attr):
return getattr(token, attr).key return getattr(token, attr).key

View File

@ -41,8 +41,6 @@ subword_features = true
DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["model"] DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["model"]
class EditTreeLemmatizer(TrainablePipe): class EditTreeLemmatizer(TrainablePipe):
""" """
Lemmatizer that lemmatizes each word using a predicted edit tree. Lemmatizer that lemmatizes each word using a predicted edit tree.

View File

@ -42,8 +42,6 @@ subword_features = true
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
def entity_linker_score(examples, **kwargs): def entity_linker_score(examples, **kwargs):
return Scorer.score_links(examples, negative_labels=[EntityLinker.NIL], **kwargs) return Scorer.score_links(examples, negative_labels=[EntityLinker.NIL], **kwargs)

View File

@ -21,8 +21,6 @@ DEFAULT_ENT_ID_SEP = "||"
PatternType = Dict[str, Union[str, List[Dict[str, Any]]]] PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
def entity_ruler_score(examples, **kwargs): def entity_ruler_score(examples, **kwargs):
return get_ner_prf(examples) return get_ner_prf(examples)

View File

@ -18,8 +18,6 @@ from ..vocab import Vocab
from .pipe import Pipe from .pipe import Pipe
def lemmatizer_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: def lemmatizer_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
return Scorer.score_token_attr(examples, "lemma", **kwargs) return Scorer.score_token_attr(examples, "lemma", **kwargs)

View File

@ -34,10 +34,6 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
DEFAULT_SPANS_KEY = "ruler" DEFAULT_SPANS_KEY = "ruler"
def prioritize_new_ents_filter( def prioritize_new_ents_filter(
entities: Iterable[Span], spans: Iterable[Span] entities: Iterable[Span], spans: Iterable[Span]
) -> List[Span]: ) -> List[Span]:

View File

@ -159,10 +159,6 @@ def build_preset_spans_suggester(spans_key: str) -> Suggester:
return partial(preset_spans_suggester, spans_key=spans_key) return partial(preset_spans_suggester, spans_key=spans_key)
def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
kwargs = dict(kwargs) kwargs = dict(kwargs)
attr_prefix = "spans_" attr_prefix = "spans_"

View File

@ -26,8 +26,6 @@ subword_features = true
DEFAULT_TOK2VEC_MODEL = Config().from_str(default_model_config)["model"] DEFAULT_TOK2VEC_MODEL = Config().from_str(default_model_config)["model"]
class Tok2Vec(TrainablePipe): class Tok2Vec(TrainablePipe):
"""Apply a "token-to-vector" model and set its outputs in the doc.tensor """Apply a "token-to-vector" model and set its outputs in the doc.tensor
attribute. This is mostly useful to share a single subnetwork between multiple attribute. This is mostly useful to share a single subnetwork between multiple

View File

@ -14,6 +14,7 @@ from spacy import util
original_is_same_func = util.is_same_func original_is_same_func = util.is_same_func
def patched_is_same_func(func1, func2): def patched_is_same_func(func1, func2):
# Handle Cython functions # Handle Cython functions
try: try:
@ -22,17 +23,22 @@ def patched_is_same_func(func1, func2):
# For Cython functions, just compare the string representation # For Cython functions, just compare the string representation
return str(func1) == str(func2) return str(func1) == str(func2)
util.is_same_func = patched_is_same_func util.is_same_func = patched_is_same_func
@pytest.fixture @pytest.fixture
def reference_factory_registrations(): def reference_factory_registrations():
"""Load reference factory registrations from JSON file""" """Load reference factory registrations from JSON file"""
if not REFERENCE_FILE.exists(): if not REFERENCE_FILE.exists():
pytest.fail(f"Reference file {REFERENCE_FILE} not found. Run export_factory_registrations.py first.") pytest.fail(
f"Reference file {REFERENCE_FILE} not found. Run export_factory_registrations.py first."
)
with REFERENCE_FILE.open("r") as f: with REFERENCE_FILE.open("r") as f:
return json.load(f) return json.load(f)
def test_factory_registrations_preserved(reference_factory_registrations): def test_factory_registrations_preserved(reference_factory_registrations):
"""Test that all factory registrations from the reference file are still present.""" """Test that all factory registrations from the reference file are still present."""
# Ensure the registry is populated # Ensure the registry is populated
@ -51,13 +57,17 @@ def test_factory_registrations_preserved(reference_factory_registrations):
module_name = func.__module__ module_name = func.__module__
except (AttributeError, TypeError): except (AttributeError, TypeError):
# For Cython functions, just use a placeholder # For Cython functions, just use a placeholder
module_name = str(func).split()[1].split('.')[0] module_name = str(func).split()[1].split(".")[0]
try: try:
func_name = func.__qualname__ func_name = func.__qualname__
except (AttributeError, TypeError): except (AttributeError, TypeError):
# For Cython functions, use the function's name # For Cython functions, use the function's name
func_name = func.__name__ if hasattr(func, "__name__") else str(func).split()[1].split('.')[-1] func_name = (
func.__name__
if hasattr(func, "__name__")
else str(func).split()[1].split(".")[-1]
)
current_registrations[name] = { current_registrations[name] = {
"name": name, "name": name,
@ -66,11 +76,19 @@ def test_factory_registrations_preserved(reference_factory_registrations):
} }
# Check for missing registrations # Check for missing registrations
missing_registrations = set(reference_factory_registrations.keys()) - set(current_registrations.keys()) missing_registrations = set(reference_factory_registrations.keys()) - set(
assert not missing_registrations, f"Missing factory registrations: {', '.join(sorted(missing_registrations))}" current_registrations.keys()
)
assert (
not missing_registrations
), f"Missing factory registrations: {', '.join(sorted(missing_registrations))}"
# Check for new registrations (not an error, but informative) # Check for new registrations (not an error, but informative)
new_registrations = set(current_registrations.keys()) - set(reference_factory_registrations.keys()) new_registrations = set(current_registrations.keys()) - set(
reference_factory_registrations.keys()
)
if new_registrations: if new_registrations:
# This is not an error, just informative # This is not an error, just informative
print(f"New factory registrations found: {', '.join(sorted(new_registrations))}") print(
f"New factory registrations found: {', '.join(sorted(new_registrations))}"
)

View File

@ -7,6 +7,7 @@ from spacy.util import registry
# Path to the reference registry contents, relative to this file # Path to the reference registry contents, relative to this file
REFERENCE_FILE = Path(__file__).parent / "registry_contents.json" REFERENCE_FILE = Path(__file__).parent / "registry_contents.json"
@pytest.fixture @pytest.fixture
def reference_registry(): def reference_registry():
"""Load reference registry contents from JSON file""" """Load reference registry contents from JSON file"""
@ -16,6 +17,7 @@ def reference_registry():
with REFERENCE_FILE.open("r") as f: with REFERENCE_FILE.open("r") as f:
return json.load(f) return json.load(f)
def test_registry_types(reference_registry): def test_registry_types(reference_registry):
"""Test that all registry types match the reference""" """Test that all registry types match the reference"""
# Get current registry types # Get current registry types
@ -26,6 +28,7 @@ def test_registry_types(reference_registry):
missing_types = expected_registry_types - current_registry_types missing_types = expected_registry_types - current_registry_types
assert not missing_types, f"Missing registry types: {', '.join(missing_types)}" assert not missing_types, f"Missing registry types: {', '.join(missing_types)}"
def test_registry_entries(reference_registry): def test_registry_entries(reference_registry):
"""Test that all registry entries are present""" """Test that all registry entries are present"""
# Check each registry's entries # Check each registry's entries
@ -45,4 +48,6 @@ def test_registry_entries(reference_registry):
# Check for missing entries - these would indicate our new registry population # Check for missing entries - these would indicate our new registry population
# mechanism is missing something # mechanism is missing something
missing_entries = expected_set - current_set missing_entries = expected_set - current_set
assert not missing_entries, f"Registry '{registry_name}' missing entries: {', '.join(missing_entries)}" assert (
not missing_entries
), f"Registry '{registry_name}' missing entries: {', '.join(missing_entries)}"

View File

@ -136,6 +136,7 @@ class registry(thinc.registry):
def ensure_populated(cls) -> None: def ensure_populated(cls) -> None:
"""Ensure the registry is populated with all necessary components.""" """Ensure the registry is populated with all necessary components."""
from .registrations import populate_registry, REGISTRY_POPULATED from .registrations import populate_registry, REGISTRY_POPULATED
if not REGISTRY_POPULATED: if not REGISTRY_POPULATED:
populate_registry() populate_registry()