mirror of
https://github.com/magnum-opus-nn-cp/backend.git
synced 2024-11-10 18:06:33 +03:00
added ml tasks throttling
This commit is contained in:
parent
8b4c8c3dc7
commit
65ac64e48a
|
@ -67,6 +67,7 @@ class ProcessedTextSerializer(serializers.ModelSerializer):
|
|||
summary = serializers.CharField(source="summery")
|
||||
file_name = serializers.SerializerMethodField("get_file_name")
|
||||
|
||||
@extend_schema_field(serializers.CharField)
|
||||
def get_file_name(self, obj):
|
||||
if not obj.file:
|
||||
return "Text"
|
||||
|
|
|
@ -41,24 +41,31 @@ class RetrieveEntryApiView(generics.RetrieveAPIView):
|
|||
lookup_url_kwarg = "uuid"
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
get=extend_schema(parameters=[OpenApiParameter(name="type", type=str)])
|
||||
)
|
||||
class UpdateTextDescriptionApiView(generics.GenericAPIView):
|
||||
serializer_class = TextSubmitSerializer
|
||||
queryset = Text.objects.none()
|
||||
permission_classes = [permissions.AllowAny]
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
# TODO: add get param for gen
|
||||
type = self.request.query_params.get("type")
|
||||
text = get_object_or_404(Text, id=self.kwargs["id"])
|
||||
run_mth = ["f", "bert"]
|
||||
if type in run_mth:
|
||||
run_mth = [type]
|
||||
|
||||
if text.description:
|
||||
e = False
|
||||
if "bert" not in text.description:
|
||||
if "bert" not in text.description and "bert" in run_mth:
|
||||
e = True
|
||||
text.description["bert"] = {}
|
||||
re = requests.post(ML_HOST + "bert/describe", json={"data": text.text})
|
||||
if re.status_code == 200:
|
||||
text.description["bert"]["text"] = re.json()["text"]
|
||||
|
||||
if "f" not in text.description:
|
||||
if "f" not in text.description and "f" in run_mth:
|
||||
e = True
|
||||
text.description["f"] = {}
|
||||
re = requests.post(ML_HOST + "tfidf/describe", json={"data": text.text})
|
||||
|
@ -68,12 +75,17 @@ def get(self, request, *args, **kwargs):
|
|||
text.save(update_fields=["description"])
|
||||
|
||||
else:
|
||||
text.description = {"bert": {}, "f": {}}
|
||||
re = requests.post(ML_HOST + "bert/describe", json={"data": text.text})
|
||||
if re.status_code == 200:
|
||||
text.description["bert"]["text"] = re.json()["text"]
|
||||
re = requests.post(ML_HOST + "tfidf/describe", json={"data": text.text})
|
||||
if re.status_code == 200:
|
||||
text.description["f"]["text"] = re.json()["text"]
|
||||
text.description = {}
|
||||
if "bert" in run_mth:
|
||||
text.description["bert"] = {}
|
||||
re = requests.post(ML_HOST + "bert/describe", json={"data": text.text})
|
||||
if re.status_code == 200:
|
||||
text.description["bert"]["text"] = re.json()["text"]
|
||||
|
||||
if "f" in run_mth:
|
||||
text.description["f"] = {}
|
||||
re = requests.post(ML_HOST + "tfidf/describe", json={"data": text.text})
|
||||
if re.status_code == 200:
|
||||
text.description["f"]["text"] = re.json()["text"]
|
||||
text.save(update_fields=["description"])
|
||||
return Response(data=ProcessedTextSerializer().to_representation(instance=text))
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
from press_release_nl.processor.models import Entry, Text
|
||||
from press_release_nl.processor.services import create_highlighted_document
|
||||
from press_release_nl.utils.celery import get_scheduled_tasks_name
|
||||
|
||||
ML_HOST = "http://192.168.107.95:8000/"
|
||||
# ML_HOST = "https://dev2.akarpov.ru/"
|
||||
|
@ -27,12 +28,16 @@ def load_text(pk: int):
|
|||
|
||||
@shared_task
|
||||
def run_ml(pk: int, f=True):
|
||||
if get_scheduled_tasks_name().count("press_release_nl.processor.tasks.run_ml") >= 2:
|
||||
run_ml.apply_async(kwargs={"pk": pk}, countdown=10)
|
||||
return
|
||||
try:
|
||||
entry = Entry.objects.get(pk=pk)
|
||||
except Entry.DoesNotExist:
|
||||
return
|
||||
if entry.texts.filter(text__isnull=True).exists():
|
||||
sleep(10)
|
||||
run_ml.apply_async(kwargs={"pk": pk}, countdown=10)
|
||||
return
|
||||
for text in entry.texts.all():
|
||||
re_bert = requests.post(ML_HOST + "bert/process", json={"data": text.text})
|
||||
re_tf = requests.post(ML_HOST + "tfidf/process", json={"data": text.text})
|
||||
|
|
10
press_release_nl/utils/celery.py
Normal file
10
press_release_nl/utils/celery.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from config.celery_app import app
|
||||
|
||||
|
||||
def get_scheduled_tasks_name() -> [str]:
|
||||
i = app.control.inspect()
|
||||
t = i.active()
|
||||
all_tasks = []
|
||||
for worker, tasks in t.items():
|
||||
all_tasks += tasks
|
||||
return [x["name"] for x in all_tasks]
|
Loading…
Reference in New Issue
Block a user