Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d4cd423

Browse files
committed
classification added
1 parent 33fd9b1 commit d4cd423

File tree

4 files changed

+390
-0
lines changed

4 files changed

+390
-0
lines changed

jigsawstack/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .embedding import Embedding, AsyncEmbedding
1616
from .exceptions import JigsawStackError
1717
from .image_generation import ImageGeneration, AsyncImageGeneration
18+
from .classification import Classification, AsyncClassification
1819

1920

2021
class JigsawStack:
@@ -25,6 +26,7 @@ class JigsawStack:
2526
web: Web
2627
search: Search
2728
prompt_engine: PromptEngine
29+
classification: Classification
2830
api_key: str
2931
api_url: str
3032
disable_request_logging: bool
@@ -118,6 +120,12 @@ def __init__(
118120
disable_request_logging=disable_request_logging,
119121
).image_generation
120122

123+
self.classification = Classification(
124+
api_key=api_key,
125+
api_url=api_url,
126+
disable_request_logging=disable_request_logging,
127+
)
128+
121129

122130

123131
class AsyncJigsawStack:
@@ -229,6 +237,12 @@ def __init__(
229237
disable_request_logging=disable_request_logging,
230238
).image_generation
231239

240+
self.classification = AsyncClassification(
241+
api_key=api_key,
242+
api_url=api_url,
243+
disable_request_logging=disable_request_logging,
244+
)
245+
232246

233247

234248
# Create a global instance of the Web class

jigsawstack/classification.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from typing import Any, Dict, List, Union, cast
2+
from typing_extensions import NotRequired, TypedDict, Literal
3+
from .request import Request, RequestConfig
4+
from .async_request import AsyncRequest, AsyncRequestConfig
5+
from ._config import ClientConfig
6+
7+
8+
class DatasetItemText(TypedDict):
9+
type: Literal["text"]
10+
"""
11+
Type of the dataset item: text
12+
"""
13+
14+
value: str
15+
"""
16+
Value of the dataset item
17+
"""
18+
19+
20+
class DatasetItemImage(TypedDict):
21+
type: Literal["image"]
22+
"""
23+
Type of the dataset item: image
24+
"""
25+
26+
value: str
27+
"""
28+
Value of the dataset item
29+
"""
30+
31+
32+
class LabelItemText(TypedDict):
33+
key: NotRequired[str]
34+
"""
35+
Optional key for the label
36+
"""
37+
38+
type: Literal["text"]
39+
"""
40+
Type of the label: text
41+
"""
42+
43+
value: str
44+
"""
45+
Value of the label
46+
"""
47+
48+
49+
class LabelItemImage(TypedDict):
50+
key: NotRequired[str]
51+
"""
52+
Optional key for the label
53+
"""
54+
55+
type: Literal["image", "text"]
56+
"""
57+
Type of the label: image or text
58+
"""
59+
60+
value: str
61+
"""
62+
Value of the label
63+
"""
64+
65+
66+
class ClassificationTextParams(TypedDict):
67+
dataset: List[DatasetItemText]
68+
"""
69+
List of text dataset items to classify
70+
"""
71+
72+
labels: List[LabelItemText]
73+
"""
74+
List of text labels for classification
75+
"""
76+
77+
multiple_labels: NotRequired[bool]
78+
"""
79+
Whether to allow multiple labels per item
80+
"""
81+
82+
83+
class ClassificationImageParams(TypedDict):
84+
dataset: List[DatasetItemImage]
85+
"""
86+
List of image dataset items to classify
87+
"""
88+
89+
labels: List[LabelItemImage]
90+
"""
91+
List of labels for classification
92+
"""
93+
94+
multiple_labels: NotRequired[bool]
95+
"""
96+
Whether to allow multiple labels per item
97+
"""
98+
99+
100+
class ClassificationResponse(TypedDict):
101+
predictions: List[Union[str, List[str]]]
102+
"""
103+
Classification predictions - single labels or multiple labels per item
104+
"""
105+
106+
107+
108+
class Classification(ClientConfig):
109+
110+
config: RequestConfig
111+
112+
def __init__(
113+
self,
114+
api_key: str,
115+
api_url: str,
116+
disable_request_logging: Union[bool, None] = False,
117+
):
118+
super().__init__(api_key, api_url, disable_request_logging)
119+
self.config = RequestConfig(
120+
api_url=api_url,
121+
api_key=api_key,
122+
disable_request_logging=disable_request_logging,
123+
)
124+
125+
def text(self, params: ClassificationTextParams) -> ClassificationResponse:
126+
path = "/classification"
127+
resp = Request(
128+
config=self.config,
129+
path=path,
130+
params=cast(Dict[Any, Any], params),
131+
verb="post",
132+
).perform_with_content()
133+
return resp
134+
def image(self, params: ClassificationImageParams) -> ClassificationResponse:
135+
path = "/classification"
136+
resp = Request(
137+
config=self.config,
138+
path=path,
139+
params=cast(Dict[Any, Any], params),
140+
verb="post",
141+
).perform_with_content()
142+
return resp
143+
144+
145+
146+
class AsyncClassification(ClientConfig):
147+
config: AsyncRequestConfig
148+
149+
def __init__(
150+
self,
151+
api_key: str,
152+
api_url: str,
153+
disable_request_logging: Union[bool, None] = False,
154+
):
155+
super().__init__(api_key, api_url, disable_request_logging)
156+
self.config = AsyncRequestConfig(
157+
api_url=api_url,
158+
api_key=api_key,
159+
disable_request_logging=disable_request_logging,
160+
)
161+
162+
async def text(self, params: ClassificationTextParams) -> ClassificationResponse:
163+
path = "/classification"
164+
resp = await AsyncRequest(
165+
config=self.config,
166+
path=path,
167+
params=cast(Dict[Any, Any], params),
168+
verb="post",
169+
).perform_with_content()
170+
return resp
171+
172+
async def image(self, params: ClassificationImageParams) -> ClassificationResponse:
173+
path = "/classification"
174+
resp = await AsyncRequest(
175+
config=self.config,
176+
path=path,
177+
params=cast(Dict[Any, Any], params),
178+
verb="post",
179+
).perform_with_content()
180+
return resp

tests/test_classification.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from unittest.mock import MagicMock
2+
import unittest
3+
from jigsawstack.exceptions import JigsawStackError
4+
from jigsawstack import JigsawStack, AsyncJigsawStack
5+
import asyncio
6+
import logging
7+
8+
import pytest
9+
10+
logging.basicConfig(level=logging.INFO)
11+
logger = logging.getLogger(__name__)
12+
13+
client = JigsawStack()
14+
async_client = AsyncJigsawStack()
15+
16+
17+
class TestClassificationAPI(unittest.TestCase):
18+
def test_classification_text_success_response(self):
19+
params = {
20+
"dataset": [
21+
{"type": "text", "value": "Hello"},
22+
{"type": "text", "value": "World"}
23+
],
24+
"labels": [
25+
{"type": "text", "value": "Greeting"},
26+
{"type": "text", "value": "Object"}
27+
]
28+
}
29+
try:
30+
result = client.classification.text(params)
31+
assert result["success"] == True
32+
except JigsawStackError as e:
33+
pytest.fail(f"Unexpected JigsawStackError: {e}")
34+
35+
def test_classification_text_async_success_response(self):
36+
async def _test():
37+
params = {
38+
"dataset": [
39+
{"type": "text", "value": "Hello"},
40+
{"type": "text", "value": "World"}
41+
],
42+
"labels": [
43+
{"type": "text", "value": "Greeting"},
44+
{"type": "text", "value": "Object"}
45+
]
46+
}
47+
try:
48+
result = await async_client.classification.text(params)
49+
assert result["success"] == True
50+
except JigsawStackError as e:
51+
pytest.fail(f"Unexpected JigsawStackError: {e}")
52+
53+
asyncio.run(_test())
54+
55+
def test_classification_text_with_multiple_labels(self):
56+
params = {
57+
"dataset": [
58+
{"type": "text", "value": "This is a positive and happy message"}
59+
],
60+
"labels": [
61+
{"type": "text", "value": "positive"},
62+
{"type": "text", "value": "negative"},
63+
{"type": "text", "value": "happy"},
64+
{"type": "text", "value": "sad"}
65+
],
66+
"multiple_labels": True
67+
}
68+
try:
69+
result = client.classification.text(params)
70+
assert result["success"] == True
71+
except JigsawStackError as e:
72+
pytest.fail(f"Unexpected JigsawStackError: {e}")
73+
74+
def test_classification_image_success_response(self):
75+
params = {
76+
"dataset": [
77+
{"type": "image", "value": "https://example.com/image1.jpg"},
78+
{"type": "image", "value": "https://example.com/image2.jpg"}
79+
],
80+
"labels": [
81+
{"type": "text", "value": "Cat"},
82+
{"type": "text", "value": "Dog"}
83+
]
84+
}
85+
try:
86+
result = client.classification.image(params)
87+
assert result["success"] == True
88+
except JigsawStackError as e:
89+
pytest.fail(f"Unexpected JigsawStackError: {e}")
90+
91+
def test_classification_image_async_success_response(self):
92+
async def _test():
93+
params = {
94+
"dataset": [
95+
{"type": "image", "value": "https://example.com/image1.jpg"},
96+
{"type": "image", "value": "https://example.com/image2.jpg"}
97+
],
98+
"labels": [
99+
{"type": "text", "value": "Cat"},
100+
{"type": "text", "value": "Dog"}
101+
]
102+
}
103+
try:
104+
result = await async_client.classification.image(params)
105+
assert result["success"] == True
106+
except JigsawStackError as e:
107+
pytest.fail(f"Unexpected JigsawStackError: {e}")
108+
109+
asyncio.run(_test())
110+
111+
def test_classification_image_with_multiple_labels(self):
112+
params = {
113+
"dataset": [
114+
{"type": "image", "value": "https://example.com/pet_image.jpg"}
115+
],
116+
"labels": [
117+
{"type": "text", "value": "cute"},
118+
{"type": "text", "value": "fluffy"},
119+
{"type": "text", "value": "animal"},
120+
{"type": "text", "value": "pet"}
121+
],
122+
"multiple_labels": True
123+
}
124+
try:
125+
result = client.classification.image(params)
126+
assert result["success"] == True
127+
except JigsawStackError as e:
128+
pytest.fail(f"Unexpected JigsawStackError: {e}")
129+

0 commit comments

Comments
 (0)