|
19 | 19 | class GraniteioModel:
|
20 | 20 | @staticmethod
|
21 | 21 | def processor_of_block(block: GraniteioModelBlock):
|
22 |
| - model = value_of_expr(block.model) |
| 22 | + from granite_io import make_backend, make_io_processor |
| 23 | + from granite_io.backend.base import Backend |
| 24 | + from granite_io.io import InputOutputProcessor |
| 25 | + |
| 26 | + processor = value_of_expr(block.processor) |
| 27 | + if isinstance(processor, InputOutputProcessor): |
| 28 | + return processor |
| 29 | + model: str = value_of_expr(block.model) |
23 | 30 | backend = value_of_expr(block.backend)
|
24 | 31 | assert isinstance(model, str), f"The model should be a string: {model}"
|
25 |
| - assert isinstance( |
26 |
| - backend, (dict, str) |
27 |
| - ), f"The backend should be a string or a dictionary: {backend}" |
28 | 32 | match backend:
|
29 | 33 | case {"transformers": device}:
|
30 |
| - assert isinstance(backend, dict) |
31 |
| - from granite_io import make_backend |
32 |
| - |
33 | 34 | backend = make_backend(
|
34 | 35 | "transformers",
|
35 | 36 | {
|
36 | 37 | "model_name": model,
|
37 | 38 | "device": device,
|
38 | 39 | },
|
39 | 40 | )
|
40 |
| - case backend_name if isinstance(backend_name, str): |
41 |
| - from granite_io import make_backend |
42 |
| - |
| 41 | + case str(): |
43 | 42 | backend = make_backend(
|
44 |
| - backend_name, |
| 43 | + backend, |
45 | 44 | {
|
46 | 45 | "model_name": model,
|
47 | 46 | },
|
48 | 47 | )
|
| 48 | + case Backend(): |
| 49 | + pass |
49 | 50 | case _:
|
50 | 51 | assert False, f"Unexpected backend: {backend}"
|
51 |
| - if block.processor is None: |
| 52 | + if processor is None: |
52 | 53 | processor_name = model
|
53 | 54 | else:
|
54 |
| - processor_name = value_of_expr(block.processor) |
| 55 | + assert isinstance( |
| 56 | + processor, str |
| 57 | + ), f"The processor should be a string: {processor}" |
| 58 | + processor_name = value_of_expr(processor) |
55 | 59 | assert isinstance(
|
56 | 60 | processor_name, str
|
57 | 61 | ), f"The processor should be a string: {processor_name}"
|
58 |
| - from granite_io import make_io_processor |
59 |
| - |
60 | 62 | io_processor = make_io_processor(processor_name, backend=backend)
|
61 | 63 | return io_processor
|
62 | 64 |
|
@@ -87,7 +89,7 @@ async def async_generate_text(
|
87 | 89 | inputs
|
88 | 90 | )
|
89 | 91 | try: # TODO: update when new version of granite-io is released
|
90 |
| - message = result.next_message.model_dump() |
| 92 | + message = result.next_message.model_dump() # pyright: ignore |
91 | 93 | except AttributeError:
|
92 | 94 | message = result.results[0].next_message.model_dump()
|
93 | 95 | raw_result = result.model_dump()
|
|
0 commit comments