Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions llm/server/server/scripts/start_server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@ export PYTHONIOENCODING=utf8
export LC_ALL=C.UTF-8

# PaddlePaddle environment variables
export FLAGS_allocator_strategy=auto_growth
export FLAGS_dynamic_static_unified_comm=0
export FLAGS_use_xqa_optim=1
export FLAGS_gemm_use_half_precision_compute_type=0
export NVIDIA_TF32_OVERRIDE=0

# Model hyperparameters
export MP_NUM=${MP_NUM:-"1"} # Number of GPUs
export MP_NUM=${MP_NUM:-"1"} # number of model parallelism
export MP_NNODES=${MP_NNODES:-"1"} # number of nodes
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"} # GPU ids
export MAX_SEQ_LEN=${MAX_SEQ_LEN:-"8192"}
export MAX_DEC_LEN=${MAX_DEC_LEN:-"2048"}
Expand Down Expand Up @@ -43,7 +41,26 @@ mkdir -p log
rm -rf console.log log/*
rm -rf /dev/shm/*

echo "start serving ..."
FED_POD_IP=$(hostname -i)
if [ "$MP_NNODE" -gt 1 ]; then
POD_0_IP=$POD_0_IP
HOST_IP=$FED_POD_IP
else
POD_0_IP="127.0.0.1"
HOST_IP="127.0.0.1"
fi

echo "POD_0_IP: $POD_0_IP HOST_IP: $HOST_IP"

if [ "$POD_0_IP" == "$HOST_IP" ]; then
echo "Master node, start serving ..."
else
echo "Slave node, start push mode"
# waiting for master node to start serving ...
sleep ${SERVER_WAITTING_TIME:-"25"}
fi



tritonserver --exit-timeout-secs 100 --cuda-memory-pool-byte-size 0:0 --cuda-memory-pool-byte-size 1:0 \
--cuda-memory-pool-byte-size 2:0 --cuda-memory-pool-byte-size 3:0 --cuda-memory-pool-byte-size 4:0 \
Expand Down
6 changes: 4 additions & 2 deletions llm/server/server/server/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from paddlenlp.trl.llm_utils import get_eos_token_id
from server.engine.config import Config
from server.utils import data_processor_logger
from paddlenlp.utils.env import USE_FAST_TOKENIZER


class BaseDataProcessor(ABC):
Expand Down Expand Up @@ -121,7 +122,8 @@ class DataProcessor(BaseDataProcessor):
def __init__(self):
self.config = Config()
max_length = self.config.get_model_config().get('max_length', 1024)
self.src_length = max_length - self.config.seq_len_limit
self.src_length = self.config.seq_len_limit - max_length


self.decode_status = dict()
self.tokenizer = self._load_tokenizer()
Expand Down Expand Up @@ -288,7 +290,7 @@ def _load_tokenizer(self):
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False)
else:
from paddlenlp.transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(self.config.model_dir)
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=USE_FAST_TOKENIZER)

def clear_request_status(self, task_id):
"""
Expand Down
21 changes: 21 additions & 0 deletions llm/server/server/server/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def read_from_env(self):
else:
raise Exception(f"unsupported device type: {self.device}")

# multi-node config
self.nnode = int(env.get("MP_NNODE", "1"))
assert self.mp_num % self.nnode == 0 ,f"mp_num: {self.mp_num} should be divisible by nnode: {self.nnode}"
self.mp_num_per_node = self.mp_num // self.nnode
self.host_ip = os.getenv("HOST_IP", "127.0.0.1")

# Triton config
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_BATCH", 1))
if self.max_prefill_batch <= 0:
Expand Down Expand Up @@ -93,6 +99,7 @@ def read_from_env(self):
self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))


# infer config
self.max_batch_size = int(env.get("BATCH_SIZE", 50))
self.max_seq_len = int(env.get("MAX_SEQ_LEN", 8192))
Expand Down Expand Up @@ -168,6 +175,20 @@ def check(self):
f"which means the exported MAX_DEC_LEN should less than "
f"{self.max_seq_len}, but now it's {self.dec_len_limit}."
)
if os.getenv("DISABLE_CAPACITY_CHECKER", "0") == 1:
# max_output_token_num
max_output_token_num = (self.total_block_num - self.max_block_num) * self.block_size + self.enc_dec_block_num * self.block_size
assert max_output_token_num >= self.dec_len_limit, (
f"The available output token number of the service is {max_output_token_num}, "
f"which is less than the setting MAX_DEC_LEN:{self.dec_len_limit}. "
)

# Maximum input length of a single query that the service can handle
max_input_token_num = int(math.floor(self.max_block_num * self.block_size - self.dec_token_num))
assert max_input_token_num >= self.seq_len_limit, (
f"The available input token number of the service is {max_input_token_num}, "
f"which is less than the setting MAX_SEQ_LEN:{self.seq_len_limit}. "
)

def print(self, file=None):
"""
Expand Down
9 changes: 6 additions & 3 deletions llm/server/server/server/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def start(self):
"""
assert not self.is_started, "The engine is already started.!"
start_time = time.time()
self.queue_service = self._start_tasks_queue_service()
# Master node only
if self.cfg.nnode == 1 or self.cfg.host_ip == os.getenv('POD_0_IP', '127.0.0.1'):
self.queue_service = self._start_tasks_queue_service()
self.tasks_queue = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_port)

self.token_processor.tasks_queue = self.tasks_queue
Expand Down Expand Up @@ -258,7 +260,7 @@ def _infer_processes_ready(self):
Returns:
return: True if all ready, False otherwise
"""
if np.sum(self.flag_ready_array) == self.cfg.mp_num:
if np.sum(self.flag_ready_array) == self.cfg.mp_num_per_node:
return True
return False

Expand Down Expand Up @@ -378,7 +380,8 @@ def _start_gpu_infer_service(self):
pd_cmd = "python3 -m paddle.distributed.launch "
py_script = os.path.join(current_dir_path, "infer.py")

arguments = (f" --devices {self.cfg.device_ids} {py_script} --model_dir {self.cfg.model_dir}"
arguments = (f" --nnodes {str(self.cfg.nnode)}"
f" --devices {self.cfg.device_ids} {py_script} --model_dir {self.cfg.model_dir}"
f" --max_batch_size {self.cfg.max_batch_size} --max_seq_len {self.cfg.max_seq_len}"
f" --max_dec_len {self.cfg.max_dec_len}"
f" --max_block_num {self.cfg.total_block_num} --block_size {self.cfg.block_size}"
Expand Down
100 changes: 43 additions & 57 deletions llm/server/server/server/engine/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ def __init__(self, args):
self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
self.args.num_attention_heads = self.get_value(self.model_cfg, ["num_attention_heads", "n_head"])
self.args.hidden_size = self.model_cfg["hidden_size"]
if "deepseek" in self.model_cfg["model_type"]:
self.qk_nope_head_dim = int(self.model_cfg["qk_nope_head_dim"])
self.qk_rope_head_dim = int(self.model_cfg["qk_rope_head_dim"])
self.v_head_dim = int(self.model_cfg["v_head_dim"])

self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))

self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", 5))
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", 8))
Expand Down Expand Up @@ -181,15 +184,26 @@ def init_inputs(self):
cache_type = self.args.dtype
else:
cache_type = "uint8"

self.cache_kvs["key_caches_{}".format(i)] = paddle.full(shape=[
self.args.max_block_num, kv_num_head,
self.args.block_size, self.args.hidden_size // self.args.num_attention_heads
], fill_value=0, dtype=cache_type)
self.cache_kvs["value_caches_{}".format(i)] = paddle.full(shape=[
self.args.max_block_num, kv_num_head,
self.args.block_size, self.args.hidden_size // self.args.num_attention_heads
], fill_value=0, dtype=cache_type)

if "deepseek" in self.model_cfg["model_type"]:
self.cache_kvs["key_caches_{}".format(i)] = paddle.full(shape=[
self.args.max_block_num, kv_num_head,
self.args.block_size,
self.qk_nope_head_dim + self.qk_rope_head_dim
], fill_value=0, dtype=cache_type)
self.cache_kvs["value_caches_{}".format(i)] = paddle.full(shape=[
self.args.max_block_num, kv_num_head,
self.args.block_size, self.v_head_dim
], fill_value=0, dtype=cache_type)
else:
self.cache_kvs["key_caches_{}".format(i)] = paddle.full(shape=[
self.args.max_block_num, kv_num_head,
self.args.block_size, self.args.hidden_size // self.args.num_attention_heads
], fill_value=0, dtype=cache_type)
self.cache_kvs["value_caches_{}".format(i)] = paddle.full(shape=[
self.args.max_block_num, kv_num_head,
self.args.block_size, self.args.hidden_size // self.args.num_attention_heads
], fill_value=0, dtype=cache_type)

pre_max_block_num = (self.args.max_seq_len + self.args.block_size - 1) // self.args.block_size + self.args.enc_dec_block_num
self.share_inputs["block_tables"] = paddle.full(
Expand Down Expand Up @@ -273,11 +287,11 @@ def init_inputs(self):
fill_value=-1,
dtype="int64")

if self.reduce_dialogue_repetition:
self.share_inputs["first_token_ids"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")

self.share_inputs["first_token_ids"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
# speculate decoding input
if self.is_speculate_decoding:
self.share_inputs["accept_tokens"] = paddle.full(
Expand Down Expand Up @@ -324,9 +338,9 @@ def dy_input_preprocess(self, tasks):
self.share_inputs['max_length'][idx:idx + 1] = max_dec_len
self.share_inputs['stop_flags'][idx:idx + 1] = False

if self.reduce_dialogue_repetition:
self.share_inputs['first_token_ids'][idx:idx + 1] = self.share_inputs['input_ids'][idx:idx + 1, :1]
self.share_inputs["ori_seq_lens_encoder"][idx:idx + 1] = length

self.share_inputs['first_token_ids'][idx:idx + 1] = self.share_inputs['input_ids'][idx:idx + 1, :1]
self.share_inputs["ori_seq_lens_encoder"][idx:idx + 1] = length

if "infer_seed" in task:
self.share_inputs['infer_seed'][idx:idx + 1] = task['infer_seed']
Expand Down Expand Up @@ -371,9 +385,8 @@ def step_cuda(self, seq_lens_this_time):
self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id,
speculate_step_token_num)
self.share_inputs['step_idx'], self.share_inputs['next_tokens'], self.share_inputs['first_token_ids'],
self.args.block_size, self.args.enc_dec_block_num, 0)

def initialize_engine_ready_check_flag(self):
"""
Expand Down Expand Up @@ -460,21 +473,25 @@ def run(self):
if use_custom_health_checker:
engine_healthy_recorded_time_array[0] = time.time()

if self.rank == 0:
if self.rank % self.config.mp_num_per_node == 0:
if not self.infer_queue.empty():
flag_broadcast_array[0] = 1
if self.config.nnode > 1:
self.infer_queue.read_finish_flag.set(1)
else:
flag_broadcast_array[0] = 1

if self.nranks > 1:
paddle.distributed.barrier()

if flag_broadcast_array[0] == 1:
if flag_broadcast_array[0] == 1 or self.infer_queue.read_finish_flag.get() == 1:
logger.info(f'rank: {self.rank} start to get')
if seq_lens_this_time is not None:
self.share_inputs["seq_lens_this_time"][:real_bsz] = seq_lens_this_time

tasks, read_finish = self.infer_queue.get()
if read_finish:
flag_broadcast_array[0] = 0
self.infer_queue.read_finish_flag.set(0)

req_dicts = []
for req_dict, bsz in tasks:
Expand Down Expand Up @@ -542,7 +559,7 @@ def _init_predictor(self):
"""
predictor init
"""
device_id = self.rank % 8
device_id = self.rank % self.config.mp_num_per_node
if use_pir_api():
self.model_file = os.path.join(self.model_dir, f"model.json")
self.param_file = os.path.join(self.model_dir, f"model.pdiparams")
Expand All @@ -553,30 +570,10 @@ def _init_predictor(self):

config.enable_use_gpu(100, device_id)

pir_flag = int(os.environ.get("FLAGS_enable_pir_api", 0))
if pir_flag == 1:
if use_pir_api():
config.enable_new_executor()
config.enable_new_ir()

# distributed config
if self.mp_degree > 1:
trainer_endpoints = fleet.worker_endpoints()
current_endpoint = trainer_endpoints[self.rank]
dist_config = config.dist_config()
dist_config.set_ranks(self.nranks, self.rank)
dist_config.set_endpoints(trainer_endpoints, current_endpoint)
dist_config.enable_dist_model(True)
if self.config.distributed_config_path:
dist_config.set_comm_init_config(self.config.distributed_config_path)
else:
raise Exception("Please set DISTRIBUTED_CONFIG env variable.")
logger.warning(
f"Use default distributed config, please set env DISTRIBUTED_CONFIG"
)
dist_config.set_comm_init_config(
os.path.join(Dir_Path + "/config", "rank_mapping_mp{}.csv".format(self.nranks)))

config.set_dist_config(dist_config)
self.predictor = paddle.inference.create_predictor(config)
self.input_names = self.predictor.get_input_names()
self.seq_lens_handle = self.predictor.get_input_handle('seq_lens_this_time')
Expand All @@ -595,17 +592,6 @@ def share_data(self):
input_tensor = self.predictor.get_input_handle(name)
input_tensor.share_external_data(self.share_inputs[name])

def predict(self, real_bsz):
"""
predict
"""
seq_lens_this_time = copy.deepcopy(
self.share_inputs['seq_lens_this_time'][:real_bsz])
self.seq_lens_handle.share_external_data(seq_lens_this_time)
self.share_inputs['not_need_stop'][0] = True
while self.share_inputs['not_need_stop']:
self.predictor.run()
self.share_inputs["seq_lens_this_time"][:real_bsz] = seq_lens_this_time


def parse_args():
Expand Down
8 changes: 6 additions & 2 deletions llm/server/server/server/engine/task_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def __init__(self, rank=0, mp_num=8, port=56666):
QueueManager.register('get_barrier1')
QueueManager.register('get_barrier2')
QueueManager.register('get_queue')
QueueManager.register('get_read_finish_flag')

self.client_manager = QueueManager(address=('127.0.0.1', port),
self.client_manager = QueueManager(address=(os.getenv("POD_0_IP","127.0.0.1"), port),
authkey=b'infer_queue'
)
self.client_manager.connect()
Expand All @@ -60,6 +61,7 @@ def __init__(self, rank=0, mp_num=8, port=56666):
self.barrier1 = self.client_manager.get_barrier1()
self.barrier2 = self.client_manager.get_barrier2()
self.queue = self.client_manager.get_queue()
self.read_finish_flag = self.client_manager.get_read_finish_flag()
self.mp_num = mp_num
self.rank = rank
self.position = 1 << rank
Expand Down Expand Up @@ -155,7 +157,9 @@ def launch_queue_service(port, num_workers):
QueueManager.register('get_barrier2', callable=lambda: barrier2)
q = Queue()
QueueManager.register("get_queue", callable=lambda: q)
m = QueueManager(address=('127.0.0.1', port), authkey=b'infer_queue')
read_finish_flag = Value("i", 0)
QueueManager.register("get_read_finish_flag", callable=lambda: read_finish_flag, proxytype=ValueProxy)
m = QueueManager(address=(os.getenv("POD_0_IP","127.0.0.1"), port), authkey=b'infer_queue')
s = m.get_server()
logger.info("launch queue service success")
s.serve_forever()
Expand Down
4 changes: 3 additions & 1 deletion llm/server/server/server/triton_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def initialize(self, args):
self.engine.start()
model_server_logger.info("Create engine success")

self._initialize_push_mode()
# Master node only
if self.cfg.nnode == 1 or os.getenv('POD_0_IP',"127.0.0.1") == self.cfg.host_ip:
self._initialize_push_mode()
model_server_logger.info("Init triton server success")


Expand Down
2 changes: 1 addition & 1 deletion llm/server/server/server/triton_server_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def check_infer_engine_process():
return:
status: bool, True if process is alive else False
"""
mp_num = int(env_config.mp_num)
mp_num = int(env_config.mp_num_per_node)
for i in range(mp_num):
try:
infer_live_flag_shm = shared_memory.SharedMemory(name=env_config.get_unique_name("shm_flag_infer_{}_live".format(i)))
Expand Down
Loading