Thanks to visit codestin.com
Credit goes to github.com

Skip to content

issue in training #135

@Ashbajawed

Description

@Ashbajawed

Can anyone here help me with the training code, Im using litgpt and getting following error

AssertionError: Logits size 15456 does not match labels size 4096

CODE SNAPSHOt
`def collate_fn(batch, text_tokenizer, max_length=512):
"""Collate function to prepare batch for training with custom input ID generation."""
inputs, labels = zip(*batch)

# Use get_input_ids_TT function for tokenizing and creating input ids
input_encodings = [get_input_ids_TT(text, text_tokenizer) for text in inputs]

# Pad the input encodings to the maximum length
input_ids = [encoding[-1] for encoding in input_encodings]
max_input_length = max([input_id.shape[1] for input_id in input_ids])  # Find max length

# Pad the input sequences to ensure they match in size
padded_input_ids = [
    torch.cat([input_id, torch.zeros(1, max_input_length - input_id.shape[1], dtype=torch.long)], dim=1)
    if input_id.shape[1] < max_input_length else input_id
    for input_id in input_ids
]

# Stack the padded input_ids into a batch tensor
input_ids_tensor = torch.cat(padded_input_ids, dim=0)

# Tokenize labels using the tokenizer (standard behavior)
label_encodings = [get_input_ids_TT(text, text_tokenizer) for text in labels]
label_ids = [encoding[-1] for encoding in label_encodings]
max_label_length = max([label_id.shape[1] for label_id in label_ids])  # Find max label length

# Pad the label sequences to ensure they match in size
padded_label_ids = [
    torch.cat([label_id, torch.zeros(1, max_label_length - label_id.shape[1], dtype=torch.long)], dim=1)
    if label_id.shape[1] < max_label_length else label_id
    for label_id in label_ids
]

# Stack the padded label_ids into a batch tensor
labels_tensor = torch.cat(padded_label_ids, dim=0)

# Add input_ids and labels to input_encodings
input_encodings = {
    'input_ids': input_ids_tensor,
    'labels': labels_tensor  # Ensure labels are padded and batched
}

return input_encodings

Prepare DataLoader

train_loader = DataLoader(train_examples, batch_size=128, shuffle=True, collate_fn=lambda x: collate_fn(x, text_tokenizer))
val_loader = DataLoader(val_examples, batch_size=128, shuffle=True, collate_fn=lambda x: collate_fn(x, text_tokenizer))

def _forward(batch):
input_ids = batch['input_ids'].to(device)
input_ids = tuple(input_ids.chunk(8, dim=0)) # Chunking if necessary

labels = batch['labels'].to(device)
audio_features = None  # Audio features are None, change if you use audio features

optimizer.zero_grad()

# Pass input_ids and audio_features to the model
outputs = model(input_ids=input_ids, audio_features=audio_features)

# Assuming model output is a tuple (logits, ...)
logits = outputs[0]

# Check if logits are a list, if yes, combine them
if isinstance(logits, list):
    logits = torch.cat(logits, dim=0)

# Get the batch size and sequence length
batch_size = logits.size(0)
seq_length = logits.size(1)

# Reshape logits to match the labels (flattening the batch and sequence dimensions)
logits = logits.view(-1, logits.size(-1))  # Shape: [batch_size * seq_length, vocab_size]

# Reshape labels to match the logits after flattening
labels = labels.view(-1)  # Shape: [batch_size * seq_length]

# Ensure the number of logits matches the number of labels
assert logits.size(0) == labels.size(0), f"Logits size {logits.size(0)} does not match labels size {labels.size(0)}"

return logits, labels

def train_model(model, train_loader, val_loader, optimizer, device, epochs=3, save_path="litgpt_model.pth"):
"""
Train the model using the specified parameters.
"""
model.train()
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    total_loss = 0

    # Training loop
    for batch in tqdm(train_loader, desc="Training"):
        logits, labels = _forward(batch)
        
        # Make sure the shapes match
        print(f"Logits shape: {logits.shape}, Labels shape: {labels.shape}")
        
        # Calculate the loss
        loss = criterion(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    print(f"Training Loss: {avg_train_loss:.4f}")

    # Validation loop
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            logits, labels = _forward(batch)
            val_loss = criterion(logits, labels)
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")

# Save model checkpoint
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

Initialize optimizer

optimizer = AdamW(model.parameters(), lr=5e-5)

Train the model

train_model(model, train_loader, val_loader, optimizer, device=device, epochs=2, save_path="checkpoints/litgpt_trained_model.pth")
`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions