mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 18:07:26 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			226 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			226 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# coding: utf8
 | 
						|
from __future__ import unicode_literals
 | 
						|
 | 
						|
import plac
 | 
						|
import math
 | 
						|
from tqdm import tqdm
 | 
						|
import numpy
 | 
						|
from ast import literal_eval
 | 
						|
from pathlib import Path
 | 
						|
from preshed.counter import PreshCounter
 | 
						|
import tarfile
 | 
						|
import gzip
 | 
						|
import zipfile
 | 
						|
import ujson as json
 | 
						|
from spacy.lexeme import intify_attrs
 | 
						|
 | 
						|
from ._messages import Messages
 | 
						|
from ..vectors import Vectors
 | 
						|
from ..errors import Errors, Warnings, user_warning
 | 
						|
from ..util import prints, ensure_path, get_lang_class
 | 
						|
 | 
						|
try:
 | 
						|
    import ftfy
 | 
						|
except ImportError:
 | 
						|
    ftfy = None
 | 
						|
 | 
						|
 | 
						|
@plac.annotations(
 | 
						|
    lang=("model language", "positional", None, str),
 | 
						|
    output_dir=("model output directory", "positional", None, Path),
 | 
						|
    freqs_loc=("location of words frequencies file", "option", "f", Path),
 | 
						|
    jsonl_loc=("location of JSONL-formatted attributes file", "option", "j", Path),
 | 
						|
    clusters_loc=("optional: location of brown clusters data",
 | 
						|
                  "option", "c", str),
 | 
						|
    vectors_loc=("optional: location of vectors file in Word2Vec format "
 | 
						|
                 "(either as .txt or zipped as .zip or .tar.gz)", "option",
 | 
						|
                 "v", str),
 | 
						|
    prune_vectors=("optional: number of vectors to prune to",
 | 
						|
                   "option", "V", int)
 | 
						|
)
 | 
						|
def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, jsonl_loc=None,
 | 
						|
               vectors_loc=None, prune_vectors=-1):
 | 
						|
    """
 | 
						|
    Create a new model from raw data, like word frequencies, Brown clusters
 | 
						|
    and word vectors.
 | 
						|
    """
 | 
						|
    if jsonl_loc is not None:
 | 
						|
        if freqs_loc is not None or clusters_loc is not None:
 | 
						|
            settings = ['-j']
 | 
						|
            if freqs_loc:
 | 
						|
                settings.append('-f')
 | 
						|
            if clusters_loc:
 | 
						|
                settings.append('-c')
 | 
						|
            prints(' '.join(settings),
 | 
						|
                title=(
 | 
						|
                    "The -f and -c arguments are deprecated, and not compatible "
 | 
						|
                    "with the -j argument, which should specify the same information. "
 | 
						|
                    "Either merge the frequencies and clusters data into the "
 | 
						|
                    "jsonl-formatted file (recommended), or use only the -f and "
 | 
						|
                    "-c files, without the other lexical attributes."))
 | 
						|
        jsonl_loc = ensure_path(jsonl_loc)
 | 
						|
        lex_attrs = (json.loads(line) for line in jsonl_loc.open())
 | 
						|
    else:
 | 
						|
        clusters_loc = ensure_path(clusters_loc)
 | 
						|
        freqs_loc = ensure_path(freqs_loc)
 | 
						|
        if freqs_loc is not None and not freqs_loc.exists():
 | 
						|
            prints(freqs_loc, title=Messages.M037, exits=1)
 | 
						|
        lex_attrs = read_attrs_from_deprecated(freqs_loc, clusters_loc)
 | 
						|
 | 
						|
    nlp = create_model(lang, lex_attrs)
 | 
						|
    if vectors_loc is not None:
 | 
						|
        add_vectors(nlp, vectors_loc, prune_vectors)
 | 
						|
    vec_added = len(nlp.vocab.vectors)
 | 
						|
    lex_added = len(nlp.vocab)
 | 
						|
    prints(Messages.M039.format(entries=lex_added, vectors=vec_added),
 | 
						|
           title=Messages.M038)
 | 
						|
    if not output_dir.exists():
 | 
						|
        output_dir.mkdir()
 | 
						|
    nlp.to_disk(output_dir)
 | 
						|
    return nlp
 | 
						|
 | 
						|
def open_file(loc):
 | 
						|
    '''Handle .gz, .tar.gz or unzipped files'''
 | 
						|
    loc = ensure_path(loc)
 | 
						|
    print("Open loc")
 | 
						|
    if tarfile.is_tarfile(str(loc)):
 | 
						|
        return tarfile.open(str(loc), 'r:gz')
 | 
						|
    elif loc.parts[-1].endswith('gz'):
 | 
						|
        return (line.decode('utf8') for line in gzip.open(str(loc), 'r'))
 | 
						|
    elif loc.parts[-1].endswith('zip'):
 | 
						|
        zip_file = zipfile.ZipFile(str(loc))
 | 
						|
        names = zip_file.namelist()
 | 
						|
        file_ = zip_file.open(names[0])
 | 
						|
        return (line.decode('utf8') for line in file_)
 | 
						|
    else:
 | 
						|
        return loc.open('r', encoding='utf8')
 | 
						|
 | 
						|
def read_attrs_from_deprecated(freqs_loc, clusters_loc):
 | 
						|
    probs, oov_prob = read_freqs(freqs_loc) if freqs_loc is not None else ({}, -20)
 | 
						|
    clusters = read_clusters(clusters_loc) if clusters_loc else {}
 | 
						|
    lex_attrs = []
 | 
						|
    sorted_probs = sorted(probs.items(), key=lambda item: item[1], reverse=True)
 | 
						|
    for i, (word, prob) in tqdm(enumerate(sorted_probs)):
 | 
						|
        attrs = {'orth': word, 'id': i, 'prob': prob}
 | 
						|
        # Decode as a little-endian string, so that we can do & 15 to get
 | 
						|
        # the first 4 bits. See _parse_features.pyx
 | 
						|
        if word in clusters:
 | 
						|
            attrs['cluster'] = int(clusters[word][::-1], 2)
 | 
						|
        else:
 | 
						|
            attrs['cluster'] = 0
 | 
						|
        lex_attrs.append(attrs)
 | 
						|
    return lex_attrs
 | 
						|
 | 
						|
 | 
						|
def create_model(lang, lex_attrs):
 | 
						|
    print("Creating model...")
 | 
						|
    lang_class = get_lang_class(lang)
 | 
						|
    nlp = lang_class()
 | 
						|
    for lexeme in nlp.vocab:
 | 
						|
        lexeme.rank = 0
 | 
						|
    lex_added = 0
 | 
						|
    for attrs in lex_attrs:
 | 
						|
        if 'settings' in attrs:
 | 
						|
            continue
 | 
						|
        lexeme = nlp.vocab[attrs['orth']]
 | 
						|
        lexeme.set_attrs(**attrs)
 | 
						|
        lexeme.is_oov = False
 | 
						|
        lex_added += 1
 | 
						|
        lex_added += 1
 | 
						|
    oov_prob = min(lex.prob for lex in nlp.vocab)
 | 
						|
    nlp.vocab.cfg.update({'oov_prob': oov_prob-1})
 | 
						|
    return nlp
 | 
						|
 | 
						|
def add_vectors(nlp, vectors_loc, prune_vectors):
 | 
						|
    vectors_loc = ensure_path(vectors_loc)
 | 
						|
    if vectors_loc and vectors_loc.parts[-1].endswith('.npz'):
 | 
						|
        nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open('rb')))
 | 
						|
        for lex in nlp.vocab:
 | 
						|
            if lex.rank:
 | 
						|
                nlp.vocab.vectors.add(lex.orth, row=lex.rank)
 | 
						|
    else:
 | 
						|
        vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None)
 | 
						|
        if vector_keys is not None:
 | 
						|
            for word in vector_keys:
 | 
						|
                if word not in nlp.vocab:
 | 
						|
                    lexeme = nlp.vocab[word]
 | 
						|
                    lexeme.is_oov = False
 | 
						|
        if vectors_data is not None:
 | 
						|
            nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
 | 
						|
    nlp.vocab.vectors.name = '%s_model.vectors' % nlp.meta['lang']
 | 
						|
    nlp.meta['vectors']['name'] = nlp.vocab.vectors.name
 | 
						|
    if prune_vectors >= 1:
 | 
						|
        nlp.vocab.prune_vectors(prune_vectors)
 | 
						|
 | 
						|
def read_vectors(vectors_loc):
 | 
						|
    print("Reading vectors from %s" % vectors_loc)
 | 
						|
    f = open_file(vectors_loc)
 | 
						|
    shape = tuple(int(size) for size in next(f).split())
 | 
						|
    vectors_data = numpy.zeros(shape=shape, dtype='f')
 | 
						|
    vectors_keys = []
 | 
						|
    for i, line in enumerate(tqdm(f)):
 | 
						|
        line = line.rstrip()
 | 
						|
        pieces = line.rsplit(' ', vectors_data.shape[1]+1)
 | 
						|
        word = pieces.pop(0)
 | 
						|
        if len(pieces) != vectors_data.shape[1]:
 | 
						|
            raise ValueError(Errors.E094.format(line_num=i, loc=vectors_loc))
 | 
						|
        vectors_data[i] = numpy.asarray(pieces, dtype='f')
 | 
						|
        vectors_keys.append(word)
 | 
						|
    return vectors_data, vectors_keys
 | 
						|
 | 
						|
 | 
						|
def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
 | 
						|
    print("Counting frequencies...")
 | 
						|
    counts = PreshCounter()
 | 
						|
    total = 0
 | 
						|
    with freqs_loc.open() as f:
 | 
						|
        for i, line in enumerate(f):
 | 
						|
            freq, doc_freq, key = line.rstrip().split('\t', 2)
 | 
						|
            freq = int(freq)
 | 
						|
            counts.inc(i + 1, freq)
 | 
						|
            total += freq
 | 
						|
    counts.smooth()
 | 
						|
    log_total = math.log(total)
 | 
						|
    probs = {}
 | 
						|
    with freqs_loc.open() as f:
 | 
						|
        for line in tqdm(f):
 | 
						|
            freq, doc_freq, key = line.rstrip().split('\t', 2)
 | 
						|
            doc_freq = int(doc_freq)
 | 
						|
            freq = int(freq)
 | 
						|
            if doc_freq >= min_doc_freq and freq >= min_freq and len(key) < max_length:
 | 
						|
                word = literal_eval(key)
 | 
						|
                smooth_count = counts.smoother(int(freq))
 | 
						|
                probs[word] = math.log(smooth_count) - log_total
 | 
						|
    oov_prob = math.log(counts.smoother(0)) - log_total
 | 
						|
    return probs, oov_prob
 | 
						|
 | 
						|
 | 
						|
def read_clusters(clusters_loc):
 | 
						|
    print("Reading clusters...")
 | 
						|
    clusters = {}
 | 
						|
    if ftfy is None:
 | 
						|
        user_warning(Warnings.W004)
 | 
						|
    with clusters_loc.open() as f:
 | 
						|
        for line in tqdm(f):
 | 
						|
            try:
 | 
						|
                cluster, word, freq = line.split()
 | 
						|
                if ftfy is not None:
 | 
						|
                    word = ftfy.fix_text(word)
 | 
						|
            except ValueError:
 | 
						|
                continue
 | 
						|
            # If the clusterer has only seen the word a few times, its
 | 
						|
            # cluster is unreliable.
 | 
						|
            if int(freq) >= 3:
 | 
						|
                clusters[word] = cluster
 | 
						|
            else:
 | 
						|
                clusters[word] = '0'
 | 
						|
    # Expand clusters with re-casing
 | 
						|
    for word, cluster in list(clusters.items()):
 | 
						|
        if word.lower() not in clusters:
 | 
						|
            clusters[word.lower()] = cluster
 | 
						|
        if word.title() not in clusters:
 | 
						|
            clusters[word.title()] = cluster
 | 
						|
        if word.upper() not in clusters:
 | 
						|
            clusters[word.upper()] = cluster
 | 
						|
    return clusters
 |