-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
110 lines (93 loc) · 3.01 KB
/
utils.py
File metadata and controls
110 lines (93 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from qwen_vl_utils import process_vision_info
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
from transformers import Qwen2_5_VLProcessor
from typing import List, Dict, Any
from drivefusion import DriveFusionProcessor, DriveFusionForConditionalGeneration
def generate_drivefusion_output(
model: DriveFusionForConditionalGeneration,
processor: DriveFusionProcessor,
message: List,
speed: List = None,
gps: List = None,
use_queries: bool = None,
max_new_tokens: int = 4000,
device: str = "cuda"
) -> Dict[str, Any]:
model.eval()
# Prepare the text input by applying the chat template
text_input = processor.apply_chat_template(
message, # Use the sample without the system message
tokenize=False,
add_generation_prompt=True
)
# print(text_input)
# Process the visual input from the sample
image_inputs, _ = process_vision_info(message)
# Prepare the inputs for the model
model_inputs = processor(
text=[text_input],
images=image_inputs,
return_tensors="pt",
speed=speed,
gps_points=gps,
use_queries=use_queries
).to(device) # Move inputs to the specified device
# Generate text with the model
with torch.inference_mode():
generated_ids, trajectory, target_speeds = model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
do_sample=True
)
# Trim the generated ids to remove the input ids
trimmed_generated_ids = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
]
# Decode the output text
output_text = processor.batch_decode(
trimmed_generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return {
"text": output_text[0],
"trajectory": trajectory,
"target_speeds": target_speeds
}
def load_drivefusion_model(
model_id: str,
dtype: torch.dtype = torch.float16,
redownload: bool = False
):
model = DriveFusionForConditionalGeneration.from_pretrained(
model_id,
device_map="cuda",
torch_dtype=dtype,
force_download=redownload,
trust_remote_code=True,
)
processor = DriveFusionProcessor.from_pretrained(
model_id,
force_download=redownload,
use_fast=True,
trust_remote_code=True,
)
return model , processor
def load_qwen2_5_vl_model(
model_id: str,
dtype: torch.dtype = torch.bfloat16,
redownload: bool = False
):
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
device_map="cuda:0",
torch_dtype=dtype,
force_download=redownload,
)
processor = Qwen2_5_VLProcessor.from_pretrained(
model_id,
force_download=redownload,
use_fast=True,
)
return model , processor