File size: 5,846 Bytes
01d5a5d
 
 
fa21e69
01d5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa21e69
 
01d5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from typing import Optional, Dict, Any, List, Tuple
import os
import chromadb
from chromadb.config import Settings
import logging
from lpm_kernel.configs.logging import get_train_process_logger

logger = get_train_process_logger()


def get_embedding_dimension(embedding: List[float]) -> int:
    """
    Get the dimension of an embedding vector
    
    Args:
        embedding: The embedding vector
        
    Returns:
        The dimension of the embedding vector
    """
    return len(embedding)


def detect_embedding_model_dimension(model_name: str) -> Optional[int]:
    """
    Detect the dimension of an embedding model based on its name
    This is a fallback method when we can't get a sample embedding
    
    Args:
        model_name: The name of the embedding model
        
    Returns:
        The dimension of the embedding model, or None if unknown
    """
    # Common embedding model dimensions
    model_dimensions = {
        # OpenAI models
        "text-embedding-ada-002": 1536,
        "text-embedding-3-small": 1536,
        "text-embedding-3-large": 3072,
        
        # Ollama models
        "snowflake-arctic-embed": 768,
        "snowflake-arctic-embed:110m": 768,
        "nomic-embed-text": 768,
        "nomic-embed-text:v1.5": 768,
        "mxbai-embed-large": 1024,
        "mxbai-embed-large:v1": 1024,
    }
    
    # Try to find exact match
    if model_name in model_dimensions:
        return model_dimensions[model_name]
    
    # Try to find partial match
    for model, dimension in model_dimensions.items():
        if model in model_name:
            return dimension
    
    # Default to OpenAI dimension if unknown
    logger.warning(f"Unknown embedding model: {model_name}, defaulting to 1536 dimensions")
    return 1536


def reinitialize_chroma_collections(dimension: int = 1536) -> bool:
    """
    Reinitialize ChromaDB collections with a new dimension
    
    Args:
        dimension: The new dimension for the collections
        
    Returns:
        True if successful, False otherwise
    """
    try:
        chroma_path = os.getenv("CHROMA_PERSIST_DIRECTORY", "./data/chroma_db")
        settings = Settings(anonymized_telemetry=False)
        client = chromadb.PersistentClient(path=chroma_path, settings=settings)
        
        # Delete and recreate document collection
        try:
            # Check if collection exists before attempting to delete
            try:
                client.get_collection(name="documents")
                client.delete_collection(name="documents")
                logger.info("Deleted 'documents' collection")
            except ValueError:
                logger.info("'documents' collection does not exist, will create new")
        except Exception as e:
            logger.error(f"Error deleting 'documents' collection: {str(e)}", exc_info=True)
            return False
        
        # Create document collection with new dimension
        try:
            client.create_collection(
                name="documents",
                metadata={
                    "hnsw:space": "cosine",
                    "dimension": dimension
                }
            )
            logger.info(f"Created 'documents' collection with dimension {dimension}")
        except Exception as e:
            logger.error(f"Error creating 'documents' collection: {str(e)}", exc_info=True)
            return False
        
        # Delete and recreate chunk collection
        try:
            # Check if collection exists before attempting to delete
            try:
                client.get_collection(name="document_chunks")
                client.delete_collection(name="document_chunks")
                logger.info("Deleted 'document_chunks' collection")
            except ValueError:
                logger.info("'document_chunks' collection does not exist, will create new")
        except Exception as e:
            logger.error(f"Error deleting 'document_chunks' collection: {str(e)}", exc_info=True)
            return False
        
        # Create chunk collection with new dimension
        try:
            client.create_collection(
                name="document_chunks",
                metadata={
                    "hnsw:space": "cosine",
                    "dimension": dimension
                }
            )
            logger.info(f"Created 'document_chunks' collection with dimension {dimension}")
        except Exception as e:
            logger.error(f"Error creating 'document_chunks' collection: {str(e)}", exc_info=True)
            return False
        
        # Verify collections were created with correct dimension
        try:
            doc_collection = client.get_collection(name="documents")
            chunk_collection = client.get_collection(name="document_chunks")
            
            doc_dimension = doc_collection.metadata.get("dimension")
            if doc_dimension != dimension:
                logger.error(f"Verification failed: 'documents' collection has incorrect dimension: {doc_dimension} vs {dimension}")
                return False
                
            chunk_dimension = chunk_collection.metadata.get("dimension")
            if chunk_dimension != dimension:
                logger.error(f"Verification failed: 'document_chunks' collection has incorrect dimension: {chunk_dimension} vs {dimension}")
                return False
                
            logger.info(f"Verification successful: Both collections have correct dimension: {dimension}")
        except Exception as e:
            logger.error(f"Error verifying collections: {str(e)}", exc_info=True)
            return False
        
        return True
    except Exception as e:
        logger.error(f"Error reinitializing ChromaDB collections: {str(e)}", exc_info=True)
        return False