added save route

This commit is contained in:
Alexander Karpov 2023-05-27 01:03:34 +03:00
parent 71432809ad
commit 60a019fcf7
8 changed files with 415 additions and 98 deletions

View File

@ -1,3 +1,5 @@
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers from rest_framework import serializers
from rest_framework.generics import get_object_or_404 from rest_framework.generics import get_object_or_404
@ -9,6 +11,8 @@
BasePoint, BasePoint,
Region, Region,
Restaurant, Restaurant,
UserRoute,
UserRouteDate,
) )
@ -78,20 +82,17 @@ class Meta:
class InputRoutePointSerializer(serializers.Serializer): class InputRoutePointSerializer(serializers.Serializer):
type = serializers.ChoiceField(choices=["point", "transition"]) type = serializers.ChoiceField(choices=["point", "transition"])
time = serializers.IntegerField(min_value=0, required=True) duration = serializers.IntegerField(min_value=0, required=True)
# point # point
point = serializers.CharField( point = serializers.CharField(
min_length=24, max_length=24, required=False, allow_blank=True, allow_null=True min_length=24, max_length=24, required=False, allow_blank=True, allow_null=True
) )
point_type = serializers.CharField(
required=False, allow_blank=True, allow_null=True
)
# transition # transition
point_from = serializers.CharField(
min_length=24, max_length=24, required=False, allow_blank=True, allow_null=True
)
point_to = serializers.CharField(
min_length=24, max_length=24, required=False, allow_blank=True, allow_null=True
)
distance = serializers.FloatField(min_value=0, required=False, allow_null=True) distance = serializers.FloatField(min_value=0, required=False, allow_null=True)
def validate(self, data): def validate(self, data):
@ -99,24 +100,49 @@ def validate(self, data):
if "point" not in data or not data["point"]: if "point" not in data or not data["point"]:
raise serializers.ValidationError("Point id is required") raise serializers.ValidationError("Point id is required")
get_object_or_404(BasePoint, oid=data["point"]) get_object_or_404(BasePoint, oid=data["point"])
if "distance" not in data or not data["point_type"]:
raise serializers.ValidationError("Point type is required")
else: else:
if "point_to" not in data or not data["point_to"]:
raise serializers.ValidationError("Point to id is required")
get_object_or_404(BasePoint, oid=data["point_to"])
if "point_from" not in data or not data["point_from"]:
raise serializers.ValidationError("Point from id is required")
get_object_or_404(BasePoint, oid=data["point_from"])
if "distance" not in data or not data["distance"]: if "distance" not in data or not data["distance"]:
raise serializers.ValidationError("Distance is required") raise serializers.ValidationError("Distance is required")
return data return data
class InputRouteSerializer(serializers.Serializer): class InputRouteDateSerializer(serializers.Serializer):
date = serializers.DateField()
points = serializers.ListSerializer(child=InputRoutePointSerializer()) points = serializers.ListSerializer(child=InputRoutePointSerializer())
class ResaurantSerializer(serializers.ModelSerializer): class InputRouteSerializer(serializers.Serializer):
dates = serializers.ListSerializer(child=InputRouteDateSerializer())
class ListUserRouteSerializer(serializers.ModelSerializer):
class Meta:
model = UserRoute
fields = ["id", "created"]
class UserRouteDateSerializer(serializers.ModelSerializer):
points = serializers.SerializerMethodField(method_name="get_points")
@extend_schema_field(InputRoutePointSerializer)
def get_points(self, obj):
return [x.get_json() for x in obj.points.all()]
class Meta:
model = UserRouteDate
fields = ["date", "points"]
class UserRouteSerializer(serializers.ModelSerializer):
class Meta:
model = UserRoute
fields = ["created", "dates"]
class RestaurantSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Restaurant model = Restaurant
exclude = ("phones",) exclude = ("phones",)

View File

@ -4,14 +4,18 @@
BuildRouteApiView, BuildRouteApiView,
ListRegionApiView, ListRegionApiView,
ListCityApiView, ListCityApiView,
SaveRouteSerializer, SaveRouteApiView,
ListUserFavoriteRoutes,
RetrieveRoute,
) )
app_name = "events" app_name = "events"
urlpatterns = [ urlpatterns = [
path("route/build", BuildRouteApiView.as_view(), name="build_route"), path("route/build", BuildRouteApiView.as_view(), name="build_route"),
path("route/save", SaveRouteSerializer.as_view(), name="save_route"), path("route/save", SaveRouteApiView.as_view(), name="save_route"),
path("route/list", ListUserFavoriteRoutes.as_view(), name="list_routes"),
path("route/<int:pk>", RetrieveRoute.as_view(), name="get_route"),
path("data/regions", ListRegionApiView.as_view(), name="regions"), path("data/regions", ListRegionApiView.as_view(), name="regions"),
path("data/cities", ListCityApiView.as_view(), name="cities"), path("data/cities", ListCityApiView.as_view(), name="cities"),
] ]

View File

@ -1,4 +1,10 @@
from rest_framework.generics import GenericAPIView, ListAPIView, get_object_or_404 from rest_framework.generics import (
GenericAPIView,
ListAPIView,
get_object_or_404,
RetrieveAPIView,
)
from rest_framework.exceptions import MethodNotAllowed
from rest_framework.response import Response from rest_framework.response import Response
from drf_spectacular.utils import extend_schema from drf_spectacular.utils import extend_schema
from django.db.models import Count from django.db.models import Count
@ -15,6 +21,8 @@
RouteInputSerializer, RouteInputSerializer,
CitySerializer, CitySerializer,
InputRouteSerializer, InputRouteSerializer,
ListUserRouteSerializer,
UserRouteSerializer,
) )
from passfinder.events.models import ( from passfinder.events.models import (
BasePoint, BasePoint,
@ -23,6 +31,7 @@
UserRoute, UserRoute,
UserRoutePoint, UserRoutePoint,
UserRouteTransaction, UserRouteTransaction,
UserRouteDate,
) )
@ -99,7 +108,7 @@ class ListCityApiView(ListAPIView):
) )
class SaveRouteSerializer(GenericAPIView): class SaveRouteApiView(GenericAPIView):
serializer_class = InputRouteSerializer serializer_class = InputRouteSerializer
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
@ -107,16 +116,37 @@ def post(self, request, *args, **kwargs):
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
data = serializer.data data = serializer.data
route = UserRoute.objects.create(user=self.request.user) route = UserRoute.objects.create(user=self.request.user)
for point in data["points"]: for date in data["dates"]:
if point["type"] == "point": date_obj = UserRouteDate.objects.create(date=date["date"], route=route)
UserRoutePoint.objects.create( for point in date["points"]:
route=route, point=BasePoint.objects.get(oid=point["point"]) if point["type"] == "point":
) UserRoutePoint.objects.create(
else: date=date_obj,
UserRouteTransaction.objects.create( duration=point["duration"],
route=route, point=BasePoint.objects.get(oid=point["point"]),
point_from=BasePoint.objects.get(oid=point["point_from"]), )
point_to=BasePoint.objects.get(oid=point["point_to"]), else:
) UserRouteTransaction.objects.create(
date=date_obj,
duration=point["duration"],
distance=point["distance"],
)
return Response(data=data) return Response(data=data)
class ListUserFavoriteRoutes(ListAPIView):
serializer_class = ListUserRouteSerializer
def get_queryset(self):
return UserRoute.objects.filter(user=self.request.user)
class RetrieveRoute(RetrieveAPIView):
serializer_class = UserRouteSerializer
def get_object(self):
route = get_object_or_404(UserRoute, pk=self.kwargs["pk"])
if route.user != self.request.user:
raise MethodNotAllowed
return route

View File

@ -0,0 +1,49 @@
# Generated by Django 4.2.1 on 2023-05-26 21:13
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("events", "0024_basepoint_price"),
]
operations = [
migrations.RemoveField(
model_name="userroute",
name="user",
),
migrations.RemoveField(
model_name="userroutepoint",
name="baseuserroutepoint_ptr",
),
migrations.RemoveField(
model_name="userroutepoint",
name="point",
),
migrations.RemoveField(
model_name="userroutetransaction",
name="baseuserroutepoint_ptr",
),
migrations.RemoveField(
model_name="userroutetransaction",
name="point_from",
),
migrations.RemoveField(
model_name="userroutetransaction",
name="point_to",
),
migrations.DeleteModel(
name="BaseUserRoutePoint",
),
migrations.DeleteModel(
name="UserRoute",
),
migrations.DeleteModel(
name="UserRoutePoint",
),
migrations.DeleteModel(
name="UserRouteTransaction",
),
]

View File

@ -0,0 +1,156 @@
# Generated by Django 4.2.1 on 2023-05-26 21:14
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
("contenttypes", "0002_remove_content_type_name"),
("events", "0025_remove_userroute_user_and_more"),
]
operations = [
migrations.CreateModel(
name="BaseUserRouteDatePoint",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("duration", models.IntegerField()),
],
options={
"abstract": False,
"base_manager_name": "objects",
},
),
migrations.CreateModel(
name="UserRoute",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("created", models.DateTimeField(auto_now_add=True)),
(
"user",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="routes",
to=settings.AUTH_USER_MODEL,
),
),
],
),
migrations.CreateModel(
name="UserRouteTransaction",
fields=[
(
"baseuserroutedatepoint_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="events.baseuserroutedatepoint",
),
),
("distance", models.FloatField()),
],
options={
"abstract": False,
"base_manager_name": "objects",
},
bases=("events.baseuserroutedatepoint",),
),
migrations.CreateModel(
name="UserRouteDate",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("date", models.DateField()),
(
"route",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="dates",
to="events.userroute",
),
),
],
options={
"unique_together": {("date", "route")},
},
),
migrations.AddField(
model_name="baseuserroutedatepoint",
name="date",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="points",
to="events.userroutedate",
),
),
migrations.AddField(
model_name="baseuserroutedatepoint",
name="polymorphic_ctype",
field=models.ForeignKey(
editable=False,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="polymorphic_%(app_label)s.%(class)s_set+",
to="contenttypes.contenttype",
),
),
migrations.CreateModel(
name="UserRoutePoint",
fields=[
(
"baseuserroutedatepoint_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="events.baseuserroutedatepoint",
),
),
("point_type", models.CharField(max_length=250)),
(
"point",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="events.basepoint",
),
),
],
options={
"abstract": False,
"base_manager_name": "objects",
},
bases=("events.baseuserroutedatepoint",),
),
]

View File

@ -245,23 +245,45 @@ def __str__(self):
return f"{self.user}'s route" return f"{self.user}'s route"
class BaseUserRoutePoint(PolymorphicModel): class UserRouteDate(models.Model):
date = models.DateField()
route = models.ForeignKey( route = models.ForeignKey(
"UserRoute", related_name="points", on_delete=models.CASCADE "UserRoute", related_name="dates", on_delete=models.CASCADE
) )
class Meta:
unique_together = ("date", "route")
class UserRoutePoint(BaseUserRoutePoint):
class BaseUserRouteDatePoint(PolymorphicModel):
date = models.ForeignKey(
"UserRouteDate", related_name="points", on_delete=models.CASCADE
)
duration = models.IntegerField()
class UserRoutePoint(BaseUserRouteDatePoint):
type = "point" type = "point"
point = models.ForeignKey("BasePoint", on_delete=models.CASCADE) point = models.ForeignKey("BasePoint", on_delete=models.CASCADE)
point_type = models.CharField(max_length=250)
def get_json(self):
return {
"type": "point",
"duration": self.duration,
"point": self.point.oid,
"point_name": self.point.title,
"point_type": self.point_type,
}
class UserRouteTransaction(BaseUserRoutePoint): class UserRouteTransaction(BaseUserRouteDatePoint):
type = "transition" type = "transition"
point_from = models.ForeignKey(
"BasePoint", related_name="user_route_point_from", on_delete=models.CASCADE
)
point_to = models.ForeignKey(
"BasePoint", related_name="user_route_point_to", on_delete=models.CASCADE
)
distance = models.FloatField() distance = models.FloatField()
def get_json(self):
return {
"type": "transition",
"duration": self.duration,
"distance": self.distance,
}

View File

@ -2,7 +2,12 @@
from .mapping.mapping import * from .mapping.mapping import *
from .models.models import * from .models.models import *
from passfinder.events.models import Event, Region, Hotel, BasePoint, City, Restaurant from passfinder.events.models import Event, Region, Hotel, BasePoint, City, Restaurant
from passfinder.events.api.serializers import HotelSerializer, EventSerializer, ResaurantSerializer, ObjectRouteSerializer from passfinder.events.api.serializers import (
HotelSerializer,
EventSerializer,
RestaurantSerializer,
ObjectRouteSerializer,
)
from passfinder.recomendations.models import * from passfinder.recomendations.models import *
from random import choice, sample from random import choice, sample
from collections import Counter from collections import Counter
@ -44,12 +49,7 @@ def nearest_attraction(attraction, nearest_n):
def nearest_mus(museum, nearest_n): def nearest_mus(museum, nearest_n):
return get_nearest_( return get_nearest_(
museum, museum, "museum", mus_mapping, rev_mus_mapping, nearest_n, mus_model
"museum",
mus_mapping,
rev_mus_mapping,
nearest_n,
mus_model
) )
@ -94,9 +94,9 @@ def get_nearest_event(event, nearest_n):
return nearest_concert(event, nearest_n) return nearest_concert(event, nearest_n)
if event.type == "movie": if event.type == "movie":
return nearest_movie(event, nearest_n) return nearest_movie(event, nearest_n)
if event.type == 'museum': if event.type == "museum":
return nearest_mus(event, nearest_n) return nearest_mus(event, nearest_n)
if event.type == 'attraction': if event.type == "attraction":
return nearest_attraction(event, nearest_n) return nearest_attraction(event, nearest_n)
@ -259,7 +259,7 @@ def dist_func(event1: Event, event2: Event):
return dist return dist
except: except:
return 1000000 return 1000000
#return (event1.lon - event2.lon) ** 2 + (event1.lat - event2.lat) ** 2 # return (event1.lon - event2.lon) ** 2 + (event1.lat - event2.lat) ** 2
def generate_nearest(): def generate_nearest():
@ -299,7 +299,7 @@ def generate_nearest_restaurants():
nr.save() nr.save()
if i % 100 == 0: if i % 100 == 0:
print(i) print(i)
for i, hotel in enumerate(Hotel.objects.all()): for i, hotel in enumerate(Hotel.objects.all()):
sorted_rests = list(sorted(rests.copy(), key=lambda x: dist_func(x, hotel))) sorted_rests = list(sorted(rests.copy(), key=lambda x: dist_func(x, hotel)))
nr = NearestRestaurantToHotel.objects.create(hotel=hotel) nr = NearestRestaurantToHotel.objects.create(hotel=hotel)
@ -309,7 +309,6 @@ def generate_nearest_restaurants():
print(i) print(i)
def match_points(): def match_points():
regions = list(City.objects.all()) regions = list(City.objects.all())
for i, point in enumerate(Event.objects.all()): for i, point in enumerate(Event.objects.all()):
@ -349,10 +348,12 @@ def calculate_favorite_metric(event: Event, user: User):
if event.type == "movie": if event.type == "movie":
preferred = pref.preffered_movies.all() preferred = pref.preffered_movies.all()
return calculate_mean_metric(preferred, event, cinema_model, rev_cinema_mapping) return calculate_mean_metric(preferred, event, cinema_model, rev_cinema_mapping)
if event.type == 'attraction': if event.type == "attraction":
preferred = pref.prefferred_attractions.all() preferred = pref.prefferred_attractions.all()
return calculate_mean_metric(preferred, event, attracion_model, rev_attraction_mapping) return calculate_mean_metric(
if event.type == 'museum': preferred, event, attracion_model, rev_attraction_mapping
)
if event.type == "museum":
preferred = pref.prefferred_museums.all() preferred = pref.prefferred_museums.all()
return calculate_mean_metric(preferred, event, mus_model, rev_mus_mapping) return calculate_mean_metric(preferred, event, mus_model, rev_mus_mapping)
return 1000000 return 1000000
@ -366,7 +367,7 @@ def get_nearest_favorite(
if candidate not in exclude_events: if candidate not in exclude_events:
first_event = candidate first_event = candidate
break break
if first_event is None: if first_event is None:
result = events[0] result = events[0]
else: else:
@ -408,28 +409,34 @@ def generate_point(point: BasePoint):
"type": "point", "type": "point",
"point": event_data, "point": event_data,
"point_type": "point", "point_type": "point",
"time": timedelta(minutes=90+choice(range(-10, 90, 10))).seconds "time": timedelta(minutes=90 + choice(range(-10, 90, 10))).seconds,
} }
def generate_restaurant(point: BasePoint): def generate_restaurant(point: BasePoint):
rest_data = ObjectRouteSerializer(point).data rest_data = ObjectRouteSerializer(point).data
return { return {
"type": "point", "type": "point",
"point": rest_data, "point": rest_data,
"point_type": "restaurant", "point_type": "restaurant",
"time": timedelta(minutes=90+choice(range(-10, 90, 10))).seconds "time": timedelta(minutes=90 + choice(range(-10, 90, 10))).seconds,
} }
def generate_multiple_tours(user: User, city: City, start_date: datetime.date, end_date: datetime.date): def generate_multiple_tours(
user: User, city: City, start_date: datetime.date, end_date: datetime.date
):
hotels = sample(list(Hotel.objects.filter(city=city)), 5) hotels = sample(list(Hotel.objects.filter(city=city)), 5)
pool = Pool(5) pool = Pool(5)
return pool.map(generate_tour, [(user, start_date, end_date, hotel) for hotel in hotels]) return pool.map(
generate_tour, [(user, start_date, end_date, hotel) for hotel in hotels]
)
def generate_tour(user: User, city: City, start_date: datetime.date, end_date: datetime.date): def generate_tour(
user: User, city: City, start_date: datetime.date, end_date: datetime.date
):
UserPreferences.objects.get_or_create(user=user) UserPreferences.objects.get_or_create(user=user)
hotel = choice(list(Hotel.objects.filter(city=city))) hotel = choice(list(Hotel.objects.filter(city=city)))
current_date = start_date current_date = start_date
@ -438,15 +445,10 @@ def generate_tour(user: User, city: City, start_date: datetime.date, end_date: d
while current_date < end_date: while current_date < end_date:
local_points, local_paths = generate_path(user, points, hotel) local_points, local_paths = generate_path(user, points, hotel)
points.extend(local_points) points.extend(local_points)
paths.append( paths.append({"date": current_date, "paths": local_paths})
{
'date': current_date,
'paths': local_paths
}
)
current_date += timedelta(days=1) current_date += timedelta(days=1)
return paths, points return paths, points
@ -456,55 +458,82 @@ def generate_hotel(hotel: Hotel):
"type": "point", "type": "point",
"point": hotel_data, "point": hotel_data,
"point_type": "hotel", "point_type": "hotel",
"time": timedelta(minutes=90+choice(range(-10, 90, 10))).seconds "time": timedelta(minutes=90 + choice(range(-10, 90, 10))).seconds,
} }
def generate_path(user: User, disallowed_points: Iterable[BasePoint], hotel: Hotel): def generate_path(user: User, disallowed_points: Iterable[BasePoint], hotel: Hotel):
# region_events = Event.objects.filter(region=region) # region_events = Event.objects.filter(region=region)
#candidates = NearestHotel.objects.get(hotel=hotel).nearest_events.all() # candidates = NearestHotel.objects.get(hotel=hotel).nearest_events.all()
allowed_types = ['museum', 'attraction'] allowed_types = ["museum", "attraction"]
start_point = NearestRestaurantToHotel.objects.get(hotel=hotel).restaurants.first() start_point = NearestRestaurantToHotel.objects.get(hotel=hotel).restaurants.first()
candidates = list(filter(lambda x: x.type in allowed_types, map(lambda x: x.event, start_point.nearestrestauranttoevent_set.all()[0:100]))) candidates = list(
filter(
lambda x: x.type in allowed_types,
map(
lambda x: x.event, start_point.nearestrestauranttoevent_set.all()[0:100]
),
)
)
points = [start_point] points = [start_point]
path = [ path = [
generate_hotel(hotel), generate_hotel(hotel),
generate_route(start_point, hotel), generate_route(start_point, hotel),
generate_restaurant(points[-1]) generate_restaurant(points[-1]),
] ]
start_time = datetime.combine(datetime.now(), time(hour=10)) start_time = datetime.combine(datetime.now(), time(hour=10))
how_many_eat = 1 how_many_eat = 1
while start_time.hour < 22 and start_time.day == datetime.now().day: while start_time.hour < 22 and start_time.day == datetime.now().day:
if (start_time.hour > 14 and how_many_eat == 1) or (start_time.hour > 20 and how_many_eat == 2): if (start_time.hour > 14 and how_many_eat == 1) or (
point = NearestRestaurantToEvent.objects.get(event=points[-1]).restaurants.all()[0] start_time.hour > 20 and how_many_eat == 2
):
point = NearestRestaurantToEvent.objects.get(
event=points[-1]
).restaurants.all()[0]
points.append(point) points.append(point)
candidates = list(filter(lambda x: x.type in allowed_types, map(lambda x: x.event, point.nearestrestauranttoevent_set.all()[0:100]))) candidates = list(
filter(
lambda x: x.type in allowed_types,
map(
lambda x: x.event,
point.nearestrestauranttoevent_set.all()[0:100],
),
)
)
if not len(candidates): if not len(candidates):
candidates = list(map(lambda x: x.event, point.nearestrestauranttoevent_set.all()[0:100])) candidates = list(
map(
lambda x: x.event,
point.nearestrestauranttoevent_set.all()[0:100],
)
)
path.append(generate_restaurant(points[-1])) path.append(generate_restaurant(points[-1]))
start_time += timedelta(seconds=path[-1]['time']) start_time += timedelta(seconds=path[-1]["time"])
how_many_eat += 1 how_many_eat += 1
continue continue
if start_time.hour > 17: if start_time.hour > 17:
allowed_types = ['play', 'concert', 'movie'] allowed_types = ["play", "concert", "movie"]
if candidates is None: if candidates is None:
candidates = NearestEvent.objects.get(event=points[-1]).nearest.filter(type__in=allowed_types) candidates = NearestEvent.objects.get(event=points[-1]).nearest.filter(
type__in=allowed_types
)
if not len(candidates): if not len(candidates):
candidates = NearestEvent.objects.get(event=points[-1]).nearest.all() candidates = NearestEvent.objects.get(event=points[-1]).nearest.all()
try: try:
points.append(get_nearest_favorite(candidates, user, points + disallowed_points)) points.append(
get_nearest_favorite(candidates, user, points + disallowed_points)
)
except AttributeError: except AttributeError:
points.append(get_nearest_favorite(candidates, user, points)) points.append(get_nearest_favorite(candidates, user, points))
@ -518,17 +547,21 @@ def generate_path(user: User, disallowed_points: Iterable[BasePoint], hotel: Hot
return points, path return points, path
def calculate_distance(sample1: Event, samples: Iterable[Event], model: AnnoyIndex, rev_mapping): def calculate_distance(
sample1: Event, samples: Iterable[Event], model: AnnoyIndex, rev_mapping
):
metrics = [] metrics = []
for sample in samples: for sample in samples:
metrics.append(model.get_distance(rev_mapping[sample1.oid], rev_mapping[sample.oid])) metrics.append(
model.get_distance(rev_mapping[sample1.oid], rev_mapping[sample.oid])
)
return sum(metrics) / len(metrics) return sum(metrics) / len(metrics)
def get_onboarding_attractions(): def get_onboarding_attractions():
sample_attractions = sample(list(Event.objects.filter(type='attraction')), 200) sample_attractions = sample(list(Event.objects.filter(type="attraction")), 200)
first_attraction = choice(sample_attractions) first_attraction = choice(sample_attractions)
attractions = [first_attraction] attractions = [first_attraction]
@ -537,12 +570,10 @@ def get_onboarding_attractions():
mx_dist = 0 mx_dist = 0
mx_attraction = None mx_attraction = None
for att in sample_attractions: for att in sample_attractions:
if att in attractions: continue if att in attractions:
continue
local_dist = calculate_distance( local_dist = calculate_distance(
att, att, attractions, attracion_model, rev_attraction_mapping
attractions,
attracion_model,
rev_attraction_mapping
) )
if local_dist > mx_dist: if local_dist > mx_dist:
mx_dist = local_dist mx_dist = local_dist
@ -552,4 +583,4 @@ def get_onboarding_attractions():
def get_onboarding_hotels(stars=Iterable[int]): def get_onboarding_hotels(stars=Iterable[int]):
return sample(list(Hotel.objects.filter(stars__in=stars)), 10) return sample(list(Hotel.objects.filter(stars__in=stars)), 10)

View File

@ -3,7 +3,6 @@
from rest_framework.generics import get_object_or_404 from rest_framework.generics import get_object_or_404
from passfinder.events.models import BasePoint from passfinder.events.models import BasePoint
from passfinder.users.clickhouse_models import UserPreferenceClickHouse
from passfinder.users.models import UserPreference from passfinder.users.models import UserPreference
User = get_user_model() User = get_user_model()