Problem: Gemini blokował treści o energetyce jako 'dangerous content' Rozwiązanie: Przekazywanie safety_settings do API z BLOCK_NONE Uwaga: FREE tier może nadal mieć ograniczenia Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
583 lines
20 KiB
Python
583 lines
20 KiB
Python
"""
|
|
Google Gemini AI Service
|
|
========================
|
|
Reusable service for interacting with Google Gemini API.
|
|
|
|
Features:
|
|
- Multiple model support (Flash, Pro, Flash-8B)
|
|
- Error handling and retries
|
|
- Cost tracking
|
|
- Streaming responses
|
|
- Safety settings configuration
|
|
|
|
Author: MTB Tracker Team
|
|
Created: 2025-10-18
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import hashlib
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Optional, Dict, Any, List
|
|
import google.generativeai as genai
|
|
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Database imports for cost tracking
|
|
try:
|
|
from database import SessionLocal, AIAPICostLog, AIUsageLog
|
|
DB_AVAILABLE = True
|
|
except ImportError:
|
|
logger.warning("Database not available - cost tracking disabled")
|
|
DB_AVAILABLE = False
|
|
|
|
# Available Gemini models (2025 - Gemini 1.5 retired April 29, 2025)
|
|
GEMINI_MODELS = {
|
|
'flash': 'gemini-2.5-flash', # Best for general use - balanced cost/quality
|
|
'flash-lite': 'gemini-2.5-flash-lite', # Ultra cheap - $0.10/$0.40 per 1M tokens
|
|
'pro': 'gemini-2.5-pro', # High quality - best reasoning/coding
|
|
'flash-2.0': 'gemini-2.0-flash', # Second generation - 1M context window
|
|
}
|
|
|
|
# Pricing per 1M tokens (USD) - updated 2025-10-18
|
|
GEMINI_PRICING = {
|
|
'gemini-2.5-flash': {'input': 0.075, 'output': 0.30},
|
|
'gemini-2.5-flash-lite': {'input': 0.10, 'output': 0.40},
|
|
'gemini-2.5-pro': {'input': 1.25, 'output': 5.00},
|
|
'gemini-2.0-flash': {'input': 0.075, 'output': 0.30},
|
|
}
|
|
|
|
class GeminiService:
|
|
"""Service class for Google Gemini API interactions."""
|
|
|
|
def __init__(self, api_key: Optional[str] = None, model: str = 'flash'):
|
|
"""
|
|
Initialize Gemini service.
|
|
|
|
Args:
|
|
api_key: Google AI API key (reads from GOOGLE_GEMINI_API_KEY env if not provided)
|
|
model: Model to use ('flash', 'flash-lite', 'pro', 'flash-2.0')
|
|
"""
|
|
self.api_key = api_key or os.getenv('GOOGLE_GEMINI_API_KEY')
|
|
|
|
# Debug: Log API key (masked)
|
|
if self.api_key:
|
|
logger.info(f"API key loaded: {self.api_key[:10]}...{self.api_key[-4:]}")
|
|
else:
|
|
logger.error("API key is None or empty!")
|
|
|
|
if not self.api_key or self.api_key == 'TWOJ_KLUCZ_API_TUTAJ':
|
|
raise ValueError(
|
|
"GOOGLE_GEMINI_API_KEY not configured. "
|
|
"Please add your API key to .env file."
|
|
)
|
|
|
|
# Configure Gemini
|
|
genai.configure(api_key=self.api_key)
|
|
|
|
# Set model
|
|
self.model_name = GEMINI_MODELS.get(model, GEMINI_MODELS['flash'])
|
|
self.model = genai.GenerativeModel(self.model_name)
|
|
|
|
# Safety settings (disabled for testing - enable in production if needed)
|
|
# Note: Even BLOCK_ONLY_HIGH was blocking neutral prompts like "mountain biking"
|
|
# For production apps, consider using BLOCK_ONLY_HIGH or BLOCK_MEDIUM_AND_ABOVE
|
|
self.safety_settings = {
|
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
|
}
|
|
|
|
logger.info(f"Gemini service initialized with model: {self.model_name}")
|
|
|
|
def generate_text(
|
|
self,
|
|
prompt: str,
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
stream: bool = False,
|
|
feature: str = 'general',
|
|
user_id: Optional[int] = None,
|
|
company_id: Optional[int] = None,
|
|
related_entity_type: Optional[str] = None,
|
|
related_entity_id: Optional[int] = None
|
|
) -> str:
|
|
"""
|
|
Generate text using Gemini API with automatic cost tracking.
|
|
|
|
Args:
|
|
prompt: Text prompt to send to the model
|
|
temperature: Sampling temperature (0.0-1.0). Higher = more creative
|
|
max_tokens: Maximum tokens to generate (None = model default)
|
|
stream: Whether to stream the response
|
|
feature: Feature name for cost tracking ('chat', 'news_evaluation', etc.)
|
|
user_id: Optional user ID for cost tracking
|
|
company_id: Optional company ID for context
|
|
related_entity_type: Entity type ('zopk_news', 'chat_message', etc.)
|
|
related_entity_id: Entity ID for reference
|
|
|
|
Returns:
|
|
Generated text response
|
|
|
|
Raises:
|
|
Exception: If API call fails
|
|
"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Use minimal configuration to avoid blocking issues with FREE tier
|
|
# Only set temperature if different from default
|
|
generation_config = None
|
|
if temperature != 0.7 or max_tokens:
|
|
generation_config = {'temperature': temperature}
|
|
if max_tokens:
|
|
generation_config['max_output_tokens'] = max_tokens
|
|
|
|
# Try passing safety_settings to reduce blocking for legitimate news content
|
|
# Note: FREE tier may still have built-in restrictions
|
|
if generation_config:
|
|
response = self.model.generate_content(
|
|
prompt,
|
|
generation_config=generation_config,
|
|
safety_settings=self.safety_settings
|
|
)
|
|
else:
|
|
response = self.model.generate_content(
|
|
prompt,
|
|
safety_settings=self.safety_settings
|
|
)
|
|
|
|
if stream:
|
|
# Return generator for streaming
|
|
return response
|
|
|
|
# Check if response was blocked by safety filters
|
|
if not response.candidates:
|
|
raise Exception(
|
|
f"Response blocked. No candidates returned. "
|
|
f"This may be due to safety filters."
|
|
)
|
|
|
|
candidate = response.candidates[0]
|
|
|
|
# Check finish reason
|
|
if candidate.finish_reason not in [1, 0]: # 1=STOP, 0=UNSPECIFIED
|
|
finish_reasons = {
|
|
2: "SAFETY - Content blocked by safety filters",
|
|
3: "RECITATION - Content blocked due to recitation",
|
|
4: "OTHER - Other reason",
|
|
5: "MAX_TOKENS - Reached max token limit"
|
|
}
|
|
reason = finish_reasons.get(candidate.finish_reason, f"Unknown ({candidate.finish_reason})")
|
|
raise Exception(
|
|
f"Response incomplete. Finish reason: {reason}. "
|
|
f"Try adjusting safety settings or prompt."
|
|
)
|
|
|
|
# Count tokens and log cost
|
|
response_text = response.text
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
|
|
input_tokens = self.count_tokens(prompt)
|
|
output_tokens = self.count_tokens(response_text)
|
|
|
|
logger.info(
|
|
f"Gemini API call successful. "
|
|
f"Tokens: {input_tokens}+{output_tokens}, "
|
|
f"Latency: {latency_ms}ms, "
|
|
f"Model: {self.model_name}"
|
|
)
|
|
|
|
# Log to database for cost tracking
|
|
self._log_api_cost(
|
|
prompt=prompt,
|
|
response_text=response_text,
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
latency_ms=latency_ms,
|
|
success=True,
|
|
feature=feature,
|
|
user_id=user_id,
|
|
company_id=company_id,
|
|
related_entity_type=related_entity_type,
|
|
related_entity_id=related_entity_id
|
|
)
|
|
|
|
return response_text
|
|
|
|
except Exception as e:
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
|
|
# Log failed request
|
|
self._log_api_cost(
|
|
prompt=prompt,
|
|
response_text='',
|
|
input_tokens=self.count_tokens(prompt),
|
|
output_tokens=0,
|
|
latency_ms=latency_ms,
|
|
success=False,
|
|
error_message=str(e),
|
|
feature=feature,
|
|
user_id=user_id,
|
|
company_id=company_id,
|
|
related_entity_type=related_entity_type,
|
|
related_entity_id=related_entity_id
|
|
)
|
|
|
|
logger.error(f"Gemini API error: {str(e)}")
|
|
raise Exception(f"Gemini API call failed: {str(e)}")
|
|
|
|
def chat(self, messages: List[Dict[str, str]]) -> str:
|
|
"""
|
|
Multi-turn chat conversation.
|
|
|
|
Args:
|
|
messages: List of message dicts with 'role' and 'content' keys
|
|
Example: [
|
|
{'role': 'user', 'content': 'Hello'},
|
|
{'role': 'model', 'content': 'Hi there!'},
|
|
{'role': 'user', 'content': 'How are you?'}
|
|
]
|
|
|
|
Returns:
|
|
Model's response to the last message
|
|
"""
|
|
try:
|
|
chat = self.model.start_chat(history=[])
|
|
|
|
# Add conversation history
|
|
for msg in messages[:-1]: # All except last
|
|
if msg['role'] == 'user':
|
|
chat.send_message(msg['content'])
|
|
|
|
# Send last message and get response
|
|
response = chat.send_message(messages[-1]['content'])
|
|
|
|
return response.text
|
|
|
|
except Exception as e:
|
|
logger.error(f"Gemini chat error: {str(e)}")
|
|
raise Exception(f"Gemini chat failed: {str(e)}")
|
|
|
|
def analyze_image(self, image_path: str, prompt: str) -> str:
|
|
"""
|
|
Analyze image with Gemini Vision.
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
prompt: Text prompt describing what to analyze
|
|
|
|
Returns:
|
|
Analysis result
|
|
"""
|
|
try:
|
|
import PIL.Image
|
|
|
|
img = PIL.Image.open(image_path)
|
|
|
|
response = self.model.generate_content(
|
|
[prompt, img],
|
|
safety_settings=self.safety_settings
|
|
)
|
|
|
|
return response.text
|
|
|
|
except Exception as e:
|
|
logger.error(f"Gemini image analysis error: {str(e)}")
|
|
raise Exception(f"Image analysis failed: {str(e)}")
|
|
|
|
def count_tokens(self, text: str) -> int:
|
|
"""
|
|
Count tokens in text.
|
|
|
|
Args:
|
|
text: Text to count tokens for
|
|
|
|
Returns:
|
|
Number of tokens
|
|
"""
|
|
try:
|
|
result = self.model.count_tokens(text)
|
|
return result.total_tokens
|
|
except Exception as e:
|
|
logger.warning(f"Token counting failed: {e}")
|
|
# Rough estimate: ~4 chars per token
|
|
return len(text) // 4
|
|
|
|
def _log_api_cost(
|
|
self,
|
|
prompt: str,
|
|
response_text: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
latency_ms: int,
|
|
success: bool = True,
|
|
error_message: Optional[str] = None,
|
|
feature: str = 'general',
|
|
user_id: Optional[int] = None,
|
|
company_id: Optional[int] = None,
|
|
related_entity_type: Optional[str] = None,
|
|
related_entity_id: Optional[int] = None
|
|
):
|
|
"""
|
|
Log API call costs to database for monitoring
|
|
|
|
Args:
|
|
prompt: Input prompt text
|
|
response_text: Output response text
|
|
input_tokens: Number of input tokens used
|
|
output_tokens: Number of output tokens generated
|
|
latency_ms: Response time in milliseconds
|
|
success: Whether API call succeeded
|
|
error_message: Error details if failed
|
|
feature: Feature name ('chat', 'news_evaluation', 'user_creation', etc.)
|
|
user_id: Optional user ID
|
|
company_id: Optional company ID for context
|
|
related_entity_type: Entity type ('zopk_news', 'chat_message', etc.)
|
|
related_entity_id: Entity ID for reference
|
|
"""
|
|
if not DB_AVAILABLE:
|
|
return
|
|
|
|
try:
|
|
# Calculate costs
|
|
pricing = GEMINI_PRICING.get(self.model_name, {'input': 0.075, 'output': 0.30})
|
|
input_cost = (input_tokens / 1_000_000) * pricing['input']
|
|
output_cost = (output_tokens / 1_000_000) * pricing['output']
|
|
total_cost = input_cost + output_cost
|
|
|
|
# Cost in cents for AIUsageLog (more precise)
|
|
cost_cents = total_cost * 100
|
|
|
|
# Create prompt hash (for debugging, not storing full prompt for privacy)
|
|
prompt_hash = hashlib.sha256(prompt.encode()).hexdigest()
|
|
|
|
# Save to database
|
|
db = SessionLocal()
|
|
try:
|
|
# Log to legacy AIAPICostLog table
|
|
legacy_log = AIAPICostLog(
|
|
timestamp=datetime.now(),
|
|
api_provider='gemini',
|
|
model_name=self.model_name,
|
|
feature=feature,
|
|
user_id=user_id,
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
total_tokens=input_tokens + output_tokens,
|
|
input_cost=input_cost,
|
|
output_cost=output_cost,
|
|
total_cost=total_cost,
|
|
success=success,
|
|
error_message=error_message,
|
|
latency_ms=latency_ms,
|
|
prompt_hash=prompt_hash
|
|
)
|
|
db.add(legacy_log)
|
|
|
|
# Log to new AIUsageLog table (with automatic daily aggregation via trigger)
|
|
usage_log = AIUsageLog(
|
|
request_type=feature,
|
|
model=self.model_name,
|
|
tokens_input=input_tokens,
|
|
tokens_output=output_tokens,
|
|
cost_cents=cost_cents,
|
|
user_id=user_id,
|
|
company_id=company_id,
|
|
related_entity_type=related_entity_type,
|
|
related_entity_id=related_entity_id,
|
|
prompt_length=len(prompt),
|
|
response_length=len(response_text),
|
|
response_time_ms=latency_ms,
|
|
success=success,
|
|
error_message=error_message
|
|
)
|
|
db.add(usage_log)
|
|
|
|
db.commit()
|
|
|
|
logger.info(
|
|
f"API cost logged: {feature} - ${total_cost:.6f} "
|
|
f"({input_tokens}+{output_tokens} tokens, {latency_ms}ms)"
|
|
)
|
|
finally:
|
|
db.close()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to log API cost: {e}")
|
|
|
|
def generate_embedding(
|
|
self,
|
|
text: str,
|
|
task_type: str = 'retrieval_document',
|
|
title: Optional[str] = None,
|
|
user_id: Optional[int] = None,
|
|
feature: str = 'embedding'
|
|
) -> Optional[List[float]]:
|
|
"""
|
|
Generate embedding vector for text using Google's text-embedding model.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
task_type: One of:
|
|
- 'retrieval_document': For documents to be retrieved
|
|
- 'retrieval_query': For search queries
|
|
- 'semantic_similarity': For comparing texts
|
|
- 'classification': For text classification
|
|
- 'clustering': For text clustering
|
|
title: Optional title for document (improves quality)
|
|
user_id: User ID for cost tracking
|
|
feature: Feature name for cost tracking
|
|
|
|
Returns:
|
|
768-dimensional embedding vector or None on error
|
|
|
|
Cost: ~$0.00001 per 1K tokens (very cheap)
|
|
"""
|
|
if not text or not text.strip():
|
|
logger.warning("Empty text provided for embedding")
|
|
return None
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Use text-embedding-004 model (768 dimensions)
|
|
# This is Google's recommended model for embeddings
|
|
result = genai.embed_content(
|
|
model='models/text-embedding-004',
|
|
content=text,
|
|
task_type=task_type,
|
|
title=title
|
|
)
|
|
|
|
embedding = result.get('embedding')
|
|
|
|
if not embedding:
|
|
logger.error("No embedding returned from API")
|
|
return None
|
|
|
|
# Log cost (embedding API is very cheap)
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
token_count = len(text) // 4 # Approximate
|
|
|
|
# Embedding pricing: ~$0.00001 per 1K tokens
|
|
cost_usd = (token_count / 1000) * 0.00001
|
|
|
|
logger.debug(
|
|
f"Embedding generated: {len(embedding)} dims, "
|
|
f"{token_count} tokens, {latency_ms}ms, ${cost_usd:.8f}"
|
|
)
|
|
|
|
# Log to database (if cost tracking is important)
|
|
if DB_AVAILABLE and user_id:
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
usage_log = AIUsageLog(
|
|
request_type=feature,
|
|
model='text-embedding-004',
|
|
tokens_input=token_count,
|
|
tokens_output=0,
|
|
cost_cents=cost_usd * 100,
|
|
user_id=user_id,
|
|
prompt_length=len(text),
|
|
response_length=len(embedding) * 4, # 4 bytes per float
|
|
response_time_ms=latency_ms,
|
|
success=True
|
|
)
|
|
db.add(usage_log)
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
logger.error(f"Failed to log embedding cost: {e}")
|
|
|
|
return embedding
|
|
|
|
except Exception as e:
|
|
logger.error(f"Embedding generation error: {e}")
|
|
return None
|
|
|
|
def generate_embeddings_batch(
|
|
self,
|
|
texts: List[str],
|
|
task_type: str = 'retrieval_document',
|
|
user_id: Optional[int] = None
|
|
) -> List[Optional[List[float]]]:
|
|
"""
|
|
Generate embeddings for multiple texts.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
task_type: Task type for all embeddings
|
|
user_id: User ID for cost tracking
|
|
|
|
Returns:
|
|
List of embedding vectors (None for failed items)
|
|
"""
|
|
results = []
|
|
for text in texts:
|
|
embedding = self.generate_embedding(
|
|
text=text,
|
|
task_type=task_type,
|
|
user_id=user_id,
|
|
feature='embedding_batch'
|
|
)
|
|
results.append(embedding)
|
|
return results
|
|
|
|
|
|
# Global service instance (initialized in app.py)
|
|
_gemini_service: Optional[GeminiService] = None
|
|
|
|
|
|
def init_gemini_service(api_key: Optional[str] = None, model: str = 'flash'):
|
|
"""
|
|
Initialize global Gemini service instance.
|
|
Call this in app.py during Flask app initialization.
|
|
|
|
Args:
|
|
api_key: Google AI API key (optional if set in env)
|
|
model: Model to use ('flash', 'flash-8b', 'pro')
|
|
"""
|
|
global _gemini_service
|
|
try:
|
|
_gemini_service = GeminiService(api_key=api_key, model=model)
|
|
logger.info("Global Gemini service initialized successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize Gemini service: {e}")
|
|
_gemini_service = None
|
|
|
|
|
|
def get_gemini_service() -> Optional[GeminiService]:
|
|
"""
|
|
Get global Gemini service instance.
|
|
|
|
Returns:
|
|
GeminiService instance or None if not initialized
|
|
"""
|
|
return _gemini_service
|
|
|
|
|
|
def generate_text(prompt: str, **kwargs) -> Optional[str]:
|
|
"""
|
|
Convenience function to generate text using global service.
|
|
|
|
Args:
|
|
prompt: Text prompt
|
|
**kwargs: Additional arguments for generate_text()
|
|
|
|
Returns:
|
|
Generated text or None if service not initialized
|
|
"""
|
|
service = get_gemini_service()
|
|
if service:
|
|
return service.generate_text(prompt, **kwargs)
|
|
|
|
logger.warning("Gemini service not initialized")
|
|
return None
|