Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 81a651d

Browse files
Fix checkmarx security issues (#227)
Signed-off-by: Lv, Liang1 <liang1.lv@intel.com>
1 parent 285ab60 commit 81a651d

File tree

6 files changed

+86
-11
lines changed

6 files changed

+86
-11
lines changed

intel_extension_for_transformers/neural_chat/pipeline/tools/cut_video.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,11 @@ def cut_video(args, outdir):
111111
parser.add_argument("--sr", type=str, default=16000)
112112
parser.add_argument("--out_path", type=str, default="../raw")
113113
args = parser.parse_args()
114-
114+
115+
# Validate and normalize input and output paths
116+
if not os.path.exists(args.path):
117+
raise FileNotFoundError(f"Input path '{args.path}' does not exist.")
118+
115119
outdir = os.path.join(shlex.quote(args.path), shlex.quote(args.out_path))
116120
if not os.path.exists(outdir):
117121
os.mkdir(outdir)

intel_extension_for_transformers/neural_chat/ui/basic_frontend/fastchat/serve/cli.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,24 @@
3535

3636
from fastchat.serve.inference import chat_loop, ChatIO
3737

38+
def is_safe_input(input_text):
39+
# Define a regular expression pattern to match safe input
40+
safe_pattern = r'^[a-zA-Z0-9\s,.!?]+$'
41+
return re.match(safe_pattern, input_text) is not None
3842

3943
class SimpleChatIO(ChatIO):
4044
def prompt_for_input(self, role) -> str:
41-
return input(f"{role}: ").strip()
45+
query = input(f"{role}: ").strip()
46+
# Validate user input
47+
if not query:
48+
print('Input cannot be empty. Please try again.')
49+
return None
50+
51+
# Perform input validation
52+
if not is_safe_input(query):
53+
print('Invalid characters in input. Please use only letters, numbers, and common punctuation.')
54+
return None
55+
return query
4256

4357
def prompt_for_output(self, role: str):
4458
print(f"{role}: ", end="", flush=True)

workflows/chatbot/demo/basic_frontend/fastchat/serve/cli.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,24 @@
1818

1919
from fastchat.serve.inference import chat_loop, ChatIO
2020

21+
def is_safe_input(input_text):
22+
# Define a regular expression pattern to match safe input
23+
safe_pattern = r'^[a-zA-Z0-9\s,.!?]+$'
24+
return re.match(safe_pattern, input_text) is not None
2125

2226
class SimpleChatIO(ChatIO):
2327
def prompt_for_input(self, role) -> str:
24-
return input(f"{role}: ").strip()
28+
query = input(f"{role}: ").strip()
29+
# Validate user input
30+
if not query:
31+
print('Input cannot be empty. Please try again.')
32+
return None
33+
34+
# Perform input validation
35+
if not is_safe_input(query):
36+
print('Invalid characters in input. Please use only letters, numbers, and common punctuation.')
37+
return None
38+
return query
2539

2640
def prompt_for_output(self, role: str):
2741
print(f"{role}: ", end="", flush=True)

workflows/chatbot/fine_tuning/instruction_tuning_pipeline/finetune_clm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def main():
347347
parser = HfArgumentParser(
348348
(ModelArguments, DataArguments, GaudiTrainingArguments, FinetuneArguments)
349349
)
350-
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
350+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json") and os.path.exists(sys.argv[1]):
351351
# If we pass only one argument to the script and it's the path to a json file,
352352
# let's parse it to get our arguments.
353353
model_args, data_args, training_args, finetune_args = parser.parse_json_file(

workflows/chatbot/inference/backend/chat/model_worker.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,22 +244,50 @@ async def talkingbot(request: Request):
244244
async def api_get_status(request: Request):
245245
return worker.get_status()
246246

247+
def validate_port(value):
248+
try:
249+
port = int(value)
250+
if 1 <= port <= 65535:
251+
return port
252+
else:
253+
raise argparse.ArgumentTypeError("Port number must be between 1 and 65535.")
254+
except ValueError:
255+
raise argparse.ArgumentTypeError("Invalid port number. Must be an integer.")
256+
257+
def validate_device(value):
258+
valid_devices = ["cpu", "cuda", "mps"]
259+
if value in valid_devices:
260+
return value
261+
else:
262+
raise argparse.ArgumentTypeError(f"Invalid device. Must be one of {', '.join(valid_devices)}.")
263+
264+
def validate_limit_model_concurrency(value):
265+
if value >= 0:
266+
return value
267+
else:
268+
raise argparse.ArgumentTypeError("Limit model concurrency must be a non-negative integer.")
269+
270+
def validate_stream_interval(value):
271+
if value > 0:
272+
return value
273+
else:
274+
raise argparse.ArgumentTypeError("Stream interval must be a positive integer.")
247275

248276
if __name__ == "__main__":
249277
parser = argparse.ArgumentParser()
250278
parser.add_argument("--host", type=str, default="0.0.0.0")
251-
parser.add_argument("--port", type=int, default=8080)
279+
parser.add_argument("--port", type=validate_port, default=8080)
252280
parser.add_argument("--worker-address", type=str,
253281
default="http://localhost:8080")
254282
parser.add_argument("--controller-address", type=str,
255283
default="http://localhost:80")
256284
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
257285
parser.add_argument("--model-name", type=str)
258-
parser.add_argument("--device", type=str, choices=["cpu", "cuda", "mps"], default="cuda")
286+
parser.add_argument("--device", type=validate_device, choices=["cpu", "cuda", "mps"], default="cpu")
259287
parser.add_argument("--num-gpus", type=int, default=1)
260288
parser.add_argument("--load-8bit", action="store_true")
261-
parser.add_argument("--limit-model-concurrency", type=int, default=5)
262-
parser.add_argument("--stream-interval", type=int, default=2)
289+
parser.add_argument("--limit-model-concurrency", type=validate_limit_model_concurrency, default=5)
290+
parser.add_argument("--stream-interval", type=validate_stream_interval, default=2)
263291
parser.add_argument("--no-register", action="store_true")
264292
parser.add_argument("--ipex", action="store_true")
265293
parser.add_argument("--itrex", action="store_true")

workflows/chatbot/inference/memory_controller/chat_with_memory.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os
1+
import os, re
22
from langchain.llms import HuggingFacePipeline
33
from langchain.prompts import PromptTemplate
44
from langchain.memory import ConversationBufferWindowMemory
@@ -45,6 +45,10 @@ def inference(args, query, memory):
4545
print("inference cost {} seconds.".format(end_time - start_time))
4646
return result, memory
4747

48+
def is_safe_input(input_text):
49+
# Define a regular expression pattern to match safe input
50+
safe_pattern = r'^[a-zA-Z0-9\s,.!?]+$'
51+
return re.match(safe_pattern, input_text) is not None
4852

4953
if __name__ == "__main__":
5054

@@ -63,7 +67,7 @@ def inference(args, query, memory):
6367
"max_length": args.max_length,
6468
"device_map": "auto",
6569
"repetition_penalty": args.penalty,
66-
}
70+
})
6771
if args.memory_type == "buffer_window":
6872
memory = ConversationBufferWindowMemory(memory_key="chat_history", k=3)
6973
elif args.memory_type == "buffer":
@@ -74,8 +78,19 @@ def inference(args, query, memory):
7478

7579
while True:
7680
query = input("Enter input (or 'exit' to quit): ").strip()
77-
if query == 'exit':
81+
if query.lower() == 'exit':
7882
print('exit')
7983
break
84+
85+
# Validate user input
86+
if not query:
87+
print('Input cannot be empty. Please try again.')
88+
continue
89+
90+
# Perform input validation
91+
if not is_safe_input(query):
92+
print('Invalid characters in input. Please use only letters, numbers, and common punctuation.')
93+
continue
94+
8095
result, memory = inference(args, query, memory)
8196
print("Input:" + query + '\nResponse:' + result + '\n')

0 commit comments

Comments
 (0)