# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""[Experimental] Auth Tokens API client."""

import json
import logging
from typing import Any, Dict, List, Optional
from urllib.parse import urlencode
from . import _api_module
from . import _common
from . import _tokens_converters as tokens_converters
from . import types

logger = logging.getLogger('google_genai.tokens')


def _get_field_masks(setup: _common.StringDict) -> str:
  """Return field_masks"""
  fields = []
  for k, v in setup.items():
    # 2nd layer, recursively get field masks see TODO(b/418290100)
    if isinstance(v, dict) and v:
      field = [f'{k}.{kk}' for kk in v.keys()]
    else:
      field = [k]  # 1st layer
    fields.extend(field)

  return ','.join(fields)


def _convert_bidi_setup_to_token_setup(
    request_dict: _common.StringDict,
    config: Optional[types.CreateAuthTokenConfigOrDict] = None,
) -> _common.StringDict:
  """Converts bidiGenerateContentSetup."""
  bidi_setup = request_dict.get('bidiGenerateContentSetup')
  if bidi_setup and bidi_setup.get('setup'):
    # Handling mismatch between AuthToken service and
    # BidiGenerateContent service
    request_dict['bidiGenerateContentSetup'] = bidi_setup.get('setup')

    # Convert non null bidiGenerateContentSetup to field_mask
    field_mask = _get_field_masks(request_dict['bidiGenerateContentSetup'])

    if (
        isinstance(config, dict)
        and config.get('lock_additional_fields') is not None
        and not config.get('lock_additional_fields')
    ) or (
        isinstance(config, types.CreateAuthTokenConfig)
        and config.lock_additional_fields is not None
        and not config.lock_additional_fields  # pylint: disable=literal-comparison
    ):
      # Empty list, lock non null fields
      request_dict['fieldMask'] = field_mask
    elif (
        isinstance(config, dict)
        and config.get('lock_additional_fields') is None
    ) or (
        isinstance(config, types.CreateAuthTokenConfig)
        and config.lock_additional_fields is None
    ):
      # None. Global lock. unset fieldMask
      request_dict.pop('fieldMask', None)
    elif request_dict['fieldMask']:
      # Lock non null + additional fields
      additional_fields_list: Optional[List[str]] = request_dict.get(
          'fieldMask'
      )
      generation_config_list = types.GenerationConfig().model_dump().keys()
      if additional_fields_list:
        field_mask_list = []
        for field in additional_fields_list:
          if field in generation_config_list:
            field = f'generationConfig.{field}'
          field_mask_list.append(field)
      else:
        field_mask_list = []
      request_dict['fieldMask'] = (
          field_mask + ',' + ','.join(field_mask_list)
          if field_mask_list
          else field_mask
      )
    else:
      # Lock all fields
      request_dict.pop('fieldMask', None)
  else:
    field_mask = request_dict.get('fieldMask', [])
    field_mask_str = ','.join(field_mask)
    if field_mask:
      request_dict['fieldMask'] = field_mask_str
    else:
      request_dict.pop('fieldMask', None)
  if not request_dict.get('bidiGenerateContentSetup'):
    request_dict.pop('bidiGenerateContentSetup', None)

  return request_dict


class Tokens(_api_module.BaseModule):
  """[Experimental] Auth Tokens API client.

  This class provides methods for creating auth tokens.
  """

  @_common.experimental_warning(
      "The SDK's token creation implementation is experimental, "
      'and may change in future versions.',
  )
  def create(
      self, *, config: Optional[types.CreateAuthTokenConfigOrDict] = None
  ) -> types.AuthToken:
    """[Experimental] Creates an auth token.

    Args:
      config (CreateAuthTokenConfig): Optional configuration for the request.

    The CreateAuthTokenConfig's `live_constrained_parameters` attrubite
    Can be used to lock the parameters of the live session so they
    can't be changed client side. This behavior has two basic modes depending on
    whether `lock_additional_fields` is set:

    If you do not pass `lock_additional_fields` the entire
    `live_constrained_parameters` is locked and can't be changed
    by the token's user.

    If you set `lock_additional_fields`, then the non-null fields of
    `live_constrained_parameters` are locked, and any additional fields
    specified in `lock_additional_fields`.

    Usage:

    .. code-block:: python

      # Case 1: If LiveEphemeralParameters is unset, unlock LiveConnectConfig
      # when using the token in Live API sessions. Each session connection can
      # use a different configuration.

      config = types.CreateAuthTokenConfig(
          uses=10,
          expire_time='2025-05-01T00:00:00Z',
      )
      auth_token = client.tokens.create(config=config)

    .. code-block:: python

      # Case 2: If LiveEphemeralParameters is set, lock all fields in
      # LiveConnectConfig when using the token in Live API sessions. For
      # example, changing `output_audio_transcription` in the Live API
      # connection will be ignored by the API.

      auth_token = client.tokens.create(
          config=types.CreateAuthTokenConfig(
              uses=10,
              live_constrained_parameters=types.LiveEphemeralParameters(
                  model='gemini-live-2.5-flash-preview',
                  config=types.LiveConnectConfig(
                      system_instruction='You are an LLM called Gemini.'
                  ),
              ),
          )
      )

    .. code-block:: python

      # Case 3: If LiveEphemeralParameters is set and lockAdditionalFields is
      # empty, lock LiveConnectConfig with set fields (e.g.
      # system_instruction in this example) when using the token in Live API
      # sessions.
      auth_token = client.tokens.create(
          config=types.CreateAuthTokenConfig(
              uses=10,
              live_constrained_parameters=types.LiveEphemeralParameters(
                  config=types.LiveConnectConfig(
                      system_instruction='You are an LLM called Gemini.'
                  ),
              ),
              lock_additional_fields=[],
          )
      )

    .. code-block:: python

      # Case 4: If LiveEphemeralParameters is set and lockAdditionalFields is
      # set, lock LiveConnectConfig with set and additional fields (e.g.
      # system_instruction, temperature in this example) when using the token
      # in Live API sessions.
      auth_token = client.tokens.create(
          config=types.CreateAuthTokenConfig(
              uses=10,
              live_constrained_parameters=types.LiveEphemeralParameters(
                  model='gemini-live-2.5-flash-preview',
                  config=types.LiveConnectConfig(
                      system_instruction='You are an LLM called Gemini.'
                  ),
              ),
              lock_additional_fields=['temperature'],
          )
      )
    """

    parameter_model = types.CreateAuthTokenParameters(
        config=config,
    )
    request_url_dict: Optional[dict[str, str]]
    if self._api_client.vertexai:
      raise ValueError(
          'This method is only supported in the Gemini Developer client.'
      )
    else:
      request_dict = tokens_converters._CreateAuthTokenParameters_to_mldev(
          self._api_client,
          parameter_model
      )
      request_url_dict = request_dict.get('_url')
      if request_url_dict:
        path = 'auth_tokens'.format_map(request_url_dict)
      else:
        path = 'auth_tokens'

    query_params = request_dict.get('_query')
    if query_params:
      path = f'{path}?{urlencode(query_params)}'
    # TODO: remove the hack that pops config.

    request_dict.pop('config', None)

    # Token creation request data need to replace 'setup' with
    # 'bidiGenerateContentSetup'
    if request_dict:
      request_dict = _convert_bidi_setup_to_token_setup(request_dict, config)

    http_options: Optional[types.HttpOptions] = None
    if (
        parameter_model is not None
        and parameter_model.config is not None
        and parameter_model.config.http_options is not None
    ):
      http_options = parameter_model.config.http_options

    request_dict = _common.convert_to_dict(request_dict)
    request_dict = _common.encode_unserializable_types(request_dict)

    response = self._api_client.request(
        'post', path, request_dict, http_options
    )
    response_dict = '' if not response.body else json.loads(response.body)

    if not self._api_client.vertexai:
      response_dict = tokens_converters._AuthToken_from_mldev(
          response_dict
      )

    return_value = types.AuthToken._from_response(
        response=response_dict, kwargs=parameter_model.model_dump()
    )
    self._api_client._verify_response(return_value)
    return return_value


class AsyncTokens(_api_module.BaseModule):
  """[Experimental] Async Auth Tokens API client.

  This class provides asynchronous methods for creating auth tokens.
  """

  @_common.experimental_warning(
      "The SDK's token creation implementation is experimental, "
      'and may change in future versions.',
  )
  async def create(
      self, *, config: Optional[types.CreateAuthTokenConfigOrDict] = None
  ) -> types.AuthToken:
    """Creates an auth token asynchronously. Support in v1alpha only.

    Args:
      config (CreateAuthTokenConfig): Optional configuration for the request.

    Usage:

    .. code-block:: python

      client = genai.Client(
          api_key=API_KEY,
          http_options=types.HttpOptions(api_version='v1alpha'),
      )

      auth_token = await client.aio.tokens.create(
          config=types.CreateAuthTokenConfig(
              uses=10,
              live_constrained_parameters=types.LiveEphemeralParameters(
                  model='gemini-live-2.5-flash-preview',
                  config=types.LiveConnectConfig(
                      system_instruction='You are an LLM called Gemini.'
                  ),
              ),
          )
      )
    """

    parameter_model = types.CreateAuthTokenParameters(
        config=config,
    )

    request_url_dict: Optional[dict[str, str]]
    if self._api_client.vertexai:
      raise ValueError(
          'This method is only supported in the Gemini Developer client.'
      )
    else:
      request_dict = tokens_converters._CreateAuthTokenParameters_to_mldev(
          self._api_client,
          parameter_model
      )
      request_url_dict = request_dict.get('_url')
      if request_url_dict:
        path = 'auth_tokens'.format_map(request_url_dict)
      else:
        path = 'auth_tokens'

    query_params = request_dict.get('_query')
    if query_params:
      path = f'{path}?{urlencode(query_params)}'
    # TODO: remove the hack that pops config.
    request_dict.pop('config', None)

    # Token creation request data need to replace 'setup' with
    # 'bidiGenerateContentSetup'
    request_dict = _convert_bidi_setup_to_token_setup(request_dict, config)

    http_options: Optional[types.HttpOptions] = None
    if (
        parameter_model is not None
        and parameter_model.config is not None
        and parameter_model.config.http_options is not None
    ):
      http_options = parameter_model.config.http_options

    request_dict = _common.convert_to_dict(request_dict)
    request_dict = _common.encode_unserializable_types(request_dict)

    response = await self._api_client.async_request(
        'post',
        path,
        request_dict,
        http_options=http_options,
    )
    response_dict = '' if not response.body else json.loads(response.body)

    if not self._api_client.vertexai:
      response_dict = tokens_converters._AuthToken_from_mldev(
          response_dict
      )

    return_value = types.AuthToken._from_response(
        response=response_dict, kwargs=parameter_model.model_dump()
    )
    self._api_client._verify_response(return_value)
    return return_value
