Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,53 @@ checkpoints
out
wandb
amlt

# Testing
.pytest_cache/
.coverage
htmlcov/
coverage.xml
*.pyc
.tox/
.mypy_cache/
.hypothesis/

# Claude
.claude/*

# Build artifacts
dist/
*.egg-info/
.eggs/

# IDE
.idea/
*.swp
*.swo
*~
.project
.pydevproject
.settings/

# OS
.DS_Store
Thumbs.db

# Virtual environments
venv/
env/
ENV/
.Python

# Jupyter
.ipynb_checkpoints/
*.ipynb_checkpoints

# Logs
*.log
logs/

# Temporary files
*.tmp
*.temp
.cache/
5,817 changes: 5,817 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

146 changes: 146 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
[tool.poetry]
name = "samba-model"
version = "0.1.0"
description = "Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling"
authors = ["Your Name <[email protected]>"]
readme = "README.md"
packages = [{include = "lit_gpt"}]

[tool.poetry.dependencies]
python = "^3.9"
torch = ">=2.1.0"
lightning = "2.1.2"
jsonargparse = {extras = ["signatures"], version = "*"}
tokenizers = "*"
sentencepiece = "*"
wandb = "*"
torchmetrics = "*"
tensorboard = "*"
zstandard = "*"
pandas = "*"
pyarrow = "*"
huggingface_hub = "*"
einops = "*"
opt_einsum = "*"
packaging = "*"
azureml-mlflow = "*"
# Optional CUDA dependencies - install manually if CUDA is available
# xformers = {version = "*", optional = true}
# flash-attn = {version = "*", optional = true}
# causal-conv1d = {version = "*", optional = true}
# mamba-ssm = {version = "*", optional = true}
lm-eval = "*"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-cov = "^4.1.0"
pytest-mock = "^3.11.0"
pytest-xdist = "^3.3.0"
black = "^23.7.0"
isort = "^5.12.0"
flake8 = "^6.1.0"
mypy = "^1.5.0"
pre-commit = "^3.3.3"

[tool.poetry.scripts]
test = "pytest:main"
tests = "pytest:main"

[tool.pytest.ini_options]
minversion = "7.0"
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = [
"-ra",
"--strict-markers",
"--cov=lit_gpt",
"--cov-report=term-missing",
"--cov-report=html",
"--cov-report=xml",
# "--cov-fail-under=80", # Disabled for initial setup validation
"-v",
"--tb=short",
"--maxfail=1",
]
markers = [
"unit: Unit tests",
"integration: Integration tests",
"slow: Slow tests",
]
filterwarnings = [
"ignore::DeprecationWarning",
"ignore::PendingDeprecationWarning",
]

[tool.coverage.run]
source = ["lit_gpt"]
omit = [
"*/tests/*",
"*/test_*",
"*/__pycache__/*",
"*/site-packages/*",
"*/distutils/*",
"*/venv/*",
"*/.venv/*",
]

[tool.coverage.report]
precision = 2
show_missing = true
skip_covered = false
fail_under = 80

[tool.coverage.html]
directory = "htmlcov"

[tool.coverage.xml]
output = "coverage.xml"

[tool.isort]
profile = "black"
line_length = 120
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true

[tool.black]
line-length = 120
target-version = ['py39', 'py310', 'py311']
include = '\.pyi?$'
extend-exclude = '''
/(
# directories
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| build
| dist
)/
'''

[tool.mypy]
python_version = "3.9"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = false
disallow_incomplete_defs = false
check_untyped_defs = true
disallow_untyped_decorators = false
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
warn_unreachable = true
strict_optional = true
ignore_missing_imports = true

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Empty file added tests/__init__.py
Empty file.
160 changes: 160 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Shared pytest fixtures and configuration for all tests."""

import os
import tempfile
from pathlib import Path
from typing import Generator, Dict, Any
import pytest
import torch
from unittest.mock import MagicMock, Mock


@pytest.fixture
def temp_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for test files."""
with tempfile.TemporaryDirectory() as tmp_dir:
yield Path(tmp_dir)


@pytest.fixture
def mock_config() -> Dict[str, Any]:
"""Provide a mock configuration dictionary for testing."""
return {
"model_name": "test_model",
"block_size": 128,
"vocab_size": 50257,
"n_layer": 4,
"n_head": 4,
"n_embd": 128,
"rotary_percentage": 0.25,
"parallel_residual": True,
"bias": False,
"lm_head_bias": False,
"n_query_groups": 1,
"shared_attention_norm": False,
"norm_eps": 1e-5,
"intermediate_size": None,
"condense_ratio": 1,
}


@pytest.fixture
def mock_model_config(mock_config):
"""Create a mock model configuration object."""
from lit_gpt.config import Config
return Config(**mock_config)


@pytest.fixture
def device() -> torch.device:
"""Return the appropriate device for testing (CPU or CUDA if available)."""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


@pytest.fixture
def sample_tensor(device) -> torch.Tensor:
"""Create a sample tensor for testing."""
return torch.randn(2, 10, 128, device=device)


@pytest.fixture
def mock_tokenizer() -> Mock:
"""Create a mock tokenizer for testing."""
tokenizer = Mock()
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
tokenizer.decode = Mock(return_value="Hello world")
tokenizer.eos_id = 2
tokenizer.bos_id = 1
return tokenizer


@pytest.fixture
def mock_checkpoint(temp_dir: Path) -> Path:
"""Create a mock checkpoint file for testing."""
checkpoint_path = temp_dir / "mock_checkpoint.pt"
checkpoint_data = {
"model_state_dict": {"layer1.weight": torch.randn(10, 10)},
"optimizer_state_dict": {"param_groups": [{"lr": 0.001}]},
"epoch": 5,
"global_step": 1000,
}
torch.save(checkpoint_data, checkpoint_path)
return checkpoint_path


@pytest.fixture
def sample_text_data() -> str:
"""Provide sample text data for testing."""
return """The quick brown fox jumps over the lazy dog.
This is a sample text for testing purposes.
Machine learning models need diverse training data."""


@pytest.fixture
def mock_wandb(monkeypatch):
"""Mock wandb for testing without actual logging."""
mock_wandb_module = MagicMock()
mock_wandb_module.init = MagicMock()
mock_wandb_module.log = MagicMock()
mock_wandb_module.finish = MagicMock()
monkeypatch.setattr("wandb", mock_wandb_module)
return mock_wandb_module


@pytest.fixture
def mock_dataset(temp_dir: Path) -> Path:
"""Create a mock dataset file for testing."""
dataset_path = temp_dir / "mock_dataset.bin"
# Create a simple binary file with some data
data = torch.randint(0, 50257, (1000,), dtype=torch.uint16)
data.numpy().tofile(dataset_path)
return dataset_path


@pytest.fixture(autouse=True)
def reset_environment():
"""Reset environment variables before each test."""
# Store original environment
original_env = os.environ.copy()

yield

# Restore original environment
os.environ.clear()
os.environ.update(original_env)


@pytest.fixture
def capture_stdout(monkeypatch):
"""Capture stdout for testing print statements."""
from io import StringIO
buffer = StringIO()
monkeypatch.setattr("sys.stdout", buffer)
return buffer


# Markers for different test types
def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line("markers", "unit: Unit tests")
config.addinivalue_line("markers", "integration: Integration tests")
config.addinivalue_line("markers", "slow: Slow tests")


# Skip slow tests by default unless --runslow is provided
def pytest_addoption(parser):
"""Add custom command line options."""
parser.addoption(
"--runslow", action="store_true", default=False, help="run slow tests"
)


def pytest_collection_modifyitems(config, items):
"""Modify test collection to handle slow tests."""
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
Empty file added tests/integration/__init__.py
Empty file.
Loading