diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index b75a44e04..c34ecd934 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -27,8 +27,6 @@ from pilot.conversation import ( from pilot.utils import ( build_logger, server_error_msg, - violates_moderation, - moderation_msg ) from pilot.server.gradio_css import code_highlight_css @@ -129,14 +127,9 @@ def add_text(state, text, request: gr.Request): if len(text) <= 0: state.skip_next = True return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 - if args.moderate: - flagged = violates_moderation(text) - if flagged: - state.skip_next = True - return (state, state.to_gradio_chatbot(), moderation_msg) + ( - no_change_btn,) * 5 - text = text[:4000] # Hard cut-off + """ Default support 4000 tokens, if tokens too lang, we will cut off """ + text = text[:4000] state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False @@ -439,9 +432,7 @@ if __name__ == "__main__": "--model-list-mode", type=str, default="once", choices=["once", "reload"] ) parser.add_argument("--share", default=False, action="store_true") - parser.add_argument( - "--moderate", action="store_true", help="Enable content moderation" - ) + args = parser.parse_args() logger.info(f"args: {args}") diff --git a/pilot/utils.py b/pilot/utils.py index b2505eba1..0179d12c2 100644 --- a/pilot/utils.py +++ b/pilot/utils.py @@ -14,7 +14,6 @@ import requests from pilot.configs.model_config import LOGDIR server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" -moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." handler = None @@ -125,27 +124,6 @@ def disable_torch_init(): setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) -def violates_moderation(text): - """ - Check whether the text violates OpenAI moderation API. - """ - url = "https://api.openai.com/v1/moderations" - headers = {"Content-Type": "application/json", - "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} - text = text.replace("\n", "") - data = "{" + '"input": ' + f'"{text}"' + "}" - data = data.encode("utf-8") - try: - ret = requests.post(url, headers=headers, data=data, timeout=5) - flagged = ret.json()["results"][0]["flagged"] - except requests.exceptions.RequestException as e: - flagged = False - except KeyError as e: - flagged = False - - return flagged - - def pretty_print_semaphore(semaphore): if semaphore is None: return "None"