mirror of https://github.com/langgenius/dify.git
Merge a1fb0fd065 into 9065d54f4a
This commit is contained in:
commit
6d37db59e6
|
|
@ -41,14 +41,36 @@ def datetime_to_nanoseconds(dt: datetime | None) -> int | None:
|
||||||
class MLflowDataTrace(BaseTraceInstance):
|
class MLflowDataTrace(BaseTraceInstance):
|
||||||
def __init__(self, config: MLflowConfig | DatabricksConfig):
|
def __init__(self, config: MLflowConfig | DatabricksConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
if isinstance(config, DatabricksConfig):
|
self._config = config
|
||||||
self._setup_databricks(config)
|
self._project_url = ""
|
||||||
else:
|
# Defer actual setup to first use to avoid blocking workflow on initialization
|
||||||
self._setup_mlflow(config)
|
self._initialized = False
|
||||||
|
|
||||||
# Enable async logging to minimize performance overhead
|
# Enable async logging to minimize performance overhead
|
||||||
os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "true"
|
os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "true"
|
||||||
|
|
||||||
|
def _ensure_initialized(self):
|
||||||
|
"""Lazy initialization to prevent blocking workflow when tracing service is unavailable."""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(self._config, DatabricksConfig):
|
||||||
|
self._setup_databricks(self._config)
|
||||||
|
else:
|
||||||
|
self._setup_mlflow(self._config)
|
||||||
|
self._initialized = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("[MLflow] Failed to initialize tracing: %s. Tracing will be disabled.", e)
|
||||||
|
# Set a fallback project URL even on failure
|
||||||
|
if isinstance(self._config, DatabricksConfig):
|
||||||
|
host = self._config.host.rstrip("/")
|
||||||
|
self._project_url = f"{host}/ml/experiments/{self._config.experiment_id}/traces"
|
||||||
|
else:
|
||||||
|
self._project_url = f"{self._config.tracking_uri}/#/experiments/{self._config.experiment_id}/traces"
|
||||||
|
# Mark as initialized to avoid repeated attempts
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
def _setup_databricks(self, config: DatabricksConfig):
|
def _setup_databricks(self, config: DatabricksConfig):
|
||||||
"""Setup connection to Databricks-managed MLflow instances"""
|
"""Setup connection to Databricks-managed MLflow instances"""
|
||||||
os.environ["DATABRICKS_HOST"] = config.host
|
os.environ["DATABRICKS_HOST"] = config.host
|
||||||
|
|
@ -87,6 +109,9 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||||
|
|
||||||
def trace(self, trace_info: BaseTraceInfo):
|
def trace(self, trace_info: BaseTraceInfo):
|
||||||
"""Simple dispatch to trace methods"""
|
"""Simple dispatch to trace methods"""
|
||||||
|
# Initialize on first use
|
||||||
|
self._ensure_initialized()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(trace_info, WorkflowTraceInfo):
|
if isinstance(trace_info, WorkflowTraceInfo):
|
||||||
self.workflow_trace(trace_info)
|
self.workflow_trace(trace_info)
|
||||||
|
|
@ -104,7 +129,6 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||||
self.generate_name_trace(trace_info)
|
self.generate_name_trace(trace_info)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[MLflow] Trace error")
|
logger.exception("[MLflow] Trace error")
|
||||||
raise
|
|
||||||
|
|
||||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||||
"""Create workflow span as root, with node spans as children"""
|
"""Create workflow span as root, with node spans as children"""
|
||||||
|
|
|
||||||
|
|
@ -935,11 +935,26 @@ class TraceQueueManager:
|
||||||
|
|
||||||
self.app_id = app_id
|
self.app_id = app_id
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
|
# Lazy initialization: don't create trace instance until needed
|
||||||
|
# This prevents workflow from being blocked if tracing service is unavailable
|
||||||
|
self._trace_instance = None
|
||||||
|
self._trace_instance_initialized = False
|
||||||
self.flask_app = current_app._get_current_object() # type: ignore
|
self.flask_app = current_app._get_current_object() # type: ignore
|
||||||
if trace_manager_timer is None:
|
if trace_manager_timer is None:
|
||||||
self.start_timer()
|
self.start_timer()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trace_instance(self):
|
||||||
|
"""Lazy initialization of trace instance to avoid blocking workflow on startup."""
|
||||||
|
if not self._trace_instance_initialized:
|
||||||
|
self._trace_instance_initialized = True
|
||||||
|
try:
|
||||||
|
self._trace_instance = OpsTraceManager.get_ops_trace_instance(self.app_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to initialize trace instance for app_id: %s", self.app_id)
|
||||||
|
self._trace_instance = None
|
||||||
|
return self._trace_instance
|
||||||
|
|
||||||
def add_trace_task(self, trace_task: TraceTask):
|
def add_trace_task(self, trace_task: TraceTask):
|
||||||
global trace_manager_timer, trace_manager_queue
|
global trace_manager_timer, trace_manager_queue
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -230,13 +230,26 @@ class TestDatetimeToNanoseconds:
|
||||||
|
|
||||||
|
|
||||||
class TestInit:
|
class TestInit:
|
||||||
def test_mlflow_config_no_auth(self, mock_mlflow):
|
def test_mlflow_config_no_auth_lazy_init(self, mock_mlflow):
|
||||||
|
"""Test that MLflow setup is deferred until first use (lazy initialization)."""
|
||||||
config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0")
|
config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0")
|
||||||
trace = MLflowDataTrace(config)
|
trace = MLflowDataTrace(config)
|
||||||
|
# Setup should NOT be called during __init__ (lazy initialization)
|
||||||
|
mock_mlflow.set_tracking_uri.assert_not_called()
|
||||||
|
mock_mlflow.set_experiment.assert_not_called()
|
||||||
|
# Project URL should be empty before initialization
|
||||||
|
assert trace._project_url == ""
|
||||||
|
assert os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] == "true"
|
||||||
|
|
||||||
|
def test_mlflow_config_initialized_on_first_use(self, mock_mlflow):
|
||||||
|
"""Test that MLflow setup is called when _ensure_initialized is called."""
|
||||||
|
config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0")
|
||||||
|
trace = MLflowDataTrace(config)
|
||||||
|
# Trigger initialization
|
||||||
|
trace._ensure_initialized()
|
||||||
mock_mlflow.set_tracking_uri.assert_called_with("http://localhost:5000")
|
mock_mlflow.set_tracking_uri.assert_called_with("http://localhost:5000")
|
||||||
mock_mlflow.set_experiment.assert_called_with(experiment_id="0")
|
mock_mlflow.set_experiment.assert_called_with(experiment_id="0")
|
||||||
assert trace.get_project_url() == "http://localhost:5000/#/experiments/0/traces"
|
assert trace.get_project_url() == "http://localhost:5000/#/experiments/0/traces"
|
||||||
assert os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] == "true"
|
|
||||||
|
|
||||||
def test_mlflow_config_with_auth(self, mock_mlflow):
|
def test_mlflow_config_with_auth(self, mock_mlflow):
|
||||||
config = MLflowConfig(
|
config = MLflowConfig(
|
||||||
|
|
@ -245,7 +258,9 @@ class TestInit:
|
||||||
username="user",
|
username="user",
|
||||||
password="pass",
|
password="pass",
|
||||||
)
|
)
|
||||||
MLflowDataTrace(config)
|
trace = MLflowDataTrace(config)
|
||||||
|
# Trigger initialization to set auth env vars
|
||||||
|
trace._ensure_initialized()
|
||||||
assert os.environ["MLFLOW_TRACKING_USERNAME"] == "user"
|
assert os.environ["MLFLOW_TRACKING_USERNAME"] == "user"
|
||||||
assert os.environ["MLFLOW_TRACKING_PASSWORD"] == "pass"
|
assert os.environ["MLFLOW_TRACKING_PASSWORD"] == "pass"
|
||||||
|
|
||||||
|
|
@ -257,6 +272,8 @@ class TestInit:
|
||||||
client_secret="csec",
|
client_secret="csec",
|
||||||
)
|
)
|
||||||
trace = MLflowDataTrace(config)
|
trace = MLflowDataTrace(config)
|
||||||
|
# Trigger initialization
|
||||||
|
trace._ensure_initialized()
|
||||||
assert os.environ["DATABRICKS_HOST"] == "https://db.com/"
|
assert os.environ["DATABRICKS_HOST"] == "https://db.com/"
|
||||||
assert os.environ["DATABRICKS_CLIENT_ID"] == "cid"
|
assert os.environ["DATABRICKS_CLIENT_ID"] == "cid"
|
||||||
assert os.environ["DATABRICKS_CLIENT_SECRET"] == "csec"
|
assert os.environ["DATABRICKS_CLIENT_SECRET"] == "csec"
|
||||||
|
|
@ -271,13 +288,31 @@ class TestInit:
|
||||||
personal_access_token="pat",
|
personal_access_token="pat",
|
||||||
)
|
)
|
||||||
trace = MLflowDataTrace(config)
|
trace = MLflowDataTrace(config)
|
||||||
|
# Trigger initialization
|
||||||
|
trace._ensure_initialized()
|
||||||
assert os.environ["DATABRICKS_TOKEN"] == "pat"
|
assert os.environ["DATABRICKS_TOKEN"] == "pat"
|
||||||
assert "db.com/ml/experiments/1/traces" in trace.get_project_url()
|
assert "db.com/ml/experiments/1/traces" in trace.get_project_url()
|
||||||
|
|
||||||
def test_databricks_no_creds_raises(self, mock_mlflow):
|
def test_databricks_no_creds_raises_on_init(self, mock_mlflow):
|
||||||
config = DatabricksConfig(host="https://db.com", experiment_id="1")
|
config = DatabricksConfig(host="https://db.com", experiment_id="1")
|
||||||
|
trace = MLflowDataTrace(config)
|
||||||
|
# Error is raised when initializing, not during __init__
|
||||||
with pytest.raises(ValueError, match="Either Databricks token"):
|
with pytest.raises(ValueError, match="Either Databricks token"):
|
||||||
MLflowDataTrace(config)
|
trace._ensure_initialized()
|
||||||
|
|
||||||
|
def test_initialization_error_handled_gracefully(self, mock_mlflow):
|
||||||
|
"""Test that initialization errors are handled gracefully in trace method."""
|
||||||
|
config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0")
|
||||||
|
trace = MLflowDataTrace(config)
|
||||||
|
# Simulate initialization failure
|
||||||
|
mock_mlflow.set_tracking_uri.side_effect = ConnectionError("Connection refused")
|
||||||
|
# Should not raise when calling trace - error is logged and tracing is disabled
|
||||||
|
trace_info = _make_workflow_trace_info()
|
||||||
|
# trace() will call _ensure_initialized() which will catch the error
|
||||||
|
# and set _initialized=True to avoid repeated attempts
|
||||||
|
trace.trace(trace_info)
|
||||||
|
# After failed init, the instance should be marked as initialized
|
||||||
|
assert trace._initialized is True
|
||||||
|
|
||||||
|
|
||||||
# ── trace dispatcher ────────────────────────────────────────────────────────
|
# ── trace dispatcher ────────────────────────────────────────────────────────
|
||||||
|
|
|
||||||
|
|
@ -517,16 +517,41 @@ def test_extract_streaming_metrics_invalid_json():
|
||||||
|
|
||||||
|
|
||||||
def test_trace_queue_manager_add_and_collect(monkeypatch):
|
def test_trace_queue_manager_add_and_collect(monkeypatch):
|
||||||
monkeypatch.setattr(
|
# With lazy initialization, get_ops_trace_instance should NOT be called during __init__
|
||||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
|
mock_get_instance = MagicMock(return_value=True)
|
||||||
)
|
monkeypatch.setattr("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", mock_get_instance)
|
||||||
manager = TraceQueueManager(app_id="app-id", user_id="user")
|
manager = TraceQueueManager(app_id="app-id", user_id="user")
|
||||||
|
# get_ops_trace_instance should NOT be called during initialization (lazy init)
|
||||||
|
mock_get_instance.assert_not_called()
|
||||||
|
|
||||||
|
# When trace_instance property is accessed, it should call get_ops_trace_instance
|
||||||
|
_ = manager.trace_instance
|
||||||
|
mock_get_instance.assert_called_once_with("app-id")
|
||||||
|
|
||||||
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE)
|
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE)
|
||||||
manager.add_trace_task(task)
|
manager.add_trace_task(task)
|
||||||
tasks = manager.collect_tasks()
|
tasks = manager.collect_tasks()
|
||||||
assert tasks == [task]
|
assert tasks == [task]
|
||||||
|
|
||||||
|
|
||||||
|
def test_trace_queue_manager_lazy_init_error_handling(monkeypatch):
|
||||||
|
"""Test that TraceQueueManager handles initialization errors gracefully."""
|
||||||
|
mock_get_instance = MagicMock(side_effect=ConnectionError("Service unavailable"))
|
||||||
|
monkeypatch.setattr("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", mock_get_instance)
|
||||||
|
|
||||||
|
manager = TraceQueueManager(app_id="app-id", user_id="user")
|
||||||
|
# get_ops_trace_instance should NOT be called during initialization
|
||||||
|
mock_get_instance.assert_not_called()
|
||||||
|
|
||||||
|
# When trace_instance is accessed, it should handle the error gracefully
|
||||||
|
instance = manager.trace_instance
|
||||||
|
mock_get_instance.assert_called_once_with("app-id")
|
||||||
|
# Should return None when initialization fails
|
||||||
|
assert instance is None
|
||||||
|
# Should be marked as initialized even on failure
|
||||||
|
assert manager._trace_instance_initialized is True
|
||||||
|
|
||||||
|
|
||||||
def test_trace_queue_manager_run_invokes_send(monkeypatch):
|
def test_trace_queue_manager_run_invokes_send(monkeypatch):
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
|
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue