1+ from typing import Any , Dict , List , Union , cast
2+ from typing_extensions import NotRequired , TypedDict , Literal
3+ from .request import Request , RequestConfig
4+ from .async_request import AsyncRequest , AsyncRequestConfig
5+ from ._config import ClientConfig
6+
7+
8+ class DatasetItemText (TypedDict ):
9+ type : Literal ["text" ]
10+ """
11+ Type of the dataset item: text
12+ """
13+
14+ value : str
15+ """
16+ Value of the dataset item
17+ """
18+
19+
20+ class DatasetItemImage (TypedDict ):
21+ type : Literal ["image" ]
22+ """
23+ Type of the dataset item: image
24+ """
25+
26+ value : str
27+ """
28+ Value of the dataset item
29+ """
30+
31+
32+ class LabelItemText (TypedDict ):
33+ key : NotRequired [str ]
34+ """
35+ Optional key for the label
36+ """
37+
38+ type : Literal ["text" ]
39+ """
40+ Type of the label: text
41+ """
42+
43+ value : str
44+ """
45+ Value of the label
46+ """
47+
48+
49+ class LabelItemImage (TypedDict ):
50+ key : NotRequired [str ]
51+ """
52+ Optional key for the label
53+ """
54+
55+ type : Literal ["image" , "text" ]
56+ """
57+ Type of the label: image or text
58+ """
59+
60+ value : str
61+ """
62+ Value of the label
63+ """
64+
65+
66+ class ClassificationTextParams (TypedDict ):
67+ dataset : List [DatasetItemText ]
68+ """
69+ List of text dataset items to classify
70+ """
71+
72+ labels : List [LabelItemText ]
73+ """
74+ List of text labels for classification
75+ """
76+
77+ multiple_labels : NotRequired [bool ]
78+ """
79+ Whether to allow multiple labels per item
80+ """
81+
82+
83+ class ClassificationImageParams (TypedDict ):
84+ dataset : List [DatasetItemImage ]
85+ """
86+ List of image dataset items to classify
87+ """
88+
89+ labels : List [LabelItemImage ]
90+ """
91+ List of labels for classification
92+ """
93+
94+ multiple_labels : NotRequired [bool ]
95+ """
96+ Whether to allow multiple labels per item
97+ """
98+
99+
100+ class ClassificationResponse (TypedDict ):
101+ predictions : List [Union [str , List [str ]]]
102+ """
103+ Classification predictions - single labels or multiple labels per item
104+ """
105+
106+
107+
108+ class Classification (ClientConfig ):
109+
110+ config : RequestConfig
111+
112+ def __init__ (
113+ self ,
114+ api_key : str ,
115+ api_url : str ,
116+ disable_request_logging : Union [bool , None ] = False ,
117+ ):
118+ super ().__init__ (api_key , api_url , disable_request_logging )
119+ self .config = RequestConfig (
120+ api_url = api_url ,
121+ api_key = api_key ,
122+ disable_request_logging = disable_request_logging ,
123+ )
124+
125+ def text (self , params : ClassificationTextParams ) -> ClassificationResponse :
126+ path = "/classification"
127+ resp = Request (
128+ config = self .config ,
129+ path = path ,
130+ params = cast (Dict [Any , Any ], params ),
131+ verb = "post" ,
132+ ).perform_with_content ()
133+ return resp
134+ def image (self , params : ClassificationImageParams ) -> ClassificationResponse :
135+ path = "/classification"
136+ resp = Request (
137+ config = self .config ,
138+ path = path ,
139+ params = cast (Dict [Any , Any ], params ),
140+ verb = "post" ,
141+ ).perform_with_content ()
142+ return resp
143+
144+
145+
146+ class AsyncClassification (ClientConfig ):
147+ config : AsyncRequestConfig
148+
149+ def __init__ (
150+ self ,
151+ api_key : str ,
152+ api_url : str ,
153+ disable_request_logging : Union [bool , None ] = False ,
154+ ):
155+ super ().__init__ (api_key , api_url , disable_request_logging )
156+ self .config = AsyncRequestConfig (
157+ api_url = api_url ,
158+ api_key = api_key ,
159+ disable_request_logging = disable_request_logging ,
160+ )
161+
162+ async def text (self , params : ClassificationTextParams ) -> ClassificationResponse :
163+ path = "/classification"
164+ resp = await AsyncRequest (
165+ config = self .config ,
166+ path = path ,
167+ params = cast (Dict [Any , Any ], params ),
168+ verb = "post" ,
169+ ).perform_with_content ()
170+ return resp
171+
172+ async def image (self , params : ClassificationImageParams ) -> ClassificationResponse :
173+ path = "/classification"
174+ resp = await AsyncRequest (
175+ config = self .config ,
176+ path = path ,
177+ params = cast (Dict [Any , Any ], params ),
178+ verb = "post" ,
179+ ).perform_with_content ()
180+ return resp
0 commit comments