diff --git a/src/extract/extract.py b/src/extract/extract.py index 28ea3f3..349abad 100644 --- a/src/extract/extract.py +++ b/src/extract/extract.py @@ -22,6 +22,7 @@ from shared.vectorflow_request import VectorflowRequest from services.rabbitmq.rabbit_service import create_connection_params from pika.exceptions import AMQPConnectionError +from shared.utils import update_batch_and_job_status logging.basicConfig(filename='./extract-log.txt', level=logging.INFO) logging.basicConfig(filename='./extract-error-log.txt', level=logging.ERROR) @@ -134,24 +135,7 @@ def remove_from_minio(filename): client = create_minio_client() client.remove_object(os.getenv("MINIO_BUCKET"), filename) -# TODO: refactor into utils -def update_batch_and_job_status(job_id, batch_status, batch_id): - try: - if not job_id and batch_id: - job = safe_db_operation(batch_service.get_batch, batch_id) - job_id = job.job_id - updated_batch_status = safe_db_operation(batch_service.update_batch_status, batch_id, batch_status) - job = safe_db_operation(job_service.update_job_with_batch, job_id, updated_batch_status) - if job.job_status == JobStatus.COMPLETED: - logging.info(f"Job {job_id} completed successfully") - elif job.job_status == JobStatus.PARTIALLY_COMPLETED: - logging.info(f"Job {job_id} partially completed. {job.batches_succeeded} out of {job.total_batches} batches succeeded") - elif job.job_status == JobStatus.FAILED: - logging.info(f"Job {job_id} failed. {job.batches_succeeded} out of {job.total_batches} batches succeeded") - - except Exception as e: - logging.error('Error updating job and batch status: %s', e) - safe_db_operation(job_service.update_job_status, job_id, JobStatus.FAILED) + #################### ## RabbitMQ Logic ## diff --git a/src/shared/utils.py b/src/shared/utils.py index 51b495f..a643fc5 100644 --- a/src/shared/utils.py +++ b/src/shared/utils.py @@ -1,6 +1,14 @@ +import sys +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))) import uuid import requests import json +import logging +from services.database.database import get_db, safe_db_operation +import services.database.job_service as job_service +from shared.job_status import JobStatus +import services.database.batch_service as batch_service def generate_uuid_from_tuple(t, namespace_uuid='6ba7b810-9dad-11d1-80b4-00c04fd430c8'): namespace = uuid.UUID(namespace_uuid) @@ -30,4 +38,37 @@ def send_embeddings_to_webhook(embedded_chunks: list[dict], job): json=data ) - return response \ No newline at end of file + return response + + +def update_batch_and_job_status(job_id, batch_status, batch_id): + try: + if not job_id and batch_id: + job = safe_db_operation(batch_service.get_batch, batch_id) + job_id = job.job_id + updated_batch_status = safe_db_operation(batch_service.update_batch_status, batch_id, batch_status) + job = safe_db_operation(job_service.update_job_with_batch, job_id, updated_batch_status) + if job.job_status == JobStatus.COMPLETED: + logging.info(f"Job {job_id} completed successfully") + elif job.job_status == JobStatus.PARTIALLY_COMPLETED: + logging.info(f"Job {job_id} partially completed. {job.batches_succeeded} out of {job.total_batches} batches succeeded") + elif job.job_status == JobStatus.FAILED: + logging.info(f"Job {job_id} failed. {job.batches_succeeded} out of {job.total_batches} batches succeeded") + + except Exception as e: + logging.error('Error updating job and batch status: %s', e) + safe_db_operation(job_service.update_job_status, job_id, JobStatus.FAILED) + + + +def update_batch_status(job_id, batch_status, batch_id, retries = None, bypass_retries=False): + try: + updated_batch_status = safe_db_operation(batch_service.update_batch_status, batch_id, batch_status) + logging.info(f"Status for batch {batch_id} as part of job {job_id} updated to {updated_batch_status}") + if updated_batch_status == BatchStatus.FAILED and (retries == config.MAX_BATCH_RETRIES or bypass_retries): + logging.info(f"Batch {batch_id} failed. Updating job status.") + update_batch_and_job_status(job_id, BatchStatus.FAILED, batch_id) + except Exception as e: + logging.error('Error updating batch status: %s', e) + + diff --git a/src/worker/tests/test_worker.py b/src/worker/tests/test_worker.py index eadb484..17e1057 100644 --- a/src/worker/tests/test_worker.py +++ b/src/worker/tests/test_worker.py @@ -19,7 +19,7 @@ class TestWorker(unittest.TestCase): @patch('services.database.job_service.get_job') @patch('services.database.batch_service.get_batch') @patch('worker.worker.embed_openai_batch') - @patch('worker.worker.update_batch_status') + @patch('shared.utils.update_batch_status') def test_process_batch_success( self, mock_update_batch_and_job_status, @@ -58,7 +58,7 @@ def test_process_batch_success( @patch('services.database.job_service.get_job') @patch('services.database.batch_service.get_batch') @patch('worker.worker.embed_openai_batch') - @patch('worker.worker.update_batch_status') + @patch('shared.utils.update_batch_status') def test_process_batch_success_different_model( self, mock_update_batch_and_job_status, @@ -99,7 +99,7 @@ def test_process_batch_success_different_model( @patch('services.database.job_service.get_job') @patch('services.database.batch_service.get_batch') @patch('worker.worker.embed_openai_batch') - @patch('worker.worker.update_batch_status') + @patch('shared.utils.update_batch_status') def test_process_batch_failure_no_vectors( self, mock_update_batch_and_job_status, @@ -138,7 +138,7 @@ def test_process_batch_failure_no_vectors( @patch('services.database.job_service.get_job') @patch('services.database.batch_service.get_batch') @patch('worker.worker.embed_openai_batch') - @patch('worker.worker.update_batch_status') + @patch('shared.utils.update_batch_status') def test_process_batch_failure_openai( self, mock_update_batch_and_job_status, @@ -179,7 +179,7 @@ def test_process_batch_failure_openai( @patch('services.database.job_service.get_job') @patch('services.database.batch_service.get_batch') @patch('worker.worker.embed_openai_batch') - @patch('worker.worker.update_batch_and_job_status') + @patch('shared.utils.update_batch_and_job_status') def test_process_batch_failure_validate_chunks( self, mock_update_batch_and_job_status, @@ -344,4 +344,4 @@ def test_chunk_sentence_by_characters_too_big(self): self.assertEqual(type(chunks[0]), dict) if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/src/worker/worker.py b/src/worker/worker.py index 6853ab8..f308f06 100644 --- a/src/worker/worker.py +++ b/src/worker/worker.py @@ -14,6 +14,7 @@ import worker.config as config import services.database.batch_service as batch_service import services.database.job_service as job_service +import shared.utils as utils import tiktoken from pika.exceptions import AMQPConnectionError from shared.chunk_strategy import ChunkStrategy @@ -22,7 +23,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from services.database.database import get_db, safe_db_operation from shared.job_status import JobStatus -from shared.utils import send_embeddings_to_webhook, generate_uuid_from_tuple +from shared.utils import send_embeddings_to_webhook, generate_uuid_from_tuple from services.rabbitmq.rabbit_service import create_connection_params from worker.vector_uploader import VectorUploader @@ -62,15 +63,15 @@ def process_batch(batch_id, source_data, vector_db_key, embeddings_api_key): upload_to_vector_db(batch_id, embedded_chunks) else: logging.error(f"Failed to get OPEN AI embeddings for batch {batch.id}. Adding batch to retry queue.") - update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id, batch.retries) + utils.update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id, batch.retries) except Exception as e: logging.error('Error embedding batch: %s', e) - update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id) + utils.update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id) else: logging.error('Unsupported embeddings type: %s', embeddings_type.value) - update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id, bypass_retries=True) + utils.update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id, bypass_retries=True) # NOTE: this method will embed mulitple chunks (a list of strings) at once and return a list of lists of floats (a list of embeddings) # NOTE: this assumes that the embedded chunks are returned in the same order the raw chunks were sent @@ -149,7 +150,7 @@ def chunk_data(batch, source_data, job): chunked_data = validate_chunks(chunked_data, job.chunk_validation_url) if not chunked_data: - update_batch_and_job_status(batch.job_id, BatchStatus.FAILED, batch.id) + utils.update_batch_and_job_status(batch.job_id, BatchStatus.FAILED, batch.id) raise Exception("Failed to chunk data") return chunked_data @@ -323,16 +324,6 @@ def create_batches_for_embedding(chunks, max_batch_size): embedding_batches = [chunks[i:i + max_batch_size] for i in range(0, len(chunks), max_batch_size)] return embedding_batches -# TODO: refactor into utils -def update_batch_status(job_id, batch_status, batch_id, retries = None, bypass_retries=False): - try: - updated_batch_status = safe_db_operation(batch_service.update_batch_status, batch_id, batch_status) - logging.info(f"Status for batch {batch_id} as part of job {job_id} updated to {updated_batch_status}") - if updated_batch_status == BatchStatus.FAILED and (retries == config.MAX_BATCH_RETRIES or bypass_retries): - logging.info(f"Batch {batch_id} failed. Updating job status.") - update_batch_and_job_status(job_id, BatchStatus.FAILED, batch_id) - except Exception as e: - logging.error('Error updating batch status: %s', e) def upload_to_vector_db(batch_id, text_embeddings_list): try: @@ -345,31 +336,13 @@ def upload_to_vector_db(batch_id, text_embeddings_list): def process_webhook_response(response, job_id, batch_id): if response and hasattr(response, 'status_code') and response.status_code == 200: - update_batch_and_job_status(job_id, BatchStatus.COMPLETED, batch_id) + utils.update_batch_and_job_status(job_id, BatchStatus.COMPLETED, batch_id) else: logging.error("Error sending embeddings to webhook. Response: %s", response) - update_batch_and_job_status(job_id, BatchStatus.FAILED, batch_id) + utils.update_batch_and_job_status(job_id, BatchStatus.FAILED, batch_id) if response.json() and response.json()['error']: logging.error("Error message: %s", response.json()['error']) -# TODO: refactor into utils -def update_batch_and_job_status(job_id, batch_status, batch_id): - try: - if not job_id and batch_id: - job = safe_db_operation(batch_service.get_batch, batch_id) - job_id = job.job_id - updated_batch_status = safe_db_operation(batch_service.update_batch_status, batch_id, batch_status) - job = safe_db_operation(job_service.update_job_with_batch, job_id, updated_batch_status) - if job.job_status == JobStatus.COMPLETED: - logging.info(f"Job {job_id} completed successfully") - elif job.job_status == JobStatus.PARTIALLY_COMPLETED: - logging.info(f"Job {job_id} partially completed. {job.batches_succeeded} out of {job.total_batches} batches succeeded") - elif job.job_status == JobStatus.FAILED: - logging.info(f"Job {job_id} failed. {job.batches_succeeded} out of {job.total_batches} batches succeeded") - - except Exception as e: - logging.error('Error updating job and batch status: %s', e) - safe_db_operation(job_service.update_job_status, job_id, JobStatus.FAILED) def callback(ch, method, properties, body): try: