Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/instructlab/cli/data/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,14 @@ def generate(
# determine student model arch from train section of config and pick system prompt to
# pass to SDG appropriately
system_prompt, legacy_pretraining_format = None, None
if not student_model_id:

student_id = (
student_model_id
if student_model_id
else ctx.obj.config.general.student_model_id
)

if not student_id:
student_model_path = pathlib.Path(ctx.obj.config.train.model_path)
student_model_arch = get_model_arch(student_model_path)
system_prompt = get_sysprompt(student_model_arch)
Expand All @@ -297,12 +304,10 @@ def generate(
)
else:
try:
student_model_config = resolve_model_id(
student_model_id, ctx.obj.config.models
)
student_model_config = resolve_model_id(student_id, ctx.obj.config.models)
if not student_model_config:
raise ValueError(
f"Student model with ID '{student_model_id}' not found in the configuration."
f"Student model with ID '{student_id}' not found in the configuration."
)
except ValueError as ve:
click.secho(f"failed to locate student model by ID: {ve}", fg="red")
Expand Down
4 changes: 4 additions & 0 deletions src/instructlab/cli/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,10 @@ def train(
"Please run using either the 'accelerated' or 'full' pipelines."
)

ctx.params["model_id"] = (
model_id if model_id else ctx.obj.config.general.student_model_id
)

# we can use train_args locally to run lower fidelity training
if is_high_fidelity(device=device) and pipeline == "accelerated":
train_args, torch_args = map_train_to_library(ctx, ctx.params)
Expand Down
12 changes: 12 additions & 0 deletions src/instructlab/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ class _general(BaseModel):
default=False,
description="Use legacy IBM Granite chat template (default uses 3.0 Instruct template)",
)
# Global student model id
student_model_id: str | None = Field(
Copy link
Member

@RobotSail RobotSail Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try using typing.NotRequired here, we should see whether this type can allow us to avoid explicitly include this field within the config

Suggested change
student_model_id: str | None = Field(
student_model_id: typing.NotRequired[str] = Field(

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically today all config fields will be rendered in the default on-disk config which is super confusing. I want to see if using this type will prevent this field from being listed there entirely until it's necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think thats compatible with pydantic

The error "`NotRequired[] can be only used in a TypedDict definition`" is occurring because `typing.NotRequired` is being used in a `pydantic` model, which is not allowed. `NotRequired` is specifically designed for use in `TypedDict` definitions, not in `pydantic` models.

In your code, the problematic line is:


teacher_model_id: typing.NotRequired[str] = Field(
    default_factory=lambda: None,
    description="ID of the teacher model to be used for data generation.",
    exclude=True,
)


---

### Why This is Happening

`pydantic` models use `Optional` or `default_factory` to define optional fields. The `NotRequired` type hint is not compatible with `pydantic` models because it is meant for `TypedDict` definitions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jaideepr97 I see, thank you for checking this. The way you have done it is perfect

default_factory=lambda: None,
description="ID of the student model to be used for training.",
exclude=True,
)
# Global teacher model id
teacher_model_id: str | None = Field(
default_factory=lambda: None,
description="ID of the teacher model to be used for data generation.",
exclude=True,
)

@field_validator("log_level")
def validate_log_level(cls, v):
Expand Down
10 changes: 9 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,19 @@
_CFG_FILE_NAME = "test-serve-config.yaml"


def setup_test_models_config(models_list, dest=""):
def setup_test_models_config(
models_list, dest="", global_student_id=None, global_teacher_id=None
):
_cfg_name = "test-models-config.yaml"
cfg = configuration.get_default_config()

cfg.models = models_list

if global_student_id:
cfg.general.student_model_id = global_student_id
if global_teacher_id:
cfg.general.teacher_model_id = global_teacher_id

with pathlib.Path(f"{dest}/{_cfg_name}").open("w", encoding="utf-8") as f:
yaml.dump(cfg.model_dump(), f)
return _cfg_name
Expand Down
35 changes: 35 additions & 0 deletions tests/test_lab_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,38 @@ def test_generate_with_teacher_model_id(

assert result.exit_code == 0, f"Command failed with output: {result.output}"
assert mock_gen_data.call_count == 1


@patch("instructlab.cli.data.generate.gen_data")
@patch("instructlab.cli.data.data.storage_dirs_exist", return_value=True)
def test_generate_with_global_teacher_model_id(
mock_gen_data, _, cli_runner: CliRunner, tmp_path: Path
):
fname = common.setup_test_models_config(
models_list=[
model_info(
id="teacher_model",
path="teacher/model/path",
family="llama",
system_prompt="system prompt",
)
],
dest=tmp_path,
global_teacher_id="teacher_model",
)

result = cli_runner.invoke(
lab.ilab,
[
f"--config={tmp_path}/{fname}",
"data",
"generate",
"--pipeline",
"simple",
"--output-dir",
str(tmp_path),
],
)

assert result.exit_code == 0, f"Command failed with output: {result.output}"
assert mock_gen_data.call_count == 1
6 changes: 6 additions & 0 deletions tests/testdata/default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ general:
# Log level for logging.
# Default: INFO
log_level: INFO
# ID of the student model to be used for training.
# Default: None
student_model_id:
# ID of the teacher model to be used for data generation.
# Default: None
teacher_model_id:
# Use legacy IBM Granite chat template (default uses 3.0 Instruct template)
# Default: False
use_legacy_tmpl: false
Expand Down
Loading