# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Extra utils depending on types that are shared between sync and async modules."""

import inspect
import logging
import sys
import typing
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin

import pydantic

from . import _common
from . import _mcp_utils
from . import _transformers as t
from . import errors
from . import types
from ._adapters import McpToGenAiToolAdapter


if sys.version_info >= (3, 10):
  from types import UnionType
else:
  UnionType = typing._UnionGenericAlias  # type: ignore[attr-defined]

if typing.TYPE_CHECKING:
  from mcp import ClientSession as McpClientSession
  from mcp.types import Tool as McpTool
else:
  McpClientSession: typing.Type = Any
  McpTool: typing.Type = Any
  try:
    from mcp import ClientSession as McpClientSession
    from mcp.types import Tool as McpTool
  except ImportError:
    McpClientSession = None
    McpTool = None

_DEFAULT_MAX_REMOTE_CALLS_AFC = 10

logger = logging.getLogger('google_genai.models')


def _create_generate_content_config_model(
    config: types.GenerateContentConfigOrDict,
) -> types.GenerateContentConfig:
  if isinstance(config, dict):
    return types.GenerateContentConfig(**config)
  else:
    return config


def _get_gcs_uri(
    src: Union[str, types.BatchJobSourceOrDict]
) -> Optional[str]:
  """Extracts the first GCS URI from the source, if available."""
  if isinstance(src, str) and src.startswith('gs://'):
    return src
  elif isinstance(src, dict) and src.get('gcs_uri'):
    return src['gcs_uri'][0] if src['gcs_uri'] else None
  elif isinstance(src, types.BatchJobSource) and src.gcs_uri:
    return src.gcs_uri[0] if src.gcs_uri else None
  return None


def _get_bigquery_uri(
    src: Union[str, types.BatchJobSourceOrDict]
) -> Optional[str]:
  """Extracts the BigQuery URI from the source, if available."""
  if isinstance(src, str) and src.startswith('bq://'):
    return src
  elif isinstance(src, dict) and src.get('bigquery_uri'):
    return src['bigquery_uri']
  elif isinstance(src, types.BatchJobSource) and src.bigquery_uri:
    return src.bigquery_uri
  return None


def format_destination(
    src: Union[str, types.BatchJobSourceOrDict],
    config: Optional[types.CreateBatchJobConfigOrDict] = None,
) -> types.CreateBatchJobConfig:
  """Formats the destination uri based on the source uri for Vertex AI."""
  config = (
      types._CreateBatchJobParameters(config=config).config
      or types.CreateBatchJobConfig()
  )

  unique_name = None
  if not config.display_name:
    unique_name = _common.timestamped_unique_name()
    config.display_name = f'genai_batch_job_{unique_name}'

  if not config.dest:
    gcs_source_uri = _get_gcs_uri(src)
    bigquery_source_uri = _get_bigquery_uri(src)

    if gcs_source_uri and gcs_source_uri.endswith('.jsonl'):
      config.dest = f'{gcs_source_uri[:-6]}/dest'
    elif bigquery_source_uri:
      unique_name = unique_name or _common.timestamped_unique_name()
      config.dest = f'{bigquery_source_uri}_dest_{unique_name}'
    else:
      raise ValueError(f'The source {src} is not supported.')
  return config


def get_function_map(
    config: Optional[types.GenerateContentConfigOrDict] = None,
    mcp_to_genai_tool_adapters: Optional[
        dict[str, McpToGenAiToolAdapter]
    ] = None,
    is_caller_method_async: bool = False,
) -> dict[str, Union[Callable[..., Any], McpToGenAiToolAdapter]]:
  """Returns a function map from the config."""
  function_map: dict[str, Union[Callable[..., Any], McpToGenAiToolAdapter]] = {}
  if not config:
    return function_map
  config_model = _create_generate_content_config_model(config)
  if config_model.tools:
    for tool in config_model.tools:
      if callable(tool):
        if inspect.iscoroutinefunction(tool) and not is_caller_method_async:
          raise errors.UnsupportedFunctionError(
              f'Function {tool.__name__} is a coroutine function, which is not'
              ' supported for automatic function calling. Please manually'
              f' invoke {tool.__name__} to get the function response.'
          )
        function_map[tool.__name__] = tool
  if mcp_to_genai_tool_adapters:
    if not is_caller_method_async:
      raise errors.UnsupportedFunctionError(
          'MCP tools are not supported in synchronous methods.'
      )
    for tool_name, _ in mcp_to_genai_tool_adapters.items():
      if function_map.get(tool_name):
        raise ValueError(
            f'Tool {tool_name} is already defined for the request.'
        )
    function_map.update(mcp_to_genai_tool_adapters)
  return function_map


def convert_number_values_for_dict_function_call_args(
    args: _common.StringDict,
) -> _common.StringDict:
  """Converts float values in dict with no decimal to integers."""
  return {
      key: convert_number_values_for_function_call_args(value)
      for key, value in args.items()
  }


def convert_number_values_for_function_call_args(
    args: Union[dict[str, object], list[object], object],
) -> Union[dict[str, object], list[object], object]:
  """Converts float values with no decimal to integers."""
  if isinstance(args, float) and args.is_integer():
    return int(args)
  if isinstance(args, dict):
    return {
        key: convert_number_values_for_function_call_args(value)
        for key, value in args.items()
    }
  if isinstance(args, list):
    return [
        convert_number_values_for_function_call_args(value) for value in args
    ]
  return args


def is_annotation_pydantic_model(annotation: Any) -> bool:
  try:
    return inspect.isclass(annotation) and issubclass(
        annotation, pydantic.BaseModel
    )
  # for python 3.10 and below, inspect.isclass(annotation) has inconsistent
  # results with versions above. for example, inspect.isclass(dict[str, int]) is
  # True in 3.10 and below but False in 3.11 and above.
  except TypeError:
    return False


def convert_if_exist_pydantic_model(
    value: Any, annotation: Any, param_name: str, func_name: str
) -> Any:
  if isinstance(value, dict) and is_annotation_pydantic_model(annotation):
    try:
      return annotation(**value)
    except pydantic.ValidationError as e:
      raise errors.UnknownFunctionCallArgumentError(
          f'Failed to parse parameter {param_name} for function'
          f' {func_name} from function call part because function call argument'
          f' value {value} is not compatible with parameter annotation'
          f' {annotation}, due to error {e}'
      )
  if isinstance(value, list) and get_origin(annotation) == list:
    item_type = get_args(annotation)[0]
    return [
        convert_if_exist_pydantic_model(item, item_type, param_name, func_name)
        for item in value
    ]
  if isinstance(value, dict) and get_origin(annotation) == dict:
    _, value_type = get_args(annotation)
    return {
        k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name)
        for k, v in value.items()
    }
  # example 1: typing.Union[int, float]
  # example 2: int | float equivalent to UnionType[int, float]
  if get_origin(annotation) in (Union, UnionType):
    for arg in get_args(annotation):
      if (
          (get_args(arg) and get_origin(arg) is list)
          or isinstance(value, arg)
          or (isinstance(value, dict) and is_annotation_pydantic_model(arg))
      ):
        try:
          return convert_if_exist_pydantic_model(
              value, arg, param_name, func_name
          )
        # do not raise here because there could be multiple pydantic model types
        # in the union type.
        except pydantic.ValidationError:
          continue
    # if none of the union type is matched, raise error
    raise errors.UnknownFunctionCallArgumentError(
        f'Failed to parse parameter {param_name} for function'
        f' {func_name} from function call part because function call argument'
        f' value {value} cannot be converted to parameter annotation'
        f' {annotation}.'
    )
  # the only exception for value and annotation type to be different is int and
  # float. see convert_number_values_for_function_call_args function for context
  if isinstance(value, int) and annotation is float:
    return value
  if not isinstance(value, annotation):
    raise errors.UnknownFunctionCallArgumentError(
        f'Failed to parse parameter {param_name} for function {func_name} from'
        f' function call part because function call argument value {value} is'
        f' not compatible with parameter annotation {annotation}.'
    )
  return value


def convert_argument_from_function(
    args: _common.StringDict, function: Callable[..., Any]
) -> _common.StringDict:
  signature = inspect.signature(function)
  func_name = function.__name__
  converted_args = {}
  for param_name, param in signature.parameters.items():
    if param_name in args:
      converted_args[param_name] = convert_if_exist_pydantic_model(
          args[param_name],
          param.annotation,
          param_name,
          func_name,
      )
  return converted_args


def invoke_function_from_dict_args(
    args: _common.StringDict, function_to_invoke: Callable[..., Any]
) -> Any:
  converted_args = convert_argument_from_function(args, function_to_invoke)
  try:
    return function_to_invoke(**converted_args)
  except Exception as e:
    raise errors.FunctionInvocationError(
        f'Failed to invoke function {function_to_invoke.__name__} with'
        f' converted arguments {converted_args} from model returned function'
        f' call argument {args} because of error {e}'
    )


async def invoke_function_from_dict_args_async(
    args: _common.StringDict, function_to_invoke: Callable[..., Any]
) -> Any:
  converted_args = convert_argument_from_function(args, function_to_invoke)
  try:
    return await function_to_invoke(**converted_args)
  except Exception as e:
    raise errors.FunctionInvocationError(
        f'Failed to invoke function {function_to_invoke.__name__} with'
        f' converted arguments {converted_args} from model returned function'
        f' call argument {args} because of error {e}'
    )


def get_function_response_parts(
    response: types.GenerateContentResponse,
    function_map: dict[str, Union[Callable[..., Any], McpToGenAiToolAdapter]],
) -> list[types.Part]:
  """Returns the function response parts from the response."""
  func_response_parts = []
  if (
      response.candidates is not None
      and isinstance(response.candidates[0].content, types.Content)
      and response.candidates[0].content.parts is not None
  ):
    for part in response.candidates[0].content.parts:
      if not part.function_call:
        continue
      func_name = part.function_call.name
      if func_name is not None and part.function_call.args is not None:
        func = function_map[func_name]
        args = convert_number_values_for_dict_function_call_args(
            part.function_call.args
        )
        func_response: _common.StringDict
        try:
          if not isinstance(func, McpToGenAiToolAdapter):
            func_response = {
                'result': invoke_function_from_dict_args(args, func)
            }
        except Exception as e:  # pylint: disable=broad-except
          func_response = {'error': str(e)}
        func_response_part = types.Part.from_function_response(
            name=func_name, response=func_response
        )
        func_response_parts.append(func_response_part)
  return func_response_parts


async def get_function_response_parts_async(
    response: types.GenerateContentResponse,
    function_map: dict[str, Union[Callable[..., Any], McpToGenAiToolAdapter]],
) -> list[types.Part]:
  """Returns the function response parts from the response."""
  func_response_parts = []
  if (
      response.candidates is not None
      and isinstance(response.candidates[0].content, types.Content)
      and response.candidates[0].content.parts is not None
  ):
    for part in response.candidates[0].content.parts:
      if not part.function_call:
        continue
      func_name = part.function_call.name
      if func_name is not None and part.function_call.args is not None:
        func = function_map[func_name]
        args = convert_number_values_for_dict_function_call_args(
            part.function_call.args
        )
        func_response: _common.StringDict
        try:
          if isinstance(func, McpToGenAiToolAdapter):
            mcp_tool_response = await func.call_tool(
                types.FunctionCall(name=func_name, args=args)
            )
            if mcp_tool_response.isError:
              func_response = {'error': mcp_tool_response}
            else:
              func_response = {'result': mcp_tool_response}
          elif inspect.iscoroutinefunction(func):
            func_response = {
                'result': await invoke_function_from_dict_args_async(args, func)
            }
          else:
            func_response = {
                'result': invoke_function_from_dict_args(args, func)
            }
        except Exception as e:  # pylint: disable=broad-except
          func_response = {'error': str(e)}
        func_response_part = types.Part.from_function_response(
            name=func_name, response=func_response
        )
        func_response_parts.append(func_response_part)
  return func_response_parts


def should_disable_afc(
    config: Optional[types.GenerateContentConfigOrDict] = None,
) -> bool:
  """Returns whether automatic function calling is enabled."""
  if not config:
    return False
  config_model = _create_generate_content_config_model(config)
  # If max_remote_calls is less or equal to 0, warn and disable AFC.
  if (
      config_model
      and config_model.automatic_function_calling
      and config_model.automatic_function_calling.maximum_remote_calls
      is not None
      and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
  ):
    logger.warning(
        'max_remote_calls in automatic_function_calling_config'
        f' {config_model.automatic_function_calling.maximum_remote_calls} is'
        ' less than or equal to 0. Disabling automatic function calling.'
        ' Please set max_remote_calls to a positive integer.'
    )
    return True

  # Default to enable AFC if not specified.
  if (
      not config_model.automatic_function_calling
      or config_model.automatic_function_calling.disable is None
  ):
    return False

  if (
      config_model.automatic_function_calling.disable
      and config_model.automatic_function_calling.maximum_remote_calls
      is not None
      # exclude the case where max_remote_calls is set to 10 by default.
      and 'maximum_remote_calls'
      in config_model.automatic_function_calling.model_fields_set
      and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
  ):
    logger.warning(
        '`automatic_function_calling.disable` is set to `True`. And'
        ' `automatic_function_calling.maximum_remote_calls` is a'
        ' positive number'
        f' {config_model.automatic_function_calling.maximum_remote_calls}.'
        ' Disabling automatic function calling. If you want to enable'
        ' automatic function calling, please set'
        ' `automatic_function_calling.disable` to `False` or leave it unset,'
        ' and set `automatic_function_calling.maximum_remote_calls` to a'
        ' positive integer or leave'
        ' `automatic_function_calling.maximum_remote_calls` unset.'
    )

  return config_model.automatic_function_calling.disable


def get_max_remote_calls_afc(
    config: Optional[types.GenerateContentConfigOrDict] = None,
) -> int:
  if not config:
    return _DEFAULT_MAX_REMOTE_CALLS_AFC
  """Returns the remaining remote calls for automatic function calling."""
  if should_disable_afc(config):
    raise ValueError(
        'automatic function calling is not enabled, but SDK is trying to get'
        ' max remote calls.'
    )
  config_model = _create_generate_content_config_model(config)
  if (
      not config_model.automatic_function_calling
      or config_model.automatic_function_calling.maximum_remote_calls is None
  ):
    return _DEFAULT_MAX_REMOTE_CALLS_AFC
  return int(config_model.automatic_function_calling.maximum_remote_calls)


def should_append_afc_history(
    config: Optional[types.GenerateContentConfigOrDict] = None,
) -> bool:
  if not config:
    return True
  config_model = _create_generate_content_config_model(config)
  if not config_model.automatic_function_calling:
    return True
  return not config_model.automatic_function_calling.ignore_call_history


def parse_config_for_mcp_usage(
    config: Optional[types.GenerateContentConfigOrDict] = None,
) -> Optional[types.GenerateContentConfig]:
  """Returns a parsed config with an appended MCP header if MCP tools or sessions are used."""
  if not config:
    return None
  config_model = _create_generate_content_config_model(config)
  # Create a copy of the config model with the tools field cleared since some
  # tools may not be pickleable.
  config_model_copy = config_model.model_copy(update={'tools': None})
  config_model_copy.tools = config_model.tools
  if config_model.tools and _mcp_utils.has_mcp_tool_usage(config_model.tools):
    if config_model_copy.http_options is None:
      config_model_copy.http_options = types.HttpOptions(headers={})
    if config_model_copy.http_options.headers is None:
      config_model_copy.http_options.headers = {}
    _mcp_utils.set_mcp_usage_header(config_model_copy.http_options.headers)

  return config_model_copy


async def parse_config_for_mcp_sessions(
    config: Optional[types.GenerateContentConfigOrDict] = None,
) -> tuple[
    Optional[types.GenerateContentConfig],
    dict[str, McpToGenAiToolAdapter],
]:
  """Returns a parsed config with MCP sessions converted to GenAI tools.

  Also returns a map of MCP tools to GenAI tool adapters to be used for AFC.
  """
  mcp_to_genai_tool_adapters: dict[str, McpToGenAiToolAdapter] = {}
  parsed_config = parse_config_for_mcp_usage(config)
  if not parsed_config:
    return None, mcp_to_genai_tool_adapters
  # Create a copy of the config model with the tools field cleared as they will
  # be replaced with the MCP tools converted to GenAI tools.
  parsed_config_copy = parsed_config.model_copy(update={'tools': None})
  if parsed_config.tools:
    parsed_config_copy.tools = []
    for tool in parsed_config.tools:
      if McpClientSession is not None and isinstance(tool, McpClientSession):
        mcp_to_genai_tool_adapter = McpToGenAiToolAdapter(
            tool, await tool.list_tools()
        )
        # Extend the config with the MCP session tools converted to GenAI tools.
        parsed_config_copy.tools.extend(mcp_to_genai_tool_adapter.tools)
        for genai_tool in mcp_to_genai_tool_adapter.tools:
          if genai_tool.function_declarations:
            for function_declaration in genai_tool.function_declarations:
              if function_declaration.name:
                if mcp_to_genai_tool_adapters.get(function_declaration.name):
                  raise ValueError(
                      f'Tool {function_declaration.name} is already defined for'
                      ' the request.'
                  )
                mcp_to_genai_tool_adapters[function_declaration.name] = (
                    mcp_to_genai_tool_adapter
                )
      else:
        parsed_config_copy.tools.append(tool)

  return parsed_config_copy, mcp_to_genai_tool_adapters


def append_chunk_contents(
    contents: Union[types.ContentListUnion, types.ContentListUnionDict],
    chunk: types.GenerateContentResponse,
) -> None:
  """Appends the contents of the chunk to the contents list."""
  if chunk is not None and chunk.candidates is not None:
    chunk_content = chunk.candidates[0].content
    contents = t.t_contents(contents)  # type: ignore[assignment]
    if isinstance(contents, list) and chunk_content is not None:
      contents.append(chunk_content)  # type: ignore[arg-type]
