-
Notifications
You must be signed in to change notification settings - Fork 83
Open
Labels
enhancementNew feature or requestNew feature or request
Milestone
Description
Currently, we require the user to do this themselves:
class MyGPT2(nn.Module):
def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
super().__init__()
config = GPT2Config.from_pretrained("gpt2")
config.pad_token_id = tokenizer.pad_token_id
config.num_labels = 2
self.hf_model = GPT2ForSequenceClassification.from_pretrained(
"gpt2", config=config
)
def forward(self, data: MutableMapping) -> torch.Tensor:
device = next(self.parameters()).device
input_ids = data["input_ids"].to(device)
attn_mask = data["attention_mask"].to(device)
output_dict = self.hf_model(input_ids=input_ids, attention_mask=attn_mask)
return output_dict.logitsCan we provide a generic wrapper? I suspect the use case has very little variance. Maybe something like:
def huggingface_wrapper(hf_model):
class WrappedModel(nn.Module):
def __init__(self, hf_model: PretrainedModel):
self.hf_model = hf_model
def forward(self, data):
output_dict = self.hf_model(**data)
return output_dict.logits
return WrappedModel(hf_model)Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request