diff --git a/app.py b/app.py index 2b4277b..b13d3a0 100644 --- a/app.py +++ b/app.py @@ -354,6 +354,51 @@ def inject_audit_access(): return dict(is_audit_owner=is_audit_owner()) +@app.context_processor +def inject_company_context(): + """Inject multi-company context into all templates.""" + if not current_user.is_authenticated or not current_user.company_id: + return {} + + from database import UserCompany + from helpers.company_context import get_active_company_id + + db = SessionLocal() + try: + user_companies = db.query(UserCompany).filter_by( + user_id=current_user.id + ).order_by(UserCompany.is_primary.desc(), UserCompany.created_at.asc()).all() + + # Eager-load company objects while session is open + for uc in user_companies: + _ = uc.company.name if uc.company else None + + active_cid = get_active_company_id() + + # Validate active_company_id is still valid for this user + valid_ids = {uc.company_id for uc in user_companies} + if active_cid not in valid_ids: + active_cid = current_user.company_id + session.pop('active_company_id', None) + + active_company = None + for uc in user_companies: + if uc.company_id == active_cid: + active_company = uc.company + break + + return { + 'user_companies': user_companies, + 'active_company_id': active_cid, + 'active_company': active_company, + 'has_multiple_companies': len(user_companies) > 1, + } + except Exception: + return {} + finally: + db.close() + + @app.context_processor def inject_notifications(): """Inject unread notifications count into all templates""" diff --git a/blueprints/auth/routes.py b/blueprints/auth/routes.py index d21c6b8..f3df4d2 100644 --- a/blueprints/auth/routes.py +++ b/blueprints/auth/routes.py @@ -379,6 +379,7 @@ def login(): # No 2FA - login directly login_user(user, remember=remember) + session['active_company_id'] = user.company_id user.last_login = datetime.now() user.login_count = (user.login_count or 0) + 1 _auto_link_person(db, user) @@ -477,6 +478,7 @@ def verify_2fa(): next_page = session.pop('2fa_next', None) login_user(user, remember=remember) + session['active_company_id'] = user.company_id session['2fa_verified'] = True user.last_login = datetime.now() user.login_count = (user.login_count or 0) + 1 diff --git a/blueprints/public/routes.py b/blueprints/public/routes.py index 3be48b1..377b8e1 100644 --- a/blueprints/public/routes.py +++ b/blueprints/public/routes.py @@ -2701,3 +2701,25 @@ def sitemap_xml(): xml_parts.append('') return Response('\n'.join(xml_parts), mimetype='application/xml') + + +@bp.route('/api/switch-company/', methods=['POST']) +@login_required +def switch_company(company_id): + """Switch the active company context for multi-company users.""" + db = SessionLocal() + try: + uc = db.query(UserCompany).filter_by( + user_id=current_user.id, + company_id=company_id + ).first() + + if not uc: + flash('Nie masz uprawnień do tej firmy.', 'error') + else: + session['active_company_id'] = company_id + flash(f'Przełączono na firmę: {uc.company.name}', 'info') + finally: + db.close() + + return redirect(request.referrer or url_for('dashboard')) diff --git a/helpers/__init__.py b/helpers/__init__.py new file mode 100644 index 0000000..5cfb789 --- /dev/null +++ b/helpers/__init__.py @@ -0,0 +1 @@ +# helpers package diff --git a/helpers/company_context.py b/helpers/company_context.py new file mode 100644 index 0000000..363e5cf --- /dev/null +++ b/helpers/company_context.py @@ -0,0 +1,14 @@ +"""Company context helpers for multi-company users.""" + +from flask import session +from flask_login import current_user + + +def get_active_company_id(): + """Return the active company ID from session, falling back to users.company_id.""" + if not current_user.is_authenticated: + return None + active_id = session.get('active_company_id') + if active_id: + return active_id + return current_user.company_id