Spaces:
Running
Running
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, ClassVar | |
| from chromadb import EmbeddingFunction | |
| class ModelStorage: | |
| '''Global model storage''' | |
| EMBED_MODEL: ClassVar[dict[str, EmbeddingFunction]] = {} | |
| class UiGradioConfig: | |
| '''Gradio settings for gr.Blocks()''' | |
| css: str | None = ''' | |
| .gradio-container { | |
| width: 70% !important; | |
| margin: 0 auto !important; | |
| } | |
| ''' | |
| if hasattr(sys, 'getandroidapilevel') or 'ANDROID_ROOT' in os.environ: | |
| css = None | |
| theme: str | None = None | |
| fill_height: bool = False | |
| footer_links: list[str] = ['gradio', 'settings'] | |
| delete_cache: tuple[int, int] | None = None | |
| def get_demo_launch_kwargs(cls): | |
| return dict( | |
| css=cls.css, | |
| theme=cls.theme, | |
| footer_links=cls.footer_links, | |
| ) | |
| def get_demo_blocks_kwargs(cls): | |
| return dict( | |
| fill_height=cls.fill_height, | |
| delete_cache=cls.delete_cache, | |
| ) | |
| class InferenceConfig: | |
| '''Model inference settings''' | |
| def __init__(self): | |
| self.encode_kwargs: dict[str, Any] = dict( | |
| batch_size=300, | |
| normalize_embeddings=None, | |
| ) | |
| self.reasoning_format = 'none' | |
| self.sampling_kwargs: dict[str, Any] = dict( | |
| temperature=0.2, | |
| top_p=0.9, | |
| top_k=40, | |
| repeat_penalty=1, | |
| ) | |
| self.max_tokens = -1 | |
| self.do_sample: bool = False | |
| self.rag_mode: bool = False | |
| self.history_len: int = 0 | |
| self.enable_thinking = False | |
| self.show_thinking: bool = False | |
| self.resize_size = 256 | |
| class TextLoadConfig: | |
| '''Settings for loading texts from documents''' | |
| def __init__(self): | |
| self.partition_kwargs: dict[str, str | int | bool | None] = dict( | |
| chunking_strategy='basic', | |
| max_characters=800, | |
| new_after_n_chars=500, | |
| overlap=0, | |
| clean=True, | |
| bullets=True, | |
| extra_whitespace=True, | |
| dashes=False, | |
| trailing_punctuation=True, | |
| lowercase=False, | |
| ) | |
| self.SUPPORTED_FILE_EXTS: str = '.csv .tsv .docx .md .org .pdf .pptx .xlsx' | |
| self.subtitle_lang: str = 'ru' | |
| self.SUBTITLE_LANGS: list[str] = ['ru', 'en'] | |
| self.max_lines_text_view: int = 200 | |
| class DbConfig: | |
| '''Vector database parameters (Chroma)''' | |
| def __init__(self): | |
| self.create_collection_kwargs: dict[str, Any] = dict( | |
| configuration=dict( | |
| hnsw=dict( | |
| space='cosine', # l2, ip, cosine, default l2 | |
| ef_construction=200, | |
| ) | |
| ) | |
| ) | |
| self.query_kwargs: dict[str, Any] = dict( | |
| n_results=2, | |
| max_distance_treshold=0.5, | |
| ) | |
| class PromptConfig: | |
| '''Prompts''' | |
| def __init__(self): | |
| self.system_prompt: str | None = None | |
| self.user_msg_with_context: str = '' | |
| self.context_template: str = '''Ответь на вопрос при условии контекста. | |
| Контекст: | |
| {context} | |
| Вопрос: | |
| {user_message}''' | |
| class ModelConfig: | |
| '''Configuration of paths, models and generation parameters''' | |
| def __init__(self): | |
| self.LLM_MODEL_DIR: Path = Path(os.getenv('LLAMA_CACHE', 'llm_models')) | |
| self.EMBED_MODEL_DIR: Path = Path(os.getenv('EMBED_MODEL_DIR', 'embed_models')) | |
| self.LLM_MODEL_DIR.mkdir(exist_ok=True) | |
| self.EMBED_MODEL_DIR.mkdir(exist_ok=True) | |
| self.llm_model_repo: str = None | |
| self.llm_model_file: str = None | |
| self.llm_model_mmproj: str = None | |
| self.embed_model_repo: str = os.getenv('EMBED_MODEL_REPO', 'Alibaba-NLP/gte-multilingual-base') | |
| self.embed_model_kwargs: dict[str, Any] = dict( | |
| cache_folder=self.EMBED_MODEL_DIR, | |
| trust_remote_code=True, | |
| token=os.getenv('HF_TOKEN'), | |
| model_kwargs=dict( | |
| dtype='auto', | |
| device_map='auto', | |
| ) | |
| ) | |
| self.llm_model_kwargs: dict[str, Any] = dict( | |
| n_gpu_layers=-1, | |
| n_ctx=4096, | |
| ) | |
| class ReposConfig: | |
| '''Links to repositories with ggu models''' | |
| def __init__(self): | |
| self.llm_model_repos: list[str] = [ | |
| 'bartowski/google_gemma-3-1b-it-GGUF', | |
| 'bartowski/google_gemma-3-4b-it-GGUF', | |
| 'bartowski/Qwen_Qwen3-1.7B-GGUF', | |
| 'bartowski/Qwen_Qwen3-4B-GGUF', | |
| 'bartowski/Qwen_Qwen3-0.6B-GGUF', | |
| ] | |
| self.embed_model_repos: list[str] = [ | |
| 'Alibaba-NLP/gte-multilingual-base', | |
| 'sergeyzh/rubert-tiny-turbo', | |
| 'intfloat/multilingual-e5-large', | |
| 'intfloat/multilingual-e5-base', | |
| 'intfloat/multilingual-e5-small', | |
| 'intfloat/multilingual-e5-large-instruct', | |
| 'sentence-transformers/all-mpnet-base-v2', | |
| 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', | |
| 'ai-forever/ruElectra-medium', | |
| 'ai-forever/sbert_large_nlu_ru', | |
| 'deepvk/USER2-small', | |
| 'BAAI/bge-m3-retromae', | |
| ] | |
| class Config: | |
| '''General config''' | |
| def __init__(self): | |
| self.Inference: InferenceConfig = InferenceConfig() | |
| self.TextLoad: TextLoadConfig = TextLoadConfig() | |
| self.Prompt: PromptConfig = PromptConfig() | |
| self.Db: DbConfig = DbConfig() | |
| self.Model: ModelConfig = ModelConfig() | |
| self.Repos: ReposConfig = ReposConfig() | |
| self.CHATBOT_RAG_ENABLED = os.getenv('CHATBOT_RAG_ENABLED', '1').lower() in ('1', 'true') | |
| self.generation_kwargs: dict[str, Any] = dict( | |
| do_sample=self.Inference.do_sample, | |
| temperature=self.Inference.sampling_kwargs['temperature'], | |
| top_p=self.Inference.sampling_kwargs['top_p'], | |
| top_k=self.Inference.sampling_kwargs['top_k'], | |
| repeat_penalty=self.Inference.sampling_kwargs['repeat_penalty'], | |
| max_tokens=self.Inference.max_tokens, | |
| history_len=self.Inference.history_len, | |
| system_prompt=self.Prompt.system_prompt, | |
| context_template=self.Prompt.context_template, | |
| enable_thinking=self.Inference.enable_thinking, | |
| show_thinking=self.Inference.show_thinking, | |
| n_results=self.Db.query_kwargs['n_results'], | |
| max_distance_treshold=self.Db.query_kwargs['max_distance_treshold'], | |
| user_msg_with_context=self.Prompt.user_msg_with_context, | |
| rag_mode=self.Inference.rag_mode, | |
| ) | |
| self.load_text_kwargs: dict[str, Any] = dict( | |
| chunking_strategy=self.TextLoad.partition_kwargs['chunking_strategy'], | |
| max_characters=self.TextLoad.partition_kwargs['max_characters'], | |
| new_after_n_chars=self.TextLoad.partition_kwargs['new_after_n_chars'], | |
| overlap=self.TextLoad.partition_kwargs['overlap'], | |
| clean=self.TextLoad.partition_kwargs['clean'], | |
| bullets=self.TextLoad.partition_kwargs['bullets'], | |
| extra_whitespace=self.TextLoad.partition_kwargs['extra_whitespace'], | |
| dashes=self.TextLoad.partition_kwargs['dashes'], | |
| trailing_punctuation=self.TextLoad.partition_kwargs['trailing_punctuation'], | |
| lowercase=self.TextLoad.partition_kwargs['lowercase'], | |
| subtitle_lang=self.TextLoad.subtitle_lang, | |
| ) | |
| self.load_model_kwargs: dict[str, Any] = dict( | |
| llm_model_repo=self.Model.llm_model_repo, | |
| llm_model_file=self.Model.llm_model_file, | |
| llm_model_mmproj=self.Model.llm_model_mmproj, | |
| embed_model_repo=self.Model.embed_model_repo, | |
| n_gpu_layers=self.Model.llm_model_kwargs['n_gpu_layers'], | |
| n_ctx=self.Model.llm_model_kwargs['n_ctx'], | |
| ) | |
| self.view_text_kwargs: dict[str, Any] = dict( | |
| max_lines_text_view=self.TextLoad.max_lines_text_view, | |
| ) | |
| def update_env(self, **kwargs) -> None: | |
| LLAMA_ARG_MMPROJ=self.load_model_kwargs['llm_model_mmproj'] | |
| dict_to_updating = dict( | |
| LLAMA_ARG_CTX_SIZE=str(self.load_model_kwargs['n_ctx']), | |
| LLAMA_ARG_N_GPU_LAYERS=str(self.load_model_kwargs['n_gpu_layers']), | |
| ) | |
| for k, v in kwargs.items(): | |
| dict_to_updating[k] = str(v) | |
| os.environ.pop('LLAMA_ARG_MODEL_URL', None) | |
| os.environ.pop('LLAMA_ARG_MMPROJ_URL', None) | |
| os.environ.pop('LLAMA_ARG_HF_REPO', None) | |
| os.environ.pop('LLAMA_ARG_HF_FILE', None) | |
| os.environ.pop('LLAMA_ARG_MMPROJ', None) | |
| if LLAMA_ARG_MMPROJ: | |
| dict_to_updating['LLAMA_ARG_MMPROJ'] = str(self.Model.LLM_MODEL_DIR / LLAMA_ARG_MMPROJ) | |
| os.environ.update(dict_to_updating) | |
| def get_completions_kwargs(self) -> dict[str, Any]: | |
| return dict( | |
| temperature=self.generation_kwargs['temperature'], | |
| top_p=self.generation_kwargs['top_p'], | |
| max_tokens=self.generation_kwargs['max_tokens'], | |
| extra_body=dict( | |
| top_k=self.generation_kwargs['top_k'], | |
| repeat_penalty=self.generation_kwargs['repeat_penalty'], | |
| reasoning_format=self.Inference.reasoning_format, | |
| chat_template_kwargs=dict( | |
| enable_thinking=self.generation_kwargs['enable_thinking'], | |
| ), | |
| ), | |
| ) | |
| def get_rag_kwargs(self) -> dict[str, Any]: | |
| return dict( | |
| n_results=self.generation_kwargs['n_results'], | |
| max_distance_treshold=self.generation_kwargs['max_distance_treshold'], | |
| user_msg_with_context=self.generation_kwargs['user_msg_with_context'], | |
| context_template=self.generation_kwargs['context_template'], | |
| ) | |
| def get_partition_kwargs(self) -> dict[str, Any]: | |
| return dict( | |
| chunking_strategy=self.load_text_kwargs['chunking_strategy'], | |
| max_characters=self.load_text_kwargs['max_characters'], | |
| new_after_n_chars=self.load_text_kwargs['new_after_n_chars'], | |
| overlap=self.load_text_kwargs['overlap'], | |
| clean=self.load_text_kwargs['clean'], | |
| bullets=self.load_text_kwargs['bullets'], | |
| extra_whitespace=self.load_text_kwargs['extra_whitespace'], | |
| dashes=self.load_text_kwargs['dashes'], | |
| trailing_punctuation=self.load_text_kwargs['trailing_punctuation'], | |
| lowercase=self.load_text_kwargs['lowercase'], | |
| ) | |
| def get_clean_kwargs(self) -> dict[str, Any]: | |
| return dict( | |
| bullets=self.load_text_kwargs['bullets'], | |
| extra_whitespace=self.load_text_kwargs['extra_whitespace'], | |
| dashes=self.load_text_kwargs['dashes'], | |
| trailing_punctuation=self.load_text_kwargs['trailing_punctuation'], | |
| lowercase=self.load_text_kwargs['lowercase'], | |
| ) | |
| def get_chunking_kwargs(self): | |
| return dict( | |
| max_characters=self.load_text_kwargs['max_characters'], | |
| new_after_n_chars=self.load_text_kwargs['new_after_n_chars'], | |
| overlap=self.load_text_kwargs['overlap'], | |
| ) | |
| def get_embed_model_kwargs(self) -> dict[str, Any]: | |
| return self.Model.embed_model_kwargs | |
| def get_encode_kwargs(self) -> dict[str, Any]: | |
| return self.Inference.encode_kwargs | |
| def get_llm_model_kwargs(self) -> dict[str, Any]: | |
| return self.Model.llm_model_kwargs | |
| def get_query_kwargs(self) -> dict[str, Any]: | |
| return dict( | |
| n_results=self.generation_kwargs['n_results'], | |
| max_distance_treshold=self.generation_kwargs['max_distance_treshold'], | |
| ) | |