Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit dbd5b54

Browse files
Merge pull request #53 from JigsawStack/feat/new-classification
Feat/new classification
2 parents 33fd9b1 + f84591d commit dbd5b54

File tree

5 files changed

+359
-1
lines changed

5 files changed

+359
-1
lines changed

.gitignore

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,11 @@ test.py
2121
test_web.py
2222

2323
.eggs/
24-
.conda/
24+
.conda/
25+
26+
main.py
27+
.python-version
28+
pyproject.toml
29+
uv.lock
30+
31+
.ruff_cache/

jigsawstack/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .embedding import Embedding, AsyncEmbedding
1616
from .exceptions import JigsawStackError
1717
from .image_generation import ImageGeneration, AsyncImageGeneration
18+
from .classification import Classification, AsyncClassification
1819

1920

2021
class JigsawStack:
@@ -25,6 +26,7 @@ class JigsawStack:
2526
web: Web
2627
search: Search
2728
prompt_engine: PromptEngine
29+
classification: Classification
2830
api_key: str
2931
api_url: str
3032
disable_request_logging: bool
@@ -118,6 +120,12 @@ def __init__(
118120
disable_request_logging=disable_request_logging,
119121
).image_generation
120122

123+
self.classification = Classification(
124+
api_key=api_key,
125+
api_url=api_url,
126+
disable_request_logging=disable_request_logging,
127+
)
128+
121129

122130

123131
class AsyncJigsawStack:
@@ -229,6 +237,12 @@ def __init__(
229237
disable_request_logging=disable_request_logging,
230238
).image_generation
231239

240+
self.classification = AsyncClassification(
241+
api_key=api_key,
242+
api_url=api_url,
243+
disable_request_logging=disable_request_logging,
244+
)
245+
232246

233247

234248
# Create a global instance of the Web class

jigsawstack/classification.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from typing import Any, Dict, List, Union, cast
2+
from typing_extensions import NotRequired, TypedDict, Literal
3+
from .request import Request, RequestConfig
4+
from .async_request import AsyncRequest, AsyncRequestConfig
5+
from ._config import ClientConfig
6+
7+
8+
class DatasetItemText(TypedDict):
9+
type: Literal["text"]
10+
"""
11+
Type of the dataset item: text
12+
"""
13+
14+
value: str
15+
"""
16+
Value of the dataset item
17+
"""
18+
19+
20+
class DatasetItemImage(TypedDict):
21+
type: Literal["image"]
22+
"""
23+
Type of the dataset item: image
24+
"""
25+
26+
value: str
27+
"""
28+
Value of the dataset item
29+
"""
30+
31+
32+
class LabelItemText(TypedDict):
33+
key: NotRequired[str]
34+
"""
35+
Optional key for the label
36+
"""
37+
38+
type: Literal["text"]
39+
"""
40+
Type of the label: text
41+
"""
42+
43+
value: str
44+
"""
45+
Value of the label
46+
"""
47+
48+
49+
class LabelItemImage(TypedDict):
50+
key: NotRequired[str]
51+
"""
52+
Optional key for the label
53+
"""
54+
55+
type: Literal["image", "text"]
56+
"""
57+
Type of the label: image or text
58+
"""
59+
60+
value: str
61+
"""
62+
Value of the label
63+
"""
64+
65+
66+
class ClassificationTextParams(TypedDict):
67+
dataset: List[DatasetItemText]
68+
"""
69+
List of text dataset items to classify
70+
"""
71+
72+
labels: List[LabelItemText]
73+
"""
74+
List of text labels for classification
75+
"""
76+
77+
multiple_labels: NotRequired[bool]
78+
"""
79+
Whether to allow multiple labels per item
80+
"""
81+
82+
83+
class ClassificationImageParams(TypedDict):
84+
dataset: List[DatasetItemImage]
85+
"""
86+
List of image dataset items to classify
87+
"""
88+
89+
labels: List[LabelItemImage]
90+
"""
91+
List of labels for classification
92+
"""
93+
94+
multiple_labels: NotRequired[bool]
95+
"""
96+
Whether to allow multiple labels per item
97+
"""
98+
99+
100+
class ClassificationResponse(TypedDict):
101+
predictions: List[Union[str, List[str]]]
102+
"""
103+
Classification predictions - single labels or multiple labels per item
104+
"""
105+
106+
107+
108+
class Classification(ClientConfig):
109+
110+
config: RequestConfig
111+
112+
def __init__(
113+
self,
114+
api_key: str,
115+
api_url: str,
116+
disable_request_logging: Union[bool, None] = False,
117+
):
118+
super().__init__(api_key, api_url, disable_request_logging)
119+
self.config = RequestConfig(
120+
api_url=api_url,
121+
api_key=api_key,
122+
disable_request_logging=disable_request_logging,
123+
)
124+
125+
def text(self, params: ClassificationTextParams) -> ClassificationResponse:
126+
path = "/classification"
127+
resp = Request(
128+
config=self.config,
129+
path=path,
130+
params=cast(Dict[Any, Any], params),
131+
verb="post",
132+
).perform_with_content()
133+
return resp
134+
def image(self, params: ClassificationImageParams) -> ClassificationResponse:
135+
path = "/classification"
136+
resp = Request(
137+
config=self.config,
138+
path=path,
139+
params=cast(Dict[Any, Any], params),
140+
verb="post",
141+
).perform_with_content()
142+
return resp
143+
144+
145+
146+
class AsyncClassification(ClientConfig):
147+
config: AsyncRequestConfig
148+
149+
def __init__(
150+
self,
151+
api_key: str,
152+
api_url: str,
153+
disable_request_logging: Union[bool, None] = False,
154+
):
155+
super().__init__(api_key, api_url, disable_request_logging)
156+
self.config = AsyncRequestConfig(
157+
api_url=api_url,
158+
api_key=api_key,
159+
disable_request_logging=disable_request_logging,
160+
)
161+
162+
async def text(self, params: ClassificationTextParams) -> ClassificationResponse:
163+
path = "/classification"
164+
resp = await AsyncRequest(
165+
config=self.config,
166+
path=path,
167+
params=cast(Dict[Any, Any], params),
168+
verb="post",
169+
).perform_with_content()
170+
return resp
171+
172+
async def image(self, params: ClassificationImageParams) -> ClassificationResponse:
173+
path = "/classification"
174+
resp = await AsyncRequest(
175+
config=self.config,
176+
path=path,
177+
params=cast(Dict[Any, Any], params),
178+
verb="post",
179+
).perform_with_content()
180+
return resp

tests/test_classification.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from jigsawstack.exceptions import JigsawStackError
2+
from jigsawstack import JigsawStack
3+
4+
import pytest
5+
6+
# flake8: noqa
7+
8+
client = JigsawStack()
9+
10+
11+
@pytest.mark.parametrize("dataset,labels", [
12+
(
13+
[
14+
{"type": "text", "value": "I love programming"},
15+
{"type": "text", "value": "I love reading books"},
16+
{"type": "text", "value": "I love watching movies"},
17+
{"type": "text", "value": "I love playing games"},
18+
],
19+
[
20+
{"type": "text", "value": "programming"},
21+
{"type": "text", "value": "reading"},
22+
{"type": "text", "value": "watching"},
23+
{"type": "text", "value": "playing"},
24+
]
25+
),
26+
(
27+
[
28+
{"type": "text", "value": "This is awesome!"},
29+
{"type": "text", "value": "I hate this product"},
30+
{"type": "text", "value": "It's okay, nothing special"},
31+
],
32+
[
33+
{"type": "text", "value": "positive"},
34+
{"type": "text", "value": "negative"},
35+
{"type": "text", "value": "neutral"},
36+
]
37+
),
38+
(
39+
[
40+
{"type": "text", "value": "The weather is sunny today"},
41+
{"type": "text", "value": "It's raining heavily outside"},
42+
{"type": "text", "value": "Snow is falling gently"},
43+
],
44+
[
45+
{"type": "text", "value": "sunny"},
46+
{"type": "text", "value": "rainy"},
47+
{"type": "text", "value": "snowy"},
48+
]
49+
),
50+
])
51+
def test_classification_text_success_response(dataset, labels) -> None:
52+
params = {
53+
"dataset": dataset,
54+
"labels": labels,
55+
}
56+
try:
57+
result = client.classification.text(params)
58+
print(result)
59+
assert result["success"] == True
60+
except JigsawStackError as e:
61+
print(str(e))
62+
assert e.message == "Failed to parse API response. Please try again."
63+
64+
65+
@pytest.mark.parametrize("dataset,labels", [
66+
(
67+
[
68+
{"type": "image", "value": "https://as2.ftcdn.net/v2/jpg/02/24/11/57/1000_F_224115780_2ssvcCoTfQrx68Qsl5NxtVIDFWKtAgq2.jpg"},
69+
{"type": "image", "value": "https://t3.ftcdn.net/jpg/02/95/44/22/240_F_295442295_OXsXOmLmqBUfZreTnGo9PREuAPSLQhff.jpg"},
70+
{"type": "image", "value": "https://as1.ftcdn.net/v2/jpg/05/54/94/46/1000_F_554944613_okdr3fBwcE9kTOgbLp4BrtVi8zcKFWdP.jpg"},
71+
],
72+
[
73+
{"type": "text", "value": "banana"},
74+
{"type": "image", "value": "https://upload.wikimedia.org/wikipedia/commons/8/8a/Banana-Single.jpg"},
75+
{"type": "text", "value": "kisses"},
76+
]
77+
),
78+
])
79+
def test_classification_image_success_response(dataset, labels) -> None:
80+
params = {
81+
"dataset": dataset,
82+
"labels": labels,
83+
}
84+
try:
85+
result = client.classification.image(params)
86+
print(result)
87+
assert result["success"] == True
88+
except JigsawStackError as e:
89+
print(str(e))
90+
assert e.message == "Failed to parse API response. Please try again."

tests/test_file_store.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from unittest.mock import MagicMock
2+
import unittest
3+
from jigsawstack.exceptions import JigsawStackError
4+
from jigsawstack import JigsawStack
5+
6+
import pytest
7+
8+
# flake8: noqa
9+
10+
client = JigsawStack()
11+
12+
13+
@pytest.mark.skip(reason="Skipping TestStoreAPI class for now")
14+
class TestStoreAPI(unittest.TestCase):
15+
def test_upload_success_response(self) -> None:
16+
# Sample file content as bytes
17+
file_content = b"This is a test file content"
18+
options = {
19+
"key": "test-file.txt",
20+
"content_type": "text/plain",
21+
"overwrite": True,
22+
"temp_public_url": True
23+
}
24+
try:
25+
result = client.store.upload(file_content, options)
26+
assert result["success"] == True
27+
except JigsawStackError as e:
28+
assert e.message == "Failed to parse API response. Please try again."
29+
30+
def test_get_success_response(self) -> None:
31+
key = "test-file.txt"
32+
try:
33+
result = client.store.get(key)
34+
# For file retrieval, we expect the actual file content
35+
assert result is not None
36+
except JigsawStackError as e:
37+
assert e.message == "Failed to parse API response. Please try again."
38+
39+
def test_delete_success_response(self) -> None:
40+
key = "test-file.txt"
41+
try:
42+
result = client.store.delete(key)
43+
assert result["success"] == True
44+
except JigsawStackError as e:
45+
assert e.message == "Failed to parse API response. Please try again."
46+
47+
def test_upload_without_options_success_response(self) -> None:
48+
# Test upload without optional parameters
49+
file_content = b"This is another test file content"
50+
try:
51+
result = client.store.upload(file_content)
52+
assert result["success"] == True
53+
except JigsawStackError as e:
54+
assert e.message == "Failed to parse API response. Please try again."
55+
56+
def test_upload_with_partial_options_success_response(self) -> None:
57+
# Test upload with partial options
58+
file_content = b"This is a test file with partial options"
59+
options = {
60+
"key": "partial-test-file.txt",
61+
"overwrite": False
62+
}
63+
try:
64+
result = client.store.upload(file_content, options)
65+
assert result["success"] == True
66+
except JigsawStackError as e:
67+
assert e.message == "Failed to parse API response. Please try again."

0 commit comments

Comments
 (0)