-
Couldn't load subscription status.
- Fork 294
Description
Can anyone here help me with the training code, Im using litgpt and getting following error
AssertionError: Logits size 15456 does not match labels size 4096
CODE SNAPSHOt
`def collate_fn(batch, text_tokenizer, max_length=512):
"""Collate function to prepare batch for training with custom input ID generation."""
inputs, labels = zip(*batch)
# Use get_input_ids_TT function for tokenizing and creating input ids
input_encodings = [get_input_ids_TT(text, text_tokenizer) for text in inputs]
# Pad the input encodings to the maximum length
input_ids = [encoding[-1] for encoding in input_encodings]
max_input_length = max([input_id.shape[1] for input_id in input_ids]) # Find max length
# Pad the input sequences to ensure they match in size
padded_input_ids = [
torch.cat([input_id, torch.zeros(1, max_input_length - input_id.shape[1], dtype=torch.long)], dim=1)
if input_id.shape[1] < max_input_length else input_id
for input_id in input_ids
]
# Stack the padded input_ids into a batch tensor
input_ids_tensor = torch.cat(padded_input_ids, dim=0)
# Tokenize labels using the tokenizer (standard behavior)
label_encodings = [get_input_ids_TT(text, text_tokenizer) for text in labels]
label_ids = [encoding[-1] for encoding in label_encodings]
max_label_length = max([label_id.shape[1] for label_id in label_ids]) # Find max label length
# Pad the label sequences to ensure they match in size
padded_label_ids = [
torch.cat([label_id, torch.zeros(1, max_label_length - label_id.shape[1], dtype=torch.long)], dim=1)
if label_id.shape[1] < max_label_length else label_id
for label_id in label_ids
]
# Stack the padded label_ids into a batch tensor
labels_tensor = torch.cat(padded_label_ids, dim=0)
# Add input_ids and labels to input_encodings
input_encodings = {
'input_ids': input_ids_tensor,
'labels': labels_tensor # Ensure labels are padded and batched
}
return input_encodings
Prepare DataLoader
train_loader = DataLoader(train_examples, batch_size=128, shuffle=True, collate_fn=lambda x: collate_fn(x, text_tokenizer))
val_loader = DataLoader(val_examples, batch_size=128, shuffle=True, collate_fn=lambda x: collate_fn(x, text_tokenizer))
def _forward(batch):
input_ids = batch['input_ids'].to(device)
input_ids = tuple(input_ids.chunk(8, dim=0)) # Chunking if necessary
labels = batch['labels'].to(device)
audio_features = None # Audio features are None, change if you use audio features
optimizer.zero_grad()
# Pass input_ids and audio_features to the model
outputs = model(input_ids=input_ids, audio_features=audio_features)
# Assuming model output is a tuple (logits, ...)
logits = outputs[0]
# Check if logits are a list, if yes, combine them
if isinstance(logits, list):
logits = torch.cat(logits, dim=0)
# Get the batch size and sequence length
batch_size = logits.size(0)
seq_length = logits.size(1)
# Reshape logits to match the labels (flattening the batch and sequence dimensions)
logits = logits.view(-1, logits.size(-1)) # Shape: [batch_size * seq_length, vocab_size]
# Reshape labels to match the logits after flattening
labels = labels.view(-1) # Shape: [batch_size * seq_length]
# Ensure the number of logits matches the number of labels
assert logits.size(0) == labels.size(0), f"Logits size {logits.size(0)} does not match labels size {labels.size(0)}"
return logits, labels
def train_model(model, train_loader, val_loader, optimizer, device, epochs=3, save_path="litgpt_model.pth"):
"""
Train the model using the specified parameters.
"""
model.train()
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(epochs):
print(f"Epoch {epoch + 1}/{epochs}")
total_loss = 0
# Training loop
for batch in tqdm(train_loader, desc="Training"):
logits, labels = _forward(batch)
# Make sure the shapes match
print(f"Logits shape: {logits.shape}, Labels shape: {labels.shape}")
# Calculate the loss
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_train_loss = total_loss / len(train_loader)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation loop
model.eval()
total_val_loss = 0
with torch.no_grad():
for batch in tqdm(val_loader, desc="Validation"):
logits, labels = _forward(batch)
val_loss = criterion(logits, labels)
total_val_loss += val_loss.item()
avg_val_loss = total_val_loss / len(val_loader)
print(f"Validation Loss: {avg_val_loss:.4f}")
# Save model checkpoint
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")
Initialize optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)
Train the model
train_model(model, train_loader, val_loader, optimizer, device=device, epochs=2, save_path="checkpoints/litgpt_trained_model.pth")
`