Bloom-filter backed Lookup Tables (#4268)

* Improve load_language_data helper

* WIP: Add Lookups implementation

* Start moving lemma data over to JSON

* WIP: move data over for more languages

* Convert more languages

* Fix lemmatizer fixtures in tests

* Finish conversion

* Auto-format JSON files

* Fix test for now

* Make sure tables are stored on instance

* Update docstrings

* Update docstrings and errors

* Update test

* Add Lookups.__len__

* Add serialization methods

* Add Lookups.remove_table

* Use msgpack for serialization to disk

* Fix file exists check

* Try using OrderedDict for everything

* Update .flake8 [ci skip]

* Try fixing serialization

* Update test_lookups.py

* Update test_serialize_vocab_strings.py

* Lookups / Tables now work

This implements the stubs in the Lookups/Table classes. Currently this
is in Cython but with no type declarations, so that could be improved.

* Add lookups to setup.py

* Actually add lookups pyx

The previous commit added the old py file...

* Lookups work-in-progress

* Move from pyx back to py

* Add string based lookups, fix serialization

* Update tests, language/lemmatizer to work with string lookups

There are some outstanding issues here:

- a pickling-related test fails due to the bloom filter
- some custom lemmatizers (fr/nl at least) have issues

More generally, there's a question of how to deal with the case where
you have a string but want to use the lookup table. Currently the table
allows access by string or id, but that's getting pretty awkward.

* Change lemmatizer lookup method to pass (orth, string)

* Fix token lookup

* Fix French lookup

* Fix lt lemmatizer test

* Fix Dutch lemmatizer

* Fix lemmatizer lookup test

This was using a normal dict instead of a Table, so checks for the
string instead of an integer key failed.

* Make uk/nl/ru lemmatizer lookup methods consistent

The mentioned tokenizers all have their own implementation of the
`lookup` method, which accesses a `Lookups` table. The way that was
called in `token.pyx` was changed so this should be updated to have the
same arguments as `lookup` in `lemmatizer.py` (specificially (orth/id,
string)).

Prior to this change tests weren't failing, but there would probably be
issues with normal use of a model. More tests should proably be added.

Additionally, the language-specific `lookup` implementations seem like
they might not be needed, since they handle things like lower-casing
that aren't actually language specific.

* Make recently added Greek method compatible

* Remove redundant class/method

Leftovers from a merge not cleaned up adequately.
This commit is contained in:
Paul O'Leary McCann 2019-09-13 00:26:11 +09:00 committed by Matthew Honnibal
parent 7d782aa97b
commit 7d8df69158
15 changed files with 126 additions and 53 deletions

View File

@ -46,9 +46,9 @@ class GreekLemmatizer(object):
) )
return lemmas return lemmas
def lookup(self, string): def lookup(self, orth, string):
if string in self.lookup_table: if orth in self.lookup_table:
return self.lookup_table[string] return self.lookup_table[orth]
return string return string

View File

@ -52,7 +52,7 @@ class FrenchLemmatizer(object):
elif univ_pos in (SCONJ, "SCONJ", "sconj"): elif univ_pos in (SCONJ, "SCONJ", "sconj"):
univ_pos = "sconj" univ_pos = "sconj"
else: else:
return [self.lookup(string)] return [self.lookup(None, string)]
# See Issue #435 for example of where this logic is requied. # See Issue #435 for example of where this logic is requied.
if self.is_base_form(univ_pos, morphology): if self.is_base_form(univ_pos, morphology):
return list(set([string.lower()])) return list(set([string.lower()]))
@ -114,9 +114,9 @@ class FrenchLemmatizer(object):
def punct(self, string, morphology=None): def punct(self, string, morphology=None):
return self(string, "punct", morphology) return self(string, "punct", morphology)
def lookup(self, string): def lookup(self, orth, string):
if string in self.lookup_table: if orth is not None and orth in self.lookup_table:
return self.lookup_table[string][0] return self.lookup_table[orth][0]
return string return string

View File

@ -62,11 +62,11 @@ class DutchLemmatizer(object):
# are not lemmatized. They are lowercased, however. # are not lemmatized. They are lowercased, however.
return [string] return [string]
# if string in self.lemma_index.get(univ_pos) # if string in self.lemma_index.get(univ_pos)
lemma_index = self.index.get(univ_pos, {}) lemma_index = self.index.get_string(univ_pos, {})
# string is already lemma # string is already lemma
if string in lemma_index: if string in lemma_index:
return [string] return [string]
exceptions = self.exc.get(univ_pos, {}) exceptions = self.exc.get_string(univ_pos, {})
# string is irregular token contained in exceptions index. # string is irregular token contained in exceptions index.
try: try:
lemma = exceptions[string] lemma = exceptions[string]
@ -75,12 +75,12 @@ class DutchLemmatizer(object):
pass pass
# string corresponds to key in lookup table # string corresponds to key in lookup table
lookup_table = self.lookup_table lookup_table = self.lookup_table
looked_up_lemma = lookup_table.get(string) looked_up_lemma = lookup_table.get_string(string)
if looked_up_lemma and looked_up_lemma in lemma_index: if looked_up_lemma and looked_up_lemma in lemma_index:
return [looked_up_lemma] return [looked_up_lemma]
forms, is_known = lemmatize( forms, is_known = lemmatize(
string, lemma_index, exceptions, self.rules.get(univ_pos, []) string, lemma_index, exceptions, self.rules.get_string(univ_pos, [])
) )
# Back-off through remaining return value candidates. # Back-off through remaining return value candidates.
@ -103,9 +103,12 @@ class DutchLemmatizer(object):
# Overrides parent method so that a lowercased version of the string is # Overrides parent method so that a lowercased version of the string is
# used to search the lookup table. This is necessary because our lookup # used to search the lookup table. This is necessary because our lookup
# table consists entirely of lowercase keys. # table consists entirely of lowercase keys.
def lookup(self, string): def lookup(self, orth, string):
string = string.lower() string = string.lower()
return self.lookup_table.get(string, string) if orth is not None:
return self.lookup_table.get(orth, string)
else:
return self.lookup_table.get_string(string, string)
def noun(self, string, morphology=None): def noun(self, string, morphology=None):
return self(string, "noun", morphology) return self(string, "noun", morphology)

View File

@ -115,7 +115,7 @@ class RussianLemmatizer(Lemmatizer):
def pron(self, string, morphology=None): def pron(self, string, morphology=None):
return self(string, "pron", morphology) return self(string, "pron", morphology)
def lookup(self, string): def lookup(self, orth, string):
analyses = self._morph.parse(string) analyses = self._morph.parse(string)
if len(analyses) == 1: if len(analyses) == 1:
return analyses[0].normal_form return analyses[0].normal_form

View File

@ -112,7 +112,7 @@ class UkrainianLemmatizer(Lemmatizer):
def pron(self, string, morphology=None): def pron(self, string, morphology=None):
return self(string, "pron", morphology) return self(string, "pron", morphology)
def lookup(self, string): def lookup(self, orth, string):
analyses = self._morph.parse(string) analyses = self._morph.parse(string)
if len(analyses) == 1: if len(analyses) == 1:
return analyses[0].normal_form return analyses[0].normal_form

View File

@ -32,6 +32,7 @@ from .lang.tokenizer_exceptions import TOKEN_MATCH
from .lang.tag_map import TAG_MAP from .lang.tag_map import TAG_MAP
from .lang.lex_attrs import LEX_ATTRS, is_stop from .lang.lex_attrs import LEX_ATTRS, is_stop
from .errors import Errors, Warnings, deprecation_warning from .errors import Errors, Warnings, deprecation_warning
from .strings import hash_string
from . import util from . import util
from . import about from . import about

View File

@ -93,9 +93,9 @@ class Lemmatizer(object):
def punct(self, string, morphology=None): def punct(self, string, morphology=None):
return self(string, "punct", morphology) return self(string, "punct", morphology)
def lookup(self, string): def lookup(self, orth, string):
if string in self.lookup_table: if orth in self.lookup_table:
return self.lookup_table[string] return self.lookup_table[orth]
return string return string

View File

@ -1,4 +1,4 @@
# coding: utf8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import srsly import srsly
@ -6,7 +6,12 @@ from collections import OrderedDict
from .errors import Errors from .errors import Errors
from .util import SimpleFrozenDict, ensure_path from .util import SimpleFrozenDict, ensure_path
from .strings import hash_string
from . import util
import srsly
from preshed.bloom import BloomFilter
class Lookups(object): class Lookups(object):
"""Container for large lookup tables and dictionaries, e.g. lemmatization """Container for large lookup tables and dictionaries, e.g. lemmatization
@ -14,10 +19,6 @@ class Lookups(object):
so they can be accessed before the pipeline components are applied (e.g. so they can be accessed before the pipeline components are applied (e.g.
in the tokenizer and lemmatizer), as well as within the pipeline components in the tokenizer and lemmatizer), as well as within the pipeline components
via doc.vocab.lookups. via doc.vocab.lookups.
Important note: At the moment, this class only performs a very basic
dictionary lookup. We're planning to replace this with a more efficient
implementation. See #3971 for details.
""" """
def __init__(self): def __init__(self):
@ -54,8 +55,7 @@ class Lookups(object):
""" """
if name in self.tables: if name in self.tables:
raise ValueError(Errors.E158.format(name=name)) raise ValueError(Errors.E158.format(name=name))
table = Table(name=name) table = Table(name=name, data=data)
table.update(data)
self._tables[name] = table self._tables[name] = table
return table return table
@ -100,10 +100,9 @@ class Lookups(object):
bytes_data (bytes): The data to load. bytes_data (bytes): The data to load.
RETURNS (Lookups): The loaded Lookups. RETURNS (Lookups): The loaded Lookups.
""" """
self._tables = OrderedDict() for key, value in srsly.msgpack_loads(bytes_data).items():
msg = srsly.msgpack_loads(bytes_data) self._tables[key] = Table(key)
for key, value in msg.items(): self._tables[key].update_raw(value)
self._tables[key] = Table.from_dict(value)
return self return self
def to_disk(self, path, **kwargs): def to_disk(self, path, **kwargs):
@ -137,8 +136,10 @@ class Lookups(object):
class Table(OrderedDict): class Table(OrderedDict):
"""A table in the lookups. Subclass of OrderedDict that implements a """A table in the lookups. Subclass of builtin dict that implements a
slightly more consistent and unified API. slightly more consistent and unified API.
Includes a Bloom filter to speed up missed lookups.
""" """
@classmethod @classmethod
@ -153,15 +154,81 @@ class Table(OrderedDict):
self.update(data) self.update(data)
return self return self
def __init__(self, name=None): def __init__(self, name=None, data=None):
"""Initialize a new table. """Initialize a new table.
name (unicode): Optional table name for reference. name (unicode): Optional table name for reference.
data (dict): Initial data, used to hint Bloom Filter.
RETURNS (Table): The newly created object. RETURNS (Table): The newly created object.
""" """
OrderedDict.__init__(self) OrderedDict.__init__(self)
self.name = name self.name = name
# assume a default size of 1M items
size = 1E6
if data and len(data) > 0:
size = len(data)
self.bloom = BloomFilter.from_error_rate(size)
if data:
self.update(data)
def set(self, key, value): def set(self, key, value):
"""Set new key/value pair. Same as table[key] = value.""" """Set new key/value pair, where key is an integer. Same as
table[key] = value.
"""
self[key] = value self[key] = value
def __setitem__(self, key, value):
OrderedDict.__setitem__(self, key, value)
self.bloom.add(key)
def set_string(self, key, value):
"""Set new key/value pair, where key is a string to be hashed.
"""
hkey = hash_string(key)
self.set(hkey, value)
def update(self, data):
"""Add entries in a dict-like to the table, where keys are strings to
be hashed.
"""
for key, val in data.items():
self.set_string(key, val)
def update_raw(self, data):
"""Add entries in a dict-like to the table, where keys are ints.
"""
for key, val in data.items():
self.set(key, val)
def get(self, key, default=None):
return OrderedDict.get(self, key, default)
def get_string(self, key, default=None):
hkey = hash_string(key)
return OrderedDict.get(self, hkey, default)
def __contains__(self, key):
# This can give a false positive, so we need to check it after
if key not in self.bloom:
return False
return OrderedDict.__contains__(self, key)
def contains_string(self, key):
hkey = hash_string(key)
return self.__contains__(hkey)
def to_bytes(self):
# TODO: serialize bloom too. For now just reconstruct it.
return srsly.msgpack_dumps({'name': self.name, 'dict': dict(self.items())})
def from_bytes(self, data):
loaded = srsly.msgpack_loads(data)
self.name = loaded['name']
for key, val in loaded['dict'].items():
self[key] = val
self.bloom.add(key)
return self

View File

@ -273,7 +273,7 @@ cdef class Morphology:
""" """
if token.lemma == 0: if token.lemma == 0:
orth_str = self.strings[token.lex.orth] orth_str = self.strings[token.lex.orth]
lemma = self.lemmatizer.lookup(orth_str) lemma = self.lemmatizer.lookup(token.lex.orth, orth_str)
token.lemma = self.strings.add(lemma) token.lemma = self.strings.add(lemma)
cdef int assign_tag(self, TokenC* token, tag_str) except -1: cdef int assign_tag(self, TokenC* token, tag_str) except -1:

View File

@ -5,11 +5,13 @@ import pytest
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.lemmatizer import Lemmatizer from spacy.lemmatizer import Lemmatizer
from spacy.lookups import Table
@pytest.fixture @pytest.fixture
def lemmatizer(): def lemmatizer():
return Lemmatizer(lookup={"dogs": "dog", "boxen": "box", "mice": "mouse"}) lookup = Table(data={"dogs": "dog", "boxen": "box", "mice": "mouse"})
return Lemmatizer(lookup=lookup)
@pytest.fixture @pytest.fixture

View File

@ -17,4 +17,4 @@ TEST_CASES = [
@pytest.mark.parametrize("tokens,lemmas", TEST_CASES) @pytest.mark.parametrize("tokens,lemmas", TEST_CASES)
def test_lt_lemmatizer(lt_lemmatizer, tokens, lemmas): def test_lt_lemmatizer(lt_lemmatizer, tokens, lemmas):
assert lemmas == [lt_lemmatizer.lookup(token) for token in tokens] assert lemmas == [lt_lemmatizer.lookup_table.get_string(token, token) for token in tokens]

View File

@ -133,11 +133,11 @@ def test_nl_lemmatizer_pronoun_lemmas(nl_lemmatizer, text, lemma):
# Using the lemma lookup table only # Using the lemma lookup table only
@pytest.mark.parametrize("text,lemma", noun_irreg_lemmatization_cases) @pytest.mark.parametrize("text,lemma", noun_irreg_lemmatization_cases)
def test_nl_lemmatizer_lookup_noun(nl_lemmatizer, text, lemma): def test_nl_lemmatizer_lookup_noun(nl_lemmatizer, text, lemma):
lemma_pred = nl_lemmatizer.lookup(text) lemma_pred = nl_lemmatizer.lookup(None, text)
assert lemma_pred in (lemma, text) assert lemma_pred in (lemma, text)
@pytest.mark.parametrize("text,lemma", verb_irreg_lemmatization_cases) @pytest.mark.parametrize("text,lemma", verb_irreg_lemmatization_cases)
def test_nl_lemmatizer_lookup_verb(nl_lemmatizer, text, lemma): def test_nl_lemmatizer_lookup_verb(nl_lemmatizer, text, lemma):
lemma_pred = nl_lemmatizer.lookup(text) lemma_pred = nl_lemmatizer.lookup(None, text)
assert lemma_pred in (lemma, text) assert lemma_pred in (lemma, text)

View File

@ -19,9 +19,9 @@ def test_lookups_api():
table = lookups.get_table(table_name) table = lookups.get_table(table_name)
assert table.name == table_name assert table.name == table_name
assert len(table) == 2 assert len(table) == 2
assert table.get("hello") == "world" assert table.get_string("hello") == "world"
table.set("a", "b") table.set_string("a", "b")
assert table.get("a") == "b" assert table.get_string("a") == "b"
table = lookups.get_table(table_name) table = lookups.get_table(table_name)
assert len(table) == 3 assert len(table) == 3
with pytest.raises(KeyError): with pytest.raises(KeyError):
@ -50,10 +50,10 @@ def test_lookups_to_from_bytes():
assert "table2" in new_lookups assert "table2" in new_lookups
table1 = new_lookups.get_table("table1") table1 = new_lookups.get_table("table1")
assert len(table1) == 2 assert len(table1) == 2
assert table1.get("foo") == "bar" assert table1.get_string("foo") == "bar"
table2 = new_lookups.get_table("table2") table2 = new_lookups.get_table("table2")
assert len(table2) == 3 assert len(table2) == 3
assert table2.get("b") == 2 assert table2.get_string("b") == 2
assert new_lookups.to_bytes() == lookups_bytes assert new_lookups.to_bytes() == lookups_bytes
@ -72,10 +72,11 @@ def test_lookups_to_from_disk():
assert "table2" in new_lookups assert "table2" in new_lookups
table1 = new_lookups.get_table("table1") table1 = new_lookups.get_table("table1")
assert len(table1) == 2 assert len(table1) == 2
assert table1.get("foo") == "bar" assert table1.get_string("foo") == "bar"
table2 = new_lookups.get_table("table2") table2 = new_lookups.get_table("table2")
assert len(table2) == 3 assert len(table2) == 3
assert table2.get("b") == 2 assert table2.get_string("b") == 2
# This fails on Python 3.5 # This fails on Python 3.5
@ -93,10 +94,9 @@ def test_lookups_to_from_bytes_via_vocab():
assert table_name in new_vocab.lookups assert table_name in new_vocab.lookups
table = new_vocab.lookups.get_table(table_name) table = new_vocab.lookups.get_table(table_name)
assert len(table) == 2 assert len(table) == 2
assert table.get("hello") == "world" assert table.get_string("hello") == "world"
assert new_vocab.to_bytes() == vocab_bytes assert new_vocab.to_bytes() == vocab_bytes
# This fails on Python 3.5 # This fails on Python 3.5
@pytest.mark.xfail @pytest.mark.xfail
def test_lookups_to_from_disk_via_vocab(): def test_lookups_to_from_disk_via_vocab():
@ -113,4 +113,4 @@ def test_lookups_to_from_disk_via_vocab():
assert table_name in new_vocab.lookups assert table_name in new_vocab.lookups
table = new_vocab.lookups.get_table(table_name) table = new_vocab.lookups.get_table(table_name)
assert len(table) == 2 assert len(table) == 2
assert table.get("hello") == "world" assert table.get_string("hello") == "world"

View File

@ -335,7 +335,7 @@ cdef class Token:
""" """
def __get__(self): def __get__(self):
if self.c.lemma == 0: if self.c.lemma == 0:
lemma_ = self.vocab.morphology.lemmatizer.lookup(self.orth_) lemma_ = self.vocab.morphology.lemmatizer.lookup(self.orth, self.orth_)
return self.vocab.strings[lemma_] return self.vocab.strings[lemma_]
else: else:
return self.c.lemma return self.c.lemma
@ -862,7 +862,7 @@ cdef class Token:
""" """
def __get__(self): def __get__(self):
if self.c.lemma == 0: if self.c.lemma == 0:
return self.vocab.morphology.lemmatizer.lookup(self.orth_) return self.vocab.morphology.lemmatizer.lookup(self.orth, self.orth_)
else: else:
return self.vocab.strings[self.c.lemma] return self.vocab.strings[self.c.lemma]

View File

@ -18,10 +18,10 @@ from .structs cimport SerializedLexemeC
from .compat import copy_reg, basestring_ from .compat import copy_reg, basestring_
from .errors import Errors from .errors import Errors
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
from .lookups import Lookups
from .attrs import intify_attrs, NORM from .attrs import intify_attrs, NORM
from .vectors import Vectors from .vectors import Vectors
from ._ml import link_vectors_to_models from ._ml import link_vectors_to_models
from .lookups import Lookups
from . import util from . import util