# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""[Experimental] Text Only Local Tokenizer."""

import logging
from typing import Any, Iterable
from typing import Optional, Union

from sentencepiece import sentencepiece_model_pb2

from . import _common
from . import _local_tokenizer_loader as loader
from . import _transformers as t
from . import types
from . import types
from ._transformers import t_contents

logger = logging.getLogger("google_genai.local_tokenizer")


class _TextsAccumulator:
  """Accumulates countable texts from `Content` and `Tool` objects.

  This class is responsible for traversing complex `Content` and `Tool`
  objects and extracting all the text content that should be included when
  calculating token counts.

  A key feature of this class is its ability to detect unsupported fields in
  `Content` objects. If a user provides a `Content` object with fields that
  this local tokenizer doesn't recognize (e.g., new fields added in a future
  API update), this class will log a warning.

  The detection mechanism for `Content` objects works by recursively building
  a "counted" version of the input object. This "counted" object only
  contains the data that was successfully processed and added to the text
  list for tokenization. After traversing the input, the original `Content`
  object is compared to the "counted" object. If they don't match, it
  signifies the presence of unsupported fields, and a warning is logged.
  """

  def __init__(self) -> None:
    self._texts: list[str] = []

  def get_texts(self) -> Iterable[str]:
    return self._texts

  def add_contents(self, contents: Iterable[types.Content]) -> None:
    for content in contents:
      self.add_content(content)

  def add_content(self, content: types.Content) -> None:
    counted_content = types.Content(parts=[], role=content.role)
    if content.parts:
      for part in content.parts:
        assert counted_content.parts is not None
        counted_part = types.Part()
        if part.file_data is not None or part.inline_data is not None:
          raise ValueError(
              "LocalTokenizers do not support non-text content types."
          )
        if part.video_metadata is not None:
          counted_part.video_metadata = part.video_metadata
        if part.function_call is not None:
          self.add_function_call(part.function_call)
          counted_part.function_call = part.function_call
        if part.function_response is not None:
          self.add_function_response(part.function_response)
          counted_part.function_response = part.function_response
        if part.text is not None:
          counted_part.text = part.text
          self._texts.append(part.text)
        counted_content.parts.append(counted_part)

    if content.model_dump(exclude_none=True) != counted_content.model_dump(
        exclude_none=True
    ):
      logger.warning(
          "Content contains unsupported types for token counting. Supported"
          f" fields {counted_content}. Got {content}."
      )

  def add_function_call(self, function_call: types.FunctionCall) -> None:
    """Processes a function call and adds relevant text to the accumulator.

    Args:
        function_call: The function call to process.
    """
    if function_call.name:
      self._texts.append(function_call.name)
    counted_function_call = types.FunctionCall(name=function_call.name)
    if function_call.args:
      counted_args = self._dict_traverse(function_call.args)
      counted_function_call.args = counted_args

  def add_tool(self, tool: types.Tool) -> types.Tool:
    counted_tool = types.Tool(function_declarations=[])
    if tool.function_declarations:
      for function_declaration in tool.function_declarations:
        counted_function_declaration = self._function_declaration_traverse(
            function_declaration
        )
        if counted_tool.function_declarations is None:
          counted_tool.function_declarations = []
        counted_tool.function_declarations.append(counted_function_declaration)

    return counted_tool

  def add_tools(self, tools: Iterable[types.Tool]) -> None:
    for tool in tools:
      self.add_tool(tool)

  def add_function_responses(
      self, function_responses: Iterable[types.FunctionResponse]
  ) -> None:
    for function_response in function_responses:
      self.add_function_response(function_response)

  def add_function_response(
      self, function_response: types.FunctionResponse
  ) -> None:
    counted_function_response = types.FunctionResponse()
    if function_response.name:
      self._texts.append(function_response.name)
      counted_function_response.name = function_response.name
    if function_response.response:
      counted_response = self._dict_traverse(function_response.response)
      counted_function_response.response = counted_response

  def _function_declaration_traverse(
      self, function_declaration: types.FunctionDeclaration
  ) -> types.FunctionDeclaration:
    counted_function_declaration = types.FunctionDeclaration()
    if function_declaration.name:
      self._texts.append(function_declaration.name)
      counted_function_declaration.name = function_declaration.name
    if function_declaration.description:
      self._texts.append(function_declaration.description)
      counted_function_declaration.description = (
          function_declaration.description
      )
    if function_declaration.parameters:
      counted_parameters = self.add_schema(function_declaration.parameters)
      counted_function_declaration.parameters = counted_parameters
    if function_declaration.response:
      counted_response = self.add_schema(function_declaration.response)
      counted_function_declaration.response = counted_response
    return counted_function_declaration

  def add_schema(self, schema: types.Schema) -> types.Schema:
    """Processes a schema and adds relevant text to the accumulator.

    Args:
        schema: The schema to process.

    Returns:
        The new schema object with only countable fields.
    """
    counted_schema = types.Schema()
    if schema.type:
      counted_schema.type = schema.type
    if schema.title:
      counted_schema.title = schema.title
    if schema.default is not None:
      counted_schema.default = schema.default
    if schema.format:
      self._texts.append(schema.format)
      counted_schema.format = schema.format
    if schema.description:
      self._texts.append(schema.description)
      counted_schema.description = schema.description
    if schema.enum:
      self._texts.extend(schema.enum)
      counted_schema.enum = schema.enum
    if schema.required:
      self._texts.extend(schema.required)
      counted_schema.required = schema.required
    if schema.property_ordering:
      counted_schema.property_ordering = schema.property_ordering
    if schema.items:
      counted_schema_items = self.add_schema(schema.items)
      counted_schema.items = counted_schema_items
    if schema.properties:
      d = {}
      for key, value in schema.properties.items():
        self._texts.append(key)
        counted_value = self.add_schema(value)
        d[key] = counted_value
      counted_schema.properties = d
    if schema.example:
      counted_schema_example = self._any_traverse(schema.example)
      counted_schema.example = counted_schema_example
    return counted_schema

  def _dict_traverse(self, d: dict[str, Any]) -> dict[str, Any]:
    """Processes a dict and adds relevant text to the accumulator.

    Args:
        d: The dict to process.

    Returns:
        The new dict object with only countable fields.
    """
    counted_dict = {}
    self._texts.extend(list(d.keys()))
    for key, val in d.items():
      counted_dict[key] = self._any_traverse(val)
    return counted_dict

  def _any_traverse(self, value: Any) -> Any:
    """Processes a value and adds relevant text to the accumulator.

    Args:
        value: The value to process.

    Returns:
        The new value with only countable fields.
    """
    if isinstance(value, str):
      self._texts.append(value)
      return value
    elif isinstance(value, dict):
      return self._dict_traverse(value)
    elif isinstance(value, list):
      return [self._any_traverse(item) for item in value]
    else:
      return value


def _token_str_to_bytes(
    token: str, type: sentencepiece_model_pb2.ModelProto.SentencePiece.Type
) -> bytes:
  if type == sentencepiece_model_pb2.ModelProto.SentencePiece.Type.BYTE:
    return _parse_hex_byte(token).to_bytes(length=1, byteorder="big")
  else:
    return token.replace("▁", " ").encode("utf-8")


def _parse_hex_byte(token: str) -> int:
  """Parses a hex byte string of the form '<0xXX>' and returns the integer value.

  Raises ValueError if the input is malformed or the byte value is invalid.
  """

  if len(token) != 6:
    raise ValueError(f"Invalid byte length: {token}")
  if not token.startswith("<0x") or not token.endswith(">"):
    raise ValueError(f"Invalid byte format: {token}")

  try:
    val = int(token[3:5], 16)  # Parse the hex part directly
  except ValueError:
    raise ValueError(f"Invalid hex value: {token}")

  if val >= 256:
    raise ValueError(f"Byte value out of range: {token}")

  return val


class LocalTokenizer:
  """[Experimental] Text Only Local Tokenizer.

  This class provides a local tokenizer for text only token counting.

  LIMITATIONS:
  - Only supports text based tokenization and no multimodal tokenization.
  - Forward compatibility depends on the open-source tokenizer models for future
  Gemini versions.
  - For token counting of tools and response schemas, the `LocalTokenizer` only
  supports `types.Tool` and `types.Schema` objects. Python functions or Pydantic
  models cannot be passed directly.
  """

  def __init__(self, model_name: str):
    self._tokenizer_name = loader.get_tokenizer_name(model_name)
    self._model_proto = loader.load_model_proto(self._tokenizer_name)
    self._tokenizer = loader.get_sentencepiece(self._tokenizer_name)

  @_common.experimental_warning(
      "The SDK's local tokenizer implementation is experimental and may change"
      " in the future. It only supports text based tokenization."
  )
  def count_tokens(
      self,
      contents: Union[types.ContentListUnion, types.ContentListUnionDict],
      *,
      config: Optional[types.CountTokensConfigOrDict] = None,
  ) -> types.CountTokensResult:
    """Counts the number of tokens in a given text.

    Args:
      contents: The contents to tokenize.

    Returns:
      A `CountTokensResult` containing the total number of tokens.
    """
    processed_contents = t.t_contents(contents)
    text_accumulator = _TextsAccumulator()
    config = types.CountTokensConfig.model_validate(config or {})
    text_accumulator.add_contents(processed_contents)
    if config.tools:
      text_accumulator.add_tools(config.tools)
    if config.generation_config and config.generation_config.response_schema:
      text_accumulator.add_schema(config.generation_config.response_schema)
    if config.system_instruction:
      text_accumulator.add_contents(t.t_contents([config.system_instruction]))
    tokens_list = self._tokenizer.encode(list(text_accumulator.get_texts()))
    return types.CountTokensResult(
        total_tokens=sum(len(tokens) for tokens in tokens_list)
    )

  @_common.experimental_warning(
      "The SDK's local tokenizer implementation is experimental and may change"
      " in the future. It only supports text based tokenization."
  )
  def compute_tokens(
      self,
      contents: Union[types.ContentListUnion, types.ContentListUnionDict],
  ) -> types.ComputeTokensResult:
    """Computes the tokens ids and string pieces in the input."""
    processed_contents = t.t_contents(contents)
    text_accumulator = _TextsAccumulator()
    for content in processed_contents:
      text_accumulator.add_content(content)
    tokens_protos = self._tokenizer.EncodeAsImmutableProto(
        text_accumulator.get_texts()
    )

    roles = []
    for content in processed_contents:
      if content.parts:
        for _ in content.parts:
          roles.append(content.role)

    token_infos = []
    for tokens_proto, role in zip(tokens_protos, roles):
      token_infos.append(
          types.TokensInfo(
              token_ids=[piece.id for piece in tokens_proto.pieces],
              tokens=[
                  _token_str_to_bytes(
                      piece.piece, self._model_proto.pieces[piece.id].type
                  )
                  for piece in tokens_proto.pieces
              ],
              role=role,
          )
      )
    return types.ComputeTokensResult(tokens_info=token_infos)
