-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathtransformers_server.py
More file actions
146 lines (118 loc) · 4.74 KB
/
transformers_server.py
File metadata and controls
146 lines (118 loc) · 4.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from flask import Flask, request, jsonify # type: ignore
import logging
import argparse
from transformers import AutoModelForCausalLM
import torch
from PIL import Image
import base64
import io
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# Global variables
model = None
text_tokenizer = None
visual_tokenizer = None
def parse_args():
parser = argparse.ArgumentParser(description="Transformers Server")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--max-memory", type=float, default=0.9)
parser.add_argument("--torch-dtype", type=str, default="bfloat16")
return parser.parse_args()
def load_model(args):
global model, text_tokenizer, visual_tokenizer
logger.info(f"Loading model {args.model}...")
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=dtype_map[args.torch_dtype],
multimodal_max_length=32768,
trust_remote_code=True,
device_map=args.device,
)
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()
logger.info("Model loaded successfully!")
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "healthy"}), 200
@app.route("/v1/chat/completions", methods=["POST"])
def chat_completions():
logger.info("Received request to /v1/chat/completions")
try:
data = request.json
messages = data.get("messages", [])
if not messages:
return jsonify({"error": "No messages provided"}), 400
content = messages[0].get("content", [])
prompt = None
images = []
media_type = "image" # Default to image
for item in content:
if item["type"] == "text":
prompt = item["text"]
elif item["type"] == "image_url":
image_url = item["image_url"]["url"]
image_data = base64.b64decode(image_url.split(",")[1])
images.append(Image.open(io.BytesIO(image_data)))
elif item["type"] == "video_frames":
media_type = "video"
# Get pre-extracted frames from the request
frame_list = item["frames"]
for frame_data in frame_list:
image_data = base64.b64decode(frame_data.split(",")[1])
images.append(Image.open(io.BytesIO(image_data)))
logger.info(f"Received {len(images)} pre-extracted video frames")
if not prompt or not images:
return jsonify({"error": "Missing prompt or media"}), 400
# Prepare query based on number of images
if media_type == "video" or len(images) > 1:
query = "\n".join(["<image>"] * len(images)) + "\n" + prompt
else:
query = f"<image>\n{prompt}"
prompt, input_ids, pixel_values = model.preprocess_inputs(
query, images, max_partition=9 if media_type == "image" else 1
)
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
input_ids = input_ids.unsqueeze(0).to(device=model.device)
attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
if pixel_values is not None:
pixel_values = pixel_values.to(
dtype=visual_tokenizer.dtype, device=visual_tokenizer.device
)
pixel_values = [pixel_values]
with torch.inference_mode():
gen_kwargs = dict(
max_new_tokens=1024,
do_sample=False,
top_p=None,
top_k=None,
temperature=None,
repetition_penalty=None,
eos_token_id=model.generation_config.eos_token_id,
pad_token_id=text_tokenizer.pad_token_id,
use_cache=True,
)
output_ids = model.generate(
input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
**gen_kwargs,
)[0]
output = text_tokenizer.decode(output_ids, skip_special_tokens=True)
return jsonify({"choices": [{"message": {"content": output}}]})
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
args = parse_args()
load_model(args)
logger.info(f"Starting server on port {args.port}")
app.run(host="0.0.0.0", port=args.port)