1
- from typing import Optional , Union
1
+ from typing import Iterable , Optional , Union
2
2
3
3
from .base_types import Flavor , TestCase , TestCases
4
4
from .clients import Client
5
+ from .common import batched
5
6
6
7
from .retry import RegularPeriodRetry , RetryMechanism
7
8
from .submission import Submission , Submissions
@@ -55,6 +56,56 @@ def resolve_client(
55
56
)
56
57
57
58
59
+ def create_submissions (
60
+ client : Optional [Client ] = None ,
61
+ submissions : Optional [Union [Submission , Submissions ]] = None ,
62
+ ) -> Union [Submission , Submissions ]:
63
+ client = resolve_client (client = client , submissions = submissions )
64
+
65
+ if isinstance (submissions , Submission ):
66
+ return client .create_submission (submissions )
67
+
68
+ # TODO: Use result from get_config.
69
+ batch_size = client .EFFECTIVE_SUBMISSION_BATCH_SIZE
70
+ result_submissions = []
71
+ for submission_batch in batched (submissions , batch_size ):
72
+ submissions_list = list (submission_batch )
73
+ if batch_size > 1 :
74
+ result_submissions .extend (client .create_submissions (submissions_list ))
75
+ else :
76
+ result_submissions .append (client .create_submission (submissions_list [0 ]))
77
+
78
+ return result_submissions
79
+
80
+
81
+ def get_submissions (
82
+ * ,
83
+ client : Optional [Client ] = None ,
84
+ submissions : Optional [Union [Submission , Submissions ]] = None ,
85
+ fields : Union [str , Iterable [str ], None ] = None ,
86
+ ) -> Union [Submission , Submissions ]:
87
+ client = resolve_client (client = client , submissions = submissions )
88
+
89
+ if isinstance (submissions , Submission ):
90
+ return client .get_submission (submissions , fields = fields )
91
+
92
+ # TODO: Use result from get_config.
93
+ batch_size = client .EFFECTIVE_SUBMISSION_BATCH_SIZE
94
+ result_submissions = []
95
+ for submission_batch in batched (submissions , batch_size ):
96
+ submissions_list = list (submission_batch )
97
+ if batch_size > 1 :
98
+ result_submissions .extend (
99
+ client .get_submissions (submissions_list , fields = fields )
100
+ )
101
+ else :
102
+ result_submissions .append (
103
+ client .get_submission (submissions_list [0 ], fields = fields )
104
+ )
105
+
106
+ return result_submissions
107
+
108
+
58
109
def wait (
59
110
* ,
60
111
client : Optional [Client ] = None ,
@@ -76,7 +127,7 @@ def wait(
76
127
}
77
128
78
129
while len (submissions_to_check ) > 0 and not retry_mechanism .is_done ():
79
- client . check_submissions ( list (submissions_to_check .values ()))
130
+ get_submissions ( client = client , submissions = list (submissions_to_check .values ()))
80
131
for token in list (submissions_to_check ):
81
132
submission = submissions_to_check [token ]
82
133
if submission .is_done ():
@@ -162,12 +213,12 @@ def _execute(
162
213
163
214
client = resolve_client (client = client , submissions = submissions )
164
215
all_submissions = create_submissions_from_test_cases (submissions , test_cases )
165
- all_submissions = client . submit ( all_submissions )
216
+ all_submissions = create_submissions ( client = client , submissions = all_submissions )
166
217
167
218
if wait_for_result :
168
- all_submissions = wait (client = client , submissions = all_submissions )
169
-
170
- return all_submissions
219
+ return wait (client = client , submissions = all_submissions )
220
+ else :
221
+ return all_submissions
171
222
172
223
173
224
def async_execute (
0 commit comments