I’m working on a service that can stream LLM responses and I want to make it compatible with batch processing.
Previously I was using the TextIteratorStreamer object to handle the streaming but this is incompatible with batching (ValueError(“TextStreamer only supports batch size 1”)
Is there any plans on making this feature compatible with batching, or is there any other tool I can use instead?
Would greatly appreciate any advice - cheers!
1 Like
you can try this code, I just modified it a little bit:
class BatchTextIteratorStreamer(TextIteratorStreamer):
def __init__(self, batch_size:int, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
self.batch_size = batch_size
self.token_cache = [[] for _ in range(batch_size)]
self.print_len = [0 for _ in range(batch_size)]
self.generate_exception = None
def put(self, value):
if len(value.shape) != 2:
value = torch.reshape(value, (self.batch_size, value.shape[0] // self.batch_size))
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
printable_texts = list()
for idx in range(self.batch_size):
self.token_cache[idx].extend(value[idx].tolist())
text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
if text.endswith("\n"):
printable_text = text[self.print_len[idx] :]
self.token_cache[idx] = []
self.print_len[idx] = 0
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len[idx] :]
self.print_len[idx] += len(printable_text)
else:
printable_text = text[self.print_len[idx] : text.rfind(" ") + 1]
self.print_len[idx] += len(printable_text)
printable_texts.append(printable_text)
self.on_finalized_text(printable_texts)
def end(self):
printable_texts = list()
for idx in range(self.batch_size):
if len(self.token_cache[idx]) > 0:
text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
printable_text = text[self.print_len[idx] :]
self.token_cache[idx] = []
self.print_len[idx] = 0
else:
printable_text = ""
printable_texts.append(printable_text)
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_texts, stream_end=True)
def on_finalized_text(self, texts: List[str], stream_end: bool = False):
self.text_queue.put(texts, timeout=self.timeout)
if stream_end:
self.text_queue.put(self.stop_signal, timeout=self.timeout)
5 Likes
Thanks a lot. It worked for me!!
1 Like
Pardon my curiosity, but in the spirit of knowledge sharing, can you please elaborate some on why you want a streamer to support batches?