From c0f26ee1e3ece882b308f93ffda8f0bea787698d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roberto=20Tom=C3=A1s=20Collins?= Date: Mon, 30 Sep 2024 22:06:01 -0400 Subject: [PATCH] add ollama (no tts yet but text ollama) --- .gitignore | 228 ++++++++++++++++++++++++++++++++++++++++++++++++ PDF2Audio.ipynb | 120 +++++++++++++++++++++---- app.py | 113 +++++++++++++++++++++--- 3 files changed, 429 insertions(+), 32 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2bd3451 --- /dev/null +++ b/.gitignore @@ -0,0 +1,228 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* diff --git a/PDF2Audio.ipynb b/PDF2Audio.ipynb index e5e30dc..bad916e 100644 --- a/PDF2Audio.ipynb +++ b/PDF2Audio.ipynb @@ -43,7 +43,6 @@ }, "outputs": [], "source": [ - "# @title\n", "import concurrent.futures as cf\n", "import glob\n", "import io\n", @@ -64,6 +63,9 @@ "\n", "import re\n", "\n", + "# import sys\n", + "# logger.configure(handlers=[{\"sink\": sys.stdout, \"level\": \"DEBUG\"}])\n", + "\n", "def read_readme():\n", " readme_path = Path(\"README.md\")\n", " if readme_path.exists():\n", @@ -74,7 +76,7 @@ " return content\n", " else:\n", " return \"README.md not found. Please check the repository for more information.\"\n", - " \n", + "\n", "# Define multiple sets of instruction templates\n", "INSTRUCTION_TEMPLATES = {\n", "################# PODCAST ##################\n", @@ -260,6 +262,7 @@ "The summary should have around 256 words.\n", "\"\"\",\n", " },\n", + "\n", "################# PODCAST French ##################\n", "\"podcast (French)\": {\n", " \"intro\": \"\"\"Votre tâche consiste à prendre le texte fourni et à le transformer en un dialogue de podcast vivant, engageant et informatif, dans le style de NPR. Le texte d'entrée peut être désorganisé ou non structuré, car il peut provenir de diverses sources telles que des fichiers PDF ou des pages web.\n", @@ -434,7 +437,7 @@ "पॉडकास्ट में लगभग 20,000 शब्द होने चाहिए।\n", "\"\"\",\n", " },\n", - "\n", + " \n", "################# PODCAST Chinese ##################\n", "\"podcast (Chinese)\": {\n", " \"intro\": \"\"\"你的任务是将提供的输入文本转变为一个生动、有趣、信息丰富的播客对话,风格类似NPR。输入文本可能是凌乱的或未结构化的,因为它可能来自PDF或网页等各种来源。\n", @@ -469,7 +472,6 @@ "播客应约有20,000字。\n", "\"\"\",\n", " },\n", - "\n", "}\n", "\n", "# Function to update instruction fields based on template selection\n", @@ -480,7 +482,7 @@ " INSTRUCTION_TEMPLATES[template][\"scratch_pad\"],\n", " INSTRUCTION_TEMPLATES[template][\"prelude\"],\n", " INSTRUCTION_TEMPLATES[template][\"dialog\"]\n", - " )\n", + " )\n", "\n", "import concurrent.futures as cf\n", "import glob\n", @@ -500,6 +502,8 @@ "from pypdf import PdfReader\n", "from tenacity import retry, retry_if_exception_type\n", "\n", + "import requests # Added import for handling HTTP Ollama requests\n", + "\n", "# Define standard values\n", "STANDARD_TEXT_MODELS = [\n", " \"o1-preview-2024-09-12\",\n", @@ -527,6 +531,33 @@ " \"shimmer\",\n", "]\n", "\n", + "# Function to get Ollama models\n", + "def get_ollama_models(api_base=\"http://localhost:11434\", api_path=\"/v1/models\"):\n", + " api = api_base + api_path\n", + " try:\n", + " response = requests.get(api)\n", + " response.raise_for_status()\n", + " models_info = response.json()\n", + " logger.info(f\"Retrieved models from Ollama API: {models_info}\")\n", + " \n", + " if isinstance(models_info, dict) and \"data\" in models_info:\n", + " models = [model['id'] for model in models_info[\"data\"] if \"id\" in model]\n", + " else:\n", + " # Handle unexpected data structures\n", + " logger.warning(\"Unexpected data format in Ollama API response.\")\n", + " models = []\n", + " \n", + " return models\n", + " except requests.exceptions.RequestException as e:\n", + " logger.warning(f\"Could not connect to Ollama API at {api_base}: {e}\")\n", + " return []\n", + " except ValueError as e:\n", + " logger.warning(f\"Failed to parse JSON from Ollama API at {api_base}: {e}\")\n", + " return []\n", + " except KeyError as e:\n", + " logger.warning(f\"Missing expected key in Ollama API response: {e}\")\n", + " return []\n", + "\n", "class DialogueItem(BaseModel):\n", " text: str\n", " speaker: Literal[\"speaker-1\", \"speaker-2\"]\n", @@ -582,11 +613,17 @@ " edited_transcript: str = None,\n", " user_feedback: str = None,\n", " original_text: str = None,\n", - " debug = False,\n", ") -> tuple:\n", - " # Validate API Key\n", - " if not os.getenv(\"OPENAI_API_KEY\") and not openai_api_key:\n", - " raise gr.Error(\"OpenAI API key is required\")\n", + " # Determine if the selected model is an Ollama model\n", + " if text_model.startswith('ollama/'):\n", + " if not api_base:\n", + " api_base = 'http://localhost:11434'\n", + " else:\n", + " # Use OpenAI API\n", + " if not openai_api_key:\n", + " openai_api_key = os.getenv(\"OPENAI_API_KEY\")\n", + " if not openai_api_key:\n", + " raise gr.Error(\"OpenAI API key is required for non-Ollama models\")\n", "\n", " combined_text = original_text or \"\"\n", "\n", @@ -634,9 +671,8 @@ " if edited_transcript_processed.strip()!='' or user_feedback_processed.strip()!='':\n", " user_feedback_processed=\"\"+user_feedback_processed+\"\\n\\n\"+instruction_improve+\"\" \n", " \n", - " if debug:\n", - " logger.info (edited_transcript_processed)\n", - " logger.info (user_feedback_processed)\n", + " logger.debug(edited_transcript_processed)\n", + " logger.debug(user_feedback_processed)\n", " \n", " # Generate the dialogue using the LLM\n", " llm_output = generate_dialogue(\n", @@ -665,9 +701,12 @@ " characters += len(line.text)\n", "\n", " for future, transcript_line in futures:\n", - " audio_chunk = future.result()\n", - " audio += audio_chunk\n", - " transcript += transcript_line + \"\\n\\n\"\n", + " try:\n", + " audio_chunk = future.result()\n", + " audio += audio_chunk\n", + " transcript += transcript_line + \"\\n\\n\"\n", + " except Exception as e:\n", + " logger.error(f\"Error generating audio chunk: {e}\")\n", "\n", " logger.info(f\"Generated {characters} characters of audio\")\n", "\n", @@ -698,6 +737,7 @@ " audio_file, transcript, original_text = generate_audio(*args)\n", " return audio_file, transcript, original_text, None # Return None as the error when successful\n", " except Exception as e:\n", + " logger.error(f\"An error occurred during audio generation: {e}\")\n", " # If an error occurs during generation, return None for the outputs and the error message\n", " return None, None, None, str(e)\n", "\n", @@ -706,7 +746,7 @@ " #new_args = list(args)\n", " #new_args[-2] = edited_transcript # Update edited transcript\n", " #new_args[-1] = user_feedback # Update user feedback\n", - " return validate_and_generate_audio(*new_args)\n", + " return validate_and_generate_audio(*args)\n", "\n", "# New function to handle user feedback and regeneration\n", "def process_feedback_and_regenerate(feedback, *args):\n", @@ -794,7 +834,7 @@ " api_base = gr.Textbox(\n", " label=\"Custom API Base\",\n", " placeholder=\"Enter custom API base URL if using a custom/local model...\",\n", - " info=\"If you are using a custom or local model, provide the API base URL here, e.g.: http://localhost:8080/v1 for llama.cpp REST server.\",\n", + " info=\"If you are using a custom or local model, provide the API base URL here, e.g.: http://localhost:11434 for Ollama server or http://localhost:8080/v1 for llama.cpp REST server.\",\n", " )\n", "\n", " with gr.Column(scale=3):\n", @@ -865,6 +905,50 @@ " outputs=[intro_instructions, text_instructions, scratch_pad_instructions, prelude_dialog, podcast_dialog_instructions]\n", " )\n", " \n", + " # Update api_base when text_model changes\n", + " def update_api_base_on_model_change(text_model_value, api_base_value):\n", + " if text_model_value.startswith('ollama/') and not api_base_value:\n", + " return gr.update(value='http://localhost:11434')\n", + " else:\n", + " return gr.update()\n", + "\n", + " text_model.change(\n", + " fn=update_api_base_on_model_change,\n", + " inputs=[text_model, api_base],\n", + " outputs=api_base\n", + " )\n", + "\n", + " # Update text_model choices when api_base changes\n", + " def update_text_models_on_api_base_change(api_base_value):\n", + " standard_models = [\n", + " \"o1-preview-2024-09-12\",\n", + " \"o1-preview\",\n", + " \"gpt-4o-2024-08-06\",\n", + " \"gpt-4o-mini\",\n", + " \"o1-mini-2024-09-12\",\n", + " \"o1-mini\",\n", + " \"chatgpt-4o-latest\",\n", + " \"gpt-4-turbo\",\n", + " \"openai/custom_model\",\n", + " ]\n", + " # Get the Ollama models from the new api_base\n", + " if api_base_value:\n", + " ollama_models = get_ollama_models(api_base_value.rstrip('/'))\n", + " ollama_models = [f'ollama/{model}' for model in ollama_models]\n", + " else:\n", + " ollama_models = []\n", + " # Combine models\n", + " updated_models = standard_models + ollama_models\n", + " updated_models = list(set(updated_models))\n", + " # Return the updated dropdown choices\n", + " return gr.update(choices=updated_models)\n", + "\n", + " api_base.change(\n", + " fn=update_text_models_on_api_base_change,\n", + " inputs=[api_base],\n", + " outputs=[text_model]\n", + " )\n", + "\n", " submit_btn.click(\n", " fn=validate_and_generate_audio,\n", " inputs=[\n", @@ -926,7 +1010,7 @@ "\n", "# Launch the Gradio app\n", "if __name__ == \"__main__\":\n", - " demo.launch()" + " demo.launch()\n" ] }, { diff --git a/app.py b/app.py index 5f3aeed..00d1755 100644 --- a/app.py +++ b/app.py @@ -18,6 +18,9 @@ import re +# import sys +# logger.configure(handlers=[{"sink": sys.stdout, "level": "DEBUG"}]) + def read_readme(): readme_path = Path("README.md") if readme_path.exists(): @@ -28,7 +31,7 @@ def read_readme(): return content else: return "README.md not found. Please check the repository for more information." - + # Define multiple sets of instruction templates INSTRUCTION_TEMPLATES = { ################# PODCAST ################## @@ -434,7 +437,7 @@ def update_instructions(template): INSTRUCTION_TEMPLATES[template]["scratch_pad"], INSTRUCTION_TEMPLATES[template]["prelude"], INSTRUCTION_TEMPLATES[template]["dialog"] - ) + ) import concurrent.futures as cf import glob @@ -454,6 +457,8 @@ def update_instructions(template): from pypdf import PdfReader from tenacity import retry, retry_if_exception_type +import requests # Added import for handling HTTP Ollama requests + # Define standard values STANDARD_TEXT_MODELS = [ "o1-preview-2024-09-12", @@ -481,6 +486,33 @@ def update_instructions(template): "shimmer", ] +# Function to get Ollama models +def get_ollama_models(api_base="http://localhost:11434", api_path="/v1/models"): + api = api_base + api_path + try: + response = requests.get(api) + response.raise_for_status() + models_info = response.json() + logger.info(f"Retrieved models from Ollama API: {models_info}") + + if isinstance(models_info, dict) and "data" in models_info: + models = [model['id'] for model in models_info["data"] if "id" in model] + else: + # Handle unexpected data structures + logger.warning("Unexpected data format in Ollama API response.") + models = [] + + return models + except requests.exceptions.RequestException as e: + logger.warning(f"Could not connect to Ollama API at {api_base}: {e}") + return [] + except ValueError as e: + logger.warning(f"Failed to parse JSON from Ollama API at {api_base}: {e}") + return [] + except KeyError as e: + logger.warning(f"Missing expected key in Ollama API response: {e}") + return [] + class DialogueItem(BaseModel): text: str speaker: Literal["speaker-1", "speaker-2"] @@ -536,11 +568,17 @@ def generate_audio( edited_transcript: str = None, user_feedback: str = None, original_text: str = None, - debug = False, ) -> tuple: - # Validate API Key - if not os.getenv("OPENAI_API_KEY") and not openai_api_key: - raise gr.Error("OpenAI API key is required") + # Determine if the selected model is an Ollama model + if text_model.startswith('ollama/'): + if not api_base: + api_base = 'http://localhost:11434' + else: + # Use OpenAI API + if not openai_api_key: + openai_api_key = os.getenv("OPENAI_API_KEY") + if not openai_api_key: + raise gr.Error("OpenAI API key is required for non-Ollama models") combined_text = original_text or "" @@ -588,9 +626,8 @@ def generate_dialogue(text: str, intro_instructions: str, text_instructions: str if edited_transcript_processed.strip()!='' or user_feedback_processed.strip()!='': user_feedback_processed=""+user_feedback_processed+"\n\n"+instruction_improve+"" - if debug: - logger.info (edited_transcript_processed) - logger.info (user_feedback_processed) + logger.debug(edited_transcript_processed) + logger.debug(user_feedback_processed) # Generate the dialogue using the LLM llm_output = generate_dialogue( @@ -619,9 +656,12 @@ def generate_dialogue(text: str, intro_instructions: str, text_instructions: str characters += len(line.text) for future, transcript_line in futures: - audio_chunk = future.result() - audio += audio_chunk - transcript += transcript_line + "\n\n" + try: + audio_chunk = future.result() + audio += audio_chunk + transcript += transcript_line + "\n\n" + except Exception as e: + logger.error(f"Error generating audio chunk: {e}") logger.info(f"Generated {characters} characters of audio") @@ -652,6 +692,7 @@ def validate_and_generate_audio(*args): audio_file, transcript, original_text = generate_audio(*args) return audio_file, transcript, original_text, None # Return None as the error when successful except Exception as e: + logger.error(f"An error occurred during audio generation: {e}") # If an error occurs during generation, return None for the outputs and the error message return None, None, None, str(e) @@ -660,7 +701,7 @@ def edit_and_regenerate(edited_transcript, user_feedback, *args): #new_args = list(args) #new_args[-2] = edited_transcript # Update edited transcript #new_args[-1] = user_feedback # Update user feedback - return validate_and_generate_audio(*new_args) + return validate_and_generate_audio(*args) # New function to handle user feedback and regeneration def process_feedback_and_regenerate(feedback, *args): @@ -748,7 +789,7 @@ def process_feedback_and_regenerate(feedback, *args): api_base = gr.Textbox( label="Custom API Base", placeholder="Enter custom API base URL if using a custom/local model...", - info="If you are using a custom or local model, provide the API base URL here, e.g.: http://localhost:8080/v1 for llama.cpp REST server.", + info="If you are using a custom or local model, provide the API base URL here, e.g.: http://localhost:11434 for Ollama server or http://localhost:8080/v1 for llama.cpp REST server.", ) with gr.Column(scale=3): @@ -819,6 +860,50 @@ def update_edit_box(checkbox_value): outputs=[intro_instructions, text_instructions, scratch_pad_instructions, prelude_dialog, podcast_dialog_instructions] ) + # Update api_base when text_model changes + def update_api_base_on_model_change(text_model_value, api_base_value): + if text_model_value.startswith('ollama/') and not api_base_value: + return gr.update(value='http://localhost:11434') + else: + return gr.update() + + text_model.change( + fn=update_api_base_on_model_change, + inputs=[text_model, api_base], + outputs=api_base + ) + + # Update text_model choices when api_base changes + def update_text_models_on_api_base_change(api_base_value): + standard_models = [ + "o1-preview-2024-09-12", + "o1-preview", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "o1-mini-2024-09-12", + "o1-mini", + "chatgpt-4o-latest", + "gpt-4-turbo", + "openai/custom_model", + ] + # Get the Ollama models from the new api_base + if api_base_value: + ollama_models = get_ollama_models(api_base_value.rstrip('/')) + ollama_models = [f'ollama/{model}' for model in ollama_models] + else: + ollama_models = [] + # Combine models + updated_models = standard_models + ollama_models + updated_models = list(set(updated_models)) + # Return the updated dropdown choices + return gr.update(choices=updated_models) + + api_base.change( + fn=update_text_models_on_api_base_change, + inputs=[api_base], + outputs=[text_model] + ) + submit_btn.click( fn=validate_and_generate_audio, inputs=[