[TextGeneration] max token refactor#1217
Conversation
dbogunowicz
left a comment
There was a problem hiding this comment.
The good: much less additional code and complexity that I thought
The bad and ugly: could you add appropriate tests in tests/deepsparse/transformers/pipelines/test_text_generation.py ?
d1d7b7a to
735d33d
Compare
mgoin
left a comment
There was a problem hiding this comment.
Could max_tokens default to the prompt_length - sequence_length so we don't risk running out of kv cache context? I'm not sure what happens there actually, especially when using internal kv cache
@dbogunowicz would like to get your opinion on this |
|
@dsikka this is a very good idea. |
- Remove max_generated_tokens from the constructor and add it to the TextGenerationInput Schema - Add num_generated_predictions to the TextGenerationInput which if > 1, repeats the input sequence and turns off deterministic prediction. If a sequence is already provided multiple times, the sequence is not repeated.
55d1c08 to
b5de75f
Compare
Talking to the MLE team, I think for now we want to keep the defaults as is and update them once we've established best practices. |
For this ticket:
https://app.asana.com/0/1201735099598270/1205276886236972/f
Summary:
max_tokensargument and makes it part of the pipeline inputnum_generated_predictionsto the input as well, which dictates the number of sequences that are generated for a given input. Similar to the hugging face implementation, we repeat the input based on the number provided, defaulting to 1. Whennum_generated_predictionsis > 1, the engine'sdeterministicproperty is togged to Falsestr, List[str], and List[List[str]]Testing
num_generated_predictions