Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 25 additions & 31 deletions cvat/apps/engine/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,17 @@ def _get_label_mapping(db_labels):


def _write_annotation_guide(
zip_object, annotation_guide, guide_filename, assets_dirname, target_dir=None
):
zip_object: ZipFile,
annotation_guide: Optional[models.AnnotationGuide],
guide_filename: str,
assets_dirname: str,
target_dir: str,
) -> None:
if annotation_guide is not None:
md = annotation_guide.markdown
assets = annotation_guide.assets.all()
assets_dirname = os.path.join(target_dir or "", assets_dirname)
guide_filename = os.path.join(target_dir or "", guide_filename)
assets_dirname = os.path.join(target_dir, assets_dirname)
guide_filename = os.path.join(target_dir, guide_filename)

for db_asset in assets:
md = md.replace(
Expand Down Expand Up @@ -386,7 +390,7 @@ def _write_directory(
)

@abstractmethod
def export_to(self, file: str | ZipFile, target_dir: str | None = None): ...
def export_to(self, file: str) -> None: ...

@classmethod
def get_object(cls, pk: int) -> models.Project | models.Task:
Expand Down Expand Up @@ -420,7 +424,7 @@ def __init__(self, pk, version=Version.V1, *, lightweight: bool):
self._label_mapping = _get_label_mapping(db_labels)
self._lightweight = lightweight

def _write_annotation_guide(self, zip_object, target_dir=None):
def _write_annotation_guide(self, zip_object: ZipFile, target_dir: str) -> None:
annotation_guide = (
self._db_task.annotation_guide if hasattr(self._db_task, "annotation_guide") else None
)
Expand All @@ -432,10 +436,9 @@ def _write_annotation_guide(self, zip_object, target_dir=None):
target_dir=target_dir,
)

def _write_data(self, zip_object, target_dir=None):
target_data_dir = (
os.path.join(target_dir, self.DATA_DIRNAME) if target_dir else self.DATA_DIRNAME
)
def _write_data(self, zip_object: ZipFile, target_dir: str) -> None:
target_data_dir = os.path.join(target_dir, self.DATA_DIRNAME)

if self._db_data.storage == StorageChoice.LOCAL:
data_dir = self._db_data.get_upload_dirname()
self._write_directory(
Expand Down Expand Up @@ -547,19 +550,18 @@ def _write_data(self, zip_object, target_dir=None):
else:
raise NotImplementedError

def _write_task(self, zip_object, target_dir=None):
def _write_task(self, zip_object: ZipFile, target_dir: str) -> None:
task_dir = self._db_task.get_dirname()
target_task_dir = (
os.path.join(target_dir, self.TASK_DIRNAME) if target_dir else self.TASK_DIRNAME
)
target_task_dir = os.path.join(target_dir, self.TASK_DIRNAME)

self._write_directory(
source_dir=task_dir,
zip_object=zip_object,
target_dir=target_task_dir,
recursive=False,
)

def _write_manifest(self, zip_object, target_dir=None):
def _write_manifest(self, zip_object: ZipFile, target_dir: str) -> None:
def serialize_task():
task_serializer = TaskReadSerializer(self._db_task)
for field in ("url", "owner", "assignee"):
Expand Down Expand Up @@ -674,14 +676,10 @@ def serialize_data():
task["data"] = serialize_data()
task["jobs"] = serialize_jobs()

target_manifest_file = (
os.path.join(target_dir, self.MANIFEST_FILENAME)
if target_dir
else self.MANIFEST_FILENAME
)
target_manifest_file = os.path.join(target_dir, self.MANIFEST_FILENAME)
zip_object.writestr(target_manifest_file, data=JSONRenderer().render(task))

def _write_annotations(self, zip_object: ZipFile, target_dir: Optional[str] = None) -> None:
def _write_annotations(self, zip_object: ZipFile, target_dir: str) -> None:
def serialize_annotations():
db_jobs = self._get_db_jobs()
db_job_ids = (j.id for j in db_jobs)
Expand All @@ -708,22 +706,18 @@ def serialize_shapes():
yield self._prepare_annotations(annotation_data, self._label_mapping)

annotations = serialize_annotations()
target_annotations_file = (
os.path.join(target_dir, self.ANNOTATIONS_FILENAME)
if target_dir
else self.ANNOTATIONS_FILENAME
)
target_annotations_file = os.path.join(target_dir, self.ANNOTATIONS_FILENAME)
with zip_object.open(target_annotations_file, "w") as f:
rapidjson.dump(annotations, f)

def _export_task(self, zip_obj, target_dir=None):
def _export_task(self, zip_obj: ZipFile, target_dir: str) -> None:
self._write_data(zip_obj, target_dir)
self._write_task(zip_obj, target_dir)
self._write_manifest(zip_obj, target_dir)
self._write_annotations(zip_obj, target_dir)
self._write_annotation_guide(zip_obj, target_dir)

def export_to(self, file: str | ZipFile, target_dir: str | None = None):
def export_to(self, file: str | ZipFile, target_dir: str = "") -> None:
if (
self._db_task.data.storage_method == StorageMethodChoice.FILE_SYSTEM
and self._db_task.data.storage == StorageChoice.SHARE
Expand Down Expand Up @@ -1154,7 +1148,7 @@ def __init__(self, pk, *, lightweight: bool, version: Version = Version.V1):
self._label_mapping = _get_label_mapping(db_labels)
self._lightweight = lightweight

def _write_annotation_guide(self, zip_object, target_dir=None):
def _write_annotation_guide(self, zip_object: ZipFile) -> None:
annotation_guide = (
self._db_project.annotation_guide
if hasattr(self._db_project, "annotation_guide")
Expand All @@ -1165,7 +1159,7 @@ def _write_annotation_guide(self, zip_object, target_dir=None):
annotation_guide,
self.ANNOTATION_GUIDE_FILENAME,
self.ASSETS_DIRNAME,
target_dir=target_dir,
target_dir="",
)

def _write_tasks(self, zip_object):
Expand Down Expand Up @@ -1201,7 +1195,7 @@ def serialize_project():

zip_object.writestr(self.MANIFEST_FILENAME, data=JSONRenderer().render(project))

def export_to(self, file: str, target_dir: str | None = None):
def export_to(self, file: str) -> None:
with ZipFile(file, "w") as output_file:
self._write_annotation_guide(output_file)
self._write_manifest(output_file)
Expand Down
Loading