diff --git a/README.md b/README.md index 710d8db..9362501 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ $ ./app/manage.py runserver }, { "value": "каучук", - "type": "Матерьял" + "type": "Материал" }, { "value": "синий", diff --git a/app/conf/settings/base.py b/app/conf/settings/base.py index 1c94469..27c6c59 100644 --- a/app/conf/settings/base.py +++ b/app/conf/settings/base.py @@ -216,3 +216,4 @@ REST_FRAMEWORK = { # django-cors-headers CORS_ALLOW_ALL_ORIGINS = True +YANDEX_DICT_API_KEY = env.str('YANDEX_DICT') diff --git a/app/search/services/search/bert.py b/app/search/services/search/bert.py new file mode 100644 index 0000000..f8243fd --- /dev/null +++ b/app/search/services/search/bert.py @@ -0,0 +1,25 @@ +from transformers import BertTokenizer, BertModel +import torch +import numpy as np + +from scipy.spatial import distance + +tokenizer = BertTokenizer.from_pretrained("DeepPavlov/rubert-base-cased-sentence") +model = BertModel.from_pretrained("DeepPavlov/rubert-base-cased-sentence") + + +def get_embedding(word): + inputs = tokenizer(word, return_tensors="pt") + outputs = model(**inputs) + word_vect = outputs.pooler_output.detach().numpy() + return word_vect + + +def get_distance(first_word, second_word): + w1 = get_embedding(first_word) + w2 = get_embedding(second_word) + cos_distance = np.round(distance.cosine(w1, w2), 2) + return 1 - cos_distance + + +get_distance("электрогитара", "электрическая гитара") diff --git a/app/search/services/search/main.py b/app/search/services/search/main.py index a0450f6..5e61583 100644 --- a/app/search/services/search/main.py +++ b/app/search/services/search/main.py @@ -21,7 +21,7 @@ def process_search(data: List[dict], limit=5, offset=0) -> List[dict]: qs = qs & apply_qs_search(val) qs = qs.order_by("-score") elif typ == "All": - qs = apply_all_qs_search(qs, val) & qs + qs = apply_all_qs_search(val) & qs elif typ == "Category": qs = apply_qs_category(qs, val) qs = qs.order_by("-score") @@ -35,4 +35,4 @@ def process_search(data: List[dict], limit=5, offset=0) -> List[dict]: qs = qs.filter(unit_characteristics__in=val) else: qs = qs.filter(characteristics__in=val) - return [x.serialize_self() for x in qs.distinct()[offset: offset + limit]] + return [x.serialize_self() for x in qs.distinct()[offset : offset + limit]] diff --git a/app/search/services/search/methods.py b/app/search/services/search/methods.py index e25e74a..e9368cd 100644 --- a/app/search/services/search/methods.py +++ b/app/search/services/search/methods.py @@ -1,11 +1,14 @@ +from functools import cache from typing import List +from django.utils.text import slugify + from search.models import ( Product, ProductCharacteristic, ProductUnitCharacteristic, ) -from search.services.spell_check import pos +from search.services.spell_check import pos, spell_check def _clean_text(text: str) -> List[str]: @@ -13,9 +16,11 @@ def _clean_text(text: str) -> List[str]: text = text.replace(st, " ") text = text.split() functors_pos = {"INTJ", "PRCL", "CONJ", "PREP"} # function words - return [word for word in text if pos(word) not in functors_pos] + text = [word for word in text if pos(word) not in functors_pos] + return [spell_check(x) for x in text] +@cache def process_unit_operation(unit: ProductUnitCharacteristic.objects, operation: str): if operation.startswith("<=") or operation.startswith("=<"): return unit.filter( @@ -41,20 +46,20 @@ def process_unit_operation(unit: ProductUnitCharacteristic.objects, operation: s return unit +@cache def apply_qs_search(text: str): text = _clean_text(text) - products = Product.objects.none() + qs = Product.objects.filter() for word in text: - products = ( - products - | Product.objects.filter(name__unaccent__icontains=word) - | Product.objects.filter(name__unaccent__trigram_similar=word) + qs = qs.filter(name__unaccent__trigram_similar=word) | qs.filter( + name__unaccent__icontains=word ) - products = products.order_by("-score") + products = qs.order_by("-score") return products -def apply_all_qs_search(orig_qs, text: str): +@cache +def apply_all_qs_search(text: str): # words text = _clean_text(text) @@ -105,9 +110,23 @@ def apply_all_qs_search(orig_qs, text: str): ) qs = ( Product.objects.filter(name__icontains=word) + | Product.objects.filter(name__trigram_similar=word) | Product.objects.filter(category__name__icontains=word) | Product.objects.filter(characteristics__in=car) ) + if any( + x in word + for x in "абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ" + ): + qs = qs | Product.objects.filter( + name__icontains=word.translate( + str.maketrans( + "абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ", + "abvgdeejzijklmnoprstufhzcss_y_euaABVGDEEJZIJKLMNOPRSTUFHZCSS_Y_EUA", + ) + ) + ) + print(qs) prod = prod & qs if u_qs: @@ -116,11 +135,13 @@ def apply_all_qs_search(orig_qs, text: str): return prod +@cache def apply_qs_category(qs, name: str): qs = qs.filter(category__name__icontains=name) return qs +@cache def appy_qs_characteristic(qs, name: str): char = ProductCharacteristic.objects.filter(product__in=qs) char = char.filter(characteristic__value__icontains=name) | char.filter( diff --git a/app/search/services/search/prepare.py b/app/search/services/search/prepare.py index 6e19e80..22f601d 100644 --- a/app/search/services/search/prepare.py +++ b/app/search/services/search/prepare.py @@ -2,11 +2,15 @@ from typing import List, Dict from rest_framework.exceptions import ValidationError -from search.models import Characteristic, ProductCharacteristic, ProductUnitCharacteristic, UnitCharacteristic +from search.models import ( + Characteristic, + ProductCharacteristic, + ProductUnitCharacteristic, + UnitCharacteristic, +) from search.services.hints import get_hints from search.services.search.methods import process_unit_operation -) -from search.services.spell_check import spell_check_ru as spell_check +from search.services.spell_check import spell_check def apply_union(data: List[Dict]) -> List[Dict]: diff --git a/requirements/base.txt b/requirements/base.txt index bab147f..f44dd90 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -13,3 +13,8 @@ celery==5.2.7 pyspellchecker==0.7.0 pymorphy2 + +transformers +torch +scipy +numpy