Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions scripts/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def main():
dtype=model.dtype,
block_size=32,
cache_memory_fraction=0.1,
max_num_seqs=1,
head_dim_v=v_head_dim,
layer_types=layer_types,
conv_dim=conv_dim,
Expand Down Expand Up @@ -230,8 +231,9 @@ def main():
)

sampling_info = SamplingBatchInfo.from_reqs([request])
token_sampling_info = None if sampling_info.is_all_greedy else sampling_info

next_token_id = model.logits_to_tokens(logits, context_lengths, sampling_info)
next_token_id = model.logits_to_tokens(logits, context_lengths, token_sampling_info)

token_id = int(next_token_id[0])
is_finished = token_id in eos_token_ids
Expand Down Expand Up @@ -267,7 +269,7 @@ def main():
state_slot_mapping=state_slot_mapping,
)

next_token_id = model.logits_to_tokens(logits, mx.array([1]), sampling_info)
next_token_id = model.logits_to_tokens(logits, sampling_info=token_sampling_info)

token_id = int(next_token_id[0])
is_finished = token_id in eos_token_ids
Expand Down
34 changes: 12 additions & 22 deletions src/parallax/server/cache/linear_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,22 +136,14 @@ def copy_slot(self, dst_slot_idx: int, src_slot_idx: int):
mx.eval(*arrays)

def read_states(self, slot_mapping: mx.array) -> Tuple[Optional[mx.array], Optional[mx.array]]:
conv_state_list = []
linear_state_list = []

for slot_idx in slot_mapping:
slot_idx = int(slot_idx)
if self.conv_state_cache is not None:
conv_state_slice = self.conv_state_cache[0, slot_idx]
conv_state_list.append(conv_state_slice[None, :, :])

if self.linear_state_cache is not None:
linear_state_slice = self.linear_state_cache[0, slot_idx]
linear_state_list.append(linear_state_slice[None, :, :, :])

conv_states = mx.concatenate(conv_state_list, axis=0) if conv_state_list else None
linear_states = mx.concatenate(linear_state_list, axis=0) if linear_state_list else None

conv_states = (
self.conv_state_cache[0, slot_mapping] if self.conv_state_cache is not None else None
)
linear_states = (
self.linear_state_cache[0, slot_mapping]
if self.linear_state_cache is not None
else None
)
return conv_states, linear_states

def write_states(
Expand All @@ -160,13 +152,11 @@ def write_states(
conv_states: Optional[mx.array],
linear_states: Optional[mx.array],
):
for i, slot_idx in enumerate(slot_mapping):
slot_idx = int(slot_idx)
if self.conv_state_cache is not None and conv_states is not None:
self.conv_state_cache[0, slot_idx] = conv_states[i]
if self.conv_state_cache is not None and conv_states is not None:
self.conv_state_cache[0, slot_mapping] = conv_states

if self.linear_state_cache is not None and linear_states is not None:
self.linear_state_cache[0, slot_idx] = linear_states[i]
if self.linear_state_cache is not None and linear_states is not None:
self.linear_state_cache[0, slot_mapping] = linear_states

def is_packed(self) -> bool:
"""LinearCache doesn't use packed format."""
Expand Down
Loading