This commit is contained in:
BillionToken 2026-03-24 22:32:05 +08:00 committed by GitHub
commit 6d37db59e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 113 additions and 14 deletions

View File

@ -41,14 +41,36 @@ def datetime_to_nanoseconds(dt: datetime | None) -> int | None:
class MLflowDataTrace(BaseTraceInstance):
def __init__(self, config: MLflowConfig | DatabricksConfig):
super().__init__(config)
if isinstance(config, DatabricksConfig):
self._setup_databricks(config)
else:
self._setup_mlflow(config)
self._config = config
self._project_url = ""
# Defer actual setup to first use to avoid blocking workflow on initialization
self._initialized = False
# Enable async logging to minimize performance overhead
os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "true"
def _ensure_initialized(self):
"""Lazy initialization to prevent blocking workflow when tracing service is unavailable."""
if self._initialized:
return
try:
if isinstance(self._config, DatabricksConfig):
self._setup_databricks(self._config)
else:
self._setup_mlflow(self._config)
self._initialized = True
except Exception as e:
logger.warning("[MLflow] Failed to initialize tracing: %s. Tracing will be disabled.", e)
# Set a fallback project URL even on failure
if isinstance(self._config, DatabricksConfig):
host = self._config.host.rstrip("/")
self._project_url = f"{host}/ml/experiments/{self._config.experiment_id}/traces"
else:
self._project_url = f"{self._config.tracking_uri}/#/experiments/{self._config.experiment_id}/traces"
# Mark as initialized to avoid repeated attempts
self._initialized = True
def _setup_databricks(self, config: DatabricksConfig):
"""Setup connection to Databricks-managed MLflow instances"""
os.environ["DATABRICKS_HOST"] = config.host
@ -87,6 +109,9 @@ class MLflowDataTrace(BaseTraceInstance):
def trace(self, trace_info: BaseTraceInfo):
"""Simple dispatch to trace methods"""
# Initialize on first use
self._ensure_initialized()
try:
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
@ -104,7 +129,6 @@ class MLflowDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
except Exception:
logger.exception("[MLflow] Trace error")
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
"""Create workflow span as root, with node spans as children"""

View File

@ -935,11 +935,26 @@ class TraceQueueManager:
self.app_id = app_id
self.user_id = user_id
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
# Lazy initialization: don't create trace instance until needed
# This prevents workflow from being blocked if tracing service is unavailable
self._trace_instance = None
self._trace_instance_initialized = False
self.flask_app = current_app._get_current_object() # type: ignore
if trace_manager_timer is None:
self.start_timer()
@property
def trace_instance(self):
"""Lazy initialization of trace instance to avoid blocking workflow on startup."""
if not self._trace_instance_initialized:
self._trace_instance_initialized = True
try:
self._trace_instance = OpsTraceManager.get_ops_trace_instance(self.app_id)
except Exception:
logger.exception("Failed to initialize trace instance for app_id: %s", self.app_id)
self._trace_instance = None
return self._trace_instance
def add_trace_task(self, trace_task: TraceTask):
global trace_manager_timer, trace_manager_queue
try:

View File

@ -230,13 +230,26 @@ class TestDatetimeToNanoseconds:
class TestInit:
def test_mlflow_config_no_auth(self, mock_mlflow):
def test_mlflow_config_no_auth_lazy_init(self, mock_mlflow):
"""Test that MLflow setup is deferred until first use (lazy initialization)."""
config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0")
trace = MLflowDataTrace(config)
# Setup should NOT be called during __init__ (lazy initialization)
mock_mlflow.set_tracking_uri.assert_not_called()
mock_mlflow.set_experiment.assert_not_called()
# Project URL should be empty before initialization
assert trace._project_url == ""
assert os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] == "true"
def test_mlflow_config_initialized_on_first_use(self, mock_mlflow):
"""Test that MLflow setup is called when _ensure_initialized is called."""
config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0")
trace = MLflowDataTrace(config)
# Trigger initialization
trace._ensure_initialized()
mock_mlflow.set_tracking_uri.assert_called_with("http://localhost:5000")
mock_mlflow.set_experiment.assert_called_with(experiment_id="0")
assert trace.get_project_url() == "http://localhost:5000/#/experiments/0/traces"
assert os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] == "true"
def test_mlflow_config_with_auth(self, mock_mlflow):
config = MLflowConfig(
@ -245,7 +258,9 @@ class TestInit:
username="user",
password="pass",
)
MLflowDataTrace(config)
trace = MLflowDataTrace(config)
# Trigger initialization to set auth env vars
trace._ensure_initialized()
assert os.environ["MLFLOW_TRACKING_USERNAME"] == "user"
assert os.environ["MLFLOW_TRACKING_PASSWORD"] == "pass"
@ -257,6 +272,8 @@ class TestInit:
client_secret="csec",
)
trace = MLflowDataTrace(config)
# Trigger initialization
trace._ensure_initialized()
assert os.environ["DATABRICKS_HOST"] == "https://db.com/"
assert os.environ["DATABRICKS_CLIENT_ID"] == "cid"
assert os.environ["DATABRICKS_CLIENT_SECRET"] == "csec"
@ -271,13 +288,31 @@ class TestInit:
personal_access_token="pat",
)
trace = MLflowDataTrace(config)
# Trigger initialization
trace._ensure_initialized()
assert os.environ["DATABRICKS_TOKEN"] == "pat"
assert "db.com/ml/experiments/1/traces" in trace.get_project_url()
def test_databricks_no_creds_raises(self, mock_mlflow):
def test_databricks_no_creds_raises_on_init(self, mock_mlflow):
config = DatabricksConfig(host="https://db.com", experiment_id="1")
trace = MLflowDataTrace(config)
# Error is raised when initializing, not during __init__
with pytest.raises(ValueError, match="Either Databricks token"):
MLflowDataTrace(config)
trace._ensure_initialized()
def test_initialization_error_handled_gracefully(self, mock_mlflow):
"""Test that initialization errors are handled gracefully in trace method."""
config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0")
trace = MLflowDataTrace(config)
# Simulate initialization failure
mock_mlflow.set_tracking_uri.side_effect = ConnectionError("Connection refused")
# Should not raise when calling trace - error is logged and tracing is disabled
trace_info = _make_workflow_trace_info()
# trace() will call _ensure_initialized() which will catch the error
# and set _initialized=True to avoid repeated attempts
trace.trace(trace_info)
# After failed init, the instance should be marked as initialized
assert trace._initialized is True
# ── trace dispatcher ────────────────────────────────────────────────────────

View File

@ -517,16 +517,41 @@ def test_extract_streaming_metrics_invalid_json():
def test_trace_queue_manager_add_and_collect(monkeypatch):
monkeypatch.setattr(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
)
# With lazy initialization, get_ops_trace_instance should NOT be called during __init__
mock_get_instance = MagicMock(return_value=True)
monkeypatch.setattr("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", mock_get_instance)
manager = TraceQueueManager(app_id="app-id", user_id="user")
# get_ops_trace_instance should NOT be called during initialization (lazy init)
mock_get_instance.assert_not_called()
# When trace_instance property is accessed, it should call get_ops_trace_instance
_ = manager.trace_instance
mock_get_instance.assert_called_once_with("app-id")
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE)
manager.add_trace_task(task)
tasks = manager.collect_tasks()
assert tasks == [task]
def test_trace_queue_manager_lazy_init_error_handling(monkeypatch):
"""Test that TraceQueueManager handles initialization errors gracefully."""
mock_get_instance = MagicMock(side_effect=ConnectionError("Service unavailable"))
monkeypatch.setattr("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", mock_get_instance)
manager = TraceQueueManager(app_id="app-id", user_id="user")
# get_ops_trace_instance should NOT be called during initialization
mock_get_instance.assert_not_called()
# When trace_instance is accessed, it should handle the error gracefully
instance = manager.trace_instance
mock_get_instance.assert_called_once_with("app-id")
# Should return None when initialization fails
assert instance is None
# Should be marked as initialized even on failure
assert manager._trace_instance_initialized is True
def test_trace_queue_manager_run_invokes_send(monkeypatch):
monkeypatch.setattr(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)