# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor

from typing import Any

import torch
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import (
    OpSchema,
    OutputSharding,
    RuntimeSchemaInfo,
)
from torch.distributed.tensor._ops.single_dim_strategy import (
    _ShardingPlaceholder,
    register_single_dim_strategy,
)
from torch.distributed.tensor._ops.utils import register_prop_rule
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate


aten = torch.ops.aten


@register_prop_rule(aten.convolution.default)
def convolution_rules(op_schema: OpSchema) -> OutputSharding:
    (
        input_spec,
        weight_spec,
        bias_spec,
        stride,
        padding,
        dilation,
        _transposed,
        _output_padding,
        _groups,
    ) = op_schema.args_schema

    if not isinstance(input_spec, DTensorSpec):
        raise AssertionError
    if not isinstance(weight_spec, DTensorSpec):
        raise AssertionError
    # bias_spec can be None (optional parameter in aten.convolution schema)
    if bias_spec is not None:
        if not isinstance(bias_spec, DTensorSpec):
            raise AssertionError
    if input_spec.tensor_meta is None:
        raise AssertionError
    if weight_spec.tensor_meta is None:
        raise AssertionError
    in_shape = input_spec.tensor_meta.shape
    weight_shape = weight_spec.tensor_meta.shape
    if not isinstance(stride, list):
        raise AssertionError(f"stride must be list, got {type(stride)}")
    if not isinstance(padding, list):
        raise AssertionError(f"padding must be list, got {type(padding)}")
    if not isinstance(dilation, list):
        raise AssertionError(f"dilation must be list, got {type(dilation)}")
    # weight_shape might not be torch.Size in all cases (e.g., SymIntArrayRef during tracing)
    # so we don't assert its type, just use it
    out_conv_shape = [
        (d + 2 * padding[i] - dilation[i] * (weight_shape[i + 1] - 1) - 1) // stride[i]
        + 1
        for (i, d) in enumerate(in_shape[2:])
    ]
    output_shape = [in_shape[0], weight_shape[0]] + out_conv_shape
    output_stride = [1]
    for i in range(1, len(output_shape)):
        output_stride.insert(0, output_stride[0] * output_shape[-i])
    output_dim_map = input_spec.dim_map
    pending_sums = input_spec.sums

    tensor_meta = TensorMeta(
        torch.Size(output_shape),
        tuple(output_stride),
        input_spec.tensor_meta.dtype,
    )
    return OutputSharding(
        DTensorSpec.from_dim_map(
            input_spec.mesh,
            output_dim_map,
            pending_sums,
            tensor_meta=tensor_meta,
        )
    )


@register_prop_rule(aten.convolution_backward.default)
def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
    input_spec = op_schema.args_schema[0]
    (
        grad_output_spec,
        input_spec,
        weight_spec,
        bias_shape_opt,
        _stride,
        _padding,
        _dilation,
        _transposed,
        _output_padding,
        _groups,
        _output_mask,
    ) = op_schema.args_schema

    if not isinstance(grad_output_spec, DTensorSpec):
        raise AssertionError
    if not isinstance(input_spec, DTensorSpec):
        raise AssertionError
    if not isinstance(weight_spec, DTensorSpec):
        raise AssertionError
    # bias_shape_opt can be None (optional parameter in aten.convolution_backward schema)
    if bias_shape_opt is not None:
        if not isinstance(bias_shape_opt, list):
            raise AssertionError
    if input_spec.tensor_meta is None:
        raise AssertionError
    weight_tensor_meta = weight_spec.tensor_meta

    # Only create bias_tensor_meta if bias_shape_opt is not None
    if bias_shape_opt is not None:
        bias_tensor_meta = TensorMeta(
            torch.Size(bias_shape_opt),
            (1,),
            input_spec.tensor_meta.dtype,
        )
    else:
        bias_tensor_meta = None

    grad_input_spec = input_spec
    grad_weight_spec = DTensorSpec.from_dim_map(
        input_spec.mesh,
        [-1, -1, -1, -1],
        [0],
        tensor_meta=weight_tensor_meta,
    )

    # Only create grad_bias_spec if we have bias_tensor_meta
    if bias_tensor_meta is not None:
        grad_bias_spec = DTensorSpec.from_dim_map(
            input_spec.mesh,
            [-1],
            [0],
            tensor_meta=bias_tensor_meta,
        )
    else:
        grad_bias_spec = None

    # TODO: actually the output_mask is not respected here, we should
    # set the corresponding spec to `None` if the output_mask is not `False`
    # for a certain output Tensor. This also applies to the conv handler
    # in torch/distributed/tensor/_tp_conv.py
    return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec])


# Single-dim strategies for autoparallel optimizer support.
# These coexist with the prop_rules above — strategies take precedence
# in the propagation path, while the prop_rules + custom handlers in
# _tp_conv.py continue to handle runtime dispatch.


@register_single_dim_strategy(
    [aten.convolution.default],
    schema_info=RuntimeSchemaInfo(2),
)
def convolution_single_dim_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder]]:
    bias_meta = args_schema[2]
    # [output, input, weight, (bias)]
    rule: list[Placement | _ShardingPlaceholder] = [
        _ShardingPlaceholder(0),  # output
        _ShardingPlaceholder(0),  # input
        Replicate(),  # weight
    ]
    if bias_meta is not None:
        rule.append(Replicate())  # bias
    return [rule]


@register_single_dim_strategy(
    [aten.convolution_backward.default],
    schema_info=RuntimeSchemaInfo(3),
)
def convolution_backward_single_dim_strategy(
    op: torch._ops.OpOverload,
    args_schema: tuple[Any, ...],
    kwargs_schema: dict[str, Any],
) -> list[list[Placement | _ShardingPlaceholder | None]]:
    bias_sizes = args_schema[3]
    has_bias = bias_sizes is not None
    # outputs: [grad_input, grad_weight, grad_bias]
    # inputs: [grad_output, input, weight]
    rule: list[Placement | _ShardingPlaceholder | None] = [
        _ShardingPlaceholder(0),  # grad_input
        Partial("sum"),  # grad_weight
        Partial("sum") if has_bias else None,  # grad_bias
        _ShardingPlaceholder(0),  # grad_output
        _ShardingPlaceholder(0),  # input
        Replicate(),  # weight
    ]
    return [rule]
