diff --git a/scripts/generate.py b/scripts/generate.py index 0348a205..1bd77ef8 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -162,6 +162,7 @@ def main(): dtype=model.dtype, block_size=32, cache_memory_fraction=0.1, + max_num_seqs=1, head_dim_v=v_head_dim, layer_types=layer_types, conv_dim=conv_dim, @@ -230,8 +231,9 @@ def main(): ) sampling_info = SamplingBatchInfo.from_reqs([request]) + token_sampling_info = None if sampling_info.is_all_greedy else sampling_info - next_token_id = model.logits_to_tokens(logits, context_lengths, sampling_info) + next_token_id = model.logits_to_tokens(logits, context_lengths, token_sampling_info) token_id = int(next_token_id[0]) is_finished = token_id in eos_token_ids @@ -267,7 +269,7 @@ def main(): state_slot_mapping=state_slot_mapping, ) - next_token_id = model.logits_to_tokens(logits, mx.array([1]), sampling_info) + next_token_id = model.logits_to_tokens(logits, sampling_info=token_sampling_info) token_id = int(next_token_id[0]) is_finished = token_id in eos_token_ids diff --git a/src/parallax/server/cache/linear_cache.py b/src/parallax/server/cache/linear_cache.py index 920c5215..398849fa 100644 --- a/src/parallax/server/cache/linear_cache.py +++ b/src/parallax/server/cache/linear_cache.py @@ -136,22 +136,14 @@ def copy_slot(self, dst_slot_idx: int, src_slot_idx: int): mx.eval(*arrays) def read_states(self, slot_mapping: mx.array) -> Tuple[Optional[mx.array], Optional[mx.array]]: - conv_state_list = [] - linear_state_list = [] - - for slot_idx in slot_mapping: - slot_idx = int(slot_idx) - if self.conv_state_cache is not None: - conv_state_slice = self.conv_state_cache[0, slot_idx] - conv_state_list.append(conv_state_slice[None, :, :]) - - if self.linear_state_cache is not None: - linear_state_slice = self.linear_state_cache[0, slot_idx] - linear_state_list.append(linear_state_slice[None, :, :, :]) - - conv_states = mx.concatenate(conv_state_list, axis=0) if conv_state_list else None - linear_states = mx.concatenate(linear_state_list, axis=0) if linear_state_list else None - + conv_states = ( + self.conv_state_cache[0, slot_mapping] if self.conv_state_cache is not None else None + ) + linear_states = ( + self.linear_state_cache[0, slot_mapping] + if self.linear_state_cache is not None + else None + ) return conv_states, linear_states def write_states( @@ -160,13 +152,11 @@ def write_states( conv_states: Optional[mx.array], linear_states: Optional[mx.array], ): - for i, slot_idx in enumerate(slot_mapping): - slot_idx = int(slot_idx) - if self.conv_state_cache is not None and conv_states is not None: - self.conv_state_cache[0, slot_idx] = conv_states[i] + if self.conv_state_cache is not None and conv_states is not None: + self.conv_state_cache[0, slot_mapping] = conv_states - if self.linear_state_cache is not None and linear_states is not None: - self.linear_state_cache[0, slot_idx] = linear_states[i] + if self.linear_state_cache is not None and linear_states is not None: + self.linear_state_cache[0, slot_mapping] = linear_states def is_packed(self) -> bool: """LinearCache doesn't use packed format."""