added ml tasks throttling

This commit is contained in:
Alexander Karpov 2023-09-09 14:13:08 +03:00
parent 8b4c8c3dc7
commit 65ac64e48a
4 changed files with 39 additions and 11 deletions

View File

@ -67,6 +67,7 @@ class ProcessedTextSerializer(serializers.ModelSerializer):
summary = serializers.CharField(source="summery")
file_name = serializers.SerializerMethodField("get_file_name")
@extend_schema_field(serializers.CharField)
def get_file_name(self, obj):
if not obj.file:
return "Text"

View File

@ -41,24 +41,31 @@ class RetrieveEntryApiView(generics.RetrieveAPIView):
lookup_url_kwarg = "uuid"
@extend_schema_view(
get=extend_schema(parameters=[OpenApiParameter(name="type", type=str)])
)
class UpdateTextDescriptionApiView(generics.GenericAPIView):
serializer_class = TextSubmitSerializer
queryset = Text.objects.none()
permission_classes = [permissions.AllowAny]
def get(self, request, *args, **kwargs):
# TODO: add get param for gen
type = self.request.query_params.get("type")
text = get_object_or_404(Text, id=self.kwargs["id"])
run_mth = ["f", "bert"]
if type in run_mth:
run_mth = [type]
if text.description:
e = False
if "bert" not in text.description:
if "bert" not in text.description and "bert" in run_mth:
e = True
text.description["bert"] = {}
re = requests.post(ML_HOST + "bert/describe", json={"data": text.text})
if re.status_code == 200:
text.description["bert"]["text"] = re.json()["text"]
if "f" not in text.description:
if "f" not in text.description and "f" in run_mth:
e = True
text.description["f"] = {}
re = requests.post(ML_HOST + "tfidf/describe", json={"data": text.text})
@ -68,12 +75,17 @@ def get(self, request, *args, **kwargs):
text.save(update_fields=["description"])
else:
text.description = {"bert": {}, "f": {}}
re = requests.post(ML_HOST + "bert/describe", json={"data": text.text})
if re.status_code == 200:
text.description["bert"]["text"] = re.json()["text"]
re = requests.post(ML_HOST + "tfidf/describe", json={"data": text.text})
if re.status_code == 200:
text.description["f"]["text"] = re.json()["text"]
text.description = {}
if "bert" in run_mth:
text.description["bert"] = {}
re = requests.post(ML_HOST + "bert/describe", json={"data": text.text})
if re.status_code == 200:
text.description["bert"]["text"] = re.json()["text"]
if "f" in run_mth:
text.description["f"] = {}
re = requests.post(ML_HOST + "tfidf/describe", json={"data": text.text})
if re.status_code == 200:
text.description["f"]["text"] = re.json()["text"]
text.save(update_fields=["description"])
return Response(data=ProcessedTextSerializer().to_representation(instance=text))

View File

@ -6,6 +6,7 @@
from press_release_nl.processor.models import Entry, Text
from press_release_nl.processor.services import create_highlighted_document
from press_release_nl.utils.celery import get_scheduled_tasks_name
ML_HOST = "http://192.168.107.95:8000/"
# ML_HOST = "https://dev2.akarpov.ru/"
@ -27,12 +28,16 @@ def load_text(pk: int):
@shared_task
def run_ml(pk: int, f=True):
if get_scheduled_tasks_name().count("press_release_nl.processor.tasks.run_ml") >= 2:
run_ml.apply_async(kwargs={"pk": pk}, countdown=10)
return
try:
entry = Entry.objects.get(pk=pk)
except Entry.DoesNotExist:
return
if entry.texts.filter(text__isnull=True).exists():
sleep(10)
run_ml.apply_async(kwargs={"pk": pk}, countdown=10)
return
for text in entry.texts.all():
re_bert = requests.post(ML_HOST + "bert/process", json={"data": text.text})
re_tf = requests.post(ML_HOST + "tfidf/process", json={"data": text.text})

View File

@ -0,0 +1,10 @@
from config.celery_app import app
def get_scheduled_tasks_name() -> [str]:
i = app.control.inspect()
t = i.active()
all_tasks = []
for worker, tasks in t.items():
all_tasks += tasks
return [x["name"] for x in all_tasks]