2017-10-27 03:58:14 +03:00
|
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# coding: utf8
|
|
|
|
|
"""Example of a spaCy v2.0 pipeline component that requests all countries via
|
|
|
|
|
the REST Countries API, merges country names into one token, assigns entity
|
|
|
|
|
labels and sets attributes on country tokens, e.g. the capital and lat/lng
|
|
|
|
|
coordinates. Can be extended with more details from the API.
|
|
|
|
|
|
|
|
|
|
* REST Countries API: https://restcountries.eu (Mozilla Public License MPL 2.0)
|
2017-11-07 14:00:43 +03:00
|
|
|
|
* Custom pipeline components: https://spacy.io//usage/processing-pipelines#custom-components
|
2017-10-27 03:58:14 +03:00
|
|
|
|
|
2017-11-07 03:22:30 +03:00
|
|
|
|
Compatible with: spaCy v2.0.0+
|
2019-03-16 16:15:49 +03:00
|
|
|
|
Last tested with: v2.1.0
|
2018-03-28 13:46:27 +03:00
|
|
|
|
Prerequisites: pip install requests
|
2017-10-27 03:58:14 +03:00
|
|
|
|
"""
|
2017-10-27 04:55:04 +03:00
|
|
|
|
from __future__ import unicode_literals, print_function
|
2017-10-10 05:26:06 +03:00
|
|
|
|
|
|
|
|
|
import requests
|
2017-10-27 03:58:14 +03:00
|
|
|
|
import plac
|
2017-10-10 05:26:06 +03:00
|
|
|
|
from spacy.lang.en import English
|
|
|
|
|
from spacy.matcher import PhraseMatcher
|
2017-10-11 03:30:40 +03:00
|
|
|
|
from spacy.tokens import Doc, Span, Token
|
2017-10-10 05:26:06 +03:00
|
|
|
|
|
|
|
|
|
|
2017-10-27 03:58:14 +03:00
|
|
|
|
def main():
|
|
|
|
|
# For simplicity, we start off with only the blank English Language class
|
|
|
|
|
# and no model or pre-defined pipeline loaded.
|
|
|
|
|
nlp = English()
|
|
|
|
|
rest_countries = RESTCountriesComponent(nlp) # initialise component
|
2018-12-02 06:26:26 +03:00
|
|
|
|
nlp.add_pipe(rest_countries) # add it to the pipeline
|
|
|
|
|
doc = nlp("Some text about Colombia and the Czech Republic")
|
|
|
|
|
print("Pipeline", nlp.pipe_names) # pipeline contains component name
|
|
|
|
|
print("Doc has countries", doc._.has_country) # Doc contains countries
|
2017-10-27 03:58:14 +03:00
|
|
|
|
for token in doc:
|
|
|
|
|
if token._.is_country:
|
2018-12-02 06:26:26 +03:00
|
|
|
|
print(
|
|
|
|
|
token.text,
|
|
|
|
|
token._.country_capital,
|
|
|
|
|
token._.country_latlng,
|
|
|
|
|
token._.country_flag,
|
|
|
|
|
) # country data
|
|
|
|
|
print("Entities", [(e.text, e.label_) for e in doc.ents]) # entities
|
2017-10-27 03:58:14 +03:00
|
|
|
|
|
2017-10-10 05:26:06 +03:00
|
|
|
|
|
2017-10-27 03:58:14 +03:00
|
|
|
|
class RESTCountriesComponent(object):
|
|
|
|
|
"""spaCy v2.0 pipeline component that requests all countries via
|
|
|
|
|
the REST Countries API, merges country names into one token, assigns entity
|
|
|
|
|
labels and sets attributes on country tokens.
|
2017-10-10 05:26:06 +03:00
|
|
|
|
"""
|
|
|
|
|
|
2018-12-02 06:26:26 +03:00
|
|
|
|
name = "rest_countries" # component name, will show up in the pipeline
|
|
|
|
|
|
|
|
|
|
def __init__(self, nlp, label="GPE"):
|
2017-10-10 05:26:06 +03:00
|
|
|
|
"""Initialise the pipeline component. The shared nlp instance is used
|
|
|
|
|
to initialise the matcher with the shared vocab, get the label ID and
|
|
|
|
|
generate Doc objects as phrase match patterns.
|
|
|
|
|
"""
|
|
|
|
|
# Make request once on initialisation and store the data
|
2018-12-02 06:26:26 +03:00
|
|
|
|
r = requests.get("https://restcountries.eu/rest/v2/all")
|
2017-10-10 05:26:06 +03:00
|
|
|
|
r.raise_for_status() # make sure requests raises an error if it fails
|
|
|
|
|
countries = r.json()
|
|
|
|
|
|
|
|
|
|
# Convert API response to dict keyed by country name for easy lookup
|
|
|
|
|
# This could also be extended using the alternative and foreign language
|
|
|
|
|
# names provided by the API
|
2018-12-02 06:26:26 +03:00
|
|
|
|
self.countries = {c["name"]: c for c in countries}
|
2017-10-10 05:26:06 +03:00
|
|
|
|
self.label = nlp.vocab.strings[label] # get entity label ID
|
|
|
|
|
|
|
|
|
|
# Set up the PhraseMatcher with Doc patterns for each country name
|
|
|
|
|
patterns = [nlp(c) for c in self.countries.keys()]
|
|
|
|
|
self.matcher = PhraseMatcher(nlp.vocab)
|
2018-12-02 06:26:26 +03:00
|
|
|
|
self.matcher.add("COUNTRIES", None, *patterns)
|
2017-10-10 05:26:06 +03:00
|
|
|
|
|
|
|
|
|
# Register attribute on the Token. We'll be overwriting this based on
|
|
|
|
|
# the matches, so we're only setting a default value, not a getter.
|
|
|
|
|
# If no default value is set, it defaults to None.
|
2018-12-02 06:26:26 +03:00
|
|
|
|
Token.set_extension("is_country", default=False)
|
|
|
|
|
Token.set_extension("country_capital", default=False)
|
|
|
|
|
Token.set_extension("country_latlng", default=False)
|
|
|
|
|
Token.set_extension("country_flag", default=False)
|
2017-10-10 05:26:06 +03:00
|
|
|
|
|
|
|
|
|
# Register attributes on Doc and Span via a getter that checks if one of
|
|
|
|
|
# the contained tokens is set to is_country == True.
|
2018-12-02 06:26:26 +03:00
|
|
|
|
Doc.set_extension("has_country", getter=self.has_country)
|
|
|
|
|
Span.set_extension("has_country", getter=self.has_country)
|
2017-10-10 05:26:06 +03:00
|
|
|
|
|
|
|
|
|
def __call__(self, doc):
|
|
|
|
|
"""Apply the pipeline component on a Doc object and modify it if matches
|
|
|
|
|
are found. Return the Doc, so it can be processed by the next component
|
|
|
|
|
in the pipeline, if available.
|
|
|
|
|
"""
|
|
|
|
|
matches = self.matcher(doc)
|
|
|
|
|
spans = [] # keep the spans for later so we can merge them afterwards
|
|
|
|
|
for _, start, end in matches:
|
|
|
|
|
# Generate Span representing the entity & set label
|
|
|
|
|
entity = Span(doc, start, end, label=self.label)
|
|
|
|
|
spans.append(entity)
|
|
|
|
|
# Set custom attribute on each token of the entity
|
|
|
|
|
# Can be extended with other data returned by the API, like
|
|
|
|
|
# currencies, country code, flag, calling code etc.
|
|
|
|
|
for token in entity:
|
2018-12-02 06:26:26 +03:00
|
|
|
|
token._.set("is_country", True)
|
|
|
|
|
token._.set("country_capital", self.countries[entity.text]["capital"])
|
|
|
|
|
token._.set("country_latlng", self.countries[entity.text]["latlng"])
|
|
|
|
|
token._.set("country_flag", self.countries[entity.text]["flag"])
|
2017-10-10 05:26:06 +03:00
|
|
|
|
# Overwrite doc.ents and add entity – be careful not to replace!
|
|
|
|
|
doc.ents = list(doc.ents) + [entity]
|
|
|
|
|
for span in spans:
|
|
|
|
|
# Iterate over all spans and merge them into one token. This is done
|
|
|
|
|
# after setting the entities – otherwise, it would cause mismatched
|
|
|
|
|
# indices!
|
|
|
|
|
span.merge()
|
|
|
|
|
return doc # don't forget to return the Doc!
|
|
|
|
|
|
|
|
|
|
def has_country(self, tokens):
|
|
|
|
|
"""Getter for Doc and Span attributes. Returns True if one of the tokens
|
|
|
|
|
is a country. Since the getter is only called when we access the
|
|
|
|
|
attribute, we can refer to the Token's 'is_country' attribute here,
|
|
|
|
|
which is already set in the processing step."""
|
2018-12-02 06:26:26 +03:00
|
|
|
|
return any([t._.get("is_country") for t in tokens])
|
2017-10-10 05:26:06 +03:00
|
|
|
|
|
|
|
|
|
|
2018-12-02 06:26:26 +03:00
|
|
|
|
if __name__ == "__main__":
|
2017-10-27 03:58:14 +03:00
|
|
|
|
plac.call(main)
|
2017-10-10 05:26:06 +03:00
|
|
|
|
|
2017-10-27 03:58:14 +03:00
|
|
|
|
# Expected output:
|
|
|
|
|
# Pipeline ['rest_countries']
|
|
|
|
|
# Doc has countries True
|
|
|
|
|
# Colombia Bogotá [4.0, -72.0] https://restcountries.eu/data/col.svg
|
|
|
|
|
# Czech Republic Prague [49.75, 15.5] https://restcountries.eu/data/cze.svg
|
|
|
|
|
# Entities [('Colombia', 'GPE'), ('Czech Republic', 'GPE')]
|